mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[sharding_in_types] Error out when PartitionSpec is passed to APIs that take out_sharding
like einsum when context_mesh is unset.
This change is raising a better error because doing `NamedSharding(empty_mesh, P('x'))` will raise an error on construction but it is uglier than the current error added in this change. PiperOrigin-RevId: 726253654
This commit is contained in:
parent
876668faa1
commit
15cd83ae00
@ -1783,6 +1783,12 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
|
||||
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if isinstance(sharding, PartitionSpec):
|
||||
if cur_mesh.empty:
|
||||
raise ValueError(
|
||||
'Using PartitionSpec when you are not under a mesh context via'
|
||||
' `jax.sharding.use_mesh` is not allowed. Please pass a'
|
||||
' NamedSharding instance or enter into a mesh context via'
|
||||
f' `jax.sharding.use_mesh`. Got {sharding}')
|
||||
sharding = NamedSharding(cur_mesh, sharding) # type: ignore
|
||||
else:
|
||||
if (check_mesh_consistency and not cur_mesh.empty and
|
||||
|
@ -6621,6 +6621,23 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
s = NamedSharding(mesh, P(P.UNCONSTRAINED))
|
||||
jax.lax.with_sharding_constraint(np.arange(8), s)
|
||||
|
||||
@config.sharding_in_types(True)
|
||||
def test_pspec_einsum_no_context_mesh(self):
|
||||
mesh = jtu.create_mesh((1, 1), ('x', 'y'),
|
||||
axis_types={AxisTypes.Explicit: ('x', 'y')})
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', None)))
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', 'y'))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Using PartitionSpec when.*not under a mesh context.*is not allowed"):
|
||||
f(arr, arr2)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user