Merge pull request #8559 from jakevdp:sparse-shape-tuple

PiperOrigin-RevId: 410506456
This commit is contained in:
jax authors 2021-11-17 06:10:38 -08:00
commit 1063b7d3b7
2 changed files with 4 additions and 4 deletions

View File

@ -730,7 +730,7 @@ class JAXSparse:
return map(_asarray_or_float0, args)
def __init__(self, args, *, shape):
self.shape = shape
self.shape = tuple(shape)
def __repr__(self):
name = self.__class__.__name__

View File

@ -1363,10 +1363,10 @@ class SparseGradTest(jtu.JaxTestCase):
class SparseObjectTest(jtu.JaxTestCase):
def test_repr(self):
M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
assert repr(M) == "BCOO(float32[5], nse=4)"
self.assertEqual(repr(M), "BCOO(float32[5], nse=4)")
M_invalid = sparse.BCOO(([], []), shape=100)
assert repr(M_invalid) == "BCOO(<invalid>)"
M_invalid = sparse.BCOO(([], []), shape=(100,))
self.assertEqual(repr(M_invalid), "BCOO(<invalid>)")
@parameterized.named_parameters(
{"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}