integrate attrs in jax.jit

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2024-01-25 22:20:36 -08:00
parent d4660a0972
commit 4a8babb101
36 changed files with 347 additions and 165 deletions

View File

@ -241,6 +241,7 @@ py_library_providing_imports_info(
"_src/internal_test_util/**",
],
) + [
"experimental/attrs.py",
# until new parallelism APIs are moved out of experimental
"experimental/maps.py",
"experimental/pjit.py",

View File

@ -372,7 +372,7 @@ def _trace_to_jaxpr(fun, in_tree, in_avals):
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint")
try:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
except core.ConcretizationTypeError as e:
msg, = e.args
if 'for checkpoint' not in msg:
@ -620,7 +620,7 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros, reduce_axes):
in_cts_nz, _ = partition_list(in_zeros, in_cts)
return in_cts_nz
transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # type: ignore

View File

@ -560,7 +560,7 @@ def xla_computation(fun: Callable,
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
ordered_effects = list(

View File

@ -746,7 +746,7 @@ def jaxpr_to_checkify_jaxpr(
fun = lu.wrap_init(checkify_jaxpr_partial)
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
out_tree, error_effects = metadata()
return checked_jaxpr, out_tree, error_effects
@ -832,7 +832,7 @@ def checkify_while_body_jaxpr(
return out
new_body_f_ = lu.wrap_init(new_body_f)
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
*body_jaxpr.in_avals])
closed_jaxpr = pe.close_jaxpr(jaxpr)
err_vals, err_tree = jtu.tree_flatten(error)
@ -1128,7 +1128,7 @@ def checkify(f: Callable[..., Out],
# stage:
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify')
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)

View File

@ -69,7 +69,7 @@ class custom_vmap:
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
in_tree = treedef_tuple((tree_structure(consts), in_tree))
assert self.vmap_rule is not None

View File

@ -66,7 +66,7 @@ def _resolve_kwargs(fun, args, kwargs):
return ba.args
def _initial_style_jaxpr(fun, in_avals):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return jaxpr, consts
def _close_jaxpr(jaxpr):
@ -977,7 +977,7 @@ def custom_gradient(fun):
ans_flat, out_tree = tree_flatten((ans,))
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
def bwd(res, cts):
@ -1102,7 +1102,7 @@ def _maybe_perturbed(x: Any) -> bool:
@cache()
def _closure_convert_for_avals(fun, in_tree, in_avals):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
out_tree = out_tree()
(closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts)

View File

@ -697,7 +697,7 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
nonzeros)
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
@lu.transformation_with_aux

View File

@ -789,7 +789,7 @@ def _batch_jaxpr2(
avals_in2 = [core.unmapped_aval(axis_size, axis_name, b, aval)
if b is not not_mapped else aval
for aval, b in unsafe_zip(avals_in, in_axes2)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in2)
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis,
@ -834,7 +834,7 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
main_type)
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
@lu.transformation_with_aux

View File

@ -1797,7 +1797,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
else:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
out, tokens = jaxpr_subcomp(

View File

@ -663,7 +663,7 @@ def _closed_call_param_updater(params, _, __):
call_param_updaters[core.closed_call_p] = _closed_call_param_updater
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
_, avals_out, _ = trace_to_jaxpr_dynamic(
_, avals_out, _, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params), avals, debug_info)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
return avals_out
@ -1113,7 +1113,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
return [*known_vals_out, *residuals]
known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]
jaxpr_known, _, consts_known = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
(out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking
# check jaxpr_known and jaxpr_unknown in isolation
@ -1754,6 +1754,9 @@ class JaxprStackFrame:
eqns: list[JaxprEqn]
invars: list[Var]
effects: core.Effects
attrs_tracked: list[tuple[Any, str]]
attrs_inits: list
attrs_vars: list[Var]
debug_info: DebugInfo | None
def __init__(self):
@ -1765,23 +1768,29 @@ class JaxprStackFrame:
self.eqns = [] # cleared when we pop frame from main
self.invars = []
self.effects = set()
self.attrs_tracked = []
self.attrs_inits = []
self.attrs_vars = []
self.debug_info = None
def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)
def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> tuple[Jaxpr, list[Any]]:
def to_jaxpr(self, out_tracers: Sequence[Tracer]
) -> tuple[Jaxpr, list[Any], list[tuple[Any, str]]]:
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvals: Sequence[Any]
invars = self.attrs_vars + self.invars
state_outvars = [self.tracer_to_var[id(t)] for t in get_states(self.attrs_tracked)]
explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
outvars = state_outvars + explicit_outvars
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects)
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns)
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, list(constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore
set_states(self.attrs_tracked, self.attrs_inits)
return jaxpr, list(constvals), self.attrs_tracked
def to_jaxpr2(self, out_tracers):
# It's not necessary, but we keep the tracer-to-var mapping injective:
@ -2064,7 +2073,7 @@ class DynamicJaxprTrace(core.Trace):
for a, in_axis in zip(in_avals, params['in_axes'])]
with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore
with core.new_sublevel():
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
jaxpr, reduced_out_avals, consts, () = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals,
debug_info=debug_info_final(f, map_primitive.name))
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
@ -2098,7 +2107,7 @@ class DynamicJaxprTrace(core.Trace):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
main_ = ref(self.main)
@ -2108,7 +2117,7 @@ class DynamicJaxprTrace(core.Trace):
nz_tangent_avals, zero_avals = partition_list(in_zeros, in_avals)
jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals))
in_avals_ = (*in_avals, *nz_tangent_avals)
jaxpr, _, out_consts = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_)
jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_)
return jaxpr, out_consts, out_zeros()
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
@ -2132,7 +2141,7 @@ class DynamicJaxprTrace(core.Trace):
symbolic_zeros):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
main_ = ref(self.main)
@ -2170,7 +2179,7 @@ class DynamicJaxprTrace(core.Trace):
in_avals_t = [*[t.aval for t in tracers_res], *out_types]
with core.new_sublevel():
call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic(
call_jaxpr, out_avals, call_consts, () = trace_to_subjaxpr_dynamic(
call, self.main, in_avals_p)
closed_call_jaxpr = core.ClosedJaxpr(
convert_constvars_jaxpr(call_jaxpr), ())
@ -2183,8 +2192,8 @@ class DynamicJaxprTrace(core.Trace):
@_memoize
def transpose_jaxpr_thunk():
for store in transpose_flat.stores: store.reset()
jaxpr, _, consts = trace_to_subjaxpr_dynamic(transpose_flat, main_(),
in_avals_t)
jaxpr, _, consts, () = trace_to_subjaxpr_dynamic(
transpose_flat, main_(), in_avals_t)
return jaxpr, consts
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
@ -2299,13 +2308,13 @@ def trace_to_jaxpr_dynamic(
debug_info: DebugInfo | None = None,
*,
keep_inputs: list[bool] | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]:
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
del main, fun
return jaxpr, out_avals, consts
return jaxpr, out_avals, consts, attrs_tracked
def trace_to_subjaxpr_dynamic(
@ -2315,7 +2324,7 @@ def trace_to_subjaxpr_dynamic(
*,
keep_inputs: Sequence[bool] | None = None,
debug_info: DebugInfo | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]:
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
@ -2326,10 +2335,10 @@ def trace_to_subjaxpr_dynamic(
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
ans = fun.call_wrapped(*in_tracers_)
out_tracers = map(trace.full_raise, ans)
jaxpr, consts = frame.to_jaxpr(out_tracers)
jaxpr, consts, attrs_tracked = frame.to_jaxpr(out_tracers)
del fun, main, trace, frame, in_tracers, out_tracers, ans
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
@profiler.annotate_function
@ -2380,7 +2389,7 @@ def trace_to_jaxpr_final(
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
del fun, main
return jaxpr, out_avals, consts
@ -2404,6 +2413,15 @@ AbstractedAxesSpec = Union[
tuple[AbstractedAxisName, ...],
]
AttrsTracked = list[tuple[Any, str]]
AttrStates = list
def set_states(attrs_tracked: AttrsTracked, vals: AttrStates):
for ((obj, attr), val) in zip(attrs_tracked, vals):
setattr(obj, attr, val)
def get_states(attrs_tracked: AttrsTracked):
return [getattr(obj, attr) for (obj, attr) in attrs_tracked]
def infer_lambda_input_type(
axes_specs: Sequence[AbstractedAxesSpec] | None,
@ -2629,7 +2647,7 @@ def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
in_avals = [substitute(v.aval) for v in jaxpr.invars]
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals)
padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals)
return padded_jaxpr, padded_consts
class BoundedAxisSize(NamedTuple):

View File

@ -58,7 +58,7 @@ def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, out_tree, False,
primitive_name or "<unknown>")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
return jaxpr, consts, out_tree()
@weakref_lru_cache
@ -226,7 +226,7 @@ def _prune_zeros(ts):
return [t for t in ts if type(t) is not ad_util.Zero]
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
return core.ClosedJaxpr(jaxpr, consts)
def _show_diff(array1, array2):

View File

@ -88,7 +88,7 @@ def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
consts = map(lambda x: ref_get(x, ()), consts)
all_consts = merge_lists(is_const_ref, consts, const_refs)
return core.eval_jaxpr(jaxpr, all_consts, i, *args)
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_hoist), in_avals)
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr
@ -98,7 +98,7 @@ def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
f, out_tree_thunk = flatten_fun_nokwargs(
lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
f, state_avals)
return jaxpr, consts, out_tree_thunk()
@ -607,7 +607,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
return for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps,
reverse=reverse, which_linear=jaxpr_known_which_linear,
unroll=unroll)
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
known, [v.aval for v in known_invars])
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars],
@ -629,7 +629,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
_, ans = split_list(out_flat, [num_res])
_, ans = partition_list(out_inst, ans)
return ans
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
staged, [v.aval for v in [*resvars, *eqn.invars]])
assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars)
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
@ -667,7 +667,7 @@ def _convert_outputs_to_writes(
AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error
v.aval.dtype)) # pytype: disable=attribute-error
for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals])
assert not consts
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] # pytype: disable=attribute-error
@ -693,7 +693,7 @@ def _convert_inputs_to_reads(
aval.dtype)) # pytype: disable=attribute-error
for aval, loop_invar in zip(res_val_avals, loop_invar_res)]
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])
return jaxpr
@ -720,7 +720,7 @@ def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr:
ad.backward_pass(
tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ())
return []
jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr_trans, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
return jaxpr_trans

