Merge pull request #6726 from njunge94:auxiliary_solver_data

PiperOrigin-RevId: 376899659
This commit is contained in:
jax authors 2021-06-01 12:58:39 -07:00
commit edd203e305
2 changed files with 115 additions and 20 deletions

View File

@ -49,7 +49,7 @@ from jax._src.util import (partial, unzip2, unzip3, safe_map, safe_zip,
split_list, cache, extend_name_stack)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple, tree_multimap,
tree_leaves)
tree_leaves, tree_structure)
from jax import ad_util
from jax.config import config
@ -1954,7 +1954,19 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
f"{tree_unflatten(tree2, avals2)}.")
def _check_tree(func_name, expected_name, actual_tree, expected_tree):
def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
if has_aux:
actual_tree_children = actual_tree.children()
if len(actual_tree_children) == 2:
# select first child as result tree
actual_tree = tree_structure(actual_tree_children[0])
else:
raise ValueError(
f"{func_name}() produced a pytree with structure "
f"{actual_tree}, but a pytree tuple with auxiliary "
f"output was expected because has_aux was set to True.")
if actual_tree != expected_tree:
raise TypeError(
f"{func_name}() output pytree structure must match {expected_name}, "
@ -2141,7 +2153,7 @@ def _check_shapes(func_name, expected_name, actual, expected):
@api_boundary
def custom_linear_solve(
matvec, b, solve, transpose_solve=None, symmetric=False):
matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False):
"""Perform a matrix-free linear solve with implicitly defined gradients.
This function allows for overriding or defining gradients for a linear
@ -2160,7 +2172,7 @@ def custom_linear_solve(
b: constant right handle side of the equation. May be any nested structure
of arrays.
solve: higher level function that solves for solution to the linear
equation, i.e., ``solve(matvec, x)) == x`` for all ``x`` of the same form
equation, i.e., ``solve(matvec, x) == x`` for all ``x`` of the same form
as ``b``. This function need not be differentiable.
transpose_solve: higher level function for solving the transpose linear
equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
@ -2169,10 +2181,12 @@ def custom_linear_solve(
``symmetric=True``, in which case ``solve`` provides the default value.
symmetric: bool indicating if it is safe to assume the linear map
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
has_aux: bool indicating whether the ``solve`` and ``transpose_solve`` functions
return auxiliary data like solver diagnostics as a second argument.
Returns:
Result of ``solve(matvec, b)``, with gradients defined assuming that the
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
"""
if transpose_solve is None and symmetric:
transpose_solve = solve
@ -2182,22 +2196,29 @@ def custom_linear_solve(
tree, = treedef_children(in_args_tree)
def _shape_checked(fun, name):
def _shape_checked(fun, name, has_aux):
def f(x):
y = fun(x)
_check_shapes(name, "b", y, b_flat)
return y
return f
def f_aux(x):
y, aux = fun(x)
_check_shapes(name, "b", y, b_flat)
return y, aux
return f_aux if has_aux else f
# no auxiliary data assumed for matvec
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
_shape_checked(matvec, "matvec"), in_args_tree, b_avals,
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
'custom_linear_solve')
_check_tree("matvec", "b", out_tree, tree)
_check_tree("matvec", "b", out_tree, tree, False)
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(solve, matvec), "solve"), in_args_tree, b_avals,
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
'custom_linear_solve')
_check_tree("solve", "b", out_tree, tree)
_check_tree("solve", "b", out_tree, tree, has_aux)
if transpose_solve is None:
vecmat_jaxpr = tr_solve_jaxpr = None
@ -2214,9 +2235,9 @@ def custom_linear_solve(
assert out_tree == tree
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve"),
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux),
in_args_tree, b_avals, 'custom_linear_solve')
_check_tree("transpose_solve", "b", out_tree, tree)
_check_tree("transpose_solve", "b", out_tree, tree, has_aux)
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
@ -2226,11 +2247,21 @@ def custom_linear_solve(
out_flat = linear_solve_p.bind(
*(_flatten(all_consts) + b_flat),
const_lengths=const_lengths, jaxprs=jaxprs)
return tree_unflatten(tree, out_flat)
return tree_unflatten(out_tree, out_flat)
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
return _map(raise_to_shaped, args[sum(const_lengths):])
args_to_raise = args[sum(const_lengths):]
# raise aux_args to shaped arrays as well if present
# number of aux args is the difference in out_avals
# of solve and matvec (since they map to the same vector space)
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
if num_aux > 0:
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
return _map(raise_to_shaped, args_to_raise)
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
@ -2263,16 +2294,29 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
params, _ = _split_linear_solve_args(primals, const_lengths)
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
num_x_leaves = len(b_dot)
# x is a flat tree with possible aux values appended
# since x_tree == b_tree == b_dot_tree, we can cut off
# aux values with len info provided by b_dot tree here
x_leaves, _ = split_list(x, [num_x_leaves])
if all(type(p) is ad_util.Zero for p in params_dot.matvec):
# no need to evaluate matvec_tangents
rhs = b_dot
else:
matvec_tangents = _tangent_linear_map(
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x)
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves)
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
# split into x tangents and aux tangents (these become zero)
dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves])
daux_leaves = _map(ad_util.Zero.from_value, daux_leaves)
x_dot = dx_leaves + daux_leaves
return x, x_dot
@ -2282,10 +2326,14 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
'differentiation of custom_linear_solve')
params, b = _split_linear_solve_args(primals, const_lengths)
# split off symbolic zeros in the cotangent if present
x_cotangent, _ = split_list(cotangent, [len(b)])
assert all(ad.is_undefined_primal(x) for x in b)
cotangent_b = linear_solve_p.bind(
*(_flatten(params.transpose()) + cotangent),
cotangent_b_full = linear_solve_p.bind(
*(_flatten(params.transpose()) + x_cotangent),
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose())
# drop aux values in cotangent computation
cotangent_b, _ = split_list(cotangent_b_full, [len(b)])
return [None] * sum(const_lengths) + cotangent_b
@ -2302,6 +2350,7 @@ def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths,
(matvec, vecmat, solve, solve_t) = jaxprs
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
num_aux = len(solve.out_avals) - len(matvec.out_avals)
# Fixpoint computation of which parts of x and b are batched; we need to
# ensure this is consistent between all four jaxprs
b_bat = orig_b_bat
@ -2318,7 +2367,9 @@ def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths,
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat,
axis_name=axis_name, main_type=main_type)
x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat)
# batch all aux data by default
x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat)
# Apply matvec and solve_t -> new batched parts of b
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat,
@ -2360,7 +2411,7 @@ def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths,
*(new_params + new_b),
const_lengths=const_lengths,
jaxprs=batched_jaxprs)
out_dims = [0 if batched else batching.not_mapped for batched in b_bat]
out_dims = [0 if batched else batching.not_mapped for batched in solve_x_bat]
return outs, out_dims

