mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality constaints on symbolic expressions, but with limited support for the cases when a constrain contains more than one term, e.g., "a >= b". Here we add a simple decision procedure for such inequalities, based on the elimination algorithm based on the following properties: * if we have two constraints "a + b >= 0" and "-a + c >= 0" we can eliminate "a" and infer the derived constraint "b + c >= 0". * the lower bound of "a + c", in presence of a constraint "a >= b" it greater-or-equal to "b + c". The above rules can be generalized to cases when the eliminated terms have coefficients different than 1. This algorithm is exponential in the number of constraints, but we implement a limited form. When we add a constraint we combine it with already added constraints, but the result of the combination is not combined further. This is sufficient for the cases we have encountered so far. The termination of the algorithm is ensured by always eliminating the largest (leading) term, ensuring that the result of a combination of constraints has a smaller leading term. With this added power for reasoning, we can retire the previous heuristics for handling "min", "max", "floordiv" and "mod" and replace them with the addition of some implicit constraints for them, e.g., "max(a, b) >= a", etc., and then letting the decision procedure do its job. We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
This commit is contained in:
parent
2518a6f6d2
commit
e20afac46a
@ -34,6 +34,7 @@ py_library(
|
||||
"_export.py",
|
||||
"_serialization.py",
|
||||
"_shape_poly.py",
|
||||
"_shape_poly_decision.py",
|
||||
"serialization_generated.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
|
@ -35,3 +35,4 @@ from jax.experimental.export._serialization import (
|
||||
serialize,
|
||||
deserialize,
|
||||
)
|
||||
from jax.experimental.export import _shape_poly_decision
|
||||
|
@ -25,7 +25,7 @@ This enables many JAX programs to be traced with symbolic dimensions
|
||||
in some dimensions. A priority has been to enable the batch
|
||||
dimension in neural network examples to be polymorphic.
|
||||
|
||||
This was built initially for jax2tf, but it is now customizable to be
|
||||
This was built initially for jax2tf, but it is now
|
||||
independent of TF. The best documentation at the moment is in the
|
||||
jax2tf.convert docstring, and the
|
||||
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
|
||||
@ -210,48 +210,6 @@ class _DimAtom:
|
||||
"""Lexicographic comparison"""
|
||||
return self._syntactic_cmp(other) >= 0
|
||||
|
||||
def bounds(self) -> tuple[float, float]:
|
||||
"""Returns the lower and upper bounds, or -+ inf."""
|
||||
if self.var is not None:
|
||||
return (1, np.inf) # variables are assumed to be >= 1
|
||||
opnd_bounds = [opnd.bounds() for opnd in self.operands]
|
||||
if self.operation == _DimAtom.FLOORDIV: # a // b
|
||||
(a_l, a_u), (b_l, b_u) = opnd_bounds
|
||||
def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
|
||||
assert b != 0
|
||||
if not np.isinf(b): # divisor is finite
|
||||
return math.floor(a / b) if not np.isinf(a) else -np.inf if (a >= 0) != (b >= 0) else np.inf
|
||||
elif not np.isinf(a): # dividend is finite and divisor is infinite
|
||||
return -1 if (a >= 0) != (b >= 0) else 0
|
||||
else: # both dividend and divisor are infinite
|
||||
return -np.inf if (a >= 0) != (b >= 0) else np.inf
|
||||
|
||||
# Same reasoning as for multiplication: the bounds are among the cross-product
|
||||
# of the bounds.
|
||||
bound_candidates = [math_floor_with_inf(a_l, b_l), math_floor_with_inf(a_l, b_u),
|
||||
math_floor_with_inf(a_u, b_l), math_floor_with_inf(a_u, b_u)]
|
||||
return (min(*bound_candidates), max(*bound_candidates))
|
||||
|
||||
elif self.operation == _DimAtom.MOD:
|
||||
_, (b_l, b_u) = opnd_bounds
|
||||
if b_l > 0: # positive divisor
|
||||
return (0, b_u - 1)
|
||||
elif b_u < 0: # negative divisor
|
||||
return (b_l + 1, 0)
|
||||
else:
|
||||
return (-np.inf, np.inf)
|
||||
|
||||
elif self.operation == _DimAtom.MAX:
|
||||
(a_l, a_h), (b_l, b_h) = opnd_bounds
|
||||
return (max(a_l, b_l), max(a_h, b_h))
|
||||
|
||||
elif self.operation == _DimAtom.MIN:
|
||||
(a_l, a_h), (b_l, b_h) = opnd_bounds
|
||||
return (min(a_l, b_l), min(a_h, b_h))
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
def evaluate(self, env: DimVarEnv):
|
||||
if self.var is not None:
|
||||
try:
|
||||
@ -401,21 +359,6 @@ class _DimMon(dict):
|
||||
elif diff > 0: d[key] = diff
|
||||
return _DimMon(d)
|
||||
|
||||
def bounds(self, scope: SymbolicScope) -> tuple[float, float]:
|
||||
"""Returns the lower and upper bounds, or -+inf."""
|
||||
# The bounds of a product are among the product of bounds.
|
||||
bounds = []
|
||||
for a, exp in self.items():
|
||||
a_l, a_u = a.bounds()
|
||||
assert a_l <= a_u
|
||||
bounds.append((a_l ** exp, a_u ** exp))
|
||||
|
||||
candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)]
|
||||
calculated_bounds = (min(*candidates), max(*candidates)) # type: ignore
|
||||
constrained_bounds = scope._monomial_bounds.get(self, (- np.inf, np.inf))
|
||||
return (max(calculated_bounds[0], constrained_bounds[0]),
|
||||
min(calculated_bounds[1], constrained_bounds[1]))
|
||||
|
||||
def evaluate(self, env: DimVarEnv):
|
||||
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1)
|
||||
def pow_opt(v, p: int):
|
||||
@ -442,7 +385,7 @@ class _DimExpr():
|
||||
scope: SymbolicScope):
|
||||
# Do not construct _DimExpr directly, unless you are sure that coeffs is
|
||||
# normalized; Use _DimExpr.normalize.
|
||||
# Takes ownership of coeffs
|
||||
# Takes ownership of coeffs.
|
||||
self._coeffs = coeffs or {_DimMon(): 0}
|
||||
self._scope = scope
|
||||
self._monomials_sorted = tuple(sorted(self._coeffs.items(), reverse=True))
|
||||
@ -462,12 +405,31 @@ class _DimExpr():
|
||||
|
||||
@property
|
||||
def leading_term(self) -> tuple[_DimMon, int]:
|
||||
"""Returns the highest degree term that comes first lexicographically."""
|
||||
"""Returns the highest degree term that comes last lexicographically."""
|
||||
return self._monomials_sorted[0]
|
||||
|
||||
def to_single_term(self) -> tuple[int, int, _DimMon] | None:
|
||||
"""Extracts the single term: k + c * term.
|
||||
Returns None if the expression is not a single term, or (k, c, term)
|
||||
"""
|
||||
n1 = 0
|
||||
n2 = 0
|
||||
mon = None
|
||||
for m, c in self.monomials():
|
||||
if m.degree == 0:
|
||||
n1 = c
|
||||
continue
|
||||
if mon is None:
|
||||
mon = m
|
||||
n2 = c
|
||||
continue
|
||||
return None
|
||||
assert mon is not None
|
||||
return (n1, n2, mon)
|
||||
|
||||
@classmethod
|
||||
def _add_coeffs(cls, coeffs: dict[_DimMon, int], mon: _DimMon, coeff: int):
|
||||
"""Do `coeffs[mon] += coeff` but remove 0 coefficients."""
|
||||
"""Computes `coeffs[mon] += coeff` while removing 0 coefficients."""
|
||||
old_c = coeffs.get(mon)
|
||||
if old_c is None:
|
||||
if coeff != 0: coeffs[mon] = coeff
|
||||
@ -504,24 +466,6 @@ class _DimExpr():
|
||||
else:
|
||||
return int(free_const)
|
||||
|
||||
@classmethod
|
||||
def normalize_floordiv_times_divisor(cls, coeffs: dict[_DimMon, int],
|
||||
scope: SymbolicScope) -> DimSize:
|
||||
# Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes
|
||||
# up when handling strided convolution.
|
||||
for dec in _decompose_expr(_DimExpr(coeffs, scope), _DimAtom.FLOORDIV,
|
||||
with_exp=1):
|
||||
# e = factor * floordiv(operands)^exp * rest_monomial + rest_expr
|
||||
if dec.rest_monomial == 1 and dec.factor == 1:
|
||||
continue
|
||||
m_trimmed, m_remainder = divmod(dec.factor * dec.rest_monomial, dec.operands[1])
|
||||
if m_remainder == 0:
|
||||
return m_trimmed * (
|
||||
dec.operands[0] -
|
||||
_DimExpr.from_operation(_DimAtom.MOD, *dec.operands,
|
||||
scope=scope)) + dec.rest_expr
|
||||
return _DimExpr.normalize(coeffs, scope)
|
||||
|
||||
@classmethod
|
||||
def from_constant(cls, c: int, scope: SymbolicScope):
|
||||
return _DimExpr({_DimMon(): op.index(c)}, scope)
|
||||
@ -570,15 +514,18 @@ class _DimExpr():
|
||||
mon = self.to_atom()
|
||||
return mon.to_var() if mon is not None else None
|
||||
|
||||
def to_constant(self) -> int | None:
|
||||
@classmethod
|
||||
def to_constant(cls, e: DimSize) -> int | None:
|
||||
"""Extract the constant from a symbolic expression.
|
||||
Returns None if the expression is not a single constant."""
|
||||
m, m_c = self.leading_term
|
||||
if not isinstance(e, _DimExpr):
|
||||
return int(e)
|
||||
m, m_c = e.leading_term
|
||||
return m_c if m.degree == 0 else None
|
||||
|
||||
@property
|
||||
def is_constant(self):
|
||||
return self.to_constant() is not None
|
||||
return _DimExpr.to_constant(self) is not None
|
||||
|
||||
def get_vars(self) -> set[str]:
|
||||
"""The variables that appear in a symbolic dimension."""
|
||||
@ -587,6 +534,39 @@ class _DimExpr():
|
||||
acc.update(mon.get_vars())
|
||||
return acc
|
||||
|
||||
@classmethod
|
||||
def _merge_sorted_terms(
|
||||
cls,
|
||||
e1: Sequence[tuple[_DimMon, int]], i1: int, f1: int,
|
||||
e2: Sequence[tuple[_DimMon, int]], i2: int, f2: int) -> Sequence[tuple[_DimMon, int]]:
|
||||
"""Computes e1[i1:] * f1 + e2[i2:] * f2.
|
||||
|
||||
e1, e2, and the result are sorted with largest term first.
|
||||
This is an optimization for a common operation. The unoptimized code would
|
||||
compute each subexpression in term.
|
||||
"""
|
||||
acc = []
|
||||
while i1 < len(e1) and i2 < len(e2):
|
||||
m1, m1_c = e1[i1]
|
||||
m2, m2_c = e2[i2]
|
||||
cmp = m1._syntactic_cmp(m2) # Pick the largest monomial
|
||||
if cmp < 0:
|
||||
acc.append((m2, m2_c * f2))
|
||||
i2 += 1
|
||||
elif cmp > 0:
|
||||
acc.append((m1, m1_c * f1))
|
||||
i1 += 1
|
||||
else: # They are equal, combine them
|
||||
i1 += 1
|
||||
i2 += 1
|
||||
m1_c = m1_c * f1 + m2_c * f2
|
||||
if m1_c == 0: continue
|
||||
acc.append((m1, m1_c))
|
||||
|
||||
acc.extend((m1, m1_c * f1) for m1, m1_c in itertools.islice(e1, i1, len(e1)))
|
||||
acc.extend((m2, m2_c * f2) for m2, m2_c in itertools.islice(e2, i2, len(e2)))
|
||||
return acc
|
||||
|
||||
def _syntactic_cmp(self, other: _DimExpr) -> int:
|
||||
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
|
||||
The comparison is done lexicographically (syntactic), to be used for sorting.
|
||||
@ -594,6 +574,7 @@ class _DimExpr():
|
||||
"""
|
||||
s_mons = self._monomials_sorted
|
||||
o_mons = other._monomials_sorted
|
||||
if c := cmp_comparable(self._size, other._size): return c
|
||||
def cmp_mon(s_mon: tuple[_DimMon, int], o_mon: tuple[_DimMon, int]) -> int:
|
||||
if c := s_mon[0]._syntactic_cmp(o_mon[0]): return c
|
||||
return cmp_comparable(s_mon[1], o_mon[1])
|
||||
@ -618,79 +599,6 @@ class _DimExpr():
|
||||
|
||||
return diff == 0
|
||||
|
||||
def ge(self, other: DimSize,
|
||||
cmp_str: Callable[[], str] | None = None) -> bool:
|
||||
"""Implements `self >= other`.
|
||||
|
||||
Raises InconclusiveDimensionOperation if the result is not conclusive.
|
||||
Uses `cmp_str()` as a description of the comparison in the exception
|
||||
string.
|
||||
"""
|
||||
self_minus_other = _ensure_poly(self - other, "ge", self.scope)
|
||||
lb, ub = self_minus_other.bounds()
|
||||
if lb >= 0:
|
||||
return True
|
||||
if ub < 0:
|
||||
return False
|
||||
# Attempt to handle max. For the decomposition
|
||||
# e = factor * max(op1, op2) + rest_expr
|
||||
# use the rule
|
||||
# e >= 0 IF (factor > 0 AND
|
||||
# (factor * op1 + rest_expr >= 0 OR
|
||||
# factor * op2 + rest_expr >= 0))
|
||||
# OR
|
||||
# (factor < 0 AND
|
||||
# (factor * op1 + rest_expr >= 0 AND
|
||||
# factor * op2 + rest_expr >= 0))
|
||||
for dec in _decompose_expr(self_minus_other, _DimAtom.MAX,
|
||||
with_exp=1, with_rest_monomial=1):
|
||||
op1, op2 = dec.operands
|
||||
if dec.factor > 0:
|
||||
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) or
|
||||
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
|
||||
return True
|
||||
else:
|
||||
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) and
|
||||
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
|
||||
return True
|
||||
|
||||
# Attempt to handle min. For the decomposition
|
||||
# e = factor * min(op1, op2) + rest_expr
|
||||
# use the same rule as for
|
||||
# e = max(factor * op1, factor * op2) + rest_expr
|
||||
for dec in _decompose_expr(self_minus_other, _DimAtom.MIN,
|
||||
with_exp=1, with_rest_monomial=1):
|
||||
op1, op2 = dec.operands
|
||||
if dec.factor > 0:
|
||||
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) and
|
||||
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
|
||||
return True
|
||||
else:
|
||||
if (definitely_geq_0(dec.factor * op1 + dec.rest_expr) or
|
||||
definitely_geq_0(dec.factor * op2 + dec.rest_expr)):
|
||||
return True
|
||||
|
||||
# Attempt to handle floordiv >= 0
|
||||
for dec in _decompose_expr(self_minus_other, _DimAtom.FLOORDIV,
|
||||
with_exp=1, with_rest_monomial=1,
|
||||
with_rest_expr=0):
|
||||
# e = factor * floordiv(op0, op1)^1 * 1 + 0
|
||||
if dec.factor > 0:
|
||||
if definitely_geq_0(dec.operands[0]) and definitely_geq_0(dec.operands[1]):
|
||||
return True
|
||||
|
||||
if cmp_str is not None:
|
||||
msg = cmp_str()
|
||||
else:
|
||||
msg = f"'{self}' >= '{other}'"
|
||||
|
||||
if self.scope._explicit_constraints:
|
||||
describe_scope = f"\nUsing symbolic scope {self.scope}"
|
||||
else:
|
||||
describe_scope = ""
|
||||
raise InconclusiveDimensionOperation(
|
||||
f"Symbolic dimension comparison {msg} is inconclusive.{describe_scope}")
|
||||
|
||||
def __hash__(self):
|
||||
return self._hash
|
||||
|
||||
@ -719,7 +627,7 @@ class _DimExpr():
|
||||
coeffs = self._coeffs.copy()
|
||||
for mon, coeff in other.monomials():
|
||||
_DimExpr._add_coeffs(coeffs, mon, coeff)
|
||||
return _DimExpr.normalize_floordiv_times_divisor(coeffs, self.scope)
|
||||
return _DimExpr.normalize(coeffs, self.scope)
|
||||
|
||||
def __radd__(self, other):
|
||||
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
||||
@ -749,7 +657,7 @@ class _DimExpr():
|
||||
for mon2, coeff2 in other.monomials():
|
||||
mon = mon1.mul(mon2)
|
||||
_DimExpr._add_coeffs(coeffs, mon, coeff1 * coeff2)
|
||||
return _DimExpr.normalize_floordiv_times_divisor(coeffs, self.scope)
|
||||
return _DimExpr.normalize(coeffs, self.scope)
|
||||
|
||||
def __rmul__(self, other):
|
||||
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
||||
@ -815,28 +723,39 @@ class _DimExpr():
|
||||
return False
|
||||
elif not core.is_constant_dim(other):
|
||||
return False
|
||||
else:
|
||||
other = _ensure_poly(other, "eq", self.scope)
|
||||
return self.eq(other)
|
||||
|
||||
# Equality is used very frequently because expressions are cached. We could
|
||||
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
|
||||
# but that would be too expensive. It would also have the unfortunate drawback
|
||||
# that we cannot then cache `e.bounds()` because hashing invokes equality
|
||||
# which would lead to infinite recursion.
|
||||
diff = self - other
|
||||
|
||||
# We look for `self - other == k`, and we rely on the fact that when we
|
||||
# normalize _DimExpr that represent integers as ints.
|
||||
if is_symbolic_dim(diff):
|
||||
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
|
||||
# cannot raise exceptions, because it is used indirectly when hashing.
|
||||
# So, we say that the expressions are disequal, which is really unsound.
|
||||
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported
|
||||
return False
|
||||
|
||||
return diff == 0
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __ge__(self, other: DimSize) -> bool:
|
||||
return self.ge(
|
||||
other, lambda: f"'{self}' >= '{other}'")
|
||||
return _geq_decision(self, other, lambda: f"'{self}' >= '{other}'")
|
||||
|
||||
def __le__(self, other: DimSize):
|
||||
return _ensure_poly(other, "le", self.scope).ge(
|
||||
self, lambda: f"'{self}' <= '{other}'")
|
||||
return _geq_decision(other, self, lambda: f"'{self}' <= '{other}'")
|
||||
|
||||
def __gt__(self, other: DimSize):
|
||||
return not _ensure_poly(other, "le", self.scope).ge(
|
||||
self, lambda: f"'{self}' > '{other}'")
|
||||
return not _geq_decision(other, self, lambda: f"'{self}' > '{other}'")
|
||||
|
||||
def __lt__(self, other: DimSize):
|
||||
return not self.ge(
|
||||
other, lambda: f"'{self}' < '{other}'")
|
||||
return not _geq_decision(self, other, lambda: f"'{self}' < '{other}'")
|
||||
|
||||
def divmod(self, divisor: _DimExpr) -> tuple[DimSize, int]:
|
||||
"""
|
||||
@ -890,41 +809,6 @@ class _DimExpr():
|
||||
_DimExpr.from_operation(_DimAtom.MOD, self, divisor,
|
||||
scope=self.scope))
|
||||
|
||||
def bounds(self) -> tuple[float, float]:
|
||||
"""Returns the lower and upper bounds, or -+inf."""
|
||||
lb = ub = self._coeffs.get(_DimMon(), 0) # The free coefficient
|
||||
for mon, coeff in self.monomials():
|
||||
if mon.degree == 0: continue # We already included the free coefficient
|
||||
m_l, m_u = mon.bounds(self.scope)
|
||||
assert m_l <= m_u and coeff != 0
|
||||
item_l, item_u = coeff * m_l, coeff * m_u
|
||||
lb = lb + min(item_l, item_u) # type: ignore
|
||||
ub = ub + max(item_l, item_u) # type: ignore
|
||||
|
||||
bounds_from_constraints = self.scope._expr_bounds.get(self, (- np.inf, np.inf))
|
||||
lb = max(lb, bounds_from_constraints[0]) # type: ignore
|
||||
ub = min(ub, bounds_from_constraints[1]) # type: ignore
|
||||
if lb != -np.inf or ub != np.inf:
|
||||
return lb, ub
|
||||
# Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0
|
||||
# TODO(necula): add more principled support for floordiv and mod
|
||||
# For example, this will miss "1 + a - mod(b, a)"
|
||||
for dec in _decompose_expr(self, _DimAtom.MOD,
|
||||
with_exp=1, with_rest_monomial=1):
|
||||
# E = factor*mod(op1, op2)^1 * 1 + rest_expr
|
||||
if dec.rest_expr == - dec.factor * dec.operands[1]:
|
||||
try:
|
||||
if dec.operands[1] <= 0:
|
||||
continue
|
||||
except InconclusiveDimensionOperation:
|
||||
continue
|
||||
if dec.factor > 0:
|
||||
return (-np.inf, -1)
|
||||
else:
|
||||
return (1, np.inf)
|
||||
|
||||
return lb, ub
|
||||
|
||||
def evaluate(self, env: DimVarEnv):
|
||||
# Evaluates as a value of dtype=core.dim_value_dtype()
|
||||
terms = [_evaluate_multiply(mon.evaluate(env), core.dim_constant(coeff))
|
||||
@ -932,25 +816,25 @@ class _DimExpr():
|
||||
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
|
||||
|
||||
def max(self, other: DimSize) -> DimSize:
|
||||
lb, ub = _ensure_poly(self - other, "max", self.scope).bounds()
|
||||
lb, ub = _bounds_decision(self - other, _stop_early_for_geq0_leq0)
|
||||
if 0 <= lb: return self
|
||||
if ub <= 0: return other
|
||||
return _DimExpr.from_operation(_DimAtom.MAX, self, other, scope=self.scope)
|
||||
|
||||
def rmax(self, other: DimSize) -> DimSize:
|
||||
lb, ub = _ensure_poly(self - other, "max", self.scope).bounds()
|
||||
lb, ub = _bounds_decision(self - other, _stop_early_for_geq0_leq0)
|
||||
if 0 <= lb: return self
|
||||
if ub <= 0: return other
|
||||
return _DimExpr.from_operation(_DimAtom.MAX, other, self, scope=self.scope)
|
||||
|
||||
def min(self, other: DimSize) -> DimSize:
|
||||
lb, ub = _ensure_poly(self - other, "min", self.scope).bounds()
|
||||
lb, ub = _bounds_decision(self - other, _stop_early_for_geq0_leq0)
|
||||
if 0 <= lb: return other
|
||||
if ub <= 0: return self
|
||||
return _DimExpr.from_operation(_DimAtom.MIN, self, other, scope=self.scope)
|
||||
|
||||
def rmin(self, other: DimSize) -> DimSize:
|
||||
lb, ub = _ensure_poly(self - other, "min", self.scope).bounds()
|
||||
lb, ub = _bounds_decision(self - other, _stop_early_for_geq0_leq0)
|
||||
if 0 <= lb: return other
|
||||
if ub <= 0: return self
|
||||
return _DimExpr.from_operation(_DimAtom.MIN, other, self, scope=self.scope)
|
||||
@ -981,6 +865,8 @@ def cmp_sequence(s1, s2, elem_cmp) -> int:
|
||||
if len(s1) < l2: return -1
|
||||
return 0
|
||||
|
||||
def _stop_early_for_geq0_leq0(lb, ub):
|
||||
return 0 <= lb or ub <= 0
|
||||
|
||||
class SymbolicScope:
|
||||
"""Indentifies a scope for symbolic expressions.
|
||||
@ -1003,19 +889,20 @@ class SymbolicScope:
|
||||
self._location_frame = source_info_util.user_frame(source_info_util.current())
|
||||
self._explicit_constraints: list[tuple[_DimExpr, str]] = []
|
||||
|
||||
# Keep an efficient representation of
|
||||
# the explicit constraints for use during reasoning.
|
||||
self._monomial_bounds: dict[_DimMon, tuple[float, float]] = {}
|
||||
self._expr_bounds: dict[_DimExpr, tuple[float, float]] = {}
|
||||
|
||||
constraints = self._parse_constraints(constraints_str)
|
||||
for c, c_str in zip(constraints, constraints_str):
|
||||
if (const := c.to_constant()) is not None:
|
||||
if (const := _DimExpr.to_constant(c)) is not None:
|
||||
if const < 0:
|
||||
raise ValueError(f"Unsatisfiable explicit constraint: {c_str}")
|
||||
continue
|
||||
self._explicit_constraints.append((c, c_str))
|
||||
self._process_constraint(c, c_str)
|
||||
|
||||
# We cache the _DimExpr.bounds calls. The result depends only on the
|
||||
# explicit and implicit constraints, so it is safe to keep it in the
|
||||
# scope.
|
||||
self._bounds_cache: dict[tuple[_DimExpr,
|
||||
Callable[[float, float], bool] | None],
|
||||
tuple[float, float]] = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
extras = []
|
||||
@ -1051,39 +938,6 @@ class SymbolicScope:
|
||||
f"Got {repr(constraints_str)}")
|
||||
return tuple(parse_one(cs) for cs in constraints_str)
|
||||
|
||||
def _process_constraint(self, e: _DimExpr, e_str: str):
|
||||
# Look for the special case m*mon + n >= 0.
|
||||
# Then assert mon >= ceil(n / m) or mon <= floor(n / m)
|
||||
n = m = 0
|
||||
mon = None
|
||||
nr_non_trivial_monomials = 0
|
||||
for _mon, count in e.monomials():
|
||||
if _mon.degree == 0:
|
||||
n = count
|
||||
continue
|
||||
nr_non_trivial_monomials += 1
|
||||
mon = _mon
|
||||
m = count
|
||||
|
||||
if nr_non_trivial_monomials > 1:
|
||||
# The general case, we just remember this constraint in _expr_bounds.
|
||||
self._expr_bounds[e] = (0, np.inf)
|
||||
return
|
||||
|
||||
# A single non-trivial monomial
|
||||
assert isinstance(mon, _DimMon)
|
||||
bounds = mon.bounds(self) # This considers default internal constraints, and
|
||||
# previous external constraints
|
||||
if m > 0: # mon >= ceil(-n / m)
|
||||
ge = int(np.ceil(- n / m))
|
||||
new_bounds = (max(ge, bounds[0]), bounds[1])
|
||||
else: # mon <= floor(-n / m)
|
||||
le = int(np.floor(-n / m))
|
||||
new_bounds = (bounds[0], min(le, bounds[1]))
|
||||
if new_bounds[0] > new_bounds[1]:
|
||||
raise ValueError(f"Unsatisfiable constraints: {e_str}")
|
||||
self._monomial_bounds[mon] = new_bounds
|
||||
|
||||
def _check_same_scope(self, other: _DimExpr,
|
||||
when: str = "",
|
||||
self_descr: str = " ",
|
||||
@ -1096,6 +950,32 @@ class SymbolicScope:
|
||||
f"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.")
|
||||
|
||||
|
||||
# Set by the shape_poly_decision.py module.
|
||||
# Calling convention:
|
||||
# _geq_decision(e1, e2, cmp_str)
|
||||
# where e1 and e2 are two expressions to be compared for greater-equal
|
||||
# and `cmp_str()` is a string that describes the comparison for error
|
||||
# messages.
|
||||
# Returns: a boolean or raises InconclusiveDimensionOperation
|
||||
# TODO: remove this trampoline when we refactor the sources
|
||||
def _geq_decision_unimplemented(d1: DimSize, d2: DimSize,
|
||||
cmp_str: Callable[[], str]) -> bool:
|
||||
raise NotImplementedError("_geq_decision is uninitialized")
|
||||
_geq_decision: Callable[[DimSize, DimSize, Callable[[], str]], bool] = _geq_decision_unimplemented
|
||||
#
|
||||
# Calling convention:
|
||||
# _bounds_decision(e, stop_early)
|
||||
# returns a tuple with the lower and upper bound of e.
|
||||
# `stop_early(lb, ub)` can be called in an iterative process to decide if the
|
||||
# current bounds are tight enough.
|
||||
# TODO: remove this trampoline when we refactor the sources
|
||||
def _bounds_decision_unimplemented(
|
||||
d: DimSize,
|
||||
stop_early: Callable[[float, float], bool] | None) -> tuple[float, float]:
|
||||
raise NotImplementedError("_bounds_decision is uninitialized")
|
||||
_bounds_decision: Callable[[DimSize, Callable[[float, float], bool] | None],
|
||||
tuple[float, float]] = _bounds_decision_unimplemented
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Decomposition:
|
||||
"""Decomposition of an expression around an operation atom.
|
||||
@ -1185,16 +1065,6 @@ def is_poly_dim(p: DimSize) -> bool:
|
||||
|
||||
dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int]
|
||||
|
||||
def definitely_geq_0(d: DimSize) -> bool:
|
||||
"""Returns true when we can prove that d >=0, false otherwise.
|
||||
Note that a result of False may mean that we cannot conclusively prove the
|
||||
sign of `d`, it does not mean that `d < 0`.
|
||||
"""
|
||||
try:
|
||||
return d >= 0
|
||||
except InconclusiveDimensionOperation:
|
||||
return False
|
||||
|
||||
def _einsum_contract_path(*operands, **kwargs):
|
||||
"""Like opt_einsum.contract_path, with support for DimExpr shapes.
|
||||
|
||||
|
412
jax/experimental/export/_shape_poly_decision.py
Normal file
412
jax/experimental/export/_shape_poly_decision.py
Normal file
@ -0,0 +1,412 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shape polymorphism support for deciding inequalities of symbolic dimensions.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Sequence
|
||||
import itertools
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.experimental.export import _shape_poly
|
||||
from jax.experimental.export._shape_poly import (
|
||||
_DimExpr, _DimMon, _DimAtom,
|
||||
SymbolicScope,
|
||||
DimSize,
|
||||
InconclusiveDimensionOperation,
|
||||
)
|
||||
|
||||
|
||||
def geq_decision(e1: DimSize, e2: DimSize,
|
||||
cmp_str: Callable[[], str]) -> bool:
|
||||
"""Implements `e1 >= e2`.
|
||||
|
||||
Args:
|
||||
e1, e2: the expressions to compare for greater-equal
|
||||
cmp_str: a callable such that `cmp_str()` describes the comparison
|
||||
for error messages, e.g., "a <= b". Without this all comparisions would
|
||||
be reported as ">=".
|
||||
|
||||
Raises InconclusiveDimensionOperation if the result is not conclusive.
|
||||
"""
|
||||
if isinstance(e1, _DimExpr):
|
||||
scope = e1.scope
|
||||
if isinstance(e2, _DimExpr):
|
||||
scope._check_same_scope(e2, f"when comparing {cmp_str()}")
|
||||
elif isinstance(e2, _DimExpr):
|
||||
scope = e2.scope
|
||||
else:
|
||||
return int(e1) >= int(e2)
|
||||
decision = _DecisionByElimination(scope)
|
||||
lb, ub = decision.bounds(e1 - e2, _stop_early_for_geq0)
|
||||
if lb >= 0:
|
||||
return True
|
||||
if ub < 0:
|
||||
return False
|
||||
|
||||
if scope._explicit_constraints:
|
||||
describe_scope = f"\nUsing symbolic scope {scope}"
|
||||
else:
|
||||
describe_scope = ""
|
||||
raise InconclusiveDimensionOperation(
|
||||
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
|
||||
|
||||
_shape_poly._geq_decision = geq_decision
|
||||
|
||||
def _stop_early_for_geq0(lb, ub):
|
||||
return lb >= 0 or ub < 0
|
||||
|
||||
def bounds_decision(e: DimSize,
|
||||
stop_early: Callable[[float, float], bool] | None) -> tuple[float, float]:
|
||||
if not isinstance(e, _DimExpr):
|
||||
return (int(e), int(e))
|
||||
decision = _DecisionByElimination(e.scope)
|
||||
return decision.bounds(e, stop_early)
|
||||
|
||||
_shape_poly._bounds_decision = bounds_decision
|
||||
|
||||
|
||||
class _DecisionByElimination:
|
||||
"""A decision procedure based on elimination of terms.
|
||||
|
||||
Given an expression `e = m_k*m + rest_e` for which we want to compute bounds,
|
||||
and a constraint `c = m_c*m + rest_c >= 0`,
|
||||
|
||||
Let `e0 = abs(m_c)*e - sgn(m_c)*m_k*c`. (Note that we eliminated `m` from
|
||||
`e0`, since `abs(m_c)*m_k = sgn(m_c)*m_k*m_c`.)
|
||||
|
||||
Since `c >= 0`,
|
||||
if `sgn(m_c)*m_k > 0`:
|
||||
then `abs(m_c)*e >= e0`, hence, `LB(e) >= ceil(LB(e0) / abs(m_c))`,
|
||||
|
||||
if `sgn(m_c)*m_k < 0`
|
||||
then `abs(m_c)*e <= e0`, hence, `UB(e) <= floor(UB(e0) / abs(m_c))`,
|
||||
"""
|
||||
def __init__(self, scope: SymbolicScope):
|
||||
self.scope = scope
|
||||
self._processed_for_internal_constraints: set[_DimMon] = set()
|
||||
# The other fields are for keeping an efficient representation of
|
||||
# the explicit constraints.
|
||||
self._term_bounds: dict[_DimMon, tuple[float, float]] = {}
|
||||
# The _expr_constraints represents a set of constraints that are not
|
||||
# just simple monomials. The set is represented as a mapping from a
|
||||
# monomial "m" to tuples (k, c) where "c >= 0" represents a constraint that
|
||||
# has "m" as the leading monomial with coefficient "k".
|
||||
self._expr_constraints: dict[_DimMon, set[tuple[int, _DimExpr]]] = collections.defaultdict(set)
|
||||
|
||||
# TODO: find a way to reuse the state reflecting the explicit constraints
|
||||
# We sort the constraints, so that the results of the heuristics do not
|
||||
# depend on the order in which the user writes the constraints.
|
||||
for c, c_str in sorted(scope._explicit_constraints,
|
||||
key=lambda c: c[0]._monomials_sorted):
|
||||
self.add_constraint(c, 0, c_str)
|
||||
|
||||
def add_constraint(self,
|
||||
e1: _DimExpr | int | float,
|
||||
e2: _DimExpr | int | float,
|
||||
constraint_str: str | None = None):
|
||||
"""Adds a constraint "e1 >= e2" to the internal state."""
|
||||
if isinstance(e1, float):
|
||||
if np.isinf(e1) and e1 >= 0: return
|
||||
assert e1 == np.floor(e1)
|
||||
e1 = int(e1)
|
||||
if isinstance(e2, float):
|
||||
if np.isinf(e2) and e2 <= 0: return
|
||||
e2 = int(e2)
|
||||
e = e1 if isinstance(e2, (int, float)) and e2 == 0 else e1 - e2
|
||||
if constraint_str is None:
|
||||
constraint_str = f"{e1} >= {e2}"
|
||||
if (const := _DimExpr.to_constant(e)) is not None:
|
||||
if const < 0:
|
||||
raise ValueError(f"Unsatisfiable constraint: {constraint_str}")
|
||||
return
|
||||
assert isinstance(e, _DimExpr)
|
||||
self._add_to_state(e, constraint_str)
|
||||
combinations = self._combine_with_existing_constraints(e, constraint_str)
|
||||
for a in combinations:
|
||||
self._add_to_state(a, f"{a} >= 0")
|
||||
|
||||
|
||||
def _combine_with_existing_constraints(self,
|
||||
e: _DimExpr,
|
||||
debug_str: str) -> set[_DimExpr]:
|
||||
# This combines `e` with those constraints already present. The resulting
|
||||
# constraints are not scanned for new internal constraints (because there
|
||||
# are no new monomials), but they are also not combined further.
|
||||
# TODO: this results in incompleteness, but it is probably a good
|
||||
# compromise.
|
||||
combinations: set[_DimExpr] = set()
|
||||
def acc_combination(e: _DimExpr | int):
|
||||
if (const := _DimExpr.to_constant(e)) is not None:
|
||||
if const < 0:
|
||||
raise ValueError(f"Unsatisfiable constraints: {debug_str}")
|
||||
else:
|
||||
combinations.add(e) # type: ignore
|
||||
|
||||
# First combine with the existing monomial constraints
|
||||
for e_m, e_c in e.monomials():
|
||||
if e_m.degree == 0: continue
|
||||
m_lb, m_ub = self._term_bounds.get(e_m, (-np.inf, np.inf))
|
||||
if e_c > 0:
|
||||
if m_ub < np.inf:
|
||||
e_minus_m = _DimExpr._merge_sorted_terms(e._monomials_sorted, 0, 1,
|
||||
[(e_m, e_c)], 0, -1)
|
||||
e_minus_m_ub = _DimExpr._merge_sorted_terms(e_minus_m, 0, 1,
|
||||
[(_DimMon(), 1)], 0, e_c * int(m_ub))
|
||||
acc_combination(_DimExpr(dict(e_minus_m_ub), e.scope))
|
||||
else:
|
||||
if m_lb > -np.inf:
|
||||
e_minus_m = _DimExpr._merge_sorted_terms(e._monomials_sorted, 0, 1,
|
||||
[(e_m, e_c)], 0, -1)
|
||||
e_minus_m_lb = _DimExpr._merge_sorted_terms(e_minus_m, 0, 1,
|
||||
[(_DimMon(), 1)], 0, e_c * int(m_lb))
|
||||
acc_combination(_DimExpr(dict(e_minus_m_lb), e.scope))
|
||||
|
||||
for prev_constraints in self._expr_constraints.values():
|
||||
for _, prev in prev_constraints:
|
||||
# Compose "e" with "prev" if they have one monomial with different
|
||||
# signs
|
||||
for e_m, e_c in e.monomials():
|
||||
if e_m.degree == 0: continue
|
||||
prev_c = prev._coeffs.get(e_m)
|
||||
if prev_c is not None and prev_c * e_c < 0:
|
||||
new_constraint = _DimExpr(
|
||||
dict(_DimExpr._merge_sorted_terms(e._monomials_sorted, 0, abs(prev_c),
|
||||
prev._monomials_sorted, 0, abs(e_c))),
|
||||
e.scope)
|
||||
acc_combination(new_constraint)
|
||||
break
|
||||
|
||||
return combinations
|
||||
|
||||
def _add_to_state(self, e: _DimExpr,
|
||||
constraint_str: str):
|
||||
"""Updates the internal state to reflect "e >= 0". """
|
||||
assert _DimExpr.to_constant(e) is None
|
||||
for m, m_c in e.monomials():
|
||||
if m.degree == 0: continue
|
||||
_add_internal_constraints(self, m, e.scope)
|
||||
|
||||
if (mon_factors := e.to_single_term()) is not None:
|
||||
n, mon_c, mon = mon_factors
|
||||
bounds = self._term_bounds.get(mon, (- np.inf, np.inf))
|
||||
if mon_c > 0:
|
||||
mon_ge = int(np.ceil(- n / mon_c))
|
||||
new_bounds = (max(mon_ge, bounds[0]), bounds[1])
|
||||
else:
|
||||
le = int(np.floor(-n / mon_c))
|
||||
new_bounds = (bounds[0], min(le, bounds[1]))
|
||||
if new_bounds[0] > new_bounds[1]:
|
||||
raise ValueError(f"Unsatisfiable constraint: {constraint_str}")
|
||||
|
||||
self._term_bounds[mon] = new_bounds
|
||||
return
|
||||
|
||||
lead_m, lead_m_c = e.leading_term
|
||||
self._expr_constraints[lead_m].add((lead_m_c, e))
|
||||
|
||||
def bounds(self, e: DimSize,
|
||||
stop_early: Callable[[float, float], bool] | None
|
||||
) -> tuple[float, float]:
|
||||
"""Returns the lower and upper bounds, or -+inf.
|
||||
|
||||
See more details in `_shape_poly.bounds_decision`.
|
||||
"""
|
||||
if (const := _DimExpr.to_constant(e)) is not None:
|
||||
return (const, const)
|
||||
assert isinstance(e, _DimExpr)
|
||||
cache_key = (e, stop_early)
|
||||
if (res := self.scope._bounds_cache.get(cache_key)) is not None: return res
|
||||
res = self._bounds_for_sorted_terms(e.scope, e._monomials_sorted, 0, stop_early)
|
||||
self.scope._bounds_cache[cache_key] = res
|
||||
return res
|
||||
|
||||
def _bounds_for_sorted_terms(self,
|
||||
scope: SymbolicScope,
|
||||
e: Sequence[tuple[_DimMon, int]],
|
||||
i: int,
|
||||
stop_early: Callable[[float, float], bool] | None) -> tuple[float, float]:
|
||||
"""The lower and upper bounds of e[i:].
|
||||
|
||||
See comments about soundness and `cmp_with` in the `_shape_poly.bounds_decision`` method.
|
||||
"""
|
||||
if i >= len(e): return (0, 0)
|
||||
|
||||
m, m_c = e[i]
|
||||
if len(m) == 0: # A constant
|
||||
assert i == len(e) - 1 # Must be last
|
||||
return (m_c, m_c)
|
||||
|
||||
_add_internal_constraints(self, m, scope)
|
||||
lb = -np.inf
|
||||
ub = np.inf
|
||||
|
||||
# Look among the term bounds
|
||||
if m in self._term_bounds:
|
||||
m_lb, m_ub = self._term_bounds.get(m, (- np.inf, np.inf))
|
||||
rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, e, i + 1, None)
|
||||
if m_c > 0:
|
||||
lb = max(lb, m_c * m_lb + rest_lb)
|
||||
ub = min(ub, m_c * m_ub + rest_ub)
|
||||
else:
|
||||
lb = max(lb, m_c * m_ub + rest_lb)
|
||||
ub = min(ub, m_c * m_lb + rest_ub)
|
||||
|
||||
if stop_early is not None and stop_early(lb, ub): return (lb, ub)
|
||||
|
||||
# Now look through the _expr_constraints
|
||||
if m in self._expr_constraints:
|
||||
for m_k, c in self._expr_constraints[m]:
|
||||
# A complex expression. See comments from top of class.
|
||||
sgn_m_k = 1 if m_k > 0 else -1
|
||||
abs_m_k = m_k * sgn_m_k
|
||||
# The recursive call has a smaller leading monomial, because we are only
|
||||
# looking at the tail of e, and in c the largest monomial is m, and the
|
||||
# merging will cancel the m.
|
||||
rest = _DimExpr._merge_sorted_terms(e, i, abs_m_k,
|
||||
c._monomials_sorted, 0, - sgn_m_k * m_c)
|
||||
rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, rest, 0, None)
|
||||
if m_c / m_k > 0:
|
||||
lb = max(lb, np.ceil(rest_lb / abs_m_k))
|
||||
else:
|
||||
ub = min(ub, np.floor(rest_ub / abs_m_k))
|
||||
if stop_early is not None and stop_early(lb, ub): return (lb, ub)
|
||||
|
||||
# Now look for special rules for atoms
|
||||
if (m_a := m.to_atom()) is not None:
|
||||
if m_a.operation in [_DimAtom.MAX, _DimAtom.MIN]:
|
||||
# m_c*MAX(op1, op2) + rest_e >= max(m_c * op1 + rest_e, m_c * op2 + rest_e)
|
||||
# if m_c > 0. Similar rules for when m_c < 0 and for MIN.
|
||||
op1, op2 = m_a.operands
|
||||
rest1 = _DimExpr._merge_sorted_terms(e, i + 1, 1,
|
||||
op1._monomials_sorted, 0, m_c)
|
||||
rest2 = _DimExpr._merge_sorted_terms(e, i + 1, 1,
|
||||
op2._monomials_sorted, 0, m_c)
|
||||
rest1_lb, rest1_ub = self._bounds_for_sorted_terms(scope, rest1, 0, None)
|
||||
rest2_lb, rest2_ub = self._bounds_for_sorted_terms(scope, rest2, 0, None)
|
||||
like_max = (m_c > 0 if m_a.operation == _DimAtom.MAX else m_c < 0)
|
||||
if like_max:
|
||||
lb = max(lb, max(rest1_lb, rest2_lb))
|
||||
ub = min(ub, max(rest1_ub, rest2_ub))
|
||||
else:
|
||||
lb = max(lb, min(rest1_lb, rest2_lb))
|
||||
ub = min(ub, min(rest1_ub, rest2_ub))
|
||||
if stop_early is not None and stop_early(lb, ub): return (lb, ub)
|
||||
|
||||
return lb, ub
|
||||
|
||||
def _add_internal_constraints(decision: _DecisionByElimination, m: _DimMon, scope: SymbolicScope):
|
||||
"""Adds the internal constraints for the monomial `m`."""
|
||||
if m in decision._processed_for_internal_constraints: return
|
||||
decision._processed_for_internal_constraints.add(m)
|
||||
m_e = _DimExpr.from_monomial(m, 1, scope) # m as a _DimExpr
|
||||
a = m.to_atom()
|
||||
if a is None:
|
||||
# This is a multiplication of atoms. Try to compute bounds based on
|
||||
# the bounds of the atoms.
|
||||
bounds = []
|
||||
for a, exp in m.items():
|
||||
a_l, a_u = decision.bounds(_DimExpr.from_monomial(_DimMon.from_atom(a, 1),
|
||||
1, scope), None)
|
||||
assert a_l <= a_u
|
||||
bounds.append((a_l ** exp, a_u ** exp))
|
||||
|
||||
candidate_bounds = [math.prod(atom_bounds)
|
||||
for atom_bounds in itertools.product(*bounds)]
|
||||
m_l = min(*candidate_bounds)
|
||||
m_u = max(*candidate_bounds)
|
||||
decision.add_constraint(m_e, m_l)
|
||||
decision.add_constraint(m_u, m_e)
|
||||
return
|
||||
|
||||
# It is an atom, is it a variable?
|
||||
if (v := a.to_var()) is not None:
|
||||
decision.add_constraint(m_e, 1) # v >= 1
|
||||
return
|
||||
|
||||
if a.operation == _DimAtom.MOD:
|
||||
op1, op2 = a.operands
|
||||
op2_b_l, op2_b_u = decision.bounds(op2, _stop_early_for_geq0)
|
||||
if op2_b_l > 0: # positive divisor
|
||||
decision.add_constraint(m_e, 0) # m >= 0
|
||||
decision.add_constraint(op2 - 1, m_e) # m <= op2 - 1
|
||||
decision.add_constraint(op2_b_u - 1, m_e)
|
||||
elif op2_b_u < 0: # negative divisor
|
||||
decision.add_constraint(m_e, op2 + 1) # m >= op2 + 1
|
||||
decision.add_constraint(m_e, op2_b_l + 1)
|
||||
decision.add_constraint(0, m_e) # m <= 0
|
||||
return
|
||||
|
||||
if a.operation == _DimAtom.FLOORDIV:
|
||||
op1, op2 = a.operands
|
||||
(op1_l, op1_u) = decision.bounds(op1, None)
|
||||
(op2_l, op2_u) = decision.bounds(op2, None)
|
||||
|
||||
def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
|
||||
# When either a or b are infinite, the results represent the limit
|
||||
# of "a // b".
|
||||
assert b != 0
|
||||
if not np.isinf(b): # divisor b is finite
|
||||
if not np.isinf(a):
|
||||
return math.floor(a / b)
|
||||
# a is infinite, b is finite
|
||||
return -np.inf if (a >= 0) != (b >= 0) else np.inf
|
||||
elif not np.isinf(a): # dividend a is finite and divisor b is infinite
|
||||
return -1 if (a >= 0) != (b >= 0) else 0
|
||||
else: # both dividend and divisor are infinite
|
||||
return -np.inf if (a >= 0) != (b >= 0) else np.inf
|
||||
|
||||
# Same reasoning as for multiplication: the bounds are among the cross-product
|
||||
# of the bounds.
|
||||
candidate_bounds = [math_floor_with_inf(op1_l, op2_l),
|
||||
math_floor_with_inf(op1_l, op2_u),
|
||||
math_floor_with_inf(op1_u, op2_l),
|
||||
math_floor_with_inf(op1_u, op2_u)]
|
||||
m_l = min(*candidate_bounds)
|
||||
m_u = max(*candidate_bounds)
|
||||
decision.add_constraint(m_e, m_l)
|
||||
decision.add_constraint(m_u, m_e)
|
||||
if op2_l >= 0:
|
||||
if decision.bounds(op1, _stop_early_for_geq0)[0] >= 0:
|
||||
decision.add_constraint(m_e, 0)
|
||||
mod_e = _DimExpr.from_operation(_DimAtom.MOD, op1, op2,
|
||||
scope=scope)
|
||||
combined = op2 * m_e + mod_e
|
||||
decision.add_constraint(op1, combined)
|
||||
decision.add_constraint(combined, op1)
|
||||
return
|
||||
|
||||
if a.operation == _DimAtom.MAX:
|
||||
op1, op2 = a.operands
|
||||
op1_b_l, op1_b_u = decision.bounds(op1, None)
|
||||
op2_b_l, op2_b_u = decision.bounds(op2, None)
|
||||
decision.add_constraint(m_e, max(op1_b_l, op2_b_l))
|
||||
decision.add_constraint(max(op1_b_u, op2_b_u), m_e)
|
||||
decision.add_constraint(m_e, op1)
|
||||
decision.add_constraint(m_e, op2)
|
||||
|
||||
if a.operation == _DimAtom.MIN:
|
||||
op1, op2 = a.operands
|
||||
op1_b_l, op1_b_u = decision.bounds(op1, None)
|
||||
op2_b_l, op2_b_u = decision.bounds(op2, None)
|
||||
decision.add_constraint(m_e, min(op1_b_l, op2_b_l))
|
||||
decision.add_constraint(min(op1_b_u, op2_b_u), m_e)
|
||||
decision.add_constraint(op1, m_e)
|
||||
decision.add_constraint(op2, m_e)
|
@ -37,6 +37,7 @@ import re
|
||||
import jax
|
||||
from jax.experimental import export
|
||||
from jax.experimental.export import _shape_poly as shape_poly
|
||||
from jax.experimental.export import _shape_poly_decision as shape_poly_decision
|
||||
from jax.experimental import pjit
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
@ -73,11 +74,13 @@ expect_error_associative_scan = (
|
||||
def _expect(*, current, best):
|
||||
return current
|
||||
|
||||
|
||||
def _bounds(e: shape_poly.DimSize) -> tuple[float, float]:
|
||||
if not isinstance(e, shape_poly._DimExpr):
|
||||
e = shape_poly._ensure_poly(e, "_bounds", shape_poly.SymbolicScope())
|
||||
return e.bounds()
|
||||
if isinstance(e, shape_poly._DimExpr):
|
||||
scope = e.scope
|
||||
else:
|
||||
scope = shape_poly.SymbolicScope()
|
||||
decision = shape_poly_decision._DecisionByElimination(scope)
|
||||
return decision.bounds(e, None)
|
||||
|
||||
def _assert_equal_bounds(tst: jtu.JaxTestCase,
|
||||
e: shape_poly.DimSize,
|
||||
@ -473,8 +476,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.assertEqual(_bounds(core.non_negative_dim(a - 5)), (0, np.inf))
|
||||
self.assertEqual(_bounds(core.non_negative_dim(15 - a)), (0, 14))
|
||||
self.assertEqual(_bounds(core.non_negative_dim(15 - a) // 3), (0, 4))
|
||||
self.assertEqual(_bounds(a - core.non_negative_dim(a - 3)),
|
||||
_expect(best=(1, 3), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(a - core.non_negative_dim(a - 3)), (1, 3))
|
||||
|
||||
def test_max_dim(self):
|
||||
a, b, c, d = shape_poly.symbolic_shape("a, b, c, d")
|
||||
@ -497,28 +499,23 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertGreaterEqual(core.max_dim(a, b), core.min_dim(a, b))
|
||||
|
||||
self.assertEqual(_bounds(b - core.max_dim(b - a, 0)),
|
||||
_expect(best=(1, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(a - core.min_dim(a, b)),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(b - core.min_dim(a, b)),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(b - core.max_dim(b - a, 0)), (1, np.inf))
|
||||
self.assertEqual(_bounds(a - core.min_dim(a, b)), (0, np.inf))
|
||||
self.assertEqual(_bounds(b - core.min_dim(a, b)), (0, np.inf))
|
||||
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) - a),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) - b),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) - a), (0, np.inf))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) - b), (0, np.inf))
|
||||
self.assertEqual((0, 0), _bounds(core.max_dim(1 - b, 0)))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) - a - core.max_dim(b - a, 0)),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
_expect(best=(0, 0), current=(0, np.inf)))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) + core.max_dim(c, d) -
|
||||
core.max_dim(a + core.max_dim(c, d),
|
||||
b + core.max_dim(c, d))),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
_expect(best=(0, 0), current=(0, np.inf)))
|
||||
self.assertEqual(_bounds(core.max_dim(a, b) + core.min_dim(c, d) -
|
||||
core.max_dim(a + core.min_dim(c, d),
|
||||
b + core.min_dim(c, d))),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
_expect(best=(0, 0), current=(0, np.inf)))
|
||||
|
||||
self.sampled_assertion(core.max_dim(a, 5), core.max_dim, a, 5)
|
||||
self.sampled_assertion(core.max_dim(5, a), core.max_dim, 5, a)
|
||||
@ -758,10 +755,11 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
"m >= 0", "m <= 3"]
|
||||
scope = shape_poly.SymbolicScope(assumptions)
|
||||
a, d = shape_poly.symbolic_shape("a, d", scope=scope)
|
||||
self.assertEqual(_bounds(a - 4*d),
|
||||
_expect(best=(1, 3), current=(-np.inf, np.inf))) # a - 4d = m >= 1
|
||||
self.assertEqual(_bounds(a - 2*d),
|
||||
_expect(best=(3, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(a - 4*d), (1, 3)) # a - 4d = m >= 1
|
||||
self.assertEqual(_bounds(a - 2*d), (3, np.inf)) # a - 2d = m + 2d >= 3
|
||||
# TODO: The incompleteness is due to the way we combine external constraints
|
||||
self.assertEqual(_bounds(a),
|
||||
_expect(best=(5, np.inf), current=(4, np.inf))) # a >= 4d + m >= 5
|
||||
|
||||
def test_constraints_errors(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
@ -806,14 +804,10 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
a, b, c = shape_poly.symbolic_shape(
|
||||
"a, b, c",
|
||||
constraints=("a + 2 <= b", "b <= a + 5", "a + b >= c"))
|
||||
self.assertEqual(_bounds(b),
|
||||
_expect(best=(3, np.inf), current=(1, np.inf)))
|
||||
self.assertEqual(_bounds(b - a),
|
||||
_expect(best=(2, 5), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(b - a - 7),
|
||||
_expect(best=(-5, -2), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(c - 2*a - 5),
|
||||
_expect(best=(-np.inf, 0), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(b), (3, np.inf))
|
||||
self.assertEqual(_bounds(b - a), (2, 5))
|
||||
self.assertEqual(_bounds(b - a - 7), (-5, -2))
|
||||
self.assertEqual(_bounds(c - 2*a - 5), (-np.inf, 0))
|
||||
|
||||
def test_constraints_fractional(self):
|
||||
a, = shape_poly.symbolic_shape("a",
|
||||
@ -823,11 +817,11 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(constraint="2*a + 3*b >= 10", exp="a + b - 2",
|
||||
bounds=_expect(best=(2, np.inf), current=(0, np.inf))),
|
||||
bounds=(2, np.inf)),
|
||||
dict(constraint="-2*a + 3*b >= 10", exp="a + 2*b",
|
||||
bounds=_expect(best=(9, np.inf), current=(3, np.inf))),
|
||||
bounds=(9, np.inf)),
|
||||
dict(constraint="-2*a + -3*b >= -10", exp="-1*a + 2*b",
|
||||
bounds=_expect(best=(-1, 3), current=(-np.inf, np.inf))),
|
||||
bounds=(-1, 3)),
|
||||
dict(constraint="2*a + -3*b >= 10", exp="-1*a + 2*b", bounds=(-np.inf, np.inf)),
|
||||
]
|
||||
)
|
||||
@ -872,8 +866,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
"a4 >= a5",
|
||||
)
|
||||
)
|
||||
self.assertEqual(_bounds(a1 - a5),
|
||||
_expect(best=(0, np.inf), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(a1 - a5), (0, np.inf))
|
||||
|
||||
def test_constraints_rounding_monomials(self):
|
||||
a1, a2 = shape_poly.symbolic_shape(
|
||||
@ -893,10 +886,8 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
"2*a1 >= 2*a2 + 3", "-2*a1 + 2*a2 >= -5"
|
||||
)
|
||||
)
|
||||
self.assertEqual(_bounds(a1 - a2 - 2),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(a2 - a1 + 2),
|
||||
_expect(best=(0, 0), current=(-np.inf, np.inf)))
|
||||
self.assertEqual(_bounds(a1 - a2 - 2), (0, 0))
|
||||
self.assertEqual(_bounds(a2 - a1 + 2), (0, 0))
|
||||
|
||||
def test_constraints_unsat_trivial(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
@ -905,10 +896,8 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
"a1", constraints=("a1 >= a1 + 1",))
|
||||
|
||||
def test_constraints_unsat_monomials(self):
|
||||
with self.assertRaisesRegex(_expect(best=ValueError,
|
||||
current=shape_poly.InconclusiveDimensionOperation),
|
||||
_expect(best="Unsatisfiable constraint",
|
||||
current="inconclusive")):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Unsatisfiable constraint"):
|
||||
a1, a2, *_ = shape_poly.symbolic_shape(
|
||||
"a1, a2, a3, a4",
|
||||
constraints=(
|
||||
@ -919,10 +908,8 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
a1 >= a2
|
||||
|
||||
def test_constraints_unsat_not_monomials(self):
|
||||
with self.assertRaisesRegex(_expect(best=ValueError,
|
||||
current=shape_poly.InconclusiveDimensionOperation),
|
||||
_expect(best="Unsatisfiable constraint",
|
||||
current="inconclusive")):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Unsatisfiable constraint"):
|
||||
a1, a2, a3, a4 = shape_poly.symbolic_shape(
|
||||
"a1, a2, a3, a4",
|
||||
constraints=(
|
||||
|
Loading…
x
Reference in New Issue
Block a user