mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[callback] Enable 64-bit types and add tests.
This takes advantage of a recent fix in XLA:TPU to enable 64-bit host transfers. PiperOrigin-RevId: 562890507
This commit is contained in:
parent
7224c24521
commit
f27816af30
@ -1935,9 +1935,6 @@ def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
|
||||
channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
|
||||
send_op = hlo.SendOp([operand], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
dtype_str = _dtype_to_xla_type_string(aval.dtype)
|
||||
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
|
||||
raise NotImplementedError("64-bit types not supported.")
|
||||
send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
||||
dict(
|
||||
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
||||
@ -1954,9 +1951,6 @@ def receive_from_host(channel: int, token: hlo.TokenType,
|
||||
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
|
||||
hlo.TokenType.get()], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
dtype_str = _dtype_to_xla_type_string(out_aval.dtype)
|
||||
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
|
||||
raise NotImplementedError("64-bit types not supported.")
|
||||
recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
||||
dict(
|
||||
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
||||
|
@ -90,6 +90,41 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
out = f(0.)
|
||||
self.assertEqual(out, 1.)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"{flavor}_{dtype}",
|
||||
dtype=dtype,
|
||||
callback=dict(io_unordered=io_calback_unordered,
|
||||
io_ordered=io_callback_ordered,
|
||||
pure=jax.pure_callback)[flavor])
|
||||
for flavor in ("io_unordered", "io_ordered", "pure")
|
||||
for dtype in jtu.dtypes.all
|
||||
)
|
||||
def test_callback_works_with_all_types(self, *, callback, dtype):
|
||||
def host_func(x):
|
||||
if dtype == np.bool_:
|
||||
return ~ x
|
||||
else:
|
||||
return x + x
|
||||
_received = None
|
||||
def _cb(x):
|
||||
nonlocal _received
|
||||
_received = x
|
||||
return host_func(x)
|
||||
|
||||
if dtype == np.bool_:
|
||||
x = np.array([True, False, True, True], dtype=np.bool_)
|
||||
else:
|
||||
x = np.arange(4, dtype=dtype)
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return callback(_cb,
|
||||
core.ShapedArray(x.shape, x.dtype), x)
|
||||
|
||||
out = f(x)
|
||||
self.assertAllClose(out, host_func(x))
|
||||
jax.effects_barrier()
|
||||
self.assertAllClose(_received, x)
|
||||
|
||||
@with_pure_and_io_callbacks
|
||||
def test_callback_with_wrong_number_of_args(self, *, callback):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user