mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Allow ConvDimensionNumbers to be passed into conv_transpose (#2915)
This commit is contained in:
parent
4d236b5c47
commit
04102e5b9d
@ -520,10 +520,7 @@ def conv_general_dilated(
|
||||
(for a 2D convolution).
|
||||
"""
|
||||
dnums: ConvDimensionNumbers
|
||||
if isinstance(dimension_numbers, ConvDimensionNumbers):
|
||||
dnums = dimension_numbers
|
||||
else:
|
||||
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
|
||||
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
|
||||
if lhs_dilation is None:
|
||||
lhs_dilation = (1,) * (lhs.ndim - 2)
|
||||
elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):
|
||||
@ -5055,13 +5052,16 @@ def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers):
|
||||
Args:
|
||||
lhs_shape: tuple of nonnegative integers, shape of the convolution input.
|
||||
rhs_shape: tuple of nonnegative integers, shape of the convolution kernel.
|
||||
dimension_numbers: None or a tuple/list of strings, following the
|
||||
convolution dimension number specification format in xla_client.py.
|
||||
dimension_numbers: None or a tuple/list of strings or a ConvDimensionNumbers
|
||||
object following the convolution dimension number specification format in
|
||||
xla_client.py.
|
||||
|
||||
Returns:
|
||||
A `ConvDimensionNumbers` object that represents `dimension_numbers` in the
|
||||
canonical form used by lax functions.
|
||||
"""
|
||||
if isinstance(dimension_numbers, ConvDimensionNumbers):
|
||||
return dimension_numbers
|
||||
if len(lhs_shape) != len(rhs_shape):
|
||||
msg = "convolution requires lhs and rhs ndim to be equal, got {} and {}."
|
||||
raise TypeError(msg.format(len(lhs_shape), len(rhs_shape)))
|
||||
|
Loading…
x
Reference in New Issue
Block a user