mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #14696 from gnecula:tf_cumsum_poly
PiperOrigin-RevId: 513835766
This commit is contained in:
commit
04fc30b58a
@ -1825,6 +1825,11 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
|
||||
# Check that all inputs have a consistent leading dimension `num_elems`.
|
||||
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)
|
||||
|
||||
if core.is_special_dim_size(elems_flat[0].shape[axis]):
|
||||
raise NotImplementedError("associative scan over axis "
|
||||
f"of non-constant size: {elems_flat[0].shape[axis]}. You may be "
|
||||
"able to avoid this on TPU.")
|
||||
num_elems = int(elems_flat[0].shape[axis])
|
||||
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
|
||||
raise ValueError('Array inputs to associative_scan must have the same '
|
||||
|
@ -467,6 +467,9 @@ def _reduce_window_lower(
|
||||
operand_aval, = ctx.avals_in
|
||||
scalar_aval = operand_aval.update(shape=())
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
if any(not core.is_constant_shape(s)
|
||||
for s in [window_dimensions, window_dilation, window_strides, base_dilation, *padding]):
|
||||
raise NotImplementedError("ReduceWindowOp for dynamic shapes")
|
||||
rw = hlo.ReduceWindowOp(
|
||||
mlir.aval_to_ir_types(aval_out), [operand],
|
||||
[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)],
|
||||
|
@ -2281,15 +2281,20 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions,
|
||||
|
||||
if not isinstance(init_val, (tf.Tensor, tf.Variable)):
|
||||
init_val = tf.constant(init_val, operand.dtype)
|
||||
window_dimensions_tf = _eval_shape(window_dimensions)
|
||||
window_strides_tf = _eval_shape(window_strides)
|
||||
window_dilation_tf = _eval_shape(window_dilation)
|
||||
base_dilation_tf = _eval_shape(base_dilation)
|
||||
padding_tf = [_eval_shape(p) for p in padding]
|
||||
out = tfxla.reduce_window(
|
||||
operand,
|
||||
init_val,
|
||||
reducer_fn,
|
||||
window_dimensions,
|
||||
window_strides,
|
||||
base_dilations=base_dilation,
|
||||
window_dilations=window_dilation,
|
||||
padding=padding)
|
||||
window_dimensions_tf,
|
||||
window_strides_tf,
|
||||
base_dilations=base_dilation_tf,
|
||||
window_dilations=window_dilation_tf,
|
||||
padding=padding_tf)
|
||||
# TODO: implement shape inference for XlaReduceWindow
|
||||
out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval))
|
||||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||||
|
@ -1855,6 +1855,19 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x: lax_control_flow.cummax(x, axis=1, reverse=False),
|
||||
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
||||
poly_axes=[0]),
|
||||
PolyHarness("jnp.cumsum", "reduce_axis=poly",
|
||||
lambda x: jnp.cumsum(x, axis=0),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
poly_axes=[0],
|
||||
expect_error=(
|
||||
(None, None) if (not config.jax2tf_default_experimental_native_lowering or
|
||||
jtu.device_under_test() == "tpu") else
|
||||
(NotImplementedError,
|
||||
"associative scan over axis of non-constant size"))),
|
||||
PolyHarness("jnp.cumsum", "reduce_axis=static",
|
||||
lambda x: jnp.cumsum(x, axis=1),
|
||||
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||
poly_axes=[0]),
|
||||
PolyHarness("delta", "0",
|
||||
lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)) + x,
|
||||
arg_descriptors=[RandArg((3, 1), _f32)],
|
||||
@ -2567,6 +2580,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism not implemented for JAX primitives still using HLO fallback lowering; b/261682623")
|
||||
|
||||
if harness.fullname == "jnp.cumsum_reduce_axis=poly" and jtu.device_under_test() == "tpu":
|
||||
# https://github.com/openxla/stablehlo/issues/1258
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism not implemented for window_reductions on TPU")
|
||||
harness.run_test(self)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user