View File

@ -985,7 +985,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res])
out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known)
return [*intensive_res, *out_loop]
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
known, [v.aval for v in ins_known])
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
eqn_known = pe.new_jaxpr_eqn(
@ -1108,8 +1108,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
new_in_avals = [*remaining_const_avals, *[a.inner_aval for a in in_ref_avals],
*carry_avals,
*[core.mapped_aval(length, 0, a) for a in xs_avals]]
new_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped),
new_in_avals)
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), new_in_avals)
new_linear = (*remaining_consts_linear, *in_refs_linear,
*carry_linear, *xs_linear)
all_out = scan_p.bind(*remaining_consts, *in_refs, *carry, *xs,
@ -1768,7 +1767,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
*carry)
carry, refs_out = split_list(carry_refs, [num_carry])
return [*refs_out, *carry]
new_body_jaxpr, _, new_body_consts = pe.trace_to_jaxpr_dynamic(
new_body_jaxpr, _, new_body_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(new_body), [*remaining_body_const_avals, *[a.inner_aval for a
in ref_avals],
*carry_avals])
@ -1782,7 +1781,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
consts_refs_carry, [cond_nconsts, num_refs])
del refs # We don't use them here!
return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry)
new_cond_jaxpr, _, new_cond_consts = pe.trace_to_jaxpr_dynamic(
new_cond_jaxpr, _, new_cond_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(new_cond), [*cond_consts_avals,
*[a.inner_aval for a in ref_avals],
*carry_avals])

