mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Fixes for handling of convolutions with shape_polymorphism and enable_xla=False
Issue: #11402 Due to a typo we were running no tests for convolutions with shape polymorphism and enable_xla=False. Added a few more tests from #11402 (Thanks @sdenton4). The main issue was that in presence of shape polymorphism we cannot just use `x.shape` for a TF value `x` because it will contain `None` in the place of unknown dimensions. We must use instead the JAX abstract values. This does not fix all issues reported in #11402, there is still the computation of padding or padding="SAME". Commented out the corresponding test.
This commit is contained in:
parent
86ab8a84ee
commit
b22121c0c1
@ -71,29 +71,41 @@ def _invert_permutation(perm):
|
||||
return tuple(perm.index(i) for i in range(len(perm)))
|
||||
|
||||
|
||||
def _transpose_for_tf_conv(lhs, rhs, dimension_numbers):
|
||||
"""Tranposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions."""
|
||||
def _transpose_with_shape(x: TfVal, x_shape: core.Shape, permutation) -> Tuple[TfVal, core.Shape]:
|
||||
"""Computes transposition of x and its shape.
|
||||
x_shape matches x.shape in the known dimensions, and it has dimension
|
||||
polynomials elsewhere, while x.shape has None."""
|
||||
return tf.transpose(x, perm=permutation), tuple(x_shape[i] for i in permutation)
|
||||
|
||||
def _transpose_for_tf_conv(lhs, lhs_shape: core.Shape,
|
||||
rhs, rhs_shape: core.Shape, dimension_numbers):
|
||||
"""Tranposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions.
|
||||
|
||||
The shapes passed in and returned may contain polynomials, and thus may
|
||||
be different than lhs.shape and rhs.shape."""
|
||||
# TODO(marcvanzee): Add tests for this ops for shape polymorphism.
|
||||
lhs_perm, rhs_perm, _ = dimension_numbers
|
||||
|
||||
# TODO(marcvanzee): Consider merging tranposes if we want to optimize.
|
||||
# For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW".
|
||||
lhs = tf.transpose(lhs, lhs_perm) # lhs --> "NCHW"
|
||||
lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, lhs_perm) # lhs --> "NCHW"
|
||||
if len(lhs_perm) == 3:
|
||||
# For 1D convolution, we add a trivial "W" dimension, so that 2D Convolution
|
||||
# logic can be applied downstream.
|
||||
lhs = lhs[:, :, :, np.newaxis]
|
||||
lhs_shape = tuple(lhs_shape) + (1,)
|
||||
# However, the TF ops only support "NHWC" on CPU, so we transpose again.
|
||||
lhs = tf.transpose(lhs, (0, 2, 3, 1)) # "NCHW" --> "NHWC"
|
||||
lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, (0, 2, 3, 1)) # "NCHW" --> "NHWC"
|
||||
|
||||
# For `rhs_perm`, perm (0, 1, 2, 3) corresponds to "OIHW".
|
||||
rhs = tf.transpose(rhs, rhs_perm) # rhs --> "OIHW"
|
||||
rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, rhs_perm) # rhs --> "OIHW"
|
||||
# Handle conv1d case.
|
||||
if len(rhs_perm) == 3:
|
||||
rhs = rhs[:, :, :, np.newaxis]
|
||||
rhs_shape = tuple(rhs_shape) + (1,)
|
||||
# For the tf ops, rhs is expected to be "OIHW".
|
||||
rhs = tf.transpose(rhs, (2, 3, 1, 0)) # "OIHW" --> "HWIO"
|
||||
return lhs, rhs
|
||||
rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, (2, 3, 1, 0)) # "OIHW" --> "HWIO"
|
||||
return lhs, lhs_shape, rhs, rhs_shape
|
||||
|
||||
|
||||
def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
|
||||
@ -104,18 +116,22 @@ def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
|
||||
return "EXPLICIT"
|
||||
|
||||
|
||||
def _pad_spatial_dims(in_shape, padding, is_conv1d):
|
||||
"""Pads `in_shape` using `padding`, which specifies padding for the spatial dimensions."""
|
||||
def _pad_spatial_dims(x, x_shape, padding, is_conv1d):
|
||||
"""Pads `x` using `padding`, which specifies padding for the spatial dimensions."""
|
||||
# Add empty padding for batch and feature dimensions.
|
||||
no_pad = tf.constant([[0, 0]])
|
||||
no_pad = ((0, 0),)
|
||||
padding = tuple(padding)
|
||||
if is_conv1d:
|
||||
padding = tf.concat([no_pad, padding, no_pad], 0)
|
||||
padding = no_pad + padding + no_pad
|
||||
# Add empty padding for dummy dimension, too.
|
||||
padding = tf.concat([no_pad, padding, no_pad, no_pad], 0)
|
||||
padding = no_pad + padding + no_pad + no_pad
|
||||
else:
|
||||
padding = tf.concat([no_pad, padding, no_pad], 0)
|
||||
in_shape = tf.pad(in_shape, padding)
|
||||
return in_shape
|
||||
padding = no_pad + padding + no_pad
|
||||
x = tf.pad(x, padding)
|
||||
assert len(x.shape) == len(padding)
|
||||
x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding))
|
||||
jax2tf._assert_matching_abstract_shape(x, x_shape)
|
||||
return x, x_shape
|
||||
|
||||
|
||||
def _conv_transpose_pads_to_padtype(kernel_sdims, lhs_dilation, padding):
|
||||
@ -224,19 +240,27 @@ def _conv_general_dilated(
|
||||
preferred_element_type: Optional[DType],
|
||||
_in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray):
|
||||
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
|
||||
# In presence of shape polymorphism, lhs.shape and rhs.shape may contain
|
||||
# None. The actual dimension polynomial shapes are in _in_avals.
|
||||
del lhs_shape, rhs_shape, precision # Unused arguments.
|
||||
out_shape = jax2tf._aval_to_tf_shape(_out_aval)
|
||||
_validate_spatial_dimensions(len(lhs.shape) - 2)
|
||||
is_conv1d = len(lhs.shape) - 2 == 1
|
||||
lhs_shape, rhs_shape = _in_avals[0].shape, _in_avals[1].shape
|
||||
jax2tf._assert_matching_abstract_shape(lhs, lhs_shape)
|
||||
jax2tf._assert_matching_abstract_shape(rhs, rhs_shape)
|
||||
out_shape = _out_aval.shape
|
||||
_validate_spatial_dimensions(len(lhs_shape) - 2)
|
||||
is_conv1d = len(lhs_shape) - 2 == 1
|
||||
|
||||
tf_window_strides = _normalize_window_strides(window_strides)
|
||||
padding, lhs_dilation, rhs_dilation = _normalize_padding_and_dilations(
|
||||
padding, lhs_dilation, rhs_dilation, is_conv1d)
|
||||
|
||||
lhs, rhs = _transpose_for_tf_conv(lhs, rhs, dimension_numbers)
|
||||
|
||||
in_channels = lhs.shape[-1]
|
||||
*rhs_spatial_shapes, _, rhs_out_channel = rhs.shape
|
||||
lhs, lhs_shape, rhs, rhs_shape = _transpose_for_tf_conv(lhs, lhs_shape,
|
||||
rhs, rhs_shape,
|
||||
dimension_numbers)
|
||||
jax2tf._assert_matching_abstract_shape(lhs, lhs_shape)
|
||||
jax2tf._assert_matching_abstract_shape(rhs, rhs_shape)
|
||||
in_channels = lhs_shape[-1]
|
||||
*rhs_spatial_shapes, _, rhs_out_channel = rhs_shape
|
||||
|
||||
is_transpose = any([d != 1 for d in lhs_dilation])
|
||||
is_atrous = any([d != 1 for d in rhs_dilation])
|
||||
@ -255,18 +279,18 @@ def _conv_general_dilated(
|
||||
rhs_spatial_shapes, lhs_dilation, padding)
|
||||
else:
|
||||
padding_type = pads_to_padtype(
|
||||
lhs.shape[1:3], rhs_dilated_shape, window_strides, padding)
|
||||
lhs_shape[1:3], rhs_dilated_shape, window_strides, padding)
|
||||
# We only manually pad if we aren't using a tranposed convolutions.
|
||||
if padding_type == "EXPLICIT":
|
||||
lhs = _pad_spatial_dims(lhs, padding, is_conv1d)
|
||||
lhs, lhs_shape = _pad_spatial_dims(lhs, lhs_shape, padding, is_conv1d)
|
||||
padding_type = "VALID"
|
||||
|
||||
if any(r > l for l, r in zip(lhs.shape[1:3], rhs_dilated_shape)
|
||||
) and padding_type != "SAME":
|
||||
# If the filter shape is bigger than the input shape in a spatial dimension,
|
||||
if padding_type != "SAME" and any(l < r for l, r in zip(lhs_shape[1:3], rhs_dilated_shape)):
|
||||
# If the input shape is smaller than the filter shape in a spatial dimension,
|
||||
# lax returns only zeros while tf.conv2d returns an error.
|
||||
# We thus return zeros to make sure the behavior is consistent.
|
||||
return tf.broadcast_to(tf.constant(0, dtype=tf.float32), out_shape)
|
||||
return tf.broadcast_to(tf.constant(0, dtype=tf.float32),
|
||||
jax2tf._eval_shape(out_shape))
|
||||
|
||||
if is_depthwise:
|
||||
# Reshape filter from
|
||||
@ -276,7 +300,7 @@ def _conv_general_dilated(
|
||||
rhs_out_channel // in_channels)
|
||||
output = tf.nn.depthwise_conv2d(
|
||||
input=lhs,
|
||||
filter=tf.reshape(rhs, new_rhs_shape),
|
||||
filter=tf.reshape(rhs, jax2tf._eval_shape(new_rhs_shape)),
|
||||
strides=tf_window_strides,
|
||||
padding=padding_type,
|
||||
dilations=rhs_dilation)
|
||||
@ -286,7 +310,7 @@ def _conv_general_dilated(
|
||||
rhs_t = tf.reverse(rhs, [0, 1])
|
||||
rhs_t = tf.transpose(rhs_t, (0, 1, 3, 2))
|
||||
|
||||
# We should tranpose `out_shape` to "NHWC", which is what TF expects.
|
||||
# We should transpose `out_shape` to "NHWC", which is what TF expects.
|
||||
# First transpose to "NCHW".
|
||||
if is_conv1d:
|
||||
tf_out_shape = tuple(out_shape[i] for i in output_perm) + (1,)
|
||||
@ -297,7 +321,7 @@ def _conv_general_dilated(
|
||||
output = tf.nn.conv2d_transpose(
|
||||
input=lhs,
|
||||
filters=rhs_t,
|
||||
output_shape=tf_out_shape,
|
||||
output_shape=jax2tf._eval_shape(tf_out_shape),
|
||||
strides=lhs_dilation,
|
||||
padding=padding_type)
|
||||
|
||||
|
@ -687,6 +687,15 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
|
||||
dim_values, dim_avals, "")) # type: ignore
|
||||
return shape_values
|
||||
|
||||
def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]):
|
||||
"""Asserts that shape matches x.shape in the known dimensions and has
|
||||
dimension polynomials elsewhere."""
|
||||
# Ensures that the shape does not contain None; it should contain polynomials
|
||||
assert (len(x.shape) == len(shape) and
|
||||
all((xd is None and isinstance(sd, shape_poly._DimPolynomial) or
|
||||
core.is_constant_dim(sd) and xd == sd)
|
||||
for xd, sd in zip(x.shape, shape))), \
|
||||
f"Shape {shape} does not match x.shape {x.shape}"
|
||||
|
||||
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
|
||||
# pylint: disable=assigning-non-slot
|
||||
|
@ -101,7 +101,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
@primitive_harness.parameterized(
|
||||
primitive_harness.all_harnesses,
|
||||
include_jax_unimpl=False,
|
||||
#one_containing="cumprod_dtype_by_fun_shape=float16[8,9]_axis=0_reverse=False"
|
||||
#one_containing="conv_general_dilated_dtype_precision_lhs=float16[2,3,9,10]_rhs=float16[3,3,4,5]_windowstrides=(1,1)_padding=((0,0),(0,0))_lhsdilation=(1,1)_rhsdilation=(1,1)_dimensionnumbers=('NCHW','OIHW','NCHW')_featuregroupcount=1_batchgroupcount=1_precision=DEFAULT_preferred=float64_enablexla=True"
|
||||
)
|
||||
@jtu.ignore_warning(
|
||||
category=UserWarning, message="Using reduced precision for gradient.*")
|
||||
|
@ -1117,7 +1117,7 @@ def _make_harness(group_name: str, name: str,
|
||||
check_result=True,
|
||||
skip_jax_run=True,
|
||||
tol=None,
|
||||
enable_and_diable_xla=False,
|
||||
enable_and_disable_xla=False,
|
||||
expect_error=(None, None),
|
||||
**params) -> Union[Harness, Sequence[Harness]]:
|
||||
"""The `poly_axes` must correspond to the non-static arguments, and for each
|
||||
@ -1139,15 +1139,15 @@ def _make_harness(group_name: str, name: str,
|
||||
`expect_error` is a pair of an Exception type and a regular expression to
|
||||
match the expected exception string.
|
||||
|
||||
enable_and_diable_xla=True means that we generate two harnesses,
|
||||
enable_and_disable_xla=True means that we generate two harnesses,
|
||||
one with enable_xla=False.
|
||||
"""
|
||||
if enable_and_diable_xla:
|
||||
if enable_and_disable_xla:
|
||||
return [
|
||||
_make_harness(group_name, name + ("" if enable_xla else "_noxla"), # type: ignore
|
||||
func, args, poly_axes=poly_axes,
|
||||
check_result=check_result, tol=tol, enable_xla=enable_xla,
|
||||
enable_and_diable_xla=False, skip_jax_run=skip_jax_run,
|
||||
enable_and_disable_xla=False, skip_jax_run=skip_jax_run,
|
||||
expect_error=expect_error,
|
||||
**params)
|
||||
for enable_xla in [True, False]
|
||||
@ -1187,7 +1187,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda op: jnp.arange(2 * op.shape[0], dtype=_f32),
|
||||
[RandArg((3,), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("arange", "start_no_dtype",
|
||||
lambda op: jnp.arange(op.shape[0]),
|
||||
[RandArg((3,), _f32)],
|
||||
@ -1212,13 +1212,13 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
# Reduce the non-poly dimension
|
||||
_make_harness("argmax", "1",
|
||||
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
[
|
||||
_make_harness("average",
|
||||
f"axis={axis}_weights=None",
|
||||
@ -1276,18 +1276,55 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
jax.grad(lambda x: jnp.sum(jnp.concatenate([x, x], axis=0))),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[(0, 1)]),
|
||||
|
||||
# Issue #11402 InconclusiveDimensionOperation: Dimension polynomial '-1*t' is not a multiple of '2'
|
||||
# TODO(still fails)
|
||||
# _make_harness("conv_general_dilated", "1d_1",
|
||||
# lambda lhs, rhs: lax.conv_general_dilated(
|
||||
# lhs, rhs,
|
||||
# window_strides=(2,),
|
||||
# padding="SAME",
|
||||
# rhs_dilation=None,
|
||||
# dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
|
||||
# rhs_spec=(2, 1, 0),
|
||||
# out_spec=(0, 2, 1))),
|
||||
# [RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
# poly_axes=[1, None],
|
||||
# enable_and_disable_xla=True),
|
||||
# Issue #11402
|
||||
_make_harness("conv_general_dilated", "1d_2",
|
||||
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
|
||||
strides=(2,),
|
||||
padding="SAME",
|
||||
rhs_dilation=None,
|
||||
transpose_kernel=False),
|
||||
[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
poly_axes=[0, None],
|
||||
enable_and_disable_xla=True),
|
||||
# Issue #11402
|
||||
_make_harness("conv_general_dilated", "1d_3",
|
||||
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
|
||||
strides=(2,),
|
||||
padding="SAME",
|
||||
rhs_dilation=None,
|
||||
transpose_kernel=False),
|
||||
[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
|
||||
poly_axes=[1, None],
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("conv_general_dilated", "",
|
||||
lambda lhs, rhs: lax.conv_general_dilated(lhs, rhs,
|
||||
window_strides=(2, 3),
|
||||
padding=((0, 0), (0, 0)),
|
||||
lhs_dilation=(1, 1),
|
||||
rhs_dilation=(1, 2),
|
||||
dimension_numbers=("NCHW", "OIHW", "NCHW"),
|
||||
feature_group_count=1,
|
||||
batch_group_count=1,
|
||||
precision=None),
|
||||
lambda lhs, rhs: lax.conv_general_dilated(
|
||||
lhs, rhs,
|
||||
window_strides=(2, 3),
|
||||
padding=((0, 0), (0, 0)),
|
||||
lhs_dilation=(1, 1),
|
||||
rhs_dilation=(1, 2),
|
||||
dimension_numbers=("NCHW", "OIHW", "NCHW"),
|
||||
feature_group_count=1,
|
||||
batch_group_count=1,
|
||||
precision=None),
|
||||
[RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)],
|
||||
poly_axes=[0, None]),
|
||||
poly_axes=[0, None],
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("cummax", "",
|
||||
lambda x: lax_control_flow.cummax(x, axis=1, reverse=False),
|
||||
[RandArg((3, 4, 5), _f32)],
|
||||
@ -1306,43 +1343,43 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("dynamic_slice", "idx=tuple_arg",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)),
|
||||
[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
|
||||
poly_axes=[0, None],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("dynamic_slice", "idx=array",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
|
||||
[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
||||
poly_axes=[0, None],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("dynamic_slice_in_dim", "idx=0",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("dynamic_update_slice", "idx=tuple_int",
|
||||
# x:shape: (b, 4)
|
||||
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("dynamic_update_slice", "idx=tuple_arg",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))),
|
||||
[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
|
||||
poly_axes=[0, None],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("dynamic_update_slice", "idx=array",
|
||||
# x:shape: (b, 4)
|
||||
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
|
||||
[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
|
||||
poly_axes=[0, None],
|
||||
enable_and_diable_xla=True),
|
||||
enable_and_disable_xla=True),
|
||||
_make_harness("einsum", "0",
|
||||
lambda x: jnp.einsum("...i->...", x),
|
||||
[RandArg((3, 4), _f32)],
|
||||
@ -1417,45 +1454,45 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
_make_harness("getitem", "op=static_idx=poly",
|
||||
lambda a, i: a[i],
|
||||
[RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
|
||||
poly_axes=[None, 0], enable_and_diable_xla=True),
|
||||
poly_axes=[None, 0], enable_and_disable_xla=True),
|
||||
# operand is poly, index is integer
|
||||
_make_harness("getitem", "op=poly_idx=const",
|
||||
lambda a: a[1],
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_and_diable_xla=True),
|
||||
poly_axes=[0], enable_and_disable_xla=True),
|
||||
# operand is poly, index is dim poly
|
||||
_make_harness("getitem", "op=poly_idx=dim",
|
||||
lambda a: a[jax.core.dimension_as_value(a.shape[0] - 2)],
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_and_diable_xla=True),
|
||||
poly_axes=[0], enable_and_disable_xla=True),
|
||||
# Both the operand and the index are poly
|
||||
_make_harness("getitem", "op=poly_idx=poly",
|
||||
lambda a, i: a[i],
|
||||
[RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)],
|
||||
poly_axes=[0, 0], enable_and_diable_xla=True),
|
||||
poly_axes=[0, 0], enable_and_disable_xla=True),
|
||||
# op is poly and index is an entire slice
|
||||
_make_harness("getitem", "op=poly_idx=slice-all",
|
||||
lambda a: a[:],
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_and_diable_xla=True),
|
||||
poly_axes=[0], enable_and_disable_xla=True),
|
||||
# op is poly and index is a partial slice
|
||||
_make_harness("getitem", "op=poly_idx=slice-ct-1",
|
||||
lambda a: a[:2],
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_and_diable_xla=True,
|
||||
poly_axes=[0], enable_and_disable_xla=True,
|
||||
expect_error=(IndexError, "Cannot use NumPy slice indexing on an array dimension")),
|
||||
_make_harness("getitem", "op=poly_idx=slice-ct-2",
|
||||
lambda a: a[:, :2],
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_and_diable_xla=True),
|
||||
poly_axes=[0], enable_and_disable_xla=True),
|
||||
_make_harness("getitem", "op=poly_idx=slice-None-1",
|
||||
lambda a: a[:a.shape[0]],
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_and_diable_xla=True),
|
||||
poly_axes=[0], enable_and_disable_xla=True),
|
||||
_make_harness("getitem", "op=poly_idx=slice-poly",
|
||||
lambda a: a[:a.shape[0] - 1],
|
||||
[RandArg((3, 4), _f32)],
|
||||
poly_axes=[0], enable_and_diable_xla=True,
|
||||
poly_axes=[0], enable_and_disable_xla=True,
|
||||
expect_error=(IndexError, "Array slice indices must have static")),
|
||||
_make_harness("image_resize", "linear_0",
|
||||
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
|
||||
@ -1668,7 +1705,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
_make_harness("take", "",
|
||||
lambda a, i: jnp.take(a, i, axis=1),
|
||||
[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)],
|
||||
poly_axes=[0, None], enable_and_diable_xla=True),
|
||||
poly_axes=[0, None], enable_and_disable_xla=True),
|
||||
_make_harness("take_along_axis", "0",
|
||||
lambda x, y: jnp.take_along_axis(x, y, axis=0),
|
||||
[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)],
|
||||
@ -1807,7 +1844,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
# to parameterized below.
|
||||
@primitive_harness.parameterized(
|
||||
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
|
||||
#one_containing="take_along_axis_1_poly_axes=[0, 0]"
|
||||
#one_containing="conv_general_dilated_1d_2_noxla_poly_axes=[0, None]"
|
||||
)
|
||||
def test_prim(self, harness: Harness):
|
||||
_test_one_harness(self, harness)
|
||||
|
Loading…
x
Reference in New Issue
Block a user