Add identity jit tests to go from pinned_host -> device and vice versa

PiperOrigin-RevId: 620114420
This commit is contained in:
Yash Katariya 2024-03-28 18:19:25 -07:00 committed by jax authors
parent c846233089
commit 84156f359f
2 changed files with 21 additions and 2 deletions

View File

@ -1456,9 +1456,8 @@ def lower_jaxpr_to_fun(
# Insert a custom call if output is on host because XLA needs that to do the
# transfer.
if ir_result_memory_kinds is not None:
# TODO: We should have a default memory kind which we can check against.
flat_outputs = [
o if mk is None or mk == 'device' else wrap_with_memory_kind(o, mk, o_aval)
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
if ir_result_shardings is not None and name == "main":

View File

@ -1117,6 +1117,26 @@ class DevicePutTest(jtu.JaxTestCase):
self._check_device_put_addressable_shards(
out2, np_inp * np_inp * 2, s_host, 'pinned_host')
def test_identity_jit_host_to_device_and_vice_versa(self):
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
s_dev = s_host.with_memory_kind('device')
arr_host = jax.device_put(np_inp, s_host)
arr_dev = jax.device_put(np_inp, s_dev)
# pinned_host -> device
f = jax.jit(lambda x: x, out_shardings=s_dev)
out_dev = f(arr_host)
self.assertArraysEqual(out_dev, np_inp)
self.assertEqual(out_dev.sharding, s_dev)
# device -> pinned_host
g = jax.jit(lambda x: x, out_shardings=s_host)
out_host = g(arr_dev)
self.assertArraysEqual(out_host, np_inp)
self.assertEqual(out_host.sharding, s_host)
class ActivationOffloadingTest(jtu.JaxTestCase):