Stop using generators for linear_util transformations.

They lead to confusing code, nasty bugs, and unhelpful (but terse!) stack traces.
This commit is contained in:
Dougal 2024-11-12 22:39:26 -08:00
parent ed9fdbbf0a
commit 1c9b23c566
17 changed files with 311 additions and 287 deletions

View File

@ -68,11 +68,13 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
else:
return tuple(map(_ensure_str, x))
@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
@lu.transformation_with_aux2
def flatten_fun(f, store, in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)
ans = f(*py_args, **py_kwargs)
ans, out_tree = tree_flatten(ans)
store.store(out_tree)
return ans
def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
@ -82,11 +84,13 @@ def apply_flat_fun(fun, io_tree, *py_args):
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
@lu.transformation_with_aux2
def flatten_fun_nokwargs(f, store, in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
yield tree_flatten(ans)
ans = f(*py_args)
ans, out_tree = tree_flatten(ans)
store.store(out_tree)
return ans
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
@ -118,17 +122,18 @@ def flattened_fun_in_tree(
else:
return in_tree, lambda: out_tree_store.val, has_kwargs
@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
@lu.transformation_with_aux2
def flatten_fun_nokwargs2(f, store, in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
pair = yield py_args, {}
pair = f(*py_args)
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
raise TypeError("expected function with aux output to return a two-element "
f"tuple, but got type {type(pair)} with value {pair!r}")
ans, aux = pair
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
store.store((ans_tree, aux_tree))
return ans_flat, aux_flat
class _HashableWithStrictTypeEquality:
"""Box object used when comparing static arguments as a jit key.
@ -277,8 +282,8 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
@lu.transformation2
def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs):
sentinel = object()
args = [sentinel] * (len(fixed_args) + len(dyn_args))
for i, arg in zip(dyn_argnums, dyn_args):
@ -286,9 +291,7 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
fixed_args_ = iter(fixed_args)
args = [next(fixed_args_).val if x is sentinel else x for x in args]
assert next(fixed_args_, sentinel) is sentinel
ans = yield args, kwargs
yield ans
return f(*args, **kwargs)
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
kwargs: dict[str, Any]):
@ -311,11 +314,10 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
@lu.transformation
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
@lu.transformation2
def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
ans = yield args, kwargs
yield ans
return f(*args, **kwargs)
@lru_cache(maxsize=4096)
@ -435,9 +437,9 @@ def flat_out_axes(
f, out_axes = _flat_out_axes(f, tuple(leaves), treedef)
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
@lu.transformation_with_aux
def _flat_out_axes(leaves, treedef, *args, **kwargs):
ans = yield args, kwargs
@lu.transformation_with_aux2
def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
ans = f(*args, **kwargs)
spec = tree_unflatten(treedef, leaves)
try:
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
@ -449,7 +451,8 @@ def _flat_out_axes(leaves, treedef, *args, **kwargs):
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
"pmapped function's output.")
raise ValueError(msg) from None
yield ans, spec_flat
store.store(spec_flat)
return ans
def check_callable(fun):
# In Python 3.10+, the only thing stopping us from supporting staticmethods
@ -683,11 +686,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
@lu.transformation_with_aux
def result_paths(*args, **kwargs):
@lu.transformation_with_aux2
def result_paths(f, store, *args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
ans = yield args, kwargs
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]
ans = f(*args, **kwargs)
store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None,

View File

@ -330,11 +330,12 @@ def update_error(error, pred, code, metadata, payload, effect_type):
## Checkify transformation for plumbing functional error values.
@lu.transformation_with_aux
def _flatten_and_get_error_metadata_thunk(*invals):
error, out = yield invals, {}
@lu.transformation_with_aux2
def _flatten_and_get_error_metadata_thunk(f, store, *invals):
error, out = f(*invals)
out_vals, out_tree = jtu.tree_flatten((error, out))
yield out_vals, (out_tree, set(error._pred.keys()))
store.store((out_tree, set(error._pred.keys())))
return out_vals
def default_checkify_rule(primitive: core.Primitive, error: Error,
enabled_errors, *invals: core.Value,
@ -438,10 +439,12 @@ def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
consts = tuple(c.x for c in hashable_consts)
return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)
@lu.transformation_with_aux
def flatten_fun_output(*args):
ans = yield args, {}
yield tree_flatten(ans)
@lu.transformation_with_aux2
def flatten_fun_output(f, store, *args):
ans = f(*args)
ans, out_tree = tree_flatten(ans)
store.store(out_tree)
return ans
def _reduce_any_error(error: Error):

View File

@ -75,13 +75,14 @@ _stop_gradient = partial(
# like the api_util.py function, but also grabs output avals for error checking
@lu.transformation_with_aux
def _flatten_fun_nokwargs(in_tree, *args_flat):
@lu.transformation_with_aux2
def _flatten_fun_nokwargs(f, store, in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
ans = f(*py_args)
ans_flat, ans_tree = tree_flatten(ans)
ans_avals = [core.get_aval(x) for x in ans_flat]
yield ans_flat, (ans_tree, ans_avals)
store.store((ans_tree, ans_avals))
return ans_flat
### JVPs
@ -266,18 +267,18 @@ class custom_jvp(Generic[ReturnValue]):
def _add_args(f, extra_args):
return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args))
@lu.transformation
def _add_args_(extra_args, *args, **kwargs):
@lu.transformation2
def _add_args_(f, extra_args, *args, **kwargs):
extra_args = tuple(arg.val for arg in extra_args)
all_args = (extra_args + args)
yield (yield all_args, kwargs)
return f(*all_args, **kwargs)
@partial(lu.transformation_with_aux, use_eq_store=True)
def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
@partial(lu.transformation_with_aux2, use_eq_store=True)
def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args):
primals_in, tangents_in = split_list(args, [len(args) // 2])
py_primals = tree_unflatten(in_tree, primals_in)
py_tangents = tree_unflatten(in_tree, tangents_in)
pair_out = yield (py_primals, py_tangents), {}
pair_out = f(py_primals, py_tangents)
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) representing "
@ -348,7 +349,8 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
if av_et != av_t)
raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, (out_tree, primal_avals)
store.store((out_tree, primal_avals))
return primals_out + tangents_out
class CustomJVPCallPrimitive(core.Primitive):
multiple_results = True
@ -652,15 +654,15 @@ def _check_for_tracers(x):
"arguments should typically not be indicated as nondiff_argnums.")
raise UnexpectedTracerError(msg)
@partial(lu.transformation_with_aux, use_eq_store=True)
def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
@partial(lu.transformation_with_aux2, use_eq_store=True)
def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
*args):
if symbolic_zeros:
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
else:
args = args[::2]
py_args = tree_unflatten(in_tree, args)
pair_out = yield py_args, {}
pair_out = f(*py_args)
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
@ -710,16 +712,17 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
yield (*res, *primals_out), (out_tree, res_tree)
store.store((out_tree, res_tree))
return (*res, *primals_out)
@lu.transformation
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
@lu.transformation2
def _flatten_bwd(f, in_tree, in_avals, out_trees, *args):
out_tree, res_tree = out_trees()
assert len(args) == res_tree.num_leaves + out_tree.num_leaves
res, cts_out = split_list(args, [res_tree.num_leaves])
py_res = tree_unflatten(res_tree, res)
py_cts_out = tree_unflatten(out_tree, cts_out)
py_cts_in = yield (py_res, py_cts_out), {}
py_cts_in = f(py_res, py_cts_out)
if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)):
py_cts_in = tuple(py_cts_in)
# For each None in py_cts_in, indicating an argument for which the rule
@ -775,7 +778,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
f"to an input of shape/dtype {a.str_short()}.")
raise ValueError(msg)
results.append(ct)
yield results
return results
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
def _temporary_dtype_exception(a, a_) -> bool:
@ -1425,11 +1428,11 @@ def optimize_remat_of_custom_vjp_fwd(
return wrapped_fwd
@lu.transformation
def _fix_fwd_args(*args):
@lu.transformation2
def _fix_fwd_args(f, *args):
args = [(x, True) for x in args]
args = [x for pair in args for x in pair]
yield (yield args, {})
return f(*args)
def _remat_opt_impl(
*args,

View File

@ -68,42 +68,43 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
fun, aux = jvp_subtrace_aux(fun)
return jvpfun(fun, instantiate, transform_stack), aux
@lu.transformation
def jvpfun(instantiate, transform_stack, primals, tangents):
@lu.transformation2
def jvpfun(f, instantiate, transform_stack, primals, tangents):
tag = core.TraceTag()
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
and dtype(t) == float0 else t for t in tangents]
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
else contextlib.nullcontext())
with ctx:
out_primals, out_tangents = yield (tag, primals, tangents), {}
out_primals, out_tangents = f(tag, primals, tangents)
if type(instantiate) is bool:
instantiate = [instantiate] * len(out_tangents)
out_tangents = [instantiate_zeros(t) if inst else t for t, inst
in zip(out_tangents, instantiate)]
yield out_primals, out_tangents
return out_primals, out_tangents
@lu.transformation
def jvp_subtrace(tag, primals, tangents):
@lu.transformation2
def jvp_subtrace(f, tag, primals, tangents):
with core.take_current_trace() as parent_trace:
trace = JVPTrace(parent_trace, tag)
in_tracers = [maybe_jvp_tracer(trace, x, t)
for x, t in zip(primals, tangents)]
with core.set_current_trace(trace):
ans = yield in_tracers, {}
ans = f(*in_tracers)
out = unzip2(map(trace.to_primal_tangent_pair, ans))
yield out
return out
@lu.transformation_with_aux
def jvp_subtrace_aux(tag, primals, tangents):
@lu.transformation_with_aux2
def jvp_subtrace_aux(f, store, tag, primals, tangents):
with core.take_current_trace() as parent_trace:
trace = JVPTrace(parent_trace, tag)
with core.set_current_trace(trace):
ans, aux = yield map(partial(maybe_jvp_tracer, trace), primals, tangents), {}
ans, aux = f(*(map(partial(maybe_jvp_tracer, trace), primals, tangents)))
out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans))
aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag
else x for x in aux]
yield (out_primals, out_tangents), aux_primals
store.store(aux_primals)
return out_primals, out_tangents
def linearize(traceable, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False)
@ -262,10 +263,11 @@ def get_primitive_transpose(p):
"Transpose rule (for reverse-mode differentiation) for '{}' "
"not implemented".format(p)) from err
@lu.transformation_with_aux
def nonzero_tangent_outputs(*args, **kwargs):
results = (_, tangents_out) = yield args, kwargs
yield results, [type(r) is not Zero for r in tangents_out]
@lu.transformation_with_aux2
def nonzero_tangent_outputs(f, store, *args, **kwargs):
results = (_, tangents_out) = f(*args, **kwargs)
store.store([type(r) is not Zero for r in tangents_out])
return results
class JVPTrace(Trace):
@ -543,15 +545,16 @@ deflinear2(add_jaxvals_p, lambda t, *args: (t, t))
def instantiate_zeros(tangent):
return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent
@lu.transformation_with_aux
def traceable(in_tree, *primals_and_tangents):
@lu.transformation_with_aux2
def traceable(f, store, in_tree, *primals_and_tangents):
primals, tangents = tree_unflatten(in_tree, primals_and_tangents)
tangents = [Zero.from_primal_value(p) if t is None else t
for p, t in zip(primals, tangents)]
primals_out, tangents_out = yield (primals, tangents), {}
primals_out, tangents_out = f(primals, tangents)
tangents_out = [None if type(t) is Zero else t for t in tangents_out]
out_flat, out_tree = tree_flatten((primals_out, tangents_out))
yield out_flat, out_tree
store.store(out_tree)
return out_flat
def call_transpose(primitive, params, call_jaxpr, args, ct, _):
@ -588,10 +591,11 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals):
primitive_transposes[core.closed_call_p] = _closed_call_transpose
@lu.transformation_with_aux
def nonzero_outputs(*args, **kwargs):
results = yield args, kwargs
yield results, [type(r) is not Zero for r in results]
@lu.transformation_with_aux2
def nonzero_outputs(f, store, *args, **kwargs):
results = f(*args, **kwargs)
store.store([type(r) is not Zero for r in results])
return results
def map_transpose(primitive, params, call_jaxpr, args, ct, _):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
@ -655,17 +659,18 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
@lu.transformation_with_aux
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
@lu.transformation_with_aux2
def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents):
num_primals = len(nonzeros)
primals = list(primals_and_nztangents[:num_primals])
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p)
for p, nz in zip(primals, nonzeros)]
primals_out, tangents_out = yield (primals, tangents), {}
primals_out, tangents_out = f(primals, tangents)
out_nonzeros = [type(t) is not Zero for t in tangents_out]
nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero]
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
store.store(out_nonzeros)
return list(primals_out) + nonzero_tangents_out
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)

View File

@ -327,11 +327,13 @@ def unregister_vmappable(data_type: type) -> None:
def is_vmappable(x: Any) -> bool:
return type(x) is Jumble or type(x) in vmappables
@lu.transformation_with_aux
def flatten_fun_for_vmap(in_tree, *args_flat):
@lu.transformation_with_aux2
def flatten_fun_for_vmap(f, store, in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans, is_leaf=is_vmappable)
ans = f(*py_args, **py_kwargs)
ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable)
store.store(out_tree)
return ans
# Propagate ragged masking rules from invars to outvars
# rule([params], [raggedness_per_invar], outvars) ->
@ -580,16 +582,16 @@ def batch(fun: lu.WrappedFun, axis_data,
f = _batch_inner(fun, axis_data, out_dim_dests)
return _batch_outer(f, axis_data, in_dims)
@lu.transformation
def _batch_outer(axis_data, in_dims, *in_vals):
@lu.transformation2
def _batch_outer(f, axis_data, in_dims, *in_vals):
tag = TraceTag()
with source_info_util.transform_name_stack('vmap'):
outs, trace = yield (tag, in_dims, *in_vals), {}
outs, trace = f(tag, in_dims, *in_vals)
with core.ensure_no_leaks(trace): del trace
yield outs
return outs
@lu.transformation
def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals):
@lu.transformation2
def _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals):
in_dims = in_dims() if callable(in_dims) else in_dims
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
@ -599,13 +601,13 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals):
with (core.set_current_trace(trace),
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
core.add_spmd_axis_names(axis_data.spmd_name)):
outs = yield in_tracers, {}
outs = f(*in_tracers)
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)),
outs, out_dim_dests)
yield out_vals, trace
return out_vals, trace
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat: lu.WrappedFun,
@ -628,21 +630,21 @@ def vtile(f_flat: lu.WrappedFun,
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
return out.reshape(shape)
@lu.transformation
def _map_to_tile(*args_flat):
@lu.transformation2
def _map_to_tile(f, *args_flat):
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
tile_size_ = tile_size or next(sizes, None)
assert tile_size_ is not None, "No mapped arguments?"
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
yield map(untile_axis, outputs_flat, out_axes_flat)
outputs_flat = f(*map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat))
return map(untile_axis, outputs_flat, out_axes_flat)
axis_data = AxisData(axis_name, tile_size, None)
return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat))
### API for batching functions with jaxpr type inputs and outputs
@lu.transformation_with_aux
def batch_subtrace(tag, axis_data, in_dims, *in_vals):
@lu.transformation_with_aux2
def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
with core.set_current_trace(trace):
@ -650,10 +652,11 @@ def batch_subtrace(tag, axis_data, in_dims, *in_vals):
in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims)
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
outs = yield in_tracers, {}
outs = f(*in_tracers)
out_vals, out_dims = unzip2(map(trace.to_batch_info, outs))
segment_lens, out_dims = indirectify_ragged_axes(out_dims)
yield (*segment_lens, *out_vals), out_dims
store.store(out_dims)
return (*segment_lens, *out_vals)
def indirectify_ragged_axes(dims):
if not any(type(d) is RaggedAxis for d in dims):
@ -789,8 +792,8 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
@lu.transformation_with_aux
def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
@lu.transformation_with_aux2
def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals):
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
_, in_axes = resolve_ragged_axes(in_vals, in_axes)
@ -799,16 +802,17 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
with (core.set_current_trace(trace),
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
core.add_spmd_axis_names(axis_data.spmd_name)):
outs = yield in_tracers, {}
outs = f(*in_tracers)
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
out_axes, in_vals, out_vals)
yield out_vals, new_out_axes
store.store(new_out_axes)
return out_vals
@lu.transformation_with_aux
def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes,
@lu.transformation_with_aux2
def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_axes,
*in_vals):
out_vals = yield (trace, in_axes, *in_vals), {}
out_vals = f(trace, in_axes, *in_vals)
out_axes = out_axes()
out_axes_dest = [(None if src is not_mapped else 0)
if dst is zero_if_mapped else dst
@ -819,16 +823,16 @@ def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes,
out_vals = map(partial(matchaxis, axis_data.name, axis_data.size),
out_axes, out_axes_dest, out_vals)
out_batched = [dst is not None for dst in out_axes_dest]
yield out_vals, out_batched
store.store(out_batched)
return out_vals
@lu.transformation
def _batch_jaxpr_outer(axis_data, in_dims, *in_vals):
@lu.transformation2
def _batch_jaxpr_outer(f, axis_data, in_dims, *in_vals):
in_dims = in_dims() if callable(in_dims) else in_dims
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
else ax for x, ax in unsafe_zip(in_vals, in_dims)]
tag = TraceTag()
out_vals = yield (tag, in_dims, *in_vals), {}
yield out_vals
return f(tag, in_dims, *in_vals)
def _merge_bdims(x, y):
if x == y:
@ -845,8 +849,8 @@ zero_if_mapped = ZeroIfMapped()
### functions for handling custom_vjp
@lu.transformation_with_aux
def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals):
@lu.transformation_with_aux2
def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
size = axis_data.size
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
@ -855,7 +859,7 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals):
if type(val) is SymbolicZero else BatchTracer(trace, val, dim)
for val, dim in zip(in_vals, in_dims * 2)]
with core.set_current_trace(trace):
outs = yield in_tracers, {}
outs = f(*in_tracers)
# TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can
# be wasteful in the rare case it actually triggers; handle symbolically!
outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs]
@ -868,7 +872,8 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals):
out_primal_bds, out_dims, out_primals)
out_tangents = map(partial(matchaxis, trace.axis_data.name, size),
out_tangent_bds, out_dims, out_tangents)
yield out_primals + out_tangents, out_dims * 2
store.store(out_dims * 2)
return out_primals + out_tangents
def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
axis_size = axis_data.size
@ -886,11 +891,11 @@ def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
return bwd_.call_wrapped(*args)
return new_bwd
@lu.transformation
def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
@lu.transformation2
def _match_axes_and_sum(f, axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
# this is like _match_axes, but we do reduce-sums as needed
out_vals = yield in_vals, {}
yield map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name,
out_vals = f(*in_vals)
return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name,
sum_match=True), out_dims_thunk(), out_dim_dests, out_vals)
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):

View File

@ -475,18 +475,19 @@ def partition_pvals(
consts = [pval.get_known() for pval in pvals if pval.is_known()]
return knowns, avals, consts
@lu.transformation_with_aux
@lu.transformation_with_aux2
def partial_eval_wrapper_nounits(
in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue],
f, store, in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue],
*in_consts: Any):
in_avals_, in_consts_ = iter(in_avals), iter(in_consts)
in_pvals = [PartialVal.known(next(in_consts_)) if known else
PartialVal.unknown(next(in_avals_)) for known in in_knowns]
sentinel = object()
assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel
jaxpr, (*maybe_fwds, out_pvals, res, env) = yield (in_pvals,), {}
jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals)
out_knowns, out_avals, out_consts = partition_pvals(out_pvals)
yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env)
store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env))
return (*out_consts, *res)
custom_partial_eval_rules: dict[Primitive, Callable] = {}
call_partial_eval_rules: dict[Primitive, Callable] = {}
@ -574,20 +575,22 @@ def trace_to_jaxpr_nounits(
return jaxpr, out_pvals, consts
# TODO(mattjj): superfluous wrapper...?
@lu.transformation
@lu.transformation2
def trace_to_subjaxpr_nounits(
f,
trace: JaxprTrace,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
trace, instantiate, in_pvals)
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
f, trace, instantiate, in_pvals)
out_pvals = [t.pval for t in out_tracers]
del out_tracers
yield jaxpr, (out_pvals, out_consts, env)
return jaxpr, (out_pvals, out_consts, env)
@lu.transformation
@lu.transformation2
def trace_to_subjaxpr_nounits2(
f,
tag: TraceTag,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
@ -596,19 +599,19 @@ def trace_to_subjaxpr_nounits2(
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
trace = JaxprTrace(parent_trace, current_name_stack, tag)
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
trace, instantiate, in_pvals)
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
f, trace, instantiate, in_pvals)
out_pvals = [t.pval for t in out_tracers]
del out_tracers
yield jaxpr, (out_pvals, out_consts, env)
return jaxpr, (out_pvals, out_consts, env)
def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals):
def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals):
in_knowns = [pval.is_known() for pval in in_pvals]
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
in_args = merge_lists(in_knowns, in_tracers, in_consts)
with core.set_current_trace(trace):
ans = yield in_args, {}
ans = f(*in_args)
assert isinstance(ans, (list, tuple)), (
f"Got unexpected return type when tracing function to jaxpr: {ans}")
assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), (
@ -625,8 +628,9 @@ def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals):
# The below variant implements an optimization where residuals which are also
# inputs are indicated in auxiliary data rather than passed as outputs.
# TODO(mattjj): update all callers to use this version, delete other version.
@lu.transformation
@lu.transformation2
def trace_to_subjaxpr_nounits_fwd(
f,
tag: TraceTag,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
@ -635,8 +639,8 @@ def trace_to_subjaxpr_nounits_fwd(
with core.take_current_trace() as parent_trace:
trace = JaxprTrace(parent_trace, current_name_stack, tag)
with core.set_current_trace(trace):
out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits(
trace, instantiate, in_pvals)
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
f, trace, instantiate, in_pvals)
out_pvals = [t.pval for t in out_tracers]
# Which out_consts (aka residuals) are just forwarded inputs? Check obj id.
@ -646,15 +650,16 @@ def trace_to_subjaxpr_nounits_fwd(
pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None]
del out_tracers
yield jaxpr, (fwds, out_pvals, pruned_consts, env)
return jaxpr, (fwds, out_pvals, pruned_consts, env)
# The below variant implements two optimizations:
# 1. residuals that are also primal inputs are indicated in aux data rather
# than passed as outputs;
# 2. residuals that are also primal outputs are indicated in aux data rather
# than passed as redundant outputs.
@lu.transformation
@lu.transformation2
def trace_to_subjaxpr_nounits_fwd2(
f,
tag: TraceTag,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
@ -662,8 +667,8 @@ def trace_to_subjaxpr_nounits_fwd2(
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
trace = JaxprTrace(parent_trace, current_name_stack, tag)
out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits(
trace, instantiate, in_pvals)
out_tracers, jaxpr, consts, env = _trace_to_subjaxpr_nounits(
f, trace, instantiate, in_pvals)
out_pvals = [t.pval for t in out_tracers]
# Which consts (aka residuals) are just forwarded inputs? Check obj id.
@ -680,7 +685,7 @@ def trace_to_subjaxpr_nounits_fwd2(
if f1 is None and f2 is None]
del out_tracers
yield jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env)
return jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env)
FreeVar = namedtuple('FreeVar', ['val'])
@ -2066,10 +2071,10 @@ class DynamicJaxprTrace(core.Trace):
custom_staging_rules: dict[Primitive, Callable] = {}
@lu.transformation
def _interleave_fun(every_others, *args, **kwargs):
@lu.transformation2
def _interleave_fun(f, every_others, *args, **kwargs):
args_ = [x for pair in zip(args, every_others) for x in pair]
yield (yield (args_, kwargs))
return f(*args_, **kwargs)
# TODO: consider renaming to "lazy_thunk"
def _memoize(fn):
@ -2083,18 +2088,19 @@ def _memoize(fn):
return out
return memoized
@lu.transformation_with_aux
def _jvp_jaxpr_zeros(in_zeros, zero_avals, *primal_tangent_avals):
@lu.transformation_with_aux2
def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
in_primals, nz_in_tangents = split_list(primal_tangent_avals, [len(in_zeros)])
symbolic_zeros = map(ad_util.SymbolicZero, zero_avals)
tangents = merge_lists(in_zeros, nz_in_tangents, symbolic_zeros)
out = yield (*in_primals, *tangents), {}
out = f(*in_primals, *tangents)
n, ragged = divmod(len(out), 2)
assert not ragged
out_primals, out_tangents = out[:n], out[n:]
out_zeros = [type(t) is ad_util.SymbolicZero for t in out_tangents]
out_nz_tangents, _ = partition_list(out_zeros, out_tangents)
yield [*out_primals, *out_nz_tangents], out_zeros
store.store(out_zeros)
return [*out_primals, *out_nz_tangents]
# TODO(mattjj): remove this DebugInfo and helper functions, replace with
# api_util.py versions

View File

@ -690,15 +690,15 @@ def find_replicas(
num_global_replicas = global_axis_size * jaxpr_replicas
return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas)
@lu.transformation
def _change_argument_ranks(in_axes, out_axes_thunk, *args):
@lu.transformation2
def _change_argument_ranks(f, in_axes, out_axes_thunk, *args):
args = tuple(
arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,))
for in_axis, arg in zip(in_axes, args)
)
results = yield (args, {})
results = f(*args)
out_axes = out_axes_thunk()
yield tuple(
return tuple(
x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,))
for x, axis in zip(results, out_axes)
)

View File

@ -64,6 +64,7 @@ data must be immutable, because it will be stored in function memoization tables
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import Any, NamedTuple
import weakref
@ -149,10 +150,11 @@ class WrappedFun:
params: extra parameters to pass as keyword arguments to `f`, along with the
transformed keyword arguments.
"""
__slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info")
__slots__ = ("f", "f_transformed", "transforms", "stores", "params", "in_type", "debug_info")
def __init__(self, f, transforms, stores, params, in_type, debug_info):
def __init__(self, f, f_transformed, transforms, stores, params, in_type, debug_info):
self.f = f
self.f_transformed = f_transformed
self.transforms = transforms
self.stores = stores
self.params = params
@ -165,8 +167,14 @@ class WrappedFun:
def wrap(self, gen, gen_static_args, out_store) -> WrappedFun:
"""Add another transform and its store."""
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None, None)
if out_store is None:
return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args),
((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None, None)
else:
return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args),
((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None, None)
def populate_stores(self, stores):
"""Copy the values from the `stores` into `self.stores`."""
@ -175,47 +183,8 @@ class WrappedFun:
self_store.store(other_store.val)
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
ans = self.f(*args, **dict(self.params, **kwargs))
except:
# Some transformations yield from inside context managers, so we have to
# interrupt them before reraising the exception. Otherwise they will only
# get garbage-collected at some later time, running their cleanup tasks
# only after this exception is handled, which can corrupt the global
# state.
while stack:
stack.pop()[0].close()
raise
args = kwargs = None
while stack:
gen, out_store = stack.pop()
try:
ans = gen.send(ans)
except:
# As above does for the first half of the transformation, exceptions
# raised in the second half of the transformation also require us to
# clean up references here.
while stack:
stack.pop()[0].close()
raise
if out_store is not None:
ans, side = ans
out_store.store(side)
return ans
"""Calls the transformed function"""
return self.f_transformed(*args, **kwargs)
def __repr__(self):
def transform_to_str(x):
@ -234,7 +203,7 @@ class WrappedFun:
self.debug_info == other.debug_info)
@curry
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
def transformation2(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
"""Adds one more transformation to a WrappedFun.
Args:
@ -244,8 +213,28 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
"""
return fun.wrap(gen, gen_static_args, None)
# Backwards compat only. TODO: deprecate
@curry
def transformation_with_aux(
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
def gen2(f, *args, **kwargs):
gen_inst = gen(*args, **kwargs)
args_, kwargs_ = next(gen_inst)
return gen_inst.send(f(*args_, **kwargs_))
return transformation2(gen2, fun, *gen_static_args)()
# Backwards compat only. TODO: deprecate
@curry
def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
def gen2(f, store, *args, **kwargs):
gen_inst = gen(*args, **kwargs)
args_, kwargs_ = next(gen_inst)
ans, aux = gen_inst.send(f(*args_, **kwargs_))
store.store(aux)
return ans
return transformation_with_aux2(gen2, fun, *gen_static_args)()
@curry
def transformation_with_aux2(
gen, fun: WrappedFun, *gen_static_args, use_eq_store: bool = False
) -> tuple[WrappedFun, Callable[[], Any]]:
"""Adds one more transformation with auxiliary output to a WrappedFun."""
@ -261,8 +250,9 @@ def fun_name(f):
def wrap_init(f, params=None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params_dict = {} if params is None else params
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, (), (), params, None, None)
return WrappedFun(f, partial(f, **params_dict), (), (), params, None, None)
def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
@ -270,7 +260,7 @@ def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
if in_type is None:
return f
_check_input_type(in_type)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info)
return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, in_type, f.debug_info)
def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
@ -317,7 +307,7 @@ def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None
assert f.debug_info is None
if debug_info is None:
return f
return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info)
return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info)
def cache(call: Callable, *, explain: Callable | None = None):
@ -357,9 +347,9 @@ def cache(call: Callable, *, explain: Callable | None = None):
cache_clearing_funs.add(memoized_fun.cache_clear)
return memoized_fun
@transformation
def hashable_partial(*args):
yield (yield args, {})
@transformation2
def hashable_partial(f, *args):
return f(*args)
def merge_linear_aux(aux1, aux2):

