mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations such as average (which needs to divide by the size of a dimension) or indexing (which needs to normalize indices by comparing them with 0 and adding dimension size for negative indices). In both of these cases the size of a dimenion needs to be used as a value in the array computation. In general, the size of a dimension is used only to customize primitives. This change introduces `core.dim_as_value` which must be used on a dimension size before using it as a value in the array computation. E.g., ``` def average(x): return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0]) ``` This function is the identity function if the dimension size is constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`. Note that this does not change fundamentally the flavor of shape polymorphism supported in jax2tf: intermediate shapes and their values may depend on the input shapes, but never does a shape depend on the input values. In fact, one could have expressed the `dim_as_value` already: ``` def dim_as_value(d): jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,))) ``` We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`, `lax.dynamic_slice`, `lax.dynamic_update_slice` by using `core.dim_as_value` internally, but to fully roll-up the solution we need to make `core.dim_as_value` a public API and teach the users how to use it when they want to use shape polymorphism. Alternatively, perhaps there is a way to automatically convert dimension polynomials to values when passed to the lax primitives.
This commit is contained in:
parent
87f72ac446
commit
b62ceba91c
11
CHANGELOG.md
11
CHANGELOG.md
@ -11,15 +11,18 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.2.19 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.18...main).
|
||||
|
||||
* New features:
|
||||
* Improved the support for shape polymorphism in jax2tf for operations that
|
||||
need to use a dimension size in array computation, e.g., `jnp.mean`.
|
||||
({jax-issue}`#7317`)
|
||||
|
||||
## jaxlib 0.1.70 (unreleased)
|
||||
* Breaking changes:
|
||||
* Support for Python 3.6 has been dropped, per the
|
||||
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
|
||||
Please upgrade to a supported Python version.
|
||||
|
||||
|
||||
* Breaking changes:
|
||||
* The host_callback mechnism now uses one thread per local device for
|
||||
|
||||
* The host_callback mechanism now uses one thread per local device for
|
||||
making the calls to the Python callbacks. Previously there was a single
|
||||
thread for all devices. This means that the callbacks may now be called
|
||||
interleaved. The callbacks corresponding to one device will still be
|
||||
|
@ -6763,7 +6763,8 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
|
||||
raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))
|
||||
|
||||
|
||||
def _dynamic_slice_indices(operand, start_indices):
|
||||
def _dynamic_slice_indices(operand, start_indices):
|
||||
# Normalize the start_indices w.r.t. operand.shape
|
||||
if len(start_indices) != operand.ndim:
|
||||
msg = ("Length of slice indices must match number of operand dimensions ({} "
|
||||
"vs {})")
|
||||
@ -6772,15 +6773,13 @@ def _dynamic_slice_indices(operand, start_indices):
|
||||
if start_indices.ndim != 1:
|
||||
raise ValueError("Slice indices must be a 1D sequence, got {}"
|
||||
.format(start_indices.shape))
|
||||
return select(lt(start_indices, _zeros(start_indices)),
|
||||
add(start_indices, _const(start_indices, operand.shape)),
|
||||
start_indices)
|
||||
else:
|
||||
return [np.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_))
|
||||
if isinstance(i, (int, np.integer))
|
||||
else select(lt(i, _const(i, 0)), add(i, _const(i, d)), i)
|
||||
for i, d in zip(start_indices, operand.shape)]
|
||||
|
||||
start_indices = [i for i in start_indices]
|
||||
return [np.asarray(i + d if i < 0 else i, getattr(i, 'dtype', dtypes.int_))
|
||||
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d)
|
||||
else select(lt(i, _const(i, 0)),
|
||||
add(i, convert_element_type(core.dimension_as_value(d), _dtype(i))),
|
||||
i)
|
||||
for i, d in zip(start_indices, operand.shape)]
|
||||
|
||||
def _const(example, val):
|
||||
dtype = _dtype(example)
|
||||
|
@ -2162,9 +2162,9 @@ def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
|
||||
if where is None:
|
||||
if axis is None:
|
||||
normalizer = size(a)
|
||||
normalizer = core.dimension_as_value(size(a))
|
||||
else:
|
||||
normalizer = _axis_size(a, axis)
|
||||
normalizer = core.dimension_as_value(_axis_size(a, axis))
|
||||
else:
|
||||
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
|
||||
@ -2187,9 +2187,9 @@ def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
|
||||
if weights is None: # Treat all weights as 1
|
||||
avg = mean(a, axis=axis)
|
||||
if axis is None:
|
||||
weights_sum = full((), size(a), dtype=avg.dtype)
|
||||
weights_sum = full((), core.dimension_as_value(size(a)), dtype=avg.dtype)
|
||||
else:
|
||||
weights_sum = full_like(avg, a.shape[axis], dtype=avg.dtype)
|
||||
weights_sum = full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype)
|
||||
else:
|
||||
weights = asarray(weights)
|
||||
|
||||
@ -2212,7 +2212,7 @@ def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
|
||||
if len(weights_shape) != 1:
|
||||
raise ValueError("1D weights expected when shapes of a and "
|
||||
"weights differ.")
|
||||
if weights_shape[0] != a_shape[axis]:
|
||||
if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
|
||||
raise ValueError("Length of weights not "
|
||||
"compatible with specified axis.")
|
||||
|
||||
@ -2247,9 +2247,9 @@ def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
|
||||
if where is None:
|
||||
if axis is None:
|
||||
normalizer = size(a)
|
||||
normalizer = core.dimension_as_value(size(a))
|
||||
else:
|
||||
normalizer = _axis_size(a, axis)
|
||||
normalizer = core.dimension_as_value(_axis_size(a, axis))
|
||||
else:
|
||||
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
|
||||
normalizer = normalizer - ddof
|
||||
@ -4782,9 +4782,14 @@ def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
|
||||
def _normalize_index(index, axis_size):
|
||||
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
|
||||
if core.is_constant_dim(axis_size):
|
||||
axis_size_val = _constant_like(index, axis_size)
|
||||
else:
|
||||
axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size),
|
||||
_dtype(index))
|
||||
return lax.select(
|
||||
lax.lt(index, _constant_like(index, 0)),
|
||||
lax.add(index, _constant_like(index, axis_size)),
|
||||
lax.add(index, axis_size_val),
|
||||
index)
|
||||
|
||||
@partial(jit, static_argnums=(2,))
|
||||
@ -5199,7 +5204,7 @@ def _index_to_gather(x_shape, idx, normalize_indices=True):
|
||||
abstract_i = None
|
||||
# Handle basic int indexes.
|
||||
if isinstance(abstract_i, (ConcreteArray,ShapedArray)) and _int(abstract_i):
|
||||
if x_shape[x_axis] == 0:
|
||||
if core.symbolic_equal_dim(x_shape[x_axis], 0):
|
||||
# XLA gives error when indexing into an axis of size 0
|
||||
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
|
||||
i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
|
||||
|
@ -1283,6 +1283,9 @@ class DimensionHandler:
|
||||
"""Implements `0 if d == 0 else 1 + dilation * (d - 1))`"""
|
||||
return 0 if d == 0 else 1 + dilation * (d - 1)
|
||||
|
||||
def as_value(self, d: DimSize):
|
||||
"""Turns a dimension size into a JAX value that we can compute with."""
|
||||
return d
|
||||
|
||||
_dimension_handler_int = DimensionHandler()
|
||||
_SPECIAL_DIMENSION_HANDLERS: Dict[type, DimensionHandler] = {}
|
||||
@ -1378,6 +1381,11 @@ def stride_shape(s: Shape, window_size: Shape, window_stride: Shape) -> Shape:
|
||||
"""(s - window_size) // window_stride + 1"""
|
||||
return tuple(map(stride_dim, s, window_size, window_stride))
|
||||
|
||||
def dimension_as_value(d: DimSize):
|
||||
"""Turns a dimension size into a JAX value that we can compute with.
|
||||
This is the identity function for constant dimensions."""
|
||||
handler, ds = _dim_handler_and_canonical(d)
|
||||
return handler.as_value(*ds)
|
||||
|
||||
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
||||
if type(dim) in _SPECIAL_DIMENSION_HANDLERS:
|
||||
|
@ -13,5 +13,5 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from .jax2tf import convert, dtype_of_val, shape_as_value, split_to_logical_devices, PolyShape
|
||||
from .jax2tf import convert, dtype_of_val, split_to_logical_devices, PolyShape
|
||||
from .call_tf import call_tf
|
||||
|
@ -648,91 +648,6 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
|
||||
return shape_poly.eval_shape(shape, _thread_local_state.shape_env)
|
||||
|
||||
|
||||
def shape_as_value(x):
|
||||
"""Injects the shape of `x` as an array value.
|
||||
|
||||
**Experimental: please give feedback, and expect changes!**
|
||||
|
||||
This allows the use of a shape expression as array argument to JAX functions.
|
||||
A typical example is for implementing a mean operation:
|
||||
|
||||
jnp.sum(x) / np.prod(jax2tf.shape_as_value(x))
|
||||
"""
|
||||
# return shape_as_value_p.bind(x)
|
||||
return NotImplementedError("shape_as_value is deprecated")
|
||||
|
||||
|
||||
# # TODO: move this to masking or to some common library, if approved
|
||||
# shape_as_value_p = core.Primitive("shape_as_value")
|
||||
# shape_as_value_p.multiple_results = True
|
||||
# def _shape_as_value_impl(x):
|
||||
# x_shape = np.shape(x)
|
||||
# def dim_to_int(dim: shape_poly.DimSize) -> int:
|
||||
# dim_int = _poly_dim_to_tf_dim(dim)
|
||||
# if dim_int is None:
|
||||
# msg = ("shape_as_value is not implemented for non-constant shapes "
|
||||
# "except for masking and jax2tf. "
|
||||
# f"Has shape: {x_shape}")
|
||||
# raise TypeError(msg)
|
||||
# else:
|
||||
# return dim_int
|
||||
# return tuple(map(dim_to_int, x_shape))
|
||||
#
|
||||
# shape_as_value_p.def_impl(_shape_as_value_impl)
|
||||
#
|
||||
# def _shape_as_value_abstract(x_aval: core.AbstractValue) -> Sequence[core.AbstractValue]:
|
||||
# rank = len(x_aval.shape) # type: ignore[attr-defined]
|
||||
# return (core.ShapedArray((), dtypes.canonicalize_dtype(np.int_), weak_type=True),) * rank
|
||||
#
|
||||
# shape_as_value_p.def_abstract_eval(_shape_as_value_abstract)
|
||||
#
|
||||
# def _shape_as_value_translation(comp, x):
|
||||
# return xla_client._xla.ops.Tuple(comp,
|
||||
# tuple(xb.constant(comp, d)
|
||||
# for d in comp.GetShape(x).dimensions()))
|
||||
#
|
||||
# xla.translations[shape_as_value_p] = _shape_as_value_translation
|
||||
#
|
||||
# def _shape_as_value_jvp_rule(primals, tangents):
|
||||
# # The shape does not depend on the contents of the input
|
||||
# x, = primals
|
||||
# zero = ad.Zero.from_value(0.)
|
||||
# return shape_as_value(x), (zero,) * len(x.shape)
|
||||
#
|
||||
# ad.primitive_jvps[shape_as_value_p] = _shape_as_value_jvp_rule
|
||||
#
|
||||
# def _shape_as_value__batching_rule(batched_args, batch_dims):
|
||||
# xv, = batched_args
|
||||
# batch_dim, = batch_dims
|
||||
# batch_size = xv.shape[batch_dim]
|
||||
# batched_shape = shape_as_value(xv)
|
||||
# one_shape = batched_shape[0:batch_dim] + batched_shape[batch_dim+1:]
|
||||
# res = tuple(jnp.broadcast_to(d, (batch_size, 1)) for d in one_shape)
|
||||
# return res, (0,) * len(one_shape)
|
||||
#
|
||||
# batching.primitive_batchers[shape_as_value_p] = _shape_as_value__batching_rule
|
||||
#
|
||||
# def _shape_as_value_masking_rule(operands, operands_logical_shapes):
|
||||
# x_logical_shape, = operands_logical_shapes
|
||||
# return tuple(x_logical_shape)
|
||||
#
|
||||
# masking.masking_rules[shape_as_value_p] = _shape_as_value_masking_rule
|
||||
#
|
||||
# def _shape_as_value_tf(x: TfVal,
|
||||
# _in_avals: Sequence[core.AbstractValue],
|
||||
# _out_aval: core.AbstractValue) -> TfVal:
|
||||
# x_aval = _in_avals[0]
|
||||
# def dim_to_tfval(dim: shape_poly.DimSize, dim_idx: int) -> TfVal:
|
||||
# dim_int = _poly_dim_to_tf_dim(dim)
|
||||
# if dim_int is not None:
|
||||
# return tf.convert_to_tensor(dim_int)
|
||||
# else:
|
||||
# return tf.shape(x)[dim_idx]
|
||||
# return tuple(dim_to_tfval(dim, dim_idx)
|
||||
# for dim_idx, dim in enumerate(x_aval.shape)) # type: ignore[attr-defined]
|
||||
#
|
||||
# tf_impl_with_avals[shape_as_value_p] = _shape_as_value_tf
|
||||
|
||||
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
|
||||
# pylint: disable=assigning-non-slot
|
||||
|
||||
@ -775,8 +690,7 @@ class TensorFlowTracer(core.Tracer):
|
||||
else:
|
||||
assert self._aval.dtype == _to_jax_dtype(val.dtype), f"expected {self._aval.dtype} == {val.dtype}"
|
||||
|
||||
for aval_dim, val_dim in zip(
|
||||
self._aval.shape, val_shape): # type: ignore[attr-defined]
|
||||
for aval_dim, val_dim in zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
|
||||
if val_dim is None:
|
||||
assert shape_poly.is_poly_dim(aval_dim), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||||
elif not shape_poly.is_poly_dim(aval_dim):
|
||||
@ -2481,8 +2395,8 @@ def _slice(operand, start_indices, limit_indices, strides, _in_avals,
|
||||
_eval_shape(strides)))
|
||||
out = operand[slices]
|
||||
# TODO(b/184503314): improve shape inference for __getitem__
|
||||
#out.set_shape(_aval_to_tf_shape(_out_aval))
|
||||
#assert False, f"start_indices={start_indices}, limit_indices={limit_indices}, strides={strides}, out={out}"
|
||||
# E.g., operand.shape=(b, 5, 3), start_indices=(0, 1, 1), limit_indices=(b, 5, 3), strides=(1, 2, 1)
|
||||
out.set_shape(_aval_to_tf_shape(_out_aval))
|
||||
return out
|
||||
|
||||
|
||||
@ -3075,3 +2989,5 @@ def _register_checkpoint_pytrees():
|
||||
|
||||
|
||||
_register_checkpoint_pytrees()
|
||||
|
||||
shape_poly._register_conversion_rules()
|
||||
|
@ -30,6 +30,8 @@ from jax._src.numpy import lax_numpy
|
||||
import opt_einsum
|
||||
from jax import config
|
||||
from jax import core
|
||||
from . import jax2tf as jax2tf_internal
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
@ -419,6 +421,10 @@ class DimensionHandlerPoly(core.DimensionHandler):
|
||||
f"window_size '{window_size}', stride '{window_stride}'. Reason: {e}.")
|
||||
return d
|
||||
|
||||
def as_value(self, d: DimSize):
|
||||
"""Turns a dimension size into a Jax value that we can compute with."""
|
||||
return _dim_as_value(d)
|
||||
|
||||
core._SPECIAL_DIMENSION_HANDLERS[_DimPolynomial] = DimensionHandlerPoly()
|
||||
|
||||
def _einsum_contract_path(*operands, **kwargs):
|
||||
@ -468,6 +474,27 @@ def _einsum_contract_path(*operands, **kwargs):
|
||||
|
||||
lax_numpy._polymorphic_einsum_contract_path_handlers[_DimPolynomial] = _einsum_contract_path
|
||||
|
||||
# A JAX primitive with no array arguments but with a dimension parameter
|
||||
# that is a DimPoly. The value of the primitive is the value of the dimension.
|
||||
# This primitive is used only in the context of jax2tf, so it does not need
|
||||
# XLA translation rules.
|
||||
dim_as_value_p = core.Primitive("dim_as_value")
|
||||
def _dim_as_value_abstract(dim: DimSize) -> core.AbstractValue:
|
||||
return core.ShapedArray((), np.int32)
|
||||
|
||||
dim_as_value_p.def_abstract_eval(_dim_as_value_abstract)
|
||||
|
||||
def _dim_as_value(dim: DimSize):
|
||||
return dim_as_value_p.bind(dim=dim)
|
||||
|
||||
def _dim_as_value_jax2tf(dim: DimSize):
|
||||
dim_tf, = jax2tf_internal._eval_shape((dim,))
|
||||
assert dim_tf.dtype == tf.int32
|
||||
return dim_tf
|
||||
|
||||
def _register_conversion_rules():
|
||||
jax2tf_internal.tf_impl[dim_as_value_p] = _dim_as_value_jax2tf
|
||||
|
||||
|
||||
class PolyShape(tuple):
|
||||
"""Tuple of polymorphic dimension specifications.
|
||||
|
@ -21,6 +21,7 @@ import functools
|
||||
from functools import partial
|
||||
import operator
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
@ -35,7 +36,6 @@ import numpy as np
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
import unittest
|
||||
|
||||
from jax.config import config
|
||||
|
||||
@ -643,9 +643,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
tf.TensorSpec([3, 4, 8, 9]))
|
||||
|
||||
self.assertEqual((3, 4, 8, 8), tuple(tf_grad.output_shapes[0]))
|
||||
# TODO(necula): Understand why we get partial shapes for the gradient, even
|
||||
# though all inputs have known shapes.
|
||||
self.assertEqual((3, 4, 8, None), tuple(tf_grad.output_shapes[1]))
|
||||
self.assertEqual((3, 4, 8, 9), tuple(tf_grad.output_shapes[1]))
|
||||
|
||||
def test_gradients_pytree(self):
|
||||
"""Shape polymorphism with gradients and pytrees for inputs and outputs."""
|
||||
@ -878,15 +876,12 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertEqual(1, f_tf(x45))
|
||||
|
||||
|
||||
class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def setUp(self):
|
||||
raise unittest.SkipTest("shape_as_value not supported anymore. See #6080.")
|
||||
class DimAsValueTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_concrete_shapes(self):
|
||||
# Test shape_as_value with concrete shapes. All transformations work.
|
||||
# Test dim_as_value with concrete shapes.
|
||||
def f(x):
|
||||
return jnp.sum(x, axis=0) * jax2tf.shape_as_value(x)[0]
|
||||
return jnp.sum(x, axis=0) * core.dimension_as_value(x.shape[0])
|
||||
|
||||
x = np.arange(3.)
|
||||
self.assertAllClose(9., f(x))
|
||||
@ -903,15 +898,11 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
|
||||
res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])])
|
||||
self.assertAllClose(res_iter, res_vmap)
|
||||
|
||||
res_mask2, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=2))
|
||||
self.assertAllClose(2., res_mask2)
|
||||
res_mask3, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=3))
|
||||
self.assertAllClose(9., res_mask3)
|
||||
|
||||
def test_dynamic_shapes(self):
|
||||
# Test shape_as_value with dynamic shapes. All transformations work.
|
||||
# Test dim_as_value with dynamic shapes.
|
||||
def f(x):
|
||||
return jnp.sum(x, axis=0) * jax2tf.shape_as_value(x)[0]
|
||||
return jnp.sum(x, axis=0) * core.dimension_as_value(x.shape[0])
|
||||
|
||||
x = np.arange(3.)
|
||||
self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x))
|
||||
@ -937,53 +928,47 @@ class ShapeAsValueTest(tf_test_util.JaxToTfTestCase):
|
||||
res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])])
|
||||
self.assertAllClose(res_iter, res_vmap)
|
||||
|
||||
res_mask2, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=2))
|
||||
self.assertAllClose(2., res_mask2)
|
||||
res_mask3, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=3))
|
||||
self.assertAllClose(9., res_mask3)
|
||||
|
||||
def test_cond(self):
|
||||
# Test the primitive under conditional
|
||||
def f(x):
|
||||
return lax.cond(
|
||||
jnp.sum(x) > 0.,
|
||||
lambda _: jnp.sum(x) / functools.reduce(lax.mul,
|
||||
jax2tf.shape_as_value(x)),
|
||||
lambda _: 0.,
|
||||
operand=None)
|
||||
|
||||
x = np.ones((2, 3, 4))
|
||||
self.assertAllClose(1., f(x))
|
||||
self.assertAllClose(1.,
|
||||
jax2tf.convert(f, polymorphic_shapes=["(a, b, 4)"])(x))
|
||||
res_vmap_tf = jax2tf.convert(jax.vmap(f, in_axes=1),
|
||||
polymorphic_shapes=["b1, b2, ..."])(xv)
|
||||
self.assertAllClose(res_iter, res_vmap_tf.numpy())
|
||||
|
||||
def test_mean0(self):
|
||||
|
||||
def f_jax(x):
|
||||
return jnp.sum(x, axis=0) / jax2tf.shape_as_value(x)[0]
|
||||
return jnp.sum(x, axis=0) / core.dimension_as_value(x.shape[0])
|
||||
|
||||
x = np.arange(12.).reshape((3, 4))
|
||||
f_tf = self.CheckShapePolymorphism(
|
||||
f_jax,
|
||||
input_signature=[tf.TensorSpec([None, 4], dtype=x.dtype)],
|
||||
polymorphic_shapes=[("batch, _")],
|
||||
polymorphic_shapes=[("b, _")],
|
||||
expected_output_signature=tf.TensorSpec([4]))
|
||||
self.assertAllClose(np.array([4., 5., 6., 7.]), f_tf(x))
|
||||
|
||||
def test_mean_all_axes(self):
|
||||
|
||||
def f_jax(x):
|
||||
return jnp.sum(x) / np.prod(jax2tf.shape_as_value(x))
|
||||
return jnp.sum(x) / core.dimension_as_value(np.prod(x.shape))
|
||||
|
||||
x = np.arange(12.).reshape((3, 4))
|
||||
f_tf = self.CheckShapePolymorphism(
|
||||
f_jax,
|
||||
input_signature=[tf.TensorSpec([None, 4], dtype=x.dtype)],
|
||||
polymorphic_shapes=[("batch, _")],
|
||||
polymorphic_shapes=[("b, _")],
|
||||
expected_output_signature=tf.TensorSpec([]))
|
||||
|
||||
self.assertAllClose(jnp.mean(x), f_tf(x))
|
||||
|
||||
def test_errors(self):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type"):
|
||||
core.dimension_as_value(np.array([1, 2], dtype=np.int32))
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Shapes must be 1D sequences of concrete values of integer type"):
|
||||
core.dimension_as_value(np.float32(1))
|
||||
|
||||
###
|
||||
### We define primitive harnesses for which we will test shape-polymorphic
|
||||
@ -1033,6 +1018,7 @@ def _make_harness(group_name: str, name: str,
|
||||
|
||||
_f32 = np.float32
|
||||
|
||||
# List containing either harnesses, or lists of harnesses
|
||||
_POLY_SHAPE_TEST_HARNESSES = [
|
||||
_make_harness("jnp_add", "",
|
||||
jnp.add,
|
||||
@ -1079,12 +1065,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
[RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)],
|
||||
poly_axes=[0, 0]),
|
||||
|
||||
_make_harness("dynamic_update_slice", "",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0]),
|
||||
|
||||
_make_harness("einsum", "0",
|
||||
lambda x: jnp.einsum("...i->...", x),
|
||||
[RandArg((3, 4), _f32)],
|
||||
@ -1143,12 +1123,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
[RandArg((3, 4), _f32), RandArg((4, 5), _f32)],
|
||||
poly_axes=[1, 0]),
|
||||
|
||||
# TODO(necula): not supported yet
|
||||
# _make_harness("jnp_getitem", "",
|
||||
# lambda a, i: a[i],
|
||||
# [RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
|
||||
# poly_axes=[0, 0]),
|
||||
|
||||
_make_harness("iota", "",
|
||||
lambda x: x + lax.iota(_f32, x.shape[0]),
|
||||
[RandArg((3,), _f32)],
|
||||
@ -1166,6 +1140,64 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
poly_axes=[0, None],
|
||||
tol=1e-5),
|
||||
|
||||
[
|
||||
_make_harness("jnp_mean",
|
||||
f"axis={axis}_keepdims={keepdims}_where=None",
|
||||
lambda x, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=None),
|
||||
[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)],
|
||||
poly_axes=[0])
|
||||
for keepdims in [False, True]
|
||||
for axis in [None, (0,), (0, 1), (1,)]
|
||||
],
|
||||
|
||||
[
|
||||
_make_harness("jnp_mean",
|
||||
f"axis={axis}_keepdims={keepdims}_where=Some",
|
||||
lambda x, where, axis, keepdims: jnp.mean(x, axis=axis, keepdims=keepdims, where=where),
|
||||
[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), StaticArg(axis), StaticArg(keepdims)],
|
||||
poly_axes=[0, 0])
|
||||
for keepdims in [False, True]
|
||||
for axis in [None, (0,), (0, 1), (1,)]
|
||||
],
|
||||
|
||||
[
|
||||
_make_harness("jnp_average",
|
||||
f"axis={axis}_weights=None",
|
||||
lambda x, axis: jnp.average(x, axis=axis, returned=False, weights=None),
|
||||
[RandArg((7, 8, 4), _f32), StaticArg(axis)],
|
||||
poly_axes=[0])
|
||||
for axis in [None, 0, 1]
|
||||
],
|
||||
|
||||
[
|
||||
_make_harness("jnp_average",
|
||||
f"axis={axis}_weights=Some",
|
||||
lambda x, weights, axis: jnp.average(x, axis=axis, returned=False, weights=weights),
|
||||
[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), _f32), StaticArg(axis)],
|
||||
poly_axes=[0, 0])
|
||||
for axis in [None, 0, 1]
|
||||
],
|
||||
|
||||
[
|
||||
_make_harness("jnp_var",
|
||||
f"axis={axis}_keepdims={keepdims}_where=None",
|
||||
lambda x, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=None),
|
||||
[RandArg((7, 8, 4), _f32), StaticArg(axis), StaticArg(keepdims)],
|
||||
poly_axes=[0])
|
||||
for keepdims in [False, True]
|
||||
for axis in [None, (0,), (0, 1), (1,)]
|
||||
],
|
||||
|
||||
[
|
||||
_make_harness("jnp_var",
|
||||
f"axis={axis}_keepdims={keepdims}_where=Some",
|
||||
lambda x, where, axis, keepdims: jnp.var(x, axis=axis, keepdims=keepdims, where=where),
|
||||
[RandArg((7, 8, 4), _f32), RandArg((7, 8, 4), np.bool_), StaticArg(axis), StaticArg(keepdims)],
|
||||
poly_axes=[0, 0])
|
||||
for keepdims in [False, True]
|
||||
for axis in [None, (0,), (0, 1), (1,)]
|
||||
],
|
||||
|
||||
_make_harness("jnp_where", "",
|
||||
jnp.where,
|
||||
[RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)],
|
||||
@ -1182,7 +1214,6 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
[RandArg((3, 2, 4), _f32)],
|
||||
poly_axes=[0]),
|
||||
|
||||
# TODO: random_gamma does not work yet.
|
||||
_make_harness("random_gamma", "",
|
||||
lambda key, a: jax.random.gamma(key, a),
|
||||
[RandArg((3, 2), np.uint32), RandArg((3, 3), _f32)],
|
||||
@ -1260,6 +1291,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
poly_axes=[0]),
|
||||
]
|
||||
|
||||
|
||||
for enable_xla in [False, True]:
|
||||
_POLY_SHAPE_TEST_HARNESSES.extend([
|
||||
# Reduce the poly dimension
|
||||
@ -1276,22 +1308,77 @@ for enable_xla in [False, True]:
|
||||
poly_axes=[0],
|
||||
enable_xla=enable_xla),
|
||||
|
||||
_make_harness("dynamic_slice", f"enable_xla={enable_xla}",
|
||||
_make_harness("dynamic_slice", f"idx=tuple_int_enable_xla={enable_xla}",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_xla=enable_xla),
|
||||
|
||||
_make_harness("dynamic_slice", f"idx=tuple_arg_enable_xla={enable_xla}",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)),
|
||||
[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
|
||||
poly_axes=[0, None],
|
||||
enable_xla=enable_xla),
|
||||
|
||||
_make_harness("dynamic_slice", f"idx=array_enable_xla={enable_xla}",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
|
||||
[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
||||
poly_axes=[0, None],
|
||||
enable_xla=enable_xla),
|
||||
|
||||
_make_harness("dynamic_update_slice", f"idx=tuple_int_enable_xla={enable_xla}",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_xla=enable_xla),
|
||||
|
||||
_make_harness("dynamic_update_slice", f"idx=tuple_arg_enable_xla={enable_xla}",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))),
|
||||
[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
|
||||
poly_axes=[0, None],
|
||||
enable_xla=enable_xla),
|
||||
|
||||
_make_harness("dynamic_update_slice", f"idx=array_enable_xla={enable_xla}",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
|
||||
[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
||||
poly_axes=[0, None],
|
||||
enable_xla=enable_xla),
|
||||
|
||||
_make_harness("jnp_take", f"enable_xla={enable_xla}",
|
||||
lambda a, i: jnp.take(a, i, axis=1),
|
||||
[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)],
|
||||
poly_axes=[0, None], enable_xla=enable_xla),
|
||||
|
||||
_make_harness("jnp_getitem", f"enable_xla={enable_xla}",
|
||||
# operand is non-poly, index is poly
|
||||
_make_harness("jnp_getitem", f"op=static_idx=poly_enable_xla={enable_xla}",
|
||||
lambda a, i: a[i],
|
||||
[RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
|
||||
poly_axes=[None, 0], enable_xla=enable_xla),
|
||||
|
||||
# operand is poly, index is integer
|
||||
_make_harness("jnp_getitem", f"op=poly_idx=const_enable_xla={enable_xla}",
|
||||
lambda a: a[1],
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_xla=enable_xla),
|
||||
|
||||
# operand is poly, index is dim poly
|
||||
_make_harness("jnp_getitem", f"op=poly_idx=dim_enable_xla={enable_xla}",
|
||||
lambda a: a[jax.core.dimension_as_value(a.shape[0] - 2)],
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_xla=enable_xla),
|
||||
|
||||
# Both the operand and the index are poly
|
||||
_make_harness("jnp_getitem", f"op=poly_idx=poly_enable_xla={enable_xla}",
|
||||
lambda a, i: a[i],
|
||||
[RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)],
|
||||
poly_axes=[0, 0], enable_xla=enable_xla),
|
||||
|
||||
])
|
||||
|
||||
for reduce_op in [jnp.all, jnp.any, jnp.max, jnp.min, jnp.prod, jnp.sum]:
|
||||
@ -1343,9 +1430,6 @@ def _add_vmap_primitive_harnesses():
|
||||
|
||||
# We do not yet support shape polymorphism for vmap for some primitives
|
||||
_NOT_SUPPORTED_YET = frozenset([
|
||||
# In the random._gamma_impl we do reshape(-1, 2) for the keys
|
||||
"random_gamma",
|
||||
|
||||
# In linalg._lu_python we do reshape(-1, ...)
|
||||
"lu",
|
||||
"custom_linear_solve",
|
||||
@ -1409,6 +1493,14 @@ def _get_jax2tf_limitations(
|
||||
|
||||
_add_vmap_primitive_harnesses()
|
||||
|
||||
def _flatten_harnesses(harnesses):
|
||||
res = []
|
||||
for h in harnesses:
|
||||
if isinstance(h, Sequence):
|
||||
res.extend(h)
|
||||
else:
|
||||
res.append(h)
|
||||
return res
|
||||
|
||||
class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
"""Tests for primitives that take shape values as parameters."""
|
||||
@ -1418,9 +1510,11 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# For each primitive "xxx" the
|
||||
# test will be called "test_prim_xxx_...".
|
||||
# If you want to run this test for only one harness that includes "foo"
|
||||
# in the name, add parameter `one_containing="foo"` to parameterized below.
|
||||
@primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES,
|
||||
one_containing="dynamic_slice_enable_xla=True_poly_axes=[0]")
|
||||
# in the name (after test_prim), add parameter `one_containing="foo"`
|
||||
# to parameterized below.
|
||||
@primitive_harness.parameterized(_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
|
||||
#one_containing="dynamic_slice_idx=tuple_arg_enable_xla=False_poly_axes=[0, None]"
|
||||
)
|
||||
def test_prim(self, harness: Harness):
|
||||
args = harness.dyn_args_maker(self.rng())
|
||||
poly_axes = harness.params["poly_axes"] # type: Sequence[Sequence[int]]
|
||||
|
@ -338,8 +338,8 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
"""
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes,
|
||||
enable_xla=enable_xla)
|
||||
f_tf = tf.function(f_tf, autograph=False, input_signature=input_signature)
|
||||
concrete_f_tf = f_tf.get_concrete_function(*input_signature)
|
||||
f_tf_func = tf.function(f_tf, autograph=False, input_signature=input_signature)
|
||||
concrete_f_tf = f_tf_func.get_concrete_function(*input_signature)
|
||||
if expected_output_signature:
|
||||
# Strangely, output_shapes can be a single shape for a function with a
|
||||
# single result, or a list/tuple of shapes.
|
||||
|
Loading…
x
Reference in New Issue
Block a user