From c51537f827ee082ccb69da85dce0c17f36c0a7e6 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 27 Feb 2023 11:12:50 +0200 Subject: [PATCH] [shape_poly] Add support for jnp.cum{sum,prod,max,min} with shape polymorphism Unfortunately, on CPU and GPU where we use associative scan, we cannot support shape polymorphism with native lowering. --- jax/_src/lax/control_flow/loops.py | 5 +++++ jax/_src/lax/windowed_reductions.py | 3 +++ jax/experimental/jax2tf/jax2tf.py | 15 ++++++++++----- .../jax2tf/tests/shape_poly_test.py | 17 +++++++++++++++++ 4 files changed, 35 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index c70c06af6..6777edfff 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1825,6 +1825,11 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): # Check that all inputs have a consistent leading dimension `num_elems`. axis = util.canonicalize_axis(axis, elems_flat[0].ndim) + + if core.is_special_dim_size(elems_flat[0].shape[axis]): + raise NotImplementedError("associative scan over axis " + f"of non-constant size: {elems_flat[0].shape[axis]}. You may be " + "able to avoid this on TPU.") num_elems = int(elems_flat[0].shape[axis]) if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): raise ValueError('Array inputs to associative_scan must have the same ' diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index f7049d600..4a074de35 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -467,6 +467,9 @@ def _reduce_window_lower( operand_aval, = ctx.avals_in scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) + if any(not core.is_constant_shape(s) + for s in [window_dimensions, window_dilation, window_strides, base_dilation, *padding]): + raise NotImplementedError("ReduceWindowOp for dynamic shapes") rw = hlo.ReduceWindowOp( mlir.aval_to_ir_types(aval_out), [operand], [mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)], diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c2a2759fc..4116ecbc2 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2281,15 +2281,20 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions, if not isinstance(init_val, (tf.Tensor, tf.Variable)): init_val = tf.constant(init_val, operand.dtype) + window_dimensions_tf = _eval_shape(window_dimensions) + window_strides_tf = _eval_shape(window_strides) + window_dilation_tf = _eval_shape(window_dilation) + base_dilation_tf = _eval_shape(base_dilation) + padding_tf = [_eval_shape(p) for p in padding] out = tfxla.reduce_window( operand, init_val, reducer_fn, - window_dimensions, - window_strides, - base_dilations=base_dilation, - window_dilations=window_dilation, - padding=padding) + window_dimensions_tf, + window_strides_tf, + base_dilations=base_dilation_tf, + window_dilations=window_dilation_tf, + padding=padding_tf) # TODO: implement shape inference for XlaReduceWindow out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval)) if _WRAP_JAX_JIT_WITH_TF_FUNCTION: diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 94b899c0d..f0e3d1b97 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1855,6 +1855,19 @@ _POLY_SHAPE_TEST_HARNESSES = [ lambda x: lax_control_flow.cummax(x, axis=1, reverse=False), arg_descriptors=[RandArg((3, 4, 5), _f32)], poly_axes=[0]), + PolyHarness("jnp.cumsum", "reduce_axis=poly", + lambda x: jnp.cumsum(x, axis=0), + arg_descriptors=[RandArg((3, 5), _f32)], + poly_axes=[0], + expect_error=( + (None, None) if (not config.jax2tf_default_experimental_native_lowering or + jtu.device_under_test() == "tpu") else + (NotImplementedError, + "associative scan over axis of non-constant size"))), + PolyHarness("jnp.cumsum", "reduce_axis=static", + lambda x: jnp.cumsum(x, axis=1), + arg_descriptors=[RandArg((3, 5), _f32)], + poly_axes=[0]), PolyHarness("delta", "0", lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)) + x, arg_descriptors=[RandArg((3, 1), _f32)], @@ -2567,6 +2580,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): raise unittest.SkipTest( "native lowering with shape polymorphism not implemented for JAX primitives still using HLO fallback lowering; b/261682623") + if harness.fullname == "jnp.cumsum_reduce_axis=poly" and jtu.device_under_test() == "tpu": + # https://github.com/openxla/stablehlo/issues/1258 + raise unittest.SkipTest( + "native lowering with shape polymorphism not implemented for window_reductions on TPU") harness.run_test(self)