Support closures in all arguments of lax.custom_root (#1570)

* WIP: linear solvers

* Draft of lax.linear_solve

* Refactor pytree munging inside lax.root.

The primitive's implementation and JVP rules are now 100% pytree free.

* Fixup linear_solve

* Linearize multiple times in _root_jvp to avoid zeros

* fix deftraced

* add a symmetric argument

* Fixup float64; add a test for symmetric/non-symmetric

* test zeros in linear_solve_jvp

* Revisions per review

* Adjust signature of linear_solve

* restore botched test

* variable names

* WIP: root solve jaxpr

* WIP more tests

* rewrite root

* Root works with jaxprs

* root -> custom_root

* WIP undefined tangent

* Delayed undefined JVP errors

* use raise_on_undefined_tangents inside define_implicit_gradient

* more tests on jvps with undefined tangents

* Remove define_implicit_gradient

* Support closures in custom_root

* revert api-test

* another test

* jit tests

* spelling
This commit is contained in:
Stephan Hoyer 2019-10-29 16:00:00 -07:00 committed by GitHub
parent 4595d43650
commit 5bcbce744e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 87 deletions

View File

@ -117,7 +117,6 @@ Operators
sort_key_val
sqrt
square
stop_gradient
sub
tan
tie_in
@ -136,6 +135,15 @@ Control flow operators
scan
while_loop
Custom gradient operators
-------------------------
.. autosummary::
:toctree: _autosummary
stop_gradient
custom_linear_solve
custom_root
Parallel operators
------------------

View File

@ -44,7 +44,7 @@ from jax.lib import xla_client
from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,
split_dict, cache)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, tree_map)
treedef_children, treedef_tuple)
from jax import ad_util
_map = safe_map
@ -858,32 +858,6 @@ def _memcpy(axis, num, src, dst, offset):
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule
def _flatten_higher_order_func(
f, tree, func_name, input_name,
):
"""Flatten a higher order function ``f`` of the form ``f(g, x)``.
``f`` must have the type signature:
.. code-block:: haskell
f :: (a -> a) -> a -> a
```a`` many be any arbitrary fixed pytree structure. The returned function has
the same structure as ``f``, except every appearence of ``a`` is replaced by a
flat sequence of arrays in the style used internally by JAX primitives
(variadic ``*args`` arguments in function calls, lists in return values).
"""
def flat_fun(flat_g, *args_flat):
args = tree_unflatten(tree, args_flat)
g = partial(apply_flat_fun_nokwargs, flat_g, (tree, tree))
out = f(g, args)
out_flat, out_tree = tree_flatten(out)
_check_tree(func_name, input_name, out_tree, tree)
return out_flat
return flat_fun
def _check_tree(func_name, expected_name, actual_tree, expected_tree):
if actual_tree != expected_tree:
raise TypeError(
@ -908,12 +882,33 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
tree_unflatten(tree2, avals2)))
def root(f, initial_guess, solve, tangent_solve):
def _stop_gradient_fun(f):
"""Create a version of f() that stops all gradients."""
def wrapper(*args, **kwargs):
args_flat, in_args_tree = tree_flatten((args, kwargs))
args_avals = tuple(_map(_abstractify, args_flat))
g = lambda a, b: f(*a, **b)
jaxpr, consts, out_tree = _initial_style_jaxpr(g, in_args_tree, args_avals)
out = core.jaxpr_as_fun(jaxpr)(*lax.stop_gradient(consts + tuple(args_flat)))
return tree_unflatten(out_tree, out)
return wrapper
_RootTuple = collections.namedtuple('_RootTuple', 'f, solve, l_and_s')
def _split_root_args(args, const_lengths):
params_list = split_list(args, list(const_lengths))
return _RootTuple(*params_list[:-1]), params_list[-1]
def custom_root(f, initial_guess, solve, tangent_solve):
"""Differentiably solve for a roots of a function.
This is a low-level routine, mostly intended for internal use in JAX.
Gradients of root() are defined with respect to closed-over variables from
the provided function f.
Gradients of custom_root() are defined with respect to closed-over variables
from the provided function ``f`` via the implicit function theorem:
https://en.wikipedia.org/wiki/Implicit_function_theorem
Args:
f: function for which to find a root. Should accept a single argument,
@ -945,39 +940,51 @@ def root(f, initial_guess, solve, tangent_solve):
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(_abstractify, guess_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals)
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals)
in_tree, = treedef_children(in_args_tree)
_check_tree("f", "initial_guess", out_tree, in_tree)
solve_flat = _flatten_higher_order_func(
solve, in_tree, "solve", "initial_guess")
tangent_solve_flat = _flatten_higher_order_func(
tangent_solve, in_tree, "tangent_solve", "initial_guess")
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
partial(solve, _stop_gradient_fun(f)), in_args_tree, guess_avals)
_check_tree("solve", "initial_guess", solution_tree, in_tree)
out_flat = root_p.bind(*itertools.chain(consts, guess_flat),
num_consts=len(consts), jaxpr=jaxpr, solve=solve_flat,
tangent_solve=tangent_solve_flat)
def linearize_and_solve(x, b):
unchecked_zeros, f_jvp = api.linearize(f, x)
return tangent_solve(f_jvp, b)
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2)
_check_tree("tangent_solve", "x", out_tree, in_tree)
all_consts = [f_consts, solve_consts, l_and_s_consts]
const_lengths = _RootTuple(*_map(len, all_consts))
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
out_flat = root_p.bind(
*(_flatten(all_consts) + guess_flat),
const_lengths=const_lengths, jaxprs=jaxprs)
return tree_unflatten(out_tree, out_flat)
def _root_abstract_eval(*args, **kwargs):
return args[kwargs['num_consts']:]
return args[sum(kwargs['const_lengths']):]
def _root_impl(*args, **kwargs):
num_consts, jaxpr, solve, _ = split_dict(
kwargs, ['num_consts', 'jaxpr', 'solve', 'tangent_solve'])
params, initial_guess = split_list(args, [num_consts])
f = partial(core.jaxpr_as_fun(jaxpr), *params)
return solve(f, *initial_guess)
const_lengths, jaxprs = split_dict(kwargs, ['const_lengths', 'jaxprs'])
params, initial_guess = _split_root_args(args, const_lengths)
solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess))
return solution
def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve):
params = primals[:num_consts]
solution = tuple(root_p.bind(*primals, num_consts=num_consts, jaxpr=jaxpr,
solve=solve, tangent_solve=tangent_solve))
params_dot = tangents[:num_consts]
def _root_jvp(primals, tangents, const_lengths, jaxprs):
params, _ = _split_root_args(primals, const_lengths)
solution = tuple(root_p.bind(
*primals, const_lengths=const_lengths, jaxprs=jaxprs))
params_dot, _ = _split_root_args(tangents, const_lengths)
# F(m, u) = 0 # system of equations in u, parameterized by m
# # solution is u*(m) defined in a neighborhood
@ -988,13 +995,14 @@ def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve):
#
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp
f = core.jaxpr_as_fun(jaxpr)
f_fixed_params = lambda *solution: f(*(params + solution))
f_fixed_solution = lambda *params: f(*(params + solution))
_, rhs = ad.jvp(lu.wrap_init(f_fixed_solution)).call_wrapped(params, params_dot)
_, f_jvp_wrt_solution = api.linearize(f_fixed_params, *solution)
solution_dot = [-x for x in tangent_solve(f_jvp_wrt_solution, *rhs)]
f = core.jaxpr_as_fun(jaxprs.f)
linearize_and_solve = partial(
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
f_at_solution = lambda *params: f(*itertools.chain(params, solution))
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
params.f, params_dot.f)
solution_dot = _map(
operator.neg, linearize_and_solve(*itertools.chain(solution, rhs)))
return solution, solution_dot
@ -1004,7 +1012,8 @@ root_p.multiple_results = True
root_p.def_impl(_root_impl)
root_p.def_abstract_eval(_root_abstract_eval)
ad.primitive_jvps[root_p] = _root_jvp
xla.initial_style_translations[root_p] = xla.lower_fun(_root_impl, initial_style=True)
xla.initial_style_translations[root_p] = xla.lower_fun(
_root_impl, initial_style=True)
# TODO(shoyer): write batching rule
@ -1122,13 +1131,13 @@ def custom_linear_solve(
jaxprs = _LinearSolveTuple(
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
out_flat = custom_linear_solve_p.bind(
out_flat = linear_solve_p.bind(
*(_flatten(all_consts) + b_flat),
const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
return tree_unflatten(tree, out_flat)
def _custom_linear_solve_abstract_eval(*args, **kwargs):
def _linear_solve_abstract_eval(*args, **kwargs):
return args[sum(kwargs['const_lengths']):]
@ -1160,7 +1169,7 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs, tree):
# ∂x = A^{-1} (∂b - ∂A x)
kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
x = custom_linear_solve_p.bind(*primals, **kwargs)
x = linear_solve_p.bind(*primals, **kwargs)
params, _ = _split_linear_solve_args(primals, const_lengths)
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
@ -1174,12 +1183,12 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs, tree):
_check_shapes("matvec", "b", matvec_tangents, x, tree)
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
x_dot = custom_linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
return x, x_dot
def _custom_linear_solve_transpose_rule(cotangent, *primals, **kwargs):
def _linear_solve_transpose_rule(cotangent, *primals, **kwargs):
const_lengths, jaxprs, tree = split_dict(
kwargs, ['const_lengths', 'jaxprs', 'tree'])
@ -1189,19 +1198,19 @@ def _custom_linear_solve_transpose_rule(cotangent, *primals, **kwargs):
params, b = _split_linear_solve_args(primals, const_lengths)
assert b == [ad.undefined_primal] * len(b)
cotangent_b = custom_linear_solve_p.bind(
cotangent_b = linear_solve_p.bind(
*(_flatten(params.transpose()) + cotangent),
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose(),
tree=tree)
return [None] * sum(const_lengths) + cotangent_b
custom_linear_solve_p = core.Primitive('custom_linear_solve')
custom_linear_solve_p.multiple_results = True
custom_linear_solve_p.def_impl(_custom_linear_solve_impl)
custom_linear_solve_p.def_abstract_eval(_custom_linear_solve_abstract_eval)
ad.primitive_jvps[custom_linear_solve_p] = _custom_linear_solve_jvp
xla.initial_style_translations[custom_linear_solve_p] = xla.lower_fun(
linear_solve_p = core.Primitive('custom_linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.initial_style_translations[linear_solve_p] = xla.lower_fun(
_custom_linear_solve_impl, initial_style=True)
ad.primitive_transposes[custom_linear_solve_p] = _custom_linear_solve_transpose_rule
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
# TODO(shoyer): write batching rule

View File

@ -1061,7 +1061,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
api.grad(lambda x: jit_run_scan(x))(0.) # doesn't crash
def test_root_scalar(self):
def test_custom_root_scalar(self):
# TODO(shoyer): Figure out why this fails and re-enable it, if possible. My
# best guess is that TPUs use less stable numerics for pow().
@ -1091,7 +1091,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def sqrt_cubed(x, tangent_solve=scalar_solve):
f = lambda y: y ** 2 - x ** 3
return lax.root(f, 0.0, binary_search, tangent_solve)
return lax.custom_root(f, 0.0, binary_search, tangent_solve)
value, grad = api.value_and_grad(sqrt_cubed)(5.0)
self.assertAllClose(value, 5 ** 1.5, check_dtypes=False)
@ -1106,33 +1106,61 @@ class LaxControlFlowTest(jtu.JaxTestCase):
results = api.jit(sqrt_cubed)(5.0)
self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False)
def test_root_vector(self):
def oracle(func, x0):
del func # unused
return x0
def test_custom_root_vector_with_solve_closure(self):
def vector_solve(f, y):
return np.linalg.solve(api.jacobian(f)(y), y)
def linear_solve(a, b):
f = lambda y: high_precision_dot(a, y) - b
x0 = np.linalg.solve(a, b)
return lax.root(f, x0, oracle, vector_solve)
x0 = np.zeros_like(b)
solution = np.linalg.solve(a, b)
oracle = lambda func, x0: solution
return lax.custom_root(f, x0, oracle, vector_solve)
rng = onp.random.RandomState(0)
a = rng.randn(2, 2)
b = rng.randn(2)
jtu.check_grads(linear_solve, (a, b), order=2)
def test_root_errors(self):
actual = api.jit(linear_solve)(a, b)
expected = np.linalg.solve(a, b)
self.assertAllClose(expected, actual, check_dtypes=True)
def test_custom_root_with_custom_linear_solve(self):
def linear_solve(a, b):
f = lambda x: np.dot(a, x) - b
factors = jsp.linalg.cho_factor(a)
cho_solve = lambda f, b: jsp.linalg.cho_solve(factors, b)
def pos_def_solve(g, b):
return lax.custom_linear_solve(g, b, cho_solve, symmetric=True)
return lax.custom_root(f, b, cho_solve, pos_def_solve)
rng = onp.random.RandomState(0)
a = rng.randn(2, 2)
b = rng.randn(2)
actual = linear_solve(np.dot(a, a.T), b)
expected = np.linalg.solve(np.dot(a, a.T), b)
self.assertAllClose(expected, actual, check_dtypes=True)
actual = api.jit(linear_solve)(np.dot(a, a.T), b)
expected = np.linalg.solve(np.dot(a, a.T), b)
self.assertAllClose(expected, actual, check_dtypes=True)
jtu.check_grads(lambda x, y: linear_solve(np.dot(x, x.T), y),
(a, b), order=2)
def test_custom_root_errors(self):
with self.assertRaisesRegex(TypeError, re.escape("f() output pytree")):
lax.root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x)
lax.custom_root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x)
with self.assertRaisesRegex(TypeError, re.escape("solve() output pytree")):
lax.root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x)
lax.custom_root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x)
def dummy_root_usage(x):
f = lambda y: x - y
return lax.root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))
return lax.custom_root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))
with self.assertRaisesRegex(
TypeError, re.escape("tangent_solve() output pytree")):
@ -1224,7 +1252,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def test_custom_linear_solve_cholesky(self):
def positive_definive_solve(a, b):
def positive_definite_solve(a, b):
factors = jsp.linalg.cho_factor(a)
def solve(matvec, x):
return jsp.linalg.cho_solve(factors, x)
@ -1236,16 +1264,16 @@ class LaxControlFlowTest(jtu.JaxTestCase):
b = rng.randn(2)
expected = np.linalg.solve(high_precision_dot(a, a.T), b)
actual = positive_definive_solve(high_precision_dot(a, a.T), b)
actual = positive_definite_solve(high_precision_dot(a, a.T), b)
self.assertAllClose(expected, actual, check_dtypes=True)
actual = api.jit(positive_definive_solve)(high_precision_dot(a, a.T), b)
actual = api.jit(positive_definite_solve)(high_precision_dot(a, a.T), b)
self.assertAllClose(expected, actual, check_dtypes=True)
# numerical gradients are only well defined if ``a`` is guaranteed to be
# positive definite.
jtu.check_grads(
lambda x, y: positive_definive_solve(high_precision_dot(x, x.T), y),
lambda x, y: positive_definite_solve(high_precision_dot(x, x.T), y),
(a, b), order=2)
def test_custom_linear_solve_lu(self):