[jax2tf] Add paths that do not use XLA in conv_general_dilated.

This adds some amount of support for people who want to run
convolutions without having XLA linked in. These paths can
seemingly be converted for TFJS as well.

Due to a so far unknown bug in some of the conversions, the
paths are disabled by default and the "ENABLE_TF_CONVOLUTION"
global variable in jax2tf.py must be explictly toggled to use
them. See the comment associated with ENABLE_TF_CONVOLUTION
for context.
This commit is contained in:
Benjamin Chetioui 2020-10-08 17:46:37 +02:00
parent e6399cd7b5
commit c674c19b34
3 changed files with 219 additions and 7 deletions

View File

@ -675,6 +675,128 @@ def _conv_general_precision_config_proto(precision):
proto.operand_precision.append(int(precision))
return proto
# _try_tf_conv returns a Tensor when it succeeds, or a string describing why
# it did not succeed otherwise.
def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
out_shape) -> Union[str, TfVal]:
# TODO(bchetioui): this function is not exhaustive wrt which convolution cases
# can be translated into TF primitives. Further investigation is needed to
# fully flesh it out.
if not lhs.dtype in [tf.float16, tf.float32, tf.float64]:
return f"tf.nn.convolution is not supported for dtype {lhs.dtype}"
if feature_group_count != 1:
return "tf.nn.convolution does not support grouped convolutions"
# TODO(bchetioui): is there something to do with batch_group_count?
if batch_group_count != 1:
return "Unimplemented support for batch_group_count != 1"
nb_spatial_dimensions = len(lhs.shape) - 2
# TF can only deal with 1D, 2D and 3D convolution
if nb_spatial_dimensions < 1 or nb_spatial_dimensions > 3:
return ("TensorFlow can only handle convolutions with 1, 2, or 3 "
"spatial dimensions")
# TODO(bchetioui): handle different stride cases
if list(window_strides) != [1] * nb_spatial_dimensions:
return ("Unimplemented support for window_strides != "
f"{tuple([1] * nb_spatial_dimensions)}")
success = lambda res: (res, None)
failure = lambda msg: (None, msg)
def convert_padding():
# TODO(bchetioui): in this instance, we can not use padtype_to_pads as
# string padding is not implemented for transposed convolution.
if list(lhs_dilation) != [1] * nb_spatial_dimensions:
return failure("Padding conversion is not supported for transposed "
"convolution.")
lhs_perm, rhs_perm, _ = dimension_numbers
effective_rhs_shape = [(k-1) * r + 1 for k, r in
zip(np.take(rhs.shape, rhs_perm)[2:], rhs_dilation)]
lhs_shape = np.take(lhs.shape, lhs_perm)[2:]
# TF only allows 'VALID' and 'SAME' padding
for pad_str in ['VALID', 'SAME']:
gen_padding = lax.padtype_to_pads(
lhs_shape, effective_rhs_shape, window_strides, pad_str)
if list(gen_padding) == list(padding):
return success(pad_str)
return failure("Input padding not supported in TensorFlow.")
def convert_dim_nums():
lhs_spec, rhs_spec, out_spec = dimension_numbers
# TF only allows filters with shape:
# spatial_filter_shape + [in_channels, out_channels]. In JAX however,
# rhs_spec is represented as a tuple containing the following:
# [out_channels, in_channels] + spatial_filter_shape.
supported_rhs_shape = ([nb_spatial_dimensions + 1, nb_spatial_dimensions] +
list(range(nb_spatial_dimensions)))
if list(rhs_spec) != supported_rhs_shape:
return failure("Input filter (RHS) shape format not supported in "
"TensorFlow")
# TF only supports same LHS and output data format
if lhs_spec != out_spec:
return failure("TensorFlow requires the same data format for LHS and "
"output.")
# Alphabet extracted from the documentation of tf.conv{1,2,3}d
spatial_dim_alphabet = 'DHW'[-nb_spatial_dimensions:]
# TF only supports the following data formats:
# - [batch_size, in_channels] + input_spatial_shape
# TODO(bchetioui): TF currently does not support the above on CPU. To avoid
# failing on this platform, this path is commented out for now.
#if list(lhs_spec) == list(range(len(lhs_spec))):
# return "NC" + spatial_dim_alphabet
# - [batch_size] + input_spatial_shape + [in_channels]
if list(lhs_spec) == ([0, len(lhs_spec) - 1] +
list(range(1, len(lhs_spec) - 1))):
return success("N" + spatial_dim_alphabet + "C")
return failure("Data format is unsupported by TensorFlow")
def convert_dilation_and_compute_result(tf_padding, tf_dim_nums):
no_dilation = [1] * nb_spatial_dimensions
# TODO(bchetioui): is there a generic way to do a transposed atrous
# convolution in TensorFlow?
if not (list(lhs_dilation) == no_dilation or
list(rhs_dilation) == no_dilation):
return "Both LHS and RHS dilations are set"
# This is a non-dilated or atrous convolution
if list(lhs_dilation) == no_dilation:
return tf.nn.convolution(
lhs, rhs, strides=window_strides, padding=tf_padding,
data_format=tf_dim_nums, dilations=rhs_dilation)
# TODO(bchetioui): the below path is unreachable for now, as passing a lhs
# dilation to this function will result in convert_padding returning None
# systematically. This must be investigated further.
# Dilation of the LHS is transposed convolution
return tf.nn.conv_transpose(
lhs, rhs, out_shape, window_strides, padding=tf_padding,
data_format=tf_dim_nums, dilations=lhs_dilation)
tf_padding, error = convert_padding()
if tf_padding is None:
return error
tf_dim_nums, error = convert_dim_nums()
if tf_dim_nums is None:
return error
return convert_dilation_and_compute_result(tf_padding, tf_dim_nums)
# TODO(bchetioui): enabling this flag permits using a conversion path purely
# based on TF (and not XLA) for _conv_general_dilated in cases when it is
# possible. It is disabled by default due to a so far unknown bug when running
# a test in compiled mode. The test that fails is
#
# test_conv_general_dilated_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None
#
# with the following assertion error in TensorFlowTrace.process_primitive:
#
# AssertionError: conv_general_dilated: out.aval = ShapedArray(float32[1,3,24,26,16]); expected ShapedArray(float32[1,3,26,24,16])
#
# Deactivating this assertion is enough to pass the test, which suggests that
# the end shape is indeed the correct one (i.e. (1,3,26,24,16)). Further
# investigation is required to really understand this behavior, which we have
# not managed to reproduce as a pure TF test.
ENABLE_TF_CONVOLUTION = False
def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, lhs_shape, rhs_shape, precision):
@ -683,6 +805,15 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count, lhs_shape,
rhs_shape, precision)
if ENABLE_TF_CONVOLUTION:
info_or_result = _try_tf_conv(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count, out_shape
)
if not isinstance(info_or_result, str):
return info_or_result
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
precision_config_proto = _conv_general_precision_config_proto(precision)
assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count

