Improves support for conv_general_dilated in JAX for models running on the web or mobile through TFLite and TFjs.

### Before this change

Prior to my change, there were a number of limitations to using convolutions for web/mobile:

* No strides other than (1,1) could be used.
* Padding was only possible for values ["VALID", "SAME"]
* Transposed convolutions were unsupported
* Depthwise convolutions were unsupported
* Input could only be provided in a very specific format, which prevented many use cases.

### After this change

After this change, we now can support the following cases:
* Any strides size can be used
* Any padding can be used (VALID, SAME, or custom numbers)
* Transposed convolutions are supported
* Depthwise convolutions are supported
* Input can be provided in any format.

### Impact on examples

Before, most of the Flax examples using convolutions were failing.
After, all convolutions are converting successfully.

PiperOrigin-RevId: 403302738
This commit is contained in:
Marc van Zee 2021-10-15 01:00:18 -07:00 committed by jax authors
parent 0578ba68f4
commit aaf3bb789e
3 changed files with 340 additions and 229 deletions

View File

@ -1,6 +1,6 @@
# Evaluation Results
*Last generated on: 2021-10-04* (YYYY-MM-DD)
*Last generated on: 2021-10-13* (YYYY-MM-DD)
## jax2tf --> TFLite
@ -12,15 +12,15 @@ These exampls are representative for what the average ML researcher is intereste
| Example | Result | Error Message |
| --- | --- | --- |
| imagenet | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
| imagenet | FAIL | NotImplementedError('Call to reduce_window cannot be converted with enable_xla=False. Unimplemented support for padding - See source code for the precise conditions under which reduce_window can be converted without XLA.')
| lm1b | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")
| mnist | SUCCESS |
| nlp_seq | FAIL | ConverterError('/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: error: \'tf.Expm1\' op is neither a custom op nor a flex op\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:3798:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:819:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:836:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:277:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/lax/lax.py:192:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/numpy/lax_numpy.py:661:0: note: called from\n/Users/marcvanzee/github/jax/jax/linear_util.py:166:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:879:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:1645:0: note: called from\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: note: Error code: ERROR_NEEDS_CUSTOM_OPS\n<unknown>:0: error: failed while converting: \'main\': \nSome ops in the model are custom ops, See instructions to implement custom ops: https://www.tensorflow.org/lite/guide/ops_custom \nCustom ops: Expm1\nDetails:\n\ttf.Expm1(tensor<2x1x2xf32>) -> (tensor<2x1x2xf32>) : {device = ""}\n\n')
| pixelcnn++ | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Input padding not supported in TensorFlow. - See source code for the precise conditions under which convolutions can be converted without XLA.')
| ppo | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
| nlp_seq | FAIL | ConverterError('/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: error: \'tf.Expm1\' op is neither a custom op nor a flex op\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:3798:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:820:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:837:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:276:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/lax/lax.py:211:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/numpy/lax_numpy.py:662:0: note: called from\n/Users/marcvanzee/github/jax/jax/linear_util.py:166:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:880:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:1641:0: note: called from\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: note: Error code: ERROR_NEEDS_CUSTOM_OPS\n<unknown>:0: error: failed while converting: \'main\': \nSome ops in the model are custom ops, See instructions to implement custom ops: https://www.tensorflow.org/lite/guide/ops_custom \nCustom ops: Expm1\nDetails:\n\ttf.Expm1(tensor<2x1x2xf32>) -> (tensor<2x1x2xf32>) : {device = ""}\n\n')
| pixelcnn++ | FAIL | ConverterError('/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: error: \'tf.Expm1\' op is neither a custom op nor a flex op\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:3798:0: Error code: ERROR_NEEDS_CUSTOM_OPS\n<unknown>:0: error: failed while converting: \'main\': \nSome ops in the model are custom ops, See instructions to implement custom ops: https://www.tensorflow.org/lite/guide/ops_custom \nCustom ops: Expm1\nDetails:\n\ttf.Expm1(tensor<1x16x16x4xf32>) -> (tensor<1x16x16x4xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x16x16x8xf32>) -> (tensor<1x16x16x8xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x32x32x4xf32>) -> (tensor<1x32x32x4xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x32x32x8xf32>) -> (tensor<1x32x32x8xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x64x64x2xf32>) -> (tensor<1x64x64x2xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x64x64x4xf32>) -> (tensor<1x64x64x4xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x64x64x8xf32>) -> (tensor<1x64x64x8xf32>) : {device = ""}\n\n')
| ppo | SUCCESS |
| seq2seq | SUCCESS |
| sst2 | FAIL | NotImplementedError("Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0, 1, 2))'; op_shape=(2, 6, 3).")
| vae | FAIL | ModuleNotFoundError("No module named 'utils'")
| vae | SUCCESS |
| wmt | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")
## jax2tf --> TFjs
@ -33,13 +33,13 @@ These exampls are representative for what the average ML researcher is intereste
| Example | Result | Error Message |
| --- | --- | --- |
| imagenet | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
| imagenet | FAIL | NotImplementedError('Call to reduce_window cannot be converted with enable_xla=False. Unimplemented support for padding - See source code for the precise conditions under which reduce_window can be converted without XLA.')
| lm1b | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")
| mnist | SUCCESS |
| nlp_seq | FAIL | ValueError("Error when tracing gradients for SavedModel.\n\nSee the stack trace above to see the error that was raised when converting a gradient function to a concrete function. You may need to update the custom gradient, or disable saving gradients with the option tf.saved_model.SaveOptions(custom_gradients=False).\n\tProblematic op name: IdentityN\n\tGradient inputs: (<tf.Tensor 'AddV2_12:0' shape=(2, 1, 8) dtype=float32>, <tf.Tensor 'jax2tf_arg_0:0' shape=(8,) dtype=float32>, <tf.Tensor 'jax2tf_arg_1:0' shape=(4, 8) dtype=float32>, <tf.Tensor 'jax2tf_arg_2:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_3:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_4:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_5:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_6:0' shape=(2,) dtype=float32>, <tf.Tensor 'jax2tf_arg_7:0' shape=(4, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_8:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_9:0' shape=(2, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_10:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_11:0' shape=(1, 2, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_12:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_13:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_14:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_15:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_16:0' shape=(8, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_17:0' shape=(2, 1) dtype=float32>)")
| pixelcnn++ | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Input padding not supported in TensorFlow. - See source code for the precise conditions under which convolutions can be converted without XLA.')
| ppo | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
| seq2seq | FAIL | ValueError('Unsupported Ops in the model before optimization\nBitcast, BitwiseAnd, BitwiseOr, RightShift, LeftShift, BitwiseXor')
| pixelcnn++ | SUCCESS |
| ppo | SUCCESS |
| seq2seq | FAIL | ValueError('Unsupported Ops in the model before optimization\nBitwiseXor, Bitcast, LeftShift, BitwiseAnd, RightShift, BitwiseOr')
| sst2 | FAIL | NotImplementedError("Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0, 1, 2))'; op_shape=(2, 6, 3).")
| vae | SUCCESS |
| wmt | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
import builtins
from functools import partial
import contextlib
import os
@ -1453,6 +1454,56 @@ def _precision_config_proto(precision: Optional[Tuple[PrecisionType,
return proto
def _transpose_for_tf_conv(lhs, rhs, dimension_numbers):
"""Tranposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions."""
# TODO(marcvanzee): Add tests for this ops for shape polymorphism.
lhs_perm, rhs_perm, _ = dimension_numbers
# TODO(marcvanzee): Consider merging tranposes if we want to optimize.
# For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW".
lhs = tf.transpose(lhs, lhs_perm) # lhs --> "NCHW"
# However, the TF ops only support "NHWC" on CPU, so we transpose again.
lhs = tf.transpose(lhs, (0, 2, 3, 1)) # "NCHW" --> "NHWC"
# For `rhs_perm`, perm (0, 1, 2, 3) corresponds to "OIHW".
rhs = tf.transpose(rhs, rhs_perm) # rhs --> "OIHW"
# For the tf ops, rhs is expected to be "OIHW".
rhs = tf.transpose(rhs, (2, 3, 1, 0)) # "OIHW" --> "HWIO"
return lhs, rhs
def _pad_for_tf_conv(lhs, rhs_dilated_shape, window_strides, padding):
"""Pads `lhs` if padding isn't "SAME" or "VALID" and returns (pad_type, new_lhs)."""
for pad_str in ["VALID", "SAME"]:
lhs_sdims = lhs.shape[1:3] # lhs == NHWC
gen_padding = lax.padtype_to_pads(lhs_sdims, rhs_dilated_shape,
window_strides, pad_str)
if list(gen_padding) == list(padding):
return pad_str, lhs
# Since TF ops only accepts padding in ["VALID", "SAME"], we manually pad
# using tf.pad if the padding is different, and we pass "VALID" padding to TF.
# Add empty padding for batch and feature dimensions.
no_pad = tf.constant([[0, 0]])
padding = tf.concat([no_pad, padding, no_pad], 0)
lhs = tf.pad(lhs, padding)
return "VALID", lhs
def _is_valid_padding(kernel_sdims, strides, padding):
"""Returns True if `padding` corresponds to "VALID" padding for a transposed convolution."""
# This is simply the padding == 'VALID' part of lax._conv_transpose_padding.
for (begin, end), k, s in zip(padding, kernel_sdims, strides):
pad_len = k + s - 2 + builtins.max(k - s, 0)
pad_a = k - 1
pad_b = pad_len - pad_a
if begin != pad_a or end != pad_b:
return False
return True
def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
preferred_element_type: Optional[DType], out_shape) -> TfVal:
@ -1462,114 +1513,108 @@ def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
"convolutions can be converted without XLA.")
return _xla_disabled_error("conv_general_dilated", f"{msg} - {suffix}")
# TODO(bchetioui): this function is not exhaustive wrt which convolution cases
# can be translated into TF primitives. Further investigation is needed to
# fully flesh it out.
if lhs.dtype not in [tf.float16, tf.float32, tf.float64]:
raise error(f"tf.nn.convolution is not supported for dtype {lhs.dtype}")
if feature_group_count != 1:
raise error("tf.nn.convolution does not support grouped convolutions")
# TODO(bchetioui): is there something to do with batch_group_count?
nr_spatial_dimensions = len(lhs.shape) - 2
# Currently we only support 2D convolutions because it keeps the code
# relatively simple and covers most cases.
if nr_spatial_dimensions != 2:
error(f"We only support 2D convolutions, but found {nr_spatial_dimensions}.")
# We can implement batch grouping when there is a need for it.
if batch_group_count != 1:
raise error("Unimplemented support for batch_group_count != 1")
nb_spatial_dimensions = len(lhs.shape) - 2
# TF can only deal with 1D, 2D and 3D convolution
if nb_spatial_dimensions < 1 or nb_spatial_dimensions > 3:
raise error("TensorFlow can only handle convolutions with 1, 2, or 3 "
"spatial dimensions")
# TODO(bchetioui): handle different stride cases
if list(window_strides) != [1] * nb_spatial_dimensions:
raise error("Unimplemented support for window_strides != "
f"{tuple([1] * nb_spatial_dimensions)}")
raise error("Unimplemented support for batch_group_count != 1 "
f"(found {batch_group_count})")
if preferred_element_type is not None and preferred_element_type != lhs.dtype:
raise error("Unimplemented support for preferred_element_type")
def convert_padding() -> str:
# TODO(bchetioui): in this instance, we can not use padtype_to_pads as
# string padding is not implemented for transposed convolution.
if list(lhs_dilation) != [1] * nb_spatial_dimensions:
raise error("Padding conversion is not supported for transposed "
"convolution.")
lhs_perm, rhs_perm, _ = dimension_numbers
effective_rhs_shape = [
(k - 1) * r + 1
for k, r in zip(np.take(rhs.shape, rhs_perm)[2:], rhs_dilation)
]
lhs_shape = np.take(lhs.shape, lhs_perm)[2:]
# TF only allows 'VALID' and 'SAME' padding
for pad_str in ["VALID", "SAME"]:
gen_padding = lax.padtype_to_pads(lhs_shape, effective_rhs_shape,
window_strides, pad_str)
if list(gen_padding) == list(padding):
return pad_str
raise error("Input padding not supported in TensorFlow.")
lhs, rhs = _transpose_for_tf_conv(lhs, rhs, dimension_numbers)
output_perm = dimension_numbers[2]
def convert_dim_nums() -> str:
lhs_spec, rhs_spec, out_spec = dimension_numbers
# TF only allows filters with shape:
# spatial_filter_shape + [in_channels, out_channels]. In JAX however,
# rhs_spec is represented as a tuple containing the following:
# [out_channels, in_channels] + spatial_filter_shape.
supported_rhs_shape = ([nb_spatial_dimensions + 1, nb_spatial_dimensions] +
list(range(nb_spatial_dimensions)))
if list(rhs_spec) != supported_rhs_shape:
raise error("Input filter (RHS) shape format not supported in "
"TensorFlow.")
# TF only supports same LHS and output data format
if lhs_spec != out_spec:
raise error("TensorFlow requires the same data format for LHS and "
"output.")
# Alphabet extracted from the documentation of tf.conv{1,2,3}d
spatial_dim_alphabet = "DHW"[-nb_spatial_dimensions:]
# TF only supports the following data formats:
# - [batch_size, in_channels] + input_spatial_shape
in_channels = lhs.shape[-1]
*rhs_spatial_shapes, _, rhs_out_channel = rhs.shape
# TODO(bchetioui): TF currently does not support the above on CPU. To avoid
# failing on this platform, this path is commented out for now.
# if list(lhs_spec) == list(range(len(lhs_spec))):
# return "NC" + spatial_dim_alphabet
is_depthwise = in_channels == feature_group_count and feature_group_count > 1
is_transpose = list(lhs_dilation) != [1] * nr_spatial_dimensions
is_atrous = list(rhs_dilation) != [1] * nr_spatial_dimensions
# - [batch_size] + input_spatial_shape + [in_channels]
if list(lhs_spec) == ([0, len(lhs_spec) - 1] +
list(range(1,
len(lhs_spec) - 1))):
return "N" + spatial_dim_alphabet + "C"
raise error("Data format is unsupported by TensorFlow.")
if feature_group_count > 1 and not is_depthwise:
raise error("Grouped convolutions are unsupported")
def convert_dilation_and_compute_result(tf_padding: str,
tf_dim_nums: str) -> TfVal:
no_dilation = [1] * nb_spatial_dimensions
# TODO(bchetioui): is there a generic way to do a transposed atrous
# convolution in TensorFlow?
if not (list(lhs_dilation) == no_dilation or
list(rhs_dilation) == no_dilation):
raise error("Both LHS and RHS dilations are set.")
# This is a non-dilated or atrous convolution
if list(lhs_dilation) == no_dilation:
return tf.nn.convolution(
lhs,
rhs,
strides=window_strides,
padding=tf_padding,
data_format=tf_dim_nums,
dilations=rhs_dilation)
# TODO(bchetioui): the below path is unreachable for now, as passing a lhs
# dilation to this function will result in convert_padding returning None
# systematically. This must be investigated further.
# Dilation of the LHS is transposed convolution
return tf.nn.conv_transpose(
lhs,
rhs,
out_shape,
window_strides,
padding=tf_padding,
data_format=tf_dim_nums,
dilations=lhs_dilation)
if is_transpose:
# We provide support for transposed convolutions called through
# lax.conv2d_tranpose, but only if the provided padding was VALID.
if not _is_valid_padding(rhs_spatial_shapes, window_strides, padding):
raise error("Can only convert Transposed Convolutions with 'VALID' padding")
tf_padding = convert_padding()
tf_dim_nums = convert_dim_nums()
return convert_dilation_and_compute_result(tf_padding, tf_dim_nums)
if [is_depthwise, is_atrous, is_transpose].count(True) > 1:
raise error("Can only do one of depthwise, atrous and tranposed convolutions")
rhs_dilated_shape = [
(k - 1) * r + 1 for k, r in zip(rhs_spatial_shapes, rhs_dilation)
]
padding_type, padded_lhs = _pad_for_tf_conv(lhs, rhs_dilated_shape, window_strides,
padding)
if any(r > l for l, r in zip(padded_lhs.shape[1:3], rhs_dilated_shape)):
# If the filter shape is bigger than the input shape in a spatial dimension,
# lax returns only zeros while tf.conv2d returns an error.
# We thus return zeros to make sure the behavior is consistent.
return tf.broadcast_to(tf.constant(0, dtype=tf.float32), out_shape)
# Some TF ops require len(window_strides) == 4 while others do not. We simply
# ensure it always has len(4).
if type(window_strides) == int:
window_strides = [window_strides] * 2
if len(window_strides) == 2:
window_strides = [1] + list(window_strides) + [1]
if is_depthwise:
# Reshape filter from
# [filter_height, filter_width, 1, in_channels * channel_multiplier] to
# [filter_height, filter_width, in_channels, channel_multiplier].
new_rhs_shape = tuple(rhs_spatial_shapes) + (in_channels, rhs_out_channel // in_channels)
output = tf.nn.depthwise_conv2d(
input=padded_lhs,
filter=tf.reshape(rhs, new_rhs_shape),
strides=window_strides,
padding=padding_type,
dilations=rhs_dilation)
elif is_transpose:
# tf.nn.conv2d_transpose requires a transposed filter.
rhs_t = tf.reverse(rhs, [0, 1])
rhs_t = tf.transpose(rhs_t, (0, 1, 3, 2))
# We should tranpose `out_shape` so it conforms to what TF expects.
tf_out_shape = tuple(out_shape[i] for i in output_perm) # "NCHW"
tf_out_shape = tuple(tf_out_shape[i] for i in (0, 2, 3, 1)) # "NCHW" -> "NHWC"
output = tf.nn.conv2d_transpose(
input=lhs,
filters=rhs_t,
output_shape=tf_out_shape,
strides=lhs_dilation,
padding="VALID")
else:
output = tf.nn.conv2d(
input=padded_lhs,
filters=rhs,
strides=window_strides,
padding=padding_type,
dilations=rhs_dilation)
# TF outputs in format "NHWC", so convert to "NCHW", which is lax's default
# format.
output = tf.transpose(output, (0, 3, 1, 2)) # "NHWC" --> "NCHW"
# To determine the right permutation, we compute the inverse permutation of
# `output_perm`, so that when `output_perm` is applied to `output`, we obtain
# the outpt in NCHW format.
inverse_perm = tuple(output_perm.index(i) for i in range(4))
output = tf.transpose(output, inverse_perm) # "NCHW" -> desired output shape.
return output
def _conv_general_dilated(lhs, rhs, *,

View File

@ -2662,63 +2662,66 @@ def _make_conv_harness(name,
dimension_numbers=("NCHW", "OIHW", "NCHW"),
batch_group_count=1,
preferred_element_type=None,
enable_xla=True):
define(
lax.conv_general_dilated_p,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_padding={padding}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_precision={precision}_preferred={jtu.dtype_str(preferred_element_type)}_enablexla={enable_xla}"
.replace(" ", ""),
lax.conv_general_dilated,
[
RandArg(lhs_shape, dtype),
RandArg(rhs_shape, dtype),
StaticArg(window_strides),
StaticArg(padding),
StaticArg(lhs_dilation),
StaticArg(rhs_dilation),
StaticArg(dimension_numbers),
StaticArg(feature_group_count),
StaticArg(batch_group_count),
StaticArg(precision),
StaticArg(preferred_element_type),
],
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dtype=dtype,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=precision,
preferred_element_type=preferred_element_type,
enable_xla=enable_xla,
jax_unimplemented=[
Limitation(
"preferred_element_type=i64 not implemented",
devices="tpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int64])),
# b/183565702 - no integer convolutions for GPU
Limitation(
"preferred_element_type not implemented for integers",
devices="gpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int16, np.int32,
np.int64])),
Limitation(
"preferred_element_type=f64 not implemented",
devices="tpu",
dtypes=(np.float16, jnp.bfloat16, np.float32),
enabled=(preferred_element_type in [np.float64])),
Limitation(
"preferred_element_type=c128 not implemented",
devices="tpu",
dtypes=np.complex64,
enabled=(preferred_element_type in [np.complex128])),
],
)
works_without_xla=False):
enable_xla_cases = [True, False] if works_without_xla else [True]
for enable_xla in enable_xla_cases:
define(
lax.conv_general_dilated_p,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_padding={padding}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_precision={precision}_preferred={jtu.dtype_str(preferred_element_type)}_enablexla={enable_xla}"
.replace(" ", ""),
lax.conv_general_dilated,
[
RandArg(lhs_shape, dtype),
RandArg(rhs_shape, dtype),
StaticArg(window_strides),
StaticArg(padding),
StaticArg(lhs_dilation),
StaticArg(rhs_dilation),
StaticArg(dimension_numbers),
StaticArg(feature_group_count),
StaticArg(batch_group_count),
StaticArg(precision),
StaticArg(preferred_element_type),
],
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dtype=dtype,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=precision,
preferred_element_type=preferred_element_type,
enable_xla=enable_xla,
jax_unimplemented=[
Limitation(
"preferred_element_type=i64 not implemented",
devices="tpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int64])),
# b/183565702 - no integer convolutions for GPU
Limitation(
"preferred_element_type not implemented for integers",
devices="gpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int16, np.int32,
np.int64])),
Limitation(
"preferred_element_type=f64 not implemented",
devices="tpu",
dtypes=(np.float16, jnp.bfloat16, np.float32),
enabled=(preferred_element_type in [np.float64])),
Limitation(
"preferred_element_type=c128 not implemented",
devices="tpu",
dtypes=np.complex64,
enabled=(preferred_element_type in [np.complex128])),
],
)
# Validate dtypes and precision
@ -2729,12 +2732,19 @@ for dtype in jtu.dtypes.all_inexact:
# This first harness runs the tests for all dtypes and precisions using
# default values for all the other parameters. Variations of other parameters
# can thus safely skip testing their corresponding default value.
_make_conv_harness("dtype_precision", dtype=dtype, precision=precision)
_make_conv_harness(
"dtype_precision",
dtype=dtype,
precision=precision)
# Validate preferred_element_type
for dtype, preferred_element_type in preferred_type_combinations:
works_without_xla = dtype == np.float32 and preferred_element_type == np.float32
_make_conv_harness(
"preferred", dtype=dtype, preferred_element_type=preferred_element_type)
"preferred", dtype=dtype,
preferred_element_type=preferred_element_type,
works_without_xla=works_without_xla)
# Validate variations of feature_group_count and batch_group_count
for batch_group_count, feature_group_count in [
@ -2752,25 +2762,81 @@ for batch_group_count, feature_group_count in [
feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
#--- BEGIN Tests for conv_general_dilated with works_for_xla=True ---
# feature_group_count is supported for enable_xla=False only if we are doing a
# depthwise convolution, i.e.: in_channels == feature_group_count.
# See explanation of depthwise convolution at
# https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
_make_conv_harness(
"depthwise2d",
lhs_shape=(2, 3, 9, 9), # "NCHW": in_channels == 3
rhs_shape=(12, 1, 3, 3), # "OIHW": channel_multiplier = 12/3 = 4
feature_group_count=3,
works_without_xla=True)
# Validate variations of window_strides
for window_strides in [(2, 3)]:
_make_conv_harness("window_strides", window_strides=window_strides)
_make_conv_harness(
"window_strides",
window_strides=window_strides,
works_without_xla=True)
# Validate variations of padding
for padding in [
((1, 2), (0, 0)), # padding only one spatial axis
((1, 2), (2, 1)) # padding on both spatial axes
]:
_make_conv_harness("padding", padding=padding)
_make_conv_harness("padding", padding=padding, works_without_xla=True)
# Validate variations of dilations
for lhs_dilation, rhs_dilation in [
((2, 2), (1, 1)), # dilation only on LHS (transposed)
((1, 1), (2, 3)), # dilation only on RHS (atrous)
((2, 3), (3, 2)) # dilation on both LHS and RHS (transposed & atrous)
((1, 1), (2, 2)), # dilation only on RHS (atrous)
]:
_make_conv_harness(
"dilations", lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
"dilations",
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
works_without_xla=True)
# Simulate a call from lax.conv_transpose.
_make_conv_harness(
"conv_tranpose2d_valid_padding",
lhs_shape=(1, 16, 16, 2),
rhs_shape=(2, 3, 2, 2),
window_strides=(1, 1),
lhs_dilation=(2, 2),
padding=((1, 1), (2, 2)),
dimension_numbers=("NHWC", "HWIO", "NHWC"),
works_without_xla=True)
# Validate rhs > lhs.
# One dimension of rhs is bigger than lhs.
_make_conv_harness(
"rhs_oob",
lhs_shape=(2, 3, 9, 10),
rhs_shape=(3, 3, 10, 5),
works_without_xla=True)
# Effective rhs size is too big after applying rhs_dilation.
_make_conv_harness(
"rhs_oob_after_dilation",
lhs_shape=(2, 3, 9, 10),
rhs_shape=(3, 3, 4, 5),
rhs_dilation=(2, 3),
works_without_xla=True)
# Effective rhs is too big after applying input padding.
_make_conv_harness(
"rhs_oob_after_pading",
lhs_shape=(1, 3, 2, 2),
rhs_shape=(64, 3, 7, 7),
window_strides=(2, 2),
padding=((3, 3), (3, 3)),
works_without_xla=True)
# Dimension numbers and corresponding permutation
for dimension_numbers, lhs_shape, rhs_shape in [
@ -2781,57 +2847,63 @@ for dimension_numbers, lhs_shape, rhs_shape in [
"dimension_numbers",
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1,), (1,)), # no dilation with "VALID" padding
("SAME", (1,), (1,)), # no dilation with "SAME" padding
("VALID", (1,), (2,)), # dilation only on RHS with "VALID" padding
("SAME", (1,), (2,)), # dilation only on RHS with "SAME" padding
# TODO(bchetioui): LHS dilation with string padding can never be done using
# TF convolution functions for now.
]:
for dimension_numbers, lhs_shape, rhs_shape in [
(("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)), # TF default
# TODO(bchetioui): the NCW data format is not supported on CPU for TF
# for now. That path is thus disabled to allow the code to use XLA instead.
]:
for enable_xla in [False, True]:
_make_conv_harness(
"tf_conversion_path_1d",
lhs_shape=lhs_shape,
padding=padding,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1,),
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
enable_xla=enable_xla)
dimension_numbers=dimension_numbers,
works_without_xla=True)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1, 1), (1, 1)), # no dilation with "VALID" padding
("SAME", (1, 1), (1, 1)), # no dilation with "SAME" padding
("VALID", (1, 1), (1, 2)), # dilation only on RHS with "VALID" padding
("SAME", (1, 1), (1, 2)), # dilation only on RHS with "SAME" padding
# TODO(bchetioui): LHS dilation with string padding can never be done using
# TF convolution functions for now.
([(1, 2), (0, 1)], (1, 1), (1, 2))
]:
for dimension_numbers, lhs_shape, rhs_shape in [
(("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)), # TF default
# TODO(bchetioui): the NCHW data format is not supported on CPU for TF
# for now. That path is thus disabled to allow the code to use XLA instead.
(("NCHW", "HWIO", "NCHW"), (1, 1, 28, 28), (3, 3, 1, 16)),
]:
for enable_xla in [False, True]:
_make_conv_harness(
"tf_conversion_path_2d",
lhs_shape=lhs_shape,
padding=padding,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1, 1),
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
enable_xla=enable_xla)
_make_conv_harness(
"tf_conversion_path_2d",
lhs_shape=lhs_shape,
padding=padding,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1, 1),
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
works_without_xla=True)
#--- END Tests for conv_general_dilated with works_for_xla=True ---
for lhs_dilation, rhs_dilation in [
# Note: LHS dilation does work for enable_xla=False, but only if
# padding=='VALID' (see test above for conv_transpose2d_valid_padding).
((2, 2), (1, 1)), # dilation only on LHS (transposed)
((2, 3), (3, 2)) # dilation on both LHS and RHS (transposed & atrous)
]:
_make_conv_harness(
"dilations", lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1,), (1,)), # no dilation with "VALID" padding
("SAME", (1,), (1,)), # no dilation with "SAME" padding
("VALID", (1,), (2,)), # dilation only on RHS with "VALID" padding
("SAME", (1,), (2,)), # dilation only on RHS with "SAME" padding
]:
for dimension_numbers, lhs_shape, rhs_shape in [
(("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)), # TF default
]:
_make_conv_harness(
"tf_conversion_path_1d",
lhs_shape=lhs_shape,
padding=padding,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1,),
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1, 1, 1), (1, 1, 1)), # no dilation with "VALID" padding
@ -2839,26 +2911,20 @@ for padding, lhs_dilation, rhs_dilation in [
("VALID", (1, 1, 1), (1, 1,
2)), # dilation only on RHS with "VALID" padding
("SAME", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "SAME" padding
# TODO(bchetioui): LHS dilation with string padding can never be done using
# TF convolution functions for now.
]:
for dimension_numbers, lhs_shape, rhs_shape in [
# TF default
(("NDHWC", "DHWIO", "NDHWC"), (1, 4, 28, 28, 1), (2, 3, 3, 1, 16)),
# TODO(bchetioui): the NCDHW data format is not supported on CPU for TF
# for now. That path is thus disabled to allow the code to use XLA instead.
]:
for enable_xla in [False, True]:
_make_conv_harness(
"tf_conversion_path_3d",
lhs_shape=lhs_shape,
padding=padding,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1, 1, 1),
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
enable_xla=enable_xla)
_make_conv_harness(
"tf_conversion_path_3d",
lhs_shape=lhs_shape,
padding=padding,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1, 1, 1),
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation)
if config.jax_enable_x64: