[shape_poly] Add a decision procedure for inequalities.

In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".

Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:

  * if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
    eliminate "a" and infer the derived constraint "b + c >= 0".
  * the lower bound of "a + c", in presence of a constraint "a >= b"
    it greater-or-equal to "b + c".

The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.

This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.

The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.

With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.

We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
This commit is contained in:
George Necula 2024-01-20 08:47:52 +00:00
parent 2518a6f6d2
commit e20afac46a
5 changed files with 576 additions and 305 deletions

View File

@ -34,6 +34,7 @@ py_library(
"_export.py",
"_serialization.py",
"_shape_poly.py",
"_shape_poly_decision.py",
"serialization_generated.py",
],
srcs_version = "PY3",

View File

@ -35,3 +35,4 @@ from jax.experimental.export._serialization import (
serialize,
deserialize,
)
from jax.experimental.export import _shape_poly_decision

View File

@ -25,7 +25,7 @@ This enables many JAX programs to be traced with symbolic dimensions
in some dimensions. A priority has been to enable the batch
dimension in neural network examples to be polymorphic.
This was built initially for jax2tf, but it is now customizable to be
This was built initially for jax2tf, but it is now
independent of TF. The best documentation at the moment is in the
jax2tf.convert docstring, and the
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
@ -210,48 +210,6 @@ class _DimAtom:
"""Lexicographic comparison"""
return self._syntactic_cmp(other) >= 0
def bounds(self) -> tuple[float, float]:
"""Returns the lower and upper bounds, or -+ inf."""
if self.var is not None:
return (1, np.inf) # variables are assumed to be >= 1
opnd_bounds = [opnd.bounds() for opnd in self.operands]
if self.operation == _DimAtom.FLOORDIV: # a // b
(a_l, a_u), (b_l, b_u) = opnd_bounds
def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
assert b != 0
if not np.isinf(b): # divisor is finite
return math.floor(a / b) if not np.isinf(a) else -np.inf if (a >= 0) != (b >= 0) else np.inf
elif not np.isinf(a): # dividend is finite and divisor is infinite
return -1 if (a >= 0) != (b >= 0) else 0
else: # both dividend and divisor are infinite
return -np.inf if (a >= 0) != (b >= 0) else np.inf
# Same reasoning as for multiplication: the bounds are among the cross-product
# of the bounds.
bound_candidates = [math_floor_with_inf(a_l, b_l), math_floor_with_inf(a_l, b_u),
math_floor_with_inf(a_u, b_l), math_floor_with_inf(a_u, b_u)]
return (min(*bound_candidates), max(*bound_candidates))
elif self.operation == _DimAtom.MOD:
_, (b_l, b_u) = opnd_bounds
if b_l > 0: # positive divisor
return (0, b_u - 1)
elif b_u < 0: # negative divisor
return (b_l + 1, 0)
else:
return (-np.inf, np.inf)
elif self.operation == _DimAtom.MAX:
(a_l, a_h), (b_l, b_h) = opnd_bounds
return (max(a_l, b_l), max(a_h, b_h))
elif self.operation == _DimAtom.MIN:
(a_l, a_h), (b_l, b_h) = opnd_bounds
return (min(a_l, b_l), min(a_h, b_h))
else:
assert False
def evaluate(self, env: DimVarEnv):
if self.var is not None:
try:
@ -401,21 +359,6 @@ class _DimMon(dict):
elif diff > 0: d[key] = diff
return _DimMon(d)
def bounds(self, scope: SymbolicScope) -> tuple[float, float]:
"""Returns the lower and upper bounds, or -+inf."""
# The bounds of a product are among the product of bounds.
bounds = []
for a, exp in self.items():
a_l, a_u = a.bounds()
assert a_l <= a_u
bounds.append((a_l ** exp, a_u ** exp))
candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)]
calculated_bounds = (min(*candidates), max(*candidates)) # type: ignore
constrained_bounds = scope._monomial_bounds.get(self, (- np.inf, np.inf))
return (max(calculated_bounds[0], constrained_bounds[0]),
min(calculated_bounds[1], constrained_bounds[1]))
def evaluate(self, env: DimVarEnv):
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1)
def pow_opt(v, p: int):
@ -442,7 +385,7 @@ class _DimExpr():
scope: SymbolicScope):
# Do not construct _DimExpr directly, unless you are sure that coeffs is
# normalized; Use _DimExpr.normalize.
# Takes ownership of coeffs
# Takes ownership of coeffs.
self._coeffs = coeffs or {_DimMon(): 0}
self._scope = scope
self._monomials_sorted = tuple(sorted(self._coeffs.items(), reverse=True))
@ -462,12 +405,31 @@ class _DimExpr():
@property
def leading_term(self) -> tuple[_DimMon, int]:
"""Returns the highest degree term that comes first lexicographically."""
"""Returns the highest degree term that comes last lexicographically."""
return self._monomials_sorted[0]
def to_single_term(self) -> tuple[int, int, _DimMon] | None:
"""Extracts the single term: k + c * term.
Returns None if the expression is not a single term, or (k, c, term)
"""
n1 = 0
n2 = 0
mon = None
for m, c in self.monomials():
if m.degree == 0:
n1 = c
continue
if mon is None:
mon = m
n2 = c
continue
return None
assert mon is not None
return (n1, n2, mon)
@classmethod
def _add_coeffs(cls, coeffs: dict[_DimMon, int], mon: _DimMon, coeff: int):
"""Do `coeffs[mon] += coeff` but remove 0 coefficients."""
"""Computes `coeffs[mon] += coeff` while removing 0 coefficients."""
old_c = coeffs.get(mon)
if old_c is None:
if coeff != 0: coeffs[mon] = coeff
@ -504,24 +466,6 @@ class _DimExpr():
else:
return int(free_const)
@classmethod
def normalize_floordiv_times_divisor(cls, coeffs: dict[_DimMon, int],
scope: SymbolicScope) -> DimSize:
# Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes
# up when handling strided convolution.
for dec in _decompose_expr(_DimExpr(coeffs, scope), _DimAtom.FLOORDIV,
with_exp=1):
# e = factor * floordiv(operands)^exp * rest_monomial + rest_expr
if dec.rest_monomial == 1 and dec.factor == 1:
continue
m_trimmed, m_remainder = divmod(dec.factor * dec.rest_monomial, dec.operands[1])
if m_remainder == 0:
return m_trimmed * (
dec.operands[0] -
_DimExpr.from_operation(_DimAtom.MOD, *dec.operands,
scope=scope)) + dec.rest_expr
return _DimExpr.normalize(coeffs, scope)
@classmethod
def from_constant(cls, c: int, scope: SymbolicScope):
return _DimExpr({_DimMon(): op.index(c)}, scope)
@ -570,15 +514,18 @@ class _DimExpr():
mon = self.to_atom()
return mon.to_var() if mon is not None else None
def to_constant(self) -> int | None:
@classmethod
def to_constant(cls, e: DimSize) -> int | None:
"""Extract the constant from a symbolic expression.
Returns None if the expression is not a single constant."""
m, m_c = self.leading_term
if not isinstance(e, _DimExpr):
return int(e)
m, m_c = e.leading_term
return m_c if m.degree == 0 else None
@property
def is_constant(self):
return self.to_constant() is not None
return _DimExpr.to_constant(self) is not None
def get_vars(self) -> set[str]:
"""The variables that appear in a symbolic dimension."""
@ -587,6 +534,39 @@ class _DimExpr():
acc.update(mon.get_vars())
return acc
@classmethod
def _merge_sorted_terms(
cls,
e1: Sequence[tuple[_DimMon, int]], i1: int, f1: int,
e2: Sequence[tuple[_DimMon, int]], i2: int, f2: int) -> Sequence[tuple[_DimMon, int]]:
"""Computes e1[i1:] * f1 + e2[i2:] * f2.
e1, e2, and the result are sorted with largest term first.
This is an optimization for a common operation. The unoptimized code would
compute each subexpression in term.
"""
acc = []
while i1 < len(e1) and i2 < len(e2):
m1, m1_c = e1[i1]
m2, m2_c = e2[i2]
cmp = m1._syntactic_cmp(m2) # Pick the largest monomial
if cmp < 0:
acc.append((m2, m2_c * f2))
i2 += 1
elif cmp > 0:
acc.append((m1, m1_c * f1))
i1 += 1
else: # They are equal, combine them
i1 += 1
i2 += 1
m1_c = m1_c * f1 + m2_c * f2
if m1_c == 0: continue
acc.append((m1, m1_c))
acc.extend((m1, m1_c * f1) for m1, m1_c in itertools.islice(e1, i1, len(e1)))
acc.extend((m2, m2_c * f2) for m2, m2_c in itertools.islice(e2, i2, len(e2)))
return acc
def _syntactic_cmp(self, other: _DimExpr) -> int:
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically (syntactic), to be used for sorting.
@ -594,6 +574,7 @@ class _DimExpr():
"""
s_mons = self._monomials_sorted
o_mons = other._monomials_sorted
if c := cmp_comparable(self._size, other._size): return c
def cmp_mon(s_mon: tuple[_DimMon, int], o_mon: tuple[_DimMon, int]) -> int:
if c := s_mon[0]._syntactic_cmp(o_mon[0]): return c
return cmp_comparable(s_mon[1], o_mon[1])
@ -618,79 +599,6 @@ class _DimExpr():
return diff == 0
def ge(self, other: DimSize,
cmp_str: Callable[[], str] | None = None) -> bool:
"""Implements `self >= other`.
Raises InconclusiveDimensionOperation if the result is not conclusive.
Uses `cmp_str()` as a description of the comparison in the exception
string.
"""
self_minus_other = _ensure_poly(self - other, "ge", self.scope)
lb, ub = self_minus_other.bounds()
if lb >= 0:
return True
if ub < 0:
return False
# Attempt to handle max. For the decomposition
# e = factor * max(op1, op2) + rest_expr
# use the rule
# e >= 0 IF (factor > 0 AND
# (factor * op1 + rest_expr >= 0 OR
# factor * op2 + rest_expr >= 0))
# OR
# (factor < 0 AND
# (factor * op1 + rest_expr >= 0 AND
# factor * op2 + rest_expr >= 0))
for dec in _decompose_expr(self_minus_other, _DimAtom.MAX,
with_exp=1, with_rest_monomial=1):
op1, op2 = dec.operands
if dec.factor > 0:
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) or
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
return True
else:
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) and
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
return True
# Attempt to handle min. For the decomposition
# e = factor * min(op1, op2) + rest_expr
# use the same rule as for
# e = max(factor * op1, factor * op2) + rest_expr
for dec in _decompose_expr(self_minus_other, _DimAtom.MIN,
with_exp=1, with_rest_monomial=1):
op1, op2 = dec.operands
if dec.factor > 0:
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) and
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
return True
else:
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) or
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
return True
# Attempt to handle floordiv >= 0
for dec in _decompose_expr(self_minus_other, _DimAtom.FLOORDIV,
with_exp=1, with_rest_monomial=1,
with_rest_expr=0):
# e = factor * floordiv(op0, op1)^1 * 1 + 0
if dec.factor > 0:
if definitely_geq_0(dec.operands[0]) and definitely_geq_0(dec.operands[1]):
return True
if cmp_str is not None:
msg = cmp_str()
else:
msg = f"'{self}' >= '{other}'"
if self.scope._explicit_constraints:
describe_scope = f"\nUsing symbolic scope {self.scope}"
else:
describe_scope = ""
raise InconclusiveDimensionOperation(
f"Symbolic dimension comparison {msg} is inconclusive.{describe_scope}")
def __hash__(self):
return self._hash
@ -719,7 +627,7 @@ class _DimExpr():
coeffs = self._coeffs.copy()
for mon, coeff in other.monomials():
_DimExpr._add_coeffs(coeffs, mon, coeff)
return _DimExpr.normalize_floordiv_times_divisor(coeffs, self.scope)
return _DimExpr.normalize(coeffs, self.scope)
def __radd__(self, other):
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
@ -749,7 +657,7 @@ class _DimExpr():
for mon2, coeff2 in other.monomials():
mon = mon1.mul(mon2)
_DimExpr._add_coeffs(coeffs, mon, coeff1 * coeff2)
return _DimExpr.normalize_floordiv_times_divisor(coeffs, self.scope)
return _DimExpr.normalize(coeffs, self.scope)
def __rmul__(self, other):
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
@ -815,28 +723,39 @@ class _DimExpr():
return False
elif not core.is_constant_dim(other):
return False
else:
other = _ensure_poly(other, "eq", self.scope)
return self.eq(other)
# Equality is used very frequently because expressions are cached. We could
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
# but that would be too expensive. It would also have the unfortunate drawback
# that we cannot then cache `e.bounds()` because hashing invokes equality
# which would lead to infinite recursion.
diff = self - other
# We look for `self - other == k`, and we rely on the fact that when we
# normalize _DimExpr that represent integers as ints.
if is_symbolic_dim(diff):
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
# cannot raise exceptions, because it is used indirectly when hashing.
# So, we say that the expressions are disequal, which is really unsound.
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported
return False
return diff == 0
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def __ge__(self, other: DimSize) -> bool:
return self.ge(
other, lambda: f"'{self}' >= '{other}'")
return _geq_decision(self, other, lambda: f"'{self}' >= '{other}'")
def __le__(self, other: DimSize):
return _ensure_poly(other, "le", self.scope).ge(
self, lambda: f"'{self}' <= '{other}'")
return _geq_decision(other, self, lambda: f"'{self}' <= '{other}'")
def __gt__(self, other: DimSize):
return not _ensure_poly(other, "le", self.scope).ge(
self, lambda: f"'{self}' > '{other}'")
return not _geq_decision(other, self, lambda: f"'{self}' > '{other}'")
def __lt__(self, other: DimSize):
return not self.ge(
other, lambda: f"'{self}' < '{other}'")
return not _geq_decision(self, other, lambda: f"'{self}' < '{other}'")
def divmod(self, divisor: _DimExpr) -> tuple[DimSize, int]:
"""
@ -890,41 +809,6 @@ class _DimExpr():
_DimExpr.from_operation(_DimAtom.MOD, self, divisor,
scope=self.scope))
def bounds(self) -> tuple[float, float]:
"""Returns the lower and upper bounds, or -+inf."""
lb = ub = self._coeffs.get(_DimMon(), 0) # The free coefficient
for mon, coeff in self.monomials():
if mon.degree == 0: continue # We already included the free coefficient
m_l, m_u = mon.bounds(self.scope)
assert m_l <= m_u and coeff != 0
item_l, item_u = coeff * m_l, coeff * m_u
lb = lb + min(item_l, item_u) # type: ignore
ub = ub + max(item_l, item_u) # type: ignore
bounds_from_constraints = self.scope._expr_bounds.get(self, (- np.inf, np.inf))
lb = max(lb, bounds_from_constraints[0]) # type: ignore
ub = min(ub, bounds_from_constraints[1]) # type: ignore
if lb != -np.inf or ub != np.inf:
return lb, ub
# Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0
# TODO(necula): add more principled support for floordiv and mod
# For example, this will miss "1 + a - mod(b, a)"
for dec in _decompose_expr(self, _DimAtom.MOD,
with_exp=1, with_rest_monomial=1):
# E = factor*mod(op1, op2)^1 * 1 + rest_expr
if dec.rest_expr == - dec.factor * dec.operands[1]:
try:
if dec.operands[1] <= 0:
continue
except InconclusiveDimensionOperation:
continue
if dec.factor > 0:
return (-np.inf, -1)
else:
return (1, np.inf)
return lb, ub
def evaluate(self, env: DimVarEnv):
# Evaluates as a value of dtype=core.dim_value_dtype()
terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff))
@ -932,25 +816,25 @@ class _DimExpr():
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
def max(self, other: DimSize) -> DimSize:
lb, ub = _ensure_poly(self - other, "max", self.scope).bounds()
lb, ub = _bounds_decision(self - other, _stop_early_for_geq0_leq0)
if 0 <= lb: return self
if ub <= 0: return other
return _DimExpr.from_operation(_DimAtom.MAX, self, other, scope=self.scope)
def rmax(self, other: DimSize) -> DimSize:
lb, ub = _ensure_poly(self - other, "max", self.scope).bounds()
lb, ub = _bounds_decision(self - other, _stop_early_for_geq0_leq0)
if 0 <= lb: return self
if ub <= 0: return other
return _DimExpr.from_operation(_DimAtom.MAX, other, self, scope=self.scope)
def min(self, other: DimSize) -> DimSize:
lb, ub = _ensure_poly(self - other, "min", self.scope).bounds()
lb, ub = _bounds_decision(self - other, _stop_early_for_geq0_leq0)
if 0 <= lb: return other
if ub <= 0: return self
return _DimExpr.from_operation(_DimAtom.MIN, self, other, scope=self.scope)
def rmin(self, other: DimSize) -> DimSize:
lb, ub = _ensure_poly(self - other, "min", self.scope).bounds()
lb, ub = _bounds_decision(self - other, _stop_early_for_geq0_leq0)
if 0 <= lb: return other
if ub <= 0: return self
return _DimExpr.from_operation(_DimAtom.MIN, other, self, scope=self.scope)
@ -981,6 +865,8 @@ def cmp_sequence(s1, s2, elem_cmp) -> int:
if len(s1) < l2: return -1
return 0
def _stop_early_for_geq0_leq0(lb, ub):
return 0 <= lb or ub <= 0
class SymbolicScope:
"""Indentifies a scope for symbolic expressions.
@ -1003,19 +889,20 @@ class SymbolicScope:
self._location_frame = source_info_util.user_frame(source_info_util.current())
self._explicit_constraints: list[tuple[_DimExpr, str]] = []
# Keep an efficient representation of
# the explicit constraints for use during reasoning.
self._monomial_bounds: dict[_DimMon, tuple[float, float]] = {}
self._expr_bounds: dict[_DimExpr, tuple[float, float]] = {}
constraints = self._parse_constraints(constraints_str)
for c, c_str in zip(constraints, constraints_str):
if (const := c.to_constant()) is not None:
if (const := _DimExpr.to_constant(c)) is not None:
if const < 0:
raise ValueError(f"Unsatisfiable explicit constraint: {c_str}")
continue
self._explicit_constraints.append((c, c_str))
self._process_constraint(c, c_str)
# We cache the _DimExpr.bounds calls. The result depends only on the
# explicit and implicit constraints, so it is safe to keep it in the
# scope.
self._bounds_cache: dict[tuple[_DimExpr,
Callable[[float, float], bool] | None],
tuple[float, float]] = {}
def __str__(self) -> str:
extras = []
@ -1051,39 +938,6 @@ class SymbolicScope:
f"Got {repr(constraints_str)}")
return tuple(parse_one(cs) for cs in constraints_str)
def _process_constraint(self, e: _DimExpr, e_str: str):
# Look for the special case m*mon + n >= 0.
# Then assert mon >= ceil(n / m) or mon <= floor(n / m)
n = m = 0
mon = None
nr_non_trivial_monomials = 0
for _mon, count in e.monomials():
if _mon.degree == 0:
n = count
continue
nr_non_trivial_monomials += 1
mon = _mon
m = count
if nr_non_trivial_monomials > 1:
# The general case, we just remember this constraint in _expr_bounds.
self._expr_bounds[e] = (0, np.inf)
return
# A single non-trivial monomial
assert isinstance(mon, _DimMon)
bounds = mon.bounds(self) # This considers default internal constraints, and
# previous external constraints
if m > 0: # mon >= ceil(-n / m)
ge = int(np.ceil(- n / m))
new_bounds = (max(ge, bounds[0]), bounds[1])
else: # mon <= floor(-n / m)
le = int(np.floor(-n / m))
new_bounds = (bounds[0], min(le, bounds[1]))
if new_bounds[0] > new_bounds[1]:
raise ValueError(f"Unsatisfiable constraints: {e_str}")
self._monomial_bounds[mon] = new_bounds
def _check_same_scope(self, other: _DimExpr,
when: str = "",
self_descr: str = " ",
@ -1096,6 +950,32 @@ class SymbolicScope:
f"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.")
# Set by the shape_poly_decision.py module.
# Calling convention:
# _geq_decision(e1, e2, cmp_str)
# where e1 and e2 are two expressions to be compared for greater-equal
# and `cmp_str()` is a string that describes the comparison for error
# messages.
# Returns: a boolean or raises InconclusiveDimensionOperation
# TODO: remove this trampoline when we refactor the sources
def _geq_decision_unimplemented(d1: DimSize, d2: DimSize,
cmp_str: Callable[[], str]) -> bool:
raise NotImplementedError("_geq_decision is uninitialized")
_geq_decision: Callable[[DimSize, DimSize, Callable[[], str]], bool] = _geq_decision_unimplemented
#
# Calling convention:
# _bounds_decision(e, stop_early)
# returns a tuple with the lower and upper bound of e.
# `stop_early(lb, ub)` can be called in an iterative process to decide if the
# current bounds are tight enough.
# TODO: remove this trampoline when we refactor the sources
def _bounds_decision_unimplemented(
d: DimSize,
stop_early: Callable[[float, float], bool] | None) -> tuple[float, float]:
raise NotImplementedError("_bounds_decision is uninitialized")
_bounds_decision: Callable[[DimSize, Callable[[float, float], bool] | None],
tuple[float, float]] = _bounds_decision_unimplemented
@dataclasses.dataclass
class _Decomposition:
"""Decomposition of an expression around an operation atom.
@ -1185,16 +1065,6 @@ def is_poly_dim(p: DimSize) -> bool:
dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int]
def definitely_geq_0(d: DimSize) -> bool:
"""Returns true when we can prove that d >=0, false otherwise.
Note that a result of False may mean that we cannot conclusively prove the
sign of `d`, it does not mean that `d < 0`.
"""
try:
return d >= 0
except InconclusiveDimensionOperation:
return False
def _einsum_contract_path(*operands, **kwargs):
"""Like opt_einsum.contract_path, with support for DimExpr shapes.

View File

@ -0,0 +1,412 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shape polymorphism support for deciding inequalities of symbolic dimensions.
"""
from __future__ import annotations
import collections
from collections.abc import Sequence
import itertools
import math
from typing import Callable
import numpy as np
from jax.experimental.export import _shape_poly
from jax.experimental.export._shape_poly import (
_DimExpr, _DimMon, _DimAtom,
SymbolicScope,
DimSize,
InconclusiveDimensionOperation,
)
def geq_decision(e1: DimSize, e2: DimSize,
cmp_str: Callable[[], str]) -> bool:
"""Implements `e1 >= e2`.
Args:
e1, e2: the expressions to compare for greater-equal
cmp_str: a callable such that `cmp_str()` describes the comparison
for error messages, e.g., "a <= b". Without this all comparisions would
be reported as ">=".
Raises InconclusiveDimensionOperation if the result is not conclusive.
"""
if isinstance(e1, _DimExpr):
scope = e1.scope
if isinstance(e2, _DimExpr):
scope._check_same_scope(e2, f"when comparing {cmp_str()}")
elif isinstance(e2, _DimExpr):
scope = e2.scope
else:
return int(e1) >= int(e2)
decision = _DecisionByElimination(scope)
lb, ub = decision.bounds(e1 - e2, _stop_early_for_geq0)
if lb >= 0:
return True
if ub < 0:
return False
if scope._explicit_constraints:
describe_scope = f"\nUsing symbolic scope {scope}"
else:
describe_scope = ""
raise InconclusiveDimensionOperation(
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
_shape_poly._geq_decision = geq_decision
def _stop_early_for_geq0(lb, ub):
return lb >= 0 or ub < 0
def bounds_decision(e: DimSize,
stop_early: Callable[[float, float], bool] | None) -> tuple[float, float]:
if not isinstance(e, _DimExpr):
return (int(e), int(e))
decision = _DecisionByElimination(e.scope)
return decision.bounds(e, stop_early)
_shape_poly._bounds_decision = bounds_decision
class _DecisionByElimination:
"""A decision procedure based on elimination of terms.
Given an expression `e = m_k*m + rest_e` for which we want to compute bounds,
and a constraint `c = m_c*m + rest_c >= 0`,
Let `e0 = abs(m_c)*e - sgn(m_c)*m_k*c`. (Note that we eliminated `m` from
`e0`, since `abs(m_c)*m_k = sgn(m_c)*m_k*m_c`.)
Since `c >= 0`,
if `sgn(m_c)*m_k > 0`:
then `abs(m_c)*e >= e0`, hence, `LB(e) >= ceil(LB(e0) / abs(m_c))`,
if `sgn(m_c)*m_k < 0`
then `abs(m_c)*e <= e0`, hence, `UB(e) <= floor(UB(e0) / abs(m_c))`,
"""
def __init__(self, scope: SymbolicScope):
self.scope = scope
self._processed_for_internal_constraints: set[_DimMon] = set()
# The other fields are for keeping an efficient representation of
# the explicit constraints.
self._term_bounds: dict[_DimMon, tuple[float, float]] = {}
# The _expr_constraints represents a set of constraints that are not
# just simple monomials. The set is represented as a mapping from a
# monomial "m" to tuples (k, c) where "c >= 0" represents a constraint that
# has "m" as the leading monomial with coefficient "k".
self._expr_constraints: dict[_DimMon, set[tuple[int, _DimExpr]]] = collections.defaultdict(set)
# TODO: find a way to reuse the state reflecting the explicit constraints
# We sort the constraints, so that the results of the heuristics do not
# depend on the order in which the user writes the constraints.
for c, c_str in sorted(scope._explicit_constraints,
key=lambda c: c[0]._monomials_sorted):
self.add_constraint(c, 0, c_str)
def add_constraint(self,
e1: _DimExpr | int | float,
e2: _DimExpr | int | float,
constraint_str: str | None = None):
"""Adds a constraint "e1 >= e2" to the internal state."""
if isinstance(e1, float):
if np.isinf(e1) and e1 >= 0: return
assert e1 == np.floor(e1)
e1 = int(e1)
if isinstance(e2, float):
if np.isinf(e2) and e2 <= 0: return
e2 = int(e2)
e = e1 if isinstance(e2, (int, float)) and e2 == 0 else e1 - e2
if constraint_str is None:
constraint_str = f"{e1} >= {e2}"
if (const := _DimExpr.to_constant(e)) is not None:
if const < 0:
raise ValueError(f"Unsatisfiable constraint: {constraint_str}")
return
assert isinstance(e, _DimExpr)
self._add_to_state(e, constraint_str)
combinations = self._combine_with_existing_constraints(e, constraint_str)
for a in combinations:
self._add_to_state(a, f"{a} >= 0")
def _combine_with_existing_constraints(self,
e: _DimExpr,
debug_str: str) -> set[_DimExpr]:
# This combines `e` with those constraints already present. The resulting
# constraints are not scanned for new internal constraints (because there
# are no new monomials), but they are also not combined further.
# TODO: this results in incompleteness, but it is probably a good
# compromise.
combinations: set[_DimExpr] = set()
def acc_combination(e: _DimExpr | int):
if (const := _DimExpr.to_constant(e)) is not None:
if const < 0:
raise ValueError(f"Unsatisfiable constraints: {debug_str}")
else:
combinations.add(e) # type: ignore
# First combine with the existing monomial constraints
for e_m, e_c in e.monomials():
if e_m.degree == 0: continue
m_lb, m_ub = self._term_bounds.get(e_m, (-np.inf, np.inf))
if e_c > 0:
if m_ub < np.inf:
e_minus_m = _DimExpr._merge_sorted_terms(e._monomials_sorted, 0, 1,
[(e_m, e_c)], 0, -1)
e_minus_m_ub = _DimExpr._merge_sorted_terms(e_minus_m, 0, 1,
[(_DimMon(), 1)], 0, e_c * int(m_ub))
acc_combination(_DimExpr(dict(e_minus_m_ub), e.scope))
else:
if m_lb > -np.inf:
e_minus_m = _DimExpr._merge_sorted_terms(e._monomials_sorted, 0, 1,
[(e_m, e_c)], 0, -1)
e_minus_m_lb = _DimExpr._merge_sorted_terms(e_minus_m, 0, 1,
[(_DimMon(), 1)], 0, e_c * int(m_lb))
acc_combination(_DimExpr(dict(e_minus_m_lb), e.scope))
for prev_constraints in self._expr_constraints.values():
for _, prev in prev_constraints:
# Compose "e" with "prev" if they have one monomial with different
# signs
for e_m, e_c in e.monomials():
if e_m.degree == 0: continue
prev_c = prev._coeffs.get(e_m)
if prev_c is not None and prev_c * e_c < 0:
new_constraint = _DimExpr(
dict(_DimExpr._merge_sorted_terms(e._monomials_sorted, 0, abs(prev_c),
prev._monomials_sorted, 0, abs(e_c))),
e.scope)
acc_combination(new_constraint)
break
return combinations
def _add_to_state(self, e: _DimExpr,
constraint_str: str):
"""Updates the internal state to reflect "e >= 0". """
assert _DimExpr.to_constant(e) is None
for m, m_c in e.monomials():
if m.degree == 0: continue
_add_internal_constraints(self, m, e.scope)
if (mon_factors := e.to_single_term()) is not None:
n, mon_c, mon = mon_factors
bounds = self._term_bounds.get(mon, (- np.inf, np.inf))
if mon_c > 0:
mon_ge = int(np.ceil(- n / mon_c))
new_bounds = (max(mon_ge, bounds[0]), bounds[1])
else:
le = int(np.floor(-n / mon_c))
new_bounds = (bounds[0], min(le, bounds[1]))
if new_bounds[0] > new_bounds[1]:
raise ValueError(f"Unsatisfiable constraint: {constraint_str}")
self._term_bounds[mon] = new_bounds
return
lead_m, lead_m_c = e.leading_term
self._expr_constraints[lead_m].add((lead_m_c, e))
def bounds(self, e: DimSize,
stop_early: Callable[[float, float], bool] | None
) -> tuple[float, float]:
"""Returns the lower and upper bounds, or -+inf.
See more details in `_shape_poly.bounds_decision`.
"""
if (const := _DimExpr.to_constant(e)) is not None:
return (const, const)
assert isinstance(e, _DimExpr)
cache_key = (e, stop_early)
if (res := self.scope._bounds_cache.get(cache_key)) is not None: return res
res = self._bounds_for_sorted_terms(e.scope, e._monomials_sorted, 0, stop_early)
self.scope._bounds_cache[cache_key] = res
return res
def _bounds_for_sorted_terms(self,
scope: SymbolicScope,
e: Sequence[tuple[_DimMon, int]],
i: int,
stop_early: Callable[[float, float], bool] | None) -> tuple[float, float]:
"""The lower and upper bounds of e[i:].
See comments about soundness and `cmp_with` in the `_shape_poly.bounds_decision`` method.
"""
if i >= len(e): return (0, 0)
m, m_c = e[i]
if len(m) == 0: # A constant
assert i == len(e) - 1 # Must be last
return (m_c, m_c)
_add_internal_constraints(self, m, scope)
lb = -np.inf
ub = np.inf
# Look among the term bounds
if m in self._term_bounds:
m_lb, m_ub = self._term_bounds.get(m, (- np.inf, np.inf))
rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, e, i + 1, None)
if m_c > 0:
lb = max(lb, m_c * m_lb + rest_lb)
ub = min(ub, m_c * m_ub + rest_ub)
else:
lb = max(lb, m_c * m_ub + rest_lb)
ub = min(ub, m_c * m_lb + rest_ub)
if stop_early is not None and stop_early(lb, ub): return (lb, ub)
# Now look through the _expr_constraints
if m in self._expr_constraints:
for m_k, c in self._expr_constraints[m]:
# A complex expression. See comments from top of class.
sgn_m_k = 1 if m_k > 0 else -1
abs_m_k = m_k * sgn_m_k
# The recursive call has a smaller leading monomial, because we are only
# looking at the tail of e, and in c the largest monomial is m, and the
# merging will cancel the m.
rest = _DimExpr._merge_sorted_terms(e, i, abs_m_k,
c._monomials_sorted, 0, - sgn_m_k * m_c)
rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, rest, 0, None)
if m_c / m_k > 0:
lb = max(lb, np.ceil(rest_lb / abs_m_k))
else:
ub = min(ub, np.floor(rest_ub / abs_m_k))
if stop_early is not None and stop_early(lb, ub): return (lb, ub)
# Now look for special rules for atoms
if (m_a := m.to_atom()) is not None:
if m_a.operation in [_DimAtom.MAX, _DimAtom.MIN]:
# m_c*MAX(op1, op2) + rest_e >= max(m_c * op1 + rest_e, m_c * op2 + rest_e)
# if m_c > 0. Similar rules for when m_c < 0 and for MIN.
op1, op2 = m_a.operands
rest1 = _DimExpr._merge_sorted_terms(e, i + 1, 1,
op1._monomials_sorted, 0, m_c)
rest2 = _DimExpr._merge_sorted_terms(e, i + 1, 1,
op2._monomials_sorted, 0, m_c)
rest1_lb, rest1_ub = self._bounds_for_sorted_terms(scope, rest1, 0, None)
rest2_lb, rest2_ub = self._bounds_for_sorted_terms(scope, rest2, 0, None)
like_max = (m_c > 0 if m_a.operation == _DimAtom.MAX else m_c < 0)
if like_max:
lb = max(lb, max(rest1_lb, rest2_lb))
ub = min(ub, max(rest1_ub, rest2_ub))
else:
lb = max(lb, min(rest1_lb, rest2_lb))
ub = min(ub, min(rest1_ub, rest2_ub))
if stop_early is not None and stop_early(lb, ub): return (lb, ub)
return lb, ub
def _add_internal_constraints(decision: _DecisionByElimination, m: _DimMon, scope: SymbolicScope):
"""Adds the internal constraints for the monomial `m`."""
if m in decision._processed_for_internal_constraints: return
decision._processed_for_internal_constraints.add(m)
m_e = _DimExpr.from_monomial(m, 1, scope) # m as a _DimExpr
a = m.to_atom()
if a is None:
# This is a multiplication of atoms. Try to compute bounds based on
# the bounds of the atoms.
bounds = []
for a, exp in m.items():
a_l, a_u = decision.bounds(_DimExpr.from_monomial(_DimMon.from_atom(a, 1),
1, scope), None)
assert a_l <= a_u
bounds.append((a_l ** exp, a_u ** exp))
candidate_bounds = [math.prod(atom_bounds)
for atom_bounds in itertools.product(*bounds)]
m_l = min(*candidate_bounds)
m_u = max(*candidate_bounds)
decision.add_constraint(m_e, m_l)
decision.add_constraint(m_u, m_e)
return
# It is an atom, is it a variable?
if (v := a.to_var()) is not None:
decision.add_constraint(m_e, 1) # v >= 1
return
if a.operation == _DimAtom.MOD:
op1, op2 = a.operands
op2_b_l, op2_b_u = decision.bounds(op2, _stop_early_for_geq0)
if op2_b_l > 0: # positive divisor
decision.add_constraint(m_e, 0) # m >= 0
decision.add_constraint(op2 - 1, m_e) # m <= op2 - 1
decision.add_constraint(op2_b_u - 1, m_e)
elif op2_b_u < 0: # negative divisor
decision.add_constraint(m_e, op2 + 1) # m >= op2 + 1
decision.add_constraint(m_e, op2_b_l + 1)
decision.add_constraint(0, m_e) # m <= 0
return
if a.operation == _DimAtom.FLOORDIV:
op1, op2 = a.operands
(op1_l, op1_u) = decision.bounds(op1, None)
(op2_l, op2_u) = decision.bounds(op2, None)
def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
# When either a or b are infinite, the results represent the limit
# of "a // b".
assert b != 0
if not np.isinf(b): # divisor b is finite
if not np.isinf(a):
return math.floor(a / b)
# a is infinite, b is finite
return -np.inf if (a >= 0) != (b >= 0) else np.inf
elif not np.isinf(a): # dividend a is finite and divisor b is infinite
return -1 if (a >= 0) != (b >= 0) else 0
else: # both dividend and divisor are infinite
return -np.inf if (a >= 0) != (b >= 0) else np.inf
# Same reasoning as for multiplication: the bounds are among the cross-product
# of the bounds.
candidate_bounds = [math_floor_with_inf(op1_l, op2_l),
math_floor_with_inf(op1_l, op2_u),
math_floor_with_inf(op1_u, op2_l),
math_floor_with_inf(op1_u, op2_u)]
m_l = min(*candidate_bounds)
m_u = max(*candidate_bounds)
decision.add_constraint(m_e, m_l)
decision.add_constraint(m_u, m_e)
if op2_l >= 0:
if decision.bounds(op1, _stop_early_for_geq0)[0] >= 0:
decision.add_constraint(m_e, 0)
mod_e = _DimExpr.from_operation(_DimAtom.MOD, op1, op2,
scope=scope)
combined = op2 * m_e + mod_e
decision.add_constraint(op1, combined)
decision.add_constraint(combined, op1)
return
if a.operation == _DimAtom.MAX:
op1, op2 = a.operands
op1_b_l, op1_b_u = decision.bounds(op1, None)
op2_b_l, op2_b_u = decision.bounds(op2, None)
decision.add_constraint(m_e, max(op1_b_l, op2_b_l))
decision.add_constraint(max(op1_b_u, op2_b_u), m_e)
decision.add_constraint(m_e, op1)
decision.add_constraint(m_e, op2)
if a.operation == _DimAtom.MIN:
op1, op2 = a.operands
op1_b_l, op1_b_u = decision.bounds(op1, None)
op2_b_l, op2_b_u = decision.bounds(op2, None)
decision.add_constraint(m_e, min(op1_b_l, op2_b_l))
decision.add_constraint(min(op1_b_u, op2_b_u), m_e)
decision.add_constraint(op1, m_e)
decision.add_constraint(op2, m_e)

