mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
add jax.input_saved_vjp to let user pass primal inputs to bwd pass
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
6978f35293
commit
2bb7dbaa32
@ -25,6 +25,7 @@ from __future__ import annotations
|
||||
import atexit
|
||||
import collections
|
||||
from collections.abc import Callable, Hashable, Iterable, Sequence
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache
|
||||
import inspect
|
||||
import math
|
||||
@ -41,7 +42,8 @@ from jax._src import stages
|
||||
from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
|
||||
tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix,
|
||||
prefix_errors, generate_key_paths, tree_flatten_with_path)
|
||||
prefix_errors, generate_key_paths, tree_flatten_with_path,
|
||||
equality_errors_pytreedef)
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -2031,6 +2033,82 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
|
||||
return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)
|
||||
|
||||
|
||||
def saved_input_vjp(f: Callable, which: Sequence[bool], *primals,
|
||||
allow_unused: bool = True, allow_opaque: bool = True):
|
||||
if len(which) != len(primals):
|
||||
raise ValueError(
|
||||
"length of 'which' argument must equal the number of primal input values, "
|
||||
f"but got {len(which)=} and {len(primals)=}")
|
||||
|
||||
dbg = debug_info("saved_input_vjp", f, primals, {})
|
||||
fun = lu.wrap_init(f, debug_info=dbg)
|
||||
primals_flat, in_tree = tree_flatten(primals)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
|
||||
out_primals_flat, _, jaxpr, residuals = ad.linearize(fun, *primals_flat)
|
||||
primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w))
|
||||
id_map = {id(x): i for i, x in enumerate(primals_filt)}
|
||||
opaque_residuals = []
|
||||
res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else
|
||||
RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore
|
||||
for r in residuals]
|
||||
f_vjp = Partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, out_tree(),
|
||||
jaxpr, opaque_residuals)
|
||||
|
||||
if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}):
|
||||
unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which))
|
||||
if w and id(x) not in res_ids]
|
||||
assert unused
|
||||
if len(unused) == 1:
|
||||
(i, a), = unused
|
||||
start, was = "an input value", "was"
|
||||
msg = f" {dbg.arg_names[i]} of type {a.str_short()}"
|
||||
else:
|
||||
start, was = "multiple input values", "were"
|
||||
msg = "\n" + "\n".join(f" * {dbg.arg_names[i]} of type {a.str_short()}"
|
||||
for i, a in unused)
|
||||
raise Exception(f"with {allow_unused=}, {start} marked to be saved {was} "
|
||||
f"not used by the backward pass:{msg}")
|
||||
|
||||
if not allow_opaque and opaque_residuals:
|
||||
msg = ", ".join(core.get_aval(x).str_short() for x in opaque_residuals)
|
||||
raise Exception(f"with {allow_opaque=}, the backward pass requires opaque "
|
||||
f"(non-input) residuals: {msg}")
|
||||
|
||||
out_primals = tree_unflatten(out_tree(), out_primals_flat)
|
||||
return out_primals, f_vjp
|
||||
|
||||
def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, jaxpr,
|
||||
opaque_residuals, ct, *saved_primals):
|
||||
primals_filtered, filtered_tree_ = tree_flatten(saved_primals)
|
||||
if filtered_tree != filtered_tree_:
|
||||
raise ValueError(
|
||||
"inputs passed to f_vjp must be a tuple of (pytrees of) "
|
||||
"arrays with the same structure as\n"
|
||||
" tuple(x for x, w in zip(inputs, which) if w)\n"
|
||||
"given the original call\n"
|
||||
" _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n"
|
||||
"but the structures differ:\n" +
|
||||
"\n".join(f" * inputs{keystr(path)} was a {thing1} in the original "
|
||||
f"call, but a {thing2} here, so {explanation}"
|
||||
for path, thing1, thing2, explanation
|
||||
in equality_errors_pytreedef(filtered_tree, filtered_tree_)))
|
||||
|
||||
residuals = [primals_filtered[i.idx] if i.primal else opaque_residuals[i.idx]
|
||||
for i in res_spec]
|
||||
dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars]
|
||||
cts_flat, out_tree_ = tree_flatten(ct)
|
||||
assert out_tree_ == out_tree
|
||||
arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat)
|
||||
return tree_unflatten(in_tree, arg_cts)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RSpec:
|
||||
idx: int
|
||||
primal: bool
|
||||
|
||||
si_vjp = saved_input_vjp
|
||||
|
||||
|
||||
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
"""Transpose a function that is promised to be linear.
|
||||
|
||||
|
@ -19,6 +19,10 @@ from jax.experimental.x64_context import (
|
||||
enable_x64 as enable_x64,
|
||||
disable_x64 as disable_x64,
|
||||
)
|
||||
from jax._src.api import (
|
||||
saved_input_vjp as saved_input_vjp,
|
||||
si_vjp as si_vjp
|
||||
)
|
||||
from jax._src.callback import (
|
||||
io_callback as io_callback
|
||||
)
|
||||
|
@ -11496,5 +11496,63 @@ class OverrideLoweringTest(jtu.JaxTestCase):
|
||||
self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir)
|
||||
|
||||
|
||||
class InputSavedVJPTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
def f(x, y):
|
||||
return x * y
|
||||
|
||||
primals = 2., 3.
|
||||
y, f_vjp = api.si_vjp(f, [True, True], *primals)
|
||||
arg_cts = f_vjp(1., *primals)
|
||||
self.assertAllClose(y, 6.)
|
||||
self.assertAllClose(arg_cts, (3., 2.))
|
||||
|
||||
def test_basic_unused(self):
|
||||
f = jnp.sin
|
||||
primals = 3.,
|
||||
y, f_vjp = api.si_vjp(f, [True], *primals)
|
||||
x_ct, = f_vjp(1., *primals)
|
||||
self.assertAllClose(y, jnp.sin(3.))
|
||||
self.assertAllClose(x_ct, jnp.cos(3.))
|
||||
|
||||
with self.assertRaisesRegex(Exception, "not used by the backward pass: x"):
|
||||
_ = api.si_vjp(f, [True], *primals, allow_unused=False)
|
||||
|
||||
def test_basic_opaque(self):
|
||||
f = jnp.sin
|
||||
primals = 3.,
|
||||
with self.assertRaisesRegex(Exception, "the backward pass requires opaque"):
|
||||
_ = api.si_vjp(f, [True], *primals, allow_opaque=False)
|
||||
|
||||
def test_basic_pytree_error(self):
|
||||
def f(x):
|
||||
return [x['hi'] * x['bye']]
|
||||
|
||||
y, f_vjp = api.si_vjp(f, [True], {'hi': 2., 'bye': 3.})
|
||||
arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.})
|
||||
self.assertAllClose(y, [6.])
|
||||
self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.})
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "but the structures differ"):
|
||||
f_vjp(1., {'hi': 2.})
|
||||
|
||||
def test_fsdp(self):
|
||||
# see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp"
|
||||
def f2(x, w):
|
||||
x = 1. * x
|
||||
x = x @ w
|
||||
x = 2. * x
|
||||
return x
|
||||
|
||||
x = jnp.ones((3, 4))
|
||||
w = jnp.ones((4, 4))
|
||||
y, f2_sivjp = api.si_vjp(f2, [False, True], x, w)
|
||||
y_grad = jnp.ones_like(y)
|
||||
x_grad, w_grad = f2_sivjp(y_grad, w)
|
||||
self.assertAllClose(x_grad, 2. * y_grad @ w.T)
|
||||
self.assertAllClose(w_grad, 2. * x.T @ y_grad)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user