Merge pull request #16938 from jakevdp:spsolve-grad

PiperOrigin-RevId: 553889073
This commit is contained in:
jax authors 2023-08-04 12:59:23 -07:00
commit 6e873ab816
2 changed files with 63 additions and 0 deletions

View File

@ -20,10 +20,12 @@ import functools
import jax
import jax.numpy as jnp
from jax.experimental import sparse
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src import core
from jax._src.interpreters import ad
from jax._src.lib import gpu_solver
import numpy as np
@ -549,9 +551,48 @@ def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder):
return result
def _spsolve_jvp_lhs(data_dot, data, indices, indptr, b, **kwds):
# d/dM M^-1 b = M^-1 M_dot M^-1 b
p = spsolve(data, indices, indptr, b, **kwds)
q = sparse.csr_matvec_p.bind(data_dot, indices, indptr, p,
shape=(indptr.size - 1, len(b)),
transpose=False)
return -spsolve(data, indices, indptr, q, **kwds)
def _spsolve_jvp_rhs(b_dot, data, indices, indptr, b, **kwds):
# d/db M^-1 b = M^-1 b_dot
return spsolve(data, indices, indptr, b_dot, **kwds)
def _csr_transpose(data, indices, indptr):
# Transpose of a square CSR matrix
m = indptr.size - 1
row = jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1
row_T, indices_T, data_T = jax.lax.sort((indices, row, data), num_keys=2)
indptr_T = jnp.zeros_like(indptr).at[1:].set(
jnp.cumsum(jnp.bincount(row_T, length=m)).astype(indptr.dtype))
return data_T, indices_T, indptr_T
def _spsolve_transpose(ct, data, indices, indptr, b, **kwds):
assert not ad.is_undefined_primal(indices)
assert not ad.is_undefined_primal(indptr)
if ad.is_undefined_primal(b):
# TODO(jakevdp): can we do this without an explicit transpose?
data_T, indices_T, indptr_T = _csr_transpose(data, indices, indptr)
ct_out = spsolve(data_T, indices_T, indptr_T, ct, **kwds)
return data, indices, indptr, ct_out
else:
# Should never reach here, because JVP is linear wrt data.
raise NotImplementedError("spsolve transpose with respect to data")
spsolve_p = core.Primitive('spsolve')
spsolve_p.def_impl(functools.partial(xla.apply_primitive, spsolve_p))
spsolve_p.def_abstract_eval(_spsolve_abstract_eval)
ad.defjvp(spsolve_p, _spsolve_jvp_lhs, None, None, _spsolve_jvp_rhs)
ad.primitive_transposes[spsolve_p] = _spsolve_transpose
mlir.register_lowering(spsolve_p, _spsolve_gpu_lowering, platform='cuda')
mlir.register_lowering(spsolve_p, _spsolve_cpu_lowering, platform='cpu')

View File

@ -2799,6 +2799,28 @@ class SparseSolverTest(sptu.SparseTestCase):
self.assertAllClose(a @ x, b, rtol=1e-2, atol=1e-3)
self._CompileAndCheck(sparse_solve, args_maker)
@jtu.sample_product(
size=[10, 20, 50],
dtype=jtu.dtypes.floating,
)
@unittest.skipIf(jtu.device_under_test() == "tpu", "test requires CPU or GPU")
@unittest.skipIf(jtu.device_under_test() == "cuda" and not GPU_LOWERING_ENABLED,
"test requires cusparse/cusolver")
@jtu.skip_on_devices("rocm", "test requires cusolver")
def test_sparse_qr_linear_solver_grads(self, size, dtype):
rng = rand_sparse(self.rng())
a = rng((size, size), dtype)
nse = (a != 0).sum()
data, indices, indptr = sparse_csr._csr_fromdense(a, nse=nse)
rng_k = jtu.rand_default(self.rng())
b = rng_k([size], dtype)
def sparse_solve(data, b, tol=1e-8):
return sparse.linalg.spsolve(data, indices, indptr, b, tol=tol)
jtu.check_grads(sparse_solve, (data, b), order=1, rtol=0.05, atol=0.05)
class SparseUtilTest(sptu.SparseTestCase):