[shape_poly] Fixes for the lexicographic ordering of monomials.

There were several bugs in the ordering of atoms and
monomials. The ordering for atoms and moomials is used
for sorting, and the __eq__ is also used for hashing.

One bug was that the ordering of atoms sometimes used
the `id` ordering. Another (performance) bug was that
the __eq__ for atoms used the (semantic) __eq__ for
DimExpr. The latter is expensive to compute, but for
sorting all we need is a syntactic comparison.

We introduce a `_syntactic_cmp` method for atoms,
monomials and expressions and we use it exclusively
for the ordering of atoms and monomials.

We also clean up printing and add tests for ordering and
pretty printing. Now we print monomial in "decreasing" order.
This is a change from before, in the sense that "a + b" is
printed as "b + a".
This commit is contained in:
George Necula 2024-01-05 14:48:53 +07:00
parent cbd453ec96
commit f1c87e0176
5 changed files with 235 additions and 95 deletions

View File

@ -44,7 +44,7 @@ import math
import operator as op
import threading
import tokenize
from typing import Any, Optional, Union
from typing import Any, Union
import numpy as np
import opt_einsum
@ -61,9 +61,9 @@ from jax._src.interpreters import mlir
from jax._src.numpy import lax_numpy
from jax._src import tree_util
from jax._src import util
from jax._src.typing import DimSize, Shape
DimSize = Union["_DimExpr", int]
TfVal = Any
DimVarEnv = dict[str, jax.Array]
DType = Any
@ -129,6 +129,11 @@ class _DimAtom:
def from_var(cls, v: str) -> _DimAtom:
return _DimAtom(var=v)
@classmethod
def from_operation(cls, operation: str, *operands: DimSize) -> _DimAtom:
return _DimAtom(*(_ensure_poly(o, operation) for o in operands),
operation=operation)
def to_var(self) -> str | None:
return self.var
@ -142,10 +147,6 @@ class _DimAtom:
acc.update(opnd.get_vars())
return acc
@classmethod
def from_operation(cls, operation: str, *operands: _DimExpr) -> _DimAtom:
return _DimAtom(*operands, operation=operation)
def __str__(self):
if self.var is not None:
return self.var
@ -156,38 +157,41 @@ class _DimAtom:
def __hash__(self):
return hash((self.var, self.operation, *self.operands))
def __eq__(self, other: Any):
# Used only for hashing
if not isinstance(other, _DimAtom): return False
if (self.var is None) != (other.var is None): return False
def _syntactic_cmp(self, other: _DimAtom) -> int:
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically (syntactic), to be used for sorting.
The result is not related to the semantic value.
"""
if self.var is not None:
return self.var == other.var
else:
def symbolic_equal(e1: _DimExpr, e2: _DimExpr) -> bool:
try:
return e1 == e2
except InconclusiveDimensionOperation:
return False
return (self.operation == other.operation and
all(symbolic_equal(self_o, other_o)
for self_o, other_o in zip(self.operands, other.operands)))
if other.var is not None:
return cmp_comparable(self.var, other.var)
else:
return -1
if other.var is not None: return 1
if c := cmp_comparable(self.operation, other.operation): return c # type: ignore
return cmp_sequence(self.operands, other.operands,
lambda s_o, o_o: s_o._syntactic_cmp(o_o))
def __eq__(self, other: Any):
"""Lexicographic comparison."""
if not isinstance(other, _DimAtom): return False
return self._syntactic_cmp(other) == 0
def __lt__(self, other: _DimAtom):
"""
Comparison to another atom in graded reverse lexicographic order.
Used only for determining a sorting order, does not relate to the
comparison of the values of the atom.
"""
if self.var is not None and other.var is not None:
return self.var < other.var
elif self.var is not None:
return True
elif other.var is not None:
return True
elif self.operation != other.operation:
return self.operation < other.operation # type: ignore
else:
return id(self) < id(other)
"""Lexicographic comparison."""
return self._syntactic_cmp(other) < 0
def __le__(self, other: _DimAtom):
"""Lexicographic comparison."""
return self._syntactic_cmp(other) <= 0
def __gt__(self, other: _DimAtom):
"""Lexicographic comparison."""
return self._syntactic_cmp(other) > 0
def __ge__(self, other: _DimAtom):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) >= 0
def bounds(self) -> tuple[float, float]:
"""Returns the lower and upper bounds, or -+ inf."""
@ -275,16 +279,26 @@ class _DimMon(dict):
def from_atom(clscls, a: _DimAtom, aexp: int):
return _DimMon({a: aexp})
@classmethod
def from_operation(cls, operation: str, *operands: DimSize) -> _DimMon:
return _DimMon({_DimAtom.from_operation(operation, *operands): 1})
def to_var(self) -> str | None:
"""Extract the variable name "x", from a monomial "x".
Return None, if the monomial is not a single variable."""
"""Extract the variable name from a monomial.
Return None if the monomial is not a single variable."""
a = self.to_atom()
return a.to_var() if a is not None else None
def to_atom(self) -> _DimAtom | None:
"""Extract the single atom from a monomial.
Return None if the monomial is not a single atom."""
items = self.items()
if len(items) != 1:
return None
(a, aexp), = items
if aexp != 1:
return None
return a.to_var()
return a
def get_vars(self) -> set[str]:
# All the vars that appear in the monomial
@ -293,23 +307,39 @@ class _DimMon(dict):
acc.update(a.get_vars())
return acc
@classmethod
def from_operation(cls, operation: str, *operands: _DimExpr) -> _DimMon:
return _DimMon({_DimAtom.from_operation(operation, *operands): 1})
@property
def degree(self):
return sum(self.values())
def _syntactic_cmp(self, other: _DimMon) -> int:
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically (syntactic), to be used for sorting.
The result is not related to the semantic value.
"""
if c := cmp_comparable(self.degree, other.degree): return c
def cmp_atom(s_a: tuple[_DimAtom, int], o_a: tuple[_DimAtom, int]) -> int:
if c := s_a[0]._syntactic_cmp(o_a[0]): return c
# Consider the monomials with exponents to be expanded as multiplications.
# Then a higher exponent for a "small" atom should lead to a "smaller" monomial.
return - cmp_comparable(s_a[1], o_a[1])
return cmp_sequence(sorted(self.items()), sorted(other.items()), cmp_atom)
def __lt__(self, other: _DimMon):
"""
Comparison to another monomial in graded reverse lexicographic order.
Used only for determining a sorting order, does not relate to the
comparison of the values of the monomial.
"""
self_key = -self.degree, tuple(sorted(self))
other_key = -other.degree, tuple(sorted(other))
return self_key > other_key
"""Lexicographic comparison"""
return self._syntactic_cmp(other) < 0
def __le__(self, other: _DimMon):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) <= 0
def __gt__(self, other: _DimMon):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) > 0
def __ge__(self, other: _DimMon):
"""Lexicographic comparison"""
return self._syntactic_cmp(other) >= 0
def mul(self, other: _DimMon) -> _DimMon:
"""
@ -344,7 +374,6 @@ class _DimMon(dict):
candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)]
return (min(*candidates), max(*candidates)) # type: ignore
def evaluate(self, env: DimVarEnv):
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1)
def pow_opt(v, p: int):
@ -366,7 +395,6 @@ class _DimExpr():
integer coefficients. The special monomial `_DimMon()` is mapped to the
free integer coefficient of the expression.
"""
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
def __init__(self, coeffs: dict[_DimMon, int]):
# Do not construct _DimExpr directly, unless you are sure that coeffs is
@ -377,6 +405,11 @@ class _DimExpr():
def monomials(self) -> Iterable[tuple[_DimMon, int]]:
return self._coeffs.items()
def monomials_sorted(self, reverse=False):
"""The monomials in sorted lexicographic order.
Higher-degree monomials come later in the order."""
return sorted(self.monomials(), reverse=reverse)
@classmethod
def _add_coeffs(cls, coeffs: dict[_DimMon, int], mon: _DimMon, coeff: int):
"""Do `coeffs[mon] += coeff` but remove 0 coefficients."""
@ -430,26 +463,39 @@ class _DimExpr():
return _DimExpr.normalize(coeffs)
@classmethod
def from_monomial(cls, mon: _DimMon, exp: int):
return _DimExpr.normalize({mon: exp})
def from_monomial(cls, mon: _DimMon, count: int):
return _DimExpr.normalize({mon: count})
@classmethod
def from_var(cls, v: str) -> _DimExpr:
return _DimExpr({_DimMon.from_var(v): 1})
@classmethod
def from_operation(cls, operation: str, *operands: _DimExpr) -> _DimExpr:
def from_operation(cls, operation: str, *operands: DimSize) -> _DimExpr:
return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1)
def to_var(self) -> str | None:
"""Extract the variable name "x", from a symbolic expression."""
def to_monomial(self) -> _DimMon | None:
"""Extract the single monomial from a symbolic expression.
Returns None if the expression is not a single monomial."""
items = self.monomials()
if len(items) != 1: # type: ignore
return None
(mon, mon_count), = items
if mon_count != 1:
return None
return mon.to_var()
return mon
def to_atom(self) -> _DimAtom | None:
"""Extract the atom from a symbolic expression.
Returns None if the expression is not a single atom."""
mon = self.to_monomial()
return mon.to_atom() if mon is not None else None
def to_var(self) -> str | None:
"""Extract the variable name from a symbolic expression.
Returns None if the expression is not a single variable."""
mon = self.to_atom()
return mon.to_var() if mon is not None else None
def get_vars(self) -> set[str]:
"""The variables that appear in a symbolic dimension."""
@ -458,7 +504,19 @@ class _DimExpr():
acc.update(mon.get_vars())
return acc
def eq(self, other: DimSize) -> bool:
def _syntactic_cmp(self, other: _DimExpr) -> int:
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically (syntactic), to be used for sorting.
The result is not related to the semantic value.
"""
s_mons = self.monomials_sorted()
o_mons = other.monomials_sorted()
def cmp_mon(s_mon: tuple[_DimMon, int], o_mon: tuple[_DimMon, int]) -> int:
if c := s_mon[0]._syntactic_cmp(o_mon[0]): return c
return cmp_comparable(s_mon[1], o_mon[1])
return cmp_sequence(s_mons, o_mons, cmp_mon)
def eq(self, other) -> bool:
lb, ub = _ensure_poly(self - other, "eq").bounds()
if lb == ub == 0:
return True
@ -522,8 +580,11 @@ class _DimExpr():
if c == 1:
return str(mon)
return f"{c}*{mon}"
return " + ".join(_one_monomial(mon, c)
for mon, c in sorted(self.monomials(), reverse=True))
# We print first the "larger" monomials, so that the constant is last.
res = " + ".join(_one_monomial(mon, c)
for mon, c in self.monomials_sorted(reverse=True))
res = res.replace(" + -", " - ")
return res
def __repr__(self):
return str(self)
@ -627,7 +688,7 @@ class _DimExpr():
# We must overload __eq__ and __ne__, or else we get unsound defaults.
__eq__ = eq
def __ne__(self, other: DimSize) -> bool:
def __ne__(self, other) -> bool:
return not self.eq(other)
__ge__ = ge
@ -759,6 +820,21 @@ class _DimExpr():
# Used for implicit coercions of polynomials as JAX arrays
return _dim_as_value(self)
def cmp_comparable(i1, i2) -> int:
if i1 < i2: return -1
if i1 > i2: return 1
return 0
def cmp_sequence(s1, s2, elem_cmp) -> int:
"""Compares two sequences using `elem_cmp`."""
l2 = len(s2)
for i, e1 in enumerate(s1):
if i >= l2: return 1
if c := elem_cmp(e1, s2[i]): return c
if len(s1) < l2: return -1
return 0
@dataclasses.dataclass
class _Decomposition:
"""Decomposition of an expression around an operation atom.
@ -819,7 +895,7 @@ dtypes._weak_types.append(_DimExpr)
def _convertible_to_int(p: DimSize) -> bool:
try:
op.index(p)
op.index(p) # type: ignore
return True
except:
return False
@ -1053,7 +1129,7 @@ class _Parser:
if core.is_constant_dim(expr) and self.like_shape is not None:
like_shape_dim = self.like_shape[len(self.dimensions)]
if expr != like_shape_dim:
if expr != like_shape_dim: # type: ignore[operator]
raise self.parse_err(tok,
(f"different size {expr} for known dimension; "
f"like={self.like_shape}"))
@ -1178,11 +1254,11 @@ class _Parser:
def atom(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
if tok.exact_type == tokenize.NAME:
if tok.string == _DimAtom.MOD:
return self.binary_op(_DimAtom.MOD, self.next_tok())
return self.atom_binary_op(_DimAtom.MOD, self.next_tok())
if tok.string == _DimAtom.FLOORDIV:
return self.binary_op(_DimAtom.FLOORDIV, self.next_tok())
return self.atom_binary_op(_DimAtom.FLOORDIV, self.next_tok())
if tok.string == _DimAtom.NON_NEGATIVE:
return self.unary_op(_DimAtom.NON_NEGATIVE, self.next_tok())
return self.atom_unary_op(_DimAtom.NON_NEGATIVE, self.next_tok())
return _DimExpr.from_var(tok.string), self.next_tok()
number_sign = 1
if tok.exact_type == tokenize.MINUS: # -k are negative constants
@ -1195,13 +1271,13 @@ class _Parser:
self.expect_token(tok, [tokenize.NAME, tokenize.MINUS, tokenize.NUMBER])
assert False
def unary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]:
def atom_unary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]:
tok = self.consume_token(tok, tokenize.LPAR)
e1, tok = self.expr(tok)
tok = self.consume_token(tok, tokenize.RPAR)
return _DimExpr.from_operation(op, e1), tok # type: ignore
def binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]:
def atom_binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]:
tok = self.consume_token(tok, tokenize.LPAR)
e1, tok = self.expr(tok)
tok = self.consume_token(tok, tokenize.COMMA)
@ -1274,7 +1350,7 @@ class CachingShapeEvaluator:
@functools.lru_cache(128)
def evaluate(self, e: DimSize):
if core.is_constant_dim(e):
res = op.index(e)
res = op.index(e) # type: ignore
else:
res = e.evaluate(self.env) # type: ignore
return res

