Merge pull request #19231 from gnecula:poly_eq

PiperOrigin-RevId: 596400074
This commit is contained in:
jax authors 2024-01-07 10:24:36 -08:00
commit 4998c80bcd
2 changed files with 33 additions and 10 deletions

View File

@ -19,6 +19,13 @@ Remember to align the itemized text with the first line of an item within a list
devices.
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
and `descending` arguments.
* Several changes to the handling of shape polymorphism (for
{mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`): cleaner
pretty-printing of symbolic expressions ({jax-issue}`#19227`); simplified
and faster equality comparisons, where we consider two symbolic dimensions
to be equal if the normalized form of their difference reduces to 0
({jax-issue}`#19231`; note that this may result in user-visible behavior
changes).
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).

View File

@ -527,14 +527,24 @@ class _DimExpr():
return cmp_comparable(s_mon[1], o_mon[1])
return cmp_sequence(s_mons, o_mons, cmp_mon)
def eq(self, other) -> bool:
lb, ub = _ensure_poly(self - other, "eq").bounds()
if lb == ub == 0:
return True
if lb > 0 or ub < 0:
def eq(self, other: _DimExpr) -> bool:
# Equality is used very frequently because expressions are cached. We could
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
# but that would be too expensive. It would also have the unfortunate drawback
# that we cannot then cache `e.bounds()` because hashing invokes equality
# which would lead to infinite recursion.
diff = self - other
# We look for `self - other == k`, and we rely on the fact that when we
# normalize _DimExpr that represent integers as ints.
if is_poly_dim(diff):
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
# cannot raise exceptions, because it is used indirectly when hashing.
# So, we say that the expressions are disequal, which is really unsound.
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported
return False
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported
return False
return diff == 0
def inconclusive_comparison(self, operation: str, op: Any) -> Exception:
return InconclusiveDimensionOperation(
@ -698,9 +708,15 @@ class _DimExpr():
raise InconclusiveDimensionOperation(f"Symbolic dimension '{self}' used in a context that requires a constant")
# We must overload __eq__ and __ne__, or else we get unsound defaults.
__eq__ = eq
def __ne__(self, other) -> bool:
return not self.eq(other)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, _DimExpr) and not core.is_constant_dim(other):
return False
else:
other = _ensure_poly(other, "eq")
return self.eq(other)
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
__ge__ = ge