mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add paths that do not use XLA in conv_general_dilated.
This adds some amount of support for people who want to run convolutions without having XLA linked in. These paths can seemingly be converted for TFJS as well. Due to a so far unknown bug in some of the conversions, the paths are disabled by default and the "ENABLE_TF_CONVOLUTION" global variable in jax2tf.py must be explictly toggled to use them. See the comment associated with ENABLE_TF_CONVOLUTION for context.
This commit is contained in:
parent
e6399cd7b5
commit
c674c19b34
@ -675,6 +675,128 @@ def _conv_general_precision_config_proto(precision):
|
||||
proto.operand_precision.append(int(precision))
|
||||
return proto
|
||||
|
||||
# _try_tf_conv returns a Tensor when it succeeds, or a string describing why
|
||||
# it did not succeed otherwise.
|
||||
def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
out_shape) -> Union[str, TfVal]:
|
||||
# TODO(bchetioui): this function is not exhaustive wrt which convolution cases
|
||||
# can be translated into TF primitives. Further investigation is needed to
|
||||
# fully flesh it out.
|
||||
if not lhs.dtype in [tf.float16, tf.float32, tf.float64]:
|
||||
return f"tf.nn.convolution is not supported for dtype {lhs.dtype}"
|
||||
if feature_group_count != 1:
|
||||
return "tf.nn.convolution does not support grouped convolutions"
|
||||
# TODO(bchetioui): is there something to do with batch_group_count?
|
||||
if batch_group_count != 1:
|
||||
return "Unimplemented support for batch_group_count != 1"
|
||||
nb_spatial_dimensions = len(lhs.shape) - 2
|
||||
# TF can only deal with 1D, 2D and 3D convolution
|
||||
if nb_spatial_dimensions < 1 or nb_spatial_dimensions > 3:
|
||||
return ("TensorFlow can only handle convolutions with 1, 2, or 3 "
|
||||
"spatial dimensions")
|
||||
# TODO(bchetioui): handle different stride cases
|
||||
if list(window_strides) != [1] * nb_spatial_dimensions:
|
||||
return ("Unimplemented support for window_strides != "
|
||||
f"{tuple([1] * nb_spatial_dimensions)}")
|
||||
|
||||
success = lambda res: (res, None)
|
||||
failure = lambda msg: (None, msg)
|
||||
|
||||
def convert_padding():
|
||||
# TODO(bchetioui): in this instance, we can not use padtype_to_pads as
|
||||
# string padding is not implemented for transposed convolution.
|
||||
if list(lhs_dilation) != [1] * nb_spatial_dimensions:
|
||||
return failure("Padding conversion is not supported for transposed "
|
||||
"convolution.")
|
||||
lhs_perm, rhs_perm, _ = dimension_numbers
|
||||
effective_rhs_shape = [(k-1) * r + 1 for k, r in
|
||||
zip(np.take(rhs.shape, rhs_perm)[2:], rhs_dilation)]
|
||||
lhs_shape = np.take(lhs.shape, lhs_perm)[2:]
|
||||
# TF only allows 'VALID' and 'SAME' padding
|
||||
for pad_str in ['VALID', 'SAME']:
|
||||
gen_padding = lax.padtype_to_pads(
|
||||
lhs_shape, effective_rhs_shape, window_strides, pad_str)
|
||||
if list(gen_padding) == list(padding):
|
||||
return success(pad_str)
|
||||
return failure("Input padding not supported in TensorFlow.")
|
||||
|
||||
def convert_dim_nums():
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
# TF only allows filters with shape:
|
||||
# spatial_filter_shape + [in_channels, out_channels]. In JAX however,
|
||||
# rhs_spec is represented as a tuple containing the following:
|
||||
# [out_channels, in_channels] + spatial_filter_shape.
|
||||
supported_rhs_shape = ([nb_spatial_dimensions + 1, nb_spatial_dimensions] +
|
||||
list(range(nb_spatial_dimensions)))
|
||||
if list(rhs_spec) != supported_rhs_shape:
|
||||
return failure("Input filter (RHS) shape format not supported in "
|
||||
"TensorFlow")
|
||||
# TF only supports same LHS and output data format
|
||||
if lhs_spec != out_spec:
|
||||
return failure("TensorFlow requires the same data format for LHS and "
|
||||
"output.")
|
||||
# Alphabet extracted from the documentation of tf.conv{1,2,3}d
|
||||
spatial_dim_alphabet = 'DHW'[-nb_spatial_dimensions:]
|
||||
# TF only supports the following data formats:
|
||||
# - [batch_size, in_channels] + input_spatial_shape
|
||||
|
||||
# TODO(bchetioui): TF currently does not support the above on CPU. To avoid
|
||||
# failing on this platform, this path is commented out for now.
|
||||
#if list(lhs_spec) == list(range(len(lhs_spec))):
|
||||
# return "NC" + spatial_dim_alphabet
|
||||
|
||||
# - [batch_size] + input_spatial_shape + [in_channels]
|
||||
if list(lhs_spec) == ([0, len(lhs_spec) - 1] +
|
||||
list(range(1, len(lhs_spec) - 1))):
|
||||
return success("N" + spatial_dim_alphabet + "C")
|
||||
return failure("Data format is unsupported by TensorFlow")
|
||||
|
||||
def convert_dilation_and_compute_result(tf_padding, tf_dim_nums):
|
||||
no_dilation = [1] * nb_spatial_dimensions
|
||||
# TODO(bchetioui): is there a generic way to do a transposed atrous
|
||||
# convolution in TensorFlow?
|
||||
if not (list(lhs_dilation) == no_dilation or
|
||||
list(rhs_dilation) == no_dilation):
|
||||
return "Both LHS and RHS dilations are set"
|
||||
# This is a non-dilated or atrous convolution
|
||||
if list(lhs_dilation) == no_dilation:
|
||||
return tf.nn.convolution(
|
||||
lhs, rhs, strides=window_strides, padding=tf_padding,
|
||||
data_format=tf_dim_nums, dilations=rhs_dilation)
|
||||
# TODO(bchetioui): the below path is unreachable for now, as passing a lhs
|
||||
# dilation to this function will result in convert_padding returning None
|
||||
# systematically. This must be investigated further.
|
||||
# Dilation of the LHS is transposed convolution
|
||||
return tf.nn.conv_transpose(
|
||||
lhs, rhs, out_shape, window_strides, padding=tf_padding,
|
||||
data_format=tf_dim_nums, dilations=lhs_dilation)
|
||||
|
||||
tf_padding, error = convert_padding()
|
||||
if tf_padding is None:
|
||||
return error
|
||||
tf_dim_nums, error = convert_dim_nums()
|
||||
if tf_dim_nums is None:
|
||||
return error
|
||||
return convert_dilation_and_compute_result(tf_padding, tf_dim_nums)
|
||||
|
||||
# TODO(bchetioui): enabling this flag permits using a conversion path purely
|
||||
# based on TF (and not XLA) for _conv_general_dilated in cases when it is
|
||||
# possible. It is disabled by default due to a so far unknown bug when running
|
||||
# a test in compiled mode. The test that fails is
|
||||
#
|
||||
# test_conv_general_dilated_tf_conversion_path_3d_lhs=float32[1,4,28,28,1]_rhs=float32[2,3,3,1,16]_windowstrides=(1,1,1)_padding=VALID_lhsdilation=(1,1,1)_rhsdilation=(1,1,2)_dimensionnumbers=('NDHWC','DHWIO','NDHWC')_featuregroupcount=1_batchgroupcount=1_precision=None
|
||||
#
|
||||
# with the following assertion error in TensorFlowTrace.process_primitive:
|
||||
#
|
||||
# AssertionError: conv_general_dilated: out.aval = ShapedArray(float32[1,3,24,26,16]); expected ShapedArray(float32[1,3,26,24,16])
|
||||
#
|
||||
# Deactivating this assertion is enough to pass the test, which suggests that
|
||||
# the end shape is indeed the correct one (i.e. (1,3,26,24,16)). Further
|
||||
# investigation is required to really understand this behavior, which we have
|
||||
# not managed to reproduce as a pure TF test.
|
||||
ENABLE_TF_CONVOLUTION = False
|
||||
|
||||
def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count,
|
||||
batch_group_count, lhs_shape, rhs_shape, precision):
|
||||
@ -683,6 +805,15 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count, lhs_shape,
|
||||
rhs_shape, precision)
|
||||
|
||||
if ENABLE_TF_CONVOLUTION:
|
||||
info_or_result = _try_tf_conv(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count, out_shape
|
||||
)
|
||||
if not isinstance(info_or_result, str):
|
||||
return info_or_result
|
||||
|
||||
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
|
||||
precision_config_proto = _conv_general_precision_config_proto(precision)
|
||||
assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count
|
||||
|
@ -989,4 +989,80 @@ lax_conv_general_dilated = tuple( # Validate dtypes and precision
|
||||
(("NHWC", "HWIO", "NHWC"), (2, 9, 10, 3), (4, 5, 3, 3)), # TF default
|
||||
(("NCHW", "HWIO", "NHWC"), (2, 3, 9, 10), (4, 5, 3, 3)), # custom
|
||||
]
|
||||
) + tuple(
|
||||
_make_conv_harness("tf_conversion_path_1d", lhs_shape=lhs_shape,
|
||||
padding=padding, rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers, window_strides=(1,),
|
||||
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
|
||||
for padding, lhs_dilation, rhs_dilation in [
|
||||
("VALID", (1,), (1,)), # no dilation with "VALID" padding
|
||||
("SAME", (1,), (1,)), # no dilation with "SAME" padding
|
||||
("VALID", (1,), (2,)), # dilation only on RHS with "VALID" padding
|
||||
("SAME", (1,), (2,)), # dilation only on RHS with "SAME" padding
|
||||
# TODO(bchetioui): LHS dilation with string padding can never be done using
|
||||
# TF convolution functions for now.
|
||||
]
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
(("NWC", "WIO", "NWC"), (1, 28, 1), (3, 1, 16)), # TF default
|
||||
# TODO(bchetioui): the NCW data format is not supported on CPU for TF
|
||||
# for now. That path is thus disabled to allow the code to use XLA instead.
|
||||
]
|
||||
) + tuple(
|
||||
_make_conv_harness("tf_conversion_path_2d", lhs_shape=lhs_shape,
|
||||
padding=padding, rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers, window_strides=(1, 1),
|
||||
lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation)
|
||||
for padding, lhs_dilation, rhs_dilation in [
|
||||
("VALID", (1, 1), (1, 1)), # no dilation with "VALID" padding
|
||||
("SAME", (1, 1), (1, 1)), # no dilation with "SAME" padding
|
||||
("VALID", (1, 1), (1, 2)), # dilation only on RHS with "VALID" padding
|
||||
("SAME", (1, 1), (1, 2)), # dilation only on RHS with "SAME" padding
|
||||
# TODO(bchetioui): LHS dilation with string padding can never be done using
|
||||
# TF convolution functions for now.
|
||||
]
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
(("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)), # TF default
|
||||
# TODO(bchetioui): the NCHW data format is not supported on CPU for TF
|
||||
# for now. That path is thus disabled to allow the code to use XLA instead.
|
||||
]
|
||||
) + tuple(
|
||||
_make_conv_harness("tf_conversion_path_3d", lhs_shape=lhs_shape,
|
||||
padding=padding, rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers,
|
||||
window_strides=(1, 1, 1), lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation)
|
||||
for padding, lhs_dilation, rhs_dilation in [
|
||||
("VALID", (1, 1, 1), (1, 1, 1)), # no dilation with "VALID" padding
|
||||
("SAME", (1, 1, 1), (1, 1, 1)), # no dilation with "SAME" padding
|
||||
("VALID", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "VALID" padding
|
||||
("SAME", (1, 1, 1), (1, 1, 2)), # dilation only on RHS with "SAME" padding
|
||||
# TODO(bchetioui): LHS dilation with string padding can never be done using
|
||||
# TF convolution functions for now.
|
||||
]
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
# TF default
|
||||
(("NDHWC", "DHWIO", "NDHWC"), (1, 4, 28, 28, 1), (2, 3, 3, 1, 16)),
|
||||
# TODO(bchetioui): the NCDHW data format is not supported on CPU for TF
|
||||
# for now. That path is thus disabled to allow the code to use XLA instead.
|
||||
]
|
||||
) + tuple(
|
||||
# tf.nn.convolution only supports a subset of the possible dtypes for JAX
|
||||
# convolutions (float16, float32, float64). With the below tests, we ensure
|
||||
# that we avoid this branch in cases when it would not succeed.
|
||||
_make_conv_harness("tf_conversion_path_dtype", lhs_shape=lhs_shape,
|
||||
padding='VALID', rhs_shape=rhs_shape, dtype=dtype,
|
||||
dimension_numbers=dimension_numbers, lhs_dilation=(1, 1),
|
||||
rhs_dilation=(1, 1))
|
||||
for dtype in jtu.dtypes.all_inexact
|
||||
for dimension_numbers, lhs_shape, rhs_shape in [
|
||||
(("NHWC", "HWIO", "NHWC"), (1, 28, 28, 1), (3, 3, 1, 16)), # TF default
|
||||
]
|
||||
) + tuple(
|
||||
# Validate that feature_group_count != 1 does not go through tf.nn.convolution
|
||||
# as that would result in an exception.
|
||||
[_make_conv_harness("tf_avoid_path_feature_group_count", dtype=np.float32,
|
||||
lhs_shape=(1, 28, 28, 32), rhs_shape=(3, 3, 16, 8),
|
||||
padding='VALID', batch_group_count=1, lhs_dilation=(1, 1),
|
||||
rhs_dilation=(1, 1), feature_group_count=2,
|
||||
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))]
|
||||
)
|
||||
|
@ -614,22 +614,27 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated)
|
||||
def test_conv_general_dilated(self, harness: primitive_harness.Harness):
|
||||
if jtu.device_under_test() == "gpu" and harness.params["dtype"] in [np.complex64, np.complex128]:
|
||||
dtype, device = harness.params["dtype"], jtu.device_under_test()
|
||||
if device == "gpu" and dtype in [np.complex64, np.complex128]:
|
||||
raise unittest.SkipTest("TODO: crash on GPU in TF")
|
||||
|
||||
tol = None
|
||||
if jtu.device_under_test() == "gpu":
|
||||
if device == "gpu":
|
||||
tol = 1e-4
|
||||
elif jtu.device_under_test() == "tpu":
|
||||
elif device == "tpu":
|
||||
tol = 1e-3
|
||||
# TODO(bchetioui): significant discrepancies in some float16 cases.
|
||||
if harness.params["dtype"] == np.float16:
|
||||
if dtype == np.float16:
|
||||
tol = 1.
|
||||
# TODO(bchetioui): slight occasional discrepancy in float32 cases.
|
||||
elif harness.params["dtype"] == np.float32:
|
||||
tol = 0.5 if jtu.device_under_test() == "tpu" else 1e-4
|
||||
elif harness.params["dtype"] == np.complex64 and jtu.device_under_test() == "tpu":
|
||||
elif dtype == np.float32:
|
||||
tol = 0.5 if device == "tpu" else 1e-4
|
||||
elif dtype == np.complex64 and device == "tpu":
|
||||
tol = 0.1
|
||||
# TODO(bchetioui): slight discrepancy when going through the path using
|
||||
# tf.nn.convolution.
|
||||
elif dtype == np.float64 and device == "cpu":
|
||||
tol = 1e-13
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user