[callback] Enable 64-bit types and add tests.

This takes advantage of a recent fix in XLA:TPU to enable
64-bit host transfers.

PiperOrigin-RevId: 562890507
This commit is contained in:
George Necula 2023-09-05 14:22:44 -07:00 committed by jax authors
parent 7224c24521
commit f27816af30
2 changed files with 35 additions and 6 deletions

View File

@ -1935,9 +1935,6 @@ def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
send_op = hlo.SendOp([operand], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
dtype_str = _dtype_to_xla_type_string(aval.dtype)
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
raise NotImplementedError("64-bit types not supported.")
send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
dict(
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
@ -1954,9 +1951,6 @@ def receive_from_host(channel: int, token: hlo.TokenType,
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
hlo.TokenType.get()], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
dtype_str = _dtype_to_xla_type_string(out_aval.dtype)
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
raise NotImplementedError("64-bit types not supported.")
recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
dict(
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),

View File

@ -90,6 +90,41 @@ class PythonCallbackTest(jtu.JaxTestCase):
out = f(0.)
self.assertEqual(out, 1.)
@parameterized.named_parameters(
dict(testcase_name=f"{flavor}_{dtype}",
dtype=dtype,
callback=dict(io_unordered=io_calback_unordered,
io_ordered=io_callback_ordered,
pure=jax.pure_callback)[flavor])
for flavor in ("io_unordered", "io_ordered", "pure")
for dtype in jtu.dtypes.all
)
def test_callback_works_with_all_types(self, *, callback, dtype):
def host_func(x):
if dtype == np.bool_:
return ~ x
else:
return x + x
_received = None
def _cb(x):
nonlocal _received
_received = x
return host_func(x)
if dtype == np.bool_:
x = np.array([True, False, True, True], dtype=np.bool_)
else:
x = np.arange(4, dtype=dtype)
@jax.jit
def f(x):
return callback(_cb,
core.ShapedArray(x.shape, x.dtype), x)
out = f(x)
self.assertAllClose(out, host_func(x))
jax.effects_barrier()
self.assertAllClose(_received, x)
@with_pure_and_io_callbacks
def test_callback_with_wrong_number_of_args(self, *, callback):