From 1c9b23c566bcc0a373f2f9c8716bcc46851000f1 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 12 Nov 2024 22:39:26 -0800 Subject: [PATCH] Stop using generators for linear_util transformations. They lead to confusing code, nasty bugs, and unhelpful (but terse!) stack traces. --- jax/_src/api_util.py | 62 +++++++++-------- jax/_src/checkify.py | 19 +++--- jax/_src/custom_derivatives.py | 47 +++++++------ jax/_src/interpreters/ad.py | 61 +++++++++-------- jax/_src/interpreters/batching.py | 85 +++++++++++++----------- jax/_src/interpreters/partial_eval.py | 64 ++++++++++-------- jax/_src/interpreters/pxla.py | 8 +-- jax/_src/linear_util.py | 96 ++++++++++++--------------- jax/_src/pallas/primitives.py | 7 +- jax/experimental/attrs.py | 31 ++++----- jax/experimental/jax2tf/jax2tf.py | 8 +-- jax/experimental/jet.py | 35 +++++----- jax/experimental/ode.py | 8 +-- jax/experimental/shard_map.py | 46 +++++++------ jax/experimental/sparse/transform.py | 9 +-- jax/extend/linear_util.py | 2 + tests/util_test.py | 10 +-- 17 files changed, 311 insertions(+), 287 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 329abd6b7..1bfce85d5 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -68,11 +68,13 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]: else: return tuple(map(_ensure_str, x)) -@lu.transformation_with_aux -def flatten_fun(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun(f, store, in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) - ans = yield py_args, py_kwargs - yield tree_flatten(ans) + ans = f(*py_args, **py_kwargs) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def apply_flat_fun(fun, io_tree, *py_args): in_tree_expected, out_tree = io_tree @@ -82,11 +84,13 @@ def apply_flat_fun(fun, io_tree, *py_args): ans = fun(*args) return tree_unflatten(out_tree, ans) -@lu.transformation_with_aux -def flatten_fun_nokwargs(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_nokwargs(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - ans = yield py_args, {} - yield tree_flatten(ans) + ans = f(*py_args) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def apply_flat_fun_nokwargs(fun, io_tree, py_args): in_tree_expected, out_tree = io_tree @@ -118,17 +122,18 @@ def flattened_fun_in_tree( else: return in_tree, lambda: out_tree_store.val, has_kwargs -@lu.transformation_with_aux -def flatten_fun_nokwargs2(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_nokwargs2(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - pair = yield py_args, {} + pair = f(*py_args) if not isinstance(pair, (list, tuple)) or len(pair) != 2: raise TypeError("expected function with aux output to return a two-element " f"tuple, but got type {type(pair)} with value {pair!r}") ans, aux = pair ans_flat, ans_tree = tree_flatten(ans) aux_flat, aux_tree = tree_flatten(aux) - yield (ans_flat, aux_flat), (ans_tree, aux_tree) + store.store((ans_tree, aux_tree)) + return ans_flat, aux_flat class _HashableWithStrictTypeEquality: """Box object used when comparing static arguments as a jit key. @@ -277,8 +282,8 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args -@lu.transformation -def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs): +@lu.transformation2 +def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs): sentinel = object() args = [sentinel] * (len(fixed_args) + len(dyn_args)) for i, arg in zip(dyn_argnums, dyn_args): @@ -286,9 +291,7 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs): fixed_args_ = iter(fixed_args) args = [next(fixed_args_).val if x is sentinel else x for x in args] assert next(fixed_args_, sentinel) is sentinel - ans = yield args, kwargs - yield ans - + return f(*args, **kwargs) def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], kwargs: dict[str, Any]): @@ -311,11 +314,10 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs -@lu.transformation -def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): +@lu.transformation2 +def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs) - ans = yield args, kwargs - yield ans + return f(*args, **kwargs) @lru_cache(maxsize=4096) @@ -435,9 +437,9 @@ def flat_out_axes( f, out_axes = _flat_out_axes(f, tuple(leaves), treedef) return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef)) -@lu.transformation_with_aux -def _flat_out_axes(leaves, treedef, *args, **kwargs): - ans = yield args, kwargs +@lu.transformation_with_aux2 +def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs): + ans = f(*args, **kwargs) spec = tree_unflatten(treedef, leaves) try: spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None)) @@ -449,7 +451,8 @@ def _flat_out_axes(leaves, treedef, *args, **kwargs): "that the `out_axes` argument to `pmap` is a pytree prefix of the " "pmapped function's output.") raise ValueError(msg) from None - yield ans, spec_flat + store.store(spec_flat) + return ans def check_callable(fun): # In Python 3.10+, the only thing stopping us from supporting staticmethods @@ -683,11 +686,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() for path, l in generate_key_paths(x) if l is not static) -@lu.transformation_with_aux -def result_paths(*args, **kwargs): +@lu.transformation_with_aux2 +def result_paths(f, store, *args, **kwargs): "linear_util transform to get output pytree paths of pre-flattened function." - ans = yield args, kwargs - yield ans, [keystr(path) for path, _ in generate_key_paths(ans)] + ans = f(*args, **kwargs) + store.store([keystr(path) for path, _ in generate_key_paths(ans)]) + return ans def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, result_paths: tuple[str, ...] | None = None, diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 55db5d13e..22fde8bd1 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -330,11 +330,12 @@ def update_error(error, pred, code, metadata, payload, effect_type): ## Checkify transformation for plumbing functional error values. -@lu.transformation_with_aux -def _flatten_and_get_error_metadata_thunk(*invals): - error, out = yield invals, {} +@lu.transformation_with_aux2 +def _flatten_and_get_error_metadata_thunk(f, store, *invals): + error, out = f(*invals) out_vals, out_tree = jtu.tree_flatten((error, out)) - yield out_vals, (out_tree, set(error._pred.keys())) + store.store((out_tree, set(error._pred.keys()))) + return out_vals def default_checkify_rule(primitive: core.Primitive, error: Error, enabled_errors, *invals: core.Value, @@ -438,10 +439,12 @@ def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors, consts = tuple(c.x for c in hashable_consts) return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args) -@lu.transformation_with_aux -def flatten_fun_output(*args): - ans = yield args, {} - yield tree_flatten(ans) +@lu.transformation_with_aux2 +def flatten_fun_output(f, store, *args): + ans = f(*args) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def _reduce_any_error(error: Error): diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index e37494c4f..69130cc18 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -75,13 +75,14 @@ _stop_gradient = partial( # like the api_util.py function, but also grabs output avals for error checking -@lu.transformation_with_aux -def _flatten_fun_nokwargs(in_tree, *args_flat): +@lu.transformation_with_aux2 +def _flatten_fun_nokwargs(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - ans = yield py_args, {} + ans = f(*py_args) ans_flat, ans_tree = tree_flatten(ans) ans_avals = [core.get_aval(x) for x in ans_flat] - yield ans_flat, (ans_tree, ans_avals) + store.store((ans_tree, ans_avals)) + return ans_flat ### JVPs @@ -266,18 +267,18 @@ class custom_jvp(Generic[ReturnValue]): def _add_args(f, extra_args): return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args)) -@lu.transformation -def _add_args_(extra_args, *args, **kwargs): +@lu.transformation2 +def _add_args_(f, extra_args, *args, **kwargs): extra_args = tuple(arg.val for arg in extra_args) all_args = (extra_args + args) - yield (yield all_args, kwargs) + return f(*all_args, **kwargs) -@partial(lu.transformation_with_aux, use_eq_store=True) -def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): +@partial(lu.transformation_with_aux2, use_eq_store=True) +def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args): primals_in, tangents_in = split_list(args, [len(args) // 2]) py_primals = tree_unflatten(in_tree, primals_in) py_tangents = tree_unflatten(in_tree, tangents_in) - pair_out = yield (py_primals, py_tangents), {} + pair_out = f(py_primals, py_tangents) if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom JVP rule {jvp_name} for function {primal_name} " "must produce a pair (list or tuple of length two) representing " @@ -348,7 +349,8 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): if av_et != av_t) raise TypeError(msg.format('\n'.join(disagreements))) - yield primals_out + tangents_out, (out_tree, primal_avals) + store.store((out_tree, primal_avals)) + return primals_out + tangents_out class CustomJVPCallPrimitive(core.Primitive): multiple_results = True @@ -652,15 +654,15 @@ def _check_for_tracers(x): "arguments should typically not be indicated as nondiff_argnums.") raise UnexpectedTracerError(msg) -@partial(lu.transformation_with_aux, use_eq_store=True) -def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, +@partial(lu.transformation_with_aux2, use_eq_store=True) +def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, *args): if symbolic_zeros: args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])] else: args = args[::2] py_args = tree_unflatten(in_tree, args) - pair_out = yield py_args, {} + pair_out = f(*py_args) if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " "must produce a pair (list or tuple of length two) where the first " @@ -710,16 +712,17 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - yield (*res, *primals_out), (out_tree, res_tree) + store.store((out_tree, res_tree)) + return (*res, *primals_out) -@lu.transformation -def _flatten_bwd(in_tree, in_avals, out_trees, *args): +@lu.transformation2 +def _flatten_bwd(f, in_tree, in_avals, out_trees, *args): out_tree, res_tree = out_trees() assert len(args) == res_tree.num_leaves + out_tree.num_leaves res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) - py_cts_in = yield (py_res, py_cts_out), {} + py_cts_in = f(py_res, py_cts_out) if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)): py_cts_in = tuple(py_cts_in) # For each None in py_cts_in, indicating an argument for which the rule @@ -775,7 +778,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args): f"to an input of shape/dtype {a.str_short()}.") raise ValueError(msg) results.append(ct) - yield results + return results # TODO(mattjj): remove both these exceptions to cotangent compatibility check def _temporary_dtype_exception(a, a_) -> bool: @@ -1425,11 +1428,11 @@ def optimize_remat_of_custom_vjp_fwd( return wrapped_fwd -@lu.transformation -def _fix_fwd_args(*args): +@lu.transformation2 +def _fix_fwd_args(f, *args): args = [(x, True) for x in args] args = [x for pair in args for x in pair] - yield (yield args, {}) + return f(*args) def _remat_opt_impl( *args, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index d080aae75..99340e728 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -68,42 +68,43 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, fun, aux = jvp_subtrace_aux(fun) return jvpfun(fun, instantiate, transform_stack), aux -@lu.transformation -def jvpfun(instantiate, transform_stack, primals, tangents): +@lu.transformation2 +def jvpfun(f, instantiate, transform_stack, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) with ctx: - out_primals, out_tangents = yield (tag, primals, tangents), {} + out_primals, out_tangents = f(tag, primals, tangents) if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate)] - yield out_primals, out_tangents + return out_primals, out_tangents -@lu.transformation -def jvp_subtrace(tag, primals, tangents): +@lu.transformation2 +def jvp_subtrace(f, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) in_tracers = [maybe_jvp_tracer(trace, x, t) for x, t in zip(primals, tangents)] with core.set_current_trace(trace): - ans = yield in_tracers, {} + ans = f(*in_tracers) out = unzip2(map(trace.to_primal_tangent_pair, ans)) - yield out + return out -@lu.transformation_with_aux -def jvp_subtrace_aux(tag, primals, tangents): +@lu.transformation_with_aux2 +def jvp_subtrace_aux(f, store, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) with core.set_current_trace(trace): - ans, aux = yield map(partial(maybe_jvp_tracer, trace), primals, tangents), {} + ans, aux = f(*(map(partial(maybe_jvp_tracer, trace), primals, tangents))) out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag else x for x in aux] - yield (out_primals, out_tangents), aux_primals + store.store(aux_primals) + return out_primals, out_tangents def linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) @@ -262,10 +263,11 @@ def get_primitive_transpose(p): "Transpose rule (for reverse-mode differentiation) for '{}' " "not implemented".format(p)) from err -@lu.transformation_with_aux -def nonzero_tangent_outputs(*args, **kwargs): - results = (_, tangents_out) = yield args, kwargs - yield results, [type(r) is not Zero for r in tangents_out] +@lu.transformation_with_aux2 +def nonzero_tangent_outputs(f, store, *args, **kwargs): + results = (_, tangents_out) = f(*args, **kwargs) + store.store([type(r) is not Zero for r in tangents_out]) + return results class JVPTrace(Trace): @@ -543,15 +545,16 @@ deflinear2(add_jaxvals_p, lambda t, *args: (t, t)) def instantiate_zeros(tangent): return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent -@lu.transformation_with_aux -def traceable(in_tree, *primals_and_tangents): +@lu.transformation_with_aux2 +def traceable(f, store, in_tree, *primals_and_tangents): primals, tangents = tree_unflatten(in_tree, primals_and_tangents) tangents = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primals, tangents)] - primals_out, tangents_out = yield (primals, tangents), {} + primals_out, tangents_out = f(primals, tangents) tangents_out = [None if type(t) is Zero else t for t in tangents_out] out_flat, out_tree = tree_flatten((primals_out, tangents_out)) - yield out_flat, out_tree + store.store(out_tree) + return out_flat def call_transpose(primitive, params, call_jaxpr, args, ct, _): @@ -588,10 +591,11 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals): primitive_transposes[core.closed_call_p] = _closed_call_transpose -@lu.transformation_with_aux -def nonzero_outputs(*args, **kwargs): - results = yield args, kwargs - yield results, [type(r) is not Zero for r in results] +@lu.transformation_with_aux2 +def nonzero_outputs(f, store, *args, **kwargs): + results = f(*args, **kwargs) + store.store([type(r) is not Zero for r in results]) + return results def map_transpose(primitive, params, call_jaxpr, args, ct, _): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts @@ -655,17 +659,18 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate): jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() -@lu.transformation_with_aux -def f_jvp_traceable(nonzeros, *primals_and_nztangents): +@lu.transformation_with_aux2 +def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p) for p, nz in zip(primals, nonzeros)] - primals_out, tangents_out = yield (primals, tangents), {} + primals_out, tangents_out = f(primals, tangents) out_nonzeros = [type(t) is not Zero for t in tangents_out] nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero] - yield list(primals_out) + nonzero_tangents_out, out_nonzeros + store.store(out_nonzeros) + return list(primals_out) + nonzero_tangents_out def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index b6325ed81..f4658ec2b 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -327,11 +327,13 @@ def unregister_vmappable(data_type: type) -> None: def is_vmappable(x: Any) -> bool: return type(x) is Jumble or type(x) in vmappables -@lu.transformation_with_aux -def flatten_fun_for_vmap(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_for_vmap(f, store, in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) - ans = yield py_args, py_kwargs - yield tree_flatten(ans, is_leaf=is_vmappable) + ans = f(*py_args, **py_kwargs) + ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable) + store.store(out_tree) + return ans # Propagate ragged masking rules from invars to outvars # rule([params], [raggedness_per_invar], outvars) -> @@ -580,16 +582,16 @@ def batch(fun: lu.WrappedFun, axis_data, f = _batch_inner(fun, axis_data, out_dim_dests) return _batch_outer(f, axis_data, in_dims) -@lu.transformation -def _batch_outer(axis_data, in_dims, *in_vals): +@lu.transformation2 +def _batch_outer(f, axis_data, in_dims, *in_vals): tag = TraceTag() with source_info_util.transform_name_stack('vmap'): - outs, trace = yield (tag, in_dims, *in_vals), {} + outs, trace = f(tag, in_dims, *in_vals) with core.ensure_no_leaks(trace): del trace - yield outs + return outs -@lu.transformation -def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals): +@lu.transformation2 +def _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) @@ -599,13 +601,13 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals): with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), core.add_spmd_axis_names(axis_data.spmd_name)): - outs = yield in_tracers, {} + outs = f(*in_tracers) out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)), outs, out_dim_dests) - yield out_vals, trace + return out_vals, trace # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. def vtile(f_flat: lu.WrappedFun, @@ -628,21 +630,21 @@ def vtile(f_flat: lu.WrappedFun, shape[axis:axis+2] = [shape[axis] * shape[axis+1]] return out.reshape(shape) - @lu.transformation - def _map_to_tile(*args_flat): + @lu.transformation2 + def _map_to_tile(f, *args_flat): sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None) tile_size_ = tile_size or next(sizes, None) assert tile_size_ is not None, "No mapped arguments?" - outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} - yield map(untile_axis, outputs_flat, out_axes_flat) + outputs_flat = f(*map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat)) + return map(untile_axis, outputs_flat, out_axes_flat) axis_data = AxisData(axis_name, tile_size, None) return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat)) ### API for batching functions with jaxpr type inputs and outputs -@lu.transformation_with_aux -def batch_subtrace(tag, axis_data, in_dims, *in_vals): +@lu.transformation_with_aux2 +def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals): with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) with core.set_current_trace(trace): @@ -650,10 +652,11 @@ def batch_subtrace(tag, axis_data, in_dims, *in_vals): in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) if dim is not None else x for x, dim in zip(in_vals, in_dims)] - outs = yield in_tracers, {} + outs = f(*in_tracers) out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) segment_lens, out_dims = indirectify_ragged_axes(out_dims) - yield (*segment_lens, *out_vals), out_dims + store.store(out_dims) + return (*segment_lens, *out_vals) def indirectify_ragged_axes(dims): if not any(type(d) is RaggedAxis for d in dims): @@ -789,8 +792,8 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched() -@lu.transformation_with_aux -def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): +@lu.transformation_with_aux2 +def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals): with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) _, in_axes = resolve_ragged_axes(in_vals, in_axes) @@ -799,16 +802,17 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), core.add_spmd_axis_names(axis_data.spmd_name)): - outs = yield in_tracers, {} + outs = f(*in_tracers) out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) new_out_axes = indirectify_ragged_axes_against_inputs_outputs( out_axes, in_vals, out_vals) - yield out_vals, new_out_axes + store.store(new_out_axes) + return out_vals -@lu.transformation_with_aux -def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes, +@lu.transformation_with_aux2 +def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_axes, *in_vals): - out_vals = yield (trace, in_axes, *in_vals), {} + out_vals = f(trace, in_axes, *in_vals) out_axes = out_axes() out_axes_dest = [(None if src is not_mapped else 0) if dst is zero_if_mapped else dst @@ -819,16 +823,16 @@ def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes, out_vals = map(partial(matchaxis, axis_data.name, axis_data.size), out_axes, out_axes_dest, out_vals) out_batched = [dst is not None for dst in out_axes_dest] - yield out_vals, out_batched + store.store(out_batched) + return out_vals -@lu.transformation -def _batch_jaxpr_outer(axis_data, in_dims, *in_vals): +@lu.transformation2 +def _batch_jaxpr_outer(f, axis_data, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) else ax for x, ax in unsafe_zip(in_vals, in_dims)] tag = TraceTag() - out_vals = yield (tag, in_dims, *in_vals), {} - yield out_vals + return f(tag, in_dims, *in_vals) def _merge_bdims(x, y): if x == y: @@ -845,8 +849,8 @@ zero_if_mapped = ZeroIfMapped() ### functions for handling custom_vjp -@lu.transformation_with_aux -def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): +@lu.transformation_with_aux2 +def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals): size = axis_data.size with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) @@ -855,7 +859,7 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): if type(val) is SymbolicZero else BatchTracer(trace, val, dim) for val, dim in zip(in_vals, in_dims * 2)] with core.set_current_trace(trace): - outs = yield in_tracers, {} + outs = f(*in_tracers) # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can # be wasteful in the rare case it actually triggers; handle symbolically! outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] @@ -868,7 +872,8 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): out_primal_bds, out_dims, out_primals) out_tangents = map(partial(matchaxis, trace.axis_data.name, size), out_tangent_bds, out_dims, out_tangents) - yield out_primals + out_tangents, out_dims * 2 + store.store(out_dims * 2) + return out_primals + out_tangents def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests): axis_size = axis_data.size @@ -886,11 +891,11 @@ def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests): return bwd_.call_wrapped(*args) return new_bwd -@lu.transformation -def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals): +@lu.transformation2 +def _match_axes_and_sum(f, axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals): # this is like _match_axes, but we do reduce-sums as needed - out_vals = yield in_vals, {} - yield map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name, + out_vals = f(*in_vals) + return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name, sum_match=True), out_dims_thunk(), out_dim_dests, out_vals) def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 5431762d6..943c15b6e 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -475,18 +475,19 @@ def partition_pvals( consts = [pval.get_known() for pval in pvals if pval.is_known()] return knowns, avals, consts -@lu.transformation_with_aux +@lu.transformation_with_aux2 def partial_eval_wrapper_nounits( - in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], + f, store, in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], *in_consts: Any): in_avals_, in_consts_ = iter(in_avals), iter(in_consts) in_pvals = [PartialVal.known(next(in_consts_)) if known else PartialVal.unknown(next(in_avals_)) for known in in_knowns] sentinel = object() assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel - jaxpr, (*maybe_fwds, out_pvals, res, env) = yield (in_pvals,), {} + jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals) out_knowns, out_avals, out_consts = partition_pvals(out_pvals) - yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env) + store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env)) + return (*out_consts, *res) custom_partial_eval_rules: dict[Primitive, Callable] = {} call_partial_eval_rules: dict[Primitive, Callable] = {} @@ -574,20 +575,22 @@ def trace_to_jaxpr_nounits( return jaxpr, out_pvals, consts # TODO(mattjj): superfluous wrapper...? -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits( + f, trace: JaxprTrace, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] del out_tracers - yield jaxpr, (out_pvals, out_consts, env) + return jaxpr, (out_pvals, out_consts, env) -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits2( + f, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -596,19 +599,19 @@ def trace_to_subjaxpr_nounits2( current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] del out_tracers - yield jaxpr, (out_pvals, out_consts, env) + return jaxpr, (out_pvals, out_consts, env) -def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals): +def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals): in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] in_args = merge_lists(in_knowns, in_tracers, in_consts) with core.set_current_trace(trace): - ans = yield in_args, {} + ans = f(*in_args) assert isinstance(ans, (list, tuple)), ( f"Got unexpected return type when tracing function to jaxpr: {ans}") assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( @@ -625,8 +628,9 @@ def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals): # The below variant implements an optimization where residuals which are also # inputs are indicated in auxiliary data rather than passed as outputs. # TODO(mattjj): update all callers to use this version, delete other version. -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits_fwd( + f, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -635,8 +639,8 @@ def trace_to_subjaxpr_nounits_fwd( with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) with core.set_current_trace(trace): - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. @@ -646,15 +650,16 @@ def trace_to_subjaxpr_nounits_fwd( pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] del out_tracers - yield jaxpr, (fwds, out_pvals, pruned_consts, env) + return jaxpr, (fwds, out_pvals, pruned_consts, env) # The below variant implements two optimizations: # 1. residuals that are also primal inputs are indicated in aux data rather # than passed as outputs; # 2. residuals that are also primal outputs are indicated in aux data rather # than passed as redundant outputs. -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits_fwd2( + f, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -662,8 +667,8 @@ def trace_to_subjaxpr_nounits_fwd2( current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) - out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] # Which consts (aka residuals) are just forwarded inputs? Check obj id. @@ -680,7 +685,7 @@ def trace_to_subjaxpr_nounits_fwd2( if f1 is None and f2 is None] del out_tracers - yield jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env) + return jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env) FreeVar = namedtuple('FreeVar', ['val']) @@ -2066,10 +2071,10 @@ class DynamicJaxprTrace(core.Trace): custom_staging_rules: dict[Primitive, Callable] = {} -@lu.transformation -def _interleave_fun(every_others, *args, **kwargs): +@lu.transformation2 +def _interleave_fun(f, every_others, *args, **kwargs): args_ = [x for pair in zip(args, every_others) for x in pair] - yield (yield (args_, kwargs)) + return f(*args_, **kwargs) # TODO: consider renaming to "lazy_thunk" def _memoize(fn): @@ -2083,18 +2088,19 @@ def _memoize(fn): return out return memoized -@lu.transformation_with_aux -def _jvp_jaxpr_zeros(in_zeros, zero_avals, *primal_tangent_avals): +@lu.transformation_with_aux2 +def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): in_primals, nz_in_tangents = split_list(primal_tangent_avals, [len(in_zeros)]) symbolic_zeros = map(ad_util.SymbolicZero, zero_avals) tangents = merge_lists(in_zeros, nz_in_tangents, symbolic_zeros) - out = yield (*in_primals, *tangents), {} + out = f(*in_primals, *tangents) n, ragged = divmod(len(out), 2) assert not ragged out_primals, out_tangents = out[:n], out[n:] out_zeros = [type(t) is ad_util.SymbolicZero for t in out_tangents] out_nz_tangents, _ = partition_list(out_zeros, out_tangents) - yield [*out_primals, *out_nz_tangents], out_zeros + store.store(out_zeros) + return [*out_primals, *out_nz_tangents] # TODO(mattjj): remove this DebugInfo and helper functions, replace with # api_util.py versions diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 316fbc077..caa414741 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -690,15 +690,15 @@ def find_replicas( num_global_replicas = global_axis_size * jaxpr_replicas return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas) -@lu.transformation -def _change_argument_ranks(in_axes, out_axes_thunk, *args): +@lu.transformation2 +def _change_argument_ranks(f, in_axes, out_axes_thunk, *args): args = tuple( arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,)) for in_axis, arg in zip(in_axes, args) ) - results = yield (args, {}) + results = f(*args) out_axes = out_axes_thunk() - yield tuple( + return tuple( x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,)) for x, axis in zip(results, out_axes) ) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 08f94c6e8..37d812dec 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -64,6 +64,7 @@ data must be immutable, because it will be stored in function memoization tables from __future__ import annotations from collections.abc import Callable +from functools import partial from typing import Any, NamedTuple import weakref @@ -149,10 +150,11 @@ class WrappedFun: params: extra parameters to pass as keyword arguments to `f`, along with the transformed keyword arguments. """ - __slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info") + __slots__ = ("f", "f_transformed", "transforms", "stores", "params", "in_type", "debug_info") - def __init__(self, f, transforms, stores, params, in_type, debug_info): + def __init__(self, f, f_transformed, transforms, stores, params, in_type, debug_info): self.f = f + self.f_transformed = f_transformed self.transforms = transforms self.stores = stores self.params = params @@ -165,8 +167,14 @@ class WrappedFun: def wrap(self, gen, gen_static_args, out_store) -> WrappedFun: """Add another transform and its store.""" - return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None, None) + if out_store is None: + return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args), + ((gen, gen_static_args),) + self.transforms, + (out_store,) + self.stores, self.params, None, None) + else: + return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args), + ((gen, gen_static_args),) + self.transforms, + (out_store,) + self.stores, self.params, None, None) def populate_stores(self, stores): """Copy the values from the `stores` into `self.stores`.""" @@ -175,47 +183,8 @@ class WrappedFun: self_store.store(other_store.val) def call_wrapped(self, *args, **kwargs): - """Calls the underlying function, applying the transforms. - - The positional `args` and keyword `kwargs` are passed to the first - transformation generator. - """ - stack = [] - for (gen, gen_static_args), out_store in zip(self.transforms, self.stores): - gen = gen(*(gen_static_args + tuple(args)), **kwargs) - args, kwargs = next(gen) - stack.append((gen, out_store)) - gen = gen_static_args = out_store = None - - try: - ans = self.f(*args, **dict(self.params, **kwargs)) - except: - # Some transformations yield from inside context managers, so we have to - # interrupt them before reraising the exception. Otherwise they will only - # get garbage-collected at some later time, running their cleanup tasks - # only after this exception is handled, which can corrupt the global - # state. - while stack: - stack.pop()[0].close() - raise - - args = kwargs = None - while stack: - gen, out_store = stack.pop() - try: - ans = gen.send(ans) - except: - # As above does for the first half of the transformation, exceptions - # raised in the second half of the transformation also require us to - # clean up references here. - while stack: - stack.pop()[0].close() - raise - if out_store is not None: - ans, side = ans - out_store.store(side) - - return ans + """Calls the transformed function""" + return self.f_transformed(*args, **kwargs) def __repr__(self): def transform_to_str(x): @@ -234,7 +203,7 @@ class WrappedFun: self.debug_info == other.debug_info) @curry -def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: +def transformation2(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """Adds one more transformation to a WrappedFun. Args: @@ -244,8 +213,28 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """ return fun.wrap(gen, gen_static_args, None) +# Backwards compat only. TODO: deprecate @curry -def transformation_with_aux( +def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: + def gen2(f, *args, **kwargs): + gen_inst = gen(*args, **kwargs) + args_, kwargs_ = next(gen_inst) + return gen_inst.send(f(*args_, **kwargs_)) + return transformation2(gen2, fun, *gen_static_args)() + +# Backwards compat only. TODO: deprecate +@curry +def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: + def gen2(f, store, *args, **kwargs): + gen_inst = gen(*args, **kwargs) + args_, kwargs_ = next(gen_inst) + ans, aux = gen_inst.send(f(*args_, **kwargs_)) + store.store(aux) + return ans + return transformation_with_aux2(gen2, fun, *gen_static_args)() + +@curry +def transformation_with_aux2( gen, fun: WrappedFun, *gen_static_args, use_eq_store: bool = False ) -> tuple[WrappedFun, Callable[[], Any]]: """Adds one more transformation with auxiliary output to a WrappedFun.""" @@ -261,8 +250,9 @@ def fun_name(f): def wrap_init(f, params=None) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" + params_dict = {} if params is None else params params = () if params is None else tuple(sorted(params.items())) - return WrappedFun(f, (), (), params, None, None) + return WrappedFun(f, partial(f, **params_dict), (), (), params, None, None) def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: @@ -270,7 +260,7 @@ def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: if in_type is None: return f _check_input_type(in_type) - return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, in_type, f.debug_info) def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed @@ -317,7 +307,7 @@ def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None assert f.debug_info is None if debug_info is None: return f - return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info) def cache(call: Callable, *, explain: Callable | None = None): @@ -357,9 +347,9 @@ def cache(call: Callable, *, explain: Callable | None = None): cache_clearing_funs.add(memoized_fun.cache_clear) return memoized_fun -@transformation -def hashable_partial(*args): - yield (yield args, {}) +@transformation2 +def hashable_partial(f, *args): + return f(*args) def merge_linear_aux(aux1, aux2): diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index c7bd7dd71..d77ca86c1 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -824,14 +824,13 @@ def debug_print_lowering_rule(ctx, *args, **params): # because they should appear as atomic JAX values to the users. # TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU # inferred by the compiler. -@lu.transformation -def wrap_with_transforms(transforms, *args): +@lu.transformation2 +def wrap_with_transforms(f, transforms, *args): new_args = tuple( state_types.TransformedRef(a, t) if t else a for a, t in zip(args, transforms) ) - res = yield new_args, {} - yield res + return f(*new_args) run_scoped_p = jax_core.Primitive("run_scoped") diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index a25d93a35..b4adbadfa 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -97,34 +97,34 @@ def jvp(f, primals, tangents, attr_tangents): out_tangents = tree_unflatten(out_tree(), out_tangents_flat) return out_primals, out_tangents, tangent_attrs_out -@lu.transformation -def _set_attrs(attrs, attr_vals, *args): +@lu.transformation2 +def _set_attrs(f, attrs, attr_vals, *args): for (o, a), x in zip(attrs, attr_vals): jax_setattr(o, a, x) - yield (yield args, {}) + return f(*args) def _jvp(fun: lu.WrappedFun): return jvpfun2(jvp_subtrace2(fun)) -@lu.transformation -def jvpfun2(primals, tangents): +@lu.transformation2 +def jvpfun2(f, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = source_info_util.transform_name_stack('jvp') with ctx: - out_primals, out_tangents, tangent_attrs_out = yield (tag, primals, tangents), {} - yield out_primals, out_tangents, tangent_attrs_out + out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents) + return out_primals, out_tangents, tangent_attrs_out -@lu.transformation -def jvp_subtrace2(tag, primals, tangents): +@lu.transformation2 +def jvp_subtrace2(f, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = ad.JVPTrace(parent_trace, tag) tag.attrs_tracked = [] # attrs written to in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x for x, t in zip(primals, tangents)] with core.set_current_trace(trace): - ans = yield in_tracers, {} + ans = f(*in_tracers) out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) tangent_attrs_out = [] for (obj, name) in tag.attrs_tracked: @@ -133,7 +133,7 @@ def jvp_subtrace2(tag, primals, tangents): if type(tangent) is not ad.Zero: tangent_attrs_out.append((obj, name, tangent)) del tag.attrs_tracked - yield out_primals, out_tangents, tangent_attrs_out + return out_primals, out_tangents, tangent_attrs_out def _setattr_jvp(trace, obj, attr, maybe_tracer): primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) @@ -175,11 +175,12 @@ def _linearize(traceable: lu.WrappedFun, *primals): return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], jaxpr, consts, attrs()) -@lu.transformation_with_aux -def _split_attrs(*args, **kwargs): - primals, tangents, tangent_attrs = yield args, kwargs +@lu.transformation_with_aux2 +def _split_attrs(f, store, *args, **kwargs): + primals, tangents, tangent_attrs = f(*args, **kwargs) attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) - yield (primals, tangents, tangent_attr_vals), attrs + store.store(attrs) + return primals, tangents, tangent_attr_vals def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): in_tree, out_tree = io_tree diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c6d920918..b8acb0d1a 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1040,20 +1040,20 @@ def _convert_jax_impl(impl_jax: Callable, *, return wrapped_tf -@lu.transformation -def _interpret_subtrace(in_avals: Sequence[core.ShapedArray], +@lu.transformation2 +def _interpret_subtrace(f, in_avals: Sequence[core.ShapedArray], *in_vals: TfVal): trace = TensorFlowTrace() in_tracers = tuple( TensorFlowTracer(trace, val, aval) for val, aval in zip(in_vals, in_avals)) with core.set_current_trace(trace): - outs = yield in_tracers, {} # type: Sequence[TfVal] + outs = f(*in_tracers) out_tracers: Iterable[TensorFlowTracer] = ( map(trace.to_tf_tracer, outs)) out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = ( tuple((t.val, t.aval) for t in out_tracers)) - yield out_vals_with_avals + return out_vals_with_avals def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal, diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 827e4d01b..75b040a28 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -141,40 +141,43 @@ def jet(fun, primals, series): if not treedef_is_leaf(treedef): raise ValueError(f"term {j} for argument {i} is not an array") - @lu.transformation_with_aux - def flatten_fun_output(*args): - ans = yield args, {} - yield tree_flatten(ans) + @lu.transformation_with_aux2 + def flatten_fun_output(f, store, *args): + ans = f(*args) + ans, tree = tree_flatten(ans) + store.store(tree) + return ans f, out_tree = flatten_fun_output(lu.wrap_init(fun)) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms) -@lu.transformation -def jet_fun(order, primals, series): +@lu.transformation2 +def jet_fun(f, order, primals, series): tag = core.TraceTag() - out_primals, out_terms = yield (tag, order, primals, series), {} + out_primals, out_terms = f(tag, order, primals, series) out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s for p, s in zip(out_primals, out_terms)] - yield out_primals, out_terms + return out_primals, out_terms -@lu.transformation -def jet_subtrace(tag, order, primals, series): +@lu.transformation2 +def jet_subtrace(f, tag, order, primals, series): with core.take_current_trace() as parent_trace: trace = JetTrace(tag, parent_trace, order) in_tracers = map(partial(JetTracer, trace), primals, series) with core.set_current_trace(trace): - ans = yield in_tracers, {} + ans = f(*in_tracers) out_primals, out_terms = unzip2(map(trace.to_primal_terms_pair, ans)) - yield out_primals, out_terms + return out_primals, out_terms -@lu.transformation_with_aux -def traceable(in_tree_def, *primals_and_series): +@lu.transformation_with_aux2 +def traceable(f, store, in_tree_def, *primals_and_series): primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series) - primals_out, series_out = yield (primals_in, series_in), {} + primals_out, series_out = f(primals_in, series_in) out_flat, out_tree_def = tree_flatten((primals_out, series_out)) - yield out_flat, out_tree_def + store.store(out_tree_def) + return out_flat class JetTracer(core.Tracer): diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index b8e3daee4..987e461a3 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -47,12 +47,12 @@ zip = safe_zip def ravel_first_arg(f, unravel): return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped -@lu.transformation -def ravel_first_arg_(unravel, y_flat, *args): +@lu.transformation2 +def ravel_first_arg_(f, unravel, y_flat, *args): y = unravel(y_flat) - ans = yield (y,) + args, {} + ans = f(y, *args) ans_flat, _ = ravel_pytree(ans) - yield ans_flat + return ans_flat def interp_fit_dopri(y0, y1, k, dt): # Fit a polynomial to the results of a Runge-Kutta step. diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3a9446862..c658ddd3a 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1479,15 +1479,15 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, return pe.merge_lists(out_knowns, out_tracers, out_consts) pe.JaxprTrace.process_shard_map = _shard_map_partial_eval -@lu.transformation -def _promote_scalar_residuals(*args, **kwargs): - jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs +@lu.transformation2 +def _promote_scalar_residuals(f, *args, **kwargs): + jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) which = [f1 is None and f2 is None and not v.aval.shape for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x for x in out_consts] - yield jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) + return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) def _promote_scalar_residuals_jaxpr(jaxpr, which): @lu.wrap_init @@ -1728,13 +1728,13 @@ def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): check_rep=False, auto=frozenset()), in_specs, out_specs) -@lu.transformation -def _handle_reshapes(in_axes, out_axes_thunk, *args, **kwargs): +@lu.transformation2 +def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax), list(args), list(in_axes)) - out = yield args, {} - yield tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), - list(out), list(out_axes_thunk())) + out = f(*args) + return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), + list(out), list(out_axes_thunk())) def _axis_to_spec(axis_name, ax): if isinstance(ax, int): @@ -1855,27 +1855,28 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) return _match_rep(fun, mesh, out_reps_src, out_reps_dst) -@lu.transformation_with_aux -def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args): +@lu.transformation_with_aux2 +def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): with core.take_current_trace() as parent: tag = core.TraceTag() t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) in_tracers = map(partial(RewriteTracer, t), in_reps, args) with core.set_current_trace(t): - ans = yield in_tracers, {} + ans = f(*in_tracers) out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans)) del t, in_tracers, ans - yield out_vals, out_reps + store.store(out_reps) + return out_vals -@lu.transformation -def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args): - outs = yield args, {} +@lu.transformation2 +def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args): + outs = f(*args) out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_ out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_ _check_reps2(mesh, out_reps_dst, out_reps_src) outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)] - yield outs + return outs # TODO(mattjj): caching def _replication_rewrite_match( @@ -1901,16 +1902,17 @@ def _replication_rewrite_nomatch( jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts), out_rep() -@lu.transformation_with_aux -def _rewrite_subtrace(tag, mesh, in_reps, *in_vals): +@lu.transformation_with_aux2 +def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals): with core.take_current_trace() as parent_trace: assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) t = RewriteTrace(parent_trace, tag, mesh) in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) with core.set_current_trace(t): - outs = yield in_tracers, {} - ans = unzip2(map(t.to_val_rep_pair, outs)) - yield ans + outs = f(*in_tracers) + out_vals, out_reps = unzip2(map(t.to_val_rep_pair, outs)) + store.store(out_reps) + return out_vals def _rewrite_bwd(bwd, mesh, in_reps, reps_dst): def new_bwd(*args): diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 7c5a96650..050d0a5e0 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -340,16 +340,17 @@ class SparseTrace(core.Trace): with core.set_current_trace(self): return fun.call_wrapped(*tracers) -@lu.transformation_with_aux -def sparsify_subtrace(tag, spenv, spvalues, *bufs): +@lu.transformation_with_aux2 +def sparsify_subtrace(f, store, tag, spenv, spvalues, *bufs): with core.take_current_trace() as parent: trace = SparseTrace(parent, tag, spenv) with core.set_current_trace(trace): in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues] - outs = yield in_tracers, {} + outs = f(*in_tracers) out_traces = [trace.to_sparse_tracer(out) for out in outs] buffers = spenv._buffers - yield buffers, [out._spvalue for out in out_traces] + store.store([out._spvalue for out in out_traces]) + return buffers def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]): tag = core.TraceTag() diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 74c52dddb..8b80d033f 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -22,5 +22,7 @@ from jax._src.linear_util import ( merge_linear_aux as merge_linear_aux, transformation as transformation, transformation_with_aux as transformation_with_aux, + transformation2 as transformation2, + transformation_with_aux2 as transformation_with_aux2, wrap_init as wrap_init, ) diff --git a/tests/util_test.py b/tests/util_test.py index 5f07d2f50..5e99fff4b 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -42,8 +42,8 @@ class UtilTest(jtu.JaxTestCase): assert not kwargs return tuple(a * factor for a in args) - @lu.transformation_with_aux - def kw_to_positional(factor, *args, **kwargs): + @lu.transformation_with_aux2 + def kw_to_positional(f, store, factor, *args, **kwargs): """A transformation with auxiliary output. Turns all keyword parameters into positional ones. @@ -55,12 +55,12 @@ class UtilTest(jtu.JaxTestCase): kwargs_keys = kwargs.keys() new_args = tuple(kwargs[k] for k in kwargs_keys) new_kwargs = dict(factor=factor) - results = yield args + new_args, new_kwargs # Yield transformed (args, kwargs) + results = f(*(args + new_args), **new_kwargs) # Yield transformed (args, kwargs) # Assume results correspond 1:1 to the args + new_args assert len(results) == len(args) + len(new_args) aux_output = len(new_args) - yield (results[0:len(args)], - dict(zip(kwargs_keys, results[len(args):]))), aux_output + store.store(aux_output) + return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):]))) wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`. wf, out_thunk = kw_to_positional(wf, 2)