[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.

The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.

This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,

```
def average(x):
   return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```

This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.

Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:

```
def dim_as_value(d):
   jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```

We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
This commit is contained in:
George Necula 2021-07-16 20:01:22 +03:00
parent 87f72ac446
commit b62ceba91c
9 changed files with 227 additions and 175 deletions

View File

@ -11,15 +11,18 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.19 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.18...main).
* New features:
* Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g., `jnp.mean`.
({jax-issue}`#7317`)
## jaxlib 0.1.70 (unreleased)
* Breaking changes:
* Support for Python 3.6 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported Python version.
* Breaking changes:
* The host_callback mechnism now uses one thread per local device for
* The host_callback mechanism now uses one thread per local device for
making the calls to the Python callbacks. Previously there was a single
thread for all devices. This means that the callbacks may now be called
interleaved. The callbacks corresponding to one device will still be

View File

@ -6763,7 +6763,8 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))
def _dynamic_slice_indices(operand, start_indices):
def _dynamic_slice_indices(operand, start_indices):
# Normalize the start_indices w.r.t. operand.shape
if len(start_indices) != operand.ndim:
msg = ("Length of slice indices must match number of operand dimensions ({} "
"vs {})")
@ -6772,15 +6773,13 @@ def _dynamic_slice_indices(operand, start_indices):
if start_indices.ndim != 1:
raise ValueError("Slice indices must be a 1D sequence, got {}"
.format(start_indices.shape))
return select(lt(start_indices, _zeros(start_indices)),
add(start_indices, _const(start_indices, operand.shape)),
start_indices)
else:
return [np.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_))
if isinstance(i, (int, np.integer))
else select(lt(i, _const(i, 0)), add(i, _const(i, d)), i)
for i, d in zip(start_indices, operand.shape)]
start_indices = [i for i in start_indices]
return [np.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_))
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d)
else select(lt(i, _const(i, 0)),
add(i, convert_element_type(core.dimension_as_value(d), _dtype(i))),
i)
for i, d in zip(start_indices, operand.shape)]
def _const(example, val):
dtype = _dtype(example)

View File

@ -2162,9 +2162,9 @@ def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
if where is None:
if axis is None:
normalizer = size(a)
normalizer = core.dimension_as_value(size(a))
else:
normalizer = _axis_size(a, axis)
normalizer = core.dimension_as_value(_axis_size(a, axis))
else:
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
@ -2187,9 +2187,9 @@ def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
if weights is None: # Treat all weights as 1
avg = mean(a, axis=axis)
if axis is None:
weights_sum = full((), size(a), dtype=avg.dtype)
weights_sum = full((), core.dimension_as_value(size(a)), dtype=avg.dtype)
else:
weights_sum = full_like(avg, a.shape[axis], dtype=avg.dtype)
weights_sum = full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype)
else:
weights = asarray(weights)
@ -2212,7 +2212,7 @@ def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
if len(weights_shape) != 1:
raise ValueError("1D weights expected when shapes of a and "
"weights differ.")
if weights_shape[0] != a_shape[axis]:
if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
raise ValueError("Length of weights not "
"compatible with specified axis.")
@ -2247,9 +2247,9 @@ def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
if where is None:
if axis is None:
normalizer = size(a)
normalizer = core.dimension_as_value(size(a))
else:
normalizer = _axis_size(a, axis)
normalizer = core.dimension_as_value(_axis_size(a, axis))
else:
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
normalizer = normalizer - ddof
@ -4782,9 +4782,14 @@ def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
def _normalize_index(index, axis_size):
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
if core.is_constant_dim(axis_size):
axis_size_val = _constant_like(index, axis_size)
else:
axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size),
_dtype(index))
return lax.select(
lax.lt(index, _constant_like(index, 0)),
lax.add(index, _constant_like(index, axis_size)),
lax.add(index, axis_size_val),
index)
@partial(jit, static_argnums=(2,))
@ -5199,7 +5204,7 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
abstract_i = None
# Handle basic int indexes.
if isinstance(abstract_i, (ConcreteArray,ShapedArray)) and _int(abstract_i):
if x_shape[x_axis] == 0:
if core.symbolic_equal_dim(x_shape[x_axis], 0):
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i

View File

