Merge pull request #19492 from gnecula:poly_tests

PiperOrigin-RevId: 601050430
This commit is contained in:
jax authors 2024-01-24 02:04:25 -08:00
commit a74b04a43f

View File

@ -17,9 +17,11 @@ from __future__ import annotations
import enum
from collections.abc import Sequence
import cProfile
import itertools
import math
import os
from pstats import Stats
from typing import Any, Callable
import unittest
@ -64,10 +66,54 @@ expect_error_associative_scan = (
"associative scan over axis of non-constant size",
)
# We want to have a complete test suite, even if the decision procedure is not
# yet as tight as we want. We will write the expected values as
# _expect(best=..., current=...) and we use the `current` value in the
# test, but we aspire to a decision procedure where we could compute `best`.
def _expect(*, current, best):
return current
def _bounds(e: shape_poly.DimSize) -> tuple[float, float]:
if not isinstance(e, shape_poly._DimExpr):
e = shape_poly._ensure_poly(e, "_bounds", shape_poly.SymbolicScope())
return e.bounds()
def _assert_equal_bounds(tst: jtu.JaxTestCase,
e: shape_poly.DimSize,
bounds: tuple[float, float]):
if isinstance(e, shape_poly._DimExpr):
scope = e.scope
else:
scope = shape_poly.SymbolicScope()
decision = shape_poly._make_decision_state(scope)
found_bounds = decision.bounds(e)
tst.assertEqual(bounds, found_bounds)
def _start_profile(tst: jtu.JaxTestCase):
tst.prof = None
if os.getenv("JAX_PROFILE_TEST", False):
tst.prof = cProfile.Profile()
tst.prof.enable()
def _stop_profile(tst: jtu.JaxTestCase):
if tst.prof is not None:
p = Stats(tst.prof)
p.strip_dirs()
p.sort_stats("cumtime").print_stats(.2)
p.print_callers(.2)
@jtu.with_config(jax_enable_key_reuse_checks=False)
class DimExprTest(jtu.JaxTestCase):
def setUp(self):
_start_profile(self)
super().setUp()
def tearDown(self):
super().tearDown()
_stop_profile(self)
class AssertionType(enum.Enum):
EQ = 1,
GEQ = 2
@ -143,7 +189,7 @@ class DimExprTest(jtu.JaxTestCase):
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2),
# Keep for backwards compatibility. We ought to be able to parse
# non_negative
("non_negative(a - 2)", core.max_dim(a - 2, 0)),
("non_negative(a - 2)", "build_inside"),
("max(a, b)", "build_inside"),
("min(a, b)", "build_inside"),
]])
@ -316,58 +362,70 @@ class DimExprTest(jtu.JaxTestCase):
shape_poly._DimMon.from_operation(shape_poly._DimAtom.NON_NEGATIVE,
a - 2 * b - 1,
scope=a.scope))
def test_poly_bounds_arithmetic(self):
a, b = shape_poly.symbolic_shape("a, b")
def test_bounds_arithmetic(self):
a, b, c = shape_poly.symbolic_shape("a, b, c")
bounded_le4 = 5 - a
bounded_ge2 = b + 1
bounded_ge0_le4 = a % 5
self.assertEqual(a.bounds(), (1, np.inf))
self.assertEqual((- a).bounds(), (-np.inf, -1))
self.assertEqual(bounded_le4.bounds(), (-np.inf, 4))
self.assertEqual(bounded_ge2.bounds(), (2, np.inf))
self.assertEqual(bounded_ge0_le4.bounds(), (0, 4))
self.assertEqual(_bounds(a), (1, np.inf))
self.assertEqual(_bounds(- a), (-np.inf, -1))
self.assertEqual(_bounds(bounded_le4), (-np.inf, 4))
self.assertEqual(_bounds(bounded_ge2), (2, np.inf))
self.assertEqual(_bounds(bounded_ge0_le4), (0, 4))
self.assertEqual(_bounds(b + 1 - b), (1, 1))
self.assertEqual(_bounds(a), (1, np.inf))
self.assertEqual(_bounds(a - 1), (0, np.inf))
self.assertEqual(_bounds(5*a - 2), (3, np.inf))
self.assertEqual(_bounds(-5*a - 2), (-np.inf, -7))
self.assertEqual(_bounds(2*a + 3*b + 5*c - 1), (9, np.inf))
self.assertEqual(_bounds(2*a - 3*b + 5*c - 1), (-np.inf, np.inf))
self.assertEqual(_bounds(-2*a + -3*b + -5*c + 20), (-np.inf, 10))
# Additions
self.assertEqual((bounded_ge0_le4 + bounded_le4).bounds(), (-np.inf, 8))
self.assertEqual((bounded_ge0_le4 + bounded_ge2).bounds(), (2, np.inf))
self.assertEqual((bounded_le4 + bounded_ge2).bounds(), (-np.inf, np.inf))
self.assertEqual(_bounds(bounded_ge0_le4 + bounded_le4), (-np.inf, 8))
self.assertEqual(_bounds(bounded_ge0_le4 + bounded_ge2), (2, np.inf))
self.assertEqual(_bounds(bounded_le4 + bounded_ge2), (-np.inf, np.inf))
# Subtractions
self.assertEqual((bounded_ge0_le4 - bounded_le4).bounds(), (-4, np.inf))
self.assertEqual((- bounded_ge0_le4 + bounded_le4).bounds(), (-np.inf, 4))
self.assertEqual((bounded_ge0_le4 - bounded_ge2).bounds(), (-np.inf, 2))
self.assertEqual((- bounded_ge0_le4 + bounded_ge2).bounds(), (-2, np.inf))
self.assertEqual((bounded_le4 - bounded_ge2).bounds(), (-np.inf, 2))
self.assertEqual((- bounded_le4 + bounded_ge2).bounds(), (-2, np.inf))
self.assertEqual(_bounds(bounded_ge0_le4 - bounded_le4), (-4, np.inf))
self.assertEqual(_bounds(- bounded_ge0_le4 + bounded_le4), (-np.inf, 4))
self.assertEqual(_bounds(bounded_ge0_le4 - bounded_ge2), (-np.inf, 2))
self.assertEqual(_bounds(- bounded_ge0_le4 + bounded_ge2), (-2, np.inf))
self.assertEqual(_bounds(bounded_le4 - bounded_ge2), (-np.inf, 2))
self.assertEqual(_bounds(- bounded_le4 + bounded_ge2), (-2, np.inf))
# Multiplications
self.assertEqual((2 * a - 3).bounds(), (-1, np.inf))
self.assertEqual((-2 * a - 3).bounds(), (-np.inf, -5))
self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, np.inf))
self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (-np.inf, np.inf))
self.assertEqual((a + b - a * b + a * b * a).bounds(), (-np.inf, np.inf))
self.assertEqual((a + 2 * b - a).bounds(), (2, np.inf))
self.assertEqual((a + 2 * b - a).bounds(), (2, np.inf))
self.assertEqual(_bounds(2 * a - 3), (-1, np.inf))
self.assertEqual(_bounds(-2 * a - 3), (-np.inf, -5))
self.assertEqual(_bounds(3 * a * b * b + 5 * a - 7), (1, np.inf))
self.assertEqual(_bounds(3 * a * b * b - 5 * a - 7), (-np.inf, np.inf))
self.assertEqual(_bounds(a + b - a * b + a * b * a), (-np.inf, np.inf))
self.assertEqual(_bounds(a + 2 * b - a), (2, np.inf))
self.assertEqual(_bounds(a + 2 * b - a), (2, np.inf))
def test_poly_bounds_mod(self):
# Higher order polynomial
self.assertEqual(_bounds(a*a + b - 2), (0, np.inf))
self.assertEqual(_bounds(-2*a*b - b - 2), (-np.inf, -5))
def test_bounds_mod(self):
a, b = shape_poly.symbolic_shape("a, b")
self.assertEqual((5 - a % 5).bounds(), (1, 5))
self.assertEqual((-5 - a % (-5)).bounds(), (-5, -1))
self.assertEqual((a - 5 % a).bounds(), (1, np.inf))
self.assertEqual((a - 5 % a).bounds(), (1, np.inf))
self.assertEqual((3 * (a + b) - 5 % (3 * (a + b))).bounds(), (1, np.inf))
self.assertEqual((- a + (b - 5) % a).bounds(), (-np.inf, -1))
# self.assertEqual(_bounds(5 - a % 5), (1, 5))
self.assertEqual(_bounds(-5 - a % (-5)), (-5, -1))
self.assertEqual(_bounds(a - 5 % a), (1, np.inf))
self.assertEqual(_bounds(a - 5 % a), (1, np.inf))
self.assertEqual(_bounds(3 * (a + b) - 5 % (3 * (a + b))), (1, np.inf))
self.assertEqual(_bounds(- a + (b - 5) % a), (-np.inf, -1))
# mod
self.assertEqual(((b + 1) % 2).bounds(), (0, 1))
self.assertEqual(((b + 1) % -2).bounds(), (-1, 0))
self.assertEqual(((b - 4) % 2).bounds(), (0, 1))
self.assertEqual(((b + 1) % a).bounds(), (0, np.inf))
self.assertEqual((11 % (a + 1)).bounds(), (0, np.inf))
self.assertEqual((-11 % (a + 1)).bounds(), (0, np.inf))
self.assertEqual((b % (a - 2)).bounds(), (-np.inf, np.inf))
self.assertEqual(_bounds((b + 1) % 2), (0, 1))
self.assertEqual(_bounds((b + 1) % -2), (-1, 0))
self.assertEqual(_bounds((b - 4) % 2), (0, 1))
self.assertEqual(_bounds((b + 1) % a), (0, np.inf))
self.assertEqual(_bounds(11 % (a + 1)), (0, np.inf))
self.assertEqual(_bounds(-11 % (a + 1)), (0, np.inf))
self.assertEqual(_bounds(b % (a - 2)), (-np.inf, np.inf))
# This arises in convolutions, because we use "-2 * div(-b, 2)" to get
# the "2*ceil(b / 2)".
@ -375,17 +433,21 @@ class DimExprTest(jtu.JaxTestCase):
def poly_bounds_div(self):
a, b = shape_poly.symbolic_shape("a, b")
self.assertEqual(((a + 4) // 2).bounds(), (2, np.inf))
self.assertEqual(((a + 4) // -2).bounds(), (-np.inf, -3))
self.assertEqual(((a + 5) // 2).bounds(), (3, np.inf))
self.assertEqual(((a + 5) // -2).bounds(), (-np.inf, -3))
self.assertEqual((11 // (a + 1)).bounds(), (0, 5))
self.assertEqual((-11 // (a + 1)).bounds(), (-6, -1))
self.assertEqual((-11 // (- a)).bounds(), (0, 11)) # finite negative dividend, infinite divisor
self.assertEqual(((b + 1) // (a + 1)).bounds(), (0, np.inf))
self.assertEqual((-b // (a + 1)).bounds(), (-np.inf, -1))
self.assertEqual(_bounds((a + 4) // 2), (2, np.inf))
self.assertEqual(_bounds((a + 4) // -2), (-np.inf, -3))
self.assertEqual(_bounds((a + 5) // 2), (3, np.inf))
self.assertEqual(_bounds((a + 5) // -2), (-np.inf, -3))
self.assertEqual(_bounds(11 // (a + 1)), (0, 5))
self.assertEqual(_bounds(-11 // (a + 1)), (-6, -1))
self.assertEqual(_bounds(-11 // (- a)), (0, 11)) # finite negative dividend, infinite divisor
self.assertEqual(_bounds((b + 1) // (a + 1)), (0, np.inf))
self.assertEqual(_bounds(-b // (a + 1)), (-np.inf, -1))
def test_poly_bounds_div_generated(self):
self.assertEqual(_bounds(a - a // 2), (1, np.inf))
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 1))
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 0))
def test_bounds_div_generated(self):
a, b = shape_poly.symbolic_shape("a, b")
# Generate test cases for floordiv and mod: (a + N) // +-2, (N - a) // +-2
# and then evaluate them for a = 1, 5, 10000
@ -397,53 +459,32 @@ class DimExprTest(jtu.JaxTestCase):
for operation in (op.floordiv, op.mod)
]
for atom in div_mod_atoms:
lb, ub = atom.bounds()
lb, ub = _bounds(atom)
self.assertLessEqual(lb, ub)
for a_val in (1, 5, 10000):
atom_val = atom.evaluate(dict(a=a_val))
self.assertGreaterEqual(atom_val, lb)
self.assertLessEqual(atom_val, ub)
def test_poly_bounds_non_negative(self):
def test_bounds_non_negative(self):
a, b = shape_poly.symbolic_shape("a, b")
self.assertEqual(core.non_negative_dim(a).bounds(), (1, np.inf))
self.assertEqual(core.non_negative_dim(a - 5).bounds(), (0, np.inf))
self.assertEqual(core.non_negative_dim(15 - a).bounds(), (0, 14))
self.assertEqual((core.non_negative_dim(15 - a) // 3).bounds(), (0, 4))
self.assertEqual(_bounds(core.non_negative_dim(a)), (1, np.inf))
self.assertEqual(_bounds(core.non_negative_dim(a - 5)), (0, np.inf))
self.assertEqual(_bounds(core.non_negative_dim(15 - a)), (0, 14))
self.assertEqual(_bounds(core.non_negative_dim(15 - a) // 3), (0, 4))
self.assertEqual(_bounds(a - core.non_negative_dim(a - 3)),
_expect(best=(1, 3), current=(-np.inf, np.inf)))
def test_min_dim(self):
a, b, c = shape_poly.symbolic_shape("a, b, c")
self.assertEqual(core.min_dim(a, b).bounds(), (1, np.inf))
self.assertEqual(core.min_dim(2, b).bounds(), (1, 2))
self.assertEqual(core.min_dim(a, -2), -2)
self.assertEqual(core.min_dim(a - 5, 1).bounds(), (-4, 1))
self.assertEqual(core.min_dim(15 - a, 10).bounds(), (-np.inf, 10))
self.assertEqual(core.min_dim(15 - a, 20).bounds(), (-np.inf, 14))
self.assertEqual(a, core.min_dim(a, a + 2))
self.assertEqual(a - 2, core.min_dim(a, a - 2))
self.assertEqual(core.min_dim(a % 2 - 1, -1), -1)
self.assertEqual(core.min_dim(-1, a % 2 - 1), -1)
self.assertGreaterEqual(a, core.min_dim(a, b))
self.assertGreaterEqual(a + c - 1, core.min_dim(a, b))
self.assertGreaterEqual(b, core.min_dim(a, b))
self.assertGreaterEqual(b + c - 1, core.min_dim(a, b))
self.sampled_assertion(core.min_dim(a, 5),
core.min_dim, a, 5)
self.sampled_assertion(core.min_dim(5, a),
core.min_dim, 5, a)
def test_max_dim(self):
a, b, c = shape_poly.symbolic_shape("a, b, c")
a, b, c, d = shape_poly.symbolic_shape("a, b, c, d")
self.assertEqual(core.max_dim(a, b).bounds(), (1, np.inf))
self.assertEqual(core.max_dim(2, b).bounds(), (2, np.inf))
self.assertEqual(core.max_dim(a, 2).bounds(), (2, np.inf))
self.assertEqual(core.max_dim(a - 5, 1).bounds(), (1, np.inf))
self.assertEqual(core.max_dim(15 - a, 0).bounds(), (0, 14))
self.assertEqual((core.max_dim(15 - a, 0) // 3).bounds(), (0, 4))
self.assertEqual(_bounds(core.max_dim(a, b)), (1, np.inf))
self.assertEqual(_bounds(core.max_dim(2, b)), (2, np.inf))
self.assertEqual(_bounds(core.max_dim(a, 2)), (2, np.inf))
self.assertEqual(_bounds(core.max_dim(a - 5, 1)), (1, np.inf))
self.assertEqual(_bounds(core.max_dim(15 - a, 0)), (0, 14))
self.assertEqual(_bounds(core.max_dim(15 - a, 0) // 3), (0, 4))
self.assertEqual(a + 2, core.max_dim(a, a + 2))
self.assertEqual(a , core.max_dim(a, a - 2))
@ -455,10 +496,55 @@ class DimExprTest(jtu.JaxTestCase):
self.assertGreaterEqual(core.max_dim(a, b) + c - 1, b)
self.assertGreaterEqual(core.max_dim(a, b), core.min_dim(a, b))
self.sampled_assertion(core.max_dim(a, 5),
core.max_dim, a, 5)
self.sampled_assertion(core.max_dim(5, a),
core.max_dim, 5, a)
self.assertEqual(_bounds(b - core.max_dim(b - a, 0)),
_expect(best=(1, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(a - core.min_dim(a, b)),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(b - core.min_dim(a, b)),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(core.max_dim(a, b) - a),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(core.max_dim(a, b) - b),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
self.assertEqual((0, 0), _bounds(core.max_dim(1 - b, 0)))
self.assertEqual(_bounds(core.max_dim(a, b) - a - core.max_dim(b - a, 0)),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(core.max_dim(a, b) + core.max_dim(c, d) -
core.max_dim(a + core.max_dim(c, d),
b + core.max_dim(c, d))),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(core.max_dim(a, b) + core.min_dim(c, d) -
core.max_dim(a + core.min_dim(c, d),
b + core.min_dim(c, d))),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
self.sampled_assertion(core.max_dim(a, 5), core.max_dim, a, 5)
self.sampled_assertion(core.max_dim(5, a), core.max_dim, 5, a)
def test_min_dim(self):
a, b, c = shape_poly.symbolic_shape("a, b, c")
self.assertEqual(_bounds(core.min_dim(a, b)), (1, np.inf))
self.assertEqual(_bounds(core.min_dim(2, b)), (1, 2))
self.assertEqual(core.min_dim(a, -2), -2)
self.assertEqual(_bounds(core.min_dim(a - 5, 1)), (-4, 1))
self.assertEqual(_bounds(core.min_dim(15 - a, 10)), (-np.inf, 10))
self.assertEqual(_bounds(core.min_dim(15 - a, 20)), (-np.inf, 14))
# Test simplification during construction
self.assertEqual(a, core.min_dim(a, a + 2))
self.assertEqual(a - 2, core.min_dim(a, a - 2))
self.assertEqual(core.min_dim(a % 2 - 1, -1), -1)
self.assertEqual(core.min_dim(-1, a % 2 - 1), -1)
self.assertGreaterEqual(a, core.min_dim(a, b))
self.assertGreaterEqual(a + c - 1, core.min_dim(a, b))
self.assertGreaterEqual(b, core.min_dim(a, b))
self.assertGreaterEqual(b + c - 1, core.min_dim(a, b))
self.sampled_assertion(core.min_dim(a, 5), core.min_dim, a, 5)
self.sampled_assertion(core.min_dim(5, a), core.min_dim, 5, a)
def test_clamp_dim(self):
a, b = shape_poly.symbolic_shape("a, b")
@ -467,13 +553,21 @@ class DimExprTest(jtu.JaxTestCase):
self.assertLessEqual(b, clamp)
self.assertLessEqual(clamp, b + 10)
def test_poly_bounds_complex(self):
def test_bounds_complex(self):
a, b = shape_poly.symbolic_shape("a, b")
min_a_b = b - core.non_negative_dim(b - a)
# This comes up in slicing with stride
self.assertGreaterEqual(min_a_b // 2, 0)
def test_poly_equal(self):
# Comes up in slice[-2:-2b-2:-2]
self.assertGreaterEqual((-1 * core.max_dim(-1, a - 2 * b - 2) + core.max_dim(-1, a - 2) + 1),
0)
self.assertGreaterEqual((-1 * core.max_dim(-1, a - 2 * b - 2) + core.max_dim(-1, a - 2) + 1) // 2,
0)
self.assertEqual((0, 0),
_bounds(core.max_dim(-1*b + 1, 0) // 2))
def test_equal(self):
a, b = shape_poly.symbolic_shape("a, b")
poly3 = a + 3 - a
self.assertEqual(poly3, 3)
@ -518,30 +612,9 @@ class DimExprTest(jtu.JaxTestCase):
# self.sampled_assertion((a // 2) * 6,
# lambda x: x, 3 * a - 3 * (a % 2))
def test_poly_compare(self):
def test_compare_ge(self):
a, b = shape_poly.symbolic_shape("a, b")
poly = 4 * a + b + 3
self.assertTrue(poly.ge(0))
self.assertTrue(poly.ge(8))
self.assertTrue(poly.ge(poly))
self.assertTrue(poly.ge(poly - 1))
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
re.escape("comparison 'b + 4*a + 3' >= '9' is inconclusive")):
poly.ge(9)
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
"comparison the_comp is inconclusive"):
poly.ge(9, lambda: "the_comp")
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"inconclusive"):
(4 * a - b).ge(0)
def test_poly_compare_overload(self):
a, b = shape_poly.symbolic_shape("a, b")
self.assertGreaterEqual(a, a)
self.assertGreaterEqual(a, 0)
self.assertGreaterEqual(a, 1)
@ -549,25 +622,33 @@ class DimExprTest(jtu.JaxTestCase):
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
a >= 2
poly = 4 * a + b + 3
self.assertGreaterEqual(poly, 0)
self.assertGreaterEqual(poly, 8)
self.assertGreater(poly, 7)
self.assertGreaterEqual(poly, poly)
self.assertGreaterEqual(poly, poly - 1)
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
re.escape("comparison 'b + 4*a + 3' >= '9' is inconclusive")):
poly >= 9
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
re.escape("comparison 'b + 4*a + 3' > '9' is inconclusive")):
poly > 9
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"inconclusive"):
(4 * a - b) >= 0
# LHS is an integer
self.assertLessEqual(8, poly)
self.assertLess(7, poly)
self.assertGreaterEqual(-8, -poly)
self.assertGreater(-7, -poly)
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
poly >= 9
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
(4 * a - b) >= 0
def test_poly_int_results(self):
def test_int_results(self):
# Whenever the result is an integer, it should be represented as a
# Python integer, not a symbolic dimension.
a, b = shape_poly.symbolic_shape("a, b")
@ -597,7 +678,7 @@ class DimExprTest(jtu.JaxTestCase):
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, b + a)", "mod(2*a*b + b^2, b + a)"),
(3, a, "floordiv(3, a)", "mod(3, a)"),
]])
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
def test_divmod(self, *, dividend, quotient, divisor, remainder):
if isinstance(quotient, str):
d1, d2 = divmod(dividend, divisor)
self.assertEqual((quotient, remainder), (str(d1), str(d2)))
@ -650,8 +731,8 @@ class DimExprTest(jtu.JaxTestCase):
def test_constraints_basic(self):
a, b = shape_poly.symbolic_shape("a, b",
constraints=("a >= 5", "b <= 16"))
self.assertEqual(a.bounds(), (5, np.inf))
self.assertEqual(b.bounds(), (1, 16))
self.assertEqual(_bounds(a), (5, np.inf))
self.assertEqual(_bounds(b), (1, 16))
def test_constraints_trivial(self):
a, = shape_poly.symbolic_shape("a",
@ -665,9 +746,22 @@ class DimExprTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError,
"Unsatisfiable.*a <= 0"):
_ = shape_poly.symbolic_shape("a, b",
# Contradicts the default a >= 1
constraints=("a <= 0",))
a, b = shape_poly.symbolic_shape("a, b",
# Contradicts the default a >= 1
constraints=("a <= 0",))
a >= b
def test_constraints_a_minus_4d(self):
# simulates d = div(a, 4) and m = mod(a, 4)
assumptions = ["a >= 4*d + m ",
"a <= 4*d + m",
"m >= 0", "m <= 3"]
scope = shape_poly.SymbolicScope(assumptions)
a, d = shape_poly.symbolic_shape("a, d", scope=scope)
self.assertEqual(_bounds(a - 4*d),
_expect(best=(1, 3), current=(-np.inf, np.inf))) # a - 4d = m >= 1
self.assertEqual(_bounds(a - 2*d),
_expect(best=(3, np.inf), current=(-np.inf, np.inf)))
def test_constraints_errors(self):
with self.assertRaisesRegex(ValueError,
@ -708,12 +802,50 @@ class DimExprTest(jtu.JaxTestCase):
constraints=("a >= b",))
self.assertGreaterEqual(a, b)
def test_constraints_complex(self):
a, b, c = shape_poly.symbolic_shape(
"a, b, c",
constraints=("a + 2 <= b", "b <= a + 5", "a + b >= c"))
self.assertEqual(_bounds(b),
_expect(best=(3, np.inf), current=(1, np.inf)))
self.assertEqual(_bounds(b - a),
_expect(best=(2, 5), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(b - a - 7),
_expect(best=(-5, -2), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(c - 2*a - 5),
_expect(best=(-np.inf, 0), current=(-np.inf, np.inf)))
def test_constraints_fractional(self):
a, = shape_poly.symbolic_shape("a",
constraints=("2 * a >= 5", "3 * a <= 10",))
self.assertEqual(_bounds(5*a - 2), (13, 13))
@jtu.parameterized_filterable(
kwargs=[
dict(constraint="2*a + 3*b >= 10", exp="a + b - 2",
bounds=_expect(best=(2, np.inf), current=(0, np.inf))),
dict(constraint="-2*a + 3*b >= 10", exp="a + 2*b",
bounds=_expect(best=(9, np.inf), current=(3, np.inf))),
dict(constraint="-2*a + -3*b >= -10", exp="-1*a + 2*b",
bounds=_expect(best=(-1, 3), current=(-np.inf, np.inf))),
dict(constraint="2*a + -3*b >= 10", exp="-1*a + 2*b", bounds=(-np.inf, np.inf)),
]
)
def test_constraints_complex_gen(self,
constraint: str, exp: str,
bounds: tuple[float, float]):
a, b, exp = shape_poly.symbolic_shape(
"a, b, " + exp,
constraints=(constraint,))
self.assertEqual(bounds, _bounds(exp))
def test_constraints_override(self):
# Some constaints override other
a, b = shape_poly.symbolic_shape("a, b",
constraints=("a >= 5", "b <= 16",
"a >= 10", "b <= 10"))
self.assertEqual(a.bounds(), (10, np.inf))
self.assertEqual(b.bounds(), (1, 10))
self.assertEqual(_bounds(a), (10, np.inf))
self.assertEqual(_bounds(b), (1, 10))
def test_constraints_error_msg(self):
a, b = shape_poly.symbolic_shape("a, b",
@ -728,7 +860,83 @@ class DimExprTest(jtu.JaxTestCase):
a, b = shape_poly.symbolic_shape("a, b",
constraints=("a >= 5", "a <= 10"))
self.assertIs(a.scope, b.scope)
self.assertEqual(a.bounds(), (5, 10))
self.assertEqual(_bounds(a), (5, 10))
def test_constraints_seq(self):
a1, a2, a3, a4, a5 = shape_poly.symbolic_shape(
"a1, a2, a3, a4, a5",
constraints=(
"a1 >= a2",
"a2 >= a3",
"a3 >= a4",
"a4 >= a5",
)
)
self.assertEqual(_bounds(a1 - a5),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
def test_constraints_rounding_monomials(self):
a1, a2 = shape_poly.symbolic_shape(
"a1, a2",
constraints=(
# a1 >= 2 and a1 <= 2
"2 * a1 >= 3", "-2 * a1 >= -5"
)
)
self.assertEqual(_bounds(a1), (2, 2))
def test_constraints_rounding_not_monomials(self):
a1, a2 = shape_poly.symbolic_shape(
"a1, a2",
constraints=(
# a1 >= a2 + 2 and a1 <= a2 + 2
"2*a1 >= 2*a2 + 3", "-2*a1 + 2*a2 >= -5"
)
)
self.assertEqual(_bounds(a1 - a2 - 2),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(a2 - a1 + 2),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
def test_constraints_unsat_trivial(self):
with self.assertRaisesRegex(ValueError,
r"Unsatisfiable explicit constraint: a1 >= a1 \+ 1"):
_ = shape_poly.symbolic_shape(
"a1", constraints=("a1 >= a1 + 1",))
def test_constraints_unsat_monomials(self):
with self.assertRaisesRegex(_expect(best=ValueError,
current=shape_poly.InconclusiveDimensionOperation),
_expect(best="Unsatisfiable constraint",
current="inconclusive")):
a1, a2, *_ = shape_poly.symbolic_shape(
"a1, a2, a3, a4",
constraints=(
# The following -> a1 >= 5
"a1 >= a3 + 1", "a3 >= 4",
# The following -> a1 <= 4
"a1 <= a4 + 2", "a4 <= 2"))
a1 >= a2
def test_constraints_unsat_not_monomials(self):
with self.assertRaisesRegex(_expect(best=ValueError,
current=shape_poly.InconclusiveDimensionOperation),
_expect(best="Unsatisfiable constraint",
current="inconclusive")):
a1, a2, a3, a4 = shape_poly.symbolic_shape(
"a1, a2, a3, a4",
constraints=(
# The following -> a1 >= a2 + 5
"a1 >= a3 + 1", "a3 >= a2 + 4",
# The following -> a1 <= a2 + 4
"a1 <= a4 + 2", "a4 <= a2 + 2"))
self.assertGreaterEqual(a1, a2)
def test_constraints_sharing(self):
a, = shape_poly.symbolic_shape("a",
constraints=("a >= 5", "a <= 10"))
self.assertEqual(_bounds(a), (5, 10))
# The constraints order does not matter, and they are canonicalized
a1, = shape_poly.symbolic_shape("a",
constraints=("2*a + 5 <= a + 15", "a >= 5"))
@ -739,7 +947,7 @@ class DimExprTest(jtu.JaxTestCase):
def test_constraints_different_scope(self):
a, = shape_poly.symbolic_shape("a",
constraints=("a >= 5", "a <= 10"))
self.assertEqual(a.bounds(), (5, 10))
self.assertEqual(_bounds(a), (5, 10))
a1, = shape_poly.symbolic_shape("a",
constraints=("a <= 10",))
self.assertNotEqual(a, a1)
@ -882,6 +1090,14 @@ def check_shape_poly(tst, f_jax: Callable, *,
@jtu.with_config(jax_enable_key_reuse_checks=False)
class ShapePolyTest(jtu.JaxTestCase):
def setUp(self):
_start_profile(self)
super().setUp()
def tearDown(self):
super().tearDown()
_stop_profile(self)
def test_simple_unary(self):
"""Test shape polymorphism for a simple case, unary function."""
@ -1118,10 +1334,28 @@ class ShapePolyTest(jtu.JaxTestCase):
polymorphic_shapes=["a"],
symbolic_constraints=["a >= 8"])
def test_constraints_for_profile(self):
# A somewhat more involved tests to stress test the correctness and
# performance
def f(x): # x: i32[a, b]
acc = 0
for start in range(0, 10):
slice = x[start::2] # exercises floordiv and min
acc += jnp.sum(slice, axis=0)
slice = x[start:(x.shape[0] - x.shape[0] % 2):2] # exercises max and min
acc += jnp.sum(slice, axis=0)
return acc
exp = export.export(f)(jax.ShapeDtypeStruct(export.symbolic_shape("a, b"),
np.int32))
def test_constraints_compile_time_check(self):
def f(x): # x: i32[a]
a = x.shape[0]
assert a.bounds() == (2, 4)
assert _bounds(a) == (2, 4)
return lax.dynamic_slice_in_dim(x, 1, 2, 0)
x_spec = jax.ShapeDtypeStruct(
@ -1154,7 +1388,7 @@ class ShapePolyTest(jtu.JaxTestCase):
nonlocal f_tracing_count
f_tracing_count += 1
a = x.shape[0]
assert a.bounds() == expected_a_bounds
assert _bounds(a) == expected_a_bounds
x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), np.int32)
_ = export.export(f)(x_spec)
@ -2692,21 +2926,12 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
"""This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES."""
def setUp(self):
self.prof = None
if os.getenv("JAX_PROFILE_TEST", False):
import cProfile
self.prof = cProfile.Profile()
self.prof.enable()
_start_profile(self)
super().setUp()
def tearDown(self):
super().tearDown()
if self.prof is not None:
from pstats import Stats
p = Stats(self.prof)
p.strip_dirs()
p.sort_stats("cumtime").print_stats(.2)
p.print_callers(.2)
_stop_profile(self)
# For each primitive "xxx" the test will be called "test_harness_xxx_...".
# If you want to run this test for only one harness that includes "foo"
@ -2787,6 +3012,5 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
for fname, _ in harness.override_jax_config_flags.items():
jax.config.update(fname, prev_jax_config_flags[fname])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())