diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index cc226a678..1a0886192 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1935,9 +1935,6 @@ def send_to_host(channel: int, token: hlo.TokenType, operand: Any, channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE) send_op = hlo.SendOp([operand], token, channel_handle, is_host_transfer=ir.BoolAttr.get(True)) - dtype_str = _dtype_to_xla_type_string(aval.dtype) - if dtype_str in {"f64", "s64", "u64", "c64", "c128"}: - raise NotImplementedError("64-bit types not supported.") send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( dict( _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), @@ -1954,9 +1951,6 @@ def receive_from_host(channel: int, token: hlo.TokenType, recv_op = hlo.RecvOp([aval_to_ir_type(out_aval), hlo.TokenType.get()], token, channel_handle, is_host_transfer=ir.BoolAttr.get(True)) - dtype_str = _dtype_to_xla_type_string(out_aval.dtype) - if dtype_str in {"f64", "s64", "u64", "c64", "c128"}: - raise NotImplementedError("64-bit types not supported.") recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( dict( _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index eb377fa7d..15b1e9d7e 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -90,6 +90,41 @@ class PythonCallbackTest(jtu.JaxTestCase): out = f(0.) self.assertEqual(out, 1.) + @parameterized.named_parameters( + dict(testcase_name=f"{flavor}_{dtype}", + dtype=dtype, + callback=dict(io_unordered=io_calback_unordered, + io_ordered=io_callback_ordered, + pure=jax.pure_callback)[flavor]) + for flavor in ("io_unordered", "io_ordered", "pure") + for dtype in jtu.dtypes.all + ) + def test_callback_works_with_all_types(self, *, callback, dtype): + def host_func(x): + if dtype == np.bool_: + return ~ x + else: + return x + x + _received = None + def _cb(x): + nonlocal _received + _received = x + return host_func(x) + + if dtype == np.bool_: + x = np.array([True, False, True, True], dtype=np.bool_) + else: + x = np.arange(4, dtype=dtype) + @jax.jit + def f(x): + return callback(_cb, + core.ShapedArray(x.shape, x.dtype), x) + + out = f(x) + self.assertAllClose(out, host_func(x)) + jax.effects_barrier() + self.assertAllClose(_received, x) + @with_pure_and_io_callbacks def test_callback_with_wrong_number_of_args(self, *, callback):