mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Copybara import of the project:
-- a74c74c25572eec23c28e08dbe67781a23be19fb by George Necula <gcnecula@gmail.com>: Fix scatter in CLIP mode with uint32 and uint64 indices Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python or C constants. PiperOrigin-RevId: 496883024
This commit is contained in:
parent
76f92c47a6
commit
ce5320a2e4
@ -1562,8 +1562,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
|
||||
for i in dnums.scatter_dims_to_operand_dims)
|
||||
# Stack upper_bounds into a Array[n]
|
||||
upper_bound = lax.shape_as_value(upper_bounds)
|
||||
upper_bound = lax.min(upper_bound,
|
||||
lax.convert_element_type(np.uint64(np.iinfo(indices.dtype).max), np.int64))
|
||||
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
|
||||
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
||||
(len(indices.shape) - 1,))
|
||||
return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64),
|
||||
|
@ -2200,7 +2200,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]],
|
||||
|
Loading…
x
Reference in New Issue
Block a user