[JAX] Keep CPU host callbacks alive via IFRT, rather than by attaching them to the Python object.

We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive.

PiperOrigin-RevId: 571141106
This commit is contained in:
Peter Hawkins 2023-10-05 15:06:29 -07:00 committed by jax authors
parent 7c353c4b55
commit 15126504a7
7 changed files with 51 additions and 22 deletions

View File

@ -194,7 +194,7 @@ def pure_callback_lowering(
)
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
result, _, keepalive = mlir.emit_python_callback(
result, _, _ = mlir.emit_python_callback(
ctx,
_callback,
None,
@ -204,7 +204,6 @@ def pure_callback_lowering(
False,
sharding=op_sharding,
)
ctx.module_context.add_keepalive(keepalive)
return result
@ -435,7 +434,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
if ordered:
token = ctx.tokens_in.get(_OrderedIOEffect)[0]
result, token, keepalive = mlir.emit_python_callback(
result, token, _ = mlir.emit_python_callback(
ctx,
_callback,
token,
@ -447,7 +446,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
)
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)}))
else:
result, token, keepalive = mlir.emit_python_callback(
result, token, _ = mlir.emit_python_callback(
ctx,
_callback,
None,
@ -457,7 +456,6 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
True,
sharding=op_sharding,
)
ctx.module_context.add_keepalive(keepalive)
return result

View File

@ -514,14 +514,13 @@ def check_lowering_rule(ctx, *args, err_tree, debug):
if not config.jax_experimental_unsafe_xla_runtime_errors:
raise functionalization_error
out_op, _, keep_alive = mlir.emit_python_callback(
out_op, _, _ = mlir.emit_python_callback(
ctx, callback=functools.partial(python_err, err_tree),
token=None,
operands=args,
operand_avals=list(ctx.avals_in),
result_avals=list(ctx.avals_out),
has_side_effect=True)
ctx.module_context.add_keepalive(keep_alive)
return out_op
def check_lowering_rule_unsupported(*a, debug, **k):

View File

@ -153,14 +153,13 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
*flat_args, effect=effect, callback=callback, **params))
if effects.ordered_effects.contains(effect):
token = ctx.tokens_in.get(effect)[0]
result, token, keepalive = mlir.emit_python_callback(
result, token, _ = mlir.emit_python_callback(
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
else:
result, token, keepalive = mlir.emit_python_callback(
result, token, _ = mlir.emit_python_callback(
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True,
sharding=sharding)
ctx.module_context.add_keepalive(keepalive)
return result
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
platform="cpu")

View File

@ -43,6 +43,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
@ -528,9 +529,17 @@ class ModuleContext:
def new_channel(self) -> int:
return next(self.channel_iterator)
# Adds an IFRT host callback object to the context. A reference to these
# callbacks will be provided to IFRT during compilation so it can do things
# like serialize them and keep them alive.
def add_host_callback(self, host_callback: Any) -> None:
self.host_callbacks.append(host_callback)
# Keeps a value alive as long as the Python executable is alive.
# TODO(phawkins): this feature is problematic, because you almost certainly
# want to keep alive values as long as the underlying runtime executable is
# still alive/executing. The Python executable object may have a shorter
# lifetime, so it's highly likely any caller of this method is buggy.
def add_keepalive(self, keepalive: Any) -> None:
self.keepalives.append(keepalive)
@ -2177,7 +2186,7 @@ def _emit_tpu_python_callback(
result_shapes: Sequence[xc.Shape],
*,
sharding: xc.OpSharding | None = None
) -> tuple[Sequence[ir.Value], Any, Any]:
) -> tuple[Sequence[ir.Value], Any]:
token = token or hlo.CreateTokenOp().result
_wrapped_callback = callback
@ -2215,11 +2224,11 @@ def _emit_tpu_python_callback(
callback.__name__, sharding=sharding)
outputs.append(out)
recv_channels.append(channel)
opaque = backend.make_python_callback_from_host_send_and_recv(
ifrt_callback = backend.make_python_callback_from_host_send_and_recv(
_wrapped_callback, operand_shapes, result_shapes, send_channels,
recv_channels, pickle_util.dumps) # type: ignore # pylint: disable=missing-parameter
ctx.module_context.add_host_callback(opaque)
return outputs, token, opaque
ctx.module_context.add_host_callback(ifrt_callback)
return outputs, token
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
if minor_to_major is None:
@ -2296,7 +2305,7 @@ def emit_python_callback(
(aval, shape)
for aval, shape in zip(result_avals, result_shapes)
if not is_empty_shape(aval.shape)])
non_empty_outputs, token, keepalive = _emit_tpu_python_callback(
non_empty_outputs, token = _emit_tpu_python_callback(
backend, ctx, _wrapped_callback, token,
operands, operand_avals, operand_shapes,
non_empty_result_avals, non_empty_result_shapes,
@ -2306,7 +2315,7 @@ def emit_python_callback(
ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype))
if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter)
for result_aval in result_avals]
return outputs, token, keepalive
return outputs, token, None
result_types = util.flatten([aval_to_ir_types(aval) for aval in result_avals])
if token:
@ -2325,10 +2334,14 @@ def emit_python_callback(
result_types = [token_type()[0], *result_types]
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
callback_descriptor, keepalive = (
callback_descriptor, ifrt_callback = (
backend.get_emit_python_callback_descriptor(_wrapped_callback,
operand_shapes,
result_shapes))
if xla_extension_version >= 202:
ctx.module_context.add_host_callback(ifrt_callback)
else:
ctx.module_context.add_keepalive(ifrt_callback)
descriptor_operand = ir_constant(callback_descriptor)
callback_operands = [descriptor_operand, *operands]
if operand_mlir_layouts is not None:
@ -2358,7 +2371,7 @@ def emit_python_callback(
]
if token:
token, *results = results
return results, token, keepalive
return results, token, ifrt_callback
def build_xla_computation_helper(
closed_jaxpr: core.ClosedJaxpr, *, name: str, platform: str,

View File

@ -544,10 +544,9 @@ def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder):
A = csr_matrix((data, indices, indptr), shape=(b.size, b.size))
return (linalg.spsolve(A, b).astype(b.dtype),)
result, _, keepalive = mlir.emit_python_callback(
result, _, _ = mlir.emit_python_callback(
ctx, _callback, None, args, ctx.avals_in, ctx.avals_out,
has_side_effect=False)
ctx.module_context.add_keepalive(keepalive)
return result

View File

@ -123,13 +123,12 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
if effects.ordered_effects.contains(effect):
token_in = ctx.tokens_in.get(effect)[0]
out_op, token_out, keep_alive = mlir.emit_python_callback(
out_op, token_out, _ = mlir.emit_python_callback(
ctx, callback, token_in, list(args), list(ctx.avals_in),
list(ctx.avals_out), True)
if token_out:
ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect:
token_out})))
ctx.module_context.add_keepalive(keep_alive)
return out_op
mlir.register_lowering(callback_p, callback_effect_lowering)

View File

@ -30,6 +30,7 @@ from jax._src import test_util as jtu
from jax._src import util
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax.experimental import io_callback
from jax.experimental import maps
from jax.experimental import pjit
@ -772,6 +773,27 @@ class PureCallbackTest(jtu.JaxTestCase):
out,
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
@unittest.skipIf(xla_extension_version < 202, "Test requires jaxlib 0.4.18")
def test_callback_with_immediate_executable_destruction(self):
def loop_body(i, x):
del i
return jax.pure_callback(lambda y: y + np.ones(4, np.float32),
x, x)
class AClass:
def f(self, ys):
return lax.fori_loop(0, 10, loop_body, jnp.ones(4, np.float32))
num_devices = jax.local_device_count()
c = AClass()
out = jax.pmap(c.f)(np.ones((num_devices,), np.float32))
# c.f is an ephemeral bound method object, and it will be destroyed
# immediately. This test verifies that the execution itself keeps the
# callback alive.
np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))
def test_callback_inside_xmap(self):
def _callback(x):