Merge pull request #13157 from jakevdp:bcoo-astype

PiperOrigin-RevId: 487046458
This commit is contained in:
jax authors 2022-11-08 14:05:09 -08:00
commit 768076eec4
2 changed files with 9 additions and 2 deletions

View File

@ -787,6 +787,10 @@ def _reshape(self, *args, **kwargs):
"""Sum array along axis."""
return sparsify(lambda x: x.reshape(*args, **kwargs))(self)
def _astype(self, *args, **kwargs):
"""Copy the array and cast to a specified dtype."""
return sparsify(lambda x: x.astype(*args, **kwargs))(self)
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# Only sparsify the array argument; sparse indices not yet supported
@ -801,8 +805,9 @@ def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=Fa
_swap_args = lambda f: lambda a, b: f(b, a)
_bcoo_methods = {
'reshape': _reshape,
'sum': _sum,
"astype": _astype,
"reshape": _reshape,
"sum": _sum,
"__neg__": sparsify(jnp.negative),
"__pos__": sparsify(jnp.positive),
"__matmul__": sparsify(jnp.matmul),

View File

@ -2537,6 +2537,8 @@ class SparseObjectTest(jtu.JaxTestCase):
self.assertArraysEqual(M.sum(1), Msp.sum(1).todense())
self.assertArraysEqual(M.sum(), Msp.sum())
self.assertArraysEqual(M.astype(float), Msp.astype(float).todense())
@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch)
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]