Make device_put resharding on single device array input work under use_mesh. Fixes https://github.com/jax-ml/jax/issues/26552

PiperOrigin-RevId: 728382461
This commit is contained in:
Yash Katariya 2025-02-18 15:22:06 -08:00 committed by jax authors
parent 00d8297071
commit 8bcbf585df
4 changed files with 46 additions and 3 deletions

View File

@ -33,6 +33,7 @@ from jax._src import errors
from jax._src import profiler
from jax._src import util
from jax._src import xla_bridge
from jax._src.mesh import set_concrete_mesh
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
@ -1158,7 +1159,10 @@ def shard_device_array(x, devices, indices, sharding):
if sharding.is_fully_replicated:
shards = [x] * len(devices)
else:
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
# TODO(yashkatariya): Maybe this should be set when we call the handler in
# InputsHandler.__call__?
with set_concrete_mesh(None):
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
aval = core.shaped_abstractify(x)
return pxla.batched_device_put(aval, sharding, shards, devices)

View File

@ -550,7 +550,7 @@ def get_abstract_mesh():
@contextlib.contextmanager
def set_concrete_mesh(mesh: Mesh):
def set_concrete_mesh(mesh: Mesh | None):
prev_val = jax_config.device_context.swap_local(mesh)
try:
yield

View File

@ -405,7 +405,6 @@ class LayoutTest(jtu.JaxTestCase):
self.assertArraysEqual(out, inp.T)
def test_device_put_user_concrete_layout(self):
shape = (8, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
dll = DLL(major_to_minor=(1, 0))
@ -416,6 +415,27 @@ class LayoutTest(jtu.JaxTestCase):
dll.major_to_minor)
self.assertArraysEqual(out, np_inp)
def test_device_put_user_concrete_layout_multi_device(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape = (16, 128)
s = NamedSharding(mesh, P('x'))
np_inp = np.arange(math.prod(shape)).reshape(shape)
jnp_inp = jnp.arange(math.prod(shape)).reshape(shape)
arr = jax.device_put(np_inp, s)
custom_layout = Layout(DLL(major_to_minor=(0, 1)), s)
out1 = jax.device_put(arr, custom_layout)
with jax.sharding.use_mesh(mesh):
out2 = jax.device_put(arr, custom_layout)
out3 = jax.device_put(jnp_inp, custom_layout)
out4 = jax.device_put(np_inp, custom_layout)
for o in [out1, out2, out3, out4]:
self.assertArraysEqual(o, np_inp)
self.assertEqual(o.layout.device_local_layout.major_to_minor,
custom_layout.device_local_layout.major_to_minor)
def test_concrete_layout_jit(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
shape = (16, 128)

View File

@ -6670,6 +6670,25 @@ class ShardingInTypesTest(jtu.JaxTestCase):
"PartitionSpec cannot contain axis names that are.*Auto.*Manual"):
f(arr, arr2)
def test_device_put_under_use_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
x = jnp.zeros((4, 4), dtype=jnp.int32)
x_np = np.zeros((4, 4), dtype=np.int32)
s = NamedSharding(mesh, P('x', 'y'))
with jax.sharding.use_mesh(mesh):
y = jax.device_put(x, s)
self.assertArraysEqual(y, x)
self.assertEqual(y.sharding, s)
y2 = jax.device_put(x_np, s)
self.assertArraysEqual(y2, x_np)
self.assertEqual(y2.sharding, s)
s2 = NamedSharding(mesh, P('x'))
z = jax.device_put(y, s2)
self.assertArraysEqual(z, x)
self.assertEqual(z.sharding, s2)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):