mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
2154 lines
83 KiB
Python
2154 lines
83 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Shape polymorphism support.
|
|
|
|
See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import enum
|
|
from collections.abc import Callable, Sequence
|
|
import dataclasses
|
|
from enum import Enum
|
|
import functools
|
|
import itertools
|
|
import io
|
|
import copy
|
|
import operator as op
|
|
import tokenize
|
|
from typing import Any, Union, overload
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import opt_einsum
|
|
|
|
import jax
|
|
from jax.interpreters import xla
|
|
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src.lax import lax
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.numpy import lax_numpy
|
|
from jax._src import source_info_util
|
|
from jax._src import tree_util
|
|
from jax._src import util
|
|
|
|
|
|
DimSize = Union["_DimExpr", int]
|
|
TfVal = Any
|
|
DimVarEnv = dict[str, jax.Array]
|
|
DType = Any
|
|
|
|
# Tuples of terms and their coefficients, sorted with the largest term first.
|
|
SortedTerms = Sequence[tuple["_DimTerm", int]]
|
|
SortedFactors = Sequence[tuple["_DimFactor", int]]
|
|
|
|
# Normalization rules represent the explicit constraint `t*tk == e` as
|
|
# a mapping of `t` to `(e, tk)`.
|
|
NormalizationRules = dict["_DimTerm", tuple["_DimExpr", int]]
|
|
|
|
|
|
class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation):
|
|
"""Raised when we cannot conclusively compute with symbolic dimensions."""
|
|
|
|
_help_msg = """
|
|
This error arises for comparison operations with shapes that
|
|
are non-constant, and the result of the operation cannot be represented as
|
|
a boolean value for all values of the symbolic dimensions involved.
|
|
|
|
Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported
|
|
for more details.
|
|
"""
|
|
|
|
def __init__(self, message: str):
|
|
error_msg = f"{message}{InconclusiveDimensionOperation._help_msg}"
|
|
# https://github.com/python/mypy/issues/5887
|
|
super().__init__(error_msg)
|
|
|
|
class UnexpectedDimVar(Exception):
|
|
pass
|
|
|
|
class Comparator(Enum):
|
|
EQ = 1
|
|
GEQ = 2
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class _SymbolicConstraint:
|
|
# Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2
|
|
cmp: Comparator
|
|
debug_str: str # The form in which the user expressed it, for error messages
|
|
e1: DimSize # This has been normalized w.r.t. previous constraints only
|
|
e2: DimSize # This has been normalized w.r.t. previous constraints only
|
|
|
|
def __repr__(self):
|
|
return f"Constraint({self.debug_str})"
|
|
|
|
|
|
class _DimFactor:
|
|
"""Represents a factor in a symbolic dimension expression.
|
|
|
|
Factors are either variables, or expressions of the form floordiv(E1, E2) or
|
|
mod(E1, E2), or max(E1, E2), or min(E1, E2).
|
|
Factors are multiplied to form terms (see _DimTerm), and
|
|
terms are added to form symbolic expressions (see _DimExpr).
|
|
|
|
Args:
|
|
* var: if specified then the factor is a dimension variable. `operation`
|
|
must be `None`.
|
|
* operation: if specified then the factor is an operation applied to
|
|
`operands`. One of `FLOORDIR` or `MOD` or `MAX` or `MIN`.
|
|
`var` must be `None`
|
|
* operands: the operands to which the operation is applied.
|
|
"""
|
|
# The supported operations
|
|
# FLOORDIV(e1, e2) and MOD(e1, e2) have the same semantics as in Python:
|
|
# FLOORDIV(e1, e2) = e1 // e2 = floor(e1 / e2)
|
|
# if e2 > 0 then 0 <= MOD(e1, e2) < e2
|
|
# if e2 < 0 then e2 < MOD(e1, e2) <= 0
|
|
# e1 = e2 * FLOORDIV(e1, e2) + MOD(e1, e2)
|
|
#
|
|
FLOORDIV = "floordiv"
|
|
MOD = "mod"
|
|
MAX = "max"
|
|
MIN = "min"
|
|
NON_NEGATIVE = "non_negative" # The max of the operand and 0. Replaced with
|
|
# max but kept here for backwards compatibility.
|
|
|
|
__slots__ = ["var", "operation", "operands", "_hash", "_size"]
|
|
|
|
def __init__(self, *operands: _DimExpr,
|
|
var: str | None = None,
|
|
operation: str | None = None):
|
|
if var is not None:
|
|
assert operation is None
|
|
assert not operands
|
|
else:
|
|
assert operation is not None
|
|
self.var = var
|
|
self.operation = operation
|
|
self.operands = operands
|
|
self._hash = None
|
|
self._size: int = 1 if var is not None else 1 + sum(o._size for o in operands)
|
|
|
|
@staticmethod
|
|
def from_var(v: str) -> _DimFactor:
|
|
return _DimFactor(var=v)
|
|
|
|
@staticmethod
|
|
def from_operation(operation: str, *operands: DimSize,
|
|
scope: SymbolicScope) -> _DimFactor:
|
|
return _DimFactor(*(_ensure_poly(o, operation, scope) for o in operands),
|
|
operation=operation)
|
|
|
|
def to_var(self) -> str | None:
|
|
return self.var
|
|
|
|
def get_vars(self) -> set[str]:
|
|
# All the vars that appear
|
|
if self.var is not None:
|
|
return {self.var}
|
|
else:
|
|
acc = set()
|
|
for opnd in self.operands:
|
|
acc.update(opnd._get_vars())
|
|
return acc
|
|
|
|
def __str__(self):
|
|
if self.var is not None:
|
|
return self.var
|
|
opnd_str = ", ".join([str(opnd) for opnd in self.operands])
|
|
return f"{self.operation}({opnd_str})"
|
|
__repr__ = __str__
|
|
|
|
def __hash__(self):
|
|
if self._hash is None:
|
|
self._hash = hash((self.var, self.operation, *self.operands))
|
|
return self._hash
|
|
|
|
def _syntactic_cmp(self, other: _DimFactor) -> int:
|
|
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
|
|
The comparison is done lexicographically (syntactic), to be used for sorting.
|
|
The result is not related to the semantic value.
|
|
"""
|
|
if c := cmp_comparable(self._size, other._size): return c
|
|
if self.var is not None:
|
|
return cmp_comparable(self.var, other.var)
|
|
if c := cmp_comparable(self.operation, other.operation): return c
|
|
return cmp_sequence(self.operands, other.operands,
|
|
lambda s_o, o_o: s_o._syntactic_cmp(o_o))
|
|
|
|
def __eq__(self, other: Any):
|
|
"""Lexicographic comparison."""
|
|
if not isinstance(other, _DimFactor): return False
|
|
return self._syntactic_cmp(other) == 0
|
|
|
|
def __lt__(self, other: _DimFactor):
|
|
"""Lexicographic comparison."""
|
|
return self._syntactic_cmp(other) < 0
|
|
|
|
def __le__(self, other: _DimFactor):
|
|
"""Lexicographic comparison."""
|
|
return self._syntactic_cmp(other) <= 0
|
|
|
|
def __gt__(self, other: _DimFactor):
|
|
"""Lexicographic comparison."""
|
|
return self._syntactic_cmp(other) > 0
|
|
|
|
def __ge__(self, other: _DimFactor):
|
|
"""Lexicographic comparison"""
|
|
return self._syntactic_cmp(other) >= 0
|
|
|
|
def evaluate(self, env: DimVarEnv, scope: SymbolicScope):
|
|
if self.var is not None:
|
|
try:
|
|
return env[self.var]
|
|
except KeyError:
|
|
# Perhaps there is a normalization rule for this variable
|
|
normalized_var = _DimExpr._from_var(self.var, scope)
|
|
if core.is_constant_dim(normalized_var):
|
|
return normalized_var
|
|
non_trivial_normalization = (v1 := normalized_var._to_var()) is None or v1 != self.var # type: ignore
|
|
if non_trivial_normalization:
|
|
return normalized_var._evaluate(env) # type: ignore
|
|
err_msg = (
|
|
f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n"
|
|
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
|
|
raise UnexpectedDimVar(err_msg)
|
|
else:
|
|
operand_values = [opnd._evaluate(env) for opnd in self.operands]
|
|
if self.operation == _DimFactor.FLOORDIV:
|
|
return divmod(*operand_values)[0] # type: ignore
|
|
elif self.operation == _DimFactor.MOD:
|
|
return divmod(*operand_values)[1] # type: ignore
|
|
elif self.operation == _DimFactor.MAX:
|
|
op1, op2 = operand_values
|
|
if core.is_constant_dim(op1) and core.is_constant_dim(op2):
|
|
return max(op1, op2)
|
|
if core.is_symbolic_dim(op1) or core.is_symbolic_dim(op2):
|
|
return core.max_dim(op1, op2)
|
|
# In the context of `evaluate` dimension variables may be mapped to
|
|
# JAX Tracers.
|
|
return lax.max(op1, op2)
|
|
elif self.operation == _DimFactor.MIN:
|
|
op1, op2 = operand_values
|
|
if core.is_constant_dim(op1) and core.is_constant_dim(op2):
|
|
return min(op1, op2)
|
|
if core.is_symbolic_dim(op1) or core.is_symbolic_dim(op2):
|
|
return core.min_dim(op1, op2)
|
|
# In the context of `evaluate` dimension variables may be mapped to
|
|
# JAX Tracers.
|
|
return lax.min(op1, op2)
|
|
else:
|
|
assert False, self.operation
|
|
|
|
def __deepcopy__(self, memo):
|
|
return _DimFactor(*copy.deepcopy(self.operands, memo),
|
|
var=copy.deepcopy(self.var, memo),
|
|
operation=copy.deepcopy(self.operation, memo))
|
|
|
|
|
|
class _DimTerm:
|
|
"""Represents a multiplication of factors.
|
|
|
|
The representation is a sequence of _DimFactor factors along with their
|
|
integer exponents (>= 1). The empty sequence represents the constant 1.
|
|
"""
|
|
__slots__ = ["_factors", "_hash", "_size"]
|
|
def __init__(self, sorted_factors: SortedFactors):
|
|
self._factors = sorted_factors
|
|
self._hash = None
|
|
self._size = sum((1 + f_exp * f._size) for f, f_exp in self._factors)
|
|
|
|
def __hash__(self):
|
|
if self._hash is None:
|
|
self._hash = hash(tuple(self._factors))
|
|
return self._hash
|
|
|
|
def __str__(self):
|
|
return "*".join(f"{fact}^{exponent}" if exponent != 1 else str(fact)
|
|
for fact, exponent in sorted(self._factors))
|
|
|
|
__repr__ = __str__
|
|
|
|
@staticmethod
|
|
def from_var(v: str) -> _DimTerm:
|
|
return _DimTerm(((_DimFactor.from_var(v), 1),))
|
|
|
|
@staticmethod
|
|
def from_factor(f: _DimFactor, f_exp: int):
|
|
return _DimTerm(((f, f_exp),))
|
|
|
|
@staticmethod
|
|
def from_operation(operation: str, *operands: DimSize,
|
|
scope: SymbolicScope) -> _DimTerm:
|
|
return _DimTerm(((_DimFactor.from_operation(operation, *operands,
|
|
scope=scope), 1),))
|
|
|
|
def to_var(self) -> str | None:
|
|
"""Extract the variable name from a term.
|
|
Return None if the term is not a single variable."""
|
|
a = self.to_factor()
|
|
return a.to_var() if a is not None else None
|
|
|
|
def to_factor(self) -> _DimFactor | None:
|
|
"""Extract the single factor from a term.
|
|
Return None if the term is not a single factor."""
|
|
if len(self._factors) > 1: return None
|
|
(f, f_exp), = self._factors
|
|
if f_exp != 1: return None
|
|
return f
|
|
|
|
def get_vars(self) -> set[str]:
|
|
# All the vars that appear in the term.
|
|
acc = set()
|
|
for (f, _) in self._factors:
|
|
acc.update(f.get_vars())
|
|
return acc
|
|
|
|
@property
|
|
def is_constant(self):
|
|
return not self._factors
|
|
|
|
def _syntactic_cmp(self, other: _DimTerm) -> int:
|
|
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
|
|
The comparison is done lexicographically (syntactic), to be used for sorting.
|
|
The result is not related to the semantic value.
|
|
"""
|
|
if c := cmp_comparable(self._size, other._size): return c
|
|
def cmp_factor(s_f: tuple[_DimFactor, int], o_f: tuple[_DimFactor, int]) -> int:
|
|
if c := s_f[0]._syntactic_cmp(o_f[0]): return c
|
|
# Consider the terms with exponents to be expanded as multiplications.
|
|
# Then a higher exponent for a "large" factor should lead to a "larger" term.
|
|
return cmp_comparable(s_f[1], o_f[1])
|
|
|
|
return cmp_sequence(self._factors, other._factors, cmp_factor)
|
|
|
|
def __lt__(self, other: _DimTerm):
|
|
"""Lexicographic comparison"""
|
|
return self._syntactic_cmp(other) < 0
|
|
|
|
def __le__(self, other: _DimTerm):
|
|
"""Lexicographic comparison"""
|
|
return self._syntactic_cmp(other) <= 0
|
|
|
|
def __gt__(self, other: _DimTerm):
|
|
"""Lexicographic comparison"""
|
|
return self._syntactic_cmp(other) > 0
|
|
|
|
def __ge__(self, other: _DimTerm):
|
|
"""Lexicographic comparison"""
|
|
return self._syntactic_cmp(other) >= 0
|
|
|
|
def __eq__(self, other) -> bool:
|
|
if not isinstance(other, _DimTerm): return False
|
|
return self._syntactic_cmp(other) == 0
|
|
|
|
def __ne__(self, other) -> bool:
|
|
return not (self == other)
|
|
|
|
def mul(self, other: _DimTerm) -> _DimTerm:
|
|
"""
|
|
Returns the product with another term. Example: (n^2*m) * n == n^3 * m.
|
|
"""
|
|
return _DimTerm(_DimExpr._linear_combination_sorted_pairs(self._factors, 0, 1,
|
|
other._factors, 0, 1))
|
|
|
|
def divide(self, divisor: _DimTerm) -> _DimTerm:
|
|
"""
|
|
Divides by another term. Raises a InconclusiveDimensionOperation
|
|
if the result is not a term.
|
|
For example, (n^3 * m) // n == n^2*m, but n // m fails.
|
|
"""
|
|
new_factors = _DimExpr._linear_combination_sorted_pairs(self._factors, 0, 1,
|
|
divisor._factors, 0, -1)
|
|
for _, f_exp in new_factors:
|
|
if f_exp <= 0:
|
|
raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.")
|
|
return _DimTerm(new_factors)
|
|
|
|
def evaluate(self, env: DimVarEnv, scope: SymbolicScope):
|
|
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1)
|
|
def pow_opt(v, p: int):
|
|
return v if p == 1 else prod([v] * p)
|
|
return prod([pow_opt(f.evaluate(env, scope), exp) for f, exp in self._factors])
|
|
|
|
def __deepcopy__(self, memo):
|
|
return _DimTerm(copy.deepcopy(self._factors, memo))
|
|
|
|
# The constant 1, as a term.
|
|
_DimTerm_one = _DimTerm(())
|
|
|
|
|
|
class _DimExpr:
|
|
"""Symbolic expressions using dimension variables.
|
|
|
|
A dimension expression is an addition of terms (_DimTerm), which themselves
|
|
are products of factors (_DimFactor).
|
|
|
|
The representation of a _DimExpr is as sequence of pairs `(term, coeff)`,
|
|
representing the linear combination of terms with the given coefficients.
|
|
The sequence is sorted by lexicographic (syntactic) ordering of `_DimTerm`,
|
|
with the largest terms first. The special term `_DimTerm_one` is mapped
|
|
to the free integer coefficient of the expression.
|
|
|
|
We overload integer operations, but we do that soundly, raising
|
|
:class:`InconclusiveDimensionOperation` when the result is not
|
|
representable as a _DimExpr.
|
|
"""
|
|
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
|
|
__slots__ = ("_sorted_terms", "_scope", "_hash", "_size")
|
|
def __init__(self, sorted_terms: SortedTerms,
|
|
scope: SymbolicScope):
|
|
# Do not construct _DimExpr directly, unless you are sure that `terms` is
|
|
# normalized; Use _DimExpr._normalize_sorted_terms.
|
|
self._sorted_terms = tuple(sorted_terms) or ((_DimTerm_one, 0),)
|
|
self._scope = scope
|
|
self._hash = None
|
|
# _size speeds up _syntactic_cmp, which is used a lot for hashing.
|
|
self._size = sum((1 + abs(m_count) * m._size)
|
|
for m, m_count in self._sorted_terms)
|
|
|
|
@property
|
|
def scope(self):
|
|
# We make the expression scope visible, but read-only.
|
|
return self._scope
|
|
|
|
@staticmethod
|
|
def _coeff_to_sorted_terms(coeffs: dict[_DimTerm, int]) -> SortedTerms:
|
|
return sorted((p for p in coeffs.items() if p[1] != 0), reverse=True)
|
|
|
|
@staticmethod
|
|
def _from_term(t: _DimTerm, t_k: int, scope: SymbolicScope) -> DimSize:
|
|
return _DimExpr._normalize_sorted_terms(((t, t_k),), scope)
|
|
|
|
@staticmethod
|
|
def _from_var(v: str, scope: SymbolicScope) -> DimSize:
|
|
return _DimExpr._normalize_sorted_terms(((_DimTerm.from_var(v), 1),), scope)
|
|
|
|
@staticmethod
|
|
def _from_operation(operation: str, *operands: DimSize,
|
|
scope: SymbolicScope) -> DimSize:
|
|
if operation == _DimFactor.NON_NEGATIVE: # For parsing, for backwards compatibility
|
|
return _DimExpr._from_term(
|
|
_DimTerm.from_operation(_DimFactor.MAX, *operands, 0,
|
|
scope=scope), 1,
|
|
scope=scope)
|
|
return _DimExpr._from_term(
|
|
_DimTerm.from_operation(operation, *operands, scope=scope), 1,
|
|
scope=scope)
|
|
|
|
@property
|
|
def _leading_term(self) -> tuple[_DimTerm, int]:
|
|
"""Returns the highest degree term that comes last lexicographically."""
|
|
return self._sorted_terms[0]
|
|
|
|
def _to_single_term(self) -> tuple[int, int, _DimTerm] | None:
|
|
"""Extracts the single term: k + c * term.
|
|
Returns None if the expression is not a single term, or (k, c, term)
|
|
"""
|
|
n1 = 0
|
|
n2 = 0
|
|
term = None
|
|
for t, t_k in self._sorted_terms:
|
|
if t.is_constant:
|
|
n1 = t_k
|
|
continue
|
|
if term is None:
|
|
term = t
|
|
n2 = t_k
|
|
continue
|
|
return None
|
|
assert term is not None
|
|
return (n1, n2, term)
|
|
|
|
@staticmethod
|
|
def _add_coeff(coeffs: dict[_DimTerm, int], t: _DimTerm, coeff: int):
|
|
"""coeffs[t] += coeff, with squashing 0 coefficients."""
|
|
if coeff == 0: return
|
|
coeffs[t] = coeffs.get(t, 0) + coeff
|
|
|
|
@staticmethod
|
|
def _normalize_term(t: _DimTerm, t_k: int,
|
|
scope: SymbolicScope) -> Sequence[tuple[_DimTerm, int]]:
|
|
# If (t, t_k) is among the scope normalization rules, then return
|
|
# a list of `term * coefficient` to add to the expression containing (t, t_k).
|
|
# Returns the empty sequence if no normalizations are necessary.
|
|
if not scope._normalization_rules: return []
|
|
updates = []
|
|
after, t_k_after = scope._normalization_rules.get(t, (None, 0))
|
|
if after is not None and t_k % t_k_after == 0:
|
|
# We have t*t_k_after -> after.
|
|
# We subtract `t*t_k` and add `after * (t_k // t_k_after)`.
|
|
updates.append((t, - t_k))
|
|
updates.extend((t2, tc2 * (t_k // t_k_after))
|
|
for t2, tc2 in after._sorted_terms)
|
|
return updates
|
|
|
|
if len(t._factors) <= 1:
|
|
return updates
|
|
|
|
# A product of factors; look up individually
|
|
for f, fexp in t._factors:
|
|
f_after, f_k_after = scope._normalization_rules.get(_DimTerm(((f, fexp),)), (None, 0))
|
|
if f_after is not None and t_k % f_k_after == 0:
|
|
# We subtract `t*t_k`.
|
|
updates.append((t, - t_k))
|
|
# And add `(t // f**fexp) * f_after * (t_k // f_k_after)`
|
|
t_without_f = t.divide(_DimTerm(((f, fexp),)))
|
|
updates.extend((t2.mul(t_without_f), tc2 * (t_k // f_k_after))
|
|
for t2, tc2 in f_after._sorted_terms)
|
|
return updates
|
|
return updates
|
|
|
|
@staticmethod
|
|
def _normalize_sorted_terms(terms: SortedTerms,
|
|
scope: SymbolicScope) -> DimSize:
|
|
"""Constructs a _DimExpr in normal form from sorted terms.
|
|
|
|
Ensures that the symbolic dimension is normalized, e.g., does not
|
|
have terms with coefficient 0, it reflects all the scope
|
|
normalization_rules, and it is represented as a Python integer if it is
|
|
known to be a constant.
|
|
|
|
Does not attempt to normalize the keys (terms) inside `terms`.
|
|
"""
|
|
for t, t_k in terms:
|
|
assert t_k != 0
|
|
if updates := _DimExpr._normalize_term(t, t_k, scope):
|
|
coeffs = dict(terms)
|
|
for t1, t1_k in updates:
|
|
_DimExpr._add_coeff(coeffs, t1, t1_k)
|
|
terms = _DimExpr._coeff_to_sorted_terms(coeffs)
|
|
# TODO: check the case when we need to apply multiple normalizations
|
|
break
|
|
|
|
if not terms: return 0
|
|
if terms[0][0].is_constant: return terms[0][1]
|
|
return _DimExpr(terms, scope)
|
|
|
|
def _to_term(self) -> _DimTerm | None:
|
|
"""Extract the single term from a symbolic expression.
|
|
Returns None if the expression is not a single term."""
|
|
if len(self._sorted_terms) > 1: return None
|
|
(t, t_k), = self._sorted_terms
|
|
return t if t_k == 1 else None
|
|
|
|
def _to_factor(self) -> _DimFactor | None:
|
|
"""Extract the factor from a symbolic expression.
|
|
Returns None if the expression is not a single factor."""
|
|
t = self._to_term()
|
|
return t.to_factor() if t is not None else None
|
|
|
|
def _to_var(self) -> str | None:
|
|
"""Extract the variable name from a symbolic expression.
|
|
Returns None if the expression is not a single variable."""
|
|
mon = self._to_factor()
|
|
return mon.to_var() if mon is not None else None
|
|
|
|
@staticmethod
|
|
def _to_constant(e: DimSize) -> int | None:
|
|
"""Extract the constant from a symbolic expression.
|
|
Returns None if the expression is not a single constant."""
|
|
if not isinstance(e, _DimExpr):
|
|
return int(e)
|
|
m, m_c = e._leading_term
|
|
return m_c if m.is_constant else None
|
|
|
|
@property
|
|
def _is_constant(self):
|
|
return _DimExpr._to_constant(self) is not None
|
|
|
|
def _get_vars(self) -> set[str]:
|
|
"""The variables that appear in a symbolic dimension."""
|
|
acc = set()
|
|
for mon, _ in self._sorted_terms:
|
|
acc.update(mon.get_vars())
|
|
return acc
|
|
|
|
# There are some uses already of `get_vars`, we keep it a while longer
|
|
# for backwards compatibility.
|
|
get_vars = _get_vars
|
|
|
|
@overload
|
|
@staticmethod
|
|
def _linear_combination_sorted_pairs(
|
|
e1: SortedTerms, i1: int, f1: int,
|
|
e2: SortedTerms, i2: int, f2: int) -> SortedTerms: ... # type: ignore[bad-return-type,unused-ignore]
|
|
|
|
@overload
|
|
@staticmethod
|
|
def _linear_combination_sorted_pairs(
|
|
e1: SortedFactors, i1: int, f1: int,
|
|
e2: SortedFactors, i2: int, f2: int) -> SortedFactors: ... # type: ignore[bad-return-type,unused-ignore]
|
|
|
|
@staticmethod
|
|
def _linear_combination_sorted_pairs(
|
|
pairs1, i1, f1,
|
|
pairs2, i2, f2):
|
|
"""Computes e1[i1:] * f1 + e2[i2:] * f2.
|
|
|
|
e1, e2, and the result are sorted with largest term first.
|
|
This is an optimization for a common operation. The unoptimized code would
|
|
compute each subexpression in turn. This works for both SortedTerms and SortedFactors.
|
|
"""
|
|
len1 = len(pairs1)
|
|
len2 = len(pairs2)
|
|
acc = []
|
|
while i1 < len1 and i2 < len2:
|
|
m1, m1_c = pairs1[i1]
|
|
m2, m2_c = pairs2[i2]
|
|
cmp = m1._syntactic_cmp(m2) # Pick the largest term
|
|
if cmp < 0:
|
|
acc.append((m2, m2_c * f2))
|
|
i2 += 1
|
|
elif cmp > 0:
|
|
acc.append((m1, m1_c * f1))
|
|
i1 += 1
|
|
else: # They are equal, combine them
|
|
i1 += 1
|
|
i2 += 1
|
|
m1_c = m1_c * f1 + m2_c * f2
|
|
if m1_c == 0: continue
|
|
acc.append((m1, m1_c))
|
|
|
|
if i1 < len1:
|
|
acc.extend((m1, m1_c * f1) for m1, m1_c in itertools.islice(pairs1, i1, len1) if m1_c != 0)
|
|
if i2 < len2:
|
|
acc.extend((m2, m2_c * f2) for m2, m2_c in itertools.islice(pairs2, i2, len2) if m2_c != 0)
|
|
return acc
|
|
|
|
def _syntactic_cmp(self, other: _DimExpr) -> int:
|
|
"""Returns -1 if self < other, 0 if self == other, 1 if self > other.
|
|
The comparison is done lexicographically (syntactic), to be used for sorting.
|
|
The result is not related to the semantic value.
|
|
"""
|
|
s_terms = self._sorted_terms
|
|
o_terms = other._sorted_terms
|
|
if c := cmp_comparable(self._size, other._size): return c
|
|
def cmp_factor(s_f: tuple[_DimTerm, int], o_f: tuple[_DimTerm, int]) -> int:
|
|
if c := s_f[0]._syntactic_cmp(o_f[0]): return c
|
|
return cmp_comparable(s_f[1], o_f[1])
|
|
return cmp_sequence(s_terms, o_terms, cmp_factor)
|
|
|
|
def _eq(self, other: _DimExpr) -> bool:
|
|
# Equality is used very frequently because expressions are cached. We could
|
|
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
|
|
# but that would be too expensive. It would also have the unfortunate drawback
|
|
# that we cannot then cache `e.bounds()` because hashing invokes equality
|
|
# which would lead to infinite recursion.
|
|
diff = self - other
|
|
|
|
# We look for `self - other == k`, and we rely on the fact that when we
|
|
# normalize _DimExpr that represent integers as ints.
|
|
if is_symbolic_dim(diff):
|
|
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
|
|
# cannot raise exceptions, because it is used indirectly when hashing.
|
|
# So, we say that the expressions are disequal, which is really unsound.
|
|
# See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported
|
|
return False
|
|
|
|
return diff == 0
|
|
|
|
def __hash__(self):
|
|
if self._hash is None:
|
|
self._hash = hash((self._sorted_terms, self.scope))
|
|
return self._hash
|
|
|
|
def __str__(self):
|
|
def _one_term(t, t_k):
|
|
abs_t_k = abs(t_k)
|
|
sgn_t_k = "+" if t_k > 0 else "-"
|
|
if t.is_constant:
|
|
return f"{sgn_t_k} {abs_t_k}" if abs_t_k != 0 else "0"
|
|
if abs_t_k == 1:
|
|
return f"{sgn_t_k} {t}"
|
|
return f"{sgn_t_k} {abs_t_k}*{t}"
|
|
# We print first the "larger" terms, so that the constant is last.
|
|
res = " ".join(_one_term(t, t_k)
|
|
for t, t_k in self._sorted_terms)
|
|
if res.startswith("+ "):
|
|
res = res[2:]
|
|
return res
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
# A special case for linear combinations because they are common
|
|
@staticmethod
|
|
def _linear_combination(e1: DimSize, k1: int,
|
|
e2: DimSize, k2: int,
|
|
scope: SymbolicScope) -> DimSize:
|
|
"""Computes and normalizes `e1 * k1 + e2 * k2`"""
|
|
if isinstance(e1, _DimExpr):
|
|
e1_terms = e1._sorted_terms
|
|
if isinstance(e2, _DimExpr):
|
|
e1.scope._check_same_scope(e2, when="for linear combination")
|
|
else:
|
|
if not isinstance(e2, _DimExpr):
|
|
return e1 * k1 + e2 * k2 # Constants
|
|
e1_terms = ((_DimTerm_one, op.index(e1)),)
|
|
if isinstance(e2, _DimExpr):
|
|
e2_terms = e2._sorted_terms
|
|
elif e2 == 0:
|
|
e2_terms = ()
|
|
else:
|
|
e2_terms = ((_DimTerm_one, op.index(e2)),)
|
|
new_terms = _DimExpr._linear_combination_sorted_pairs(e1_terms, 0, k1,
|
|
e2_terms, 0, k2)
|
|
return _DimExpr._normalize_sorted_terms(new_terms, scope)
|
|
|
|
# We overload +, -, *, because they are fully defined for _DimExpr.
|
|
def __add__(self, other):
|
|
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
|
return self.__jax_array__().__add__(other)
|
|
if isinstance(other, int) and other == 0: return self
|
|
return _DimExpr._linear_combination(self, 1, other, 1, self.scope)
|
|
|
|
def __radd__(self, other):
|
|
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
|
return self.__jax_array__().__radd__(other)
|
|
if isinstance(other, int) and other == 0: return self
|
|
return _DimExpr._linear_combination(self, 1, other, 1, self.scope)
|
|
|
|
def __sub__(self, other):
|
|
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
|
return self.__jax_array__().__sub__(other)
|
|
if isinstance(other, int) and other == 0: return self
|
|
return _DimExpr._linear_combination(self, 1, other, -1, self.scope)
|
|
|
|
def __rsub__(self, other):
|
|
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
|
return self.__jax_array__().__rsub__(other)
|
|
return _DimExpr._linear_combination(self, -1, other, 1, self.scope)
|
|
|
|
def __neg__(self) -> DimSize:
|
|
return _DimExpr._linear_combination(self, -1, 0, 0, self.scope)
|
|
|
|
def __mul__(self, other):
|
|
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
|
return self.__jax_array__().__mul__(other)
|
|
if isinstance(other, int):
|
|
if other == 1: return self
|
|
if other == 0: return 0
|
|
return _DimExpr._linear_combination(self, other, 0, 0, self.scope)
|
|
other = _ensure_poly(other, "mul", self.scope)
|
|
coeffs: dict[_DimTerm, int] = {}
|
|
for mon1, coeff1 in self._sorted_terms:
|
|
for mon2, coeff2 in other._sorted_terms:
|
|
mon = mon1.mul(mon2)
|
|
_DimExpr._add_coeff(coeffs, mon, coeff1 * coeff2)
|
|
return _DimExpr._normalize_sorted_terms(_DimExpr._coeff_to_sorted_terms(coeffs),
|
|
self.scope)
|
|
|
|
def __rmul__(self, other):
|
|
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
|
return self.__jax_array__().__rmul__(other)
|
|
if isinstance(other, int):
|
|
if other == 1: return self
|
|
if other == 0: return 0
|
|
return _DimExpr._linear_combination(self, other, 0, 0, self.scope)
|
|
return _ensure_poly(other, "mul", self.scope).__mul__(self)
|
|
|
|
def __pow__(self, power, modulo=None):
|
|
assert modulo is None
|
|
try:
|
|
power = int(power)
|
|
except:
|
|
raise InconclusiveDimensionOperation(f"Symbolic dimension cannot be raised to non-integer power '{self}' ^ '{power}'")
|
|
return functools.reduce(op.mul, [self] * power)
|
|
|
|
def __floordiv__(self, divisor):
|
|
if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor):
|
|
return self.__jax_array__().__floordiv__(divisor)
|
|
return self._divmod(divisor)[0]
|
|
|
|
def __rfloordiv__(self, other):
|
|
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
|
return self.__jax_array__().__rfloordiv__(other)
|
|
return _ensure_poly(other, "floordiv", self.scope).__floordiv__(self)
|
|
|
|
def __truediv__(self, divisor):
|
|
# Used for "/", which always returns a float
|
|
return self.__jax_array__().__truediv__(divisor)
|
|
|
|
def __rtruediv__(self, dividend):
|
|
# Used for "/", when dividend is not a _DimExpr
|
|
return self.__jax_array__().__rtruediv__(dividend)
|
|
|
|
def __mod__(self, divisor):
|
|
if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor):
|
|
return self.__jax_array__().__mod__(divisor)
|
|
return self._divmod(divisor)[1]
|
|
|
|
def __rmod__(self, dividend):
|
|
if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend):
|
|
return self.__jax_array__().__rmod__(dividend)
|
|
return _ensure_poly(dividend, "mod", self.scope).__mod__(self)
|
|
|
|
def __divmod__(self, divisor):
|
|
if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor):
|
|
return self.__jax_array__().__divmod__(divisor)
|
|
return self._divmod(divisor)
|
|
|
|
def __rdivmod__(self, dividend):
|
|
if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend):
|
|
return self.__jax_array__().__rdivmod__(dividend)
|
|
return _ensure_poly(dividend, "divmod", self.scope).__divmod__(self)
|
|
|
|
def __int__(self):
|
|
if (c := _DimExpr._to_constant(self)) is not None:
|
|
return c
|
|
raise InconclusiveDimensionOperation(f"Symbolic dimension '{self}' used in a context that requires a constant")
|
|
|
|
# We must overload __eq__ and __ne__, or else we get unsound defaults.
|
|
def __eq__(self, other: Any) -> bool:
|
|
if isinstance(other, _DimExpr):
|
|
if self.scope is not other.scope:
|
|
return False
|
|
elif not core.is_constant_dim(other):
|
|
return False
|
|
|
|
# Equality is used very frequently because expressions are cached. We could
|
|
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
|
|
# but that would be too expensive. It would also have the unfortunate drawback
|
|
# that we cannot then cache `e.bounds()` because hashing invokes equality
|
|
# which would lead to infinite recursion.
|
|
diff = self - other
|
|
|
|
# We look for `self - other == k`, and we rely on the fact that when we
|
|
# normalize _DimExpr that represent integers as ints.
|
|
if is_symbolic_dim(diff):
|
|
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
|
|
# cannot raise exceptions, because it is used indirectly when hashing.
|
|
# So, we say that the expressions are disequal, which is really unsound.
|
|
# See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported
|
|
return False
|
|
|
|
return diff == 0
|
|
|
|
def __ne__(self, other: Any) -> bool:
|
|
return not self.__eq__(other)
|
|
|
|
def __ge__(self, other: DimSize) -> bool:
|
|
return _geq_decision(self, other, lambda: f"'{self}' >= '{other}'")
|
|
|
|
def __le__(self, other: DimSize):
|
|
return _geq_decision(other, self, lambda: f"'{self}' <= '{other}'")
|
|
|
|
def __gt__(self, other: DimSize):
|
|
return not _geq_decision(other, self, lambda: f"'{self}' > '{other}'")
|
|
|
|
def __lt__(self, other: DimSize):
|
|
return not _geq_decision(self, other, lambda: f"'{self}' < '{other}'")
|
|
|
|
def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]:
|
|
"""
|
|
Floor division with remainder (divmod) generalized to expressions.
|
|
If the `divisor` is not a constant, the remainder must be 0.
|
|
If the `divisor` is a constant, the remainder may be non 0, for consistency
|
|
with integer divmod.
|
|
|
|
:return: Quotient resulting from polynomial division and integer remainder.
|
|
"""
|
|
try:
|
|
dividend, quotient = self, 0
|
|
# invariant: self = dividend + divisor * quotient
|
|
# quotient and dividend are changed in the loop; the leading term of
|
|
# dividend decreases at each iteration.
|
|
while is_symbolic_dim(dividend) and not dividend._is_constant: # type: ignore[attribute-error,unused-ignore]
|
|
mon, count = dividend._leading_term
|
|
if isinstance(divisor, _DimExpr):
|
|
dterm, dcount = divisor._leading_term
|
|
qterm = mon.divide(dterm)
|
|
else:
|
|
qterm, dcount = mon, int(divisor)
|
|
qcount, rcount = divmod(count, dcount)
|
|
if rcount != 0:
|
|
raise InconclusiveDimensionOperation("")
|
|
|
|
q = _DimExpr._from_term(qterm, qcount, self.scope)
|
|
quotient += q
|
|
dividend -= q * divisor
|
|
|
|
dividend = int(dividend) # type: ignore[assignment]
|
|
if isinstance(divisor, _DimExpr):
|
|
if dividend != 0:
|
|
raise InconclusiveDimensionOperation("")
|
|
remainder = 0
|
|
else:
|
|
q, r = divmod(dividend, int(divisor))
|
|
quotient += q
|
|
remainder = r
|
|
|
|
if config.enable_checks.value:
|
|
v1 = divisor * quotient
|
|
v2 = v1 + remainder
|
|
assert self == _ensure_poly(v2, "check", self.scope), (
|
|
self, v2, type(self), type(v2))
|
|
assert self == _ensure_poly(divisor * quotient + remainder, "test", self.scope), (
|
|
self, divisor, quotient, remainder)
|
|
return quotient, remainder
|
|
except InconclusiveDimensionOperation:
|
|
return (_DimExpr._from_operation(_DimFactor.FLOORDIV, self, divisor,
|
|
scope=self.scope), # type: ignore
|
|
_DimExpr._from_operation(_DimFactor.MOD, self, divisor,
|
|
scope=self.scope))
|
|
|
|
def _evaluate(self, env: DimVarEnv):
|
|
# Evaluates as a value of dtype=core.dim_value_dtype()
|
|
terms = [_evaluate_multiply(t.evaluate(env, self.scope), core.dim_constant(t_k))
|
|
for t, t_k in self._sorted_terms]
|
|
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
|
|
|
|
def max(self, other: DimSize) -> DimSize:
|
|
lb, ub = _bounds_decision(self - other, BoundsPrecision.FOR_GEQ0_OR_LEQ0)
|
|
if 0 <= lb: return self
|
|
if ub <= 0: return other
|
|
return _DimExpr._from_operation(_DimFactor.MAX, self, other, scope=self.scope)
|
|
|
|
def rmax(self, other: DimSize) -> DimSize:
|
|
lb, ub = _bounds_decision(self - other, BoundsPrecision.FOR_GEQ0_OR_LEQ0)
|
|
if 0 <= lb: return self
|
|
if ub <= 0: return other
|
|
return _DimExpr._from_operation(_DimFactor.MAX, other, self, scope=self.scope)
|
|
|
|
def min(self, other: DimSize) -> DimSize:
|
|
lb, ub = _bounds_decision(self - other, BoundsPrecision.FOR_GEQ0_OR_LEQ0)
|
|
if 0 <= lb: return other
|
|
if ub <= 0: return self
|
|
return _DimExpr._from_operation(_DimFactor.MIN, self, other, scope=self.scope)
|
|
|
|
def rmin(self, other: DimSize) -> DimSize:
|
|
lb, ub = _bounds_decision(self - other, BoundsPrecision.FOR_GEQ0_OR_LEQ0)
|
|
if 0 <= lb: return other
|
|
if ub <= 0: return self
|
|
return _DimExpr._from_operation(_DimFactor.MIN, other, self, scope=self.scope)
|
|
|
|
@staticmethod
|
|
def _get_aval(dim: _DimExpr):
|
|
return core.dim_value_aval()
|
|
|
|
def dimension_as_value(self):
|
|
"""Turns a dimension size into a Jax value that we can compute with."""
|
|
return _dim_as_value(self)
|
|
|
|
def __jax_array__(self):
|
|
# Used for implicit coercions of polynomials as JAX arrays
|
|
return _dim_as_value(self)
|
|
|
|
def __deepcopy__(self, memo):
|
|
return _DimExpr(
|
|
copy.deepcopy(self._sorted_terms, memo),
|
|
copy.deepcopy(self._scope, memo))
|
|
|
|
|
|
def cmp_comparable(i1, i2) -> int:
|
|
if i1 < i2: return -1
|
|
if i1 > i2: return 1
|
|
return 0
|
|
|
|
def cmp_sequence(s1, s2, elem_cmp) -> int:
|
|
"""Compares two sequences using `elem_cmp`."""
|
|
l2 = len(s2)
|
|
for i, e1 in enumerate(s1):
|
|
if i >= l2: return 1
|
|
if c := elem_cmp(e1, s2[i]): return c
|
|
if len(s1) < l2: return -1
|
|
return 0
|
|
|
|
|
|
class SymbolicScope:
|
|
"""Indentifies a scope for symbolic expressions.
|
|
|
|
All symbolic expressions that interact (e.g., appear in the argument shapes
|
|
for one JAX function invocation, or are involved in arithmetic operations)
|
|
must be from the same scope and must share the same SymbolicScope object.
|
|
|
|
Holds the constraints on symbolic expressions.
|
|
|
|
See [the README](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints)
|
|
for more details.
|
|
|
|
Args:
|
|
constraints_str: A sequence of constraints on symbolic dimension expressions,
|
|
of the form `e1 >= e2` or `e1 <= e2` or `e1 == e2`.
|
|
"""
|
|
|
|
def __init__(self,
|
|
constraints_str: Sequence[str] = ()):
|
|
if isinstance(constraints_str, str):
|
|
raise ValueError(
|
|
"The symbolic constraints should be a sequence of strings. "
|
|
f"Got {repr(constraints_str)}")
|
|
self._initialized = False
|
|
self._location_frame = source_info_util.user_frame(source_info_util.current())
|
|
# Keep the explicit constraints in the order in which they were added
|
|
self._explicit_constraints: list[_SymbolicConstraint] = []
|
|
|
|
# We cache the _DimExpr.bounds calls. The result depends only on the
|
|
# explicit and implicit constraints, so it is safe to keep it in the
|
|
# scope. Set the cache before we parse constraints. We also keep the
|
|
# bounds precision with which we computed the cached result.
|
|
self._bounds_cache: dict[_DimExpr,
|
|
tuple[float, float, BoundsPrecision]] = {}
|
|
|
|
# We store here a decision procedure state initialized with all the
|
|
# _explicit_constraints.
|
|
self._decision_initial_state: Any | None = None
|
|
|
|
# We turn the equality constraints into normalization rules.
|
|
# For an explicit constraint `t*tk == e`, we keep
|
|
# `_normalization_rules[t] = (e, tk)`.
|
|
# During building of expressions, if we encounter the term
|
|
# `t*tk1` and `tk1 % tk == 0`, we replace it with `e*(tk1 // tk)`.
|
|
self._normalization_rules: NormalizationRules = {}
|
|
|
|
for c_str in constraints_str:
|
|
self._parse_and_process_explicit_constraint(c_str)
|
|
self._bounds_cache.clear()
|
|
self._initialized = True
|
|
|
|
def __str__(self) -> str:
|
|
extras = []
|
|
if self._explicit_constraints:
|
|
extras.append(" with constraints:")
|
|
for constr in self._explicit_constraints:
|
|
extras.append(f" {constr.debug_str}")
|
|
loc = source_info_util._summarize_frame(self._location_frame) if self._location_frame else "unknown"
|
|
return f"{id(self)} created at {loc}" + "\n".join(extras)
|
|
__repr__ = __str__
|
|
|
|
def _parse_and_process_explicit_constraint(self, c_str: str):
|
|
if not isinstance(c_str, str):
|
|
raise ValueError(
|
|
f"SymbolicScope constraint must be a string: got {repr(c_str)}")
|
|
cmp_pos, cmp, is_geq = c_str.find("=="), Comparator.EQ, True
|
|
if cmp_pos < 0:
|
|
cmp_pos, cmp, is_geq = c_str.find(">="), Comparator.GEQ, True
|
|
if cmp_pos < 0:
|
|
cmp_pos, cmp, is_geq = c_str.find("<="), Comparator.GEQ, False
|
|
if cmp_pos < 0:
|
|
raise ValueError("Constraint parsing error: must contain one of '==' or '>=' or '<='")
|
|
e1_str = c_str[:cmp_pos]
|
|
e1, = _Parser(e1_str, None, repr(e1_str), self).parse() # type: ignore[name-error,unused-ignore]
|
|
e2_str = c_str[cmp_pos + 2:]
|
|
e2, = _Parser(e2_str, None, repr(e2_str), self).parse() # type: ignore[name-error,unused-ignore]
|
|
if cmp == Comparator.GEQ and not is_geq:
|
|
e1, e2 = e2, e1
|
|
|
|
diff = e1 - e2
|
|
if (diff_const := _DimExpr._to_constant(diff)) is not None:
|
|
if ((cmp == Comparator.EQ and diff_const != 0) or
|
|
(cmp == Comparator.GEQ and diff_const < 0)):
|
|
raise ValueError(f"Unsatisfiable explicit constraint: {c_str}")
|
|
return
|
|
|
|
if cmp == Comparator.EQ:
|
|
if not isinstance(e1, _DimExpr):
|
|
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
|
|
"The left-hand-side must be of the form `term * coefficient`.")
|
|
(before, before_k), *rest = e1._sorted_terms
|
|
if rest:
|
|
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
|
|
"The left-hand-side must be of the form `term * coefficient`.")
|
|
|
|
after = _ensure_poly(e2, "parse_constraint", e1.scope) # type: ignore[name-error,unused-ignore]
|
|
if before in self._normalization_rules:
|
|
raise NotImplementedError(
|
|
f"Found multiple equality constraints with the same left-hand-side: {before}")
|
|
self._normalization_rules[before] = (after, before_k)
|
|
|
|
constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2)
|
|
self._explicit_constraints.append(constr)
|
|
|
|
def _check_same_scope(self, other: _DimExpr,
|
|
when: str = "",
|
|
self_descr: str = " ",
|
|
other_descr: str = "unknown"):
|
|
if self is not other.scope:
|
|
raise ValueError(
|
|
f"Invalid mixing of symbolic scopes {when}.\n"
|
|
f"Expected {self_descr}scope {self}\n"
|
|
f"and found for '{other}' ({other_descr}) scope {other.scope}\n"
|
|
f"See https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints.")
|
|
|
|
def _clear_caches(self):
|
|
self._bounds_cache.clear()
|
|
|
|
|
|
class BoundsPrecision(enum.Enum):
|
|
"""Specifies desired precision for the bounds calculation.
|
|
|
|
Since the bounds calculations are expensive, we allow the caller to specify
|
|
a sufficient condition for a result. As the bounds calculations progresses
|
|
the lower bounds is progressively increased and the upper bounds is
|
|
progressively decreased. Depending on the precision, we may stop the
|
|
computation early, if the results are sufficient for the use case.
|
|
|
|
The enumeration values are chosen such that, if "(lb, ub)" are sufficient
|
|
for a precision value then they are also sufficient for any smaller
|
|
precision.
|
|
"""
|
|
|
|
# For static evaluation of "max(e1, e2)", can stop if "lb >= 0 OR ub <= 0"
|
|
FOR_GEQ0_OR_LEQ0 = 0
|
|
|
|
# For deciding inequalities, such as "e1 >= e2", can stop if "lb >= 0 OR ub < 0"
|
|
FOR_GEQ0_OR_LT0 = 1
|
|
|
|
# Do not stop early, refine bounds as much as possible.
|
|
BEST = 2
|
|
|
|
def _bounds_are_sufficient(self, lb: float, ub: float) -> bool:
|
|
if self == BoundsPrecision.FOR_GEQ0_OR_LEQ0:
|
|
return lb >= 0 or ub <= 0
|
|
if self == BoundsPrecision.FOR_GEQ0_OR_LT0:
|
|
return lb >= 0 or ub < 0
|
|
return False
|
|
|
|
#
|
|
# Calling convention:
|
|
# _bounds_decision(e, stop_early)
|
|
# returns a tuple with the lower and upper bound of e.
|
|
# `stop_early(lb, ub)` can be called in an iterative process to decide if the
|
|
# current bounds are tight enough.
|
|
# TODO: remove this trampoline when we refactor the sources
|
|
def _bounds_decision_unimplemented(
|
|
d: DimSize,
|
|
prec: BoundsPrecision) -> tuple[float, float]:
|
|
del d, prec
|
|
raise NotImplementedError("_bounds_decision is uninitialized")
|
|
|
|
_bounds_decision: Callable[[DimSize, BoundsPrecision],
|
|
tuple[float, float]] = _bounds_decision_unimplemented
|
|
|
|
def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
|
|
"""Implements `e1 >= e2`.
|
|
|
|
Args:
|
|
e1, e2: the expressions to compare for greater-equal
|
|
cmp_str: a callable such that `cmp_str()` describes the comparison
|
|
for error messages, e.g., "a <= b". Without this all comparisons would
|
|
be reported as ">=".
|
|
|
|
Raises InconclusiveDimensionOperation if the result is not conclusive.
|
|
"""
|
|
if isinstance(e1, _DimExpr):
|
|
scope = e1.scope
|
|
if isinstance(e2, _DimExpr):
|
|
scope._check_same_scope(e2, f"when comparing {cmp_str()}")
|
|
elif isinstance(e2, _DimExpr):
|
|
scope = e2.scope
|
|
else:
|
|
return int(e1) >= int(e2)
|
|
lb, ub = _bounds_decision(e1 - e2, BoundsPrecision.FOR_GEQ0_OR_LT0)
|
|
if lb >= 0:
|
|
return True
|
|
if ub < 0:
|
|
return False
|
|
|
|
if scope._explicit_constraints:
|
|
describe_scope = f"\nUsing symbolic scope {scope}"
|
|
else:
|
|
describe_scope = ""
|
|
raise InconclusiveDimensionOperation(
|
|
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")
|
|
|
|
core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
|
|
xla.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
|
|
dtypes._weak_types.append(_DimExpr)
|
|
|
|
def _convertible_to_int(p: DimSize) -> bool:
|
|
try:
|
|
op.index(p) # type: ignore
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
def _ensure_poly(p: DimSize,
|
|
operation_name: str,
|
|
scope: SymbolicScope) -> _DimExpr:
|
|
if isinstance(p, _DimExpr):
|
|
scope._check_same_scope(p, when=f"for operation {operation_name}")
|
|
return p
|
|
if _convertible_to_int(p):
|
|
return _DimExpr(((_DimTerm_one, op.index(p)),), scope)
|
|
raise TypeError(f"Symbolic dimension {operation_name} not supported for {p}.")
|
|
|
|
def _convertible_to_poly(p: DimSize) -> bool:
|
|
return isinstance(p, _DimExpr) or _convertible_to_int(p)
|
|
|
|
def is_symbolic_dim(p: DimSize) -> bool:
|
|
"""Checks if a dimension is symbolic.
|
|
"""
|
|
return isinstance(p, _DimExpr)
|
|
|
|
dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int]
|
|
|
|
def _einsum_contract_path(*operands, **kwargs):
|
|
"""Like opt_einsum.contract_path, with support for DimExpr shapes.
|
|
|
|
We use opt_einsum.contract_path to compute the schedule, using a fixed
|
|
constant for all dimension variables. This is safe because we throw an
|
|
error if there are more than 1 contractions. Essentially, we just use
|
|
opt_einsum.contract_path to parse the specification.
|
|
"""
|
|
|
|
# Replace the polymorphic shapes with some concrete shapes for calling
|
|
# into opt_einsum.contract_path, because the latter wants to compute the
|
|
# sizes of operands and intermediate results.
|
|
fake_ops = []
|
|
for operand in operands:
|
|
# We replace only array operands
|
|
if not hasattr(operand, "dtype"):
|
|
fake_ops.append(operand)
|
|
else:
|
|
shape = np.shape(operand)
|
|
def fake_dim(d):
|
|
if core.is_constant_dim(d):
|
|
return d
|
|
else:
|
|
if not isinstance(d, _DimExpr):
|
|
raise TypeError(f"Encountered unexpected shape dimension {d}")
|
|
# It is Ok to replace all polynomials with the same value. We may miss
|
|
# here some errors due to non-equal dimensions, but we catch them
|
|
# later.
|
|
return 8
|
|
fake_ops.append(jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)),
|
|
operand.dtype))
|
|
|
|
contract_fake_ops, contractions = opt_einsum.contract_path(*fake_ops,
|
|
**kwargs)
|
|
contract_operands = []
|
|
for operand in contract_fake_ops:
|
|
idx = tuple(i for i, fake_op in enumerate(fake_ops) if operand is fake_op)
|
|
assert len(idx) == 1
|
|
contract_operands.append(operands[idx[0]])
|
|
return contract_operands, contractions
|
|
|
|
lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path
|
|
|
|
# To implement shape-constraint checking we use a shape assertion primitive.
|
|
# shape_assertion_p.bind(assert_what: bool, *error_message_inputs,
|
|
# error_message="...{0}...{1}")
|
|
# where "{0}" refers to error_message_inputs[0], etc.
|
|
shape_assertion_p = core.Primitive("shape_assertion")
|
|
shape_assertion_p.multiple_results = True
|
|
shape_assertion_p.def_effectful_abstract_eval(
|
|
lambda *_, **__: ((), {shape_assertion_effect})) # type: ignore
|
|
|
|
def _shape_assertion_lowering_rule(ctx: mlir.LoweringRuleContext,
|
|
assert_what: mlir.ir.Value,
|
|
*error_message_inputs: mlir.ir.Value,
|
|
error_message: str):
|
|
op = mlir.custom_call(
|
|
"shape_assertion",
|
|
result_types=[], # No results
|
|
operands=[assert_what, *error_message_inputs],
|
|
has_side_effect=True,
|
|
extra_attributes=dict(error_message=mlir.ir.StringAttr.get(error_message))
|
|
)
|
|
return op.results
|
|
|
|
mlir.register_lowering(shape_assertion_p, _shape_assertion_lowering_rule)
|
|
|
|
class ShapeAssertionEffect(effects.Effect):
|
|
__str__ = lambda _: "ShapeAssertionEffect"
|
|
|
|
shape_assertion_effect = ShapeAssertionEffect()
|
|
|
|
effects.lowerable_effects.add_type(ShapeAssertionEffect)
|
|
effects.control_flow_allowed_effects.add_type(ShapeAssertionEffect)
|
|
effects.remat_allowed_effects.add_type(ShapeAssertionEffect)
|
|
effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect)
|
|
|
|
def shape_assertion(assert_what: jax.Array,
|
|
*error_message_inputs: jax.Array,
|
|
error_message: str) -> None:
|
|
"""Adds a shape assertion in the code.
|
|
|
|
Args:
|
|
assert_what: a boolean asserted to be true. Must be computed based only
|
|
on dimension expressions, so that it can be evaluated after shape
|
|
refinement.
|
|
error_message_inputs: integers expressions whose values can be referenced
|
|
in the `error_message`. Must be computed based only
|
|
on dimension expressions, so that they can be evaluated after shape
|
|
refinement.
|
|
error_message: an error message, possibly containing format specifiers
|
|
{0}, {1}, ..., referencing the values of the `error_message_inputs`.
|
|
The format specifiers are sometimes processed with Python's
|
|
`string::format` method, and sometimes with `llvm::formatv`.
|
|
"""
|
|
shape_assertion_p.bind(assert_what, *error_message_inputs,
|
|
error_message=error_message)
|
|
|
|
# A JAX primitive with no array arguments but with a dimension parameter
|
|
# that is a DimExpr. The value of the primitive is the value of the dimension,
|
|
# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype())
|
|
dim_as_value_p = core.Primitive("dim_as_value")
|
|
dim_as_value_p.def_abstract_eval(lambda dim: core.dim_value_aval())
|
|
|
|
def dim_as_value_impl(dim: DimSize):
|
|
raise NotImplementedError(
|
|
"Evaluation rule for 'dim_as_value' is not implemented. "
|
|
"It seems that you are using shape polymorphism outside jax.export.")
|
|
|
|
dim_as_value_p.def_impl(dim_as_value_impl)
|
|
def _dim_as_value(dim: DimSize):
|
|
return dim_as_value_p.bind(dim=dim)
|
|
|
|
def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *,
|
|
dim):
|
|
res, = mlir.eval_dynamic_shape(ctx, (dim,))
|
|
out_type = mlir.aval_to_ir_type(ctx.avals_out[0])
|
|
if out_type != res.type: # type: ignore
|
|
return [mlir.hlo.convert(out_type, res)]
|
|
else:
|
|
return [res]
|
|
|
|
mlir.register_lowering(dim_as_value_p, _dim_as_value_lowering)
|
|
|
|
|
|
class PolyShape(tuple):
|
|
"""Tuple of polymorphic dimension specifications.
|
|
|
|
See docstring of :func:`jax2tf.convert`.
|
|
"""
|
|
|
|
def __init__(self, *dim_specs):
|
|
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
|
|
DeprecationWarning, stacklevel=2)
|
|
tuple.__init__(dim_specs)
|
|
|
|
def __new__(cls, *dim_specs):
|
|
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
|
|
DeprecationWarning, stacklevel=2)
|
|
for ds in dim_specs:
|
|
if not isinstance(ds, (int, str)) and ds != ...:
|
|
msg = (f"Invalid polymorphic shape element: {ds!r}; must be a string "
|
|
"representing a dimension variable, or an integer, or ...")
|
|
raise ValueError(msg)
|
|
return tuple.__new__(PolyShape, dim_specs)
|
|
|
|
def __str__(self):
|
|
return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")"
|
|
|
|
|
|
def symbolic_shape(shape_spec: str | None,
|
|
*,
|
|
constraints: Sequence[str] = (),
|
|
scope: SymbolicScope | None = None,
|
|
like: Sequence[int | None] | None = None
|
|
) -> Sequence[DimSize]:
|
|
"""Constructs a symbolic shape from a string representation.
|
|
|
|
See https://jax.readthedocs.io/en/latest/export/shape_poly.html for examples.
|
|
|
|
Args:
|
|
shape_spec: a symbolic shape specification. None stands for "...".
|
|
A shape specification is the string representation of a tuple (the
|
|
parentheses are optional) with comma-separated dimension expressions.
|
|
A dimension expression can be either: an integer constant,
|
|
a dimension variable (alphanumeric
|
|
starting with a letter), e1 + e2, e1 - e2, e1 * e2, floordiv(e1, e2),
|
|
mod(e1, e2), max(e1, e2), or min(e1, e2).
|
|
constraints: a sequence of constraints on symbolic dimension expressions, of
|
|
the form `e1 >= e2` or `e1 <= e2`, or `e1 == e2`.
|
|
See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints)
|
|
for usage.
|
|
scope: optionally, you can specify that the parsed symbolic expressions
|
|
be created in the given scope. If this is missing, then a new
|
|
`SymbolicScope` is created with the given `constraints`.
|
|
You cannot specify both a `scope` and `constraints`.
|
|
See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints)
|
|
for usage.
|
|
like: when `shape_spec` contains placeholders ("_", "..."), use this
|
|
shape to fill in the placeholders.
|
|
The dimensions of `like` that are used for filling
|
|
must be not `None`. If a dimension in `like` is not `None` and
|
|
the corresponding dimension in `shape_spec` is a constant then they
|
|
must be equal.
|
|
|
|
Returns: a tuple with integers or symbolic expressions involving dimension variables.
|
|
"""
|
|
shape_spec_repr = repr(shape_spec)
|
|
if shape_spec is None:
|
|
shape_spec = "..."
|
|
elif isinstance(shape_spec, PolyShape): # TODO: deprecate
|
|
shape_spec = str(shape_spec)
|
|
elif not isinstance(shape_spec, str):
|
|
raise ValueError("polymorphic shape spec should be None or a string. "
|
|
f"Found {shape_spec_repr}.")
|
|
if scope is None:
|
|
scope = SymbolicScope(constraints)
|
|
elif constraints:
|
|
raise ValueError("Cannot specify both a `scope` and `constraints`.")
|
|
dimensions = _Parser(shape_spec, like, shape_spec_repr, scope).parse()
|
|
return dimensions
|
|
|
|
def symbolic_args_specs(
|
|
args, # pytree of arguments
|
|
shapes_specs, # prefix pytree of strings
|
|
constraints: Sequence[str] = (),
|
|
scope: SymbolicScope | None = None,
|
|
):
|
|
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
|
|
|
|
See the documentation of :func:`jax.export.symbolic_shape` and
|
|
the [shape polymorphism documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html) for details.
|
|
|
|
Args:
|
|
args: a pytree of arguments. These can be jax.Array, or jax.ShapeDTypeSpec.
|
|
They are used to learn the pytree structure of the arguments, their dtypes,
|
|
and to fill-in the actual shapes where the `shapes_specs` contains
|
|
placeholders. Note that only the shape dimensions for which
|
|
`shapes_specs` is a placeholder are used from `args`.
|
|
shapes_specs: should be `None` (all arguments have static shapes),
|
|
a single string (see `shape_spec` for :func:`jax.export.symbolic_shape`;
|
|
applies to all arguments), or a pytree matching a prefix
|
|
of the `args`.
|
|
See [how optional parameters are matched to
|
|
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
|
constraints: as for :func:`jax.export.symbolic_shape`.
|
|
scope: as for :func:`jax.export.symbolic_shape`.
|
|
|
|
Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes
|
|
replaced with symbolic dimensions as specified by `shapes_specs`.
|
|
"""
|
|
polymorphic_shapes = shapes_specs
|
|
args_flat, args_tree = tree_util.tree_flatten(args)
|
|
|
|
shapes_and_dtypes = tuple(map(shape_and_dtype_jax_array, args_flat))
|
|
shapes, dtypes = util.unzip2(shapes_and_dtypes)
|
|
|
|
if isinstance(args, tuple) and isinstance(polymorphic_shapes, list):
|
|
# TODO: Remove backward-compatibility workaround
|
|
polymorphic_shapes_ = tuple(polymorphic_shapes)
|
|
else:
|
|
polymorphic_shapes_ = polymorphic_shapes
|
|
|
|
try:
|
|
polymorphic_shapes_flat = tree_util.broadcast_prefix(
|
|
polymorphic_shapes_, args,
|
|
is_leaf=lambda x: x is None)
|
|
except ValueError:
|
|
e, *_ = tree_util.prefix_errors(
|
|
polymorphic_shapes_, args,
|
|
is_leaf=lambda x: x is None)
|
|
raise e("export.symbolic_args_specs shapes_specs") from None
|
|
|
|
# Now add in the polymorphic shapes
|
|
if scope is None:
|
|
scope = SymbolicScope(constraints)
|
|
elif constraints:
|
|
raise ValueError("Cannot use both `scope` and `constraints`")
|
|
args_specs_flat = (
|
|
jax.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=scope), t)
|
|
for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat))
|
|
|
|
return args_tree.unflatten(args_specs_flat)
|
|
|
|
def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
|
"""Returns the shape and dtype of a jax.Array or a j"""
|
|
if isinstance(a, jax.ShapeDtypeStruct):
|
|
return a.shape, a.dtype
|
|
aval = core.raise_to_shaped(core.get_aval(a))
|
|
return aval.shape, aval.dtype
|
|
|
|
|
|
class _Parser:
|
|
def __init__(self,
|
|
shape_spec: str,
|
|
like_shape: Sequence[int | None] | None,
|
|
shape_spec_repr: str,
|
|
scope: SymbolicScope):
|
|
self.shape_spec = shape_spec
|
|
self.shape_spec_repr = shape_spec_repr # For error messages
|
|
self.like_shape = like_shape
|
|
self.dimensions: list[DimSize] = [] # dimensions we have parsed
|
|
self.scope = scope
|
|
|
|
def parse(self) -> Sequence[DimSize]:
|
|
self.tokstream = tokenize.tokenize(
|
|
io.BytesIO(self.shape_spec.encode("utf-8")).readline)
|
|
tok = self.consume_token(self.next_tok(), tokenize.ENCODING) # Always 1st
|
|
sh, tok = self.shape(tok)
|
|
self.expect_token(tok, [tokenize.ENDMARKER])
|
|
return sh
|
|
|
|
def add_dim(self, expr: DimSize | None, tok: tokenize.TokenInfo):
|
|
if expr is None:
|
|
raise self.parse_err(tok,
|
|
("unexpected placeholder for unknown dimension; "
|
|
f"like={self.like_shape}"))
|
|
|
|
if core.is_constant_dim(expr) and self.like_shape is not None:
|
|
like_shape_dim = self.like_shape[len(self.dimensions)]
|
|
if expr != like_shape_dim:
|
|
raise self.parse_err(tok,
|
|
(f"different size {expr} for known dimension; "
|
|
f"like={self.like_shape}"))
|
|
self.dimensions.append(expr)
|
|
|
|
def parse_err(self, tok: tokenize.TokenInfo | None, detail: str) -> Exception:
|
|
msg = (
|
|
f"syntax error in symbolic shape {self.shape_spec_repr} "
|
|
f"in dimension {len(self.dimensions)}: {detail}. ")
|
|
if tok is not None:
|
|
msg += f"Parsed '{tok.line[:tok.start[1]]}', remaining '{tok.line[tok.start[1]:]}'."
|
|
return ValueError(msg)
|
|
|
|
def next_tok(self) -> tokenize.TokenInfo:
|
|
while True:
|
|
try:
|
|
t = next(self.tokstream) # type: ignore[attribute-error,unused-ignore]
|
|
except StopIteration:
|
|
raise self.parse_err(None, "unexpected end of string")
|
|
if t.exact_type not in [tokenize.NEWLINE, tokenize.INDENT, tokenize.DEDENT]:
|
|
return t
|
|
|
|
def expect_token(self, tok: tokenize.TokenInfo, expected: Sequence[int]) -> None:
|
|
if tok.exact_type not in expected:
|
|
msg = ("expecting one of {" +
|
|
", ".join(tokenize.tok_name[t] for t in expected) + "} but found " +
|
|
tokenize.tok_name[tok.exact_type])
|
|
raise self.parse_err(tok, msg)
|
|
|
|
def consume_token(self, tok: tokenize.TokenInfo, expected: int) -> tokenize.TokenInfo:
|
|
self.expect_token(tok, [expected])
|
|
return self.next_tok()
|
|
|
|
def integer(self, tok: tokenize.TokenInfo) -> tuple[int, tokenize.TokenInfo]:
|
|
self.expect_token(tok, [tokenize.NUMBER])
|
|
try:
|
|
val = int(tok.string)
|
|
except Exception:
|
|
raise self.parse_err(tok, f"expecting integer, found {tok.string}")
|
|
return val, self.next_tok()
|
|
|
|
# What can follow a shape?
|
|
FOLLOW_SHAPE = [tokenize.ENDMARKER, tokenize.RPAR]
|
|
def shape(self, tok: tokenize.TokenInfo) -> tuple[Sequence[DimSize], tokenize.TokenInfo]:
|
|
# A comma-separated list of _DimExpr, or "_", possibly ended with ...
|
|
if tok.exact_type == tokenize.LPAR:
|
|
res, tok = self.shape(self.next_tok())
|
|
tok = self.consume_token(tok, tokenize.RPAR)
|
|
return res, tok
|
|
|
|
while True:
|
|
if tok.exact_type in self.FOLLOW_SHAPE:
|
|
break
|
|
# Error checking in presence of placeholders
|
|
if (tok.exact_type == tokenize.ELLIPSIS or
|
|
tok.exact_type == tokenize.NAME and tok.string == "_"):
|
|
if self.like_shape is None:
|
|
raise self.parse_err(tok,
|
|
"spec contains ... but no 'like' shape was given")
|
|
if tok.exact_type == tokenize.ELLIPSIS:
|
|
min_len_like_shape = len(self.dimensions)
|
|
else:
|
|
min_len_like_shape = len(self.dimensions) + 1
|
|
if len(self.like_shape) < min_len_like_shape:
|
|
raise self.parse_err(
|
|
tok,
|
|
f"cannot resolve placeholder '{tok.string}' because we parsed "
|
|
f"{len(self.dimensions)} already and 'like' shape has "
|
|
f"only {len(self.like_shape)} dimensions")
|
|
if tok.exact_type == tokenize.ELLIPSIS:
|
|
to_add = self.like_shape[len(self.dimensions):] # type: ignore[index]
|
|
for ad in to_add:
|
|
self.add_dim(ad, tok)
|
|
tok = self.next_tok()
|
|
break
|
|
|
|
if tok.exact_type == tokenize.NAME and tok.string == "_":
|
|
e = self.like_shape[len(self.dimensions)] # type: ignore[index]
|
|
tok = self.next_tok()
|
|
else:
|
|
e, tok = self.expr(tok)
|
|
self.add_dim(e, tok)
|
|
if tok.exact_type in self.FOLLOW_SHAPE:
|
|
break
|
|
tok = self.consume_token(tok, tokenize.COMMA)
|
|
|
|
return tuple(self.dimensions), tok
|
|
|
|
# What token can follow a _DimExpr
|
|
FOLLOW_EXPR = FOLLOW_SHAPE + [tokenize.COMMA]
|
|
|
|
def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
|
|
# A sum of terms
|
|
next_t_negated = (tok.exact_type == tokenize.MINUS)
|
|
if next_t_negated:
|
|
tok = self.next_tok()
|
|
elif tok.exact_type == tokenize.PLUS:
|
|
tok = self.next_tok()
|
|
acc = None
|
|
while True:
|
|
t, tok = self.term(tok)
|
|
t_sign = - t if next_t_negated else t
|
|
acc = acc + t_sign if acc is not None else t_sign # type: ignore[operator]
|
|
if tok.exact_type in self.FOLLOW_EXPR:
|
|
return acc, tok
|
|
next_t_negated = (tok.exact_type == tokenize.MINUS)
|
|
self.expect_token(tok, [tokenize.PLUS, tokenize.MINUS])
|
|
tok = self.next_tok()
|
|
|
|
FOLLOW_TERM = FOLLOW_EXPR + [tokenize.PLUS, tokenize.MINUS]
|
|
def term(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
|
|
# A term is product of factors. Each factor may be raised to an integer power.
|
|
acc = None
|
|
while True:
|
|
f, tok = self.factor(tok)
|
|
if tok.exact_type == tokenize.CIRCUMFLEX:
|
|
tok = self.next_tok()
|
|
self.expect_token(tok, [tokenize.NUMBER])
|
|
power, tok = self.integer(tok)
|
|
f = f ** power
|
|
|
|
acc = acc * f if acc is not None else f # type: ignore[operator]
|
|
if tok.exact_type in self.FOLLOW_TERM:
|
|
return acc, tok # type: ignore[bad-return-type,unused-ignore]
|
|
tok = self.consume_token(tok, tokenize.STAR)
|
|
|
|
def factor(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
|
|
if tok.exact_type == tokenize.NAME:
|
|
if tok.string in (_DimFactor.MOD, _DimFactor.FLOORDIV, _DimFactor.MAX, _DimFactor.MIN):
|
|
return self.factor_binary_op(tok.string, self.next_tok())
|
|
if tok.string == _DimFactor.NON_NEGATIVE: # We still parse this for backwards compatibility
|
|
return self.factor_unary_op(_DimFactor.NON_NEGATIVE, self.next_tok())
|
|
return _DimExpr._from_var(tok.string, self.scope), self.next_tok()
|
|
number_sign = 1
|
|
if tok.exact_type == tokenize.MINUS: # -k are negative constants
|
|
number_sign = -1
|
|
tok = self.next_tok()
|
|
self.expect_token(tok, [tokenize.NUMBER])
|
|
if tok.exact_type == tokenize.NUMBER:
|
|
v, tok = self.integer(tok)
|
|
return v * number_sign, tok
|
|
self.expect_token(tok, [tokenize.NAME, tokenize.MINUS, tokenize.NUMBER])
|
|
assert False
|
|
|
|
def factor_unary_op(self, op: str, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]:
|
|
tok = self.consume_token(tok, tokenize.LPAR)
|
|
e1, tok = self.expr(tok)
|
|
tok = self.consume_token(tok, tokenize.RPAR)
|
|
return _DimExpr._from_operation(op, e1,
|
|
scope=self.scope), tok
|
|
|
|
def factor_binary_op(self, op: str, tok) -> tuple[DimSize, tokenize.TokenInfo]:
|
|
tok = self.consume_token(tok, tokenize.LPAR)
|
|
e1, tok = self.expr(tok)
|
|
tok = self.consume_token(tok, tokenize.COMMA)
|
|
e2, tok = self.expr(tok)
|
|
tok = self.consume_token(tok, tokenize.RPAR)
|
|
if op == _DimFactor.MAX:
|
|
return core.max_dim(e1, e2), tok
|
|
if op == _DimFactor.MIN:
|
|
return core.min_dim(e1, e2), tok
|
|
return _DimExpr._from_operation(op, e1, e2,
|
|
scope=self.scope), tok
|
|
|
|
|
|
def _evaluate_add(v1, v2):
|
|
try:
|
|
if op.index(v1) == 0:
|
|
return v2
|
|
except:
|
|
pass
|
|
try:
|
|
if op.index(v2) == 0:
|
|
return v1
|
|
except:
|
|
pass
|
|
return v1 + v2
|
|
|
|
def _evaluate_multiply(v1, v2):
|
|
try:
|
|
if op.index(v1) == 1:
|
|
return v2
|
|
except:
|
|
pass
|
|
try:
|
|
if op.index(v2) == 1:
|
|
return v1
|
|
except:
|
|
pass
|
|
return v1 * v2
|
|
|
|
# dimension_size(operand, dimension=i) get the operand.shape[i] as a
|
|
# value of type shape_poly.dim_as_value_dtype().
|
|
dimension_size_p = core.Primitive("dimension_size")
|
|
def _dimension_size_abstract_eval(aval: core.AbstractValue, **_) -> core.AbstractValue:
|
|
return core.dim_value_aval()
|
|
|
|
dimension_size_p.def_abstract_eval(_dimension_size_abstract_eval)
|
|
|
|
def _dimension_size_impl(arg, *, dimension):
|
|
return core.dim_constant(arg.shape[dimension])
|
|
dimension_size_p.def_impl(_dimension_size_impl)
|
|
|
|
def _dimension_size_lowering_rule(ctx, arg, *, dimension):
|
|
dim_size = mlir.hlo.get_dimension_size(arg, dimension)
|
|
dim_type = mlir.aval_to_ir_type(core.dim_value_aval())
|
|
if dim_size.type != dim_type:
|
|
dim_size = mlir.hlo.convert(dim_type, dim_size)
|
|
return [dim_size]
|
|
|
|
mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule)
|
|
|
|
|
|
def all_dim_vars(args_avals: Sequence[core.ShapedArray]) -> Sequence[str]:
|
|
dim_vars: set[str] = set()
|
|
for a in args_avals:
|
|
for d in a.shape:
|
|
if is_symbolic_dim(d):
|
|
dim_vars = dim_vars.union(d._get_vars())
|
|
return sorted(dim_vars)
|
|
|
|
|
|
class ShapeEvaluator:
|
|
def __init__(self, env: DimVarEnv):
|
|
self.env = env
|
|
|
|
def evaluate(self, e: DimSize):
|
|
if core.is_constant_dim(e):
|
|
res = op.index(e) # type: ignore
|
|
else:
|
|
res = e._evaluate(self.env) # type: ignore
|
|
return res
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ShapeConstraint:
|
|
|
|
comp: Comparator
|
|
left: DimSize
|
|
right: DimSize
|
|
# `error_message_pieces` is a list of strings and DimSize. The error message
|
|
# is formed by evaluating the DimSize and concatenating the sequence.
|
|
error_message_pieces: Sequence[str | DimSize]
|
|
|
|
def check_statically(self, eval: ShapeEvaluator) -> None:
|
|
"""Evaluates a constraint statically."""
|
|
left, right = eval.evaluate(self.left), eval.evaluate(self.right)
|
|
try:
|
|
if self.comp == Comparator.EQ:
|
|
ok = (left == right)
|
|
elif self.comp == Comparator.GEQ:
|
|
ok = (left >= right)
|
|
else:
|
|
assert False # We are in a context where we know we can evaluate
|
|
# all symbolic expressions to constants.
|
|
except InconclusiveDimensionOperation as e:
|
|
raise self.make_error(eval) from e
|
|
if not ok:
|
|
raise self.make_error(eval)
|
|
|
|
def compute(self, eval: ShapeEvaluator) -> jax.Array | None:
|
|
"""Computes if the constraint is satisfied.
|
|
|
|
If the constraint can be resolved statically returns None
|
|
or raises ValueError otherwise. If the constraint cannot be
|
|
resolved statically, returns a value representing if the
|
|
constraint is satisfied.
|
|
"""
|
|
left, right = eval.evaluate(self.left), eval.evaluate(self.right)
|
|
# Try to evaluate the constraint statically.
|
|
if core.is_constant_shape((left, right)):
|
|
left_int, right_int = op.index(left), op.index(right)
|
|
if self.comp == Comparator.EQ:
|
|
if not (left_int == right_int):
|
|
raise self.make_error(eval)
|
|
elif self.comp == Comparator.GEQ:
|
|
if not (left_int >= right_int):
|
|
raise self.make_error(eval)
|
|
else: assert False
|
|
return None
|
|
|
|
if self.comp == Comparator.EQ:
|
|
is_ok = lax.eq(left, right)
|
|
elif self.comp == Comparator.GEQ:
|
|
is_ok = lax.ge(left, right)
|
|
else: assert False
|
|
return is_ok
|
|
|
|
def __str__(self):
|
|
return (f"{self.left} {'==' if self.comp == Comparator.EQ else '>='} {self.right}"
|
|
f" ({self.error_message_pieces})")
|
|
__repr__ = __str__
|
|
|
|
def error_message_and_inputs(
|
|
self,
|
|
eval: ShapeEvaluator) -> tuple[str, Sequence[Any]]:
|
|
"""Forms the error_message and error message_inputs.
|
|
See shape_assertion.
|
|
"""
|
|
# There is currently a limitation in the shape assertion checker that
|
|
# it supports at most 32 error_message_inputs. We try to stay within the
|
|
# limit, reusing a format specifier if possible.
|
|
max_error_message_inputs = 32
|
|
format_specifiers: dict[DimSize, str] = {}
|
|
error_message_inputs: list[Any] = []
|
|
error_message_strings: list[str] = []
|
|
for e in self.error_message_pieces:
|
|
if isinstance(e, str):
|
|
error_message_strings.append(e)
|
|
continue
|
|
cached_spec = format_specifiers.get(e)
|
|
if cached_spec is not None:
|
|
error_message_strings.append(cached_spec)
|
|
continue
|
|
if len(error_message_inputs) >= max_error_message_inputs:
|
|
error_message_strings.append("N/A")
|
|
continue
|
|
spec = "{" + str(len(error_message_inputs)) + "}"
|
|
format_specifiers[e] = spec
|
|
error_message_strings.append(spec)
|
|
error_message_inputs.append(eval.evaluate(e))
|
|
return ("".join(error_message_strings),
|
|
error_message_inputs)
|
|
|
|
def make_error(self, eval: ShapeEvaluator) -> Exception:
|
|
error_message, error_message_inputs = self.error_message_and_inputs(eval)
|
|
return ValueError(error_message.format(*error_message_inputs))
|
|
|
|
|
|
class ShapeConstraints:
|
|
def __init__(self):
|
|
self.constraints: list[ShapeConstraint] = []
|
|
|
|
def add_constraint(self,
|
|
comp: Comparator,
|
|
left: DimSize, right: DimSize,
|
|
error_message_pieces: Sequence[str | DimSize]):
|
|
c = ShapeConstraint(comp, left, right, error_message_pieces)
|
|
self.constraints.append(c)
|
|
|
|
def check_statically(self, eval: ShapeEvaluator) -> None:
|
|
"""Evaluates all the constraints statically.
|
|
|
|
If the static checking of any constraint fails, raises ValueError.
|
|
"""
|
|
for constraint in self.constraints:
|
|
constraint.check_statically(eval)
|
|
|
|
def shape_assertions(self, eval: ShapeEvaluator) -> None:
|
|
"""Computes the shape assertions for the set of constraints.
|
|
|
|
See jax_export.Exported docstring.
|
|
"""
|
|
# We want to report the errors in the same order as `check_statically`.
|
|
# So, we process them in order, in case some fail statically, and we
|
|
# generate the shape assertions in the same order.
|
|
for constraint in self.constraints:
|
|
is_ok = constraint.compute(eval)
|
|
if is_ok is None: continue # Was resolved statically
|
|
error_message, error_message_inputs = constraint.error_message_and_inputs(eval)
|
|
shape_assertion(
|
|
is_ok, *error_message_inputs,
|
|
error_message=error_message)
|
|
|
|
@dataclasses.dataclass
|
|
class _DimEquation:
|
|
# Encodes that `aval_dim_expr`, which is a symbolic expressions containing
|
|
# unknown dimension variables from the abstract values, is the specification
|
|
# for dimension named `dim_name` (e.g., "args[0].field.shape[2]").
|
|
aval_dim_expr: _DimExpr
|
|
dim_name: str
|
|
|
|
def __str__(self):
|
|
return f"Dimension size of {self.dim_name} with specification '{self.aval_dim_expr}'"
|
|
__repr__ = __str__
|
|
|
|
|
|
def args_kwargs_path_to_str(path: tree_util.KeyPath) -> str:
|
|
# String description of `args` or `kwargs`, assuming the path for a tree for
|
|
# the tuple `(args, kwargs)`.
|
|
if path[0] == tree_util.SequenceKey(0):
|
|
return f"args{tree_util.keystr(path[1:])}"
|
|
elif path[0] == tree_util.SequenceKey(1):
|
|
return f"kwargs{tree_util.keystr(path[1:])}"
|
|
else:
|
|
assert False
|
|
|
|
@functools.lru_cache(128)
|
|
def _cached_pretty_print_dimension_descriptor(
|
|
args_kwargs_tree: tree_util.PyTreeDef,
|
|
flat_arg_idx: int) -> str:
|
|
args_kwargs_with_paths, _ = tree_util.tree_flatten_with_path(
|
|
args_kwargs_tree.unflatten((0,) * args_kwargs_tree.num_leaves))
|
|
arg_str = args_kwargs_path_to_str(args_kwargs_with_paths[flat_arg_idx][0])
|
|
return arg_str
|
|
|
|
def pretty_print_dimension_descriptor(
|
|
args_kwargs_tree: tree_util.PyTreeDef,
|
|
flat_arg_idx: int, dim_idx: int | None) -> str:
|
|
arg_str = _cached_pretty_print_dimension_descriptor(args_kwargs_tree, flat_arg_idx)
|
|
if dim_idx is not None:
|
|
arg_str += f".shape[{dim_idx}]"
|
|
return arg_str
|
|
|
|
@util.cache()
|
|
def solve_dim_vars(
|
|
args_avals: Sequence[core.ShapedArray],
|
|
args_kwargs_tree: tree_util.PyTreeDef,
|
|
) -> tuple[DimVarEnv, ShapeConstraints, Sequence[tuple[str, int, int]]]:
|
|
"""Solves dimension variables in a called function's avals in terms of actual argument shapes.
|
|
|
|
For example, given:
|
|
|
|
args_avals = [ShapedArray((3, a, a + b), f32)]
|
|
|
|
we introduce fresh "synthetic" dimension variables to represent the actual
|
|
dimension size of actual arguments for each non-constant dimension.
|
|
Each synthetic variable has a name, an arg_idx, and a dim_idx, e.g.:
|
|
|
|
synthetic_vars = [("args[0].shape[1]", 0, 1), ("args[0].shape[2]", 0, 2)]
|
|
|
|
and then we express the solution for the unknown dimension variables {a, b}
|
|
as symbolic expressions in terms of the synthetic variables:
|
|
|
|
dict(a=args[0].shape[1], b=args[0].shape[2] - args[0].shape[1])
|
|
|
|
Not all equations are solvable. For now, we solve first the linear
|
|
uni-variate equations, then the solved variables are used to simplify the
|
|
remaining equations to linear uni-variate equations, and the process
|
|
continues until all dimension variables are solved.
|
|
|
|
Args:
|
|
args_avals: the abstract values of the `args`, with shapes that may
|
|
include unknown dimension variables.
|
|
args_kwargs_tree: a PyTreeDef that describes the tuple `(args, kwargs)`
|
|
from which the flat sequence `args_avals` is extracted. Used for
|
|
describing args and kwargs in synthetic variable names and in
|
|
error messages.
|
|
|
|
Returns: a 3-tuple with: (a) the solution for the unknown dimension variables
|
|
(b) a list of constraints that must be satisfied for the solution to be a
|
|
valid one, and (c) and the list of synthetic variables that may appear in
|
|
the solution and the constraints.
|
|
|
|
Raises ValueError if it cannot solve some dimension variable.
|
|
"""
|
|
dim_equations: list[_DimEquation] = []
|
|
synth_dimension_vars: list[tuple[str, int, int]] = []
|
|
# tuples with argument name and its polymorphic shape ('args[0]', '(a, a + b'))
|
|
polymorphic_shape_specs: list[tuple[str, str]] = []
|
|
for arg_idx, aval in enumerate(args_avals):
|
|
if all(not is_symbolic_dim(d) for d in aval.shape):
|
|
continue
|
|
polymorphic_shape_specs.append(
|
|
(pretty_print_dimension_descriptor(args_kwargs_tree, arg_idx, None),
|
|
str(aval.shape)))
|
|
for dim_idx, aval_d in enumerate(aval.shape):
|
|
if is_symbolic_dim(aval_d):
|
|
synth_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree,
|
|
arg_idx, dim_idx)
|
|
synth_dimension_vars.append((synth_dim_var, arg_idx, dim_idx))
|
|
dim_equations.append(
|
|
_DimEquation(aval_dim_expr=aval_d,
|
|
dim_name=synth_dim_var))
|
|
|
|
solution, shape_constraints = _solve_dim_equations(dim_equations,
|
|
polymorphic_shape_specs)
|
|
return solution, shape_constraints, synth_dimension_vars
|
|
|
|
|
|
def compute_dim_vars_from_arg_shapes(
|
|
args_avals: Sequence[core.ShapedArray],
|
|
*actual_args: jax.Array,
|
|
args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]:
|
|
"""Computes values of dimension variables to unify args_avals with actual arguments.
|
|
|
|
Like `solve_dim_vars` except that here we express the solution as
|
|
JAX arrays that reference the `actual_args`. This function can be used to
|
|
generate the code for computing the dimension variables. It also generates
|
|
the shape assertions.
|
|
|
|
Returns: the values of the dimension variables, in the order determined by
|
|
`all_dim_vars(args_avals)`.
|
|
"""
|
|
dim_vars = all_dim_vars(args_avals)
|
|
solution, shape_constraints, synth_dim_vars = solve_dim_vars(
|
|
tuple(args_avals), args_kwargs_tree=args_kwargs_tree)
|
|
|
|
# Replace the synthetic vars with the dynamic shape of the actual arg
|
|
synthetic_env: DimVarEnv = {
|
|
vname: dimension_size_p.bind(actual_args[arg_idx], dimension=dim_idx)
|
|
for (vname, arg_idx, dim_idx) in synth_dim_vars
|
|
}
|
|
synthetic_eval = ShapeEvaluator(synthetic_env)
|
|
shape_constraints.shape_assertions(synthetic_eval)
|
|
dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars]
|
|
return tuple(dim_values)
|
|
|
|
def _solve_dim_equations(
|
|
eqns: list[_DimEquation],
|
|
polymorphic_shape_specs: Sequence[tuple[str, str]]
|
|
) -> tuple[DimVarEnv, ShapeConstraints]:
|
|
# Returns a shape environment and the shape constraints if it can solve all
|
|
# dimension variables. Raises an exception if it cannot.
|
|
shape_env: DimVarEnv = {}
|
|
solution_error_message_pieces: list[str | DimSize] = [
|
|
" Obtained dimension variables: "
|
|
] # Error message describing the solution
|
|
# Prepare error message piece describing the polymorphic shape specs
|
|
poly_specs_err_msg = (
|
|
" Using the following polymorphic shapes specifications: " +
|
|
",".join(f"{arg_name}.shape = {arg_spec}"
|
|
for arg_name, arg_spec in polymorphic_shape_specs)) + "."
|
|
solution_err_msg_trailer_errors = ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
|
|
|
|
shape_constraints = ShapeConstraints() # accumulate shape constraints
|
|
scope: SymbolicScope | None = None
|
|
|
|
def process_one_eqn(eqn: _DimEquation) -> bool:
|
|
# We start with a DimEquation of the form `dim_expr = dim_value`
|
|
# Try to rewrite the equation as `var * factor_var = dim_value_2` (a linear
|
|
# uni-variate equation). Returns `False` if this rewrite fails.
|
|
# Otherwise, compute the `var` value as `dim_value_2 // factor`, add it to
|
|
# `shape_env` and return `True`.
|
|
#
|
|
# Invariant:
|
|
# var * factor_var + remaining_terms_from_dim_expr = dim_value
|
|
var, var_k = None, None
|
|
nonlocal scope
|
|
if scope is None:
|
|
scope = eqn.aval_dim_expr.scope
|
|
elif config.enable_checks.value:
|
|
scope._check_same_scope(eqn.aval_dim_expr, when=f"solving equation {eqn}")
|
|
|
|
dim_value = _DimExpr._from_var(eqn.dim_name, scope)
|
|
|
|
for term, term_k in eqn.aval_dim_expr._sorted_terms:
|
|
# Perhaps we can already evaluate this term (all vars solved)
|
|
try:
|
|
term_value = term.evaluate(shape_env, scope)
|
|
except UnexpectedDimVar:
|
|
# `mon` still uses some variables not yet solved. We handle only the
|
|
# case when `mon` is a single variable.
|
|
v = term.to_var()
|
|
if v is not None and var is None:
|
|
var, var_k = v, term_k
|
|
continue
|
|
else:
|
|
dim_value = dim_value + core.dim_constant(-1) * _evaluate_multiply(term_value, core.dim_constant(term_k))
|
|
continue
|
|
return False # This equation cannot yet be used to solve a variable
|
|
|
|
if var is not None:
|
|
if var_k == 1:
|
|
var_value = dim_value
|
|
else:
|
|
var_value, var_remainder = divmod(dim_value, core.dim_constant(var_k)) # type: ignore
|
|
shape_constraints.add_constraint(
|
|
Comparator.EQ, var_remainder, 0,
|
|
error_message_pieces=([
|
|
"Input shapes do not match the polymorphic shapes specification. "
|
|
"Division had remainder ", var_remainder,
|
|
f" when computing the value of '{var}'." + poly_specs_err_msg
|
|
] + solution_error_message_pieces + [
|
|
solution_err_msg_trailer_errors]))
|
|
|
|
if not isinstance(var_value, _DimExpr):
|
|
assert var_value.dtype == core.dim_value_dtype() # type: ignore[attribute-error,unused-ignore]
|
|
shape_env[var] = var_value # type: ignore
|
|
solution_error_message_pieces.extend([ # type: ignore[container-type-mismatch,unused-ignore]
|
|
f"'{var}' = ", var_value,
|
|
f" from specification '{eqn.aval_dim_expr}' "
|
|
f"for dimension {eqn.dim_name} (= ",
|
|
_DimExpr._from_var(eqn.dim_name, eqn.aval_dim_expr.scope),
|
|
"), "])
|
|
|
|
shape_constraints.add_constraint(
|
|
Comparator.GEQ, var_value, 1,
|
|
error_message_pieces=[
|
|
"Input shapes do not match the polymorphic shapes specification. "
|
|
f"Expected value >= 1 for dimension variable '{var}'." +
|
|
poly_specs_err_msg
|
|
] + solution_error_message_pieces + [
|
|
solution_err_msg_trailer_errors])
|
|
|
|
return True
|
|
else:
|
|
# All variables are resolved for this equation, we emit an assertion
|
|
shape_constraints.add_constraint(
|
|
Comparator.EQ,
|
|
_DimExpr._from_var(eqn.dim_name, eqn.aval_dim_expr.scope),
|
|
eqn.aval_dim_expr._evaluate(shape_env),
|
|
error_message_pieces=([
|
|
"Input shapes do not match the polymorphic shapes specification. "
|
|
f"Found inconsistency between dimension size {eqn.dim_name} (= ",
|
|
_DimExpr._from_var(eqn.dim_name, eqn.aval_dim_expr.scope),
|
|
f") and the specification '{eqn.aval_dim_expr}' (= ",
|
|
eqn.aval_dim_expr._evaluate(shape_env),
|
|
")." + poly_specs_err_msg] + solution_error_message_pieces +
|
|
[solution_err_msg_trailer_errors])
|
|
)
|
|
return True
|
|
|
|
def add_explicit_symbolic_constraints(shape_env: DimVarEnv):
|
|
if not shape_env: return
|
|
assert scope is not None
|
|
for constr in scope._explicit_constraints:
|
|
# We can't just construct constr.e1 - constr.e2 because for an equality
|
|
# constraint it would be reduced to 0.
|
|
c_e1 = constr.e1._evaluate(shape_env) if not core.is_constant_dim(constr.e1) else constr.e1 # type: ignore
|
|
c_e2 = constr.e2._evaluate(shape_env) if not core.is_constant_dim(constr.e2) else constr.e2 # type: ignore
|
|
c_diff = c_e1 - c_e2
|
|
shape_constraints.add_constraint(
|
|
constr.cmp, c_diff, 0,
|
|
error_message_pieces=[
|
|
f"Input shapes do not match the symbolic shape constraint {constr.debug_str}. "
|
|
f"Expected '{constr.e1} - {constr.e2}' to be "
|
|
f"{'greater or equal' if constr.cmp == Comparator.GEQ else 'equal'} to 0, "
|
|
"but found ", c_diff,
|
|
|
|
". " + poly_specs_err_msg
|
|
] + solution_error_message_pieces + [
|
|
solution_err_msg_trailer_errors])
|
|
|
|
|
|
while True:
|
|
nr_eqns = len(eqns)
|
|
eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)]
|
|
if not eqns:
|
|
add_explicit_symbolic_constraints(shape_env)
|
|
return shape_env, shape_constraints # SUCCESS
|
|
elif len(eqns) >= nr_eqns:
|
|
break
|
|
|
|
# We have some equations that we cannot solve further
|
|
unsolved_vars: set[str] = set()
|
|
unsolved_polys: list[_DimExpr] = []
|
|
for eqn in eqns:
|
|
unsolved_vars = unsolved_vars.union(eqn.aval_dim_expr._get_vars())
|
|
unsolved_polys.append(eqn.aval_dim_expr)
|
|
unsolved_vars = unsolved_vars.difference(shape_env.keys())
|
|
err_msg = (
|
|
f"Cannot solve for values of dimension variables {unsolved_vars}. "
|
|
"We can only solve linear uni-variate constraints." + poly_specs_err_msg +
|
|
" Unprocessed specifications: " +
|
|
", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}"
|
|
for eqn in eqns) +
|
|
". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details."
|
|
)
|
|
raise ValueError(err_msg)
|