mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Support closures in all arguments of lax.custom_root (#1570)
* WIP: linear solvers * Draft of lax.linear_solve * Refactor pytree munging inside lax.root. The primitive's implementation and JVP rules are now 100% pytree free. * Fixup linear_solve * Linearize multiple times in _root_jvp to avoid zeros * fix deftraced * add a symmetric argument * Fixup float64; add a test for symmetric/non-symmetric * test zeros in linear_solve_jvp * Revisions per review * Adjust signature of linear_solve * restore botched test * variable names * WIP: root solve jaxpr * WIP more tests * rewrite root * Root works with jaxprs * root -> custom_root * WIP undefined tangent * Delayed undefined JVP errors * use raise_on_undefined_tangents inside define_implicit_gradient * more tests on jvps with undefined tangents * Remove define_implicit_gradient * Support closures in custom_root * revert api-test * another test * jit tests * spelling
This commit is contained in:
parent
4595d43650
commit
5bcbce744e
@ -117,7 +117,6 @@ Operators
|
||||
sort_key_val
|
||||
sqrt
|
||||
square
|
||||
stop_gradient
|
||||
sub
|
||||
tan
|
||||
tie_in
|
||||
@ -136,6 +135,15 @@ Control flow operators
|
||||
scan
|
||||
while_loop
|
||||
|
||||
Custom gradient operators
|
||||
-------------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
stop_gradient
|
||||
custom_linear_solve
|
||||
custom_root
|
||||
|
||||
Parallel operators
|
||||
------------------
|
||||
|
@ -44,7 +44,7 @@ from jax.lib import xla_client
|
||||
from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,
|
||||
split_dict, cache)
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
treedef_children, tree_map)
|
||||
treedef_children, treedef_tuple)
|
||||
from jax import ad_util
|
||||
|
||||
_map = safe_map
|
||||
@ -858,32 +858,6 @@ def _memcpy(axis, num, src, dst, offset):
|
||||
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule
|
||||
|
||||
|
||||
def _flatten_higher_order_func(
|
||||
f, tree, func_name, input_name,
|
||||
):
|
||||
"""Flatten a higher order function ``f`` of the form ``f(g, x)``.
|
||||
|
||||
``f`` must have the type signature:
|
||||
|
||||
.. code-block:: haskell
|
||||
|
||||
f :: (a -> a) -> a -> a
|
||||
|
||||
```a`` many be any arbitrary fixed pytree structure. The returned function has
|
||||
the same structure as ``f``, except every appearence of ``a`` is replaced by a
|
||||
flat sequence of arrays in the style used internally by JAX primitives
|
||||
(variadic ``*args`` arguments in function calls, lists in return values).
|
||||
"""
|
||||
def flat_fun(flat_g, *args_flat):
|
||||
args = tree_unflatten(tree, args_flat)
|
||||
g = partial(apply_flat_fun_nokwargs, flat_g, (tree, tree))
|
||||
out = f(g, args)
|
||||
out_flat, out_tree = tree_flatten(out)
|
||||
_check_tree(func_name, input_name, out_tree, tree)
|
||||
return out_flat
|
||||
return flat_fun
|
||||
|
||||
|
||||
def _check_tree(func_name, expected_name, actual_tree, expected_tree):
|
||||
if actual_tree != expected_tree:
|
||||
raise TypeError(
|
||||
@ -908,12 +882,33 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
||||
tree_unflatten(tree2, avals2)))
|
||||
|
||||
|
||||
def root(f, initial_guess, solve, tangent_solve):
|
||||
def _stop_gradient_fun(f):
|
||||
"""Create a version of f() that stops all gradients."""
|
||||
def wrapper(*args, **kwargs):
|
||||
args_flat, in_args_tree = tree_flatten((args, kwargs))
|
||||
args_avals = tuple(_map(_abstractify, args_flat))
|
||||
g = lambda a, b: f(*a, **b)
|
||||
jaxpr, consts, out_tree = _initial_style_jaxpr(g, in_args_tree, args_avals)
|
||||
out = core.jaxpr_as_fun(jaxpr)(*lax.stop_gradient(consts + tuple(args_flat)))
|
||||
return tree_unflatten(out_tree, out)
|
||||
return wrapper
|
||||
|
||||
|
||||
_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
|
||||
|
||||
|
||||
def _split_root_args(args, const_lengths):
|
||||
params_list = split_list(args, list(const_lengths))
|
||||
return _RootTuple(*params_list[:-1]), params_list[-1]
|
||||
|
||||
|
||||
def custom_root(f, initial_guess, solve, tangent_solve):
|
||||
"""Differentiably solve for a roots of a function.
|
||||
|
||||
This is a low-level routine, mostly intended for internal use in JAX.
|
||||
Gradients of root() are defined with respect to closed-over variables from
|
||||
the provided function f.
|
||||
Gradients of custom_root() are defined with respect to closed-over variables
|
||||
from the provided function ``f`` via the implicit function theorem:
|
||||
https://en.wikipedia.org/wiki/Implicit_function_theorem
|
||||
|
||||
Args:
|
||||
f: function for which to find a root. Should accept a single argument,
|
||||
@ -945,39 +940,51 @@ def root(f, initial_guess, solve, tangent_solve):
|
||||
"""
|
||||
guess_flat, in_args_tree = tree_flatten((initial_guess,))
|
||||
guess_avals = tuple(_map(_abstractify, guess_flat))
|
||||
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals)
|
||||
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
|
||||
f, in_args_tree, guess_avals)
|
||||
|
||||
in_tree, = treedef_children(in_args_tree)
|
||||
_check_tree("f", "initial_guess", out_tree, in_tree)
|
||||
|
||||
solve_flat = _flatten_higher_order_func(
|
||||
solve, in_tree, "solve", "initial_guess")
|
||||
tangent_solve_flat = _flatten_higher_order_func(
|
||||
tangent_solve, in_tree, "tangent_solve", "initial_guess")
|
||||
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
|
||||
partial(solve, _stop_gradient_fun(f)), in_args_tree, guess_avals)
|
||||
_check_tree("solve", "initial_guess", solution_tree, in_tree)
|
||||
|
||||
out_flat = root_p.bind(*itertools.chain(consts, guess_flat),
|
||||
num_consts=len(consts), jaxpr=jaxpr, solve=solve_flat,
|
||||
tangent_solve=tangent_solve_flat)
|
||||
def linearize_and_solve(x, b):
|
||||
unchecked_zeros, f_jvp = api.linearize(f, x)
|
||||
return tangent_solve(f_jvp, b)
|
||||
|
||||
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
|
||||
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
|
||||
_check_tree("tangent_solve", "x", out_tree, in_tree)
|
||||
|
||||
all_consts = [f_consts, solve_consts, l_and_s_consts]
|
||||
const_lengths = _RootTuple(*_map(len, all_consts))
|
||||
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
|
||||
|
||||
out_flat = root_p.bind(
|
||||
*(_flatten(all_consts) + guess_flat),
|
||||
const_lengths=const_lengths, jaxprs=jaxprs)
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
|
||||
def _root_abstract_eval(*args, **kwargs):
|
||||
return args[kwargs['num_consts']:]
|
||||
return args[sum(kwargs['const_lengths']):]
|
||||
|
||||
|
||||
def _root_impl(*args, **kwargs):
|
||||
num_consts, jaxpr, solve, _ = split_dict(
|
||||
kwargs, ['num_consts', 'jaxpr', 'solve', 'tangent_solve'])
|
||||
params, initial_guess = split_list(args, [num_consts])
|
||||
f = partial(core.jaxpr_as_fun(jaxpr), *params)
|
||||
return solve(f, *initial_guess)
|
||||
const_lengths, jaxprs = split_dict(kwargs, ['const_lengths', 'jaxprs'])
|
||||
params, initial_guess = _split_root_args(args, const_lengths)
|
||||
solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
|
||||
return solution
|
||||
|
||||
|
||||
def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve):
|
||||
params = primals[:num_consts]
|
||||
solution = tuple(root_p.bind(*primals, num_consts=num_consts, jaxpr=jaxpr,
|
||||
solve=solve, tangent_solve=tangent_solve))
|
||||
params_dot = tangents[:num_consts]
|
||||
def _root_jvp(primals, tangents, const_lengths, jaxprs):
|
||||
params, _ = _split_root_args(primals, const_lengths)
|
||||
solution = tuple(root_p.bind(
|
||||
*primals, const_lengths=const_lengths, jaxprs=jaxprs))
|
||||
|
||||
params_dot, _ = _split_root_args(tangents, const_lengths)
|
||||
|
||||
# F(m, u) = 0 # system of equations in u, parameterized by m
|
||||
# # solution is u*(m) defined in a neighborhood
|
||||
@ -988,13 +995,14 @@ def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve):
|
||||
#
|
||||
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp
|
||||
|
||||
f = core.jaxpr_as_fun(jaxpr)
|
||||
f_fixed_params = lambda *solution: f(*(params + solution))
|
||||
f_fixed_solution = lambda *params: f(*(params + solution))
|
||||
|
||||
_, rhs = ad.jvp(lu.wrap_init(f_fixed_solution)).call_wrapped(params, params_dot)
|
||||
_, f_jvp_wrt_solution = api.linearize(f_fixed_params, *solution)
|
||||
solution_dot = [-x for x in tangent_solve(f_jvp_wrt_solution, *rhs)]
|
||||
f = core.jaxpr_as_fun(jaxprs.f)
|
||||
linearize_and_solve = partial(
|
||||
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
|
||||
f_at_solution = lambda *params: f(*itertools.chain(params, solution))
|
||||
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
|
||||
params.f, params_dot.f)
|
||||
solution_dot = _map(
|
||||
operator.neg, linearize_and_solve(*itertools.chain(solution, rhs)))
|
||||
|
||||
return solution, solution_dot
|
||||
|
||||
@ -1004,7 +1012,8 @@ root_p.multiple_results = True
|
||||
root_p.def_impl(_root_impl)
|
||||
root_p.def_abstract_eval(_root_abstract_eval)
|
||||
ad.primitive_jvps[root_p] = _root_jvp
|
||||
xla.initial_style_translations[root_p] = xla.lower_fun(_root_impl, initial_style=True)
|
||||
xla.initial_style_translations[root_p] = xla.lower_fun(
|
||||
_root_impl, initial_style=True)
|
||||
# TODO(shoyer): write batching rule
|
||||
|
||||
|
||||
@ -1122,13 +1131,13 @@ def custom_linear_solve(
|
||||
jaxprs = _LinearSolveTuple(
|
||||
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
|
||||
|
||||
out_flat = custom_linear_solve_p.bind(
|
||||
out_flat = linear_solve_p.bind(
|
||||
*(_flatten(all_consts) + b_flat),
|
||||
const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
|
||||
return tree_unflatten(tree, out_flat)
|
||||
|
||||
|
||||
def _custom_linear_solve_abstract_eval(*args, **kwargs):
|
||||
def _linear_solve_abstract_eval(*args, **kwargs):
|
||||
return args[sum(kwargs['const_lengths']):]
|
||||
|
||||
|
||||
@ -1160,7 +1169,7 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs, tree):
|
||||
# ∂x = A^{-1} (∂b - ∂A x)
|
||||
|
||||
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
|
||||
x = custom_linear_solve_p.bind(*primals, **kwargs)
|
||||
x = linear_solve_p.bind(*primals, **kwargs)
|
||||
|
||||
params, _ = _split_linear_solve_args(primals, const_lengths)
|
||||
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
|
||||
@ -1174,12 +1183,12 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs, tree):
|
||||
_check_shapes("matvec", "b", matvec_tangents, x, tree)
|
||||
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
|
||||
|
||||
x_dot = custom_linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
|
||||
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
|
||||
|
||||
return x, x_dot
|
||||
|
||||
|
||||
def _custom_linear_solve_transpose_rule(cotangent, *primals, **kwargs):
|
||||
def _linear_solve_transpose_rule(cotangent, *primals, **kwargs):
|
||||
const_lengths, jaxprs, tree = split_dict(
|
||||
kwargs, ['const_lengths', 'jaxprs', 'tree'])
|
||||
|
||||
@ -1189,19 +1198,19 @@ def _custom_linear_solve_transpose_rule(cotangent, *primals, **kwargs):
|
||||
|
||||
params, b = _split_linear_solve_args(primals, const_lengths)
|
||||
assert b == [ad.undefined_primal] * len(b)
|
||||
cotangent_b = custom_linear_solve_p.bind(
|
||||
cotangent_b = linear_solve_p.bind(
|
||||
*(_flatten(params.transpose()) + cotangent),
|
||||
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose(),
|
||||
tree=tree)
|
||||
return [None] * sum(const_lengths) + cotangent_b
|
||||
|
||||
|
||||
custom_linear_solve_p = core.Primitive('custom_linear_solve')
|
||||
custom_linear_solve_p.multiple_results = True
|
||||
custom_linear_solve_p.def_impl(_custom_linear_solve_impl)
|
||||
custom_linear_solve_p.def_abstract_eval(_custom_linear_solve_abstract_eval)
|
||||
ad.primitive_jvps[custom_linear_solve_p] = _custom_linear_solve_jvp
|
||||
xla.initial_style_translations[custom_linear_solve_p] = xla.lower_fun(
|
||||
linear_solve_p = core.Primitive('custom_linear_solve')
|
||||
linear_solve_p.multiple_results = True
|
||||
linear_solve_p.def_impl(_custom_linear_solve_impl)
|
||||
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
|
||||
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
|
||||
xla.initial_style_translations[linear_solve_p] = xla.lower_fun(
|
||||
_custom_linear_solve_impl, initial_style=True)
|
||||
ad.primitive_transposes[custom_linear_solve_p] = _custom_linear_solve_transpose_rule
|
||||
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
|
||||
# TODO(shoyer): write batching rule
|
||||
|
@ -1061,7 +1061,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
api.grad(lambda x: jit_run_scan(x))(0.) # doesn't crash
|
||||
|
||||
def test_root_scalar(self):
|
||||
def test_custom_root_scalar(self):
|
||||
|
||||
# TODO(shoyer): Figure out why this fails and re-enable it, if possible. My
|
||||
# best guess is that TPUs use less stable numerics for pow().
|
||||
@ -1091,7 +1091,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def sqrt_cubed(x, tangent_solve=scalar_solve):
|
||||
f = lambda y: y ** 2 - x ** 3
|
||||
return lax.root(f, 0.0, binary_search, tangent_solve)
|
||||
return lax.custom_root(f, 0.0, binary_search, tangent_solve)
|
||||
|
||||
value, grad = api.value_and_grad(sqrt_cubed)(5.0)
|
||||
self.assertAllClose(value, 5 ** 1.5, check_dtypes=False)
|
||||
@ -1106,33 +1106,61 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
results = api.jit(sqrt_cubed)(5.0)
|
||||
self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False)
|
||||
|
||||
def test_root_vector(self):
|
||||
def oracle(func, x0):
|
||||
del func # unused
|
||||
return x0
|
||||
def test_custom_root_vector_with_solve_closure(self):
|
||||
|
||||
def vector_solve(f, y):
|
||||
return np.linalg.solve(api.jacobian(f)(y), y)
|
||||
|
||||
def linear_solve(a, b):
|
||||
f = lambda y: high_precision_dot(a, y) - b
|
||||
x0 = np.linalg.solve(a, b)
|
||||
return lax.root(f, x0, oracle, vector_solve)
|
||||
x0 = np.zeros_like(b)
|
||||
solution = np.linalg.solve(a, b)
|
||||
oracle = lambda func, x0: solution
|
||||
return lax.custom_root(f, x0, oracle, vector_solve)
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
a = rng.randn(2, 2)
|
||||
b = rng.randn(2)
|
||||
jtu.check_grads(linear_solve, (a, b), order=2)
|
||||
|
||||
def test_root_errors(self):
|
||||
actual = api.jit(linear_solve)(a, b)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
def test_custom_root_with_custom_linear_solve(self):
|
||||
|
||||
def linear_solve(a, b):
|
||||
f = lambda x: np.dot(a, x) - b
|
||||
factors = jsp.linalg.cho_factor(a)
|
||||
cho_solve = lambda f, b: jsp.linalg.cho_solve(factors, b)
|
||||
def pos_def_solve(g, b):
|
||||
return lax.custom_linear_solve(g, b, cho_solve, symmetric=True)
|
||||
return lax.custom_root(f, b, cho_solve, pos_def_solve)
|
||||
|
||||
rng = onp.random.RandomState(0)
|
||||
a = rng.randn(2, 2)
|
||||
b = rng.randn(2)
|
||||
|
||||
actual = linear_solve(np.dot(a, a.T), b)
|
||||
expected = np.linalg.solve(np.dot(a, a.T), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
actual = api.jit(linear_solve)(np.dot(a, a.T), b)
|
||||
expected = np.linalg.solve(np.dot(a, a.T), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
jtu.check_grads(lambda x, y: linear_solve(np.dot(x, x.T), y),
|
||||
(a, b), order=2)
|
||||
|
||||
def test_custom_root_errors(self):
|
||||
with self.assertRaisesRegex(TypeError, re.escape("f() output pytree")):
|
||||
lax.root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x)
|
||||
lax.custom_root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x)
|
||||
with self.assertRaisesRegex(TypeError, re.escape("solve() output pytree")):
|
||||
lax.root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x)
|
||||
lax.custom_root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x)
|
||||
|
||||
def dummy_root_usage(x):
|
||||
f = lambda y: x - y
|
||||
return lax.root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))
|
||||
return lax.custom_root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, re.escape("tangent_solve() output pytree")):
|
||||
@ -1224,7 +1252,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
def test_custom_linear_solve_cholesky(self):
|
||||
|
||||
def positive_definive_solve(a, b):
|
||||
def positive_definite_solve(a, b):
|
||||
factors = jsp.linalg.cho_factor(a)
|
||||
def solve(matvec, x):
|
||||
return jsp.linalg.cho_solve(factors, x)
|
||||
@ -1236,16 +1264,16 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
b = rng.randn(2)
|
||||
|
||||
expected = np.linalg.solve(high_precision_dot(a, a.T), b)
|
||||
actual = positive_definive_solve(high_precision_dot(a, a.T), b)
|
||||
actual = positive_definite_solve(high_precision_dot(a, a.T), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
actual = api.jit(positive_definive_solve)(high_precision_dot(a, a.T), b)
|
||||
actual = api.jit(positive_definite_solve)(high_precision_dot(a, a.T), b)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
# numerical gradients are only well defined if ``a`` is guaranteed to be
|
||||
# positive definite.
|
||||
jtu.check_grads(
|
||||
lambda x, y: positive_definive_solve(high_precision_dot(x, x.T), y),
|
||||
lambda x, y: positive_definite_solve(high_precision_dot(x, x.T), y),
|
||||
(a, b), order=2)
|
||||
|
||||
def test_custom_linear_solve_lu(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user