View File

@ -1027,7 +1027,7 @@ def _reduction_jaxpr(computation, aval):
f"Reduction functions should only return an array.\n"
f"Full return value: {result}")
return (result,)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(comp, (aval, aval))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp, (aval, aval))
if any(isinstance(c, core.Tracer) for c in consts):
raise NotImplementedError(
"Reduction computations can't close over Tracers. Please open an issue "
@ -1040,7 +1040,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals))
comp = lu.wrap_init(computation)
flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
if any(isinstance(c, core.Tracer) for c in consts):
raise NotImplementedError(
"Reduction computations can't close over Tracers. Please open an issue "

View File

@ -999,7 +999,7 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
for a, a_in_axes in zip(in_avals, params['in_axes'])]
with core.extend_axis_env_nd(global_axis_sizes.items()):
with core.new_sublevel():
jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
jaxpr, mapped_out_avals, consts, () = trace_to_subjaxpr_dynamic(
f, self.main, mapped_in_avals)
out_axes = params['out_axes_thunk']()
if params['spmd_out_axes_thunk'] is not None:
@ -1340,7 +1340,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
# NOTE: We don't extend the resource env with the mesh shape, because those
# resources are already in scope! It's the outermost xmap that introduces
# them!
vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(f, local_avals)
vectorized_jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, local_avals)
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
@ -1415,7 +1415,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
add_spmd_axes(mesh_in_axes, spmd_in_axes)
add_spmd_axes(mesh_out_axes, spmd_out_axes)
global_in_avals = ctx.avals_in
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
sharded_global_in_nodes = [
[mlir.wrap_with_sharding_op(
@ -1477,7 +1477,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
# resources are already in scope! It's the outermost xmap that introduces
# them!
global_in_avals = ctx.avals_in
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
# We in-line here rather than generating a Call HLO as in the xla_call

View File

@ -192,7 +192,7 @@ def _convert_block_spec_to_block_mapping(
block_shape = tuple(
mapped if s is None else s for s in block_shape)
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts))
def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None

View File

@ -543,7 +543,7 @@ def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
if consts:
raise NotImplementedError
jaxpr = pe.convert_constvars_jaxpr(jaxpr)

View File

@ -103,7 +103,7 @@ def run_scoped(f: Callable[..., None], *types, **kw_types) -> None:
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(f), in_tree)
avals = map(lambda t: t.get_aval(), flat_types)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
run_scoped_p.bind(*consts, jaxpr=jaxpr)

View File

@ -26,11 +26,11 @@ from jax import api_util
from jax import tree_util
from jax import lax
from jax._src import state
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src import ad_util
from jax._src import core as jax_core
from jax._src.state import primitives as sp
@ -258,7 +258,7 @@ def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray,
idx_avals = [i32_aval] * (len(grid) + 1)
else:
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function), idx_avals)
shape = aval.shape if block_mapping is None else block_mapping.block_shape
if dim is batching.not_mapped:
@ -387,7 +387,7 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
consts = map(lambda x: sp.ref_get(x, ()), consts)
all_consts = merge_lists(is_const_ref, consts, const_refs)
return jax_core.eval_jaxpr(jaxpr, all_consts, *args)
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_hoist), in_avals)
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr
@ -401,8 +401,7 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals,
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun), jaxpr_in_tree)
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, debug)
jaxpr = _hoist_consts_to_refs(jaxpr)
return grid_mapping, jaxpr, consts, out_tree_thunk()

