Copybara import of the project:

--
a74c74c25572eec23c28e08dbe67781a23be19fb by George Necula <gcnecula@gmail.com>:

Fix scatter in CLIP mode with uint32 and uint64 indices

Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python or C constants.

PiperOrigin-RevId: 496883024
This commit is contained in:
George Necula 2022-12-21 03:46:01 -08:00 committed by jax authors
parent 76f92c47a6
commit ce5320a2e4
2 changed files with 2 additions and 3 deletions

View File

@ -1562,8 +1562,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
for i in dnums.scatter_dims_to_operand_dims)
# Stack upper_bounds into a Array[n]
upper_bound = lax.shape_as_value(upper_bounds)
upper_bound = lax.min(upper_bound,
lax.convert_element_type(np.uint64(np.iinfo(indices.dtype).max), np.int64))
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64),

View File

@ -2200,7 +2200,7 @@ class LaxTest(jtu.JaxTestCase):
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]],