View File

@ -824,14 +824,13 @@ def debug_print_lowering_rule(ctx, *args, **params):
# because they should appear as atomic JAX values to the users.
# TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU
# inferred by the compiler.
@lu.transformation
def wrap_with_transforms(transforms, *args):
@lu.transformation2
def wrap_with_transforms(f, transforms, *args):
new_args = tuple(
state_types.TransformedRef(a, t) if t else a
for a, t in zip(args, transforms)
)
res = yield new_args, {}
yield res
return f(*new_args)
run_scoped_p = jax_core.Primitive("run_scoped")

View File

@ -97,34 +97,34 @@ def jvp(f, primals, tangents, attr_tangents):
out_tangents = tree_unflatten(out_tree(), out_tangents_flat)
return out_primals, out_tangents, tangent_attrs_out
@lu.transformation
def _set_attrs(attrs, attr_vals, *args):
@lu.transformation2
def _set_attrs(f, attrs, attr_vals, *args):
for (o, a), x in zip(attrs, attr_vals):
jax_setattr(o, a, x)
yield (yield args, {})
return f(*args)
def _jvp(fun: lu.WrappedFun):
return jvpfun2(jvp_subtrace2(fun))
@lu.transformation
def jvpfun2(primals, tangents):
@lu.transformation2
def jvpfun2(f, primals, tangents):
tag = core.TraceTag()
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
and dtype(t) == float0 else t for t in tangents]
ctx = source_info_util.transform_name_stack('jvp')
with ctx:
out_primals, out_tangents, tangent_attrs_out = yield (tag, primals, tangents), {}
yield out_primals, out_tangents, tangent_attrs_out
out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents)
return out_primals, out_tangents, tangent_attrs_out
@lu.transformation
def jvp_subtrace2(tag, primals, tangents):
@lu.transformation2
def jvp_subtrace2(f, tag, primals, tangents):
with core.take_current_trace() as parent_trace:
trace = ad.JVPTrace(parent_trace, tag)
tag.attrs_tracked = [] # attrs written to
in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x
for x, t in zip(primals, tangents)]
with core.set_current_trace(trace):
ans = yield in_tracers, {}
ans = f(*in_tracers)
out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans))
tangent_attrs_out = []
for (obj, name) in tag.attrs_tracked:
@ -133,7 +133,7 @@ def jvp_subtrace2(tag, primals, tangents):
if type(tangent) is not ad.Zero:
tangent_attrs_out.append((obj, name, tangent))
del tag.attrs_tracked
yield out_primals, out_tangents, tangent_attrs_out
return out_primals, out_tangents, tangent_attrs_out
def _setattr_jvp(trace, obj, attr, maybe_tracer):
primal, tangent = trace.to_primal_tangent_pair(maybe_tracer)
@ -175,11 +175,12 @@ def _linearize(traceable: lu.WrappedFun, *primals):
return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals],
jaxpr, consts, attrs())
@lu.transformation_with_aux
def _split_attrs(*args, **kwargs):
primals, tangents, tangent_attrs = yield args, kwargs
@lu.transformation_with_aux2
def _split_attrs(f, store, *args, **kwargs):
primals, tangents, tangent_attrs = f(*args, **kwargs)
attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs)
yield (primals, tangents, tangent_attr_vals), attrs
store.store(attrs)
return primals, tangents, tangent_attr_vals
def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs):
in_tree, out_tree = io_tree

