mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[JAX] Keep CPU host callbacks alive via IFRT, rather than by attaching them to the Python object.
We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive. PiperOrigin-RevId: 571141106
This commit is contained in:
parent
7c353c4b55
commit
15126504a7
@ -194,7 +194,7 @@ def pure_callback_lowering(
|
||||
)
|
||||
|
||||
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
|
||||
result, _, keepalive = mlir.emit_python_callback(
|
||||
result, _, _ = mlir.emit_python_callback(
|
||||
ctx,
|
||||
_callback,
|
||||
None,
|
||||
@ -204,7 +204,6 @@ def pure_callback_lowering(
|
||||
False,
|
||||
sharding=op_sharding,
|
||||
)
|
||||
ctx.module_context.add_keepalive(keepalive)
|
||||
return result
|
||||
|
||||
|
||||
@ -435,7 +434,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
||||
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
|
||||
if ordered:
|
||||
token = ctx.tokens_in.get(_OrderedIOEffect)[0]
|
||||
result, token, keepalive = mlir.emit_python_callback(
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
ctx,
|
||||
_callback,
|
||||
token,
|
||||
@ -447,7 +446,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
||||
)
|
||||
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)}))
|
||||
else:
|
||||
result, token, keepalive = mlir.emit_python_callback(
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
ctx,
|
||||
_callback,
|
||||
None,
|
||||
@ -457,7 +456,6 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
||||
True,
|
||||
sharding=op_sharding,
|
||||
)
|
||||
ctx.module_context.add_keepalive(keepalive)
|
||||
return result
|
||||
|
||||
|
||||
|
@ -514,14 +514,13 @@ def check_lowering_rule(ctx, *args, err_tree, debug):
|
||||
if not config.jax_experimental_unsafe_xla_runtime_errors:
|
||||
raise functionalization_error
|
||||
|
||||
out_op, _, keep_alive = mlir.emit_python_callback(
|
||||
out_op, _, _ = mlir.emit_python_callback(
|
||||
ctx, callback=functools.partial(python_err, err_tree),
|
||||
token=None,
|
||||
operands=args,
|
||||
operand_avals=list(ctx.avals_in),
|
||||
result_avals=list(ctx.avals_out),
|
||||
has_side_effect=True)
|
||||
ctx.module_context.add_keepalive(keep_alive)
|
||||
return out_op
|
||||
|
||||
def check_lowering_rule_unsupported(*a, debug, **k):
|
||||
|
@ -153,14 +153,13 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
*flat_args, effect=effect, callback=callback, **params))
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token = ctx.tokens_in.get(effect)[0]
|
||||
result, token, keepalive = mlir.emit_python_callback(
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
|
||||
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
|
||||
else:
|
||||
result, token, keepalive = mlir.emit_python_callback(
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, True,
|
||||
sharding=sharding)
|
||||
ctx.module_context.add_keepalive(keepalive)
|
||||
return result
|
||||
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
|
||||
platform="cpu")
|
||||
|
@ -43,6 +43,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import dialects
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
@ -528,9 +529,17 @@ class ModuleContext:
|
||||
def new_channel(self) -> int:
|
||||
return next(self.channel_iterator)
|
||||
|
||||
# Adds an IFRT host callback object to the context. A reference to these
|
||||
# callbacks will be provided to IFRT during compilation so it can do things
|
||||
# like serialize them and keep them alive.
|
||||
def add_host_callback(self, host_callback: Any) -> None:
|
||||
self.host_callbacks.append(host_callback)
|
||||
|
||||
# Keeps a value alive as long as the Python executable is alive.
|
||||
# TODO(phawkins): this feature is problematic, because you almost certainly
|
||||
# want to keep alive values as long as the underlying runtime executable is
|
||||
# still alive/executing. The Python executable object may have a shorter
|
||||
# lifetime, so it's highly likely any caller of this method is buggy.
|
||||
def add_keepalive(self, keepalive: Any) -> None:
|
||||
self.keepalives.append(keepalive)
|
||||
|
||||
@ -2177,7 +2186,7 @@ def _emit_tpu_python_callback(
|
||||
result_shapes: Sequence[xc.Shape],
|
||||
*,
|
||||
sharding: xc.OpSharding | None = None
|
||||
) -> tuple[Sequence[ir.Value], Any, Any]:
|
||||
) -> tuple[Sequence[ir.Value], Any]:
|
||||
token = token or hlo.CreateTokenOp().result
|
||||
_wrapped_callback = callback
|
||||
|
||||
@ -2215,11 +2224,11 @@ def _emit_tpu_python_callback(
|
||||
callback.__name__, sharding=sharding)
|
||||
outputs.append(out)
|
||||
recv_channels.append(channel)
|
||||
opaque = backend.make_python_callback_from_host_send_and_recv(
|
||||
ifrt_callback = backend.make_python_callback_from_host_send_and_recv(
|
||||
_wrapped_callback, operand_shapes, result_shapes, send_channels,
|
||||
recv_channels, pickle_util.dumps) # type: ignore # pylint: disable=missing-parameter
|
||||
ctx.module_context.add_host_callback(opaque)
|
||||
return outputs, token, opaque
|
||||
ctx.module_context.add_host_callback(ifrt_callback)
|
||||
return outputs, token
|
||||
|
||||
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
|
||||
if minor_to_major is None:
|
||||
@ -2296,7 +2305,7 @@ def emit_python_callback(
|
||||
(aval, shape)
|
||||
for aval, shape in zip(result_avals, result_shapes)
|
||||
if not is_empty_shape(aval.shape)])
|
||||
non_empty_outputs, token, keepalive = _emit_tpu_python_callback(
|
||||
non_empty_outputs, token = _emit_tpu_python_callback(
|
||||
backend, ctx, _wrapped_callback, token,
|
||||
operands, operand_avals, operand_shapes,
|
||||
non_empty_result_avals, non_empty_result_shapes,
|
||||
@ -2306,7 +2315,7 @@ def emit_python_callback(
|
||||
ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype))
|
||||
if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter)
|
||||
for result_aval in result_avals]
|
||||
return outputs, token, keepalive
|
||||
return outputs, token, None
|
||||
|
||||
result_types = util.flatten([aval_to_ir_types(aval) for aval in result_avals])
|
||||
if token:
|
||||
@ -2325,10 +2334,14 @@ def emit_python_callback(
|
||||
result_types = [token_type()[0], *result_types]
|
||||
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
|
||||
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
|
||||
callback_descriptor, keepalive = (
|
||||
callback_descriptor, ifrt_callback = (
|
||||
backend.get_emit_python_callback_descriptor(_wrapped_callback,
|
||||
operand_shapes,
|
||||
result_shapes))
|
||||
if xla_extension_version >= 202:
|
||||
ctx.module_context.add_host_callback(ifrt_callback)
|
||||
else:
|
||||
ctx.module_context.add_keepalive(ifrt_callback)
|
||||
descriptor_operand = ir_constant(callback_descriptor)
|
||||
callback_operands = [descriptor_operand, *operands]
|
||||
if operand_mlir_layouts is not None:
|
||||
@ -2358,7 +2371,7 @@ def emit_python_callback(
|
||||
]
|
||||
if token:
|
||||
token, *results = results
|
||||
return results, token, keepalive
|
||||
return results, token, ifrt_callback
|
||||
|
||||
def build_xla_computation_helper(
|
||||
closed_jaxpr: core.ClosedJaxpr, *, name: str, platform: str,
|
||||
|
@ -544,10 +544,9 @@ def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder):
|
||||
A = csr_matrix((data, indices, indptr), shape=(b.size, b.size))
|
||||
return (linalg.spsolve(A, b).astype(b.dtype),)
|
||||
|
||||
result, _, keepalive = mlir.emit_python_callback(
|
||||
result, _, _ = mlir.emit_python_callback(
|
||||
ctx, _callback, None, args, ctx.avals_in, ctx.avals_out,
|
||||
has_side_effect=False)
|
||||
ctx.module_context.add_keepalive(keepalive)
|
||||
return result
|
||||
|
||||
|
||||
|
@ -123,13 +123,12 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token_in = ctx.tokens_in.get(effect)[0]
|
||||
|
||||
out_op, token_out, keep_alive = mlir.emit_python_callback(
|
||||
out_op, token_out, _ = mlir.emit_python_callback(
|
||||
ctx, callback, token_in, list(args), list(ctx.avals_in),
|
||||
list(ctx.avals_out), True)
|
||||
if token_out:
|
||||
ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect:
|
||||
token_out})))
|
||||
ctx.module_context.add_keepalive(keep_alive)
|
||||
return out_op
|
||||
|
||||
mlir.register_lowering(callback_p, callback_effect_lowering)
|
||||
|
@ -30,6 +30,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental import io_callback
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
@ -772,6 +773,27 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
out,
|
||||
np.arange(2 * jax.local_device_count()).reshape([-1, 2]) + 1.)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 202, "Test requires jaxlib 0.4.18")
|
||||
def test_callback_with_immediate_executable_destruction(self):
|
||||
|
||||
def loop_body(i, x):
|
||||
del i
|
||||
return jax.pure_callback(lambda y: y + np.ones(4, np.float32),
|
||||
x, x)
|
||||
|
||||
class AClass:
|
||||
def f(self, ys):
|
||||
return lax.fori_loop(0, 10, loop_body, jnp.ones(4, np.float32))
|
||||
|
||||
num_devices = jax.local_device_count()
|
||||
c = AClass()
|
||||
out = jax.pmap(c.f)(np.ones((num_devices,), np.float32))
|
||||
# c.f is an ephemeral bound method object, and it will be destroyed
|
||||
# immediately. This test verifies that the execution itself keeps the
|
||||
# callback alive.
|
||||
np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))
|
||||
|
||||
|
||||
def test_callback_inside_xmap(self):
|
||||
|
||||
def _callback(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user