View File

@ -37,6 +37,7 @@ import re
import jax
from jax.experimental import export
from jax.experimental.export import _shape_poly as shape_poly
from jax.experimental.export import _shape_poly_decision as shape_poly_decision
from jax.experimental import pjit
from jax import lax
import jax.numpy as jnp
@ -73,11 +74,13 @@ expect_error_associative_scan = (
def _expect(*, current, best):
return current
def _bounds(e: shape_poly.DimSize) -> tuple[float, float]:
if not isinstance(e, shape_poly._DimExpr):
e = shape_poly._ensure_poly(e, "_bounds", shape_poly.SymbolicScope())
return e.bounds()
if isinstance(e, shape_poly._DimExpr):
scope = e.scope
else:
scope = shape_poly.SymbolicScope()
decision = shape_poly_decision._DecisionByElimination(scope)
return decision.bounds(e, None)
def _assert_equal_bounds(tst: jtu.JaxTestCase,
e: shape_poly.DimSize,
@ -473,8 +476,7 @@ class DimExprTest(jtu.JaxTestCase):
self.assertEqual(_bounds(core.non_negative_dim(a - 5)), (0, np.inf))
self.assertEqual(_bounds(core.non_negative_dim(15 - a)), (0, 14))
self.assertEqual(_bounds(core.non_negative_dim(15 - a) // 3), (0, 4))
self.assertEqual(_bounds(a - core.non_negative_dim(a - 3)),
_expect(best=(1, 3), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(a - core.non_negative_dim(a - 3)), (1, 3))
def test_max_dim(self):
a, b, c, d = shape_poly.symbolic_shape("a, b, c, d")
@ -497,28 +499,23 @@ class DimExprTest(jtu.JaxTestCase):
self.assertGreaterEqual(core.max_dim(a, b), core.min_dim(a, b))
self.assertEqual(_bounds(b - core.max_dim(b - a, 0)),
_expect(best=(1, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(a - core.min_dim(a, b)),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(b - core.min_dim(a, b)),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(b - core.max_dim(b - a, 0)), (1, np.inf))
self.assertEqual(_bounds(a - core.min_dim(a, b)), (0, np.inf))
self.assertEqual(_bounds(b - core.min_dim(a, b)), (0, np.inf))
self.assertEqual(_bounds(core.max_dim(a, b) - a),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(core.max_dim(a, b) - b),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(core.max_dim(a, b) - a), (0, np.inf))
self.assertEqual(_bounds(core.max_dim(a, b) - b), (0, np.inf))
self.assertEqual((0, 0), _bounds(core.max_dim(1 - b, 0)))
self.assertEqual(_bounds(core.max_dim(a, b) - a - core.max_dim(b - a, 0)),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
_expect(best=(0, 0), current=(0, np.inf)))
self.assertEqual(_bounds(core.max_dim(a, b) + core.max_dim(c, d) -
core.max_dim(a + core.max_dim(c, d),
b + core.max_dim(c, d))),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
_expect(best=(0, 0), current=(0, np.inf)))
self.assertEqual(_bounds(core.max_dim(a, b) + core.min_dim(c, d) -
core.max_dim(a + core.min_dim(c, d),
b + core.min_dim(c, d))),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
_expect(best=(0, 0), current=(0, np.inf)))
self.sampled_assertion(core.max_dim(a, 5), core.max_dim, a, 5)
self.sampled_assertion(core.max_dim(5, a), core.max_dim, 5, a)
@ -758,10 +755,11 @@ class DimExprTest(jtu.JaxTestCase):
"m >= 0", "m <= 3"]
scope = shape_poly.SymbolicScope(assumptions)
a, d = shape_poly.symbolic_shape("a, d", scope=scope)
self.assertEqual(_bounds(a - 4*d),
_expect(best=(1, 3), current=(-np.inf, np.inf))) # a - 4d = m >= 1
self.assertEqual(_bounds(a - 2*d),
_expect(best=(3, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(a - 4*d), (1, 3)) # a - 4d = m >= 1
self.assertEqual(_bounds(a - 2*d), (3, np.inf)) # a - 2d = m + 2d >= 3
# TODO: The incompleteness is due to the way we combine external constraints
self.assertEqual(_bounds(a),
_expect(best=(5, np.inf), current=(4, np.inf))) # a >= 4d + m >= 5
def test_constraints_errors(self):
with self.assertRaisesRegex(ValueError,
@ -806,14 +804,10 @@ class DimExprTest(jtu.JaxTestCase):
a, b, c = shape_poly.symbolic_shape(
"a, b, c",
constraints=("a + 2 <= b", "b <= a + 5", "a + b >= c"))
self.assertEqual(_bounds(b),
_expect(best=(3, np.inf), current=(1, np.inf)))
self.assertEqual(_bounds(b - a),
_expect(best=(2, 5), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(b - a - 7),
_expect(best=(-5, -2), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(c - 2*a - 5),
_expect(best=(-np.inf, 0), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(b), (3, np.inf))
self.assertEqual(_bounds(b - a), (2, 5))
self.assertEqual(_bounds(b - a - 7), (-5, -2))
self.assertEqual(_bounds(c - 2*a - 5), (-np.inf, 0))
def test_constraints_fractional(self):
a, = shape_poly.symbolic_shape("a",
@ -823,11 +817,11 @@ class DimExprTest(jtu.JaxTestCase):
@jtu.parameterized_filterable(
kwargs=[
dict(constraint="2*a + 3*b >= 10", exp="a + b - 2",
bounds=_expect(best=(2, np.inf), current=(0, np.inf))),
bounds=(2, np.inf)),
dict(constraint="-2*a + 3*b >= 10", exp="a + 2*b",
bounds=_expect(best=(9, np.inf), current=(3, np.inf))),
bounds=(9, np.inf)),
dict(constraint="-2*a + -3*b >= -10", exp="-1*a + 2*b",
bounds=_expect(best=(-1, 3), current=(-np.inf, np.inf))),
bounds=(-1, 3)),
dict(constraint="2*a + -3*b >= 10", exp="-1*a + 2*b", bounds=(-np.inf, np.inf)),
]
)
@ -872,8 +866,7 @@ class DimExprTest(jtu.JaxTestCase):
"a4 >= a5",
)
)
self.assertEqual(_bounds(a1 - a5),
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(a1 - a5), (0, np.inf))
def test_constraints_rounding_monomials(self):
a1, a2 = shape_poly.symbolic_shape(
@ -893,10 +886,8 @@ class DimExprTest(jtu.JaxTestCase):
"2*a1 >= 2*a2 + 3", "-2*a1 + 2*a2 >= -5"
)
)
self.assertEqual(_bounds(a1 - a2 - 2),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(a2 - a1 + 2),
_expect(best=(0, 0), current=(-np.inf, np.inf)))
self.assertEqual(_bounds(a1 - a2 - 2), (0, 0))
self.assertEqual(_bounds(a2 - a1 + 2), (0, 0))
def test_constraints_unsat_trivial(self):
with self.assertRaisesRegex(ValueError,
@ -905,10 +896,8 @@ class DimExprTest(jtu.JaxTestCase):
"a1", constraints=("a1 >= a1 + 1",))
def test_constraints_unsat_monomials(self):
with self.assertRaisesRegex(_expect(best=ValueError,
current=shape_poly.InconclusiveDimensionOperation),
_expect(best="Unsatisfiable constraint",
current="inconclusive")):
with self.assertRaisesRegex(ValueError,
"Unsatisfiable constraint"):
a1, a2, *_ = shape_poly.symbolic_shape(
"a1, a2, a3, a4",
constraints=(
@ -919,10 +908,8 @@ class DimExprTest(jtu.JaxTestCase):
a1 >= a2
def test_constraints_unsat_not_monomials(self):
with self.assertRaisesRegex(_expect(best=ValueError,
current=shape_poly.InconclusiveDimensionOperation),
_expect(best="Unsatisfiable constraint",
current="inconclusive")):
with self.assertRaisesRegex(ValueError,
"Unsatisfiable constraint"):
a1, a2, a3, a4 = shape_poly.symbolic_shape(
"a1, a2, a3, a4",
constraints=(