Merge pull request #25787 from dfm:tri-diag-jvp

PiperOrigin-RevId: 714109627
This commit is contained in:
jax authors 2025-01-10 11:03:18 -08:00
commit 016fca79ca
2 changed files with 94 additions and 53 deletions

View File

@ -2568,6 +2568,26 @@ def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs):
b_out, b_aval, _nan_like_hlo(ctx, b_aval), b_aval)]
def _tridiagonal_product(dl, d, du, b):
y = lax.reshape(d, d.shape + (1,)) * b
y = y.at[..., 1:, :].add(dl[..., 1:, None] * b[..., :-1, :])
y = y.at[..., :-1, :].add(du[..., :-1, None] * b[..., 1:, :])
return y
def _tridiagonal_solve_jvp_rule(primals, tangents):
*diags, _ = primals
*diags_dot, b_dot = tangents
ans = tridiagonal_solve_p.bind(*primals)
if all(type(p) is ad_util.Zero for p in diags_dot):
rhs = b_dot
else:
matvec_dot = _tridiagonal_product(*map(ad.instantiate_zeros, diags_dot), ans)
rhs = ad.add_tangents(b_dot, -matvec_dot)
ans_dot = tridiagonal_solve_p.bind(*diags, rhs)
return ans, ans_dot
def _tridiagonal_solve_transpose_rule(cotangent, dl, d, du, b):
# Tridiagonal solve is nonlinear in the tridiagonal arguments and linear
# otherwise.
@ -2576,7 +2596,11 @@ def _tridiagonal_solve_transpose_rule(cotangent, dl, d, du, b):
if type(cotangent) is ad_util.Zero:
cotangent_b = ad_util.Zero(b.aval)
else:
cotangent_b = tridiagonal_solve(dl, d, du, cotangent)
dl_trans = lax.concatenate((lax.zeros_like_array(du[..., -1:]), du[..., :-1]),
du.ndim-1)
du_trans = lax.concatenate((dl[..., 1:], lax.zeros_like_array(dl[..., :1])),
dl.ndim-1)
cotangent_b = tridiagonal_solve(dl_trans, d, du_trans, cotangent)
return [None, None, None, cotangent_b]
@ -2605,9 +2629,9 @@ def _tridiagonal_solve_batching_rule(batched_args, batch_dims):
tridiagonal_solve_p = standard_primitive(
_tridiagonal_solve_shape_rule, _tridiagonal_solve_dtype_rule,
'tridiagonal_solve')
ad.primitive_jvps[tridiagonal_solve_p] = _tridiagonal_solve_jvp_rule
ad.primitive_transposes[tridiagonal_solve_p] = _tridiagonal_solve_transpose_rule
batching.primitive_batchers[tridiagonal_solve_p] = _tridiagonal_solve_batching_rule
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
mlir.register_lowering(
tridiagonal_solve_p,
@ -2623,50 +2647,32 @@ mlir.register_lowering(
platform='rocm')
def _tridiagonal_solve_jax(dl, d, du, b, **kw):
"""Pure JAX implementation of `tridiagonal_solve`."""
def prepend_zero(x):
return lax.concatenate(
[lax.full((1,) + x.shape[1:], 0, dtype=x.dtype), x[:-1]], dimension=0)
fwd1 = lambda tu_, x: x[1] / (x[0] - x[2] * tu_)
def _tridiagonal_solve_jax_impl(dl, d, du, b):
def fwd(carry, args):
cp, dp = carry
a, b, c, d = args
cp_next = c / (b - a * cp)
dp_next = (d - a * dp) / (b - a * cp)
return (cp_next, dp_next), (cp, dp)
def fwd2(b_, x):
return (x[0] - x[3][np.newaxis, ...] * b_) / (
x[1] - x[3] * x[2])[np.newaxis, ...]
(_, final), (cp, dp) = lax.scan(
fwd, (du[0] / d[0], b[0] / d[0]), (dl[1:], d[1:], du[1:], b[1:, :]),
unroll=32)
bwd1 = lambda x_, x: x[0] - x[1][np.newaxis, ...] * x_
double = lambda f, args: (f(*args), f(*args))
def bwd(xn, args):
cp, dp = args
x = dp - cp * xn
return x, xn
# Move relevant dimensions to the front for the scan.
moveaxis_fwd = lambda x: lax.transpose(x, (x.ndim - 1, *range(x.ndim - 1)))
moveaxis_bwd = lambda x: lax.transpose(x, (*range(1, x.ndim), 0))
dl = moveaxis_fwd(dl)
d = moveaxis_fwd(d)
du = moveaxis_fwd(du)
b = moveaxis_fwd(b)
b = moveaxis_fwd(b)
end, ans = lax.scan(bwd, final, (cp, dp), unroll=32, reverse=True)
return lax.concatenate((end[None], ans), 0)
# Forward pass.
_, tu_ = lax.scan(lambda tu_, x: double(fwd1, (tu_, x)),
du[0] / d[0],
(d, du, dl),
unroll=32)
_, b_ = lax.scan(lambda b_, x: double(fwd2, (b_, x)),
b[0] / d[0:1],
(b, d, prepend_zero(tu_), dl),
unroll=32)
# Backsubstitution.
_, x_ = lax.scan(lambda x_, x: double(bwd1, (x_, x)),
b_[-1],
(b_[::-1], tu_[::-1]),
unroll=32)
result = x_[::-1]
result = moveaxis_bwd(result)
result = moveaxis_bwd(result)
return result
def _tridiagonal_solve_jax(dl, d, du, b, **_):
impl = _tridiagonal_solve_jax_impl
for _ in range(dl.ndim - 1):
impl = api.vmap(impl)
return impl(dl, d, du, b)
mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun(

View File

@ -2184,20 +2184,55 @@ class LaxLinalgTest(jtu.JaxTestCase):
self.assertAllClose(
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
@jtu.sample_product(dtype=float_types + complex_types)
def test_tridiagonal_solve(self, dtype):
@jtu.sample_product(shape=[(3,), (3, 4), (3, 4, 5)],
dtype=float_types + complex_types)
def test_tridiagonal_solve(self, shape, dtype):
if dtype not in float_types and jtu.test_device_matches(["gpu"]):
self.skipTest("Data type not supported on GPU")
dl = np.array([0.0, 2.0, 3.0], dtype=dtype)
d = np.ones(3, dtype=dtype)
du = np.array([1.0, 2.0, 0.0], dtype=dtype)
m = 3
B = np.ones([m, 1], dtype=dtype)
X = lax.linalg.tridiagonal_solve(dl, d, du, B)
A = np.eye(3, dtype=dtype)
A[[1, 2], [0, 1]] = dl[1:]
A[[0, 1], [1, 2]] = du[:-1]
np.testing.assert_allclose(A @ X, B, rtol=1e-6, atol=1e-6)
rng = self.rng()
d = 1.0 + jtu.rand_positive(rng)(shape, dtype)
dl = jtu.rand_default(rng)(shape, dtype)
du = jtu.rand_default(rng)(shape, dtype)
b = jtu.rand_default(rng)(shape + (1,), dtype)
x = lax.linalg.tridiagonal_solve(dl, d, du, b)
def build_tri(dl, d, du):
return jnp.diag(d) + jnp.diag(dl[1:], -1) + jnp.diag(du[:-1], 1)
for _ in shape[:-1]:
build_tri = jax.vmap(build_tri)
a = build_tri(dl, d, du)
self.assertAllClose(a @ x, b, atol=5e-5, rtol=1e-4)
def test_tridiagonal_solve_endpoints(self):
# tridagonal_solve shouldn't depend on the endpoints being explicitly zero.
dtype = np.float32
size = 10
dl = np.linspace(-1.0, 1.0, size, dtype=dtype)
dlz = np.copy(dl)
dlz[0] = 0.0
d = np.linspace(1.0, 2.0, size, dtype=dtype)
du = np.linspace(1.0, -1.0, size, dtype=dtype)
duz = np.copy(du)
duz[-1] = 0.0
b = np.linspace(0.1, -0.1, size, dtype=dtype)[:, None]
self.assertAllClose(
lax.linalg.tridiagonal_solve(dl, d, du, b),
lax.linalg.tridiagonal_solve(dlz, d, duz, b),
)
@jtu.sample_product(shape=[(3,), (3, 4)], dtype=float_types + complex_types)
def test_tridiagonal_solve_grad(self, shape, dtype):
if dtype not in float_types and jtu.test_device_matches(["gpu"]):
self.skipTest("Data type not supported on GPU")
rng = self.rng()
d = 1.0 + jtu.rand_positive(rng)(shape, dtype)
dl = jtu.rand_default(rng)(shape, dtype)
du = jtu.rand_default(rng)(shape, dtype)
b = jtu.rand_default(rng)(shape + (1,), dtype)
args = (dl, d, du, b)
jtu.check_grads(lax.linalg.tridiagonal_solve, args, order=2, atol=1e-1,
rtol=1e-1)
@jtu.sample_product(
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],