Merge pull request #14696 from gnecula:tf_cumsum_poly

PiperOrigin-RevId: 513835766
This commit is contained in:
jax authors 2023-03-03 08:36:11 -08:00
commit 04fc30b58a
4 changed files with 35 additions and 5 deletions

View File

@ -1825,6 +1825,11 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
# Check that all inputs have a consistent leading dimension `num_elems`.
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)
if core.is_special_dim_size(elems_flat[0].shape[axis]):
raise NotImplementedError("associative scan over axis "
f"of non-constant size: {elems_flat[0].shape[axis]}. You may be "
"able to avoid this on TPU.")
num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
raise ValueError('Array inputs to associative_scan must have the same '

View File

@ -467,6 +467,9 @@ def _reduce_window_lower(
operand_aval, = ctx.avals_in
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
if any(not core.is_constant_shape(s)
for s in [window_dimensions, window_dilation, window_strides, base_dilation, *padding]):
raise NotImplementedError("ReduceWindowOp for dynamic shapes")
rw = hlo.ReduceWindowOp(
mlir.aval_to_ir_types(aval_out), [operand],
[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)],

View File

@ -2281,15 +2281,20 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions,
if not isinstance(init_val, (tf.Tensor, tf.Variable)):
init_val = tf.constant(init_val, operand.dtype)
window_dimensions_tf = _eval_shape(window_dimensions)
window_strides_tf = _eval_shape(window_strides)
window_dilation_tf = _eval_shape(window_dilation)
base_dilation_tf = _eval_shape(base_dilation)
padding_tf = [_eval_shape(p) for p in padding]
out = tfxla.reduce_window(
operand,
init_val,
reducer_fn,
window_dimensions,
window_strides,
base_dilations=base_dilation,
window_dilations=window_dilation,
padding=padding)
window_dimensions_tf,
window_strides_tf,
base_dilations=base_dilation_tf,
window_dilations=window_dilation_tf,
padding=padding_tf)
# TODO: implement shape inference for XlaReduceWindow
out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval))
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:

View File

@ -1855,6 +1855,19 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x: lax_control_flow.cummax(x, axis=1, reverse=False),
arg_descriptors=[RandArg((3, 4, 5), _f32)],
poly_axes=[0]),
PolyHarness("jnp.cumsum", "reduce_axis=poly",
lambda x: jnp.cumsum(x, axis=0),
arg_descriptors=[RandArg((3, 5), _f32)],
poly_axes=[0],
expect_error=(
(None, None) if (not config.jax2tf_default_experimental_native_lowering or
jtu.device_under_test() == "tpu") else
(NotImplementedError,
"associative scan over axis of non-constant size"))),
PolyHarness("jnp.cumsum", "reduce_axis=static",
lambda x: jnp.cumsum(x, axis=1),
arg_descriptors=[RandArg((3, 5), _f32)],
poly_axes=[0]),
PolyHarness("delta", "0",
lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)) + x,
arg_descriptors=[RandArg((3, 1), _f32)],
@ -2567,6 +2580,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
raise unittest.SkipTest(
"native lowering with shape polymorphism not implemented for JAX primitives still using HLO fallback lowering; b/261682623")
if harness.fullname == "jnp.cumsum_reduce_axis=poly" and jtu.device_under_test() == "tpu":
# https://github.com/openxla/stablehlo/issues/1258
raise unittest.SkipTest(
"native lowering with shape polymorphism not implemented for window_reductions on TPU")
harness.run_test(self)