View File

@ -53,8 +53,8 @@ from jax._src.util import merge_lists
from jax._src.util import partition_list
from jax._src.util import split_list
from jax._src.util import weakref_lru_cache
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
import jax.numpy as jnp
from jax_triton import triton_lib
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
@ -424,7 +424,7 @@ def _associative_scan_lowering(
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree
)
combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
flat_fun, in_avals
)
out_tree = out_tree_thunk()
@ -572,7 +572,7 @@ def lower_fun(
def f_lowered(ctx: TritonLoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(fn, params)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
return out if multiple_results else out[0]
@ -998,7 +998,7 @@ def _reduction_lowering(body, ctx: TritonLoweringRuleContext, a, axes):
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree
)
combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
flat_fun, [*mapped_avals, *mapped_avals]
)
out_tree = out_tree_thunk()

View File

@ -21,7 +21,7 @@ import itertools as it
import logging
import operator as op
import weakref
from typing import Callable, cast, NamedTuple, Any, Union
from typing import Callable, cast, NamedTuple, Any, Union, Optional
import threading
import warnings
@ -129,10 +129,13 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
args_flat, _, params, _, out_tree, _, _, _, arg_names = infer_params_fn(
*args, **kwargs)
args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \
infer_params_fn(*args, **kwargs)
for arg in args_flat:
dispatch.check_arg(arg)
if attrs_tracked:
init_states = _get_states(attrs_tracked)
args_flat = [*init_states, *args_flat]
try:
out_flat = pjit_p.bind(*args_flat, **params)
except pxla.DeviceAssignmentMismatchError as e:
@ -142,8 +145,20 @@ def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, arg_names)
raise ValueError(msg) from None
if attrs_tracked:
final_states, out_flat = split_list(out_flat, [len(attrs_tracked)])
_set_states(attrs_tracked, final_states)
outs = tree_unflatten(out_tree, out_flat)
return outs, out_flat, out_tree, args_flat, params['jaxpr']
return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked
def _set_states(attrs_tracked, vals):
from jax.experimental.attrs import jax_setattr # type: ignore
for ((obj, attr), val) in zip(attrs_tracked, vals):
jax_setattr(obj, attr, val)
def _get_states(attrs_tracked):
from jax.experimental.attrs import jax_getattr # type: ignore
return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked]
def _python_pjit(fun: Callable, infer_params_fn):
@ -161,7 +176,8 @@ def _python_pjit(fun: Callable, infer_params_fn):
return wrapped
def _get_fastpath_data(executable, out_tree, args_flat, out_flat):
def _get_fastpath_data(executable, out_tree, args_flat, out_flat, attrs_tracked,
) -> Optional[pxla.MeshExecutableFastpathData]:
out_flat, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
use_fastpath = (
@ -172,7 +188,9 @@ def _get_fastpath_data(executable, out_tree, args_flat, out_flat):
not executable.unsafe_call.ordered_effects and
not executable.unsafe_call.has_unordered_effects and
not executable.unsafe_call.has_host_callbacks and
all(isinstance(x, xc.ArrayImpl) for x in out_flat)
all(isinstance(x, xc.ArrayImpl) for x in out_flat) and
# no attr state effects
not attrs_tracked
)
if use_fastpath:
@ -224,11 +242,12 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
@api_boundary
def cache_miss(*args, **kwargs):
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
fun, infer_params_fn, *args, **kwargs)
executable = _read_most_recent_pjit_call_executable(jaxpr)
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
return outs, fastpath_data
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked)
return outs, maybe_fastpath_data
if xla_extension_version >= 226:
cpp_pjit_f = xc._xla.pjit( # type: ignore
@ -312,7 +331,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
out_layouts = kwargs.pop('_out_layouts', None)
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donated_invars, in_layouts_flat, out_layouts_flat,
arg_names) = infer_params_fn(
arg_names, ()) = infer_params_fn(
*args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts)
resource_env = params['resource_env']
mesh = None if resource_env is None else resource_env.physical_mesh
@ -339,7 +358,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
@api_boundary
def eval_shape(*args, **kwargs):
_, _, params, _, out_tree, _, _, _, _ = infer_params_fn(
_, _, params, _, out_tree, _, _, _, _, _ = infer_params_fn(
*args, **kwargs, _in_layouts=None, _out_layouts=None)
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape)
for x in params['jaxpr'].out_avals]
@ -464,7 +483,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
hashable_pytree(in_shardings), hashable_pytree(in_layouts), in_avals,
in_tree, resource_env, dbg, device_or_backend_set, True if kwargs else False)
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
flat_fun, hashable_pytree(out_shardings), hashable_pytree(out_layouts),
in_type, dbg, device_or_backend_set, HashableFunction(out_tree, closure=()),
HashableFunction(res_paths, closure=()), inline)
@ -477,19 +496,19 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
implicit_args = []
args_flat = [*implicit_args, *explicit_args]
num_extra_args = len(implicit_args) + len(consts)
num_extra_args = len(implicit_args) + len(attrs_tracked) + len(consts)
canonicalized_in_shardings_flat = \
(UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
donated_invars = (False,) * num_extra_args + donated_invars
assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) ==
len(donated_invars) == len(consts) + len(args_flat))
len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat))
# in_shardings and out_shardings here are all GSPMDSharding.
params = dict(
jaxpr=jaxpr,
in_shardings=canonicalized_in_shardings_flat,
out_shardings=canonicalized_out_shardings_flat,
out_shardings=out_shardings,
resource_env=resource_env,
donated_invars=donated_invars,
name=getattr(flat_fun, '__name__', '<unknown>'),
@ -498,7 +517,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
)
return (consts + args_flat, in_type, params, in_tree, out_tree(),
donated_invars, in_layouts_flat, out_layouts_flat,
dbg.arg_names if dbg else None)
dbg.arg_names if dbg else None, attrs_tracked)
def _extract_implicit_args(
in_type: Sequence[tuple[core.AbstractValue, bool]],
@ -1047,11 +1066,13 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
if config.dynamic_shapes.value:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
lu.annotate(fun, in_type), debug_info=pe_debug)
attrs_tracked = []
else:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
fun, in_type, debug_info=pe_debug)
if not config.dynamic_shapes.value:
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
if not config.dynamic_shapes.value and not attrs_tracked:
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
if config.enable_key_reuse_checks.value:
@ -1065,7 +1086,7 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
else:
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
final_consts = []
return closed_jaxpr, final_consts, global_out_avals
return closed_jaxpr, final_consts, global_out_avals, attrs_tracked
@lru_cache(maxsize=4096)
@ -1108,13 +1129,12 @@ def _check_and_canonicalize_out_shardings(
def _pjit_jaxpr(fun, out_shardings_thunk, out_layouts_thunk, in_type, debug_info,
device_or_backend_set, out_tree, result_paths, inline):
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
fun, in_type, debug_info, result_paths, IgnoreKey(inline))
canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_layouts_thunk, out_tree, tuple(out_type),
jaxpr.jaxpr.debug_info, device_or_backend_set)
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat
return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat, attrs_tracked
@dataclasses.dataclass(frozen=True)
@ -1357,7 +1377,7 @@ def _pjit_call_impl(*args, jaxpr,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat)
compiled, tree_structure(out_flat), args, out_flat, [])
return out_flat, fastpath_data
f = _get_jaxpr_as_fun(
@ -1856,7 +1876,7 @@ pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
@lu.cache
def _pjit_transpose_trace(fun, in_avals):
transpose_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
transpose_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
return transpose_jaxpr

View File

@ -66,7 +66,7 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
else v.aval for v, d in zip(jaxpr.invars, should_discharge)]
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr,
should_discharge, consts))
new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
new_jaxpr, _ , new_consts, () = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
return new_jaxpr, new_consts
@dataclasses.dataclass
@ -427,7 +427,7 @@ def _convert_outputs_to_writes(
return []
res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef)
else v.aval for v in jaxpr.outvars]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals])
assert not consts
return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error
@ -447,7 +447,7 @@ def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
split_list([v.aval for v in jaxpr.invars], [num_res])
res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else
aval for aval in res_val_avals]
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*res_ref_avals, *orig_ref_avals])
return jaxpr
@ -648,7 +648,7 @@ def _run_state_partial_eval_custom(
def staged(*args):
out = run_state_p.bind(*args, **staged_params)
return out[num_res:]
staged_call_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(staged,
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged,
[v.aval for v in res_staged_invars])
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
staged_outvars,
@ -707,7 +707,7 @@ def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool]
ad.backward_pass(
tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ())
return []
jaxpr_trans, _, consts = pe.trace_to_jaxpr_dynamic(
jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
return jaxpr_trans, consts
@ -770,7 +770,7 @@ def _initial_style_jaxpr(fun, in_tree, in_avals):
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
tree_util.treedef_tuple((in_tree,)))
debug = pe.debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
return jaxpr, consts, out_tree_thunk()
def run_state(f: Callable[..., None]):

