[sparse] implement __len__ on sparse objects

This commit is contained in:
Jake VanderPlas 2023-02-01 11:46:02 -08:00
parent fcb9dfb080
commit 27c068e7b7
2 changed files with 24 additions and 16 deletions

View File

@ -32,6 +32,9 @@ class JAXSparse(abc.ABC):
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
def __len__(self):
return self.shape[0]
@property
def size(self) -> int:
return util.prod(self.shape)

View File

@ -2438,33 +2438,38 @@ class SparseObjectTest(sptu.SparseTestCase):
@parameterized.named_parameters(
{"testcase_name": f"_{Obj.__name__}", "Obj": Obj}
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO])
def test_attrs(self, Obj, shape=(5, 8), dtype=np.float16):
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO, sparse.BCSR])
def test_attrs(self, Obj, shape=(5, 8), dtype=np.float32):
rng = rand_sparse(self.rng(), post=Obj.fromdense)
M = rng(shape, dtype)
assert isinstance(M, Obj)
assert M.shape == shape
assert M.size == np.prod(shape)
assert M.ndim == len(shape)
assert M.dtype == dtype
assert M.nse == (M.todense() != 0).sum()
assert M.data.dtype == dtype
self.assertIsInstance(M, Obj)
self.assertEqual(M.shape, shape)
self.assertEqual(M.size, np.prod(shape))
self.assertEqual(M.ndim, len(shape))
self.assertEqual(M.dtype, dtype)
self.assertEqual(M.nse, (M.todense() != 0).sum())
self.assertEqual(M.data.dtype, dtype)
self.assertEqual(len(M), M.shape[0])
with self.assertRaises(TypeError):
hash(M)
if isinstance(M, sparse.CSR):
assert len(M.data) == len(M.indices)
assert len(M.indptr) == M.shape[0] + 1
self.assertEqual(len(M.data), len(M.indices))
self.assertEqual(len(M.indptr), M.shape[0] + 1)
elif isinstance(M, sparse.CSC):
assert len(M.data) == len(M.indices)
assert len(M.indptr) == M.shape[1] + 1
self.assertEqual(len(M.data), len(M.indices))
self.assertEqual(len(M.indptr), M.shape[1] + 1)
elif isinstance(M, sparse.COO):
assert len(M.data) == len(M.row) == len(M.col)
self.assertEqual(len(M.data), len(M.row))
self.assertEqual(len(M.data), len(M.col))
elif isinstance(M, sparse.BCOO):
assert M.data.shape[M.n_batch] == M.indices.shape[-2]
assert M.indices.shape[-1] == M.n_sparse
self.assertEqual(M.data.shape[M.n_batch], M.indices.shape[-2])
self.assertEqual(M.indices.shape[-1], M.n_sparse)
elif isinstance(M, sparse.BCSR):
self.assertEqual(M.data.shape[M.n_batch], M.indices.shape[-1])
self.assertEqual(M.indptr.shape[-1], M.shape[M.n_batch] + 1)
else:
raise ValueError(f"{Obj=} not expected.")