View File

@ -1040,20 +1040,20 @@ def _convert_jax_impl(impl_jax: Callable, *,
return wrapped_tf
@lu.transformation
def _interpret_subtrace(in_avals: Sequence[core.ShapedArray],
@lu.transformation2
def _interpret_subtrace(f, in_avals: Sequence[core.ShapedArray],
*in_vals: TfVal):
trace = TensorFlowTrace()
in_tracers = tuple(
TensorFlowTracer(trace, val, aval)
for val, aval in zip(in_vals, in_avals))
with core.set_current_trace(trace):
outs = yield in_tracers, {} # type: Sequence[TfVal]
outs = f(*in_tracers)
out_tracers: Iterable[TensorFlowTracer] = (
map(trace.to_tf_tracer, outs))
out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = (
tuple((t.val, t.aval) for t in out_tracers))
yield out_vals_with_avals
return out_vals_with_avals
def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal,

View File

@ -141,40 +141,43 @@ def jet(fun, primals, series):
if not treedef_is_leaf(treedef):
raise ValueError(f"term {j} for argument {i} is not an array")
@lu.transformation_with_aux
def flatten_fun_output(*args):
ans = yield args, {}
yield tree_flatten(ans)
@lu.transformation_with_aux2
def flatten_fun_output(f, store, *args):
ans = f(*args)
ans, tree = tree_flatten(ans)
store.store(tree)
return ans
f, out_tree = flatten_fun_output(lu.wrap_init(fun))
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
@lu.transformation
def jet_fun(order, primals, series):
@lu.transformation2
def jet_fun(f, order, primals, series):
tag = core.TraceTag()
out_primals, out_terms = yield (tag, order, primals, series), {}
out_primals, out_terms = f(tag, order, primals, series)
out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s
for p, s in zip(out_primals, out_terms)]
yield out_primals, out_terms
return out_primals, out_terms
@lu.transformation
def jet_subtrace(tag, order, primals, series):
@lu.transformation2
def jet_subtrace(f, tag, order, primals, series):
with core.take_current_trace() as parent_trace:
trace = JetTrace(tag, parent_trace, order)
in_tracers = map(partial(JetTracer, trace), primals, series)
with core.set_current_trace(trace):
ans = yield in_tracers, {}
ans = f(*in_tracers)
out_primals, out_terms = unzip2(map(trace.to_primal_terms_pair, ans))
yield out_primals, out_terms
return out_primals, out_terms
@lu.transformation_with_aux
def traceable(in_tree_def, *primals_and_series):
@lu.transformation_with_aux2
def traceable(f, store, in_tree_def, *primals_and_series):
primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series)
primals_out, series_out = yield (primals_in, series_in), {}
primals_out, series_out = f(primals_in, series_in)
out_flat, out_tree_def = tree_flatten((primals_out, series_out))
yield out_flat, out_tree_def
store.store(out_tree_def)
return out_flat
class JetTracer(core.Tracer):

