fixup from 5415306: remove extraneous lines

also add test
This commit is contained in:
Matthew Johnson 2022-03-11 11:36:13 -08:00
parent ee6749608a
commit 39c2f8b051
2 changed files with 32 additions and 2 deletions

View File

@ -278,8 +278,6 @@ def _generic_reduce_window_batch_rule(
operands, init_values = util.split_list(batched_args, [num_operands])
operand_bdims, init_value_bdims = util.split_list(batch_dims, [num_operands])
operand, init = batched_args
bdim, init_bdim = batch_dims
if any(init_bdim is not None for init_bdim in init_value_bdims):
raise NotImplementedError("reduce_window batching is not implemented for "
"initial values")

View File

@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
import jax
import jax.numpy as jnp
from jax import dtypes
from jax import lax
@ -790,6 +791,37 @@ class LaxVmapTest(jtu.JaxTestCase):
# TODO Collapse
# TODO Scatter
# TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU
@jtu.skip_on_devices("gpu")
def test_variadic_reduce_window(self):
# https://github.com/google/jax/discussions/9818 and
# https://github.com/google/jax/issues/9837
def normpool(x):
norms = jnp.linalg.norm(x, axis=-1)
idxs = jnp.arange(x.shape[0])
def g(a, b):
an, ai = a
bn, bi = b
which = an >= bn
return (jnp.where(which, an, bn), jnp.where(which, ai, bi))
_, idxs = lax.reduce_window((norms, idxs), (-np.inf, -1), g,
window_dimensions=(2,), window_strides=(2,),
padding=((0, 0),))
return x[idxs]
inpt = jnp.array([
[1.0, 0.0, 1.0],
[2.0, 2.0, 0.0],
[3.0, 0.0, 1.0],
[0.0, 1.0, 1.0],
])
output = jax.vmap(normpool)(inpt[None, ...]) # doesn't crash
expected = jnp.array([[[2.0, 2.0, 0.0], [3.0, 0.0, 1.0]]])
self.assertAllClose(output, expected, check_dtypes=False)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())