Fixes padding generation for padding == 'SAME' in reduce_window to (#4110)

* Fixes padding generation for padding == 'SAME' in reduce_window to
take window_dilation into account. (Fixes google/jax#3973).

This commit applies the fix suggested by James on the issue,
which is backed by the meaning of padding described on
https://www.tensorflow.org/xla/operation_semantics#reducewindow.

* Added shape tests for reduce_window when stride is 1 in each
direction and padding is 'SAME'.
This commit is contained in:
Benjamin Chetioui 2020-08-20 20:45:15 +02:00 committed by GitHub
parent 1e8ac24863
commit 1e6b809818
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 1 deletions

View File

@ -1143,7 +1143,9 @@ def reduce_window(operand: Array, init_value: Array, computation: Callable,
operator.
"""
if isinstance(padding, str):
padding = tuple(padtype_to_pads(operand.shape, window_dimensions,
dilated_window_dims = (window_dimensions if window_dilation is None else
_dilate_shape(window_dimensions, window_dilation))
padding = tuple(padtype_to_pads(operand.shape, dilated_window_dims,
window_strides, padding))
else:
padding = tuple(padding)

View File

@ -1358,6 +1358,35 @@ class LaxTest(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": (f"_shape={shape}_windowdimensions={window_dimensions}"
f"_basedilation={base_dilation}_windowdilation="
f"{window_dilation}"),
"shape": shape, "window_dimensions": window_dimensions,
"base_dilation": base_dilation, "window_dilation": window_dilation}
for shape, window_dimensions, base_dilation, window_dilation in (
itertools.chain(
itertools.product(
[(4, 6)],
[(1, 1), (3, 4)],
[(1, 1), (1, 2), (2, 13), (40, 60)],
[(1, 1), (1, 2), (2, 13), (40, 60)]),
itertools.product(
[(3, 2, 4, 6)],
[(1, 1, 1, 1), (2, 1, 2, 1)],
[(1, 1, 1, 1), (1, 2, 2, 1), (30, 40, 3, 2)],
[(1, 1, 1, 1), (1, 2, 2, 1), (30, 40, 3, 2)])))))
def testReduceWindowShapeDilation(self, shape, window_dimensions,
base_dilation, window_dilation):
operand, padding, strides = np.ones(shape), 'SAME', (1,) * len(shape)
result = lax.reduce_window(operand, 0., lax.add, padding=padding,
window_strides=strides,
window_dimensions=window_dimensions)
# With a stride of 1 in each direction and a padding of 'SAME', the
# shape of the input should be equal to the shape of the result according
# to https://www.tensorflow.org/xla/operation_semantics#reducewindow.
self.assertEqual(shape, result.shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_shape={}_axis={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),