Improves error message in case of invalid sharding mesh

PiperOrigin-RevId: 661358450
This commit is contained in:
jax authors 2024-08-09 12:17:28 -07:00 committed by jax authors
parent aa334145b4
commit 3bd3597703
2 changed files with 14 additions and 15 deletions

View File

@ -53,18 +53,17 @@ class TransferToMemoryKind:
@util.cache(max_size=128, trace_context_in_key=False)
def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
try:
for p in parsed_pspec:
if p is not None:
for r in p:
mesh.shape[r]
if r in _manual_axes:
raise ValueError(
f"Axis: {r} of {parsed_pspec.get_partition_spec()} "
f"is also found in manual_axes: {_manual_axes}.") from None
except KeyError as e:
raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is "
"undefined.") from None
for p in parsed_pspec:
if p is not None:
for r in p:
if r not in mesh.shape:
raise ValueError(
f"Resource axis: {r} of {parsed_pspec.get_partition_spec()} "
f"is not found in mesh: {tuple(mesh.shape.keys())}.")
if r in _manual_axes:
raise ValueError(
f"Axis: {r} of {parsed_pspec.get_partition_spec()} "
f"is also found in manual_axes: {_manual_axes}.") from None
def hashed_index(x) -> int:

View File

@ -4398,7 +4398,7 @@ class PJitErrorTest(jtu.JaxTestCase):
spec = P(resources,)
with self.assertRaisesRegex(
ValueError,
r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"):
r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."):
pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x)
@check_1d_2d_mesh(set_mesh=False)
@ -4408,7 +4408,7 @@ class PJitErrorTest(jtu.JaxTestCase):
spec = P(resources,)
with self.assertRaisesRegex(
ValueError,
r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"):
r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."):
pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x)
@check_1d_2d_mesh(set_mesh=False)
@ -4418,7 +4418,7 @@ class PJitErrorTest(jtu.JaxTestCase):
spec = P(resources,)
with self.assertRaisesRegex(
ValueError,
r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"):
r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."):
pjit(
lambda x: with_sharding_constraint(x, spec),
in_shardings=None,