mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14809 from jakevdp:fix-oob-correction
PiperOrigin-RevId: 514572448
This commit is contained in:
commit
0ec82f4d62
@ -1439,7 +1439,7 @@ def _coo_correct_out_of_bound_indices(row, col, shape, transpose):
|
||||
f = partial(_coo_correct_out_of_bound_indices,
|
||||
shape=shape[row.ndim:], transpose=transpose)
|
||||
return nfold_vmap(f, row.ndim)(row, col)
|
||||
mask = (row > shape[0]) | (col > shape[1])
|
||||
mask = (row >= shape[0]) | (col >= shape[1])
|
||||
if transpose:
|
||||
row = jnp.where(mask, 0, row)
|
||||
col = jnp.where(mask, shape[1], col)
|
||||
|
Loading…
x
Reference in New Issue
Block a user