Merge pull request #25409 from gnecula:poly_mod

PiperOrigin-RevId: 705537824
This commit is contained in:
jax authors 2024-12-12 09:50:28 -08:00
commit 3c649b134a
2 changed files with 65 additions and 7 deletions

View File

@ -129,6 +129,7 @@ class _DimFactor:
MOD = "mod"
MAX = "max"
MIN = "min"
# TODO(necula): remove non_negative
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
# max but kept here for backwards compatibility.
@ -1090,6 +1091,24 @@ class SymbolicScope:
raise NotImplementedError(
f"Found multiple equality constraints with the same left-hand-side: {before}")
self._normalization_rules[before] = (after, before_k)
# Look for constraints of the form mod(before_e1, before_k2) * 1 == 0
if (before_k == 1 and
isinstance(constr.e2, int) and constr.e2 == 0 and
(before_f := before.to_factor()) and
before_f.operation == _DimFactor.MOD and
(before_k2 := _DimExpr._to_constant(before_f.operands[1])) is not None):
# Add before_k2*floordiv(before_e1, before_k2) == before_e1
k_times_floordiv = _DimExpr._from_term(
_DimTerm.from_operation(
_DimFactor.FLOORDIV, *before_f.operands, scope=constr.e1.scope),
before_k2, scope=constr.e1.scope)
before_e1 = before_f.operands[0]
self._process_explicit_constraint(
_SymbolicConstraint(cmp=Comparator.EQ,
e1=k_times_floordiv, e2=before_e1,
diff=k_times_floordiv - before_e1,
debug_str=f"{k_times_floordiv} == {before_e1}")
)
self._explicit_constraints.append(constr)

View File

@ -1099,20 +1099,37 @@ class DimExprTest(jtu.JaxTestCase):
self.assertEqual(exp.in_avals[0], exp.in_avals[1])
def test_constraints_eq_threefry(self):
# Test equalities that arise out of the threefree lowering
# Test equalities that arise out of the threefry lowering
# x : i32[a] # a may be even or odd
# x_padded: i32[a + a % 2] = jnp.concat([x, jnp.zeros((a % 2,))])
# x_reshaped: i32[2, (a + a % 2) // 2] = x_padded.reshape((-1, 2))
# x_1 = x_reshaped.reshape((-1,))
# x_reshaped: i32[(a + a % 2) // 2, 2] = x_padded.reshape((-1, 2))
# x_1: i32[a + a % 2] = x_reshaped.reshape((-1,))
a, = shape_poly.symbolic_shape(
"a",
constraints=("mod(a + mod(a, 2), -2) == 0",
"2*floordiv(mod(a, 2) + a, -2) == a"))
constraints=("mod(a + mod(a, 2), -2) == 0",))
x_reshaped, r = divmod(a + a % 2, -2)
self.assertEqual(r, 0)
self.assertEqual(x_reshaped, (a + a % 2) // -2)
self.assertEqual(2 * x_reshaped, a)
self.assertEqual(- x_reshaped, -1 * ((a + a % 2) // -2))
self.assertEqual(-2 * x_reshaped, a + a % 2)
def test_constraints_eq_mod_0(self):
# mod(b, N) == 0 is a common constraint, we need to ensure we can use it
# to infer things like: N * floordiv(b, N) == b, b >= N.
b, c, d = shape_poly.symbolic_shape(
"b, c, d",
constraints=("mod(b, 4) == 0",))
# Inequalities work, because we use more expensive reasoning
self.assertGreaterEqual(b, 4 * (b // 4))
self.assertGreaterEqual(4 * (b // 4), b)
# Equalities used to fail
self.assertEqual(b, 4 * (b // 4))
# And an equality that may come up in a reshape
self.assertEqual(math.prod([b, c, d]), math.prod([b // 4, c, d, 2, 2]))
self.assertGreaterEqual(b, b // 4)
self.assertGreaterEqual(b, 3 * (b // 4))
def test_constraints_eq_a_minus_4d(self):
# simulates d = div(a, 4) and m = mod(a, 4)
@ -1682,6 +1699,28 @@ class ShapePolyTest(jtu.JaxTestCase):
"mod(a + mod(a, 2), -2) == 0",
"-2*floordiv(a + mod(a, 2), -2) == a + mod(a, 2)"])
def test_constraints_eq_mod_0(self):
# mod(b, N) == 0 is a common constraint, we need to ensure we can use it
# to infer things like: N * floordiv(b, N) == b, b >= N.
def f(x): # x: f32[b] and b % 4 == 0
b = x.shape[0]
y1 = jnp.ones((1, 3, 4, b // 4), dtype=x.dtype)
y2 = y1.reshape((1, 3, -1)) # : f32[1, 3, b]
y3 = x.reshape((1, 1, b)) + y2 # : f32[1, 3, b]
slice0 = lax.slice(x, (0,), (b // 4,)) # Requires b >= b // 4
slice1 = lax.slice(x, (0,), (2 * (b // 4),)) # Requires b >= 2 * (b // 4)
slice2 = lax.slice(x, (0,), (3 * (b // 4),)) # Requires b >= 2 * (b // 4)
slice3 = lax.slice(x, (0,), (4 * (b // 4),)) # Requires b >= 2 * (b // 4)
return (jnp.sum(y3) +
jnp.sum(slice0) + jnp.sum(slice1) +
jnp.sum(slice2) + jnp.sum(slice3))
check_shape_poly(self, f,
arg_descriptors=[RandArg((16,), _i32)],
polymorphic_shapes=["b"],
symbolic_constraints=["mod(b, 4) == 0"])
def test_constraints_for_profile(self):
# A somewhat more involved tests to stress test the correctness and
# performance