Attach keepalive to executable

This commit is contained in:
Sharad Vikram 2022-04-14 14:18:31 -07:00
parent c8230251ca
commit ef982cfa8c
6 changed files with 36 additions and 19 deletions

View File

@ -815,7 +815,7 @@ def xla_computation(fun: Callable,
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
effects = list(jaxpr.effects)
m = mlir.lower_jaxpr_to_module(
m, _ = mlir.lower_jaxpr_to_module(
f"xla_computation_{fun_name}",
core.ClosedJaxpr(jaxpr, consts),
effects=effects,

View File

@ -290,7 +290,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
return XlaComputation(
name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device,
in_avals=abstract_args, out_avals=out_avals, effects=jaxpr.effects,
kept_var_idx=kept_var_idx)
kept_var_idx=kept_var_idx, keepalive=None)
if not _on_exit:
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
@ -326,14 +326,14 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module_name = f"jit_{fun.__name__}"
effects = [eff for eff in closed_jaxpr.effects if eff in core.ordered_effects]
module = mlir.lower_jaxpr_to_module(
module, keepalive = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, effects, backend.platform,
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
return XlaComputation(
name, module, False, donated_invars, which_explicit, nreps=nreps,
device=device, backend=backend, tuple_args=tuple_args,
in_avals=abstract_args, out_avals=out_avals, effects=effects,
kept_var_idx=kept_var_idx)
kept_var_idx=kept_var_idx, keepalive=keepalive)
def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool:
@ -761,11 +761,15 @@ def compile_or_get_cached(backend, computation, compile_options):
class XlaCompiledComputation(stages.Executable):
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call):
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call,
keepalive: Any):
self._xla_executable = xla_executable
self.in_avals = in_avals
self._kept_var_idx = kept_var_idx
self.unsafe_call = unsafe_call
# Only the `unsafe_call` function is cached, so to avoid the `keepalive`
# being garbage collected we attach it to `unsafe_call`.
self.unsafe_call.keepalive = keepalive
@staticmethod
def from_xla_computation(
@ -779,7 +783,8 @@ class XlaCompiledComputation(stages.Executable):
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
effects: List[core.Effect],
kept_var_idx: Set[int]) -> XlaCompiledComputation:
kept_var_idx: Set[int],
keepalive: Optional[Any]) -> XlaCompiledComputation:
sticky_device = device
input_handler = _input_handler(backend, explicit_args, in_avals)
result_handlers = map(partial(aval_to_result_handler, sticky_device),
@ -800,7 +805,8 @@ class XlaCompiledComputation(stages.Executable):
execute = _execute_compiled if nreps == 1 else _execute_replicated
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts,
result_handlers, effects, kept_var_idx)
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
keepalive)
def is_trivial(self):
return self._xla_executable == None
@ -814,11 +820,13 @@ class XlaCompiledComputation(stages.Executable):
@staticmethod
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals, effects,
kept_var_idx) -> XlaCompiledComputation:
kept_var_idx, keepalive: Optional[Any]) -> XlaCompiledComputation:
assert keepalive is None
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
out_avals, result_handlers, effects, kept_var_idx)
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call)
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
keepalive)
# -- stages.Executable protocol

View File

@ -365,6 +365,7 @@ class ModuleContext:
platform: str
axis_context: AxisContext
name_stack: NameStack
keepalives: List[Any]
# Cached primitive lowerings.
cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp]
@ -378,6 +379,7 @@ class ModuleContext:
platform: str,
axis_context: AxisContext,
name_stack: NameStack,
keepalives: List[Any],
context: Optional[ir.Context] = None,
module: Optional[ir.Module] = None,
ip: Optional[ir.InsertionPoint] = None,
@ -393,6 +395,10 @@ class ModuleContext:
self.name_stack = name_stack
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
else cached_primitive_lowerings)
self.keepalives = keepalives
def add_keepalive(self, keepalive: Any) -> None:
self.keepalives.append(keepalive)
def replace(self, **kw): return dataclasses.replace(self, **kw)
@ -483,7 +489,7 @@ def lower_jaxpr_to_module(
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
) -> ir.Module:
) -> Tuple[ir.Module, Optional[Any]]:
"""Lowers a top-level jaxpr to an MHLO module.
Handles the quirks of the argument/return value passing conventions of the
@ -518,7 +524,9 @@ def lower_jaxpr_to_module(
msg = f"Donation is not implemented for {platform}.\n{msg}"
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
ctx = ModuleContext(platform, axis_context, name_stack)
# Create a keepalives list that will be mutated during the lowering.
keepalives: List[Any] = []
ctx = ModuleContext(platform, axis_context, name_stack, keepalives)
with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
@ -541,7 +549,7 @@ def lower_jaxpr_to_module(
input_output_aliases=input_output_aliases)
ctx.module.operation.verify()
return ctx.module
return ctx.module, ctx.keepalives
def module_to_string(module: ir.Module) -> str:
output = io.StringIO()

View File

@ -1054,11 +1054,13 @@ def lower_parallel_callable(
module_name = f"pmap_{fun.__name__}"
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
effects = list(closed_jaxpr.effects)
module = mlir.lower_jaxpr_to_module(
# TODO(sharadmv): attach keepalive to computation
module, keepalive = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, effects, backend.platform, mlir.ReplicaAxisContext(axis_env),
name_stack, donated_invars, replicated_args=replicated_args,
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
del keepalive
return PmapComputation(module, pci=pci, replicas=replicas, parts=parts,
shards=shards, tuple_args=tuple_args)
@ -2238,10 +2240,12 @@ def lower_mesh_computation(
module_name = f"{api_name}_{fun_name}"
with core.extend_axis_env_nd(mesh.shape.items()):
effects = list(closed_jaxpr.effects)
module = mlir.lower_jaxpr_to_module(
# TODO(sharadmv): attach keepalive to computation
module, keepalive = mlir.lower_jaxpr_to_module(
module_name, closed_jaxpr, effects, backend.platform, axis_ctx, name_stack,
donated_invars, replicated_args=replicated_args,
arg_shardings=in_partitions, result_shardings=out_partitions)
del keepalive
return MeshComputation(
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,

View File

@ -141,7 +141,7 @@ def _sharded_callable(
axis_env = xla.AxisEnv(nrep, (), ())
effects = list(jaxpr.effects)
module = mlir.lower_jaxpr_to_module(
module, _ = mlir.lower_jaxpr_to_module(
"spjit_{}".format(fun.__name__),
core.ClosedJaxpr(jaxpr, consts),
effects,

View File

@ -88,9 +88,6 @@ def _(*avals, callback, out_avals, effect):
del avals, callback
return out_avals, {effect}
# TODO(sharadmv): Attach keep alive to executable
leak = [].append
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect):
del out_avals
def _token_callback(token, *args):
@ -102,7 +99,7 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
ctx.module_context.platform, _token_callback,
[token_in, *args], [core.abstract_token, *ctx.avals_in],
[core.abstract_token, *ctx.avals_out], True)
leak(keep_alive)
ctx.module_context.add_keepalive(keep_alive)
ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect:
token_out})))
return out_op