mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Attach keepalive to executable
This commit is contained in:
parent
c8230251ca
commit
ef982cfa8c
@ -815,7 +815,7 @@ def xla_computation(fun: Callable,
|
||||
out_parts_flat = tuple(flatten_axes(
|
||||
"xla_computation out_parts", out_tree(), out_parts))
|
||||
effects = list(jaxpr.effects)
|
||||
m = mlir.lower_jaxpr_to_module(
|
||||
m, _ = mlir.lower_jaxpr_to_module(
|
||||
f"xla_computation_{fun_name}",
|
||||
core.ClosedJaxpr(jaxpr, consts),
|
||||
effects=effects,
|
||||
|
@ -290,7 +290,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
return XlaComputation(
|
||||
name, None, True, None, None, jaxpr=jaxpr, consts=consts, device=device,
|
||||
in_avals=abstract_args, out_avals=out_avals, effects=jaxpr.effects,
|
||||
kept_var_idx=kept_var_idx)
|
||||
kept_var_idx=kept_var_idx, keepalive=None)
|
||||
|
||||
if not _on_exit:
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
@ -326,14 +326,14 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module_name = f"jit_{fun.__name__}"
|
||||
effects = [eff for eff in closed_jaxpr.effects if eff in core.ordered_effects]
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module, keepalive = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, effects, backend.platform,
|
||||
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
|
||||
return XlaComputation(
|
||||
name, module, False, donated_invars, which_explicit, nreps=nreps,
|
||||
device=device, backend=backend, tuple_args=tuple_args,
|
||||
in_avals=abstract_args, out_avals=out_avals, effects=effects,
|
||||
kept_var_idx=kept_var_idx)
|
||||
kept_var_idx=kept_var_idx, keepalive=keepalive)
|
||||
|
||||
|
||||
def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool:
|
||||
@ -761,11 +761,15 @@ def compile_or_get_cached(backend, computation, compile_options):
|
||||
|
||||
|
||||
class XlaCompiledComputation(stages.Executable):
|
||||
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call):
|
||||
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call,
|
||||
keepalive: Any):
|
||||
self._xla_executable = xla_executable
|
||||
self.in_avals = in_avals
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self.unsafe_call = unsafe_call
|
||||
# Only the `unsafe_call` function is cached, so to avoid the `keepalive`
|
||||
# being garbage collected we attach it to `unsafe_call`.
|
||||
self.unsafe_call.keepalive = keepalive
|
||||
|
||||
@staticmethod
|
||||
def from_xla_computation(
|
||||
@ -779,7 +783,8 @@ class XlaCompiledComputation(stages.Executable):
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
out_avals: Sequence[core.AbstractValue],
|
||||
effects: List[core.Effect],
|
||||
kept_var_idx: Set[int]) -> XlaCompiledComputation:
|
||||
kept_var_idx: Set[int],
|
||||
keepalive: Optional[Any]) -> XlaCompiledComputation:
|
||||
sticky_device = device
|
||||
input_handler = _input_handler(backend, explicit_args, in_avals)
|
||||
result_handlers = map(partial(aval_to_result_handler, sticky_device),
|
||||
@ -800,7 +805,8 @@ class XlaCompiledComputation(stages.Executable):
|
||||
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
||||
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts,
|
||||
result_handlers, effects, kept_var_idx)
|
||||
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call)
|
||||
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
|
||||
keepalive)
|
||||
|
||||
def is_trivial(self):
|
||||
return self._xla_executable == None
|
||||
@ -814,11 +820,13 @@ class XlaCompiledComputation(stages.Executable):
|
||||
|
||||
@staticmethod
|
||||
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals, effects,
|
||||
kept_var_idx) -> XlaCompiledComputation:
|
||||
kept_var_idx, keepalive: Optional[Any]) -> XlaCompiledComputation:
|
||||
assert keepalive is None
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
unsafe_call = partial(_execute_trivial, jaxpr, device, consts,
|
||||
out_avals, result_handlers, effects, kept_var_idx)
|
||||
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call)
|
||||
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
|
||||
keepalive)
|
||||
|
||||
# -- stages.Executable protocol
|
||||
|
||||
|
@ -365,6 +365,7 @@ class ModuleContext:
|
||||
platform: str
|
||||
axis_context: AxisContext
|
||||
name_stack: NameStack
|
||||
keepalives: List[Any]
|
||||
|
||||
# Cached primitive lowerings.
|
||||
cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp]
|
||||
@ -378,6 +379,7 @@ class ModuleContext:
|
||||
platform: str,
|
||||
axis_context: AxisContext,
|
||||
name_stack: NameStack,
|
||||
keepalives: List[Any],
|
||||
context: Optional[ir.Context] = None,
|
||||
module: Optional[ir.Module] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
@ -393,6 +395,10 @@ class ModuleContext:
|
||||
self.name_stack = name_stack
|
||||
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
|
||||
else cached_primitive_lowerings)
|
||||
self.keepalives = keepalives
|
||||
|
||||
def add_keepalive(self, keepalive: Any) -> None:
|
||||
self.keepalives.append(keepalive)
|
||||
|
||||
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
||||
|
||||
@ -483,7 +489,7 @@ def lower_jaxpr_to_module(
|
||||
replicated_args: Optional[Sequence[bool]] = None,
|
||||
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
|
||||
) -> ir.Module:
|
||||
) -> Tuple[ir.Module, Optional[Any]]:
|
||||
"""Lowers a top-level jaxpr to an MHLO module.
|
||||
|
||||
Handles the quirks of the argument/return value passing conventions of the
|
||||
@ -518,7 +524,9 @@ def lower_jaxpr_to_module(
|
||||
msg = f"Donation is not implemented for {platform}.\n{msg}"
|
||||
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
|
||||
|
||||
ctx = ModuleContext(platform, axis_context, name_stack)
|
||||
# Create a keepalives list that will be mutated during the lowering.
|
||||
keepalives: List[Any] = []
|
||||
ctx = ModuleContext(platform, axis_context, name_stack, keepalives)
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# Remove module name characters that XLA would alter. This ensures that
|
||||
# XLA computation preserves the module name.
|
||||
@ -541,7 +549,7 @@ def lower_jaxpr_to_module(
|
||||
input_output_aliases=input_output_aliases)
|
||||
|
||||
ctx.module.operation.verify()
|
||||
return ctx.module
|
||||
return ctx.module, ctx.keepalives
|
||||
|
||||
def module_to_string(module: ir.Module) -> str:
|
||||
output = io.StringIO()
|
||||
|
@ -1054,11 +1054,13 @@ def lower_parallel_callable(
|
||||
module_name = f"pmap_{fun.__name__}"
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
effects = list(closed_jaxpr.effects)
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
# TODO(sharadmv): attach keepalive to computation
|
||||
module, keepalive = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, effects, backend.platform, mlir.ReplicaAxisContext(axis_env),
|
||||
name_stack, donated_invars, replicated_args=replicated_args,
|
||||
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
|
||||
result_shardings=_shardings_to_mlir_shardings(parts.out_parts))
|
||||
del keepalive
|
||||
return PmapComputation(module, pci=pci, replicas=replicas, parts=parts,
|
||||
shards=shards, tuple_args=tuple_args)
|
||||
|
||||
@ -2238,10 +2240,12 @@ def lower_mesh_computation(
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
effects = list(closed_jaxpr.effects)
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
# TODO(sharadmv): attach keepalive to computation
|
||||
module, keepalive = mlir.lower_jaxpr_to_module(
|
||||
module_name, closed_jaxpr, effects, backend.platform, axis_ctx, name_stack,
|
||||
donated_invars, replicated_args=replicated_args,
|
||||
arg_shardings=in_partitions, result_shardings=out_partitions)
|
||||
del keepalive
|
||||
|
||||
return MeshComputation(
|
||||
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
|
||||
|
@ -141,7 +141,7 @@ def _sharded_callable(
|
||||
|
||||
axis_env = xla.AxisEnv(nrep, (), ())
|
||||
effects = list(jaxpr.effects)
|
||||
module = mlir.lower_jaxpr_to_module(
|
||||
module, _ = mlir.lower_jaxpr_to_module(
|
||||
"spjit_{}".format(fun.__name__),
|
||||
core.ClosedJaxpr(jaxpr, consts),
|
||||
effects,
|
||||
|
@ -88,9 +88,6 @@ def _(*avals, callback, out_avals, effect):
|
||||
del avals, callback
|
||||
return out_avals, {effect}
|
||||
|
||||
|
||||
# TODO(sharadmv): Attach keep alive to executable
|
||||
leak = [].append
|
||||
def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out_avals, effect):
|
||||
del out_avals
|
||||
def _token_callback(token, *args):
|
||||
@ -102,7 +99,7 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
|
||||
ctx.module_context.platform, _token_callback,
|
||||
[token_in, *args], [core.abstract_token, *ctx.avals_in],
|
||||
[core.abstract_token, *ctx.avals_out], True)
|
||||
leak(keep_alive)
|
||||
ctx.module_context.add_keepalive(keep_alive)
|
||||
ctx.set_tokens_out(ctx.tokens_in.update_tokens(mlir.TokenSet({effect:
|
||||
token_out})))
|
||||
return out_op
|
||||
|
Loading…
x
Reference in New Issue
Block a user