mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] add JVP & transpose rules for coo primitives
This commit is contained in:
parent
fd6069c450
commit
926de5a2bc
@ -38,6 +38,7 @@ from typing import Any, Tuple
|
||||
from jax import api
|
||||
from jax import core
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax.interpreters import xla
|
||||
@ -54,6 +55,8 @@ Dtype = Any
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# utilities
|
||||
# TODO: possibly make these utilities into primitives, targeting
|
||||
# csr2coo/coo2csr/SPDDMM
|
||||
@functools.partial(jit, static_argnums=1)
|
||||
def _csr_to_coo(indptr, nnz):
|
||||
return jnp.cumsum(jnp.zeros_like(indptr, shape=nnz).at[indptr].add(1)) - 1
|
||||
@ -63,6 +66,16 @@ def _coo_to_csr(row, nrows):
|
||||
indptr = jnp.zeros(nrows + 1, row.dtype)
|
||||
return indptr.at[1:].set(jnp.cumsum(jnp.bincount(row, length=nrows)))
|
||||
|
||||
@jit
|
||||
def _csr_extract(indices, indptr, mat):
|
||||
"""Extract values of dense matrix mat at given CSR indices."""
|
||||
return _coo_extract(_csr_to_coo(indptr, len(indices)), indices, mat)
|
||||
|
||||
@jit
|
||||
def _coo_extract(row, col, mat):
|
||||
"""Extract values of dense matrix mat at given COO indices."""
|
||||
return mat[row, col]
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_todense
|
||||
|
||||
@ -293,23 +306,38 @@ def _coo_todense_abstract_eval(data, row, col, *, shape):
|
||||
def _coo_todense_gpu_translation_rule(c, data, row, col, *, shape):
|
||||
return cusparse.coo_todense(c, data, row, col, shape=shape)
|
||||
|
||||
def _coo_todense_jvp(primals, tangents, *, shape):
|
||||
data, row, col = primals
|
||||
data_dot, row_dot, col_dot = tangents
|
||||
|
||||
assert isinstance(row_dot, ad.Zero)
|
||||
assert isinstance(col_dot, ad.Zero)
|
||||
# TODO: propagate symbolic zeros if possible.
|
||||
data_dot = lax.zeros_like_array(data) if isinstance(data_dot, ad.Zero) else data_dot
|
||||
|
||||
# Note: we assume that transpose has the same sparsity pattern. Can we assert this?
|
||||
primals_out = coo_todense(data, row, col, shape=shape)
|
||||
tangents_out = coo_todense(data_dot, row, col, shape=shape)
|
||||
|
||||
return primals_out, tangents_out
|
||||
|
||||
def _coo_todense_transpose(ct, data, row, col, *, shape):
|
||||
assert ad.is_undefined_primal(data)
|
||||
if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
|
||||
raise ValueError("Cannot transpose with respect to sparse indices")
|
||||
assert ct.shape == shape
|
||||
assert row.aval.dtype == col.aval.dtype
|
||||
assert ct.dtype == data.aval.dtype
|
||||
return _coo_extract(row, col, ct), row, col
|
||||
|
||||
ad.primitive_jvps[coo_todense_p] = _coo_todense_jvp
|
||||
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
|
||||
xla.translations[coo_todense_p] = xla.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_todense_p] = _coo_todense_gpu_translation_rule
|
||||
|
||||
def _coo_todense_jvp_rule(primals_in, tangents_in, **params):
|
||||
vals, rows, cols, = primals_in
|
||||
mat_dot, rows_dot, cols_dot = tangents_in
|
||||
assert type(rows_dot) is ad_util.Zero
|
||||
assert type(cols_dot) is ad_util.Zero
|
||||
|
||||
primals_out = coo_todense(vals, rows, cols, **params)
|
||||
tangents_out = ad_util.Zero.from_value(primals_out) if type(mat_dot) is ad_util.Zero else coo_todense(mat_dot, rows, cols, **params)
|
||||
return primals_out, tangents_out
|
||||
ad.primitive_jvps[coo_todense_p] = _coo_todense_jvp_rule
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_fromdense
|
||||
|
||||
@ -357,20 +385,36 @@ def _coo_fromdense_gpu_translation_rule(c, mat, *, nnz, index_dtype):
|
||||
c, mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
|
||||
return xops.Tuple(c, [data, row, col])
|
||||
|
||||
def _coo_fromdense_jvp(primals, tangents, *, nnz, index_dtype):
|
||||
M, = primals
|
||||
Mdot, = tangents
|
||||
|
||||
# TODO: propagate symbolic zeros if possible.
|
||||
Mdot = lax.zeros_like_array(M) if isinstance(Mdot, ad.Zero) else Mdot
|
||||
|
||||
primals_out = coo_fromdense(M, nnz=nnz, index_dtype=index_dtype)
|
||||
_, row, col = primals_out
|
||||
tangents_out = _coo_extract(row, col, Mdot), ad.Zero(row.aval), ad.Zero(col.aval)
|
||||
return primals_out, tangents_out
|
||||
|
||||
def _coo_fromdense_transpose(ct, M, *, nnz, index_dtype):
|
||||
data, row, col = ct
|
||||
assert len(data) == nnz
|
||||
assert row.dtype == col.dtype == index_dtype
|
||||
if isinstance(row, ad.Zero) or isinstance(col, ad.Zero):
|
||||
raise ValueError("Cannot transpose with respect to sparse indices")
|
||||
assert ad.is_undefined_primal(M)
|
||||
return coo_todense(data, row, col, shape=M.aval.shape)
|
||||
|
||||
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
|
||||
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
|
||||
|
||||
xla.translations[coo_fromdense_p] = xla.lower_fun(
|
||||
_coo_fromdense_impl, multiple_results=True)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_fromdense_p] = _coo_fromdense_gpu_translation_rule
|
||||
|
||||
def _coo_fromdense_jvp_rule(primals_in, tangents_in, **params):
|
||||
mat, = primals_in
|
||||
mat_dot, = tangents_in
|
||||
data, row, col = coo_fromdense(mat, **params)
|
||||
tangents_out = ad_util.Zero.from_value(data) if type(mat_dot) is ad_util.Zero else coo_fromdense(mat_dot, **params)[0]
|
||||
return (data, row, col), (tangents_out, ad_util.Zero.from_value(row), ad_util.Zero.from_value(col))
|
||||
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp_rule
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matvec
|
||||
|
||||
@ -417,28 +461,44 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
|
||||
def _coo_matvec_gpu_translation_rule(c, data, row, col, v, *, shape, transpose):
|
||||
return cusparse.coo_matvec(c, data, row, col, v, shape=shape, transpose=transpose)
|
||||
|
||||
def _coo_matvec_jvp(primals, tangents, *, shape, transpose):
|
||||
data, row, col, v = primals
|
||||
data_dot, row_dot, col_dot, v_dot = tangents
|
||||
|
||||
assert isinstance(row_dot, ad.Zero)
|
||||
assert isinstance(col_dot, ad.Zero)
|
||||
|
||||
# TODO: propagate symbolic zeros if possible.
|
||||
_zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad.Zero) else t
|
||||
data_dot = _zero(data, data_dot)
|
||||
v_dot = _zero(v, v_dot)
|
||||
|
||||
primals_out = coo_matvec(data, row, col, v, shape=shape, transpose=transpose)
|
||||
tangents_out = (
|
||||
coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose) +
|
||||
coo_matvec(data, row, col, v_dot, shape=shape, transpose=transpose)
|
||||
)
|
||||
return primals_out, tangents_out
|
||||
|
||||
def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
|
||||
assert not ad.is_undefined_primal(row)
|
||||
assert not ad.is_undefined_primal(col)
|
||||
|
||||
if ad.is_undefined_primal(v):
|
||||
return data, row, col, coo_matvec(data, row, col, ct, shape=shape, transpose=not transpose)
|
||||
else:
|
||||
v = jnp.asarray(v)
|
||||
# return _coo_extract(row, col, jnp.outer(ct, v)), row, col, v
|
||||
return ct[row] * v[col], row, col, v
|
||||
|
||||
ad.primitive_jvps[coo_matvec_p] = _coo_matvec_jvp
|
||||
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
|
||||
xla.translations[coo_matvec_p] = xla.lower_fun(
|
||||
_coo_matvec_impl, multiple_results=False)
|
||||
if cusparse and cusparse.is_supported:
|
||||
xla.backend_specific_translations['gpu'][
|
||||
coo_matvec_p] = _coo_matvec_gpu_translation_rule
|
||||
|
||||
def _coo_matvec_jvp_rule(primals_in, tangents_in, **params):
|
||||
vals, rows, cols, vec = primals_in
|
||||
sparse_mat_dot, rows_dot, cols_dot, vec_dot = tangents_in
|
||||
assert type(rows_dot) is ad_util.Zero
|
||||
assert type(cols_dot) is ad_util.Zero
|
||||
|
||||
primals_out = coo_matvec(vals, rows, cols, vec, **params)
|
||||
_zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad_util.Zero) else t
|
||||
|
||||
_sparse_mat_dot = _zero(vals, sparse_mat_dot)
|
||||
_vec_dot = _zero(vec, vec_dot)
|
||||
|
||||
tangents_out = coo_matvec(_sparse_mat_dot, rows, cols, vec, **params) + coo_matvec(vals, rows, cols, _vec_dot, **params)
|
||||
return primals_out, tangents_out
|
||||
ad.primitive_jvps[coo_matvec_p] = _coo_matvec_jvp_rule
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matmat
|
||||
|
||||
|
@ -17,7 +17,10 @@ import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import api
|
||||
from jax import config
|
||||
from jax import dtypes
|
||||
from jax.experimental import sparse_ops
|
||||
from jax.lib import cusparse
|
||||
from jax.lib import xla_bridge
|
||||
@ -148,16 +151,6 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(M.toarray(), todense(*args))
|
||||
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
|
||||
|
||||
todense = lambda data: sparse_ops.coo_todense(data, M.row, M.col, shape=M.shape)
|
||||
tangent = jnp.ones_like(M.data)
|
||||
y, dy = jvp(todense, (M.data, ), (tangent, ))
|
||||
self.assertArraysEqual(M.toarray(), y)
|
||||
self.assertArraysEqual(todense(tangent), dy)
|
||||
|
||||
y, dy = jit(lambda prim, tan: jvp(todense, prim, tan))((M.data, ), (tangent, ))
|
||||
self.assertArraysEqual(M.toarray(), y)
|
||||
self.assertArraysEqual(todense(tangent), dy)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
@ -182,13 +175,6 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
||||
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
||||
|
||||
tangent = jnp.ones_like(M)
|
||||
(data, row, col), (data_dot, row_dot, col_dot) = jvp(fromdense, (M, ), (tangent, ))
|
||||
self.assertArraysEqual(data, M_coo.data.astype(dtype))
|
||||
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
|
||||
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
|
||||
self.assertArraysEqual(data_dot, fromdense(tangent)[0])
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
@ -209,11 +195,6 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
y, dy = jvp(lambda x: sparse_ops.coo_matvec(M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(), (v, ), (jnp.ones_like(v), ))
|
||||
self.assertAllClose((op(M) @ v).sum(), y, rtol=MATMUL_TOL)
|
||||
|
||||
y, dy = jvp(lambda x: sparse_ops.coo_matvec(x, M.row, M.col, v, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
|
||||
self.assertAllClose((op(M) @ v).sum(), y, rtol=MATMUL_TOL)
|
||||
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
@ -240,6 +221,7 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
|
||||
y, dy = jvp(lambda x: sparse_ops.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
|
||||
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
||||
def test_gpu_translation_rule(self):
|
||||
version = xla_bridge.get_backend().platform_version
|
||||
@ -268,6 +250,108 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
M_out = todense(*args, shape=M.shape)
|
||||
self.assertArraysEqual(M, M_out)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
def test_coo_todense_ad(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng(), post=jnp.array)
|
||||
M = rng(shape, dtype)
|
||||
data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
|
||||
f = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape)
|
||||
|
||||
# Forward-mode
|
||||
primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)])
|
||||
self.assertArraysEqual(primals, f(data))
|
||||
self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))
|
||||
|
||||
# Reverse-mode
|
||||
primals, vjp_fun = api.vjp(f, data)
|
||||
data_out, = vjp_fun(primals)
|
||||
self.assertArraysEqual(primals, f(data))
|
||||
self.assertArraysEqual(data_out, data)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
def test_coo_fromdense_ad(self, shape, dtype):
|
||||
rng = rand_sparse(self.rng(), post=jnp.array)
|
||||
M = rng(shape, dtype)
|
||||
nnz = (M != 0).sum()
|
||||
f = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz)
|
||||
|
||||
# Forward-mode
|
||||
primals, tangents = api.jvp(f, [M], [jnp.ones_like(M)])
|
||||
self.assertArraysEqual(primals[0], f(M)[0])
|
||||
self.assertArraysEqual(primals[1], f(M)[1])
|
||||
self.assertArraysEqual(primals[2], f(M)[2])
|
||||
self.assertArraysEqual(tangents[0], jnp.ones(nnz, dtype=dtype))
|
||||
self.assertEqual(tangents[1].dtype, dtypes.float0)
|
||||
self.assertEqual(tangents[2].dtype, dtypes.float0)
|
||||
|
||||
# Reverse-mode
|
||||
primals, vjp_fun = api.vjp(f, M)
|
||||
M_out, = vjp_fun(primals)
|
||||
self.assertArraysEqual(primals[0], f(M)[0])
|
||||
self.assertArraysEqual(primals[1], f(M)[1])
|
||||
self.assertArraysEqual(primals[2], f(M)[2])
|
||||
self.assertArraysEqual(M_out, M)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
jtu.format_shape_dtype_string(bshape, dtype)),
|
||||
"shape": shape, "dtype": dtype, "bshape": bshape}
|
||||
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
|
||||
for bshape in [shape[-1:] + s for s in [()]] # TODO: matmul autodiff
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex)) # TODO: other types
|
||||
|
||||
def test_coo_matvec_ad(self, shape, dtype, bshape):
|
||||
rng = rand_sparse(self.rng(), post=jnp.array)
|
||||
rng_b = jtu.rand_default(self.rng())
|
||||
|
||||
M = rng(shape, dtype)
|
||||
data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
|
||||
x = rng_b(bshape, dtype)
|
||||
xdot = rng_b(bshape, dtype)
|
||||
|
||||
# Forward-mode with respect to the vector
|
||||
f_dense = lambda x: M @ x
|
||||
f_sparse = lambda x: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape)
|
||||
v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot])
|
||||
v_dense, t_dense = api.jvp(f_dense, [x], [xdot])
|
||||
self.assertAllClose(v_sparse, v_dense)
|
||||
self.assertAllClose(t_sparse, t_dense)
|
||||
|
||||
# Reverse-mode with respect to the vector
|
||||
primals_dense, vjp_dense = api.vjp(f_dense, x)
|
||||
primals_sparse, vjp_sparse = api.vjp(f_sparse, x)
|
||||
out_dense, = vjp_dense(primals_dense)
|
||||
out_sparse, = vjp_sparse(primals_sparse)
|
||||
self.assertAllClose(primals_dense[0], primals_sparse[0])
|
||||
self.assertAllClose(out_dense, out_sparse)
|
||||
|
||||
# Forward-mode with respect to nonzero elements of the matrix
|
||||
f_sparse = lambda data: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape)
|
||||
f_dense = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape) @ x
|
||||
data = rng((len(data),), data.dtype)
|
||||
data_dot = rng((len(data),), data.dtype)
|
||||
v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot])
|
||||
v_dense, t_dense = api.jvp(f_dense, [data], [data_dot])
|
||||
self.assertAllClose(v_sparse, v_dense)
|
||||
self.assertAllClose(t_sparse, t_dense)
|
||||
|
||||
# Reverse-mode with respect to nonzero elements of the matrix
|
||||
primals_dense, vjp_dense = api.vjp(f_dense, data)
|
||||
primals_sparse, vjp_sparse = api.vjp(f_sparse, data)
|
||||
out_dense, = vjp_dense(primals_dense)
|
||||
out_sparse, = vjp_sparse(primals_sparse)
|
||||
self.assertAllClose(primals_dense[0], primals_sparse[0])
|
||||
self.assertAllClose(out_dense, out_sparse)
|
||||
|
||||
|
||||
class SparseObjectTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(
|
||||
|
Loading…
x
Reference in New Issue
Block a user