Merge pull request #19871 from gnecula:poly_stride_sym

PiperOrigin-RevId: 608566786
This commit is contained in:
jax authors 2024-02-20 05:16:47 -08:00
commit c69a5daca3
3 changed files with 18 additions and 13 deletions

View File

@ -772,7 +772,7 @@ def slice_in_dim(operand: Array | np.ndarray, start_index: int | None,
axis = int(axis)
start_indices[axis] = start_index_int
limit_indices[axis] = limit_index_int
strides[axis] = int(stride)
strides[axis] = core._canonicalize_dimension(stride)
return slice(operand, start_indices, limit_indices, strides)

View File

@ -414,12 +414,13 @@ class _DecisionByElimination:
(op1_l, op1_u) = self.bounds(op1, BoundsPrecision.BEST)
(op2_l, op2_u) = self.bounds(op2, BoundsPrecision.BEST)
def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
# When either a or b are infinite, the results represent the limit
def math_floor_with_inf(a: float, b: float):
# math.floor(a / b), but aware of inf.
# When either a or b are infinite, the result represents the limit
# of "a // b".
assert b != 0
assert b != 0 # we caught division by 0 earlier
if not np.isinf(b): # divisor b is finite
if not np.isinf(a):
if not np.isinf(a): # both dividend a and divisor b are finite
return math.floor(a / b)
# a is infinite, b is finite
return -np.inf if (a >= 0) != (b >= 0) else np.inf
@ -430,6 +431,9 @@ class _DecisionByElimination:
# Same reasoning as for multiplication: the bounds are among the cross-product
# of the bounds.
if op2_l <= 0 <= op2_u:
raise InconclusiveDimensionOperation(
f"Possible division by 0 in division by {op2}")
candidate_bounds = [math_floor_with_inf(op1_l, op2_l),
math_floor_with_inf(op1_l, op2_u),
math_floor_with_inf(op1_u, op2_l),

View File

@ -457,7 +457,7 @@ class DimExprTest(jtu.JaxTestCase):
# the "2*ceil(b / 2)".
self.assertGreaterEqual(-2 * ((- b) // 2), b)
def poly_bounds_div(self):
def test_bounds_floordiv(self):
a, b = shape_poly.symbolic_shape("a, b")
self.assertEqual(_bounds((a + 4) // 2), (2, np.inf))
self.assertEqual(_bounds((a + 4) // -2), (-np.inf, -3))
@ -471,9 +471,11 @@ class DimExprTest(jtu.JaxTestCase):
self.assertEqual(_bounds(a - a // 2), (1, np.inf))
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 1))
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 0))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Possible division by 0"):
_bounds(a // (a - 3))
def test_bounds_div_generated(self):
def test_bounds_floordiv_against_concrete_evaluation(self):
a, b = shape_poly.symbolic_shape("a, b")
# Generate test cases for floordiv and mod: (a + N) // +-2, (N - a) // +-2
# and then evaluate them for a = 1, 5, 10000
@ -2986,11 +2988,10 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=2, axis=0),
arg_descriptors=[RandArg((13, 4), _f32)],
polymorphic_shapes=["b, ..."]),
# TODO: Not yet, the slice_in_dim does int(stride)
# PolyHarness("slice_in_dim", "stride=sym",
# lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=x.shape[0] // 4, axis=0),
# arg_descriptors=[RandArg((13, 4), _f32)],
# polymorphic_shapes=["b, ..."]),
PolyHarness("slice_in_dim", "stride_sym",
lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=1 + x.shape[0] // 4, axis=0),
arg_descriptors=[RandArg((13, 4), _f32)],
polymorphic_shapes=["b, ..."]),
PolyHarness("jnp_split", "idx_tuple_ct",
# The indices are a tuple with constants
lambda a: jnp.split(a, (2,)),