mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
bcoo_todense: fix corner case
This commit is contained in:
parent
d622d5c824
commit
72fe3babee
@ -639,8 +639,8 @@ def _bcoo_todense_impl(data, indices, *, shape):
|
||||
batch_slices = tuple(slice(s) for s in shape[:n_batch])
|
||||
sparse_ind = tuple(indices[tuple(np.mgrid[batch_slices]) + (i,)] for i in range(n_sparse))
|
||||
batch_ind = tuple(np.mgrid[batch_slices + (slice(1),)])[:-1]
|
||||
if not (batch_ind or sparse_ind):
|
||||
return data[0]
|
||||
if not sparse_ind:
|
||||
data = data.sum(n_batch, keepdims=bool(batch_ind))
|
||||
return jnp.zeros(shape, data.dtype).at[batch_ind + sparse_ind].add(data)
|
||||
|
||||
@bcoo_todense_p.def_abstract_eval
|
||||
|
Loading…
x
Reference in New Issue
Block a user