mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] fix GPU translation rule for coo/csr matmat
This commit is contained in:
parent
17a606a95d
commit
f8081a9a52
@ -137,8 +137,9 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(indices).element_type())
|
||||
B_dtype = np.dtype(c.get_shape(B).element_type())
|
||||
B_shape = c.get_shape(B).dimensions()
|
||||
rows, cols = shape
|
||||
_, Ccols = c.get_shape(B).dimensions()
|
||||
_, Ccols = B_shape
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
@ -154,11 +155,10 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
|
||||
b"cusparse_csr_matmat",
|
||||
operands=(data, indices, indptr, B),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(indices),
|
||||
c.get_shape(indptr),
|
||||
c.get_shape(B),
|
||||
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
|
||||
@ -272,8 +272,9 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
B_dtype = np.dtype(c.get_shape(B).element_type())
|
||||
B_shape = c.get_shape(B).dimensions()
|
||||
rows, cols = shape
|
||||
_, Ccols = c.get_shape(B).dimensions()
|
||||
_, Ccols = B_shape
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
@ -289,11 +290,10 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
|
||||
b"cusparse_coo_matmat",
|
||||
operands=(data, row, col, B),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(row),
|
||||
c.get_shape(col),
|
||||
c.get_shape(B),
|
||||
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
|
||||
|
@ -202,7 +202,6 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
|
||||
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_T={}".format(jtu.format_shape_dtype_string(shape, dtype), transpose),
|
||||
"shape": shape, "dtype": dtype, "transpose": transpose}
|
||||
@ -229,6 +228,23 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
y, dy = jvp(lambda x: sparse.coo_matmat(x, M.row, M.col, B, shape=shape, transpose=transpose).sum(), (M.data, ), (jnp.ones_like(M.data), ))
|
||||
self.assertAllClose((op(M) @ B).sum(), y, rtol=MATMUL_TOL)
|
||||
|
||||
def test_coo_matmat_layout(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7533
|
||||
d = jnp.array([1.0, 2.0, 3.0, 4.0])
|
||||
i = jnp.array([0, 0, 1, 2])
|
||||
j = jnp.array([0, 2, 0, 0])
|
||||
shape = (3, 3)
|
||||
|
||||
x = jnp.arange(9).reshape(3, 3).astype(d.dtype)
|
||||
|
||||
def f(x):
|
||||
return sparse.coo_matmat(d, i, j, x.T, shape=shape)
|
||||
|
||||
result = f(x)
|
||||
result_jit = jit(f)(x)
|
||||
|
||||
self.assertAllClose(result, result_jit)
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
||||
def test_gpu_translation_rule(self):
|
||||
version = xla_bridge.get_backend().platform_version
|
||||
|
Loading…
x
Reference in New Issue
Block a user