mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16938 from jakevdp:spsolve-grad
PiperOrigin-RevId: 553889073
This commit is contained in:
commit
6e873ab816
@ -20,10 +20,12 @@ import functools
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental import sparse
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import core
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lib import gpu_solver
|
||||
|
||||
import numpy as np
|
||||
@ -549,9 +551,48 @@ def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder):
|
||||
return result
|
||||
|
||||
|
||||
def _spsolve_jvp_lhs(data_dot, data, indices, indptr, b, **kwds):
|
||||
# d/dM M^-1 b = M^-1 M_dot M^-1 b
|
||||
p = spsolve(data, indices, indptr, b, **kwds)
|
||||
q = sparse.csr_matvec_p.bind(data_dot, indices, indptr, p,
|
||||
shape=(indptr.size - 1, len(b)),
|
||||
transpose=False)
|
||||
return -spsolve(data, indices, indptr, q, **kwds)
|
||||
|
||||
|
||||
def _spsolve_jvp_rhs(b_dot, data, indices, indptr, b, **kwds):
|
||||
# d/db M^-1 b = M^-1 b_dot
|
||||
return spsolve(data, indices, indptr, b_dot, **kwds)
|
||||
|
||||
|
||||
def _csr_transpose(data, indices, indptr):
|
||||
# Transpose of a square CSR matrix
|
||||
m = indptr.size - 1
|
||||
row = jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1
|
||||
row_T, indices_T, data_T = jax.lax.sort((indices, row, data), num_keys=2)
|
||||
indptr_T = jnp.zeros_like(indptr).at[1:].set(
|
||||
jnp.cumsum(jnp.bincount(row_T, length=m)).astype(indptr.dtype))
|
||||
return data_T, indices_T, indptr_T
|
||||
|
||||
|
||||
def _spsolve_transpose(ct, data, indices, indptr, b, **kwds):
|
||||
assert not ad.is_undefined_primal(indices)
|
||||
assert not ad.is_undefined_primal(indptr)
|
||||
if ad.is_undefined_primal(b):
|
||||
# TODO(jakevdp): can we do this without an explicit transpose?
|
||||
data_T, indices_T, indptr_T = _csr_transpose(data, indices, indptr)
|
||||
ct_out = spsolve(data_T, indices_T, indptr_T, ct, **kwds)
|
||||
return data, indices, indptr, ct_out
|
||||
else:
|
||||
# Should never reach here, because JVP is linear wrt data.
|
||||
raise NotImplementedError("spsolve transpose with respect to data")
|
||||
|
||||
|
||||
spsolve_p = core.Primitive('spsolve')
|
||||
spsolve_p.def_impl(functools.partial(xla.apply_primitive, spsolve_p))
|
||||
spsolve_p.def_abstract_eval(_spsolve_abstract_eval)
|
||||
ad.defjvp(spsolve_p, _spsolve_jvp_lhs, None, None, _spsolve_jvp_rhs)
|
||||
ad.primitive_transposes[spsolve_p] = _spsolve_transpose
|
||||
mlir.register_lowering(spsolve_p, _spsolve_gpu_lowering, platform='cuda')
|
||||
mlir.register_lowering(spsolve_p, _spsolve_cpu_lowering, platform='cpu')
|
||||
|
||||
|
@ -2799,6 +2799,28 @@ class SparseSolverTest(sptu.SparseTestCase):
|
||||
self.assertAllClose(a @ x, b, rtol=1e-2, atol=1e-3)
|
||||
self._CompileAndCheck(sparse_solve, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
size=[10, 20, 50],
|
||||
dtype=jtu.dtypes.floating,
|
||||
)
|
||||
@unittest.skipIf(jtu.device_under_test() == "tpu", "test requires CPU or GPU")
|
||||
@unittest.skipIf(jtu.device_under_test() == "cuda" and not GPU_LOWERING_ENABLED,
|
||||
"test requires cusparse/cusolver")
|
||||
@jtu.skip_on_devices("rocm", "test requires cusolver")
|
||||
def test_sparse_qr_linear_solver_grads(self, size, dtype):
|
||||
rng = rand_sparse(self.rng())
|
||||
a = rng((size, size), dtype)
|
||||
nse = (a != 0).sum()
|
||||
data, indices, indptr = sparse_csr._csr_fromdense(a, nse=nse)
|
||||
|
||||
rng_k = jtu.rand_default(self.rng())
|
||||
b = rng_k([size], dtype)
|
||||
|
||||
def sparse_solve(data, b, tol=1e-8):
|
||||
return sparse.linalg.spsolve(data, indices, indptr, b, tol=tol)
|
||||
|
||||
jtu.check_grads(sparse_solve, (data, b), order=1, rtol=0.05, atol=0.05)
|
||||
|
||||
|
||||
class SparseUtilTest(sptu.SparseTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user