View File

@ -47,12 +47,12 @@ zip = safe_zip
def ravel_first_arg(f, unravel):
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
@lu.transformation
def ravel_first_arg_(unravel, y_flat, *args):
@lu.transformation2
def ravel_first_arg_(f, unravel, y_flat, *args):
y = unravel(y_flat)
ans = yield (y,) + args, {}
ans = f(y, *args)
ans_flat, _ = ravel_pytree(ans)
yield ans_flat
return ans_flat
def interp_fit_dopri(y0, y1, k, dt):
# Fit a polynomial to the results of a Runge-Kutta step.

View File

@ -1479,15 +1479,15 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
return pe.merge_lists(out_knowns, out_tracers, out_consts)
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
@lu.transformation
def _promote_scalar_residuals(*args, **kwargs):
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs
@lu.transformation2
def _promote_scalar_residuals(f, *args, **kwargs):
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs)
which = [f1 is None and f2 is None and not v.aval.shape
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in out_consts]
yield jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
def _promote_scalar_residuals_jaxpr(jaxpr, which):
@lu.wrap_init
@ -1728,13 +1728,13 @@ def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name):
check_rep=False, auto=frozenset()),
in_specs, out_specs)
@lu.transformation
def _handle_reshapes(in_axes, out_axes_thunk, *args, **kwargs):
@lu.transformation2
def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax),
list(args), list(in_axes))
out = yield args, {}
yield tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax),
list(out), list(out_axes_thunk()))
out = f(*args)
return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax),
list(out), list(out_axes_thunk()))
def _axis_to_spec(axis_name, ax):
if isinstance(ax, int):
@ -1855,27 +1855,28 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk):
fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps)
return _match_rep(fun, mesh, out_reps_src, out_reps_dst)
@lu.transformation_with_aux
def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args):
@lu.transformation_with_aux2
def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args):
with core.take_current_trace() as parent:
tag = core.TraceTag()
t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh)
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
with core.set_current_trace(t):
ans = yield in_tracers, {}
ans = f(*in_tracers)
out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans))
del t, in_tracers, ans
yield out_vals, out_reps
store.store(out_reps)
return out_vals
@lu.transformation
def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args):
outs = yield args, {}
@lu.transformation2
def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args):
outs = f(*args)
out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_
out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_
_check_reps2(mesh, out_reps_dst, out_reps_src)
outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst
else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)]
yield outs
return outs
# TODO(mattjj): caching
def _replication_rewrite_match(
@ -1901,16 +1902,17 @@ def _replication_rewrite_nomatch(
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts), out_rep()
@lu.transformation_with_aux
def _rewrite_subtrace(tag, mesh, in_reps, *in_vals):
@lu.transformation_with_aux2
def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals):
with core.take_current_trace() as parent_trace:
assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals))
t = RewriteTrace(parent_trace, tag, mesh)
in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals)
with core.set_current_trace(t):
outs = yield in_tracers, {}
ans = unzip2(map(t.to_val_rep_pair, outs))
yield ans
outs = f(*in_tracers)
out_vals, out_reps = unzip2(map(t.to_val_rep_pair, outs))
store.store(out_reps)
return out_vals
def _rewrite_bwd(bwd, mesh, in_reps, reps_dst):
def new_bwd(*args):

