Fix scatter in CLIP mode with uint32 and uint64 indices

Clipping uses np.iinfo(indices.dtype).max and those values
are too large to be converted to Python constants or C constants.

This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to
failures when jax_array=False. Since that use case will go away
soon we just enable the fix for when jax_array=True.

PiperOrigin-RevId: 502568204
This commit is contained in:
George Necula 2023-01-17 06:25:26 -08:00 committed by jax authors
parent 7ce9fa2f87
commit 7e0041c903
3 changed files with 9 additions and 6 deletions

View File

@ -1562,7 +1562,14 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
for i in dnums.scatter_dims_to_operand_dims)
# Stack upper_bounds into a Array[n]
upper_bound = lax.shape_as_value(upper_bounds)
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
if jax.config.jax_array:
# This fix fails lax_test_no_jax_array
upper_bound = lax.min(upper_bound,
lax.convert_element_type(np.uint64(np.iinfo(indices.dtype).max),
np.int64))
else:
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64),

View File

@ -2378,10 +2378,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
if harness.fullname.find(s) != -1:
raise unittest.SkipTest("TODO(necula): crashes in simplifyDynamicGatherToGather")
if harness.fullname.find("vmap_dynamic_update_slice_shapes_operand=float32[3]_update=float32[1]_start_indices=(array(1, dtype=uint32),)_enable_xla=True") != -1:
# Python int 4294967295 too large to convert to int32
raise unittest.SkipTest("TODO(necula): re-land cl/496854361")
harness.run_test(self)

View File

@ -2226,7 +2226,7 @@ class LaxTest(jtu.JaxTestCase):
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
((10, 5,), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]],