mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improves error message in case of invalid sharding mesh
PiperOrigin-RevId: 661358450
This commit is contained in:
parent
aa334145b4
commit
3bd3597703
@ -53,18 +53,17 @@ class TransferToMemoryKind:
|
||||
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
|
||||
try:
|
||||
for p in parsed_pspec:
|
||||
if p is not None:
|
||||
for r in p:
|
||||
mesh.shape[r]
|
||||
if r in _manual_axes:
|
||||
raise ValueError(
|
||||
f"Axis: {r} of {parsed_pspec.get_partition_spec()} "
|
||||
f"is also found in manual_axes: {_manual_axes}.") from None
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is "
|
||||
"undefined.") from None
|
||||
for p in parsed_pspec:
|
||||
if p is not None:
|
||||
for r in p:
|
||||
if r not in mesh.shape:
|
||||
raise ValueError(
|
||||
f"Resource axis: {r} of {parsed_pspec.get_partition_spec()} "
|
||||
f"is not found in mesh: {tuple(mesh.shape.keys())}.")
|
||||
if r in _manual_axes:
|
||||
raise ValueError(
|
||||
f"Axis: {r} of {parsed_pspec.get_partition_spec()} "
|
||||
f"is also found in manual_axes: {_manual_axes}.") from None
|
||||
|
||||
|
||||
def hashed_index(x) -> int:
|
||||
|
@ -4398,7 +4398,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
spec = P(resources,)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"):
|
||||
r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."):
|
||||
pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x)
|
||||
|
||||
@check_1d_2d_mesh(set_mesh=False)
|
||||
@ -4408,7 +4408,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
spec = P(resources,)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"):
|
||||
r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."):
|
||||
pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x)
|
||||
|
||||
@check_1d_2d_mesh(set_mesh=False)
|
||||
@ -4418,7 +4418,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
spec = P(resources,)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"):
|
||||
r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."):
|
||||
pjit(
|
||||
lambda x: with_sharding_constraint(x, spec),
|
||||
in_shardings=None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user