View File

@ -2067,6 +2067,50 @@ class LaxControlFlowTest(jtu.JaxTestCase):
actual = api.vmap(linear_solve, (None, 1), 1)(a, c)
self.assertAllClose(expected, actual)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_aux(self):
def explicit_jacobian_solve_aux(matvec, b):
x = lax.stop_gradient(jnp.linalg.solve(api.jacobian(matvec)(b), b))
return x, array_aux
def matrix_free_solve_aux(matvec, b):
return lax.custom_linear_solve(
matvec, b, explicit_jacobian_solve_aux, explicit_jacobian_solve_aux,
symmetric=True, has_aux=True)
def linear_solve_aux(a, b):
return matrix_free_solve_aux(partial(high_precision_dot, a), b)
# array aux values, to be able to use jtu.check_grads
array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)}
rng = np.random.RandomState(0)
a = rng.randn(3, 3)
a = a + a.T
b = rng.randn(3)
expected = jnp.linalg.solve(a, b)
actual_nojit, nojit_aux = linear_solve_aux(a, b)
actual_jit, jit_aux = api.jit(linear_solve_aux)(a, b)
self.assertAllClose(expected, actual_nojit)
self.assertAllClose(expected, actual_jit)
# scalar dict equality check
self.assertDictEqual(nojit_aux, array_aux)
self.assertDictEqual(jit_aux, array_aux)
# jvp / vjp test
jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=2e-3)
# vmap test
c = rng.randn(3, 2)
expected = jnp.linalg.solve(a, c)
expected_aux = tree_util.tree_map(partial(np.repeat, repeats=2), array_aux)
actual_vmap, vmap_aux = api.vmap(linear_solve_aux, (None, 1), -1)(a, c)
self.assertAllClose(expected, actual_vmap)
jtu.check_eq(expected_aux, vmap_aux)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_custom_linear_solve_zeros(self):
def explicit_jacobian_solve(matvec, b):