Merge pull request #15475 from mattjj:shmap-functools-partial-errors

PiperOrigin-RevId: 522699281
This commit is contained in:
jax authors 2023-04-07 15:37:44 -07:00
commit c27972d873
2 changed files with 23 additions and 8 deletions

View File

@ -164,16 +164,17 @@ def _check_specs_vs_args(
def _spec_rank_error(
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
fun_name = getattr(f, '__name__', str(f))
if error_type == SpecErrorType.input:
prefix, base = 'in', 'args'
ba = _try_infer_args(f, tree)
else:
prefix, base = 'out', f'{f.__name__}(*args)'
prefix, base = 'out', f'{fun_name}(*args)'
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
if error_type == SpecErrorType.input and ba is not None:
arg_key, *_ = fail_key
extra = (f", where {base}[{arg_key}] is bound to {f.__name__}'s "
extra = (f", where {base}[{arg_key}] is bound to {fun_name}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
else:
extra = ""
@ -183,13 +184,13 @@ def _spec_rank_error(
f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, "
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
assert msgs
msg = (f"shard_map applied to the function '{f.__name__}' was given an "
msg = (f"shard_map applied to the function '{fun_name}' was given an "
f"{prefix}_specs entry which is too long to be compatible with the "
f"corresponding {prefix}put value from the function:\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
f"Entries in {prefix}_specs must be of length no greater than the "
f"number of axes in the corresponding {prefix}put value.\n\n"
f"Either revise the spec to be shorter, or modify '{f.__name__}' so "
f"Either revise the spec to be shorter, or modify '{fun_name}' so "
f"that its {prefix}puts have sufficient rank.")
return msg
@ -197,11 +198,12 @@ def _spec_divisibility_error(
f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
fails: List[Union[core.ShapedArray, NoFail]]) -> str:
ba = _try_infer_args(f, tree)
fun_name = getattr(f, '__name__', str(f))
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
if ba is not None:
arg_key, *_ = fail_key
extra = (f", where args[{arg_key}] is bound to {f.__name__}'s "
extra = (f", where args[{arg_key}] is bound to {fun_name}'s "
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
names = _canonicalize_spec(spec)
for d, ns in names.items():
@ -216,7 +218,7 @@ def _spec_divisibility_error(
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
f"{aval.shape[d]}")
assert msgs
msg = (f"shard_map applied to the function '{f.__name__}' was given argument "
msg = (f"shard_map applied to the function '{fun_name}' was given argument "
f"arrays with axis sizes that are not evenly divisible by the "
f"corresponding mesh axis sizes:\n\n"
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "
@ -226,11 +228,12 @@ def _spec_divisibility_error(
f"axis or axes indicated by the corresponding elements of the "
f"argument's in_specs entry. Consider checking that in_specs are "
f"correct, and if so consider changing the mesh axis sizes or else "
f"padding the input and adapting '{f.__name__}' appropriately.")
f"padding the input and adapting '{fun_name}' appropriately.")
return msg
def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
fails: List[Union[Set, NoFail]]) -> str:
fun_name = getattr(f, '__name__', str(f))
msgs = []
for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails):
dst = _canonicalize_spec(spec)
@ -251,7 +254,7 @@ def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
f"corresponding output value is replicated across mesh axis "
f"'{need_rep_}', but could not infer replication over any axes")
assert msgs
msg = (f"shard_map applied to the function '{f.__name__}' was given "
msg = (f"shard_map applied to the function '{fun_name}' was given "
f"out_specs which require replication which can't be statically "
f"inferred given the mesh:\n\n"
f"The mesh given has shape {mesh.device_ids.shape} with corresponding "

View File

@ -569,6 +569,18 @@ class ShardMapTest(jtu.JaxTestCase):
g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x', None))
_ = g(sharded_rng) # don't crash!
def test_functools_partial_rank_error(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial
def f(x):
return x
g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x',))
x = jnp.arange(4)
with self.assertRaises(ValueError):
g(x)
class FunSpec(NamedTuple):
name: str