Fix failures in lax.reduce_window on scalar inputs.

Fixes https://github.com/google/jax/issues/10834

PiperOrigin-RevId: 452327212
This commit is contained in:
Peter Hawkins 2022-06-01 10:23:42 -07:00 committed by jax authors
parent e225317ff8
commit 7306694a09
2 changed files with 29 additions and 5 deletions

View File

@ -304,7 +304,8 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
window_strides=mlir.dense_int_elements(window_strides),
base_dilations=mlir.dense_int_elements(base_dilation),
window_dilations=mlir.dense_int_elements(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
with ir.InsertionPoint(reducer):
if jaxpr.effects:
@ -417,7 +418,7 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
operand_shape = lax._dilate_shape(operand_shape, base_dilation)
if window_dilation is not None:
window_dimensions = lax._dilate_shape(window_dimensions, window_dilation)
pads_lo, pads_hi = zip(*padding)
pads_lo, pads_hi = [(), ()] if len(padding) == 0 else zip(*padding)
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
return core.stride_shape(operand_padded, window_dimensions, window_strides)
@ -453,7 +454,8 @@ def _reduce_window_lower(
window_strides=mlir.dense_int_elements(window_strides),
base_dilations=mlir.dense_int_elements(base_dilation),
window_dilations=mlir.dense_int_elements(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
mhlo.ReturnOp(reduce_op(*reducer.arguments))
@ -498,7 +500,8 @@ def _select_and_scatter_lower(
init_value,
window_dimensions=mlir.dense_int_elements(window_dimensions),
window_strides=mlir.dense_int_elements(window_strides),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
select = op.select.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(select):
if select_jaxpr.effects:
@ -743,7 +746,8 @@ def _select_and_gather_add_lowering(
window_strides=mlir.dense_int_elements(window_strides),
base_dilations=mlir.dense_int_elements(base_dilation),
window_dilations=mlir.dense_int_elements(window_dilation),
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
shape=(len(padding), 2)))
scalar_type = ir.RankedTensorType.get([], double_word_type)
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):

View File

@ -1756,6 +1756,26 @@ class LaxTest(jtu.JaxTestCase):
out_jit = jax.jit(fun)(arr, init)
self.assertEqual(dtypes.is_weakly_typed(out_jit), arr_weak_type and init_weak_type)
def testReduceWindowScalar(self):
rng = jtu.rand_small(self.rng())
dtype = jnp.float32
init_val = np.asarray(0, dtype=dtype)
op = lax.add
def fun(operand, init_val):
return lax.reduce_window(
operand, init_val, op, window_dimensions=(), window_strides=(),
padding=(), base_dilation=(), window_dilation=())
def reference_fun(operand, init_val):
return lax_reference.reduce_window(
operand, init_val, op, window_dimensions=(), window_strides=(),
padding=(), base_dilation=())
args_maker = lambda: [rng((), dtype), init_val]
self._CompileAndCheck(fun, args_maker)
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
"_basedilation={}_windowdilation={}")