mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix failures in lax.reduce_window on scalar inputs.
Fixes https://github.com/google/jax/issues/10834 PiperOrigin-RevId: 452327212
This commit is contained in:
parent
e225317ff8
commit
7306694a09
@ -304,7 +304,8 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
|
||||
window_strides=mlir.dense_int_elements(window_strides),
|
||||
base_dilations=mlir.dense_int_elements(base_dilation),
|
||||
window_dilations=mlir.dense_int_elements(window_dilation),
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
|
||||
shape=(len(padding), 2)))
|
||||
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
|
||||
with ir.InsertionPoint(reducer):
|
||||
if jaxpr.effects:
|
||||
@ -417,7 +418,7 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
|
||||
operand_shape = lax._dilate_shape(operand_shape, base_dilation)
|
||||
if window_dilation is not None:
|
||||
window_dimensions = lax._dilate_shape(window_dimensions, window_dilation)
|
||||
pads_lo, pads_hi = zip(*padding)
|
||||
pads_lo, pads_hi = [(), ()] if len(padding) == 0 else zip(*padding)
|
||||
operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
|
||||
return core.stride_shape(operand_padded, window_dimensions, window_strides)
|
||||
|
||||
@ -453,7 +454,8 @@ def _reduce_window_lower(
|
||||
window_strides=mlir.dense_int_elements(window_strides),
|
||||
base_dilations=mlir.dense_int_elements(base_dilation),
|
||||
window_dilations=mlir.dense_int_elements(window_dilation),
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
|
||||
shape=(len(padding), 2)))
|
||||
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer):
|
||||
mhlo.ReturnOp(reduce_op(*reducer.arguments))
|
||||
@ -498,7 +500,8 @@ def _select_and_scatter_lower(
|
||||
init_value,
|
||||
window_dimensions=mlir.dense_int_elements(window_dimensions),
|
||||
window_strides=mlir.dense_int_elements(window_strides),
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
|
||||
shape=(len(padding), 2)))
|
||||
select = op.select.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(select):
|
||||
if select_jaxpr.effects:
|
||||
@ -743,7 +746,8 @@ def _select_and_gather_add_lowering(
|
||||
window_strides=mlir.dense_int_elements(window_strides),
|
||||
base_dilations=mlir.dense_int_elements(base_dilation),
|
||||
window_dilations=mlir.dense_int_elements(window_dilation),
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64)))
|
||||
padding=ir.DenseIntElementsAttr.get(np.asarray(padding, np.int64),
|
||||
shape=(len(padding), 2)))
|
||||
scalar_type = ir.RankedTensorType.get([], double_word_type)
|
||||
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer):
|
||||
|
@ -1756,6 +1756,26 @@ class LaxTest(jtu.JaxTestCase):
|
||||
out_jit = jax.jit(fun)(arr, init)
|
||||
self.assertEqual(dtypes.is_weakly_typed(out_jit), arr_weak_type and init_weak_type)
|
||||
|
||||
def testReduceWindowScalar(self):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
dtype = jnp.float32
|
||||
init_val = np.asarray(0, dtype=dtype)
|
||||
op = lax.add
|
||||
|
||||
def fun(operand, init_val):
|
||||
return lax.reduce_window(
|
||||
operand, init_val, op, window_dimensions=(), window_strides=(),
|
||||
padding=(), base_dilation=(), window_dilation=())
|
||||
|
||||
def reference_fun(operand, init_val):
|
||||
return lax_reference.reduce_window(
|
||||
operand, init_val, op, window_dimensions=(), window_strides=(),
|
||||
padding=(), base_dilation=())
|
||||
|
||||
args_maker = lambda: [rng((), dtype), init_val]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
self._CheckAgainstNumpy(reference_fun, fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
|
||||
"_basedilation={}_windowdilation={}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user