diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index d23feb93c..03f7318c7 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1783,6 +1783,12 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, cur_mesh = mesh_lib.get_abstract_mesh() if isinstance(sharding, PartitionSpec): + if cur_mesh.empty: + raise ValueError( + 'Using PartitionSpec when you are not under a mesh context via' + ' `jax.sharding.use_mesh` is not allowed. Please pass a' + ' NamedSharding instance or enter into a mesh context via' + f' `jax.sharding.use_mesh`. Got {sharding}') sharding = NamedSharding(cur_mesh, sharding) # type: ignore else: if (check_mesh_consistency and not cur_mesh.empty and diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 1923d5ae4..f693f67ae 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6621,6 +6621,23 @@ class ShardingInTypesTest(jtu.JaxTestCase): s = NamedSharding(mesh, P(P.UNCONSTRAINED)) jax.lax.with_sharding_constraint(np.arange(8), s) + @config.sharding_in_types(True) + def test_pspec_einsum_no_context_mesh(self): + mesh = jtu.create_mesh((1, 1), ('x', 'y'), + axis_types={AxisTypes.Explicit: ('x', 'y')}) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', None))) + + @jax.jit + def f(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', 'y')) + + with self.assertRaisesRegex( + ValueError, + "Using PartitionSpec when.*not under a mesh context.*is not allowed"): + f(arr, arr2) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):