@ -1283,6 +1283,9 @@ class DimensionHandler:
"""Implements `0 if d == 0 else 1 + dilation * (d - 1))`"""
return 0 if d == 0 else 1 + dilation * (d - 1)
def as_value(self, d: DimSize):
"""Turns a dimension size into a JAX value that we can compute with."""
return d
_dimension_handler_int = DimensionHandler()
_SPECIAL_DIMENSION_HANDLERS: Dict[type, DimensionHandler] = {}
@ -1378,6 +1381,11 @@ def stride_shape(s: Shape, window_size: Shape, window_stride: Shape) -> Shape:
"""(s - window_size) // window_stride + 1"""
return tuple(map(stride_dim, s, window_size, window_stride))
def dimension_as_value(d: DimSize):
"""Turns a dimension size into a JAX value that we can compute with.
This is the identity function for constant dimensions."""
handler, ds = _dim_handler_and_canonical(d)
return handler.as_value(*ds)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
if type(dim) in _SPECIAL_DIMENSION_HANDLERS:

View File

@ -13,5 +13,5 @@
# limitations under the License.
# flake8: noqa: F401
from .jax2tf import convert, dtype_of_val, shape_as_value, split_to_logical_devices, PolyShape
from .jax2tf import convert, dtype_of_val, split_to_logical_devices, PolyShape
from .call_tf import call_tf

View File

@ -648,91 +648,6 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
return shape_poly.eval_shape(shape, _thread_local_state.shape_env)
def shape_as_value(x):
"""Injects the shape of `x` as an array value.
**Experimental: please give feedback, and expect changes!**
This allows the use of a shape expression as array argument to JAX functions.
A typical example is for implementing a mean operation:
jnp.sum(x) / np.prod(jax2tf.shape_as_value(x))
"""
# return shape_as_value_p.bind(x)
return NotImplementedError("shape_as_value is deprecated")
# # TODO: move this to masking or to some common library, if approved
# shape_as_value_p = core.Primitive("shape_as_value")
# shape_as_value_p.multiple_results = True
# def _shape_as_value_impl(x):
# x_shape = np.shape(x)
# def dim_to_int(dim: shape_poly.DimSize) -> int:
# dim_int = _poly_dim_to_tf_dim(dim)
# if dim_int is None:
# msg = ("shape_as_value is not implemented for non-constant shapes "
# "except for masking and jax2tf. "
# f"Has shape: {x_shape}")
# raise TypeError(msg)
# else:
# return dim_int
# return tuple(map(dim_to_int, x_shape))
#
# shape_as_value_p.def_impl(_shape_as_value_impl)
#
# def _shape_as_value_abstract(x_aval: core.AbstractValue) -> Sequence[core.AbstractValue]:
# rank = len(x_aval.shape) # type: ignore[attr-defined]
# return (core.ShapedArray((), dtypes.canonicalize_dtype(np.int_), weak_type=True),) * rank
#
# shape_as_value_p.def_abstract_eval(_shape_as_value_abstract)
#
# def _shape_as_value_translation(comp, x):
# return xla_client._xla.ops.Tuple(comp,
# tuple(xb.constant(comp, d)
# for d in comp.GetShape(x).dimensions()))
#
# xla.translations[shape_as_value_p] = _shape_as_value_translation
#
# def _shape_as_value_jvp_rule(primals, tangents):
# # The shape does not depend on the contents of the input
# x, = primals
# zero = ad.Zero.from_value(0.)
# return shape_as_value(x), (zero,) * len(x.shape)
#
# ad.primitive_jvps[shape_as_value_p] = _shape_as_value_jvp_rule
#
# def _shape_as_value__batching_rule(batched_args, batch_dims):
# xv, = batched_args
# batch_dim, = batch_dims
# batch_size = xv.shape[batch_dim]
# batched_shape = shape_as_value(xv)
# one_shape = batched_shape[0:batch_dim] + batched_shape[batch_dim+1:]
# res = tuple(jnp.broadcast_to(d, (batch_size, 1)) for d in one_shape)
# return res, (0,) * len(one_shape)
#
# batching.primitive_batchers[shape_as_value_p] = _shape_as_value__batching_rule
#
# def _shape_as_value_masking_rule(operands, operands_logical_shapes):
# x_logical_shape, = operands_logical_shapes
# return tuple(x_logical_shape)
#
# masking.masking_rules[shape_as_value_p] = _shape_as_value_masking_rule
#
# def _shape_as_value_tf(x: TfVal,
# _in_avals: Sequence[core.AbstractValue],
# _out_aval: core.AbstractValue) -> TfVal:
# x_aval = _in_avals[0]
# def dim_to_tfval(dim: shape_poly.DimSize, dim_idx: int) -> TfVal:
# dim_int = _poly_dim_to_tf_dim(dim)
# if dim_int is not None:
# return tf.convert_to_tensor(dim_int)
# else:
# return tf.shape(x)[dim_idx]
# return tuple(dim_to_tfval(dim, dim_idx)
# for dim_idx, dim in enumerate(x_aval.shape)) # type: ignore[attr-defined]
#
# tf_impl_with_avals[shape_as_value_p] = _shape_as_value_tf
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
# pylint: disable=assigning-non-slot
@ -775,8 +690,7 @@ class TensorFlowTracer(core.Tracer):
else:
assert self._aval.dtype == _to_jax_dtype(val.dtype), f"expected {self._aval.dtype} == {val.dtype}"
for aval_dim, val_dim in zip(
self._aval.shape, val_shape): # type: ignore[attr-defined]
for aval_dim, val_dim in zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
if val_dim is None:
assert shape_poly.is_poly_dim(aval_dim), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
elif not shape_poly.is_poly_dim(aval_dim):
@ -2481,8 +2395,8 @@ def _slice(operand, start_indices, limit_indices, strides, _in_avals,
_eval_shape(strides)))
out = operand[slices]
# TODO(b/184503314): improve shape inference for __getitem__
#out.set_shape(_aval_to_tf_shape(_out_aval))
#assert False, f"start_indices={start_indices}, limit_indices={limit_indices}, strides={strides}, out={out}"
# E.g., operand.shape=(b, 5, 3), start_indices=(0, 1, 1), limit_indices=(b, 5, 3), strides=(1, 2, 1)
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
@ -3075,3 +2989,5 @@ def _register_checkpoint_pytrees():
_register_checkpoint_pytrees()
shape_poly._register_conversion_rules()

