mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Refactor shape_poly to be independent of TF.
Separates out into shape_poly_tf.py the TF-specific parts of shape polymorphism (essentially using tf.shape to get the actual shape, along with tf.{add,subtract,multiply} to evaluate shape polynomials. PiperOrigin-RevId: 410529872
This commit is contained in:
parent
0f56838435
commit
72f5a3ca5c
@ -47,6 +47,7 @@ from jax.interpreters import xla
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
from . import shape_poly
|
||||
from . import shape_poly_tf
|
||||
from . import impl_no_xla
|
||||
|
||||
import numpy as np
|
||||
@ -304,12 +305,19 @@ def convert(fun: Callable,
|
||||
in_tree.children()[0],
|
||||
polymorphic_shapes_))
|
||||
|
||||
def fix_tf1_shape(arg: TfVal) -> Sequence[Optional[int]]:
|
||||
tf_arg_shape = np.shape(arg)
|
||||
return tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape)
|
||||
args_shapes_flat = tuple(fix_tf1_shape(a) for a in args_flat)
|
||||
|
||||
args_dim_exprs_flat = tuple(shape_poly_tf.DimExprTfVal.for_arg(a) for a in args_flat)
|
||||
# Construct the abstract values for the flat arguments, possibly based on
|
||||
# the input shapes and the polymorphic_shapes if given. May create new shape
|
||||
# variables. May cast the args_flat to JAX types, using JAX's interpretation
|
||||
# of types of constants.
|
||||
args_avals_flat, shapeenv = _args_to_avals_and_env(
|
||||
args_flat, arg_dtypes_flat, polymorphic_shapes_flat)
|
||||
args_avals_flat, shapeenv = shape_poly.args_avals_and_env(
|
||||
args_shapes_flat, arg_dtypes_flat, polymorphic_shapes_flat,
|
||||
args_dim_exprs_flat)
|
||||
|
||||
# This function may take pytrees of TfVals. We can only set
|
||||
# tf.custom_gradient on functions that take a flat argument list.
|
||||
@ -638,47 +646,6 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
_thread_local_state.constant_cache[const_key] = (val, tf_val)
|
||||
return tf_val, jax_dtype
|
||||
|
||||
def _args_to_avals_and_env(
|
||||
args: Sequence[TfVal],
|
||||
arg_jax_dtypes: Sequence[DType],
|
||||
polymorphic_shapes: Sequence[Optional[Union[str, PolyShape]]]) -> \
|
||||
Tuple[Sequence[core.ShapedArray], shape_poly.ShapeEnv]:
|
||||
"""Computes canonicalized args, abstract values and a dimension environment for arguments.
|
||||
|
||||
Args:
|
||||
args: the arguments, TF inputs. Must be tf.Tensor or tf.Variable.
|
||||
arg_dtypes: the inferred JAX dtypes for the args.
|
||||
polymorphic_shapes: the polymorphic specifications for the arguments.
|
||||
Returns: a tuple of: a sequence of abstract values corresponding to the
|
||||
arguments, and a dimension variable environment.
|
||||
"""
|
||||
dim_equations: List[shape_poly.DimEquation] = []
|
||||
|
||||
def input_aval(arg: TfVal,
|
||||
arg_jax_dtype: DType,
|
||||
polymorphic_shape: Optional[str]) -> core.ShapedArray:
|
||||
"""The abstract value for an input."""
|
||||
arg_shape = np.shape(arg)
|
||||
aval_shape = shape_poly.parse_spec(polymorphic_shape, arg_shape)
|
||||
arg_tf_shape = tf.shape(arg)
|
||||
for i, d in enumerate(aval_shape):
|
||||
dim_size = arg_shape[i]
|
||||
if isinstance(dim_size, tf.compat.v1.Dimension):
|
||||
dim_size = dim_size.value
|
||||
if not shape_poly.is_poly_dim(d):
|
||||
assert d == dim_size
|
||||
else:
|
||||
dim_equations.append(shape_poly.DimEquation(
|
||||
poly=d, tf_expr=arg_tf_shape[i])) # type: ignore
|
||||
|
||||
|
||||
return core.ShapedArray(aval_shape, arg_jax_dtype)
|
||||
|
||||
avals = tuple(map(input_aval, args, arg_jax_dtypes, polymorphic_shapes)) # type: ignore
|
||||
|
||||
shapeenv = shape_poly.solve_dim_equations(dim_equations)
|
||||
return avals, shapeenv
|
||||
|
||||
|
||||
def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
|
||||
assert all(map(lambda x: x is not None, shape)), (
|
||||
@ -2589,6 +2556,11 @@ def _pjit_sharding_constraint(arg: TfVal, *,
|
||||
|
||||
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
|
||||
|
||||
def _dim_as_value_jax2tf(dim: shape_poly.DimSize):
|
||||
dim_tf, = _eval_shape((dim,))
|
||||
return dim_tf
|
||||
|
||||
tf_impl[shape_poly.dim_as_value_p] = _dim_as_value_jax2tf
|
||||
|
||||
def _register_checkpoint_pytrees():
|
||||
"""Registers TF custom container types as pytrees."""
|
||||
@ -2620,5 +2592,3 @@ def _register_checkpoint_pytrees():
|
||||
|
||||
|
||||
_register_checkpoint_pytrees()
|
||||
|
||||
shape_poly._register_conversion_rules()
|
||||
|
@ -11,11 +11,25 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shape polymorphism support for jax2tf.
|
||||
"""Shape polymorphism support.
|
||||
|
||||
For usage instructions, read the jax2tf.convert docstring, and the
|
||||
This was built initially for jax2tf, but it is now customizeable to be
|
||||
independent of TF. The idea is that we introduce a set of dimension variables
|
||||
at the top-level of a `jit` function. They are introduced implicitly by way
|
||||
of specifying for each dimension of each argument a dimension polynomial
|
||||
in terms of some dimension variables. All dimension variables are assumed to
|
||||
range over integers greater or equal to 1.
|
||||
|
||||
Dimension polynomials overload some integer operations, such as
|
||||
add, multiply, equality, etc. The JAX NumPy layer and the LAX layers have been
|
||||
touched up to be sensitive to handling shapes that contain dimension
|
||||
polynomials. This enables many JAX programs to be traced with dimension
|
||||
polynomials in some dimensions. A priority has been to enable the batch
|
||||
dimension in neural network examples to be polymorphic.
|
||||
|
||||
The best documentation at the moment is in the
|
||||
jax2tf.convert docstring, and the
|
||||
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
|
||||
|
||||
"""
|
||||
import collections
|
||||
import dataclasses
|
||||
@ -23,18 +37,15 @@ import itertools
|
||||
import functools
|
||||
import operator as op
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union
|
||||
|
||||
import jax
|
||||
from jax._src.numpy import lax_numpy
|
||||
import opt_einsum
|
||||
from jax import config
|
||||
from jax import core
|
||||
from . import jax2tf as jax2tf_internal
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
DimSize = core.DimSize
|
||||
Shape = core.Shape
|
||||
@ -62,7 +73,7 @@ class _DimMon(dict):
|
||||
"""Represents a multivariate monomial, such as n^3 * m.
|
||||
|
||||
The representation is a dictionary mapping var:exponent.
|
||||
The `var` are strings and the exponents are >= 1.
|
||||
The `var` are strings and the exponents are integers >= 1.
|
||||
The dimension variables are assumed to range over integers >= 1.
|
||||
"""
|
||||
def __hash__(self):
|
||||
@ -124,21 +135,13 @@ class _DimMon(dict):
|
||||
elif diff > 0: d[key] = diff
|
||||
return _DimMon(d)
|
||||
|
||||
def evaluate(self, env: Dict[str, Any]) -> Any:
|
||||
prod = lambda xs: functools.reduce(op.mul, xs) if xs else 1
|
||||
def pow_opt(v, p):
|
||||
return v if p == 1 else pow(v, p)
|
||||
def evaluate(self, env: "ShapeEnv") -> Union["DimExpr", int]:
|
||||
prod = lambda xs: functools.reduce(_multiply_dim_expr, xs) if xs else 1
|
||||
def pow_opt(v: DimExpr, p: int) -> DimExpr:
|
||||
return v if p == 1 else prod([v] * p)
|
||||
return prod([pow_opt(env[id], deg) for id, deg in self.items()])
|
||||
|
||||
|
||||
def _multiply(coeff, v: TfVal) -> TfVal:
|
||||
try:
|
||||
coeff = int(coeff)
|
||||
except:
|
||||
return coeff * v
|
||||
else:
|
||||
return 0 if coeff == 0 else v if coeff == 1 else coeff * v
|
||||
|
||||
class _DimPolynomial(dict):
|
||||
"""Polynomial with integer coefficients for polymorphic shapes.
|
||||
|
||||
@ -375,12 +378,6 @@ class _DimPolynomial(dict):
|
||||
ub = None if ub is None else ub + coeff
|
||||
return lb, ub
|
||||
|
||||
def evaluate(self, env: Dict[str, Any]) -> Any:
|
||||
def pow_opt(v, p):
|
||||
return v if p == 1 else pow(v, p)
|
||||
terms = [_multiply(coeff, mon.evaluate(env)) for mon, coeff in self.items()]
|
||||
return sum(terms) if len(terms) > 1 else terms[0]
|
||||
|
||||
@property
|
||||
def is_constant(self):
|
||||
return len(self) == 1 and next(iter(self)).degree == 0
|
||||
@ -390,6 +387,10 @@ class _DimPolynomial(dict):
|
||||
"""Returns the highest degree term that comes first lexicographically."""
|
||||
return max(self.items())
|
||||
|
||||
def evaluate(self, env: "ShapeEnv") -> Union["DimExpr", int]:
|
||||
terms = [_multiply_dim_expr(mon.evaluate(env), coeff) for mon, coeff in self.items()]
|
||||
return functools.reduce(_add_dim_expr, terms) if len(terms) > 1 else terms[0]
|
||||
|
||||
|
||||
def _ensure_poly(p: DimSize) -> _DimPolynomial:
|
||||
if isinstance(p, _DimPolynomial): return p
|
||||
@ -504,15 +505,6 @@ dim_as_value_p.def_abstract_eval(_dim_as_value_abstract)
|
||||
def _dim_as_value(dim: DimSize):
|
||||
return dim_as_value_p.bind(dim=dim)
|
||||
|
||||
def _dim_as_value_jax2tf(dim: DimSize):
|
||||
dim_tf, = jax2tf_internal._eval_shape((dim,))
|
||||
assert dim_tf.dtype == tf.int32
|
||||
return dim_tf
|
||||
|
||||
def _register_conversion_rules():
|
||||
jax2tf_internal.tf_impl[dim_as_value_p] = _dim_as_value_jax2tf
|
||||
|
||||
|
||||
class PolyShape(tuple):
|
||||
"""Tuple of polymorphic dimension specifications.
|
||||
|
||||
@ -527,8 +519,8 @@ class PolyShape(tuple):
|
||||
return tuple.__new__(PolyShape, dim_specs)
|
||||
|
||||
|
||||
def parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
arg_shape: Sequence[Optional[int]]) -> Tuple[DimSize, ...]:
|
||||
def _parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
arg_shape: Sequence[Optional[int]]) -> Tuple[DimSize, ...]:
|
||||
"""Parse the shape polymorphic specification for one array argument.
|
||||
Args:
|
||||
spec: a shape polymorphic specification, either a string, or a PolyShape.
|
||||
@ -612,8 +604,6 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
if isinstance(dim_spec, str):
|
||||
dim_spec = dim_spec.strip()
|
||||
dim_size = arg_shape[i]
|
||||
if isinstance(dim_size, tf.compat.v1.Dimension):
|
||||
dim_size = dim_size.value
|
||||
if dim_size is None:
|
||||
def need_dim_var_msg():
|
||||
msg = (f"polymorphic shape {repr(spec)} in axis {i} must contain a dimension variable "
|
||||
@ -644,22 +634,133 @@ def parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
dims = tuple([_process_dim(i, ds) for i, ds in enumerate(spec_tuple)])
|
||||
return dims
|
||||
|
||||
|
||||
_Expr = TypeVar("_Expr")
|
||||
|
||||
class DimExpr:
|
||||
"""Encapsulates an expression that denotes the value of a dimension size
|
||||
for some intermediate value in the computation. It is computed based on the
|
||||
input values for the top-level function arguments and the polymorphic
|
||||
specifications.
|
||||
|
||||
This class is intedend to be subclassed. An example of DimExpr, is
|
||||
shape_poly_tf.DimExprTfVal which computes dimension sizes using TF ops,
|
||||
e.g., `tf.shape(arg0)[0] + 1`.
|
||||
"""
|
||||
def __init__(self, raw: _Expr):
|
||||
self._raw = raw
|
||||
|
||||
@property
|
||||
def raw(self) -> _Expr:
|
||||
"""Returns the raw value."""
|
||||
return self._raw
|
||||
|
||||
@classmethod
|
||||
def for_arg(cls, arg: _Expr) -> Sequence["DimExpr"]:
|
||||
"""A list of dimension expressions, one for each dimensionof arg."""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_known_constant(self) -> Optional[int]:
|
||||
"""Extract the constant if possible.
|
||||
|
||||
Typically this should work when the input arguments have known shape, but
|
||||
it may not work when we compile code that has dynamic shapes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, other: Union["DimExpr", int]) -> "DimExpr":
|
||||
raise NotImplementedError
|
||||
|
||||
def subtract(self, other: Union["DimExpr", int]) -> "DimExpr":
|
||||
raise NotImplementedError
|
||||
|
||||
def multiply(self, other: Union["DimExpr", int]) -> "DimExpr":
|
||||
raise NotImplementedError
|
||||
|
||||
def divmod(self, factor: int) -> Tuple["DimExpr", "DimExpr"]:
|
||||
"""Like Python divmod."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _add_dim_expr(v1: Union["DimExpr", int],
|
||||
v2: Union["DimExpr", int]) -> Union["DimExpr", int]:
|
||||
if isinstance(v1, DimExpr):
|
||||
return v1.add(v2) if v2 != 0 else v1
|
||||
elif isinstance(v2, DimExpr):
|
||||
return v2.add(v1) if v1 != 0 else v2
|
||||
else:
|
||||
return v1 + v2 # integers
|
||||
|
||||
|
||||
def _multiply_dim_expr(v1: Union["DimExpr", int],
|
||||
v2: Union[DimExpr, int]) -> Union[DimExpr, int]:
|
||||
if isinstance(v1, DimExpr):
|
||||
return v1.multiply(v2) if v2 != 1 else v1
|
||||
elif isinstance(v2, DimExpr):
|
||||
return v2.multiply(v1) if v1 != 1 else v2
|
||||
else:
|
||||
return v1 * v2 # integers
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DimEquation:
|
||||
# Represents poly == tf_expr
|
||||
# Represents poly == _expr
|
||||
poly: _DimPolynomial
|
||||
tf_expr: TfVal
|
||||
dim_expr: DimExpr
|
||||
|
||||
# A dimension environment maps dimension variables to TF expressions that
|
||||
# compute the value of the dimension. These expressions refer to the TF
|
||||
# A dimension environment maps dimension variables to expressions that
|
||||
# compute the value of the dimension. These expressions refer to the
|
||||
# function arguments.
|
||||
ShapeEnv = Dict[str, TfVal]
|
||||
ShapeEnv = Dict[str, DimExpr]
|
||||
DType = Any
|
||||
|
||||
def eval_shape(shape: Sequence[DimSize], shape_env: Dict[str, TfVal]) -> Sequence[TfVal]:
|
||||
return tuple(d.evaluate(shape_env) if type(d) is _DimPolynomial else d # type: ignore
|
||||
def eval_shape(shape: Sequence[DimSize], shape_env: ShapeEnv) -> Sequence[_Expr]:
|
||||
def eval_dim(d: DimSize) -> Any:
|
||||
d1 = d.evaluate(shape_env) # type: ignore[union-attr]
|
||||
if isinstance(d1, int):
|
||||
return d1
|
||||
else:
|
||||
return d1.raw
|
||||
return tuple(eval_dim(d) if type(d) is _DimPolynomial else d # type: ignore
|
||||
for d in shape)
|
||||
|
||||
def solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
def args_avals_and_env(
|
||||
arg_shapes: Sequence[Sequence[Optional[int]]],
|
||||
arg_jax_dtypes: Sequence[DType],
|
||||
polymorphic_shapes: Sequence[Optional[Union[str, PolyShape]]],
|
||||
arg_dim_exprs: Sequence[Sequence[DimExpr]]) -> \
|
||||
Tuple[Sequence[core.ShapedArray], ShapeEnv]:
|
||||
"""Computes abstract values and a dimension environment for arguments.
|
||||
|
||||
Args:
|
||||
arg_shapes: the shapes for the arguments, possibly having None dimensions.
|
||||
arg_dtypes: the inferred JAX dtypes for the args.
|
||||
polymorphic_shapes: the polymorphic specifications for the arguments.
|
||||
arg_dim_exprs: an expression that represents each dimension of each argument.
|
||||
Returns: a tuple of: a sequence of abstract values corresponding to the
|
||||
arguments, and a dimension variable environment.
|
||||
"""
|
||||
dim_equations: List[DimEquation] = []
|
||||
|
||||
def input_aval(arg_shape: Sequence[Optional[int]],
|
||||
arg_jax_dtype: DType,
|
||||
polymorphic_shape: Optional[str],
|
||||
arg_dim_exprs: Sequence[DimExpr]) -> core.ShapedArray:
|
||||
"""The abstract value for an input."""
|
||||
aval_shape = _parse_spec(polymorphic_shape, arg_shape)
|
||||
for i, d in enumerate(aval_shape):
|
||||
if is_poly_dim(d):
|
||||
dim_equations.append(DimEquation(
|
||||
poly=d, dim_expr=arg_dim_exprs[i])) # type: ignore
|
||||
|
||||
return core.ShapedArray(aval_shape, arg_jax_dtype)
|
||||
|
||||
avals = tuple(map(input_aval, arg_shapes, arg_jax_dtypes, polymorphic_shapes, arg_dim_exprs)) # type: ignore
|
||||
|
||||
shapeenv = _solve_dim_equations(dim_equations)
|
||||
return avals, shapeenv
|
||||
|
||||
def _solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
# Returns a shape environment if it can solve all dimension variables.
|
||||
# Raises an exception if it cannot.
|
||||
shapeenv: ShapeEnv = {}
|
||||
@ -667,24 +768,24 @@ def solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
def _shapeenv_to_str() -> str:
|
||||
if shapeenv:
|
||||
return (" Partial solution: " +
|
||||
", ".join([f"{var} = {val}" for var, val in shapeenv.items()]) + ".")
|
||||
", ".join([f"{var} = {val.raw}" for var, val in shapeenv.items()]) + ".")
|
||||
else:
|
||||
return ""
|
||||
|
||||
def process_one_eqn(eqn: DimEquation) -> bool:
|
||||
# Try to rewrite the equation as "var * factor_var = tf_expr" (a linear
|
||||
# Try to rewrite the equation as "var * factor_var = dim_expr" (a linear
|
||||
# uni-variate equation. Return False if this rewrite fails.
|
||||
# Otherwise, add the variable to shapeenv and return True.
|
||||
|
||||
# The invariant is: var * factor_var + rest_eqn_poly = tf_expr
|
||||
# The invariant is: var * factor_var + rest_eqn_poly = dim_expr
|
||||
var, factor_var = None, None
|
||||
tf_expr = eqn.tf_expr
|
||||
dim_expr = eqn.dim_expr
|
||||
|
||||
for mon, factor in eqn.poly.items():
|
||||
# Perhaps we can already evaluate this monomial (all vars solved)
|
||||
try:
|
||||
mon_value = mon.evaluate(shapeenv)
|
||||
tf_expr = tf.math.subtract(tf_expr, _multiply(factor, mon_value))
|
||||
dim_expr = dim_expr.subtract(_multiply_dim_expr(mon_value, factor))
|
||||
continue
|
||||
except KeyError:
|
||||
# There are some indeterminate variables. We handle only the case of
|
||||
@ -697,37 +798,33 @@ def solve_dim_equations(eqns: List[DimEquation]) -> ShapeEnv:
|
||||
return False
|
||||
|
||||
if var is not None:
|
||||
try:
|
||||
var_value = tf.math.floordiv(tf_expr, factor_var) if factor_var != 1 else tf_expr
|
||||
# Check that the division is even. Works only in eager mode.
|
||||
if tf.math.floormod(tf_expr, factor_var).numpy() != 0:
|
||||
msg = (f"Dimension variable {var} must have integer value >= 1. " # type: ignore
|
||||
f"Found value {int(tf_expr.numpy()) / factor_var} when solving "
|
||||
f"{eqn.poly} == {eqn.tf_expr}.{_shapeenv_to_str()}")
|
||||
raise ValueError(msg)
|
||||
if var_value.numpy() <= 0:
|
||||
msg = (f"Dimension variable {var} must have integer value >= 1. "
|
||||
f"Found value {int(var_value.numpy())} when solving "
|
||||
f"{eqn.poly} == {eqn.tf_expr}.{_shapeenv_to_str()}")
|
||||
raise ValueError(msg)
|
||||
except AttributeError:
|
||||
var_value, var_remainder = dim_expr.divmod(factor_var) # type: ignore
|
||||
# Check that the division is even. Works only in eager mode.
|
||||
var_remainder_int = var_remainder.is_known_constant()
|
||||
if var_remainder_int is not None and var_remainder_int != 0:
|
||||
# TODO(necula): check even in graph mode, by embedding the checks in
|
||||
# the graph.
|
||||
pass
|
||||
msg = (f"Dimension variable {var} must have integer value >= 1. " # type: ignore
|
||||
f"Found value {int(dim_expr.is_known_constant()) / factor_var} when solving "
|
||||
f"{eqn.poly} == {eqn.dim_expr}.{_shapeenv_to_str()}")
|
||||
raise ValueError(msg)
|
||||
var_value_int = var_value.is_known_constant()
|
||||
if var_value_int is not None and var_value_int <= 0:
|
||||
msg = (f"Dimension variable {var} must have integer value >= 1. "
|
||||
f"Found value {int(var_value_int)} when solving "
|
||||
f"{eqn.poly} == {eqn.dim_expr.raw}.{_shapeenv_to_str()}")
|
||||
raise ValueError(msg)
|
||||
|
||||
shapeenv[var] = var_value
|
||||
return True
|
||||
else:
|
||||
# All variables are resolved for this equation
|
||||
try:
|
||||
if tf_expr.numpy() != 0:
|
||||
err_msg = (
|
||||
"Found inconsistency when solving "
|
||||
f"{eqn.poly} == {eqn.tf_expr}.{_shapeenv_to_str()}")
|
||||
raise ValueError(err_msg)
|
||||
except AttributeError:
|
||||
# TODO(necula): check that the equation is satisfied even in graph
|
||||
# mode.
|
||||
pass
|
||||
dim_expr_int = dim_expr.is_known_constant()
|
||||
if dim_expr_int is not None and dim_expr_int != 0:
|
||||
err_msg = (
|
||||
"Found inconsistency when solving "
|
||||
f"{eqn.poly} == {eqn.dim_expr.raw}.{_shapeenv_to_str()}")
|
||||
raise ValueError(err_msg)
|
||||
return True
|
||||
|
||||
while True:
|
||||
|
65
jax/experimental/jax2tf/shape_poly_tf.py
Normal file
65
jax/experimental/jax2tf/shape_poly_tf.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Specialized shape polymorphism support for jax2tf.
|
||||
|
||||
See the shape_poly.py module documentation, the jax2tf.convert docstring, and the
|
||||
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
|
||||
"""
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
from . import shape_poly
|
||||
|
||||
TfVal = Any
|
||||
|
||||
class DimExprTfVal(shape_poly.DimExpr):
|
||||
"""Express dimensions using tf.shape and the TF arguments."""
|
||||
def __init__(self, tfval: TfVal):
|
||||
super(DimExprTfVal, self).__init__(tfval)
|
||||
|
||||
@classmethod
|
||||
def for_arg(cls, arg: TfVal) -> Sequence[shape_poly.DimExpr]:
|
||||
tf_shape = tf.shape(arg)
|
||||
return tuple(DimExprTfVal(tf_shape[i]) for i in range(len(np.shape(arg))))
|
||||
|
||||
def is_known_constant(self) -> Optional[int]:
|
||||
# When under TF eager, the dimension expressions should be constants.
|
||||
# Under TF graph, they will not be.
|
||||
try:
|
||||
return self.raw.numpy()
|
||||
except AttributeError as e:
|
||||
assert str(e).find("numpy") > 0, e
|
||||
return None
|
||||
|
||||
def add(self, other: Union[shape_poly.DimExpr, int]) -> shape_poly.DimExpr:
|
||||
if isinstance(other, shape_poly.DimExpr):
|
||||
other = other.raw # type: ignore[assignment]
|
||||
return DimExprTfVal(tf.math.add(self.raw, other))
|
||||
|
||||
def subtract(self, other: Union[shape_poly.DimExpr, int]) -> shape_poly.DimExpr:
|
||||
if isinstance(other, shape_poly.DimExpr):
|
||||
other = other.raw # type: ignore[assignment]
|
||||
return DimExprTfVal(tf.math.subtract(self.raw, other))
|
||||
|
||||
def multiply(self, other: Union[shape_poly.DimExpr, int]) -> shape_poly.DimExpr:
|
||||
if isinstance(other, shape_poly.DimExpr):
|
||||
other = other.raw # type: ignore[assignment]
|
||||
return DimExprTfVal(tf.math.multiply(self.raw, other))
|
||||
|
||||
def divmod(self, factor: int) -> Tuple[shape_poly.DimExpr, shape_poly.DimExpr]:
|
||||
dividend = DimExprTfVal(tf.math.floordiv(self.raw, factor)) if factor != 1 else self
|
||||
mod = DimExprTfVal(tf.math.floormod(self.raw, factor))
|
||||
return dividend, mod
|
@ -26,6 +26,7 @@ import jax
|
||||
from jax import core
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
from jax.experimental.jax2tf import shape_poly_tf
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
@ -51,21 +52,19 @@ PS = jax2tf.PolyShape
|
||||
class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_parse_poly_spec(self):
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec(None, (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec("2, 3", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec("2, _", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec("2, ...", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly._parse_spec(None, (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly._parse_spec("2, 3", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly._parse_spec("2, _", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly._parse_spec("2, ...", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly._parse_spec("...", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly._parse_spec(" ( 2 , 3 ) ", (2, 3)))
|
||||
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
self.assertEqual((a, 3), shape_poly.parse_spec("(a, ...) ", (None, 3)))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
self.assertEqual((a, 3), shape_poly._parse_spec("(a, ...) ", (None, 3)))
|
||||
tshape = tf.TensorShape([None, 3])
|
||||
self.assertEqual((a, 3), shape_poly.parse_spec("(a, ...) ", tshape))
|
||||
# Test directly with tf.compat.v1.Dimension
|
||||
self.assertEqual((a, 3), shape_poly.parse_spec("(a, ...) ", tshape.dims))
|
||||
self.assertEqual((a, 3), shape_poly._parse_spec("(a, ...) ", tshape))
|
||||
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dim_spec={dim_spec}",
|
||||
dim_spec=dim_spec, dim_poly=dim_poly)
|
||||
@ -82,8 +81,8 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
dim_poly=3 * a * b * a - 2):
|
||||
# For internal usage only (the polymorphic_shapes of VJP) we need to
|
||||
# parse polynomials.
|
||||
self.assertEqual((dim_poly,), shape_poly.parse_spec(dim_spec, (2,)))
|
||||
self.assertEqual((dim_poly,), shape_poly.parse_spec(str(dim_poly), (2,)))
|
||||
self.assertEqual((dim_poly,), shape_poly._parse_spec(dim_spec, (2,)))
|
||||
self.assertEqual((dim_poly,), shape_poly._parse_spec(str(dim_poly), (2,)))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dim_spec={dim_spec}",
|
||||
@ -101,11 +100,11 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
dim_poly=3 * a * b * a - 2):
|
||||
# For internal usage only (the polymorphic_shapes of VJP) we need to
|
||||
# parse polynomials.
|
||||
self.assertEqual((dim_poly,), shape_poly.parse_spec(dim_spec, (2,)))
|
||||
self.assertEqual((dim_poly,), shape_poly.parse_spec(str(dim_poly), (2,)))
|
||||
self.assertEqual((dim_poly,), shape_poly._parse_spec(dim_spec, (2,)))
|
||||
self.assertEqual((dim_poly,), shape_poly._parse_spec(str(dim_poly), (2,)))
|
||||
|
||||
def test_dim_vars(self):
|
||||
a, b, a1 = shape_poly.parse_spec("a, b, a", (2, 3, 2))
|
||||
a, b, a1 = shape_poly._parse_spec("a, b, a", (2, 3, 2))
|
||||
self.assertEqual(True, a == a)
|
||||
self.assertEqual(True, a == a1)
|
||||
self.assertEqual(False, a != a)
|
||||
@ -134,19 +133,19 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
b in [a, b]
|
||||
|
||||
def test_get_vars(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
|
||||
self.assertEqual({"a"}, a.get_vars())
|
||||
self.assertEqual({"a", "b"}, (a * b * a).get_vars())
|
||||
|
||||
def test_evaluate(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
|
||||
self.assertEqual(1, (a * a - b).evaluate(dict(a=2, b=3)))
|
||||
self.assertEqual(2, (a * a - b + 1).evaluate(dict(a=-2, b=3)))
|
||||
|
||||
def test_dim_vars_symbolic_equal(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
self.assertTrue(core.symbolic_equal_dim(a, a))
|
||||
self.assertFalse(core.symbolic_equal_dim(a, 1))
|
||||
self.assertFalse(core.symbolic_equal_dim(a, b))
|
||||
@ -165,7 +164,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertTrue(core.symbolic_equal_dim(1, "a"))
|
||||
|
||||
def test_poly_bounds(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
self.assertEqual(a.bounds(), (1, None))
|
||||
self.assertEqual((2 * a).bounds(), (2, None))
|
||||
self.assertEqual((2 * a - 3).bounds(), (-1, None))
|
||||
@ -176,7 +175,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual((a + 2 * b - a).bounds(), (2, None))
|
||||
|
||||
def test_poly_equal(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
poly3 = a + 3 - a
|
||||
self.assertTrue(poly3 == 3)
|
||||
self.assertTrue(poly3 == np.array(3, np.int64))
|
||||
@ -195,7 +194,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
(3 * a * b * a - 2).eq(a * b * a)
|
||||
|
||||
def test_poly_compare(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
poly = 4 * a + b + 3
|
||||
self.assertTrue(poly.ge(0))
|
||||
self.assertTrue(poly.ge(8))
|
||||
@ -209,7 +208,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
(4 * a - b).ge(0)
|
||||
|
||||
def test_poly_compare_overload(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
poly = 4 * a + b + 3
|
||||
self.assertTrue(poly >= 0)
|
||||
self.assertTrue(poly >= 8)
|
||||
@ -224,7 +223,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
(4 * a - b) >= 0
|
||||
|
||||
def test_core_greater_equal(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
self.assertTrue(core.greater_equal_dim(a, a))
|
||||
self.assertTrue(core.greater_equal_dim(a, 0))
|
||||
self.assertTrue(core.greater_equal_dim(a, 1))
|
||||
@ -240,7 +239,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
core.greater_equal_dim(a, b)
|
||||
|
||||
def test_poly_int_results(self):
|
||||
a, b = shape_poly.parse_spec("a, b", (2, 3))
|
||||
a, b = shape_poly._parse_spec("a, b", (2, 3))
|
||||
self.assertEqual(a + 2 - a, 2)
|
||||
self.assertIsInstance(a + 2 - a, int)
|
||||
self.assertEqual(a + (2 - a), 2)
|
||||
@ -299,14 +298,14 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual(quotient, dividend / divisor)
|
||||
|
||||
def test_poly_truediv_error(self):
|
||||
a, = shape_poly.parse_spec("a,", (2,))
|
||||
a, = shape_poly._parse_spec("a,", (2,))
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
"Division of '3' by dimension polynomial .* is not supported"):
|
||||
3 / a
|
||||
|
||||
def test_dilate_shape(self):
|
||||
"""0 if d == 0 else 1 + dilation * (d - 1))"""
|
||||
a, = shape_poly.parse_spec("a,", (2,))
|
||||
a, = shape_poly._parse_spec("a,", (2,))
|
||||
|
||||
self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3)))
|
||||
self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3)))
|
||||
@ -315,7 +314,7 @@ class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_stride_shape(self):
|
||||
"""(s - window_size) // window_stride + 1"""
|
||||
a, stride = shape_poly.parse_spec("a, s", (2, 3))
|
||||
a, stride = shape_poly._parse_spec("a, s", (2, 3))
|
||||
|
||||
self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2)))
|
||||
self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2)))
|
||||
@ -418,11 +417,12 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# check expected_shapeenv.
|
||||
arg_dtypes = (_f32,) * len(arg_shapes)
|
||||
def f_tf(*tf_args):
|
||||
avals, shape_env = jax2tf.jax2tf._args_to_avals_and_env(
|
||||
tf_args, arg_dtypes, polymorphic_shapes) # The function under test
|
||||
avals, shape_env = shape_poly.args_avals_and_env(
|
||||
arg_shapes, arg_dtypes, polymorphic_shapes,
|
||||
tuple(shape_poly_tf.DimExprTfVal.for_arg(a) for a in tf_args)) # The function under test
|
||||
if expected_avals is not None:
|
||||
self.assertEqual(expected_avals, avals)
|
||||
return shape_env
|
||||
return {k: d.raw for k, d in shape_env.items()}
|
||||
if eager_mode:
|
||||
# If we want to check the shape_env then all arg_shapes must be known
|
||||
assert all(all(d is not None for d in a_s)
|
||||
@ -439,7 +439,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def shaped_array(shape_spec: str, actual_shape: core.Shape):
|
||||
return core.ShapedArray(
|
||||
shape_poly.parse_spec(shape_spec, actual_shape), np.float32)
|
||||
shape_poly._parse_spec(shape_spec, actual_shape), np.float32)
|
||||
|
||||
# Known shapes for the arguments
|
||||
check_avals(
|
||||
@ -649,7 +649,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=["a, ...", "a"],
|
||||
eager_mode=True)
|
||||
|
||||
|
||||
def test_pytree(self):
|
||||
"""Arguments and polymorphic_shapes are pytrees."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user