Run single controller device_put via efficient reshard if the device_set of input and the sharding is the same. The transfer_guard in the test fails before this CL.

PiperOrigin-RevId: 648464214
This commit is contained in:
Yash Katariya 2024-07-01 13:13:53 -07:00 committed by jax authors
parent 9def0f1c00
commit 94ba6c3f98
2 changed files with 30 additions and 3 deletions

View File

@ -330,7 +330,7 @@ def _override_get_device_assignment(sharding, *args, **kwargs):
def _identity_fn(x):
return x
def _mcjax_reshard(x, target_sharding):
def _different_device_order_reshard(x, target_sharding):
from jax._src import api, array
inp_sharding = x.sharding
@ -410,7 +410,14 @@ def _device_put_sharding_impl(x, aval, device):
if (not s.is_fully_addressable and
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable):
assert isinstance(s, Sharding)
return _mcjax_reshard(x, s)
return _different_device_order_reshard(x, s)
if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and
x.is_fully_addressable and len(s.device_set) > 1 and
s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error
s.device_set == x.sharding.device_set):
assert isinstance(s, Sharding)
return _different_device_order_reshard(x, s)
if not s.is_fully_addressable:
if ((isinstance(x, array.ArrayImpl) and not x._committed) or

View File

@ -39,7 +39,7 @@ from jax.errors import JAXTypeError
from jax import lax
from jax.lax import with_sharding_constraint
from jax._src import prng
from jax.sharding import PartitionSpec as P
from jax.sharding import PartitionSpec as P, Mesh
from jax.experimental import multihost_utils
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
@ -4220,6 +4220,26 @@ class ArrayPjitTest(jtu.JaxTestCase):
out2 = compiled2(jnp.arange(8))
self.assertArraysEqual(out2, np.arange(8) * 2)
def test_device_put_efficient_reshard_single_host(self):
if jax.device_count() < 4:
self.skipTest('Requires >= 4 devices')
dev = jax.devices()
mesh1 = Mesh(np.array([dev[0], dev[1], dev[2], dev[3]]).reshape(2, 2),
('x', 'y'))
mesh2 = Mesh(np.array([dev[3], dev[2], dev[1], dev[0]]).reshape(2, 2),
('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s1 = NamedSharding(mesh1, P('x', 'y'))
s2 = NamedSharding(mesh2, P('x'))
x_s1 = jax.device_put(np_inp, s1)
with jax.transfer_guard('disallow_explicit'):
out = jax.device_put(x_s1, s2)
self.assertArraysEqual(out, np_inp)
self.assertEqual(out.sharding, s2)
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")