View File

@ -1143,7 +1143,7 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfV
partial(core.evaluate_shape, shape, dim_vars),
dim_values, [core.dim_value_aval()] * len(dim_values), "") # type: ignore
# Keep only the non-constant dimensions
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf # type: ignore
for d, d_tf in zip(shape, shape_values_tf))

View File

@ -604,9 +604,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
expect_error=(
"Input shapes do not match the polymorphic shapes specification. "
"Expected value >= 1 for dimension variable 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
"'b' = 0 from specification 'a + 2*b' for dimension args[0].shape[0] (= 2), . "
"'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
)),
dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer
@ -614,7 +614,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
expect_error=(
"Input shapes do not match the polymorphic shapes specification. "
"Division had remainder 1 when computing the value of 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
)),
@ -622,10 +622,10 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
poly_spec="(a + 2*b, a, a + b)",
expect_error=(
"Input shapes do not match the polymorphic shapes specification. "
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification 'a + 2*b' (= 10). "
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b). "
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification '2*b + a' (= 10). "
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= 6), . "
"'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
)),
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
@ -633,7 +633,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
expect_error=(
"Cannot solve for values of dimension variables {'c'}. "
"We can only solve linear uni-variate constraints. "
"Using the following polymorphic shapes specifications: args[0].shape = (2*a + b, a, c^2). "
"Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). "
"Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details."
)),
@ -1501,7 +1501,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=[x_polymorphic_shape, y_polymorphic_shape])(x, y)
self.assertEqual(np.float32, zw_specs[0].dtype)
self.assertEqual(np.float32, zw_specs[1].dtype)
self.assertEqual(("(a, 5)", "(a + b, 5)"), zw_polymorphic_shapes)
self.assertEqual(("(a, 5)", "(b + a, 5)"), zw_polymorphic_shapes)
# We can use the zw_polymorphic_shapes for jax2tf.convert
z, w = jax2tf.convert(

View File

@ -590,10 +590,10 @@ class JaxExportTest(jtu.JaxTestCase):
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c",
expect_error_outer_exp=re.escape(
"Expected value >= 1 for dimension variable 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, a + b). "
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, b + a). "
"Obtained dimension variables: 'a' = 4 from specification "
"'a' for dimension args[0].shape[1] (= 4), "
"'b' = c + -4 from specification 'a + b' for dimension args[0].shape[2] (= c),")),
"'b' = c - 4 from specification 'b + a' for dimension args[0].shape[2] (= c),")),
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,c",
expect_error_outer_exp=re.escape(
"Found inconsistency between dimension size "
@ -610,10 +610,10 @@ class JaxExportTest(jtu.JaxTestCase):
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
expect_error_outer_exp=re.escape(
"Expected value >= 1 for dimension variable 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, a + b). "
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, b + a). "
"Obtained dimension variables: 'a' = c from "
"specification 'a' for dimension args[0].shape[1] (= c), "
"'b' = 0 from specification 'a + b' for dimension args[0].shape[2] (= c)")),
"'b' = 0 from specification 'b + a' for dimension args[0].shape[2] (= c)")),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
expect_error_outer_exp=re.escape(
"Shape mismatch for args[0].shape[0] (expected same constant)")),
@ -689,9 +689,9 @@ class JaxExportTest(jtu.JaxTestCase):
expect_error=(
"Input shapes do not match the polymorphic shapes specification. "
"Expected value >= 1 for dimension variable 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
"'b' = 0 from specification 'a + 2*b' for dimension args[0].shape[0] (= 2), . "
"'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
)),
dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer
@ -699,7 +699,7 @@ class JaxExportTest(jtu.JaxTestCase):
expect_error=(
"Input shapes do not match the polymorphic shapes specification. "
"Division had remainder 1 when computing the value of 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
)),
@ -707,10 +707,10 @@ class JaxExportTest(jtu.JaxTestCase):
poly_spec="(a + 2*b, a, a + b)",
expect_error=(
"Input shapes do not match the polymorphic shapes specification. "
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification 'a + 2*b' (= 10). "
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b). "
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification '2*b + a' (= 10). "
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= 6), . "
"'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
)),
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
@ -718,7 +718,7 @@ class JaxExportTest(jtu.JaxTestCase):
expect_error=(
"Cannot solve for values of dimension variables {'c'}. "
"We can only solve linear uni-variate constraints. "
"Using the following polymorphic shapes specifications: args[0].shape = (2*a + b, a, c^2). "
"Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). "
"Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. "
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details."
)),

