mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13157 from jakevdp:bcoo-astype
PiperOrigin-RevId: 487046458
This commit is contained in:
commit
768076eec4
@ -787,6 +787,10 @@ def _reshape(self, *args, **kwargs):
|
||||
"""Sum array along axis."""
|
||||
return sparsify(lambda x: x.reshape(*args, **kwargs))(self)
|
||||
|
||||
def _astype(self, *args, **kwargs):
|
||||
"""Copy the array and cast to a specified dtype."""
|
||||
return sparsify(lambda x: x.astype(*args, **kwargs))(self)
|
||||
|
||||
def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
mode=None, fill_value=None):
|
||||
# Only sparsify the array argument; sparse indices not yet supported
|
||||
@ -801,8 +805,9 @@ def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=Fa
|
||||
_swap_args = lambda f: lambda a, b: f(b, a)
|
||||
|
||||
_bcoo_methods = {
|
||||
'reshape': _reshape,
|
||||
'sum': _sum,
|
||||
"astype": _astype,
|
||||
"reshape": _reshape,
|
||||
"sum": _sum,
|
||||
"__neg__": sparsify(jnp.negative),
|
||||
"__pos__": sparsify(jnp.positive),
|
||||
"__matmul__": sparsify(jnp.matmul),
|
||||
|
@ -2537,6 +2537,8 @@ class SparseObjectTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(M.sum(1), Msp.sum(1).todense())
|
||||
self.assertArraysEqual(M.sum(), Msp.sum())
|
||||
|
||||
self.assertArraysEqual(M.astype(float), Msp.astype(float).todense())
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, n_batch=n_batch)
|
||||
for shape in [(5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user