From 7b7d36a8e6105d4bbc4e7cb0b86f171bbf2c884b Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 10 Apr 2025 21:32:07 -0700 Subject: [PATCH] Add a 2D test in memories_test. PiperOrigin-RevId: 746295338 --- tests/memories_test.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/memories_test.py b/tests/memories_test.py index 570b0c375..278044eab 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -791,6 +791,36 @@ class ComputeOffload(jtu.BufferDonationTestCase): lowered_text = f.lower(jnp.arange(8)).as_text() self.assertIn('_xla_compute_type', lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) + def h(x): + y = g(x) + return y * 3 + + out2 = h(inp) + self.assertArraysEqual(out2, inp * 6) + self.assertEqual(out2.sharding.memory_kind, "pinned_host") + + def test_compute_on_2d(self): + out_s = SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host") + + @compute_on("device_host") + @jax.jit + def g(x): + return x * 2 + + @jax.jit + def f(x): + y = g(x) + return y * 3 + + inp = jnp.arange(9943.0) + inp = jnp.reshape(inp, (61, 163)) + out = f(inp) + self.assertArraysEqual(out, inp * 6) + + lowered_text = f.lower(inp).as_text() + self.assertIn("_xla_compute_type", lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) def h(x): y = g(x)