View File

@ -15,7 +15,7 @@
from collections.abc import Sequence
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
from jax._src.state import AbstractRef
@ -44,7 +44,7 @@ def hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
consts = map(lambda x: ref_get(x, ()), consts)
all_consts = merge_lists(is_const_ref, consts, const_refs)
return core.eval_jaxpr(jaxpr, all_consts, *args)
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_hoist), in_avals)
assert not consts, "All consts should have been converted to refs"
return hoisted_jaxpr

64
jax/experimental/attrs.py Normal file
View File

@ -0,0 +1,64 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from jax._src import core
from jax._src.interpreters import partial_eval as pe
JaxVal = Any
getattr_p = core.Primitive('getattr')
setattr_p = core.Primitive('setattr')
def jax_getattr(obj: Any, attr: str):
return getattr_p.bind(obj=obj, attr=attr)
def jax_setattr(obj: Any, attr: str, val: JaxVal):
setattr_p.bind(val, obj=obj, attr=attr)
@getattr_p.def_impl
def _getattr_impl(*, obj, attr):
return getattr(obj, attr)
@setattr_p.def_impl
def _setattr_impl(val, *, obj, attr):
setattr(obj, attr, val)
def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str):
frame = trace.main.jaxpr_stack[-1] # type: ignore
if (obj, attr) not in frame.attrs_tracked:
init_val = getattr(obj, attr)
aval = core.raise_to_shaped(core.get_aval(init_val))
tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current())
var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval)
setattr(obj, attr, tracer)
frame.attrs_tracked.append((obj, attr))
frame.attrs_inits.append(init_val)
frame.attrs_vars.append(var)
pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked
def _getattr_staging(trace, *, obj, attr):
trace._ensure_tracked(obj, attr)
return getattr(obj, attr)
pe.custom_staging_rules[getattr_p] = _getattr_staging
def _setattr_staging(trace, tracer, *, obj, attr):
trace._ensure_tracked(obj, attr)
setattr(obj, attr, tracer)
pe.custom_staging_rules[setattr_p] = _setattr_staging

