update sparse _reshape docstring to match jax.Array.reshape

This commit is contained in:
Blair-Johnson 2024-02-08 16:15:00 -05:00
parent 4c505f8bac
commit 8bcd3d1049

View File

@ -889,7 +889,7 @@ def _sum(self, *args, **kwargs):
return sparsify(lambda x: x.sum(*args, **kwargs))(self)
def _reshape(self, *args, **kwargs):
"""Sum array along axis."""
"""Returns an array containing the same data with a new shape."""
return sparsify(lambda x: x.reshape(*args, **kwargs))(self)
def _astype(self, *args, **kwargs):