mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8559 from jakevdp:sparse-shape-tuple
PiperOrigin-RevId: 410506456
This commit is contained in:
commit
1063b7d3b7
@ -730,7 +730,7 @@ class JAXSparse:
|
||||
return map(_asarray_or_float0, args)
|
||||
|
||||
def __init__(self, args, *, shape):
|
||||
self.shape = shape
|
||||
self.shape = tuple(shape)
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
|
@ -1363,10 +1363,10 @@ class SparseGradTest(jtu.JaxTestCase):
|
||||
class SparseObjectTest(jtu.JaxTestCase):
|
||||
def test_repr(self):
|
||||
M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
|
||||
assert repr(M) == "BCOO(float32[5], nse=4)"
|
||||
self.assertEqual(repr(M), "BCOO(float32[5], nse=4)")
|
||||
|
||||
M_invalid = sparse.BCOO(([], []), shape=100)
|
||||
assert repr(M_invalid) == "BCOO(<invalid>)"
|
||||
M_invalid = sparse.BCOO(([], []), shape=(100,))
|
||||
self.assertEqual(repr(M_invalid), "BCOO(<invalid>)")
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
|
||||
|
Loading…
x
Reference in New Issue
Block a user