Add a 2D test in memories_test.

PiperOrigin-RevId: 746295338
This commit is contained in:
jax authors 2025-04-10 21:32:07 -07:00
parent 9f5f6edb85
commit 7b7d36a8e6

View File

@ -791,6 +791,36 @@ class ComputeOffload(jtu.BufferDonationTestCase):
lowered_text = f.lower(jnp.arange(8)).as_text()
self.assertIn('_xla_compute_type', lowered_text)
@functools.partial(jax.jit, out_shardings=out_s)
def h(x):
y = g(x)
return y * 3
out2 = h(inp)
self.assertArraysEqual(out2, inp * 6)
self.assertEqual(out2.sharding.memory_kind, "pinned_host")
def test_compute_on_2d(self):
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host")
@compute_on("device_host")
@jax.jit
def g(x):
return x * 2
@jax.jit
def f(x):
y = g(x)
return y * 3
inp = jnp.arange(9943.0)
inp = jnp.reshape(inp, (61, 163))
out = f(inp)
self.assertArraysEqual(out, inp * 6)
lowered_text = f.lower(inp).as_text()
self.assertIn("_xla_compute_type", lowered_text)
@functools.partial(jax.jit, out_shardings=out_s)
def h(x):
y = g(x)