mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add identity jit tests to go from pinned_host -> device and vice versa
PiperOrigin-RevId: 620114420
This commit is contained in:
parent
c846233089
commit
84156f359f
@ -1456,9 +1456,8 @@ def lower_jaxpr_to_fun(
|
||||
# Insert a custom call if output is on host because XLA needs that to do the
|
||||
# transfer.
|
||||
if ir_result_memory_kinds is not None:
|
||||
# TODO: We should have a default memory kind which we can check against.
|
||||
flat_outputs = [
|
||||
o if mk is None or mk == 'device' else wrap_with_memory_kind(o, mk, o_aval)
|
||||
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
|
||||
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
|
||||
|
||||
if ir_result_shardings is not None and name == "main":
|
||||
|
@ -1117,6 +1117,26 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
self._check_device_put_addressable_shards(
|
||||
out2, np_inp * np_inp * 2, s_host, 'pinned_host')
|
||||
|
||||
def test_identity_jit_host_to_device_and_vice_versa(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
|
||||
s_dev = s_host.with_memory_kind('device')
|
||||
arr_host = jax.device_put(np_inp, s_host)
|
||||
arr_dev = jax.device_put(np_inp, s_dev)
|
||||
|
||||
# pinned_host -> device
|
||||
f = jax.jit(lambda x: x, out_shardings=s_dev)
|
||||
out_dev = f(arr_host)
|
||||
self.assertArraysEqual(out_dev, np_inp)
|
||||
self.assertEqual(out_dev.sharding, s_dev)
|
||||
|
||||
# device -> pinned_host
|
||||
g = jax.jit(lambda x: x, out_shardings=s_host)
|
||||
out_host = g(arr_dev)
|
||||
self.assertArraysEqual(out_host, np_inp)
|
||||
self.assertEqual(out_host.sharding, s_host)
|
||||
|
||||
|
||||
class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user