Merge pull request #14809 from jakevdp:fix-oob-correction

PiperOrigin-RevId: 514572448
This commit is contained in:
jax authors 2023-03-06 17:25:16 -08:00
commit 0ec82f4d62

View File

@ -1439,7 +1439,7 @@ def _coo_correct_out_of_bound_indices(row, col, shape, transpose):
f = partial(_coo_correct_out_of_bound_indices,
shape=shape[row.ndim:], transpose=transpose)
return nfold_vmap(f, row.ndim)(row, col)
mask = (row > shape[0]) | (col > shape[1])
mask = (row >= shape[0]) | (col >= shape[1])
if transpose:
row = jnp.where(mask, 0, row)
col = jnp.where(mask, shape[1], col)