mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15475 from mattjj:shmap-functools-partial-errors
PiperOrigin-RevId: 522699281
This commit is contained in:
commit
c27972d873
@ -164,16 +164,17 @@ def _check_specs_vs_args(
|
||||
def _spec_rank_error(
|
||||
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
|
||||
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
|
||||
fun_name = getattr(f, '__name__', str(f))
|
||||
if error_type == SpecErrorType.input:
|
||||
prefix, base = 'in', 'args'
|
||||
ba = _try_infer_args(f, tree)
|
||||
else:
|
||||
prefix, base = 'out', f'{f.__name__}(*args)'
|
||||
prefix, base = 'out', f'{fun_name}(*args)'
|
||||
msgs = []
|
||||
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
||||
if error_type == SpecErrorType.input and ba is not None:
|
||||
arg_key, *_ = fail_key
|
||||
extra = (f", where {base}[{arg_key}] is bound to {f.__name__}'s "
|
||||
extra = (f", where {base}[{arg_key}] is bound to {fun_name}'s "
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
|
||||
else:
|
||||
extra = ""
|
||||
@ -183,13 +184,13 @@ def _spec_rank_error(
|
||||
f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, "
|
||||
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
|
||||
assert msgs
|
||||
msg = (f"shard_map applied to the function '{f.__name__}' was given an "
|
||||
msg = (f"shard_map applied to the function '{fun_name}' was given an "
|
||||
f"{prefix}_specs entry which is too long to be compatible with the "
|
||||
f"corresponding {prefix}put value from the function:\n\n"
|
||||
+ '\n\n'.join(msgs) + '\n\n' +
|
||||
f"Entries in {prefix}_specs must be of length no greater than the "
|
||||
f"number of axes in the corresponding {prefix}put value.\n\n"
|
||||
f"Either revise the spec to be shorter, or modify '{f.__name__}' so "
|
||||
f"Either revise the spec to be shorter, or modify '{fun_name}' so "
|
||||
f"that its {prefix}puts have sufficient rank.")
|
||||
return msg
|
||||
|
||||
@ -197,11 +198,12 @@ def _spec_divisibility_error(
|
||||
f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
|
||||
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
|
||||
ba = _try_infer_args(f, tree)
|
||||
fun_name = getattr(f, '__name__', str(f))
|
||||
msgs = []
|
||||
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
||||
if ba is not None:
|
||||
arg_key, *_ = fail_key
|
||||
extra = (f", where args[{arg_key}] is bound to {f.__name__}'s "
|
||||
extra = (f", where args[{arg_key}] is bound to {fun_name}'s "
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
|
||||
names = _canonicalize_spec(spec)
|
||||
for d, ns in names.items():
|
||||
@ -216,7 +218,7 @@ def _spec_divisibility_error(
|
||||
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
|
||||
f"{aval.shape[d]}")
|
||||
assert msgs
|
||||
msg = (f"shard_map applied to the function '{f.__name__}' was given argument "
|
||||
msg = (f"shard_map applied to the function '{fun_name}' was given argument "
|
||||
f"arrays with axis sizes that are not evenly divisible by the "
|
||||
f"corresponding mesh axis sizes:\n\n"
|
||||
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "
|
||||
@ -226,11 +228,12 @@ def _spec_divisibility_error(
|
||||
f"axis or axes indicated by the corresponding elements of the "
|
||||
f"argument's in_specs entry. Consider checking that in_specs are "
|
||||
f"correct, and if so consider changing the mesh axis sizes or else "
|
||||
f"padding the input and adapting '{f.__name__}' appropriately.")
|
||||
f"padding the input and adapting '{fun_name}' appropriately.")
|
||||
return msg
|
||||
|
||||
def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
|
||||
fails: List[Union[Set, NoFail]]) -> str:
|
||||
fun_name = getattr(f, '__name__', str(f))
|
||||
msgs = []
|
||||
for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails):
|
||||
dst = _canonicalize_spec(spec)
|
||||
@ -251,7 +254,7 @@ def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
|
||||
f"corresponding output value is replicated across mesh axis "
|
||||
f"'{need_rep_}', but could not infer replication over any axes")
|
||||
assert msgs
|
||||
msg = (f"shard_map applied to the function '{f.__name__}' was given "
|
||||
msg = (f"shard_map applied to the function '{fun_name}' was given "
|
||||
f"out_specs which require replication which can't be statically "
|
||||
f"inferred given the mesh:\n\n"
|
||||
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "
|
||||
|
@ -569,6 +569,18 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x', None))
|
||||
_ = g(sharded_rng) # don't crash!
|
||||
|
||||
def test_functools_partial_rank_error(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x',))
|
||||
x = jnp.arange(4)
|
||||
with self.assertRaises(ValueError):
|
||||
g(x)
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user