mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #6726 from njunge94:auxiliary_solver_data
PiperOrigin-RevId: 376899659
This commit is contained in:
commit
edd203e305
@ -49,7 +49,7 @@ from jax._src.util import (partial, unzip2, unzip3, safe_map, safe_zip,
|
||||
split_list, cache, extend_name_stack)
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
treedef_children, treedef_tuple, tree_multimap,
|
||||
tree_leaves)
|
||||
tree_leaves, tree_structure)
|
||||
from jax import ad_util
|
||||
from jax.config import config
|
||||
|
||||
@ -1954,7 +1954,19 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
||||
f"{tree_unflatten(tree2, avals2)}.")
|
||||
|
||||
|
||||
def _check_tree(func_name, expected_name, actual_tree, expected_tree):
|
||||
def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
|
||||
if has_aux:
|
||||
actual_tree_children = actual_tree.children()
|
||||
|
||||
if len(actual_tree_children) == 2:
|
||||
# select first child as result tree
|
||||
actual_tree = tree_structure(actual_tree_children[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{func_name}() produced a pytree with structure "
|
||||
f"{actual_tree}, but a pytree tuple with auxiliary "
|
||||
f"output was expected because has_aux was set to True.")
|
||||
|
||||
if actual_tree != expected_tree:
|
||||
raise TypeError(
|
||||
f"{func_name}() output pytree structure must match {expected_name}, "
|
||||
@ -2141,7 +2153,7 @@ def _check_shapes(func_name, expected_name, actual, expected):
|
||||
|
||||
@api_boundary
|
||||
def custom_linear_solve(
|
||||
matvec, b, solve, transpose_solve=None, symmetric=False):
|
||||
matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False):
|
||||
"""Perform a matrix-free linear solve with implicitly defined gradients.
|
||||
|
||||
This function allows for overriding or defining gradients for a linear
|
||||
@ -2160,7 +2172,7 @@ def custom_linear_solve(
|
||||
b: constant right handle side of the equation. May be any nested structure
|
||||
of arrays.
|
||||
solve: higher level function that solves for solution to the linear
|
||||
equation, i.e., ``solve(matvec, x)) == x`` for all ``x`` of the same form
|
||||
equation, i.e., ``solve(matvec, x) == x`` for all ``x`` of the same form
|
||||
as ``b``. This function need not be differentiable.
|
||||
transpose_solve: higher level function for solving the transpose linear
|
||||
equation, i.e., ``transpose_solve(vecmat, x) == x``, where ``vecmat`` is
|
||||
@ -2169,10 +2181,12 @@ def custom_linear_solve(
|
||||
``symmetric=True``, in which case ``solve`` provides the default value.
|
||||
symmetric: bool indicating if it is safe to assume the linear map
|
||||
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.
|
||||
has_aux: bool indicating whether the ``solve`` and ``transpose_solve`` functions
|
||||
return auxiliary data like solver diagnostics as a second argument.
|
||||
|
||||
Returns:
|
||||
Result of ``solve(matvec, b)``, with gradients defined assuming that the
|
||||
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
|
||||
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
|
||||
"""
|
||||
if transpose_solve is None and symmetric:
|
||||
transpose_solve = solve
|
||||
@ -2182,22 +2196,29 @@ def custom_linear_solve(
|
||||
|
||||
tree, = treedef_children(in_args_tree)
|
||||
|
||||
def _shape_checked(fun, name):
|
||||
def _shape_checked(fun, name, has_aux):
|
||||
def f(x):
|
||||
y = fun(x)
|
||||
_check_shapes(name, "b", y, b_flat)
|
||||
return y
|
||||
return f
|
||||
|
||||
def f_aux(x):
|
||||
y, aux = fun(x)
|
||||
_check_shapes(name, "b", y, b_flat)
|
||||
return y, aux
|
||||
|
||||
return f_aux if has_aux else f
|
||||
|
||||
# no auxiliary data assumed for matvec
|
||||
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(matvec, "matvec"), in_args_tree, b_avals,
|
||||
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
|
||||
'custom_linear_solve')
|
||||
_check_tree("matvec", "b", out_tree, tree)
|
||||
_check_tree("matvec", "b", out_tree, tree, False)
|
||||
|
||||
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(partial(solve, matvec), "solve"), in_args_tree, b_avals,
|
||||
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
|
||||
'custom_linear_solve')
|
||||
_check_tree("solve", "b", out_tree, tree)
|
||||
_check_tree("solve", "b", out_tree, tree, has_aux)
|
||||
|
||||
if transpose_solve is None:
|
||||
vecmat_jaxpr = tr_solve_jaxpr = None
|
||||
@ -2214,9 +2235,9 @@ def custom_linear_solve(
|
||||
assert out_tree == tree
|
||||
|
||||
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
|
||||
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve"),
|
||||
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux),
|
||||
in_args_tree, b_avals, 'custom_linear_solve')
|
||||
_check_tree("transpose_solve", "b", out_tree, tree)
|
||||
_check_tree("transpose_solve", "b", out_tree, tree, has_aux)
|
||||
|
||||
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
|
||||
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
|
||||
@ -2226,11 +2247,21 @@ def custom_linear_solve(
|
||||
out_flat = linear_solve_p.bind(
|
||||
*(_flatten(all_consts) + b_flat),
|
||||
const_lengths=const_lengths, jaxprs=jaxprs)
|
||||
return tree_unflatten(tree, out_flat)
|
||||
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
|
||||
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
|
||||
return _map(raise_to_shaped, args[sum(const_lengths):])
|
||||
args_to_raise = args[sum(const_lengths):]
|
||||
|
||||
# raise aux_args to shaped arrays as well if present
|
||||
# number of aux args is the difference in out_avals
|
||||
# of solve and matvec (since they map to the same vector space)
|
||||
|
||||
num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals)
|
||||
if num_aux > 0:
|
||||
args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:])
|
||||
return _map(raise_to_shaped, args_to_raise)
|
||||
|
||||
|
||||
def _custom_linear_solve_impl(*args, const_lengths, jaxprs):
|
||||
@ -2263,16 +2294,29 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs):
|
||||
params, _ = _split_linear_solve_args(primals, const_lengths)
|
||||
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)
|
||||
|
||||
num_x_leaves = len(b_dot)
|
||||
# x is a flat tree with possible aux values appended
|
||||
# since x_tree == b_tree == b_dot_tree, we can cut off
|
||||
# aux values with len info provided by b_dot tree here
|
||||
x_leaves, _ = split_list(x, [num_x_leaves])
|
||||
|
||||
if all(type(p) is ad_util.Zero for p in params_dot.matvec):
|
||||
# no need to evaluate matvec_tangents
|
||||
rhs = b_dot
|
||||
else:
|
||||
matvec_tangents = _tangent_linear_map(
|
||||
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x)
|
||||
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x_leaves)
|
||||
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))
|
||||
|
||||
x_dot = linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)
|
||||
|
||||
# split into x tangents and aux tangents (these become zero)
|
||||
dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves])
|
||||
|
||||
daux_leaves = _map(ad_util.Zero.from_value, daux_leaves)
|
||||
|
||||
x_dot = dx_leaves + daux_leaves
|
||||
|
||||
return x, x_dot
|
||||
|
||||
|
||||
@ -2282,10 +2326,14 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
|
||||
'differentiation of custom_linear_solve')
|
||||
|
||||
params, b = _split_linear_solve_args(primals, const_lengths)
|
||||
# split off symbolic zeros in the cotangent if present
|
||||
x_cotangent, _ = split_list(cotangent, [len(b)])
|
||||
assert all(ad.is_undefined_primal(x) for x in b)
|
||||
cotangent_b = linear_solve_p.bind(
|
||||
*(_flatten(params.transpose()) + cotangent),
|
||||
cotangent_b_full = linear_solve_p.bind(
|
||||
*(_flatten(params.transpose()) + x_cotangent),
|
||||
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose())
|
||||
# drop aux values in cotangent computation
|
||||
cotangent_b, _ = split_list(cotangent_b_full, [len(b)])
|
||||
return [None] * sum(const_lengths) + cotangent_b
|
||||
|
||||
|
||||
@ -2302,6 +2350,7 @@ def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths,
|
||||
(matvec, vecmat, solve, solve_t) = jaxprs
|
||||
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
|
||||
|
||||
num_aux = len(solve.out_avals) - len(matvec.out_avals)
|
||||
# Fixpoint computation of which parts of x and b are batched; we need to
|
||||
# ensure this is consistent between all four jaxprs
|
||||
b_bat = orig_b_bat
|
||||
@ -2318,7 +2367,9 @@ def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths,
|
||||
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
|
||||
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat)
|
||||
# batch all aux data by default
|
||||
x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat)
|
||||
|
||||
# Apply matvec and solve_t -> new batched parts of b
|
||||
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
|
||||
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat,
|
||||
@ -2360,7 +2411,7 @@ def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths,
|
||||
*(new_params + new_b),
|
||||
const_lengths=const_lengths,
|
||||
jaxprs=batched_jaxprs)
|
||||
out_dims = [0 if batched else batching.not_mapped for batched in b_bat]
|
||||
out_dims = [0 if batched else batching.not_mapped for batched in solve_x_bat]
|
||||
return outs, out_dims
|
||||
|
||||
|
||||
|
@ -2067,6 +2067,50 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
actual = api.vmap(linear_solve, (None, 1), 1)(a, c)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve_aux(self):
|
||||
def explicit_jacobian_solve_aux(matvec, b):
|
||||
x = lax.stop_gradient(jnp.linalg.solve(api.jacobian(matvec)(b), b))
|
||||
return x, array_aux
|
||||
|
||||
def matrix_free_solve_aux(matvec, b):
|
||||
return lax.custom_linear_solve(
|
||||
matvec, b, explicit_jacobian_solve_aux, explicit_jacobian_solve_aux,
|
||||
symmetric=True, has_aux=True)
|
||||
|
||||
def linear_solve_aux(a, b):
|
||||
return matrix_free_solve_aux(partial(high_precision_dot, a), b)
|
||||
|
||||
# array aux values, to be able to use jtu.check_grads
|
||||
array_aux = {"converged": np.array(1.), "nfev": np.array(12345.)}
|
||||
rng = np.random.RandomState(0)
|
||||
a = rng.randn(3, 3)
|
||||
a = a + a.T
|
||||
b = rng.randn(3)
|
||||
|
||||
expected = jnp.linalg.solve(a, b)
|
||||
actual_nojit, nojit_aux = linear_solve_aux(a, b)
|
||||
actual_jit, jit_aux = api.jit(linear_solve_aux)(a, b)
|
||||
|
||||
self.assertAllClose(expected, actual_nojit)
|
||||
self.assertAllClose(expected, actual_jit)
|
||||
# scalar dict equality check
|
||||
self.assertDictEqual(nojit_aux, array_aux)
|
||||
self.assertDictEqual(jit_aux, array_aux)
|
||||
|
||||
# jvp / vjp test
|
||||
jtu.check_grads(linear_solve_aux, (a, b), order=2, rtol=2e-3)
|
||||
|
||||
# vmap test
|
||||
c = rng.randn(3, 2)
|
||||
expected = jnp.linalg.solve(a, c)
|
||||
expected_aux = tree_util.tree_map(partial(np.repeat, repeats=2), array_aux)
|
||||
actual_vmap, vmap_aux = api.vmap(linear_solve_aux, (None, 1), -1)(a, c)
|
||||
|
||||
self.assertAllClose(expected, actual_vmap)
|
||||
jtu.check_eq(expected_aux, vmap_aux)
|
||||
|
||||
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def test_custom_linear_solve_zeros(self):
|
||||
def explicit_jacobian_solve(matvec, b):
|
||||
|
Loading…
x
Reference in New Issue
Block a user