[sparse] fix GPU translation rule for coo/csr matmat

This commit is contained in:
Jake VanderPlas 2021-08-10 10:13:29 -07:00
parent 17a606a95d
commit f8081a9a52
2 changed files with 23 additions and 7 deletions

View File

@ -137,8 +137,9 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
B_dtype = np.dtype(c.get_shape(B).element_type())
B_shape = c.get_shape(B).dimensions()
rows, cols = shape
_, Ccols = c.get_shape(B).dimensions()
_, Ccols = B_shape
nnz, = c.get_shape(data).dimensions()
if compute_dtype is None:
@ -154,11 +155,10 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
b"cusparse_csr_matmat",
operands=(data, indices, indptr, B),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
c.get_shape(B),
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
@ -272,8 +272,9 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
B_dtype = np.dtype(c.get_shape(B).element_type())
B_shape = c.get_shape(B).dimensions()
rows, cols = shape
_, Ccols = c.get_shape(B).dimensions()
_, Ccols = B_shape
nnz, = c.get_shape(data).dimensions()
if compute_dtype is None:
@ -289,11 +290,10 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
b"cusparse_coo_matmat",
operands=(data, row, col, B),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
c.get_shape(B),
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),

View File

@ -202,7 +202,6 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
"shape": shape, "dtype": dtype, "transpose": transpose}
@ -229,6 +228,23 @@ class cuSparseTest(jtu.JaxTestCase):
y, dy = jvp(lambda x: sparse.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
def test_coo_matmat_layout(self):
# Regression test for https://github.com/google/jax/issues/7533
d = jnp.array([1.0, 2.0, 3.0, 4.0])
i = jnp.array([0, 0, 1, 2])
j = jnp.array([0, 2, 0, 0])
shape = (3, 3)
x = jnp.arange(9).reshape(3, 3).astype(d.dtype)
def f(x):
return sparse.coo_matmat(d, i, j, x.T, shape=shape)
result = f(x)
result_jit = jit(f)(x)
self.assertAllClose(result, result_jit)
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
def test_gpu_translation_rule(self):
version = xla_bridge.get_backend().platform_version