mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Stop using generators for linear_util transformations.
They lead to confusing code, nasty bugs, and unhelpful (but terse!) stack traces.
This commit is contained in:
parent
ed9fdbbf0a
commit
1c9b23c566
@ -68,11 +68,13 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
|
||||
else:
|
||||
return tuple(map(_ensure_str, x))
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun(in_tree, *args_flat):
|
||||
@lu.transformation_with_aux2
|
||||
def flatten_fun(f, store, in_tree, *args_flat):
|
||||
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
||||
ans = yield py_args, py_kwargs
|
||||
yield tree_flatten(ans)
|
||||
ans = f(*py_args, **py_kwargs)
|
||||
ans, out_tree = tree_flatten(ans)
|
||||
store.store(out_tree)
|
||||
return ans
|
||||
|
||||
def apply_flat_fun(fun, io_tree, *py_args):
|
||||
in_tree_expected, out_tree = io_tree
|
||||
@ -82,11 +84,13 @@ def apply_flat_fun(fun, io_tree, *py_args):
|
||||
ans = fun(*args)
|
||||
return tree_unflatten(out_tree, ans)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_nokwargs(in_tree, *args_flat):
|
||||
@lu.transformation_with_aux2
|
||||
def flatten_fun_nokwargs(f, store, in_tree, *args_flat):
|
||||
py_args = tree_unflatten(in_tree, args_flat)
|
||||
ans = yield py_args, {}
|
||||
yield tree_flatten(ans)
|
||||
ans = f(*py_args)
|
||||
ans, out_tree = tree_flatten(ans)
|
||||
store.store(out_tree)
|
||||
return ans
|
||||
|
||||
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
||||
in_tree_expected, out_tree = io_tree
|
||||
@ -118,17 +122,18 @@ def flattened_fun_in_tree(
|
||||
else:
|
||||
return in_tree, lambda: out_tree_store.val, has_kwargs
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_nokwargs2(in_tree, *args_flat):
|
||||
@lu.transformation_with_aux2
|
||||
def flatten_fun_nokwargs2(f, store, in_tree, *args_flat):
|
||||
py_args = tree_unflatten(in_tree, args_flat)
|
||||
pair = yield py_args, {}
|
||||
pair = f(*py_args)
|
||||
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
||||
raise TypeError("expected function with aux output to return a two-element "
|
||||
f"tuple, but got type {type(pair)} with value {pair!r}")
|
||||
ans, aux = pair
|
||||
ans_flat, ans_tree = tree_flatten(ans)
|
||||
aux_flat, aux_tree = tree_flatten(aux)
|
||||
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
|
||||
store.store((ans_tree, aux_tree))
|
||||
return ans_flat, aux_flat
|
||||
|
||||
class _HashableWithStrictTypeEquality:
|
||||
"""Box object used when comparing static arguments as a jit key.
|
||||
@ -277,8 +282,8 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
|
||||
|
||||
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
||||
|
||||
@lu.transformation
|
||||
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
||||
@lu.transformation2
|
||||
def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
||||
sentinel = object()
|
||||
args = [sentinel] * (len(fixed_args) + len(dyn_args))
|
||||
for i, arg in zip(dyn_argnums, dyn_args):
|
||||
@ -286,9 +291,7 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
||||
fixed_args_ = iter(fixed_args)
|
||||
args = [next(fixed_args_).val if x is sentinel else x for x in args]
|
||||
assert next(fixed_args_, sentinel) is sentinel
|
||||
ans = yield args, kwargs
|
||||
yield ans
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
|
||||
kwargs: dict[str, Any]):
|
||||
@ -311,11 +314,10 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
|
||||
|
||||
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
||||
|
||||
@lu.transformation
|
||||
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
||||
@lu.transformation2
|
||||
def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
||||
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
|
||||
ans = yield args, kwargs
|
||||
yield ans
|
||||
return f(*args, **kwargs)
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
@ -435,9 +437,9 @@ def flat_out_axes(
|
||||
f, out_axes = _flat_out_axes(f, tuple(leaves), treedef)
|
||||
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _flat_out_axes(leaves, treedef, *args, **kwargs):
|
||||
ans = yield args, kwargs
|
||||
@lu.transformation_with_aux2
|
||||
def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
|
||||
ans = f(*args, **kwargs)
|
||||
spec = tree_unflatten(treedef, leaves)
|
||||
try:
|
||||
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
|
||||
@ -449,7 +451,8 @@ def _flat_out_axes(leaves, treedef, *args, **kwargs):
|
||||
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
|
||||
"pmapped function's output.")
|
||||
raise ValueError(msg) from None
|
||||
yield ans, spec_flat
|
||||
store.store(spec_flat)
|
||||
return ans
|
||||
|
||||
def check_callable(fun):
|
||||
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
||||
@ -683,11 +686,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
|
||||
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
|
||||
for path, l in generate_key_paths(x) if l is not static)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def result_paths(*args, **kwargs):
|
||||
@lu.transformation_with_aux2
|
||||
def result_paths(f, store, *args, **kwargs):
|
||||
"linear_util transform to get output pytree paths of pre-flattened function."
|
||||
ans = yield args, kwargs
|
||||
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]
|
||||
ans = f(*args, **kwargs)
|
||||
store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
return ans
|
||||
|
||||
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
|
||||
result_paths: tuple[str, ...] | None = None,
|
||||
|
@ -330,11 +330,12 @@ def update_error(error, pred, code, metadata, payload, effect_type):
|
||||
|
||||
## Checkify transformation for plumbing functional error values.
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _flatten_and_get_error_metadata_thunk(*invals):
|
||||
error, out = yield invals, {}
|
||||
@lu.transformation_with_aux2
|
||||
def _flatten_and_get_error_metadata_thunk(f, store, *invals):
|
||||
error, out = f(*invals)
|
||||
out_vals, out_tree = jtu.tree_flatten((error, out))
|
||||
yield out_vals, (out_tree, set(error._pred.keys()))
|
||||
store.store((out_tree, set(error._pred.keys())))
|
||||
return out_vals
|
||||
|
||||
def default_checkify_rule(primitive: core.Primitive, error: Error,
|
||||
enabled_errors, *invals: core.Value,
|
||||
@ -438,10 +439,12 @@ def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
|
||||
consts = tuple(c.x for c in hashable_consts)
|
||||
return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_output(*args):
|
||||
ans = yield args, {}
|
||||
yield tree_flatten(ans)
|
||||
@lu.transformation_with_aux2
|
||||
def flatten_fun_output(f, store, *args):
|
||||
ans = f(*args)
|
||||
ans, out_tree = tree_flatten(ans)
|
||||
store.store(out_tree)
|
||||
return ans
|
||||
|
||||
|
||||
def _reduce_any_error(error: Error):
|
||||
|
@ -75,13 +75,14 @@ _stop_gradient = partial(
|
||||
|
||||
|
||||
# like the api_util.py function, but also grabs output avals for error checking
|
||||
@lu.transformation_with_aux
|
||||
def _flatten_fun_nokwargs(in_tree, *args_flat):
|
||||
@lu.transformation_with_aux2
|
||||
def _flatten_fun_nokwargs(f, store, in_tree, *args_flat):
|
||||
py_args = tree_unflatten(in_tree, args_flat)
|
||||
ans = yield py_args, {}
|
||||
ans = f(*py_args)
|
||||
ans_flat, ans_tree = tree_flatten(ans)
|
||||
ans_avals = [core.get_aval(x) for x in ans_flat]
|
||||
yield ans_flat, (ans_tree, ans_avals)
|
||||
store.store((ans_tree, ans_avals))
|
||||
return ans_flat
|
||||
|
||||
|
||||
### JVPs
|
||||
@ -266,18 +267,18 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
def _add_args(f, extra_args):
|
||||
return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args))
|
||||
|
||||
@lu.transformation
|
||||
def _add_args_(extra_args, *args, **kwargs):
|
||||
@lu.transformation2
|
||||
def _add_args_(f, extra_args, *args, **kwargs):
|
||||
extra_args = tuple(arg.val for arg in extra_args)
|
||||
all_args = (extra_args + args)
|
||||
yield (yield all_args, kwargs)
|
||||
return f(*all_args, **kwargs)
|
||||
|
||||
@partial(lu.transformation_with_aux, use_eq_store=True)
|
||||
def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
|
||||
@partial(lu.transformation_with_aux2, use_eq_store=True)
|
||||
def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args):
|
||||
primals_in, tangents_in = split_list(args, [len(args) // 2])
|
||||
py_primals = tree_unflatten(in_tree, primals_in)
|
||||
py_tangents = tree_unflatten(in_tree, tangents_in)
|
||||
pair_out = yield (py_primals, py_tangents), {}
|
||||
pair_out = f(py_primals, py_tangents)
|
||||
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
|
||||
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} "
|
||||
"must produce a pair (list or tuple of length two) representing "
|
||||
@ -348,7 +349,8 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
|
||||
if av_et != av_t)
|
||||
|
||||
raise TypeError(msg.format('\n'.join(disagreements)))
|
||||
yield primals_out + tangents_out, (out_tree, primal_avals)
|
||||
store.store((out_tree, primal_avals))
|
||||
return primals_out + tangents_out
|
||||
|
||||
class CustomJVPCallPrimitive(core.Primitive):
|
||||
multiple_results = True
|
||||
@ -652,15 +654,15 @@ def _check_for_tracers(x):
|
||||
"arguments should typically not be indicated as nondiff_argnums.")
|
||||
raise UnexpectedTracerError(msg)
|
||||
|
||||
@partial(lu.transformation_with_aux, use_eq_store=True)
|
||||
def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
|
||||
@partial(lu.transformation_with_aux2, use_eq_store=True)
|
||||
def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
|
||||
*args):
|
||||
if symbolic_zeros:
|
||||
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
|
||||
else:
|
||||
args = args[::2]
|
||||
py_args = tree_unflatten(in_tree, args)
|
||||
pair_out = yield py_args, {}
|
||||
pair_out = f(*py_args)
|
||||
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
|
||||
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
|
||||
"must produce a pair (list or tuple of length two) where the first "
|
||||
@ -710,16 +712,17 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
|
||||
"shapes/dtypes of:\n"
|
||||
f""" {str(ty_tree_).replace("'", "")}""")
|
||||
raise TypeError(m)
|
||||
yield (*res, *primals_out), (out_tree, res_tree)
|
||||
store.store((out_tree, res_tree))
|
||||
return (*res, *primals_out)
|
||||
|
||||
@lu.transformation
|
||||
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
||||
@lu.transformation2
|
||||
def _flatten_bwd(f, in_tree, in_avals, out_trees, *args):
|
||||
out_tree, res_tree = out_trees()
|
||||
assert len(args) == res_tree.num_leaves + out_tree.num_leaves
|
||||
res, cts_out = split_list(args, [res_tree.num_leaves])
|
||||
py_res = tree_unflatten(res_tree, res)
|
||||
py_cts_out = tree_unflatten(out_tree, cts_out)
|
||||
py_cts_in = yield (py_res, py_cts_out), {}
|
||||
py_cts_in = f(py_res, py_cts_out)
|
||||
if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)):
|
||||
py_cts_in = tuple(py_cts_in)
|
||||
# For each None in py_cts_in, indicating an argument for which the rule
|
||||
@ -775,7 +778,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
||||
f"to an input of shape/dtype {a.str_short()}.")
|
||||
raise ValueError(msg)
|
||||
results.append(ct)
|
||||
yield results
|
||||
return results
|
||||
|
||||
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
|
||||
def _temporary_dtype_exception(a, a_) -> bool:
|
||||
@ -1425,11 +1428,11 @@ def optimize_remat_of_custom_vjp_fwd(
|
||||
|
||||
return wrapped_fwd
|
||||
|
||||
@lu.transformation
|
||||
def _fix_fwd_args(*args):
|
||||
@lu.transformation2
|
||||
def _fix_fwd_args(f, *args):
|
||||
args = [(x, True) for x in args]
|
||||
args = [x for pair in args for x in pair]
|
||||
yield (yield args, {})
|
||||
return f(*args)
|
||||
|
||||
def _remat_opt_impl(
|
||||
*args,
|
||||
|
@ -68,42 +68,43 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
|
||||
fun, aux = jvp_subtrace_aux(fun)
|
||||
return jvpfun(fun, instantiate, transform_stack), aux
|
||||
|
||||
@lu.transformation
|
||||
def jvpfun(instantiate, transform_stack, primals, tangents):
|
||||
@lu.transformation2
|
||||
def jvpfun(f, instantiate, transform_stack, primals, tangents):
|
||||
tag = core.TraceTag()
|
||||
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
|
||||
and dtype(t) == float0 else t for t in tangents]
|
||||
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
|
||||
else contextlib.nullcontext())
|
||||
with ctx:
|
||||
out_primals, out_tangents = yield (tag, primals, tangents), {}
|
||||
out_primals, out_tangents = f(tag, primals, tangents)
|
||||
if type(instantiate) is bool:
|
||||
instantiate = [instantiate] * len(out_tangents)
|
||||
out_tangents = [instantiate_zeros(t) if inst else t for t, inst
|
||||
in zip(out_tangents, instantiate)]
|
||||
yield out_primals, out_tangents
|
||||
return out_primals, out_tangents
|
||||
|
||||
@lu.transformation
|
||||
def jvp_subtrace(tag, primals, tangents):
|
||||
@lu.transformation2
|
||||
def jvp_subtrace(f, tag, primals, tangents):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JVPTrace(parent_trace, tag)
|
||||
in_tracers = [maybe_jvp_tracer(trace, x, t)
|
||||
for x, t in zip(primals, tangents)]
|
||||
with core.set_current_trace(trace):
|
||||
ans = yield in_tracers, {}
|
||||
ans = f(*in_tracers)
|
||||
out = unzip2(map(trace.to_primal_tangent_pair, ans))
|
||||
yield out
|
||||
return out
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def jvp_subtrace_aux(tag, primals, tangents):
|
||||
@lu.transformation_with_aux2
|
||||
def jvp_subtrace_aux(f, store, tag, primals, tangents):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JVPTrace(parent_trace, tag)
|
||||
with core.set_current_trace(trace):
|
||||
ans, aux = yield map(partial(maybe_jvp_tracer, trace), primals, tangents), {}
|
||||
ans, aux = f(*(map(partial(maybe_jvp_tracer, trace), primals, tangents)))
|
||||
out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans))
|
||||
aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag
|
||||
else x for x in aux]
|
||||
yield (out_primals, out_tangents), aux_primals
|
||||
store.store(aux_primals)
|
||||
return out_primals, out_tangents
|
||||
|
||||
def linearize(traceable, *primals, **kwargs):
|
||||
has_aux = kwargs.pop('has_aux', False)
|
||||
@ -262,10 +263,11 @@ def get_primitive_transpose(p):
|
||||
"Transpose rule (for reverse-mode differentiation) for '{}' "
|
||||
"not implemented".format(p)) from err
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def nonzero_tangent_outputs(*args, **kwargs):
|
||||
results = (_, tangents_out) = yield args, kwargs
|
||||
yield results, [type(r) is not Zero for r in tangents_out]
|
||||
@lu.transformation_with_aux2
|
||||
def nonzero_tangent_outputs(f, store, *args, **kwargs):
|
||||
results = (_, tangents_out) = f(*args, **kwargs)
|
||||
store.store([type(r) is not Zero for r in tangents_out])
|
||||
return results
|
||||
|
||||
|
||||
class JVPTrace(Trace):
|
||||
@ -543,15 +545,16 @@ deflinear2(add_jaxvals_p, lambda t, *args: (t, t))
|
||||
def instantiate_zeros(tangent):
|
||||
return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def traceable(in_tree, *primals_and_tangents):
|
||||
@lu.transformation_with_aux2
|
||||
def traceable(f, store, in_tree, *primals_and_tangents):
|
||||
primals, tangents = tree_unflatten(in_tree, primals_and_tangents)
|
||||
tangents = [Zero.from_primal_value(p) if t is None else t
|
||||
for p, t in zip(primals, tangents)]
|
||||
primals_out, tangents_out = yield (primals, tangents), {}
|
||||
primals_out, tangents_out = f(primals, tangents)
|
||||
tangents_out = [None if type(t) is Zero else t for t in tangents_out]
|
||||
out_flat, out_tree = tree_flatten((primals_out, tangents_out))
|
||||
yield out_flat, out_tree
|
||||
store.store(out_tree)
|
||||
return out_flat
|
||||
|
||||
|
||||
def call_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
@ -588,10 +591,11 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals):
|
||||
primitive_transposes[core.closed_call_p] = _closed_call_transpose
|
||||
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def nonzero_outputs(*args, **kwargs):
|
||||
results = yield args, kwargs
|
||||
yield results, [type(r) is not Zero for r in results]
|
||||
@lu.transformation_with_aux2
|
||||
def nonzero_outputs(f, store, *args, **kwargs):
|
||||
results = f(*args, **kwargs)
|
||||
store.store([type(r) is not Zero for r in results])
|
||||
return results
|
||||
|
||||
def map_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
@ -655,17 +659,18 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
||||
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
|
||||
@lu.transformation_with_aux2
|
||||
def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents):
|
||||
num_primals = len(nonzeros)
|
||||
primals = list(primals_and_nztangents[:num_primals])
|
||||
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
|
||||
tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p)
|
||||
for p, nz in zip(primals, nonzeros)]
|
||||
primals_out, tangents_out = yield (primals, tangents), {}
|
||||
primals_out, tangents_out = f(primals, tangents)
|
||||
out_nonzeros = [type(t) is not Zero for t in tangents_out]
|
||||
nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero]
|
||||
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
|
||||
store.store(out_nonzeros)
|
||||
return list(primals_out) + nonzero_tangents_out
|
||||
|
||||
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
|
||||
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
|
||||
|
@ -327,11 +327,13 @@ def unregister_vmappable(data_type: type) -> None:
|
||||
def is_vmappable(x: Any) -> bool:
|
||||
return type(x) is Jumble or type(x) in vmappables
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_for_vmap(in_tree, *args_flat):
|
||||
@lu.transformation_with_aux2
|
||||
def flatten_fun_for_vmap(f, store, in_tree, *args_flat):
|
||||
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
||||
ans = yield py_args, py_kwargs
|
||||
yield tree_flatten(ans, is_leaf=is_vmappable)
|
||||
ans = f(*py_args, **py_kwargs)
|
||||
ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable)
|
||||
store.store(out_tree)
|
||||
return ans
|
||||
|
||||
# Propagate ragged masking rules from invars to outvars
|
||||
# rule([params], [raggedness_per_invar], outvars) ->
|
||||
@ -580,16 +582,16 @@ def batch(fun: lu.WrappedFun, axis_data,
|
||||
f = _batch_inner(fun, axis_data, out_dim_dests)
|
||||
return _batch_outer(f, axis_data, in_dims)
|
||||
|
||||
@lu.transformation
|
||||
def _batch_outer(axis_data, in_dims, *in_vals):
|
||||
@lu.transformation2
|
||||
def _batch_outer(f, axis_data, in_dims, *in_vals):
|
||||
tag = TraceTag()
|
||||
with source_info_util.transform_name_stack('vmap'):
|
||||
outs, trace = yield (tag, in_dims, *in_vals), {}
|
||||
outs, trace = f(tag, in_dims, *in_vals)
|
||||
with core.ensure_no_leaks(trace): del trace
|
||||
yield outs
|
||||
return outs
|
||||
|
||||
@lu.transformation
|
||||
def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals):
|
||||
@lu.transformation2
|
||||
def _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals):
|
||||
in_dims = in_dims() if callable(in_dims) else in_dims
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = BatchTrace(parent_trace, tag, axis_data)
|
||||
@ -599,13 +601,13 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals):
|
||||
with (core.set_current_trace(trace),
|
||||
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
|
||||
core.add_spmd_axis_names(axis_data.spmd_name)):
|
||||
outs = yield in_tracers, {}
|
||||
outs = f(*in_tracers)
|
||||
|
||||
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
||||
out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)),
|
||||
outs, out_dim_dests)
|
||||
|
||||
yield out_vals, trace
|
||||
return out_vals, trace
|
||||
|
||||
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
||||
def vtile(f_flat: lu.WrappedFun,
|
||||
@ -628,21 +630,21 @@ def vtile(f_flat: lu.WrappedFun,
|
||||
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
|
||||
return out.reshape(shape)
|
||||
|
||||
@lu.transformation
|
||||
def _map_to_tile(*args_flat):
|
||||
@lu.transformation2
|
||||
def _map_to_tile(f, *args_flat):
|
||||
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
|
||||
tile_size_ = tile_size or next(sizes, None)
|
||||
assert tile_size_ is not None, "No mapped arguments?"
|
||||
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
|
||||
yield map(untile_axis, outputs_flat, out_axes_flat)
|
||||
outputs_flat = f(*map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat))
|
||||
return map(untile_axis, outputs_flat, out_axes_flat)
|
||||
|
||||
axis_data = AxisData(axis_name, tile_size, None)
|
||||
return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat))
|
||||
|
||||
### API for batching functions with jaxpr type inputs and outputs
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def batch_subtrace(tag, axis_data, in_dims, *in_vals):
|
||||
@lu.transformation_with_aux2
|
||||
def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = BatchTrace(parent_trace, tag, axis_data)
|
||||
with core.set_current_trace(trace):
|
||||
@ -650,10 +652,11 @@ def batch_subtrace(tag, axis_data, in_dims, *in_vals):
|
||||
in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims)
|
||||
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
|
||||
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
|
||||
outs = yield in_tracers, {}
|
||||
outs = f(*in_tracers)
|
||||
out_vals, out_dims = unzip2(map(trace.to_batch_info, outs))
|
||||
segment_lens, out_dims = indirectify_ragged_axes(out_dims)
|
||||
yield (*segment_lens, *out_vals), out_dims
|
||||
store.store(out_dims)
|
||||
return (*segment_lens, *out_vals)
|
||||
|
||||
def indirectify_ragged_axes(dims):
|
||||
if not any(type(d) is RaggedAxis for d in dims):
|
||||
@ -789,8 +792,8 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
|
||||
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
|
||||
@lu.transformation_with_aux2
|
||||
def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = BatchTrace(parent_trace, tag, axis_data)
|
||||
_, in_axes = resolve_ragged_axes(in_vals, in_axes)
|
||||
@ -799,16 +802,17 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
|
||||
with (core.set_current_trace(trace),
|
||||
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
|
||||
core.add_spmd_axis_names(axis_data.spmd_name)):
|
||||
outs = yield in_tracers, {}
|
||||
outs = f(*in_tracers)
|
||||
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
|
||||
new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
|
||||
out_axes, in_vals, out_vals)
|
||||
yield out_vals, new_out_axes
|
||||
store.store(new_out_axes)
|
||||
return out_vals
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes,
|
||||
@lu.transformation_with_aux2
|
||||
def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_axes,
|
||||
*in_vals):
|
||||
out_vals = yield (trace, in_axes, *in_vals), {}
|
||||
out_vals = f(trace, in_axes, *in_vals)
|
||||
out_axes = out_axes()
|
||||
out_axes_dest = [(None if src is not_mapped else 0)
|
||||
if dst is zero_if_mapped else dst
|
||||
@ -819,16 +823,16 @@ def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes,
|
||||
out_vals = map(partial(matchaxis, axis_data.name, axis_data.size),
|
||||
out_axes, out_axes_dest, out_vals)
|
||||
out_batched = [dst is not None for dst in out_axes_dest]
|
||||
yield out_vals, out_batched
|
||||
store.store(out_batched)
|
||||
return out_vals
|
||||
|
||||
@lu.transformation
|
||||
def _batch_jaxpr_outer(axis_data, in_dims, *in_vals):
|
||||
@lu.transformation2
|
||||
def _batch_jaxpr_outer(f, axis_data, in_dims, *in_vals):
|
||||
in_dims = in_dims() if callable(in_dims) else in_dims
|
||||
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
|
||||
else ax for x, ax in unsafe_zip(in_vals, in_dims)]
|
||||
tag = TraceTag()
|
||||
out_vals = yield (tag, in_dims, *in_vals), {}
|
||||
yield out_vals
|
||||
return f(tag, in_dims, *in_vals)
|
||||
|
||||
def _merge_bdims(x, y):
|
||||
if x == y:
|
||||
@ -845,8 +849,8 @@ zero_if_mapped = ZeroIfMapped()
|
||||
|
||||
### functions for handling custom_vjp
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals):
|
||||
@lu.transformation_with_aux2
|
||||
def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
|
||||
size = axis_data.size
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = BatchTrace(parent_trace, tag, axis_data)
|
||||
@ -855,7 +859,7 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals):
|
||||
if type(val) is SymbolicZero else BatchTracer(trace, val, dim)
|
||||
for val, dim in zip(in_vals, in_dims * 2)]
|
||||
with core.set_current_trace(trace):
|
||||
outs = yield in_tracers, {}
|
||||
outs = f(*in_tracers)
|
||||
# TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can
|
||||
# be wasteful in the rare case it actually triggers; handle symbolically!
|
||||
outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs]
|
||||
@ -868,7 +872,8 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals):
|
||||
out_primal_bds, out_dims, out_primals)
|
||||
out_tangents = map(partial(matchaxis, trace.axis_data.name, size),
|
||||
out_tangent_bds, out_dims, out_tangents)
|
||||
yield out_primals + out_tangents, out_dims * 2
|
||||
store.store(out_dims * 2)
|
||||
return out_primals + out_tangents
|
||||
|
||||
def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
|
||||
axis_size = axis_data.size
|
||||
@ -886,11 +891,11 @@ def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
|
||||
return bwd_.call_wrapped(*args)
|
||||
return new_bwd
|
||||
|
||||
@lu.transformation
|
||||
def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
@lu.transformation2
|
||||
def _match_axes_and_sum(f, axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
# this is like _match_axes, but we do reduce-sums as needed
|
||||
out_vals = yield in_vals, {}
|
||||
yield map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name,
|
||||
out_vals = f(*in_vals)
|
||||
return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name,
|
||||
sum_match=True), out_dims_thunk(), out_dim_dests, out_vals)
|
||||
|
||||
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
|
||||
|
@ -475,18 +475,19 @@ def partition_pvals(
|
||||
consts = [pval.get_known() for pval in pvals if pval.is_known()]
|
||||
return knowns, avals, consts
|
||||
|
||||
@lu.transformation_with_aux
|
||||
@lu.transformation_with_aux2
|
||||
def partial_eval_wrapper_nounits(
|
||||
in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue],
|
||||
f, store, in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue],
|
||||
*in_consts: Any):
|
||||
in_avals_, in_consts_ = iter(in_avals), iter(in_consts)
|
||||
in_pvals = [PartialVal.known(next(in_consts_)) if known else
|
||||
PartialVal.unknown(next(in_avals_)) for known in in_knowns]
|
||||
sentinel = object()
|
||||
assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel
|
||||
jaxpr, (*maybe_fwds, out_pvals, res, env) = yield (in_pvals,), {}
|
||||
jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals)
|
||||
out_knowns, out_avals, out_consts = partition_pvals(out_pvals)
|
||||
yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env)
|
||||
store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env))
|
||||
return (*out_consts, *res)
|
||||
|
||||
custom_partial_eval_rules: dict[Primitive, Callable] = {}
|
||||
call_partial_eval_rules: dict[Primitive, Callable] = {}
|
||||
@ -574,20 +575,22 @@ def trace_to_jaxpr_nounits(
|
||||
return jaxpr, out_pvals, consts
|
||||
|
||||
# TODO(mattjj): superfluous wrapper...?
|
||||
@lu.transformation
|
||||
@lu.transformation2
|
||||
def trace_to_subjaxpr_nounits(
|
||||
f,
|
||||
trace: JaxprTrace,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
|
||||
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
|
||||
trace, instantiate, in_pvals)
|
||||
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
|
||||
f, trace, instantiate, in_pvals)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
del out_tracers
|
||||
yield jaxpr, (out_pvals, out_consts, env)
|
||||
return jaxpr, (out_pvals, out_consts, env)
|
||||
|
||||
@lu.transformation
|
||||
@lu.transformation2
|
||||
def trace_to_subjaxpr_nounits2(
|
||||
f,
|
||||
tag: TraceTag,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
@ -596,19 +599,19 @@ def trace_to_subjaxpr_nounits2(
|
||||
current_name_stack = source_info_util.current_name_stack()
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JaxprTrace(parent_trace, current_name_stack, tag)
|
||||
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
|
||||
trace, instantiate, in_pvals)
|
||||
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
|
||||
f, trace, instantiate, in_pvals)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
del out_tracers
|
||||
yield jaxpr, (out_pvals, out_consts, env)
|
||||
return jaxpr, (out_pvals, out_consts, env)
|
||||
|
||||
def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals):
|
||||
def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals):
|
||||
in_knowns = [pval.is_known() for pval in in_pvals]
|
||||
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
|
||||
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
|
||||
in_args = merge_lists(in_knowns, in_tracers, in_consts)
|
||||
with core.set_current_trace(trace):
|
||||
ans = yield in_args, {}
|
||||
ans = f(*in_args)
|
||||
assert isinstance(ans, (list, tuple)), (
|
||||
f"Got unexpected return type when tracing function to jaxpr: {ans}")
|
||||
assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), (
|
||||
@ -625,8 +628,9 @@ def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals):
|
||||
# The below variant implements an optimization where residuals which are also
|
||||
# inputs are indicated in auxiliary data rather than passed as outputs.
|
||||
# TODO(mattjj): update all callers to use this version, delete other version.
|
||||
@lu.transformation
|
||||
@lu.transformation2
|
||||
def trace_to_subjaxpr_nounits_fwd(
|
||||
f,
|
||||
tag: TraceTag,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
@ -635,8 +639,8 @@ def trace_to_subjaxpr_nounits_fwd(
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JaxprTrace(parent_trace, current_name_stack, tag)
|
||||
with core.set_current_trace(trace):
|
||||
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
|
||||
trace, instantiate, in_pvals)
|
||||
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
|
||||
f, trace, instantiate, in_pvals)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
|
||||
# Which out_consts (aka residuals) are just forwarded inputs? Check obj id.
|
||||
@ -646,15 +650,16 @@ def trace_to_subjaxpr_nounits_fwd(
|
||||
pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None]
|
||||
|
||||
del out_tracers
|
||||
yield jaxpr, (fwds, out_pvals, pruned_consts, env)
|
||||
return jaxpr, (fwds, out_pvals, pruned_consts, env)
|
||||
|
||||
# The below variant implements two optimizations:
|
||||
# 1. residuals that are also primal inputs are indicated in aux data rather
|
||||
# than passed as outputs;
|
||||
# 2. residuals that are also primal outputs are indicated in aux data rather
|
||||
# than passed as redundant outputs.
|
||||
@lu.transformation
|
||||
@lu.transformation2
|
||||
def trace_to_subjaxpr_nounits_fwd2(
|
||||
f,
|
||||
tag: TraceTag,
|
||||
instantiate: bool | Sequence[bool],
|
||||
in_pvals: Sequence[PartialVal]):
|
||||
@ -662,8 +667,8 @@ def trace_to_subjaxpr_nounits_fwd2(
|
||||
current_name_stack = source_info_util.current_name_stack()
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JaxprTrace(parent_trace, current_name_stack, tag)
|
||||
out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits(
|
||||
trace, instantiate, in_pvals)
|
||||
out_tracers, jaxpr, consts, env = _trace_to_subjaxpr_nounits(
|
||||
f, trace, instantiate, in_pvals)
|
||||
out_pvals = [t.pval for t in out_tracers]
|
||||
|
||||
# Which consts (aka residuals) are just forwarded inputs? Check obj id.
|
||||
@ -680,7 +685,7 @@ def trace_to_subjaxpr_nounits_fwd2(
|
||||
if f1 is None and f2 is None]
|
||||
|
||||
del out_tracers
|
||||
yield jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env)
|
||||
return jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env)
|
||||
|
||||
|
||||
FreeVar = namedtuple('FreeVar', ['val'])
|
||||
@ -2066,10 +2071,10 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
custom_staging_rules: dict[Primitive, Callable] = {}
|
||||
|
||||
@lu.transformation
|
||||
def _interleave_fun(every_others, *args, **kwargs):
|
||||
@lu.transformation2
|
||||
def _interleave_fun(f, every_others, *args, **kwargs):
|
||||
args_ = [x for pair in zip(args, every_others) for x in pair]
|
||||
yield (yield (args_, kwargs))
|
||||
return f(*args_, **kwargs)
|
||||
|
||||
# TODO: consider renaming to "lazy_thunk"
|
||||
def _memoize(fn):
|
||||
@ -2083,18 +2088,19 @@ def _memoize(fn):
|
||||
return out
|
||||
return memoized
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _jvp_jaxpr_zeros(in_zeros, zero_avals, *primal_tangent_avals):
|
||||
@lu.transformation_with_aux2
|
||||
def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
|
||||
in_primals, nz_in_tangents = split_list(primal_tangent_avals, [len(in_zeros)])
|
||||
symbolic_zeros = map(ad_util.SymbolicZero, zero_avals)
|
||||
tangents = merge_lists(in_zeros, nz_in_tangents, symbolic_zeros)
|
||||
out = yield (*in_primals, *tangents), {}
|
||||
out = f(*in_primals, *tangents)
|
||||
n, ragged = divmod(len(out), 2)
|
||||
assert not ragged
|
||||
out_primals, out_tangents = out[:n], out[n:]
|
||||
out_zeros = [type(t) is ad_util.SymbolicZero for t in out_tangents]
|
||||
out_nz_tangents, _ = partition_list(out_zeros, out_tangents)
|
||||
yield [*out_primals, *out_nz_tangents], out_zeros
|
||||
store.store(out_zeros)
|
||||
return [*out_primals, *out_nz_tangents]
|
||||
|
||||
# TODO(mattjj): remove this DebugInfo and helper functions, replace with
|
||||
# api_util.py versions
|
||||
|
@ -690,15 +690,15 @@ def find_replicas(
|
||||
num_global_replicas = global_axis_size * jaxpr_replicas
|
||||
return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas)
|
||||
|
||||
@lu.transformation
|
||||
def _change_argument_ranks(in_axes, out_axes_thunk, *args):
|
||||
@lu.transformation2
|
||||
def _change_argument_ranks(f, in_axes, out_axes_thunk, *args):
|
||||
args = tuple(
|
||||
arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,))
|
||||
for in_axis, arg in zip(in_axes, args)
|
||||
)
|
||||
results = yield (args, {})
|
||||
results = f(*args)
|
||||
out_axes = out_axes_thunk()
|
||||
yield tuple(
|
||||
return tuple(
|
||||
x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,))
|
||||
for x, axis in zip(results, out_axes)
|
||||
)
|
||||
|
@ -64,6 +64,7 @@ data must be immutable, because it will be stored in function memoization tables
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any, NamedTuple
|
||||
import weakref
|
||||
|
||||
@ -149,10 +150,11 @@ class WrappedFun:
|
||||
params: extra parameters to pass as keyword arguments to `f`, along with the
|
||||
transformed keyword arguments.
|
||||
"""
|
||||
__slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info")
|
||||
__slots__ = ("f", "f_transformed", "transforms", "stores", "params", "in_type", "debug_info")
|
||||
|
||||
def __init__(self, f, transforms, stores, params, in_type, debug_info):
|
||||
def __init__(self, f, f_transformed, transforms, stores, params, in_type, debug_info):
|
||||
self.f = f
|
||||
self.f_transformed = f_transformed
|
||||
self.transforms = transforms
|
||||
self.stores = stores
|
||||
self.params = params
|
||||
@ -165,8 +167,14 @@ class WrappedFun:
|
||||
|
||||
def wrap(self, gen, gen_static_args, out_store) -> WrappedFun:
|
||||
"""Add another transform and its store."""
|
||||
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
|
||||
(out_store,) + self.stores, self.params, None, None)
|
||||
if out_store is None:
|
||||
return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args),
|
||||
((gen, gen_static_args),) + self.transforms,
|
||||
(out_store,) + self.stores, self.params, None, None)
|
||||
else:
|
||||
return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args),
|
||||
((gen, gen_static_args),) + self.transforms,
|
||||
(out_store,) + self.stores, self.params, None, None)
|
||||
|
||||
def populate_stores(self, stores):
|
||||
"""Copy the values from the `stores` into `self.stores`."""
|
||||
@ -175,47 +183,8 @@ class WrappedFun:
|
||||
self_store.store(other_store.val)
|
||||
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
"""Calls the underlying function, applying the transforms.
|
||||
|
||||
The positional `args` and keyword `kwargs` are passed to the first
|
||||
transformation generator.
|
||||
"""
|
||||
stack = []
|
||||
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
|
||||
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
|
||||
args, kwargs = next(gen)
|
||||
stack.append((gen, out_store))
|
||||
gen = gen_static_args = out_store = None
|
||||
|
||||
try:
|
||||
ans = self.f(*args, **dict(self.params, **kwargs))
|
||||
except:
|
||||
# Some transformations yield from inside context managers, so we have to
|
||||
# interrupt them before reraising the exception. Otherwise they will only
|
||||
# get garbage-collected at some later time, running their cleanup tasks
|
||||
# only after this exception is handled, which can corrupt the global
|
||||
# state.
|
||||
while stack:
|
||||
stack.pop()[0].close()
|
||||
raise
|
||||
|
||||
args = kwargs = None
|
||||
while stack:
|
||||
gen, out_store = stack.pop()
|
||||
try:
|
||||
ans = gen.send(ans)
|
||||
except:
|
||||
# As above does for the first half of the transformation, exceptions
|
||||
# raised in the second half of the transformation also require us to
|
||||
# clean up references here.
|
||||
while stack:
|
||||
stack.pop()[0].close()
|
||||
raise
|
||||
if out_store is not None:
|
||||
ans, side = ans
|
||||
out_store.store(side)
|
||||
|
||||
return ans
|
||||
"""Calls the transformed function"""
|
||||
return self.f_transformed(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
def transform_to_str(x):
|
||||
@ -234,7 +203,7 @@ class WrappedFun:
|
||||
self.debug_info == other.debug_info)
|
||||
|
||||
@curry
|
||||
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
def transformation2(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
"""Adds one more transformation to a WrappedFun.
|
||||
|
||||
Args:
|
||||
@ -244,8 +213,28 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
"""
|
||||
return fun.wrap(gen, gen_static_args, None)
|
||||
|
||||
# Backwards compat only. TODO: deprecate
|
||||
@curry
|
||||
def transformation_with_aux(
|
||||
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
def gen2(f, *args, **kwargs):
|
||||
gen_inst = gen(*args, **kwargs)
|
||||
args_, kwargs_ = next(gen_inst)
|
||||
return gen_inst.send(f(*args_, **kwargs_))
|
||||
return transformation2(gen2, fun, *gen_static_args)()
|
||||
|
||||
# Backwards compat only. TODO: deprecate
|
||||
@curry
|
||||
def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
def gen2(f, store, *args, **kwargs):
|
||||
gen_inst = gen(*args, **kwargs)
|
||||
args_, kwargs_ = next(gen_inst)
|
||||
ans, aux = gen_inst.send(f(*args_, **kwargs_))
|
||||
store.store(aux)
|
||||
return ans
|
||||
return transformation_with_aux2(gen2, fun, *gen_static_args)()
|
||||
|
||||
@curry
|
||||
def transformation_with_aux2(
|
||||
gen, fun: WrappedFun, *gen_static_args, use_eq_store: bool = False
|
||||
) -> tuple[WrappedFun, Callable[[], Any]]:
|
||||
"""Adds one more transformation with auxiliary output to a WrappedFun."""
|
||||
@ -261,8 +250,9 @@ def fun_name(f):
|
||||
|
||||
def wrap_init(f, params=None) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
params_dict = {} if params is None else params
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
return WrappedFun(f, (), (), params, None, None)
|
||||
return WrappedFun(f, partial(f, **params_dict), (), (), params, None, None)
|
||||
|
||||
|
||||
def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
|
||||
@ -270,7 +260,7 @@ def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
|
||||
if in_type is None:
|
||||
return f
|
||||
_check_input_type(in_type)
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info)
|
||||
return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, in_type, f.debug_info)
|
||||
|
||||
def _check_input_type(in_type: core.InputType) -> None:
|
||||
# Check that in_type is syntactically well-formed
|
||||
@ -317,7 +307,7 @@ def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None
|
||||
assert f.debug_info is None
|
||||
if debug_info is None:
|
||||
return f
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info)
|
||||
return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info)
|
||||
|
||||
|
||||
def cache(call: Callable, *, explain: Callable | None = None):
|
||||
@ -357,9 +347,9 @@ def cache(call: Callable, *, explain: Callable | None = None):
|
||||
cache_clearing_funs.add(memoized_fun.cache_clear)
|
||||
return memoized_fun
|
||||
|
||||
@transformation
|
||||
def hashable_partial(*args):
|
||||
yield (yield args, {})
|
||||
@transformation2
|
||||
def hashable_partial(f, *args):
|
||||
return f(*args)
|
||||
|
||||
|
||||
def merge_linear_aux(aux1, aux2):
|
||||
|
@ -824,14 +824,13 @@ def debug_print_lowering_rule(ctx, *args, **params):
|
||||
# because they should appear as atomic JAX values to the users.
|
||||
# TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU
|
||||
# inferred by the compiler.
|
||||
@lu.transformation
|
||||
def wrap_with_transforms(transforms, *args):
|
||||
@lu.transformation2
|
||||
def wrap_with_transforms(f, transforms, *args):
|
||||
new_args = tuple(
|
||||
state_types.TransformedRef(a, t) if t else a
|
||||
for a, t in zip(args, transforms)
|
||||
)
|
||||
res = yield new_args, {}
|
||||
yield res
|
||||
return f(*new_args)
|
||||
|
||||
|
||||
run_scoped_p = jax_core.Primitive("run_scoped")
|
||||
|
@ -97,34 +97,34 @@ def jvp(f, primals, tangents, attr_tangents):
|
||||
out_tangents = tree_unflatten(out_tree(), out_tangents_flat)
|
||||
return out_primals, out_tangents, tangent_attrs_out
|
||||
|
||||
@lu.transformation
|
||||
def _set_attrs(attrs, attr_vals, *args):
|
||||
@lu.transformation2
|
||||
def _set_attrs(f, attrs, attr_vals, *args):
|
||||
for (o, a), x in zip(attrs, attr_vals):
|
||||
jax_setattr(o, a, x)
|
||||
yield (yield args, {})
|
||||
return f(*args)
|
||||
|
||||
def _jvp(fun: lu.WrappedFun):
|
||||
return jvpfun2(jvp_subtrace2(fun))
|
||||
|
||||
@lu.transformation
|
||||
def jvpfun2(primals, tangents):
|
||||
@lu.transformation2
|
||||
def jvpfun2(f, primals, tangents):
|
||||
tag = core.TraceTag()
|
||||
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
|
||||
and dtype(t) == float0 else t for t in tangents]
|
||||
ctx = source_info_util.transform_name_stack('jvp')
|
||||
with ctx:
|
||||
out_primals, out_tangents, tangent_attrs_out = yield (tag, primals, tangents), {}
|
||||
yield out_primals, out_tangents, tangent_attrs_out
|
||||
out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents)
|
||||
return out_primals, out_tangents, tangent_attrs_out
|
||||
|
||||
@lu.transformation
|
||||
def jvp_subtrace2(tag, primals, tangents):
|
||||
@lu.transformation2
|
||||
def jvp_subtrace2(f, tag, primals, tangents):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = ad.JVPTrace(parent_trace, tag)
|
||||
tag.attrs_tracked = [] # attrs written to
|
||||
in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x
|
||||
for x, t in zip(primals, tangents)]
|
||||
with core.set_current_trace(trace):
|
||||
ans = yield in_tracers, {}
|
||||
ans = f(*in_tracers)
|
||||
out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans))
|
||||
tangent_attrs_out = []
|
||||
for (obj, name) in tag.attrs_tracked:
|
||||
@ -133,7 +133,7 @@ def jvp_subtrace2(tag, primals, tangents):
|
||||
if type(tangent) is not ad.Zero:
|
||||
tangent_attrs_out.append((obj, name, tangent))
|
||||
del tag.attrs_tracked
|
||||
yield out_primals, out_tangents, tangent_attrs_out
|
||||
return out_primals, out_tangents, tangent_attrs_out
|
||||
|
||||
def _setattr_jvp(trace, obj, attr, maybe_tracer):
|
||||
primal, tangent = trace.to_primal_tangent_pair(maybe_tracer)
|
||||
@ -175,11 +175,12 @@ def _linearize(traceable: lu.WrappedFun, *primals):
|
||||
return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals],
|
||||
jaxpr, consts, attrs())
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _split_attrs(*args, **kwargs):
|
||||
primals, tangents, tangent_attrs = yield args, kwargs
|
||||
@lu.transformation_with_aux2
|
||||
def _split_attrs(f, store, *args, **kwargs):
|
||||
primals, tangents, tangent_attrs = f(*args, **kwargs)
|
||||
attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs)
|
||||
yield (primals, tangents, tangent_attr_vals), attrs
|
||||
store.store(attrs)
|
||||
return primals, tangents, tangent_attr_vals
|
||||
|
||||
def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs):
|
||||
in_tree, out_tree = io_tree
|
||||
|
@ -1040,20 +1040,20 @@ def _convert_jax_impl(impl_jax: Callable, *,
|
||||
return wrapped_tf
|
||||
|
||||
|
||||
@lu.transformation
|
||||
def _interpret_subtrace(in_avals: Sequence[core.ShapedArray],
|
||||
@lu.transformation2
|
||||
def _interpret_subtrace(f, in_avals: Sequence[core.ShapedArray],
|
||||
*in_vals: TfVal):
|
||||
trace = TensorFlowTrace()
|
||||
in_tracers = tuple(
|
||||
TensorFlowTracer(trace, val, aval)
|
||||
for val, aval in zip(in_vals, in_avals))
|
||||
with core.set_current_trace(trace):
|
||||
outs = yield in_tracers, {} # type: Sequence[TfVal]
|
||||
outs = f(*in_tracers)
|
||||
out_tracers: Iterable[TensorFlowTracer] = (
|
||||
map(trace.to_tf_tracer, outs))
|
||||
out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = (
|
||||
tuple((t.val, t.aval) for t in out_tracers))
|
||||
yield out_vals_with_avals
|
||||
return out_vals_with_avals
|
||||
|
||||
|
||||
def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal,
|
||||
|
@ -141,40 +141,43 @@ def jet(fun, primals, series):
|
||||
if not treedef_is_leaf(treedef):
|
||||
raise ValueError(f"term {j} for argument {i} is not an array")
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_output(*args):
|
||||
ans = yield args, {}
|
||||
yield tree_flatten(ans)
|
||||
@lu.transformation_with_aux2
|
||||
def flatten_fun_output(f, store, *args):
|
||||
ans = f(*args)
|
||||
ans, tree = tree_flatten(ans)
|
||||
store.store(tree)
|
||||
return ans
|
||||
|
||||
f, out_tree = flatten_fun_output(lu.wrap_init(fun))
|
||||
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
|
||||
return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
|
||||
|
||||
@lu.transformation
|
||||
def jet_fun(order, primals, series):
|
||||
@lu.transformation2
|
||||
def jet_fun(f, order, primals, series):
|
||||
tag = core.TraceTag()
|
||||
out_primals, out_terms = yield (tag, order, primals, series), {}
|
||||
out_primals, out_terms = f(tag, order, primals, series)
|
||||
out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s
|
||||
for p, s in zip(out_primals, out_terms)]
|
||||
yield out_primals, out_terms
|
||||
return out_primals, out_terms
|
||||
|
||||
@lu.transformation
|
||||
def jet_subtrace(tag, order, primals, series):
|
||||
@lu.transformation2
|
||||
def jet_subtrace(f, tag, order, primals, series):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = JetTrace(tag, parent_trace, order)
|
||||
in_tracers = map(partial(JetTracer, trace), primals, series)
|
||||
with core.set_current_trace(trace):
|
||||
ans = yield in_tracers, {}
|
||||
ans = f(*in_tracers)
|
||||
|
||||
out_primals, out_terms = unzip2(map(trace.to_primal_terms_pair, ans))
|
||||
yield out_primals, out_terms
|
||||
return out_primals, out_terms
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def traceable(in_tree_def, *primals_and_series):
|
||||
@lu.transformation_with_aux2
|
||||
def traceable(f, store, in_tree_def, *primals_and_series):
|
||||
primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series)
|
||||
primals_out, series_out = yield (primals_in, series_in), {}
|
||||
primals_out, series_out = f(primals_in, series_in)
|
||||
out_flat, out_tree_def = tree_flatten((primals_out, series_out))
|
||||
yield out_flat, out_tree_def
|
||||
store.store(out_tree_def)
|
||||
return out_flat
|
||||
|
||||
|
||||
class JetTracer(core.Tracer):
|
||||
|
@ -47,12 +47,12 @@ zip = safe_zip
|
||||
def ravel_first_arg(f, unravel):
|
||||
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
|
||||
|
||||
@lu.transformation
|
||||
def ravel_first_arg_(unravel, y_flat, *args):
|
||||
@lu.transformation2
|
||||
def ravel_first_arg_(f, unravel, y_flat, *args):
|
||||
y = unravel(y_flat)
|
||||
ans = yield (y,) + args, {}
|
||||
ans = f(y, *args)
|
||||
ans_flat, _ = ravel_pytree(ans)
|
||||
yield ans_flat
|
||||
return ans_flat
|
||||
|
||||
def interp_fit_dopri(y0, y1, k, dt):
|
||||
# Fit a polynomial to the results of a Runge-Kutta step.
|
||||
|
@ -1479,15 +1479,15 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
return pe.merge_lists(out_knowns, out_tracers, out_consts)
|
||||
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
|
||||
|
||||
@lu.transformation
|
||||
def _promote_scalar_residuals(*args, **kwargs):
|
||||
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs
|
||||
@lu.transformation2
|
||||
def _promote_scalar_residuals(f, *args, **kwargs):
|
||||
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs)
|
||||
which = [f1 is None and f2 is None and not v.aval.shape
|
||||
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
|
||||
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
|
||||
out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
||||
for x in out_consts]
|
||||
yield jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
|
||||
return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
|
||||
|
||||
def _promote_scalar_residuals_jaxpr(jaxpr, which):
|
||||
@lu.wrap_init
|
||||
@ -1728,13 +1728,13 @@ def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name):
|
||||
check_rep=False, auto=frozenset()),
|
||||
in_specs, out_specs)
|
||||
|
||||
@lu.transformation
|
||||
def _handle_reshapes(in_axes, out_axes_thunk, *args, **kwargs):
|
||||
@lu.transformation2
|
||||
def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
|
||||
args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax),
|
||||
list(args), list(in_axes))
|
||||
out = yield args, {}
|
||||
yield tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax),
|
||||
list(out), list(out_axes_thunk()))
|
||||
out = f(*args)
|
||||
return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax),
|
||||
list(out), list(out_axes_thunk()))
|
||||
|
||||
def _axis_to_spec(axis_name, ax):
|
||||
if isinstance(ax, int):
|
||||
@ -1855,27 +1855,28 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk):
|
||||
fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps)
|
||||
return _match_rep(fun, mesh, out_reps_src, out_reps_dst)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args):
|
||||
@lu.transformation_with_aux2
|
||||
def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args):
|
||||
with core.take_current_trace() as parent:
|
||||
tag = core.TraceTag()
|
||||
t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh)
|
||||
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
|
||||
with core.set_current_trace(t):
|
||||
ans = yield in_tracers, {}
|
||||
ans = f(*in_tracers)
|
||||
out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans))
|
||||
del t, in_tracers, ans
|
||||
yield out_vals, out_reps
|
||||
store.store(out_reps)
|
||||
return out_vals
|
||||
|
||||
@lu.transformation
|
||||
def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args):
|
||||
outs = yield args, {}
|
||||
@lu.transformation2
|
||||
def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args):
|
||||
outs = f(*args)
|
||||
out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_
|
||||
out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_
|
||||
_check_reps2(mesh, out_reps_dst, out_reps_src)
|
||||
outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst
|
||||
else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)]
|
||||
yield outs
|
||||
return outs
|
||||
|
||||
# TODO(mattjj): caching
|
||||
def _replication_rewrite_match(
|
||||
@ -1901,16 +1902,17 @@ def _replication_rewrite_nomatch(
|
||||
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
|
||||
return core.ClosedJaxpr(jaxpr_, consts), out_rep()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _rewrite_subtrace(tag, mesh, in_reps, *in_vals):
|
||||
@lu.transformation_with_aux2
|
||||
def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals):
|
||||
with core.take_current_trace() as parent_trace:
|
||||
assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals))
|
||||
t = RewriteTrace(parent_trace, tag, mesh)
|
||||
in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals)
|
||||
with core.set_current_trace(t):
|
||||
outs = yield in_tracers, {}
|
||||
ans = unzip2(map(t.to_val_rep_pair, outs))
|
||||
yield ans
|
||||
outs = f(*in_tracers)
|
||||
out_vals, out_reps = unzip2(map(t.to_val_rep_pair, outs))
|
||||
store.store(out_reps)
|
||||
return out_vals
|
||||
|
||||
def _rewrite_bwd(bwd, mesh, in_reps, reps_dst):
|
||||
def new_bwd(*args):
|
||||
|
@ -340,16 +340,17 @@ class SparseTrace(core.Trace):
|
||||
with core.set_current_trace(self):
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def sparsify_subtrace(tag, spenv, spvalues, *bufs):
|
||||
@lu.transformation_with_aux2
|
||||
def sparsify_subtrace(f, store, tag, spenv, spvalues, *bufs):
|
||||
with core.take_current_trace() as parent:
|
||||
trace = SparseTrace(parent, tag, spenv)
|
||||
with core.set_current_trace(trace):
|
||||
in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues]
|
||||
outs = yield in_tracers, {}
|
||||
outs = f(*in_tracers)
|
||||
out_traces = [trace.to_sparse_tracer(out) for out in outs]
|
||||
buffers = spenv._buffers
|
||||
yield buffers, [out._spvalue for out in out_traces]
|
||||
store.store([out._spvalue for out in out_traces])
|
||||
return buffers
|
||||
|
||||
def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]):
|
||||
tag = core.TraceTag()
|
||||
|
@ -22,5 +22,7 @@ from jax._src.linear_util import (
|
||||
merge_linear_aux as merge_linear_aux,
|
||||
transformation as transformation,
|
||||
transformation_with_aux as transformation_with_aux,
|
||||
transformation2 as transformation2,
|
||||
transformation_with_aux2 as transformation_with_aux2,
|
||||
wrap_init as wrap_init,
|
||||
)
|
||||
|
@ -42,8 +42,8 @@ class UtilTest(jtu.JaxTestCase):
|
||||
assert not kwargs
|
||||
return tuple(a * factor for a in args)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def kw_to_positional(factor, *args, **kwargs):
|
||||
@lu.transformation_with_aux2
|
||||
def kw_to_positional(f, store, factor, *args, **kwargs):
|
||||
"""A transformation with auxiliary output.
|
||||
Turns all keyword parameters into positional ones.
|
||||
|
||||
@ -55,12 +55,12 @@ class UtilTest(jtu.JaxTestCase):
|
||||
kwargs_keys = kwargs.keys()
|
||||
new_args = tuple(kwargs[k] for k in kwargs_keys)
|
||||
new_kwargs = dict(factor=factor)
|
||||
results = yield args + new_args, new_kwargs # Yield transformed (args, kwargs)
|
||||
results = f(*(args + new_args), **new_kwargs) # Yield transformed (args, kwargs)
|
||||
# Assume results correspond 1:1 to the args + new_args
|
||||
assert len(results) == len(args) + len(new_args)
|
||||
aux_output = len(new_args)
|
||||
yield (results[0:len(args)],
|
||||
dict(zip(kwargs_keys, results[len(args):]))), aux_output
|
||||
store.store(aux_output)
|
||||
return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):])))
|
||||
|
||||
wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`.
|
||||
wf, out_thunk = kw_to_positional(wf, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user