mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Make device_put resharding on single device array input work under use_mesh. Fixes https://github.com/jax-ml/jax/issues/26552
PiperOrigin-RevId: 728382461
This commit is contained in:
parent
00d8297071
commit
8bcbf585df
@ -33,6 +33,7 @@ from jax._src import errors
|
||||
from jax._src import profiler
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.mesh import set_concrete_mesh
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
@ -1158,7 +1159,10 @@ def shard_device_array(x, devices, indices, sharding):
|
||||
if sharding.is_fully_replicated:
|
||||
shards = [x] * len(devices)
|
||||
else:
|
||||
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
||||
# TODO(yashkatariya): Maybe this should be set when we call the handler in
|
||||
# InputsHandler.__call__?
|
||||
with set_concrete_mesh(None):
|
||||
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
|
||||
aval = core.shaped_abstractify(x)
|
||||
return pxla.batched_device_put(aval, sharding, shards, devices)
|
||||
|
||||
|
@ -550,7 +550,7 @@ def get_abstract_mesh():
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_concrete_mesh(mesh: Mesh):
|
||||
def set_concrete_mesh(mesh: Mesh | None):
|
||||
prev_val = jax_config.device_context.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
|
@ -405,7 +405,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, inp.T)
|
||||
|
||||
def test_device_put_user_concrete_layout(self):
|
||||
|
||||
shape = (8, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
dll = DLL(major_to_minor=(1, 0))
|
||||
@ -416,6 +415,27 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
dll.major_to_minor)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
def test_device_put_user_concrete_layout_multi_device(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (16, 128)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
jnp_inp = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
custom_layout = Layout(DLL(major_to_minor=(0, 1)), s)
|
||||
out1 = jax.device_put(arr, custom_layout)
|
||||
|
||||
with jax.sharding.use_mesh(mesh):
|
||||
out2 = jax.device_put(arr, custom_layout)
|
||||
out3 = jax.device_put(jnp_inp, custom_layout)
|
||||
out4 = jax.device_put(np_inp, custom_layout)
|
||||
|
||||
for o in [out1, out2, out3, out4]:
|
||||
self.assertArraysEqual(o, np_inp)
|
||||
self.assertEqual(o.layout.device_local_layout.major_to_minor,
|
||||
custom_layout.device_local_layout.major_to_minor)
|
||||
|
||||
def test_concrete_layout_jit(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (16, 128)
|
||||
|
@ -6670,6 +6670,25 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
"PartitionSpec cannot contain axis names that are.*Auto.*Manual"):
|
||||
f(arr, arr2)
|
||||
|
||||
def test_device_put_under_use_mesh(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
x = jnp.zeros((4, 4), dtype=jnp.int32)
|
||||
x_np = np.zeros((4, 4), dtype=np.int32)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
with jax.sharding.use_mesh(mesh):
|
||||
y = jax.device_put(x, s)
|
||||
self.assertArraysEqual(y, x)
|
||||
self.assertEqual(y.sharding, s)
|
||||
|
||||
y2 = jax.device_put(x_np, s)
|
||||
self.assertArraysEqual(y2, x_np)
|
||||
self.assertEqual(y2.sharding, s)
|
||||
|
||||
s2 = NamedSharding(mesh, P('x'))
|
||||
z = jax.device_put(y, s2)
|
||||
self.assertArraysEqual(z, x)
|
||||
self.assertEqual(z.sharding, s2)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user