mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sparse] implement __len__ on sparse objects
This commit is contained in:
parent
fcb9dfb080
commit
27c068e7b7
@ -32,6 +32,9 @@ class JAXSparse(abc.ABC):
|
||||
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
||||
__hash__ = None # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
return self.shape[0]
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return util.prod(self.shape)
|
||||
|
@ -2438,33 +2438,38 @@ class SparseObjectTest(sptu.SparseTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{Obj.__name__}", "Obj": Obj}
|
||||
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
|
||||
def test_attrs(self, Obj, shape=(5, 8), dtype=np.float16):
|
||||
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO, sparse.BCSR])
|
||||
def test_attrs(self, Obj, shape=(5, 8), dtype=np.float32):
|
||||
rng = rand_sparse(self.rng(), post=Obj.fromdense)
|
||||
M = rng(shape, dtype)
|
||||
|
||||
assert isinstance(M, Obj)
|
||||
assert M.shape == shape
|
||||
assert M.size == np.prod(shape)
|
||||
assert M.ndim == len(shape)
|
||||
assert M.dtype == dtype
|
||||
assert M.nse == (M.todense() != 0).sum()
|
||||
assert M.data.dtype == dtype
|
||||
self.assertIsInstance(M, Obj)
|
||||
self.assertEqual(M.shape, shape)
|
||||
self.assertEqual(M.size, np.prod(shape))
|
||||
self.assertEqual(M.ndim, len(shape))
|
||||
self.assertEqual(M.dtype, dtype)
|
||||
self.assertEqual(M.nse, (M.todense() != 0).sum())
|
||||
self.assertEqual(M.data.dtype, dtype)
|
||||
self.assertEqual(len(M), M.shape[0])
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
hash(M)
|
||||
|
||||
if isinstance(M, sparse.CSR):
|
||||
assert len(M.data) == len(M.indices)
|
||||
assert len(M.indptr) == M.shape[0] + 1
|
||||
self.assertEqual(len(M.data), len(M.indices))
|
||||
self.assertEqual(len(M.indptr), M.shape[0] + 1)
|
||||
elif isinstance(M, sparse.CSC):
|
||||
assert len(M.data) == len(M.indices)
|
||||
assert len(M.indptr) == M.shape[1] + 1
|
||||
self.assertEqual(len(M.data), len(M.indices))
|
||||
self.assertEqual(len(M.indptr), M.shape[1] + 1)
|
||||
elif isinstance(M, sparse.COO):
|
||||
assert len(M.data) == len(M.row) == len(M.col)
|
||||
self.assertEqual(len(M.data), len(M.row))
|
||||
self.assertEqual(len(M.data), len(M.col))
|
||||
elif isinstance(M, sparse.BCOO):
|
||||
assert M.data.shape[M.n_batch] == M.indices.shape[-2]
|
||||
assert M.indices.shape[-1] == M.n_sparse
|
||||
self.assertEqual(M.data.shape[M.n_batch], M.indices.shape[-2])
|
||||
self.assertEqual(M.indices.shape[-1], M.n_sparse)
|
||||
elif isinstance(M, sparse.BCSR):
|
||||
self.assertEqual(M.data.shape[M.n_batch], M.indices.shape[-1])
|
||||
self.assertEqual(M.indptr.shape[-1], M.shape[M.n_batch] + 1)
|
||||
else:
|
||||
raise ValueError(f"{Obj=} not expected.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user