mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] lax.reduce_window (enable_xla=False): bug fix and improvements.
* Fixes https://github.com/google/jax/issues/11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions. * Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now. * Adds support for explicit padding using the existing padding logic from convolutions. * Fixes https://github.com/google/jax/issues/11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this. * Ensures we call eval_shape on a shape containing polynomials before calling a TF op. * Fixes https://github.com/google/jax/issues/11929#issuecomment-1216261697: we weren't running any of the shape_poly_test.py tests for `enable_xla=False`. PiperOrigin-RevId: 467879449
This commit is contained in:
parent
da168a100a
commit
df5f3c556c
@ -150,20 +150,29 @@ function `lax.reduce_window_p` with the following conditions:
|
||||
|
||||
We provide partial support for all these ops, with the following limitations:
|
||||
|
||||
* `computation` should be one of `lax.min`, `lax.max`, or `lax.add`.
|
||||
* For `lax.min` and `lax.max`, dtypes `np.bool`, `np.uint32`, `np.uint64`,
|
||||
`np.complex64`, and `np.complex128` are not supported.
|
||||
* Additionally, for `lax.min`, dtypes `np.uint8` and `np.uint16` are not
|
||||
supported.
|
||||
* For `lax.add`, only dtypes `np.float16`, `np.float32`, and `np.float64` are
|
||||
supported.
|
||||
* We support at most 2 spatial dimension.
|
||||
* Base dilations other than `(1,) * len(operand)` are not supported.
|
||||
* `padding` should either be `VALID` or `SAME`.
|
||||
* Using `lax.add` on TPU may give very large deviations. This is due to the way
|
||||
the conversion is implemented (first take the average over the window and then
|
||||
multiply by window size). This gives large deviations on TPU due to the fact
|
||||
that it uses `bfloat16` for computations.
|
||||
* `computation` should be one of `lax.min`, `lax.max`, or `lax.add`.
|
||||
* For `lax.min` and `lax.max`, dtypes `np.bool`, `np.uint32`, `np.uint64`,
|
||||
`np.complex64`, and `np.complex128` are not supported.
|
||||
* Additionally, for `lax.min`, dtypes `np.uint8` and `np.uint16` are not
|
||||
supported.
|
||||
* For `lax.add`, only dtypes `np.float16`, `np.float32`, and `np.float64` are
|
||||
supported.
|
||||
* We support at most 2 spatial dimension.
|
||||
* Base dilations other than `(1,) * len(operand)` are not supported.
|
||||
* `padding` should either be `VALID` or `SAME`.
|
||||
* We compute `lax.reduce_window_sum_p` by calling `tf.nn.avg_pool` (through
|
||||
`tf.nn.pool`), and then multiplying the result by
|
||||
`np.prod(window_dimensions)`. If you are using an NN library that implements
|
||||
`avg_pool` using `lax.reduce_window` (such as Flax's
|
||||
[pooling.py](https://github.com/google/flax/blob/main/flax/linen/pooling.py)),
|
||||
this is usually implemented by dividing the result with
|
||||
`np.prod(window_dimensions)`. So when converting this function, the
|
||||
resulting computation for `avg_pool` is `(tf.nn.avg_pool(xs) *
|
||||
np.prod(window)) / np.prod(window)`. This is redundant and can be optimized.
|
||||
* Using `lax.add` on TPU may give very large deviations. This is due to the
|
||||
way the conversion is implemented (first take the average over the window
|
||||
and then multiply by window size). This gives large deviations on TPU due to
|
||||
the fact that it uses `bfloat16` for computations.
|
||||
|
||||
We implement all reductions using the Tensorflow function
|
||||
[tf.nn.pool](https://www.tensorflow.org/api_docs/python/tf/nn/pool).
|
||||
|
@ -120,10 +120,11 @@ def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
|
||||
|
||||
def _pad_spatial_dims(x, x_shape, padding):
|
||||
"""Pads `x` using `padding`, which specifies padding for the spatial dimensions."""
|
||||
# Add empty padding for batch and feature dimensions.
|
||||
no_pad = ((0, 0),)
|
||||
padding = tuple(padding)
|
||||
padding = no_pad + padding + no_pad
|
||||
if len(padding) == len(x_shape) - 2:
|
||||
# If necessary, add empty padding for batch and feature dimensions.
|
||||
no_pad = ((0, 0),)
|
||||
padding = no_pad + padding + no_pad
|
||||
x = tf.pad(x, padding)
|
||||
assert len(x.shape) == len(padding)
|
||||
x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding))
|
||||
@ -517,11 +518,9 @@ tf_impl_no_xla[lax.argmin_p] = partial(_argminmax, True)
|
||||
tf_impl_no_xla[lax.argmax_p] = partial(_argminmax, False)
|
||||
|
||||
|
||||
def _reduce_monoid(operand, window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation, computation_name,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
dtype = operand.dtype
|
||||
def _validate_reduce_window_inputs(operand_shape, computation_name, dtype,
|
||||
window_dimensions, window_strides,
|
||||
base_dilation, window_dilation):
|
||||
if computation_name not in ["min", "max", "add"]:
|
||||
raise _reduce_error("Reduction function should be either min, max, or add.")
|
||||
if computation_name in ["min", "max"] and dtype in [
|
||||
@ -540,35 +539,117 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding,
|
||||
raise _reduce_error("Add pooling does not support operands of type "
|
||||
f"{dtype}")
|
||||
|
||||
# In presence of shape polymorphism, operand.shape may contain None. The
|
||||
# actual dimension polynomial shapes are in _in_avals.
|
||||
operand_shape = _in_avals[0].shape
|
||||
if (len(operand_shape) != len(window_dimensions) != len(window_strides) !=
|
||||
len(window_dilation)):
|
||||
raise _reduce_error("Input shapes, window dimensions, window stride "
|
||||
"dimensions, and window dilation dimensions should "
|
||||
"match.")
|
||||
|
||||
has_only_spatial_dims = True
|
||||
if len(operand_shape) > 4:
|
||||
raise _reduce_error("Only 1D or 2D input are supported.")
|
||||
if len(operand_shape) > 2:
|
||||
# operand_shape = (batch, spatial_dims, ..., channel).
|
||||
has_only_spatial_dims = False
|
||||
|
||||
for name, value in [("window_dimensions", window_dimensions),
|
||||
("window_strides", window_strides),
|
||||
("window_dilation", window_dilation)]:
|
||||
if value[0] != value[-1] != 1:
|
||||
raise _reduce_error("Only 1D or 2D input are supported, expected "
|
||||
f"{name}=(1, spatial_dims, ..., 1), but got "
|
||||
f"{value}")
|
||||
|
||||
if list(base_dilation) != [1] * len(operand_shape):
|
||||
# TODO(marcvanzee): Add support for base dilations. We can do this using
|
||||
# a scatter on operand.
|
||||
raise _reduce_error("Unimplemented support for base dilation.")
|
||||
|
||||
return has_only_spatial_dims
|
||||
|
||||
|
||||
def _padding_reduce_window(operand, operand_shape, computation_name,
|
||||
window_dimensions, window_strides, padding):
|
||||
padding_type = pads_to_padtype(operand_shape, window_dimensions,
|
||||
window_strides, padding)
|
||||
if padding_type == "EXPLICIT":
|
||||
# TODO(marcvanzee): Add support for explicit padding. This can be done
|
||||
# similarly like we did for convolutions.
|
||||
raise _reduce_error("Only 'VALID' and 'SAME' padding are currently "
|
||||
"supported.")
|
||||
|
||||
def tf_pool(op, pooling_type):
|
||||
# Add batch and channel dimensions, these are expected by TF.
|
||||
op = tf.reshape(op, (1,) + operand_shape + (1,))
|
||||
op = tf.nn.pool(
|
||||
input=op,
|
||||
# https://github.com/google/jax/issues/11874.
|
||||
needs_manual_padding = (
|
||||
padding_type == "SAME" and computation_name == "add" and
|
||||
window_dimensions != [1] * len(operand_shape))
|
||||
|
||||
if needs_manual_padding or padding_type == "EXPLICIT":
|
||||
operand, operand_shape = _pad_spatial_dims(operand, operand_shape, padding)
|
||||
padding_type = "VALID"
|
||||
|
||||
return operand, operand_shape, padding_type
|
||||
|
||||
|
||||
def _reshape_reduce_window(operand, operand_shape, window_dimensions,
|
||||
window_strides, window_dilation, *,
|
||||
has_only_spatial_dims):
|
||||
# Reshape inputs so they are accepted by tf.nn.pool, which expects batch and
|
||||
# channel dimensions for operand but not for any of the other inputs.
|
||||
if has_only_spatial_dims: # len(operand_shape) <= 2
|
||||
# Call eval_shape on a shape that may contain polynomials, otherwise TF does
|
||||
# not know what to do with polynomials in the shape.
|
||||
operand_shape = jax2tf._eval_shape(operand_shape)
|
||||
# Add batch and channel dimensions to operand.
|
||||
operand = tf.reshape(operand, (1,) + operand_shape + (1,))
|
||||
else:
|
||||
# This branch assumes operand.shape = (batch, spatial_dims, ..., channel),
|
||||
# and dimensions, strides, dilation are all (1, spatial_values, ..., 1).
|
||||
# Input validation for this is done in _validate_reduce_window_inputs.
|
||||
window_dimensions = window_dimensions[1:-1]
|
||||
window_strides = window_strides[1:-1]
|
||||
window_dilation = window_dilation[1:-1]
|
||||
|
||||
return operand, window_dimensions, window_strides, window_dilation
|
||||
|
||||
|
||||
def _reduce_monoid(operand, window_dimensions, window_strides, padding,
|
||||
base_dilation, window_dilation, computation_name,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray):
|
||||
dtype = operand.dtype
|
||||
# In presence of shape polymorphism, operand.shape may contain None. The
|
||||
# actual dimension polynomial shapes are in _in_avals.
|
||||
operand_shape = _in_avals[0].shape
|
||||
|
||||
# TODO(marcvanzee): Put reduce_window arguments into dataclass, similar to
|
||||
# Gather, to simplify function calls.
|
||||
has_only_spatial_dims = _validate_reduce_window_inputs(
|
||||
operand_shape, computation_name, dtype, window_dimensions, window_strides,
|
||||
base_dilation, window_dilation)
|
||||
|
||||
operand, operand_shape, padding_type = _padding_reduce_window(
|
||||
operand, operand_shape, computation_name, window_dimensions,
|
||||
window_strides, padding)
|
||||
|
||||
operand, window_dimensions, window_strides, dilations = _reshape_reduce_window(
|
||||
operand,
|
||||
operand_shape,
|
||||
window_dimensions,
|
||||
window_strides,
|
||||
window_dilation,
|
||||
has_only_spatial_dims=has_only_spatial_dims)
|
||||
|
||||
def tf_pool(inputs, pooling_type):
|
||||
result = tf.nn.pool(
|
||||
inputs,
|
||||
window_shape=window_dimensions,
|
||||
pooling_type=pooling_type,
|
||||
padding=padding_type,
|
||||
strides=window_strides,
|
||||
dilations=window_dilation)
|
||||
op = tf.reshape(op, jax2tf._aval_to_tf_shape(_out_aval))
|
||||
return op
|
||||
dilations=dilations)
|
||||
|
||||
if has_only_spatial_dims:
|
||||
# If the input only had spatial dimensions we need to contract the batch
|
||||
# and channel dimensions before returning the output.
|
||||
result = tf.squeeze(result, [0, -1])
|
||||
|
||||
jax2tf._assert_matching_abstract_shape(result, _out_aval.shape)
|
||||
return result
|
||||
|
||||
negate = lambda x: tf.multiply(x, tf.constant(-1, dtype))
|
||||
if computation_name == "max":
|
||||
@ -577,8 +658,8 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding,
|
||||
return negate(tf_pool(negate(operand), "MAX"))
|
||||
elif computation_name == "add":
|
||||
# TODO(marcvanzee): This may give very large deviations on TPU when using
|
||||
# floats as inputs. We should think of a different implementation if users
|
||||
# run into this often.
|
||||
# floats as inputs. Alternatively, we could implement this using a
|
||||
# convolution with an all-1's kernel.
|
||||
return tf.multiply(tf_pool(operand, "AVG"), np.prod(window_dimensions))
|
||||
|
||||
|
||||
|
@ -124,25 +124,20 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
|
||||
# We keep here the explicit set of groups for which we don't have limitations
|
||||
harness_groups_no_limitations = {
|
||||
"abs", "add", "add_any", "and", "atan2",
|
||||
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "cbrt", "ceil",
|
||||
"clamp", "concatenate", "cos", "cosh", "complex", "conj",
|
||||
"convert_element_type",
|
||||
"cummax", "cummin", "device_put", "dynamic_slice",
|
||||
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt",
|
||||
"imag",
|
||||
"iota", "is_finite", "le", "lt", "log", "mul", "ne", "neg", "not",
|
||||
"or", "pad", "population_count",
|
||||
"abs", "add", "add_any", "and", "atan2", "bitcast_convert_type",
|
||||
"broadcast", "broadcast_in_dim", "cbrt", "ceil", "clamp", "concatenate",
|
||||
"cos", "cosh", "complex", "conj", "convert_element_type", "cummax",
|
||||
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
|
||||
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
|
||||
"lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count",
|
||||
"random_categorical", "random_split", "random_uniform", "random_randint",
|
||||
"reduce",
|
||||
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
|
||||
"reduce_window_mul", "reduce_window_min",
|
||||
"reduce_window_max",
|
||||
"real", "reshape", "rev", "rsqrt", "scatter_max", "scatter_min",
|
||||
"select_n", "select_and_scatter_add",
|
||||
"shift_left", "shift_right_logical", "shift_right_arithmetic", "sign",
|
||||
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
|
||||
"tie_in", "transpose", "xor", "zeros_like"
|
||||
"reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
|
||||
"reduce_window_mul", "reduce_window_min", "reduce_window_max", "real",
|
||||
"reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n",
|
||||
"select_and_scatter_add", "shift_left", "shift_right_logical",
|
||||
"shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt",
|
||||
"squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor",
|
||||
"zeros_like"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -910,6 +905,15 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
@classmethod
|
||||
def reduce_window_add(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
Jax2TfLimitation(
|
||||
"Small deviations on GPU for large inputs and enable_xla=False",
|
||||
dtypes=[np.float32],
|
||||
devices="gpu",
|
||||
modes=("eager", "graph", "compiled"),
|
||||
expect_tf_error=False,
|
||||
skip_comparison=False,
|
||||
enabled=not harness.params["enable_xla"],
|
||||
tol=3e-5),
|
||||
Jax2TfLimitation(
|
||||
"Large deviations on TPU for enable_xla=False",
|
||||
dtypes=[np.float16, np.float32],
|
||||
|
@ -828,7 +828,7 @@ for prim in [lax.argmin_p, lax.argmax_p]:
|
||||
for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
|
||||
_make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype)
|
||||
|
||||
# Some special cases, with equal elements and NaN
|
||||
# Some special cases, with equal elements and NaN
|
||||
for name, operand in [
|
||||
("nan_0", np.array([np.nan, np.nan, 2., -2., -np.nan, -np.nan], np.float32)),
|
||||
("nan_1", np.array([np.nan, -np.nan, 2., -2.], np.float32)),
|
||||
@ -2488,19 +2488,22 @@ _make_reduce_window_harness("base_dilation", base_dilation=(1, 2),
|
||||
requires_xla=True)
|
||||
# Validate window_dilation
|
||||
_make_reduce_window_harness("window_dilation", window_dilation=(1, 2))
|
||||
# Validate squeezing behavior and dimensions in tf.nn.max_pool
|
||||
for shape, window_dimensions in [
|
||||
((2,), (2,)), # 1 spatial dimension, left and right squeeze
|
||||
((2, 1), (2, 1)), # 1 spatial dimension, left squeeze
|
||||
((1, 2), (1, 2)), # 1 spatial dimension, right squeeze
|
||||
((1, 2, 1), (1, 2, 1)), # 1 spatial dimension no squeeze
|
||||
((2, 4), (2, 2)), # 2 spatial dimensions, left and right squeeze
|
||||
((2, 4, 3), (2, 2, 2)), # 3 spatial dimensions, left and right squeeze
|
||||
((1, 4, 3, 2, 1), (1, 2, 2, 2, 1)) # 3 spatial dimensions, no squeeze
|
||||
# Validate batch and channel dimensions behavior. lax.reduce_window accepts
|
||||
# inputs that either have or do not have batch and channel dimensions.
|
||||
# N=batch, DHW=spatial, C=channel.
|
||||
# Without XLA only supports 1D/2D reductions.
|
||||
for shape, window_dimensions, requires_xla in [
|
||||
((2,), (2,), False), # W
|
||||
((2, 1), (2, 1), False), # WC
|
||||
((1, 2), (1, 2), False), # NW
|
||||
((1, 2, 1), (1, 2, 1), False), # NWC
|
||||
((2, 4), (2, 2), False), # HW
|
||||
((1, 2, 4, 1), (1, 2, 2, 1), False), # NHWC
|
||||
((2, 4, 3), (2, 2, 2), True), # DHW
|
||||
((1, 4, 3, 2, 1), (1, 2, 2, 2, 1), True) # NDHWC
|
||||
]:
|
||||
requires_xla = len(shape) > 2 # Without XLA only supports 1D/2D reductions.
|
||||
_make_reduce_window_harness(
|
||||
"squeeze_dim",
|
||||
"batch_channel_dims",
|
||||
computation=lax.max,
|
||||
shape=shape,
|
||||
dtype=np.float32,
|
||||
@ -2512,17 +2515,31 @@ for shape, window_dimensions in [
|
||||
window_dimensions=window_dimensions,
|
||||
requires_xla=requires_xla)
|
||||
|
||||
# This corresponds to SAME padding.
|
||||
_make_reduce_window_harness(
|
||||
"same_padding",
|
||||
shape=(112, 112),
|
||||
init_value=-np.inf,
|
||||
computation=lax.max,
|
||||
window_dimensions=(3, 3),
|
||||
window_strides=(2, 2),
|
||||
padding="SAME")
|
||||
for computation, id_value in [(lax.max, _get_max_identity(np.float32)),
|
||||
(lax.min, _get_min_identity(np.float32)),
|
||||
(lax.add, 0.)]:
|
||||
_make_reduce_window_harness(
|
||||
"same_padding",
|
||||
shape=(112, 112),
|
||||
init_value=id_value,
|
||||
computation=computation,
|
||||
window_dimensions=(3, 3),
|
||||
window_strides=(2, 2),
|
||||
padding="SAME")
|
||||
|
||||
# A few additional test cases for manual padding, which is applied when calling
|
||||
# reduce_window with lax.add, SAME padding and window_dimensions != (1, 1, ...).
|
||||
for window_dimensions, window_strides in [((2, 2), (1, 1)), ((3, 3), (2, 2)),
|
||||
((13, 13), (5, 6))]:
|
||||
_make_reduce_window_harness(
|
||||
"manual_padding",
|
||||
shape=(12, 12),
|
||||
init_value=0.,
|
||||
computation=lax.add,
|
||||
window_dimensions=window_dimensions,
|
||||
window_strides=window_strides,
|
||||
padding="SAME")
|
||||
|
||||
# b/240647139
|
||||
_make_reduce_window_harness(
|
||||
"init_value_1d",
|
||||
shape=(1, 16000),
|
||||
|
@ -102,7 +102,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
@primitive_harness.parameterized(
|
||||
primitive_harness.all_harnesses,
|
||||
include_jax_unimpl=False,
|
||||
#one_containing="reduce_window_max",
|
||||
#one_containing="reduce_window_add_same_padding",
|
||||
)
|
||||
@jtu.ignore_warning(
|
||||
category=UserWarning, message="Using reduced precision for gradient.*")
|
||||
|
@ -1163,7 +1163,8 @@ def _make_harness(group_name: str, name: str,
|
||||
match the expected exception string.
|
||||
|
||||
enable_and_disable_xla=True means that we generate two harnesses,
|
||||
one with enable_xla=False.
|
||||
one with enable_xla=False and one with enable_xal=True. Otherwise we create
|
||||
only one harness with enable_xla=True.
|
||||
"""
|
||||
if enable_and_disable_xla:
|
||||
return [
|
||||
@ -1189,7 +1190,8 @@ def _make_harness(group_name: str, name: str,
|
||||
dtype=np.float32,
|
||||
poly_axes=poly_axes, check_result=check_result,
|
||||
skip_jax_run=skip_jax_run, expect_error=expect_error,
|
||||
tol=tol)
|
||||
tol=tol,
|
||||
**params)
|
||||
|
||||
|
||||
_f32 = np.float32
|
||||
@ -1637,13 +1639,26 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min,
|
||||
(2, 2), (1, 1), "VALID"),
|
||||
[RandArg((3, 8), _f32)],
|
||||
poly_axes=[0]),
|
||||
poly_axes=[0],
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("reduce_window", "add",
|
||||
# x.shape = (b, 8)
|
||||
lambda x: lax.reduce_window(x, 0, lax.add, (2, 2), (1, 1),
|
||||
"VALID"),
|
||||
[RandArg((3, 8), _f32)],
|
||||
poly_axes=[0]),
|
||||
poly_axes=[0],
|
||||
enable_and_disable_xla=True),
|
||||
# https://github.com/google/jax/issues/11804
|
||||
# Use the reshape trick to simulate a polymorphic dimension of 16*b.
|
||||
# (See test "conv_general_dilated.1d_1" above for more details.)
|
||||
_make_harness("reduce_window", "add",
|
||||
# x.shape = (1, 16*b, 1)
|
||||
lambda x: lax.reduce_window(
|
||||
jnp.reshape(x, (1, -1, 1)),
|
||||
0., lax.add, (1, 4, 1), (1, 2, 1), "SAME"),
|
||||
[RandArg((1, 128, 16), _f32)],
|
||||
poly_axes=[1],
|
||||
enable_and_disable_xla=True),
|
||||
# TODO(necula): not yet supported, but also unlikely to come up.
|
||||
# _make_harness("random_uniform", "odd",
|
||||
# lambda key, a: jax.random.uniform(key, (2 * a.shape[0] + 1, a.shape[1]),
|
||||
@ -1897,7 +1912,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# to parameterized below.
|
||||
@primitive_harness.parameterized(
|
||||
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
|
||||
#one_containing="reduce_window_add",
|
||||
#one_containing="reduce_window_add_noxla_poly_axes=[1]",
|
||||
)
|
||||
def test_prim(self, harness: Harness):
|
||||
_test_one_harness(self, harness)
|
||||
|
@ -392,7 +392,8 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
"""
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes,
|
||||
enable_xla=enable_xla)
|
||||
f_tf_func = tf.function(f_tf, autograph=False, input_signature=input_signature)
|
||||
f_tf_func = tf.function(
|
||||
f_tf, autograph=False, input_signature=input_signature)
|
||||
concrete_f_tf = f_tf_func.get_concrete_function(*input_signature)
|
||||
if expected_output_signature:
|
||||
# Strangely, output_shapes can be a single shape for a function with a
|
||||
|
Loading…
x
Reference in New Issue
Block a user