Allow ConvDimensionNumbers to be passed into conv_transpose (#2915)

This commit is contained in:
tamaranorman 2020-05-04 19:02:13 +01:00 committed by GitHub
parent 4d236b5c47
commit 04102e5b9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -520,10 +520,7 @@ def conv_general_dilated(
(for a 2D convolution).
"""
dnums: ConvDimensionNumbers
if isinstance(dimension_numbers, ConvDimensionNumbers):
dnums = dimension_numbers
else:
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
if lhs_dilation is None:
lhs_dilation = (1,) * (lhs.ndim - 2)
elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):
@ -5055,13 +5052,16 @@ def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers):
Args:
lhs_shape: tuple of nonnegative integers, shape of the convolution input.
rhs_shape: tuple of nonnegative integers, shape of the convolution kernel.
dimension_numbers: None or a tuple/list of strings, following the
convolution dimension number specification format in xla_client.py.
dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers
object following the convolution dimension number specification format in
xla_client.py.
Returns:
A `ConvDimensionNumbers` object that represents `dimension_numbers` in the
canonical form used by lax functions.
"""
if isinstance(dimension_numbers, ConvDimensionNumbers):
return dimension_numbers
if len(lhs_shape) != len(rhs_shape):
msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}."
raise TypeError(msg.format(len(lhs_shape), len(rhs_shape)))