mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19231 from gnecula:poly_eq
PiperOrigin-RevId: 596400074
This commit is contained in:
commit
4998c80bcd
@ -19,6 +19,13 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
devices.
|
||||
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
|
||||
and `descending` arguments.
|
||||
* Several changes to the handling of shape polymorphism (for
|
||||
{mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`): cleaner
|
||||
pretty-printing of symbolic expressions ({jax-issue}`#19227`); simplified
|
||||
and faster equality comparisons, where we consider two symbolic dimensions
|
||||
to be equal if the normalized form of their difference reduces to 0
|
||||
({jax-issue}`#19231`; note that this may result in user-visible behavior
|
||||
changes).
|
||||
* Deprecations & Removals
|
||||
* A number of previously deprecated functions have been removed, following a
|
||||
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
|
||||
|
@ -527,14 +527,24 @@ class _DimExpr():
|
||||
return cmp_comparable(s_mon[1], o_mon[1])
|
||||
return cmp_sequence(s_mons, o_mons, cmp_mon)
|
||||
|
||||
def eq(self, other) -> bool:
|
||||
lb, ub = _ensure_poly(self - other, "eq").bounds()
|
||||
if lb == ub == 0:
|
||||
return True
|
||||
if lb > 0 or ub < 0:
|
||||
def eq(self, other: _DimExpr) -> bool:
|
||||
# Equality is used very frequently because expressions are cached. We could
|
||||
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
|
||||
# but that would be too expensive. It would also have the unfortunate drawback
|
||||
# that we cannot then cache `e.bounds()` because hashing invokes equality
|
||||
# which would lead to infinite recursion.
|
||||
diff = self - other
|
||||
|
||||
# We look for `self - other == k`, and we rely on the fact that when we
|
||||
# normalize _DimExpr that represent integers as ints.
|
||||
if is_poly_dim(diff):
|
||||
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
|
||||
# cannot raise exceptions, because it is used indirectly when hashing.
|
||||
# So, we say that the expressions are disequal, which is really unsound.
|
||||
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported
|
||||
return False
|
||||
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported
|
||||
return False
|
||||
|
||||
return diff == 0
|
||||
|
||||
def inconclusive_comparison(self, operation: str, op: Any) -> Exception:
|
||||
return InconclusiveDimensionOperation(
|
||||
@ -698,9 +708,15 @@ class _DimExpr():
|
||||
raise InconclusiveDimensionOperation(f"Symbolic dimension '{self}' used in a context that requires a constant")
|
||||
|
||||
# We must overload __eq__ and __ne__, or else we get unsound defaults.
|
||||
__eq__ = eq
|
||||
def __ne__(self, other) -> bool:
|
||||
return not self.eq(other)
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, _DimExpr) and not core.is_constant_dim(other):
|
||||
return False
|
||||
else:
|
||||
other = _ensure_poly(other, "eq")
|
||||
return self.eq(other)
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
__ge__ = ge
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user