mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
d4660a0972
commit
4a8babb101
@ -241,6 +241,7 @@ py_library_providing_imports_info(
|
||||
"_src/internal_test_util/**",
|
||||
],
|
||||
) + [
|
||||
"experimental/attrs.py",
|
||||
# until new parallelism APIs are moved out of experimental
|
||||
"experimental/maps.py",
|
||||
"experimental/pjit.py",
|
||||
|
@ -372,7 +372,7 @@ def _trace_to_jaxpr(fun, in_tree, in_avals):
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint")
|
||||
try:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
except core.ConcretizationTypeError as e:
|
||||
msg, = e.args
|
||||
if 'for checkpoint' not in msg:
|
||||
@ -620,7 +620,7 @@ def _transpose_jaxpr(jaxpr, in_lin, out_zeros, reduce_axes):
|
||||
in_cts_nz, _ = partition_list(in_zeros, in_cts)
|
||||
return in_cts_nz
|
||||
|
||||
transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
|
||||
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
|
||||
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
|
||||
return transposed_jaxpr, cell.in_cts_zero # type: ignore
|
||||
|
||||
|
@ -560,7 +560,7 @@ def xla_computation(fun: Callable,
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
|
||||
ordered_effects = list(
|
||||
|
@ -746,7 +746,7 @@ def jaxpr_to_checkify_jaxpr(
|
||||
fun = lu.wrap_init(checkify_jaxpr_partial)
|
||||
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)
|
||||
|
||||
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
|
||||
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
|
||||
checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
|
||||
out_tree, error_effects = metadata()
|
||||
return checked_jaxpr, out_tree, error_effects
|
||||
@ -832,7 +832,7 @@ def checkify_while_body_jaxpr(
|
||||
return out
|
||||
new_body_f_ = lu.wrap_init(new_body_f)
|
||||
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
|
||||
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
|
||||
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
|
||||
*body_jaxpr.in_avals])
|
||||
closed_jaxpr = pe.close_jaxpr(jaxpr)
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
@ -1128,7 +1128,7 @@ def checkify(f: Callable[..., Out],
|
||||
# stage:
|
||||
fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
|
||||
debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify')
|
||||
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
|
||||
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
||||
# checkify:
|
||||
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
|
||||
|
@ -69,7 +69,7 @@ class custom_vmap:
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
in_tree = treedef_tuple((tree_structure(consts), in_tree))
|
||||
assert self.vmap_rule is not None
|
||||
|
@ -66,7 +66,7 @@ def _resolve_kwargs(fun, args, kwargs):
|
||||
return ba.args
|
||||
|
||||
def _initial_style_jaxpr(fun, in_avals):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
return jaxpr, consts
|
||||
|
||||
def _close_jaxpr(jaxpr):
|
||||
@ -977,7 +977,7 @@ def custom_gradient(fun):
|
||||
ans_flat, out_tree = tree_flatten((ans,))
|
||||
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
|
||||
ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
|
||||
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
|
||||
|
||||
def bwd(res, cts):
|
||||
@ -1102,7 +1102,7 @@ def _maybe_perturbed(x: Any) -> bool:
|
||||
@cache()
|
||||
def _closure_convert_for_avals(fun, in_tree, in_avals):
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
||||
jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
||||
out_tree = out_tree()
|
||||
|
||||
(closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts)
|
||||
|
@ -697,7 +697,7 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
||||
nonzeros)
|
||||
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
||||
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
||||
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
||||
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
|
@ -789,7 +789,7 @@ def _batch_jaxpr2(
|
||||
avals_in2 = [core.unmapped_aval(axis_size, axis_name, b, aval)
|
||||
if b is not not_mapped else aval
|
||||
for aval, b in unsafe_zip(avals_in, in_axes2)]
|
||||
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in2)
|
||||
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
|
||||
|
||||
def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis,
|
||||
@ -834,7 +834,7 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
|
||||
main_type)
|
||||
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
|
||||
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
||||
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
|
@ -1797,7 +1797,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
|
||||
else:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
||||
|
||||
out, tokens = jaxpr_subcomp(
|
||||
|
@ -663,7 +663,7 @@ def _closed_call_param_updater(params, _, __):
|
||||
call_param_updaters[core.closed_call_p] = _closed_call_param_updater
|
||||
|
||||
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
|
||||
_, avals_out, _ = trace_to_jaxpr_dynamic(
|
||||
_, avals_out, _, () = trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(fun, params), avals, debug_info)
|
||||
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
|
||||
return avals_out
|
||||
@ -1113,7 +1113,7 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
|
||||
return [*known_vals_out, *residuals]
|
||||
|
||||
known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]
|
||||
jaxpr_known, _, consts_known = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
|
||||
jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
|
||||
(out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking
|
||||
|
||||
# check jaxpr_known and jaxpr_unknown in isolation
|
||||
@ -1754,6 +1754,9 @@ class JaxprStackFrame:
|
||||
eqns: list[JaxprEqn]
|
||||
invars: list[Var]
|
||||
effects: core.Effects
|
||||
attrs_tracked: list[tuple[Any, str]]
|
||||
attrs_inits: list
|
||||
attrs_vars: list[Var]
|
||||
debug_info: DebugInfo | None
|
||||
|
||||
def __init__(self):
|
||||
@ -1765,23 +1768,29 @@ class JaxprStackFrame:
|
||||
self.eqns = [] # cleared when we pop frame from main
|
||||
self.invars = []
|
||||
self.effects = set()
|
||||
self.attrs_tracked = []
|
||||
self.attrs_inits = []
|
||||
self.attrs_vars = []
|
||||
self.debug_info = None
|
||||
|
||||
def add_eqn(self, eqn: core.JaxprEqn):
|
||||
self.eqns.append(eqn)
|
||||
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> tuple[Jaxpr, list[Any]]:
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer]
|
||||
) -> tuple[Jaxpr, list[Any], list[tuple[Any, str]]]:
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
||||
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
|
||||
constvals: Sequence[Any]
|
||||
invars = self.attrs_vars + self.invars
|
||||
state_outvars = [self.tracer_to_var[id(t)] for t in get_states(self.attrs_tracked)]
|
||||
explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
|
||||
outvars = state_outvars + explicit_outvars
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars,
|
||||
self.eqns)
|
||||
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects)
|
||||
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns)
|
||||
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects)
|
||||
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
return jaxpr, list(constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore
|
||||
set_states(self.attrs_tracked, self.attrs_inits)
|
||||
return jaxpr, list(constvals), self.attrs_tracked
|
||||
|
||||
def to_jaxpr2(self, out_tracers):
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
@ -2064,7 +2073,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
for a, in_axis in zip(in_avals, params['in_axes'])]
|
||||
with core.extend_axis_env(axis_name, params["global_axis_size"], None): # type: ignore
|
||||
with core.new_sublevel():
|
||||
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
jaxpr, reduced_out_avals, consts, () = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, reduced_in_avals,
|
||||
debug_info=debug_info_final(f, map_primitive.name))
|
||||
ordered_effects = effects.ordered_effects.filter_in(jaxpr.effects)
|
||||
@ -2098,7 +2107,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
with core.new_sublevel():
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
main_ = ref(self.main)
|
||||
|
||||
@ -2108,7 +2117,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
nz_tangent_avals, zero_avals = partition_list(in_zeros, in_avals)
|
||||
jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals))
|
||||
in_avals_ = (*in_avals, *nz_tangent_avals)
|
||||
jaxpr, _, out_consts = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_)
|
||||
jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_)
|
||||
return jaxpr, out_consts, out_zeros()
|
||||
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
@ -2132,7 +2141,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
symbolic_zeros):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
with core.new_sublevel():
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
|
||||
main_ = ref(self.main)
|
||||
@ -2170,7 +2179,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
in_avals_t = [*[t.aval for t in tracers_res], *out_types]
|
||||
|
||||
with core.new_sublevel():
|
||||
call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic(
|
||||
call_jaxpr, out_avals, call_consts, () = trace_to_subjaxpr_dynamic(
|
||||
call, self.main, in_avals_p)
|
||||
closed_call_jaxpr = core.ClosedJaxpr(
|
||||
convert_constvars_jaxpr(call_jaxpr), ())
|
||||
@ -2183,8 +2192,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
@_memoize
|
||||
def transpose_jaxpr_thunk():
|
||||
for store in transpose_flat.stores: store.reset()
|
||||
jaxpr, _, consts = trace_to_subjaxpr_dynamic(transpose_flat, main_(),
|
||||
in_avals_t)
|
||||
jaxpr, _, consts, () = trace_to_subjaxpr_dynamic(
|
||||
transpose_flat, main_(), in_avals_t)
|
||||
return jaxpr, consts
|
||||
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
@ -2299,13 +2308,13 @@ def trace_to_jaxpr_dynamic(
|
||||
debug_info: DebugInfo | None = None,
|
||||
*,
|
||||
keep_inputs: list[bool] | None = None,
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]:
|
||||
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
|
||||
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
|
||||
del main, fun
|
||||
return jaxpr, out_avals, consts
|
||||
return jaxpr, out_avals, consts, attrs_tracked
|
||||
|
||||
|
||||
def trace_to_subjaxpr_dynamic(
|
||||
@ -2315,7 +2324,7 @@ def trace_to_subjaxpr_dynamic(
|
||||
*,
|
||||
keep_inputs: Sequence[bool] | None = None,
|
||||
debug_info: DebugInfo | None = None,
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
|
||||
) -> tuple[Jaxpr, list[AbstractValue], list[Any], list[tuple[Any, str]]]:
|
||||
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
|
||||
|
||||
frame = JaxprStackFrame()
|
||||
@ -2326,10 +2335,10 @@ def trace_to_subjaxpr_dynamic(
|
||||
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
|
||||
ans = fun.call_wrapped(*in_tracers_)
|
||||
out_tracers = map(trace.full_raise, ans)
|
||||
jaxpr, consts = frame.to_jaxpr(out_tracers)
|
||||
jaxpr, consts, attrs_tracked = frame.to_jaxpr(out_tracers)
|
||||
del fun, main, trace, frame, in_tracers, out_tracers, ans
|
||||
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts
|
||||
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
@ -2380,7 +2389,7 @@ def trace_to_jaxpr_final(
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
with core.new_sublevel():
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(
|
||||
fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)
|
||||
del fun, main
|
||||
return jaxpr, out_avals, consts
|
||||
@ -2404,6 +2413,15 @@ AbstractedAxesSpec = Union[
|
||||
tuple[AbstractedAxisName, ...],
|
||||
]
|
||||
|
||||
AttrsTracked = list[tuple[Any, str]]
|
||||
AttrStates = list
|
||||
def set_states(attrs_tracked: AttrsTracked, vals: AttrStates):
|
||||
for ((obj, attr), val) in zip(attrs_tracked, vals):
|
||||
setattr(obj, attr, val)
|
||||
|
||||
def get_states(attrs_tracked: AttrsTracked):
|
||||
return [getattr(obj, attr) for (obj, attr) in attrs_tracked]
|
||||
|
||||
|
||||
def infer_lambda_input_type(
|
||||
axes_specs: Sequence[AbstractedAxesSpec] | None,
|
||||
@ -2629,7 +2647,7 @@ def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
|
||||
|
||||
in_avals = [substitute(v.aval) for v in jaxpr.invars]
|
||||
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
|
||||
padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals)
|
||||
padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals)
|
||||
return padded_jaxpr, padded_consts
|
||||
|
||||
class BoundedAxisSize(NamedTuple):
|
||||
|
@ -58,7 +58,7 @@ def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun, in_tree, out_tree, False,
|
||||
primitive_name or "<unknown>")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
|
||||
return jaxpr, consts, out_tree()
|
||||
|
||||
@weakref_lru_cache
|
||||
@ -226,7 +226,7 @@ def _prune_zeros(ts):
|
||||
return [t for t in ts if type(t) is not ad_util.Zero]
|
||||
|
||||
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
def _show_diff(array1, array2):
|
||||
|
@ -88,7 +88,7 @@ def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
consts = map(lambda x: ref_get(x, ()), consts)
|
||||
all_consts = merge_lists(is_const_ref, consts, const_refs)
|
||||
return core.eval_jaxpr(jaxpr, all_consts, i, *args)
|
||||
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_hoist), in_avals)
|
||||
assert not consts, "All consts should have been converted to refs"
|
||||
return hoisted_jaxpr
|
||||
@ -98,7 +98,7 @@ def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
|
||||
) -> tuple[core.Jaxpr, list[Any], PyTreeDef]:
|
||||
f, out_tree_thunk = flatten_fun_nokwargs(
|
||||
lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, state_avals)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
@ -607,7 +607,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
|
||||
return for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps,
|
||||
reverse=reverse, which_linear=jaxpr_known_which_linear,
|
||||
unroll=unroll)
|
||||
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
|
||||
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
known, [v.aval for v in known_invars])
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||||
eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars],
|
||||
@ -629,7 +629,7 @@ def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
|
||||
_, ans = split_list(out_flat, [num_res])
|
||||
_, ans = partition_list(out_inst, ans)
|
||||
return ans
|
||||
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
|
||||
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
staged, [v.aval for v in [*resvars, *eqn.invars]])
|
||||
assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars)
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||||
@ -667,7 +667,7 @@ def _convert_outputs_to_writes(
|
||||
AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error
|
||||
v.aval.dtype)) # pytype: disable=attribute-error
|
||||
for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [*in_avals, *res_ref_avals])
|
||||
assert not consts
|
||||
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] # pytype: disable=attribute-error
|
||||
@ -693,7 +693,7 @@ def _convert_inputs_to_reads(
|
||||
aval.dtype)) # pytype: disable=attribute-error
|
||||
for aval, loop_invar in zip(res_val_avals, loop_invar_res)]
|
||||
|
||||
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])
|
||||
return jaxpr
|
||||
|
||||
@ -720,7 +720,7 @@ def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: list[bool]) -> core.Jaxpr:
|
||||
ad.backward_pass(
|
||||
tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ())
|
||||
return []
|
||||
jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr_trans, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
|
||||
return jaxpr_trans
|
||||
|
||||
|
@ -985,7 +985,7 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
||||
intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res])
|
||||
out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known)
|
||||
return [*intensive_res, *out_loop]
|
||||
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
|
||||
call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
known, [v.aval for v in ins_known])
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||||
eqn_known = pe.new_jaxpr_eqn(
|
||||
@ -1108,8 +1108,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
|
||||
new_in_avals = [*remaining_const_avals, *[a.inner_aval for a in in_ref_avals],
|
||||
*carry_avals,
|
||||
*[core.mapped_aval(length, 0, a) for a in xs_avals]]
|
||||
new_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped),
|
||||
new_in_avals)
|
||||
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), new_in_avals)
|
||||
new_linear = (*remaining_consts_linear, *in_refs_linear,
|
||||
*carry_linear, *xs_linear)
|
||||
all_out = scan_p.bind(*remaining_consts, *in_refs, *carry, *xs,
|
||||
@ -1768,7 +1767,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
|
||||
*carry)
|
||||
carry, refs_out = split_list(carry_refs, [num_carry])
|
||||
return [*refs_out, *carry]
|
||||
new_body_jaxpr, _, new_body_consts = pe.trace_to_jaxpr_dynamic(
|
||||
new_body_jaxpr, _, new_body_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(new_body), [*remaining_body_const_avals, *[a.inner_aval for a
|
||||
in ref_avals],
|
||||
*carry_avals])
|
||||
@ -1782,7 +1781,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr,
|
||||
consts_refs_carry, [cond_nconsts, num_refs])
|
||||
del refs # We don't use them here!
|
||||
return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry)
|
||||
new_cond_jaxpr, _, new_cond_consts = pe.trace_to_jaxpr_dynamic(
|
||||
new_cond_jaxpr, _, new_cond_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(new_cond), [*cond_consts_avals,
|
||||
*[a.inner_aval for a in ref_avals],
|
||||
*carry_avals])
|
||||
|
@ -1027,7 +1027,7 @@ def _reduction_jaxpr(computation, aval):
|
||||
f"Reduction functions should only return an array.\n"
|
||||
f"Full return value: {result}")
|
||||
return (result,)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(comp, (aval, aval))
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(comp, (aval, aval))
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise NotImplementedError(
|
||||
"Reduction computations can't close over Tracers. Please open an issue "
|
||||
@ -1040,7 +1040,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree):
|
||||
flat_in_avals, in_tree = tree_util.tree_flatten((avals, avals))
|
||||
comp = lu.wrap_init(computation)
|
||||
flat_comp, out_tree = api_util.flatten_fun_nokwargs(comp, in_tree)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_comp, tuple(flat_in_avals))
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise NotImplementedError(
|
||||
"Reduction computations can't close over Tracers. Please open an issue "
|
||||
|
@ -999,7 +999,7 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
for a, a_in_axes in zip(in_avals, params['in_axes'])]
|
||||
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
||||
with core.new_sublevel():
|
||||
jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
jaxpr, mapped_out_avals, consts, () = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, mapped_in_avals)
|
||||
out_axes = params['out_axes_thunk']()
|
||||
if params['spmd_out_axes_thunk'] is not None:
|
||||
@ -1340,7 +1340,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
# NOTE: We don't extend the resource env with the mesh shape, because those
|
||||
# resources are already in scope! It's the outermost xmap that introduces
|
||||
# them!
|
||||
vectorized_jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(f, local_avals)
|
||||
vectorized_jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, local_avals)
|
||||
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
|
||||
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
|
||||
|
||||
@ -1415,7 +1415,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
add_spmd_axes(mesh_in_axes, spmd_in_axes)
|
||||
add_spmd_axes(mesh_out_axes, spmd_out_axes)
|
||||
global_in_avals = ctx.avals_in
|
||||
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
||||
vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
||||
|
||||
sharded_global_in_nodes = [
|
||||
[mlir.wrap_with_sharding_op(
|
||||
@ -1477,7 +1477,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
# resources are already in scope! It's the outermost xmap that introduces
|
||||
# them!
|
||||
global_in_avals = ctx.avals_in
|
||||
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
||||
vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
||||
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
|
||||
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
|
@ -192,7 +192,7 @@ def _convert_block_spec_to_block_mapping(
|
||||
block_shape = tuple(
|
||||
mapped if s is None else s for s in block_shape)
|
||||
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
return BlockMapping(block_shape, jax_core.ClosedJaxpr(jaxpr, consts))
|
||||
|
||||
def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
|
||||
|
@ -543,7 +543,7 @@ def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
|
||||
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
wrapped_fun = lu.wrap_init(f, params)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
if consts:
|
||||
raise NotImplementedError
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
|
@ -103,7 +103,7 @@ def run_scoped(f: Callable[..., None], *types, **kw_types) -> None:
|
||||
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
|
||||
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(f), in_tree)
|
||||
avals = map(lambda t: t.get_aval(), flat_types)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
|
||||
run_scoped_p.bind(*consts, jaxpr=jaxpr)
|
||||
|
||||
|
||||
|
@ -26,11 +26,11 @@ from jax import api_util
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax._src import state
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src import ad_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.state import primitives as sp
|
||||
@ -258,7 +258,7 @@ def _batch_block_mapping(grid: tuple[int, ...], aval: jax_core.ShapedArray,
|
||||
idx_avals = [i32_aval] * (len(grid) + 1)
|
||||
else:
|
||||
idx_avals = [i32_aval, *block_mapping.index_map_jaxpr.in_avals]
|
||||
block_mapping_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_block_map_function), idx_avals)
|
||||
shape = aval.shape if block_mapping is None else block_mapping.block_shape
|
||||
if dim is batching.not_mapped:
|
||||
@ -387,7 +387,7 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
|
||||
consts = map(lambda x: sp.ref_get(x, ()), consts)
|
||||
all_consts = merge_lists(is_const_ref, consts, const_refs)
|
||||
return jax_core.eval_jaxpr(jaxpr, all_consts, *args)
|
||||
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_hoist), in_avals)
|
||||
assert not consts, "All consts should have been converted to refs"
|
||||
return hoisted_jaxpr
|
||||
@ -401,8 +401,7 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals,
|
||||
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun), jaxpr_in_tree)
|
||||
debug = pe.debug_info(fun, jaxpr_in_tree, out_tree_thunk, False, "pallas_call")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
|
||||
debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, debug)
|
||||
jaxpr = _hoist_consts_to_refs(jaxpr)
|
||||
return grid_mapping, jaxpr, consts, out_tree_thunk()
|
||||
|
||||
|
@ -53,8 +53,8 @@ from jax._src.util import merge_lists
|
||||
from jax._src.util import partition_list
|
||||
from jax._src.util import split_list
|
||||
from jax._src.util import weakref_lru_cache
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
import jax.numpy as jnp
|
||||
from jax_triton import triton_lib
|
||||
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace
|
||||
@ -424,7 +424,7 @@ def _associative_scan_lowering(
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(body), in_tree
|
||||
)
|
||||
combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
flat_fun, in_avals
|
||||
)
|
||||
out_tree = out_tree_thunk()
|
||||
@ -572,7 +572,7 @@ def lower_fun(
|
||||
|
||||
def f_lowered(ctx: TritonLoweringRuleContext, *args, **params):
|
||||
wrapped_fun = lu.wrap_init(fn, params)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
|
||||
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
|
||||
return out if multiple_results else out[0]
|
||||
@ -998,7 +998,7 @@ def _reduction_lowering(body, ctx: TritonLoweringRuleContext, a, axes):
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(body), in_tree
|
||||
)
|
||||
combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
combine_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
flat_fun, [*mapped_avals, *mapped_avals]
|
||||
)
|
||||
out_tree = out_tree_thunk()
|
||||
|
@ -21,7 +21,7 @@ import itertools as it
|
||||
import logging
|
||||
import operator as op
|
||||
import weakref
|
||||
from typing import Callable, cast, NamedTuple, Any, Union
|
||||
from typing import Callable, cast, NamedTuple, Any, Union, Optional
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
@ -129,10 +129,13 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
|
||||
|
||||
|
||||
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
|
||||
args_flat, _, params, _, out_tree, _, _, _, arg_names = infer_params_fn(
|
||||
*args, **kwargs)
|
||||
args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \
|
||||
infer_params_fn(*args, **kwargs)
|
||||
for arg in args_flat:
|
||||
dispatch.check_arg(arg)
|
||||
if attrs_tracked:
|
||||
init_states = _get_states(attrs_tracked)
|
||||
args_flat = [*init_states, *args_flat]
|
||||
try:
|
||||
out_flat = pjit_p.bind(*args_flat, **params)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
@ -142,8 +145,20 @@ def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, args_flat, api_name, arg_names)
|
||||
raise ValueError(msg) from None
|
||||
if attrs_tracked:
|
||||
final_states, out_flat = split_list(out_flat, [len(attrs_tracked)])
|
||||
_set_states(attrs_tracked, final_states)
|
||||
outs = tree_unflatten(out_tree, out_flat)
|
||||
return outs, out_flat, out_tree, args_flat, params['jaxpr']
|
||||
return outs, out_flat, out_tree, args_flat, params['jaxpr'], attrs_tracked
|
||||
|
||||
def _set_states(attrs_tracked, vals):
|
||||
from jax.experimental.attrs import jax_setattr # type: ignore
|
||||
for ((obj, attr), val) in zip(attrs_tracked, vals):
|
||||
jax_setattr(obj, attr, val)
|
||||
|
||||
def _get_states(attrs_tracked):
|
||||
from jax.experimental.attrs import jax_getattr # type: ignore
|
||||
return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked]
|
||||
|
||||
|
||||
def _python_pjit(fun: Callable, infer_params_fn):
|
||||
@ -161,7 +176,8 @@ def _python_pjit(fun: Callable, infer_params_fn):
|
||||
return wrapped
|
||||
|
||||
|
||||
def _get_fastpath_data(executable, out_tree, args_flat, out_flat):
|
||||
def _get_fastpath_data(executable, out_tree, args_flat, out_flat, attrs_tracked,
|
||||
) -> Optional[pxla.MeshExecutableFastpathData]:
|
||||
out_flat, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
|
||||
|
||||
use_fastpath = (
|
||||
@ -172,7 +188,9 @@ def _get_fastpath_data(executable, out_tree, args_flat, out_flat):
|
||||
not executable.unsafe_call.ordered_effects and
|
||||
not executable.unsafe_call.has_unordered_effects and
|
||||
not executable.unsafe_call.has_host_callbacks and
|
||||
all(isinstance(x, xc.ArrayImpl) for x in out_flat)
|
||||
all(isinstance(x, xc.ArrayImpl) for x in out_flat) and
|
||||
# no attr state effects
|
||||
not attrs_tracked
|
||||
)
|
||||
|
||||
if use_fastpath:
|
||||
@ -224,11 +242,12 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
|
||||
@api_boundary
|
||||
def cache_miss(*args, **kwargs):
|
||||
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
|
||||
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
|
||||
fun, infer_params_fn, *args, **kwargs)
|
||||
executable = _read_most_recent_pjit_call_executable(jaxpr)
|
||||
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
|
||||
return outs, fastpath_data
|
||||
maybe_fastpath_data = _get_fastpath_data(
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked)
|
||||
return outs, maybe_fastpath_data
|
||||
|
||||
if xla_extension_version >= 226:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
@ -312,7 +331,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
out_layouts = kwargs.pop('_out_layouts', None)
|
||||
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
|
||||
donated_invars, in_layouts_flat, out_layouts_flat,
|
||||
arg_names) = infer_params_fn(
|
||||
arg_names, ()) = infer_params_fn(
|
||||
*args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts)
|
||||
resource_env = params['resource_env']
|
||||
mesh = None if resource_env is None else resource_env.physical_mesh
|
||||
@ -339,7 +358,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
|
||||
@api_boundary
|
||||
def eval_shape(*args, **kwargs):
|
||||
_, _, params, _, out_tree, _, _, _, _ = infer_params_fn(
|
||||
_, _, params, _, out_tree, _, _, _, _, _ = infer_params_fn(
|
||||
*args, **kwargs, _in_layouts=None, _out_layouts=None)
|
||||
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape)
|
||||
for x in params['jaxpr'].out_avals]
|
||||
@ -464,7 +483,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
hashable_pytree(in_shardings), hashable_pytree(in_layouts), in_avals,
|
||||
in_tree, resource_env, dbg, device_or_backend_set, True if kwargs else False)
|
||||
|
||||
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
|
||||
jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
|
||||
flat_fun, hashable_pytree(out_shardings), hashable_pytree(out_layouts),
|
||||
in_type, dbg, device_or_backend_set, HashableFunction(out_tree, closure=()),
|
||||
HashableFunction(res_paths, closure=()), inline)
|
||||
@ -477,19 +496,19 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
implicit_args = []
|
||||
args_flat = [*implicit_args, *explicit_args]
|
||||
|
||||
num_extra_args = len(implicit_args) + len(consts)
|
||||
num_extra_args = len(implicit_args) + len(attrs_tracked) + len(consts)
|
||||
canonicalized_in_shardings_flat = \
|
||||
(UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
|
||||
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
|
||||
donated_invars = (False,) * num_extra_args + donated_invars
|
||||
assert (len(canonicalized_in_shardings_flat) == len(in_layouts_flat) ==
|
||||
len(donated_invars) == len(consts) + len(args_flat))
|
||||
len(donated_invars) == len(attrs_tracked) + len(consts) + len(args_flat))
|
||||
|
||||
# in_shardings and out_shardings here are all GSPMDSharding.
|
||||
params = dict(
|
||||
jaxpr=jaxpr,
|
||||
in_shardings=canonicalized_in_shardings_flat,
|
||||
out_shardings=canonicalized_out_shardings_flat,
|
||||
out_shardings=out_shardings,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=getattr(flat_fun, '__name__', '<unknown>'),
|
||||
@ -498,7 +517,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
)
|
||||
return (consts + args_flat, in_type, params, in_tree, out_tree(),
|
||||
donated_invars, in_layouts_flat, out_layouts_flat,
|
||||
dbg.arg_names if dbg else None)
|
||||
dbg.arg_names if dbg else None, attrs_tracked)
|
||||
|
||||
def _extract_implicit_args(
|
||||
in_type: Sequence[tuple[core.AbstractValue, bool]],
|
||||
@ -1047,11 +1066,13 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
|
||||
if config.dynamic_shapes.value:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
||||
lu.annotate(fun, in_type), debug_info=pe_debug)
|
||||
attrs_tracked = []
|
||||
else:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
|
||||
fun, in_type, debug_info=pe_debug)
|
||||
|
||||
if not config.dynamic_shapes.value:
|
||||
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
|
||||
if not config.dynamic_shapes.value and not attrs_tracked:
|
||||
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
||||
|
||||
if config.enable_key_reuse_checks.value:
|
||||
@ -1065,7 +1086,7 @@ def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
|
||||
else:
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
final_consts = []
|
||||
return closed_jaxpr, final_consts, global_out_avals
|
||||
return closed_jaxpr, final_consts, global_out_avals, attrs_tracked
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
@ -1108,13 +1129,12 @@ def _check_and_canonicalize_out_shardings(
|
||||
|
||||
def _pjit_jaxpr(fun, out_shardings_thunk, out_layouts_thunk, in_type, debug_info,
|
||||
device_or_backend_set, out_tree, result_paths, inline):
|
||||
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
|
||||
jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
|
||||
fun, in_type, debug_info, result_paths, IgnoreKey(inline))
|
||||
canonicalized_out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
||||
out_shardings_thunk, out_layouts_thunk, out_tree, tuple(out_type),
|
||||
jaxpr.jaxpr.debug_info, device_or_backend_set)
|
||||
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
|
||||
return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat
|
||||
return jaxpr, final_consts, canonicalized_out_shardings_flat, out_layouts_flat, attrs_tracked
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -1357,7 +1377,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
fastpath_data = _get_fastpath_data(
|
||||
compiled, tree_structure(out_flat), args, out_flat)
|
||||
compiled, tree_structure(out_flat), args, out_flat, [])
|
||||
return out_flat, fastpath_data
|
||||
|
||||
f = _get_jaxpr_as_fun(
|
||||
@ -1856,7 +1876,7 @@ pe.partial_eval_jaxpr_custom_rules[pjit_p] = \
|
||||
|
||||
@lu.cache
|
||||
def _pjit_transpose_trace(fun, in_avals):
|
||||
transpose_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
transpose_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
|
||||
return transpose_jaxpr
|
||||
|
||||
|
@ -66,7 +66,7 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
|
||||
else v.aval for v, d in zip(jaxpr.invars, should_discharge)]
|
||||
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr,
|
||||
should_discharge, consts))
|
||||
new_jaxpr, _ , new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
|
||||
new_jaxpr, _ , new_consts, () = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
|
||||
return new_jaxpr, new_consts
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -427,7 +427,7 @@ def _convert_outputs_to_writes(
|
||||
return []
|
||||
res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef)
|
||||
else v.aval for v in jaxpr.outvars]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [*in_avals, *res_ref_avals])
|
||||
assert not consts
|
||||
return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error
|
||||
@ -447,7 +447,7 @@ def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
split_list([v.aval for v in jaxpr.invars], [num_res])
|
||||
res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else
|
||||
aval for aval in res_val_avals]
|
||||
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [*res_ref_avals, *orig_ref_avals])
|
||||
return jaxpr
|
||||
|
||||
@ -648,7 +648,7 @@ def _run_state_partial_eval_custom(
|
||||
def staged(*args):
|
||||
out = run_state_p.bind(*args, **staged_params)
|
||||
return out[num_res:]
|
||||
staged_call_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(staged,
|
||||
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged,
|
||||
[v.aval for v in res_staged_invars])
|
||||
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
|
||||
staged_outvars,
|
||||
@ -707,7 +707,7 @@ def _transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: Sequence[bool]
|
||||
ad.backward_pass(
|
||||
tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ())
|
||||
return []
|
||||
jaxpr_trans, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
|
||||
return jaxpr_trans, consts
|
||||
|
||||
@ -770,7 +770,7 @@ def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||
fun_, out_tree_thunk = api_util.flatten_fun_nokwargs(lu.wrap_init(fun),
|
||||
tree_util.treedef_tuple((in_tree,)))
|
||||
debug = pe.debug_info(fun_, in_tree, out_tree_thunk, False, 'run_state')
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
def run_state(f: Callable[..., None]):
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.state import AbstractRef
|
||||
@ -44,7 +44,7 @@ def hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
consts = map(lambda x: ref_get(x, ()), consts)
|
||||
all_consts = merge_lists(is_const_ref, consts, const_refs)
|
||||
return core.eval_jaxpr(jaxpr, all_consts, *args)
|
||||
hoisted_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(_hoist), in_avals)
|
||||
assert not consts, "All consts should have been converted to refs"
|
||||
return hoisted_jaxpr
|
||||
|
64
jax/experimental/attrs.py
Normal file
64
jax/experimental/attrs.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from jax._src import core
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
JaxVal = Any
|
||||
|
||||
getattr_p = core.Primitive('getattr')
|
||||
setattr_p = core.Primitive('setattr')
|
||||
|
||||
def jax_getattr(obj: Any, attr: str):
|
||||
return getattr_p.bind(obj=obj, attr=attr)
|
||||
|
||||
def jax_setattr(obj: Any, attr: str, val: JaxVal):
|
||||
setattr_p.bind(val, obj=obj, attr=attr)
|
||||
|
||||
|
||||
@getattr_p.def_impl
|
||||
def _getattr_impl(*, obj, attr):
|
||||
return getattr(obj, attr)
|
||||
|
||||
@setattr_p.def_impl
|
||||
def _setattr_impl(val, *, obj, attr):
|
||||
setattr(obj, attr, val)
|
||||
|
||||
|
||||
def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str):
|
||||
frame = trace.main.jaxpr_stack[-1] # type: ignore
|
||||
if (obj, attr) not in frame.attrs_tracked:
|
||||
init_val = getattr(obj, attr)
|
||||
aval = core.raise_to_shaped(core.get_aval(init_val))
|
||||
tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current())
|
||||
var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval)
|
||||
setattr(obj, attr, tracer)
|
||||
frame.attrs_tracked.append((obj, attr))
|
||||
frame.attrs_inits.append(init_val)
|
||||
frame.attrs_vars.append(var)
|
||||
pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked
|
||||
|
||||
def _getattr_staging(trace, *, obj, attr):
|
||||
trace._ensure_tracked(obj, attr)
|
||||
return getattr(obj, attr)
|
||||
pe.custom_staging_rules[getattr_p] = _getattr_staging
|
||||
|
||||
def _setattr_staging(trace, tracer, *, obj, attr):
|
||||
trace._ensure_tracked(obj, attr)
|
||||
setattr(obj, attr, tracer)
|
||||
pe.custom_staging_rules[setattr_p] = _setattr_staging
|
@ -459,7 +459,7 @@ class custom_partitioning:
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
|
||||
"custom_partitioning")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
assert not len(consts)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
out_flat = custom_partitioning_p.bind(
|
||||
|
@ -718,7 +718,7 @@ def _jet_jaxpr(
|
||||
) -> tuple[core.ClosedJaxpr, Any]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
|
||||
jaxpr_jet, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
f_jet, primals_and_series_avals)
|
||||
return core.ClosedJaxpr(jaxpr_jet, consts), out_tree_def
|
||||
|
||||
|
@ -181,7 +181,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature
|
||||
args_flat, in_tree = tree_util.tree_flatten(args)
|
||||
in_avals_flat = [core.get_aval(arg) for arg in args_flat]
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
|
||||
return get_jaxpr_type_signature(jaxpr)
|
||||
|
||||
|
||||
|
@ -152,7 +152,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature
|
||||
args_flat, in_tree = tree_util.tree_flatten(args)
|
||||
in_avals_flat = [core.get_aval(arg) for arg in args_flat]
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
|
||||
return get_jaxpr_type_signature(jaxpr)
|
||||
|
||||
|
||||
|
@ -426,7 +426,7 @@ def _shard_map_staging(
|
||||
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
|
||||
main = trace.main
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, genavals, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
|
||||
jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
|
||||
out_avals_ = map(_check_shapedarray, genavals)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
@ -1378,7 +1378,7 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which):
|
||||
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
|
||||
for v, w in zip(jaxpr.constvars, which)]
|
||||
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
return jaxpr
|
||||
|
||||
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
@ -1495,7 +1495,7 @@ def _add_reshapes(which, jaxpr_known, jaxpr_staged):
|
||||
res = [_add_singleton(x) if not x.shape else x for x in res]
|
||||
return [*out_known, *res]
|
||||
avals_in = [v.aval for v in jaxpr_known.invars]
|
||||
jaxpr_known, _, () = pe.trace_to_jaxpr_dynamic(known, avals_in)
|
||||
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(known, avals_in)
|
||||
|
||||
@lu.wrap_init
|
||||
def staged(*args):
|
||||
@ -1505,7 +1505,7 @@ def _add_reshapes(which, jaxpr_known, jaxpr_staged):
|
||||
res_avals = [core.unmapped_aval(1, None, 0, v.aval) if w else v.aval
|
||||
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
|
||||
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
|
||||
jaxpr_staged, _, () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
|
||||
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
|
||||
|
||||
return jaxpr_known, jaxpr_staged
|
||||
|
||||
@ -1773,7 +1773,7 @@ def _replication_rewrite_match(
|
||||
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
|
||||
f = _match_rep(f, mesh, out_rep, out_rep_dst)
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
|
||||
return core.ClosedJaxpr(jaxpr_, consts)
|
||||
|
||||
# TODO(mattjj): caching
|
||||
@ -1785,7 +1785,7 @@ def _replication_rewrite_nomatch(
|
||||
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
|
||||
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
|
||||
return core.ClosedJaxpr(jaxpr_, consts), out_rep()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
|
@ -447,7 +447,7 @@ def sparsify_raw(f):
|
||||
spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue)
|
||||
in_avals_flat = spvalues_to_avals(spenv, spvalues_flat)
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
|
||||
jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
|
||||
jaxpr, out_avals_flat, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
|
||||
result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
|
||||
if len(out_avals_flat) != len(result):
|
||||
raise Exception("Internal: eval_sparse does not return expected number of arguments. "
|
||||
@ -747,7 +747,7 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
|
||||
args = spvalues_to_arrays(spenv, spvalues)
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
|
||||
sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
|
||||
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
|
||||
sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
|
||||
assert out_tree is not None
|
||||
return sp_jaxpr, out_tree
|
||||
|
@ -82,7 +82,7 @@ from jax._src.interpreters.partial_eval import (
|
||||
result_info as result_info,
|
||||
sig_info as sig_info,
|
||||
trace_to_jaxpr as trace_to_jaxpr,
|
||||
trace_to_jaxpr_dynamic as trace_to_jaxpr_dynamic,
|
||||
trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic,
|
||||
trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2,
|
||||
trace_to_jaxpr_final as trace_to_jaxpr_final,
|
||||
trace_to_jaxpr_final2 as trace_to_jaxpr_final2,
|
||||
@ -97,4 +97,12 @@ from jax._src.interpreters.partial_eval import (
|
||||
trivial_ctx as trivial_ctx,
|
||||
)
|
||||
|
||||
|
||||
# TODO(mattjj): remove temporary shim when trace_to_jaxpr_dynamic sig stabilizes
|
||||
def trace_to_jaxpr_dynamic(fun, in_avals, debug_info=None, *, keep_inputs=None): # noqa
|
||||
jaxpr, out_avals, consts, () = _trace_to_jaxpr_dynamic(
|
||||
fun, in_avals, debug_info, keep_inputs=keep_inputs)
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
|
||||
from jax._src.core import Jaxpr as Jaxpr
|
||||
|
@ -1301,6 +1301,15 @@ jax_test(
|
||||
srcs = ["clear_backends_test.py"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "attrs_test",
|
||||
srcs = ["attrs_test.py"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
jax_test(
|
||||
name = "experimental_rnn_test",
|
||||
srcs = ["experimental_rnn_test.py"],
|
||||
|
64
tests/attrs_test.py
Normal file
64
tests/attrs_test.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
|
||||
from jax.experimental.attrs import jax_setattr, jax_getattr
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
@dataclass
|
||||
class Thing:
|
||||
x: float
|
||||
|
||||
class AttrsTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_basic(self, jit: bool):
|
||||
thing = Thing(1.0)
|
||||
|
||||
def double_it() -> None:
|
||||
cur_x = jax_getattr(thing, "x")
|
||||
jax_setattr(thing, "x", cur_x * 2)
|
||||
|
||||
if jit:
|
||||
double_it = jax.jit(double_it)
|
||||
|
||||
self.assertEqual(thing.x, 1.0)
|
||||
double_it()
|
||||
self.assertEqual(thing.x, 2.0)
|
||||
double_it()
|
||||
self.assertEqual(thing.x, 4.0)
|
||||
double_it()
|
||||
self.assertEqual(thing.x, 8.0)
|
||||
double_it()
|
||||
self.assertEqual(thing.x, 16.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -624,7 +624,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
def f(x, y):
|
||||
return x, y
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
self.assertLen(jaxpr.invars, 3)
|
||||
@ -648,7 +648,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
return (x, w)
|
||||
return g(x, y, x, y)
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
|
||||
@ -684,7 +684,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
return (x, w)
|
||||
return g(x.shape[0], x, y, x, y)
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
|
||||
@ -720,7 +720,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
u = lax_internal._reduce_sum(w, [0])
|
||||
return (u,)
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, a, b], keep_inputs=[False, True, True])
|
||||
|
||||
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
|
||||
@ -745,7 +745,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
def g(x): return x
|
||||
return g(a),
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
f, [n, m, a, b], keep_inputs=[False, False, True, True])
|
||||
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
|
||||
# e:f32[a] = xla_call[
|
||||
|
@ -318,7 +318,7 @@ class PallasCallDMATest(parameterized.TestCase):
|
||||
pltpu.run_scoped(body, pltpu.VMEM((8,), jnp.float32))
|
||||
return []
|
||||
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(kernel),
|
||||
[
|
||||
state.shaped_array_ref((8,), jnp.float32),
|
||||
|
@ -132,7 +132,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(Exception):
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval])
|
||||
else:
|
||||
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval])
|
||||
self.assertSetEqual(jaxpr.effects,
|
||||
{ReadEffect(len(jaxpr.constvars))})
|
||||
@ -229,7 +229,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(Exception):
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
||||
else:
|
||||
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval])
|
||||
self.assertSetEqual(jaxpr.effects,
|
||||
{WriteEffect(len(jaxpr.constvars))})
|
||||
@ -306,7 +306,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(Exception):
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), [ref_aval, val_aval])
|
||||
else:
|
||||
jaxpr, out_avals, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, out_avals, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval])
|
||||
self.assertSetEqual(jaxpr.effects,
|
||||
{AccumEffect(len(jaxpr.constvars))})
|
||||
@ -326,7 +326,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
x[()] = jnp.int32(1)
|
||||
x[()] = jnp.int32(2)
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
@ -339,7 +339,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def body(x):
|
||||
ref_addupdate(x, (), jnp.int32(1))
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
@ -349,14 +349,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def body(x_ref):
|
||||
x = x_ref[()]
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def body(x_ref):
|
||||
x = x_ref[:, 0]
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32)])
|
||||
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
@ -364,14 +364,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def body(x_ref):
|
||||
x_ref[()] = jnp.int32(2)
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def body(x_ref, val):
|
||||
x_ref[:, 0] = val
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
|
||||
core.ShapedArray((1,), jnp.int32)])
|
||||
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
|
||||
@ -380,14 +380,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def body(x_ref):
|
||||
x = ref_swap(x_ref, (), jnp.int32(2))
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def body(x_ref, val):
|
||||
x = ref_swap(x_ref, (slice(None), 0), val)
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
|
||||
core.ShapedArray((1,), jnp.int32)])
|
||||
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
|
||||
@ -397,7 +397,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def body(x_ref):
|
||||
ref_addupdate(x_ref, (), jnp.int32(2))
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((), jnp.int32)])
|
||||
|
||||
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
|
||||
@ -405,7 +405,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def body(x_ref, val):
|
||||
ref_addupdate(x_ref, (slice(None), 0), val)
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [shaped_array_ref((1, 2), jnp.int32),
|
||||
core.ShapedArray((1,), jnp.int32)])
|
||||
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
|
||||
@ -422,7 +422,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
shaped_array_ref((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
|
||||
|
||||
@ -438,7 +438,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
shaped_array_ref((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, get_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, lax.sin_p)
|
||||
@ -458,7 +458,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
shaped_array_ref((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, addupdate_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, addupdate_p)
|
||||
self.assertEqual(jaxpr.eqns[2].primitive, get_p)
|
||||
@ -529,12 +529,12 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
# discharge-of-vmap
|
||||
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
|
||||
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f_batched), [bat_ref_aval, *bat_idx_avals])
|
||||
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs)
|
||||
# vmap-of-discharge
|
||||
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, *idx_avals])
|
||||
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
||||
@ -553,7 +553,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
a = ref_get(a_ref, ())
|
||||
return [a + 1]
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||
@ -569,7 +569,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
a = ref_get(a_ref, (0, 1))
|
||||
return [a + 1]
|
||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
|
||||
@ -588,7 +588,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
a = a_ref[jnp.array([0, 1])]
|
||||
return [a + 1]
|
||||
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), in_avals)
|
||||
discharged_jaxpr, discharged_consts = discharge_state(
|
||||
stateful_jaxpr, consts)
|
||||
@ -603,7 +603,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
return []
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that ignores the first
|
||||
# value and returns second value plus 1.
|
||||
@ -620,7 +620,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
discharged_jaxpr, () = discharge_state(stateful_jaxpr, consts)
|
||||
@ -640,7 +640,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
|
||||
return []
|
||||
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, discharged_consts = discharge_state(
|
||||
stateful_jaxpr, consts)
|
||||
@ -654,7 +654,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
return []
|
||||
in_avals = [shaped_array_ref((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that adds the first value,
|
||||
# second value, and 1.
|
||||
@ -672,7 +672,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
@ -693,7 +693,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
jnp.ones((2, 3), 'float32'))
|
||||
return []
|
||||
in_avals = [shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, discharged_consts = discharge_state(
|
||||
stateful_jaxpr, consts)
|
||||
@ -707,7 +707,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
b = a + 1
|
||||
return [a, b]
|
||||
in_avals = [shaped_array_ref((4,), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = discharge_state(stateful_jaxpr, consts)
|
||||
self.assertLen(discharged_jaxpr.invars, 1)
|
||||
@ -727,7 +727,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
shaped_array_ref((4,), jnp.dtype('float32')),
|
||||
shaped_array_ref((4,), jnp.dtype('float32'))
|
||||
]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
stateful_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = discharge_state(
|
||||
stateful_jaxpr, consts, should_discharge=[False, True])
|
||||
@ -929,13 +929,13 @@ if CAN_USE_HYPOTHESIS:
|
||||
ref = get_vmap_param.bat_ref
|
||||
|
||||
f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim])
|
||||
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f_batched), [bat_ref_aval, *bat_non_slice_idx_avals])
|
||||
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx)
|
||||
|
||||
# vmap-of-discharge
|
||||
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, *idx_avals])
|
||||
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
||||
@ -974,13 +974,13 @@ if CAN_USE_HYPOTHESIS:
|
||||
|
||||
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
||||
out_axes=[])
|
||||
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
|
||||
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
|
||||
|
||||
# vmap-of-discharge
|
||||
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
|
||||
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
||||
@ -1018,13 +1018,13 @@ if CAN_USE_HYPOTHESIS:
|
||||
|
||||
f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims),
|
||||
out_axes=[])
|
||||
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f_batched), [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals])
|
||||
jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx)
|
||||
|
||||
# vmap-of-discharge
|
||||
stateful_jaxpr, _, stateful_consts = pe.trace_to_jaxpr_dynamic(
|
||||
stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [ref_aval, val_aval, *idx_avals])
|
||||
jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts)
|
||||
f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_),
|
||||
@ -1323,7 +1323,7 @@ class GeneralRefTest(jtu.JaxTestCase):
|
||||
x_ref[...] = x
|
||||
ref_addupdate(x_ref, (), x)
|
||||
return [x]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [AbstractRef(core.UnshapedArray(jnp.int32))])
|
||||
self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray)
|
||||
self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32"))
|
||||
@ -1334,7 +1334,7 @@ class GeneralRefTest(jtu.JaxTestCase):
|
||||
x_ref[...] = x
|
||||
ref_addupdate(x_ref, (), x)
|
||||
return [x]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [AbstractRef(core.AbstractToken())])
|
||||
self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken)
|
||||
|
||||
@ -1343,7 +1343,7 @@ class GeneralRefTest(jtu.JaxTestCase):
|
||||
x_ref = x_ref_ref[...]
|
||||
return [x_ref]
|
||||
# Not sure why you'd ever want to do this, but it works!
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f),
|
||||
[AbstractRef(AbstractRef(core.ShapedArray((), jnp.int32)))])
|
||||
self.assertIs(type(jaxpr.outvars[0].aval), AbstractRef)
|
||||
|
Loading…
x
Reference in New Issue
Block a user