From 465b48f044f5867a04f9ff75fda14cf1aa471414 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 18 Oct 2021 13:52:42 -0700 Subject: [PATCH] [sparse] preserve dtype in bcoo_todense --- jax/experimental/sparse/bcoo.py | 2 +- tests/sparse_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 71dc867d1..547b65a9b 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -156,7 +156,7 @@ def _bcoo_todense_impl(data, indices, *, shape): batch_ind = tuple(grid)[:-1] if not sparse_ind: - data = data.sum(n_batch, keepdims=bool(batch_ind)) + data = data.sum(n_batch, keepdims=bool(batch_ind), dtype=data.dtype) return jnp.zeros(shape, data.dtype).at[batch_ind + sparse_ind].add(data) @bcoo_todense_p.def_abstract_eval diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 83e2830ad..fbb84c72c 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -384,7 +384,7 @@ class BCOOTest(jtu.JaxTestCase): jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense), "shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense} for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] - for dtype in jtu.dtypes.floating + jtu.dtypes.complex + for dtype in jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex for n_batch in range(len(shape) + 1) for n_dense in range(len(shape) + 1 - n_batch))) def test_bcoo_dense_round_trip(self, shape, dtype, n_batch, n_dense):