[jax2tf] Fixes for handling of convolutions with shape_polymorphism and enable_xla=False

Issue: #11402

Due to a typo we were running no tests for convolutions with shape
polymorphism and enable_xla=False.

Added a few more tests from #11402 (Thanks @sdenton4).

The main issue was that in presence of shape polymorphism we cannot
just use `x.shape` for a TF value `x` because it will contain `None`
in the place of unknown dimensions. We must use instead the JAX
abstract values.

This does not fix all issues reported in #11402, there is still the
computation of padding or padding="SAME". Commented out the
corresponding test.
This commit is contained in:
George Necula 2022-07-12 16:27:57 +03:00
parent 86ab8a84ee
commit b22121c0c1
4 changed files with 137 additions and 67 deletions

View File

@ -71,29 +71,41 @@ def _invert_permutation(perm):
return tuple(perm.index(i) for i in range(len(perm)))
def _transpose_for_tf_conv(lhs, rhs, dimension_numbers):
"""Tranposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions."""
def _transpose_with_shape(x: TfVal, x_shape: core.Shape, permutation) -> Tuple[TfVal, core.Shape]:
"""Computes transposition of x and its shape.
x_shape matches x.shape in the known dimensions, and it has dimension
polynomials elsewhere, while x.shape has None."""
return tf.transpose(x, perm=permutation), tuple(x_shape[i] for i in permutation)
def _transpose_for_tf_conv(lhs, lhs_shape: core.Shape,
rhs, rhs_shape: core.Shape, dimension_numbers):
"""Tranposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions.
The shapes passed in and returned may contain polynomials, and thus may
be different than lhs.shape and rhs.shape."""
# TODO(marcvanzee): Add tests for this ops for shape polymorphism.
lhs_perm, rhs_perm, _ = dimension_numbers
# TODO(marcvanzee): Consider merging tranposes if we want to optimize.
# For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW".
lhs = tf.transpose(lhs, lhs_perm) # lhs --> "NCHW"
lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, lhs_perm) # lhs --> "NCHW"
if len(lhs_perm) == 3:
# For 1D convolution, we add a trivial "W" dimension, so that 2D Convolution
# logic can be applied downstream.
lhs = lhs[:, :, :, np.newaxis]
lhs_shape = tuple(lhs_shape) + (1,)
# However, the TF ops only support "NHWC" on CPU, so we transpose again.
lhs = tf.transpose(lhs, (0, 2, 3, 1)) # "NCHW" --> "NHWC"
lhs, lhs_shape = _transpose_with_shape(lhs, lhs_shape, (0, 2, 3, 1)) # "NCHW" --> "NHWC"
# For `rhs_perm`, perm (0, 1, 2, 3) corresponds to "OIHW".
rhs = tf.transpose(rhs, rhs_perm) # rhs --> "OIHW"
rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, rhs_perm) # rhs --> "OIHW"
# Handle conv1d case.
if len(rhs_perm) == 3:
rhs = rhs[:, :, :, np.newaxis]
rhs_shape = tuple(rhs_shape) + (1,)
# For the tf ops, rhs is expected to be "OIHW".
rhs = tf.transpose(rhs, (2, 3, 1, 0)) # "OIHW" --> "HWIO"
return lhs, rhs
rhs, rhs_shape = _transpose_with_shape(rhs, rhs_shape, (2, 3, 1, 0)) # "OIHW" --> "HWIO"
return lhs, lhs_shape, rhs, rhs_shape
def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
@ -104,18 +116,22 @@ def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
return "EXPLICIT"
def _pad_spatial_dims(in_shape, padding, is_conv1d):
"""Pads `in_shape` using `padding`, which specifies padding for the spatial dimensions."""
def _pad_spatial_dims(x, x_shape, padding, is_conv1d):
"""Pads `x` using `padding`, which specifies padding for the spatial dimensions."""
# Add empty padding for batch and feature dimensions.
no_pad = tf.constant([[0, 0]])
no_pad = ((0, 0),)
padding = tuple(padding)
if is_conv1d:
padding = tf.concat([no_pad, padding, no_pad], 0)
padding = no_pad + padding + no_pad
# Add empty padding for dummy dimension, too.
padding = tf.concat([no_pad, padding, no_pad, no_pad], 0)
padding = no_pad + padding + no_pad + no_pad
else:
padding = tf.concat([no_pad, padding, no_pad], 0)
in_shape = tf.pad(in_shape, padding)
return in_shape
padding = no_pad + padding + no_pad
x = tf.pad(x, padding)
assert len(x.shape) == len(padding)
x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding))
jax2tf._assert_matching_abstract_shape(x, x_shape)
return x, x_shape
def _conv_transpose_pads_to_padtype(kernel_sdims, lhs_dilation, padding):
@ -224,19 +240,27 @@ def _conv_general_dilated(
preferred_element_type: Optional[DType],
_in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray):
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
# In presence of shape polymorphism, lhs.shape and rhs.shape may contain
# None. The actual dimension polynomial shapes are in _in_avals.
del lhs_shape, rhs_shape, precision # Unused arguments.
out_shape = jax2tf._aval_to_tf_shape(_out_aval)
_validate_spatial_dimensions(len(lhs.shape) - 2)
is_conv1d = len(lhs.shape) - 2 == 1
lhs_shape, rhs_shape = _in_avals[0].shape, _in_avals[1].shape
jax2tf._assert_matching_abstract_shape(lhs, lhs_shape)
jax2tf._assert_matching_abstract_shape(rhs, rhs_shape)
out_shape = _out_aval.shape
_validate_spatial_dimensions(len(lhs_shape) - 2)
is_conv1d = len(lhs_shape) - 2 == 1
tf_window_strides = _normalize_window_strides(window_strides)
padding, lhs_dilation, rhs_dilation = _normalize_padding_and_dilations(
padding, lhs_dilation, rhs_dilation, is_conv1d)
lhs, rhs = _transpose_for_tf_conv(lhs, rhs, dimension_numbers)
in_channels = lhs.shape[-1]
*rhs_spatial_shapes, _, rhs_out_channel = rhs.shape
lhs, lhs_shape, rhs, rhs_shape = _transpose_for_tf_conv(lhs, lhs_shape,
rhs, rhs_shape,
dimension_numbers)
jax2tf._assert_matching_abstract_shape(lhs, lhs_shape)
jax2tf._assert_matching_abstract_shape(rhs, rhs_shape)
in_channels = lhs_shape[-1]
*rhs_spatial_shapes, _, rhs_out_channel = rhs_shape
is_transpose = any([d != 1 for d in lhs_dilation])
is_atrous = any([d != 1 for d in rhs_dilation])
@ -255,18 +279,18 @@ def _conv_general_dilated(
rhs_spatial_shapes, lhs_dilation, padding)
else:
padding_type = pads_to_padtype(
lhs.shape[1:3], rhs_dilated_shape, window_strides, padding)
lhs_shape[1:3], rhs_dilated_shape, window_strides, padding)
# We only manually pad if we aren't using a tranposed convolutions.
if padding_type == "EXPLICIT":
lhs = _pad_spatial_dims(lhs, padding, is_conv1d)
lhs, lhs_shape = _pad_spatial_dims(lhs, lhs_shape, padding, is_conv1d)
padding_type = "VALID"
if any(r > l for l, r in zip(lhs.shape[1:3], rhs_dilated_shape)
) and padding_type != "SAME":
# If the filter shape is bigger than the input shape in a spatial dimension,
if padding_type != "SAME" and any(l < r for l, r in zip(lhs_shape[1:3], rhs_dilated_shape)):
# If the input shape is smaller than the filter shape in a spatial dimension,
# lax returns only zeros while tf.conv2d returns an error.
# We thus return zeros to make sure the behavior is consistent.
return tf.broadcast_to(tf.constant(0, dtype=tf.float32), out_shape)
return tf.broadcast_to(tf.constant(0, dtype=tf.float32),
jax2tf._eval_shape(out_shape))
if is_depthwise:
# Reshape filter from
@ -276,7 +300,7 @@ def _conv_general_dilated(
rhs_out_channel // in_channels)
output = tf.nn.depthwise_conv2d(
input=lhs,
filter=tf.reshape(rhs, new_rhs_shape),
filter=tf.reshape(rhs, jax2tf._eval_shape(new_rhs_shape)),
strides=tf_window_strides,
padding=padding_type,
dilations=rhs_dilation)
@ -286,7 +310,7 @@ def _conv_general_dilated(
rhs_t = tf.reverse(rhs, [0, 1])
rhs_t = tf.transpose(rhs_t, (0, 1, 3, 2))
# We should tranpose `out_shape` to "NHWC", which is what TF expects.
# We should transpose `out_shape` to "NHWC", which is what TF expects.
# First transpose to "NCHW".
if is_conv1d:
tf_out_shape = tuple(out_shape[i] for i in output_perm) + (1,)
@ -297,7 +321,7 @@ def _conv_general_dilated(
output = tf.nn.conv2d_transpose(
input=lhs,
filters=rhs_t,
output_shape=tf_out_shape,
output_shape=jax2tf._eval_shape(tf_out_shape),
strides=lhs_dilation,
padding=padding_type)

View File

@ -687,6 +687,15 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize]) -> Sequence[TfVal]:
dim_values, dim_avals, "")) # type: ignore
return shape_values
def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]):
"""Asserts that shape matches x.shape in the known dimensions and has
dimension polynomials elsewhere."""
# Ensures that the shape does not contain None; it should contain polynomials
assert (len(x.shape) == len(shape) and
all((xd is None and isinstance(sd, shape_poly._DimPolynomial) or
core.is_constant_dim(sd) and xd == sd)
for xd, sd in zip(x.shape, shape))), \
f"Shape {shape} does not match x.shape {x.shape}"
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
# pylint: disable=assigning-non-slot

