Add JVP rules COO sparse ops.

Updated coo_matvec jvp rule.

Make flake8 happy.
This commit is contained in:
Gert-Jan 2021-05-01 12:09:03 +00:00
parent 6ce4ef46b9
commit 81903e894b
2 changed files with 83 additions and 4 deletions

View File

@ -39,16 +39,17 @@ from jax import api
from jax import core
from jax import jit
from jax import tree_util
from jax import lax
from jax.interpreters import xla
from jax.lib import cusparse
from jax.lib import xla_bridge
from jax.lib import xla_client
import jax.numpy as jnp
import numpy as np
from jax.interpreters import ad
from jax import ad_util
xb = xla_bridge
xops = xla_client.ops
Dtype = Any
#--------------------------------------------------------------------
@ -298,6 +299,17 @@ if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_todense_p] = _coo_todense_gpu_translation_rule
def _coo_todense_jvp_rule(primals_in, tangents_in, **params):
vals, rows, cols, = primals_in
mat_dot, rows_dot, cols_dot = tangents_in
assert type(rows_dot) is ad_util.Zero
assert type(cols_dot) is ad_util.Zero
primals_out = coo_todense(vals, rows, cols, **params)
tangents_out = ad_util.Zero.from_value(primals_out) if type(mat_dot) is ad_util.Zero else coo_todense(mat_dot, rows, cols, **params)
return primals_out, tangents_out
ad.primitive_jvps[coo_todense_p] = _coo_todense_jvp_rule
#--------------------------------------------------------------------
# coo_fromdense
@ -351,6 +363,14 @@ if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_fromdense_p] = _coo_fromdense_gpu_translation_rule
def _coo_fromdense_jvp_rule(primals_in, tangents_in, **params):
mat, = primals_in
mat_dot, = tangents_in
data, row, col = coo_fromdense(mat, **params)
tangents_out = ad_util.Zero.from_value(data) if type(mat_dot) is ad_util.Zero else coo_fromdense(mat_dot, **params)[0]
return (data, row, col), (tangents_out, ad_util.Zero.from_value(row), ad_util.Zero.from_value(col))
ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp_rule
#--------------------------------------------------------------------
# coo_matvec
@ -403,6 +423,22 @@ if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_matvec_p] = _coo_matvec_gpu_translation_rule
def _coo_matvec_jvp_rule(primals_in, tangents_in, **params):
vals, rows, cols, vec = primals_in
sparse_mat_dot, rows_dot, cols_dot, vec_dot = tangents_in
assert type(rows_dot) is ad_util.Zero
assert type(cols_dot) is ad_util.Zero
primals_out = coo_matvec(vals, rows, cols, vec, **params)
_zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad_util.Zero) else t
_sparse_mat_dot = _zero(vals, sparse_mat_dot)
_vec_dot = _zero(vec, vec_dot)
tangents_out = coo_matvec(_sparse_mat_dot, rows, cols, vec, **params) + coo_matvec(vals, rows, cols, _vec_dot, **params)
return primals_out, tangents_out
ad.primitive_jvps[coo_matvec_p] = _coo_matvec_jvp_rule
#--------------------------------------------------------------------
# coo_matmat
@ -454,6 +490,22 @@ if cusparse and cusparse.is_supported:
xla.backend_specific_translations['gpu'][
coo_matmat_p] = _coo_matmat_gpu_translation_rule
def _coo_matmat_jvp_rule(primals_in, tangents_in, **params):
vals, rows, cols, mat = primals_in
sparse_mat_dot, rows_dot, cols_dot, mat_dot = tangents_in
assert type(rows_dot) is ad_util.Zero
assert type(cols_dot) is ad_util.Zero
primals_out = coo_matmat(vals, rows, cols, mat, **params)
_zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad_util.Zero) else t
_sparse_mat_dot = _zero(vals, sparse_mat_dot)
_mat_dot = _zero(mat, mat_dot)
tangents_out = coo_matmat(_sparse_mat_dot, rows, cols, mat, **params) + coo_matmat(vals, rows, cols, _mat_dot, **params)
return primals_out, tangents_out
ad.primitive_jvps[coo_matmat_p] = _coo_matmat_jvp_rule
#----------------------------------------------------------------------
# Sparse objects (APIs subject to change)
class JAXSparse:

View File

@ -25,10 +25,9 @@ from jax import jit
from jax import test_util as jtu
from jax import xla
import jax.numpy as jnp
from jax import jvp
import numpy as np
from scipy import sparse
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@ -149,6 +148,16 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertArraysEqual(M.toarray(), todense(*args))
self.assertArraysEqual(M.toarray(), jit(todense)(*args))
todense = lambda data: sparse_ops.coo_todense(data, M.row, M.col, shape=M.shape)
tangent = jnp.ones_like(M.data)
y, dy = jvp(todense, (M.data, ), (tangent, ))
self.assertArraysEqual(M.toarray(), y)
self.assertArraysEqual(todense(tangent), dy)
y, dy = jit(lambda prim, tan: jvp(todense, prim, tan))((M.data, ), (tangent, ))
self.assertArraysEqual(M.toarray(), y)
self.assertArraysEqual(todense(tangent), dy)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
@ -173,6 +182,13 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
tangent = jnp.ones_like(M)
(data, row, col), (data_dot, row_dot, col_dot) = jvp(fromdense, (M, ), (tangent, ))
self.assertArraysEqual(data, M_coo.data.astype(dtype))
self.assertArraysEqual(row, M_coo.row.astype(index_dtype))
self.assertArraysEqual(col, M_coo.col.astype(index_dtype))
self.assertArraysEqual(data_dot, fromdense(tangent)[0])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
@ -193,6 +209,12 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
y, dy = jvp(lambda x: sparse_ops.coo_matvec(M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(), (v, ), (jnp.ones_like(v), ))
self.assertAllClose((op(M) @ v).sum(), y, rtol=MATMUL_TOL)
y, dy = jvp(lambda x: sparse_ops.coo_matvec(x, M.row, M.col, v, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
self.assertAllClose((op(M) @ v).sum(), y, rtol=MATMUL_TOL)
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
@ -213,6 +235,11 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)
y, dy = jvp(lambda x: sparse_ops.coo_matmat(M.data, M.row, M.col, x, shape=shape, transpose=transpose).sum(), (B, ), (jnp.ones_like(B), ))
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
y, dy = jvp(lambda x: sparse_ops.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
def test_gpu_translation_rule(self):
version = xla_bridge.get_backend().platform_version