diff --git a/jax/_src/callback.py b/jax/_src/callback.py index ef9981ee5..375a89a5b 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -194,7 +194,7 @@ def pure_callback_lowering( ) op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding) - result, _, keepalive = mlir.emit_python_callback( + result, _, _ = mlir.emit_python_callback( ctx, _callback, None, @@ -204,7 +204,6 @@ def pure_callback_lowering( False, sharding=op_sharding, ) - ctx.module_context.add_keepalive(keepalive) return result @@ -435,7 +434,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params): op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding) if ordered: token = ctx.tokens_in.get(_OrderedIOEffect)[0] - result, token, keepalive = mlir.emit_python_callback( + result, token, _ = mlir.emit_python_callback( ctx, _callback, token, @@ -447,7 +446,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params): ) ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)})) else: - result, token, keepalive = mlir.emit_python_callback( + result, token, _ = mlir.emit_python_callback( ctx, _callback, None, @@ -457,7 +456,6 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params): True, sharding=op_sharding, ) - ctx.module_context.add_keepalive(keepalive) return result diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index dfdbfd7e1..8495c012e 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -514,14 +514,13 @@ def check_lowering_rule(ctx, *args, err_tree, debug): if not config.jax_experimental_unsafe_xla_runtime_errors: raise functionalization_error - out_op, _, keep_alive = mlir.emit_python_callback( + out_op, _, _ = mlir.emit_python_callback( ctx, callback=functools.partial(python_err, err_tree), token=None, operands=args, operand_avals=list(ctx.avals_in), result_avals=list(ctx.avals_out), has_side_effect=True) - ctx.module_context.add_keepalive(keep_alive) return out_op def check_lowering_rule_unsupported(*a, debug, **k): diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index fce66521f..d6064c6c6 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -153,14 +153,13 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): *flat_args, effect=effect, callback=callback, **params)) if effects.ordered_effects.contains(effect): token = ctx.tokens_in.get(effect)[0] - result, token, keepalive = mlir.emit_python_callback( + result, token, _ = mlir.emit_python_callback( ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True) ctx.set_tokens_out(mlir.TokenSet({effect: (token,)})) else: - result, token, keepalive = mlir.emit_python_callback( + result, token, _ = mlir.emit_python_callback( ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True, sharding=sharding) - ctx.module_context.add_keepalive(keepalive) return result mlir.register_lowering(debug_callback_p, debug_callback_lowering, platform="cpu") diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index da998b09a..5a6fc8bc6 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -43,6 +43,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import dialects from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect @@ -528,9 +529,17 @@ class ModuleContext: def new_channel(self) -> int: return next(self.channel_iterator) + # Adds an IFRT host callback object to the context. A reference to these + # callbacks will be provided to IFRT during compilation so it can do things + # like serialize them and keep them alive. def add_host_callback(self, host_callback: Any) -> None: self.host_callbacks.append(host_callback) + # Keeps a value alive as long as the Python executable is alive. + # TODO(phawkins): this feature is problematic, because you almost certainly + # want to keep alive values as long as the underlying runtime executable is + # still alive/executing. The Python executable object may have a shorter + # lifetime, so it's highly likely any caller of this method is buggy. def add_keepalive(self, keepalive: Any) -> None: self.keepalives.append(keepalive) @@ -2177,7 +2186,7 @@ def _emit_tpu_python_callback( result_shapes: Sequence[xc.Shape], *, sharding: xc.OpSharding | None = None -) -> tuple[Sequence[ir.Value], Any, Any]: +) -> tuple[Sequence[ir.Value], Any]: token = token or hlo.CreateTokenOp().result _wrapped_callback = callback @@ -2215,11 +2224,11 @@ def _emit_tpu_python_callback( callback.__name__, sharding=sharding) outputs.append(out) recv_channels.append(channel) - opaque = backend.make_python_callback_from_host_send_and_recv( + ifrt_callback = backend.make_python_callback_from_host_send_and_recv( _wrapped_callback, operand_shapes, result_shapes, send_channels, recv_channels, pickle_util.dumps) # type: ignore # pylint: disable=missing-parameter - ctx.module_context.add_host_callback(opaque) - return outputs, token, opaque + ctx.module_context.add_host_callback(ifrt_callback) + return outputs, token def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None): if minor_to_major is None: @@ -2296,7 +2305,7 @@ def emit_python_callback( (aval, shape) for aval, shape in zip(result_avals, result_shapes) if not is_empty_shape(aval.shape)]) - non_empty_outputs, token, keepalive = _emit_tpu_python_callback( + non_empty_outputs, token = _emit_tpu_python_callback( backend, ctx, _wrapped_callback, token, operands, operand_avals, operand_shapes, non_empty_result_avals, non_empty_result_shapes, @@ -2306,7 +2315,7 @@ def emit_python_callback( ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype)) if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter) for result_aval in result_avals] - return outputs, token, keepalive + return outputs, token, None result_types = util.flatten([aval_to_ir_types(aval) for aval in result_avals]) if token: @@ -2325,10 +2334,14 @@ def emit_python_callback( result_types = [token_type()[0], *result_types] operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] - callback_descriptor, keepalive = ( + callback_descriptor, ifrt_callback = ( backend.get_emit_python_callback_descriptor(_wrapped_callback, operand_shapes, result_shapes)) + if xla_extension_version >= 202: + ctx.module_context.add_host_callback(ifrt_callback) + else: + ctx.module_context.add_keepalive(ifrt_callback) descriptor_operand = ir_constant(callback_descriptor) callback_operands = [descriptor_operand, *operands] if operand_mlir_layouts is not None: @@ -2358,7 +2371,7 @@ def emit_python_callback( ] if token: token, *results = results - return results, token, keepalive + return results, token, ifrt_callback def build_xla_computation_helper( closed_jaxpr: core.ClosedJaxpr, *, name: str, platform: str, diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index 08aa88b30..047043651 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -544,10 +544,9 @@ def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder): A = csr_matrix((data, indices, indptr), shape=(b.size, b.size)) return (linalg.spsolve(A, b).astype(b.dtype),) - result, _, keepalive = mlir.emit_python_callback( + result, _, _ = mlir.emit_python_callback( ctx, _callback, None, args, ctx.avals_in, ctx.avals_out, has_side_effect=False) - ctx.module_context.add_keepalive(keepalive) return result diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 4d45c46c3..51c264b13 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -123,13 +123,12 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out if effects.ordered_effects.contains(effect): token_in = ctx.tokens_in.get(effect)[0] - out_op, token_out, keep_alive = mlir.emit_python_callback( + out_op, token_out, _ = mlir.emit_python_callback( ctx, callback, token_in, list(args), list(ctx.avals_in), list(ctx.avals_out), True) if token_out: ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect: token_out}))) - ctx.module_context.add_keepalive(keep_alive) return out_op mlir.register_lowering(callback_p, callback_effect_lowering) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 7aec1ae68..63e6dd7df 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -30,6 +30,7 @@ from jax._src import test_util as jtu from jax._src import util from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version from jax.experimental import io_callback from jax.experimental import maps from jax.experimental import pjit @@ -772,6 +773,27 @@ class PureCallbackTest(jtu.JaxTestCase): out, np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.) + @unittest.skipIf(xla_extension_version < 202, "Test requires jaxlib 0.4.18") + def test_callback_with_immediate_executable_destruction(self): + + def loop_body(i, x): + del i + return jax.pure_callback(lambda y: y + np.ones(4, np.float32), + x, x) + + class AClass: + def f(self, ys): + return lax.fori_loop(0, 10, loop_body, jnp.ones(4, np.float32)) + + num_devices = jax.local_device_count() + c = AClass() + out = jax.pmap(c.f)(np.ones((num_devices,), np.float32)) + # c.f is an ephemeral bound method object, and it will be destroyed + # immediately. This test verifies that the execution itself keeps the + # callback alive. + np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32)) + + def test_callback_inside_xmap(self): def _callback(x):