mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support for 0d transpose convolution (#3643)
* Allow 0d transpose convolution * Add test for 0d conv transpose * remove whitespace
This commit is contained in:
parent
d10cf0e38f
commit
4442c333ce
@ -1487,12 +1487,14 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
|
||||
Transposed N-d convolution, with output padding following the conventions of
|
||||
keras.layers.Conv2DTranspose.
|
||||
"""
|
||||
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) > 2
|
||||
assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2
|
||||
ndims = len(lhs.shape)
|
||||
one = (1,) * (ndims - 2)
|
||||
# Set dimensional layout defaults if not specified.
|
||||
if dimension_numbers is None:
|
||||
if ndims == 3:
|
||||
if ndims == 2:
|
||||
dimension_numbers = ('NC', 'IO', 'NC')
|
||||
elif ndims == 3:
|
||||
dimension_numbers = ('NHC', 'HIO', 'NHC')
|
||||
elif ndims == 4:
|
||||
dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
|
||||
|
@ -662,6 +662,43 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# NB: below just checks for agreement, we're not calling numpy.
|
||||
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_rhs_dilation={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, rhs_dilation),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding, "rhs_dilation": rhs_dilation,
|
||||
"rng_factory": rng_factory, 'dspec': dspec}
|
||||
for lhs_shape, rhs_shape in [
|
||||
((b, i), (i, j))
|
||||
for b, i, j in itertools.product([2,3],[2,3],[2,3])]
|
||||
for dtype in float_dtypes
|
||||
for strides in [()]
|
||||
for padding in ["VALID", "SAME"]
|
||||
for dspec in [('NC', 'IO', 'NC'),]
|
||||
for rhs_dilation in [None, ()]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testConvTranspose0D(self, lhs_shape, rhs_shape, dtype, strides,
|
||||
padding, dspec, rhs_dilation, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
|
||||
|
||||
def fun(lhs, rhs):
|
||||
return lax.conv_transpose(lhs, rhs, strides, padding,
|
||||
dimension_numbers=dspec,
|
||||
rhs_dilation=rhs_dilation,
|
||||
transpose_kernel=False)
|
||||
|
||||
def fun_via_grad(lhs, rhs):
|
||||
rhs_t = self._transpose_conv_kernel(lhs, rhs, dimension_numbers=dspec)
|
||||
return self._conv_transpose_via_grad(lhs, rhs_t, strides, padding,
|
||||
rhs_dilation=rhs_dilation,
|
||||
dimension_numbers=dspec)
|
||||
|
||||
# NB: below just checks for agreement, we're not calling numpy.
|
||||
self._CheckAgainstNumpy(fun, fun_via_grad, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}_precision={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user