Merge pull request #6739 from jakevdp:sparse-op-jvp

PiperOrigin-RevId: 373870776
This commit is contained in:
jax authors 2021-05-14 14:50:20 -07:00
commit 25cc3ece66
2 changed files with 185 additions and 58 deletions

View File

@ -38,8 +38,8 @@ from typing import Any, Tuple
from jax import api
from jax import core
from jax import jit
from jax import tree_util
from jax import lax
from jax import tree_util
from jax.interpreters import xla
from jax.lib import cusparse
from jax.lib import xla_bridge
@ -54,6 +54,8 @@ Dtype = Any
#--------------------------------------------------------------------
# utilities
# TODO: possibly make these utilities into primitives, targeting
# csr2coo/coo2csr/SPDDMM
@functools.partial(jit, static_argnums=1)
def _csr_to_coo(indptr, nnz):
return jnp.cumsum(jnp.zeros_like(indptr, shape=nnz).at[indptr].add(1)) - 1
@ -63,6 +65,16 @@ def _coo_to_csr(row, nrows):
indptr = jnp.zeros(nrows + 1, row.dtype)
return indptr.at[1:].set(jnp.cumsum(jnp.bincount(row, length=nrows)))
@jit
def _csr_extract(indices, indptr, mat):
"""Extract values of dense matrix mat at given CSR indices."""
return _coo_extract(_csr_to_coo(indptr, len(indices)), indices, mat)
@jit
def _coo_extract(row, col, mat):
"""Extract values of dense matrix mat at given COO indices."""
return mat[row, col]
#--------------------------------------------------------------------
# csr_todense
@ -293,23 +305,28 @@ def _coo_todense_abstract_eval(data, row, col, *, shape):
def _coo_todense_gpu_translation_rule(c, data, row, col, *, shape):
return cusparse.coo_todense(c, data, row, col, shape=shape)
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
return coo_todense(data_dot, row, col, shape=shape)
def _coo_todense_transpose(ct, data, row, col, *, shape):
# Note: we assume that transpose has the same sparsity pattern.
# Can we check this?
assert ad.is_undefined_primal(data)
if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ct.shape == shape
assert row.aval.dtype == col.aval.dtype
assert ct.dtype == data.aval.dtype
return _coo_extract(row, col, ct), row, col
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.translations[coo_todense_p] = xla.lower_fun(
_coo_todense_impl, multiple_results=False)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_todense_p] = _coo_todense_gpu_translation_rule
def _coo_todense_jvp_rule(primals_in, tangents_in, **params):
vals, rows, cols, = primals_in
mat_dot, rows_dot, cols_dot = tangents_in
assert type(rows_dot) is ad_util.Zero
assert type(cols_dot) is ad_util.Zero
primals_out = coo_todense(vals, rows, cols, **params)
tangents_out = ad_util.Zero.from_value(primals_out) if type(mat_dot) is ad_util.Zero else coo_todense(mat_dot, rows, cols, **params)
return primals_out, tangents_out
ad.primitive_jvps[coo_todense_p] = _coo_todense_jvp_rule
#--------------------------------------------------------------------
# coo_fromdense
@ -357,20 +374,40 @@ def _coo_fromdense_gpu_translation_rule(c, mat, *, nnz, index_dtype):
c, mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
return xops.Tuple(c, [data, row, col])
def _coo_fromdense_jvp(primals, tangents, *, nnz, index_dtype):
M, = primals
Mdot, = tangents
primals_out = coo_fromdense(M, nnz=nnz, index_dtype=index_dtype)
data, row, col = primals_out
if type(Mdot) is ad.Zero:
data_dot = ad.Zero.from_value(data)
else:
data_dot = _coo_extract(row, col, Mdot)
tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col))
return primals_out, tangents_out
def _coo_fromdense_transpose(ct, M, *, nnz, index_dtype):
data, row, col = ct
assert len(data) == nnz
assert row.dtype == col.dtype == index_dtype
if isinstance(row, ad.Zero) or isinstance(col, ad.Zero):
raise ValueError("Cannot transpose with respect to sparse indices")
assert ad.is_undefined_primal(M)
return coo_todense(data, row, col, shape=M.aval.shape)
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
xla.translations[coo_fromdense_p] = xla.lower_fun(
_coo_fromdense_impl, multiple_results=True)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_fromdense_p] = _coo_fromdense_gpu_translation_rule
def _coo_fromdense_jvp_rule(primals_in, tangents_in, **params):
mat, = primals_in
mat_dot, = tangents_in
data, row, col = coo_fromdense(mat, **params)
tangents_out = ad_util.Zero.from_value(data) if type(mat_dot) is ad_util.Zero else coo_fromdense(mat_dot, **params)[0]
return (data, row, col), (tangents_out, ad_util.Zero.from_value(row), ad_util.Zero.from_value(col))
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp_rule
#--------------------------------------------------------------------
# coo_matvec
@ -417,28 +454,31 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
def _coo_matvec_gpu_translation_rule(c, data, row, col, v, *, shape, transpose):
return cusparse.coo_matvec(c, data, row, col, v, shape=shape, transpose=transpose)
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, shape, transpose):
return coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose)
def _coo_matvec_jvp_vec(v_dot, data, row, col, v, *, shape, transpose):
return coo_matvec(data, row, col, v_dot, shape=shape, transpose=transpose)
def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
assert not ad.is_undefined_primal(row)
assert not ad.is_undefined_primal(col)
if ad.is_undefined_primal(v):
return data, row, col, coo_matvec(data, row, col, ct, shape=shape, transpose=not transpose)
else:
v = jnp.asarray(v)
# return _coo_extract(row, col, jnp.outer(ct, v)), row, col, v
return ct[row] * v[col], row, col, v
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
xla.translations[coo_matvec_p] = xla.lower_fun(
_coo_matvec_impl, multiple_results=False)
if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_matvec_p] = _coo_matvec_gpu_translation_rule
def _coo_matvec_jvp_rule(primals_in, tangents_in, **params):
vals, rows, cols, vec = primals_in
sparse_mat_dot, rows_dot, cols_dot, vec_dot = tangents_in
assert type(rows_dot) is ad_util.Zero
assert type(cols_dot) is ad_util.Zero
primals_out = coo_matvec(vals, rows, cols, vec, **params)
_zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad_util.Zero) else t
_sparse_mat_dot = _zero(vals, sparse_mat_dot)
_vec_dot = _zero(vec, vec_dot)
tangents_out = coo_matvec(_sparse_mat_dot, rows, cols, vec, **params) + coo_matvec(vals, rows, cols, _vec_dot, **params)
return primals_out, tangents_out
ad.primitive_jvps[coo_matvec_p] = _coo_matvec_jvp_rule
#--------------------------------------------------------------------
# coo_matmat