View File

@ -459,7 +459,7 @@ class custom_partitioning:
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
"custom_partitioning")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
out_flat = custom_partitioning_p.bind(

View File

@ -718,7 +718,7 @@ def _jet_jaxpr(
) -> tuple[core.ClosedJaxpr, Any]:
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
jaxpr_jet, _, consts = pe.trace_to_jaxpr_dynamic(
jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic(
f_jet, primals_and_series_avals)
return core.ClosedJaxpr(jaxpr_jet, consts), out_tree_def

View File

@ -181,7 +181,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature
args_flat, in_tree = tree_util.tree_flatten(args)
in_avals_flat = [core.get_aval(arg) for arg in args_flat]
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
return get_jaxpr_type_signature(jaxpr)

View File

@ -152,7 +152,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature
args_flat, in_tree = tree_util.tree_flatten(args)
in_avals_flat = [core.get_aval(arg) for arg in args_flat]
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
return get_jaxpr_type_signature(jaxpr)

View File

@ -426,7 +426,7 @@ def _shard_map_staging(
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
main = trace.main
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, genavals, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
out_avals_ = map(_check_shapedarray, genavals)
_check_names(out_names_thunk(), out_avals_)
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
@ -1378,7 +1378,7 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which):
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
for v, w in zip(jaxpr.constvars, which)]
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(fun, in_avals)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return jaxpr
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
@ -1495,7 +1495,7 @@ def _add_reshapes(which, jaxpr_known, jaxpr_staged):
res = [_add_singleton(x) if not x.shape else x for x in res]
return [*out_known, *res]
avals_in = [v.aval for v in jaxpr_known.invars]
jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic(known, avals_in)
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(known, avals_in)
@lu.wrap_init
def staged(*args):
@ -1505,7 +1505,7 @@ def _add_reshapes(which, jaxpr_known, jaxpr_staged):
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
return jaxpr_known, jaxpr_staged
@ -1773,7 +1773,7 @@ def _replication_rewrite_match(
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
f = _match_rep(f, mesh, out_rep, out_rep_dst)
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts)
# TODO(mattjj): caching
@ -1785,7 +1785,7 @@ def _replication_rewrite_nomatch(
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts), out_rep()
@lu.transformation_with_aux

View File

@ -447,7 +447,7 @@ def sparsify_raw(f):
spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue)
in_avals_flat = spvalues_to_avals(spenv, spvalues_flat)
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
jaxpr, out_avals_flat, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
if len(out_avals_flat) != len(result):
raise Exception("Internal: eval_sparse does not return expected number of arguments. "
@ -747,7 +747,7 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
args = spvalues_to_arrays(spenv, spvalues)
args_flat, in_tree = tree_flatten(args)
avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
assert out_tree is not None
return sp_jaxpr, out_tree

View File

@ -82,7 +82,7 @@ from jax._src.interpreters.partial_eval import (
result_info as result_info,
sig_info as sig_info,
trace_to_jaxpr as trace_to_jaxpr,
trace_to_jaxpr_dynamic as trace_to_jaxpr_dynamic,
trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic,
trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2,
trace_to_jaxpr_final as trace_to_jaxpr_final,
trace_to_jaxpr_final2 as trace_to_jaxpr_final2,
@ -97,4 +97,12 @@ from jax._src.interpreters.partial_eval import (
trivial_ctx as trivial_ctx,
)
# TODO(mattjj): remove temporary shim when trace_to_jaxpr_dynamic sig stabilizes
def trace_to_jaxpr_dynamic(fun, in_avals, debug_info=None, *, keep_inputs=None): # noqa
jaxpr, out_avals, consts, () = _trace_to_jaxpr_dynamic(
fun, in_avals, debug_info, keep_inputs=keep_inputs)
return jaxpr, out_avals, consts
from jax._src.core import Jaxpr as Jaxpr

View File

@ -1301,6 +1301,15 @@ jax_test(
srcs = ["clear_backends_test.py"],
)
jax_test(
name = "attrs_test",
srcs = ["attrs_test.py"],
deps = [
"//jax:experimental",
],
)
jax_test(
name = "experimental_rnn_test",
srcs = ["experimental_rnn_test.py"],

64
tests/attrs_test.py Normal file
View File

@ -0,0 +1,64 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.util import safe_zip, safe_map
from jax.experimental.attrs import jax_setattr, jax_getattr
config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@dataclass
class Thing:
x: float
class AttrsTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_basic(self, jit: bool):
thing = Thing(1.0)
def double_it() -> None:
cur_x = jax_getattr(thing, "x")
jax_setattr(thing, "x", cur_x * 2)
if jit:
double_it = jax.jit(double_it)
self.assertEqual(thing.x, 1.0)
double_it()
self.assertEqual(thing.x, 2.0)
double_it()
self.assertEqual(thing.x, 4.0)
double_it()
self.assertEqual(thing.x, 8.0)
double_it()
self.assertEqual(thing.x, 16.0)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -624,7 +624,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
def f(x, y):
return x, y
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 3)
@ -648,7 +648,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
return (x, w)
return g(x, y, x, y)
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
@ -684,7 +684,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
return (x, w)
return g(x.shape[0], x, y, x, y)
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
@ -720,7 +720,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
u = lax_internal._reduce_sum(w, [0])
return (u,)
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, a, b], keep_inputs=[False, True, True])
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
@ -745,7 +745,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
def g(x): return x
return g(a),
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
f, [n, m, a, b], keep_inputs=[False, False, True, True])
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
# e:f32[a] = xla_call[

