mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Cleanup the API, and more documentation
This commit is contained in:
parent
14737e365e
commit
d9468c7513
@ -129,7 +129,7 @@ sets up symbolic links from site-packages into the repository.
|
||||
|
||||
To run all the JAX tests, we recommend using `pytest-xdist`, which can run tests in
|
||||
parallel. First, install `pytest-xdist` and `pytest-benchmark` by running
|
||||
`pip install pytest-xdist pytest-benchmark`.
|
||||
`ip install -r build/test-requirements.txt`.
|
||||
Then, from the repository root directory run:
|
||||
|
||||
```
|
||||
|
12
jax/core.py
12
jax/core.py
@ -1316,25 +1316,25 @@ def symbolic_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:
|
||||
def symbolic_equal_shape(s1: Shape, s2: Shape) -> bool:
|
||||
"""See DimensionHandler.symbolic_equal."""
|
||||
return (len(s1) == len(s2) and
|
||||
all(safe_map(symbolic_equal_dim, s1, s2)))
|
||||
all(map(symbolic_equal_dim, s1, s2)))
|
||||
|
||||
def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool:
|
||||
return _get_dim_handler(d1, d2).greater_equal(d1, d2)
|
||||
|
||||
def greater_equal_shape(s1: Shape, s2: Shape) -> bool:
|
||||
return all(safe_map(greater_equal_dim, s1, s2))
|
||||
return all(map(greater_equal_dim, s1, s2))
|
||||
|
||||
def sum_dim(*ds: DimSize) -> DimSize:
|
||||
return _get_dim_handler(*ds).sum(*ds)
|
||||
|
||||
def sum_shapes(*ss: Shape) -> Shape:
|
||||
return tuple(safe_map(sum_dim, *ss))
|
||||
return tuple(map(sum_dim, *ss))
|
||||
|
||||
def diff_dim(d1: DimSize, d2: DimSize) -> DimSize:
|
||||
return _get_dim_handler(d1, d2).diff(d1, d2)
|
||||
|
||||
def diff_shape(s1: Shape, s2: Shape) -> Shape:
|
||||
return tuple(safe_map(diff_dim, s1, s2))
|
||||
return tuple(map(diff_dim, s1, s2))
|
||||
|
||||
def divide_shape_sizes(s1: Shape, s2: Shape) -> int:
|
||||
s1 = s1 or (1,)
|
||||
@ -1349,14 +1349,14 @@ def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
|
||||
return _get_dim_handler(d, dilation).dilate(d, dilation)
|
||||
|
||||
def dilate_shape(s: Shape, dilations: Sequence[int]) -> Shape:
|
||||
return tuple(safe_map(dilate_dim, s, dilations))
|
||||
return tuple(map(dilate_dim, s, dilations))
|
||||
|
||||
def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
|
||||
return _get_dim_handler(d, window_size, window_stride).stride(d, window_size, window_stride)
|
||||
|
||||
def stride_shape(s: Shape, window_size: Shape, window_stride: Shape) -> Shape:
|
||||
"""(s - window_size) // window_stride + 1"""
|
||||
return tuple(safe_map(stride_dim, s, window_size, window_stride))
|
||||
return tuple(map(stride_dim, s, window_size, window_stride))
|
||||
|
||||
|
||||
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
||||
|
@ -13,5 +13,5 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from .jax2tf import convert, shape_as_value, split_to_logical_devices
|
||||
from .jax2tf import convert, shape_as_value, split_to_logical_devices, PolyShape
|
||||
from .call_tf import call_tf
|
||||
|
@ -34,7 +34,7 @@ from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lax import fft as lax_fft
|
||||
import jax._src.random
|
||||
from . import shape_poly
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
@ -48,6 +48,7 @@ from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # ty
|
||||
|
||||
from jax.lib import xla_client
|
||||
|
||||
PolyShape = shape_poly.PolyShape
|
||||
|
||||
# The scope name need to be a valid TensorFlow name. See
|
||||
# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731
|
||||
@ -119,8 +120,8 @@ def _xla_path_disabled_error(primitive_name: str) -> Exception:
|
||||
|
||||
@functools.partial(api_util.api_hook, tag="jax2tf_convert")
|
||||
def convert(fun: Callable, *,
|
||||
polymorphic_shapes_experimental: Optional[Sequence[Any]]=None,
|
||||
in_shapes=None,
|
||||
polymorphic_shapes: Optional[Sequence[Any]]=None,
|
||||
in_shapes=None, # DEPRECATED
|
||||
with_gradient=True, enable_xla=True) -> Callable:
|
||||
"""Transforms `fun` to be executed by TensorFlow.
|
||||
|
||||
@ -129,43 +130,45 @@ def convert(fun: Callable, *,
|
||||
|
||||
Args:
|
||||
fun: Function to be transformed. Its arguments and return value should be
|
||||
JAX arrays, or (nested) standard Python containers (tuple/list/dict)
|
||||
thereof.
|
||||
JAX arrays, or nested standard Python containers (tuple/list/dict)
|
||||
thereof (pytrees).
|
||||
|
||||
polymorphic_shapes_experimental: an optional sequence of shape specifications,
|
||||
one for each argument of the function to be converted. Default is `None`,
|
||||
in which case the argument shape specifications are taken from the shapes
|
||||
of the actual arguments.
|
||||
A non-default `polymorphic_shapes` is used to specify shape variables for
|
||||
some of the input dimensions, to specify that the conversion to TF must be
|
||||
done for any possible non-zero values for the shape variables.
|
||||
polymorphic_shapes: Specifies input shapes to be treated polymorphically
|
||||
during conversion.
|
||||
|
||||
.. warning::
|
||||
The shape-polymorphic conversion is an experimental feature. It is meant
|
||||
to be sound, but it is known to reject some JAX programs that are
|
||||
shape polymorphic. The details of this feature can change.
|
||||
|
||||
It should be a Python object with the same pytree structure as,
|
||||
or a prefix of, the tuple of arguments to the function,
|
||||
but with a shape specification corresponding to each argument.
|
||||
The default value is `None`, which is a shortcut for a tuple of `None`
|
||||
one for each argument, denoting that all shapes are monomorphic.
|
||||
See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
||||
|
||||
A shape specification for an array argument
|
||||
should be an object `PolyShape(dim0, dim1, ..., dimn)`
|
||||
where each `dim` is a dimension specification: a positive integer denoting
|
||||
a monomorphic dimension of the given size,
|
||||
or a string denoting a dimension variable assumed to range over non-zero
|
||||
dimension sizes,
|
||||
or the special placeholder string "_" denoting a monomorphic dimension
|
||||
whose size is given by the actual argument.
|
||||
As a shortcut, an Ellipsis suffix in the
|
||||
list of dimension specifications stands for a list of "_" placeholders.
|
||||
For convenience, a shape specification can also be given as a string
|
||||
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
|
||||
with surrounding parentheses: "(batch, ...)".
|
||||
|
||||
The conversion fails if it cannot ensure that the it would produce the same
|
||||
sequence of TF ops for any non-zero values of shape variables. This feature
|
||||
is experimental, and it may fail loudly even for code that is actually
|
||||
shape polymorphic.
|
||||
sequence of TF ops for any non-zero values of the dimension variables.
|
||||
|
||||
If an argument is a pytree, then the
|
||||
shape specification must be a matching pytree or `None`.
|
||||
See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
||||
A shape specification should be a string, with comma-separated dimension
|
||||
specifications, and optionally wrapped in parentheses. A dimension
|
||||
specification is either a number, or the placeholder `_`, or a lowercase
|
||||
word denoting a name for a dimension variable.
|
||||
In presence of dimension variables, the conversion is done with a
|
||||
shape abstraction that allows any non-zero concrete value for the variable.
|
||||
Examples of shape specifications:
|
||||
* `[None, "(batch, 16)"]`: no specification for the first argument (takes
|
||||
the shape from the actual argument); the second argument is a 2D
|
||||
array with the first dimension size set to a variable `batch` and the
|
||||
second dimension 16.
|
||||
* `["(batch, _)", "(batch,)"]`: the leading dimensions of the two arguments
|
||||
must match. The second dimension of the first argument is taken from the
|
||||
actual argument shape.
|
||||
See [the README](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||||
for more details.
|
||||
|
||||
in_shapes: DEPRECATED in favor of `polymorphic_shapes_experimental`.
|
||||
in_shapes: DEPRECATED in favor of `polymorphic_shapes`.
|
||||
|
||||
with_gradient: if set, will add a tf.custom_gradient to the converted
|
||||
function, by converting the ``jax.vjp(fun)``. Only first-order
|
||||
@ -185,7 +188,6 @@ def convert(fun: Callable, *,
|
||||
global _enable_xla
|
||||
_enable_xla = enable_xla
|
||||
api._check_callable(fun)
|
||||
polymorphic_shapes = polymorphic_shapes_experimental
|
||||
|
||||
def converted_fun(*args: TfVal) -> TfVal:
|
||||
# TODO: is there a better way to check if we are inside a transformation?
|
||||
@ -214,13 +216,14 @@ def convert(fun: Callable, *,
|
||||
else:
|
||||
if not isinstance(polymorphic_shapes, Sequence) or len(args) != len(polymorphic_shapes):
|
||||
msg = ("polymorphic_shapes must be a sequence with the same length as the argument list "
|
||||
f"({len(args)}). Got polymorphic_shapes_experimental={polymorphic_shapes}.")
|
||||
f"({len(args)}). Got polymorphic_shapes={polymorphic_shapes}.")
|
||||
raise TypeError(msg)
|
||||
polymorphic_shapes_ = tuple(polymorphic_shapes)
|
||||
|
||||
# Expand the in_shapes to match the argument pytree
|
||||
polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
|
||||
in_tree.children()[0], polymorphic_shapes_))
|
||||
in_tree.children()[0],
|
||||
polymorphic_shapes_))
|
||||
|
||||
# Construct the abstract values for the flat arguments, possibly based on
|
||||
# the input shapes and the in_shapes if given. May create new shape
|
||||
@ -262,7 +265,7 @@ def convert(fun: Callable, *,
|
||||
# TODO: enable higher-order gradients
|
||||
with tf.name_scope("jax2tf_vjp"):
|
||||
in_cts = convert(fun_vjp_jax, with_gradient=False,
|
||||
polymorphic_shapes_experimental=vjp_polymorphic_shapes)(args, out_cts)
|
||||
polymorphic_shapes=vjp_polymorphic_shapes)(args, out_cts)
|
||||
return in_cts
|
||||
|
||||
try:
|
||||
@ -295,6 +298,7 @@ def convert(fun: Callable, *,
|
||||
|
||||
return converted_fun
|
||||
|
||||
|
||||
# Internals
|
||||
|
||||
|
||||
@ -388,7 +392,7 @@ def _tfval_shape_dtype(val: TfVal) -> Tuple[Sequence[Optional[int]], DType]:
|
||||
# function arguments.
|
||||
_ShapeEnv = Dict[shape_poly.DimVar, TfVal]
|
||||
def _args_to_avals_and_env(args: Sequence[TfVal],
|
||||
polymorphic_shapes: Sequence[Optional[str]]) -> \
|
||||
polymorphic_shapes: Sequence[Optional[Union[str, PolyShape]]]) -> \
|
||||
Tuple[Sequence[core.AbstractValue], _ShapeEnv]:
|
||||
"""Computes abstract values and a dimension environment for arguments.
|
||||
|
||||
@ -421,7 +425,7 @@ def _args_to_avals_and_env(args: Sequence[TfVal],
|
||||
|
||||
return core.ShapedArray(aval_shape, dtype)
|
||||
|
||||
avals = tuple(map(input_aval, args, polymorphic_shapes))
|
||||
avals = tuple(map(input_aval, args, polymorphic_shapes)) # type: ignore
|
||||
return avals, shapeenv
|
||||
|
||||
# A shape environment maps shape variables to TfVal.
|
||||
|
@ -18,8 +18,7 @@ For usage instructions, read the jax2tf.convert docstring, and the
|
||||
|
||||
"""
|
||||
import collections
|
||||
import string
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from jax import core
|
||||
|
||||
@ -114,14 +113,20 @@ def _split_shape_ints(shape: Shape) -> Tuple[Sequence[int], Sequence[DimVar]]:
|
||||
return shape_ints, shape_vars
|
||||
|
||||
|
||||
class ShapeSyntaxError(Exception): pass
|
||||
class PolyShape(tuple):
|
||||
"""Tuple of polymorphic dimension specifications.
|
||||
|
||||
_identifiers = frozenset(string.ascii_lowercase)
|
||||
def parse_spec(spec: Optional[str],
|
||||
See docstring of :func:`jax2tf.convert`.
|
||||
"""
|
||||
def __new__(cls, *dim_specs):
|
||||
return tuple.__new__(PolyShape, dim_specs)
|
||||
|
||||
|
||||
def parse_spec(spec: Optional[Union[str, PolyShape]],
|
||||
arg_shape: Sequence[Optional[int]]) -> Tuple[DimSize, ...]:
|
||||
"""Parse the shape polymorphic specification for one array argument.
|
||||
Args:
|
||||
spec: a shape polymorphic specification.
|
||||
spec: a shape polymorphic specification, either a string, or a PolyShape.
|
||||
arg_shape: an actual shape, possibly containing unknown dimensions (None).
|
||||
|
||||
The placeholders `_` in the specification are replaced with the values from
|
||||
@ -129,55 +134,75 @@ def parse_spec(spec: Optional[str],
|
||||
|
||||
See the README.md for usage.
|
||||
"""
|
||||
shape_var_map: Dict[str, Set[int]] = collections.defaultdict(set)
|
||||
def _parse_dim(dim_spec: str, dim_size: Optional[int]) -> Union[int, DimSize]:
|
||||
if dim_spec == '_':
|
||||
if dim_size is None:
|
||||
msg = (f"polymorphic_shape '{spec}' has `_` placeholders for argument shape "
|
||||
f"dimensions that are unknown: {arg_shape}")
|
||||
raise ValueError(msg)
|
||||
return dim_size
|
||||
elif dim_spec.isdigit():
|
||||
spec_size = int(dim_spec)
|
||||
if dim_size != spec_size:
|
||||
if dim_size is None:
|
||||
msg = (f"polymorphic_shape '{spec}' must contain shape variables for argument shape "
|
||||
f"dimensions that are unknown: {arg_shape}")
|
||||
else:
|
||||
msg = (f"polymorphic_shape '{spec}' does not match argument shape {arg_shape}")
|
||||
raise ValueError(msg)
|
||||
return spec_size
|
||||
elif dim_spec[0] in _identifiers:
|
||||
if dim_size is not None:
|
||||
shape_var_map[dim_spec].add(dim_size)
|
||||
return DimVar(dim_spec)
|
||||
if spec is None:
|
||||
spec_tuple = (...,) # type: Tuple[Any,...]
|
||||
elif isinstance(spec, PolyShape):
|
||||
spec_tuple = tuple(spec)
|
||||
elif isinstance(spec, str):
|
||||
spec_ = spec.replace(" ", "")
|
||||
if spec_[0] == "(":
|
||||
if spec_[-1] != ")":
|
||||
raise ValueError(spec)
|
||||
spec_ = spec_[1:-1]
|
||||
spec_ = spec_.rstrip(",")
|
||||
if not spec_:
|
||||
spec_tuple = ()
|
||||
else:
|
||||
raise ShapeSyntaxError(dim_spec)
|
||||
specs = spec_.split(',')
|
||||
def parse_dim(ds: str):
|
||||
if ds == "...":
|
||||
return ...
|
||||
elif ds.isdigit():
|
||||
return int(ds)
|
||||
elif ds == "_" or ds.isalnum():
|
||||
return ds
|
||||
else:
|
||||
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
|
||||
|
||||
if not spec:
|
||||
if any(d is None for d in arg_shape):
|
||||
msg = ("polymorphic_shape must be specified when the argument "
|
||||
f"shape {arg_shape} is partially known.")
|
||||
raise ValueError(msg)
|
||||
return tuple(arg_shape)
|
||||
|
||||
if spec[0] == '(':
|
||||
if spec[-1] != ')':
|
||||
raise ShapeSyntaxError(spec)
|
||||
spec_ = spec[1:-1]
|
||||
spec_tuple = tuple(map(parse_dim, specs))
|
||||
else:
|
||||
spec_ = spec
|
||||
specs = spec_.replace(' ', '').strip(',').split(',')
|
||||
if len(specs) != len(arg_shape):
|
||||
msg = (f"polymorphic_shape '{spec}' has different rank than argument "
|
||||
f"shape {arg_shape}")
|
||||
raise ValueError(msg)
|
||||
dims = tuple(map(_parse_dim, specs, arg_shape))
|
||||
raise ValueError(f"PolyShape '{spec}' must be either None, a string, or PolyShape.")
|
||||
|
||||
ds_ellipses = tuple(ds for ds in spec_tuple if ds == ...)
|
||||
if ds_ellipses:
|
||||
if len(ds_ellipses) > 1 or spec_tuple[-1] != ...:
|
||||
raise ValueError(f"PolyShape '{spec}' can contain Ellipsis only at the end.")
|
||||
spec_tuple = spec_tuple[0:-1]
|
||||
if len(arg_shape) >= len(spec_tuple):
|
||||
spec_tuple = spec_tuple + ("_",) * (len(arg_shape) - len(spec_tuple))
|
||||
|
||||
if len(arg_shape) != len(spec_tuple):
|
||||
raise ValueError(f"PolyShape '{spec}' must match the rank of arguments {arg_shape}.")
|
||||
|
||||
shape_var_map: Dict[str, Set[int]] = collections.defaultdict(set)
|
||||
def _process_dim(i: int, dim_spec):
|
||||
if not isinstance(dim_spec, (str, int)):
|
||||
raise ValueError(f"PolyShape '{spec}' in axis {i} must contain only integers, strings, or Ellipsis.")
|
||||
dim_size = arg_shape[i]
|
||||
if dim_size is None:
|
||||
if dim_spec == "_" or not isinstance(dim_spec, str):
|
||||
msg = (f"PolyShape '{spec}' in axis {i} must contain a shape variable "
|
||||
f"for unknown dimension in argument shape {arg_shape}")
|
||||
raise ValueError(msg)
|
||||
return DimVar(dim_spec)
|
||||
else: # dim_size is known
|
||||
if dim_spec == "_":
|
||||
return dim_size
|
||||
if isinstance(dim_spec, int):
|
||||
if dim_spec != dim_size:
|
||||
msg = (f"PolyShape '{spec}' in axis {i} must contain a constant or '_' "
|
||||
f"for known dimension in argument shape {arg_shape}")
|
||||
raise ValueError(msg)
|
||||
return dim_size
|
||||
# We have a dimension variable for a known dimension.
|
||||
shape_var_map[dim_spec].add(dim_size)
|
||||
return DimVar(dim_spec)
|
||||
|
||||
dims = tuple([_process_dim(i, ds) for i, ds in enumerate(spec_tuple)])
|
||||
for dim_var, dim_var_values in shape_var_map.items():
|
||||
if len(dim_var_values) != 1:
|
||||
msg = (f"polymorphic shape variable '{dim_var}' corresponds to multiple "
|
||||
f"values ({sorted(dim_var_values)}), in polymorphic_shape '{spec}' and "
|
||||
msg = (f"PolyShape '{spec}' has dimension variable '{dim_var}' "
|
||||
f"corresponding to multiple values ({sorted(dim_var_values)}), for "
|
||||
f"argument shape {arg_shape}")
|
||||
raise ValueError(msg)
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
"""Tests for the jax2tf conversion for control-flow primitives."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from typing import Dict, Optional, Sequence
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import collections
|
||||
import functools
|
||||
@ -44,6 +44,8 @@ config.parse_flags_with_absl()
|
||||
# Import after parsing flags
|
||||
from jax.experimental.jax2tf.tests import primitive_harness
|
||||
|
||||
PS = jax2tf.PolyShape
|
||||
|
||||
class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -61,18 +63,18 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
self.CheckShapePolymorphism(f_jax,
|
||||
input_signature=[tf.TensorSpec([2, None])],
|
||||
polymorphic_shapes=["(_, h)"],
|
||||
polymorphic_shapes=["_, h"],
|
||||
expected_output_signature=tf.TensorSpec([2, None]))
|
||||
|
||||
self.CheckShapePolymorphism(f_jax,
|
||||
input_signature=[tf.TensorSpec([None, None])],
|
||||
polymorphic_shapes=["(h, h)"],
|
||||
polymorphic_shapes=["h, h"],
|
||||
expected_output_signature=tf.TensorSpec([None, None]))
|
||||
|
||||
def test_arg_avals(self):
|
||||
"""Test conversion of actual arguments to abstract values"""
|
||||
def check_avals(*, args: Sequence[jax2tf.jax2tf.TfVal],
|
||||
polymorphic_shapes: Sequence[Optional[str]],
|
||||
polymorphic_shapes: Sequence[Optional[Union[str, PS]]],
|
||||
expected_avals: Sequence[core.ShapedArray]):
|
||||
avals, shape_env = jax2tf.jax2tf._args_to_avals_and_env(args,
|
||||
polymorphic_shapes) # The function under test
|
||||
@ -98,11 +100,11 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
check_avals(args=[tf_const((2, 3))],
|
||||
polymorphic_shapes=[None],
|
||||
expected_avals=(shaped_array("2, 3", [2, 3]),))
|
||||
expected_avals=(shaped_array("2, 3,", [2, 3]),))
|
||||
|
||||
check_avals(args=[tf_var((2, 3))],
|
||||
polymorphic_shapes=[None],
|
||||
expected_avals=(shaped_array("2, 3", [2, 3]),))
|
||||
expected_avals=(shaped_array("(2, 3)", [2, 3]),))
|
||||
|
||||
check_avals(args=[const((2, 3))],
|
||||
polymorphic_shapes=["(2, 3)"],
|
||||
@ -112,13 +114,17 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
polymorphic_shapes=["(_, 3)"],
|
||||
expected_avals=(shaped_array("2, 3", [2, 3]),))
|
||||
|
||||
check_avals(args=[tf_const((2, 3))],
|
||||
polymorphic_shapes=[PS("_", 3)],
|
||||
expected_avals=(shaped_array("2, 3", [2, 3]),))
|
||||
|
||||
# Partially known shapes for the arguments
|
||||
check_avals(args=[tf_var([None, 3], initializer_shape=(2, 3))],
|
||||
polymorphic_shapes=["(b, 3)"],
|
||||
polymorphic_shapes=[PS("b", ...)],
|
||||
expected_avals=(shaped_array("(b, 3)", (2, 3)),))
|
||||
|
||||
check_avals(args=[tf_var([None, None], initializer_shape=(2, 3))],
|
||||
polymorphic_shapes=[("h, h")],
|
||||
polymorphic_shapes=["h, h"],
|
||||
expected_avals=(shaped_array("(h, h)", (2, 2)),))
|
||||
|
||||
check_avals(args=[tf_var([2, None], initializer_shape=(2, 3))],
|
||||
@ -130,66 +136,72 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
expected_avals=(shaped_array("(c, b, a)", (2, 3, 4)),),)
|
||||
|
||||
# Some errors
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, re.escape("PolyShape ')(' has invalid syntax")):
|
||||
check_avals(args=[const((2, 3))],
|
||||
polymorphic_shapes=[")("],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("polymorphic_shape must be specified when the argument shape (2, None) is partially known")):
|
||||
re.escape("PolyShape '..., 3' can contain Ellipsis only at the end.")):
|
||||
check_avals(args=[const((2, 3))],
|
||||
polymorphic_shapes=["..., 3"],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("PolyShape '2, 3, 4, ...' must match the rank of arguments (2, 3).")):
|
||||
check_avals(args=[const((2, 3))],
|
||||
polymorphic_shapes=["2, 3, 4, ..."],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("PolyShape '(Ellipsis, 3)' can contain Ellipsis only at the end.")):
|
||||
check_avals(args=[const((2, 3))],
|
||||
polymorphic_shapes=[PS(..., 3)],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("PolyShape 'None' in axis 1 must contain a shape variable for unknown dimension in argument shape (2, None)")):
|
||||
check_avals(args=[tf_var([2, None], initializer_shape=(2, 3))],
|
||||
polymorphic_shapes=[None],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("polymorphic_shape '()' has different rank than argument shape (2, 3)")):
|
||||
re.escape("PolyShape '()' must match the rank of arguments (2, 3)")):
|
||||
check_avals(args=[const((2, 3))],
|
||||
polymorphic_shapes=["()"],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("polymorphic_shape '(_, _)' has `_` placeholders for argument shape dimensions that are unknown: (2, None)")):
|
||||
re.escape("PolyShape '(_, _)' in axis 1 must contain a shape variable for unknown dimension in argument shape (2, None)")):
|
||||
check_avals(args=[tf_var([2, None], initializer_shape=(2, 3))],
|
||||
polymorphic_shapes=["(_, _)"],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("polymorphic_shape '(2, 13)' does not match argument shape (2, 3)")):
|
||||
re.escape("PolyShape '(2, 13)' in axis 1 must contain a constant or '_' for known dimension in argument shape (2, 3)")):
|
||||
check_avals(args=[const((2, 3))],
|
||||
polymorphic_shapes=["(2, 13)"],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("polymorphic_shape '(2, 3)' must contain shape variables for argument shape dimensions that are unknown: (2, None)")):
|
||||
re.escape("PolyShape '(2, 3)' in axis 1 must contain a shape variable for unknown dimension in argument shape (2, None)")):
|
||||
check_avals(args=[tf_var([2, None], initializer_shape=(2, 3))],
|
||||
polymorphic_shapes=["(2, 3)"],
|
||||
expected_avals=None)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
re.escape("polymorphic shape variable 'a' corresponds to multiple values ([2, 3]), in polymorphic_shape '(a, a)' and argument shape (2, 3)")):
|
||||
re.escape("PolyShape '(a, a)' has dimension variable 'a' corresponding to multiple values ([2, 3]), for argument shape (2, 3)")):
|
||||
check_avals(args=[tf_var([2, 3], initializer_shape=(2, 3))],
|
||||
polymorphic_shapes=["(a, a)"],
|
||||
expected_avals=None)
|
||||
|
||||
|
||||
def test_bad_polymorphic_shapes(self):
|
||||
def add2(x, y):
|
||||
return x + y
|
||||
|
||||
with self.assertRaisesRegex(shape_poly.ShapeSyntaxError, ""):
|
||||
self.CheckShapePolymorphism(add2,
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=[") + (", None],
|
||||
expected_output_signature=tf.TensorSpec([None]))
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.escape("polymorphic_shapes must be a sequence with the same length as the argument list (2). "
|
||||
"Got polymorphic_shapes_experimental=['(b, 4)']")):
|
||||
self.CheckShapePolymorphism(add2,
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["(b, 4)"],
|
||||
expected_output_signature=tf.TensorSpec([None]))
|
||||
|
||||
def test_pytree(self):
|
||||
"""Arguments and polymorphic_shapes are pytrees."""
|
||||
|
||||
@ -204,8 +216,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
input_signature=[([tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
[tf.TensorSpec([None])]),
|
||||
dict(a=tf.TensorSpec([None]), b=tf.TensorSpec([None]))],
|
||||
polymorphic_shapes=[(["(v,)", "(v,)"], [("v,")]),
|
||||
dict(a="(v,)", b="(v,)")],
|
||||
polymorphic_shapes=[(["v", "v"], [("v")]),
|
||||
dict(a="v", b="v")],
|
||||
expected_output_signature=tf.TensorSpec([None]))
|
||||
|
||||
# Now partial polymorphic_shapes; the parts of the polymorphic_shapes that are not specified
|
||||
@ -322,7 +334,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
y = np.ones((3,))
|
||||
res_jax = f(x, y)
|
||||
self.assertAllClose(res_jax,
|
||||
jax2tf.convert(f, polymorphic_shapes_experimental=["(b, h)", "h"])(x, y))
|
||||
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
|
||||
|
||||
def test_shape_error(self):
|
||||
"""Some of the examples from the README."""
|
||||
@ -339,12 +351,21 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.escape("add got incompatible shapes for broadcasting: (v,), (4,)")):
|
||||
jax2tf.convert(lambda x, y: x + y,
|
||||
polymorphic_shapes_experimental=["(v,)", "(4,)"])(four_ones, four_ones)
|
||||
polymorphic_shapes=["(v,)", "(4,)"])(four_ones, four_ones)
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
re.escape("Shape variable comparison v == 4 is inconclusive")):
|
||||
jax2tf.convert(lambda x: jnp.matmul(x, x),
|
||||
polymorphic_shapes_experimental=["(v, 4)"])(np.ones((4, 4)))
|
||||
polymorphic_shapes=["(v, 4)"])(np.ones((4, 4)))
|
||||
|
||||
|
||||
def test_parse_poly_spec(self):
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec(None, (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec("2, 3", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec("2, _", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec("2, ...", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3)))
|
||||
self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3)))
|
||||
|
||||
|
||||
def test_dim_vars(self):
|
||||
@ -432,9 +453,9 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
|
||||
return jnp.sum(x, axis=0) * jax2tf.shape_as_value(x)[0]
|
||||
|
||||
x = np.arange(3.)
|
||||
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes_experimental=["(b,)"])(x))
|
||||
self.assertAllClose(9., jax2tf.convert(jax.jit(f), polymorphic_shapes_experimental=["(b,)"])(x))
|
||||
self.assertAllClose(9., tf.function(jax2tf.convert(f, polymorphic_shapes_experimental=["(b,)"]))(x))
|
||||
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x))
|
||||
self.assertAllClose(9., jax2tf.convert(jax.jit(f), polymorphic_shapes=["(b,)"])(x))
|
||||
self.assertAllClose(9., tf.function(jax2tf.convert(f, polymorphic_shapes=["(b,)"]))(x))
|
||||
|
||||
res_primal, res_tangent = jax2tf.convert(
|
||||
lambda x, xt: jax.jvp(f, (x,), (xt,)),
|
||||
@ -443,7 +464,7 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
self.assertAllClose(np.array([3., 3., 3.]),
|
||||
jax2tf.convert(jax.grad(f),
|
||||
polymorphic_shapes_experimental=["b"])(x))
|
||||
polymorphic_shapes=["b"])(x))
|
||||
|
||||
xv = np.arange(24.).reshape((2, 3, 4))
|
||||
res_vmap = jax.vmap(f, in_axes=1)(xv)
|
||||
@ -466,7 +487,7 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
|
||||
operand=None)
|
||||
x = np.ones((2, 3, 4))
|
||||
self.assertAllClose(1., f(x))
|
||||
self.assertAllClose(1., jax2tf.convert(f, polymorphic_shapes_experimental=["(a, b, 4)"])(x))
|
||||
self.assertAllClose(1., jax2tf.convert(f, polymorphic_shapes=["(a, b, 4)"])(x))
|
||||
|
||||
def test_mean0(self):
|
||||
def f_jax(x):
|
||||
@ -676,7 +697,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
f_jax,
|
||||
input_signature=[tf.TensorSpec(lhs_shape),
|
||||
tf.TensorSpec(rhs_shape)],
|
||||
polymorphic_shapes=["b, _, _, _", "_, _, _, _"],
|
||||
polymorphic_shapes=["B, ...", None],
|
||||
expected_output_signature=tf.TensorSpec([None, 3, 3, 1]))
|
||||
self.assertAllClose(f_jax(lhs, rhs), f_tf(lhs, rhs))
|
||||
|
||||
@ -785,7 +806,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf = self.CheckShapePolymorphism(
|
||||
f,
|
||||
input_signature=[tf.TensorSpec([None, 3, 4]), tf.TensorSpec([2], np.int32)],
|
||||
polymorphic_shapes=["batch, _, _", "_"],
|
||||
polymorphic_shapes=["(batch, _, _)", "(_)"],
|
||||
expected_output_signature=tf.TensorSpec([None, 2, 4]))
|
||||
|
||||
self.assertAllClose(f(x, i), f_tf(x, i))
|
||||
@ -877,13 +898,13 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x: x.reshape([x.shape[0], -1]),
|
||||
input_signature=[tf.TensorSpec([None, 2, 3])],
|
||||
polymorphic_shapes=["(batch, _, _)"],
|
||||
polymorphic_shapes=["batch, _, _"],
|
||||
expected_output_signature=tf.TensorSpec([None, 6]))
|
||||
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x: x.reshape([x.shape[0], -1, x.shape[3], x.shape[2]]),
|
||||
input_signature=[tf.TensorSpec([None, 2, None, None, 3])],
|
||||
polymorphic_shapes=["(batch, 2, batch, height, 3)"],
|
||||
polymorphic_shapes=["batch, 2, batch, height, 3"],
|
||||
expected_output_signature=tf.TensorSpec([None, 6, None, None]))
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
@ -891,7 +912,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x: x.reshape([x.shape[0], -1, x.shape[2]]),
|
||||
input_signature=[tf.TensorSpec([None, 2, None, None, 3])],
|
||||
polymorphic_shapes=["(batch, 2, batch, height, 3)"],
|
||||
polymorphic_shapes=["batch, 2, batch, height, 3"],
|
||||
expected_output_signature=tf.TensorSpec([None, 6, None]))
|
||||
|
||||
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
|
||||
@ -899,7 +920,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
self.CheckShapePolymorphism(
|
||||
lambda x: x.reshape([x.shape[0], -1, 3]),
|
||||
input_signature=[tf.TensorSpec([None, 2, 4])],
|
||||
polymorphic_shapes=["(batch, _, _)"],
|
||||
polymorphic_shapes=[PS("batch", ...)],
|
||||
expected_output_signature=tf.TensorSpec([None, 1]))
|
||||
|
||||
def test_reshape_compiled(self):
|
||||
@ -917,7 +938,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
traced = False
|
||||
# If we get_concrete_function we trace once
|
||||
f_tf = tf.function(jax2tf.convert(f_jax, polymorphic_shapes_experimental=["(b, _, _)"]),
|
||||
f_tf = tf.function(jax2tf.convert(f_jax, polymorphic_shapes=[PS("b", ...)]),
|
||||
autograph=False,
|
||||
jit_compile=True).get_concrete_function(tf.TensorSpec([None, 2, 3], x.dtype))
|
||||
self.assertTrue(traced)
|
||||
@ -1016,7 +1037,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf = self.CheckShapePolymorphism(
|
||||
f_jax,
|
||||
input_signature=[tf.TensorSpec([None, 1], dtype=x.dtype)],
|
||||
polymorphic_shapes=["(b, _)"],
|
||||
polymorphic_shapes=[PS("b", ...)],
|
||||
expected_output_signature=tf.TensorSpec([None]))
|
||||
|
||||
self.assertAllClose(res_jax, f_tf(x))
|
||||
@ -1028,7 +1049,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
self.CheckShapePolymorphism(
|
||||
f_jax,
|
||||
input_signature=[tf.TensorSpec([None, None])],
|
||||
polymorphic_shapes=["(b1, b2)"],
|
||||
polymorphic_shapes=[PS("b1", "b2")],
|
||||
expected_output_signature=tf.TensorSpec([None]))
|
||||
|
||||
|
||||
|
@ -251,7 +251,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
must match the `input_signature`. (see jax2tf.convert).
|
||||
"""
|
||||
f_tf = tf.function(
|
||||
jax2tf.convert(f_jax, polymorphic_shapes_experimental=polymorphic_shapes),
|
||||
jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes),
|
||||
autograph=False,
|
||||
input_signature=input_signature)
|
||||
concrete_f_tf = f_tf.get_concrete_function(*input_signature)
|
||||
|
@ -9,6 +9,8 @@ filterwarnings =
|
||||
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
|
||||
ignore:can't resolve package from __spec__ or __package__:ImportWarning
|
||||
ignore:Using or importing the ABCs.*:DeprecationWarning
|
||||
# jax2tf tests due to mix of JAX and TF
|
||||
ignore:numpy.ufunc size changed
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user