From 84156f359fd1bf05ccb75d39a00b86ddd69f5475 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 28 Mar 2024 18:19:25 -0700 Subject: [PATCH] Add identity jit tests to go from pinned_host -> device and vice versa PiperOrigin-RevId: 620114420 --- jax/_src/interpreters/mlir.py | 3 +-- tests/memories_test.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 23edc58fd..173c8bf86 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1456,9 +1456,8 @@ def lower_jaxpr_to_fun( # Insert a custom call if output is on host because XLA needs that to do the # transfer. if ir_result_memory_kinds is not None: - # TODO: We should have a default memory kind which we can check against. flat_outputs = [ - o if mk is None or mk == 'device' else wrap_with_memory_kind(o, mk, o_aval) + o if mk is None else wrap_with_memory_kind(o, mk, o_aval) for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)] if ir_result_shardings is not None and name == "main": diff --git a/tests/memories_test.py b/tests/memories_test.py index c56a9730d..efa42fdef 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1117,6 +1117,26 @@ class DevicePutTest(jtu.JaxTestCase): self._check_device_put_addressable_shards( out2, np_inp * np_inp * 2, s_host, 'pinned_host') + def test_identity_jit_host_to_device_and_vice_versa(self): + mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + np_inp = np.arange(16).reshape(8, 2) + s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') + s_dev = s_host.with_memory_kind('device') + arr_host = jax.device_put(np_inp, s_host) + arr_dev = jax.device_put(np_inp, s_dev) + + # pinned_host -> device + f = jax.jit(lambda x: x, out_shardings=s_dev) + out_dev = f(arr_host) + self.assertArraysEqual(out_dev, np_inp) + self.assertEqual(out_dev.sharding, s_dev) + + # device -> pinned_host + g = jax.jit(lambda x: x, out_shardings=s_host) + out_host = g(arr_dev) + self.assertArraysEqual(out_host, np_inp) + self.assertEqual(out_host.sharding, s_host) + class ActivationOffloadingTest(jtu.JaxTestCase):