add jax.input_saved_vjp to let user pass primal inputs to bwd pass

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2025-03-09 01:26:40 +00:00
parent 6978f35293
commit 2bb7dbaa32
3 changed files with 141 additions and 1 deletions

View File

@ -25,6 +25,7 @@ from __future__ import annotations
import atexit
import collections
from collections.abc import Callable, Hashable, Iterable, Sequence
import dataclasses
from functools import partial, lru_cache
import inspect
import math
@ -41,7 +42,8 @@ from jax._src import stages
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
prefix_errors, generate_key_paths, tree_flatten_with_path)
prefix_errors, generate_key_paths, tree_flatten_with_path,
equality_errors_pytreedef)
from jax._src import config
from jax._src import core
from jax._src import dispatch
@ -2031,6 +2033,82 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)
def saved_input_vjp(f: Callable, which: Sequence[bool], *primals,
allow_unused: bool = True, allow_opaque: bool = True):
if len(which) != len(primals):
raise ValueError(
"length of 'which' argument must equal the number of primal input values, "
f"but got {len(which)=} and {len(primals)=}")
dbg = debug_info("saved_input_vjp", f, primals, {})
fun = lu.wrap_init(f, debug_info=dbg)
primals_flat, in_tree = tree_flatten(primals)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
out_primals_flat, _, jaxpr, residuals = ad.linearize(fun, *primals_flat)
primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w))
id_map = {id(x): i for i, x in enumerate(primals_filt)}
opaque_residuals = []
res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else
RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore
for r in residuals]
f_vjp = Partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, out_tree(),
jaxpr, opaque_residuals)
if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}):
unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which))
if w and id(x) not in res_ids]
assert unused
if len(unused) == 1:
(i, a), = unused
start, was = "an input value", "was"
msg = f" {dbg.arg_names[i]} of type {a.str_short()}"
else:
start, was = "multiple input values", "were"
msg = "\n" + "\n".join(f" * {dbg.arg_names[i]} of type {a.str_short()}"
for i, a in unused)
raise Exception(f"with {allow_unused=}, {start} marked to be saved {was} "
f"not used by the backward pass:{msg}")
if not allow_opaque and opaque_residuals:
msg = ", ".join(core.get_aval(x).str_short() for x in opaque_residuals)
raise Exception(f"with {allow_opaque=}, the backward pass requires opaque "
f"(non-input) residuals: {msg}")
out_primals = tree_unflatten(out_tree(), out_primals_flat)
return out_primals, f_vjp
def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, jaxpr,
opaque_residuals, ct, *saved_primals):
primals_filtered, filtered_tree_ = tree_flatten(saved_primals)
if filtered_tree != filtered_tree_:
raise ValueError(
"inputs passed to f_vjp must be a tuple of (pytrees of) "
"arrays with the same structure as\n"
" tuple(x for x, w in zip(inputs, which) if w)\n"
"given the original call\n"
" _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n"
"but the structures differ:\n" +
"\n".join(f" * inputs{keystr(path)} was a {thing1} in the original "
f"call, but a {thing2} here, so {explanation}"
for path, thing1, thing2, explanation
in equality_errors_pytreedef(filtered_tree, filtered_tree_)))
residuals = [primals_filtered[i.idx] if i.primal else opaque_residuals[i.idx]
for i in res_spec]
dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars]
cts_flat, out_tree_ = tree_flatten(ct)
assert out_tree_ == out_tree
arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat)
return tree_unflatten(in_tree, arg_cts)
@dataclasses.dataclass(frozen=True)
class RSpec:
idx: int
primal: bool
si_vjp = saved_input_vjp
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
"""Transpose a function that is promised to be linear.

View File

@ -19,6 +19,10 @@ from jax.experimental.x64_context import (
enable_x64 as enable_x64,
disable_x64 as disable_x64,
)
from jax._src.api import (
saved_input_vjp as saved_input_vjp,
si_vjp as si_vjp
)
from jax._src.callback import (
io_callback as io_callback
)

View File

@ -11496,5 +11496,63 @@ class OverrideLoweringTest(jtu.JaxTestCase):
self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir)
class InputSavedVJPTest(jtu.JaxTestCase):
def test_basic(self):
def f(x, y):
return x * y
primals = 2., 3.
y, f_vjp = api.si_vjp(f, [True, True], *primals)
arg_cts = f_vjp(1., *primals)
self.assertAllClose(y, 6.)
self.assertAllClose(arg_cts, (3., 2.))
def test_basic_unused(self):
f = jnp.sin
primals = 3.,
y, f_vjp = api.si_vjp(f, [True], *primals)
x_ct, = f_vjp(1., *primals)
self.assertAllClose(y, jnp.sin(3.))
self.assertAllClose(x_ct, jnp.cos(3.))
with self.assertRaisesRegex(Exception, "not used by the backward pass: x"):
_ = api.si_vjp(f, [True], *primals, allow_unused=False)
def test_basic_opaque(self):
f = jnp.sin
primals = 3.,
with self.assertRaisesRegex(Exception, "the backward pass requires opaque"):
_ = api.si_vjp(f, [True], *primals, allow_opaque=False)
def test_basic_pytree_error(self):
def f(x):
return [x['hi'] * x['bye']]
y, f_vjp = api.si_vjp(f, [True], {'hi': 2., 'bye': 3.})
arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.})
self.assertAllClose(y, [6.])
self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.})
with self.assertRaisesRegex(ValueError, "but the structures differ"):
f_vjp(1., {'hi': 2.})
def test_fsdp(self):
# see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp"
def f2(x, w):
x = 1. * x
x = x @ w
x = 2. * x
return x
x = jnp.ones((3, 4))
w = jnp.ones((4, 4))
y, f2_sivjp = api.si_vjp(f2, [False, True], x, w)
y_grad = jnp.ones_like(y)
x_grad, w_grad = f2_sivjp(y_grad, w)
self.assertAllClose(x_grad, 2. * y_grad @ w.T)
self.assertAllClose(w_grad, 2. * x.T @ y_grad)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())