View File

@ -340,16 +340,17 @@ class SparseTrace(core.Trace):
with core.set_current_trace(self):
return fun.call_wrapped(*tracers)
@lu.transformation_with_aux
def sparsify_subtrace(tag, spenv, spvalues, *bufs):
@lu.transformation_with_aux2
def sparsify_subtrace(f, store, tag, spenv, spvalues, *bufs):
with core.take_current_trace() as parent:
trace = SparseTrace(parent, tag, spenv)
with core.set_current_trace(trace):
in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues]
outs = yield in_tracers, {}
outs = f(*in_tracers)
out_traces = [trace.to_sparse_tracer(out) for out in outs]
buffers = spenv._buffers
yield buffers, [out._spvalue for out in out_traces]
store.store([out._spvalue for out in out_traces])
return buffers
def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]):
tag = core.TraceTag()

View File

@ -22,5 +22,7 @@ from jax._src.linear_util import (
merge_linear_aux as merge_linear_aux,
transformation as transformation,
transformation_with_aux as transformation_with_aux,
transformation2 as transformation2,
transformation_with_aux2 as transformation_with_aux2,
wrap_init as wrap_init,
)

View File

@ -42,8 +42,8 @@ class UtilTest(jtu.JaxTestCase):
assert not kwargs
return tuple(a * factor for a in args)
@lu.transformation_with_aux
def kw_to_positional(factor, *args, **kwargs):
@lu.transformation_with_aux2
def kw_to_positional(f, store, factor, *args, **kwargs):
"""A transformation with auxiliary output.
Turns all keyword parameters into positional ones.
@ -55,12 +55,12 @@ class UtilTest(jtu.JaxTestCase):
kwargs_keys = kwargs.keys()
new_args = tuple(kwargs[k] for k in kwargs_keys)
new_kwargs = dict(factor=factor)
results = yield args + new_args, new_kwargs # Yield transformed (args, kwargs)
results = f(*(args + new_args), **new_kwargs) # Yield transformed (args, kwargs)
# Assume results correspond 1:1 to the args + new_args
assert len(results) == len(args) + len(new_args)
aux_output = len(new_args)
yield (results[0:len(args)],
dict(zip(kwargs_keys, results[len(args):]))), aux_output
store.store(aux_output)
return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):])))
wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`.
wf, out_thunk = kw_to_positional(wf, 2)