mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Reduce number of tests for select_and_gather_add.
This commit is contained in:
parent
2e5335621b
commit
f2ff176c27
@ -924,58 +924,51 @@ lax_select_and_scatter_add = tuple( # Validate dtypes
|
||||
]
|
||||
)
|
||||
|
||||
lax_select_and_gather_add = tuple(
|
||||
# Tests with 2d shapes (see tests.lax_autodiff_test.testReduceWindowGrad)
|
||||
Harness(f"2d_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",
|
||||
lax._select_and_gather_add,
|
||||
[RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(select_prim),
|
||||
StaticArg(window_dimensions), StaticArg(window_strides),
|
||||
StaticArg(padding), StaticArg(base_dilation),
|
||||
StaticArg(window_dilation)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
window_dimensions=window_dimensions,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
base_dilation=base_dilation,
|
||||
window_dilation=window_dilation)
|
||||
def _make_select_and_gather_add_harness(
|
||||
name, *, shape=(4, 6), dtype=np.float32, select_prim=lax.le_p,
|
||||
padding='VALID', window_dimensions=(2, 2), window_strides=(1, 1),
|
||||
base_dilation=(1, 1), window_dilation=(1, 1)):
|
||||
if isinstance(padding, str):
|
||||
padding = tuple(lax.padtype_to_pads(shape, window_dimensions,
|
||||
window_strides, padding))
|
||||
return Harness(f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",
|
||||
lax._select_and_gather_add,
|
||||
[RandArg(shape, dtype), RandArg(shape, dtype),
|
||||
StaticArg(select_prim), StaticArg(window_dimensions),
|
||||
StaticArg(window_strides), StaticArg(padding),
|
||||
StaticArg(base_dilation), StaticArg(window_dilation)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
window_dimensions=window_dimensions,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
base_dilation=base_dilation,
|
||||
window_dilation=window_dilation)
|
||||
|
||||
lax_select_and_gather_add = tuple( # Validate dtypes
|
||||
_make_select_and_gather_add_harness("dtypes", dtype=dtype)
|
||||
for dtype in jtu.dtypes.all_floating
|
||||
for shape in [(4, 6)]
|
||||
for select_prim in [lax.le_p, lax.ge_p]
|
||||
for window_dimensions in [(2, 1), (1, 2)]
|
||||
for window_strides in [(1, 1), (2, 1), (1, 2)]
|
||||
for padding in tuple(set([tuple(lax.padtype_to_pads(shape, window_dimensions,
|
||||
window_strides, p))
|
||||
for p in ['VALID', 'SAME']] +
|
||||
[((0, 3), (1, 2))]))
|
||||
for base_dilation in [(1, 1)]
|
||||
for window_dilation in [(1, 1)]
|
||||
) + tuple(
|
||||
# Tests with 4d shapes (see tests.lax_autodiff_test.testReduceWindowGrad)
|
||||
Harness(f"4d_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",
|
||||
lax._select_and_gather_add,
|
||||
[RandArg(shape, dtype), RandArg(shape, dtype), StaticArg(select_prim),
|
||||
StaticArg(window_dimensions), StaticArg(window_strides),
|
||||
StaticArg(padding), StaticArg(base_dilation),
|
||||
StaticArg(window_dilation)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
window_dimensions=window_dimensions,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
base_dilation=base_dilation,
|
||||
window_dilation=window_dilation)
|
||||
for dtype in jtu.dtypes.all_floating
|
||||
for shape in [(3, 2, 4, 6)]
|
||||
for select_prim in [lax.le_p, lax.ge_p]
|
||||
for window_dimensions in [(1, 1, 2, 1), (2, 1, 2, 1)]
|
||||
for window_strides in [(1, 2, 2, 1), (1, 1, 1, 1)]
|
||||
for padding in tuple(set([tuple(lax.padtype_to_pads(shape, window_dimensions,
|
||||
window_strides, p))
|
||||
for p in ['VALID', 'SAME']] +
|
||||
[((0, 1), (1, 0), (2, 3), (0, 2))]))
|
||||
for base_dilation in [(1, 1, 1, 1)]
|
||||
for window_dilation in [(1, 1, 1, 1)]
|
||||
) + tuple( # Validate selection primitives
|
||||
[_make_select_and_gather_add_harness("select_prim", select_prim=lax.ge_p)]
|
||||
) + tuple( # Validate window dimensions
|
||||
_make_select_and_gather_add_harness("window_dimensions",
|
||||
window_dimensions=window_dimensions)
|
||||
for window_dimensions in [(2, 3)]
|
||||
) + tuple( # Validate window strides
|
||||
_make_select_and_gather_add_harness("window_strides",
|
||||
window_strides=window_strides)
|
||||
for window_strides in [(2, 3)]
|
||||
) + tuple( # Validate padding
|
||||
_make_select_and_gather_add_harness("padding", padding=padding)
|
||||
for padding in ['SAME']
|
||||
) + tuple( # Validate dilations
|
||||
_make_select_and_gather_add_harness("dilations", base_dilation=base_dilation,
|
||||
window_dilation=window_dilation)
|
||||
for base_dilation, window_dilation in [
|
||||
((2, 3), (1, 1)), # base dilation, no window dilation
|
||||
((1, 1), (2, 3)), # no base dilation, window dilation
|
||||
((2, 3), (3, 2)) # base dilation, window dilation
|
||||
]
|
||||
)
|
||||
|
||||
def _make_reduce_window_harness(name, *, shape=(4, 6), base_dilation=(1, 1),
|
||||
|
Loading…
x
Reference in New Issue
Block a user