mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improves support for conv_general_dilated in JAX for models running on the web or mobile through TFLite and TFjs.
### Before this change Prior to my change, there were a number of limitations to using convolutions for web/mobile: * No strides other than (1,1) could be used. * Padding was only possible for values ["VALID", "SAME"] * Transposed convolutions were unsupported * Depthwise convolutions were unsupported * Input could only be provided in a very specific format, which prevented many use cases. ### After this change After this change, we now can support the following cases: * Any strides size can be used * Any padding can be used (VALID, SAME, or custom numbers) * Transposed convolutions are supported * Depthwise convolutions are supported * Input can be provided in any format. ### Impact on examples Before, most of the Flax examples using convolutions were failing. After, all convolutions are converting successfully. PiperOrigin-RevId: 403302738
This commit is contained in:
parent
0578ba68f4
commit
aaf3bb789e
@ -1,6 +1,6 @@
|
||||
# Evaluation Results
|
||||
|
||||
*Last generated on: 2021-10-04* (YYYY-MM-DD)
|
||||
*Last generated on: 2021-10-13* (YYYY-MM-DD)
|
||||
|
||||
## jax2tf --> TFLite
|
||||
|
||||
@ -12,15 +12,15 @@ These exampls are representative for what the average ML researcher is intereste
|
||||
|
||||
| Example | Result | Error Message |
|
||||
| --- | --- | --- |
|
||||
| imagenet | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
|
||||
| imagenet | FAIL | NotImplementedError('Call to reduce_window cannot be converted with enable_xla=False. Unimplemented support for padding - See source code for the precise conditions under which reduce_window can be converted without XLA.')
|
||||
| lm1b | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")
|
||||
| mnist | SUCCESS |
|
||||
| nlp_seq | FAIL | ConverterError('/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: error: \'tf.Expm1\' op is neither a custom op nor a flex op\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:3798:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:819:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:836:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:277:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/lax/lax.py:192:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/numpy/lax_numpy.py:661:0: note: called from\n/Users/marcvanzee/github/jax/jax/linear_util.py:166:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:879:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:1645:0: note: called from\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: note: Error code: ERROR_NEEDS_CUSTOM_OPS\n<unknown>:0: error: failed while converting: \'main\': \nSome ops in the model are custom ops, See instructions to implement custom ops: https://www.tensorflow.org/lite/guide/ops_custom \nCustom ops: Expm1\nDetails:\n\ttf.Expm1(tensor<2x1x2xf32>) -> (tensor<2x1x2xf32>) : {device = ""}\n\n')
|
||||
| pixelcnn++ | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Input padding not supported in TensorFlow. - See source code for the precise conditions under which convolutions can be converted without XLA.')
|
||||
| ppo | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
|
||||
| nlp_seq | FAIL | ConverterError('/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: error: \'tf.Expm1\' op is neither a custom op nor a flex op\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:3798:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:820:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:837:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:276:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/lax/lax.py:211:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/numpy/lax_numpy.py:662:0: note: called from\n/Users/marcvanzee/github/jax/jax/linear_util.py:166:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:880:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:1641:0: note: called from\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: note: Error code: ERROR_NEEDS_CUSTOM_OPS\n<unknown>:0: error: failed while converting: \'main\': \nSome ops in the model are custom ops, See instructions to implement custom ops: https://www.tensorflow.org/lite/guide/ops_custom \nCustom ops: Expm1\nDetails:\n\ttf.Expm1(tensor<2x1x2xf32>) -> (tensor<2x1x2xf32>) : {device = ""}\n\n')
|
||||
| pixelcnn++ | FAIL | ConverterError('/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: error: \'tf.Expm1\' op is neither a custom op nor a flex op\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:3798:0: Error code: ERROR_NEEDS_CUSTOM_OPS\n<unknown>:0: error: failed while converting: \'main\': \nSome ops in the model are custom ops, See instructions to implement custom ops: https://www.tensorflow.org/lite/guide/ops_custom \nCustom ops: Expm1\nDetails:\n\ttf.Expm1(tensor<1x16x16x4xf32>) -> (tensor<1x16x16x4xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x16x16x8xf32>) -> (tensor<1x16x16x8xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x32x32x4xf32>) -> (tensor<1x32x32x4xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x32x32x8xf32>) -> (tensor<1x32x32x8xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x64x64x2xf32>) -> (tensor<1x64x64x2xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x64x64x4xf32>) -> (tensor<1x64x64x4xf32>) : {device = ""}\n\ttf.Expm1(tensor<1x64x64x8xf32>) -> (tensor<1x64x64x8xf32>) : {device = ""}\n\n')
|
||||
| ppo | SUCCESS |
|
||||
| seq2seq | SUCCESS |
|
||||
| sst2 | FAIL | NotImplementedError("Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0, 1, 2))'; op_shape=(2, 6, 3).")
|
||||
| vae | FAIL | ModuleNotFoundError("No module named 'utils'")
|
||||
| vae | SUCCESS |
|
||||
| wmt | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")
|
||||
|
||||
## jax2tf --> TFjs
|
||||
@ -33,13 +33,13 @@ These exampls are representative for what the average ML researcher is intereste
|
||||
|
||||
| Example | Result | Error Message |
|
||||
| --- | --- | --- |
|
||||
| imagenet | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
|
||||
| imagenet | FAIL | NotImplementedError('Call to reduce_window cannot be converted with enable_xla=False. Unimplemented support for padding - See source code for the precise conditions under which reduce_window can be converted without XLA.')
|
||||
| lm1b | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")
|
||||
| mnist | SUCCESS |
|
||||
| nlp_seq | FAIL | ValueError("Error when tracing gradients for SavedModel.\n\nSee the stack trace above to see the error that was raised when converting a gradient function to a concrete function. You may need to update the custom gradient, or disable saving gradients with the option tf.saved_model.SaveOptions(custom_gradients=False).\n\tProblematic op name: IdentityN\n\tGradient inputs: (<tf.Tensor 'AddV2_12:0' shape=(2, 1, 8) dtype=float32>, <tf.Tensor 'jax2tf_arg_0:0' shape=(8,) dtype=float32>, <tf.Tensor 'jax2tf_arg_1:0' shape=(4, 8) dtype=float32>, <tf.Tensor 'jax2tf_arg_2:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_3:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_4:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_5:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_6:0' shape=(2,) dtype=float32>, <tf.Tensor 'jax2tf_arg_7:0' shape=(4, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_8:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_9:0' shape=(2, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_10:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_11:0' shape=(1, 2, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_12:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_13:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_14:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_15:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_16:0' shape=(8, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_17:0' shape=(2, 1) dtype=float32>)")
|
||||
| pixelcnn++ | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Input padding not supported in TensorFlow. - See source code for the precise conditions under which convolutions can be converted without XLA.')
|
||||
| ppo | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
|
||||
| seq2seq | FAIL | ValueError('Unsupported Ops in the model before optimization\nBitcast, BitwiseAnd, BitwiseOr, RightShift, LeftShift, BitwiseXor')
|
||||
| pixelcnn++ | SUCCESS |
|
||||
| ppo | SUCCESS |
|
||||
| seq2seq | FAIL | ValueError('Unsupported Ops in the model before optimization\nBitwiseXor, Bitcast, LeftShift, BitwiseAnd, RightShift, BitwiseOr')
|
||||
| sst2 | FAIL | NotImplementedError("Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0, 1, 2))'; op_shape=(2, 6, 3).")
|
||||
| vae | SUCCESS |
|
||||
| wmt | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
|
||||
import builtins
|
||||
from functools import partial
|
||||
import contextlib
|
||||
import os
|
||||
@ -1453,6 +1454,56 @@ def _precision_config_proto(precision: Optional[Tuple[PrecisionType,
|
||||
return proto
|
||||
|
||||
|
||||
def _transpose_for_tf_conv(lhs, rhs, dimension_numbers):
|
||||
"""Tranposes lhs and rhs to respectively NHWC and HWIO so they can be passed to TF functions."""
|
||||
# TODO(marcvanzee): Add tests for this ops for shape polymorphism.
|
||||
lhs_perm, rhs_perm, _ = dimension_numbers
|
||||
|
||||
# TODO(marcvanzee): Consider merging tranposes if we want to optimize.
|
||||
# For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW".
|
||||
lhs = tf.transpose(lhs, lhs_perm) # lhs --> "NCHW"
|
||||
# However, the TF ops only support "NHWC" on CPU, so we transpose again.
|
||||
lhs = tf.transpose(lhs, (0, 2, 3, 1)) # "NCHW" --> "NHWC"
|
||||
# For `rhs_perm`, perm (0, 1, 2, 3) corresponds to "OIHW".
|
||||
rhs = tf.transpose(rhs, rhs_perm) # rhs --> "OIHW"
|
||||
# For the tf ops, rhs is expected to be "OIHW".
|
||||
rhs = tf.transpose(rhs, (2, 3, 1, 0)) # "OIHW" --> "HWIO"
|
||||
|
||||
return lhs, rhs
|
||||
|
||||
|
||||
def _pad_for_tf_conv(lhs, rhs_dilated_shape, window_strides, padding):
|
||||
"""Pads `lhs` if padding isn't "SAME" or "VALID" and returns (pad_type, new_lhs)."""
|
||||
|
||||
for pad_str in ["VALID", "SAME"]:
|
||||
lhs_sdims = lhs.shape[1:3] # lhs == NHWC
|
||||
gen_padding = lax.padtype_to_pads(lhs_sdims, rhs_dilated_shape,
|
||||
window_strides, pad_str)
|
||||
if list(gen_padding) == list(padding):
|
||||
return pad_str, lhs
|
||||
|
||||
# Since TF ops only accepts padding in ["VALID", "SAME"], we manually pad
|
||||
# using tf.pad if the padding is different, and we pass "VALID" padding to TF.
|
||||
# Add empty padding for batch and feature dimensions.
|
||||
no_pad = tf.constant([[0, 0]])
|
||||
padding = tf.concat([no_pad, padding, no_pad], 0)
|
||||
lhs = tf.pad(lhs, padding)
|
||||
return "VALID", lhs
|
||||
|
||||
|
||||
def _is_valid_padding(kernel_sdims, strides, padding):
|
||||
"""Returns True if `padding` corresponds to "VALID" padding for a transposed convolution."""
|
||||
# This is simply the padding == 'VALID' part of lax._conv_transpose_padding.
|
||||
for (begin, end), k, s in zip(padding, kernel_sdims, strides):
|
||||
pad_len = k + s - 2 + builtins.max(k - s, 0)
|
||||
pad_a = k - 1
|
||||
pad_b = pad_len - pad_a
|
||||
if begin != pad_a or end != pad_b:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
preferred_element_type: Optional[DType], out_shape) -> TfVal:
|
||||
@ -1462,114 +1513,108 @@ def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
"convolutions can be converted without XLA.")
|
||||
return _xla_disabled_error("conv_general_dilated", f"{msg} - {suffix}")
|
||||
|
||||
# TODO(bchetioui): this function is not exhaustive wrt which convolution cases
|
||||
# can be translated into TF primitives. Further investigation is needed to
|
||||
# fully flesh it out.
|
||||
if lhs.dtype not in [tf.float16, tf.float32, tf.float64]:
|
||||
raise error(f"tf.nn.convolution is not supported for dtype {lhs.dtype}")
|
||||
if feature_group_count != 1:
|
||||
raise error("tf.nn.convolution does not support grouped convolutions")
|
||||
# TODO(bchetioui): is there something to do with batch_group_count?
|
||||
nr_spatial_dimensions = len(lhs.shape) - 2
|
||||
|
||||
# Currently we only support 2D convolutions because it keeps the code
|
||||
# relatively simple and covers most cases.
|
||||
if nr_spatial_dimensions != 2:
|
||||
error(f"We only support 2D convolutions, but found {nr_spatial_dimensions}.")
|
||||
|
||||
# We can implement batch grouping when there is a need for it.
|
||||
if batch_group_count != 1:
|
||||
raise error("Unimplemented support for batch_group_count != 1")
|
||||
nb_spatial_dimensions = len(lhs.shape) - 2
|
||||
# TF can only deal with 1D, 2D and 3D convolution
|
||||
if nb_spatial_dimensions < 1 or nb_spatial_dimensions > 3:
|
||||
raise error("TensorFlow can only handle convolutions with 1, 2, or 3 "
|
||||
"spatial dimensions")
|
||||
# TODO(bchetioui): handle different stride cases
|
||||
if list(window_strides) != [1] * nb_spatial_dimensions:
|
||||
raise error("Unimplemented support for window_strides != "
|
||||
f"{tuple([1] * nb_spatial_dimensions)}")
|
||||
raise error("Unimplemented support for batch_group_count != 1 "
|
||||
f"(found {batch_group_count})")
|
||||
|
||||
if preferred_element_type is not None and preferred_element_type != lhs.dtype:
|
||||
raise error("Unimplemented support for preferred_element_type")
|
||||
|
||||
def convert_padding() -> str:
|
||||
# TODO(bchetioui): in this instance, we can not use padtype_to_pads as
|
||||
# string padding is not implemented for transposed convolution.
|
||||
if list(lhs_dilation) != [1] * nb_spatial_dimensions:
|
||||
raise error("Padding conversion is not supported for transposed "
|
||||
"convolution.")
|
||||
lhs_perm, rhs_perm, _ = dimension_numbers
|
||||
effective_rhs_shape = [
|
||||
(k - 1) * r + 1
|
||||
for k, r in zip(np.take(rhs.shape, rhs_perm)[2:], rhs_dilation)
|
||||
]
|
||||
lhs_shape = np.take(lhs.shape, lhs_perm)[2:]
|
||||
# TF only allows 'VALID' and 'SAME' padding
|
||||
for pad_str in ["VALID", "SAME"]:
|
||||
gen_padding = lax.padtype_to_pads(lhs_shape, effective_rhs_shape,
|
||||
window_strides, pad_str)
|
||||
if list(gen_padding) == list(padding):
|
||||
return pad_str
|
||||
raise error("Input padding not supported in TensorFlow.")
|
||||
lhs, rhs = _transpose_for_tf_conv(lhs, rhs, dimension_numbers)
|
||||
output_perm = dimension_numbers[2]
|
||||
|
||||
def convert_dim_nums() -> str:
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
# TF only allows filters with shape:
|
||||
# spatial_filter_shape + [in_channels, out_channels]. In JAX however,
|
||||
# rhs_spec is represented as a tuple containing the following:
|
||||
# [out_channels, in_channels] + spatial_filter_shape.
|
||||
supported_rhs_shape = ([nb_spatial_dimensions + 1, nb_spatial_dimensions] +
|
||||
list(range(nb_spatial_dimensions)))
|
||||
if list(rhs_spec) != supported_rhs_shape:
|
||||
raise error("Input filter (RHS) shape format not supported in "
|
||||
"TensorFlow.")
|
||||
# TF only supports same LHS and output data format
|
||||
if lhs_spec != out_spec:
|
||||
raise error("TensorFlow requires the same data format for LHS and "
|
||||
"output.")
|
||||
# Alphabet extracted from the documentation of tf.conv{1,2,3}d
|
||||
spatial_dim_alphabet = "DHW"[-nb_spatial_dimensions:]
|
||||
# TF only supports the following data formats:
|
||||
# - [batch_size, in_channels] + input_spatial_shape
|
||||
in_channels = lhs.shape[-1]
|
||||
*rhs_spatial_shapes, _, rhs_out_channel = rhs.shape
|
||||
|
||||
# TODO(bchetioui): TF currently does not support the above on CPU. To avoid
|
||||
# failing on this platform, this path is commented out for now.
|
||||
# if list(lhs_spec) == list(range(len(lhs_spec))):
|
||||
# return "NC" + spatial_dim_alphabet
|
||||
is_depthwise = in_channels == feature_group_count and feature_group_count > 1
|
||||
is_transpose = list(lhs_dilation) != [1] * nr_spatial_dimensions
|
||||
is_atrous = list(rhs_dilation) != [1] * nr_spatial_dimensions
|
||||
|
||||
# - [batch_size] + input_spatial_shape + [in_channels]
|
||||
if list(lhs_spec) == ([0, len(lhs_spec) - 1] +
|
||||
list(range(1,
|
||||
len(lhs_spec) - 1))):
|
||||
return "N" + spatial_dim_alphabet + "C"
|
||||
raise error("Data format is unsupported by TensorFlow.")
|
||||
if feature_group_count > 1 and not is_depthwise:
|
||||
raise error("Grouped convolutions are unsupported")
|
||||
|
||||
def convert_dilation_and_compute_result(tf_padding: str,
|
||||
tf_dim_nums: str) -> TfVal:
|
||||
no_dilation = [1] * nb_spatial_dimensions
|
||||
# TODO(bchetioui): is there a generic way to do a transposed atrous
|
||||
# convolution in TensorFlow?
|
||||
if not (list(lhs_dilation) == no_dilation or
|
||||
list(rhs_dilation) == no_dilation):
|
||||
raise error("Both LHS and RHS dilations are set.")
|
||||
# This is a non-dilated or atrous convolution
|
||||
if list(lhs_dilation) == no_dilation:
|
||||
return tf.nn.convolution(
|
||||
lhs,
|
||||
rhs,
|
||||
strides=window_strides,
|
||||
padding=tf_padding,
|
||||
data_format=tf_dim_nums,
|
||||
dilations=rhs_dilation)
|
||||
# TODO(bchetioui): the below path is unreachable for now, as passing a lhs
|
||||
# dilation to this function will result in convert_padding returning None
|
||||
# systematically. This must be investigated further.
|
||||
# Dilation of the LHS is transposed convolution
|
||||
return tf.nn.conv_transpose(
|
||||
lhs,
|
||||
rhs,
|
||||
out_shape,
|
||||
window_strides,
|
||||
padding=tf_padding,
|
||||
data_format=tf_dim_nums,
|
||||
dilations=lhs_dilation)
|
||||
if is_transpose:
|
||||
# We provide support for transposed convolutions called through
|
||||
# lax.conv2d_tranpose, but only if the provided padding was VALID.
|
||||
if not _is_valid_padding(rhs_spatial_shapes, window_strides, padding):
|
||||
raise error("Can only convert Transposed Convolutions with 'VALID' padding")
|
||||
|
||||
tf_padding = convert_padding()
|
||||
tf_dim_nums = convert_dim_nums()
|
||||
return convert_dilation_and_compute_result(tf_padding, tf_dim_nums)
|
||||
if [is_depthwise, is_atrous, is_transpose].count(True) > 1:
|
||||
raise error("Can only do one of depthwise, atrous and tranposed convolutions")
|
||||
|
||||
rhs_dilated_shape = [
|
||||
(k - 1) * r + 1 for k, r in zip(rhs_spatial_shapes, rhs_dilation)
|
||||
]
|
||||
|
||||
padding_type, padded_lhs = _pad_for_tf_conv(lhs, rhs_dilated_shape, window_strides,
|
||||
padding)
|
||||
|
||||
if any(r > l for l, r in zip(padded_lhs.shape[1:3], rhs_dilated_shape)):
|
||||
# If the filter shape is bigger than the input shape in a spatial dimension,
|
||||
# lax returns only zeros while tf.conv2d returns an error.
|
||||
# We thus return zeros to make sure the behavior is consistent.
|
||||
return tf.broadcast_to(tf.constant(0, dtype=tf.float32), out_shape)
|
||||
|
||||
# Some TF ops require len(window_strides) == 4 while others do not. We simply
|
||||
# ensure it always has len(4).
|
||||
if type(window_strides) == int:
|
||||
window_strides = [window_strides] * 2
|
||||
if len(window_strides) == 2:
|
||||
window_strides = [1] + list(window_strides) + [1]
|
||||
|
||||
if is_depthwise:
|
||||
# Reshape filter from
|
||||
# [filter_height, filter_width, 1, in_channels * channel_multiplier] to
|
||||
# [filter_height, filter_width, in_channels, channel_multiplier].
|
||||
new_rhs_shape = tuple(rhs_spatial_shapes) + (in_channels, rhs_out_channel // in_channels)
|
||||
output = tf.nn.depthwise_conv2d(
|
||||
input=padded_lhs,
|
||||
filter=tf.reshape(rhs, new_rhs_shape),
|
||||
strides=window_strides,
|
||||
padding=padding_type,
|
||||
dilations=rhs_dilation)
|
||||
|
||||
elif is_transpose:
|
||||
# tf.nn.conv2d_transpose requires a transposed filter.
|
||||
rhs_t = tf.reverse(rhs, [0, 1])
|
||||
rhs_t = tf.transpose(rhs_t, (0, 1, 3, 2))
|
||||
|
||||
# We should tranpose `out_shape` so it conforms to what TF expects.
|
||||
tf_out_shape = tuple(out_shape[i] for i in output_perm) # "NCHW"
|
||||
tf_out_shape = tuple(tf_out_shape[i] for i in (0, 2, 3, 1)) # "NCHW" -> "NHWC"
|
||||
|
||||
output = tf.nn.conv2d_transpose(
|
||||
input=lhs,
|
||||
filters=rhs_t,
|
||||
output_shape=tf_out_shape,
|
||||
strides=lhs_dilation,
|
||||
padding="VALID")
|
||||
|
||||
else:
|
||||
output = tf.nn.conv2d(
|
||||
input=padded_lhs,
|
||||
filters=rhs,
|
||||
strides=window_strides,
|
||||
padding=padding_type,
|
||||
dilations=rhs_dilation)
|
||||
|
||||
# TF outputs in format "NHWC", so convert to "NCHW", which is lax's default
|
||||
# format.
|
||||
output = tf.transpose(output, (0, 3, 1, 2)) # "NHWC" --> "NCHW"
|
||||
# To determine the right permutation, we compute the inverse permutation of
|
||||
# `output_perm`, so that when `output_perm` is applied to `output`, we obtain
|
||||
# the outpt in NCHW format.
|
||||
inverse_perm = tuple(output_perm.index(i) for i in range(4))
|
||||
output = tf.transpose(output, inverse_perm) # "NCHW" -> desired output shape.
|
||||
return output
|
||||
|
||||
|
||||
def _conv_general_dilated(lhs, rhs, *,
|
||||
|
@ -2662,63 +2662,66 @@ def _make_conv_harness(name,
|
||||
dimension_numbers=("NCHW", "OIHW", "NCHW"),
|
||||
batch_group_count=1,
|
||||
preferred_element_type=None,
|
||||
enable_xla=True):
|
||||
define(
|
||||
lax.conv_general_dilated_p,
|
||||
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_padding={padding}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_precision={precision}_preferred={jtu.dtype_str(preferred_element_type)}_enablexla={enable_xla}"
|
||||
.replace(" ", ""),
|
||||
lax.conv_general_dilated,
|
||||
[
|
||||
RandArg(lhs_shape, dtype),
|
||||
RandArg(rhs_shape, dtype),
|
||||
StaticArg(window_strides),
|
||||
StaticArg(padding),
|
||||
StaticArg(lhs_dilation),
|
||||
StaticArg(rhs_dilation),
|
||||
StaticArg(dimension_numbers),
|
||||
StaticArg(feature_group_count),
|
||||
StaticArg(batch_group_count),
|
||||
StaticArg(precision),
|
||||
StaticArg(preferred_element_type),
|
||||
],
|
||||
lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dtype=dtype,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
enable_xla=enable_xla,
|
||||
jax_unimplemented=[
|
||||
Limitation(
|
||||
"preferred_element_type=i64 not implemented",
|
||||
devices="tpu",
|
||||
dtypes=(np.int8, np.int16, np.int32),
|
||||
enabled=(preferred_element_type in [np.int64])),
|
||||
# b/183565702 - no integer convolutions for GPU
|
||||
Limitation(
|
||||
"preferred_element_type not implemented for integers",
|
||||
devices="gpu",
|
||||
dtypes=(np.int8, np.int16, np.int32),
|
||||
enabled=(preferred_element_type in [np.int16, np.int32,
|
||||
np.int64])),
|
||||
Limitation(
|
||||
"preferred_element_type=f64 not implemented",
|
||||
devices="tpu",
|
||||
dtypes=(np.float16, jnp.bfloat16, np.float32),
|
||||
enabled=(preferred_element_type in [np.float64])),
|
||||
Limitation(
|
||||
"preferred_element_type=c128 not implemented",
|
||||
devices="tpu",
|
||||
dtypes=np.complex64,
|
||||
enabled=(preferred_element_type in [np.complex128])),
|
||||
],
|
||||
)
|
||||
works_without_xla=False):
|
||||
enable_xla_cases = [True, False] if works_without_xla else [True]
|
||||
|
||||
for enable_xla in enable_xla_cases:
|
||||
define(
|
||||
lax.conv_general_dilated_p,
|
||||
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_windowstrides={window_strides}_padding={padding}_lhsdilation={lhs_dilation}_rhsdilation={rhs_dilation}_dimensionnumbers={dimension_numbers}_featuregroupcount={feature_group_count}_batchgroupcount={batch_group_count}_precision={precision}_preferred={jtu.dtype_str(preferred_element_type)}_enablexla={enable_xla}"
|
||||
.replace(" ", ""),
|
||||
lax.conv_general_dilated,
|
||||
[
|
||||
RandArg(lhs_shape, dtype),
|
||||
RandArg(rhs_shape, dtype),
|
||||
StaticArg(window_strides),
|
||||
StaticArg(padding),
|
||||
StaticArg(lhs_dilation),
|
||||
StaticArg(rhs_dilation),
|
||||
StaticArg(dimension_numbers),
|
||||
StaticArg(feature_group_count),
|
||||
StaticArg(batch_group_count),
|
||||
StaticArg(precision),
|
||||
StaticArg(preferred_element_type),
|
||||
],
|
||||
lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dtype=dtype,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
enable_xla=enable_xla,
|
||||
jax_unimplemented=[
|
||||
Limitation(
|
||||
"preferred_element_type=i64 not implemented",
|
||||
devices="tpu",
|
||||
dtypes=(np.int8, np.int16, np.int32),
|
||||
enabled=(preferred_element_type in [np.int64])),
|
||||
# b/183565702 - no integer convolutions for GPU
|
||||
Limitation(
|
||||
"preferred_element_type not implemented for integers",
|
||||
devices="gpu",
|
||||
dtypes=(np.int8, np.int16, np.int32),
|
||||
enabled=(preferred_element_type in [np.int16, np.int32,
|
||||
np.int64])),
|
||||
Limitation(
|
||||
"preferred_element_type=f64 not implemented",
|
||||
devices="tpu",
|
||||
dtypes=(np.float16, jnp.bfloat16, np.float32),
|
||||
enabled=(preferred_element_type in [np.float64])),
|
||||
Limitation(
|
||||
"preferred_element_type=c128 not implemented",
|
||||
devices="tpu",
|
||||
dtypes=np.complex64,
|
||||
enabled=(preferred_element_type in [np.complex128])),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Validate dtypes and precision
|
||||
@ -2729,12 +2732,19 @@ for dtype in jtu.dtypes.all_inexact:
|
||||
# This first harness runs the tests for all dtypes and precisions using
|
||||
# default values for all the other parameters. Variations of other parameters
|
||||
# can thus safely skip testing their corresponding default value.
|
||||
_make_conv_harness("dtype_precision", dtype=dtype, precision=precision)
|
||||
_make_conv_harness(
|
||||
"dtype_precision",
|
||||
dtype=dtype,
|
||||
precision=precision)
|
||||
|
||||
|
||||
# Validate preferred_element_type
|
||||
for dtype, preferred_element_type in preferred_type_combinations:
|
||||
works_without_xla = dtype == np.float32 and preferred_element_type == np.float32
|
||||
_make_conv_harness(
|
||||
"preferred", dtype=dtype, preferred_element_type=preferred_element_type)
|
||||
"preferred", dtype=dtype,
|
||||
preferred_element_type=preferred_element_type,
|
||||
works_without_xla=works_without_xla)
|
||||
|
||||
# Validate variations of feature_group_count and batch_group_count
|
||||
for batch_group_count, feature_group_count in [
|
||||
@ -2752,25 +2762,81 @@ for batch_group_count, feature_group_count in [
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
|
||||
|
||||
#--- BEGIN Tests for conv_general_dilated with works_for_xla=True ---
|
||||
|
||||
|
||||
# feature_group_count is supported for enable_xla=False only if we are doing a
|
||||
# depthwise convolution, i.e.: in_channels == feature_group_count.
|
||||
# See explanation of depthwise convolution at
|
||||
# https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
|
||||
_make_conv_harness(
|
||||
"depthwise2d",
|
||||
lhs_shape=(2, 3, 9, 9), # "NCHW": in_channels == 3
|
||||
rhs_shape=(12, 1, 3, 3), # "OIHW": channel_multiplier = 12/3 = 4
|
||||
feature_group_count=3,
|
||||
works_without_xla=True)
|
||||
|
||||
# Validate variations of window_strides
|
||||
for window_strides in [(2, 3)]:
|
||||
_make_conv_harness("window_strides", window_strides=window_strides)
|
||||
_make_conv_harness(
|
||||
"window_strides",
|
||||
window_strides=window_strides,
|
||||
works_without_xla=True)
|
||||
|
||||
# Validate variations of padding
|
||||
for padding in [
|
||||
((1, 2), (0, 0)), # padding only one spatial axis
|
||||
((1, 2), (2, 1)) # padding on both spatial axes
|
||||
]:
|
||||
_make_conv_harness("padding", padding=padding)
|
||||
_make_conv_harness("padding", padding=padding, works_without_xla=True)
|
||||
|
||||
# Validate variations of dilations
|
||||
for lhs_dilation, rhs_dilation in [
|
||||
((2, 2), (1, 1)), # dilation only on LHS (transposed)
|
||||
((1, 1), (2, 3)), # dilation only on RHS (atrous)
|
||||
((2, 3), (3, 2)) # dilation on both LHS and RHS (transposed & atrous)
|
||||
((1, 1), (2, 2)), # dilation only on RHS (atrous)
|
||||
]:
|
||||
_make_conv_harness(
|
||||
"dilations", lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
|
||||
"dilations",
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
works_without_xla=True)
|
||||
|
||||
# Simulate a call from lax.conv_transpose.
|
||||
_make_conv_harness(
|
||||
"conv_tranpose2d_valid_padding",
|
||||
lhs_shape=(1, 16, 16, 2),
|
||||
rhs_shape=(2, 3, 2, 2),
|
||||
window_strides=(1, 1),
|
||||
lhs_dilation=(2, 2),
|
||||
padding=((1, 1), (2, 2)),
|
||||
dimension_numbers=("NHWC", "HWIO", "NHWC"),
|
||||
works_without_xla=True)
|
||||
|
||||
# Validate rhs > lhs.
|
||||
# One dimension of rhs is bigger than lhs.
|
||||
_make_conv_harness(
|
||||
"rhs_oob",
|
||||
lhs_shape=(2, 3, 9, 10),
|
||||
rhs_shape=(3, 3, 10, 5),
|
||||
works_without_xla=True)
|
||||
|
||||
# Effective rhs size is too big after applying rhs_dilation.
|
||||
_make_conv_harness(
|
||||
"rhs_oob_after_dilation",
|
||||
lhs_shape=(2, 3, 9, 10),
|
||||
rhs_shape=(3, 3, 4, 5),
|
||||
rhs_dilation=(2, 3),
|
||||
works_without_xla=True)
|
||||
|
||||
# Effective rhs is too big after applying input padding.
|
||||
_make_conv_harness(
|
||||
"rhs_oob_after_pading",
|
||||
lhs_shape=(1, 3, 2, 2),
|
||||
rhs_shape=(64, 3, 7, 7),
|
||||
window_strides=(2, 2),
|
||||
padding=((3, 3), (3, 3)),
|
||||
works_without_xla=True)
|
||||
|
||||
|
||||
# Dimension numbers and corresponding permutation
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
@ -2781,57 +2847,63 @@ for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
"dimension_numbers",
|
||||
lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers)
|
||||
|
||||
for padding, lhs_dilation, rhs_dilation in [
|
||||
("VALID", (1,), (1,)), # no dilation with "VALID" padding
|
||||
("SAME", (1,), (1,)), # no dilation with "SAME" padding
|
||||
("VALID", (1,), (2,)), # dilation only on RHS with "VALID" padding
|
||||
("SAME", (1,), (2,)), # dilation only on RHS with "SAME" padding
|
||||
# TODO(bchetioui): LHS dilation with string padding can never be done using
|
||||
# TF convolution functions for now.
|
||||
]:
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
(("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)), # TF default
|
||||
# TODO(bchetioui): the NCW data format is not supported on CPU for TF
|
||||
# for now. That path is thus disabled to allow the code to use XLA instead.
|
||||
]:
|
||||
for enable_xla in [False, True]:
|
||||
_make_conv_harness(
|
||||
"tf_conversion_path_1d",
|
||||
lhs_shape=lhs_shape,
|
||||
padding=padding,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers,
|
||||
window_strides=(1,),
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
enable_xla=enable_xla)
|
||||
dimension_numbers=dimension_numbers,
|
||||
works_without_xla=True)
|
||||
|
||||
for padding, lhs_dilation, rhs_dilation in [
|
||||
("VALID", (1, 1), (1, 1)), # no dilation with "VALID" padding
|
||||
("SAME", (1, 1), (1, 1)), # no dilation with "SAME" padding
|
||||
("VALID", (1, 1), (1, 2)), # dilation only on RHS with "VALID" padding
|
||||
("SAME", (1, 1), (1, 2)), # dilation only on RHS with "SAME" padding
|
||||
# TODO(bchetioui): LHS dilation with string padding can never be done using
|
||||
# TF convolution functions for now.
|
||||
([(1, 2), (0, 1)], (1, 1), (1, 2))
|
||||
]:
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
(("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)), # TF default
|
||||
# TODO(bchetioui): the NCHW data format is not supported on CPU for TF
|
||||
# for now. That path is thus disabled to allow the code to use XLA instead.
|
||||
(("NCHW", "HWIO", "NCHW"), (1, 1, 28, 28), (3, 3, 1, 16)),
|
||||
]:
|
||||
for enable_xla in [False, True]:
|
||||
_make_conv_harness(
|
||||
"tf_conversion_path_2d",
|
||||
lhs_shape=lhs_shape,
|
||||
padding=padding,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers,
|
||||
window_strides=(1, 1),
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
enable_xla=enable_xla)
|
||||
_make_conv_harness(
|
||||
"tf_conversion_path_2d",
|
||||
lhs_shape=lhs_shape,
|
||||
padding=padding,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers,
|
||||
window_strides=(1, 1),
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
works_without_xla=True)
|
||||
|
||||
#--- END Tests for conv_general_dilated with works_for_xla=True ---
|
||||
|
||||
|
||||
for lhs_dilation, rhs_dilation in [
|
||||
# Note: LHS dilation does work for enable_xla=False, but only if
|
||||
# padding=='VALID' (see test above for conv_transpose2d_valid_padding).
|
||||
((2, 2), (1, 1)), # dilation only on LHS (transposed)
|
||||
((2, 3), (3, 2)) # dilation on both LHS and RHS (transposed & atrous)
|
||||
]:
|
||||
_make_conv_harness(
|
||||
"dilations", lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
|
||||
|
||||
|
||||
for padding, lhs_dilation, rhs_dilation in [
|
||||
("VALID", (1,), (1,)), # no dilation with "VALID" padding
|
||||
("SAME", (1,), (1,)), # no dilation with "SAME" padding
|
||||
("VALID", (1,), (2,)), # dilation only on RHS with "VALID" padding
|
||||
("SAME", (1,), (2,)), # dilation only on RHS with "SAME" padding
|
||||
]:
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
(("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)), # TF default
|
||||
]:
|
||||
_make_conv_harness(
|
||||
"tf_conversion_path_1d",
|
||||
lhs_shape=lhs_shape,
|
||||
padding=padding,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers,
|
||||
window_strides=(1,),
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation)
|
||||
|
||||
|
||||
for padding, lhs_dilation, rhs_dilation in [
|
||||
("VALID", (1, 1, 1), (1, 1, 1)), # no dilation with "VALID" padding
|
||||
@ -2839,26 +2911,20 @@ for padding, lhs_dilation, rhs_dilation in [
|
||||
("VALID", (1, 1, 1), (1, 1,
|
||||
2)), # dilation only on RHS with "VALID" padding
|
||||
("SAME", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "SAME" padding
|
||||
# TODO(bchetioui): LHS dilation with string padding can never be done using
|
||||
# TF convolution functions for now.
|
||||
]:
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
# TF default
|
||||
(("NDHWC", "DHWIO", "NDHWC"), (1, 4, 28, 28, 1), (2, 3, 3, 1, 16)),
|
||||
# TODO(bchetioui): the NCDHW data format is not supported on CPU for TF
|
||||
# for now. That path is thus disabled to allow the code to use XLA instead.
|
||||
]:
|
||||
for enable_xla in [False, True]:
|
||||
_make_conv_harness(
|
||||
"tf_conversion_path_3d",
|
||||
lhs_shape=lhs_shape,
|
||||
padding=padding,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers,
|
||||
window_strides=(1, 1, 1),
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
enable_xla=enable_xla)
|
||||
_make_conv_harness(
|
||||
"tf_conversion_path_3d",
|
||||
lhs_shape=lhs_shape,
|
||||
padding=padding,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers,
|
||||
window_strides=(1, 1, 1),
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation)
|
||||
|
||||
|
||||
if config.jax_enable_x64:
|
||||
|
Loading…
x
Reference in New Issue
Block a user