mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Raise a better error message if None
is passed to with_sharding_constraint.
PiperOrigin-RevId: 690672618
This commit is contained in:
parent
bc03e5053a
commit
987dfaef1c
@ -2449,12 +2449,18 @@ def with_sharding_constraint(x, shardings):
|
||||
shardings_flat = [_create_sharding_for_array(mesh, a, 'shardings',
|
||||
'with_sharding_constraint')
|
||||
for a in user_shardings_flat]
|
||||
for s, u in zip(shardings_flat, user_shardings_flat):
|
||||
if isinstance(s, (UnspecifiedValue, AUTO)):
|
||||
raise ValueError(
|
||||
f'One of with_sharding_constraint arguments got sharding {u} which is'
|
||||
' not allowed. Please only pass `jax.sharding.Sharding` instances.')
|
||||
del user_shardings_flat
|
||||
|
||||
# TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's
|
||||
# already part of the shardings.
|
||||
unconstrained_dims = [get_unconstrained_dims(s)
|
||||
if isinstance(s, NamedSharding) else {}
|
||||
for s in shardings_flat]
|
||||
del user_shardings_flat
|
||||
|
||||
pjit_check_aval_sharding(
|
||||
shardings_flat, x_flat, None, "with_sharding_constraint arguments",
|
||||
|
@ -3544,6 +3544,13 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out2 = pjit(identity)(arr2)
|
||||
self.assertIsInstance(out2.sharding, PositionalSharding)
|
||||
|
||||
def test_wsc_error_on_none(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'One of with_sharding_constraint arguments got sharding None which is'
|
||||
' not allowed'):
|
||||
with_sharding_constraint(jnp.arange(8), None)
|
||||
|
||||
def test_sharding_preserved_aot(self):
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
ns = NamedSharding(mesh, P('x'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user