View File

@ -139,14 +139,34 @@ class DimExprTest(jtu.JaxTestCase):
("a + -1", a - 1),
("3 * a * mod(a + 2, b + 2)", 3 * a * ((a + 2) % (b + 2))),
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2),
("non_negative(a - 2)", core.non_negative_dim(a - 2)),
("non_negative(a - 2)", "build_inside"),
]])
def test_parse_dim(self,
dim_spec="-2 * a^2 * b + b^2",
dim_poly=-2 * a * a * b + b * b):
def test_parse_dim(self, dim_spec, dim_poly):
if dim_spec == "non_negative(a - 2)":
dim_poly = core.non_negative_dim(DimExprTest.a - 2)
self.assertEqual((dim_poly,), shape_poly.symbolic_shape(dim_spec))
self.assertEqual((dim_poly,), shape_poly.symbolic_shape(str(dim_poly)))
@jtu.parameterized_filterable(
kwargs=[
dict(dim_spec=dim_spec)
for dim_spec in [
"b + a",
"a*b + a^2 + b + a",
"mod(a, 4) + floordiv(a, 4) + a",
"2*a^2 - 3*a - 1",
"a^2 + 3*a - 1",
"-1*a + 3",
"-1*mod(a, 4) + 3",
"-2*a + 3",
"a*floordiv(b, 8)*mod(b, 4)",
]
]
)
def test_print_dim(self, *, dim_spec: str):
e, = shape_poly.symbolic_shape(dim_spec)
self.assertEqual(str(e), dim_spec)
@jtu.parameterized_filterable(
kwargs=[
# sanitized shape_spec sometimes collide
@ -233,6 +253,50 @@ class DimExprTest(jtu.JaxTestCase):
self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # An Array
self.assertFalse(core.definitely_equal(1, "a"))
def test_atoms_ordering(self):
a, b = shape_poly.symbolic_shape("a, b")
self.assertTrue(a.to_atom() < b.to_atom())
self.assertFalse(a.to_atom() >= b.to_atom())
self.assertTrue(a.to_atom() <= b.to_atom())
self.assertTrue(a.to_atom() != b.to_atom())
self.assertTrue(a.to_atom() < (a % 4).to_atom())
self.assertFalse(a.to_atom() > (a % 4).to_atom())
# FLOORDIV comes before MON because we compare operations alphabetically
self.assertTrue((a // 4).to_atom() < (a % 4).to_atom())
self.assertEqual(hash((a // 4).to_atom()), hash((a // 4).to_atom()))
def test_monomial_ordering(self):
a, b = shape_poly.symbolic_shape("a, b")
self.assertTrue(a.to_monomial() < b.to_monomial())
self.assertTrue(a.to_monomial() <= b.to_monomial())
self.assertTrue(b.to_monomial() >= a.to_monomial())
self.assertTrue(b.to_monomial() > a.to_monomial())
self.assertTrue(a.to_monomial() < (a * a).to_monomial())
self.assertTrue(b.to_monomial() < (a * a).to_monomial())
self.assertTrue((a * a * b).to_monomial() < (a * b * b).to_monomial())
e1 = a * a * b + a * b * b + a * b + a * a + a + b
sorted_e1 = [shape_poly._DimExpr.from_monomial(m, m_count)
for m, m_count in e1.monomials_sorted()]
self.assertSequenceEqual(sorted_e1,
[a, b, a * a, a * b, a * a * b, a * b * b])
e2 = a * (a // 4) + (a // 4) + b * (a // 4) + b * (a % 4) + a * a + b
sorted_e2 = [shape_poly._DimExpr.from_monomial(m, m_count)
for m, m_count in e2.monomials_sorted()]
self.assertSequenceEqual(sorted_e2,
[b, a // 4, a * a, a * (a // 4), b * (a // 4), b * (a % 4)])
# This failed with a previous implementation of atom equality
self.assertNotEqual(shape_poly._DimMon.from_operation(shape_poly._DimAtom.NON_NEGATIVE,
a - b - 1),
shape_poly._DimMon.from_operation(shape_poly._DimAtom.NON_NEGATIVE,
a - 2*b - 1))
def test_poly_bounds(self):
a, b = shape_poly.symbolic_shape("a, b")
bounded_le4 = 5 - a
@ -437,7 +501,7 @@ class DimExprTest(jtu.JaxTestCase):
(a * a - b * b, a + b, a - b, 0),
(a, b, "floordiv(a, b)", "mod(a, b)"),
(3 * a, 2, "floordiv(3*a, 2)", "mod(3*a, 2)"),
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, a + b)", "mod(2*a*b + b^2, a + b)"),
(2 * a * b + b * b, a + b, "floordiv(b^2 + 2*a*b, b + a)", "mod(b^2 + 2*a*b, b + a)"),
(3, a, "floordiv(3, a)", "mod(3, a)"),
]])
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):