mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix scatter in CLIP mode with uint32 and uint64 indices
Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python constants or C constants. This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to failures when jax_array=False. Since that use case will go away soon we just enable the fix for when jax_array=True. PiperOrigin-RevId: 502568204
This commit is contained in:
parent
7ce9fa2f87
commit
7e0041c903
@ -1562,7 +1562,14 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
|
||||
for i in dnums.scatter_dims_to_operand_dims)
|
||||
# Stack upper_bounds into a Array[n]
|
||||
upper_bound = lax.shape_as_value(upper_bounds)
|
||||
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
|
||||
if jax.config.jax_array:
|
||||
# This fix fails lax_test_no_jax_array
|
||||
upper_bound = lax.min(upper_bound,
|
||||
lax.convert_element_type(np.uint64(np.iinfo(indices.dtype).max),
|
||||
np.int64))
|
||||
else:
|
||||
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
|
||||
|
||||
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
||||
(len(indices.shape) - 1,))
|
||||
return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64),
|
||||
|
@ -2378,10 +2378,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
if harness.fullname.find(s) != -1:
|
||||
raise unittest.SkipTest("TODO(necula): crashes in simplifyDynamicGatherToGather")
|
||||
|
||||
if harness.fullname.find("vmap_dynamic_update_slice_shapes_operand=float32[3]_update=float32[1]_start_indices=(array(1, dtype=uint32),)_enable_xla=True") != -1:
|
||||
# Python int 4294967295 too large to convert to int32
|
||||
raise unittest.SkipTest("TODO(necula): re-land cl/496854361")
|
||||
|
||||
harness.run_test(self)
|
||||
|
||||
|
||||
|
@ -2226,7 +2226,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
|
||||
((10, 5,), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(1,), inserted_window_dims=(0,),
|
||||
scatter_dims_to_operand_dims=(0,))),
|
||||
]],
|
||||
|
Loading…
x
Reference in New Issue
Block a user