mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #19871 from gnecula:poly_stride_sym
PiperOrigin-RevId: 608566786
This commit is contained in:
commit
c69a5daca3
@ -772,7 +772,7 @@ def slice_in_dim(operand: Array | np.ndarray, start_index: int | None,
|
||||
axis = int(axis)
|
||||
start_indices[axis] = start_index_int
|
||||
limit_indices[axis] = limit_index_int
|
||||
strides[axis] = int(stride)
|
||||
strides[axis] = core._canonicalize_dimension(stride)
|
||||
|
||||
return slice(operand, start_indices, limit_indices, strides)
|
||||
|
||||
|
@ -414,12 +414,13 @@ class _DecisionByElimination:
|
||||
(op1_l, op1_u) = self.bounds(op1, BoundsPrecision.BEST)
|
||||
(op2_l, op2_u) = self.bounds(op2, BoundsPrecision.BEST)
|
||||
|
||||
def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
|
||||
# When either a or b are infinite, the results represent the limit
|
||||
def math_floor_with_inf(a: float, b: float):
|
||||
# math.floor(a / b), but aware of inf.
|
||||
# When either a or b are infinite, the result represents the limit
|
||||
# of "a // b".
|
||||
assert b != 0
|
||||
assert b != 0 # we caught division by 0 earlier
|
||||
if not np.isinf(b): # divisor b is finite
|
||||
if not np.isinf(a):
|
||||
if not np.isinf(a): # both dividend a and divisor b are finite
|
||||
return math.floor(a / b)
|
||||
# a is infinite, b is finite
|
||||
return -np.inf if (a >= 0) != (b >= 0) else np.inf
|
||||
@ -430,6 +431,9 @@ class _DecisionByElimination:
|
||||
|
||||
# Same reasoning as for multiplication: the bounds are among the cross-product
|
||||
# of the bounds.
|
||||
if op2_l <= 0 <= op2_u:
|
||||
raise InconclusiveDimensionOperation(
|
||||
f"Possible division by 0 in division by {op2}")
|
||||
candidate_bounds = [math_floor_with_inf(op1_l, op2_l),
|
||||
math_floor_with_inf(op1_l, op2_u),
|
||||
math_floor_with_inf(op1_u, op2_l),
|
||||
|
@ -457,7 +457,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
# the "2*ceil(b / 2)".
|
||||
self.assertGreaterEqual(-2 * ((- b) // 2), b)
|
||||
|
||||
def poly_bounds_div(self):
|
||||
def test_bounds_floordiv(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
self.assertEqual(_bounds((a + 4) // 2), (2, np.inf))
|
||||
self.assertEqual(_bounds((a + 4) // -2), (-np.inf, -3))
|
||||
@ -471,9 +471,11 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(_bounds(a - a // 2), (1, np.inf))
|
||||
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 1))
|
||||
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 0))
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Possible division by 0"):
|
||||
_bounds(a // (a - 3))
|
||||
|
||||
def test_bounds_div_generated(self):
|
||||
def test_bounds_floordiv_against_concrete_evaluation(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
# Generate test cases for floordiv and mod: (a + N) // +-2, (N - a) // +-2
|
||||
# and then evaluate them for a = 1, 5, 10000
|
||||
@ -2986,11 +2988,10 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=2, axis=0),
|
||||
arg_descriptors=[RandArg((13, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
# TODO: Not yet, the slice_in_dim does int(stride)
|
||||
# PolyHarness("slice_in_dim", "stride=sym",
|
||||
# lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=x.shape[0] // 4, axis=0),
|
||||
# arg_descriptors=[RandArg((13, 4), _f32)],
|
||||
# polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("slice_in_dim", "stride_sym",
|
||||
lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=1 + x.shape[0] // 4, axis=0),
|
||||
arg_descriptors=[RandArg((13, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
PolyHarness("jnp_split", "idx_tuple_ct",
|
||||
# The indices are a tuple with constants
|
||||
lambda a: jnp.split(a, (2,)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user