mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add a 2D test in memories_test.
PiperOrigin-RevId: 746295338
This commit is contained in:
parent
9f5f6edb85
commit
7b7d36a8e6
@ -791,6 +791,36 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
lowered_text = f.lower(jnp.arange(8)).as_text()
|
||||
self.assertIn('_xla_compute_type', lowered_text)
|
||||
|
||||
@functools.partial(jax.jit, out_shardings=out_s)
|
||||
def h(x):
|
||||
y = g(x)
|
||||
return y * 3
|
||||
|
||||
out2 = h(inp)
|
||||
self.assertArraysEqual(out2, inp * 6)
|
||||
self.assertEqual(out2.sharding.memory_kind, "pinned_host")
|
||||
|
||||
def test_compute_on_2d(self):
|
||||
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host")
|
||||
|
||||
@compute_on("device_host")
|
||||
@jax.jit
|
||||
def g(x):
|
||||
return x * 2
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = g(x)
|
||||
return y * 3
|
||||
|
||||
inp = jnp.arange(9943.0)
|
||||
inp = jnp.reshape(inp, (61, 163))
|
||||
out = f(inp)
|
||||
self.assertArraysEqual(out, inp * 6)
|
||||
|
||||
lowered_text = f.lower(inp).as_text()
|
||||
self.assertIn("_xla_compute_type", lowered_text)
|
||||
|
||||
@functools.partial(jax.jit, out_shardings=out_s)
|
||||
def h(x):
|
||||
y = g(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user