Merge pull request #8269 from jakevdp:sparse-dtype

PiperOrigin-RevId: 404232531
This commit is contained in:
jax authors 2021-10-19 05:45:24 -07:00
commit a91cb81613
2 changed files with 2 additions and 2 deletions

View File

@ -156,7 +156,7 @@ def _bcoo_todense_impl(data, indices, *, shape):
batch_ind = tuple(grid)[:-1]
if not sparse_ind:
data = data.sum(n_batch, keepdims=bool(batch_ind))
data = data.sum(n_batch, keepdims=bool(batch_ind), dtype=data.dtype)
return jnp.zeros(shape, data.dtype).at[batch_ind + sparse_ind].add(data)
@bcoo_todense_p.def_abstract_eval

View File

@ -384,7 +384,7 @@ class BCOOTest(jtu.JaxTestCase):
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for dtype in jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):