mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[mutable-arrays] support closed-over mutable arrays in jit
This commit is contained in:
parent
e7eb2075b8
commit
649cd50681
@ -1918,6 +1918,7 @@ class MutableArray:
|
||||
dtype = property(lambda self: self._aval.dtype)
|
||||
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
|
||||
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
|
||||
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
|
||||
pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
|
||||
def mutable_array(init_val):
|
||||
|
@ -1005,11 +1005,9 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
|
||||
|
||||
@weakref_lru_cache
|
||||
def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
|
||||
"""Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr."""
|
||||
"""Move n invars to constvars. Like an inverse of convert_constvars_jaxpr."""
|
||||
if n == 0:
|
||||
return jaxpr.replace() # 'return jaxpr' would create cache reference cycle
|
||||
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
|
||||
raise NotImplementedError
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
constvars, invars = split_list(jaxpr.invars, [n])
|
||||
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
|
||||
|
@ -67,14 +67,12 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
|
||||
AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
|
||||
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
|
||||
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
|
||||
SingleDeviceSharding, GSPMDSharding
|
||||
)
|
||||
from jax._src.util import (safe_map, safe_zip, partition_list,
|
||||
wrap_name, tuple_update, tuple_delete,
|
||||
distributed_debug_log,
|
||||
SingleDeviceSharding, GSPMDSharding)
|
||||
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
|
||||
tuple_update, tuple_delete, distributed_debug_log,
|
||||
unzip2, HashableFunction, weakref_lru_cache)
|
||||
from jax._src.state.types import AbstractRef, RefEffect
|
||||
|
||||
@ -1153,14 +1151,14 @@ class ExecuteReplicated:
|
||||
__slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler',
|
||||
'has_unordered_effects', 'ordered_effects', 'keepalive',
|
||||
'has_host_callbacks', '_local_devices', 'kept_var_idx',
|
||||
'out_mut', '__weakref__']
|
||||
'mut', '__weakref__']
|
||||
|
||||
def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
|
||||
out_handler: ResultsHandler,
|
||||
unordered_effects: list[core.Effect],
|
||||
ordered_effects: list[core.Effect], keepalive: Any,
|
||||
has_host_callbacks: bool, kept_var_idx: set[int],
|
||||
out_mut: Sequence[int | None] | None):
|
||||
mut: MutationData | None):
|
||||
self.xla_executable = xla_executable
|
||||
self.name = name
|
||||
self.backend = backend
|
||||
@ -1172,7 +1170,7 @@ class ExecuteReplicated:
|
||||
self.keepalive = keepalive
|
||||
self.has_host_callbacks = has_host_callbacks
|
||||
self.kept_var_idx = kept_var_idx
|
||||
self.out_mut = out_mut
|
||||
self.mut = mut
|
||||
|
||||
def _add_tokens_to_inputs(self, input_bufs):
|
||||
if self.ordered_effects:
|
||||
@ -1195,6 +1193,8 @@ class ExecuteReplicated:
|
||||
@profiler.annotate_function
|
||||
def __call__(self, *args):
|
||||
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
|
||||
if self.mut:
|
||||
args = [*args, *self.mut.in_mut]
|
||||
input_bufs = self.in_handler(args)
|
||||
if (self.ordered_effects or self.has_unordered_effects
|
||||
or self.has_host_callbacks):
|
||||
@ -1215,11 +1215,11 @@ class ExecuteReplicated:
|
||||
out = self.out_handler(out_arrays)
|
||||
else:
|
||||
out = results.consume_with_handlers(self.out_handler.handlers)
|
||||
if self.out_mut is None:
|
||||
if self.mut is None:
|
||||
return out
|
||||
else:
|
||||
out_ = []
|
||||
for i, o in zip(self.out_mut, out):
|
||||
for i, o in zip(self.mut.out_mut, out):
|
||||
if i is not None:
|
||||
args[i]._buf = o
|
||||
else:
|
||||
@ -1781,19 +1781,38 @@ def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
|
||||
return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars,
|
||||
kept_var_idx, name_stack)
|
||||
|
||||
class MutationData(NamedTuple):
|
||||
in_mut: list[core.MutableArray]
|
||||
out_mut: list[int | None]
|
||||
|
||||
@weakref_lru_cache
|
||||
def _discharge_refs(
|
||||
jaxpr: core.ClosedJaxpr
|
||||
) -> tuple[core.ClosedJaxpr, Sequence[int | None], Sequence[int | None]]:
|
||||
) -> tuple[core.ClosedJaxpr, Sequence[int | None], MutationData]:
|
||||
from jax._src.state.discharge import discharge_state
|
||||
jaxpr, in_mut = _move_mutable_consts(jaxpr)
|
||||
new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
|
||||
count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end
|
||||
inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals)
|
||||
if isinstance(a, AbstractRef)}
|
||||
outin_map = {j: i for i, j in inout_map.items()}
|
||||
inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals))))
|
||||
out_mut = tuple(map(outin_map.get, range(len(new_jaxpr.out_avals))))
|
||||
return new_jaxpr, inout_aliases, out_mut
|
||||
out_mut = list(map(outin_map.get, range(len(new_jaxpr.out_avals))))
|
||||
return new_jaxpr, inout_aliases, MutationData(in_mut, out_mut)
|
||||
|
||||
@weakref_lru_cache
|
||||
def _move_mutable_consts(
|
||||
closed_jaxpr: core.ClosedJaxpr,
|
||||
) -> tuple[core.ClosedJaxpr, list[core.MutableArray]]:
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
hoist = [isinstance(c, core.MutableArray) for c in closed_jaxpr.consts]
|
||||
consts, in_mut = partition_list(hoist, closed_jaxpr.consts)
|
||||
constvars, mutvars = partition_list(hoist, jaxpr.constvars)
|
||||
invars = (*jaxpr.invars, *mutvars)
|
||||
effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns)
|
||||
jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns,
|
||||
effects, None)
|
||||
return core.ClosedJaxpr(jaxpr, consts), in_mut
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -2012,16 +2031,20 @@ def lower_sharding_computation(
|
||||
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)
|
||||
|
||||
if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects):
|
||||
closed_jaxpr, inout_aliases, out_mut = _discharge_refs(closed_jaxpr)
|
||||
if out_mut:
|
||||
closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr)
|
||||
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut)
|
||||
in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut)
|
||||
donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut)
|
||||
out_layouts_ = iter(zip(out_shardings, out_layouts))
|
||||
out_shardings, out_layouts = unzip2(
|
||||
next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i])
|
||||
for i in out_mut)
|
||||
for i in mut.out_mut)
|
||||
assert next(out_layouts_, None) is None
|
||||
# TODO(yashkatariya): remove global_in_avals / global_out_avals
|
||||
global_in_avals = closed_jaxpr.in_avals
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
else:
|
||||
inout_aliases = out_mut = None
|
||||
inout_aliases = mut = None
|
||||
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
@ -2106,7 +2129,7 @@ def lower_sharding_computation(
|
||||
host_callbacks=host_callbacks,
|
||||
keepalive=keepalive,
|
||||
kept_var_idx=kept_var_idx,
|
||||
out_mut=out_mut,
|
||||
mut=mut,
|
||||
backend=backend,
|
||||
device_assignment=da_object,
|
||||
committed=committed,
|
||||
@ -2775,7 +2798,7 @@ class UnloadedMeshExecutable:
|
||||
keepalive: Sequence[Any]
|
||||
host_callbacks: Sequence[Any]
|
||||
kept_var_idx: set[int]
|
||||
out_mut: Sequence[None | int] | None
|
||||
mut: MutationData | None
|
||||
auto_spmd_lowering: bool
|
||||
in_layouts: Sequence[SpecifiedLayout | None]
|
||||
out_layouts: Sequence[SpecifiedLayout | None]
|
||||
@ -2795,7 +2818,7 @@ class UnloadedMeshExecutable:
|
||||
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
||||
self.xla_executable, self.name, self.backend, handle_args,
|
||||
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks), self.kept_var_idx, self.out_mut)
|
||||
bool(self.host_callbacks), self.kept_var_idx, self.mut)
|
||||
return unsafe_call
|
||||
|
||||
def load(self) -> MeshExecutable:
|
||||
@ -2829,7 +2852,7 @@ class UnloadedMeshExecutable:
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
pmap_nreps: int = 1,
|
||||
out_mut: Sequence[None | int] | None = None,
|
||||
mut: MutationData | None = None,
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
all_default_mem_kind: bool = True,
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
@ -2922,7 +2945,7 @@ class UnloadedMeshExecutable:
|
||||
keepalive=keepalive,
|
||||
host_callbacks=host_callbacks,
|
||||
kept_var_idx=kept_var_idx,
|
||||
out_mut=out_mut,
|
||||
mut=mut,
|
||||
auto_spmd_lowering=auto_spmd_lowering,
|
||||
in_layouts=in_layouts, # type: ignore
|
||||
out_layouts=out_layouts, # type: ignore
|
||||
|
@ -1409,7 +1409,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
fastpath_data = _get_fastpath_data(
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], set())
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects)
|
||||
return out_flat, fastpath_data
|
||||
|
||||
f = _get_jaxpr_as_fun(
|
||||
@ -1561,6 +1561,14 @@ def pjit_staging_rule(trace, *args, **params):
|
||||
params['jaxpr'].effects, source_info)
|
||||
trace.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
elif any(isinstance(c, core.MutableArray) for c in params['jaxpr'].consts):
|
||||
jaxpr, consts = pxla._move_mutable_consts(params['jaxpr'])
|
||||
consts = map(trace.instantiate_const, consts)
|
||||
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
|
||||
donated_invars = (*params['donated_invars'],) + (False,) * len(consts)
|
||||
new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
donated_invars=donated_invars)
|
||||
return trace.default_process_primitive(pjit_p, (*args, *consts), new_params)
|
||||
else:
|
||||
return trace.default_process_primitive(pjit_p, args, params)
|
||||
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
|
||||
|
@ -436,8 +436,8 @@ def export(fun_jax: Callable,
|
||||
mlir_module = lowering.stablehlo()
|
||||
|
||||
args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
|
||||
if "out_mut" in lowering.compile_args:
|
||||
if lowering.compile_args["out_mut"]: raise NotImplementedError
|
||||
if "mut" in lowering.compile_args:
|
||||
if lowering.compile_args["mut"]: raise NotImplementedError
|
||||
if "kept_var_idx" in lowering.compile_args:
|
||||
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
|
||||
else:
|
||||
@ -747,7 +747,7 @@ def _check_lowering(lowering) -> None:
|
||||
allowed_compile_args = [
|
||||
"backend", "mesh", "global_in_avals",
|
||||
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
|
||||
"out_mut", "spmd_lowering", "auto_spmd_lowering",
|
||||
"mut", "spmd_lowering", "auto_spmd_lowering",
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
||||
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
|
||||
|
@ -1544,6 +1544,51 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
|
||||
self.assertAllClose(out2, y + w, check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_closed_over_basic(self, jit):
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
def f():
|
||||
x_mut[...] += 1.
|
||||
x_mut[0] += 1
|
||||
x_mut[1] += 5
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
f()
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
||||
check_dtypes=False)
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)()
|
||||
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_closed_over_nested(self, jit):
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
|
||||
@jax.jit
|
||||
def f(y_mut, z):
|
||||
x_mut[...] += 1.
|
||||
x_mut[0] += 1
|
||||
x_mut[1] += 5
|
||||
|
||||
y_mut[2] += 7
|
||||
return z + 9
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
y_mut = core.mutable_array(np.zeros(3))
|
||||
|
||||
w = f(y_mut, 1)
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(y_mut[...], jnp.array([0., 0., 7.]),
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(w, 10, check_dtypes=False)
|
||||
|
||||
|
||||
if CAN_USE_HYPOTHESIS:
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user