Add a fast path MPMD device_put test.

This checks that if the devices in a sharding are different, we still take `xc.copy_array_to_devices_with_sharding` path.

This is to prevent changes to shard_arg handler of Array that checks for devices instead of indices.

PiperOrigin-RevId: 609760758
This commit is contained in:
Yash Katariya 2024-02-23 10:02:19 -08:00 committed by jax authors
parent c9eaca2282
commit 4e61c8856b
2 changed files with 38 additions and 0 deletions

View File

@ -49,6 +49,7 @@ from jax._src import linear_util as lu
from jax._src import dtypes as _dtypes
from jax._src import monitoring
from jax._src import stages
from jax._src.lib import xla_client as xc
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
from jax._src.interpreters import pxla
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
@ -224,6 +225,22 @@ def count_primitive_compiles():
count[0] = dispatch.xla_primitive_callable.cache_info().misses
@contextmanager
def count_device_put_fast_path_hit():
original_fn = xc.copy_array_to_devices_with_sharding
count = [0]
def copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
count[0] += 1
return original_fn(*args, **kwargs)
xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count
try:
yield count
finally:
xc.copy_array_to_devices_with_sharding = original_fn
@contextmanager
def count_pjit_cpp_cache_miss():
original_pjit_lower = pjit_lib._pjit_lower

View File

@ -3816,6 +3816,27 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp)
def test_mpmd_device_put_fast_path(self):
if jax.device_count() < 4:
self.skipTest('Needs >= 4 devices')
dev_count = jax.device_count()
mesh1 = jax.sharding.Mesh(jax.devices()[:dev_count//2], 'x')
mesh2 = jax.sharding.Mesh(jax.devices()[dev_count//2:], 'x')
inp = np.arange(8)
arr1 = jax.device_put(inp, NamedSharding(mesh1, P('x')))
# This is to prevent changes to shard_arg_handler of Array which checks for
# indices to take the fast path for resharding. Changes made to the handler
# to check for shardings instead of indices will cause this test to fail and
# that is expected.
with jtu.count_device_put_fast_path_hit() as count:
out = jax.device_put(arr1, NamedSharding(mesh2, P('x')))
self.assertEqual(count[0], 1)
self.assertTupleEqual(out.sharding._device_assignment,
mesh2._flat_devices_tuple)
self.assertArraysEqual(out, inp)
class TempSharding(Sharding):