View File

@ -101,7 +101,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@primitive_harness.parameterized(
primitive_harness.all_harnesses,
include_jax_unimpl=False,
#one_containing="cumprod_dtype_by_fun_shape=float16[8,9]_axis=0_reverse=False"
#one_containing="conv_general_dilated_dtype_precision_lhs=float16[2,3,9,10]_rhs=float16[3,3,4,5]_windowstrides=(1,1)_padding=((0,0),(0,0))_lhsdilation=(1,1)_rhsdilation=(1,1)_dimensionnumbers=('NCHW','OIHW','NCHW')_featuregroupcount=1_batchgroupcount=1_precision=DEFAULT_preferred=float64_enablexla=True"
)
@jtu.ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*")

View File

@ -1117,7 +1117,7 @@ def _make_harness(group_name: str, name: str,
check_result=True,
skip_jax_run=True,
tol=None,
enable_and_diable_xla=False,
enable_and_disable_xla=False,
expect_error=(None, None),
**params) -> Union[Harness, Sequence[Harness]]:
"""The `poly_axes` must correspond to the non-static arguments, and for each
@ -1139,15 +1139,15 @@ def _make_harness(group_name: str, name: str,
`expect_error` is a pair of an Exception type and a regular expression to
match the expected exception string.
enable_and_diable_xla=True means that we generate two harnesses,
enable_and_disable_xla=True means that we generate two harnesses,
one with enable_xla=False.
"""
if enable_and_diable_xla:
if enable_and_disable_xla:
return [
_make_harness(group_name, name + ("" if enable_xla else "_noxla"), # type: ignore
func, args, poly_axes=poly_axes,
check_result=check_result, tol=tol, enable_xla=enable_xla,
enable_and_diable_xla=False, skip_jax_run=skip_jax_run,
enable_and_disable_xla=False, skip_jax_run=skip_jax_run,
expect_error=expect_error,
**params)
for enable_xla in [True, False]
@ -1187,7 +1187,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda op: jnp.arange(2 * op.shape[0], dtype=_f32),
[RandArg((3,), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
_make_harness("arange", "start_no_dtype",
lambda op: jnp.arange(op.shape[0]),
[RandArg((3,), _f32)],
@ -1212,13 +1212,13 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda op: lax.argmax(op, axis=0, index_dtype=np.int32),
[RandArg((3, 4, 5), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
# Reduce the non-poly dimension
_make_harness("argmax", "1",
lambda op: lax.argmax(op, axis=1, index_dtype=np.int32),
[RandArg((3, 4, 5), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
[
_make_harness("average",
f"axis={axis}_weights=None",
@ -1276,18 +1276,55 @@ _POLY_SHAPE_TEST_HARNESSES = [
jax.grad(lambda x: jnp.sum(jnp.concatenate([x, x], axis=0))),
[RandArg((3, 4, 5), _f32)],
poly_axes=[(0, 1)]),
# Issue #11402 InconclusiveDimensionOperation: Dimension polynomial '-1*t' is not a multiple of '2'
# TODO(still fails)
# _make_harness("conv_general_dilated", "1d_1",
# lambda lhs, rhs: lax.conv_general_dilated(
# lhs, rhs,
# window_strides=(2,),
# padding="SAME",
# rhs_dilation=None,
# dimension_numbers=lax.ConvDimensionNumbers(lhs_spec=(0, 2, 1),
# rhs_spec=(2, 1, 0),
# out_spec=(0, 2, 1))),
# [RandArg((1, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
# poly_axes=[1, None],
# enable_and_disable_xla=True),
# Issue #11402
_make_harness("conv_general_dilated", "1d_2",
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
strides=(2,),
padding="SAME",
rhs_dilation=None,
transpose_kernel=False),
[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
poly_axes=[0, None],
enable_and_disable_xla=True),
# Issue #11402
_make_harness("conv_general_dilated", "1d_3",
lambda lhs, rhs: lax.conv_transpose(lhs, rhs,
strides=(2,),
padding="SAME",
rhs_dilation=None,
transpose_kernel=False),
[RandArg((5, 12, 16), _f32), RandArg((4, 16, 16), _f32)],
poly_axes=[1, None],
enable_and_disable_xla=True),
_make_harness("conv_general_dilated", "",
lambda lhs, rhs: lax.conv_general_dilated(lhs, rhs,
window_strides=(2, 3),
padding=((0, 0), (0, 0)),
lhs_dilation=(1, 1),
rhs_dilation=(1, 2),
dimension_numbers=("NCHW", "OIHW", "NCHW"),
feature_group_count=1,
batch_group_count=1,
precision=None),
lambda lhs, rhs: lax.conv_general_dilated(
lhs, rhs,
window_strides=(2, 3),
padding=((0, 0), (0, 0)),
lhs_dilation=(1, 1),
rhs_dilation=(1, 2),
dimension_numbers=("NCHW", "OIHW", "NCHW"),
feature_group_count=1,
batch_group_count=1,
precision=None),
[RandArg((7, 3, 9, 10), _f32), RandArg((3, 3, 4, 5), _f32)],
poly_axes=[0, None]),
poly_axes=[0, None],
enable_and_disable_xla=True),
_make_harness("cummax", "",
lambda x: lax_control_flow.cummax(x, axis=1, reverse=False),
[RandArg((3, 4, 5), _f32)],
@ -1306,43 +1343,43 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
[RandArg((3, 4), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
_make_harness("dynamic_slice", "idx=tuple_arg",
# x:shape: (b, 4)
lambda x, i0: lax.dynamic_slice(x, (i0, np.int32(1)), (x.shape[0], 2)),
[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
poly_axes=[0, None],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
_make_harness("dynamic_slice", "idx=array",
# x:shape: (b, 4)
lambda x, idx: lax.dynamic_slice(x, idx, (x.shape[0], 2)),
[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
poly_axes=[0, None],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
_make_harness("dynamic_slice_in_dim", "idx=0",
# x:shape: (b, 4)
lambda x: lax.dynamic_slice_in_dim(x, 0, x.shape[0], axis=0),
[RandArg((3, 4), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
_make_harness("dynamic_update_slice", "idx=tuple_int",
# x:shape: (b, 4)
lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
[RandArg((3, 4), _f32)],
poly_axes=[0],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
_make_harness("dynamic_update_slice", "idx=tuple_arg",
# x:shape: (b, 4)
lambda x, i0: lax.dynamic_update_slice(x, x, (i0, np.int32(0))),
[RandArg((3, 4), _f32), np.array(-2, dtype=np.int32)],
poly_axes=[0, None],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
_make_harness("dynamic_update_slice", "idx=array",
# x:shape: (b, 4)
lambda x, idx: lax.dynamic_update_slice(x, x, idx),
[RandArg((3, 4), _f32), np.array([-2, -1], dtype=np.int32)],
poly_axes=[0, None],
enable_and_diable_xla=True),
enable_and_disable_xla=True),
_make_harness("einsum", "0",
lambda x: jnp.einsum("...i->...", x),
[RandArg((3, 4), _f32)],
@ -1417,45 +1454,45 @@ _POLY_SHAPE_TEST_HARNESSES = [
_make_harness("getitem", "op=static_idx=poly",
lambda a, i: a[i],
[RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
poly_axes=[None, 0], enable_and_diable_xla=True),
poly_axes=[None, 0], enable_and_disable_xla=True),
# operand is poly, index is integer
_make_harness("getitem", "op=poly_idx=const",
lambda a: a[1],
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_and_diable_xla=True),
poly_axes=[0], enable_and_disable_xla=True),
# operand is poly, index is dim poly
_make_harness("getitem", "op=poly_idx=dim",
lambda a: a[jax.core.dimension_as_value(a.shape[0] - 2)],
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_and_diable_xla=True),
poly_axes=[0], enable_and_disable_xla=True),
# Both the operand and the index are poly
_make_harness("getitem", "op=poly_idx=poly",
lambda a, i: a[i],
[RandArg((3, 4), _f32), np.array([1, 2, 0], np.int32)],
poly_axes=[0, 0], enable_and_diable_xla=True),
poly_axes=[0, 0], enable_and_disable_xla=True),
# op is poly and index is an entire slice
_make_harness("getitem", "op=poly_idx=slice-all",
lambda a: a[:],
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_and_diable_xla=True),
poly_axes=[0], enable_and_disable_xla=True),
# op is poly and index is a partial slice
_make_harness("getitem", "op=poly_idx=slice-ct-1",
lambda a: a[:2],
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_and_diable_xla=True,
poly_axes=[0], enable_and_disable_xla=True,
expect_error=(IndexError, "Cannot use NumPy slice indexing on an array dimension")),
_make_harness("getitem", "op=poly_idx=slice-ct-2",
lambda a: a[:, :2],
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_and_diable_xla=True),
poly_axes=[0], enable_and_disable_xla=True),
_make_harness("getitem", "op=poly_idx=slice-None-1",
lambda a: a[:a.shape[0]],
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_and_diable_xla=True),
poly_axes=[0], enable_and_disable_xla=True),
_make_harness("getitem", "op=poly_idx=slice-poly",
lambda a: a[:a.shape[0] - 1],
[RandArg((3, 4), _f32)],
poly_axes=[0], enable_and_diable_xla=True,
poly_axes=[0], enable_and_disable_xla=True,
expect_error=(IndexError, "Array slice indices must have static")),
_make_harness("image_resize", "linear_0",
lambda x: jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]),
@ -1668,7 +1705,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
_make_harness("take", "",
lambda a, i: jnp.take(a, i, axis=1),
[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)],
poly_axes=[0, None], enable_and_diable_xla=True),
poly_axes=[0, None], enable_and_disable_xla=True),
_make_harness("take_along_axis", "0",
lambda x, y: jnp.take_along_axis(x, y, axis=0),
[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)],
@ -1807,7 +1844,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
# to parameterized below.
@primitive_harness.parameterized(
_flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES),
#one_containing="take_along_axis_1_poly_axes=[0, 0]"
#one_containing="conv_general_dilated_1d_2_noxla_poly_axes=[0, None]"
)
def test_prim(self, harness: Harness):
_test_one_harness(self, harness)