mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #14696 from gnecula:tf_cumsum_poly
PiperOrigin-RevId: 513835766
This commit is contained in:
commit
04fc30b58a
@ -1825,6 +1825,11 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
|||||||
|
|
||||||
# Check that all inputs have a consistent leading dimension `num_elems`.
|
# Check that all inputs have a consistent leading dimension `num_elems`.
|
||||||
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)
|
axis = util.canonicalize_axis(axis, elems_flat[0].ndim)
|
||||||
|
|
||||||
|
if core.is_special_dim_size(elems_flat[0].shape[axis]):
|
||||||
|
raise NotImplementedError("associative scan over axis "
|
||||||
|
f"of non-constant size: {elems_flat[0].shape[axis]}. You may be "
|
||||||
|
"able to avoid this on TPU.")
|
||||||
num_elems = int(elems_flat[0].shape[axis])
|
num_elems = int(elems_flat[0].shape[axis])
|
||||||
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
|
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
|
||||||
raise ValueError('Array inputs to associative_scan must have the same '
|
raise ValueError('Array inputs to associative_scan must have the same '
|
||||||
|
@ -467,6 +467,9 @@ def _reduce_window_lower(
|
|||||||
operand_aval, = ctx.avals_in
|
operand_aval, = ctx.avals_in
|
||||||
scalar_aval = operand_aval.update(shape=())
|
scalar_aval = operand_aval.update(shape=())
|
||||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||||
|
if any(not core.is_constant_shape(s)
|
||||||
|
for s in [window_dimensions, window_dilation, window_strides, base_dilation, *padding]):
|
||||||
|
raise NotImplementedError("ReduceWindowOp for dynamic shapes")
|
||||||
rw = hlo.ReduceWindowOp(
|
rw = hlo.ReduceWindowOp(
|
||||||
mlir.aval_to_ir_types(aval_out), [operand],
|
mlir.aval_to_ir_types(aval_out), [operand],
|
||||||
[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)],
|
[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)],
|
||||||
|
@ -2281,15 +2281,20 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions,
|
|||||||
|
|
||||||
if not isinstance(init_val, (tf.Tensor, tf.Variable)):
|
if not isinstance(init_val, (tf.Tensor, tf.Variable)):
|
||||||
init_val = tf.constant(init_val, operand.dtype)
|
init_val = tf.constant(init_val, operand.dtype)
|
||||||
|
window_dimensions_tf = _eval_shape(window_dimensions)
|
||||||
|
window_strides_tf = _eval_shape(window_strides)
|
||||||
|
window_dilation_tf = _eval_shape(window_dilation)
|
||||||
|
base_dilation_tf = _eval_shape(base_dilation)
|
||||||
|
padding_tf = [_eval_shape(p) for p in padding]
|
||||||
out = tfxla.reduce_window(
|
out = tfxla.reduce_window(
|
||||||
operand,
|
operand,
|
||||||
init_val,
|
init_val,
|
||||||
reducer_fn,
|
reducer_fn,
|
||||||
window_dimensions,
|
window_dimensions_tf,
|
||||||
window_strides,
|
window_strides_tf,
|
||||||
base_dilations=base_dilation,
|
base_dilations=base_dilation_tf,
|
||||||
window_dilations=window_dilation,
|
window_dilations=window_dilation_tf,
|
||||||
padding=padding)
|
padding=padding_tf)
|
||||||
# TODO: implement shape inference for XlaReduceWindow
|
# TODO: implement shape inference for XlaReduceWindow
|
||||||
out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval))
|
out = _ensure_tf_shape_if_dynamic(out, _aval_to_tf_shape(_out_aval))
|
||||||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||||||
|
@ -1855,6 +1855,19 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
|||||||
lambda x: lax_control_flow.cummax(x, axis=1, reverse=False),
|
lambda x: lax_control_flow.cummax(x, axis=1, reverse=False),
|
||||||
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
arg_descriptors=[RandArg((3, 4, 5), _f32)],
|
||||||
poly_axes=[0]),
|
poly_axes=[0]),
|
||||||
|
PolyHarness("jnp.cumsum", "reduce_axis=poly",
|
||||||
|
lambda x: jnp.cumsum(x, axis=0),
|
||||||
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||||
|
poly_axes=[0],
|
||||||
|
expect_error=(
|
||||||
|
(None, None) if (not config.jax2tf_default_experimental_native_lowering or
|
||||||
|
jtu.device_under_test() == "tpu") else
|
||||||
|
(NotImplementedError,
|
||||||
|
"associative scan over axis of non-constant size"))),
|
||||||
|
PolyHarness("jnp.cumsum", "reduce_axis=static",
|
||||||
|
lambda x: jnp.cumsum(x, axis=1),
|
||||||
|
arg_descriptors=[RandArg((3, 5), _f32)],
|
||||||
|
poly_axes=[0]),
|
||||||
PolyHarness("delta", "0",
|
PolyHarness("delta", "0",
|
||||||
lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)) + x,
|
lambda x: lax_internal._delta(_f32, x.shape, axes=(0, 1)) + x,
|
||||||
arg_descriptors=[RandArg((3, 1), _f32)],
|
arg_descriptors=[RandArg((3, 1), _f32)],
|
||||||
@ -2567,6 +2580,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
|||||||
raise unittest.SkipTest(
|
raise unittest.SkipTest(
|
||||||
"native lowering with shape polymorphism not implemented for JAX primitives still using HLO fallback lowering; b/261682623")
|
"native lowering with shape polymorphism not implemented for JAX primitives still using HLO fallback lowering; b/261682623")
|
||||||
|
|
||||||
|
if harness.fullname == "jnp.cumsum_reduce_axis=poly" and jtu.device_under_test() == "tpu":
|
||||||
|
# https://github.com/openxla/stablehlo/issues/1258
|
||||||
|
raise unittest.SkipTest(
|
||||||
|
"native lowering with shape polymorphism not implemented for window_reductions on TPU")
|
||||||
harness.run_test(self)
|
harness.run_test(self)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user