Merge pull request #8581 from jakevdp:sparse-matmul-dtype

PiperOrigin-RevId: 410893387
This commit is contained in:
jax authors 2021-11-18 14:03:03 -08:00
commit bf74f2e50c
2 changed files with 19 additions and 6 deletions

View File

@ -46,6 +46,7 @@ from jax._src import dtypes
from jax._src.lib import cusparse
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.numpy.lax_numpy import _promote_dtypes
import jax.numpy as jnp
xb = xla_bridge
@ -875,10 +876,12 @@ class CSR(JAXSparse):
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape)
def matvec(self, v):
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape)
data, v = _promote_dtypes(self.data, v)
return csr_matvec(data, self.indices, self.indptr, v, shape=self.shape)
def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape)
data, B = _promote_dtypes(self.data, B)
return csr_matmat(data, self.indices, self.indptr, B, shape=self.shape)
def transpose(self, axes=None):
assert axes is None
@ -912,10 +915,12 @@ class CSC(JAXSparse):
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape[::-1]).T
def matvec(self, v):
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape[::-1], transpose=True)
data, v = _promote_dtypes(self.data, v)
return csr_matvec(data, self.indices, self.indptr, v, shape=self.shape[::-1], transpose=True)
def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)
data, B = _promote_dtypes(self.data, B)
return csr_matmat(data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)
def transpose(self, axes=None):
assert axes is None
@ -949,10 +954,12 @@ class COO(JAXSparse):
return coo_todense(self.data, self.row, self.col, shape=self.shape)
def matvec(self, v):
return coo_matvec(self.data, self.row, self.col, v, shape=self.shape)
data, v = _promote_dtypes(self.data, v)
return coo_matvec(data, self.row, self.col, v, shape=self.shape)
def matmat(self, B):
return coo_matmat(self.data, self.row, self.col, B, shape=self.shape)
data, B = _promote_dtypes(self.data, B)
return coo_matmat(data, self.row, self.col, B, shape=self.shape)
def transpose(self, axes=None):
assert axes is None

View File

@ -1496,9 +1496,15 @@ class SparseObjectTest(jtu.JaxTestCase):
rng_b = jtu.rand_default(self.rng())
M = rng(shape, dtype)
Msp = Obj.fromdense(M)
# Test matching type
x = rng_b(bshape, dtype)
x = jnp.asarray(x)
self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
# Test mismatched type
x = rng_b(bshape, np.int32)
x = jnp.asarray(x)
self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
@parameterized.named_parameters(jtu.cases_from_list(