mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Fixes for the lexicographic ordering of monomials.
There were several bugs in the ordering of atoms and monomials. The ordering for atoms and moomials is used for sorting, and the __eq__ is also used for hashing. One bug was that the ordering of atoms sometimes used the `id` ordering. Another (performance) bug was that the __eq__ for atoms used the (semantic) __eq__ for DimExpr. The latter is expensive to compute, but for sorting all we need is a syntactic comparison. We introduce a `_syntactic_cmp` method for atoms, monomials and expressions and we use it exclusively for the ordering of atoms and monomials. We also clean up printing and add tests for ordering and pretty printing. Now we print monomial in "decreasing" order. This is a change from before, in the sense that "a + b" is printed as "b + a".
This commit is contained in:
parent
cbd453ec96
commit
f1c87e0176
@ -44,7 +44,7 @@ import math
|
||||
import operator as op
|
||||
import threading
|
||||
import tokenize
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import opt_einsum
|
||||
@ -61,9 +61,9 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.typing import DimSize, Shape
|
||||
|
||||
|
||||
DimSize = Union["_DimExpr", int]
|
||||
TfVal = Any
|
||||
DimVarEnv = dict[str, jax.Array]
|
||||
DType = Any
|
||||
@ -129,6 +129,11 @@ class _DimAtom:
|
||||
def from_var(cls, v: str) -> _DimAtom:
|
||||
return _DimAtom(var=v)
|
||||
|
||||
@classmethod
|
||||
def from_operation(cls, operation: str, *operands: DimSize) -> _DimAtom:
|
||||
return _DimAtom(*(_ensure_poly(o, operation) for o in operands),
|
||||
operation=operation)
|
||||
|
||||
def to_var(self) -> str | None:
|
||||
return self.var
|
||||
|
||||
@ -142,10 +147,6 @@ class _DimAtom:
|
||||
acc.update(opnd.get_vars())
|
||||
return acc
|
||||
|
||||
@classmethod
|
||||
def from_operation(cls, operation: str, *operands: _DimExpr) -> _DimAtom:
|
||||
return _DimAtom(*operands, operation=operation)
|
||||
|
||||
def __str__(self):
|
||||
if self.var is not None:
|
||||
return self.var
|
||||
@ -156,38 +157,41 @@ class _DimAtom:
|
||||
def __hash__(self):
|
||||
return hash((self.var, self.operation, *self.operands))
|
||||
|
||||
def __eq__(self, other: Any):
|
||||
# Used only for hashing
|
||||
if not isinstance(other, _DimAtom): return False
|
||||
if (self.var is None) != (other.var is None): return False
|
||||
def _syntactic_cmp(self, other: _DimAtom) -> int:
|
||||
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
|
||||
The comparison is done lexicographically (syntactic), to be used for sorting.
|
||||
The result is not related to the semantic value.
|
||||
"""
|
||||
if self.var is not None:
|
||||
return self.var == other.var
|
||||
else:
|
||||
def symbolic_equal(e1: _DimExpr, e2: _DimExpr) -> bool:
|
||||
try:
|
||||
return e1 == e2
|
||||
except InconclusiveDimensionOperation:
|
||||
return False
|
||||
return (self.operation == other.operation and
|
||||
all(symbolic_equal(self_o, other_o)
|
||||
for self_o, other_o in zip(self.operands, other.operands)))
|
||||
if other.var is not None:
|
||||
return cmp_comparable(self.var, other.var)
|
||||
else:
|
||||
return -1
|
||||
if other.var is not None: return 1
|
||||
if c := cmp_comparable(self.operation, other.operation): return c # type: ignore
|
||||
return cmp_sequence(self.operands, other.operands,
|
||||
lambda s_o, o_o: s_o._syntactic_cmp(o_o))
|
||||
|
||||
def __eq__(self, other: Any):
|
||||
"""Lexicographic comparison."""
|
||||
if not isinstance(other, _DimAtom): return False
|
||||
return self._syntactic_cmp(other) == 0
|
||||
|
||||
def __lt__(self, other: _DimAtom):
|
||||
"""
|
||||
Comparison to another atom in graded reverse lexicographic order.
|
||||
Used only for determining a sorting order, does not relate to the
|
||||
comparison of the values of the atom.
|
||||
"""
|
||||
if self.var is not None and other.var is not None:
|
||||
return self.var < other.var
|
||||
elif self.var is not None:
|
||||
return True
|
||||
elif other.var is not None:
|
||||
return True
|
||||
elif self.operation != other.operation:
|
||||
return self.operation < other.operation # type: ignore
|
||||
else:
|
||||
return id(self) < id(other)
|
||||
"""Lexicographic comparison."""
|
||||
return self._syntactic_cmp(other) < 0
|
||||
|
||||
def __le__(self, other: _DimAtom):
|
||||
"""Lexicographic comparison."""
|
||||
return self._syntactic_cmp(other) <= 0
|
||||
|
||||
def __gt__(self, other: _DimAtom):
|
||||
"""Lexicographic comparison."""
|
||||
return self._syntactic_cmp(other) > 0
|
||||
|
||||
def __ge__(self, other: _DimAtom):
|
||||
"""Lexicographic comparison"""
|
||||
return self._syntactic_cmp(other) >= 0
|
||||
|
||||
def bounds(self) -> tuple[float, float]:
|
||||
"""Returns the lower and upper bounds, or -+ inf."""
|
||||
@ -275,16 +279,26 @@ class _DimMon(dict):
|
||||
def from_atom(clscls, a: _DimAtom, aexp: int):
|
||||
return _DimMon({a: aexp})
|
||||
|
||||
@classmethod
|
||||
def from_operation(cls, operation: str, *operands: DimSize) -> _DimMon:
|
||||
return _DimMon({_DimAtom.from_operation(operation, *operands): 1})
|
||||
|
||||
def to_var(self) -> str | None:
|
||||
"""Extract the variable name "x", from a monomial "x".
|
||||
Return None, if the monomial is not a single variable."""
|
||||
"""Extract the variable name from a monomial.
|
||||
Return None if the monomial is not a single variable."""
|
||||
a = self.to_atom()
|
||||
return a.to_var() if a is not None else None
|
||||
|
||||
def to_atom(self) -> _DimAtom | None:
|
||||
"""Extract the single atom from a monomial.
|
||||
Return None if the monomial is not a single atom."""
|
||||
items = self.items()
|
||||
if len(items) != 1:
|
||||
return None
|
||||
(a, aexp), = items
|
||||
if aexp != 1:
|
||||
return None
|
||||
return a.to_var()
|
||||
return a
|
||||
|
||||
def get_vars(self) -> set[str]:
|
||||
# All the vars that appear in the monomial
|
||||
@ -293,23 +307,39 @@ class _DimMon(dict):
|
||||
acc.update(a.get_vars())
|
||||
return acc
|
||||
|
||||
@classmethod
|
||||
def from_operation(cls, operation: str, *operands: _DimExpr) -> _DimMon:
|
||||
return _DimMon({_DimAtom.from_operation(operation, *operands): 1})
|
||||
|
||||
@property
|
||||
def degree(self):
|
||||
return sum(self.values())
|
||||
|
||||
def _syntactic_cmp(self, other: _DimMon) -> int:
|
||||
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
|
||||
The comparison is done lexicographically (syntactic), to be used for sorting.
|
||||
The result is not related to the semantic value.
|
||||
"""
|
||||
if c := cmp_comparable(self.degree, other.degree): return c
|
||||
def cmp_atom(s_a: tuple[_DimAtom, int], o_a: tuple[_DimAtom, int]) -> int:
|
||||
if c := s_a[0]._syntactic_cmp(o_a[0]): return c
|
||||
# Consider the monomials with exponents to be expanded as multiplications.
|
||||
# Then a higher exponent for a "small" atom should lead to a "smaller" monomial.
|
||||
return - cmp_comparable(s_a[1], o_a[1])
|
||||
|
||||
return cmp_sequence(sorted(self.items()), sorted(other.items()), cmp_atom)
|
||||
|
||||
def __lt__(self, other: _DimMon):
|
||||
"""
|
||||
Comparison to another monomial in graded reverse lexicographic order.
|
||||
Used only for determining a sorting order, does not relate to the
|
||||
comparison of the values of the monomial.
|
||||
"""
|
||||
self_key = -self.degree, tuple(sorted(self))
|
||||
other_key = -other.degree, tuple(sorted(other))
|
||||
return self_key > other_key
|
||||
"""Lexicographic comparison"""
|
||||
return self._syntactic_cmp(other) < 0
|
||||
|
||||
def __le__(self, other: _DimMon):
|
||||
"""Lexicographic comparison"""
|
||||
return self._syntactic_cmp(other) <= 0
|
||||
|
||||
def __gt__(self, other: _DimMon):
|
||||
"""Lexicographic comparison"""
|
||||
return self._syntactic_cmp(other) > 0
|
||||
|
||||
def __ge__(self, other: _DimMon):
|
||||
"""Lexicographic comparison"""
|
||||
return self._syntactic_cmp(other) >= 0
|
||||
|
||||
def mul(self, other: _DimMon) -> _DimMon:
|
||||
"""
|
||||
@ -344,7 +374,6 @@ class _DimMon(dict):
|
||||
candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)]
|
||||
return (min(*candidates), max(*candidates)) # type: ignore
|
||||
|
||||
|
||||
def evaluate(self, env: DimVarEnv):
|
||||
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1)
|
||||
def pow_opt(v, p: int):
|
||||
@ -366,7 +395,6 @@ class _DimExpr():
|
||||
integer coefficients. The special monomial `_DimMon()` is mapped to the
|
||||
free integer coefficient of the expression.
|
||||
"""
|
||||
|
||||
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
|
||||
def __init__(self, coeffs: dict[_DimMon, int]):
|
||||
# Do not construct _DimExpr directly, unless you are sure that coeffs is
|
||||
@ -377,6 +405,11 @@ class _DimExpr():
|
||||
def monomials(self) -> Iterable[tuple[_DimMon, int]]:
|
||||
return self._coeffs.items()
|
||||
|
||||
def monomials_sorted(self, reverse=False):
|
||||
"""The monomials in sorted lexicographic order.
|
||||
Higher-degree monomials come later in the order."""
|
||||
return sorted(self.monomials(), reverse=reverse)
|
||||
|
||||
@classmethod
|
||||
def _add_coeffs(cls, coeffs: dict[_DimMon, int], mon: _DimMon, coeff: int):
|
||||
"""Do `coeffs[mon] += coeff` but remove 0 coefficients."""
|
||||
@ -430,26 +463,39 @@ class _DimExpr():
|
||||
return _DimExpr.normalize(coeffs)
|
||||
|
||||
@classmethod
|
||||
def from_monomial(cls, mon: _DimMon, exp: int):
|
||||
return _DimExpr.normalize({mon: exp})
|
||||
def from_monomial(cls, mon: _DimMon, count: int):
|
||||
return _DimExpr.normalize({mon: count})
|
||||
|
||||
@classmethod
|
||||
def from_var(cls, v: str) -> _DimExpr:
|
||||
return _DimExpr({_DimMon.from_var(v): 1})
|
||||
|
||||
@classmethod
|
||||
def from_operation(cls, operation: str, *operands: _DimExpr) -> _DimExpr:
|
||||
def from_operation(cls, operation: str, *operands: DimSize) -> _DimExpr:
|
||||
return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1)
|
||||
|
||||
def to_var(self) -> str | None:
|
||||
"""Extract the variable name "x", from a symbolic expression."""
|
||||
def to_monomial(self) -> _DimMon | None:
|
||||
"""Extract the single monomial from a symbolic expression.
|
||||
Returns None if the expression is not a single monomial."""
|
||||
items = self.monomials()
|
||||
if len(items) != 1: # type: ignore
|
||||
return None
|
||||
(mon, mon_count), = items
|
||||
if mon_count != 1:
|
||||
return None
|
||||
return mon.to_var()
|
||||
return mon
|
||||
|
||||
def to_atom(self) -> _DimAtom | None:
|
||||
"""Extract the atom from a symbolic expression.
|
||||
Returns None if the expression is not a single atom."""
|
||||
mon = self.to_monomial()
|
||||
return mon.to_atom() if mon is not None else None
|
||||
|
||||
def to_var(self) -> str | None:
|
||||
"""Extract the variable name from a symbolic expression.
|
||||
Returns None if the expression is not a single variable."""
|
||||
mon = self.to_atom()
|
||||
return mon.to_var() if mon is not None else None
|
||||
|
||||
def get_vars(self) -> set[str]:
|
||||
"""The variables that appear in a symbolic dimension."""
|
||||
@ -458,7 +504,19 @@ class _DimExpr():
|
||||
acc.update(mon.get_vars())
|
||||
return acc
|
||||
|
||||
def eq(self, other: DimSize) -> bool:
|
||||
def _syntactic_cmp(self, other: _DimExpr) -> int:
|
||||
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
|
||||
The comparison is done lexicographically (syntactic), to be used for sorting.
|
||||
The result is not related to the semantic value.
|
||||
"""
|
||||
s_mons = self.monomials_sorted()
|
||||
o_mons = other.monomials_sorted()
|
||||
def cmp_mon(s_mon: tuple[_DimMon, int], o_mon: tuple[_DimMon, int]) -> int:
|
||||
if c := s_mon[0]._syntactic_cmp(o_mon[0]): return c
|
||||
return cmp_comparable(s_mon[1], o_mon[1])
|
||||
return cmp_sequence(s_mons, o_mons, cmp_mon)
|
||||
|
||||
def eq(self, other) -> bool:
|
||||
lb, ub = _ensure_poly(self - other, "eq").bounds()
|
||||
if lb == ub == 0:
|
||||
return True
|
||||
@ -522,8 +580,11 @@ class _DimExpr():
|
||||
if c == 1:
|
||||
return str(mon)
|
||||
return f"{c}*{mon}"
|
||||
return " + ".join(_one_monomial(mon, c)
|
||||
for mon, c in sorted(self.monomials(), reverse=True))
|
||||
# We print first the "larger" monomials, so that the constant is last.
|
||||
res = " + ".join(_one_monomial(mon, c)
|
||||
for mon, c in self.monomials_sorted(reverse=True))
|
||||
res = res.replace(" + -", " - ")
|
||||
return res
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
@ -627,7 +688,7 @@ class _DimExpr():
|
||||
|
||||
# We must overload __eq__ and __ne__, or else we get unsound defaults.
|
||||
__eq__ = eq
|
||||
def __ne__(self, other: DimSize) -> bool:
|
||||
def __ne__(self, other) -> bool:
|
||||
return not self.eq(other)
|
||||
|
||||
__ge__ = ge
|
||||
@ -759,6 +820,21 @@ class _DimExpr():
|
||||
# Used for implicit coercions of polynomials as JAX arrays
|
||||
return _dim_as_value(self)
|
||||
|
||||
def cmp_comparable(i1, i2) -> int:
|
||||
if i1 < i2: return -1
|
||||
if i1 > i2: return 1
|
||||
return 0
|
||||
|
||||
def cmp_sequence(s1, s2, elem_cmp) -> int:
|
||||
"""Compares two sequences using `elem_cmp`."""
|
||||
l2 = len(s2)
|
||||
for i, e1 in enumerate(s1):
|
||||
if i >= l2: return 1
|
||||
if c := elem_cmp(e1, s2[i]): return c
|
||||
if len(s1) < l2: return -1
|
||||
return 0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Decomposition:
|
||||
"""Decomposition of an expression around an operation atom.
|
||||
@ -819,7 +895,7 @@ dtypes._weak_types.append(_DimExpr)
|
||||
|
||||
def _convertible_to_int(p: DimSize) -> bool:
|
||||
try:
|
||||
op.index(p)
|
||||
op.index(p) # type: ignore
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
@ -1053,7 +1129,7 @@ class _Parser:
|
||||
|
||||
if core.is_constant_dim(expr) and self.like_shape is not None:
|
||||
like_shape_dim = self.like_shape[len(self.dimensions)]
|
||||
if expr != like_shape_dim:
|
||||
if expr != like_shape_dim: # type: ignore[operator]
|
||||
raise self.parse_err(tok,
|
||||
(f"different size {expr} for known dimension; "
|
||||
f"like={self.like_shape}"))
|
||||
@ -1178,11 +1254,11 @@ class _Parser:
|
||||
def atom(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
|
||||
if tok.exact_type == tokenize.NAME:
|
||||
if tok.string == _DimAtom.MOD:
|
||||
return self.binary_op(_DimAtom.MOD, self.next_tok())
|
||||
return self.atom_binary_op(_DimAtom.MOD, self.next_tok())
|
||||
if tok.string == _DimAtom.FLOORDIV:
|
||||
return self.binary_op(_DimAtom.FLOORDIV, self.next_tok())
|
||||
return self.atom_binary_op(_DimAtom.FLOORDIV, self.next_tok())
|
||||
if tok.string == _DimAtom.NON_NEGATIVE:
|
||||
return self.unary_op(_DimAtom.NON_NEGATIVE, self.next_tok())
|
||||
return self.atom_unary_op(_DimAtom.NON_NEGATIVE, self.next_tok())
|
||||
return _DimExpr.from_var(tok.string), self.next_tok()
|
||||
number_sign = 1
|
||||
if tok.exact_type == tokenize.MINUS: # -k are negative constants
|
||||
@ -1195,13 +1271,13 @@ class _Parser:
|
||||
self.expect_token(tok, [tokenize.NAME, tokenize.MINUS, tokenize.NUMBER])
|
||||
assert False
|
||||
|
||||
def unary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]:
|
||||
def atom_unary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]:
|
||||
tok = self.consume_token(tok, tokenize.LPAR)
|
||||
e1, tok = self.expr(tok)
|
||||
tok = self.consume_token(tok, tokenize.RPAR)
|
||||
return _DimExpr.from_operation(op, e1), tok # type: ignore
|
||||
|
||||
def binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]:
|
||||
def atom_binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]:
|
||||
tok = self.consume_token(tok, tokenize.LPAR)
|
||||
e1, tok = self.expr(tok)
|
||||
tok = self.consume_token(tok, tokenize.COMMA)
|
||||
@ -1274,7 +1350,7 @@ class CachingShapeEvaluator:
|
||||
@functools.lru_cache(128)
|
||||
def evaluate(self, e: DimSize):
|
||||
if core.is_constant_dim(e):
|
||||
res = op.index(e)
|
||||
res = op.index(e) # type: ignore
|
||||
else:
|
||||
res = e.evaluate(self.env) # type: ignore
|
||||
return res
|
||||
|
@ -1143,7 +1143,7 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfV
|
||||
partial(core.evaluate_shape, shape, dim_vars),
|
||||
dim_values, [core.dim_value_aval()] * len(dim_values), "") # type: ignore
|
||||
# Keep only the non-constant dimensions
|
||||
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf
|
||||
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf # type: ignore
|
||||
for d, d_tf in zip(shape, shape_values_tf))
|
||||
|
||||
|
||||
|
@ -604,9 +604,9 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
expect_error=(
|
||||
"Input shapes do not match the polymorphic shapes specification. "
|
||||
"Expected value >= 1 for dimension variable 'b'. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
|
||||
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
||||
"'b' = 0 from specification 'a + 2*b' for dimension args[0].shape[0] (= 2), . "
|
||||
"'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
|
||||
)),
|
||||
dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer
|
||||
@ -614,7 +614,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
expect_error=(
|
||||
"Input shapes do not match the polymorphic shapes specification. "
|
||||
"Division had remainder 1 when computing the value of 'b'. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
|
||||
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
|
||||
)),
|
||||
@ -622,10 +622,10 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
poly_spec="(a + 2*b, a, a + b)",
|
||||
expect_error=(
|
||||
"Input shapes do not match the polymorphic shapes specification. "
|
||||
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification 'a + 2*b' (= 10). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b). "
|
||||
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification '2*b + a' (= 10). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). "
|
||||
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
||||
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= 6), . "
|
||||
"'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
|
||||
)),
|
||||
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
|
||||
@ -633,7 +633,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
expect_error=(
|
||||
"Cannot solve for values of dimension variables {'c'}. "
|
||||
"We can only solve linear uni-variate constraints. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (2*a + b, a, c^2). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). "
|
||||
"Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details."
|
||||
)),
|
||||
@ -1501,7 +1501,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=[x_polymorphic_shape, y_polymorphic_shape])(x, y)
|
||||
self.assertEqual(np.float32, zw_specs[0].dtype)
|
||||
self.assertEqual(np.float32, zw_specs[1].dtype)
|
||||
self.assertEqual(("(a, 5)", "(a + b, 5)"), zw_polymorphic_shapes)
|
||||
self.assertEqual(("(a, 5)", "(b + a, 5)"), zw_polymorphic_shapes)
|
||||
|
||||
# We can use the zw_polymorphic_shapes for jax2tf.convert
|
||||
z, w = jax2tf.convert(
|
||||
|
@ -590,10 +590,10 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c",
|
||||
expect_error_outer_exp=re.escape(
|
||||
"Expected value >= 1 for dimension variable 'b'. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, a + b). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, b + a). "
|
||||
"Obtained dimension variables: 'a' = 4 from specification "
|
||||
"'a' for dimension args[0].shape[1] (= 4), "
|
||||
"'b' = c + -4 from specification 'a + b' for dimension args[0].shape[2] (= c),")),
|
||||
"'b' = c - 4 from specification 'b + a' for dimension args[0].shape[2] (= c),")),
|
||||
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,c",
|
||||
expect_error_outer_exp=re.escape(
|
||||
"Found inconsistency between dimension size "
|
||||
@ -610,10 +610,10 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
|
||||
expect_error_outer_exp=re.escape(
|
||||
"Expected value >= 1 for dimension variable 'b'. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, a + b). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, b + a). "
|
||||
"Obtained dimension variables: 'a' = c from "
|
||||
"specification 'a' for dimension args[0].shape[1] (= c), "
|
||||
"'b' = 0 from specification 'a + b' for dimension args[0].shape[2] (= c)")),
|
||||
"'b' = 0 from specification 'b + a' for dimension args[0].shape[2] (= c)")),
|
||||
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
|
||||
expect_error_outer_exp=re.escape(
|
||||
"Shape mismatch for args[0].shape[0] (expected same constant)")),
|
||||
@ -689,9 +689,9 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
expect_error=(
|
||||
"Input shapes do not match the polymorphic shapes specification. "
|
||||
"Expected value >= 1 for dimension variable 'b'. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
|
||||
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
||||
"'b' = 0 from specification 'a + 2*b' for dimension args[0].shape[0] (= 2), . "
|
||||
"'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
|
||||
)),
|
||||
dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer
|
||||
@ -699,7 +699,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
expect_error=(
|
||||
"Input shapes do not match the polymorphic shapes specification. "
|
||||
"Division had remainder 1 when computing the value of 'b'. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
|
||||
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
|
||||
)),
|
||||
@ -707,10 +707,10 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
poly_spec="(a + 2*b, a, a + b)",
|
||||
expect_error=(
|
||||
"Input shapes do not match the polymorphic shapes specification. "
|
||||
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification 'a + 2*b' (= 10). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b). "
|
||||
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification '2*b + a' (= 10). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). "
|
||||
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
|
||||
"'b' = 4 from specification 'a + b' for dimension args[0].shape[2] (= 6), . "
|
||||
"'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details."
|
||||
)),
|
||||
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
|
||||
@ -718,7 +718,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
expect_error=(
|
||||
"Cannot solve for values of dimension variables {'c'}. "
|
||||
"We can only solve linear uni-variate constraints. "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (2*a + b, a, c^2). "
|
||||
"Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). "
|
||||
"Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. "
|
||||
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details."
|
||||
)),
|
||||
|
@ -139,14 +139,34 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
("a + -1", a - 1),
|
||||
("3 * a * mod(a + 2, b + 2)", 3 * a * ((a + 2) % (b + 2))),
|
||||
("3 * floordiv(a + 2, b + 2) * 2", 3 * ((a + 2) // (b + 2)) * 2),
|
||||
("non_negative(a - 2)", core.non_negative_dim(a - 2)),
|
||||
("non_negative(a - 2)", "build_inside"),
|
||||
]])
|
||||
def test_parse_dim(self,
|
||||
dim_spec="-2 * a^2 * b + b^2",
|
||||
dim_poly=-2 * a * a * b + b * b):
|
||||
def test_parse_dim(self, dim_spec, dim_poly):
|
||||
if dim_spec == "non_negative(a - 2)":
|
||||
dim_poly = core.non_negative_dim(DimExprTest.a - 2)
|
||||
self.assertEqual((dim_poly,), shape_poly.symbolic_shape(dim_spec))
|
||||
self.assertEqual((dim_poly,), shape_poly.symbolic_shape(str(dim_poly)))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(dim_spec=dim_spec)
|
||||
for dim_spec in [
|
||||
"b + a",
|
||||
"a*b + a^2 + b + a",
|
||||
"mod(a, 4) + floordiv(a, 4) + a",
|
||||
"2*a^2 - 3*a - 1",
|
||||
"a^2 + 3*a - 1",
|
||||
"-1*a + 3",
|
||||
"-1*mod(a, 4) + 3",
|
||||
"-2*a + 3",
|
||||
"a*floordiv(b, 8)*mod(b, 4)",
|
||||
]
|
||||
]
|
||||
)
|
||||
def test_print_dim(self, *, dim_spec: str):
|
||||
e, = shape_poly.symbolic_shape(dim_spec)
|
||||
self.assertEqual(str(e), dim_spec)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
# sanitized shape_spec sometimes collide
|
||||
@ -233,6 +253,50 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # An Array
|
||||
self.assertFalse(core.definitely_equal(1, "a"))
|
||||
|
||||
def test_atoms_ordering(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
|
||||
self.assertTrue(a.to_atom() < b.to_atom())
|
||||
self.assertFalse(a.to_atom() >= b.to_atom())
|
||||
self.assertTrue(a.to_atom() <= b.to_atom())
|
||||
self.assertTrue(a.to_atom() != b.to_atom())
|
||||
|
||||
self.assertTrue(a.to_atom() < (a % 4).to_atom())
|
||||
self.assertFalse(a.to_atom() > (a % 4).to_atom())
|
||||
# FLOORDIV comes before MON because we compare operations alphabetically
|
||||
self.assertTrue((a // 4).to_atom() < (a % 4).to_atom())
|
||||
|
||||
self.assertEqual(hash((a // 4).to_atom()), hash((a // 4).to_atom()))
|
||||
|
||||
def test_monomial_ordering(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
self.assertTrue(a.to_monomial() < b.to_monomial())
|
||||
self.assertTrue(a.to_monomial() <= b.to_monomial())
|
||||
self.assertTrue(b.to_monomial() >= a.to_monomial())
|
||||
self.assertTrue(b.to_monomial() > a.to_monomial())
|
||||
|
||||
self.assertTrue(a.to_monomial() < (a * a).to_monomial())
|
||||
self.assertTrue(b.to_monomial() < (a * a).to_monomial())
|
||||
self.assertTrue((a * a * b).to_monomial() < (a * b * b).to_monomial())
|
||||
e1 = a * a * b + a * b * b + a * b + a * a + a + b
|
||||
|
||||
sorted_e1 = [shape_poly._DimExpr.from_monomial(m, m_count)
|
||||
for m, m_count in e1.monomials_sorted()]
|
||||
self.assertSequenceEqual(sorted_e1,
|
||||
[a, b, a * a, a * b, a * a * b, a * b * b])
|
||||
|
||||
e2 = a * (a // 4) + (a // 4) + b * (a // 4) + b * (a % 4) + a * a + b
|
||||
sorted_e2 = [shape_poly._DimExpr.from_monomial(m, m_count)
|
||||
for m, m_count in e2.monomials_sorted()]
|
||||
self.assertSequenceEqual(sorted_e2,
|
||||
[b, a // 4, a * a, a * (a // 4), b * (a // 4), b * (a % 4)])
|
||||
|
||||
# This failed with a previous implementation of atom equality
|
||||
self.assertNotEqual(shape_poly._DimMon.from_operation(shape_poly._DimAtom.NON_NEGATIVE,
|
||||
a - b - 1),
|
||||
shape_poly._DimMon.from_operation(shape_poly._DimAtom.NON_NEGATIVE,
|
||||
a - 2*b - 1))
|
||||
|
||||
def test_poly_bounds(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
bounded_le4 = 5 - a
|
||||
@ -437,7 +501,7 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
(a * a - b * b, a + b, a - b, 0),
|
||||
(a, b, "floordiv(a, b)", "mod(a, b)"),
|
||||
(3 * a, 2, "floordiv(3*a, 2)", "mod(3*a, 2)"),
|
||||
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, a + b)", "mod(2*a*b + b^2, a + b)"),
|
||||
(2 * a * b + b * b, a + b, "floordiv(b^2 + 2*a*b, b + a)", "mod(b^2 + 2*a*b, b + a)"),
|
||||
(3, a, "floordiv(3, a)", "mod(3, a)"),
|
||||
]])
|
||||
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
|
||||
|
Loading…
x
Reference in New Issue
Block a user