mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[shape_poly] Add better support for division, and working with strides
Previously, division was only supported in certain situation, and this led to errors, e.g., when using strides. Now we generalize the polynomials to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition to dimension variables. A symbolic dimension is now a sum of products of atoms. (We also changed the documentation to use symbolic dimension instead of dimension polynomials).
This commit is contained in:
parent
1641c8f141
commit
d25bcac93d
@ -26,6 +26,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* {func}`jax2tf.call_tf` has a new parameter `has_side_effects` (default `True`)
|
||||
that can be used to declare whether an instance can be removed or replicated
|
||||
by JAX optimizations such as dead-code elimination ({jax-issue}`#13980`).
|
||||
* Added more support for floordiv and mod for jax2tf shape polymorphism. Previously,
|
||||
certain division operations resulted in errors in presence of symbolic dimensions
|
||||
({jax-issue}`#14108`).
|
||||
|
||||
## jaxlib 0.4.2 (Jan 20, 2023)
|
||||
|
||||
|
@ -786,7 +786,8 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
|
||||
raise ValueError("Negative padding is larger than the size of the corresponding dimension: "
|
||||
f"got padding={pads} for lhs_shape[2:]={lhs_shape[2:]}")
|
||||
out_space = core.stride_shape(lhs_padded, rhs_shape[2:], strides)
|
||||
out_space = np.maximum(0, out_space)
|
||||
out_space = [d if core.greater_equal_dim(d, 0) else 0
|
||||
for d in out_space]
|
||||
if batch_group_count > 1:
|
||||
assert lhs_shape[0] % batch_group_count == 0
|
||||
out_shape_0 = lhs_shape[0] // batch_group_count
|
||||
|
@ -2543,7 +2543,7 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *,
|
||||
|
||||
if core.is_special_dim_size(repeats):
|
||||
if total_repeat_length is not None:
|
||||
raise ValueError("jnp.repeat with a DimPolynomial `repeats` is supported only "
|
||||
raise ValueError("jnp.repeat with a non-constant `repeats` is supported only "
|
||||
"when `total_repeat_length` is None")
|
||||
|
||||
# If total_repeat_length is not given, use a default.
|
||||
|
@ -1097,12 +1097,11 @@ def threefry_2x32(keypair, count):
|
||||
msg = "threefry_2x32 requires uint32 arguments, got {}"
|
||||
raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]]))
|
||||
|
||||
try:
|
||||
odd_size = count.size % 2
|
||||
except core.InconclusiveDimensionOperation as e:
|
||||
odd_size = count.size % 2
|
||||
if not isinstance(odd_size, int):
|
||||
msg = ("jax.random functions have limited support for shape polymorphism. "
|
||||
"In particular, the product of the known dimensions must be even.")
|
||||
raise core.InconclusiveDimensionOperation(msg) from e
|
||||
raise core.InconclusiveDimensionOperation(msg)
|
||||
|
||||
if odd_size:
|
||||
x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2))
|
||||
|
@ -288,7 +288,7 @@ therefore outside of the `XlaSharding` wrapper.
|
||||
## Shape-polymorphic conversion
|
||||
|
||||
**The shape polymorphism support is work in progress. It is meant to be sound,
|
||||
but it may fail to lower some programs. Please report any bugs you encounter.**
|
||||
but it may raise errors on some programs. Please report any bugs you encounter.**
|
||||
|
||||
We described above how to include in the SavedModel several specializations
|
||||
of a lowered function for a few specific input shapes. `jax2tf` can
|
||||
@ -311,13 +311,12 @@ f_tf = tf.function(jax2tf.convert(f_jax,
|
||||
f_tf.get_concrete_function(tf.TensorSpec([None, 28, 28], tf.float32))
|
||||
```
|
||||
|
||||
The `polymorphic_shapes` parameter, in the form of a sequence of strings corresponding
|
||||
to the sequence of positional
|
||||
The `polymorphic_shapes` parameter, in the form of a pytree of strings corresponding
|
||||
to the pytree of positional
|
||||
arguments, introduces one or more dimension variables, e.g., `b`, to stand for shape
|
||||
dimensions that are assumed to be unknown at JAX tracing time, even if the actual
|
||||
parameter value (here `tf.TensorSpec(...)`) happens to have fully known shape.
|
||||
dimensions that are assumed to be unknown at JAX tracing time.
|
||||
Dimension variables are assumed to range
|
||||
over all strictly positive integers.
|
||||
over all integers that are greater or equal to 1.
|
||||
In this particular example, we can
|
||||
also abbreviate `polymorphic_shapes=["(b, _, _)"]`,
|
||||
because the `_` placeholders take their value
|
||||
@ -363,7 +362,7 @@ known `tf.TensorSpec`, and any concrete input `x` whose shape matches `abs_sig`:
|
||||
It is crucial to understand that `f_jax(x)` has the freedom to re-invoke the JAX tracing machinery,
|
||||
and in fact it does so for each distinct concrete input shape, while the generation of `f_tf`
|
||||
uses JAX tracing only once, and invoking `f_tf(x)` does not use JAX tracing anymore. In fact,
|
||||
invoking the latter invocation may happen after the `f_tf` has been serialized
|
||||
the latter invocation may happen after the `f_tf` has been serialized
|
||||
to a SavedModel and reloaded in an environment where `f_jax` and the JAX
|
||||
tracing machinery are not available anymore.
|
||||
|
||||
@ -383,22 +382,22 @@ lowered with the batch dimension polymorphic and the remaining dimensions concre
|
||||
It is reasonable to expect that there will be JAX programs for which there is a
|
||||
shape-polymorphic TensorFlow graph, but which will give an error when lowering with jax2tf.
|
||||
In general, you should expect that shape polymorphism can handle those programs for which
|
||||
all the intermediate shapes can be expressed as polynomials in the dimension variables
|
||||
appearing in the input shapes. In particular, this does not include programs whose
|
||||
all the intermediate shapes can be expressed as simple expressions in the dimension variables
|
||||
appearing in the input shapes. In particular, this does not apply to programs whose
|
||||
intermediate shapes depend on the data.
|
||||
|
||||
### Details
|
||||
|
||||
In order to be able to use shape polymorphism effectively with jax2tf, it
|
||||
is worth considering what happens under the hood. When the lowered function
|
||||
is invoked with a `TensorSpec`, `jax2tf` will combine the
|
||||
`TensorSpec` from the actual argument with the `polymorphic_shapes` parameter to
|
||||
obtain a shape abstraction to be used to specialize the lowered function.
|
||||
is invoked with a `TensorSpec`, `jax2tf` will use the `polymorphic_shapes` parameter
|
||||
to obtain a shape abstraction for the inputs. The dimension sizes from the
|
||||
`TensorSpec` are used to fill in the `_` and `...` placeholders from `polymorphic_shapes`.
|
||||
Normally, the shape abstraction contains the dimension sizes, but in the
|
||||
presence of shape polymorphism, some dimensions may be dimension variables.
|
||||
|
||||
The `polymorphic_shapes` parameter must be either `None`,
|
||||
or a sequence (one per argument) of shape specifiers.
|
||||
or a pytree of shape specifiers corresponding to the pytree of arguments.
|
||||
(A value `None` for `polymorphic_shapes` is equivalent to a list of `None`.
|
||||
See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).)
|
||||
A shape specifier is combined with a `TensorSpec` as follows:
|
||||
@ -421,7 +420,7 @@ A shape specifier is combined with a `TensorSpec` as follows:
|
||||
for any argument are assumed to be equal.
|
||||
|
||||
Note that `polymorphic_shapes` controls the shape abstraction used by JAX when tracing
|
||||
the function (with `_` placeholders given by the `TensorSpec`). The `TensorSpec`
|
||||
the function. The `TensorSpec`
|
||||
gives the shape abstraction that TensorFlow will associate with the produced
|
||||
graph, and can be more specific.
|
||||
|
||||
@ -437,7 +436,7 @@ A few examples of shape specifications and uses:
|
||||
`polymorphic_shapes=["(b, 28, 28)", "(28, 16)"]`.
|
||||
|
||||
* `polymorphic_shapes=["(batch, _)", "(batch,)"]`: the leading dimensions of the two arguments
|
||||
must match, and are assumed to be greater than 0.
|
||||
must match, and are assumed to be greater than 1.
|
||||
The second dimension of the first argument is taken from the
|
||||
actual `TensorSpec`. This can be used with a `TensorSpec` pair `[None, 16]`
|
||||
and `[None]`. It can also be used with a pair of shapes `[8, 16]` and `[8]`.
|
||||
@ -447,14 +446,14 @@ A few examples of shape specifications and uses:
|
||||
JAX keeps track of the shape of all intermediate results. When those shapes depend
|
||||
on dimension variables JAX computes them as symbolic expressions
|
||||
involving dimension variables. The symbolic expressions can represent the result
|
||||
of applying arithmetic operators (add, sub, mul,
|
||||
of applying arithmetic operators (add, sub, mul, floordiv, mod,
|
||||
including the NumPy variants `np.sum`, `np.prod`, etc.) **on dimension
|
||||
variables and integers** (`int`, `np.int`, or anything convertible by `operator.index`).
|
||||
These symbolic dimensions can then be used in shape-parameters of JAX primitives
|
||||
and APIs, e.g., in `jnp.reshape`, `jnp.arange`, slicing indices, etc.
|
||||
|
||||
For example, in the following code to flatten a 2D array, the computation
|
||||
`x.shape[0] * x.shape[1]` computes the dimension polynomial `4 * b` as the
|
||||
`x.shape[0] * x.shape[1]` computes the symbolic dimension `4 * b` as the
|
||||
new shape:
|
||||
|
||||
```python
|
||||
@ -541,7 +540,7 @@ The solution is to avoid `np.array`, `float`, or JAX arrays in operations whose
|
||||
results are used as shapes, e.g., instead of `np.arange(n) * x.shape[0]` write
|
||||
`[i * x.shape[0] for i in range(n)]`.
|
||||
|
||||
### Comparison of shape polynomials is partially supported
|
||||
### Comparison of symbolic dimensions is partially supported
|
||||
|
||||
Inside JAX there are a number of equality and inequality comparisons
|
||||
involving shapes, e.g., for doing shape checking or even for choosing
|
||||
@ -576,10 +575,11 @@ as `False` and produce a lowered function that returns `1` just because the dime
|
||||
are not identical: there are some concrete input shapes for which the function
|
||||
should return `0`.
|
||||
|
||||
### Division of shape polynomials is partially supported
|
||||
### Division of symbolic dimensions is partially supported
|
||||
|
||||
Unlike addition and multiplication, which are fully supported on
|
||||
shape polynomials, division is only supported when either (a) there
|
||||
JAX will attempt to simplify division and modulo operations,
|
||||
e.g., `(a * b + a) // (b + 1) == a` and `6*a + 4 % 3 == 1`.
|
||||
In particular, JAX will handle the cases when either (a) there
|
||||
is no remainder, or (b) the divisor is a constant
|
||||
in which case there may be a constant remainder.
|
||||
For example, the code below results in a division error when trying to
|
||||
@ -609,7 +609,11 @@ jax2tf.convert(lambda x: jnp.reshape(x, (-1, x.shape[0])),
|
||||
You may also encounter division errors when working with strides, such as
|
||||
when computing the padding in a strided convolution.
|
||||
|
||||
In some cases you may know that one of the dimension variables
|
||||
When JAX cannot simplify the result of symbolic dimension division it
|
||||
will construct symbolic expressions of the form `floordiv(E, N)` and
|
||||
`mod(E, N)` and it will use a number of heuristics to evaluate comparisons
|
||||
involving these. If you encounter `InconclusiveDimensionOperation` exceptions
|
||||
you can specify that a dimension variable
|
||||
is a multiple of the divisor,
|
||||
e.g., `b` in the above example of dividing `35*b` by `-2` may
|
||||
be known to be a multiple of `2`. You can specify that by replacing
|
||||
@ -623,7 +627,7 @@ jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
|
||||
### Dimension variables must be solvable from the input shapes
|
||||
|
||||
`jax2tf` will generate code to derive the values of the dimension variables
|
||||
from the input shapes. This works only if dimension polynomials in the input shapes are linear.
|
||||
from the input shapes. This works only if the symbolic dimensions in the input shapes are linear.
|
||||
For example, the following `polymorphic_shapes` will result in errors:
|
||||
|
||||
```python
|
||||
|
@ -498,7 +498,7 @@ def make_custom_gradient_fn_tf(
|
||||
"This should not happen for first-order differentiation. "
|
||||
f"{variables=}")
|
||||
|
||||
out_cts_flat_polymorphic_shapes = tuple(str(out_aval.shape) # Note: may be polynomials, not just DimVar
|
||||
out_cts_flat_polymorphic_shapes = tuple(str(out_aval.shape) # Note: may be _DimExpr, not just DimVar
|
||||
for out_aval in out_avals) # type: ignore
|
||||
vjp_polymorphic_shapes = [
|
||||
polymorphic_shapes_flat, out_cts_flat_polymorphic_shapes
|
||||
@ -996,12 +996,12 @@ def _ensure_tf_shape_if_dynamic(x: TfVal, shape):
|
||||
def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]):
|
||||
"""Asserts that shape matches x.shape in the known dimensions and has
|
||||
dimension polynomials elsewhere."""
|
||||
# Ensures that the shape does not contain None; it should contain polynomials
|
||||
# Ensures that the shape does not contain None; it should contain symbolic expressions.
|
||||
def check_one(xd: Optional[int], sd: Any):
|
||||
if core.is_constant_dim(sd):
|
||||
return xd == sd
|
||||
else:
|
||||
assert isinstance(sd, shape_poly._DimPolynomial)
|
||||
assert isinstance(sd, shape_poly._DimExpr)
|
||||
return True
|
||||
assert (len(x.shape) == len(shape) and
|
||||
all(check_one(xd, sd)
|
||||
@ -1858,8 +1858,9 @@ def _conv_general_dilated(lhs, rhs, *,
|
||||
if tf_version >= (2, 8):
|
||||
# TODO(necula): remove when 2.8.0 is the stable TF version (and supports
|
||||
# batch_group_count.
|
||||
padding_tf = [_eval_shape(p) for p in padding]
|
||||
out = tfxla.conv(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
lhs, rhs, window_strides, padding_tf, lhs_dilation, rhs_dilation,
|
||||
dnums_proto,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
@ -1871,8 +1872,9 @@ def _conv_general_dilated(lhs, rhs, *,
|
||||
raise ValueError(
|
||||
"The batch_group_count parameter for conv requires TF version "
|
||||
"at least 2.8.0. You may want to use tf-nightly.")
|
||||
padding_tf = [_eval_shape(p) for p in padding]
|
||||
out = tfxla.conv(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
lhs, rhs, window_strides, padding_tf, lhs_dilation, rhs_dilation,
|
||||
dnums_proto,
|
||||
feature_group_count=feature_group_count,
|
||||
precision_config=precision_config_proto,
|
||||
|
@ -15,14 +15,14 @@
|
||||
|
||||
We introduce a set of dimension variables at the top-level of a `jit` function.
|
||||
They are introduced implicitly by way of specifying for each dimension of each
|
||||
argument a dimension polynomial in terms of some dimension variables.
|
||||
argument a symbolic dimension expression in terms of some dimension variables.
|
||||
All dimension variables are assumed to range over integers greater or equal to 1.
|
||||
|
||||
Dimension polynomials overload some integer operations, such as
|
||||
add, multiply, equality, etc. The JAX NumPy layer and the LAX layers have been
|
||||
touched up to be sensitive to handling shapes that contain dimension
|
||||
polynomials. This enables many JAX programs to be traced with dimension
|
||||
polynomials in some dimensions. A priority has been to enable the batch
|
||||
Symbolic dimensions overload some integer operations, such as
|
||||
add, multiply, divide, equality, etc. The JAX NumPy layer and the LAX layers have been
|
||||
touched up to be sensitive to handling shapes that contain symbolic dimensions.
|
||||
This enables many JAX programs to be traced with symbolic dimensions
|
||||
in some dimensions. A priority has been to enable the batch
|
||||
dimension in neural network examples to be polymorphic.
|
||||
|
||||
This was built initially for jax2tf, but it is now customizeable to be
|
||||
@ -34,6 +34,7 @@ import collections
|
||||
import dataclasses
|
||||
import itertools
|
||||
import functools
|
||||
import math
|
||||
import operator as op
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
|
||||
@ -63,9 +64,9 @@ class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation):
|
||||
"""Raised when we cannot conclusively compute with symbolic dimensions."""
|
||||
|
||||
_help_msg = """
|
||||
This error arises for arithmetic or comparison operations with shapes that
|
||||
This error arises for comparison operations with shapes that
|
||||
are non-constant, and the result of the operation cannot be represented as
|
||||
a polynomial of dimension variables, or a boolean constant (for comparisons).
|
||||
a boolean value for all values of the symbolic dimensions involved.
|
||||
|
||||
Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables
|
||||
for more details.
|
||||
@ -76,13 +77,152 @@ for more details.
|
||||
# https://github.com/python/mypy/issues/5887
|
||||
super().__init__(error_msg) # type: ignore
|
||||
|
||||
class _DimAtom:
|
||||
"""Represents an atom in a symbolic dimension expression.
|
||||
|
||||
Atoms are either variables, or expressions of the form floordiv(E1, E2) or
|
||||
mod(E1, E2). Atoms are multiplied to form monomials (see _DimMon), and
|
||||
monomials are added to form symbolic expressions (see _DimExpr).
|
||||
|
||||
Args:
|
||||
* var: if specified then the atom is a dimension variable. `operation`
|
||||
must be `None`.
|
||||
* operation: if specified then the atom is an operation applied to
|
||||
`operands`. One of `FLOORDIR` or `MOD`. `var` must be `None`
|
||||
* operands: the operands to which the operation is applied.
|
||||
"""
|
||||
# The supported operations
|
||||
FLOORDIV = "floordiv"
|
||||
MOD = "mod"
|
||||
|
||||
def __init__(self, *operands: '_DimExpr',
|
||||
var: Optional[str] = None,
|
||||
operation: Optional[str] = None):
|
||||
if var is not None:
|
||||
assert operation is None
|
||||
assert not operands
|
||||
else:
|
||||
assert operation is not None
|
||||
self.var = var
|
||||
self.operation = operation
|
||||
self.operands = operands
|
||||
|
||||
@classmethod
|
||||
def from_var(cls, v: str) -> '_DimAtom':
|
||||
return _DimAtom(var=v)
|
||||
|
||||
def to_var(self) -> Optional[str]:
|
||||
return self.var
|
||||
|
||||
def get_vars(self) -> Set[str]:
|
||||
# All the vars that appear
|
||||
if self.var is not None:
|
||||
return {self.var}
|
||||
else:
|
||||
acc = set()
|
||||
for opnd in self.operands:
|
||||
acc.update(opnd.get_vars())
|
||||
return acc
|
||||
|
||||
@classmethod
|
||||
def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimAtom':
|
||||
return _DimAtom(*operands, operation=operation)
|
||||
|
||||
def __str__(self):
|
||||
if self.var is not None:
|
||||
return self.var
|
||||
opnd_str = ", ".join([str(opnd) for opnd in self.operands])
|
||||
return f"{self.operation}({opnd_str})"
|
||||
__repr__ = __str__
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.var, self.operation, *self.operands))
|
||||
|
||||
def __eq__(self, other: Any):
|
||||
# Used only for hashing
|
||||
if not isinstance(other, _DimAtom): return False
|
||||
if (self.var is None) != (other.var is None): return False
|
||||
if self.var is not None:
|
||||
return self.var == other.var
|
||||
else:
|
||||
def symbolic_equal(e1: '_DimExpr', e2: '_DimExpr') -> bool:
|
||||
try:
|
||||
return e1 == e2
|
||||
except InconclusiveDimensionOperation:
|
||||
return False
|
||||
return (self.operation == other.operation and
|
||||
all(symbolic_equal(self_o, other_o)
|
||||
for self_o, other_o in zip(self.operands, other.operands)))
|
||||
|
||||
def __lt__(self, other: '_DimAtom'):
|
||||
"""
|
||||
Comparison to another atom in graded reverse lexicographic order.
|
||||
Used only for determining a sorting order, does not relate to the
|
||||
comparison of the values of the atom.
|
||||
"""
|
||||
if self.var is not None and other.var is not None:
|
||||
return self.var < other.var
|
||||
elif self.var is not None:
|
||||
return True
|
||||
elif other.var is not None:
|
||||
return True
|
||||
elif self.operation != other.operation:
|
||||
return self.operation < other.operation # type: ignore
|
||||
else:
|
||||
return id(self) < id(other)
|
||||
|
||||
def bounds(self) -> Tuple[float, float]:
|
||||
"""Returns the lower and upper bounds, or -+ inf."""
|
||||
if self.var is not None:
|
||||
return (1, np.PINF) # variables are assumed to be >= 1
|
||||
opnd_bounds = [opnd.bounds() for opnd in self.operands]
|
||||
if self.operation == _DimAtom.FLOORDIV: # a // b
|
||||
(a_l, a_u), (b_l, b_u) = opnd_bounds
|
||||
def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
|
||||
assert b != 0
|
||||
if not np.isinf(b): # divisor is finite
|
||||
return math.floor(a / b) if not np.isinf(a) else np.NINF if (a >= 0) != (b >= 0) else np.PINF
|
||||
elif not np.isinf(a): # dividend is finite and divisor is infinite
|
||||
return -1 if (a >= 0) != (b >= 0) else 0
|
||||
else: # both dividend and divisor are infinite
|
||||
return np.NINF if (a >= 0) != (b >= 0) else np.PINF
|
||||
|
||||
# Same reasoning as for multiplication: the bounds are among the cross-product
|
||||
# of the bounds.
|
||||
bound_candidates = [math_floor_with_inf(a_l, b_l), math_floor_with_inf(a_l, b_u),
|
||||
math_floor_with_inf(a_u, b_l), math_floor_with_inf(a_u, b_u)]
|
||||
return (min(*bound_candidates), max(*bound_candidates))
|
||||
|
||||
elif self.operation == _DimAtom.MOD:
|
||||
_, (b_l, b_u) = opnd_bounds
|
||||
if b_l > 0: # positive divisor
|
||||
return (0, b_u - 1)
|
||||
elif b_u < 0: # negative divisor
|
||||
return (b_l + 1, 0)
|
||||
else:
|
||||
return (np.NINF, np.PINF)
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
def evaluate(self, env: ShapeEnv):
|
||||
if self.var is not None:
|
||||
return env[self.var]
|
||||
else:
|
||||
operand_values = [opnd.evaluate(env) for opnd in self.operands]
|
||||
div_mod = divmod(*operand_values) # type: ignore
|
||||
if self.operation == _DimAtom.FLOORDIV:
|
||||
return div_mod[0]
|
||||
elif self.operation == _DimAtom.MOD:
|
||||
return div_mod[1]
|
||||
else:
|
||||
assert False, self.operation
|
||||
|
||||
class _DimMon(dict):
|
||||
"""Represents a multivariate monomial, such as n^3 * m.
|
||||
"""Represents a multiplication of atoms.
|
||||
|
||||
The representation is a dictionary mapping var:exponent.
|
||||
The `var` are strings and the exponents are integers >= 1.
|
||||
The dimension variables are assumed to range over integers >= 1.
|
||||
The representation is a dictionary mapping _DimAtom to exponent.
|
||||
The exponents are integers >= 1.
|
||||
"""
|
||||
def __hash__(self):
|
||||
return hash(frozenset(self.items()))
|
||||
@ -93,7 +233,7 @@ class _DimMon(dict):
|
||||
|
||||
@classmethod
|
||||
def from_var(cls, v: str) -> '_DimMon':
|
||||
return _DimMon({v: 1})
|
||||
return _DimMon({_DimAtom.from_var(v): 1})
|
||||
|
||||
def to_var(self) -> Optional[str]:
|
||||
"""Extract the variable name "x", from a monomial "x".
|
||||
@ -101,13 +241,21 @@ class _DimMon(dict):
|
||||
items = self.items()
|
||||
if len(items) != 1:
|
||||
return None
|
||||
(v, vexp), = items
|
||||
if vexp != 1:
|
||||
(a, aexp), = items
|
||||
if aexp != 1:
|
||||
return None
|
||||
return v
|
||||
return a.to_var()
|
||||
|
||||
def get_vars(self) -> Set[str]:
|
||||
return set(self.keys())
|
||||
# All the vars that appear in the monomial
|
||||
acc = set()
|
||||
for a in self.keys():
|
||||
acc.update(a.get_vars())
|
||||
return acc
|
||||
|
||||
@classmethod
|
||||
def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimMon':
|
||||
return _DimMon({_DimAtom.from_operation(operation, *operands): 1})
|
||||
|
||||
@property
|
||||
def degree(self):
|
||||
@ -116,7 +264,8 @@ class _DimMon(dict):
|
||||
def __lt__(self, other: '_DimMon'):
|
||||
"""
|
||||
Comparison to another monomial in graded reverse lexicographic order.
|
||||
Used for sorting.
|
||||
Used only for determining a sorting order, does not relate to the
|
||||
comparison of the values of the monomial.
|
||||
"""
|
||||
self_key = -self.degree, tuple(sorted(self))
|
||||
other_key = -other.degree, tuple(sorted(other))
|
||||
@ -143,60 +292,114 @@ class _DimMon(dict):
|
||||
elif diff > 0: d[key] = diff
|
||||
return _DimMon(d)
|
||||
|
||||
def bounds(self) -> Tuple[float, float]:
|
||||
"""Returns the lower and upper bounds, or -+inf."""
|
||||
# The bounds of a product are among the product of bounds.
|
||||
bounds = []
|
||||
for a, exp in self.items():
|
||||
a_l, a_u = a.bounds()
|
||||
assert a_l <= a_u
|
||||
bounds.append((a_l ** exp, a_u ** exp))
|
||||
|
||||
candidates = [np.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)]
|
||||
return (min(*candidates), max(*candidates)) # type: ignore
|
||||
|
||||
|
||||
def evaluate(self, env: ShapeEnv):
|
||||
prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else dim_constant(1)
|
||||
def pow_opt(v, p: int):
|
||||
return v if p == 1 else prod([v] * p)
|
||||
return prod([pow_opt(env[id], deg) for id, deg in self.items()])
|
||||
return prod([pow_opt(a.evaluate(env), deg) for a, deg in self.items()])
|
||||
|
||||
|
||||
class _DimPolynomial():
|
||||
"""Polynomial with integer coefficients for polymorphic shapes.
|
||||
class _DimExpr():
|
||||
"""Symbolic expression in terms of dimension variables.
|
||||
|
||||
The dimension variables are assumed to range over integers >= 1.
|
||||
A dimension expression is an addition of products (_DimMon)
|
||||
f atoms (_DimAtom).
|
||||
|
||||
We overload integer operations, but we do that soundly, raising
|
||||
:class:`InconclusiveDimensionOperation` when the result is not
|
||||
representable as a polynomial.
|
||||
representable as a _DimExpr.
|
||||
|
||||
The representation of a polynomial is as a dictionary mapping _DimMonomial to
|
||||
integer coefficients. The special monomial `_DimMonomial()` is mapped to the
|
||||
free integer coefficient of the polynomial. The constant result of arithmetic
|
||||
operations is represented as a Python constant.
|
||||
The representation of a _DimExpr is as a dictionary mapping _DimMon to
|
||||
integer coefficients. The special monomial `_DimMon()` is mapped to the
|
||||
free integer coefficient of the expression.
|
||||
"""
|
||||
|
||||
__array_priority__ = 1000 # Same as tracer, for __radd__ and others on ndarray
|
||||
def __init__(self, coeffs: Dict[_DimMon, int]):
|
||||
# Makes sure Polynomials are always in canonical form
|
||||
coeffs = {mon: op.index(coeff)
|
||||
for mon, coeff in coeffs.items() if coeff != 0}
|
||||
self._coeffs = coeffs or {_DimMon(): 0}
|
||||
# Do not construct _DimExpr directly, use _DimExpr.normalize
|
||||
self._coeffs = coeffs.copy() or {_DimMon(): 0}
|
||||
|
||||
def monomials(self) -> Iterable[Tuple[_DimMon, int]]:
|
||||
return self._coeffs.items()
|
||||
|
||||
@classmethod
|
||||
def from_coeffs(cls, coeffs: Dict[_DimMon, int]) -> DimSize:
|
||||
"""Constructs _DimPolynomial or an int."""
|
||||
def normalize(cls, coeffs: Dict[_DimMon, int]) -> DimSize:
|
||||
"""The main constructor for _DimExpr.
|
||||
|
||||
Ensures that the symbolic dimension is normalized, e.g.,
|
||||
it is represented as a Python int if it is known to be a constant.
|
||||
"""
|
||||
has_non_zero_degree = False
|
||||
zero_degree_const = 0
|
||||
new_coeffs = {}
|
||||
for mon, count in coeffs.items():
|
||||
if count != 0:
|
||||
if mon.degree == 0:
|
||||
zero_degree_const = count
|
||||
else:
|
||||
has_non_zero_degree = True
|
||||
new_coeffs[mon] = count
|
||||
if not has_non_zero_degree:
|
||||
return int(zero_degree_const)
|
||||
return _DimPolynomial(new_coeffs)
|
||||
free_const = 0
|
||||
new_coeffs: Dict[_DimMon, int] = {}
|
||||
for mon, coeff in coeffs.items():
|
||||
if coeff == 0: continue
|
||||
if mon.degree == 0: # A constant, there can be a single one
|
||||
free_const = coeff
|
||||
else:
|
||||
has_non_zero_degree = True
|
||||
|
||||
# Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes
|
||||
# up when handling strided convolution.
|
||||
def normalize_floordiv_times(m: _DimMon, coeff: int) -> Optional['_DimExpr']:
|
||||
floordivs = [(a, aexp) for a, aexp in m.items() if a.operation == _DimAtom.FLOORDIV]
|
||||
# A single floordiv with exponent 1
|
||||
if len(floordivs) != 1 or floordivs[0][1] != 1: return None
|
||||
floordiv, _ = floordivs[0]
|
||||
floordiv_dividend_monomials = list(floordiv.operands[1].monomials())
|
||||
if len(floordiv_dividend_monomials) != 1: return None
|
||||
floordiv_dividend_monomial, floordiv_dividend_coeff = floordiv_dividend_monomials[0]
|
||||
if coeff % floordiv_dividend_coeff: return None
|
||||
try:
|
||||
m_trimmed = m.divide(floordiv_dividend_monomial)
|
||||
except InconclusiveDimensionOperation:
|
||||
return None
|
||||
c = coeff // floordiv_dividend_coeff
|
||||
m_trimmed = m_trimmed.divide(_DimMon({floordiv: 1})) # Remove the floordiv
|
||||
return (_DimExpr.from_monomial(m_trimmed, c) *
|
||||
(floordiv.operands[0] - _DimExpr.from_operation(_DimAtom.MOD, *floordiv.operands)))
|
||||
|
||||
mon_poly = normalize_floordiv_times(mon, coeff)
|
||||
if mon_poly is not None:
|
||||
monomials = mon_poly.monomials()
|
||||
else:
|
||||
monomials = [(mon, coeff)]
|
||||
for m, c in monomials:
|
||||
new_coeffs[m] = new_coeffs.get(m, 0) + c
|
||||
|
||||
if has_non_zero_degree:
|
||||
return _DimExpr(new_coeffs)
|
||||
else:
|
||||
return int(free_const)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_var(cls, v: str) -> '_DimPolynomial':
|
||||
return _DimPolynomial({_DimMon.from_var(v): 1})
|
||||
def from_monomial(cls, mon: _DimMon, exp: int):
|
||||
return _DimExpr.normalize({mon: exp})
|
||||
|
||||
@classmethod
|
||||
def from_var(cls, v: str) -> '_DimExpr':
|
||||
return _DimExpr({_DimMon.from_var(v): 1})
|
||||
|
||||
@classmethod
|
||||
def from_operation(cls, operation: str, *operands: '_DimExpr') -> '_DimExpr':
|
||||
return _DimExpr.from_monomial(_DimMon.from_operation(operation, *operands), 1)
|
||||
|
||||
def to_var(self) -> Optional[str]:
|
||||
"""Extract the variable name "x", from a polynomial "x" """
|
||||
"""Extract the variable name "x", from a symbolic expression."""
|
||||
items = self.monomials()
|
||||
if len(items) != 1: # type: ignore
|
||||
return None
|
||||
@ -206,7 +409,7 @@ class _DimPolynomial():
|
||||
return mon.to_var()
|
||||
|
||||
def get_vars(self) -> Set[str]:
|
||||
"""The variables that appear in a polynomial."""
|
||||
"""The variables that appear in a symbolic dimension."""
|
||||
acc = set()
|
||||
for mon, _ in self.monomials():
|
||||
acc.update(mon.get_vars())
|
||||
@ -216,23 +419,23 @@ class _DimPolynomial():
|
||||
lb, ub = _ensure_poly(self - other, "eq").bounds()
|
||||
if lb == ub == 0:
|
||||
return True
|
||||
if lb is not None and lb > 0:
|
||||
if lb > 0:
|
||||
return False
|
||||
if ub is not None and ub < 0:
|
||||
if ub < 0:
|
||||
return False
|
||||
raise InconclusiveDimensionOperation(
|
||||
f"Dimension polynomial comparison '{self}' == '{other}' is inconclusive.\n"
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-dimension-polynomials-is-partially-supported.")
|
||||
f"Symbolic dimension comparison '{self}' == '{other}' is inconclusive.\n"
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported.")
|
||||
|
||||
def ge(self, other: DimSize) -> bool:
|
||||
lb, ub = _ensure_poly(self - other, "ge").bounds()
|
||||
if lb is not None and lb >= 0:
|
||||
if lb >= 0:
|
||||
return True
|
||||
if ub is not None and ub < 0:
|
||||
if ub < 0:
|
||||
return False
|
||||
raise InconclusiveDimensionOperation(
|
||||
f"Dimension polynomial comparison '{self}' >= '{other}' is inconclusive.\n"
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-dimension-polynomials-is-partially-supported.")
|
||||
f"Symbolic dimension comparison '{self}' >= '{other}' is inconclusive.\n"
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic0dimensions-is-partially-supported.")
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(sorted(self.monomials())))
|
||||
@ -250,7 +453,7 @@ class _DimPolynomial():
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
# We overload +, -, *, because they are fully defined for _DimPolynomial.
|
||||
# We overload +, -, *, because they are fully defined for _DimExpr.
|
||||
def __add__(self, other):
|
||||
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
||||
return self.__jax_array__().__add__(other)
|
||||
@ -259,7 +462,7 @@ class _DimPolynomial():
|
||||
coeffs = self._coeffs.copy()
|
||||
for mon, coeff in other.monomials():
|
||||
coeffs[mon] = coeffs.get(mon, 0) + coeff
|
||||
return _DimPolynomial.from_coeffs(coeffs)
|
||||
return _DimExpr.normalize(coeffs)
|
||||
|
||||
def __radd__(self, other):
|
||||
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
||||
@ -276,18 +479,19 @@ class _DimPolynomial():
|
||||
return self.__jax_array__().__rsub__(other)
|
||||
return _ensure_poly(other, "sub").__sub__(self)
|
||||
|
||||
def __neg__(self) -> '_DimPolynomial':
|
||||
return _DimPolynomial({mon: -coeff for mon, coeff in self.monomials()})
|
||||
def __neg__(self) -> '_DimExpr':
|
||||
return _DimExpr({mon: -coeff for mon, coeff in self.monomials()})
|
||||
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
||||
return self.__jax_array__().__mul__(other)
|
||||
other = _ensure_poly(other, "mul")
|
||||
coeffs: Dict[_DimMon, int] = {}
|
||||
for (mon1, coeff1), (mon2, coeff2) in itertools.product(self.monomials(), other.monomials()):
|
||||
mon = mon1.mul(mon2)
|
||||
coeffs[mon] = coeffs.get(mon, 0) + coeff1 * coeff2
|
||||
return _DimPolynomial.from_coeffs(coeffs)
|
||||
for mon1, coeff1 in self.monomials():
|
||||
for mon2, coeff2 in other.monomials():
|
||||
mon = mon1.mul(mon2)
|
||||
coeffs[mon] = coeffs.get(mon, 0) + coeff1 * coeff2
|
||||
return _DimExpr.normalize(coeffs)
|
||||
|
||||
def __rmul__(self, other):
|
||||
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
||||
@ -299,7 +503,7 @@ class _DimPolynomial():
|
||||
try:
|
||||
power = int(power)
|
||||
except:
|
||||
raise InconclusiveDimensionOperation(f"Dimension polynomial cannot be raised to non-integer power '{self}' ^ '{power}'")
|
||||
raise InconclusiveDimensionOperation(f"Symblic dimension cannot be raised to non-integer power '{self}' ^ '{power}'")
|
||||
return functools.reduce(op.mul, [self] * power)
|
||||
|
||||
def __floordiv__(self, divisor):
|
||||
@ -308,8 +512,7 @@ class _DimPolynomial():
|
||||
return self.divmod(_ensure_poly(divisor, "floordiv"))[0]
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
# A special case for int // poly: we use the __jax_array__ path
|
||||
if isinstance(other, core.Tracer) or _convertible_to_int(other) or not _convertible_to_poly(other):
|
||||
if isinstance(other, core.Tracer) or not _convertible_to_poly(other):
|
||||
return self.__jax_array__().__rfloordiv__(other)
|
||||
return _ensure_poly(other, "floordiv").__floordiv__(self)
|
||||
|
||||
@ -318,7 +521,7 @@ class _DimPolynomial():
|
||||
return self.__jax_array__().__truediv__(divisor)
|
||||
|
||||
def __rtruediv__(self, dividend):
|
||||
# Used for "/", when dividend is not a _DimPolynomial
|
||||
# Used for "/", when dividend is not a _DimExpr
|
||||
return self.__jax_array__().__rtruediv__(dividend)
|
||||
|
||||
def __mod__(self, divisor):
|
||||
@ -327,8 +530,7 @@ class _DimPolynomial():
|
||||
return self.divmod(_ensure_poly(divisor, "mod"))[1]
|
||||
|
||||
def __rmod__(self, dividend):
|
||||
# A special case for int // poly: we use the __jax_array__ path
|
||||
if isinstance(dividend, core.Tracer) or _convertible_to_int(dividend) or not _convertible_to_poly(dividend):
|
||||
if isinstance(dividend, core.Tracer) or not _convertible_to_poly(dividend):
|
||||
return self.__jax_array__().__rmod__(dividend)
|
||||
return _ensure_poly(dividend, "mod").__mod__(self)
|
||||
|
||||
@ -346,7 +548,7 @@ class _DimPolynomial():
|
||||
if self.is_constant:
|
||||
return op.index(next(iter(self._coeffs.values())))
|
||||
else:
|
||||
raise InconclusiveDimensionOperation(f"Dimension polynomial '{self}' used in a context that requires a constant")
|
||||
raise InconclusiveDimensionOperation(f"Symbolic dimension '{self}' used in a context that requires a constant")
|
||||
|
||||
# We must overload __eq__ and __ne__, or else we get unsound defaults.
|
||||
__eq__ = eq
|
||||
@ -364,14 +566,7 @@ class _DimPolynomial():
|
||||
def __lt__(self, other: DimSize):
|
||||
return not self.__ge__(other)
|
||||
|
||||
def _division_error_msg(self, dividend, divisor, details: str = "") -> str:
|
||||
msg = f"Cannot divide '{dividend}' by '{divisor}'."
|
||||
if details:
|
||||
msg += f"\nDetails: {details}."
|
||||
msg += "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported."
|
||||
return msg
|
||||
|
||||
def divmod(self, divisor: "_DimPolynomial") -> Tuple[DimSize, int]:
|
||||
def divmod(self, divisor: "_DimExpr") -> Tuple[DimSize, int]:
|
||||
"""
|
||||
Floor division with remainder (divmod) generalized to polynomials.
|
||||
If the `divisor` is not a constant, the remainder must be 0.
|
||||
@ -380,53 +575,55 @@ class _DimPolynomial():
|
||||
|
||||
:return: Quotient resulting from polynomial division and integer remainder.
|
||||
"""
|
||||
assert isinstance(divisor, _DimPolynomial)
|
||||
dmon, dcount = divisor.leading_term
|
||||
dividend, quotient = self, 0
|
||||
# invariant: self = dividend + divisor * quotient
|
||||
# the leading term of dividend decreases through the loop.
|
||||
while is_poly_dim(dividend) and not dividend.is_constant:
|
||||
mon, count = dividend.leading_term
|
||||
try:
|
||||
qmon = mon.divide(dmon)
|
||||
except InconclusiveDimensionOperation as e:
|
||||
raise InconclusiveDimensionOperation(
|
||||
self._division_error_msg(self, divisor, str(e)))
|
||||
qcount, rcount = divmod(count, dcount)
|
||||
if rcount != 0:
|
||||
raise InconclusiveDimensionOperation(
|
||||
self._division_error_msg(self, divisor))
|
||||
assert isinstance(divisor, _DimExpr)
|
||||
try:
|
||||
dmon, dcount = divisor.leading_term
|
||||
dividend, quotient = self, 0
|
||||
# invariant: self = dividend + divisor * quotient
|
||||
# quotient and dividend are changed in the loop; the leading term of
|
||||
# dividend decreases at each iteration.
|
||||
while is_poly_dim(dividend) and not dividend.is_constant:
|
||||
mon, count = dividend.leading_term
|
||||
try:
|
||||
qmon = mon.divide(dmon)
|
||||
except InconclusiveDimensionOperation:
|
||||
raise InconclusiveDimensionOperation("")
|
||||
qcount, rcount = divmod(count, dcount)
|
||||
if rcount != 0:
|
||||
raise InconclusiveDimensionOperation("")
|
||||
|
||||
q = _DimPolynomial.from_coeffs({qmon: qcount})
|
||||
quotient += q
|
||||
dividend -= q * divisor # type: ignore[assignment]
|
||||
q = _DimExpr.from_monomial(qmon, qcount)
|
||||
quotient += q
|
||||
dividend -= q * divisor # type: ignore[assignment]
|
||||
|
||||
dividend = int(dividend) # type: ignore[assignment]
|
||||
if divisor.is_constant:
|
||||
q, r = divmod(dividend, int(divisor)) # type: ignore
|
||||
quotient += q
|
||||
remainder = r
|
||||
else:
|
||||
if dividend != 0:
|
||||
raise InconclusiveDimensionOperation(
|
||||
self._division_error_msg(self, divisor))
|
||||
remainder = 0
|
||||
dividend = int(dividend) # type: ignore[assignment]
|
||||
if divisor.is_constant:
|
||||
q, r = divmod(dividend, int(divisor)) # type: ignore
|
||||
quotient += q
|
||||
remainder = r
|
||||
else:
|
||||
if dividend != 0:
|
||||
raise InconclusiveDimensionOperation("")
|
||||
remainder = 0
|
||||
|
||||
if config.jax_enable_checks:
|
||||
assert self == divisor * quotient + remainder
|
||||
return quotient, remainder
|
||||
if config.jax_enable_checks:
|
||||
assert self == divisor * quotient + remainder
|
||||
return quotient, remainder
|
||||
except InconclusiveDimensionOperation:
|
||||
return (_DimExpr.from_operation(_DimAtom.FLOORDIV, self, divisor), # type: ignore
|
||||
_DimExpr.from_operation(_DimAtom.MOD, self, divisor))
|
||||
|
||||
def bounds(self) -> Tuple[Optional[int], Optional[int]]:
|
||||
"""Returns the lower and upper bounds, if defined."""
|
||||
def bounds(self) -> Tuple[float, float]:
|
||||
"""Returns the lower and upper bounds, or -+inf."""
|
||||
lb = ub = self._coeffs.get(_DimMon(), 0) # The free coefficient
|
||||
for mon, coeff in self.monomials():
|
||||
if mon.degree == 0: continue
|
||||
if coeff > 0:
|
||||
ub = None # type: ignore
|
||||
lb = None if lb is None else lb + coeff
|
||||
else:
|
||||
lb = None # type: ignore
|
||||
ub = None if ub is None else ub + coeff
|
||||
if mon.degree == 0: continue # We already included the free coefficient
|
||||
m_l, m_u = mon.bounds()
|
||||
assert m_l <= m_u and coeff != 0
|
||||
item_l, item_u = coeff * m_l, coeff * m_u
|
||||
lb = lb + min(item_l, item_u) # type: ignore
|
||||
ub = ub + max(item_l, item_u) # type: ignore
|
||||
|
||||
return lb, ub
|
||||
|
||||
@property
|
||||
@ -444,7 +641,7 @@ class _DimPolynomial():
|
||||
return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0]
|
||||
|
||||
@staticmethod
|
||||
def get_aval(dim: "_DimPolynomial"):
|
||||
def get_aval(dim: "_DimExpr"):
|
||||
return dim_as_value_abstract(dim)
|
||||
|
||||
def dimension_as_value(self):
|
||||
@ -455,9 +652,9 @@ class _DimPolynomial():
|
||||
# Used for implicit coercions of polynomials as JAX arrays
|
||||
return _dim_as_value(self)
|
||||
|
||||
core.pytype_aval_mappings[_DimPolynomial] = _DimPolynomial.get_aval
|
||||
xla.pytype_aval_mappings[_DimPolynomial] = _DimPolynomial.get_aval
|
||||
dtypes._weak_types.append(_DimPolynomial)
|
||||
core.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval
|
||||
xla.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval
|
||||
dtypes._weak_types.append(_DimExpr)
|
||||
|
||||
def _convertible_to_int(p: DimSize) -> bool:
|
||||
try:
|
||||
@ -467,17 +664,17 @@ def _convertible_to_int(p: DimSize) -> bool:
|
||||
return False
|
||||
|
||||
def _ensure_poly(p: DimSize,
|
||||
operation_name: str) -> _DimPolynomial:
|
||||
if isinstance(p, _DimPolynomial): return p
|
||||
operation_name: str) -> _DimExpr:
|
||||
if isinstance(p, _DimExpr): return p
|
||||
if _convertible_to_int(p):
|
||||
return _DimPolynomial({_DimMon(): op.index(p)})
|
||||
raise TypeError(f"Dimension polynomial {operation_name} not supported for {p}.")
|
||||
return _DimExpr({_DimMon(): op.index(p)})
|
||||
raise TypeError(f"Symnbolic dimension {operation_name} not supported for {p}.")
|
||||
|
||||
def _convertible_to_poly(p: DimSize) -> bool:
|
||||
return isinstance(p, _DimPolynomial) or _convertible_to_int(p)
|
||||
return isinstance(p, _DimExpr) or _convertible_to_int(p)
|
||||
|
||||
def is_poly_dim(p: DimSize) -> bool:
|
||||
return isinstance(p, _DimPolynomial)
|
||||
return isinstance(p, _DimExpr)
|
||||
|
||||
|
||||
class DimensionHandlerPoly(core.DimensionHandler):
|
||||
@ -486,7 +683,7 @@ class DimensionHandlerPoly(core.DimensionHandler):
|
||||
Most methods are inherited.
|
||||
"""
|
||||
def is_constant(self, d: DimSize) -> bool:
|
||||
assert isinstance(d, _DimPolynomial)
|
||||
assert isinstance(d, _DimExpr)
|
||||
return False
|
||||
|
||||
def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool:
|
||||
@ -508,7 +705,7 @@ class DimensionHandlerPoly(core.DimensionHandler):
|
||||
q, r = _ensure_poly(sz1, "divide_shape").divmod(_ensure_poly(sz2, "divide_shape"))
|
||||
except InconclusiveDimensionOperation as e:
|
||||
raise InconclusiveDimensionOperation(err_msg + f"\nDetails: {e}")
|
||||
if r != 0:
|
||||
if not core.symbolic_equal_dim(r, 0):
|
||||
raise InconclusiveDimensionOperation(err_msg + f"\nRemainder is not zero: {r}")
|
||||
return q # type: ignore[return-value]
|
||||
|
||||
@ -528,11 +725,11 @@ class DimensionHandlerPoly(core.DimensionHandler):
|
||||
"""Turns a dimension size into a Jax value that we can compute with."""
|
||||
return _dim_as_value(d)
|
||||
|
||||
core._SPECIAL_DIMENSION_HANDLERS[_DimPolynomial] = DimensionHandlerPoly()
|
||||
dtypes.python_scalar_dtypes[_DimPolynomial] = dtypes.python_scalar_dtypes[int]
|
||||
core._SPECIAL_DIMENSION_HANDLERS[_DimExpr] = DimensionHandlerPoly()
|
||||
dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int]
|
||||
|
||||
def _einsum_contract_path(*operands, **kwargs):
|
||||
"""Like opt_einsum.contract_path, with support for DimPolynomial shapes.
|
||||
"""Like opt_einsum.contract_path, with support for DimExpr shapes.
|
||||
|
||||
We use opt_einsum.contract_path to compute the schedule, using a fixed
|
||||
constant for all dimension variables. This is safe because we throw an
|
||||
@ -554,7 +751,7 @@ def _einsum_contract_path(*operands, **kwargs):
|
||||
if core.is_constant_dim(d):
|
||||
return d
|
||||
else:
|
||||
if not isinstance(d, _DimPolynomial):
|
||||
if not isinstance(d, _DimExpr):
|
||||
raise TypeError(f"Encountered unexpected shape dimension {d}")
|
||||
# It is Ok to replace all polynomials with the same value. We may miss
|
||||
# here some errors due to non-equal dimensions, but we catch them
|
||||
@ -572,10 +769,10 @@ def _einsum_contract_path(*operands, **kwargs):
|
||||
contract_operands.append(operands[idx[0]])
|
||||
return contract_operands, contractions
|
||||
|
||||
lax_numpy._poly_einsum_handlers[_DimPolynomial] = _einsum_contract_path
|
||||
lax_numpy._poly_einsum_handlers[_DimExpr] = _einsum_contract_path
|
||||
|
||||
# A JAX primitive with no array arguments but with a dimension parameter
|
||||
# that is a DimPoly. The value of the primitive is the value of the dimension,
|
||||
# that is a DimExpr. The value of the primitive is the value of the dimension,
|
||||
# using int64 in x64 mode or int32 otherwise (dim_as_value_dtype())
|
||||
dim_as_value_p = core.Primitive("dim_as_value")
|
||||
def dim_as_value_dtype():
|
||||
@ -701,7 +898,7 @@ def _parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
m = re.match(r"^([a-zA-Z]\w*)(\^(\d+))?$", factor_spec)
|
||||
if not m:
|
||||
raise ValueError(f"polymorphic shape {repr(spec)} has invalid syntax (unexpected term '{factor_spec}')")
|
||||
var = _DimPolynomial.from_var(m.group(1))
|
||||
var = _DimExpr.from_var(m.group(1))
|
||||
if m.group(3) is None:
|
||||
return var
|
||||
return var ** int(m.group(3))
|
||||
@ -793,7 +990,7 @@ _JaxValue = Any
|
||||
@dataclasses.dataclass
|
||||
class DimEquation:
|
||||
# Represents poly == _expr
|
||||
poly: _DimPolynomial
|
||||
poly: _DimExpr
|
||||
dim_expr: _JaxValue # Of type dim_as_value_dtype()
|
||||
|
||||
|
||||
@ -810,7 +1007,7 @@ def get_shape_evaluator(dim_vars: Sequence[str], shape: Sequence[DimSize]) ->\
|
||||
def eval_dim(d: DimSize):
|
||||
return d.evaluate(shape_env_jax) # type: ignore[union-attr]
|
||||
|
||||
return tuple(eval_dim(d) if type(d) is _DimPolynomial else np.array(d, dtype=dim_as_value_dtype()) # type: ignore
|
||||
return tuple(eval_dim(d) if type(d) is _DimExpr else np.array(d, dtype=dim_as_value_dtype()) # type: ignore
|
||||
for d in shape)
|
||||
return eval_shape
|
||||
|
||||
@ -941,7 +1138,7 @@ def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
|
||||
# We have some equations that we cannot solve further
|
||||
unsolved_vars: Set[str] = set()
|
||||
unsolved_polys: List[_DimPolynomial] = []
|
||||
unsolved_polys: List[_DimExpr] = []
|
||||
for eqn in eqns:
|
||||
unsolved_vars = unsolved_vars.union(eqn.poly.get_vars())
|
||||
unsolved_polys.append(eqn.poly)
|
||||
|
@ -21,7 +21,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import collections
|
||||
import functools
|
||||
from functools import partial
|
||||
import operator
|
||||
import operator as op
|
||||
import re
|
||||
|
||||
import jax
|
||||
@ -54,7 +54,7 @@ from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation
|
||||
PS = jax2tf.PolyShape
|
||||
_f32 = np.float32
|
||||
|
||||
class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
class DimExprTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_parse_poly_spec(self):
|
||||
self.assertEqual((2, 3), shape_poly._parse_spec(None, (2, 3)))
|
||||
@ -115,17 +115,17 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual(False, a != a)
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
"Dimension polynomial comparison 'a' == 'b' is inconclusive"):
|
||||
"Symbolic dimension comparison 'a' == 'b' is inconclusive"):
|
||||
a.eq(b)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
"Dimension polynomial comparison 'a' == 'b' is inconclusive"):
|
||||
"Symbolic dimension comparison 'a' == 'b' is inconclusive"):
|
||||
a == b
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
"Dimension polynomial comparison 'a' == 'b' is inconclusive"):
|
||||
"Symbolic dimension comparison 'a' == 'b' is inconclusive"):
|
||||
a != b
|
||||
|
||||
self.assertLen({a, a}, 1)
|
||||
@ -134,7 +134,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertIn(b, {a, b})
|
||||
self.assertIn(a, [a, b])
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Dimension polynomial comparison .* is inconclusive"):
|
||||
"Symbolic dimension comparison .* is inconclusive"):
|
||||
b in [a, b]
|
||||
|
||||
def test_get_vars(self):
|
||||
@ -147,7 +147,8 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
|
||||
self.assertEqual(1, (a * a - b).evaluate(dict(a=2, b=3)))
|
||||
self.assertEqual(2, (a * a - b + 1).evaluate(dict(a=-2, b=3)))
|
||||
self.assertEqual(1, ((a * a) // b).evaluate(dict(a=2, b=3)))
|
||||
self.assertEqual(4, ((a * a) % b).evaluate(dict(a=5, b=7)))
|
||||
|
||||
def test_dim_vars_symbolic_equal(self):
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
@ -170,14 +171,75 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_poly_bounds(self):
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
self.assertEqual(a.bounds(), (1, None))
|
||||
self.assertEqual((2 * a).bounds(), (2, None))
|
||||
self.assertEqual((2 * a - 3).bounds(), (-1, None))
|
||||
self.assertEqual((-2 * a - 3).bounds(), (None, -5))
|
||||
self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, None))
|
||||
self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (None, None))
|
||||
self.assertEqual((a + b - a * b + a * b * a).bounds(), (None, None))
|
||||
self.assertEqual((a + 2 * b - a).bounds(), (2, None))
|
||||
bounded_le4 = 5 - a
|
||||
bounded_ge2 = b + 1
|
||||
bounded_ge0_le4 = a % 5
|
||||
self.assertEqual(a.bounds(), (1, np.PINF))
|
||||
self.assertEqual(bounded_le4.bounds(), (np.NINF, 4))
|
||||
self.assertEqual(bounded_ge2.bounds(), (2, np.PINF))
|
||||
self.assertEqual(bounded_ge0_le4.bounds(), (0, 4))
|
||||
|
||||
# Additions
|
||||
self.assertEqual((bounded_ge0_le4 + bounded_le4).bounds(), (np.NINF, 8))
|
||||
self.assertEqual((bounded_ge0_le4 + bounded_ge2).bounds(), (2, np.PINF))
|
||||
self.assertEqual((bounded_le4 + bounded_ge2).bounds(), (np.NINF, np.PINF))
|
||||
|
||||
# Subtractions
|
||||
self.assertEqual((bounded_ge0_le4 - bounded_le4).bounds(), (-4, np.PINF))
|
||||
self.assertEqual((- bounded_ge0_le4 + bounded_le4).bounds(), (np.NINF, 4))
|
||||
self.assertEqual((bounded_ge0_le4 - bounded_ge2).bounds(), (np.NINF, 2))
|
||||
self.assertEqual((- bounded_ge0_le4 + bounded_ge2).bounds(), (-2, np.PINF))
|
||||
self.assertEqual((bounded_le4 - bounded_ge2).bounds(), (np.NINF, 2))
|
||||
self.assertEqual((- bounded_le4 + bounded_ge2).bounds(), (-2, np.PINF))
|
||||
|
||||
# Multiplications
|
||||
self.assertEqual((2 * a - 3).bounds(), (-1, np.PINF))
|
||||
self.assertEqual((-2 * a - 3).bounds(), (np.NINF, -5))
|
||||
self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, np.PINF))
|
||||
self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (np.NINF, np.PINF))
|
||||
self.assertEqual((a + b - a * b + a * b * a).bounds(), (np.NINF, np.PINF))
|
||||
self.assertEqual((a + 2 * b - a).bounds(), (2, np.PINF))
|
||||
self.assertEqual((a + 2 * b - a).bounds(), (2, np.PINF))
|
||||
|
||||
# mod
|
||||
self.assertEqual(((b + 1) % 2).bounds(), (0, 1))
|
||||
self.assertEqual(((b + 1) % -2).bounds(), (-1, 0))
|
||||
self.assertEqual(((b - 4) % 2).bounds(), (0, 1))
|
||||
self.assertEqual(((b + 1) % a).bounds(), (0, np.PINF))
|
||||
self.assertEqual((11 % (a + 1)).bounds(), (0, np.PINF))
|
||||
self.assertEqual((-11 % (a + 1)).bounds(), (0, np.PINF))
|
||||
self.assertEqual((b % (a - 2)).bounds(), (np.NINF, np.PINF))
|
||||
|
||||
# floordiv
|
||||
self.assertEqual(((a + 4) // 2).bounds(), (2, np.PINF))
|
||||
self.assertEqual(((a + 4) // -2).bounds(), (np.NINF, -3))
|
||||
self.assertEqual(((a + 5) // 2).bounds(), (3, np.PINF))
|
||||
self.assertEqual(((a + 5) // -2).bounds(), (np.NINF, -3))
|
||||
self.assertEqual((11 // (a + 1)).bounds(), (0, 5))
|
||||
self.assertEqual((-11 // (a + 1)).bounds(), (-6, -1))
|
||||
self.assertEqual((-11 // (- a)).bounds(), (0, 11)) # finite negative dividend, infinite divisor
|
||||
self.assertEqual(((b + 1) // (a + 1)).bounds(), (0, np.PINF))
|
||||
self.assertEqual((-b // (a + 1)).bounds(), (np.NINF, -1))
|
||||
|
||||
|
||||
|
||||
|
||||
# Generate test cases for floordiv and mod: (a + N) // +-2, (N - a) // +-2
|
||||
# and then evaluate them for a = 1, 5, 10000
|
||||
div_mod_atoms = [
|
||||
operation(op1 + n, div)
|
||||
for op1 in (a, a + 10, a + 11, -a, -a + 10, -a + 11)
|
||||
for n in (-3, -1, 0, 1, 3)
|
||||
for div in (-2, 2, a + 4, -4 - a) # Either negative, or positive
|
||||
for operation in (op.floordiv, op.mod)
|
||||
]
|
||||
for atom in div_mod_atoms:
|
||||
lb, ub = atom.bounds()
|
||||
self.assertLessEqual(lb, ub)
|
||||
for a_val in (1, 5, 10000):
|
||||
atom_val = atom.evaluate(dict(a=a_val))
|
||||
self.assertGreaterEqual(atom_val, lb)
|
||||
self.assertLessEqual(atom_val, ub)
|
||||
|
||||
def test_poly_equal(self):
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
@ -195,9 +257,27 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertFalse((2 * a * b * a + 1).eq(a * b * a))
|
||||
self.assertFalse((3 * a * b * a - 1).eq(a * b * a))
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
re.escape("Dimension polynomial comparison '3*a^2*b + -2' == 'a^2*b' is inconclusive")):
|
||||
re.escape("Symbolic dimension comparison '3*a^2*b + -2' == 'a^2*b' is inconclusive")):
|
||||
(3 * a * b * a - 2).eq(a * b * a)
|
||||
|
||||
self.assertTrue(a % b == a % b)
|
||||
self.assertTrue(a % b - a % b == 0)
|
||||
self.assertTrue(a // b == a // b)
|
||||
self.assertTrue(a // b - a // b == 0)
|
||||
|
||||
self.assertTrue(a % b == (2 * a // 2) % (a + b - a))
|
||||
self.assertTrue(a // b == (2 * a // 2) // (a + b - a))
|
||||
|
||||
self.assertTrue(a, a + (a + b) // b - (b + a) // b)
|
||||
|
||||
# Test the normaliation (a // b) * b == a - a % b
|
||||
self.assertTrue((a // 2) * 2 == a - a % 2)
|
||||
self.assertTrue((a // 2) + (a // 2) == a - a % 2)
|
||||
self.assertTrue((a // 2) * 6 == 3 * a - 3 * (a % 2))
|
||||
self.assertTrue((a // b) * b == a - a % b)
|
||||
self.assertTrue(2 * (a // b) * b * b == 2 * b * a - 2 * b * (a % b))
|
||||
self.assertTrue(a // (2 * b) * 2 * b == a - a % (2 * b))
|
||||
|
||||
def test_poly_compare(self):
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
poly = 4 * a + b + 3
|
||||
@ -236,14 +316,16 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertTrue(core.greater_equal_shape((a, 2), (1, 1)))
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Dimension polynomial comparison .* is inconclusive"):
|
||||
"Symbolic dimension comparison .* is inconclusive"):
|
||||
core.greater_equal_dim(a, 2)
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Dimension polynomial comparison .* is inconclusive"):
|
||||
"Symbolic dimension comparison .* is inconclusive"):
|
||||
core.greater_equal_dim(a, b)
|
||||
|
||||
def test_poly_int_results(self):
|
||||
# Whenever the result is an integer, it should be represented as an
|
||||
# Python integer, not a symbolic dimension.
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
self.assertEqual(a + 2 - a, 2)
|
||||
self.assertIsInstance(a + 2 - a, int)
|
||||
@ -265,16 +347,15 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
(3 * a - 2, 3, a - 1, 1),
|
||||
(3 * a * a * b + 2 * b * b * a, a * b, 3 * a + 2 * b, 0),
|
||||
(a * a - b * b, a + b, a - b, 0),
|
||||
(a, b, None, None),
|
||||
(3 * a, 2, None, None),
|
||||
(2 * a * b + b * b, a + b, None, None),
|
||||
(3, a, None, None),
|
||||
(a, b, "floordiv(a, b)", "mod(a, b)"),
|
||||
(3 * a, 2, "floordiv(3*a, 2)", "mod(3*a, 2)"),
|
||||
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, a + b)", "mod(2*a*b + b^2, a + b)"),
|
||||
(3, a, "floordiv(3, a)", "mod(3, a)"),
|
||||
])
|
||||
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
|
||||
if quotient is None:
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Cannot divide .* by .*"):
|
||||
divmod(dividend, divisor)
|
||||
if isinstance(quotient, str):
|
||||
d1, d2 = divmod(dividend, divisor)
|
||||
self.assertEqual((quotient, remainder), (str(d1), str(d2)))
|
||||
else:
|
||||
self.assertEqual((quotient, remainder), divmod(dividend, divisor))
|
||||
|
||||
@ -297,11 +378,9 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2)))
|
||||
self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2)))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
re.escape(
|
||||
"Cannot compute stride for dimension 'a', window_size '1', stride '2'.\nDetails: Cannot divide 'a + -1' by '2'")):
|
||||
core.stride_shape((a, 20), (1, 3), (2, 2))
|
||||
(stride0, stride1) = core.stride_shape((a, 20), (1, 3), (2, 2))
|
||||
self.assertEqual("floordiv(a + -1, 2) + 1", str(stride0))
|
||||
self.assertEqual(9, stride1)
|
||||
|
||||
|
||||
class PolyHarness(Harness):
|
||||
@ -859,7 +938,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# Arguments are of the form [([x00, x01], [x10]), dict(a=ya, b=yb)]
|
||||
def add_all_jax(x_pair_of_list, y_dict):
|
||||
x_list_0, x_list_1 = x_pair_of_list
|
||||
return functools.reduce(operator.add,
|
||||
return functools.reduce(op.add,
|
||||
x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]])
|
||||
|
||||
check_shape_poly(self,
|
||||
@ -1193,15 +1272,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
jax2tf.convert(lambda x: jnp.reshape(x, (np.prod(x.shape),)),
|
||||
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type, got Traced"):
|
||||
jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * np.array([x.shape[1]]))),
|
||||
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
||||
|
||||
jax2tf.convert(lambda x: (x + x.shape[0] + jnp.sin(x.shape[0]),
|
||||
5. + x.shape[0],
|
||||
np.ones((5,), dtype=np.int32) - x.shape[0]),
|
||||
polymorphic_shapes = ["b"])(np.ones((3,)))
|
||||
jax2tf.convert(lambda x: x + x.shape[0] + jnp.sin(x.shape[0]),
|
||||
polymorphic_shapes=["b"])(np.ones(3))
|
||||
|
||||
jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
|
||||
polymorphic_shapes=["(v, _)"])(np.ones((3, 4)))
|
||||
@ -1238,7 +1310,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=["(v, 4)"])(np.ones((4, 4)))
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
re.compile("Cannot divide evenly the sizes of shapes \\(b, 5, 7\\) and \\(2, -1\\).*Details: Cannot divide '35\\*b' by '-2'",
|
||||
re.compile("Cannot divide evenly the sizes of shapes \\(b, 5, 7\\) and \\(2, -1\\)",
|
||||
re.DOTALL)):
|
||||
jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
|
||||
polymorphic_shapes=["(b, _, _)"])(np.ones((4, 5, 7)))
|
||||
@ -1255,7 +1327,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
core.InconclusiveDimensionOperation,
|
||||
re.escape("Dimension polynomial comparison 'a + 1' == 'b' is inconclusive")):
|
||||
re.escape("Symbolic dimension comparison 'a + 1' == 'b' is inconclusive")):
|
||||
jax2tf.convert(lambda x: 0 if x.shape[0] + 1 == x.shape[1] else 1,
|
||||
polymorphic_shapes=["(a, b)"])(np.ones((4, 4)))
|
||||
|
||||
@ -1365,8 +1437,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
dict(testcase_name=f"_{op.__name__}_other={other}:{type(other)}{'_other_jnp_array' if other_jnp_array else ''}{'_swap' if swap else ''}",
|
||||
op=op, other=other,
|
||||
other_jnp_array=other_jnp_array, swap=swap)
|
||||
for op in [operator.add, operator.mul, operator.sub,
|
||||
operator.mod, operator.floordiv, operator.truediv]
|
||||
for op in [op.add, op.mul, op.sub,
|
||||
op.mod, op.floordiv, op.truediv]
|
||||
for other in [
|
||||
2, np.int32(2), 2., np.float32(2),
|
||||
np.array(2, dtype=np.int32), np.arange(1, 5, dtype=np.int32),
|
||||
@ -1376,10 +1448,10 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
[True, False] if np.shape(other) == (7,) else [False]) # type: ignore
|
||||
for swap in [False, True] # The poly is the left op by default
|
||||
])
|
||||
def test_poly_binary_op(self, *, op=operator.truediv,
|
||||
other=2,
|
||||
def test_poly_binary_op(self, *, op=op.add,
|
||||
other=np.arange(2, dtype=np.int32),
|
||||
other_jnp_array=False,
|
||||
swap=False):
|
||||
swap=True):
|
||||
# Test arithmetic operations with poly and a variety of other operand types
|
||||
if config.jax2tf_default_experimental_native_lowering:
|
||||
raise unittest.SkipTest("TODO(necula): dim_as_value in native mode")
|
||||
@ -1392,15 +1464,14 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
# If the other op is an integer then the result is a symbolic dim
|
||||
try:
|
||||
operator.index(other)
|
||||
op.index(other)
|
||||
other_isint = True
|
||||
except Exception:
|
||||
other_isint = False
|
||||
|
||||
if (hasattr(poly, "dimension_as_value") and
|
||||
other_isint and
|
||||
op.__name__ != "truediv" and
|
||||
not (swap and op.__name__ in ["floordiv", "mod"])):
|
||||
op.__name__ != "truediv"):
|
||||
# If we running under jax2tf and "other" is an integer the result
|
||||
# should be a symbolic dimension
|
||||
self.assertTrue(isinstance(res, int) or hasattr(res, "dimension_as_value"))
|
||||
@ -1625,26 +1696,19 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[(0, 1)]),
|
||||
|
||||
# Issue #11402
|
||||
# We play a trick here. Since the stride is 2, when we compute the padding
|
||||
# for "SAME" we need to divide by 2. We cannot do this in general, so we
|
||||
# write the test with the assumption that the dimension is a multiple of 2.
|
||||
# We pass the lhs as (1, b, 2, 16) and then we
|
||||
# reshape it as (1, 2*b, 16), so that we know that the lhs's dimension 1
|
||||
# is a multiple of 2.
|
||||
PolyHarness("conv_general_dilated", "1d_1",
|
||||
PolyHarness("conv_general_dilated", "1d_stride=1",
|
||||
lambda lhs, rhs: lax.conv_general_dilated(
|
||||
jnp.reshape(lhs, (1, -1, 16)), rhs,
|
||||
window_strides=(2,),
|
||||
lhs, rhs,
|
||||
window_strides=(1,),
|
||||
padding="SAME",
|
||||
rhs_dilation=None,
|
||||
dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
|
||||
rhs_spec=(2, 1, 0),
|
||||
out_spec=(0, 2, 1))),
|
||||
arg_descriptors=[RandArg((1, 6, 2, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
poly_axes=[1, None]).both_enable_and_disable_xla(),
|
||||
# The same example from above, but without the reshape trick.
|
||||
PolyHarness("conv_general_dilated", "1d_1err",
|
||||
# The same example from above, but with stride=2.
|
||||
PolyHarness("conv_general_dilated", "1d_stride=2_even",
|
||||
lambda lhs, rhs: lax.conv_general_dilated(
|
||||
lhs, rhs,
|
||||
window_strides=(2,),
|
||||
@ -1655,9 +1719,20 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
out_spec=(0, 2, 1))),
|
||||
arg_descriptors=[RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
poly_axes=[1, None],
|
||||
expect_error=(core.InconclusiveDimensionOperation,
|
||||
"Cannot divide .* by '2'")
|
||||
).both_enable_and_disable_xla(),
|
||||
# The same example from above, but with stride=2 and odd input size.
|
||||
PolyHarness("conv_general_dilated", "1d_stride=2_odd",
|
||||
lambda lhs, rhs: lax.conv_general_dilated(
|
||||
lhs, rhs,
|
||||
window_strides=(2,),
|
||||
padding="SAME",
|
||||
rhs_dilation=None,
|
||||
dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
|
||||
rhs_spec=(2, 1, 0),
|
||||
out_spec=(0, 2, 1))),
|
||||
arg_descriptors=[RandArg((1, 13, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
poly_axes=[1, None],
|
||||
).both_enable_and_disable_xla(),
|
||||
# Issue #11402
|
||||
PolyHarness("conv_general_dilated", "1d_2",
|
||||
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
|
||||
@ -1797,7 +1872,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
polymorphic_shapes=["(2, b0)", "(2, b1)"],
|
||||
input_signature=[tf.TensorSpec((2, None)), tf.TensorSpec((2, None))],
|
||||
expect_error=(core.InconclusiveDimensionOperation,
|
||||
"Dimension polynomial comparison 'b1' == 'b0' is inconclusive")),
|
||||
"Symbolic dimension comparison 'b1' == 'b0' is inconclusive")),
|
||||
PolyHarness("eye", "N=poly_M=None",
|
||||
lambda x: jnp.eye(x.shape[0]),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
@ -2045,7 +2120,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: jnp.repeat(x, repeats=x.shape[0], axis=None, total_repeat_length=8),
|
||||
arg_descriptors=[RandArg((3, 2), _f32)],
|
||||
poly_axes=[0],
|
||||
expect_error=(ValueError, "jnp.repeat with a DimPolynomial `repeats` is supported only .*")),
|
||||
expect_error=(ValueError, "jnp.repeat with a non-constant `repeats` is supported only .*")),
|
||||
PolyHarness("reshape", "0",
|
||||
lambda x: x.reshape([x.shape[0], -1]),
|
||||
arg_descriptors=[RandArg((3, 2, 3), _f32)],
|
||||
@ -2130,21 +2205,19 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: lax.slice_in_dim(x, 0, -1, stride=1, axis=0),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0]),
|
||||
PolyHarness("jnp.split", "",
|
||||
lambda x: jnp.split(x, 2, axis=0),
|
||||
arg_descriptors=[RandArg((8, 5), _f32)],
|
||||
polymorphic_shapes=["2*b, ..."],
|
||||
input_signature=[tf.TensorSpec([None, 5], dtype=_f32)]),
|
||||
PolyHarness("jnp.array_split", "even",
|
||||
lambda x: jnp.array_split(x, 2, axis=0),
|
||||
arg_descriptors=[RandArg((8, 5), _f32)],
|
||||
polymorphic_shapes=["2*b, ..."],
|
||||
input_signature=[tf.TensorSpec([None, 5], dtype=_f32)]),
|
||||
PolyHarness("jnp.array_split", "odd",
|
||||
lambda x: jnp.array_split(x, 2, axis=0),
|
||||
arg_descriptors=[RandArg((9, 5), _f32)],
|
||||
polymorphic_shapes=["2*b + 1, ..."],
|
||||
input_signature=[tf.TensorSpec([None, 5], dtype=_f32)]),
|
||||
PolyHarness("slice_in_dim", "stride=2_even",
|
||||
lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=2, axis=0),
|
||||
arg_descriptors=[RandArg((12, 4), _f32)],
|
||||
poly_axes=[0]),
|
||||
PolyHarness("slice_in_dim", "stride=2_odd",
|
||||
lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=2, axis=0),
|
||||
arg_descriptors=[RandArg((13, 4), _f32)],
|
||||
poly_axes=[0]),
|
||||
# Not yet, the slice_in_dim does int(stride)
|
||||
# PolyHarness("slice_in_dim", "stride=sym",
|
||||
# lambda x: lax.slice_in_dim(x, 0, x.shape[0], stride=x.shape[0] // 4, axis=0),
|
||||
# arg_descriptors=[RandArg((13, 4), _f32)],
|
||||
# poly_axes=[0]),
|
||||
PolyHarness("squeeze", "axis=empty",
|
||||
jnp.squeeze,
|
||||
arg_descriptors=[RandArg((5,), _f32), StaticArg(())],
|
||||
|
Loading…
x
Reference in New Issue
Block a user