mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
fixup from 5415306: remove extraneous lines
also add test
This commit is contained in:
parent
ee6749608a
commit
39c2f8b051
@ -278,8 +278,6 @@ def _generic_reduce_window_batch_rule(
|
||||
operands, init_values = util.split_list(batched_args, [num_operands])
|
||||
operand_bdims, init_value_bdims = util.split_list(batch_dims, [num_operands])
|
||||
|
||||
operand, init = batched_args
|
||||
bdim, init_bdim = batch_dims
|
||||
if any(init_bdim is not None for init_bdim in init_value_bdims):
|
||||
raise NotImplementedError("reduce_window batching is not implemented for "
|
||||
"initial values")
|
||||
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
|
||||
@ -790,6 +791,37 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
# TODO Collapse
|
||||
# TODO Scatter
|
||||
|
||||
# TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def test_variadic_reduce_window(self):
|
||||
# https://github.com/google/jax/discussions/9818 and
|
||||
# https://github.com/google/jax/issues/9837
|
||||
def normpool(x):
|
||||
norms = jnp.linalg.norm(x, axis=-1)
|
||||
idxs = jnp.arange(x.shape[0])
|
||||
|
||||
def g(a, b):
|
||||
an, ai = a
|
||||
bn, bi = b
|
||||
which = an >= bn
|
||||
return (jnp.where(which, an, bn), jnp.where(which, ai, bi))
|
||||
|
||||
_, idxs = lax.reduce_window((norms, idxs), (-np.inf, -1), g,
|
||||
window_dimensions=(2,), window_strides=(2,),
|
||||
padding=((0, 0),))
|
||||
return x[idxs]
|
||||
|
||||
|
||||
inpt = jnp.array([
|
||||
[1.0, 0.0, 1.0],
|
||||
[2.0, 2.0, 0.0],
|
||||
[3.0, 0.0, 1.0],
|
||||
[0.0, 1.0, 1.0],
|
||||
])
|
||||
output = jax.vmap(normpool)(inpt[None, ...]) # doesn't crash
|
||||
expected = jnp.array([[[2.0, 2.0, 0.0], [3.0, 0.0, 1.0]]])
|
||||
self.assertAllClose(output, expected, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user