View File

@ -318,7 +318,7 @@ class PallasCallDMATest(parameterized.TestCase):
pltpu.run_scoped(body, pltpu.VMEM((8,), jnp.float32))
return []
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(kernel),
[
state.shaped_array_ref((8,), jnp.float32),

View File

@ -132,7 +132,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
else:
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval])
self.assertSetEqual(jaxpr.effects,
{ReadEffect(len(jaxpr.constvars))})
@ -229,7 +229,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
else:
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects,
{WriteEffect(len(jaxpr.constvars))})
@ -306,7 +306,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
with self.assertRaises(Exception):
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
else:
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval])
self.assertSetEqual(jaxpr.effects,
{AccumEffect(len(jaxpr.constvars))})
@ -326,7 +326,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x[()] = jnp.int32(1)
x[()] = jnp.int32(2)
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
@ -339,7 +339,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def body(x):
ref_addupdate(x, (), jnp.int32(1))
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
@ -349,14 +349,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def body(x_ref):
x = x_ref[()]
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
def body(x_ref):
x = x_ref[:, 0]
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32)])
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
@ -364,14 +364,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def body(x_ref):
x_ref[()] = jnp.int32(2)
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x_ref[:, 0] = val
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
@ -380,14 +380,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def body(x_ref):
x = ref_swap(x_ref, (), jnp.int32(2))
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x = ref_swap(x_ref, (slice(None), 0), val)
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
@ -397,7 +397,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def body(x_ref):
ref_addupdate(x_ref, (), jnp.int32(2))
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
@ -405,7 +405,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def body(x_ref, val):
ref_addupdate(x_ref, (slice(None), 0), val)
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
@ -422,7 +422,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
@ -438,7 +438,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
@ -458,7 +458,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p)
self.assertEqual(jaxpr.eqns[2].primitive, get_p)
@ -529,12 +529,12 @@ class StatePrimitivesTest(jtu.JaxTestCase):
# discharge-of-vmap
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
@ -553,7 +553,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a = ref_get(a_ref, ())
return [a + 1]
in_avals = [shaped_array_ref((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
@ -569,7 +569,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a = ref_get(a_ref, (0, 1))
return [a + 1]
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
@ -588,7 +588,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a = a_ref[jnp.array([0, 1])]
return [a + 1]
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), in_avals)
discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts)
@ -603,7 +603,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return []
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that ignores the first
# value and returns second value plus 1.
@ -620,7 +620,7 @@ class StateDischargeTest(jtu.JaxTestCase):
ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
@ -640,7 +640,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
return []
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts)
@ -654,7 +654,7 @@ class StateDischargeTest(jtu.JaxTestCase):
return []
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that adds the first value,
# second value, and 1.
@ -672,7 +672,7 @@ class StateDischargeTest(jtu.JaxTestCase):
jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
@ -693,7 +693,7 @@ class StateDischargeTest(jtu.JaxTestCase):
jnp.ones((2, 3), 'float32'))
return []
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, discharged_consts = discharge_state(
stateful_jaxpr, consts)
@ -707,7 +707,7 @@ class StateDischargeTest(jtu.JaxTestCase):
b = a + 1
return [a, b]
in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
self.assertLen(discharged_jaxpr.invars, 1)
@ -727,7 +727,7 @@ class StateDischargeTest(jtu.JaxTestCase):
shaped_array_ref((4,), jnp.dtype('float32')),
shaped_array_ref((4,), jnp.dtype('float32'))
]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = discharge_state(
stateful_jaxpr, consts, should_discharge=[False, True])
@ -929,13 +929,13 @@ if CAN_USE_HYPOTHESIS:
ref = get_vmap_param.bat_ref
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
@ -974,13 +974,13 @@ if CAN_USE_HYPOTHESIS:
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[])
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
@ -1018,13 +1018,13 @@ if CAN_USE_HYPOTHESIS:
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
out_axes=[])
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
# vmap-of-discharge
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
@ -1323,7 +1323,7 @@ class GeneralRefTest(jtu.JaxTestCase):
x_ref[...] = x
ref_addupdate(x_ref, (), x)
return [x]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [AbstractRef(core.UnshapedArray(jnp.int32))])
self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray)
self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32"))
@ -1334,7 +1334,7 @@ class GeneralRefTest(jtu.JaxTestCase):
x_ref[...] = x
ref_addupdate(x_ref, (), x)
return [x]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [AbstractRef(core.AbstractToken())])
self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken)
@ -1343,7 +1343,7 @@ class GeneralRefTest(jtu.JaxTestCase):
x_ref = x_ref_ref[...]
return [x_ref]
# Not sure why you'd ever want to do this, but it works!
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f),
[AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))])
self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef)