mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8581 from jakevdp:sparse-matmul-dtype
PiperOrigin-RevId: 410893387
This commit is contained in:
commit
bf74f2e50c
@ -46,6 +46,7 @@ from jax._src import dtypes
|
||||
from jax._src.lib import cusparse
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes
|
||||
import jax.numpy as jnp
|
||||
|
||||
xb = xla_bridge
|
||||
@ -875,10 +876,12 @@ class CSR(JAXSparse):
|
||||
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape)
|
||||
|
||||
def matvec(self, v):
|
||||
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape)
|
||||
data, v = _promote_dtypes(self.data, v)
|
||||
return csr_matvec(data, self.indices, self.indptr, v, shape=self.shape)
|
||||
|
||||
def matmat(self, B):
|
||||
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape)
|
||||
data, B = _promote_dtypes(self.data, B)
|
||||
return csr_matmat(data, self.indices, self.indptr, B, shape=self.shape)
|
||||
|
||||
def transpose(self, axes=None):
|
||||
assert axes is None
|
||||
@ -912,10 +915,12 @@ class CSC(JAXSparse):
|
||||
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape[::-1]).T
|
||||
|
||||
def matvec(self, v):
|
||||
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape[::-1], transpose=True)
|
||||
data, v = _promote_dtypes(self.data, v)
|
||||
return csr_matvec(data, self.indices, self.indptr, v, shape=self.shape[::-1], transpose=True)
|
||||
|
||||
def matmat(self, B):
|
||||
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)
|
||||
data, B = _promote_dtypes(self.data, B)
|
||||
return csr_matmat(data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)
|
||||
|
||||
def transpose(self, axes=None):
|
||||
assert axes is None
|
||||
@ -949,10 +954,12 @@ class COO(JAXSparse):
|
||||
return coo_todense(self.data, self.row, self.col, shape=self.shape)
|
||||
|
||||
def matvec(self, v):
|
||||
return coo_matvec(self.data, self.row, self.col, v, shape=self.shape)
|
||||
data, v = _promote_dtypes(self.data, v)
|
||||
return coo_matvec(data, self.row, self.col, v, shape=self.shape)
|
||||
|
||||
def matmat(self, B):
|
||||
return coo_matmat(self.data, self.row, self.col, B, shape=self.shape)
|
||||
data, B = _promote_dtypes(self.data, B)
|
||||
return coo_matmat(data, self.row, self.col, B, shape=self.shape)
|
||||
|
||||
def transpose(self, axes=None):
|
||||
assert axes is None
|
||||
|
@ -1496,9 +1496,15 @@ class SparseObjectTest(jtu.JaxTestCase):
|
||||
rng_b = jtu.rand_default(self.rng())
|
||||
M = rng(shape, dtype)
|
||||
Msp = Obj.fromdense(M)
|
||||
|
||||
# Test matching type
|
||||
x = rng_b(bshape, dtype)
|
||||
x = jnp.asarray(x)
|
||||
self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
|
||||
|
||||
# Test mismatched type
|
||||
x = rng_b(bshape, np.int32)
|
||||
x = jnp.asarray(x)
|
||||
self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user