mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #19492 from gnecula:poly_tests
PiperOrigin-RevId: 601050430
This commit is contained in:
commit
a74b04a43f
@ -17,9 +17,11 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from collections.abc import Sequence
|
||||
import cProfile
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from pstats import Stats
|
||||
from typing import Any, Callable
|
||||
import unittest
|
||||
|
||||
@ -64,10 +66,54 @@ expect_error_associative_scan = (
|
||||
"associative scan over axis of non-constant size",
|
||||
)
|
||||
|
||||
# We want to have a complete test suite, even if the decision procedure is not
|
||||
# yet as tight as we want. We will write the expected values as
|
||||
# _expect(best=..., current=...) and we use the `current` value in the
|
||||
# test, but we aspire to a decision procedure where we could compute `best`.
|
||||
def _expect(*, current, best):
|
||||
return current
|
||||
|
||||
|
||||
def _bounds(e: shape_poly.DimSize) -> tuple[float, float]:
|
||||
if not isinstance(e, shape_poly._DimExpr):
|
||||
e = shape_poly._ensure_poly(e, "_bounds", shape_poly.SymbolicScope())
|
||||
return e.bounds()
|
||||
|
||||
def _assert_equal_bounds(tst: jtu.JaxTestCase,
|
||||
e: shape_poly.DimSize,
|
||||
bounds: tuple[float, float]):
|
||||
if isinstance(e, shape_poly._DimExpr):
|
||||
scope = e.scope
|
||||
else:
|
||||
scope = shape_poly.SymbolicScope()
|
||||
decision = shape_poly._make_decision_state(scope)
|
||||
found_bounds = decision.bounds(e)
|
||||
tst.assertEqual(bounds, found_bounds)
|
||||
|
||||
def _start_profile(tst: jtu.JaxTestCase):
|
||||
tst.prof = None
|
||||
if os.getenv("JAX_PROFILE_TEST", False):
|
||||
tst.prof = cProfile.Profile()
|
||||
tst.prof.enable()
|
||||
|
||||
def _stop_profile(tst: jtu.JaxTestCase):
|
||||
if tst.prof is not None:
|
||||
p = Stats(tst.prof)
|
||||
p.strip_dirs()
|
||||
p.sort_stats("cumtime").print_stats(.2)
|
||||
p.print_callers(.2)
|
||||
|
||||
@jtu.with_config(jax_enable_key_reuse_checks=False)
|
||||
class DimExprTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
_start_profile(self)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
_stop_profile(self)
|
||||
|
||||
class AssertionType(enum.Enum):
|
||||
EQ = 1,
|
||||
GEQ = 2
|
||||
@ -143,7 +189,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2),
|
||||
# Keep for backwards compatibility. We ought to be able to parse
|
||||
# non_negative
|
||||
("non_negative(a - 2)", core.max_dim(a - 2, 0)),
|
||||
("non_negative(a - 2)", "build_inside"),
|
||||
("max(a, b)", "build_inside"),
|
||||
("min(a, b)", "build_inside"),
|
||||
]])
|
||||
@ -316,58 +362,70 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
shape_poly._DimMon.from_operation(shape_poly._DimAtom.NON_NEGATIVE,
|
||||
a - 2 * b - 1,
|
||||
scope=a.scope))
|
||||
|
||||
def test_poly_bounds_arithmetic(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
def test_bounds_arithmetic(self):
|
||||
a, b, c = shape_poly.symbolic_shape("a, b, c")
|
||||
bounded_le4 = 5 - a
|
||||
bounded_ge2 = b + 1
|
||||
bounded_ge0_le4 = a % 5
|
||||
self.assertEqual(a.bounds(), (1, np.inf))
|
||||
self.assertEqual((- a).bounds(), (-np.inf, -1))
|
||||
self.assertEqual(bounded_le4.bounds(), (-np.inf, 4))
|
||||
self.assertEqual(bounded_ge2.bounds(), (2, np.inf))
|
||||
self.assertEqual(bounded_ge0_le4.bounds(), (0, 4))
|
||||
self.assertEqual(_bounds(a), (1, np.inf))
|
||||
self.assertEqual(_bounds(- a), (-np.inf, -1))
|
||||
self.assertEqual(_bounds(bounded_le4), (-np.inf, 4))
|
||||
self.assertEqual(_bounds(bounded_ge2), (2, np.inf))
|
||||
self.assertEqual(_bounds(bounded_ge0_le4), (0, 4))
|
||||
|
||||
self.assertEqual(_bounds(b + 1 - b), (1, 1))
|
||||
self.assertEqual(_bounds(a), (1, np.inf))
|
||||
self.assertEqual(_bounds(a - 1), (0, np.inf))
|
||||
self.assertEqual(_bounds(5*a - 2), (3, np.inf))
|
||||
self.assertEqual(_bounds(-5*a - 2), (-np.inf, -7))
|
||||
self.assertEqual(_bounds(2*a + 3*b + 5*c - 1), (9, np.inf))
|
||||
self.assertEqual(_bounds(2*a - 3*b + 5*c - 1), (-np.inf, np.inf))
|
||||
self.assertEqual(_bounds(-2*a + -3*b + -5*c + 20), (-np.inf, 10))
|
||||
|
||||
# Additions
|
||||
self.assertEqual((bounded_ge0_le4 + bounded_le4).bounds(), (-np.inf, 8))
|
||||
self.assertEqual((bounded_ge0_le4 + bounded_ge2).bounds(), (2, np.inf))
|
||||
self.assertEqual((bounded_le4 + bounded_ge2).bounds(), (-np.inf, np.inf))
|
||||
self.assertEqual(_bounds(bounded_ge0_le4 + bounded_le4), (-np.inf, 8))
|
||||
self.assertEqual(_bounds(bounded_ge0_le4 + bounded_ge2), (2, np.inf))
|
||||
self.assertEqual(_bounds(bounded_le4 + bounded_ge2), (-np.inf, np.inf))
|
||||
|
||||
# Subtractions
|
||||
self.assertEqual((bounded_ge0_le4 - bounded_le4).bounds(), (-4, np.inf))
|
||||
self.assertEqual((- bounded_ge0_le4 + bounded_le4).bounds(), (-np.inf, 4))
|
||||
self.assertEqual((bounded_ge0_le4 - bounded_ge2).bounds(), (-np.inf, 2))
|
||||
self.assertEqual((- bounded_ge0_le4 + bounded_ge2).bounds(), (-2, np.inf))
|
||||
self.assertEqual((bounded_le4 - bounded_ge2).bounds(), (-np.inf, 2))
|
||||
self.assertEqual((- bounded_le4 + bounded_ge2).bounds(), (-2, np.inf))
|
||||
self.assertEqual(_bounds(bounded_ge0_le4 - bounded_le4), (-4, np.inf))
|
||||
self.assertEqual(_bounds(- bounded_ge0_le4 + bounded_le4), (-np.inf, 4))
|
||||
self.assertEqual(_bounds(bounded_ge0_le4 - bounded_ge2), (-np.inf, 2))
|
||||
self.assertEqual(_bounds(- bounded_ge0_le4 + bounded_ge2), (-2, np.inf))
|
||||
self.assertEqual(_bounds(bounded_le4 - bounded_ge2), (-np.inf, 2))
|
||||
self.assertEqual(_bounds(- bounded_le4 + bounded_ge2), (-2, np.inf))
|
||||
|
||||
# Multiplications
|
||||
self.assertEqual((2 * a - 3).bounds(), (-1, np.inf))
|
||||
self.assertEqual((-2 * a - 3).bounds(), (-np.inf, -5))
|
||||
self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, np.inf))
|
||||
self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (-np.inf, np.inf))
|
||||
self.assertEqual((a + b - a * b + a * b * a).bounds(), (-np.inf, np.inf))
|
||||
self.assertEqual((a + 2 * b - a).bounds(), (2, np.inf))
|
||||
self.assertEqual((a + 2 * b - a).bounds(), (2, np.inf))
|
||||
self.assertEqual(_bounds(2 * a - 3), (-1, np.inf))
|
||||
self.assertEqual(_bounds(-2 * a - 3), (-np.inf, -5))
|
||||
self.assertEqual(_bounds(3 * a * b * b + 5 * a - 7), (1, np.inf))
|
||||
self.assertEqual(_bounds(3 * a * b * b - 5 * a - 7), (-np.inf, np.inf))
|
||||
self.assertEqual(_bounds(a + b - a * b + a * b * a), (-np.inf, np.inf))
|
||||
self.assertEqual(_bounds(a + 2 * b - a), (2, np.inf))
|
||||
self.assertEqual(_bounds(a + 2 * b - a), (2, np.inf))
|
||||
|
||||
def test_poly_bounds_mod(self):
|
||||
# Higher order polynomial
|
||||
self.assertEqual(_bounds(a*a + b - 2), (0, np.inf))
|
||||
self.assertEqual(_bounds(-2*a*b - b - 2), (-np.inf, -5))
|
||||
|
||||
def test_bounds_mod(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
|
||||
self.assertEqual((5 - a % 5).bounds(), (1, 5))
|
||||
self.assertEqual((-5 - a % (-5)).bounds(), (-5, -1))
|
||||
self.assertEqual((a - 5 % a).bounds(), (1, np.inf))
|
||||
self.assertEqual((a - 5 % a).bounds(), (1, np.inf))
|
||||
self.assertEqual((3 * (a + b) - 5 % (3 * (a + b))).bounds(), (1, np.inf))
|
||||
self.assertEqual((- a + (b - 5) % a).bounds(), (-np.inf, -1))
|
||||
# self.assertEqual(_bounds(5 - a % 5), (1, 5))
|
||||
self.assertEqual(_bounds(-5 - a % (-5)), (-5, -1))
|
||||
self.assertEqual(_bounds(a - 5 % a), (1, np.inf))
|
||||
self.assertEqual(_bounds(a - 5 % a), (1, np.inf))
|
||||
self.assertEqual(_bounds(3 * (a + b) - 5 % (3 * (a + b))), (1, np.inf))
|
||||
self.assertEqual(_bounds(- a + (b - 5) % a), (-np.inf, -1))
|
||||
|
||||
# mod
|
||||
self.assertEqual(((b + 1) % 2).bounds(), (0, 1))
|
||||
self.assertEqual(((b + 1) % -2).bounds(), (-1, 0))
|
||||
self.assertEqual(((b - 4) % 2).bounds(), (0, 1))
|
||||
self.assertEqual(((b + 1) % a).bounds(), (0, np.inf))
|
||||
self.assertEqual((11 % (a + 1)).bounds(), (0, np.inf))
|
||||
self.assertEqual((-11 % (a + 1)).bounds(), (0, np.inf))
|
||||
self.assertEqual((b % (a - 2)).bounds(), (-np.inf, np.inf))
|
||||
self.assertEqual(_bounds((b + 1) % 2), (0, 1))
|
||||
self.assertEqual(_bounds((b + 1) % -2), (-1, 0))
|
||||
self.assertEqual(_bounds((b - 4) % 2), (0, 1))
|
||||
self.assertEqual(_bounds((b + 1) % a), (0, np.inf))
|
||||
self.assertEqual(_bounds(11 % (a + 1)), (0, np.inf))
|
||||
self.assertEqual(_bounds(-11 % (a + 1)), (0, np.inf))
|
||||
self.assertEqual(_bounds(b % (a - 2)), (-np.inf, np.inf))
|
||||
|
||||
# This arises in convolutions, because we use "-2 * div(-b, 2)" to get
|
||||
# the "2*ceil(b / 2)".
|
||||
@ -375,17 +433,21 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
|
||||
def poly_bounds_div(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
self.assertEqual(((a + 4) // 2).bounds(), (2, np.inf))
|
||||
self.assertEqual(((a + 4) // -2).bounds(), (-np.inf, -3))
|
||||
self.assertEqual(((a + 5) // 2).bounds(), (3, np.inf))
|
||||
self.assertEqual(((a + 5) // -2).bounds(), (-np.inf, -3))
|
||||
self.assertEqual((11 // (a + 1)).bounds(), (0, 5))
|
||||
self.assertEqual((-11 // (a + 1)).bounds(), (-6, -1))
|
||||
self.assertEqual((-11 // (- a)).bounds(), (0, 11)) # finite negative dividend, infinite divisor
|
||||
self.assertEqual(((b + 1) // (a + 1)).bounds(), (0, np.inf))
|
||||
self.assertEqual((-b // (a + 1)).bounds(), (-np.inf, -1))
|
||||
self.assertEqual(_bounds((a + 4) // 2), (2, np.inf))
|
||||
self.assertEqual(_bounds((a + 4) // -2), (-np.inf, -3))
|
||||
self.assertEqual(_bounds((a + 5) // 2), (3, np.inf))
|
||||
self.assertEqual(_bounds((a + 5) // -2), (-np.inf, -3))
|
||||
self.assertEqual(_bounds(11 // (a + 1)), (0, 5))
|
||||
self.assertEqual(_bounds(-11 // (a + 1)), (-6, -1))
|
||||
self.assertEqual(_bounds(-11 // (- a)), (0, 11)) # finite negative dividend, infinite divisor
|
||||
self.assertEqual(_bounds((b + 1) // (a + 1)), (0, np.inf))
|
||||
self.assertEqual(_bounds(-b // (a + 1)), (-np.inf, -1))
|
||||
|
||||
def test_poly_bounds_div_generated(self):
|
||||
self.assertEqual(_bounds(a - a // 2), (1, np.inf))
|
||||
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 1))
|
||||
self.assertEqual(_bounds(a - 2 * (a // 2)), (0, 0))
|
||||
|
||||
def test_bounds_div_generated(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
# Generate test cases for floordiv and mod: (a + N) // +-2, (N - a) // +-2
|
||||
# and then evaluate them for a = 1, 5, 10000
|
||||
@ -397,53 +459,32 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
for operation in (op.floordiv, op.mod)
|
||||
]
|
||||
for atom in div_mod_atoms:
|
||||
lb, ub = atom.bounds()
|
||||
lb, ub = _bounds(atom)
|
||||
self.assertLessEqual(lb, ub)
|
||||
for a_val in (1, 5, 10000):
|
||||
atom_val = atom.evaluate(dict(a=a_val))
|
||||
self.assertGreaterEqual(atom_val, lb)
|
||||
self.assertLessEqual(atom_val, ub)
|
||||
|
||||
def test_poly_bounds_non_negative(self):
|
||||
def test_bounds_non_negative(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
|
||||
self.assertEqual(core.non_negative_dim(a).bounds(), (1, np.inf))
|
||||
self.assertEqual(core.non_negative_dim(a - 5).bounds(), (0, np.inf))
|
||||
self.assertEqual(core.non_negative_dim(15 - a).bounds(), (0, 14))
|
||||
self.assertEqual((core.non_negative_dim(15 - a) // 3).bounds(), (0, 4))
|
||||
self.assertEqual(_bounds(core.non_negative_dim(a)), (1, np.inf))
|
||||
self.assertEqual(_bounds(core.non_negative_dim(a - 5)), (0, np.inf))
|
||||
self.assertEqual(_bounds(core.non_negative_dim(15 - a)), (0, 14))
|
||||
self.assertEqual(_bounds(core.non_negative_dim(15 - a) // 3), (0, 4))
|
||||
self.assertEqual(_bounds(a - core.non_negative_dim(a - 3)),
|
||||
_expect(best=(1, 3), current=(-np.inf, np.inf)))
|
||||
|
||||
def test_min_dim(self):
|
||||
a, b, c = shape_poly.symbolic_shape("a, b, c")
|
||||
|
||||
self.assertEqual(core.min_dim(a, b).bounds(), (1, np.inf))
|
||||
self.assertEqual(core.min_dim(2, b).bounds(), (1, 2))
|
||||
self.assertEqual(core.min_dim(a, -2), -2)
|
||||
self.assertEqual(core.min_dim(a - 5, 1).bounds(), (-4, 1))
|
||||
self.assertEqual(core.min_dim(15 - a, 10).bounds(), (-np.inf, 10))
|
||||
self.assertEqual(core.min_dim(15 - a, 20).bounds(), (-np.inf, 14))
|
||||
|
||||
self.assertEqual(a, core.min_dim(a, a + 2))
|
||||
self.assertEqual(a - 2, core.min_dim(a, a - 2))
|
||||
self.assertEqual(core.min_dim(a % 2 - 1, -1), -1)
|
||||
self.assertEqual(core.min_dim(-1, a % 2 - 1), -1)
|
||||
self.assertGreaterEqual(a, core.min_dim(a, b))
|
||||
self.assertGreaterEqual(a + c - 1, core.min_dim(a, b))
|
||||
self.assertGreaterEqual(b, core.min_dim(a, b))
|
||||
self.assertGreaterEqual(b + c - 1, core.min_dim(a, b))
|
||||
|
||||
self.sampled_assertion(core.min_dim(a, 5),
|
||||
core.min_dim, a, 5)
|
||||
self.sampled_assertion(core.min_dim(5, a),
|
||||
core.min_dim, 5, a)
|
||||
def test_max_dim(self):
|
||||
a, b, c = shape_poly.symbolic_shape("a, b, c")
|
||||
a, b, c, d = shape_poly.symbolic_shape("a, b, c, d")
|
||||
|
||||
self.assertEqual(core.max_dim(a, b).bounds(), (1, np.inf))
|
||||
self.assertEqual(core.max_dim(2, b).bounds(), (2, np.inf))
|
||||
self.assertEqual(core.max_dim(a, 2).bounds(), (2, np.inf))
|
||||
self.assertEqual(core.max_dim(a - 5, 1).bounds(), (1, np.inf))
|
||||
self.assertEqual(core.max_dim(15 - a, 0).bounds(), (0, 14))
|
||||
self.assertEqual((core.max_dim(15 - a, 0) // 3).bounds(), (0, 4))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b)), (1, np.inf))
|
||||
self.assertEqual(_bounds(core.max_dim(2, b)), (2, np.inf))
|
||||
self.assertEqual(_bounds(core.max_dim(a, 2)), (2, np.inf))
|
||||
self.assertEqual(_bounds(core.max_dim(a - 5, 1)), (1, np.inf))
|
||||
self.assertEqual(_bounds(core.max_dim(15 - a, 0)), (0, 14))
|
||||
self.assertEqual(_bounds(core.max_dim(15 - a, 0) // 3), (0, 4))
|
||||
|
||||
self.assertEqual(a + 2, core.max_dim(a, a + 2))
|
||||
self.assertEqual(a , core.max_dim(a, a - 2))
|
||||
@ -455,10 +496,55 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.assertGreaterEqual(core.max_dim(a, b) + c - 1, b)
|
||||
|
||||
self.assertGreaterEqual(core.max_dim(a, b), core.min_dim(a, b))
|
||||
self.sampled_assertion(core.max_dim(a, 5),
|
||||
core.max_dim, a, 5)
|
||||
self.sampled_assertion(core.max_dim(5, a),
|
||||
core.max_dim, 5, a)
|
||||
|
||||
self.assertEqual(_bounds(b - core.max_dim(b - a, 0)),
|
||||
_expect(best=(1, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(a - core.min_dim(a, b)),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(b - core.min_dim(a, b)),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) - a),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) - b),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual((0, 0), _bounds(core.max_dim(1 - b, 0)))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) - a - core.max_dim(b - a, 0)),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) + core.max_dim(c, d) -
|
||||
core.max_dim(a + core.max_dim(c, d),
|
||||
b + core.max_dim(c, d))),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) + core.min_dim(c, d) -
|
||||
core.max_dim(a + core.min_dim(c, d),
|
||||
b + core.min_dim(c, d))),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
|
||||
self.sampled_assertion(core.max_dim(a, 5), core.max_dim, a, 5)
|
||||
self.sampled_assertion(core.max_dim(5, a), core.max_dim, 5, a)
|
||||
|
||||
def test_min_dim(self):
|
||||
a, b, c = shape_poly.symbolic_shape("a, b, c")
|
||||
|
||||
self.assertEqual(_bounds(core.min_dim(a, b)), (1, np.inf))
|
||||
self.assertEqual(_bounds(core.min_dim(2, b)), (1, 2))
|
||||
self.assertEqual(core.min_dim(a, -2), -2)
|
||||
self.assertEqual(_bounds(core.min_dim(a - 5, 1)), (-4, 1))
|
||||
self.assertEqual(_bounds(core.min_dim(15 - a, 10)), (-np.inf, 10))
|
||||
self.assertEqual(_bounds(core.min_dim(15 - a, 20)), (-np.inf, 14))
|
||||
|
||||
# Test simplification during construction
|
||||
self.assertEqual(a, core.min_dim(a, a + 2))
|
||||
self.assertEqual(a - 2, core.min_dim(a, a - 2))
|
||||
self.assertEqual(core.min_dim(a % 2 - 1, -1), -1)
|
||||
self.assertEqual(core.min_dim(-1, a % 2 - 1), -1)
|
||||
self.assertGreaterEqual(a, core.min_dim(a, b))
|
||||
self.assertGreaterEqual(a + c - 1, core.min_dim(a, b))
|
||||
self.assertGreaterEqual(b, core.min_dim(a, b))
|
||||
self.assertGreaterEqual(b + c - 1, core.min_dim(a, b))
|
||||
|
||||
self.sampled_assertion(core.min_dim(a, 5), core.min_dim, a, 5)
|
||||
self.sampled_assertion(core.min_dim(5, a), core.min_dim, 5, a)
|
||||
|
||||
def test_clamp_dim(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
@ -467,13 +553,21 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.assertLessEqual(b, clamp)
|
||||
self.assertLessEqual(clamp, b + 10)
|
||||
|
||||
def test_poly_bounds_complex(self):
|
||||
def test_bounds_complex(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
min_a_b = b - core.non_negative_dim(b - a)
|
||||
# This comes up in slicing with stride
|
||||
self.assertGreaterEqual(min_a_b // 2, 0)
|
||||
|
||||
def test_poly_equal(self):
|
||||
# Comes up in slice[-2:-2b-2:-2]
|
||||
self.assertGreaterEqual((-1 * core.max_dim(-1, a - 2 * b - 2) + core.max_dim(-1, a - 2) + 1),
|
||||
0)
|
||||
self.assertGreaterEqual((-1 * core.max_dim(-1, a - 2 * b - 2) + core.max_dim(-1, a - 2) + 1) // 2,
|
||||
0)
|
||||
self.assertEqual((0, 0),
|
||||
_bounds(core.max_dim(-1*b + 1, 0) // 2))
|
||||
|
||||
def test_equal(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
poly3 = a + 3 - a
|
||||
self.assertEqual(poly3, 3)
|
||||
@ -518,30 +612,9 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
# self.sampled_assertion((a // 2) * 6,
|
||||
# lambda x: x, 3 * a - 3 * (a % 2))
|
||||
|
||||
def test_poly_compare(self):
|
||||
def test_compare_ge(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
poly = 4 * a + b + 3
|
||||
self.assertTrue(poly.ge(0))
|
||||
self.assertTrue(poly.ge(8))
|
||||
self.assertTrue(poly.ge(poly))
|
||||
self.assertTrue(poly.ge(poly - 1))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
re.escape("comparison 'b + 4*a + 3' >= '9' is inconclusive")):
|
||||
poly.ge(9)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
"comparison the_comp is inconclusive"):
|
||||
poly.ge(9, lambda: "the_comp")
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"inconclusive"):
|
||||
(4 * a - b).ge(0)
|
||||
|
||||
def test_poly_compare_overload(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
self.assertGreaterEqual(a, a)
|
||||
self.assertGreaterEqual(a, 0)
|
||||
self.assertGreaterEqual(a, 1)
|
||||
@ -549,25 +622,33 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
|
||||
a >= 2
|
||||
|
||||
poly = 4 * a + b + 3
|
||||
self.assertGreaterEqual(poly, 0)
|
||||
self.assertGreaterEqual(poly, 8)
|
||||
self.assertGreater(poly, 7)
|
||||
self.assertGreaterEqual(poly, poly)
|
||||
self.assertGreaterEqual(poly, poly - 1)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
re.escape("comparison 'b + 4*a + 3' >= '9' is inconclusive")):
|
||||
poly >= 9
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
re.escape("comparison 'b + 4*a + 3' > '9' is inconclusive")):
|
||||
poly > 9
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"inconclusive"):
|
||||
(4 * a - b) >= 0
|
||||
|
||||
# LHS is an integer
|
||||
self.assertLessEqual(8, poly)
|
||||
self.assertLess(7, poly)
|
||||
self.assertGreaterEqual(-8, -poly)
|
||||
self.assertGreater(-7, -poly)
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
|
||||
poly >= 9
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
|
||||
(4 * a - b) >= 0
|
||||
|
||||
def test_poly_int_results(self):
|
||||
def test_int_results(self):
|
||||
# Whenever the result is an integer, it should be represented as a
|
||||
# Python integer, not a symbolic dimension.
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
@ -597,7 +678,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, b + a)", "mod(2*a*b + b^2, b + a)"),
|
||||
(3, a, "floordiv(3, a)", "mod(3, a)"),
|
||||
]])
|
||||
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
|
||||
def test_divmod(self, *, dividend, quotient, divisor, remainder):
|
||||
if isinstance(quotient, str):
|
||||
d1, d2 = divmod(dividend, divisor)
|
||||
self.assertEqual((quotient, remainder), (str(d1), str(d2)))
|
||||
@ -650,8 +731,8 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
def test_constraints_basic(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b",
|
||||
constraints=("a >= 5", "b <= 16"))
|
||||
self.assertEqual(a.bounds(), (5, np.inf))
|
||||
self.assertEqual(b.bounds(), (1, 16))
|
||||
self.assertEqual(_bounds(a), (5, np.inf))
|
||||
self.assertEqual(_bounds(b), (1, 16))
|
||||
|
||||
def test_constraints_trivial(self):
|
||||
a, = shape_poly.symbolic_shape("a",
|
||||
@ -665,9 +746,22 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Unsatisfiable.*a <= 0"):
|
||||
_ = shape_poly.symbolic_shape("a, b",
|
||||
# Contradicts the default a >= 1
|
||||
constraints=("a <= 0",))
|
||||
a, b = shape_poly.symbolic_shape("a, b",
|
||||
# Contradicts the default a >= 1
|
||||
constraints=("a <= 0",))
|
||||
a >= b
|
||||
|
||||
def test_constraints_a_minus_4d(self):
|
||||
# simulates d = div(a, 4) and m = mod(a, 4)
|
||||
assumptions = ["a >= 4*d + m ",
|
||||
"a <= 4*d + m",
|
||||
"m >= 0", "m <= 3"]
|
||||
scope = shape_poly.SymbolicScope(assumptions)
|
||||
a, d = shape_poly.symbolic_shape("a, d", scope=scope)
|
||||
self.assertEqual(_bounds(a - 4*d),
|
||||
_expect(best=(1, 3), current=(-np.inf, np.inf))) # a - 4d = m >= 1
|
||||
self.assertEqual(_bounds(a - 2*d),
|
||||
_expect(best=(3, np.inf), current=(-np.inf, np.inf)))
|
||||
|
||||
def test_constraints_errors(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
@ -708,12 +802,50 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
constraints=("a >= b",))
|
||||
self.assertGreaterEqual(a, b)
|
||||
|
||||
def test_constraints_complex(self):
|
||||
a, b, c = shape_poly.symbolic_shape(
|
||||
"a, b, c",
|
||||
constraints=("a + 2 <= b", "b <= a + 5", "a + b >= c"))
|
||||
self.assertEqual(_bounds(b),
|
||||
_expect(best=(3, np.inf), current=(1, np.inf)))
|
||||
self.assertEqual(_bounds(b - a),
|
||||
_expect(best=(2, 5), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(b - a - 7),
|
||||
_expect(best=(-5, -2), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(c - 2*a - 5),
|
||||
_expect(best=(-np.inf, 0), current=(-np.inf, np.inf)))
|
||||
|
||||
def test_constraints_fractional(self):
|
||||
a, = shape_poly.symbolic_shape("a",
|
||||
constraints=("2 * a >= 5", "3 * a <= 10",))
|
||||
self.assertEqual(_bounds(5*a - 2), (13, 13))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(constraint="2*a + 3*b >= 10", exp="a + b - 2",
|
||||
bounds=_expect(best=(2, np.inf), current=(0, np.inf))),
|
||||
dict(constraint="-2*a + 3*b >= 10", exp="a + 2*b",
|
||||
bounds=_expect(best=(9, np.inf), current=(3, np.inf))),
|
||||
dict(constraint="-2*a + -3*b >= -10", exp="-1*a + 2*b",
|
||||
bounds=_expect(best=(-1, 3), current=(-np.inf, np.inf))),
|
||||
dict(constraint="2*a + -3*b >= 10", exp="-1*a + 2*b", bounds=(-np.inf, np.inf)),
|
||||
]
|
||||
)
|
||||
def test_constraints_complex_gen(self,
|
||||
constraint: str, exp: str,
|
||||
bounds: tuple[float, float]):
|
||||
a, b, exp = shape_poly.symbolic_shape(
|
||||
"a, b, " + exp,
|
||||
constraints=(constraint,))
|
||||
self.assertEqual(bounds, _bounds(exp))
|
||||
|
||||
def test_constraints_override(self):
|
||||
# Some constaints override other
|
||||
a, b = shape_poly.symbolic_shape("a, b",
|
||||
constraints=("a >= 5", "b <= 16",
|
||||
"a >= 10", "b <= 10"))
|
||||
self.assertEqual(a.bounds(), (10, np.inf))
|
||||
self.assertEqual(b.bounds(), (1, 10))
|
||||
self.assertEqual(_bounds(a), (10, np.inf))
|
||||
self.assertEqual(_bounds(b), (1, 10))
|
||||
|
||||
def test_constraints_error_msg(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b",
|
||||
@ -728,7 +860,83 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
a, b = shape_poly.symbolic_shape("a, b",
|
||||
constraints=("a >= 5", "a <= 10"))
|
||||
self.assertIs(a.scope, b.scope)
|
||||
self.assertEqual(a.bounds(), (5, 10))
|
||||
self.assertEqual(_bounds(a), (5, 10))
|
||||
|
||||
def test_constraints_seq(self):
|
||||
a1, a2, a3, a4, a5 = shape_poly.symbolic_shape(
|
||||
"a1, a2, a3, a4, a5",
|
||||
constraints=(
|
||||
"a1 >= a2",
|
||||
"a2 >= a3",
|
||||
"a3 >= a4",
|
||||
"a4 >= a5",
|
||||
)
|
||||
)
|
||||
self.assertEqual(_bounds(a1 - a5),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
|
||||
def test_constraints_rounding_monomials(self):
|
||||
a1, a2 = shape_poly.symbolic_shape(
|
||||
"a1, a2",
|
||||
constraints=(
|
||||
# a1 >= 2 and a1 <= 2
|
||||
"2 * a1 >= 3", "-2 * a1 >= -5"
|
||||
)
|
||||
)
|
||||
self.assertEqual(_bounds(a1), (2, 2))
|
||||
|
||||
def test_constraints_rounding_not_monomials(self):
|
||||
a1, a2 = shape_poly.symbolic_shape(
|
||||
"a1, a2",
|
||||
constraints=(
|
||||
# a1 >= a2 + 2 and a1 <= a2 + 2
|
||||
"2*a1 >= 2*a2 + 3", "-2*a1 + 2*a2 >= -5"
|
||||
)
|
||||
)
|
||||
self.assertEqual(_bounds(a1 - a2 - 2),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(a2 - a1 + 2),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
|
||||
def test_constraints_unsat_trivial(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Unsatisfiable explicit constraint: a1 >= a1 \+ 1"):
|
||||
_ = shape_poly.symbolic_shape(
|
||||
"a1", constraints=("a1 >= a1 + 1",))
|
||||
|
||||
def test_constraints_unsat_monomials(self):
|
||||
with self.assertRaisesRegex(_expect(best=ValueError,
|
||||
current=shape_poly.InconclusiveDimensionOperation),
|
||||
_expect(best="Unsatisfiable constraint",
|
||||
current="inconclusive")):
|
||||
a1, a2, *_ = shape_poly.symbolic_shape(
|
||||
"a1, a2, a3, a4",
|
||||
constraints=(
|
||||
# The following -> a1 >= 5
|
||||
"a1 >= a3 + 1", "a3 >= 4",
|
||||
# The following -> a1 <= 4
|
||||
"a1 <= a4 + 2", "a4 <= 2"))
|
||||
a1 >= a2
|
||||
|
||||
def test_constraints_unsat_not_monomials(self):
|
||||
with self.assertRaisesRegex(_expect(best=ValueError,
|
||||
current=shape_poly.InconclusiveDimensionOperation),
|
||||
_expect(best="Unsatisfiable constraint",
|
||||
current="inconclusive")):
|
||||
a1, a2, a3, a4 = shape_poly.symbolic_shape(
|
||||
"a1, a2, a3, a4",
|
||||
constraints=(
|
||||
# The following -> a1 >= a2 + 5
|
||||
"a1 >= a3 + 1", "a3 >= a2 + 4",
|
||||
# The following -> a1 <= a2 + 4
|
||||
"a1 <= a4 + 2", "a4 <= a2 + 2"))
|
||||
|
||||
self.assertGreaterEqual(a1, a2)
|
||||
|
||||
def test_constraints_sharing(self):
|
||||
a, = shape_poly.symbolic_shape("a",
|
||||
constraints=("a >= 5", "a <= 10"))
|
||||
self.assertEqual(_bounds(a), (5, 10))
|
||||
# The constraints order does not matter, and they are canonicalized
|
||||
a1, = shape_poly.symbolic_shape("a",
|
||||
constraints=("2*a + 5 <= a + 15", "a >= 5"))
|
||||
@ -739,7 +947,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
def test_constraints_different_scope(self):
|
||||
a, = shape_poly.symbolic_shape("a",
|
||||
constraints=("a >= 5", "a <= 10"))
|
||||
self.assertEqual(a.bounds(), (5, 10))
|
||||
self.assertEqual(_bounds(a), (5, 10))
|
||||
a1, = shape_poly.symbolic_shape("a",
|
||||
constraints=("a <= 10",))
|
||||
self.assertNotEqual(a, a1)
|
||||
@ -882,6 +1090,14 @@ def check_shape_poly(tst, f_jax: Callable, *,
|
||||
@jtu.with_config(jax_enable_key_reuse_checks=False)
|
||||
class ShapePolyTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
_start_profile(self)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
_stop_profile(self)
|
||||
|
||||
def test_simple_unary(self):
|
||||
"""Test shape polymorphism for a simple case, unary function."""
|
||||
|
||||
@ -1118,10 +1334,28 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
polymorphic_shapes=["a"],
|
||||
symbolic_constraints=["a >= 8"])
|
||||
|
||||
def test_constraints_for_profile(self):
|
||||
# A somewhat more involved tests to stress test the correctness and
|
||||
# performance
|
||||
def f(x): # x: i32[a, b]
|
||||
acc = 0
|
||||
for start in range(0, 10):
|
||||
slice = x[start::2] # exercises floordiv and min
|
||||
acc += jnp.sum(slice, axis=0)
|
||||
|
||||
slice = x[start:(x.shape[0] - x.shape[0] % 2):2] # exercises max and min
|
||||
acc += jnp.sum(slice, axis=0)
|
||||
return acc
|
||||
|
||||
exp = export.export(f)(jax.ShapeDtypeStruct(export.symbolic_shape("a, b"),
|
||||
np.int32))
|
||||
|
||||
|
||||
|
||||
def test_constraints_compile_time_check(self):
|
||||
def f(x): # x: i32[a]
|
||||
a = x.shape[0]
|
||||
assert a.bounds() == (2, 4)
|
||||
assert _bounds(a) == (2, 4)
|
||||
return lax.dynamic_slice_in_dim(x, 1, 2, 0)
|
||||
|
||||
x_spec = jax.ShapeDtypeStruct(
|
||||
@ -1154,7 +1388,7 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
nonlocal f_tracing_count
|
||||
f_tracing_count += 1
|
||||
a = x.shape[0]
|
||||
assert a.bounds() == expected_a_bounds
|
||||
assert _bounds(a) == expected_a_bounds
|
||||
|
||||
x_spec = jax.ShapeDtypeStruct(export.symbolic_shape("a"), np.int32)
|
||||
_ = export.export(f)(x_spec)
|
||||
@ -2692,21 +2926,12 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
"""This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES."""
|
||||
|
||||
def setUp(self):
|
||||
self.prof = None
|
||||
if os.getenv("JAX_PROFILE_TEST", False):
|
||||
import cProfile
|
||||
self.prof = cProfile.Profile()
|
||||
self.prof.enable()
|
||||
_start_profile(self)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
if self.prof is not None:
|
||||
from pstats import Stats
|
||||
p = Stats(self.prof)
|
||||
p.strip_dirs()
|
||||
p.sort_stats("cumtime").print_stats(.2)
|
||||
p.print_callers(.2)
|
||||
_stop_profile(self)
|
||||
|
||||
# For each primitive "xxx" the test will be called "test_harness_xxx_...".
|
||||
# If you want to run this test for only one harness that includes "foo"
|
||||
@ -2787,6 +3012,5 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
for fname, _ in harness.override_jax_config_flags.items():
|
||||
jax.config.update(fname, prev_jax_config_flags[fname])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user