Improve the shape incompatible error message by adding the argument/result name path to it.

PiperOrigin-RevId: 529605855
This commit is contained in:
Yash Katariya 2023-05-04 21:49:28 -07:00 committed by jax authors
parent 36ad0d4459
commit a6254c75e0
4 changed files with 28 additions and 19 deletions

View File

@ -592,7 +592,7 @@ def bench_pjit_check_aval_sharding(state):
aval = jax.core.ShapedArray((8, 2), np.int32)
while state:
pjit_check_aval_sharding([s] * 100, [aval] * 100, 'benchmark', False)
pjit_check_aval_sharding([s] * 100, [aval] * 100, None, 'benchmark', False)
@google_benchmark.register

View File

@ -571,7 +571,7 @@ def _check_sharding(aval, s):
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
pjit.pjit_check_aval_sharding(
(s,), (aval,), "device_put args", allow_uneven_sharding=False)
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
assert isinstance(aval, core.ShapedArray), aval
s.shard_shape(aval.shape) # should raise an Error if incompatible

View File

@ -466,7 +466,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
in_type = in_avals = tuple(avals)
canonicalized_in_shardings_flat = _process_in_axis_resources(
hashable_pytree(in_shardings), in_avals, in_tree, resource_env)
hashable_pytree(in_shardings), in_avals, in_tree, resource_env, dbg)
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
flat_fun, hashable_pytree(out_shardings), in_type, dbg,
@ -851,7 +851,7 @@ class PytreeLeaf:
@lru_cache(maxsize=4096)
def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
resource_env):
resource_env, debug_info):
orig_in_shardings = in_shardings_thunk()
# Only do this if original in_shardings are unspecified. If it is AUTO, go
# via flatten_axis_resources.
@ -864,6 +864,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
if not config.jax_dynamic_shapes:
pjit_check_aval_sharding(in_shardings_flat, in_avals,
None if debug_info is None else debug_info.arg_names,
"pjit arguments", allow_uneven_sharding=False)
canonicalized_shardings = tuple(
i if is_unspecified_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
@ -898,7 +899,7 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
@lru_cache(maxsize=4096)
def _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_tree, out_type):
out_shardings_thunk, out_tree, out_type, debug_info):
orig_out_shardings = out_shardings_thunk()
# TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources
# instead. This condition exists because flatten_axis_resources passes in an
@ -913,8 +914,10 @@ def _check_and_canonicalize_out_shardings(
tupled_args=False)
if not config.jax_dynamic_shapes:
pjit_check_aval_sharding(out_shardings_flat, out_type, "pjit outputs",
allow_uneven_sharding=False)
pjit_check_aval_sharding(
out_shardings_flat, out_type,
None if debug_info is None else debug_info.result_paths,
"pjit outputs", allow_uneven_sharding=False)
canonicalized_out_shardings_flat = tuple(
o if is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim)
@ -928,16 +931,19 @@ def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info, out_tree,
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
fun, in_type, debug_info, result_paths)
canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_tree, tuple(out_type))
out_shardings_thunk, out_tree, tuple(out_type), jaxpr.jaxpr.debug_info)
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return jaxpr, final_consts, canonicalized_out_shardings_flat
def pjit_check_aval_sharding(
shardings, flat_avals, what_aval: str, allow_uneven_sharding: bool):
for aval, s in zip(flat_avals, shardings):
shardings, flat_avals, names: Optional[Tuple[str, ...]],
what_aval: str, allow_uneven_sharding: bool):
new_names = [''] * len(shardings) if names is None else names
for aval, s, name in zip(flat_avals, shardings, new_names):
if is_unspecified_or_auto(s):
continue
name_str = f' with pytree key path {name}' if name else ''
shape = aval.shape
try:
# Sharding interfaces can implement `is_compatible_aval` as an optional
@ -947,8 +953,9 @@ def pjit_check_aval_sharding(
else:
s._to_xla_op_sharding(len(shape))
except ValueError as e:
raise ValueError(f'One of {what_aval} is incompatible with its sharding '
f'annotation {s}: {str(e)}')
raise ValueError(
f'One of {what_aval}{name_str} is incompatible with its sharding '
f'annotation {s}: {str(e)}')
# Use the `OpSharding` proto to find out how many ways each dimension of
# the aval is sharded. This approach will work across all
# XLACompatibleSharding.
@ -958,11 +965,11 @@ def pjit_check_aval_sharding(
cast(xc.OpSharding, op_sharding))
for i, size in enumerate(num_ways_dim_sharded):
if not allow_uneven_sharding and shape[i] % size != 0:
raise ValueError(f"One of {what_aval} was given the sharding "
raise ValueError(f"One of {what_aval}{name_str} was given the sharding "
f"of {s}, which implies that "
f"the global size of its dimension {i} should be "
f"divisible by {size}, but it is equal to {shape[i]} "
f"(full shape: {shape}) ")
f"(full shape: {shape})")
# -------------------- pjit rules --------------------
@ -1797,8 +1804,9 @@ def with_sharding_constraint(x, shardings=UNSPECIFIED,
for s in shardings_flat]
del user_shardings_flat
pjit_check_aval_sharding(shardings_flat, x_flat, "with_sharding_constraint arguments",
allow_uneven_sharding=True)
pjit_check_aval_sharding(
shardings_flat, x_flat, None, "with_sharding_constraint arguments",
allow_uneven_sharding=True)
outs = [sharding_constraint_p.bind(xf, sharding=to_gspmd_sharding(i, xf.ndim),
resource_env=resource_env,

View File

@ -3378,7 +3378,7 @@ class PJitErrorTest(jtu.JaxTestCase):
spec = P(resources, None)
mesh_size = str(math.prod([dim[1] for dim in mesh]))
error = re.compile(
r"One of pjit arguments.*" + spec_regex(spec) + r".*"
r"One of pjit arguments with pytree key path x.*" + spec_regex(spec) + r".*"
r"implies that the global size of its dimension 0 should be "
r"divisible by " + mesh_size + r", but it is equal to 3 "
r"\(full shape: \(3, 2\)\)", re.M | re.S)
@ -3391,11 +3391,12 @@ class PJitErrorTest(jtu.JaxTestCase):
spec = P(resources, None)
mesh_size = str(math.prod([dim[1] for dim in mesh]))
error = re.compile(
r"One of pjit outputs.*" + spec_regex(spec) + r".*"
r"One of pjit outputs with pytree key path \['rrr'\].*" + spec_regex(spec) + r".*"
r"implies that the global size of its dimension 0 should be "
r"divisible by " + mesh_size + r", but it is equal to 3", re.M | re.S)
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, in_shardings=None, out_shardings=P(resources, None))(x)
pjit(lambda x: {'rrr': x}, in_shardings=None,
out_shardings=P(resources, None))(x)
@check_1d_2d_mesh(set_mesh=False)
@jtu.with_mesh([('z', 1)])