Cleanup the API, and more documentation

This commit is contained in:
George Necula 2021-04-06 11:43:06 +03:00
parent 14737e365e
commit d9468c7513
8 changed files with 200 additions and 148 deletions

View File

@ -129,7 +129,7 @@ sets up symbolic links from site-packages into the repository.
To run all the JAX tests, we recommend using `pytest-xdist`, which can run tests in
parallel. First, install `pytest-xdist` and `pytest-benchmark` by running
`pip install pytest-xdist pytest-benchmark`.
`ip install -r build/test-requirements.txt`.
Then, from the repository root directory run:
```

View File

@ -1316,25 +1316,25 @@ def symbolic_equal_one_of_dim(d1: DimSize, dlist: Sequence[DimSize]) -> bool:
def symbolic_equal_shape(s1: Shape, s2: Shape) -> bool:
"""See DimensionHandler.symbolic_equal."""
return (len(s1) == len(s2) and
all(safe_map(symbolic_equal_dim, s1, s2)))
all(map(symbolic_equal_dim, s1, s2)))
def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool:
return _get_dim_handler(d1, d2).greater_equal(d1, d2)
def greater_equal_shape(s1: Shape, s2: Shape) -> bool:
return all(safe_map(greater_equal_dim, s1, s2))
return all(map(greater_equal_dim, s1, s2))
def sum_dim(*ds: DimSize) -> DimSize:
return _get_dim_handler(*ds).sum(*ds)
def sum_shapes(*ss: Shape) -> Shape:
return tuple(safe_map(sum_dim, *ss))
return tuple(map(sum_dim, *ss))
def diff_dim(d1: DimSize, d2: DimSize) -> DimSize:
return _get_dim_handler(d1, d2).diff(d1, d2)
def diff_shape(s1: Shape, s2: Shape) -> Shape:
return tuple(safe_map(diff_dim, s1, s2))
return tuple(map(diff_dim, s1, s2))
def divide_shape_sizes(s1: Shape, s2: Shape) -> int:
s1 = s1 or (1,)
@ -1349,14 +1349,14 @@ def dilate_dim(d: DimSize, dilation: DimSize) -> DimSize:
return _get_dim_handler(d, dilation).dilate(d, dilation)
def dilate_shape(s: Shape, dilations: Sequence[int]) -> Shape:
return tuple(safe_map(dilate_dim, s, dilations))
return tuple(map(dilate_dim, s, dilations))
def stride_dim(d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize:
return _get_dim_handler(d, window_size, window_stride).stride(d, window_size, window_stride)
def stride_shape(s: Shape, window_size: Shape, window_stride: Shape) -> Shape:
"""(s - window_size) // window_stride + 1"""
return tuple(safe_map(stride_dim, s, window_size, window_stride))
return tuple(map(stride_dim, s, window_size, window_stride))
def _canonicalize_dimension(dim: DimSize) -> DimSize:

View File

@ -13,5 +13,5 @@
# limitations under the License.
# flake8: noqa: F401
from .jax2tf import convert, shape_as_value, split_to_logical_devices
from .jax2tf import convert, shape_as_value, split_to_logical_devices, PolyShape
from .call_tf import call_tf

View File

@ -34,7 +34,7 @@ from jax._src.lax import linalg as lax_linalg
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import fft as lax_fft
import jax._src.random
from . import shape_poly
from jax.experimental.jax2tf import shape_poly
import numpy as np
import tensorflow as tf # type: ignore[import]
@ -48,6 +48,7 @@ from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # ty
from jax.lib import xla_client
PolyShape = shape_poly.PolyShape
# The scope name need to be a valid TensorFlow name. See
# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731
@ -119,8 +120,8 @@ def _xla_path_disabled_error(primitive_name: str) -> Exception:
@functools.partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun: Callable, *,
polymorphic_shapes_experimental: Optional[Sequence[Any]]=None,
in_shapes=None,
polymorphic_shapes: Optional[Sequence[Any]]=None,
in_shapes=None, # DEPRECATED
with_gradient=True, enable_xla=True) -> Callable:
"""Transforms `fun` to be executed by TensorFlow.
@ -129,43 +130,45 @@ def convert(fun: Callable, *,
Args:
fun: Function to be transformed. Its arguments and return value should be
JAX arrays, or (nested) standard Python containers (tuple/list/dict)
thereof.
JAX arrays, or nested standard Python containers (tuple/list/dict)
thereof (pytrees).
polymorphic_shapes_experimental: an optional sequence of shape specifications,
one for each argument of the function to be converted. Default is `None`,
in which case the argument shape specifications are taken from the shapes
of the actual arguments.
A non-default `polymorphic_shapes` is used to specify shape variables for
some of the input dimensions, to specify that the conversion to TF must be
done for any possible non-zero values for the shape variables.
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during conversion.
.. warning::
The shape-polymorphic conversion is an experimental feature. It is meant
to be sound, but it is known to reject some JAX programs that are
shape polymorphic. The details of this feature can change.
It should be a Python object with the same pytree structure as,
or a prefix of, the tuple of arguments to the function,
but with a shape specification corresponding to each argument.
The default value is `None`, which is a shortcut for a tuple of `None`
one for each argument, denoting that all shapes are monomorphic.
See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument
should be an object `PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size,
or a string denoting a dimension variable assumed to range over non-zero
dimension sizes,
or the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument.
As a shortcut, an Ellipsis suffix in the
list of dimension specifications stands for a list of "_" placeholders.
For convenience, a shape specification can also be given as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The conversion fails if it cannot ensure that the it would produce the same
sequence of TF ops for any non-zero values of shape variables. This feature
is experimental, and it may fail loudly even for code that is actually
shape polymorphic.
sequence of TF ops for any non-zero values of the dimension variables.
If an argument is a pytree, then the
shape specification must be a matching pytree or `None`.
See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification should be a string, with comma-separated dimension
specifications, and optionally wrapped in parentheses. A dimension
specification is either a number, or the placeholder `_`, or a lowercase
word denoting a name for a dimension variable.
In presence of dimension variables, the conversion is done with a
shape abstraction that allows any non-zero concrete value for the variable.
Examples of shape specifications:
* `[None, "(batch, 16)"]`: no specification for the first argument (takes
the shape from the actual argument); the second argument is a 2D
array with the first dimension size set to a variable `batch` and the
second dimension 16.
* `["(batch, _)", "(batch,)"]`: the leading dimensions of the two arguments
must match. The second dimension of the first argument is taken from the
actual argument shape.
See [the README](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
in_shapes: DEPRECATED in favor of `polymorphic_shapes_experimental`.
in_shapes: DEPRECATED in favor of `polymorphic_shapes`.
with_gradient: if set, will add a tf.custom_gradient to the converted
function, by converting the ``jax.vjp(fun)``. Only first-order
@ -185,7 +188,6 @@ def convert(fun: Callable, *,
global _enable_xla
_enable_xla = enable_xla
api._check_callable(fun)
polymorphic_shapes = polymorphic_shapes_experimental
def converted_fun(*args: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
@ -214,13 +216,14 @@ def convert(fun: Callable, *,
else:
if not isinstance(polymorphic_shapes, Sequence) or len(args) != len(polymorphic_shapes):
msg = ("polymorphic_shapes must be a sequence with the same length as the argument list "
f"({len(args)}). Got polymorphic_shapes_experimental={polymorphic_shapes}.")
f"({len(args)}). Got polymorphic_shapes={polymorphic_shapes}.")
raise TypeError(msg)
polymorphic_shapes_ = tuple(polymorphic_shapes)
# Expand the in_shapes to match the argument pytree
polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
in_tree.children()[0], polymorphic_shapes_))
in_tree.children()[0],
polymorphic_shapes_))
# Construct the abstract values for the flat arguments, possibly based on
# the input shapes and the in_shapes if given. May create new shape
@ -262,7 +265,7 @@ def convert(fun: Callable, *,
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
in_cts = convert(fun_vjp_jax, with_gradient=False,
polymorphic_shapes_experimental=vjp_polymorphic_shapes)(args, out_cts)
polymorphic_shapes=vjp_polymorphic_shapes)(args, out_cts)
return in_cts
try:
@ -295,6 +298,7 @@ def convert(fun: Callable, *,
return converted_fun
# Internals
@ -388,7 +392,7 @@ def _tfval_shape_dtype(val: TfVal) -> Tuple[Sequence[Optional[int]], DType]:
# function arguments.
_ShapeEnv = Dict[shape_poly.DimVar, TfVal]
def _args_to_avals_and_env(args: Sequence[TfVal],
polymorphic_shapes: Sequence[Optional[str]]) -> \
polymorphic_shapes: Sequence[Optional[Union[str, PolyShape]]]) -> \
Tuple[Sequence[core.AbstractValue], _ShapeEnv]:
"""Computes abstract values and a dimension environment for arguments.
@ -421,7 +425,7 @@ def _args_to_avals_and_env(args: Sequence[TfVal],
return core.ShapedArray(aval_shape, dtype)
avals = tuple(map(input_aval, args, polymorphic_shapes))
avals = tuple(map(input_aval, args, polymorphic_shapes)) # type: ignore
return avals, shapeenv
# A shape environment maps shape variables to TfVal.

View File

@ -18,8 +18,7 @@ For usage instructions, read the jax2tf.convert docstring, and the
"""
import collections
import string
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
from jax import core
@ -114,14 +113,20 @@ def _split_shape_ints(shape: Shape) -> Tuple[Sequence[int], Sequence[DimVar]]:
return shape_ints, shape_vars
class ShapeSyntaxError(Exception): pass
class PolyShape(tuple):
"""Tuple of polymorphic dimension specifications.
_identifiers = frozenset(string.ascii_lowercase)
def parse_spec(spec: Optional[str],
See docstring of :func:`jax2tf.convert`.
"""
def __new__(cls, *dim_specs):
return tuple.__new__(PolyShape, dim_specs)
def parse_spec(spec: Optional[Union[str, PolyShape]],
arg_shape: Sequence[Optional[int]]) -> Tuple[DimSize, ...]:
"""Parse the shape polymorphic specification for one array argument.
Args:
spec: a shape polymorphic specification.
spec: a shape polymorphic specification, either a string, or a PolyShape.
arg_shape: an actual shape, possibly containing unknown dimensions (None).
The placeholders `_` in the specification are replaced with the values from
@ -129,55 +134,75 @@ def parse_spec(spec: Optional[str],
See the README.md for usage.
"""
shape_var_map: Dict[str, Set[int]] = collections.defaultdict(set)
def _parse_dim(dim_spec: str, dim_size: Optional[int]) -> Union[int, DimSize]:
if dim_spec == '_':
if dim_size is None:
msg = (f"polymorphic_shape '{spec}' has `_` placeholders for argument shape "
f"dimensions that are unknown: {arg_shape}")
raise ValueError(msg)
return dim_size
elif dim_spec.isdigit():
spec_size = int(dim_spec)
if dim_size != spec_size:
if dim_size is None:
msg = (f"polymorphic_shape '{spec}' must contain shape variables for argument shape "
f"dimensions that are unknown: {arg_shape}")
else:
msg = (f"polymorphic_shape '{spec}' does not match argument shape {arg_shape}")
raise ValueError(msg)
return spec_size
elif dim_spec[0] in _identifiers:
if dim_size is not None:
shape_var_map[dim_spec].add(dim_size)
return DimVar(dim_spec)
if spec is None:
spec_tuple = (...,) # type: Tuple[Any,...]
elif isinstance(spec, PolyShape):
spec_tuple = tuple(spec)
elif isinstance(spec, str):
spec_ = spec.replace(" ", "")
if spec_[0] == "(":
if spec_[-1] != ")":
raise ValueError(spec)
spec_ = spec_[1:-1]
spec_ = spec_.rstrip(",")
if not spec_:
spec_tuple = ()
else:
raise ShapeSyntaxError(dim_spec)
specs = spec_.split(',')
def parse_dim(ds: str):
if ds == "...":
return ...
elif ds.isdigit():
return int(ds)
elif ds == "_" or ds.isalnum():
return ds
else:
raise ValueError(f"PolyShape '{spec}' has invalid syntax")
if not spec:
if any(d is None for d in arg_shape):
msg = ("polymorphic_shape must be specified when the argument "
f"shape {arg_shape} is partially known.")
raise ValueError(msg)
return tuple(arg_shape)
if spec[0] == '(':
if spec[-1] != ')':
raise ShapeSyntaxError(spec)
spec_ = spec[1:-1]
spec_tuple = tuple(map(parse_dim, specs))
else:
spec_ = spec
specs = spec_.replace(' ', '').strip(',').split(',')
if len(specs) != len(arg_shape):
msg = (f"polymorphic_shape '{spec}' has different rank than argument "
f"shape {arg_shape}")
raise ValueError(msg)
dims = tuple(map(_parse_dim, specs, arg_shape))
raise ValueError(f"PolyShape '{spec}' must be either None, a string, or PolyShape.")
ds_ellipses = tuple(ds for ds in spec_tuple if ds == ...)
if ds_ellipses:
if len(ds_ellipses) > 1 or spec_tuple[-1] != ...:
raise ValueError(f"PolyShape '{spec}' can contain Ellipsis only at the end.")
spec_tuple = spec_tuple[0:-1]
if len(arg_shape) >= len(spec_tuple):
spec_tuple = spec_tuple + ("_",) * (len(arg_shape) - len(spec_tuple))
if len(arg_shape) != len(spec_tuple):
raise ValueError(f"PolyShape '{spec}' must match the rank of arguments {arg_shape}.")
shape_var_map: Dict[str, Set[int]] = collections.defaultdict(set)
def _process_dim(i: int, dim_spec):
if not isinstance(dim_spec, (str, int)):
raise ValueError(f"PolyShape '{spec}' in axis {i} must contain only integers, strings, or Ellipsis.")
dim_size = arg_shape[i]
if dim_size is None:
if dim_spec == "_" or not isinstance(dim_spec, str):
msg = (f"PolyShape '{spec}' in axis {i} must contain a shape variable "
f"for unknown dimension in argument shape {arg_shape}")
raise ValueError(msg)
return DimVar(dim_spec)
else: # dim_size is known
if dim_spec == "_":
return dim_size
if isinstance(dim_spec, int):
if dim_spec != dim_size:
msg = (f"PolyShape '{spec}' in axis {i} must contain a constant or '_' "
f"for known dimension in argument shape {arg_shape}")
raise ValueError(msg)
return dim_size
# We have a dimension variable for a known dimension.
shape_var_map[dim_spec].add(dim_size)
return DimVar(dim_spec)
dims = tuple([_process_dim(i, ds) for i, ds in enumerate(spec_tuple)])
for dim_var, dim_var_values in shape_var_map.items():
if len(dim_var_values) != 1:
msg = (f"polymorphic shape variable '{dim_var}' corresponds to multiple "
f"values ({sorted(dim_var_values)}), in polymorphic_shape '{spec}' and "
msg = (f"PolyShape '{spec}' has dimension variable '{dim_var}' "
f"corresponding to multiple values ({sorted(dim_var_values)}), for "
f"argument shape {arg_shape}")
raise ValueError(msg)

View File

@ -14,7 +14,7 @@
"""Tests for the jax2tf conversion for control-flow primitives."""
from absl.testing import absltest
from typing import Dict, Optional, Sequence
from typing import Dict, Optional, Sequence, Union
import collections
import functools
@ -44,6 +44,8 @@ config.parse_flags_with_absl()
# Import after parsing flags
from jax.experimental.jax2tf.tests import primitive_harness
PS = jax2tf.PolyShape
class ShapePolyTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
@ -61,18 +63,18 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
self.CheckShapePolymorphism(f_jax,
input_signature=[tf.TensorSpec([2, None])],
polymorphic_shapes=["(_, h)"],
polymorphic_shapes=["_, h"],
expected_output_signature=tf.TensorSpec([2, None]))
self.CheckShapePolymorphism(f_jax,
input_signature=[tf.TensorSpec([None, None])],
polymorphic_shapes=["(h, h)"],
polymorphic_shapes=["h, h"],
expected_output_signature=tf.TensorSpec([None, None]))
def test_arg_avals(self):
"""Test conversion of actual arguments to abstract values"""
def check_avals(*, args: Sequence[jax2tf.jax2tf.TfVal],
polymorphic_shapes: Sequence[Optional[str]],
polymorphic_shapes: Sequence[Optional[Union[str, PS]]],
expected_avals: Sequence[core.ShapedArray]):
avals, shape_env = jax2tf.jax2tf._args_to_avals_and_env(args,
polymorphic_shapes) # The function under test
@ -98,11 +100,11 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
check_avals(args=[tf_const((2, 3))],
polymorphic_shapes=[None],
expected_avals=(shaped_array("2, 3", [2, 3]),))
expected_avals=(shaped_array("2, 3,", [2, 3]),))
check_avals(args=[tf_var((2, 3))],
polymorphic_shapes=[None],
expected_avals=(shaped_array("2, 3", [2, 3]),))
expected_avals=(shaped_array("(2, 3)", [2, 3]),))
check_avals(args=[const((2, 3))],
polymorphic_shapes=["(2, 3)"],
@ -112,13 +114,17 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
polymorphic_shapes=["(_, 3)"],
expected_avals=(shaped_array("2, 3", [2, 3]),))
check_avals(args=[tf_const((2, 3))],
polymorphic_shapes=[PS("_", 3)],
expected_avals=(shaped_array("2, 3", [2, 3]),))
# Partially known shapes for the arguments
check_avals(args=[tf_var([None, 3], initializer_shape=(2, 3))],
polymorphic_shapes=["(b, 3)"],
polymorphic_shapes=[PS("b", ...)],
expected_avals=(shaped_array("(b, 3)", (2, 3)),))
check_avals(args=[tf_var([None, None], initializer_shape=(2, 3))],
polymorphic_shapes=[("h, h")],
polymorphic_shapes=["h, h"],
expected_avals=(shaped_array("(h, h)", (2, 2)),))
check_avals(args=[tf_var([2, None], initializer_shape=(2, 3))],
@ -130,66 +136,72 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
expected_avals=(shaped_array("(c, b, a)", (2, 3, 4)),),)
# Some errors
with self.assertRaisesRegex(
ValueError, re.escape("PolyShape ')(' has invalid syntax")):
check_avals(args=[const((2, 3))],
polymorphic_shapes=[")("],
expected_avals=None)
with self.assertRaisesRegex(ValueError,
re.escape("polymorphic_shape must be specified when the argument shape (2, None) is partially known")):
re.escape("PolyShape '..., 3' can contain Ellipsis only at the end.")):
check_avals(args=[const((2, 3))],
polymorphic_shapes=["..., 3"],
expected_avals=None)
with self.assertRaisesRegex(ValueError,
re.escape("PolyShape '2, 3, 4, ...' must match the rank of arguments (2, 3).")):
check_avals(args=[const((2, 3))],
polymorphic_shapes=["2, 3, 4, ..."],
expected_avals=None)
with self.assertRaisesRegex(ValueError,
re.escape("PolyShape '(Ellipsis, 3)' can contain Ellipsis only at the end.")):
check_avals(args=[const((2, 3))],
polymorphic_shapes=[PS(..., 3)],
expected_avals=None)
with self.assertRaisesRegex(ValueError,
re.escape("PolyShape 'None' in axis 1 must contain a shape variable for unknown dimension in argument shape (2, None)")):
check_avals(args=[tf_var([2, None], initializer_shape=(2, 3))],
polymorphic_shapes=[None],
expected_avals=None)
with self.assertRaisesRegex(
ValueError,
re.escape("polymorphic_shape '()' has different rank than argument shape (2, 3)")):
re.escape("PolyShape '()' must match the rank of arguments (2, 3)")):
check_avals(args=[const((2, 3))],
polymorphic_shapes=["()"],
expected_avals=None)
with self.assertRaisesRegex(
ValueError,
re.escape("polymorphic_shape '(_, _)' has `_` placeholders for argument shape dimensions that are unknown: (2, None)")):
re.escape("PolyShape '(_, _)' in axis 1 must contain a shape variable for unknown dimension in argument shape (2, None)")):
check_avals(args=[tf_var([2, None], initializer_shape=(2, 3))],
polymorphic_shapes=["(_, _)"],
expected_avals=None)
with self.assertRaisesRegex(
ValueError,
re.escape("polymorphic_shape '(2, 13)' does not match argument shape (2, 3)")):
re.escape("PolyShape '(2, 13)' in axis 1 must contain a constant or '_' for known dimension in argument shape (2, 3)")):
check_avals(args=[const((2, 3))],
polymorphic_shapes=["(2, 13)"],
expected_avals=None)
with self.assertRaisesRegex(
ValueError,
re.escape("polymorphic_shape '(2, 3)' must contain shape variables for argument shape dimensions that are unknown: (2, None)")):
re.escape("PolyShape '(2, 3)' in axis 1 must contain a shape variable for unknown dimension in argument shape (2, None)")):
check_avals(args=[tf_var([2, None], initializer_shape=(2, 3))],
polymorphic_shapes=["(2, 3)"],
expected_avals=None)
with self.assertRaisesRegex(
ValueError,
re.escape("polymorphic shape variable 'a' corresponds to multiple values ([2, 3]), in polymorphic_shape '(a, a)' and argument shape (2, 3)")):
re.escape("PolyShape '(a, a)' has dimension variable 'a' corresponding to multiple values ([2, 3]), for argument shape (2, 3)")):
check_avals(args=[tf_var([2, 3], initializer_shape=(2, 3))],
polymorphic_shapes=["(a, a)"],
expected_avals=None)
def test_bad_polymorphic_shapes(self):
def add2(x, y):
return x + y
with self.assertRaisesRegex(shape_poly.ShapeSyntaxError, ""):
self.CheckShapePolymorphism(add2,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=[") + (", None],
expected_output_signature=tf.TensorSpec([None]))
with self.assertRaisesRegex(TypeError,
re.escape("polymorphic_shapes must be a sequence with the same length as the argument list (2). "
"Got polymorphic_shapes_experimental=['(b, 4)']")):
self.CheckShapePolymorphism(add2,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["(b, 4)"],
expected_output_signature=tf.TensorSpec([None]))
def test_pytree(self):
"""Arguments and polymorphic_shapes are pytrees."""
@ -204,8 +216,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
input_signature=[([tf.TensorSpec([None]), tf.TensorSpec([None])],
[tf.TensorSpec([None])]),
dict(a=tf.TensorSpec([None]), b=tf.TensorSpec([None]))],
polymorphic_shapes=[(["(v,)", "(v,)"], [("v,")]),
dict(a="(v,)", b="(v,)")],
polymorphic_shapes=[(["v", "v"], [("v")]),
dict(a="v", b="v")],
expected_output_signature=tf.TensorSpec([None]))
# Now partial polymorphic_shapes; the parts of the polymorphic_shapes that are not specified
@ -322,7 +334,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
y = np.ones((3,))
res_jax = f(x, y)
self.assertAllClose(res_jax,
jax2tf.convert(f, polymorphic_shapes_experimental=["(b, h)", "h"])(x, y))
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
def test_shape_error(self):
"""Some of the examples from the README."""
@ -339,12 +351,21 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
with self.assertRaisesRegex(TypeError,
re.escape("add got incompatible shapes for broadcasting: (v,), (4,)")):
jax2tf.convert(lambda x, y: x + y,
polymorphic_shapes_experimental=["(v,)", "(4,)"])(four_ones, four_ones)
polymorphic_shapes=["(v,)", "(4,)"])(four_ones, four_ones)
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
re.escape("Shape variable comparison v == 4 is inconclusive")):
jax2tf.convert(lambda x: jnp.matmul(x, x),
polymorphic_shapes_experimental=["(v, 4)"])(np.ones((4, 4)))
polymorphic_shapes=["(v, 4)"])(np.ones((4, 4)))
def test_parse_poly_spec(self):
self.assertEqual((2, 3), shape_poly.parse_spec(None, (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec("2, 3", (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec("2, _", (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec("2, ...", (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3)))
self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3)))
def test_dim_vars(self):
@ -432,9 +453,9 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
return jnp.sum(x, axis=0) * jax2tf.shape_as_value(x)[0]
x = np.arange(3.)
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes_experimental=["(b,)"])(x))
self.assertAllClose(9., jax2tf.convert(jax.jit(f), polymorphic_shapes_experimental=["(b,)"])(x))
self.assertAllClose(9., tf.function(jax2tf.convert(f, polymorphic_shapes_experimental=["(b,)"]))(x))
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x))
self.assertAllClose(9., jax2tf.convert(jax.jit(f), polymorphic_shapes=["(b,)"])(x))
self.assertAllClose(9., tf.function(jax2tf.convert(f, polymorphic_shapes=["(b,)"]))(x))
res_primal, res_tangent = jax2tf.convert(
lambda x, xt: jax.jvp(f, (x,), (xt,)),
@ -443,7 +464,7 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(np.array([3., 3., 3.]),
jax2tf.convert(jax.grad(f),
polymorphic_shapes_experimental=["b"])(x))
polymorphic_shapes=["b"])(x))
xv = np.arange(24.).reshape((2, 3, 4))
res_vmap = jax.vmap(f, in_axes=1)(xv)
@ -466,7 +487,7 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
operand=None)
x = np.ones((2, 3, 4))
self.assertAllClose(1., f(x))
self.assertAllClose(1., jax2tf.convert(f, polymorphic_shapes_experimental=["(a, b, 4)"])(x))
self.assertAllClose(1., jax2tf.convert(f, polymorphic_shapes=["(a, b, 4)"])(x))
def test_mean0(self):
def f_jax(x):
@ -676,7 +697,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
f_jax,
input_signature=[tf.TensorSpec(lhs_shape),
tf.TensorSpec(rhs_shape)],
polymorphic_shapes=["b, _, _, _", "_, _, _, _"],
polymorphic_shapes=["B, ...", None],
expected_output_signature=tf.TensorSpec([None, 3, 3, 1]))
self.assertAllClose(f_jax(lhs, rhs), f_tf(lhs, rhs))
@ -785,7 +806,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
f_tf = self.CheckShapePolymorphism(
f,
input_signature=[tf.TensorSpec([None, 3, 4]), tf.TensorSpec([2], np.int32)],
polymorphic_shapes=["batch, _, _", "_"],
polymorphic_shapes=["(batch, _, _)", "(_)"],
expected_output_signature=tf.TensorSpec([None, 2, 4]))
self.assertAllClose(f(x, i), f_tf(x, i))
@ -877,13 +898,13 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
self.CheckShapePolymorphism(
lambda x: x.reshape([x.shape[0], -1]),
input_signature=[tf.TensorSpec([None, 2, 3])],
polymorphic_shapes=["(batch, _, _)"],
polymorphic_shapes=["batch, _, _"],
expected_output_signature=tf.TensorSpec([None, 6]))
self.CheckShapePolymorphism(
lambda x: x.reshape([x.shape[0], -1, x.shape[3], x.shape[2]]),
input_signature=[tf.TensorSpec([None, 2, None, None, 3])],
polymorphic_shapes=["(batch, 2, batch, height, 3)"],
polymorphic_shapes=["batch, 2, batch, height, 3"],
expected_output_signature=tf.TensorSpec([None, 6, None, None]))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
@ -891,7 +912,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
self.CheckShapePolymorphism(
lambda x: x.reshape([x.shape[0], -1, x.shape[2]]),
input_signature=[tf.TensorSpec([None, 2, None, None, 3])],
polymorphic_shapes=["(batch, 2, batch, height, 3)"],
polymorphic_shapes=["batch, 2, batch, height, 3"],
expected_output_signature=tf.TensorSpec([None, 6, None]))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
@ -899,7 +920,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
self.CheckShapePolymorphism(
lambda x: x.reshape([x.shape[0], -1, 3]),
input_signature=[tf.TensorSpec([None, 2, 4])],
polymorphic_shapes=["(batch, _, _)"],
polymorphic_shapes=[PS("batch", ...)],
expected_output_signature=tf.TensorSpec([None, 1]))
def test_reshape_compiled(self):
@ -917,7 +938,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
traced = False
# If we get_concrete_function we trace once
f_tf = tf.function(jax2tf.convert(f_jax, polymorphic_shapes_experimental=["(b, _, _)"]),
f_tf = tf.function(jax2tf.convert(f_jax, polymorphic_shapes=[PS("b", ...)]),
autograph=False,
jit_compile=True).get_concrete_function(tf.TensorSpec([None, 2, 3], x.dtype))
self.assertTrue(traced)
@ -1016,7 +1037,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
f_tf = self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([None, 1], dtype=x.dtype)],
polymorphic_shapes=["(b, _)"],
polymorphic_shapes=[PS("b", ...)],
expected_output_signature=tf.TensorSpec([None]))
self.assertAllClose(res_jax, f_tf(x))
@ -1028,7 +1049,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([None, None])],
polymorphic_shapes=["(b1, b2)"],
polymorphic_shapes=[PS("b1", "b2")],
expected_output_signature=tf.TensorSpec([None]))

View File

@ -251,7 +251,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
must match the `input_signature`. (see jax2tf.convert).
"""
f_tf = tf.function(
jax2tf.convert(f_jax, polymorphic_shapes_experimental=polymorphic_shapes),
jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes),
autograph=False,
input_signature=input_signature)
concrete_f_tf = f_tf.get_concrete_function(*input_signature)

View File

@ -9,6 +9,8 @@ filterwarnings =
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
ignore:can't resolve package from __spec__ or __package__:ImportWarning
ignore:Using or importing the ABCs.*:DeprecationWarning
# jax2tf tests due to mix of JAX and TF
ignore:numpy.ufunc size changed
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"