mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add a fast path MPMD device_put test.
This checks that if the devices in a sharding are different, we still take `xc.copy_array_to_devices_with_sharding` path. This is to prevent changes to shard_arg handler of Array that checks for devices instead of indices. PiperOrigin-RevId: 609760758
This commit is contained in:
parent
c9eaca2282
commit
4e61c8856b
@ -49,6 +49,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import monitoring
|
||||
from jax._src import stages
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
|
||||
@ -224,6 +225,22 @@ def count_primitive_compiles():
|
||||
count[0] = dispatch.xla_primitive_callable.cache_info().misses
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_device_put_fast_path_hit():
|
||||
original_fn = xc.copy_array_to_devices_with_sharding
|
||||
count = [0]
|
||||
|
||||
def copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
xc.copy_array_to_devices_with_sharding = original_fn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_pjit_cpp_cache_miss():
|
||||
original_pjit_lower = pjit_lib._pjit_lower
|
||||
|
@ -3816,6 +3816,27 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
def test_mpmd_device_put_fast_path(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Needs >= 4 devices')
|
||||
|
||||
dev_count = jax.device_count()
|
||||
mesh1 = jax.sharding.Mesh(jax.devices()[:dev_count//2], 'x')
|
||||
mesh2 = jax.sharding.Mesh(jax.devices()[dev_count//2:], 'x')
|
||||
inp = np.arange(8)
|
||||
arr1 = jax.device_put(inp, NamedSharding(mesh1, P('x')))
|
||||
|
||||
# This is to prevent changes to shard_arg_handler of Array which checks for
|
||||
# indices to take the fast path for resharding. Changes made to the handler
|
||||
# to check for shardings instead of indices will cause this test to fail and
|
||||
# that is expected.
|
||||
with jtu.count_device_put_fast_path_hit() as count:
|
||||
out = jax.device_put(arr1, NamedSharding(mesh2, P('x')))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertTupleEqual(out.sharding._device_assignment,
|
||||
mesh2._flat_devices_tuple)
|
||||
self.assertArraysEqual(out, inp)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user