From 4a8babb10179f6e3a91f6ee290266cce06793c63 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 25 Jan 2024 22:20:36 -0800 Subject: [PATCH] integrate attrs in jax.jit Co-authored-by: Dougal Maclaurin --- jax/BUILD | 1 + jax/_src/ad_checkpoint.py | 4 +- jax/_src/api.py | 2 +- jax/_src/checkify.py | 6 +- jax/_src/custom_batching.py | 2 +- jax/_src/custom_derivatives.py | 6 +- jax/_src/interpreters/ad.py | 2 +- jax/_src/interpreters/batching.py | 4 +- jax/_src/interpreters/mlir.py | 2 +- jax/_src/interpreters/partial_eval.py | 68 ++++++++++++-------- jax/_src/lax/control_flow/common.py | 4 +- jax/_src/lax/control_flow/for_loop.py | 14 ++--- jax/_src/lax/control_flow/loops.py | 9 ++- jax/_src/lax/lax.py | 4 +- jax/_src/maps.py | 8 +-- jax/_src/pallas/core.py | 2 +- jax/_src/pallas/mosaic/lowering.py | 2 +- jax/_src/pallas/mosaic/primitives.py | 2 +- jax/_src/pallas/pallas_call.py | 17 +++-- jax/_src/pallas/triton/lowering.py | 10 +-- jax/_src/pjit.py | 68 +++++++++++++------- jax/_src/state/discharge.py | 12 ++-- jax/_src/state/utils.py | 4 +- jax/experimental/attrs.py | 64 +++++++++++++++++++ jax/experimental/custom_partitioning.py | 2 +- jax/experimental/jet.py | 2 +- jax/experimental/key_reuse/_forwarding.py | 2 +- jax/experimental/key_reuse/_simple.py | 2 +- jax/experimental/shard_map.py | 12 ++-- jax/experimental/sparse/transform.py | 4 +- jax/interpreters/partial_eval.py | 10 ++- tests/BUILD | 9 +++ tests/attrs_test.py | 64 +++++++++++++++++++ tests/core_test.py | 10 +-- tests/pallas/pallas_call_tpu_test.py | 2 +- tests/state_test.py | 76 +++++++++++------------ 36 files changed, 347 insertions(+), 165 deletions(-) create mode 100644 jax/experimental/attrs.py create mode 100644 tests/attrs_test.py diff --git a/jax/BUILD b/jax/BUILD index d7f163070..01cbdcdc3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -241,6 +241,7 @@ py_library_providing_imports_info( "_src/internal_test_util/**", ], ) + [ + "experimental/attrs.py", # until new parallelism APIs are moved out of experimental "experimental/maps.py", "experimental/pjit.py", diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 996a80fce..6baf89d72 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -372,7 +372,7 @@ def _trace_to_jaxpr(fun, in_tree, in_avals): flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint") try: - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) except core.ConcretizationTypeError as e: msg, = e.args if 'for checkpoint' not in msg: @@ -620,7 +620,7 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros, reduce_axes): in_cts_nz, _ = partition_list(in_zeros, in_cts) return in_cts_nz - transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals) + transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals) transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts) return transposed_jaxpr, cell.in_cts_zero # type: ignore diff --git a/jax/_src/api.py b/jax/_src/api.py index edfe052c0..e9b502c84 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -560,7 +560,7 @@ def xla_computation(fun: Callable, with ExitStack() as stack: for axis_name, size in axis_env or []: stack.enter_context(core.extend_axis_env(axis_name, size, None)) - jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals) + jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr)) ordered_effects = list( diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index e294ea42a..303d1d9af 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -746,7 +746,7 @@ def jaxpr_to_checkify_jaxpr( fun = lu.wrap_init(checkify_jaxpr_partial) fun, metadata = _flatten_and_get_error_metadata_thunk(fun) - new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals) + new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals) checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) out_tree, error_effects = metadata() return checked_jaxpr, out_tree, error_effects @@ -832,7 +832,7 @@ def checkify_while_body_jaxpr( return out new_body_f_ = lu.wrap_init(new_body_f) c_consts_avals = cond_jaxpr.in_avals[:c_consts_num] - jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals, + jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals, *body_jaxpr.in_avals]) closed_jaxpr = pe.close_jaxpr(jaxpr) err_vals, err_tree = jtu.tree_flatten(error) @@ -1128,7 +1128,7 @@ def checkify(f: Callable[..., Out], # stage: fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree) debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify') - jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, (), debug) + jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug) jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_)) # checkify: error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index c8a998712..fa74549c0 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -69,7 +69,7 @@ class custom_vmap: flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap") - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) in_tree = treedef_tuple((tree_structure(consts), in_tree)) assert self.vmap_rule is not None diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index c63adb5f6..8f6019194 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -66,7 +66,7 @@ def _resolve_kwargs(fun, args, kwargs): return ba.args def _initial_style_jaxpr(fun, in_avals): - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr, consts def _close_jaxpr(jaxpr): @@ -977,7 +977,7 @@ def custom_gradient(fun): ans_flat, out_tree = tree_flatten((ans,)) rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat] - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals) return ans, Residuals(jaxpr, in_tree(), out_tree, consts) def bwd(res, cts): @@ -1102,7 +1102,7 @@ def _maybe_perturbed(x: Any) -> bool: @cache() def _closure_convert_for_avals(fun, in_tree, in_avals): wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) + jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() (closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 766586945..a10bf5a1d 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -697,7 +697,7 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate): nonzeros) tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) - jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) + jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() @lu.transformation_with_aux diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 2675b9280..9bb4d63a4 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -789,7 +789,7 @@ def _batch_jaxpr2( avals_in2 = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(avals_in, in_axes2)] - jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in2) + jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) return core.ClosedJaxpr(jaxpr_out, consts), out_axes() def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis, @@ -834,7 +834,7 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, main_type) avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)] - jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in) + jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched() @lu.transformation_with_aux diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 340c3aa01..c9745e258 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1797,7 +1797,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args)) jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun) else: - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? out, tokens = jaxpr_subcomp( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index f95720bc3..9bf74c042 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -663,7 +663,7 @@ def _closed_call_param_updater(params, _, __): call_param_updaters[core.closed_call_p] = _closed_call_param_updater def abstract_eval_fun(fun, *avals, debug_info=None, **params): - _, avals_out, _ = trace_to_jaxpr_dynamic( + _, avals_out, _, () = trace_to_jaxpr_dynamic( lu.wrap_init(fun, params), avals, debug_info) assert all(isinstance(aval, AbstractValue) for aval in avals_out) return avals_out @@ -1113,7 +1113,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate): return [*known_vals_out, *residuals] known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk] - jaxpr_known, _, consts_known = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals) + jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals) (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking # check jaxpr_known and jaxpr_unknown in isolation @@ -1754,6 +1754,9 @@ class JaxprStackFrame: eqns: list[JaxprEqn] invars: list[Var] effects: core.Effects + attrs_tracked: list[tuple[Any, str]] + attrs_inits: list + attrs_vars: list[Var] debug_info: DebugInfo | None def __init__(self): @@ -1765,23 +1768,29 @@ class JaxprStackFrame: self.eqns = [] # cleared when we pop frame from main self.invars = [] self.effects = set() + self.attrs_tracked = [] + self.attrs_inits = [] + self.attrs_vars = [] self.debug_info = None def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> tuple[Jaxpr, list[Any]]: + def to_jaxpr(self, out_tracers: Sequence[Tracer] + ) -> tuple[Jaxpr, list[Any], list[tuple[Any, str]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) - outvars = [self.tracer_to_var[id(t)] for t in out_tracers] - constvals: Sequence[Any] + invars = self.attrs_vars + self.invars + state_outvars = [self.tracer_to_var[id(t)] for t in get_states(self.attrs_tracked)] + explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] + outvars = state_outvars + explicit_outvars constvars, constvals = unzip2(self.constvar_to_val.items()) - jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars, - self.eqns) - jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects) + jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns) + jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) - return jaxpr, list(constvals) + jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore + set_states(self.attrs_tracked, self.attrs_inits) + return jaxpr, list(constvals), self.attrs_tracked def to_jaxpr2(self, out_tracers): # It's not necessary, but we keep the tracer-to-var mapping injective: @@ -2064,7 +2073,7 @@ class DynamicJaxprTrace(core.Trace): for a, in_axis in zip(in_avals, params['in_axes'])] with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore with core.new_sublevel(): - jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic( + jaxpr, reduced_out_avals, consts, () = trace_to_subjaxpr_dynamic( f, self.main, reduced_in_avals, debug_info=debug_info_final(f, map_primitive.name)) ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects) @@ -2098,7 +2107,7 @@ class DynamicJaxprTrace(core.Trace): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): in_avals = [t.aval for t in tracers] with core.new_sublevel(): - fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) + fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) main_ = ref(self.main) @@ -2108,7 +2117,7 @@ class DynamicJaxprTrace(core.Trace): nz_tangent_avals, zero_avals = partition_list(in_zeros, in_avals) jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals)) in_avals_ = (*in_avals, *nz_tangent_avals) - jaxpr, _, out_consts = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_) + jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_) return jaxpr, out_consts, out_zeros() out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] @@ -2132,7 +2141,7 @@ class DynamicJaxprTrace(core.Trace): symbolic_zeros): in_avals = [t.aval for t in tracers] with core.new_sublevel(): - fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) + fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) main_ = ref(self.main) @@ -2170,7 +2179,7 @@ class DynamicJaxprTrace(core.Trace): in_avals_t = [*[t.aval for t in tracers_res], *out_types] with core.new_sublevel(): - call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic( + call_jaxpr, out_avals, call_consts, () = trace_to_subjaxpr_dynamic( call, self.main, in_avals_p) closed_call_jaxpr = core.ClosedJaxpr( convert_constvars_jaxpr(call_jaxpr), ()) @@ -2183,8 +2192,8 @@ class DynamicJaxprTrace(core.Trace): @_memoize def transpose_jaxpr_thunk(): for store in transpose_flat.stores: store.reset() - jaxpr, _, consts = trace_to_subjaxpr_dynamic(transpose_flat, main_(), - in_avals_t) + jaxpr, _, consts, () = trace_to_subjaxpr_dynamic( + transpose_flat, main_(), in_avals_t) return jaxpr, consts out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] @@ -2299,13 +2308,13 @@ def trace_to_jaxpr_dynamic( debug_info: DebugInfo | None = None, *, keep_inputs: list[bool] | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: +) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]: with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore main.jaxpr_stack = () # type: ignore - jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( + jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic( fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) del main, fun - return jaxpr, out_avals, consts + return jaxpr, out_avals, consts, attrs_tracked def trace_to_subjaxpr_dynamic( @@ -2315,7 +2324,7 @@ def trace_to_subjaxpr_dynamic( *, keep_inputs: Sequence[bool] | None = None, debug_info: DebugInfo | None = None, -) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: +) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs frame = JaxprStackFrame() @@ -2326,10 +2335,10 @@ def trace_to_subjaxpr_dynamic( in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] ans = fun.call_wrapped(*in_tracers_) out_tracers = map(trace.full_raise, ans) - jaxpr, consts = frame.to_jaxpr(out_tracers) + jaxpr, consts, attrs_tracked = frame.to_jaxpr(out_tracers) del fun, main, trace, frame, in_tracers, out_tracers, ans config.enable_checks.value and core.check_jaxpr(jaxpr) - return jaxpr, [v.aval for v in jaxpr.outvars], consts + return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked @profiler.annotate_function @@ -2380,7 +2389,7 @@ def trace_to_jaxpr_final( with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore main.jaxpr_stack = () # type: ignore with core.new_sublevel(): - jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( + jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic( fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) del fun, main return jaxpr, out_avals, consts @@ -2404,6 +2413,15 @@ AbstractedAxesSpec = Union[ tuple[AbstractedAxisName, ...], ] +AttrsTracked = list[tuple[Any, str]] +AttrStates = list +def set_states(attrs_tracked: AttrsTracked, vals: AttrStates): + for ((obj, attr), val) in zip(attrs_tracked, vals): + setattr(obj, attr, val) + +def get_states(attrs_tracked: AttrsTracked): + return [getattr(obj, attr) for (obj, attr) in attrs_tracked] + def infer_lambda_input_type( axes_specs: Sequence[AbstractedAxesSpec] | None, @@ -2629,7 +2647,7 @@ def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const] in_avals = [substitute(v.aval) for v in jaxpr.invars] eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts)) - padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals) + padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals) return padded_jaxpr, padded_consts class BoundedAxisSize(NamedTuple): diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 7dbc062f5..bb79bf981 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -58,7 +58,7 @@ def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) debug = pe.debug_info(fun, in_tree, out_tree, False, primitive_name or "") - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) return jaxpr, consts, out_tree() @weakref_lru_cache @@ -226,7 +226,7 @@ def _prune_zeros(ts): return [t for t in ts if type(t) is not ad_util.Zero] def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]): - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(traceable, in_avals) return core.ClosedJaxpr(jaxpr, consts) def _show_diff(array1, array2): diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 85404e99b..0bcfcc4e6 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -88,7 +88,7 @@ def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr: consts = map(lambda x: ref_get(x, ()), consts) all_consts = merge_lists(is_const_ref, consts, const_refs) return core.eval_jaxpr(jaxpr, all_consts, i, *args) - hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_hoist), in_avals) assert not consts, "All consts should have been converted to refs" return hoisted_jaxpr @@ -98,7 +98,7 @@ def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef, ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]: f, out_tree_thunk = flatten_fun_nokwargs( lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree))) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( f, state_avals) return jaxpr, consts, out_tree_thunk() @@ -607,7 +607,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn): return for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps, reverse=reverse, which_linear=jaxpr_known_which_linear, unroll=unroll) - call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic( + call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( known, [v.aval for v in known_invars]) call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars], @@ -629,7 +629,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn): _, ans = split_list(out_flat, [num_res]) _, ans = partition_list(out_inst, ans) return ans - call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic( + call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( staged, [v.aval for v in [*resvars, *eqn.invars]]) assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars) call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) @@ -667,7 +667,7 @@ def _convert_outputs_to_writes( AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error v.aval.dtype)) # pytype: disable=attribute-error for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)] - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( eval_jaxpr, [*in_avals, *res_ref_avals]) assert not consts return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] # pytype: disable=attribute-error @@ -693,7 +693,7 @@ def _convert_inputs_to_reads( aval.dtype)) # pytype: disable=attribute-error for aval, loop_invar in zip(res_val_avals, loop_invar_res)] - jaxpr, _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals]) return jaxpr @@ -720,7 +720,7 @@ def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr: ad.backward_pass( tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ()) return [] - jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr_trans, _, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) return jaxpr_trans diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 80975d457..3a937704a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -985,7 +985,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res]) out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known) return [*intensive_res, *out_loop] - call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic( + call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( known, [v.aval for v in ins_known]) call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) eqn_known = pe.new_jaxpr_eqn( @@ -1108,8 +1108,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts, new_in_avals = [*remaining_const_avals, *[a.inner_aval for a in in_ref_avals], *carry_avals, *[core.mapped_aval(length, 0, a) for a in xs_avals]] - new_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), - new_in_avals) + new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), new_in_avals) new_linear = (*remaining_consts_linear, *in_refs_linear, *carry_linear, *xs_linear) all_out = scan_p.bind(*remaining_consts, *in_refs, *carry, *xs, @@ -1768,7 +1767,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, *carry) carry, refs_out = split_list(carry_refs, [num_carry]) return [*refs_out, *carry] - new_body_jaxpr, _, new_body_consts = pe.trace_to_jaxpr_dynamic( + new_body_jaxpr, _, new_body_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(new_body), [*remaining_body_const_avals, *[a.inner_aval for a in ref_avals], *carry_avals]) @@ -1782,7 +1781,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, consts_refs_carry, [cond_nconsts, num_refs]) del refs # We don't use them here! return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry) - new_cond_jaxpr, _, new_cond_consts = pe.trace_to_jaxpr_dynamic( + new_cond_jaxpr, _, new_cond_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(new_cond), [*cond_consts_avals, *[a.inner_aval for a in ref_avals], *carry_avals]) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4ad39a398..01c86c1db 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1027,7 +1027,7 @@ def _reduction_jaxpr(computation, aval): f"Reduction functions should only return an array.\n" f"Full return value: {result}") return (result,) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(comp, (aval, aval)) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp, (aval, aval)) if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " @@ -1040,7 +1040,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree): flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals)) comp = lu.wrap_init(computation) flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals)) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals)) if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 86b8b2d1d..b1f54d07c 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -999,7 +999,7 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): for a, a_in_axes in zip(in_avals, params['in_axes'])] with core.extend_axis_env_nd(global_axis_sizes.items()): with core.new_sublevel(): - jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( + jaxpr, mapped_out_avals, consts, () = trace_to_subjaxpr_dynamic( f, self.main, mapped_in_avals) out_axes = params['out_axes_thunk']() if params['spmd_out_axes_thunk'] is not None: @@ -1340,7 +1340,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, # NOTE: We don't extend the resource env with the mesh shape, because those # resources are already in scope! It's the outermost xmap that introduces # them! - vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(f, local_avals) + vectorized_jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, local_avals) _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes) const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts] @@ -1415,7 +1415,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, add_spmd_axes(mesh_in_axes, spmd_in_axes) add_spmd_axes(mesh_out_axes, spmd_out_axes) global_in_avals = ctx.avals_in - vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals) + vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals) sharded_global_in_nodes = [ [mlir.wrap_with_sharding_op( @@ -1477,7 +1477,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, # resources are already in scope! It's the outermost xmap that introduces # them! global_in_avals = ctx.avals_in - vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals) + vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals) const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts] # We in-line here rather than generating a Call HLO as in the xla_call diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 681063fb1..1253b387a 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -192,7 +192,7 @@ def _convert_block_spec_to_block_mapping( block_shape = tuple( mapped if s is None else s for s in block_shape) flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts)) def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b17309e20..37e1865d8 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -543,7 +543,7 @@ def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable: def f_lowered(ctx: LoweringRuleContext, *args, **params): f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) wrapped_fun = lu.wrap_init(f, params) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) if consts: raise NotImplementedError jaxpr = pe.convert_constvars_jaxpr(jaxpr) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index ec673f72b..db31c3ed0 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -103,7 +103,7 @@ def run_scoped(f: Callable[..., None], *types, **kw_types) -> None: flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, _ = api_util.flatten_fun(lu.wrap_init(f), in_tree) avals = map(lambda t: t.get_aval(), flat_types) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, avals) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals) run_scoped_p.bind(*consts, jaxpr=jaxpr) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 6aa194c0a..5e466727d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -26,11 +26,11 @@ from jax import api_util from jax import tree_util from jax import lax from jax._src import state -from jax.interpreters import ad -from jax.interpreters import batching -from jax.interpreters import partial_eval as pe -from jax.interpreters import mlir -from jax.interpreters import xla +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import mlir +from jax._src.interpreters import xla from jax._src import ad_util from jax._src import core as jax_core from jax._src.state import primitives as sp @@ -258,7 +258,7 @@ def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray, idx_avals = [i32_aval] * (len(grid) + 1) else: idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals] - block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) shape = aval.shape if block_mapping is None else block_mapping.block_shape if dim is batching.not_mapped: @@ -387,7 +387,7 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr: consts = map(lambda x: sp.ref_get(x, ()), consts) all_consts = merge_lists(is_const_ref, consts, const_refs) return jax_core.eval_jaxpr(jaxpr, all_consts, *args) - hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_hoist), in_avals) assert not consts, "All consts should have been converted to refs" return hoisted_jaxpr @@ -401,8 +401,7 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals, wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(fun), jaxpr_in_tree) debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call") - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, - debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, debug) jaxpr = _hoist_consts_to_refs(jaxpr) return grid_mapping, jaxpr, consts, out_tree_thunk() diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index d6a0b47cf..7409923ad 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -53,8 +53,8 @@ from jax._src.util import merge_lists from jax._src.util import partition_list from jax._src.util import split_list from jax._src.util import weakref_lru_cache -from jax.interpreters import mlir -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe import jax.numpy as jnp from jax_triton import triton_lib from jax_triton.triton_lib import compile_ttir_to_ptx_inplace @@ -424,7 +424,7 @@ def _associative_scan_lowering( flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(body), in_tree ) - combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( flat_fun, in_avals ) out_tree = out_tree_thunk() @@ -572,7 +572,7 @@ def lower_fun( def f_lowered(ctx: TritonLoweringRuleContext, *args, **params): wrapped_fun = lu.wrap_init(fn, params) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) return out if multiple_results else out[0] @@ -998,7 +998,7 @@ def _reduction_lowering(body, ctx: TritonLoweringRuleContext, a, axes): flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(body), in_tree ) - combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( flat_fun, [*mapped_avals, *mapped_avals] ) out_tree = out_tree_thunk() diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index fffc4196c..244ab2164 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -21,7 +21,7 @@ import itertools as it import logging import operator as op import weakref -from typing import Callable, cast, NamedTuple, Any, Union +from typing import Callable, cast, NamedTuple, Any, Union, Optional import threading import warnings @@ -129,10 +129,13 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): - args_flat, _, params, _, out_tree, _, _, _, arg_names = infer_params_fn( - *args, **kwargs) + args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \ + infer_params_fn(*args, **kwargs) for arg in args_flat: dispatch.check_arg(arg) + if attrs_tracked: + init_states = _get_states(attrs_tracked) + args_flat = [*init_states, *args_flat] try: out_flat = pjit_p.bind(*args_flat, **params) except pxla.DeviceAssignmentMismatchError as e: @@ -142,8 +145,20 @@ def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): msg = _device_assignment_mismatch_error( fun_name, fails, args_flat, api_name, arg_names) raise ValueError(msg) from None + if attrs_tracked: + final_states, out_flat = split_list(out_flat, [len(attrs_tracked)]) + _set_states(attrs_tracked, final_states) outs = tree_unflatten(out_tree, out_flat) - return outs, out_flat, out_tree, args_flat, params['jaxpr'] + return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked + +def _set_states(attrs_tracked, vals): + from jax.experimental.attrs import jax_setattr # type: ignore + for ((obj, attr), val) in zip(attrs_tracked, vals): + jax_setattr(obj, attr, val) + +def _get_states(attrs_tracked): + from jax.experimental.attrs import jax_getattr # type: ignore + return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked] def _python_pjit(fun: Callable, infer_params_fn): @@ -161,7 +176,8 @@ def _python_pjit(fun: Callable, infer_params_fn): return wrapped -def _get_fastpath_data(executable, out_tree, args_flat, out_flat): +def _get_fastpath_data(executable, out_tree, args_flat, out_flat, attrs_tracked, + ) -> Optional[pxla.MeshExecutableFastpathData]: out_flat, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) use_fastpath = ( @@ -172,7 +188,9 @@ def _get_fastpath_data(executable, out_tree, args_flat, out_flat): not executable.unsafe_call.ordered_effects and not executable.unsafe_call.has_unordered_effects and not executable.unsafe_call.has_host_callbacks and - all(isinstance(x, xc.ArrayImpl) for x in out_flat) + all(isinstance(x, xc.ArrayImpl) for x in out_flat) and + # no attr state effects + not attrs_tracked ) if use_fastpath: @@ -224,11 +242,12 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames, @api_boundary def cache_miss(*args, **kwargs): - outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( + outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( fun, infer_params_fn, *args, **kwargs) executable = _read_most_recent_pjit_call_executable(jaxpr) - fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat) - return outs, fastpath_data + maybe_fastpath_data = _get_fastpath_data( + executable, out_tree, args_flat, out_flat, attrs_tracked) + return outs, maybe_fastpath_data if xla_extension_version >= 226: cpp_pjit_f = xc._xla.pjit( # type: ignore @@ -312,7 +331,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, out_layouts = kwargs.pop('_out_layouts', None) (args_flat, flat_global_in_avals, params, in_tree, out_tree, donated_invars, in_layouts_flat, out_layouts_flat, - arg_names) = infer_params_fn( + arg_names, ()) = infer_params_fn( *args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts) resource_env = params['resource_env'] mesh = None if resource_env is None else resource_env.physical_mesh @@ -339,7 +358,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames, @api_boundary def eval_shape(*args, **kwargs): - _, _, params, _, out_tree, _, _, _, _ = infer_params_fn( + _, _, params, _, out_tree, _, _, _, _, _ = infer_params_fn( *args, **kwargs, _in_layouts=None, _out_layouts=None) out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in params['jaxpr'].out_avals] @@ -464,7 +483,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs): hashable_pytree(in_shardings), hashable_pytree(in_layouts), in_avals, in_tree, resource_env, dbg, device_or_backend_set, True if kwargs else False) - jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr( + jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr( flat_fun, hashable_pytree(out_shardings), hashable_pytree(out_layouts), in_type, dbg, device_or_backend_set, HashableFunction(out_tree, closure=()), HashableFunction(res_paths, closure=()), inline) @@ -477,19 +496,19 @@ def common_infer_params(pjit_info_args, *args, **kwargs): implicit_args = [] args_flat = [*implicit_args, *explicit_args] - num_extra_args = len(implicit_args) + len(consts) + num_extra_args = len(implicit_args) + len(attrs_tracked) + len(consts) canonicalized_in_shardings_flat = \ (UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) == - len(donated_invars) == len(consts) + len(args_flat)) + len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat)) # in_shardings and out_shardings here are all GSPMDSharding. params = dict( jaxpr=jaxpr, in_shardings=canonicalized_in_shardings_flat, - out_shardings=canonicalized_out_shardings_flat, + out_shardings=out_shardings, resource_env=resource_env, donated_invars=donated_invars, name=getattr(flat_fun, '__name__', ''), @@ -498,7 +517,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs): ) return (consts + args_flat, in_type, params, in_tree, out_tree(), donated_invars, in_layouts_flat, out_layouts_flat, - dbg.arg_names if dbg else None) + dbg.arg_names if dbg else None, attrs_tracked) def _extract_implicit_args( in_type: Sequence[tuple[core.AbstractValue, bool]], @@ -1047,11 +1066,13 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): if config.dynamic_shapes.value: jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( lu.annotate(fun, in_type), debug_info=pe_debug) + attrs_tracked = [] else: - jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( + jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( fun, in_type, debug_info=pe_debug) - if not config.dynamic_shapes.value: + # TODO(dougalm,mattjj): enable debug info with attrs_tracked + if not config.dynamic_shapes.value and not attrs_tracked: jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths()) if config.enable_key_reuse_checks.value: @@ -1065,7 +1086,7 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline): else: closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) final_consts = [] - return closed_jaxpr, final_consts, global_out_avals + return closed_jaxpr, final_consts, global_out_avals, attrs_tracked @lru_cache(maxsize=4096) @@ -1108,13 +1129,12 @@ def _check_and_canonicalize_out_shardings( def _pjit_jaxpr(fun, out_shardings_thunk, out_layouts_thunk, in_type, debug_info, device_or_backend_set, out_tree, result_paths, inline): - jaxpr, final_consts, out_type = _create_pjit_jaxpr( + jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr( fun, in_type, debug_info, result_paths, IgnoreKey(inline)) canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( out_shardings_thunk, out_layouts_thunk, out_tree, tuple(out_type), jaxpr.jaxpr.debug_info, device_or_backend_set) - # lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple - return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat + return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat, attrs_tracked @dataclasses.dataclass(frozen=True) @@ -1357,7 +1377,7 @@ def _pjit_call_impl(*args, jaxpr, donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline) fastpath_data = _get_fastpath_data( - compiled, tree_structure(out_flat), args, out_flat) + compiled, tree_structure(out_flat), args, out_flat, []) return out_flat, fastpath_data f = _get_jaxpr_as_fun( @@ -1856,7 +1876,7 @@ pe.partial_eval_jaxpr_custom_rules[pjit_p] = \ @lu.cache def _pjit_transpose_trace(fun, in_avals): - transpose_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) + transpose_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts) return transpose_jaxpr diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 32576d988..1cdc03906 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -66,7 +66,7 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * , else v.aval for v, d in zip(jaxpr.invars, should_discharge)] eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr, should_discharge, consts)) - new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) + new_jaxpr, _ , new_consts, () = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) return new_jaxpr, new_consts @dataclasses.dataclass @@ -427,7 +427,7 @@ def _convert_outputs_to_writes( return [] res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef) else v.aval for v in jaxpr.outvars] - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( eval_jaxpr, [*in_avals, *res_ref_avals]) assert not consts return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error @@ -447,7 +447,7 @@ def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr: split_list([v.aval for v in jaxpr.invars], [num_res]) res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else aval for aval in res_val_avals] - jaxpr, _, () = pe.trace_to_jaxpr_dynamic( + jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic( eval_jaxpr, [*res_ref_avals, *orig_ref_avals]) return jaxpr @@ -648,7 +648,7 @@ def _run_state_partial_eval_custom( def staged(*args): out = run_state_p.bind(*args, **staged_params) return out[num_res:] - staged_call_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(staged, + staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged, [v.aval for v in res_staged_invars]) eqn_staged = pe.new_jaxpr_eqn(res_staged_invars, staged_outvars, @@ -707,7 +707,7 @@ def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool] ad.backward_pass( tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ()) return [] - jaxpr_trans, _, consts = pe.trace_to_jaxpr_dynamic( + jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) return jaxpr_trans, consts @@ -770,7 +770,7 @@ def _initial_style_jaxpr(fun, in_tree, in_avals): fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), tree_util.treedef_tuple((in_tree,))) debug = pe.debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state') - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) return jaxpr, consts, out_tree_thunk() def run_state(f: Callable[..., None]): diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py index 82b8df1b5..78e83ddf4 100644 --- a/jax/_src/state/utils.py +++ b/jax/_src/state/utils.py @@ -15,7 +15,7 @@ from collections.abc import Sequence -from jax.interpreters import partial_eval as pe +from jax._src.interpreters import partial_eval as pe from jax._src import core from jax._src import linear_util as lu from jax._src.state import AbstractRef @@ -44,7 +44,7 @@ def hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr: consts = map(lambda x: ref_get(x, ()), consts) all_consts = merge_lists(is_const_ref, consts, const_refs) return core.eval_jaxpr(jaxpr, all_consts, *args) - hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_hoist), in_avals) assert not consts, "All consts should have been converted to refs" return hoisted_jaxpr diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py new file mode 100644 index 000000000..b028cf4e8 --- /dev/null +++ b/jax/experimental/attrs.py @@ -0,0 +1,64 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +from jax._src import core +from jax._src.interpreters import partial_eval as pe + +JaxVal = Any + +getattr_p = core.Primitive('getattr') +setattr_p = core.Primitive('setattr') + +def jax_getattr(obj: Any, attr: str): + return getattr_p.bind(obj=obj, attr=attr) + +def jax_setattr(obj: Any, attr: str, val: JaxVal): + setattr_p.bind(val, obj=obj, attr=attr) + + +@getattr_p.def_impl +def _getattr_impl(*, obj, attr): + return getattr(obj, attr) + +@setattr_p.def_impl +def _setattr_impl(val, *, obj, attr): + setattr(obj, attr, val) + + +def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): + frame = trace.main.jaxpr_stack[-1] # type: ignore + if (obj, attr) not in frame.attrs_tracked: + init_val = getattr(obj, attr) + aval = core.raise_to_shaped(core.get_aval(init_val)) + tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current()) + var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) + setattr(obj, attr, tracer) + frame.attrs_tracked.append((obj, attr)) + frame.attrs_inits.append(init_val) + frame.attrs_vars.append(var) +pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked + +def _getattr_staging(trace, *, obj, attr): + trace._ensure_tracked(obj, attr) + return getattr(obj, attr) +pe.custom_staging_rules[getattr_p] = _getattr_staging + +def _setattr_staging(trace, tracer, *, obj, attr): + trace._ensure_tracked(obj, attr) + setattr(obj, attr, tracer) +pe.custom_staging_rules[setattr_p] = _setattr_staging diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 6556958ea..2759bdc52 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -459,7 +459,7 @@ class custom_partitioning: in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_partitioning") - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_partitioning_p.bind( diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 9cc12b467..ed4ecd599 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -718,7 +718,7 @@ def _jet_jaxpr( ) -> tuple[core.ClosedJaxpr, Any]: f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def) - jaxpr_jet, _, consts = pe.trace_to_jaxpr_dynamic( + jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic( f_jet, primals_and_series_avals) return core.ClosedJaxpr(jaxpr_jet, consts), out_tree_def diff --git a/jax/experimental/key_reuse/_forwarding.py b/jax/experimental/key_reuse/_forwarding.py index e02036d20..21f147165 100644 --- a/jax/experimental/key_reuse/_forwarding.py +++ b/jax/experimental/key_reuse/_forwarding.py @@ -181,7 +181,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature args_flat, in_tree = tree_util.tree_flatten(args) in_avals_flat = [core.get_aval(arg) for arg in args_flat] wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) return get_jaxpr_type_signature(jaxpr) diff --git a/jax/experimental/key_reuse/_simple.py b/jax/experimental/key_reuse/_simple.py index 86f88a0fc..5f27cae3d 100644 --- a/jax/experimental/key_reuse/_simple.py +++ b/jax/experimental/key_reuse/_simple.py @@ -152,7 +152,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature args_flat, in_tree = tree_util.tree_flatten(args) in_avals_flat = [core.get_aval(arg) for arg in args_flat] wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) return get_jaxpr_type_signature(jaxpr) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 8af936675..dc45196f3 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -426,7 +426,7 @@ def _shard_map_staging( in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) main = trace.main with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, genavals, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) + jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) out_avals_ = map(_check_shapedarray, genavals) _check_names(out_names_thunk(), out_avals_) in_rep = map(partial(_in_names_to_rep, mesh), in_names) @@ -1378,7 +1378,7 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which): res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval for v, w in zip(jaxpr.constvars, which)] in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(fun, in_avals) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, @@ -1495,7 +1495,7 @@ def _add_reshapes(which, jaxpr_known, jaxpr_staged): res = [_add_singleton(x) if not x.shape else x for x in res] return [*out_known, *res] avals_in = [v.aval for v in jaxpr_known.invars] - jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic(known, avals_in) + jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(known, avals_in) @lu.wrap_init def staged(*args): @@ -1505,7 +1505,7 @@ def _add_reshapes(which, jaxpr_known, jaxpr_staged): res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval for w, v in zip(which_, jaxpr_staged.invars[:len(which)])] avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]] - jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic(staged, avals_in) + jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(staged, avals_in) return jaxpr_known, jaxpr_staged @@ -1773,7 +1773,7 @@ def _replication_rewrite_match( f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) f = _match_rep(f, mesh, out_rep, out_rep_dst) with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) + jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts) # TODO(mattjj): caching @@ -1785,7 +1785,7 @@ def _replication_rewrite_nomatch( f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)) f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) with core.extend_axis_env_nd(mesh.shape.items()): - jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) + jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts), out_rep() @lu.transformation_with_aux diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 650201fa6..d5fcff0e4 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -447,7 +447,7 @@ def sparsify_raw(f): spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue) in_avals_flat = spvalues_to_avals(spenv, spvalues_flat) wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree) - jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) + jaxpr, out_avals_flat, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) result = eval_sparse(jaxpr, consts, spvalues_flat, spenv) if len(out_avals_flat) != len(result): raise Exception("Internal: eval_sparse does not return expected number of arguments. " @@ -747,7 +747,7 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues): args = spvalues_to_arrays(spenv, spvalues) args_flat, in_tree = tree_flatten(args) avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat] - sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) + sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat) sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts) assert out_tree is not None return sp_jaxpr, out_tree diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 0428a66dc..706f5a2fe 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -82,7 +82,7 @@ from jax._src.interpreters.partial_eval import ( result_info as result_info, sig_info as sig_info, trace_to_jaxpr as trace_to_jaxpr, - trace_to_jaxpr_dynamic as trace_to_jaxpr_dynamic, + trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic, trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2, trace_to_jaxpr_final as trace_to_jaxpr_final, trace_to_jaxpr_final2 as trace_to_jaxpr_final2, @@ -97,4 +97,12 @@ from jax._src.interpreters.partial_eval import ( trivial_ctx as trivial_ctx, ) + +# TODO(mattjj): remove temporary shim when trace_to_jaxpr_dynamic sig stabilizes +def trace_to_jaxpr_dynamic(fun, in_avals, debug_info=None, *, keep_inputs=None): # noqa + jaxpr, out_avals, consts, () = _trace_to_jaxpr_dynamic( + fun, in_avals, debug_info, keep_inputs=keep_inputs) + return jaxpr, out_avals, consts + + from jax._src.core import Jaxpr as Jaxpr diff --git a/tests/BUILD b/tests/BUILD index adcd74fcd..ba6d2c1be 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1301,6 +1301,15 @@ jax_test( srcs = ["clear_backends_test.py"], ) +jax_test( + name = "attrs_test", + srcs = ["attrs_test.py"], + deps = [ + "//jax:experimental", + ], +) + + jax_test( name = "experimental_rnn_test", srcs = ["experimental_rnn_test.py"], diff --git a/tests/attrs_test.py b/tests/attrs_test.py new file mode 100644 index 000000000..55cd7c22c --- /dev/null +++ b/tests/attrs_test.py @@ -0,0 +1,64 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass + +from absl.testing import absltest +from absl.testing import parameterized + +import jax + +from jax._src import config +from jax._src import test_util as jtu +from jax._src.util import safe_zip, safe_map + +from jax.experimental.attrs import jax_setattr, jax_getattr + +config.parse_flags_with_absl() + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +@dataclass +class Thing: + x: float + +class AttrsTest(jtu.JaxTestCase): + + @parameterized.parameters([True, False]) + def test_basic(self, jit: bool): + thing = Thing(1.0) + + def double_it() -> None: + cur_x = jax_getattr(thing, "x") + jax_setattr(thing, "x", cur_x * 2) + + if jit: + double_it = jax.jit(double_it) + + self.assertEqual(thing.x, 1.0) + double_it() + self.assertEqual(thing.x, 2.0) + double_it() + self.assertEqual(thing.x, 4.0) + double_it() + self.assertEqual(thing.x, 8.0) + double_it() + self.assertEqual(thing.x, 16.0) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/core_test.py b/tests/core_test.py index b3e3e1cf2..c96946180 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -624,7 +624,7 @@ class DynamicShapesTest(jtu.JaxTestCase): def f(x, y): return x, y - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 3) @@ -648,7 +648,7 @@ class DynamicShapesTest(jtu.JaxTestCase): return (x, w) return g(x, y, x, y) - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs @@ -684,7 +684,7 @@ class DynamicShapesTest(jtu.JaxTestCase): return (x, w) return g(x.shape[0], x, y, x, y) - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( f, [n, a, b], keep_inputs=[False, True, True]) # { lambda ; a:i32[] b:f32[a] c:f32[a]. let @@ -720,7 +720,7 @@ class DynamicShapesTest(jtu.JaxTestCase): u = lax_internal._reduce_sum(w, [0]) return (u,) - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs @@ -745,7 +745,7 @@ class DynamicShapesTest(jtu.JaxTestCase): def g(x): return x return g(a), - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( f, [n, m, a, b], keep_inputs=[False, False, True, True]) # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[a] = xla_call[ diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 8332cfcc2..fe0b0ba97 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -318,7 +318,7 @@ class PallasCallDMATest(parameterized.TestCase): pltpu.run_scoped(body, pltpu.VMEM((8,), jnp.float32)) return [] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(kernel), [ state.shaped_array_ref((8,), jnp.float32), diff --git a/tests/state_test.py b/tests/state_test.py index 13893d1af..c6088d237 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -132,7 +132,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): with self.assertRaises(Exception): pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval]) else: - jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval]) self.assertSetEqual(jaxpr.effects, {ReadEffect(len(jaxpr.constvars))}) @@ -229,7 +229,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): with self.assertRaises(Exception): pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) else: - jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, val_aval]) self.assertSetEqual(jaxpr.effects, {WriteEffect(len(jaxpr.constvars))}) @@ -306,7 +306,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): with self.assertRaises(Exception): pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval]) else: - jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, val_aval]) self.assertSetEqual(jaxpr.effects, {AccumEffect(len(jaxpr.constvars))}) @@ -326,7 +326,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): x[()] = jnp.int32(1) x[()] = jnp.int32(2) return (x[()],) - jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertLen(consts, 0) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) @@ -339,7 +339,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): def body(x): ref_addupdate(x, (), jnp.int32(1)) return (x[()],) - jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic( + jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertLen(consts, 0) self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)]) @@ -349,14 +349,14 @@ class StatePrimitivesTest(jtu.JaxTestCase): def body(x_ref): x = x_ref[()] return [x] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False)) def body(x_ref): x = x_ref[:, 0] return [x] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32)]) self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False)) @@ -364,14 +364,14 @@ class StatePrimitivesTest(jtu.JaxTestCase): def body(x_ref): x_ref[()] = jnp.int32(2) return [] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x_ref[:, 0] = val return [] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False)) @@ -380,14 +380,14 @@ class StatePrimitivesTest(jtu.JaxTestCase): def body(x_ref): x = ref_swap(x_ref, (), jnp.int32(2)) return [x] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x = ref_swap(x_ref, (slice(None), 0), val) return [x] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("c:i32[1], a[:,0] <- a[:,0], b", @@ -397,7 +397,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): def body(x_ref): ref_addupdate(x_ref, (), jnp.int32(2)) return [] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((), jnp.int32)]) self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False)) @@ -405,7 +405,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): def body(x_ref, val): ref_addupdate(x_ref, (slice(None), 0), val) return [] - jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32), core.ShapedArray((1,), jnp.int32)]) self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False)) @@ -422,7 +422,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): in_avals = [shaped_array_ref((), jnp.dtype('float32')), shaped_array_ref((), jnp.dtype('float32'))] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) self.assertEqual(jaxpr.eqns[0].primitive, get_p) self.assertEqual(jaxpr.eqns[1].primitive, get_p) @@ -438,7 +438,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): in_avals = [shaped_array_ref((), jnp.dtype('float32')), shaped_array_ref((), jnp.dtype('float32'))] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) self.assertEqual(jaxpr.eqns[0].primitive, get_p) self.assertEqual(jaxpr.eqns[1].primitive, get_p) self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p) @@ -458,7 +458,7 @@ class StatePrimitivesTest(jtu.JaxTestCase): in_avals = [shaped_array_ref((), jnp.dtype('float32')), shaped_array_ref((), jnp.dtype('float32'))] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals) self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p) self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p) self.assertEqual(jaxpr.eqns[2].primitive, get_p) @@ -529,12 +529,12 @@ class StatePrimitivesTest(jtu.JaxTestCase): # discharge-of-vmap f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) - stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs) # vmap-of-discharge - stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, *idx_avals]) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), @@ -553,7 +553,7 @@ class StateDischargeTest(jtu.JaxTestCase): a = ref_get(a_ref, ()) return [a + 1] in_avals = [shaped_array_ref((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) @@ -569,7 +569,7 @@ class StateDischargeTest(jtu.JaxTestCase): a = ref_get(a_ref, (0, 1)) return [a + 1] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) @@ -588,7 +588,7 @@ class StateDischargeTest(jtu.JaxTestCase): a = a_ref[jnp.array([0, 1])] return [a + 1] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), in_avals) discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) @@ -603,7 +603,7 @@ class StateDischargeTest(jtu.JaxTestCase): return [] in_avals = [shaped_array_ref((), jnp.dtype('float32')), core.ShapedArray((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that ignores the first # value and returns second value plus 1. @@ -620,7 +620,7 @@ class StateDischargeTest(jtu.JaxTestCase): ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32'))) return [] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that just adds 1. discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts) @@ -640,7 +640,7 @@ class StateDischargeTest(jtu.JaxTestCase): a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32') return [] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) @@ -654,7 +654,7 @@ class StateDischargeTest(jtu.JaxTestCase): return [] in_avals = [shaped_array_ref((), jnp.dtype('float32')), core.ShapedArray((), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) # Discharging should just turn this into a jaxpr that adds the first value, # second value, and 1. @@ -672,7 +672,7 @@ class StateDischargeTest(jtu.JaxTestCase): jnp.ones(2, dtype=jnp.dtype('float32'))) return [] in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) @@ -693,7 +693,7 @@ class StateDischargeTest(jtu.JaxTestCase): jnp.ones((2, 3), 'float32')) return [] in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, discharged_consts = discharge_state( stateful_jaxpr, consts) @@ -707,7 +707,7 @@ class StateDischargeTest(jtu.JaxTestCase): b = a + 1 return [a, b] in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts) self.assertLen(discharged_jaxpr.invars, 1) @@ -727,7 +727,7 @@ class StateDischargeTest(jtu.JaxTestCase): shaped_array_ref((4,), jnp.dtype('float32')), shaped_array_ref((4,), jnp.dtype('float32')) ] - stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) discharged_jaxpr, _ = discharge_state( stateful_jaxpr, consts, should_discharge=[False, True]) @@ -929,13 +929,13 @@ if CAN_USE_HYPOTHESIS: ref = get_vmap_param.bat_ref f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) - stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) # vmap-of-discharge - stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, *idx_avals]) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), @@ -974,13 +974,13 @@ if CAN_USE_HYPOTHESIS: f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), out_axes=[]) - stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) # vmap-of-discharge - stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, val_aval, *idx_avals]) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), @@ -1018,13 +1018,13 @@ if CAN_USE_HYPOTHESIS: f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), out_axes=[]) - stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) # vmap-of-discharge - stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic( + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [ref_aval, val_aval, *idx_avals]) jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), @@ -1323,7 +1323,7 @@ class GeneralRefTest(jtu.JaxTestCase): x_ref[...] = x ref_addupdate(x_ref, (), x) return [x] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [AbstractRef(core.UnshapedArray(jnp.int32))]) self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray) self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32")) @@ -1334,7 +1334,7 @@ class GeneralRefTest(jtu.JaxTestCase): x_ref[...] = x ref_addupdate(x_ref, (), x) return [x] - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [AbstractRef(core.AbstractToken())]) self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken) @@ -1343,7 +1343,7 @@ class GeneralRefTest(jtu.JaxTestCase): x_ref = x_ref_ref[...] return [x_ref] # Not sure why you'd ever want to do this, but it works! - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(f), [AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))]) self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef)