mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #25787 from dfm:tri-diag-jvp
PiperOrigin-RevId: 714109627
This commit is contained in:
commit
016fca79ca
@ -2568,6 +2568,26 @@ def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs):
|
||||
b_out, b_aval, _nan_like_hlo(ctx, b_aval), b_aval)]
|
||||
|
||||
|
||||
def _tridiagonal_product(dl, d, du, b):
|
||||
y = lax.reshape(d, d.shape + (1,)) * b
|
||||
y = y.at[..., 1:, :].add(dl[..., 1:, None] * b[..., :-1, :])
|
||||
y = y.at[..., :-1, :].add(du[..., :-1, None] * b[..., 1:, :])
|
||||
return y
|
||||
|
||||
|
||||
def _tridiagonal_solve_jvp_rule(primals, tangents):
|
||||
*diags, _ = primals
|
||||
*diags_dot, b_dot = tangents
|
||||
ans = tridiagonal_solve_p.bind(*primals)
|
||||
if all(type(p) is ad_util.Zero for p in diags_dot):
|
||||
rhs = b_dot
|
||||
else:
|
||||
matvec_dot = _tridiagonal_product(*map(ad.instantiate_zeros, diags_dot), ans)
|
||||
rhs = ad.add_tangents(b_dot, -matvec_dot)
|
||||
ans_dot = tridiagonal_solve_p.bind(*diags, rhs)
|
||||
return ans, ans_dot
|
||||
|
||||
|
||||
def _tridiagonal_solve_transpose_rule(cotangent, dl, d, du, b):
|
||||
# Tridiagonal solve is nonlinear in the tridiagonal arguments and linear
|
||||
# otherwise.
|
||||
@ -2576,7 +2596,11 @@ def _tridiagonal_solve_transpose_rule(cotangent, dl, d, du, b):
|
||||
if type(cotangent) is ad_util.Zero:
|
||||
cotangent_b = ad_util.Zero(b.aval)
|
||||
else:
|
||||
cotangent_b = tridiagonal_solve(dl, d, du, cotangent)
|
||||
dl_trans = lax.concatenate((lax.zeros_like_array(du[..., -1:]), du[..., :-1]),
|
||||
du.ndim-1)
|
||||
du_trans = lax.concatenate((dl[..., 1:], lax.zeros_like_array(dl[..., :1])),
|
||||
dl.ndim-1)
|
||||
cotangent_b = tridiagonal_solve(dl_trans, d, du_trans, cotangent)
|
||||
return [None, None, None, cotangent_b]
|
||||
|
||||
|
||||
@ -2605,9 +2629,9 @@ def _tridiagonal_solve_batching_rule(batched_args, batch_dims):
|
||||
tridiagonal_solve_p = standard_primitive(
|
||||
_tridiagonal_solve_shape_rule, _tridiagonal_solve_dtype_rule,
|
||||
'tridiagonal_solve')
|
||||
ad.primitive_jvps[tridiagonal_solve_p] = _tridiagonal_solve_jvp_rule
|
||||
ad.primitive_transposes[tridiagonal_solve_p] = _tridiagonal_solve_transpose_rule
|
||||
batching.primitive_batchers[tridiagonal_solve_p] = _tridiagonal_solve_batching_rule
|
||||
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
|
||||
|
||||
mlir.register_lowering(
|
||||
tridiagonal_solve_p,
|
||||
@ -2623,50 +2647,32 @@ mlir.register_lowering(
|
||||
platform='rocm')
|
||||
|
||||
|
||||
def _tridiagonal_solve_jax(dl, d, du, b, **kw):
|
||||
"""Pure JAX implementation of `tridiagonal_solve`."""
|
||||
def prepend_zero(x):
|
||||
return lax.concatenate(
|
||||
[lax.full((1,) + x.shape[1:], 0, dtype=x.dtype), x[:-1]], dimension=0)
|
||||
fwd1 = lambda tu_, x: x[1] / (x[0] - x[2] * tu_)
|
||||
def _tridiagonal_solve_jax_impl(dl, d, du, b):
|
||||
def fwd(carry, args):
|
||||
cp, dp = carry
|
||||
a, b, c, d = args
|
||||
cp_next = c / (b - a * cp)
|
||||
dp_next = (d - a * dp) / (b - a * cp)
|
||||
return (cp_next, dp_next), (cp, dp)
|
||||
|
||||
def fwd2(b_, x):
|
||||
return (x[0] - x[3][np.newaxis, ...] * b_) / (
|
||||
x[1] - x[3] * x[2])[np.newaxis, ...]
|
||||
(_, final), (cp, dp) = lax.scan(
|
||||
fwd, (du[0] / d[0], b[0] / d[0]), (dl[1:], d[1:], du[1:], b[1:, :]),
|
||||
unroll=32)
|
||||
|
||||
bwd1 = lambda x_, x: x[0] - x[1][np.newaxis, ...] * x_
|
||||
double = lambda f, args: (f(*args), f(*args))
|
||||
def bwd(xn, args):
|
||||
cp, dp = args
|
||||
x = dp - cp * xn
|
||||
return x, xn
|
||||
|
||||
# Move relevant dimensions to the front for the scan.
|
||||
moveaxis_fwd = lambda x: lax.transpose(x, (x.ndim - 1, *range(x.ndim - 1)))
|
||||
moveaxis_bwd = lambda x: lax.transpose(x, (*range(1, x.ndim), 0))
|
||||
dl = moveaxis_fwd(dl)
|
||||
d = moveaxis_fwd(d)
|
||||
du = moveaxis_fwd(du)
|
||||
b = moveaxis_fwd(b)
|
||||
b = moveaxis_fwd(b)
|
||||
end, ans = lax.scan(bwd, final, (cp, dp), unroll=32, reverse=True)
|
||||
return lax.concatenate((end[None], ans), 0)
|
||||
|
||||
# Forward pass.
|
||||
_, tu_ = lax.scan(lambda tu_, x: double(fwd1, (tu_, x)),
|
||||
du[0] / d[0],
|
||||
(d, du, dl),
|
||||
unroll=32)
|
||||
|
||||
_, b_ = lax.scan(lambda b_, x: double(fwd2, (b_, x)),
|
||||
b[0] / d[0:1],
|
||||
(b, d, prepend_zero(tu_), dl),
|
||||
unroll=32)
|
||||
|
||||
# Backsubstitution.
|
||||
_, x_ = lax.scan(lambda x_, x: double(bwd1, (x_, x)),
|
||||
b_[-1],
|
||||
(b_[::-1], tu_[::-1]),
|
||||
unroll=32)
|
||||
|
||||
result = x_[::-1]
|
||||
result = moveaxis_bwd(result)
|
||||
result = moveaxis_bwd(result)
|
||||
return result
|
||||
def _tridiagonal_solve_jax(dl, d, du, b, **_):
|
||||
impl = _tridiagonal_solve_jax_impl
|
||||
for _ in range(dl.ndim - 1):
|
||||
impl = api.vmap(impl)
|
||||
return impl(dl, d, du, b)
|
||||
|
||||
|
||||
mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun(
|
||||
|
@ -2184,20 +2184,55 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(
|
||||
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
|
||||
|
||||
@jtu.sample_product(dtype=float_types + complex_types)
|
||||
def test_tridiagonal_solve(self, dtype):
|
||||
@jtu.sample_product(shape=[(3,), (3, 4), (3, 4, 5)],
|
||||
dtype=float_types + complex_types)
|
||||
def test_tridiagonal_solve(self, shape, dtype):
|
||||
if dtype not in float_types and jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Data type not supported on GPU")
|
||||
dl = np.array([0.0, 2.0, 3.0], dtype=dtype)
|
||||
d = np.ones(3, dtype=dtype)
|
||||
du = np.array([1.0, 2.0, 0.0], dtype=dtype)
|
||||
m = 3
|
||||
B = np.ones([m, 1], dtype=dtype)
|
||||
X = lax.linalg.tridiagonal_solve(dl, d, du, B)
|
||||
A = np.eye(3, dtype=dtype)
|
||||
A[[1, 2], [0, 1]] = dl[1:]
|
||||
A[[0, 1], [1, 2]] = du[:-1]
|
||||
np.testing.assert_allclose(A @ X, B, rtol=1e-6, atol=1e-6)
|
||||
rng = self.rng()
|
||||
d = 1.0 + jtu.rand_positive(rng)(shape, dtype)
|
||||
dl = jtu.rand_default(rng)(shape, dtype)
|
||||
du = jtu.rand_default(rng)(shape, dtype)
|
||||
b = jtu.rand_default(rng)(shape + (1,), dtype)
|
||||
x = lax.linalg.tridiagonal_solve(dl, d, du, b)
|
||||
|
||||
def build_tri(dl, d, du):
|
||||
return jnp.diag(d) + jnp.diag(dl[1:], -1) + jnp.diag(du[:-1], 1)
|
||||
for _ in shape[:-1]:
|
||||
build_tri = jax.vmap(build_tri)
|
||||
|
||||
a = build_tri(dl, d, du)
|
||||
self.assertAllClose(a @ x, b, atol=5e-5, rtol=1e-4)
|
||||
|
||||
def test_tridiagonal_solve_endpoints(self):
|
||||
# tridagonal_solve shouldn't depend on the endpoints being explicitly zero.
|
||||
dtype = np.float32
|
||||
size = 10
|
||||
dl = np.linspace(-1.0, 1.0, size, dtype=dtype)
|
||||
dlz = np.copy(dl)
|
||||
dlz[0] = 0.0
|
||||
d = np.linspace(1.0, 2.0, size, dtype=dtype)
|
||||
du = np.linspace(1.0, -1.0, size, dtype=dtype)
|
||||
duz = np.copy(du)
|
||||
duz[-1] = 0.0
|
||||
b = np.linspace(0.1, -0.1, size, dtype=dtype)[:, None]
|
||||
self.assertAllClose(
|
||||
lax.linalg.tridiagonal_solve(dl, d, du, b),
|
||||
lax.linalg.tridiagonal_solve(dlz, d, duz, b),
|
||||
)
|
||||
|
||||
@jtu.sample_product(shape=[(3,), (3, 4)], dtype=float_types + complex_types)
|
||||
def test_tridiagonal_solve_grad(self, shape, dtype):
|
||||
if dtype not in float_types and jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Data type not supported on GPU")
|
||||
rng = self.rng()
|
||||
d = 1.0 + jtu.rand_positive(rng)(shape, dtype)
|
||||
dl = jtu.rand_default(rng)(shape, dtype)
|
||||
du = jtu.rand_default(rng)(shape, dtype)
|
||||
b = jtu.rand_default(rng)(shape + (1,), dtype)
|
||||
args = (dl, d, du, b)
|
||||
jtu.check_grads(lax.linalg.tridiagonal_solve, args, order=2, atol=1e-1,
|
||||
rtol=1e-1)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(4, 4), (15, 15), (50, 50), (100, 100)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user