mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve reduce-window testing.
This commit is contained in:
parent
7b3eff82b7
commit
7e3433e2d3
@ -289,7 +289,10 @@ def reduce(operand, init_value, computation, dimensions): # pylint: disable=red
|
||||
def reduce_window(operand, init_value, computation, window_dimensions,
|
||||
window_strides, padding):
|
||||
op, dims, strides = operand, window_dimensions, window_strides
|
||||
pads = padtype_to_pads(op.shape, dims, strides, padding)
|
||||
if isinstance(padding, str):
|
||||
pads = padtype_to_pads(op.shape, dims, strides, padding)
|
||||
else:
|
||||
pads = padding
|
||||
view = _conv_view(op.reshape((1, 1) + op.shape), (1, 1) + dims, strides, pads,
|
||||
pad_value=init_value)[0]
|
||||
view = view.reshape(view.shape[1:1+len(dims)] + (-1,))
|
||||
|
@ -1302,9 +1302,9 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}_padding={}"
|
||||
.format(op.__name__, np.dtype(dtype).name, padding),
|
||||
"op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
|
||||
{"testcase_name": "_op={}_dtype={}"
|
||||
.format(op.__name__, np.dtype(dtype).name,),
|
||||
"op": op, "init_val": init_val, "dtype": dtype,
|
||||
"rng_factory": rng_factory}
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, [np.float32]),
|
||||
@ -1312,28 +1312,34 @@ class LaxTest(jtu.JaxTestCase):
|
||||
(np.inf, lax.min, [np.float32]),
|
||||
]
|
||||
for dtype in dtypes
|
||||
for padding in ["VALID", "SAME"]
|
||||
for rng_factory in [jtu.rand_small]))
|
||||
def testReduceWindow(self, op, init_val, dtype, padding, rng_factory):
|
||||
def testReduceWindow(self, op, init_val, dtype, rng_factory):
|
||||
rng = rng_factory(self.rng())
|
||||
init_val = np.asarray(init_val, dtype=dtype)
|
||||
|
||||
all_configs = itertools.chain(
|
||||
all_configs = list(itertools.chain(
|
||||
itertools.product(
|
||||
[(4, 6)],
|
||||
[(2, 1), (1, 2)],
|
||||
[(1, 1), (2, 1), (1, 2)]),
|
||||
[(1, 1), (2, 1), (1, 2)],
|
||||
["VALID", "SAME", [(0, 3), (1, 2)]]),
|
||||
itertools.product(
|
||||
[(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)]))
|
||||
[(1, 2, 2, 1), (1, 1, 1, 1)],
|
||||
["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]])))
|
||||
|
||||
def fun(operand, init_val):
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
||||
|
||||
def reference_fun(operand, init_val):
|
||||
return lax_reference.reduce_window(operand, init_val, op, dims, strides,
|
||||
padding)
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
for shape, dims, strides in all_configs:
|
||||
for shape, dims, strides, padding in all_configs:
|
||||
args_maker = lambda: [rng(shape, dtype), init_val]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
self._CheckAgainstNumpy(fun, reference_fun, args_maker)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
# we separately test the version that uses a concrete init_val because it
|
||||
@ -1342,7 +1348,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
return lax.reduce_window(operand, init_val, op, dims, strides, padding)
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
for shape, dims, strides in all_configs:
|
||||
for shape, dims, strides, padding in all_configs:
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
Loading…
x
Reference in New Issue
Block a user