2023-09-05 22:15:22 -07:00
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Shape polymorphism support.
2024-06-12 08:47:17 +02:00
See documentation at https : / / jax . readthedocs . io / en / latest / export / shape_poly . html .
2023-09-05 22:15:22 -07:00
"""
2023-12-11 13:59:29 +00:00
from __future__ import annotations
2024-02-09 23:18:52 +01:00
import enum
2024-06-26 14:44:52 -04:00
from collections . abc import Callable , Sequence
2023-09-05 22:15:22 -07:00
import dataclasses
from enum import Enum
import functools
import itertools
import io
2024-02-12 23:35:35 -08:00
import copy
2023-09-05 22:15:22 -07:00
import operator as op
import tokenize
2024-06-26 14:44:52 -04:00
from typing import Any , Union , overload
2024-01-10 08:45:03 +02:00
import warnings
2023-09-05 22:15:22 -07:00
import numpy as np
import opt_einsum
import jax
from jax . interpreters import xla
2023-10-12 13:15:22 +01:00
from jax . _src import config
2023-09-05 22:15:22 -07:00
from jax . _src import core
from jax . _src import dtypes
from jax . _src import effects
from jax . _src . lax import lax
from jax . _src . interpreters import mlir
from jax . _src . numpy import lax_numpy
2024-01-01 23:09:42 +07:00
from jax . _src import source_info_util
2023-09-05 22:15:22 -07:00
from jax . _src import tree_util
from jax . _src import util
2024-01-05 14:48:53 +07:00
DimSize = Union [ " _DimExpr " , int ]
2023-09-05 22:15:22 -07:00
TfVal = Any
DimVarEnv = dict [ str , jax . Array ]
DType = Any
2024-02-03 06:38:01 +02:00
# Tuples of terms and their coefficients, sorted with the largest term first.
2024-02-20 23:13:20 +01:00
SortedTerms = Sequence [ tuple [ " _DimTerm " , int ] ]
SortedFactors = Sequence [ tuple [ " _DimFactor " , int ] ]
2024-02-03 06:38:01 +02:00
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
# Normalization rules represent the explicit constraint `t*tk == e` as
# a mapping of `t` to `(e, tk)`.
2024-02-20 23:13:20 +01:00
NormalizationRules = dict [ " _DimTerm " , tuple [ " _DimExpr " , int ] ]
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
2024-02-14 11:34:40 +02:00
2023-09-05 22:15:22 -07:00
class InconclusiveDimensionOperation ( core . InconclusiveDimensionOperation ) :
""" Raised when we cannot conclusively compute with symbolic dimensions. """
_help_msg = """
This error arises for comparison operations with shapes that
are non - constant , and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved .
2024-06-12 08:47:17 +02:00
Please see https : / / jax . readthedocs . io / en / latest / export / shape_poly . html #comparison-of-symbolic-dimensions-is-partially-supported
2023-09-05 22:15:22 -07:00
for more details .
"""
def __init__ ( self , message : str ) :
2024-01-01 23:09:42 +07:00
error_msg = f " { message } { InconclusiveDimensionOperation . _help_msg } "
2023-09-05 22:15:22 -07:00
# https://github.com/python/mypy/issues/5887
2024-05-17 09:46:36 +01:00
super ( ) . __init__ ( error_msg )
2023-09-05 22:15:22 -07:00
2024-09-06 11:52:12 +03:00
class UnexpectedDimVar ( Exception ) :
pass
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
class Comparator ( Enum ) :
EQ = 1
GEQ = 2
@dataclasses.dataclass ( frozen = True )
class _SymbolicConstraint :
2024-09-06 11:52:12 +03:00
# Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
cmp : Comparator
debug_str : str # The form in which the user expressed it, for error messages
2024-09-06 11:52:12 +03:00
e1 : DimSize # This has been normalized w.r.t. previous constraints only
e2 : DimSize # This has been normalized w.r.t. previous constraints only
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
def __repr__ ( self ) :
2024-09-06 11:52:12 +03:00
return f " Constraint( { self . debug_str } ) "
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
2024-02-20 23:13:20 +01:00
class _DimFactor :
""" Represents a factor in a symbolic dimension expression.
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
Factors are either variables , or expressions of the form floordiv ( E1 , E2 ) or
mod ( E1 , E2 ) , or max ( E1 , E2 ) , or min ( E1 , E2 ) .
Factors are multiplied to form terms ( see _DimTerm ) , and
terms are added to form symbolic expressions ( see _DimExpr ) .
2023-09-05 22:15:22 -07:00
Args :
2024-02-20 23:13:20 +01:00
* var : if specified then the factor is a dimension variable . ` operation `
2023-09-05 22:15:22 -07:00
must be ` None ` .
2024-02-20 23:13:20 +01:00
* operation : if specified then the factor is an operation applied to
` operands ` . One of ` FLOORDIR ` or ` MOD ` or ` MAX ` or ` MIN ` .
` var ` must be ` None `
2023-09-05 22:15:22 -07:00
* operands : the operands to which the operation is applied .
"""
# The supported operations
2023-12-23 17:58:20 +07:00
# FLOORDIV(e1, e2) and MOD(e1, e2) have the same semantics as in Python:
# FLOORDIV(e1, e2) = e1 // e2 = floor(e1 / e2)
# if e2 > 0 then 0 <= MOD(e1, e2) < e2
# if e2 < 0 then e2 < MOD(e1, e2) <= 0
# e1 = e2 * FLOORDIV(e1, e2) + MOD(e1, e2)
#
2023-09-05 22:15:22 -07:00
FLOORDIV = " floordiv "
MOD = " mod "
2023-12-13 10:14:27 +01:00
MAX = " max "
MIN = " min "
NON_NEGATIVE = " non_negative " # The max of the operand and 0. Replaced with
# max but kept here for backwards compatibility.
2023-09-05 22:15:22 -07:00
2024-02-15 13:53:05 +01:00
__slots__ = [ " var " , " operation " , " operands " , " _hash " , " _size " ]
2023-12-11 13:59:29 +00:00
def __init__ ( self , * operands : _DimExpr ,
var : str | None = None ,
operation : str | None = None ) :
2023-09-05 22:15:22 -07:00
if var is not None :
assert operation is None
assert not operands
else :
assert operation is not None
self . var = var
self . operation = operation
self . operands = operands
2024-02-15 13:53:05 +01:00
self . _hash = None
self . _size : int = 1 if var is not None else 1 + sum ( o . _size for o in operands )
@staticmethod
2024-02-20 23:13:20 +01:00
def from_var ( v : str ) - > _DimFactor :
return _DimFactor ( var = v )
2023-09-05 22:15:22 -07:00
2024-02-15 13:53:05 +01:00
@staticmethod
def from_operation ( operation : str , * operands : DimSize ,
2024-02-20 23:13:20 +01:00
scope : SymbolicScope ) - > _DimFactor :
return _DimFactor ( * ( _ensure_poly ( o , operation , scope ) for o in operands ) ,
operation = operation )
2024-01-05 14:48:53 +07:00
2023-12-11 13:59:29 +00:00
def to_var ( self ) - > str | None :
2023-09-05 22:15:22 -07:00
return self . var
def get_vars ( self ) - > set [ str ] :
# All the vars that appear
if self . var is not None :
return { self . var }
else :
acc = set ( )
for opnd in self . operands :
2024-02-20 23:13:20 +01:00
acc . update ( opnd . _get_vars ( ) )
2023-09-05 22:15:22 -07:00
return acc
def __str__ ( self ) :
if self . var is not None :
return self . var
opnd_str = " , " . join ( [ str ( opnd ) for opnd in self . operands ] )
return f " { self . operation } ( { opnd_str } ) "
__repr__ = __str__
def __hash__ ( self ) :
2024-02-15 13:53:05 +01:00
if self . _hash is None :
self . _hash = hash ( ( self . var , self . operation , * self . operands ) )
2024-01-05 22:10:50 +07:00
return self . _hash
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
def _syntactic_cmp ( self , other : _DimFactor ) - > int :
2024-01-05 14:48:53 +07:00
""" Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically ( syntactic ) , to be used for sorting .
The result is not related to the semantic value .
"""
2024-01-09 07:22:20 +02:00
if c := cmp_comparable ( self . _size , other . _size ) : return c
2024-01-05 14:48:53 +07:00
if self . var is not None :
2024-01-09 07:22:20 +02:00
return cmp_comparable ( self . var , other . var )
2024-05-17 09:46:36 +01:00
if c := cmp_comparable ( self . operation , other . operation ) : return c
2024-01-05 14:48:53 +07:00
return cmp_sequence ( self . operands , other . operands ,
lambda s_o , o_o : s_o . _syntactic_cmp ( o_o ) )
2023-09-05 22:15:22 -07:00
def __eq__ ( self , other : Any ) :
2024-01-05 14:48:53 +07:00
""" Lexicographic comparison. """
2024-02-20 23:13:20 +01:00
if not isinstance ( other , _DimFactor ) : return False
2024-01-05 14:48:53 +07:00
return self . _syntactic_cmp ( other ) == 0
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
def __lt__ ( self , other : _DimFactor ) :
2024-01-05 14:48:53 +07:00
""" Lexicographic comparison. """
return self . _syntactic_cmp ( other ) < 0
2024-02-20 23:13:20 +01:00
def __le__ ( self , other : _DimFactor ) :
2024-01-05 14:48:53 +07:00
""" Lexicographic comparison. """
return self . _syntactic_cmp ( other ) < = 0
2024-02-20 23:13:20 +01:00
def __gt__ ( self , other : _DimFactor ) :
2024-01-05 14:48:53 +07:00
""" Lexicographic comparison. """
return self . _syntactic_cmp ( other ) > 0
2024-02-20 23:13:20 +01:00
def __ge__ ( self , other : _DimFactor ) :
2024-01-05 14:48:53 +07:00
""" Lexicographic comparison """
return self . _syntactic_cmp ( other ) > = 0
2023-09-05 22:15:22 -07:00
2024-09-06 11:52:12 +03:00
def evaluate ( self , env : DimVarEnv , scope : SymbolicScope ) :
2023-09-05 22:15:22 -07:00
if self . var is not None :
try :
return env [ self . var ]
except KeyError :
2024-09-06 11:52:12 +03:00
# Perhaps there is a normalization rule for this variable
normalized_var = _DimExpr . _from_var ( self . var , scope )
if core . is_constant_dim ( normalized_var ) :
return normalized_var
non_trivial_normalization = ( v1 := normalized_var . _to_var ( ) ) is None or v1 != self . var # type: ignore
if non_trivial_normalization :
return normalized_var . _evaluate ( env ) # type: ignore
2023-09-05 22:15:22 -07:00
err_msg = (
2024-06-17 11:54:16 +03:00
f " Encountered dimension variable ' { self . var } ' that is not appearing in the shapes of the function arguments. \n "
2024-06-12 08:47:17 +02:00
" Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details. " )
2024-09-06 11:52:12 +03:00
raise UnexpectedDimVar ( err_msg )
2023-09-05 22:15:22 -07:00
else :
2024-02-20 23:13:20 +01:00
operand_values = [ opnd . _evaluate ( env ) for opnd in self . operands ]
if self . operation == _DimFactor . FLOORDIV :
2023-09-05 22:15:22 -07:00
return divmod ( * operand_values ) [ 0 ] # type: ignore
2024-02-20 23:13:20 +01:00
elif self . operation == _DimFactor . MOD :
2023-09-05 22:15:22 -07:00
return divmod ( * operand_values ) [ 1 ] # type: ignore
2024-02-20 23:13:20 +01:00
elif self . operation == _DimFactor . MAX :
2023-12-13 10:14:27 +01:00
op1 , op2 = operand_values
if core . is_constant_dim ( op1 ) and core . is_constant_dim ( op2 ) :
return max ( op1 , op2 )
if core . is_symbolic_dim ( op1 ) or core . is_symbolic_dim ( op2 ) :
return core . max_dim ( op1 , op2 )
# In the context of `evaluate` dimension variables may be mapped to
# JAX Tracers.
return lax . max ( op1 , op2 )
2024-02-20 23:13:20 +01:00
elif self . operation == _DimFactor . MIN :
2023-12-13 10:14:27 +01:00
op1 , op2 = operand_values
if core . is_constant_dim ( op1 ) and core . is_constant_dim ( op2 ) :
return min ( op1 , op2 )
if core . is_symbolic_dim ( op1 ) or core . is_symbolic_dim ( op2 ) :
return core . min_dim ( op1 , op2 )
2023-11-14 10:41:31 +01:00
# In the context of `evaluate` dimension variables may be mapped to
# JAX Tracers.
2023-12-13 10:14:27 +01:00
return lax . min ( op1 , op2 )
2023-09-05 22:15:22 -07:00
else :
assert False , self . operation
2024-02-12 23:35:35 -08:00
def __deepcopy__ ( self , memo ) :
2024-02-20 23:13:20 +01:00
return _DimFactor ( * copy . deepcopy ( self . operands , memo ) ,
var = copy . deepcopy ( self . var , memo ) ,
operation = copy . deepcopy ( self . operation , memo ) )
2024-02-12 23:35:35 -08:00
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
2024-02-20 23:13:20 +01:00
class _DimTerm :
""" Represents a multiplication of factors.
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
The representation is a sequence of _DimFactor factors along with their
2024-02-15 13:53:05 +01:00
integer exponents ( > = 1 ) . The empty sequence represents the constant 1.
2023-09-05 22:15:22 -07:00
"""
2024-02-15 13:53:05 +01:00
__slots__ = [ " _factors " , " _hash " , " _size " ]
def __init__ ( self , sorted_factors : SortedFactors ) :
self . _factors = sorted_factors
self . _hash = None
2024-02-20 23:13:20 +01:00
self . _size = sum ( ( 1 + f_exp * f . _size ) for f , f_exp in self . _factors )
2024-01-09 07:22:20 +02:00
2023-09-05 22:15:22 -07:00
def __hash__ ( self ) :
2024-02-15 13:53:05 +01:00
if self . _hash is None :
self . _hash = hash ( tuple ( self . _factors ) )
2024-01-09 07:22:20 +02:00
return self . _hash
2023-09-05 22:15:22 -07:00
def __str__ ( self ) :
2024-02-20 23:13:20 +01:00
return " * " . join ( f " { fact } ^ { exponent } " if exponent != 1 else str ( fact )
for fact , exponent in sorted ( self . _factors ) )
2023-09-05 22:15:22 -07:00
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
__repr__ = __str__
2024-02-15 13:53:05 +01:00
@staticmethod
2024-02-20 23:13:20 +01:00
def from_var ( v : str ) - > _DimTerm :
return _DimTerm ( ( ( _DimFactor . from_var ( v ) , 1 ) , ) )
2023-09-05 22:15:22 -07:00
2024-02-15 13:53:05 +01:00
@staticmethod
2024-02-20 23:13:20 +01:00
def from_factor ( f : _DimFactor , f_exp : int ) :
return _DimTerm ( ( ( f , f_exp ) , ) )
2023-09-05 22:15:22 -07:00
2024-02-15 13:53:05 +01:00
@staticmethod
def from_operation ( operation : str , * operands : DimSize ,
2024-02-20 23:13:20 +01:00
scope : SymbolicScope ) - > _DimTerm :
return _DimTerm ( ( ( _DimFactor . from_operation ( operation , * operands ,
scope = scope ) , 1 ) , ) )
2024-01-05 14:48:53 +07:00
2023-12-11 13:59:29 +00:00
def to_var ( self ) - > str | None :
2024-02-20 23:13:20 +01:00
""" Extract the variable name from a term.
Return None if the term is not a single variable . """
a = self . to_factor ( )
2024-01-05 14:48:53 +07:00
return a . to_var ( ) if a is not None else None
2024-02-20 23:13:20 +01:00
def to_factor ( self ) - > _DimFactor | None :
""" Extract the single factor from a term.
Return None if the term is not a single factor . """
2024-02-15 13:53:05 +01:00
if len ( self . _factors ) > 1 : return None
2024-02-20 23:13:20 +01:00
( f , f_exp ) , = self . _factors
if f_exp != 1 : return None
return f
2023-09-05 22:15:22 -07:00
def get_vars ( self ) - > set [ str ] :
2024-02-20 23:13:20 +01:00
# All the vars that appear in the term.
2023-09-05 22:15:22 -07:00
acc = set ( )
2024-02-15 13:53:05 +01:00
for ( f , _ ) in self . _factors :
acc . update ( f . get_vars ( ) )
2023-09-05 22:15:22 -07:00
return acc
@property
2024-02-15 13:53:05 +01:00
def is_constant ( self ) :
return not self . _factors
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
def _syntactic_cmp ( self , other : _DimTerm ) - > int :
2024-01-05 14:48:53 +07:00
""" Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically ( syntactic ) , to be used for sorting .
The result is not related to the semantic value .
2023-09-05 22:15:22 -07:00
"""
2024-01-09 07:22:20 +02:00
if c := cmp_comparable ( self . _size , other . _size ) : return c
2024-02-20 23:13:20 +01:00
def cmp_factor ( s_f : tuple [ _DimFactor , int ] , o_f : tuple [ _DimFactor , int ] ) - > int :
if c := s_f [ 0 ] . _syntactic_cmp ( o_f [ 0 ] ) : return c
2024-02-15 13:53:05 +01:00
# Consider the terms with exponents to be expanded as multiplications.
2024-02-20 23:13:20 +01:00
# Then a higher exponent for a "large" factor should lead to a "larger" term.
return cmp_comparable ( s_f [ 1 ] , o_f [ 1 ] )
2024-01-05 14:48:53 +07:00
2024-02-20 23:13:20 +01:00
return cmp_sequence ( self . _factors , other . _factors , cmp_factor )
2024-01-05 14:48:53 +07:00
2024-02-20 23:13:20 +01:00
def __lt__ ( self , other : _DimTerm ) :
2024-01-05 14:48:53 +07:00
""" Lexicographic comparison """
return self . _syntactic_cmp ( other ) < 0
2024-02-20 23:13:20 +01:00
def __le__ ( self , other : _DimTerm ) :
2024-01-05 14:48:53 +07:00
""" Lexicographic comparison """
return self . _syntactic_cmp ( other ) < = 0
2024-02-20 23:13:20 +01:00
def __gt__ ( self , other : _DimTerm ) :
2024-01-05 14:48:53 +07:00
""" Lexicographic comparison """
return self . _syntactic_cmp ( other ) > 0
2024-02-20 23:13:20 +01:00
def __ge__ ( self , other : _DimTerm ) :
2024-01-05 14:48:53 +07:00
""" Lexicographic comparison """
return self . _syntactic_cmp ( other ) > = 0
2023-09-05 22:15:22 -07:00
2024-02-15 13:53:05 +01:00
def __eq__ ( self , other ) - > bool :
2024-02-20 23:13:20 +01:00
if not isinstance ( other , _DimTerm ) : return False
2024-02-15 13:53:05 +01:00
return self . _syntactic_cmp ( other ) == 0
def __ne__ ( self , other ) - > bool :
return not ( self == other )
2024-02-20 23:13:20 +01:00
def mul ( self , other : _DimTerm ) - > _DimTerm :
2023-09-05 22:15:22 -07:00
"""
2024-02-20 23:13:20 +01:00
Returns the product with another term . Example : ( n ^ 2 * m ) * n == n ^ 3 * m .
2023-09-05 22:15:22 -07:00
"""
2024-05-17 09:46:36 +01:00
return _DimTerm ( _DimExpr . _linear_combination_sorted_pairs ( self . _factors , 0 , 1 ,
other . _factors , 0 , 1 ) )
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
def divide ( self , divisor : _DimTerm ) - > _DimTerm :
2023-09-05 22:15:22 -07:00
"""
2024-02-20 23:13:20 +01:00
Divides by another term . Raises a InconclusiveDimensionOperation
if the result is not a term .
2023-09-05 22:15:22 -07:00
For example , ( n ^ 3 * m ) / / n == n ^ 2 * m , but n / / m fails .
"""
2024-05-17 09:46:36 +01:00
new_factors = _DimExpr . _linear_combination_sorted_pairs ( self . _factors , 0 , 1 ,
divisor . _factors , 0 , - 1 )
2024-02-15 13:53:05 +01:00
for _ , f_exp in new_factors :
if f_exp < = 0 :
2023-09-05 22:15:22 -07:00
raise InconclusiveDimensionOperation ( f " Cannot divide { self } by { divisor } . " )
2024-05-17 09:46:36 +01:00
return _DimTerm ( new_factors )
2023-09-05 22:15:22 -07:00
2024-09-06 11:52:12 +03:00
def evaluate ( self , env : DimVarEnv , scope : SymbolicScope ) :
2023-09-05 22:15:22 -07:00
prod = lambda xs : functools . reduce ( _evaluate_multiply , xs ) if xs else core . dim_constant ( 1 )
def pow_opt ( v , p : int ) :
return v if p == 1 else prod ( [ v ] * p )
2024-09-06 11:52:12 +03:00
return prod ( [ pow_opt ( f . evaluate ( env , scope ) , exp ) for f , exp in self . _factors ] )
2023-09-05 22:15:22 -07:00
2024-02-12 23:35:35 -08:00
def __deepcopy__ ( self , memo ) :
2024-02-20 23:13:20 +01:00
return _DimTerm ( copy . deepcopy ( self . _factors , memo ) )
2024-02-15 13:53:05 +01:00
# The constant 1, as a term.
2024-02-20 23:13:20 +01:00
_DimTerm_one = _DimTerm ( ( ) )
2024-02-15 13:53:05 +01:00
2023-09-05 22:15:22 -07:00
2024-02-03 06:38:01 +02:00
class _DimExpr :
2024-02-20 23:13:20 +01:00
""" Symbolic expressions using dimension variables.
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
A dimension expression is an addition of terms ( _DimTerm ) , which themselves
are products of factors ( _DimFactor ) .
2024-02-03 06:38:01 +02:00
2024-02-20 23:13:20 +01:00
The representation of a _DimExpr is as sequence of pairs ` ( term , coeff ) ` ,
2024-02-03 06:38:01 +02:00
representing the linear combination of terms with the given coefficients .
2024-02-20 23:13:20 +01:00
The sequence is sorted by lexicographic ( syntactic ) ordering of ` _DimTerm ` ,
with the largest terms first . The special term ` _DimTerm_one ` is mapped
2024-02-03 06:38:01 +02:00
to the free integer coefficient of the expression .
2023-09-05 22:15:22 -07:00
We overload integer operations , but we do that soundly , raising
: class : ` InconclusiveDimensionOperation ` when the result is not
representable as a _DimExpr .
"""
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
2024-02-20 23:13:20 +01:00
__slots__ = ( " _sorted_terms " , " _scope " , " _hash " , " _size " )
def __init__ ( self , sorted_terms : SortedTerms ,
2024-01-01 23:09:42 +07:00
scope : SymbolicScope ) :
2024-02-20 23:13:20 +01:00
# Do not construct _DimExpr directly, unless you are sure that `terms` is
2024-09-06 11:52:12 +03:00
# normalized; Use _DimExpr._normalize_sorted_terms.
2024-02-20 23:13:20 +01:00
self . _sorted_terms = tuple ( sorted_terms ) or ( ( _DimTerm_one , 0 ) , )
2024-01-01 23:09:42 +07:00
self . _scope = scope
2024-02-15 13:53:05 +01:00
self . _hash = None
2024-02-03 06:38:01 +02:00
# _size speeds up _syntactic_cmp, which is used a lot for hashing.
2024-02-09 23:18:52 +01:00
self . _size = sum ( ( 1 + abs ( m_count ) * m . _size )
2024-02-20 23:13:20 +01:00
for m , m_count in self . _sorted_terms )
2024-02-03 06:38:01 +02:00
2024-01-01 23:09:42 +07:00
@property
def scope ( self ) :
# We make the expression scope visible, but read-only.
return self . _scope
2023-09-05 22:15:22 -07:00
2024-02-15 13:53:05 +01:00
@staticmethod
2024-02-20 23:13:20 +01:00
def _coeff_to_sorted_terms ( coeffs : dict [ _DimTerm , int ] ) - > SortedTerms :
2024-02-15 13:53:05 +01:00
return sorted ( ( p for p in coeffs . items ( ) if p [ 1 ] != 0 ) , reverse = True )
2024-02-03 06:38:01 +02:00
2024-02-15 13:53:05 +01:00
@staticmethod
2024-02-20 23:13:20 +01:00
def _from_term ( t : _DimTerm , t_k : int , scope : SymbolicScope ) - > DimSize :
return _DimExpr . _normalize_sorted_terms ( ( ( t , t_k ) , ) , scope )
2024-02-03 06:38:01 +02:00
2024-02-15 13:53:05 +01:00
@staticmethod
2024-09-06 11:52:12 +03:00
def _from_var ( v : str , scope : SymbolicScope ) - > DimSize :
return _DimExpr . _normalize_sorted_terms ( ( ( _DimTerm . from_var ( v ) , 1 ) , ) , scope )
2024-02-03 06:38:01 +02:00
2024-02-15 13:53:05 +01:00
@staticmethod
2024-02-20 23:13:20 +01:00
def _from_operation ( operation : str , * operands : DimSize ,
scope : SymbolicScope ) - > DimSize :
if operation == _DimFactor . NON_NEGATIVE : # For parsing, for backwards compatibility
return _DimExpr . _from_term (
_DimTerm . from_operation ( _DimFactor . MAX , * operands , 0 ,
scope = scope ) , 1 ,
2024-02-03 06:38:01 +02:00
scope = scope )
2024-02-20 23:13:20 +01:00
return _DimExpr . _from_term (
_DimTerm . from_operation ( operation , * operands , scope = scope ) , 1 ,
2024-02-03 06:38:01 +02:00
scope = scope )
2024-01-05 22:10:50 +07:00
@property
2024-02-20 23:13:20 +01:00
def _leading_term ( self ) - > tuple [ _DimTerm , int ] :
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
""" Returns the highest degree term that comes last lexicographically. """
2024-02-20 23:13:20 +01:00
return self . _sorted_terms [ 0 ]
2024-01-05 14:48:53 +07:00
2024-02-20 23:13:20 +01:00
def _to_single_term ( self ) - > tuple [ int , int , _DimTerm ] | None :
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
""" Extracts the single term: k + c * term.
Returns None if the expression is not a single term , or ( k , c , term )
"""
n1 = 0
n2 = 0
2024-02-20 23:13:20 +01:00
term = None
for t , t_k in self . _sorted_terms :
if t . is_constant :
n1 = t_k
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
continue
2024-02-20 23:13:20 +01:00
if term is None :
term = t
n2 = t_k
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
continue
return None
2024-02-20 23:13:20 +01:00
assert term is not None
return ( n1 , n2 , term )
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
2024-02-15 13:53:05 +01:00
@staticmethod
2024-02-20 23:13:20 +01:00
def _add_coeff ( coeffs : dict [ _DimTerm , int ] , t : _DimTerm , coeff : int ) :
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
""" coeffs[t] += coeff, with squashing 0 coefficients. """
if coeff == 0 : return
2024-02-15 13:53:05 +01:00
coeffs [ t ] = coeffs . get ( t , 0 ) + coeff
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
2024-02-15 13:53:05 +01:00
@staticmethod
2024-02-20 23:13:20 +01:00
def _normalize_term ( t : _DimTerm , t_k : int ,
scope : SymbolicScope ) - > Sequence [ tuple [ _DimTerm , int ] ] :
2024-02-15 13:53:05 +01:00
# If (t, t_k) is among the scope normalization rules, then return
2024-09-06 11:52:12 +03:00
# a list of `term * coefficient` to add to the expression containing (t, t_k).
# Returns the empty sequence if no normalizations are necessary.
if not scope . _normalization_rules : return [ ]
2024-02-15 13:53:05 +01:00
updates = [ ]
after , t_k_after = scope . _normalization_rules . get ( t , ( None , 0 ) )
if after is not None and t_k % t_k_after == 0 :
# We have t*t_k_after -> after.
# We subtract `t*t_k` and add `after * (t_k // t_k_after)`.
updates . append ( ( t , - t_k ) )
updates . extend ( ( t2 , tc2 * ( t_k / / t_k_after ) )
2024-02-20 23:13:20 +01:00
for t2 , tc2 in after . _sorted_terms )
2024-02-15 13:53:05 +01:00
return updates
if len ( t . _factors ) < = 1 :
return updates
# A product of factors; look up individually
for f , fexp in t . _factors :
2024-02-20 23:13:20 +01:00
f_after , f_k_after = scope . _normalization_rules . get ( _DimTerm ( ( ( f , fexp ) , ) ) , ( None , 0 ) )
2024-02-15 13:53:05 +01:00
if f_after is not None and t_k % f_k_after == 0 :
# We subtract `t*t_k`.
updates . append ( ( t , - t_k ) )
# And add `(t // f**fexp) * f_after * (t_k // f_k_after)`
2024-02-20 23:13:20 +01:00
t_without_f = t . divide ( _DimTerm ( ( ( f , fexp ) , ) ) )
2024-02-15 13:53:05 +01:00
updates . extend ( ( t2 . mul ( t_without_f ) , tc2 * ( t_k / / f_k_after ) )
2024-02-20 23:13:20 +01:00
for t2 , tc2 in f_after . _sorted_terms )
2024-02-15 13:53:05 +01:00
return updates
return updates
@staticmethod
def _normalize_sorted_terms ( terms : SortedTerms ,
scope : SymbolicScope ) - > DimSize :
""" Constructs a _DimExpr in normal form from sorted terms.
2023-09-05 22:15:22 -07:00
2024-02-03 06:38:01 +02:00
Ensures that the symbolic dimension is normalized , e . g . , does not
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
have terms with coefficient 0 , it reflects all the scope
normalization_rules , and it is represented as a Python integer if it is
known to be a constant .
2024-02-15 13:53:05 +01:00
Does not attempt to normalize the keys ( terms ) inside ` terms ` .
2023-09-05 22:15:22 -07:00
"""
2024-02-15 13:53:05 +01:00
for t , t_k in terms :
assert t_k != 0
if updates := _DimExpr . _normalize_term ( t , t_k , scope ) :
coeffs = dict ( terms )
for t1 , t1_k in updates :
2024-02-20 23:13:20 +01:00
_DimExpr . _add_coeff ( coeffs , t1 , t1_k )
2024-02-15 13:53:05 +01:00
terms = _DimExpr . _coeff_to_sorted_terms ( coeffs )
# TODO: check the case when we need to apply multiple normalizations
break
if not terms : return 0
if terms [ 0 ] [ 0 ] . is_constant : return terms [ 0 ] [ 1 ]
return _DimExpr ( terms , scope )
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
def _to_term ( self ) - > _DimTerm | None :
""" Extract the single term from a symbolic expression.
Returns None if the expression is not a single term . """
if len ( self . _sorted_terms ) > 1 : return None
( t , t_k ) , = self . _sorted_terms
return t if t_k == 1 else None
2024-01-05 14:48:53 +07:00
2024-02-20 23:13:20 +01:00
def _to_factor ( self ) - > _DimFactor | None :
""" Extract the factor from a symbolic expression.
Returns None if the expression is not a single factor . """
t = self . _to_term ( )
return t . to_factor ( ) if t is not None else None
2024-01-05 14:48:53 +07:00
2024-02-20 23:13:20 +01:00
def _to_var ( self ) - > str | None :
2024-01-05 14:48:53 +07:00
""" Extract the variable name from a symbolic expression.
Returns None if the expression is not a single variable . """
2024-02-20 23:13:20 +01:00
mon = self . _to_factor ( )
2024-01-05 14:48:53 +07:00
return mon . to_var ( ) if mon is not None else None
2023-09-05 22:15:22 -07:00
2024-02-15 13:53:05 +01:00
@staticmethod
2024-02-20 23:13:20 +01:00
def _to_constant ( e : DimSize ) - > int | None :
2024-01-01 23:09:42 +07:00
""" Extract the constant from a symbolic expression.
Returns None if the expression is not a single constant . """
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
if not isinstance ( e , _DimExpr ) :
return int ( e )
2024-02-20 23:13:20 +01:00
m , m_c = e . _leading_term
2024-02-15 13:53:05 +01:00
return m_c if m . is_constant else None
2024-01-01 23:09:42 +07:00
@property
2024-02-20 23:13:20 +01:00
def _is_constant ( self ) :
return _DimExpr . _to_constant ( self ) is not None
2024-01-01 23:09:42 +07:00
2024-02-20 23:13:20 +01:00
def _get_vars ( self ) - > set [ str ] :
2023-09-05 22:15:22 -07:00
""" The variables that appear in a symbolic dimension. """
acc = set ( )
2024-02-20 23:13:20 +01:00
for mon , _ in self . _sorted_terms :
2023-09-05 22:15:22 -07:00
acc . update ( mon . get_vars ( ) )
return acc
2024-02-20 23:13:20 +01:00
# There are some uses already of `get_vars`, we keep it a while longer
# for backwards compatibility.
get_vars = _get_vars
2024-02-15 13:53:05 +01:00
@overload
@staticmethod
def _linear_combination_sorted_pairs (
2024-02-03 06:38:01 +02:00
e1 : SortedTerms , i1 : int , f1 : int ,
2024-06-04 22:02:36 -07:00
e2 : SortedTerms , i2 : int , f2 : int ) - > SortedTerms : . . . # type: ignore[bad-return-type,unused-ignore]
2024-02-15 13:53:05 +01:00
@overload
@staticmethod
def _linear_combination_sorted_pairs (
e1 : SortedFactors , i1 : int , f1 : int ,
2024-06-04 22:02:36 -07:00
e2 : SortedFactors , i2 : int , f2 : int ) - > SortedFactors : . . . # type: ignore[bad-return-type,unused-ignore]
2024-02-15 13:53:05 +01:00
@staticmethod
def _linear_combination_sorted_pairs (
pairs1 , i1 , f1 ,
pairs2 , i2 , f2 ) :
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
""" Computes e1[i1:] * f1 + e2[i2:] * f2.
e1 , e2 , and the result are sorted with largest term first .
This is an optimization for a common operation . The unoptimized code would
2024-02-15 13:53:05 +01:00
compute each subexpression in turn . This works for both SortedTerms and SortedFactors .
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
"""
2024-02-15 13:53:05 +01:00
len1 = len ( pairs1 )
len2 = len ( pairs2 )
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
acc = [ ]
2024-02-15 13:53:05 +01:00
while i1 < len1 and i2 < len2 :
m1 , m1_c = pairs1 [ i1 ]
m2 , m2_c = pairs2 [ i2 ]
2024-02-20 23:13:20 +01:00
cmp = m1 . _syntactic_cmp ( m2 ) # Pick the largest term
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
if cmp < 0 :
acc . append ( ( m2 , m2_c * f2 ) )
i2 + = 1
elif cmp > 0 :
acc . append ( ( m1 , m1_c * f1 ) )
i1 + = 1
else : # They are equal, combine them
i1 + = 1
i2 + = 1
m1_c = m1_c * f1 + m2_c * f2
if m1_c == 0 : continue
acc . append ( ( m1 , m1_c ) )
2024-02-15 13:53:05 +01:00
if i1 < len1 :
acc . extend ( ( m1 , m1_c * f1 ) for m1 , m1_c in itertools . islice ( pairs1 , i1 , len1 ) if m1_c != 0 )
if i2 < len2 :
acc . extend ( ( m2 , m2_c * f2 ) for m2 , m2_c in itertools . islice ( pairs2 , i2 , len2 ) if m2_c != 0 )
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
return acc
2024-01-05 14:48:53 +07:00
def _syntactic_cmp ( self , other : _DimExpr ) - > int :
""" Returns -1 if self < other, 0 if self == other, 1 if self > other.
The comparison is done lexicographically ( syntactic ) , to be used for sorting .
The result is not related to the semantic value .
"""
2024-02-20 23:13:20 +01:00
s_terms = self . _sorted_terms
o_terms = other . _sorted_terms
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
if c := cmp_comparable ( self . _size , other . _size ) : return c
2024-02-20 23:13:20 +01:00
def cmp_factor ( s_f : tuple [ _DimTerm , int ] , o_f : tuple [ _DimTerm , int ] ) - > int :
if c := s_f [ 0 ] . _syntactic_cmp ( o_f [ 0 ] ) : return c
return cmp_comparable ( s_f [ 1 ] , o_f [ 1 ] )
return cmp_sequence ( s_terms , o_terms , cmp_factor )
2024-01-05 14:48:53 +07:00
2024-02-20 23:13:20 +01:00
def _eq ( self , other : _DimExpr ) - > bool :
2024-01-07 13:04:37 +02:00
# Equality is used very frequently because expressions are cached. We could
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
# but that would be too expensive. It would also have the unfortunate drawback
# that we cannot then cache `e.bounds()` because hashing invokes equality
# which would lead to infinite recursion.
diff = self - other
# We look for `self - other == k`, and we rely on the fact that when we
# normalize _DimExpr that represent integers as ints.
2024-01-10 08:45:03 +02:00
if is_symbolic_dim ( diff ) :
2024-01-07 13:04:37 +02:00
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
# cannot raise exceptions, because it is used indirectly when hashing.
# So, we say that the expressions are disequal, which is really unsound.
2024-06-12 08:47:17 +02:00
# See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported
2023-09-05 22:15:22 -07:00
return False
2024-01-07 13:04:37 +02:00
return diff == 0
2023-09-05 22:15:22 -07:00
def __hash__ ( self ) :
2024-02-15 13:53:05 +01:00
if self . _hash is None :
2024-02-20 23:13:20 +01:00
self . _hash = hash ( ( self . _sorted_terms , self . scope ) )
2024-01-05 22:10:50 +07:00
return self . _hash
2023-09-05 22:15:22 -07:00
def __str__ ( self ) :
2024-02-20 23:13:20 +01:00
def _one_term ( t , t_k ) :
abs_t_k = abs ( t_k )
sgn_t_k = " + " if t_k > 0 else " - "
if t . is_constant :
return f " { sgn_t_k } { abs_t_k } " if abs_t_k != 0 else " 0 "
if abs_t_k == 1 :
return f " { sgn_t_k } { t } "
return f " { sgn_t_k } { abs_t_k } * { t } "
# We print first the "larger" terms, so that the constant is last.
res = " " . join ( _one_term ( t , t_k )
for t , t_k in self . _sorted_terms )
[shape_polyO] Performance improvements for symbolic dimension manipulations (step 2)
We make the following improvements:
* Cache the state of the decision procedure after we process the explicit
constraints, and reuse it for new decisions.
* Rationalize the usage of add_implicit_constraints. We used to call it
conservatively, too often. Now we call it only once for each explicit constraint,
and once for each bounds decision we make. Then, in the add_implicit_constraints
we call it recursively when we encounter new sub-expressions.
* Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
since we should only need it for reporting error messages.
This speeds up inequality reasoning:
Before:
```
In [1]: from jax.experimental import export
...: from jax import core
...: a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])
In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
2024-02-14 16:40:38 +02:00
if res . startswith ( " + " ) :
2024-02-14 11:34:40 +02:00
res = res [ 2 : ]
2024-01-05 14:48:53 +07:00
return res
2023-09-05 22:15:22 -07:00
def __repr__ ( self ) :
return str ( self )
2024-02-15 13:53:05 +01:00
# A special case for linear combinations because they are common
@staticmethod
def _linear_combination ( e1 : DimSize , k1 : int ,
e2 : DimSize , k2 : int ,
scope : SymbolicScope ) - > DimSize :
""" Computes and normalizes `e1 * k1 + e2 * k2` """
if isinstance ( e1 , _DimExpr ) :
2024-02-20 23:13:20 +01:00
e1_terms = e1 . _sorted_terms
2024-02-15 13:53:05 +01:00
if isinstance ( e2 , _DimExpr ) :
e1 . scope . _check_same_scope ( e2 , when = " for linear combination " )
else :
if not isinstance ( e2 , _DimExpr ) :
return e1 * k1 + e2 * k2 # Constants
2024-02-20 23:13:20 +01:00
e1_terms = ( ( _DimTerm_one , op . index ( e1 ) ) , )
2024-02-15 13:53:05 +01:00
if isinstance ( e2 , _DimExpr ) :
2024-02-20 23:13:20 +01:00
e2_terms = e2 . _sorted_terms
2024-02-15 13:53:05 +01:00
elif e2 == 0 :
e2_terms = ( )
else :
2024-02-20 23:13:20 +01:00
e2_terms = ( ( _DimTerm_one , op . index ( e2 ) ) , )
2024-02-15 13:53:05 +01:00
new_terms = _DimExpr . _linear_combination_sorted_pairs ( e1_terms , 0 , k1 ,
e2_terms , 0 , k2 )
return _DimExpr . _normalize_sorted_terms ( new_terms , scope )
2023-09-05 22:15:22 -07:00
# We overload +, -, *, because they are fully defined for _DimExpr.
def __add__ ( self , other ) :
if isinstance ( other , core . Tracer ) or not _convertible_to_poly ( other ) :
return self . __jax_array__ ( ) . __add__ ( other )
2024-02-03 06:38:01 +02:00
if isinstance ( other , int ) and other == 0 : return self
2024-02-15 13:53:05 +01:00
return _DimExpr . _linear_combination ( self , 1 , other , 1 , self . scope )
2023-09-05 22:15:22 -07:00
def __radd__ ( self , other ) :
if isinstance ( other , core . Tracer ) or not _convertible_to_poly ( other ) :
return self . __jax_array__ ( ) . __radd__ ( other )
2024-02-03 06:38:01 +02:00
if isinstance ( other , int ) and other == 0 : return self
2024-02-15 13:53:05 +01:00
return _DimExpr . _linear_combination ( self , 1 , other , 1 , self . scope )
2023-09-05 22:15:22 -07:00
def __sub__ ( self , other ) :
if isinstance ( other , core . Tracer ) or not _convertible_to_poly ( other ) :
return self . __jax_array__ ( ) . __sub__ ( other )
2024-02-03 06:38:01 +02:00
if isinstance ( other , int ) and other == 0 : return self
2024-02-15 13:53:05 +01:00
return _DimExpr . _linear_combination ( self , 1 , other , - 1 , self . scope )
2023-09-05 22:15:22 -07:00
def __rsub__ ( self , other ) :
if isinstance ( other , core . Tracer ) or not _convertible_to_poly ( other ) :
return self . __jax_array__ ( ) . __rsub__ ( other )
2024-02-15 13:53:05 +01:00
return _DimExpr . _linear_combination ( self , - 1 , other , 1 , self . scope )
2023-09-05 22:15:22 -07:00
2024-02-15 13:53:05 +01:00
def __neg__ ( self ) - > DimSize :
return _DimExpr . _linear_combination ( self , - 1 , 0 , 0 , self . scope )
2023-09-05 22:15:22 -07:00
def __mul__ ( self , other ) :
if isinstance ( other , core . Tracer ) or not _convertible_to_poly ( other ) :
return self . __jax_array__ ( ) . __mul__ ( other )
2024-02-15 13:53:05 +01:00
if isinstance ( other , int ) :
if other == 1 : return self
if other == 0 : return 0
return _DimExpr . _linear_combination ( self , other , 0 , 0 , self . scope )
2024-01-01 23:09:42 +07:00
other = _ensure_poly ( other , " mul " , self . scope )
2024-02-20 23:13:20 +01:00
coeffs : dict [ _DimTerm , int ] = { }
for mon1 , coeff1 in self . _sorted_terms :
for mon2 , coeff2 in other . _sorted_terms :
2023-09-05 22:15:22 -07:00
mon = mon1 . mul ( mon2 )
2024-02-20 23:13:20 +01:00
_DimExpr . _add_coeff ( coeffs , mon , coeff1 * coeff2 )
2024-02-15 13:53:05 +01:00
return _DimExpr . _normalize_sorted_terms ( _DimExpr . _coeff_to_sorted_terms ( coeffs ) ,
self . scope )
2023-09-05 22:15:22 -07:00
def __rmul__ ( self , other ) :
if isinstance ( other , core . Tracer ) or not _convertible_to_poly ( other ) :
return self . __jax_array__ ( ) . __rmul__ ( other )
2024-02-15 13:53:05 +01:00
if isinstance ( other , int ) :
if other == 1 : return self
if other == 0 : return 0
return _DimExpr . _linear_combination ( self , other , 0 , 0 , self . scope )
2024-01-01 23:09:42 +07:00
return _ensure_poly ( other , " mul " , self . scope ) . __mul__ ( self )
2023-09-05 22:15:22 -07:00
def __pow__ ( self , power , modulo = None ) :
assert modulo is None
try :
power = int ( power )
except :
2023-09-22 14:54:31 -07:00
raise InconclusiveDimensionOperation ( f " Symbolic dimension cannot be raised to non-integer power ' { self } ' ^ ' { power } ' " )
2023-09-05 22:15:22 -07:00
return functools . reduce ( op . mul , [ self ] * power )
def __floordiv__ ( self , divisor ) :
if isinstance ( divisor , core . Tracer ) or not _convertible_to_poly ( divisor ) :
return self . __jax_array__ ( ) . __floordiv__ ( divisor )
2024-02-20 23:13:20 +01:00
return self . _divmod ( divisor ) [ 0 ]
2023-09-05 22:15:22 -07:00
def __rfloordiv__ ( self , other ) :
if isinstance ( other , core . Tracer ) or not _convertible_to_poly ( other ) :
return self . __jax_array__ ( ) . __rfloordiv__ ( other )
2024-01-01 23:09:42 +07:00
return _ensure_poly ( other , " floordiv " , self . scope ) . __floordiv__ ( self )
2023-09-05 22:15:22 -07:00
def __truediv__ ( self , divisor ) :
# Used for "/", which always returns a float
return self . __jax_array__ ( ) . __truediv__ ( divisor )
def __rtruediv__ ( self , dividend ) :
# Used for "/", when dividend is not a _DimExpr
return self . __jax_array__ ( ) . __rtruediv__ ( dividend )
def __mod__ ( self , divisor ) :
if isinstance ( divisor , core . Tracer ) or not _convertible_to_poly ( divisor ) :
return self . __jax_array__ ( ) . __mod__ ( divisor )
2024-02-20 23:13:20 +01:00
return self . _divmod ( divisor ) [ 1 ]
2023-09-05 22:15:22 -07:00
def __rmod__ ( self , dividend ) :
if isinstance ( dividend , core . Tracer ) or not _convertible_to_poly ( dividend ) :
return self . __jax_array__ ( ) . __rmod__ ( dividend )
2024-01-01 23:09:42 +07:00
return _ensure_poly ( dividend , " mod " , self . scope ) . __mod__ ( self )
2023-09-05 22:15:22 -07:00
def __divmod__ ( self , divisor ) :
if isinstance ( divisor , core . Tracer ) or not _convertible_to_poly ( divisor ) :
return self . __jax_array__ ( ) . __divmod__ ( divisor )
2024-02-20 23:13:20 +01:00
return self . _divmod ( divisor )
2023-09-05 22:15:22 -07:00
def __rdivmod__ ( self , dividend ) :
if isinstance ( dividend , core . Tracer ) or not _convertible_to_poly ( dividend ) :
return self . __jax_array__ ( ) . __rdivmod__ ( dividend )
2024-01-01 23:09:42 +07:00
return _ensure_poly ( dividend , " divmod " , self . scope ) . __divmod__ ( self )
2023-09-05 22:15:22 -07:00
def __int__ ( self ) :
2024-02-20 23:13:20 +01:00
if ( c := _DimExpr . _to_constant ( self ) ) is not None :
2024-02-03 06:38:01 +02:00
return c
raise InconclusiveDimensionOperation ( f " Symbolic dimension ' { self } ' used in a context that requires a constant " )
2023-09-05 22:15:22 -07:00
# We must overload __eq__ and __ne__, or else we get unsound defaults.
2024-01-07 13:04:37 +02:00
def __eq__ ( self , other : Any ) - > bool :
2024-01-01 23:09:42 +07:00
if isinstance ( other , _DimExpr ) :
if self . scope is not other . scope :
return False
elif not core . is_constant_dim ( other ) :
2024-01-07 13:04:37 +02:00
return False
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
# Equality is used very frequently because expressions are cached. We could
# implement a more precise version based on `(self - other).bounds() = (0, 0)`
# but that would be too expensive. It would also have the unfortunate drawback
# that we cannot then cache `e.bounds()` because hashing invokes equality
# which would lead to infinite recursion.
diff = self - other
# We look for `self - other == k`, and we rely on the fact that when we
# normalize _DimExpr that represent integers as ints.
if is_symbolic_dim ( diff ) :
# Here we really ought to raise InconclusiveDimensionOperation, but __eq__
# cannot raise exceptions, because it is used indirectly when hashing.
# So, we say that the expressions are disequal, which is really unsound.
2024-06-12 08:47:17 +02:00
# See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
return False
return diff == 0
2024-01-07 13:04:37 +02:00
def __ne__ ( self , other : Any ) - > bool :
return not self . __eq__ ( other )
2023-09-05 22:15:22 -07:00
2023-12-23 17:58:20 +07:00
def __ge__ ( self , other : DimSize ) - > bool :
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
return _geq_decision ( self , other , lambda : f " ' { self } ' >= ' { other } ' " )
2023-09-05 22:15:22 -07:00
def __le__ ( self , other : DimSize ) :
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
return _geq_decision ( other , self , lambda : f " ' { self } ' <= ' { other } ' " )
2023-09-05 22:15:22 -07:00
def __gt__ ( self , other : DimSize ) :
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
return not _geq_decision ( other , self , lambda : f " ' { self } ' > ' { other } ' " )
2023-09-05 22:15:22 -07:00
def __lt__ ( self , other : DimSize ) :
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
return not _geq_decision ( self , other , lambda : f " ' { self } ' < ' { other } ' " )
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
def _divmod ( self , divisor : DimSize ) - > tuple [ DimSize , int ] :
2023-09-05 22:15:22 -07:00
"""
2024-02-20 23:13:20 +01:00
Floor division with remainder ( divmod ) generalized to expressions .
2023-09-05 22:15:22 -07:00
If the ` divisor ` is not a constant , the remainder must be 0.
If the ` divisor ` is a constant , the remainder may be non 0 , for consistency
with integer divmod .
: return : Quotient resulting from polynomial division and integer remainder .
"""
try :
dividend , quotient = self , 0
# invariant: self = dividend + divisor * quotient
# quotient and dividend are changed in the loop; the leading term of
# dividend decreases at each iteration.
2024-06-04 22:02:36 -07:00
while is_symbolic_dim ( dividend ) and not dividend . _is_constant : # type: ignore[attribute-error,unused-ignore]
2024-02-20 23:13:20 +01:00
mon , count = dividend . _leading_term
2024-02-15 13:53:05 +01:00
if isinstance ( divisor , _DimExpr ) :
2024-02-20 23:13:20 +01:00
dterm , dcount = divisor . _leading_term
qterm = mon . divide ( dterm )
2024-02-15 13:53:05 +01:00
else :
2024-02-20 23:13:20 +01:00
qterm , dcount = mon , int ( divisor )
2023-09-05 22:15:22 -07:00
qcount , rcount = divmod ( count , dcount )
if rcount != 0 :
raise InconclusiveDimensionOperation ( " " )
2024-02-20 23:13:20 +01:00
q = _DimExpr . _from_term ( qterm , qcount , self . scope )
2023-09-05 22:15:22 -07:00
quotient + = q
2024-05-17 09:46:36 +01:00
dividend - = q * divisor
2023-09-05 22:15:22 -07:00
dividend = int ( dividend ) # type: ignore[assignment]
2024-02-15 13:53:05 +01:00
if isinstance ( divisor , _DimExpr ) :
2023-09-05 22:15:22 -07:00
if dividend != 0 :
raise InconclusiveDimensionOperation ( " " )
remainder = 0
2024-02-15 13:53:05 +01:00
else :
2024-05-17 09:46:36 +01:00
q , r = divmod ( dividend , int ( divisor ) )
2024-02-15 13:53:05 +01:00
quotient + = q
remainder = r
2023-09-05 22:15:22 -07:00
2023-10-12 13:15:22 +01:00
if config . enable_checks . value :
2024-01-01 23:09:42 +07:00
v1 = divisor * quotient
v2 = v1 + remainder
2024-07-23 04:32:09 -07:00
assert self == _ensure_poly ( v2 , " check " , self . scope ) , (
self , v2 , type ( self ) , type ( v2 ) )
assert self == _ensure_poly ( divisor * quotient + remainder , " test " , self . scope ) , (
self , divisor , quotient , remainder )
2023-09-05 22:15:22 -07:00
return quotient , remainder
except InconclusiveDimensionOperation :
2024-02-20 23:13:20 +01:00
return ( _DimExpr . _from_operation ( _DimFactor . FLOORDIV , self , divisor ,
scope = self . scope ) , # type: ignore
_DimExpr . _from_operation ( _DimFactor . MOD , self , divisor ,
scope = self . scope ) )
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
def _evaluate ( self , env : DimVarEnv ) :
2023-09-05 22:15:22 -07:00
# Evaluates as a value of dtype=core.dim_value_dtype()
2024-09-06 11:52:12 +03:00
terms = [ _evaluate_multiply ( t . evaluate ( env , self . scope ) , core . dim_constant ( t_k ) )
2024-02-20 23:13:20 +01:00
for t , t_k in self . _sorted_terms ]
2023-09-05 22:15:22 -07:00
return functools . reduce ( _evaluate_add , terms ) if len ( terms ) > 1 else terms [ 0 ]
2024-01-15 15:02:33 +02:00
def max ( self , other : DimSize ) - > DimSize :
2024-02-09 23:18:52 +01:00
lb , ub = _bounds_decision ( self - other , BoundsPrecision . FOR_GEQ0_OR_LEQ0 )
2024-01-15 15:02:33 +02:00
if 0 < = lb : return self
if ub < = 0 : return other
2024-02-20 23:13:20 +01:00
return _DimExpr . _from_operation ( _DimFactor . MAX , self , other , scope = self . scope )
2023-12-13 10:14:27 +01:00
2024-01-15 15:02:33 +02:00
def rmax ( self , other : DimSize ) - > DimSize :
2024-02-09 23:18:52 +01:00
lb , ub = _bounds_decision ( self - other , BoundsPrecision . FOR_GEQ0_OR_LEQ0 )
2024-01-15 15:02:33 +02:00
if 0 < = lb : return self
if ub < = 0 : return other
2024-02-20 23:13:20 +01:00
return _DimExpr . _from_operation ( _DimFactor . MAX , other , self , scope = self . scope )
2023-12-13 10:14:27 +01:00
2024-01-15 15:02:33 +02:00
def min ( self , other : DimSize ) - > DimSize :
2024-02-09 23:18:52 +01:00
lb , ub = _bounds_decision ( self - other , BoundsPrecision . FOR_GEQ0_OR_LEQ0 )
2024-01-15 15:02:33 +02:00
if 0 < = lb : return other
if ub < = 0 : return self
2024-02-20 23:13:20 +01:00
return _DimExpr . _from_operation ( _DimFactor . MIN , self , other , scope = self . scope )
2023-12-13 10:14:27 +01:00
2024-01-15 15:02:33 +02:00
def rmin ( self , other : DimSize ) - > DimSize :
2024-02-09 23:18:52 +01:00
lb , ub = _bounds_decision ( self - other , BoundsPrecision . FOR_GEQ0_OR_LEQ0 )
2024-01-15 15:02:33 +02:00
if 0 < = lb : return other
if ub < = 0 : return self
2024-02-20 23:13:20 +01:00
return _DimExpr . _from_operation ( _DimFactor . MIN , other , self , scope = self . scope )
2023-09-05 22:15:22 -07:00
@staticmethod
2024-02-20 23:13:20 +01:00
def _get_aval ( dim : _DimExpr ) :
2023-09-05 22:15:22 -07:00
return core . dim_value_aval ( )
def dimension_as_value ( self ) :
""" Turns a dimension size into a Jax value that we can compute with. """
return _dim_as_value ( self )
def __jax_array__ ( self ) :
# Used for implicit coercions of polynomials as JAX arrays
return _dim_as_value ( self )
2024-02-12 23:35:35 -08:00
def __deepcopy__ ( self , memo ) :
return _DimExpr (
2024-02-20 23:13:20 +01:00
copy . deepcopy ( self . _sorted_terms , memo ) ,
2024-02-12 23:35:35 -08:00
copy . deepcopy ( self . _scope , memo ) )
2024-02-20 23:13:20 +01:00
2024-01-05 14:48:53 +07:00
def cmp_comparable ( i1 , i2 ) - > int :
if i1 < i2 : return - 1
if i1 > i2 : return 1
return 0
def cmp_sequence ( s1 , s2 , elem_cmp ) - > int :
""" Compares two sequences using `elem_cmp`. """
l2 = len ( s2 )
for i , e1 in enumerate ( s1 ) :
if i > = l2 : return 1
if c := elem_cmp ( e1 , s2 [ i ] ) : return c
if len ( s1 ) < l2 : return - 1
return 0
2024-01-01 23:09:42 +07:00
class SymbolicScope :
""" Indentifies a scope for symbolic expressions.
All symbolic expressions that interact ( e . g . , appear in the argument shapes
for one JAX function invocation , or are involved in arithmetic operations )
must be from the same scope and must share the same SymbolicScope object .
Holds the constraints on symbolic expressions .
2024-05-31 15:14:09 +03:00
See [ the README ] ( https : / / jax . readthedocs . io / en / latest / export / shape_poly . html #user-specified-symbolic-constraints)
2024-01-01 23:09:42 +07:00
for more details .
Args :
constraints_str : A sequence of constraints on symbolic dimension expressions ,
2024-02-20 23:13:20 +01:00
of the form ` e1 > = e2 ` or ` e1 < = e2 ` or ` e1 == e2 ` .
2024-01-01 23:09:42 +07:00
"""
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
2024-01-01 23:09:42 +07:00
def __init__ ( self ,
constraints_str : Sequence [ str ] = ( ) ) :
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
if isinstance ( constraints_str , str ) :
raise ValueError (
" The symbolic constraints should be a sequence of strings. "
f " Got { repr ( constraints_str ) } " )
[shape_polyO] Performance improvements for symbolic dimension manipulations (step 2)
We make the following improvements:
* Cache the state of the decision procedure after we process the explicit
constraints, and reuse it for new decisions.
* Rationalize the usage of add_implicit_constraints. We used to call it
conservatively, too often. Now we call it only once for each explicit constraint,
and once for each bounds decision we make. Then, in the add_implicit_constraints
we call it recursively when we encounter new sub-expressions.
* Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
since we should only need it for reporting error messages.
This speeds up inequality reasoning:
Before:
```
In [1]: from jax.experimental import export
...: from jax import core
...: a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])
In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
2024-02-14 16:40:38 +02:00
self . _initialized = False
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
self . _location_frame = source_info_util . user_frame ( source_info_util . current ( ) )
# Keep the explicit constraints in the order in which they were added
self . _explicit_constraints : list [ _SymbolicConstraint ] = [ ]
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
# We cache the _DimExpr.bounds calls. The result depends only on the
# explicit and implicit constraints, so it is safe to keep it in the
2024-02-09 23:18:52 +01:00
# scope. Set the cache before we parse constraints. We also keep the
# bounds precision with which we computed the cached result.
self . _bounds_cache : dict [ _DimExpr ,
tuple [ float , float , BoundsPrecision ] ] = { }
[shape_polyO] Performance improvements for symbolic dimension manipulations (step 2)
We make the following improvements:
* Cache the state of the decision procedure after we process the explicit
constraints, and reuse it for new decisions.
* Rationalize the usage of add_implicit_constraints. We used to call it
conservatively, too often. Now we call it only once for each explicit constraint,
and once for each bounds decision we make. Then, in the add_implicit_constraints
we call it recursively when we encounter new sub-expressions.
* Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
since we should only need it for reporting error messages.
This speeds up inequality reasoning:
Before:
```
In [1]: from jax.experimental import export
...: from jax import core
...: a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])
In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
2024-02-14 16:40:38 +02:00
# We store here a decision procedure state initialized with all the
# _explicit_constraints.
self . _decision_initial_state : Any | None = None
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
# We turn the equality constraints into normalization rules.
# For an explicit constraint `t*tk == e`, we keep
# `_normalization_rules[t] = (e, tk)`.
# During building of expressions, if we encounter the term
# `t*tk1` and `tk1 % tk == 0`, we replace it with `e*(tk1 // tk)`.
self . _normalization_rules : NormalizationRules = { }
for c_str in constraints_str :
self . _parse_and_process_explicit_constraint ( c_str )
self . _bounds_cache . clear ( )
[shape_polyO] Performance improvements for symbolic dimension manipulations (step 2)
We make the following improvements:
* Cache the state of the decision procedure after we process the explicit
constraints, and reuse it for new decisions.
* Rationalize the usage of add_implicit_constraints. We used to call it
conservatively, too often. Now we call it only once for each explicit constraint,
and once for each bounds decision we make. Then, in the add_implicit_constraints
we call it recursively when we encounter new sub-expressions.
* Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
since we should only need it for reporting error messages.
This speeds up inequality reasoning:
Before:
```
In [1]: from jax.experimental import export
...: from jax import core
...: a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])
In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```
After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
2024-02-14 16:40:38 +02:00
self . _initialized = True
2024-01-01 23:09:42 +07:00
def __str__ ( self ) - > str :
extras = [ ]
if self . _explicit_constraints :
extras . append ( " with constraints: " )
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
for constr in self . _explicit_constraints :
extras . append ( f " { constr . debug_str } " )
2024-01-01 23:09:42 +07:00
loc = source_info_util . _summarize_frame ( self . _location_frame ) if self . _location_frame else " unknown "
2024-02-20 23:13:20 +01:00
return f " { id ( self ) } created at { loc } " + " \n " . join ( extras )
2024-01-01 23:09:42 +07:00
__repr__ = __str__
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
def _parse_and_process_explicit_constraint ( self , c_str : str ) :
if not isinstance ( c_str , str ) :
2024-01-01 23:09:42 +07:00
raise ValueError (
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
f " SymbolicScope constraint must be a string: got { repr ( c_str ) } " )
cmp_pos , cmp , is_geq = c_str . find ( " == " ) , Comparator . EQ , True
if cmp_pos < 0 :
cmp_pos , cmp , is_geq = c_str . find ( " >= " ) , Comparator . GEQ , True
if cmp_pos < 0 :
cmp_pos , cmp , is_geq = c_str . find ( " <= " ) , Comparator . GEQ , False
if cmp_pos < 0 :
raise ValueError ( " Constraint parsing error: must contain one of ' == ' or ' >= ' or ' <= ' " )
e1_str = c_str [ : cmp_pos ]
2024-06-04 22:02:36 -07:00
e1 , = _Parser ( e1_str , None , repr ( e1_str ) , self ) . parse ( ) # type: ignore[name-error,unused-ignore]
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
e2_str = c_str [ cmp_pos + 2 : ]
2024-06-04 22:02:36 -07:00
e2 , = _Parser ( e2_str , None , repr ( e2_str ) , self ) . parse ( ) # type: ignore[name-error,unused-ignore]
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
if cmp == Comparator . GEQ and not is_geq :
e1 , e2 = e2 , e1
diff = e1 - e2
2024-02-20 23:13:20 +01:00
if ( diff_const := _DimExpr . _to_constant ( diff ) ) is not None :
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
if ( ( cmp == Comparator . EQ and diff_const != 0 ) or
( cmp == Comparator . GEQ and diff_const < 0 ) ) :
raise ValueError ( f " Unsatisfiable explicit constraint: { c_str } " )
return
if cmp == Comparator . EQ :
if not isinstance ( e1 , _DimExpr ) :
raise ValueError ( " Invalid equality constraint: {e1} == {e2} . "
" The left-hand-side must be of the form `term * coefficient`. " )
2024-02-20 23:13:20 +01:00
( before , before_k ) , * rest = e1 . _sorted_terms
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
if rest :
raise ValueError ( " Invalid equality constraint: {e1} == {e2} . "
" The left-hand-side must be of the form `term * coefficient`. " )
2024-06-04 22:02:36 -07:00
after = _ensure_poly ( e2 , " parse_constraint " , e1 . scope ) # type: ignore[name-error,unused-ignore]
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
if before in self . _normalization_rules :
raise NotImplementedError (
f " Found multiple equality constraints with the same left-hand-side: { before } " )
self . _normalization_rules [ before ] = ( after , before_k )
2024-01-01 23:09:42 +07:00
2024-09-06 11:52:12 +03:00
constr = _SymbolicConstraint ( debug_str = c_str , cmp = cmp , e1 = e1 , e2 = e2 )
self . _explicit_constraints . append ( constr )
2024-01-01 23:09:42 +07:00
def _check_same_scope ( self , other : _DimExpr ,
when : str = " " ,
self_descr : str = " " ,
other_descr : str = " unknown " ) :
if self is not other . scope :
raise ValueError (
f " Invalid mixing of symbolic scopes { when } . \n "
f " Expected { self_descr } scope { self } \n "
f " and found for ' { other } ' ( { other_descr } ) scope { other . scope } \n "
2024-06-12 08:47:17 +02:00
f " See https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints. " )
2024-01-01 23:09:42 +07:00
2024-02-05 05:37:34 -08:00
def _clear_caches ( self ) :
self . _bounds_cache . clear ( )
2024-01-01 23:09:42 +07:00
2024-02-09 23:18:52 +01:00
class BoundsPrecision ( enum . Enum ) :
""" Specifies desired precision for the bounds calculation.
Since the bounds calculations are expensive , we allow the caller to specify
a sufficient condition for a result . As the bounds calculations progresses
the lower bounds is progressively increased and the upper bounds is
progressively decreased . Depending on the precision , we may stop the
computation early , if the results are sufficient for the use case .
2024-02-20 23:13:20 +01:00
The enumeration values are chosen such that , if " (lb, ub) " are sufficient
2024-02-09 23:18:52 +01:00
for a precision value then they are also sufficient for any smaller
precision .
"""
# For static evaluation of "max(e1, e2)", can stop if "lb >= 0 OR ub <= 0"
FOR_GEQ0_OR_LEQ0 = 0
# For deciding inequalities, such as "e1 >= e2", can stop if "lb >= 0 OR ub < 0"
FOR_GEQ0_OR_LT0 = 1
# Do not stop early, refine bounds as much as possible.
BEST = 2
def _bounds_are_sufficient ( self , lb : float , ub : float ) - > bool :
if self == BoundsPrecision . FOR_GEQ0_OR_LEQ0 :
return lb > = 0 or ub < = 0
if self == BoundsPrecision . FOR_GEQ0_OR_LT0 :
return lb > = 0 or ub < 0
return False
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
#
# Calling convention:
# _bounds_decision(e, stop_early)
# returns a tuple with the lower and upper bound of e.
# `stop_early(lb, ub)` can be called in an iterative process to decide if the
# current bounds are tight enough.
# TODO: remove this trampoline when we refactor the sources
def _bounds_decision_unimplemented (
d : DimSize ,
2024-02-09 23:18:52 +01:00
prec : BoundsPrecision ) - > tuple [ float , float ] :
del d , prec
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
raise NotImplementedError ( " _bounds_decision is uninitialized " )
2024-02-09 23:18:52 +01:00
_bounds_decision : Callable [ [ DimSize , BoundsPrecision ] ,
[shape_poly] Add a decision procedure for inequalities.
In a previous PR (#19285) we added support for inequality
constaints on symbolic expressions, but with limited support
for the cases when a constrain contains more than one term,
e.g., "a >= b".
Here we add a simple decision procedure for such inequalities,
based on the elimination algorithm based on the following properties:
* if we have two constraints "a + b >= 0" and "-a + c >= 0" we can
eliminate "a" and infer the derived constraint "b + c >= 0".
* the lower bound of "a + c", in presence of a constraint "a >= b"
it greater-or-equal to "b + c".
The above rules can be generalized to cases when the eliminated
terms have coefficients different than 1.
This algorithm is exponential in the number of constraints, but
we implement a limited form. When we add a constraint we combine
it with already added constraints, but the result of the combination
is not combined further. This is sufficient for the cases we
have encountered so far.
The termination of the algorithm is ensured by always eliminating
the largest (leading) term, ensuring that the result of a combination of
constraints has a smaller leading term.
With this added power for reasoning, we can retire the previous
heuristics for handling "min", "max", "floordiv" and "mod" and replace
them with the addition of some implicit constraints for them,
e.g., "max(a, b) >= a", etc., and then letting the decision procedure
do its job.
We moved the logic for deciding inequalities, to a new file: shape_poly_decision.py.
2024-01-20 08:47:52 +00:00
tuple [ float , float ] ] = _bounds_decision_unimplemented
2024-02-09 23:18:52 +01:00
def _geq_decision ( e1 : DimSize , e2 : DimSize , cmp_str : Callable [ [ ] , str ] ) - > bool :
""" Implements `e1 >= e2`.
Args :
e1 , e2 : the expressions to compare for greater - equal
cmp_str : a callable such that ` cmp_str ( ) ` describes the comparison
2024-02-20 23:13:20 +01:00
for error messages , e . g . , " a <= b " . Without this all comparisons would
2024-02-09 23:18:52 +01:00
be reported as " >= " .
Raises InconclusiveDimensionOperation if the result is not conclusive .
"""
if isinstance ( e1 , _DimExpr ) :
scope = e1 . scope
if isinstance ( e2 , _DimExpr ) :
scope . _check_same_scope ( e2 , f " when comparing { cmp_str ( ) } " )
elif isinstance ( e2 , _DimExpr ) :
scope = e2 . scope
else :
return int ( e1 ) > = int ( e2 )
lb , ub = _bounds_decision ( e1 - e2 , BoundsPrecision . FOR_GEQ0_OR_LT0 )
if lb > = 0 :
return True
if ub < 0 :
return False
if scope . _explicit_constraints :
describe_scope = f " \n Using symbolic scope { scope } "
else :
describe_scope = " "
raise InconclusiveDimensionOperation (
f " Symbolic dimension comparison { cmp_str ( ) } is inconclusive. { describe_scope } " )
2024-02-20 23:13:20 +01:00
core . pytype_aval_mappings [ _DimExpr ] = _DimExpr . _get_aval
xla . pytype_aval_mappings [ _DimExpr ] = _DimExpr . _get_aval
2023-09-05 22:15:22 -07:00
dtypes . _weak_types . append ( _DimExpr )
def _convertible_to_int ( p : DimSize ) - > bool :
try :
2024-01-05 14:48:53 +07:00
op . index ( p ) # type: ignore
2023-09-05 22:15:22 -07:00
return True
except :
return False
def _ensure_poly ( p : DimSize ,
2024-01-01 23:09:42 +07:00
operation_name : str ,
scope : SymbolicScope ) - > _DimExpr :
if isinstance ( p , _DimExpr ) :
scope . _check_same_scope ( p , when = f " for operation { operation_name } " )
return p
2023-09-05 22:15:22 -07:00
if _convertible_to_int ( p ) :
2024-02-20 23:13:20 +01:00
return _DimExpr ( ( ( _DimTerm_one , op . index ( p ) ) , ) , scope )
2024-02-15 13:53:05 +01:00
raise TypeError ( f " Symbolic dimension { operation_name } not supported for { p } . " )
2023-09-05 22:15:22 -07:00
def _convertible_to_poly ( p : DimSize ) - > bool :
return isinstance ( p , _DimExpr ) or _convertible_to_int ( p )
2024-01-10 08:45:03 +02:00
def is_symbolic_dim ( p : DimSize ) - > bool :
2024-05-31 15:14:09 +03:00
""" Checks if a dimension is symbolic.
"""
2023-09-05 22:15:22 -07:00
return isinstance ( p , _DimExpr )
2024-01-10 08:45:03 +02:00
def is_poly_dim ( p : DimSize ) - > bool :
# TODO: deprecated January 2024, remove June 2024.
warnings . warn ( " is_poly_dim is deprecated, use export.is_symbolic_dim " ,
DeprecationWarning , stacklevel = 2 )
return is_symbolic_dim ( p )
2023-09-05 22:15:22 -07:00
dtypes . python_scalar_dtypes [ _DimExpr ] = dtypes . python_scalar_dtypes [ int ]
def _einsum_contract_path ( * operands , * * kwargs ) :
""" Like opt_einsum.contract_path, with support for DimExpr shapes.
We use opt_einsum . contract_path to compute the schedule , using a fixed
constant for all dimension variables . This is safe because we throw an
error if there are more than 1 contractions . Essentially , we just use
opt_einsum . contract_path to parse the specification .
"""
# Replace the polymorphic shapes with some concrete shapes for calling
# into opt_einsum.contract_path, because the latter wants to compute the
# sizes of operands and intermediate results.
fake_ops = [ ]
for operand in operands :
# We replace only array operands
if not hasattr ( operand , " dtype " ) :
fake_ops . append ( operand )
else :
shape = np . shape ( operand )
def fake_dim ( d ) :
if core . is_constant_dim ( d ) :
return d
else :
if not isinstance ( d , _DimExpr ) :
raise TypeError ( f " Encountered unexpected shape dimension { d } " )
# It is Ok to replace all polynomials with the same value. We may miss
# here some errors due to non-equal dimensions, but we catch them
# later.
return 8
fake_ops . append ( jax . ShapeDtypeStruct ( tuple ( map ( fake_dim , shape ) ) ,
operand . dtype ) )
contract_fake_ops , contractions = opt_einsum . contract_path ( * fake_ops ,
* * kwargs )
contract_operands = [ ]
for operand in contract_fake_ops :
idx = tuple ( i for i , fake_op in enumerate ( fake_ops ) if operand is fake_op )
assert len ( idx ) == 1
contract_operands . append ( operands [ idx [ 0 ] ] )
return contract_operands , contractions
lax_numpy . _poly_einsum_handlers [ _DimExpr ] = _einsum_contract_path
# To implement shape-constraint checking we use a shape assertion primitive.
# shape_assertion_p.bind(assert_what: bool, *error_message_inputs,
# error_message="...{0}...{1}")
# where "{0}" refers to error_message_inputs[0], etc.
shape_assertion_p = core . Primitive ( " shape_assertion " )
shape_assertion_p . multiple_results = True
shape_assertion_p . def_effectful_abstract_eval (
lambda * _ , * * __ : ( ( ) , { shape_assertion_effect } ) ) # type: ignore
def _shape_assertion_lowering_rule ( ctx : mlir . LoweringRuleContext ,
assert_what : mlir . ir . Value ,
* error_message_inputs : mlir . ir . Value ,
error_message : str ) :
op = mlir . custom_call (
" shape_assertion " ,
result_types = [ ] , # No results
operands = [ assert_what , * error_message_inputs ] ,
has_side_effect = True ,
extra_attributes = dict ( error_message = mlir . ir . StringAttr . get ( error_message ) )
)
return op . results
mlir . register_lowering ( shape_assertion_p , _shape_assertion_lowering_rule )
class ShapeAssertionEffect ( effects . Effect ) :
__str__ = lambda _ : " ShapeAssertionEffect "
shape_assertion_effect = ShapeAssertionEffect ( )
effects . lowerable_effects . add_type ( ShapeAssertionEffect )
effects . control_flow_allowed_effects . add_type ( ShapeAssertionEffect )
effects . remat_allowed_effects . add_type ( ShapeAssertionEffect )
effects . custom_derivatives_allowed_effects . add_type ( ShapeAssertionEffect )
def shape_assertion ( assert_what : jax . Array ,
* error_message_inputs : jax . Array ,
error_message : str ) - > None :
""" Adds a shape assertion in the code.
Args :
assert_what : a boolean asserted to be true . Must be computed based only
on dimension expressions , so that it can be evaluated after shape
refinement .
error_message_inputs : integers expressions whose values can be referenced
in the ` error_message ` . Must be computed based only
on dimension expressions , so that they can be evaluated after shape
refinement .
error_message : an error message , possibly containing format specifiers
{ 0 } , { 1 } , . . . , referencing the values of the ` error_message_inputs ` .
The format specifiers are sometimes processed with Python ' s
` string : : format ` method , and sometimes with ` llvm : : formatv ` .
"""
[export] Simplify export internals, prepare for integration with AOT APIs
In preparation for a better integration of the jax.experimental.export with
the AOT APIs, we make several simplifications:
* turn on always the generation of shape assertions in presence of shape
polymorphism. Previously, shape assertions were turned on unless the
serialization version was less than 7 (possible only before March 27th, 2024
when the minimum serialization version was bumped to 9), or if the
user specified explicitly that shape assertions should be turned off. It is
not safe to turn off shape assertions and I am not aware of an instance where
somebody had to turn them off, except for temporary debugging. We keep the
`DisabledSafetyCheck.shape_assertions` API for now, for backwards compatibility,
but it has no effect and it emits a deprecation warning.
* remove the code that was conditional on the serialization version
being less than 9, e.g., for the lowering in presence of effects.
* remove a safety check that ensures that when `export` is used on JAX
callables, i.e., not the result of `jax.jit`, the code should not
contain non-replicated sharding annotations. This usage of `export` is
rare and will be removed once `export` will be integrated with the AOT
APIs.
* remove code that was needed only for older jaxlib to replace_tokens_with_dummy.
2024-05-11 15:19:31 +03:00
shape_assertion_p . bind ( assert_what , * error_message_inputs ,
error_message = error_message )
2023-09-05 22:15:22 -07:00
# A JAX primitive with no array arguments but with a dimension parameter
# that is a DimExpr. The value of the primitive is the value of the dimension,
# using int64 in x64 mode or int32 otherwise (core.dim_value_dtype())
dim_as_value_p = core . Primitive ( " dim_as_value " )
dim_as_value_p . def_abstract_eval ( lambda dim : core . dim_value_aval ( ) )
def dim_as_value_impl ( dim : DimSize ) :
raise NotImplementedError (
" Evaluation rule for ' dim_as_value ' is not implemented. "
2024-06-12 19:24:30 +02:00
" It seems that you are using shape polymorphism outside jax.export. " )
2023-09-05 22:15:22 -07:00
dim_as_value_p . def_impl ( dim_as_value_impl )
def _dim_as_value ( dim : DimSize ) :
return dim_as_value_p . bind ( dim = dim )
def _dim_as_value_lowering ( ctx : mlir . LoweringRuleContext , * ,
dim ) :
res , = mlir . eval_dynamic_shape ( ctx , ( dim , ) )
out_type = mlir . aval_to_ir_type ( ctx . avals_out [ 0 ] )
if out_type != res . type : # type: ignore
2023-11-17 11:46:24 -08:00
return [ mlir . hlo . convert ( out_type , res ) ]
2023-09-05 22:15:22 -07:00
else :
return [ res ]
mlir . register_lowering ( dim_as_value_p , _dim_as_value_lowering )
class PolyShape ( tuple ) :
""" Tuple of polymorphic dimension specifications.
See docstring of : func : ` jax2tf . convert ` .
"""
def __init__ ( self , * dim_specs ) :
2024-01-10 09:05:16 +02:00
warnings . warn ( " PolyShape is deprecated, use string specifications for symbolic shapes " ,
DeprecationWarning , stacklevel = 2 )
2023-09-05 22:15:22 -07:00
tuple . __init__ ( dim_specs )
def __new__ ( cls , * dim_specs ) :
2024-01-10 09:05:16 +02:00
warnings . warn ( " PolyShape is deprecated, use string specifications for symbolic shapes " ,
DeprecationWarning , stacklevel = 2 )
2023-09-05 22:15:22 -07:00
for ds in dim_specs :
if not isinstance ( ds , ( int , str ) ) and ds != . . . :
2023-10-23 15:11:15 +01:00
msg = ( f " Invalid polymorphic shape element: { ds !r} ; must be a string "
2023-09-05 22:15:22 -07:00
" representing a dimension variable, or an integer, or ... " )
raise ValueError ( msg )
return tuple . __new__ ( PolyShape , dim_specs )
def __str__ ( self ) :
return " ( " + " , " . join ( [ " ... " if d is . . . else str ( d ) for d in self ] ) + " ) "
2024-01-10 09:05:16 +02:00
def symbolic_shape ( shape_spec : str | None ,
2023-11-23 09:05:37 +02:00
* ,
2024-01-01 23:09:42 +07:00
constraints : Sequence [ str ] = ( ) ,
scope : SymbolicScope | None = None ,
2023-12-11 13:59:29 +00:00
like : Sequence [ int | None ] | None = None
2023-11-23 09:05:37 +02:00
) - > Sequence [ DimSize ] :
2024-06-12 08:47:17 +02:00
""" Constructs a symbolic shape from a string representation.
See https : / / jax . readthedocs . io / en / latest / export / shape_poly . html for examples .
2023-09-05 22:15:22 -07:00
Args :
2024-01-01 23:09:42 +07:00
shape_spec : a symbolic shape specification . None stands for " ... " .
2024-06-12 08:47:17 +02:00
A shape specification is the string representation of a tuple ( the
parentheses are optional ) with comma - separated dimension expressions .
A dimension expression can be either : an integer constant ,
a dimension variable ( alphanumeric
starting with a letter ) , e1 + e2 , e1 - e2 , e1 * e2 , floordiv ( e1 , e2 ) ,
mod ( e1 , e2 ) , max ( e1 , e2 ) , or min ( e1 , e2 ) .
constraints : a sequence of constraints on symbolic dimension expressions , of
the form ` e1 > = e2 ` or ` e1 < = e2 ` , or ` e1 == e2 ` .
See [ the documentation ] ( https : / / jax . readthedocs . io / en / latest / export / shape_poly . html #user-specified-symbolic-constraints)
for usage .
scope : optionally , you can specify that the parsed symbolic expressions
be created in the given scope . If this is missing , then a new
` SymbolicScope ` is created with the given ` constraints ` .
You cannot specify both a ` scope ` and ` constraints ` .
See [ the documentation ] ( https : / / jax . readthedocs . io / en / latest / export / shape_poly . html #user-specified-symbolic-constraints)
for usage .
2023-11-23 09:05:37 +02:00
like : when ` shape_spec ` contains placeholders ( " _ " , " ... " ) , use this
shape to fill in the placeholders .
The dimensions of ` like ` that are used for filling
2024-06-12 08:47:17 +02:00
must be not ` None ` . If a dimension in ` like ` is not ` None ` and
2023-09-05 22:15:22 -07:00
the corresponding dimension in ` shape_spec ` is a constant then they
must be equal .
2024-01-01 23:09:42 +07:00
2024-06-12 08:47:17 +02:00
Returns : a tuple with integers or symbolic expressions involving dimension variables .
2023-09-05 22:15:22 -07:00
"""
shape_spec_repr = repr ( shape_spec )
if shape_spec is None :
shape_spec = " ... "
2024-01-10 09:05:16 +02:00
elif isinstance ( shape_spec , PolyShape ) : # TODO: deprecate
2023-09-05 22:15:22 -07:00
shape_spec = str ( shape_spec )
elif not isinstance ( shape_spec , str ) :
raise ValueError ( " polymorphic shape spec should be None or a string. "
f " Found { shape_spec_repr } . " )
2024-01-01 23:09:42 +07:00
if scope is None :
scope = SymbolicScope ( constraints )
elif constraints :
raise ValueError ( " Cannot specify both a `scope` and `constraints`. " )
dimensions = _Parser ( shape_spec , like , shape_spec_repr , scope ) . parse ( )
return dimensions
2023-09-05 22:15:22 -07:00
2024-01-10 09:44:31 +02:00
def symbolic_args_specs (
args , # pytree of arguments
2024-06-12 08:47:17 +02:00
shapes_specs , # prefix pytree of strings
constraints : Sequence [ str ] = ( ) ,
scope : SymbolicScope | None = None ,
symbolic_constraints : Sequence [ str ] = ( ) , # DEPRECATED on 6/14/24
symbolic_scope : SymbolicScope | None = None , # DEPRECATED on 6/14/24
2024-01-10 09:44:31 +02:00
) :
""" Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
2024-06-12 08:47:17 +02:00
See the documentation of : func : ` jax . export . symbolic_shape ` and
the [ shape polymorphism documentation ] ( https : / / jax . readthedocs . io / en / latest / export / shape_poly . html ) for details .
2024-01-10 09:44:31 +02:00
Args :
args : a pytree of arguments . These can be jax . Array , or jax . ShapeDTypeSpec .
2024-06-12 08:47:17 +02:00
They are used to learn the pytree structure of the arguments , their dtypes ,
and to fill - in the actual shapes where the ` shapes_specs ` contains
2024-01-10 09:44:31 +02:00
placeholders . Note that only the shape dimensions for which
2024-06-12 08:47:17 +02:00
` shapes_specs ` is a placeholder are used from ` args ` .
shapes_specs : should be ` None ` ( all arguments have static shapes ) ,
a single string ( see ` shape_spec ` for : func : ` jax . export . symbolic_shape ` ;
applies to all arguments ) , or a pytree matching a prefix
2024-01-10 09:44:31 +02:00
of the ` args ` .
See [ how optional parameters are matched to
arguments ] ( https : / / jax . readthedocs . io / en / latest / pytrees . html #applying-optional-parameters-to-pytrees).
2024-06-12 08:47:17 +02:00
constraints : as for : func : ` jax . export . symbolic_shape ` .
scope : as for : func : ` jax . export . symbolic_shape ` .
symbolic_constraints : DEPRECATED , use ` constraints ` .
symbolic_scope : DEPRECATED , use ` scope ` .
2024-01-10 09:44:31 +02:00
Returns : a pytree of jax . ShapeDTypeStruct matching the ` args ` with the shapes
2024-06-12 08:47:17 +02:00
replaced with symbolic dimensions as specified by ` shapes_specs ` .
2024-01-10 09:44:31 +02:00
"""
2024-06-12 08:47:17 +02:00
if symbolic_constraints :
warnings . warn ( " symbolic_constraints is deprecated, use constraints " ,
DeprecationWarning , stacklevel = 2 )
if constraints :
raise ValueError ( " Cannot use both symbolic_constraints and constraints " )
constraints = symbolic_constraints
if symbolic_scope is not None :
warnings . warn ( " symbolic_scope is deprecated, use scope " ,
DeprecationWarning , stacklevel = 2 )
if scope is not None :
raise ValueError ( " Cannot use both symbolic_scope and scope " )
scope = symbolic_scope
polymorphic_shapes = shapes_specs
2024-01-10 09:44:31 +02:00
args_flat , args_tree = tree_util . tree_flatten ( args )
shapes_and_dtypes = tuple ( map ( shape_and_dtype_jax_array , args_flat ) )
shapes , dtypes = util . unzip2 ( shapes_and_dtypes )
if isinstance ( args , tuple ) and isinstance ( polymorphic_shapes , list ) :
# TODO: Remove backward-compatibility workaround
polymorphic_shapes_ = tuple ( polymorphic_shapes )
else :
polymorphic_shapes_ = polymorphic_shapes
try :
polymorphic_shapes_flat = tree_util . broadcast_prefix (
polymorphic_shapes_ , args ,
is_leaf = lambda x : x is None )
except ValueError :
e , * _ = tree_util . prefix_errors (
polymorphic_shapes_ , args ,
is_leaf = lambda x : x is None )
2024-06-12 08:47:17 +02:00
raise e ( " export.symbolic_args_specs shapes_specs " ) from None
2024-01-10 09:44:31 +02:00
# Now add in the polymorphic shapes
2024-06-12 08:47:17 +02:00
if scope is None :
scope = SymbolicScope ( constraints )
elif constraints :
raise ValueError ( " Cannot use both `scope` and `constraints` " )
2024-01-10 09:44:31 +02:00
args_specs_flat = (
2024-06-12 08:47:17 +02:00
jax . ShapeDtypeStruct ( symbolic_shape ( spec , like = s , scope = scope ) , t )
2024-01-10 09:44:31 +02:00
for s , t , spec in zip ( shapes , dtypes , polymorphic_shapes_flat ) )
return args_tree . unflatten ( args_specs_flat )
def shape_and_dtype_jax_array ( a ) - > tuple [ Sequence [ int | None ] , DType ] :
""" Returns the shape and dtype of a jax.Array or a j """
if isinstance ( a , jax . ShapeDtypeStruct ) :
return a . shape , a . dtype
aval = core . raise_to_shaped ( core . get_aval ( a ) )
return aval . shape , aval . dtype
2024-02-20 23:13:20 +01:00
2023-09-05 22:15:22 -07:00
class _Parser :
def __init__ ( self ,
shape_spec : str ,
2023-12-11 13:59:29 +00:00
like_shape : Sequence [ int | None ] | None ,
2024-01-01 23:09:42 +07:00
shape_spec_repr : str ,
scope : SymbolicScope ) :
2023-09-05 22:15:22 -07:00
self . shape_spec = shape_spec
self . shape_spec_repr = shape_spec_repr # For error messages
2023-11-23 09:05:37 +02:00
self . like_shape = like_shape
2023-09-05 22:15:22 -07:00
self . dimensions : list [ DimSize ] = [ ] # dimensions we have parsed
2024-01-01 23:09:42 +07:00
self . scope = scope
2023-09-05 22:15:22 -07:00
def parse ( self ) - > Sequence [ DimSize ] :
self . tokstream = tokenize . tokenize (
io . BytesIO ( self . shape_spec . encode ( " utf-8 " ) ) . readline )
tok = self . consume_token ( self . next_tok ( ) , tokenize . ENCODING ) # Always 1st
sh , tok = self . shape ( tok )
self . expect_token ( tok , [ tokenize . ENDMARKER ] )
return sh
2023-12-11 13:59:29 +00:00
def add_dim ( self , expr : DimSize | None , tok : tokenize . TokenInfo ) :
2023-09-05 22:15:22 -07:00
if expr is None :
raise self . parse_err ( tok ,
2023-11-23 09:05:37 +02:00
( " unexpected placeholder for unknown dimension; "
f " like= { self . like_shape } " ) )
if core . is_constant_dim ( expr ) and self . like_shape is not None :
like_shape_dim = self . like_shape [ len ( self . dimensions ) ]
2024-05-17 09:46:36 +01:00
if expr != like_shape_dim :
2023-09-05 22:15:22 -07:00
raise self . parse_err ( tok ,
2023-11-23 09:05:37 +02:00
( f " different size { expr } for known dimension; "
f " like= { self . like_shape } " ) )
2023-09-05 22:15:22 -07:00
self . dimensions . append ( expr )
2023-12-11 13:59:29 +00:00
def parse_err ( self , tok : tokenize . TokenInfo | None , detail : str ) - > Exception :
2023-09-05 22:15:22 -07:00
msg = (
2023-11-23 09:05:37 +02:00
f " syntax error in symbolic shape { self . shape_spec_repr } "
2023-09-05 22:15:22 -07:00
f " in dimension { len ( self . dimensions ) } : { detail } . " )
if tok is not None :
msg + = f " Parsed ' { tok . line [ : tok . start [ 1 ] ] } ' , remaining ' { tok . line [ tok . start [ 1 ] : ] } ' . "
return ValueError ( msg )
def next_tok ( self ) - > tokenize . TokenInfo :
while True :
try :
2024-06-04 22:02:36 -07:00
t = next ( self . tokstream ) # type: ignore[attribute-error,unused-ignore]
2023-09-05 22:15:22 -07:00
except StopIteration :
raise self . parse_err ( None , " unexpected end of string " )
if t . exact_type not in [ tokenize . NEWLINE , tokenize . INDENT , tokenize . DEDENT ] :
return t
def expect_token ( self , tok : tokenize . TokenInfo , expected : Sequence [ int ] ) - > None :
if tok . exact_type not in expected :
msg = ( " expecting one of { " +
" , " . join ( tokenize . tok_name [ t ] for t in expected ) + " } but found " +
tokenize . tok_name [ tok . exact_type ] )
raise self . parse_err ( tok , msg )
def consume_token ( self , tok : tokenize . TokenInfo , expected : int ) - > tokenize . TokenInfo :
self . expect_token ( tok , [ expected ] )
return self . next_tok ( )
def integer ( self , tok : tokenize . TokenInfo ) - > tuple [ int , tokenize . TokenInfo ] :
self . expect_token ( tok , [ tokenize . NUMBER ] )
try :
val = int ( tok . string )
except Exception :
raise self . parse_err ( tok , f " expecting integer, found { tok . string } " )
return val , self . next_tok ( )
# What can follow a shape?
FOLLOW_SHAPE = [ tokenize . ENDMARKER , tokenize . RPAR ]
def shape ( self , tok : tokenize . TokenInfo ) - > tuple [ Sequence [ DimSize ] , tokenize . TokenInfo ] :
# A comma-separated list of _DimExpr, or "_", possibly ended with ...
if tok . exact_type == tokenize . LPAR :
res , tok = self . shape ( self . next_tok ( ) )
tok = self . consume_token ( tok , tokenize . RPAR )
return res , tok
while True :
if tok . exact_type in self . FOLLOW_SHAPE :
break
2023-11-23 09:05:37 +02:00
# Error checking in presence of placeholders
if ( tok . exact_type == tokenize . ELLIPSIS or
tok . exact_type == tokenize . NAME and tok . string == " _ " ) :
if self . like_shape is None :
raise self . parse_err ( tok ,
" spec contains ... but no ' like ' shape was given " )
if tok . exact_type == tokenize . ELLIPSIS :
min_len_like_shape = len ( self . dimensions )
else :
min_len_like_shape = len ( self . dimensions ) + 1
if len ( self . like_shape ) < min_len_like_shape :
raise self . parse_err (
tok ,
f " cannot resolve placeholder ' { tok . string } ' because we parsed "
f " { len ( self . dimensions ) } already and ' like ' shape has "
f " only { len ( self . like_shape ) } dimensions " )
2023-09-05 22:15:22 -07:00
if tok . exact_type == tokenize . ELLIPSIS :
2023-11-23 09:05:37 +02:00
to_add = self . like_shape [ len ( self . dimensions ) : ] # type: ignore[index]
2023-09-05 22:15:22 -07:00
for ad in to_add :
self . add_dim ( ad , tok )
tok = self . next_tok ( )
break
2023-11-23 09:05:37 +02:00
2023-09-05 22:15:22 -07:00
if tok . exact_type == tokenize . NAME and tok . string == " _ " :
2023-11-23 09:05:37 +02:00
e = self . like_shape [ len ( self . dimensions ) ] # type: ignore[index]
2023-09-05 22:15:22 -07:00
tok = self . next_tok ( )
else :
e , tok = self . expr ( tok )
self . add_dim ( e , tok )
if tok . exact_type in self . FOLLOW_SHAPE :
break
tok = self . consume_token ( tok , tokenize . COMMA )
return tuple ( self . dimensions ) , tok
# What token can follow a _DimExpr
FOLLOW_EXPR = FOLLOW_SHAPE + [ tokenize . COMMA ]
def expr ( self , tok : tokenize . TokenInfo ) - > tuple [ DimSize , tokenize . TokenInfo ] :
2024-02-20 23:13:20 +01:00
# A sum of terms
next_t_negated = ( tok . exact_type == tokenize . MINUS )
if next_t_negated :
2024-02-14 11:34:40 +02:00
tok = self . next_tok ( )
elif tok . exact_type == tokenize . PLUS :
tok = self . next_tok ( )
2024-02-03 06:38:01 +02:00
acc = None
2023-09-05 22:15:22 -07:00
while True :
2024-02-20 23:13:20 +01:00
t , tok = self . term ( tok )
t_sign = - t if next_t_negated else t
2024-06-04 22:02:36 -07:00
acc = acc + t_sign if acc is not None else t_sign # type: ignore[operator]
2023-09-05 22:15:22 -07:00
if tok . exact_type in self . FOLLOW_EXPR :
return acc , tok
2024-02-20 23:13:20 +01:00
next_t_negated = ( tok . exact_type == tokenize . MINUS )
2023-09-05 22:15:22 -07:00
self . expect_token ( tok , [ tokenize . PLUS , tokenize . MINUS ] )
tok = self . next_tok ( )
2024-02-20 23:13:20 +01:00
FOLLOW_TERM = FOLLOW_EXPR + [ tokenize . PLUS , tokenize . MINUS ]
def term ( self , tok : tokenize . TokenInfo ) - > tuple [ DimSize , tokenize . TokenInfo ] :
# A term is product of factors. Each factor may be raised to an integer power.
2024-02-03 06:38:01 +02:00
acc = None
2023-09-05 22:15:22 -07:00
while True :
2024-02-20 23:13:20 +01:00
f , tok = self . factor ( tok )
2023-09-05 22:15:22 -07:00
if tok . exact_type == tokenize . CIRCUMFLEX :
tok = self . next_tok ( )
self . expect_token ( tok , [ tokenize . NUMBER ] )
power , tok = self . integer ( tok )
2024-02-20 23:13:20 +01:00
f = f * * power
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
acc = acc * f if acc is not None else f # type: ignore[operator]
if tok . exact_type in self . FOLLOW_TERM :
2024-06-04 22:02:36 -07:00
return acc , tok # type: ignore[bad-return-type,unused-ignore]
2023-09-05 22:15:22 -07:00
tok = self . consume_token ( tok , tokenize . STAR )
2024-02-20 23:13:20 +01:00
def factor ( self , tok : tokenize . TokenInfo ) - > tuple [ DimSize , tokenize . TokenInfo ] :
2023-09-05 22:15:22 -07:00
if tok . exact_type == tokenize . NAME :
2024-02-20 23:13:20 +01:00
if tok . string in ( _DimFactor . MOD , _DimFactor . FLOORDIV , _DimFactor . MAX , _DimFactor . MIN ) :
return self . factor_binary_op ( tok . string , self . next_tok ( ) )
if tok . string == _DimFactor . NON_NEGATIVE : # We still parse this for backwards compatibility
return self . factor_unary_op ( _DimFactor . NON_NEGATIVE , self . next_tok ( ) )
return _DimExpr . _from_var ( tok . string , self . scope ) , self . next_tok ( )
2023-09-05 22:15:22 -07:00
number_sign = 1
if tok . exact_type == tokenize . MINUS : # -k are negative constants
number_sign = - 1
tok = self . next_tok ( )
self . expect_token ( tok , [ tokenize . NUMBER ] )
if tok . exact_type == tokenize . NUMBER :
v , tok = self . integer ( tok )
return v * number_sign , tok
self . expect_token ( tok , [ tokenize . NAME , tokenize . MINUS , tokenize . NUMBER ] )
assert False
2024-02-20 23:13:20 +01:00
def factor_unary_op ( self , op : str , tok : tokenize . TokenInfo ) - > tuple [ DimSize , tokenize . TokenInfo ] :
2023-09-05 22:15:22 -07:00
tok = self . consume_token ( tok , tokenize . LPAR )
e1 , tok = self . expr ( tok )
tok = self . consume_token ( tok , tokenize . RPAR )
2024-02-20 23:13:20 +01:00
return _DimExpr . _from_operation ( op , e1 ,
2024-05-17 09:46:36 +01:00
scope = self . scope ) , tok
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
def factor_binary_op ( self , op : str , tok ) - > tuple [ DimSize , tokenize . TokenInfo ] :
2023-09-05 22:15:22 -07:00
tok = self . consume_token ( tok , tokenize . LPAR )
e1 , tok = self . expr ( tok )
tok = self . consume_token ( tok , tokenize . COMMA )
e2 , tok = self . expr ( tok )
tok = self . consume_token ( tok , tokenize . RPAR )
2024-02-20 23:13:20 +01:00
if op == _DimFactor . MAX :
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
return core . max_dim ( e1 , e2 ) , tok
2024-02-20 23:13:20 +01:00
if op == _DimFactor . MIN :
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
return core . min_dim ( e1 , e2 ) , tok
2024-02-20 23:13:20 +01:00
return _DimExpr . _from_operation ( op , e1 , e2 ,
2024-05-17 09:46:36 +01:00
scope = self . scope ) , tok
2023-09-05 22:15:22 -07:00
def _evaluate_add ( v1 , v2 ) :
try :
if op . index ( v1 ) == 0 :
return v2
except :
pass
try :
if op . index ( v2 ) == 0 :
return v1
except :
pass
return v1 + v2
def _evaluate_multiply ( v1 , v2 ) :
try :
if op . index ( v1 ) == 1 :
return v2
except :
pass
try :
if op . index ( v2 ) == 1 :
return v1
except :
pass
return v1 * v2
# dimension_size(operand, dimension=i) get the operand.shape[i] as a
# value of type shape_poly.dim_as_value_dtype().
dimension_size_p = core . Primitive ( " dimension_size " )
def _dimension_size_abstract_eval ( aval : core . AbstractValue , * * _ ) - > core . AbstractValue :
return core . dim_value_aval ( )
dimension_size_p . def_abstract_eval ( _dimension_size_abstract_eval )
def _dimension_size_impl ( arg , * , dimension ) :
return core . dim_constant ( arg . shape [ dimension ] )
dimension_size_p . def_impl ( _dimension_size_impl )
def _dimension_size_lowering_rule ( ctx , arg , * , dimension ) :
2023-11-17 11:46:24 -08:00
dim_size = mlir . hlo . get_dimension_size ( arg , dimension )
2023-09-05 22:15:22 -07:00
dim_type = mlir . aval_to_ir_type ( core . dim_value_aval ( ) )
2023-11-17 11:46:24 -08:00
if dim_size . type != dim_type :
dim_size = mlir . hlo . convert ( dim_type , dim_size )
return [ dim_size ]
2023-09-05 22:15:22 -07:00
mlir . register_lowering ( dimension_size_p , _dimension_size_lowering_rule )
2024-06-11 13:46:44 +02:00
def all_dim_vars ( args_avals : Sequence [ core . ShapedArray ] ) - > Sequence [ str ] :
2023-09-05 22:15:22 -07:00
dim_vars : set [ str ] = set ( )
for a in args_avals :
2024-06-11 13:46:44 +02:00
for d in a . shape :
2024-01-10 08:45:03 +02:00
if is_symbolic_dim ( d ) :
2024-02-20 23:13:20 +01:00
dim_vars = dim_vars . union ( d . _get_vars ( ) )
2023-11-14 23:34:30 -05:00
return sorted ( dim_vars )
2023-09-05 22:15:22 -07:00
2024-11-09 11:10:16 +02:00
class ShapeEvaluator :
def __init__ ( self , env : DimVarEnv ) :
2023-09-05 22:15:22 -07:00
self . env = env
def evaluate ( self , e : DimSize ) :
if core . is_constant_dim ( e ) :
2024-01-05 14:48:53 +07:00
res = op . index ( e ) # type: ignore
2023-09-05 22:15:22 -07:00
else :
2024-02-20 23:13:20 +01:00
res = e . _evaluate ( self . env ) # type: ignore
2023-09-05 22:15:22 -07:00
return res
@dataclasses.dataclass ( frozen = True )
class ShapeConstraint :
comp : Comparator
left : DimSize
right : DimSize
# `error_message_pieces` is a list of strings and DimSize. The error message
# is formed by evaluating the DimSize and concatenating the sequence.
2023-12-11 13:59:29 +00:00
error_message_pieces : Sequence [ str | DimSize ]
2023-09-05 22:15:22 -07:00
2024-11-09 11:10:16 +02:00
def check_statically ( self , eval : ShapeEvaluator ) - > None :
2023-09-05 22:15:22 -07:00
""" Evaluates a constraint statically. """
left , right = eval . evaluate ( self . left ) , eval . evaluate ( self . right )
try :
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
if self . comp == Comparator . EQ :
2023-09-05 22:15:22 -07:00
ok = ( left == right )
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
elif self . comp == Comparator . GEQ :
2023-09-05 22:15:22 -07:00
ok = ( left > = right )
else :
assert False # We are in a context where we know we can evaluate
# all symbolic expressions to constants.
except InconclusiveDimensionOperation as e :
raise self . make_error ( eval ) from e
if not ok :
raise self . make_error ( eval )
2024-11-09 11:10:16 +02:00
def compute ( self , eval : ShapeEvaluator ) - > jax . Array | None :
2023-09-05 22:15:22 -07:00
""" Computes if the constraint is satisfied.
If the constraint can be resolved statically returns None
or raises ValueError otherwise . If the constraint cannot be
resolved statically , returns a value representing if the
constraint is satisfied .
"""
left , right = eval . evaluate ( self . left ) , eval . evaluate ( self . right )
# Try to evaluate the constraint statically.
if core . is_constant_shape ( ( left , right ) ) :
left_int , right_int = op . index ( left ) , op . index ( right )
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
if self . comp == Comparator . EQ :
2023-09-05 22:15:22 -07:00
if not ( left_int == right_int ) :
raise self . make_error ( eval )
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
elif self . comp == Comparator . GEQ :
2023-09-05 22:15:22 -07:00
if not ( left_int > = right_int ) :
raise self . make_error ( eval )
else : assert False
return None
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
if self . comp == Comparator . EQ :
2023-09-05 22:15:22 -07:00
is_ok = lax . eq ( left , right )
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
elif self . comp == Comparator . GEQ :
2023-09-05 22:15:22 -07:00
is_ok = lax . ge ( left , right )
else : assert False
return is_ok
def __str__ ( self ) :
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
return ( f " { self . left } { ' == ' if self . comp == Comparator . EQ else ' >= ' } { self . right } "
2023-09-05 22:15:22 -07:00
f " ( { self . error_message_pieces } ) " )
__repr__ = __str__
def error_message_and_inputs (
self ,
2024-11-09 11:10:16 +02:00
eval : ShapeEvaluator ) - > tuple [ str , Sequence [ Any ] ] :
2023-09-05 22:15:22 -07:00
""" Forms the error_message and error message_inputs.
See shape_assertion .
"""
2023-09-22 14:54:31 -07:00
# There is currently a limitation in the shape assertion checker that
2023-09-05 22:15:22 -07:00
# it supports at most 32 error_message_inputs. We try to stay within the
# limit, reusing a format specifier if possible.
2023-11-17 09:37:45 -08:00
max_error_message_inputs = 32
2023-09-05 22:15:22 -07:00
format_specifiers : dict [ DimSize , str ] = { }
error_message_inputs : list [ Any ] = [ ]
error_message_strings : list [ str ] = [ ]
for e in self . error_message_pieces :
if isinstance ( e , str ) :
error_message_strings . append ( e )
continue
cached_spec = format_specifiers . get ( e )
if cached_spec is not None :
error_message_strings . append ( cached_spec )
continue
if len ( error_message_inputs ) > = max_error_message_inputs :
error_message_strings . append ( " N/A " )
continue
spec = " { " + str ( len ( error_message_inputs ) ) + " } "
format_specifiers [ e ] = spec
error_message_strings . append ( spec )
error_message_inputs . append ( eval . evaluate ( e ) )
return ( " " . join ( error_message_strings ) ,
error_message_inputs )
2024-11-09 11:10:16 +02:00
def make_error ( self , eval : ShapeEvaluator ) - > Exception :
2023-09-05 22:15:22 -07:00
error_message , error_message_inputs = self . error_message_and_inputs ( eval )
return ValueError ( error_message . format ( * error_message_inputs ) )
class ShapeConstraints :
def __init__ ( self ) :
self . constraints : list [ ShapeConstraint ] = [ ]
def add_constraint ( self ,
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
comp : Comparator ,
2023-09-05 22:15:22 -07:00
left : DimSize , right : DimSize ,
2023-12-11 13:59:29 +00:00
error_message_pieces : Sequence [ str | DimSize ] ) :
2023-09-05 22:15:22 -07:00
c = ShapeConstraint ( comp , left , right , error_message_pieces )
self . constraints . append ( c )
2024-11-09 11:10:16 +02:00
def check_statically ( self , eval : ShapeEvaluator ) - > None :
2023-09-05 22:15:22 -07:00
""" Evaluates all the constraints statically.
If the static checking of any constraint fails , raises ValueError .
"""
for constraint in self . constraints :
constraint . check_statically ( eval )
2024-11-09 11:10:16 +02:00
def shape_assertions ( self , eval : ShapeEvaluator ) - > None :
2023-09-05 22:15:22 -07:00
""" Computes the shape assertions for the set of constraints.
2024-05-31 15:14:09 +03:00
See jax_export . Exported docstring .
2023-09-05 22:15:22 -07:00
"""
# We want to report the errors in the same order as `check_statically`.
# So, we process them in order, in case some fail statically, and we
# generate the shape assertions in the same order.
for constraint in self . constraints :
is_ok = constraint . compute ( eval )
if is_ok is None : continue # Was resolved statically
error_message , error_message_inputs = constraint . error_message_and_inputs ( eval )
shape_assertion (
is_ok , * error_message_inputs ,
error_message = error_message )
@dataclasses.dataclass
class _DimEquation :
# Encodes that `aval_dim_expr`, which is a symbolic expressions containing
# unknown dimension variables from the abstract values, is the specification
# for dimension named `dim_name` (e.g., "args[0].field.shape[2]").
aval_dim_expr : _DimExpr
dim_name : str
def __str__ ( self ) :
return f " Dimension size of { self . dim_name } with specification ' { self . aval_dim_expr } ' "
__repr__ = __str__
def args_kwargs_path_to_str ( path : tree_util . KeyPath ) - > str :
# String description of `args` or `kwargs`, assuming the path for a tree for
# the tuple `(args, kwargs)`.
if path [ 0 ] == tree_util . SequenceKey ( 0 ) :
return f " args { tree_util . keystr ( path [ 1 : ] ) } "
elif path [ 0 ] == tree_util . SequenceKey ( 1 ) :
return f " kwargs { tree_util . keystr ( path [ 1 : ] ) } "
else :
assert False
@functools.lru_cache ( 128 )
def _cached_pretty_print_dimension_descriptor (
args_kwargs_tree : tree_util . PyTreeDef ,
flat_arg_idx : int ) - > str :
args_kwargs_with_paths , _ = tree_util . tree_flatten_with_path (
args_kwargs_tree . unflatten ( ( 0 , ) * args_kwargs_tree . num_leaves ) )
arg_str = args_kwargs_path_to_str ( args_kwargs_with_paths [ flat_arg_idx ] [ 0 ] )
return arg_str
def pretty_print_dimension_descriptor (
args_kwargs_tree : tree_util . PyTreeDef ,
2023-12-11 13:59:29 +00:00
flat_arg_idx : int , dim_idx : int | None ) - > str :
2023-09-05 22:15:22 -07:00
arg_str = _cached_pretty_print_dimension_descriptor ( args_kwargs_tree , flat_arg_idx )
if dim_idx is not None :
arg_str + = f " .shape[ { dim_idx } ] "
return arg_str
@util.cache ( )
def solve_dim_vars (
2024-06-11 13:46:44 +02:00
args_avals : Sequence [ core . ShapedArray ] ,
2023-09-05 22:15:22 -07:00
args_kwargs_tree : tree_util . PyTreeDef ,
) - > tuple [ DimVarEnv , ShapeConstraints , Sequence [ tuple [ str , int , int ] ] ] :
""" Solves dimension variables in a called function ' s avals in terms of actual argument shapes.
For example , given :
args_avals = [ ShapedArray ( ( 3 , a , a + b ) , f32 ) ]
we introduce fresh " synthetic " dimension variables to represent the actual
dimension size of actual arguments for each non - constant dimension .
Each synthetic variable has a name , an arg_idx , and a dim_idx , e . g . :
synthetic_vars = [ ( " args[0].shape[1] " , 0 , 1 ) , ( " args[0].shape[2] " , 0 , 2 ) ]
and then we express the solution for the unknown dimension variables { a , b }
as symbolic expressions in terms of the synthetic variables :
dict ( a = args [ 0 ] . shape [ 1 ] , b = args [ 0 ] . shape [ 2 ] - args [ 0 ] . shape [ 1 ] )
Not all equations are solvable . For now , we solve first the linear
uni - variate equations , then the solved variables are used to simplify the
remaining equations to linear uni - variate equations , and the process
continues until all dimension variables are solved .
Args :
args_avals : the abstract values of the ` args ` , with shapes that may
include unknown dimension variables .
args_kwargs_tree : a PyTreeDef that describes the tuple ` ( args , kwargs ) `
from which the flat sequence ` args_avals ` is extracted . Used for
describing args and kwargs in synthetic variable names and in
error messages .
Returns : a 3 - tuple with : ( a ) the solution for the unknown dimension variables
( b ) a list of constraints that must be satisfied for the solution to be a
valid one , and ( c ) and the list of synthetic variables that may appear in
the solution and the constraints .
Raises ValueError if it cannot solve some dimension variable .
"""
dim_equations : list [ _DimEquation ] = [ ]
synth_dimension_vars : list [ tuple [ str , int , int ] ] = [ ]
# tuples with argument name and its polymorphic shape ('args[0]', '(a, a + b'))
polymorphic_shape_specs : list [ tuple [ str , str ] ] = [ ]
for arg_idx , aval in enumerate ( args_avals ) :
2024-06-11 13:46:44 +02:00
if all ( not is_symbolic_dim ( d ) for d in aval . shape ) :
2023-09-05 22:15:22 -07:00
continue
polymorphic_shape_specs . append (
( pretty_print_dimension_descriptor ( args_kwargs_tree , arg_idx , None ) ,
2024-06-11 13:46:44 +02:00
str ( aval . shape ) ) )
for dim_idx , aval_d in enumerate ( aval . shape ) :
2024-01-10 08:45:03 +02:00
if is_symbolic_dim ( aval_d ) :
2023-09-05 22:15:22 -07:00
synth_dim_var = pretty_print_dimension_descriptor ( args_kwargs_tree ,
arg_idx , dim_idx )
synth_dimension_vars . append ( ( synth_dim_var , arg_idx , dim_idx ) )
dim_equations . append (
2024-01-01 23:09:42 +07:00
_DimEquation ( aval_dim_expr = aval_d ,
2023-09-05 22:15:22 -07:00
dim_name = synth_dim_var ) )
solution , shape_constraints = _solve_dim_equations ( dim_equations ,
polymorphic_shape_specs )
return solution , shape_constraints , synth_dimension_vars
def compute_dim_vars_from_arg_shapes (
2024-06-11 13:46:44 +02:00
args_avals : Sequence [ core . ShapedArray ] ,
2023-09-05 22:15:22 -07:00
* actual_args : jax . Array ,
args_kwargs_tree : tree_util . PyTreeDef ) - > Sequence [ jax . Array ] :
""" Computes values of dimension variables to unify args_avals with actual arguments.
Like ` solve_dim_vars ` except that here we express the solution as
JAX arrays that reference the ` actual_args ` . This function can be used to
generate the code for computing the dimension variables . It also generates
the shape assertions .
Returns : the values of the dimension variables , in the order determined by
` all_dim_vars ( args_avals ) ` .
"""
dim_vars = all_dim_vars ( args_avals )
solution , shape_constraints , synth_dim_vars = solve_dim_vars (
tuple ( args_avals ) , args_kwargs_tree = args_kwargs_tree )
# Replace the synthetic vars with the dynamic shape of the actual arg
2024-11-09 11:10:16 +02:00
synthetic_env : DimVarEnv = {
vname : dimension_size_p . bind ( actual_args [ arg_idx ] , dimension = dim_idx )
for ( vname , arg_idx , dim_idx ) in synth_dim_vars
}
synthetic_eval = ShapeEvaluator ( synthetic_env )
2023-09-05 22:15:22 -07:00
shape_constraints . shape_assertions ( synthetic_eval )
dim_values = [ synthetic_eval . evaluate ( solution [ var ] ) for var in dim_vars ]
return tuple ( dim_values )
def _solve_dim_equations (
eqns : list [ _DimEquation ] ,
polymorphic_shape_specs : Sequence [ tuple [ str , str ] ]
) - > tuple [ DimVarEnv , ShapeConstraints ] :
# Returns a shape environment and the shape constraints if it can solve all
# dimension variables. Raises an exception if it cannot.
2024-01-01 23:09:42 +07:00
shape_env : DimVarEnv = { }
2024-09-06 11:52:12 +03:00
solution_error_message_pieces : list [ str | DimSize ] = [
2023-09-05 22:15:22 -07:00
" Obtained dimension variables: "
] # Error message describing the solution
# Prepare error message piece describing the polymorphic shape specs
poly_specs_err_msg = (
" Using the following polymorphic shapes specifications: " +
" , " . join ( f " { arg_name } .shape = { arg_spec } "
for arg_name , arg_spec in polymorphic_shape_specs ) ) + " . "
2024-06-12 08:47:17 +02:00
solution_err_msg_trailer_errors = " . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details. "
2023-09-05 22:15:22 -07:00
shape_constraints = ShapeConstraints ( ) # accumulate shape constraints
2024-01-01 23:09:42 +07:00
scope : SymbolicScope | None = None
2023-09-05 22:15:22 -07:00
def process_one_eqn ( eqn : _DimEquation ) - > bool :
# We start with a DimEquation of the form `dim_expr = dim_value`
# Try to rewrite the equation as `var * factor_var = dim_value_2` (a linear
# uni-variate equation). Returns `False` if this rewrite fails.
# Otherwise, compute the `var` value as `dim_value_2 // factor`, add it to
2024-01-01 23:09:42 +07:00
# `shape_env` and return `True`.
2023-09-05 22:15:22 -07:00
#
# Invariant:
2024-02-20 23:13:20 +01:00
# var * factor_var + remaining_terms_from_dim_expr = dim_value
var , var_k = None , None
2024-01-01 23:09:42 +07:00
nonlocal scope
if scope is None :
scope = eqn . aval_dim_expr . scope
elif config . enable_checks . value :
scope . _check_same_scope ( eqn . aval_dim_expr , when = f " solving equation { eqn } " )
2024-02-20 23:13:20 +01:00
dim_value = _DimExpr . _from_var ( eqn . dim_name , scope )
2023-09-05 22:15:22 -07:00
2024-02-20 23:13:20 +01:00
for term , term_k in eqn . aval_dim_expr . _sorted_terms :
# Perhaps we can already evaluate this term (all vars solved)
2023-09-05 22:15:22 -07:00
try :
2024-09-06 11:52:12 +03:00
term_value = term . evaluate ( shape_env , scope )
except UnexpectedDimVar :
2023-09-05 22:15:22 -07:00
# `mon` still uses some variables not yet solved. We handle only the
# case when `mon` is a single variable.
2024-02-20 23:13:20 +01:00
v = term . to_var ( )
2023-09-05 22:15:22 -07:00
if v is not None and var is None :
2024-02-20 23:13:20 +01:00
var , var_k = v , term_k
2023-09-05 22:15:22 -07:00
continue
else :
2024-02-20 23:13:20 +01:00
dim_value = dim_value + core . dim_constant ( - 1 ) * _evaluate_multiply ( term_value , core . dim_constant ( term_k ) )
2023-09-05 22:15:22 -07:00
continue
return False # This equation cannot yet be used to solve a variable
if var is not None :
2024-02-20 23:13:20 +01:00
if var_k == 1 :
2023-09-05 22:15:22 -07:00
var_value = dim_value
else :
2024-02-20 23:13:20 +01:00
var_value , var_remainder = divmod ( dim_value , core . dim_constant ( var_k ) ) # type: ignore
2023-09-05 22:15:22 -07:00
shape_constraints . add_constraint (
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
Comparator . EQ , var_remainder , 0 ,
2023-09-05 22:15:22 -07:00
error_message_pieces = ( [
" Input shapes do not match the polymorphic shapes specification. "
" Division had remainder " , var_remainder ,
f " when computing the value of ' { var } ' . " + poly_specs_err_msg
] + solution_error_message_pieces + [
solution_err_msg_trailer_errors ] ) )
if not isinstance ( var_value , _DimExpr ) :
2024-06-04 22:02:36 -07:00
assert var_value . dtype == core . dim_value_dtype ( ) # type: ignore[attribute-error,unused-ignore]
2024-01-01 23:09:42 +07:00
shape_env [ var ] = var_value # type: ignore
2024-06-04 22:02:36 -07:00
solution_error_message_pieces . extend ( [ # type: ignore[container-type-mismatch,unused-ignore]
2023-09-05 22:15:22 -07:00
f " ' { var } ' = " , var_value ,
f " from specification ' { eqn . aval_dim_expr } ' "
2024-01-01 23:09:42 +07:00
f " for dimension { eqn . dim_name } (= " ,
2024-02-20 23:13:20 +01:00
_DimExpr . _from_var ( eqn . dim_name , eqn . aval_dim_expr . scope ) ,
2023-09-05 22:15:22 -07:00
" ), " ] )
shape_constraints . add_constraint (
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
Comparator . GEQ , var_value , 1 ,
2023-09-05 22:15:22 -07:00
error_message_pieces = [
" Input shapes do not match the polymorphic shapes specification. "
f " Expected value >= 1 for dimension variable ' { var } ' . " +
poly_specs_err_msg
] + solution_error_message_pieces + [
solution_err_msg_trailer_errors ] )
return True
else :
# All variables are resolved for this equation, we emit an assertion
shape_constraints . add_constraint (
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
Comparator . EQ ,
2024-02-20 23:13:20 +01:00
_DimExpr . _from_var ( eqn . dim_name , eqn . aval_dim_expr . scope ) ,
eqn . aval_dim_expr . _evaluate ( shape_env ) ,
2023-09-05 22:15:22 -07:00
error_message_pieces = ( [
" Input shapes do not match the polymorphic shapes specification. "
f " Found inconsistency between dimension size { eqn . dim_name } (= " ,
2024-02-20 23:13:20 +01:00
_DimExpr . _from_var ( eqn . dim_name , eqn . aval_dim_expr . scope ) ,
2023-09-05 22:15:22 -07:00
f " ) and the specification ' { eqn . aval_dim_expr } ' (= " ,
2024-02-20 23:13:20 +01:00
eqn . aval_dim_expr . _evaluate ( shape_env ) ,
2023-09-05 22:15:22 -07:00
" ). " + poly_specs_err_msg ] + solution_error_message_pieces +
[ solution_err_msg_trailer_errors ] )
)
return True
2024-01-01 23:09:42 +07:00
def add_explicit_symbolic_constraints ( shape_env : DimVarEnv ) :
if not shape_env : return
assert scope is not None
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
for constr in scope . _explicit_constraints :
2024-09-06 11:52:12 +03:00
# We can't just construct constr.e1 - constr.e2 because for an equality
# constraint it would be reduced to 0.
c_e1 = constr . e1 . _evaluate ( shape_env ) if not core . is_constant_dim ( constr . e1 ) else constr . e1 # type: ignore
c_e2 = constr . e2 . _evaluate ( shape_env ) if not core . is_constant_dim ( constr . e2 ) else constr . e2 # type: ignore
c_diff = c_e1 - c_e2
2024-01-01 23:09:42 +07:00
shape_constraints . add_constraint (
2024-09-06 11:52:12 +03:00
constr . cmp , c_diff , 0 ,
2024-01-01 23:09:42 +07:00
error_message_pieces = [
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
f " Input shapes do not match the symbolic shape constraint { constr . debug_str } . "
2024-09-06 11:52:12 +03:00
f " Expected ' { constr . e1 } - { constr . e2 } ' to be "
[shape_poly] Add limited support for equality explicit constraints.
Previously, we have added support for inequality constraints. Now
we also add equality constraints. These are useful when encountering
errors due to inability to check equalities of symbolic expressions,
e.g., in the broadcasting rules. For example, the current decision
procedure cannot decide that `mod(mod(a, 2), 2) == mod(a, 2)`.
To work around such limitations, it is now possible to add
the above as an equality constraint.
Like other explicit constraints, this will be used to decide equalities during staging, and will be checked at shape refinement time.
See more details in the README.md changes.
2024-01-27 19:54:52 +01:00
f " { ' greater or equal ' if constr . cmp == Comparator . GEQ else ' equal ' } to 0, "
2024-09-06 11:52:12 +03:00
" but found " , c_diff ,
2024-01-01 23:09:42 +07:00
" . " + poly_specs_err_msg
] + solution_error_message_pieces + [
solution_err_msg_trailer_errors ] )
2023-09-05 22:15:22 -07:00
while True :
nr_eqns = len ( eqns )
eqns = [ eqn for eqn in eqns if not process_one_eqn ( eqn ) ]
if not eqns :
2024-01-01 23:09:42 +07:00
add_explicit_symbolic_constraints ( shape_env )
return shape_env , shape_constraints # SUCCESS
2023-09-05 22:15:22 -07:00
elif len ( eqns ) > = nr_eqns :
break
# We have some equations that we cannot solve further
unsolved_vars : set [ str ] = set ( )
unsolved_polys : list [ _DimExpr ] = [ ]
for eqn in eqns :
2024-02-20 23:13:20 +01:00
unsolved_vars = unsolved_vars . union ( eqn . aval_dim_expr . _get_vars ( ) )
2023-09-05 22:15:22 -07:00
unsolved_polys . append ( eqn . aval_dim_expr )
2024-01-01 23:09:42 +07:00
unsolved_vars = unsolved_vars . difference ( shape_env . keys ( ) )
2023-09-05 22:15:22 -07:00
err_msg = (
f " Cannot solve for values of dimension variables { unsolved_vars } . "
" We can only solve linear uni-variate constraints. " + poly_specs_err_msg +
" Unprocessed specifications: " +
" , " . join ( f " ' { eqn . aval_dim_expr } ' for dimension size { eqn . dim_name } "
for eqn in eqns ) +
2024-06-12 08:47:17 +02:00
" . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details. "
2023-09-05 22:15:22 -07:00
)
raise ValueError ( err_msg )