mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #8269 from jakevdp:sparse-dtype
PiperOrigin-RevId: 404232531
This commit is contained in:
commit
a91cb81613
@ -156,7 +156,7 @@ def _bcoo_todense_impl(data, indices, *, shape):
|
||||
batch_ind = tuple(grid)[:-1]
|
||||
|
||||
if not sparse_ind:
|
||||
data = data.sum(n_batch, keepdims=bool(batch_ind))
|
||||
data = data.sum(n_batch, keepdims=bool(batch_ind), dtype=data.dtype)
|
||||
return jnp.zeros(shape, data.dtype).at[batch_ind + sparse_ind].add(data)
|
||||
|
||||
@bcoo_todense_p.def_abstract_eval
|
||||
|
@ -384,7 +384,7 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
|
||||
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for dtype in jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
|
||||
for n_batch in range(len(shape) + 1)
|
||||
for n_dense in range(len(shape) + 1 - n_batch)))
|
||||
def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):
|
||||
|
Loading…
x
Reference in New Issue
Block a user