[jax2tf] Reduce number of tests for select_and_gather_add.

This commit is contained in:
Benjamin Chetioui 2020-10-30 09:53:28 +01:00
parent 2e5335621b
commit f2ff176c27

View File

@ -924,58 +924,51 @@ lax_select_and_scatter_add = tuple( # Validate dtypes
]
)
lax_select_and_gather_add = tuple(
# Tests with 2d shapes (see tests.lax_autodiff_test.testReduceWindowGrad)
Harness(f"2d_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",
lax._select_and_gather_add,
[RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(select_prim),
StaticArg(window_dimensions), StaticArg(window_strides),
StaticArg(padding), StaticArg(base_dilation),
StaticArg(window_dilation)],
shape=shape,
dtype=dtype,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation)
def _make_select_and_gather_add_harness(
name, *, shape=(4, 6), dtype=np.float32, select_prim=lax.le_p,
padding='VALID', window_dimensions=(2, 2), window_strides=(1, 1),
base_dilation=(1, 1), window_dilation=(1, 1)):
if isinstance(padding, str):
padding = tuple(lax.padtype_to_pads(shape, window_dimensions,
window_strides, padding))
return Harness(f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",
lax._select_and_gather_add,
[RandArg(shape, dtype), RandArg(shape, dtype),
StaticArg(select_prim), StaticArg(window_dimensions),
StaticArg(window_strides), StaticArg(padding),
StaticArg(base_dilation), StaticArg(window_dilation)],
shape=shape,
dtype=dtype,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation)
lax_select_and_gather_add = tuple( # Validate dtypes
_make_select_and_gather_add_harness("dtypes", dtype=dtype)
for dtype in jtu.dtypes.all_floating
for shape in [(4, 6)]
for select_prim in [lax.le_p, lax.ge_p]
for window_dimensions in [(2, 1), (1, 2)]
for window_strides in [(1, 1), (2, 1), (1, 2)]
for padding in tuple(set([tuple(lax.padtype_to_pads(shape, window_dimensions,
window_strides, p))
for p in ['VALID', 'SAME']] +
[((0, 3), (1, 2))]))
for base_dilation in [(1, 1)]
for window_dilation in [(1, 1)]
) + tuple(
# Tests with 4d shapes (see tests.lax_autodiff_test.testReduceWindowGrad)
Harness(f"4d_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",
lax._select_and_gather_add,
[RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(select_prim),
StaticArg(window_dimensions), StaticArg(window_strides),
StaticArg(padding), StaticArg(base_dilation),
StaticArg(window_dilation)],
shape=shape,
dtype=dtype,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
base_dilation=base_dilation,
window_dilation=window_dilation)
for dtype in jtu.dtypes.all_floating
for shape in [(3, 2, 4, 6)]
for select_prim in [lax.le_p, lax.ge_p]
for window_dimensions in [(1, 1, 2, 1), (2, 1, 2, 1)]
for window_strides in [(1, 2, 2, 1), (1, 1, 1, 1)]
for padding in tuple(set([tuple(lax.padtype_to_pads(shape, window_dimensions,
window_strides, p))
for p in ['VALID', 'SAME']] +
[((0, 1), (1, 0), (2, 3), (0, 2))]))
for base_dilation in [(1, 1, 1, 1)]
for window_dilation in [(1, 1, 1, 1)]
) + tuple( # Validate selection primitives
[_make_select_and_gather_add_harness("select_prim", select_prim=lax.ge_p)]
) + tuple( # Validate window dimensions
_make_select_and_gather_add_harness("window_dimensions",
window_dimensions=window_dimensions)
for window_dimensions in [(2, 3)]
) + tuple( # Validate window strides
_make_select_and_gather_add_harness("window_strides",
window_strides=window_strides)
for window_strides in [(2, 3)]
) + tuple( # Validate padding
_make_select_and_gather_add_harness("padding", padding=padding)
for padding in ['SAME']
) + tuple( # Validate dilations
_make_select_and_gather_add_harness("dilations", base_dilation=base_dilation,
window_dilation=window_dilation)
for base_dilation, window_dilation in [
((2, 3), (1, 1)), # base dilation, no window dilation
((1, 1), (2, 3)), # no base dilation, window dilation
((2, 3), (3, 2)) # base dilation, window dilation
]
)
def _make_reduce_window_harness(name, *, shape=(4, 6), base_dilation=(1, 1),