From 4a93c8b30c4dafd1c53ab788aba1122b5e56e458 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 5 Mar 2025 10:21:52 -0800 Subject: [PATCH] Reverts 342cb7b99a09180472823a33c7cdad8a8db77875 PiperOrigin-RevId: 733782497 --- CHANGELOG.md | 3 - jax/_src/custom_derivatives.py | 253 ++++++++++++++++++++++++++---- jax/custom_derivatives.py | 1 + jax/experimental/jax2tf/jax2tf.py | 12 +- tests/api_test.py | 78 +++------ 5 files changed, 250 insertions(+), 97 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65f1dafa8..86d0cab0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,9 +23,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. true, matching the current behavior. If set to false, JAX does not need to emit code clamping negative indices, which improves code size. -* Breaking changes - * The ``jax.custom_derivatives.remat_opt_p`` helper primitive was removed. - ## jax 0.5.1 (Feb 24, 2025) * New Features diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 579086f36..32856106a 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -16,7 +16,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence import dataclasses -from functools import update_wrapper, reduce, partial +from functools import update_wrapper, reduce, partial, wraps from typing import Any, Generic, TypeVar from jax._src import config @@ -32,7 +32,6 @@ from jax._src.ad_util import ( from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs, prepend_static_args, debug_info) -from jax._src.custom_dce import custom_dce from jax._src.errors import UnexpectedTracerError from jax._src.state.types import AbstractRef from jax._src.interpreters import ad @@ -658,12 +657,10 @@ class custom_vjp(Generic[ReturnValue]): # TODO(necula): figure out how to construct the debug_bwd args debug_bwd = debug_info("custom_vjp bwd", self.bwd, args, {}) if self.optimize_remat: - if self.symbolic_zeros: - # TODO(dfm): This probably shouldn't be too hard to support. - raise NotImplementedError( - "remat optimization for custom_vjp does not support symbolic zeros") fwd = optimize_remat_of_custom_vjp_fwd( - self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums) + self.fun, debug_fun, self.fwd, debug_fwd, + nondiff_argnums=self.nondiff_argnums, + symbolic_zeros=self.symbolic_zeros) else: fwd = self.fwd if config.enable_custom_vjp_by_custom_transpose.value: @@ -1574,31 +1571,229 @@ custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr") # simpler, but it would be worth revisiting this. def optimize_remat_of_custom_vjp_fwd( fun: Callable[..., ReturnValue], + debug_fun: core.DebugInfo, fwd: Callable[..., tuple[ReturnValue, Any]], + debug_fwd: core.DebugInfo, nondiff_argnums: Sequence[int] = (), + symbolic_zeros: bool = False, ) -> Callable[..., tuple[ReturnValue, Any]]: - wrapped_fwd = custom_dce( - # It might seem like we don't need this lambda, but there are some real - # world use cases where the signature of `fwd` is wrong, and we shouldn't - # error out when resolving the arguments in those cases. This is fine, - # because the arguments have already been resolved in custom_vjp. - lambda *args: fwd(*args), # pylint: disable=unnecessary-lambda - static_argnums=nondiff_argnums, - ) + if symbolic_zeros: + # TODO(dfm): This probably shouldn't be too hard to support. + raise NotImplementedError( + "remat optimization for custom_vjp does not support symbolic zeros") - @wrapped_fwd.def_dce - def _(*args): - static_args, used_outs, args = split_list(args, [len(nondiff_argnums), 1]) - static_args_iter = iter(static_args) - args_iter = iter(args) - nondiff_argnums_ = set(nondiff_argnums) - fun_args = [ - next(static_args_iter) if i in nondiff_argnums_ else next(args_iter) - for i in range(len(static_args) + len(args))] - used_outs, = used_outs - _, used_res = used_outs - if any(tree_leaves(used_res)): - return fwd(*fun_args) - return fun(*fun_args), None + @wraps(fwd) + def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: + # TODO(dfm): This initial logic is duplicated from custom_vjp.__call__ + # above and it would be good to consolidate it. + fwd_name = debug_fwd.func_name if debug_fwd else str(fwd) + # Note: we use `fun` instead of `fwd` here for consistency with + # custom_vjp.__call__ above. + args = resolve_kwargs(fun, args, kwargs) + if nondiff_argnums: + for i in nondiff_argnums: _check_for_tracers(args[i]) + nondiff_argnums_ = set(nondiff_argnums) + dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_] + f_, dyn_args = argnums_partial(lu.wrap_init(fun, debug_info=debug_fun), + dyn_argnums, + args, require_static_args_hashable=False) + fwd_, _ = argnums_partial(lu.wrap_init(fwd, debug_info=debug_fwd), + dyn_argnums, args, + require_static_args_hashable=False) + else: + f_, dyn_args = lu.wrap_init(fun, debug_info=debug_fun), args + fwd_ = lu.wrap_init(fwd, debug_info=debug_fwd) + args_flat, in_tree = tree_flatten(dyn_args) + flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) + flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False, + debug_fun, debug_fwd, in_tree, out_type) + flat_fwd = _fix_fwd_args(flat_fwd) + + in_avals = [core.get_aval(x) for x in args_flat] + fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) + fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) + prim_tree, res_tree = out_trees() + num_res = res_tree.num_leaves + + if fwd_jaxpr.effects: + raise NotImplementedError( + "remat optimization for custom_vjp does not support forward " + f"functions with side effects, but {fwd_name} has the following " + f"effects: {fwd_jaxpr.effects}") + + @pe._memoize + def fun_jaxpr_thunk(): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + return jaxpr, consts + + out_flat = remat_opt_p.bind(*consts, *args_flat, + num_consts=len(consts), + num_res=num_res, + fwd_jaxpr=fwd_jaxpr, + fun_jaxpr_thunk=fun_jaxpr_thunk) + res, out_flat = split_list(out_flat, [num_res]) + out_tree = treedef_tuple((prim_tree, res_tree)) + return tree_unflatten(out_tree, (*out_flat, *res)) return wrapped_fwd + +@lu.transformation2 +def _fix_fwd_args(f, *args): + args = [(x, True) for x in args] + args = [x for pair in args for x in pair] + return f(*args) + +def _remat_opt_impl( + *args, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + del num_consts, num_res, fun_jaxpr_thunk # unused + return core.jaxpr_as_fun(fwd_jaxpr)(*args) + +def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_): + del args + return fwd_jaxpr.out_avals, fwd_jaxpr.effects + +def _remat_opt_vmap( + axis_data, args, in_dims, + *, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 + else x for x, d in zip(args, in_dims)] + in_batched = [d is not not_mapped for d in in_dims] + batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( + fwd_jaxpr, axis_data, in_batched, False) + extra_consts = batched_fwd_jaxpr.consts + batched_fwd_jaxpr = pe.close_jaxpr( + pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr)) + out_dims = [0 if b else not_mapped for b in out_batched] + + _, prim_batched = split_list(in_batched, [num_consts]) + + @pe._memoize + def batched_fun_jaxpr_thunk(): + fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) + batched_fun_jaxpr, out_batched = batching.batch_jaxpr( + fun_jaxpr, axis_data, prim_batched, False) + return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts + + batched_outs = remat_opt_p.bind(*extra_consts, *args, + num_consts=num_consts + len(extra_consts), + num_res=num_res, + fwd_jaxpr=batched_fwd_jaxpr, + fun_jaxpr_thunk=batched_fun_jaxpr_thunk) + + return batched_outs, out_dims + +def _remat_opt_jvp( + primals, + tangents, + *, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + consts, primals = split_list(primals, [num_consts]) + consts_dot, tangents = split_list(tangents, [num_consts]) + # Tangents must be instantated in case we end up DCEing later. + tangents = map(ad.instantiate_zeros, tangents) + consts_nz = [not isinstance(t, Zero) for t in consts_dot] + consts_dot = [c for nz, c in zip(consts_nz, consts_dot) if nz] + in_nz = consts_nz + [True] * len(tangents) + fwd_jaxpr_jvp_, out_nz = ad.jvp_jaxpr(fwd_jaxpr, in_nz, True) + num_out = len(out_nz) - num_res + fwd_jaxpr_jvp_ = ad.rearrange_binders( + fwd_jaxpr_jvp_, [num_consts, len(primals)], + [len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out]) + fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr)) + + # @pe._memoize + def fun_jvp_jaxpr_thunk(): + fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) + in_nz = [True] * len(primals) + fun_jvp_jaxpr, _ = ad.jvp_jaxpr(fun_jaxpr, in_nz, True) + return fun_jvp_jaxpr.jaxpr, fun_jvp_jaxpr.consts + + new_num_consts = len(fwd_jaxpr_jvp_.consts) + num_consts + len(consts_dot) + outs = remat_opt_p.bind(*fwd_jaxpr_jvp_.consts, *consts, *consts_dot, + *primals, *tangents, num_consts=new_num_consts, + num_res=2 * num_res, fwd_jaxpr=fwd_jaxpr_jvp, + fun_jaxpr_thunk=fun_jvp_jaxpr_thunk) + res, res_dot, outs, outs_dot = split_list(outs, [num_res, num_res, num_out]) + return (*res, *outs), (*res_dot, *outs_dot) + +def _remat_opt_transpose( + cts, *args, + num_consts: int, + num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr], +): + # TODO(dfm): It shouldn't be too hard to implement this as needed in the + # future. + raise NotImplementedError( + "remat optimization for custom_vjp does not support higher-order AD") + +def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): + if not any(used_outs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]]) + outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] + if any(used_res): + # If any of the residuals are used, we still need to run fwd at this point, + # but we may end up DCEing again in the future, so we must instantiate all + # the input primals. + instantiate = [False] * eqn.params["num_consts"] + instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"]) + new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs, + instantiate=instantiate) + assert not new_jaxpr.constvars + closed_jaxpr = pe.close_jaxpr(new_jaxpr) + invars = [v for used, v in zip(used_ins, eqn.invars) if used] + new_params = dict(eqn.params) + new_num_consts = sum(split_list(used_ins, [eqn.params["num_consts"]])[0]) + new_params["num_consts"] = new_num_consts + new_params["fwd_jaxpr"] = closed_jaxpr + new_params["num_res"] = sum(used_res) + new_eqn = pe.new_jaxpr_eqn( + invars, outvars, remat_opt_p, new_params, closed_jaxpr.effects, + eqn.source_info, eqn.ctx) + return used_ins, new_eqn + else: + # If none of the residuals are used, we run the primal computation instead. + # At this point we drop this custom DCE behavior, but since the primal might + # have different consts than fwd, we build a new JaxprEqn with a closed_call + # primitive. + fun_jaxpr, consts = eqn.params["fun_jaxpr_thunk"]() + new_jaxpr, used_consts, used_ins = pe.dce_jaxpr_consts(fun_jaxpr, used_prims) + consts = [c for used, c in zip(used_consts, consts) if used] + closed_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) + _, invars = split_list(eqn.invars, [eqn.params["num_consts"]]) + invars = [v for used, v in zip(used_ins, invars) if used] + new_eqn = pe.new_jaxpr_eqn( + invars, outvars, core.closed_call_p, dict(call_jaxpr=closed_jaxpr), + closed_jaxpr.effects, eqn.source_info, eqn.ctx) + used_ins = [False] * eqn.params["num_consts"] + used_ins + return used_ins, new_eqn + +remat_opt_p = core.Primitive("remat_opt") +remat_opt_p.multiple_results = True +remat_opt_p.def_impl(_remat_opt_impl) +remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval) +xla.register_initial_style_primitive(remat_opt_p) +mlir.register_lowering(remat_opt_p, mlir.lower_fun( + _remat_opt_impl, multiple_results=True)) + + +batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap +ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp +ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose +pe.dce_rules[remat_opt_p] = _remat_opt_dce diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 0b0c8621e..3628ae4aa 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -30,6 +30,7 @@ from jax._src.custom_derivatives import ( custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, CustomVJPPrimal as CustomVJPPrimal, linear_call as linear_call, + remat_opt_p as remat_opt_p, ) from jax._src.ad_util import ( diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 8f39f53ea..d58a1bb0d 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -45,7 +45,6 @@ from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core -from jax._src import custom_dce from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu @@ -3474,14 +3473,15 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]: tf_impl[ad.custom_lin_p] = _custom_lin -def _custom_dce(*args: TfVal, num_consts: int, fun_jaxpr: core.ClosedJaxpr, - dce_jaxpr_thunk: Callable) -> Sequence[TfVal]: - del num_consts, dce_jaxpr_thunk - return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_dce_call", +def _remat_opt(*args: TfVal, num_consts: int, num_res: int, + fwd_jaxpr: core.ClosedJaxpr, + fun_jaxpr_thunk: Callable) -> Sequence[TfVal]: + del num_consts, num_res, fun_jaxpr_thunk + return _interpret_jaxpr(fwd_jaxpr, *args, extra_name_stack="remat_opt", fresh_constant_cache=False) -tf_impl[custom_dce.custom_dce_p] = _custom_dce +tf_impl[custom_derivatives.remat_opt_p] = _remat_opt PartitionsOrReplicated = Union[tuple[int, ...], None] diff --git a/tests/api_test.py b/tests/api_test.py index a5e441259..543335529 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -9599,7 +9599,10 @@ class CustomVJPTest(jtu.JaxTestCase): return np.array([2.0])*x*x/np.array([1.0]), (x,) x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed @@ -9609,7 +9612,9 @@ class CustomVJPTest(jtu.JaxTestCase): def fwd(x): return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) @@ -9620,7 +9625,9 @@ class CustomVJPTest(jtu.JaxTestCase): return x*x, (x,) x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) def g(x): return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) @@ -9634,7 +9641,9 @@ class CustomVJPTest(jtu.JaxTestCase): def fwd_(x): return x*x, (x,) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd(fun, fwd_) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), + fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) calc = jax.jvp(fwd, (3.2,), (1.0,)) expected = jax.jvp(fwd_, (3.2,), (1.0,)) self.assertAllClose(calc, expected) @@ -9731,55 +9740,6 @@ class CustomVJPTest(jtu.JaxTestCase): x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error - def test_optimize_remat_nondiff_argnums(self): - @partial(jax.custom_vjp, nondiff_argnums=(2,)) - def f(x, y, fun): - return fun(x, y) - - def f_fwd(x, y, fun): - del fun - return jnp.cos(x) * y, (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(fun, res, g): - del fun - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - def fun(x, y): - return jnp.sin(x) * y - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 0.5, 0.1 - res = jax.value_and_grad(lambda *args: f(*args, fun))(x, y)[0] - self.assertAllClose(res, f_fwd(x, y, fun)[0]) - res = jax.jit(lambda *args: jax.value_and_grad( - lambda *args: f(*args, fun))(*args)[0])(x, y) - self.assertAllClose(res, fun(x, y)) - - def test_optimize_remat_incorrect_signature(self): - def f_(x, y): - return jnp.sin(x) * y - - @jax.custom_vjp - def f(x, y): - return f_(x, y) - - def wrong_signature(x, y, z): - self.fail("wrong_signature should not be called") - - @functools.wraps(wrong_signature) - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 3.2, 1.0 - self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) - - def test_dce(self): @jax.custom_vjp def f(x, y): @@ -10508,20 +10468,20 @@ class CustomDceTest(jtu.JaxTestCase): self.assertAllClose(v, jnp.tan(3.2)**2) def test_static_argnums(self): - @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(1,)) - def g(x, f): + @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) + def g(f, x): return f(x), 10 * f(x) @g.def_dce def g_dce(f, used_outs, x): # note: static_argnums are always passes first self.assertTrue(callable(f)) - return [2 * v if used else None for used, v in zip(used_outs, g(x, f))] + return [2 * v if used else None for used, v in zip(used_outs, g(f, x))] x = 1.1234 f = lambda x: jnp.exp(x) - expected = g(x, f) - self.assertAllClose(jax.jit(lambda x: g(x, f)[0])(x), 2 * expected[0]) - self.assertAllClose(jax.jit(lambda x: g(x, f)[1])(x), 2 * expected[1]) + expected = g(f, x) + self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0]) + self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1]) def test_shape_mismatch_error(self): @jax.experimental.custom_dce.custom_dce