View File

@ -30,6 +30,8 @@ from jax._src.numpy import lax_numpy
import opt_einsum
from jax import config
from jax import core
from . import jax2tf as jax2tf_internal
import numpy as np
import tensorflow as tf # type: ignore[import]
@ -419,6 +421,10 @@ class DimensionHandlerPoly(core.DimensionHandler):
f"window_size '{window_size}', stride '{window_stride}'. Reason: {e}.")
return d
def as_value(self, d: DimSize):
"""Turns a dimension size into a Jax value that we can compute with."""
return _dim_as_value(d)
core._SPECIAL_DIMENSION_HANDLERS[_DimPolynomial] = DimensionHandlerPoly()
def _einsum_contract_path(*operands, **kwargs):
@ -468,6 +474,27 @@ def _einsum_contract_path(*operands, **kwargs):
lax_numpy._polymorphic_einsum_contract_path_handlers[_DimPolynomial] = _einsum_contract_path
# A JAX primitive with no array arguments but with a dimension parameter
# that is a DimPoly. The value of the primitive is the value of the dimension.
# This primitive is used only in the context of jax2tf, so it does not need
# XLA translation rules.
dim_as_value_p = core.Primitive("dim_as_value")
def _dim_as_value_abstract(dim: DimSize) -> core.AbstractValue:
return core.ShapedArray((), np.int32)
dim_as_value_p.def_abstract_eval(_dim_as_value_abstract)
def _dim_as_value(dim: DimSize):
return dim_as_value_p.bind(dim=dim)
def _dim_as_value_jax2tf(dim: DimSize):
dim_tf, = jax2tf_internal._eval_shape((dim,))
assert dim_tf.dtype == tf.int32
return dim_tf
def _register_conversion_rules():
jax2tf_internal.tf_impl[dim_as_value_p] = _dim_as_value_jax2tf
class PolyShape(tuple):
"""Tuple of polymorphic dimension specifications.

View File

@ -21,6 +21,7 @@ import functools
from functools import partial
import operator
import re
import unittest
import jax
from jax import core
@ -35,7 +36,6 @@ import numpy as np
from jax.experimental.jax2tf.tests import tf_test_util
import tensorflow as tf # type: ignore[import]
import unittest
from jax.config import config
@ -643,9 +643,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
tf.TensorSpec([3, 4, 8, 9]))
self.assertEqual((3, 4, 8, 8), tuple(tf_grad.output_shapes[0]))
# TODO(necula): Understand why we get partial shapes for the gradient, even
# though all inputs have known shapes.
self.assertEqual((3, 4, 8, None), tuple(tf_grad.output_shapes[1]))
self.assertEqual((3, 4, 8, 9), tuple(tf_grad.output_shapes[1]))
def test_gradients_pytree(self):
"""Shape polymorphism with gradients and pytrees for inputs and outputs."""
@ -878,15 +876,12 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
self.assertEqual(1, f_tf(x45))
class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
def setUp(self):
raise unittest.SkipTest("shape_as_value not supported anymore. See #6080.")
class DimAsValueTest(tf_test_util.JaxToTfTestCase):
def test_concrete_shapes(self):
# Test shape_as_value with concrete shapes. All transformations work.
# Test dim_as_value with concrete shapes.
def f(x):
return jnp.sum(x, axis=0) * jax2tf.shape_as_value(x)[0]
return jnp.sum(x, axis=0) * core.dimension_as_value(x.shape[0])
x = np.arange(3.)
self.assertAllClose(9., f(x))
@ -903,15 +898,11 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])])
self.assertAllClose(res_iter, res_vmap)
res_mask2, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=2))
self.assertAllClose(2., res_mask2)
res_mask3, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=3))
self.assertAllClose(9., res_mask3)
def test_dynamic_shapes(self):
# Test shape_as_value with dynamic shapes. All transformations work.
# Test dim_as_value with dynamic shapes.
def f(x):
return jnp.sum(x, axis=0) * jax2tf.shape_as_value(x)[0]
return jnp.sum(x, axis=0) * core.dimension_as_value(x.shape[0])
x = np.arange(3.)
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x))
@ -937,53 +928,47 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])])
self.assertAllClose(res_iter, res_vmap)
res_mask2, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=2))
self.assertAllClose(2., res_mask2)
res_mask3, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=3))
self.assertAllClose(9., res_mask3)
def test_cond(self):
# Test the primitive under conditional
def f(x):
return lax.cond(
jnp.sum(x) > 0.,
lambda _: jnp.sum(x) / functools.reduce(lax.mul,
jax2tf.shape_as_value(x)),
lambda _: 0.,
operand=None)
x = np.ones((2, 3, 4))
self.assertAllClose(1., f(x))
self.assertAllClose(1.,
jax2tf.convert(f, polymorphic_shapes=["(a, b, 4)"])(x))
res_vmap_tf = jax2tf.convert(jax.vmap(f, in_axes=1),
polymorphic_shapes=["b1, b2, ..."])(xv)
self.assertAllClose(res_iter, res_vmap_tf.numpy())
def test_mean0(self):
def f_jax(x):
return jnp.sum(x, axis=0) / jax2tf.shape_as_value(x)[0]
return jnp.sum(x, axis=0) / core.dimension_as_value(x.shape[0])
x = np.arange(12.).reshape((3, 4))
f_tf = self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([None, 4], dtype=x.dtype)],
polymorphic_shapes=[("batch, _")],
polymorphic_shapes=[("b, _")],
expected_output_signature=tf.TensorSpec([4]))
self.assertAllClose(np.array([4., 5., 6., 7.]), f_tf(x))
def test_mean_all_axes(self):
def f_jax(x):
return jnp.sum(x) / np.prod(jax2tf.shape_as_value(x))
return jnp.sum(x) / core.dimension_as_value(np.prod(x.shape))
x = np.arange(12.).reshape((3, 4))
f_tf = self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([None, 4], dtype=x.dtype)],
polymorphic_shapes=[("batch, _")],
polymorphic_shapes=[("b, _")],
expected_output_signature=tf.TensorSpec([]))
self.assertAllClose(jnp.mean(x), f_tf(x))
def test_errors(self):
with self.assertRaisesRegex(
TypeError,
"Shapes must be 1D sequences of concrete values of integer type"):
core.dimension_as_value(np.array([1, 2], dtype=np.int32))
with self.assertRaisesRegex(
TypeError,
"Shapes must be 1D sequences of concrete values of integer type"):
core.dimension_as_value(np.float32(1))
###
### We define primitive harnesses for which we will test shape-polymorphic
@ -1033,6 +1018,7 @@ def _make_harness(group_name: str, name: str,
_f32 = np.float32
# List containing either harnesses, or lists of harnesses
_POLY_SHAPE_TEST_HARNESSES = [
_make_harness("jnp_add", "",
jnp.add,
@ -1079,12 +1065,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
[RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)],
poly_axes=[0, 0]),
_make_harness("dynamic_update_slice", "",
# x:shape: (b, 4)
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
[RandArg((3, 4), _f32)],
poly_axes=[0]),
_make_harness("einsum", "0",
lambda x: jnp.einsum("...i->...", x),
[RandArg((3, 4), _f32)],
@ -1143,12 +1123,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
[RandArg((3, 4), _f32), RandArg((4, 5), _f32)],
poly_axes=[1, 0]),
# TODO(necula): not supported yet
# _make_harness("jnp_getitem", "",
# lambda a, i: a[i],
# [RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
# poly_axes=[0, 0]),
_make_harness("iota", "",
lambda x: x + lax.iota(_f32, x.shape[0]),
[RandArg((3,), _f32)],
@ -1166,6 +1140,64 @@ _POLY_SHAPE_TEST_HARNESSES = [
poly_axes=[0, None],
tol=1e-5),
[
_make_harness("jnp_mean",
f"axis={axis}_keepdims={keepdims}_where=None",
lambda x, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=None),
[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)],
poly_axes=[0])
for keepdims in [False, True]
for axis in [None, (0,), (0, 1), (1,)]
],
[
_make_harness("jnp_mean",
f"axis={axis}_keepdims={keepdims}_where=Some",
lambda x, where, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=where),
[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), StaticArg(axis), StaticArg(keepdims)],
poly_axes=[0, 0])
for keepdims in [False, True]
for axis in [None, (0,), (0, 1), (1,)]
],
[
_make_harness("jnp_average",
f"axis={axis}_weights=None",
lambda x, axis: jnp.average(x, axis=axis, returned=False, weights=None),
[RandArg((7, 8, 4), _f32), StaticArg(axis)],
poly_axes=[0])
for axis in [None, 0, 1]
],
[
_make_harness("jnp_average",
f"axis={axis}_weights=Some",
lambda x, weights, axis: jnp.average(x, axis=axis, returned=False, weights=weights),
[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), _f32), StaticArg(axis)],
poly_axes=[0, 0])
for axis in [None, 0, 1]
],
[
_make_harness("jnp_var",
f"axis={axis}_keepdims={keepdims}_where=None",
lambda x, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=None),
[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)],
poly_axes=[0])
for keepdims in [False, True]
for axis in [None, (0,), (0, 1), (1,)]
],
[
_make_harness("jnp_var",
f"axis={axis}_keepdims={keepdims}_where=Some",
lambda x, where, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=where),
[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), StaticArg(axis), StaticArg(keepdims)],
poly_axes=[0, 0])
for keepdims in [False, True]
for axis in [None, (0,), (0, 1), (1,)]
],
_make_harness("jnp_where", "",
jnp.where,
[RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)],
@ -1182,7 +1214,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
[RandArg((3, 2, 4), _f32)],
poly_axes=[0]),
# TODO: random_gamma does not work yet.
_make_harness("random_gamma", "",
lambda key, a: jax.random.gamma(key, a),
[RandArg((3, 2), np.uint32), RandArg((3, 3), _f32)],
@ -1260,6 +1291,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
poly_axes=[0]),
]
for enable_xla in [False, True]:
_POLY_SHAPE_TEST_HARNESSES.extend([
# Reduce the poly dimension
@ -1276,22 +1308,77 @@ for enable_xla in [False, True]:
poly_axes=[0],
enable_xla=enable_xla),
_make_harness("dynamic_slice", f"enable_xla={enable_xla}",
_make_harness("dynamic_slice", f"idx=tuple_int_enable_xla={enable_xla}",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
[RandArg((3, 4), _f32)],
poly_axes=[0],
enable_xla=enable_xla),
_make_harness("dynamic_slice", f"idx=tuple_arg_enable_xla={enable_xla}",
# x:shape: (b, 4)
lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)),
[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
poly_axes=[0, None],
enable_xla=enable_xla),
_make_harness("dynamic_slice", f"idx=array_enable_xla={enable_xla}",
# x:shape: (b, 4)
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
poly_axes=[0, None],
enable_xla=enable_xla),
_make_harness("dynamic_update_slice", f"idx=tuple_int_enable_xla={enable_xla}",
# x:shape: (b, 4)
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
[RandArg((3, 4), _f32)],
poly_axes=[0],
enable_xla=enable_xla),
_make_harness("dynamic_update_slice", f"idx=tuple_arg_enable_xla={enable_xla}",
# x:shape: (b, 4)
lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))),
[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
poly_axes=[0, None],
enable_xla=enable_xla),
_make_harness("dynamic_update_slice", f"idx=array_enable_xla={enable_xla}",
# x:shape: (b, 4)
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
poly_axes=[0, None],
enable_xla=enable_xla),
_make_harness("jnp_take", f"enable_xla={enable_xla}",
lambda a, i: jnp.take(a, i, axis=1),
[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)],
poly_axes=[0, None], enable_xla=enable_xla),
_make_harness("jnp_getitem", f"enable_xla={enable_xla}",
# operand is non-poly, index is poly
_make_harness("jnp_getitem", f"op=static_idx=poly_enable_xla={enable_xla}",
lambda a, i: a[i],
[RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
poly_axes=[None, 0], enable_xla=enable_xla),
# operand is poly, index is integer
_make_harness("jnp_getitem", f"op=poly_idx=const_enable_xla={enable_xla}",
lambda a: a[1],
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_xla=enable_xla),
# operand is poly, index is dim poly
_make_harness("jnp_getitem", f"op=poly_idx=dim_enable_xla={enable_xla}",
lambda a: a[jax.core.dimension_as_value(a.shape[0] - 2)],
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_xla=enable_xla),
# Both the operand and the index are poly
_make_harness("jnp_getitem", f"op=poly_idx=poly_enable_xla={enable_xla}",
lambda a, i: a[i],
[RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)],
poly_axes=[0, 0], enable_xla=enable_xla),
])
for reduce_op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum]:
@ -1343,9 +1430,6 @@ def _add_vmap_primitive_harnesses():
# We do not yet support shape polymorphism for vmap for some primitives
_NOT_SUPPORTED_YET = frozenset([
# In the random._gamma_impl we do reshape(-1, 2) for the keys
"random_gamma",
# In linalg._lu_python we do reshape(-1, ...)
"lu",
"custom_linear_solve",
@ -1409,6 +1493,14 @@ def _get_jax2tf_limitations(
_add_vmap_primitive_harnesses()
def _flatten_harnesses(harnesses):
res = []
for h in harnesses:
if isinstance(h, Sequence):
res.extend(h)
else:
res.append(h)
return res
class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
"""Tests for primitives that take shape values as parameters."""
@ -1418,9 +1510,11 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# For each primitive "xxx" the
# test will be called "test_prim_xxx_...".
# If you want to run this test for only one harness that includes "foo"
# in the name, add parameter `one_containing="foo"` to parameterized below.
@primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES,
one_containing="dynamic_slice_enable_xla=True_poly_axes=[0]")
# in the name (after test_prim), add parameter `one_containing="foo"`
# to parameterized below.
@primitive_harness.parameterized(_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
#one_containing="dynamic_slice_idx=tuple_arg_enable_xla=False_poly_axes=[0, None]"
)
def test_prim(self, harness: Harness):
args = harness.dyn_args_maker(self.rng())
poly_axes = harness.params["poly_axes"] # type: Sequence[Sequence[int]]

View File

@ -338,8 +338,8 @@ class JaxToTfTestCase(jtu.JaxTestCase):
"""
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes,
enable_xla=enable_xla)
f_tf = tf.function(f_tf, autograph=False, input_signature=input_signature)
concrete_f_tf = f_tf.get_concrete_function(*input_signature)
f_tf_func = tf.function(f_tf, autograph=False, input_signature=input_signature)
concrete_f_tf = f_tf_func.get_concrete_function(*input_signature)
if expected_output_signature:
# Strangely, output_shapes can be a single shape for a function with a
# single result, or a list/tuple of shapes.