View File

@ -17,7 +17,10 @@ import unittest
from absl.testing import absltest
from absl.testing import parameterized
from jax import api
from jax import config
from jax import dtypes
from jax.experimental import sparse_ops
from jax.lib import cusparse
from jax.lib import xla_bridge
@ -148,16 +151,6 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertArraysEqual(M.toarray(), todense(*args))
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
todense = lambda data: sparse_ops.coo_todense(data, M.row, M.col, shape=M.shape)
tangent = jnp.ones_like(M.data)
y, dy = jvp(todense, (M.data, ), (tangent, ))
self.assertArraysEqual(M.toarray(), y)
self.assertArraysEqual(todense(tangent), dy)
y, dy = jit(lambda prim, tan: jvp(todense, prim, tan))((M.data, ), (tangent, ))
self.assertArraysEqual(M.toarray(), y)
self.assertArraysEqual(todense(tangent), dy)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
@ -182,13 +175,6 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
tangent = jnp.ones_like(M)
(data, row, col), (data_dot, row_dot, col_dot) = jvp(fromdense, (M, ), (tangent, ))
self.assertArraysEqual(data, M_coo.data.astype(dtype))
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
self.assertArraysEqual(data_dot, fromdense(tangent)[0])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
@ -209,11 +195,6 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
y, dy = jvp(lambda x: sparse_ops.coo_matvec(M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(), (v, ), (jnp.ones_like(v), ))
self.assertAllClose((op(M) @ v).sum(), y, rtol=MATMUL_TOL)
y, dy = jvp(lambda x: sparse_ops.coo_matvec(x, M.row, M.col, v, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
self.assertAllClose((op(M) @ v).sum(), y, rtol=MATMUL_TOL)
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
@ -240,6 +221,7 @@ class cuSparseTest(jtu.JaxTestCase):
y, dy = jvp(lambda x: sparse_ops.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
def test_gpu_translation_rule(self):
version = xla_bridge.get_backend().platform_version
@ -268,6 +250,111 @@ class cuSparseTest(jtu.JaxTestCase):
M_out = todense(*args, shape=M.shape)
self.assertArraysEqual(M, M_out)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def test_coo_todense_ad(self, shape, dtype):
rng = rand_sparse(self.rng(), post=jnp.array)
M = rng(shape, dtype)
data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
f = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape)
# Forward-mode
primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)])
self.assertArraysEqual(primals, f(data))
self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))
# Reverse-mode
primals, vjp_fun = api.vjp(f, data)
data_out, = vjp_fun(primals)
self.assertArraysEqual(primals, f(data))
self.assertArraysEqual(data_out, data)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
def test_coo_fromdense_ad(self, shape, dtype):
rng = rand_sparse(self.rng(), post=jnp.array)
M = rng(shape, dtype)
nnz = (M != 0).sum()
f = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz)
# Forward-mode
primals, tangents = api.jvp(f, [M], [jnp.ones_like(M)])
self.assertArraysEqual(primals[0], f(M)[0])
self.assertArraysEqual(primals[1], f(M)[1])
self.assertArraysEqual(primals[2], f(M)[2])
self.assertArraysEqual(tangents[0], jnp.ones(nnz, dtype=dtype))
self.assertEqual(tangents[1].dtype, dtypes.float0)
self.assertEqual(tangents[2].dtype, dtypes.float0)
# Reverse-mode
primals, vjp_fun = api.vjp(f, M)
M_out, = vjp_fun(primals)
self.assertArraysEqual(primals[0], f(M)[0])
self.assertArraysEqual(primals[1], f(M)[1])
self.assertArraysEqual(primals[2], f(M)[2])
self.assertArraysEqual(M_out, M)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(
jtu.format_shape_dtype_string(shape, dtype),
jtu.format_shape_dtype_string(bshape, dtype)),
"shape": shape, "dtype": dtype, "bshape": bshape}
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for bshape in [shape[-1:] + s for s in [()]] # TODO: matmul autodiff
for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) # TODO: other types
def test_coo_matvec_ad(self, shape, dtype, bshape):
tol = {np.float32: 1E-6, np.float64: 1E-13, np.complex64: 1E-6, np.complex128: 1E-13}
rng = rand_sparse(self.rng(), post=jnp.array)
rng_b = jtu.rand_default(self.rng())
M = rng(shape, dtype)
data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
x = rng_b(bshape, dtype)
xdot = rng_b(bshape, dtype)
# Forward-mode with respect to the vector
f_dense = lambda x: M @ x
f_sparse = lambda x: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape)
v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot])
v_dense, t_dense = api.jvp(f_dense, [x], [xdot])
self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)
# Reverse-mode with respect to the vector
primals_dense, vjp_dense = api.vjp(f_dense, x)
primals_sparse, vjp_sparse = api.vjp(f_sparse, x)
out_dense, = vjp_dense(primals_dense)
out_sparse, = vjp_sparse(primals_sparse)
self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
# Forward-mode with respect to nonzero elements of the matrix
f_sparse = lambda data: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape)
f_dense = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape) @ x
data = rng((len(data),), data.dtype)
data_dot = rng((len(data),), data.dtype)
v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot])
v_dense, t_dense = api.jvp(f_dense, [data], [data_dot])
self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)
# Reverse-mode with respect to nonzero elements of the matrix
primals_dense, vjp_dense = api.vjp(f_dense, data)
primals_sparse, vjp_sparse = api.vjp(f_sparse, data)
out_dense, = vjp_dense(primals_dense)
out_sparse, = vjp_sparse(primals_sparse)
self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
class SparseObjectTest(jtu.JaxTestCase):
@parameterized.named_parameters(