Raise a better error message if None is passed to with_sharding_constraint.

PiperOrigin-RevId: 690672618
This commit is contained in:
Yash Katariya 2024-10-28 10:49:06 -07:00 committed by jax authors
parent bc03e5053a
commit 987dfaef1c
2 changed files with 14 additions and 1 deletions

View File

@ -2449,12 +2449,18 @@ def with_sharding_constraint(x, shardings):
shardings_flat = [_create_sharding_for_array(mesh, a, 'shardings',
'with_sharding_constraint')
for a in user_shardings_flat]
for s, u in zip(shardings_flat, user_shardings_flat):
if isinstance(s, (UnspecifiedValue, AUTO)):
raise ValueError(
f'One of with_sharding_constraint arguments got sharding {u} which is'
' not allowed. Please only pass `jax.sharding.Sharding` instances.')
del user_shardings_flat
# TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's
# already part of the shardings.
unconstrained_dims = [get_unconstrained_dims(s)
if isinstance(s, NamedSharding) else {}
for s in shardings_flat]
del user_shardings_flat
pjit_check_aval_sharding(
shardings_flat, x_flat, None, "with_sharding_constraint arguments",

View File

@ -3544,6 +3544,13 @@ class ArrayPjitTest(jtu.JaxTestCase):
out2 = pjit(identity)(arr2)
self.assertIsInstance(out2.sharding, PositionalSharding)
def test_wsc_error_on_none(self):
with self.assertRaisesRegex(
ValueError,
'One of with_sharding_constraint arguments got sharding None which is'
' not allowed'):
with_sharding_constraint(jnp.arange(8), None)
def test_sharding_preserved_aot(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))