View File

@ -989,4 +989,80 @@ lax_conv_general_dilated = tuple( # Validate dtypes and precision
(("NHWC", "HWIO", "NHWC"), (2, 9, 10, 3), (4, 5, 3, 3)), # TF default
(("NCHW", "HWIO", "NHWC"), (2, 3, 9, 10), (4, 5, 3, 3)), # custom
]
) + tuple(
_make_conv_harness("tf_conversion_path_1d", lhs_shape=lhs_shape,
padding=padding, rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers, window_strides=(1,),
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1,), (1,)), # no dilation with "VALID" padding
("SAME", (1,), (1,)), # no dilation with "SAME" padding
("VALID", (1,), (2,)), # dilation only on RHS with "VALID" padding
("SAME", (1,), (2,)), # dilation only on RHS with "SAME" padding
# TODO(bchetioui): LHS dilation with string padding can never be done using
# TF convolution functions for now.
]
for dimension_numbers, lhs_shape, rhs_shape in [
(("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)), # TF default
# TODO(bchetioui): the NCW data format is not supported on CPU for TF
# for now. That path is thus disabled to allow the code to use XLA instead.
]
) + tuple(
_make_conv_harness("tf_conversion_path_2d", lhs_shape=lhs_shape,
padding=padding, rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers, window_strides=(1, 1),
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1, 1), (1, 1)), # no dilation with "VALID" padding
("SAME", (1, 1), (1, 1)), # no dilation with "SAME" padding
("VALID", (1, 1), (1, 2)), # dilation only on RHS with "VALID" padding
("SAME", (1, 1), (1, 2)), # dilation only on RHS with "SAME" padding
# TODO(bchetioui): LHS dilation with string padding can never be done using
# TF convolution functions for now.
]
for dimension_numbers, lhs_shape, rhs_shape in [
(("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)), # TF default
# TODO(bchetioui): the NCHW data format is not supported on CPU for TF
# for now. That path is thus disabled to allow the code to use XLA instead.
]
) + tuple(
_make_conv_harness("tf_conversion_path_3d", lhs_shape=lhs_shape,
padding=padding, rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
window_strides=(1, 1, 1), lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation)
for padding, lhs_dilation, rhs_dilation in [
("VALID", (1, 1, 1), (1, 1, 1)), # no dilation with "VALID" padding
("SAME", (1, 1, 1), (1, 1, 1)), # no dilation with "SAME" padding
("VALID", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "VALID" padding
("SAME", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "SAME" padding
# TODO(bchetioui): LHS dilation with string padding can never be done using
# TF convolution functions for now.
]
for dimension_numbers, lhs_shape, rhs_shape in [
# TF default
(("NDHWC", "DHWIO", "NDHWC"), (1, 4, 28, 28, 1), (2, 3, 3, 1, 16)),
# TODO(bchetioui): the NCDHW data format is not supported on CPU for TF
# for now. That path is thus disabled to allow the code to use XLA instead.
]
) + tuple(
# tf.nn.convolution only supports a subset of the possible dtypes for JAX
# convolutions (float16, float32, float64). With the below tests, we ensure
# that we avoid this branch in cases when it would not succeed.
_make_conv_harness("tf_conversion_path_dtype", lhs_shape=lhs_shape,
padding='VALID', rhs_shape=rhs_shape, dtype=dtype,
dimension_numbers=dimension_numbers, lhs_dilation=(1, 1),
rhs_dilation=(1, 1))
for dtype in jtu.dtypes.all_inexact
for dimension_numbers, lhs_shape, rhs_shape in [
(("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)), # TF default
]
) + tuple(
# Validate that feature_group_count != 1 does not go through tf.nn.convolution
# as that would result in an exception.
[_make_conv_harness("tf_avoid_path_feature_group_count", dtype=np.float32,
lhs_shape=(1, 28, 28, 32), rhs_shape=(3, 3, 16, 8),
padding='VALID', batch_group_count=1, lhs_dilation=(1, 1),
rhs_dilation=(1, 1), feature_group_count=2,
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))]
)

View File

@ -614,22 +614,27 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated)
def test_conv_general_dilated(self, harness: primitive_harness.Harness):
if jtu.device_under_test() == "gpu" and harness.params["dtype"] in [np.complex64, np.complex128]:
dtype, device = harness.params["dtype"], jtu.device_under_test()
if device == "gpu" and dtype in [np.complex64, np.complex128]:
raise unittest.SkipTest("TODO: crash on GPU in TF")
tol = None
if jtu.device_under_test() == "gpu":
if device == "gpu":
tol = 1e-4
elif jtu.device_under_test() == "tpu":
elif device == "tpu":
tol = 1e-3
# TODO(bchetioui): significant discrepancies in some float16 cases.
if harness.params["dtype"] == np.float16:
if dtype == np.float16:
tol = 1.
# TODO(bchetioui): slight occasional discrepancy in float32 cases.
elif harness.params["dtype"] == np.float32:
tol = 0.5 if jtu.device_under_test() == "tpu" else 1e-4
elif harness.params["dtype"] == np.complex64 and jtu.device_under_test() == "tpu":
elif dtype == np.float32:
tol = 0.5 if device == "tpu" else 1e-4
elif dtype == np.complex64 and device == "tpu":
tol = 0.1
# TODO(bchetioui): slight discrepancy when going through the path using
# tf.nn.convolution.
elif dtype == np.float64 and device == "cpu":
tol = 1e-13
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=tol, rtol=tol)