[sharding_in_types] Error out when PartitionSpec is passed to APIs that take out_sharding like einsum when context_mesh is unset.

This change is raising a better error because doing `NamedSharding(empty_mesh, P('x'))` will raise an error on construction but it is uglier than the current error added in this change.

PiperOrigin-RevId: 726253654
This commit is contained in:
Yash Katariya 2025-02-12 17:12:12 -08:00 committed by jax authors
parent 876668faa1
commit 15cd83ae00
2 changed files with 23 additions and 0 deletions

View File

@ -1783,6 +1783,12 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
cur_mesh = mesh_lib.get_abstract_mesh()
if isinstance(sharding, PartitionSpec):
if cur_mesh.empty:
raise ValueError(
'Using PartitionSpec when you are not under a mesh context via'
' `jax.sharding.use_mesh` is not allowed. Please pass a'
' NamedSharding instance or enter into a mesh context via'
f' `jax.sharding.use_mesh`. Got {sharding}')
sharding = NamedSharding(cur_mesh, sharding) # type: ignore
else:
if (check_mesh_consistency and not cur_mesh.empty and

View File

@ -6621,6 +6621,23 @@ class ShardingInTypesTest(jtu.JaxTestCase):
s = NamedSharding(mesh, P(P.UNCONSTRAINED))
jax.lax.with_sharding_constraint(np.arange(8), s)
@config.sharding_in_types(True)
def test_pspec_einsum_no_context_mesh(self):
mesh = jtu.create_mesh((1, 1), ('x', 'y'),
axis_types={AxisTypes.Explicit: ('x', 'y')})
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', None)))
@jax.jit
def f(x, y):
return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', 'y'))
with self.assertRaisesRegex(
ValueError,
"Using PartitionSpec when.*not under a mesh context.*is not allowed"):
f(arr, arr2)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):