mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Run single controller device_put via efficient reshard if the device_set of input and the sharding is the same. The transfer_guard in the test fails before this CL.
PiperOrigin-RevId: 648464214
This commit is contained in:
parent
9def0f1c00
commit
94ba6c3f98
@ -330,7 +330,7 @@ def _override_get_device_assignment(sharding, *args, **kwargs):
|
||||
def _identity_fn(x):
|
||||
return x
|
||||
|
||||
def _mcjax_reshard(x, target_sharding):
|
||||
def _different_device_order_reshard(x, target_sharding):
|
||||
from jax._src import api, array
|
||||
|
||||
inp_sharding = x.sharding
|
||||
@ -410,7 +410,14 @@ def _device_put_sharding_impl(x, aval, device):
|
||||
if (not s.is_fully_addressable and
|
||||
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable):
|
||||
assert isinstance(s, Sharding)
|
||||
return _mcjax_reshard(x, s)
|
||||
return _different_device_order_reshard(x, s)
|
||||
|
||||
if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and
|
||||
x.is_fully_addressable and len(s.device_set) > 1 and
|
||||
s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error
|
||||
s.device_set == x.sharding.device_set):
|
||||
assert isinstance(s, Sharding)
|
||||
return _different_device_order_reshard(x, s)
|
||||
|
||||
if not s.is_fully_addressable:
|
||||
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
|
||||
|
@ -39,7 +39,7 @@ from jax.errors import JAXTypeError
|
||||
from jax import lax
|
||||
from jax.lax import with_sharding_constraint
|
||||
from jax._src import prng
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.sharding import PartitionSpec as P, Mesh
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax._src import array
|
||||
@ -4220,6 +4220,26 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out2 = compiled2(jnp.arange(8))
|
||||
self.assertArraysEqual(out2, np.arange(8) * 2)
|
||||
|
||||
def test_device_put_efficient_reshard_single_host(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Requires >= 4 devices')
|
||||
|
||||
dev = jax.devices()
|
||||
mesh1 = Mesh(np.array([dev[0], dev[1], dev[2], dev[3]]).reshape(2, 2),
|
||||
('x', 'y'))
|
||||
mesh2 = Mesh(np.array([dev[3], dev[2], dev[1], dev[0]]).reshape(2, 2),
|
||||
('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s1 = NamedSharding(mesh1, P('x', 'y'))
|
||||
s2 = NamedSharding(mesh2, P('x'))
|
||||
|
||||
x_s1 = jax.device_put(np_inp, s1)
|
||||
|
||||
with jax.transfer_guard('disallow_explicit'):
|
||||
out = jax.device_put(x_s1, s2)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
self.assertEqual(out.sharding, s2)
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
Loading…
x
Reference in New Issue
Block a user