mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fixes padding generation for padding == 'SAME' in reduce_window to (#4110)
* Fixes padding generation for padding == 'SAME' in reduce_window to take window_dilation into account. (Fixes google/jax#3973). This commit applies the fix suggested by James on the issue, which is backed by the meaning of padding described on https://www.tensorflow.org/xla/operation_semantics#reducewindow. * Added shape tests for reduce_window when stride is 1 in each direction and padding is 'SAME'.
This commit is contained in:
parent
1e8ac24863
commit
1e6b809818
@ -1143,7 +1143,9 @@ def reduce_window(operand: Array, init_value: Array, computation: Callable,
|
||||
operator.
|
||||
"""
|
||||
if isinstance(padding, str):
|
||||
padding = tuple(padtype_to_pads(operand.shape, window_dimensions,
|
||||
dilated_window_dims = (window_dimensions if window_dilation is None else
|
||||
_dilate_shape(window_dimensions, window_dilation))
|
||||
padding = tuple(padtype_to_pads(operand.shape, dilated_window_dims,
|
||||
window_strides, padding))
|
||||
else:
|
||||
padding = tuple(padding)
|
||||
|
@ -1358,6 +1358,35 @@ class LaxTest(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": (f"_shape={shape}_windowdimensions={window_dimensions}"
|
||||
f"_basedilation={base_dilation}_windowdilation="
|
||||
f"{window_dilation}"),
|
||||
"shape": shape, "window_dimensions": window_dimensions,
|
||||
"base_dilation": base_dilation, "window_dilation": window_dilation}
|
||||
for shape, window_dimensions, base_dilation, window_dilation in (
|
||||
itertools.chain(
|
||||
itertools.product(
|
||||
[(4, 6)],
|
||||
[(1, 1), (3, 4)],
|
||||
[(1, 1), (1, 2), (2, 13), (40, 60)],
|
||||
[(1, 1), (1, 2), (2, 13), (40, 60)]),
|
||||
itertools.product(
|
||||
[(3, 2, 4, 6)],
|
||||
[(1, 1, 1, 1), (2, 1, 2, 1)],
|
||||
[(1, 1, 1, 1), (1, 2, 2, 1), (30, 40, 3, 2)],
|
||||
[(1, 1, 1, 1), (1, 2, 2, 1), (30, 40, 3, 2)])))))
|
||||
def testReduceWindowShapeDilation(self, shape, window_dimensions,
|
||||
base_dilation, window_dilation):
|
||||
operand, padding, strides = np.ones(shape), 'SAME', (1,) * len(shape)
|
||||
result = lax.reduce_window(operand, 0., lax.add, padding=padding,
|
||||
window_strides=strides,
|
||||
window_dimensions=window_dimensions)
|
||||
# With a stride of 1 in each direction and a padding of 'SAME', the
|
||||
# shape of the input should be equal to the shape of the result according
|
||||
# to https://www.tensorflow.org/xla/operation_semantics#reducewindow.
|
||||
self.assertEqual(shape, result.shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
|
Loading…
x
Reference in New Issue
Block a user