mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #25409 from gnecula:poly_mod
PiperOrigin-RevId: 705537824
This commit is contained in:
commit
3c649b134a
@ -129,6 +129,7 @@ class _DimFactor:
|
||||
MOD = "mod"
|
||||
MAX = "max"
|
||||
MIN = "min"
|
||||
# TODO(necula): remove non_negative
|
||||
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
|
||||
# max but kept here for backwards compatibility.
|
||||
|
||||
@ -1090,6 +1091,24 @@ class SymbolicScope:
|
||||
raise NotImplementedError(
|
||||
f"Found multiple equality constraints with the same left-hand-side: {before}")
|
||||
self._normalization_rules[before] = (after, before_k)
|
||||
# Look for constraints of the form mod(before_e1, before_k2) * 1 == 0
|
||||
if (before_k == 1 and
|
||||
isinstance(constr.e2, int) and constr.e2 == 0 and
|
||||
(before_f := before.to_factor()) and
|
||||
before_f.operation == _DimFactor.MOD and
|
||||
(before_k2 := _DimExpr._to_constant(before_f.operands[1])) is not None):
|
||||
# Add before_k2*floordiv(before_e1, before_k2) == before_e1
|
||||
k_times_floordiv = _DimExpr._from_term(
|
||||
_DimTerm.from_operation(
|
||||
_DimFactor.FLOORDIV, *before_f.operands, scope=constr.e1.scope),
|
||||
before_k2, scope=constr.e1.scope)
|
||||
before_e1 = before_f.operands[0]
|
||||
self._process_explicit_constraint(
|
||||
_SymbolicConstraint(cmp=Comparator.EQ,
|
||||
e1=k_times_floordiv, e2=before_e1,
|
||||
diff=k_times_floordiv - before_e1,
|
||||
debug_str=f"{k_times_floordiv} == {before_e1}")
|
||||
)
|
||||
|
||||
self._explicit_constraints.append(constr)
|
||||
|
||||
|
@ -1099,20 +1099,37 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.assertEqual(exp.in_avals[0], exp.in_avals[1])
|
||||
|
||||
def test_constraints_eq_threefry(self):
|
||||
# Test equalities that arise out of the threefree lowering
|
||||
# Test equalities that arise out of the threefry lowering
|
||||
# x : i32[a] # a may be even or odd
|
||||
# x_padded: i32[a + a % 2] = jnp.concat([x, jnp.zeros((a % 2,))])
|
||||
# x_reshaped: i32[2, (a + a % 2) // 2] = x_padded.reshape((-1, 2))
|
||||
# x_1 = x_reshaped.reshape((-1,))
|
||||
# x_reshaped: i32[(a + a % 2) // 2, 2] = x_padded.reshape((-1, 2))
|
||||
# x_1: i32[a + a % 2] = x_reshaped.reshape((-1,))
|
||||
a, = shape_poly.symbolic_shape(
|
||||
"a",
|
||||
constraints=("mod(a + mod(a, 2), -2) == 0",
|
||||
"2*floordiv(mod(a, 2) + a, -2) == a"))
|
||||
constraints=("mod(a + mod(a, 2), -2) == 0",))
|
||||
|
||||
x_reshaped, r = divmod(a + a % 2, -2)
|
||||
self.assertEqual(r, 0)
|
||||
self.assertEqual(x_reshaped, (a + a % 2) // -2)
|
||||
self.assertEqual(2 * x_reshaped, a)
|
||||
self.assertEqual(- x_reshaped, -1 * ((a + a % 2) // -2))
|
||||
self.assertEqual(-2 * x_reshaped, a + a % 2)
|
||||
|
||||
def test_constraints_eq_mod_0(self):
|
||||
# mod(b, N) == 0 is a common constraint, we need to ensure we can use it
|
||||
# to infer things like: N * floordiv(b, N) == b, b >= N.
|
||||
b, c, d = shape_poly.symbolic_shape(
|
||||
"b, c, d",
|
||||
constraints=("mod(b, 4) == 0",))
|
||||
|
||||
# Inequalities work, because we use more expensive reasoning
|
||||
self.assertGreaterEqual(b, 4 * (b // 4))
|
||||
self.assertGreaterEqual(4 * (b // 4), b)
|
||||
# Equalities used to fail
|
||||
self.assertEqual(b, 4 * (b // 4))
|
||||
# And an equality that may come up in a reshape
|
||||
self.assertEqual(math.prod([b, c, d]), math.prod([b // 4, c, d, 2, 2]))
|
||||
|
||||
self.assertGreaterEqual(b, b // 4)
|
||||
self.assertGreaterEqual(b, 3 * (b // 4))
|
||||
|
||||
def test_constraints_eq_a_minus_4d(self):
|
||||
# simulates d = div(a, 4) and m = mod(a, 4)
|
||||
@ -1682,6 +1699,28 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
"mod(a + mod(a, 2), -2) == 0",
|
||||
"-2*floordiv(a + mod(a, 2), -2) == a + mod(a, 2)"])
|
||||
|
||||
def test_constraints_eq_mod_0(self):
|
||||
# mod(b, N) == 0 is a common constraint, we need to ensure we can use it
|
||||
# to infer things like: N * floordiv(b, N) == b, b >= N.
|
||||
def f(x): # x: f32[b] and b % 4 == 0
|
||||
b = x.shape[0]
|
||||
y1 = jnp.ones((1, 3, 4, b // 4), dtype=x.dtype)
|
||||
y2 = y1.reshape((1, 3, -1)) # : f32[1, 3, b]
|
||||
y3 = x.reshape((1, 1, b)) + y2 # : f32[1, 3, b]
|
||||
|
||||
slice0 = lax.slice(x, (0,), (b // 4,)) # Requires b >= b // 4
|
||||
slice1 = lax.slice(x, (0,), (2 * (b // 4),)) # Requires b >= 2 * (b // 4)
|
||||
slice2 = lax.slice(x, (0,), (3 * (b // 4),)) # Requires b >= 2 * (b // 4)
|
||||
slice3 = lax.slice(x, (0,), (4 * (b // 4),)) # Requires b >= 2 * (b // 4)
|
||||
return (jnp.sum(y3) +
|
||||
jnp.sum(slice0) + jnp.sum(slice1) +
|
||||
jnp.sum(slice2) + jnp.sum(slice3))
|
||||
|
||||
check_shape_poly(self, f,
|
||||
arg_descriptors=[RandArg((16,), _i32)],
|
||||
polymorphic_shapes=["b"],
|
||||
symbolic_constraints=["mod(b, 4) == 0"])
|
||||
|
||||
def test_constraints_for_profile(self):
|
||||
# A somewhat more involved tests to stress test the correctness and
|
||||
# performance
|
||||
|
Loading…
x
Reference in New Issue
Block a user