[shape_poly] Refactor shape_poly to be independent of TF.

Separates out into shape_poly_tf.py the TF-specific parts
of shape polymorphism (essentially using tf.shape to get
the actual shape, along with tf.{add,subtract,multiply} to
evaluate shape polynomials.

PiperOrigin-RevId: 410529872
This commit is contained in:
George Necula 2021-11-17 08:05:29 -08:00 committed by jax authors
parent 0f56838435
commit 72f5a3ca5c
4 changed files with 287 additions and 156 deletions

View File

@ -47,6 +47,7 @@ from jax.interpreters import xla
from jax._src.lib import xla_client
from . import shape_poly
from . import shape_poly_tf
from . import impl_no_xla
import numpy as np
@ -304,12 +305,19 @@ def convert(fun: Callable,
in_tree.children()[0],
polymorphic_shapes_))
def fix_tf1_shape(arg: TfVal) -> Sequence[Optional[int]]:
tf_arg_shape = np.shape(arg)
return tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape)
args_shapes_flat = tuple(fix_tf1_shape(a) for a in args_flat)
args_dim_exprs_flat = tuple(shape_poly_tf.DimExprTfVal.for_arg(a) for a in args_flat)
# Construct the abstract values for the flat arguments, possibly based on
# the input shapes and the polymorphic_shapes if given. May create new shape
# variables. May cast the args_flat to JAX types, using JAX's interpretation
# of types of constants.
args_avals_flat, shapeenv = _args_to_avals_and_env(
args_flat, arg_dtypes_flat, polymorphic_shapes_flat)
args_avals_flat, shapeenv = shape_poly.args_avals_and_env(
args_shapes_flat, arg_dtypes_flat, polymorphic_shapes_flat,
args_dim_exprs_flat)
# This function may take pytrees of TfVals. We can only set
# tf.custom_gradient on functions that take a flat argument list.
@ -638,47 +646,6 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
_thread_local_state.constant_cache[const_key] = (val, tf_val)
return tf_val, jax_dtype
def _args_to_avals_and_env(
args: Sequence[TfVal],
arg_jax_dtypes: Sequence[DType],
polymorphic_shapes: Sequence[Optional[Union[str, PolyShape]]]) -> \
Tuple[Sequence[core.ShapedArray], shape_poly.ShapeEnv]:
"""Computes canonicalized args, abstract values and a dimension environment for arguments.
Args:
args: the arguments, TF inputs. Must be tf.Tensor or tf.Variable.
arg_dtypes: the inferred JAX dtypes for the args.
polymorphic_shapes: the polymorphic specifications for the arguments.
Returns: a tuple of: a sequence of abstract values corresponding to the
arguments, and a dimension variable environment.
"""
dim_equations: List[shape_poly.DimEquation] = []
def input_aval(arg: TfVal,
arg_jax_dtype: DType,
polymorphic_shape: Optional[str]) -> core.ShapedArray:
"""The abstract value for an input."""
arg_shape = np.shape(arg)
aval_shape = shape_poly.parse_spec(polymorphic_shape, arg_shape)
arg_tf_shape = tf.shape(arg)
for i, d in enumerate(aval_shape):
dim_size = arg_shape[i]
if isinstance(dim_size, tf.compat.v1.Dimension):
dim_size = dim_size.value
if not shape_poly.is_poly_dim(d):
assert d == dim_size
else:
dim_equations.append(shape_poly.DimEquation(
poly=d, tf_expr=arg_tf_shape[i])) # type: ignore
return core.ShapedArray(aval_shape, arg_jax_dtype)
avals = tuple(map(input_aval, args, arg_jax_dtypes, polymorphic_shapes)) # type: ignore
shapeenv = shape_poly.solve_dim_equations(dim_equations)
return avals, shapeenv
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
assert all(map(lambda x: x is not None, shape)), (
@ -2589,6 +2556,11 @@ def _pjit_sharding_constraint(arg: TfVal, *,
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
def _dim_as_value_jax2tf(dim: shape_poly.DimSize):
dim_tf, = _eval_shape((dim,))
return dim_tf
tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf
def _register_checkpoint_pytrees():
"""Registers TF custom container types as pytrees."""
@ -2620,5 +2592,3 @@ def _register_checkpoint_pytrees():
_register_checkpoint_pytrees()
shape_poly._register_conversion_rules()

View File

@ -11,11 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shape polymorphism support for jax2tf.
"""Shape polymorphism support.
For usage instructions, read the jax2tf.convert docstring, and the
This was built initially for jax2tf, but it is now customizeable to be
independent of TF. The idea is that we introduce a set of dimension variables
at the top-level of a `jit` function. They are introduced implicitly by way
of specifying for each dimension of each argument a dimension polynomial
in terms of some dimension variables. All dimension variables are assumed to
range over integers greater or equal to 1.
Dimension polynomials overload some integer operations, such as
add, multiply, equality, etc. The JAX NumPy layer and the LAX layers have been
touched up to be sensitive to handling shapes that contain dimension
polynomials. This enables many JAX programs to be traced with dimension
polynomials in some dimensions. A priority has been to enable the batch
dimension in neural network examples to be polymorphic.
The best documentation at the moment is in the
jax2tf.convert docstring, and the
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
"""
import collections
import dataclasses
@ -23,18 +37,15 @@ import itertools
import functools
import operator as op
import re
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union
import jax
from jax._src.numpy import lax_numpy
import opt_einsum
from jax import config
from jax import core
from . import jax2tf as jax2tf_internal
import numpy as np
import tensorflow as tf # type: ignore[import]
DimSize = core.DimSize
Shape = core.Shape
@ -62,7 +73,7 @@ class _DimMon(dict):
"""Represents a multivariate monomial, such as n^3 * m.
The representation is a dictionary mapping var:exponent.
The `var` are strings and the exponents are >= 1.
The `var` are strings and the exponents are integers >= 1.
The dimension variables are assumed to range over integers >= 1.
"""
def __hash__(self):
@ -124,21 +135,13 @@ class _DimMon(dict):
elif diff > 0: d[key] = diff
return _DimMon(d)
def evaluate(self, env: Dict[str, Any]) -> Any:
prod = lambda xs: functools.reduce(op.mul, xs) if xs else 1
def pow_opt(v, p):
return v if p == 1 else pow(v, p)
def evaluate(self, env: "ShapeEnv") -> Union["DimExpr", int]:
prod = lambda xs: functools.reduce(_multiply_dim_expr, xs) if xs else 1
def pow_opt(v: DimExpr, p: int) -> DimExpr:
return v if p == 1 else prod([v] * p)
return prod([pow_opt(env[id], deg) for id, deg in self.items()])
def _multiply(coeff, v: TfVal) -> TfVal:
try:
coeff = int(coeff)
except:
return coeff * v
else:
return 0 if coeff == 0 else v if coeff == 1 else coeff * v
class _DimPolynomial(dict):
"""Polynomial with integer coefficients for polymorphic shapes.
@ -375,12 +378,6 @@ class _DimPolynomial(dict):
ub = None if ub is None else ub + coeff
return lb, ub
def evaluate(self, env: Dict[str, Any]) -> Any:
def pow_opt(v, p):
return v if p == 1 else pow(v, p)
terms = [_multiply(coeff, mon.evaluate(env)) for mon, coeff in self.items()]
return sum(terms) if len(terms) > 1 else terms[0]
@property
def is_constant(self):
return len(self) == 1 and next(iter(self)).degree == 0
@ -390,6 +387,10 @@ class _DimPolynomial(dict):
"""Returns the highest degree term that comes first lexicographically."""
return max(self.items())
def evaluate(self, env: "ShapeEnv") -> Union["DimExpr", int]:
terms = [_multiply_dim_expr(mon.evaluate(env), coeff) for mon, coeff in self.items()]
return functools.reduce(_add_dim_expr, terms) if len(terms) > 1 else terms[0]
def _ensure_poly(p: DimSize) -> _DimPolynomial:
if isinstance(p, _DimPolynomial): return p
@ -504,15 +505,6 @@ dim_as_value_p.def_abstract_eval(_dim_as_value_abstract)
def _dim_as_value(dim: DimSize):
return dim_as_value_p.bind(dim=dim)
def _dim_as_value_jax2tf(dim: DimSize):
dim_tf, = jax2tf_internal._eval_shape((dim,))
assert dim_tf.dtype == tf.int32
return dim_tf
def _register_conversion_rules():
jax2tf_internal.tf_impl[dim_as_value_p] = _dim_as_value_jax2tf
class PolyShape(tuple):
"""Tuple of polymorphic dimension specifications.
@ -527,8 +519,8 @@ class PolyShape(tuple):
return tuple.__new__(PolyShape, dim_specs)
def parse_spec(spec: Optional[Union[str, PolyShape]],
arg_shape: Sequence[Optional[int]]) -> Tuple[DimSize, ...]:
def _parse_spec(spec: Optional[Union[str, PolyShape]],
arg_shape: Sequence[Optional[int]]) -> Tuple[DimSize, ...]:
"""Parse the shape polymorphic specification for one array argument.
Args:
spec: a shape polymorphic specification, either a string, or a PolyShape.
@ -612,8 +604,6 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
if isinstance(dim_spec, str):
dim_spec = dim_spec.strip()
dim_size = arg_shape[i]
if isinstance(dim_size, tf.compat.v1.Dimension):
dim_size = dim_size.value
if dim_size is None:
def need_dim_var_msg():
msg = (f"polymorphic shape {repr(spec)} in axis {i} must contain a dimension variable "
@ -644,22 +634,133 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
dims = tuple([_process_dim(i, ds) for i, ds in enumerate(spec_tuple)])
return dims
_Expr = TypeVar("_Expr")
class DimExpr:
"""Encapsulates an expression that denotes the value of a dimension size
for some intermediate value in the computation. It is computed based on the
input values for the top-level function arguments and the polymorphic
specifications.
This class is intedend to be subclassed. An example of DimExpr, is
shape_poly_tf.DimExprTfVal which computes dimension sizes using TF ops,
e.g., `tf.shape(arg0)[0] + 1`.
"""
def __init__(self, raw: _Expr):
self._raw = raw
@property
def raw(self) -> _Expr:
"""Returns the raw value."""
return self._raw
@classmethod
def for_arg(cls, arg: _Expr) -> Sequence["DimExpr"]:
"""A list of dimension expressions, one for each dimensionof arg."""
raise NotImplementedError
def is_known_constant(self) -> Optional[int]:
"""Extract the constant if possible.
Typically this should work when the input arguments have known shape, but
it may not work when we compile code that has dynamic shapes.
"""
raise NotImplementedError
def add(self, other: Union["DimExpr", int]) -> "DimExpr":
raise NotImplementedError
def subtract(self, other: Union["DimExpr", int]) -> "DimExpr":
raise NotImplementedError
def multiply(self, other: Union["DimExpr", int]) -> "DimExpr":
raise NotImplementedError
def divmod(self, factor: int) -> Tuple["DimExpr", "DimExpr"]:
"""Like Python divmod."""
raise NotImplementedError
def _add_dim_expr(v1: Union["DimExpr", int],
v2: Union["DimExpr", int]) -> Union["DimExpr", int]:
if isinstance(v1, DimExpr):
return v1.add(v2) if v2 != 0 else v1
elif isinstance(v2, DimExpr):
return v2.add(v1) if v1 != 0 else v2
else:
return v1 + v2 # integers
def _multiply_dim_expr(v1: Union["DimExpr", int],
v2: Union[DimExpr, int]) -> Union[DimExpr, int]:
if isinstance(v1, DimExpr):
return v1.multiply(v2) if v2 != 1 else v1
elif isinstance(v2, DimExpr):
return v2.multiply(v1) if v1 != 1 else v2
else:
return v1 * v2 # integers
@dataclasses.dataclass
class DimEquation:
# Represents poly == tf_expr
# Represents poly == _expr
poly: _DimPolynomial
tf_expr: TfVal
dim_expr: DimExpr
# A dimension environment maps dimension variables to TF expressions that
# compute the value of the dimension. These expressions refer to the TF
# A dimension environment maps dimension variables to expressions that
# compute the value of the dimension. These expressions refer to the
# function arguments.
ShapeEnv = Dict[str, TfVal]
ShapeEnv = Dict[str, DimExpr]
DType = Any
def eval_shape(shape: Sequence[DimSize], shape_env: Dict[str, TfVal]) -> Sequence[TfVal]:
return tuple(d.evaluate(shape_env) if type(d) is _DimPolynomial else d # type: ignore
def eval_shape(shape: Sequence[DimSize], shape_env: ShapeEnv) -> Sequence[_Expr]:
def eval_dim(d: DimSize) -> Any:
d1 = d.evaluate(shape_env) # type: ignore[union-attr]
if isinstance(d1, int):
return d1
else:
return d1.raw
return tuple(eval_dim(d) if type(d) is _DimPolynomial else d # type: ignore
for d in shape)
def solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
def args_avals_and_env(
arg_shapes: Sequence[Sequence[Optional[int]]],
arg_jax_dtypes: Sequence[DType],
polymorphic_shapes: Sequence[Optional[Union[str, PolyShape]]],
arg_dim_exprs: Sequence[Sequence[DimExpr]]) -> \
Tuple[Sequence[core.ShapedArray], ShapeEnv]:
"""Computes abstract values and a dimension environment for arguments.
Args:
arg_shapes: the shapes for the arguments, possibly having None dimensions.
arg_dtypes: the inferred JAX dtypes for the args.
polymorphic_shapes: the polymorphic specifications for the arguments.
arg_dim_exprs: an expression that represents each dimension of each argument.
Returns: a tuple of: a sequence of abstract values corresponding to the
arguments, and a dimension variable environment.
"""
dim_equations: List[DimEquation] = []
def input_aval(arg_shape: Sequence[Optional[int]],
arg_jax_dtype: DType,
polymorphic_shape: Optional[str],
arg_dim_exprs: Sequence[DimExpr]) -> core.ShapedArray:
"""The abstract value for an input."""
aval_shape = _parse_spec(polymorphic_shape, arg_shape)
for i, d in enumerate(aval_shape):
if is_poly_dim(d):
dim_equations.append(DimEquation(
poly=d, dim_expr=arg_dim_exprs[i])) # type: ignore
return core.ShapedArray(aval_shape, arg_jax_dtype)
avals = tuple(map(input_aval, arg_shapes, arg_jax_dtypes, polymorphic_shapes, arg_dim_exprs)) # type: ignore
shapeenv = _solve_dim_equations(dim_equations)
return avals, shapeenv
def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
# Returns a shape environment if it can solve all dimension variables.
# Raises an exception if it cannot.
shapeenv: ShapeEnv = {}
@ -667,24 +768,24 @@ def solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
def _shapeenv_to_str() -> str:
if shapeenv:
return (" Partial solution: " +
", ".join([f"{var} = {val}" for var, val in shapeenv.items()]) + ".")
", ".join([f"{var} = {val.raw}" for var, val in shapeenv.items()]) + ".")
else:
return ""
def process_one_eqn(eqn: DimEquation) -> bool:
# Try to rewrite the equation as "var * factor_var = tf_expr" (a linear
# Try to rewrite the equation as "var * factor_var = dim_expr" (a linear
# uni-variate equation. Return False if this rewrite fails.
# Otherwise, add the variable to shapeenv and return True.
# The invariant is: var * factor_var + rest_eqn_poly = tf_expr
# The invariant is: var * factor_var + rest_eqn_poly = dim_expr
var, factor_var = None, None
tf_expr = eqn.tf_expr
dim_expr = eqn.dim_expr
for mon, factor in eqn.poly.items():
# Perhaps we can already evaluate this monomial (all vars solved)
try:
mon_value = mon.evaluate(shapeenv)
tf_expr = tf.math.subtract(tf_expr, _multiply(factor, mon_value))
dim_expr = dim_expr.subtract(_multiply_dim_expr(mon_value, factor))
continue
except KeyError:
# There are some indeterminate variables. We handle only the case of
@ -697,37 +798,33 @@ def solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
return False
if var is not None:
try:
var_value = tf.math.floordiv(tf_expr, factor_var) if factor_var != 1 else tf_expr
# Check that the division is even. Works only in eager mode.
if tf.math.floormod(tf_expr, factor_var).numpy() != 0:
msg = (f"Dimension variable {var} must have integer value >= 1. " # type: ignore
f"Found value {int(tf_expr.numpy()) / factor_var} when solving "
f"{eqn.poly} == {eqn.tf_expr}.{_shapeenv_to_str()}")
raise ValueError(msg)
if var_value.numpy() <= 0:
msg = (f"Dimension variable {var} must have integer value >= 1. "
f"Found value {int(var_value.numpy())} when solving "
f"{eqn.poly} == {eqn.tf_expr}.{_shapeenv_to_str()}")
raise ValueError(msg)
except AttributeError:
var_value, var_remainder = dim_expr.divmod(factor_var) # type: ignore
# Check that the division is even. Works only in eager mode.
var_remainder_int = var_remainder.is_known_constant()
if var_remainder_int is not None and var_remainder_int != 0:
# TODO(necula): check even in graph mode, by embedding the checks in
# the graph.
pass
msg = (f"Dimension variable {var} must have integer value >= 1. " # type: ignore
f"Found value {int(dim_expr.is_known_constant()) / factor_var} when solving "
f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
raise ValueError(msg)
var_value_int = var_value.is_known_constant()
if var_value_int is not None and var_value_int <= 0:
msg = (f"Dimension variable {var} must have integer value >= 1. "
f"Found value {int(var_value_int)} when solving "
f"{eqn.poly} == {eqn.dim_expr.raw}.{_shapeenv_to_str()}")
raise ValueError(msg)
shapeenv[var] = var_value
return True
else:
# All variables are resolved for this equation
try:
if tf_expr.numpy() != 0:
err_msg = (
"Found inconsistency when solving "
f"{eqn.poly} == {eqn.tf_expr}.{_shapeenv_to_str()}")
raise ValueError(err_msg)
except AttributeError:
# TODO(necula): check that the equation is satisfied even in graph
# mode.
pass
dim_expr_int = dim_expr.is_known_constant()
if dim_expr_int is not None and dim_expr_int != 0:
err_msg = (
"Found inconsistency when solving "
f"{eqn.poly} == {eqn.dim_expr.raw}.{_shapeenv_to_str()}")
raise ValueError(err_msg)
return True
while True:

View File

@ -0,0 +1,65 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specialized shape polymorphism support for jax2tf.
See the shape_poly.py module documentation, the jax2tf.convert docstring, and the
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
"""
from typing import Any, Optional, Sequence, Tuple, Union
import numpy as np
import tensorflow as tf # type: ignore[import]
from . import shape_poly
TfVal = Any
class DimExprTfVal(shape_poly.DimExpr):
"""Express dimensions using tf.shape and the TF arguments."""
def __init__(self, tfval: TfVal):
super(DimExprTfVal, self).__init__(tfval)
@classmethod
def for_arg(cls, arg: TfVal) -> Sequence[shape_poly.DimExpr]:
tf_shape = tf.shape(arg)
return tuple(DimExprTfVal(tf_shape[i]) for i in range(len(np.shape(arg))))
def is_known_constant(self) -> Optional[int]:
# When under TF eager, the dimension expressions should be constants.
# Under TF graph, they will not be.
try:
return self.raw.numpy()
except AttributeError as e:
assert str(e).find("numpy") > 0, e
return None
def add(self, other: Union[shape_poly.DimExpr, int]) -> shape_poly.DimExpr:
if isinstance(other, shape_poly.DimExpr):
other = other.raw # type: ignore[assignment]
return DimExprTfVal(tf.math.add(self.raw, other))
def subtract(self, other: Union[shape_poly.DimExpr, int]) -> shape_poly.DimExpr:
if isinstance(other, shape_poly.DimExpr):
other = other.raw # type: ignore[assignment]
return DimExprTfVal(tf.math.subtract(self.raw, other))
def multiply(self, other: Union[shape_poly.DimExpr, int]) -> shape_poly.DimExpr:
if isinstance(other, shape_poly.DimExpr):
other = other.raw # type: ignore[assignment]
return DimExprTfVal(tf.math.multiply(self.raw, other))
def divmod(self, factor: int) -> Tuple[shape_poly.DimExpr, shape_poly.DimExpr]:
dividend = DimExprTfVal(tf.math.floordiv(self.raw, factor)) if factor != 1 else self
mod = DimExprTfVal(tf.math.floormod(self.raw, factor))
return dividend, mod

View File

@ -26,6 +26,7 @@ import jax
from jax import core
from jax.experimental import jax2tf
from jax.experimental.jax2tf import shape_poly
from jax.experimental.jax2tf import shape_poly_tf
from jax import lax
import jax.numpy as jnp
from jax._src import test_util as jtu
@ -51,21 +52,19 @@ PS = jax2tf.PolyShape
class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
def test_parse_poly_spec(self):
self.assertEqual((2, 3), shape_poly.parse_spec(None, (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec("2, 3", (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec("2, _", (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec("2, ...", (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec(None, (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec("2, 3", (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec("2, _", (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec("2, ...", (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec("...", (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec(" ( 2 , 3 ) ", (2, 3)))
a, b = shape_poly.parse_spec("a, b", (2, 3))
self.assertEqual((a, 3), shape_poly.parse_spec("(a, ...) ", (None, 3)))
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual((a, 3), shape_poly._parse_spec("(a, ...) ", (None, 3)))
tshape = tf.TensorShape([None, 3])
self.assertEqual((a, 3), shape_poly.parse_spec("(a, ...) ", tshape))
# Test directly with tf.compat.v1.Dimension
self.assertEqual((a, 3), shape_poly.parse_spec("(a, ...) ", tshape.dims))
self.assertEqual((a, 3), shape_poly._parse_spec("(a, ...) ", tshape))
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
@parameterized.named_parameters(
dict(testcase_name=f"_dim_spec={dim_spec}",
dim_spec=dim_spec, dim_poly=dim_poly)
@ -82,8 +81,8 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
dim_poly=3 * a * b * a - 2):
# For internal usage only (the polymorphic_shapes of VJP) we need to
# parse polynomials.
self.assertEqual((dim_poly,), shape_poly.parse_spec(dim_spec, (2,)))
self.assertEqual((dim_poly,), shape_poly.parse_spec(str(dim_poly), (2,)))
self.assertEqual((dim_poly,), shape_poly._parse_spec(dim_spec, (2,)))
self.assertEqual((dim_poly,), shape_poly._parse_spec(str(dim_poly), (2,)))
@parameterized.named_parameters(
dict(testcase_name=f"_dim_spec={dim_spec}",
@ -101,11 +100,11 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
dim_poly=3 * a * b * a - 2):
# For internal usage only (the polymorphic_shapes of VJP) we need to
# parse polynomials.
self.assertEqual((dim_poly,), shape_poly.parse_spec(dim_spec, (2,)))
self.assertEqual((dim_poly,), shape_poly.parse_spec(str(dim_poly), (2,)))
self.assertEqual((dim_poly,), shape_poly._parse_spec(dim_spec, (2,)))
self.assertEqual((dim_poly,), shape_poly._parse_spec(str(dim_poly), (2,)))
def test_dim_vars(self):
a, b, a1 = shape_poly.parse_spec("a, b, a", (2, 3, 2))
a, b, a1 = shape_poly._parse_spec("a, b, a", (2, 3, 2))
self.assertEqual(True, a == a)
self.assertEqual(True, a == a1)
self.assertEqual(False, a != a)
@ -134,19 +133,19 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
b in [a, b]
def test_get_vars(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual({"a"}, a.get_vars())
self.assertEqual({"a", "b"}, (a * b * a).get_vars())
def test_evaluate(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual(1, (a * a - b).evaluate(dict(a=2, b=3)))
self.assertEqual(2, (a * a - b + 1).evaluate(dict(a=-2, b=3)))
def test_dim_vars_symbolic_equal(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertTrue(core.symbolic_equal_dim(a, a))
self.assertFalse(core.symbolic_equal_dim(a, 1))
self.assertFalse(core.symbolic_equal_dim(a, b))
@ -165,7 +164,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
self.assertTrue(core.symbolic_equal_dim(1, "a"))
def test_poly_bounds(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual(a.bounds(), (1, None))
self.assertEqual((2 * a).bounds(), (2, None))
self.assertEqual((2 * a - 3).bounds(), (-1, None))
@ -176,7 +175,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
self.assertEqual((a + 2 * b - a).bounds(), (2, None))
def test_poly_equal(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
poly3 = a + 3 - a
self.assertTrue(poly3 == 3)
self.assertTrue(poly3 == np.array(3, np.int64))
@ -195,7 +194,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
(3 * a * b * a - 2).eq(a * b * a)
def test_poly_compare(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
poly = 4 * a + b + 3
self.assertTrue(poly.ge(0))
self.assertTrue(poly.ge(8))
@ -209,7 +208,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
(4 * a - b).ge(0)
def test_poly_compare_overload(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
poly = 4 * a + b + 3
self.assertTrue(poly >= 0)
self.assertTrue(poly >= 8)
@ -224,7 +223,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
(4 * a - b) >= 0
def test_core_greater_equal(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertTrue(core.greater_equal_dim(a, a))
self.assertTrue(core.greater_equal_dim(a, 0))
self.assertTrue(core.greater_equal_dim(a, 1))
@ -240,7 +239,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
core.greater_equal_dim(a, b)
def test_poly_int_results(self):
a, b = shape_poly.parse_spec("a, b", (2, 3))
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual(a + 2 - a, 2)
self.assertIsInstance(a + 2 - a, int)
self.assertEqual(a + (2 - a), 2)
@ -299,14 +298,14 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
self.assertEqual(quotient, dividend / divisor)
def test_poly_truediv_error(self):
a, = shape_poly.parse_spec("a,", (2,))
a, = shape_poly._parse_spec("a,", (2,))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Division of '3' by dimension polynomial .* is not supported"):
3 / a
def test_dilate_shape(self):
"""0 if d == 0 else 1 + dilation * (d - 1))"""
a, = shape_poly.parse_spec("a,", (2,))
a, = shape_poly._parse_spec("a,", (2,))
self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3)))
self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3)))
@ -315,7 +314,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
def test_stride_shape(self):
"""(s - window_size) // window_stride + 1"""
a, stride = shape_poly.parse_spec("a, s", (2, 3))
a, stride = shape_poly._parse_spec("a, s", (2, 3))
self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2)))
self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2)))
@ -418,11 +417,12 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# check expected_shapeenv.
arg_dtypes = (_f32,) * len(arg_shapes)
def f_tf(*tf_args):
avals, shape_env = jax2tf.jax2tf._args_to_avals_and_env(
tf_args, arg_dtypes, polymorphic_shapes) # The function under test
avals, shape_env = shape_poly.args_avals_and_env(
arg_shapes, arg_dtypes, polymorphic_shapes,
tuple(shape_poly_tf.DimExprTfVal.for_arg(a) for a in tf_args)) # The function under test
if expected_avals is not None:
self.assertEqual(expected_avals, avals)
return shape_env
return {k: d.raw for k, d in shape_env.items()}
if eager_mode:
# If we want to check the shape_env then all arg_shapes must be known
assert all(all(d is not None for d in a_s)
@ -439,7 +439,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
def shaped_array(shape_spec: str, actual_shape: core.Shape):
return core.ShapedArray(
shape_poly.parse_spec(shape_spec, actual_shape), np.float32)
shape_poly._parse_spec(shape_spec, actual_shape), np.float32)
# Known shapes for the arguments
check_avals(
@ -649,7 +649,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=["a, ...", "a"],
eager_mode=True)
def test_pytree(self):
"""Arguments and polymorphic_shapes are pytrees."""