Add public APIs for jax.lax monoidal reductions

This commit is contained in:
Jake VanderPlas 2025-02-11 16:00:03 -08:00
parent d0b6c677b0
commit e389b707ba
12 changed files with 256 additions and 45 deletions

View File

@ -21,6 +21,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
decorator to support customizing the behavior of opaque functions under
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
details.
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as

View File

@ -130,8 +130,15 @@ Operators
real
reciprocal
reduce
reduce_and
reduce_max
reduce_min
reduce_or
reduce_precision
reduce_prod
reduce_sum
reduce_window
reduce_xor
rem
reshape
rev

View File

@ -49,8 +49,11 @@ uint_dtypes = test_util.dtypes.all_unsigned
bool_dtypes = test_util.dtypes.boolean
default_dtypes = float_dtypes + int_dtypes
number_dtypes = (
float_dtypes + complex_dtypes + int_dtypes + uint_dtypes
)
all_dtypes = (
float_dtypes + complex_dtypes + int_dtypes + uint_dtypes + bool_dtypes
number_dtypes + bool_dtypes
)
python_scalar_types = [bool, int, float, complex]
@ -151,6 +154,25 @@ def lax_reduce_ops():
]
NamedReducerOpRecord = collections.namedtuple(
"NamedReducerOpRecord", ["op", "reference_op", "dtypes"]
)
def lax_named_reduce_ops():
return [
NamedReducerOpRecord(lax.reduce_sum, np.sum, number_dtypes),
NamedReducerOpRecord(lax.reduce_prod, np.prod, number_dtypes),
NamedReducerOpRecord(lax.reduce_max, np.max, all_dtypes),
NamedReducerOpRecord(lax.reduce_min, np.min, all_dtypes),
NamedReducerOpRecord(lax.reduce_and, np.bitwise_and.reduce,
bool_dtypes + int_dtypes + uint_dtypes),
NamedReducerOpRecord(lax.reduce_or, np.bitwise_or.reduce,
bool_dtypes + int_dtypes + uint_dtypes),
NamedReducerOpRecord(lax.reduce_xor, np.bitwise_xor.reduce,
bool_dtypes + int_dtypes + uint_dtypes),
]
def lax_ops():
return [
op_record(

View File

@ -1752,7 +1752,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
# Pred can be batched
pred = core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
if batched:
pred = lax._reduce_or(pred, tuple(range(len(pred_aval.shape))))
pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape))))
return pred
def body(args):
return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args))

View File

@ -2177,19 +2177,19 @@ def _get_monoid_reducer(monoid_op: Callable,
# allow bitwise reductions for boolean and integer types
_is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer)
if monoid_op is add:
return _reduce_sum if np.equal(val, 0) else None
return reduce_sum if np.equal(val, 0) else None
elif monoid_op is mul:
return _reduce_prod if np.equal(val, 1) else None
return reduce_prod if np.equal(val, 1) else None
elif monoid_op is bitwise_or and _is_intlike:
return _reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None
return reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None
elif monoid_op is bitwise_and and _is_intlike:
return _reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None
return reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None
elif monoid_op is bitwise_xor and _is_intlike:
return _reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None
return reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None
elif monoid_op is max:
return _reduce_max if np.equal(val, _get_max_identity(dtype)) else None
return reduce_max if np.equal(val, _get_max_identity(dtype)) else None
elif monoid_op is min:
return _reduce_min if np.equal(val, _get_min_identity(dtype)) else None
return reduce_min if np.equal(val, _get_min_identity(dtype)) else None
return None
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:
@ -2226,25 +2226,164 @@ def _get_min_identity(dtype: DTypeLike) -> np.ndarray:
else:
raise ValueError(f"Unsupported dtype for min: {dtype}")
def _reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array:
def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array:
"""Compute the sum of elements over one or more array axes.
Args:
operand: array over which to sum. Must have numerical dtype.
axes: sequence of zero or more unique integers specifying the axes over
which to sum. Each entry must satisfy ``0 <= axis < operand.ndim``.
Returns:
An array of the same dtype as ``operand``, with shape corresponding
to the dimensions of ``operand.shape`` with ``axes`` removed.
Notes:
Unlike :func:`jax.numpy.sum`, :func:`jax.lax.reduce_sum` does not upcast
narrow-width types for accumulation, so sums of 8-bit or 16-bit types
may be subject to rounding errors.
See also:
- :func:`jax.numpy.sum`: more flexible NumPy-style summation API, built
around :func:`jax.lax.reduce_sum`.
- Other low-level :mod:`jax.lax` reduction operators:
:func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`,
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
"""
return reduce_sum_p.bind(operand, axes=tuple(axes))
def _reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array:
def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array:
"""Compute the product of elements over one or more array axes.
Args:
operand: array over which to sum. Must have numerical dtype.
axes: sequence of zero or more unique integers specifying the axes over
which to sum. Each entry must satisfy ``0 <= axis < operand.ndim``.
Returns:
An array of the same dtype as ``operand``, with shape corresponding
to the dimensions of ``operand.shape`` with ``axes`` removed.
Notes:
Unlike :func:`jax.numpy.prod`, :func:`jax.lax.reduce_prod` does not upcast
narrow-width types for accumulation, so products of 8-bit or 16-bit types
may be subject to rounding errors.
See also:
- :func:`jax.numpy.prod`: more flexible NumPy-style product API, built
around :func:`jax.lax.reduce_prod`.
- Other low-level :mod:`jax.lax` reduction operators:
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`,
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
"""
return reduce_prod_p.bind(operand, axes=tuple(axes))
def _reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array:
def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array:
"""Compute the maximum of elements over one or more array axes.
Args:
operand: array over which to compute maximum.
axes: sequence of zero or more unique integers specifying the axes over
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
Returns:
An array of the same dtype as ``operand``, with shape corresponding
to the dimensions of ``operand.shape`` with ``axes`` removed.
See also:
- :func:`jax.numpy.max`: more flexible NumPy-style max-reduction API, built
around :func:`jax.lax.reduce_max`.
- Other low-level :mod:`jax.lax` reduction operators:
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_min`,
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
"""
return reduce_max_p.bind(operand, axes=tuple(axes))
def _reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array:
def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array:
"""Compute the minimum of elements over one or more array axes.
Args:
operand: array over which to compute minimum.
axes: sequence of zero or more unique integers specifying the axes over
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
Returns:
An array of the same dtype as ``operand``, with shape corresponding
to the dimensions of ``operand.shape`` with ``axes`` removed.
See also:
- :func:`jax.numpy.min`: more flexible NumPy-style min-reduction API, built
around :func:`jax.lax.reduce_min`.
- Other low-level :mod:`jax.lax` reduction operators:
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
"""
return reduce_min_p.bind(operand, axes=tuple(axes))
def _reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array:
def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array:
"""Compute the bitwise OR of elements over one or more array axes.
Args:
operand: array over which to compute the reduction. Must have boolean
or integer dtype.
axes: sequence of zero or more unique integers specifying the axes over
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
Returns:
An array of the same dtype as ``operand``, with shape corresponding
to the dimensions of ``operand.shape`` with ``axes`` removed.
See also:
- :func:`jax.numpy.bitwise_or.reduce`: more flexible NumPy-style logical
reduction API, built around :func:`jax.lax.reduce_or`.
- Other low-level :mod:`jax.lax` reduction operators:
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
:func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_xor`.
"""
return reduce_or_p.bind(operand, axes=tuple(axes))
def _reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array:
def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array:
"""Compute the bitwise AND of elements over one or more array axes.
Args:
operand: array over which to compute the reduction. Must have boolean
or integer dtype.
axes: sequence of zero or more unique integers specifying the axes over
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
Returns:
An array of the same dtype as ``operand``, with shape corresponding
to the dimensions of ``operand.shape`` with ``axes`` removed.
See also:
- :func:`jax.numpy.bitwise_and.reduce`: more flexible NumPy-style logical
reduction API, built around :func:`jax.lax.reduce_and`.
- Other low-level :mod:`jax.lax` reduction operators:
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
:func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
"""
return reduce_and_p.bind(operand, axes=tuple(axes))
def _reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array:
def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array:
"""Compute the bitwise XOR of elements over one or more array axes.
Args:
operand: array over which to compute the reduction. Must have boolean
or integer dtype.
axes: sequence of zero or more unique integers specifying the axes over
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
Returns:
An array of the same dtype as ``operand``, with shape corresponding
to the dimensions of ``operand.shape`` with ``axes`` removed.
See also:
- :func:`jax.numpy.bitwise_xor.reduce`: more flexible NumPy-style logical
reduction API, built around :func:`jax.lax.reduce_xor`.
- Other low-level :mod:`jax.lax` reduction operators:
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
:func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`.
"""
return reduce_xor_p.bind(operand, axes=tuple(axes))
@overload
@ -3014,11 +3153,11 @@ def _unbroadcast(aval, x):
return x
assert not aval.shape or len(x_shape) == len(aval.shape)
if not aval.shape:
return _reduce_sum(x, list(range(len(x_shape))))
return reduce_sum(x, list(range(len(x_shape))))
else:
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)]
if config.enable_checks.value: assert all(aval.shape[i] == 1 for i in dims)
return reshape(_reduce_sum(x, dims), aval.shape)
return reshape(reduce_sum(x, dims), aval.shape)
def _maybe_broadcast(target_shape, x):
x_shape = np.shape(x)
@ -4959,7 +5098,7 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
if core.definitely_equal(s, 1)]
bdims = tuple(np.delete(broadcast_dimensions, unit_dims))
axes = tuple(np.delete(range(len(shape)), bdims))
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
return ([expand_dims(reduce_sum(ct, axes), unit_dims)] +
[None] * len(dyn_shape))
def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape,
@ -5405,7 +5544,7 @@ def _pad_transpose(t, operand, padding_value, *, padding_config):
t_padv = ad_util.Zero(padding_value.aval) if ad.is_undefined_primal(padding_value) else None
else:
lo, hi, interior = util.unzip3(padding_config)
total = lambda x: _reduce_sum(x, list(range(t.ndim)))
total = lambda x: reduce_sum(x, list(range(t.ndim)))
def t_op():
unpad_config = safe_zip(np.negative(lo), np.negative(hi),
@ -6177,7 +6316,7 @@ reduce_sum_p = standard_primitive(
'reduce_sum', sharding_rule=_reduce_op_sharding_rule)
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p, _get_sum_identity)
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum,
_get_sum_identity)
batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule
@ -6192,7 +6331,7 @@ reduce_prod_p = standard_primitive(
'reduce_prod', sharding_rule=_reduce_op_sharding_rule)
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
batching.defreducer(reduce_prod_p, _get_prod_identity)
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod,
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod,
_get_prod_identity)
@ -6203,8 +6342,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
location_indicators = convert_element_type(
_eq_meet(operand, reshape(ans, shape)), g.dtype)
counts = _reduce_sum(location_indicators, axes)
return div(_reduce_sum(mul(g, location_indicators), axes), counts)
counts = reduce_sum(location_indicators, axes)
return div(reduce_sum(mul(g, location_indicators), axes), counts)
reduce_max_p = standard_primitive(
@ -6212,7 +6351,7 @@ reduce_max_p = standard_primitive(
sharding_rule=_reduce_op_sharding_rule)
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_max_p, _get_max_identity)
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max,
_get_max_identity)
batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule
@ -6222,7 +6361,7 @@ reduce_min_p = standard_primitive(
sharding_rule=_reduce_op_sharding_rule)
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_min_p, _get_min_identity)
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, _reduce_min,
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min,
_get_min_identity)

View File

@ -130,8 +130,8 @@ def psum(x, axis_name, *, axis_index_groups=None):
def pos_reduce(x):
if not pos_axes:
return x
return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0))
for axis in pos_axes])
return lax.reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0))
for axis in pos_axes])
if axis_index_groups is not None:
assert not pos_axes
size = len(axis_index_groups[0])
@ -834,10 +834,10 @@ def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
psum_p = core.Primitive('psum')
psum_p.multiple_results = True
psum_p.def_impl(partial(_allreduce_impl, psum_p, lax._reduce_sum))
psum_p.def_impl(partial(_allreduce_impl, psum_p, lax.reduce_sum))
psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
mlir.register_lowering(
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
psum_p, partial(_allreduce_lowering, lax.add_p, lax.reduce_sum))
ad.deflinear2(psum_p, _psum_transpose_rule)
batching.fancy_primitive_batchers[psum_p] = \
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
@ -845,10 +845,10 @@ batching.skippable_batchers[psum_p] = partial(_names_in_param, 'axes')
pmax_p = core.Primitive('pmax')
pmax_p.multiple_results = True
pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax._reduce_max))
pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax.reduce_max))
pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
mlir.register_lowering(
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
pmax_p, partial(_allreduce_lowering, lax.max_p, lax.reduce_max))
batching.fancy_primitive_batchers[pmax_p] = \
partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v)
batching.skippable_batchers[pmax_p] = partial(_names_in_param, 'axes')
@ -856,10 +856,10 @@ batching.skippable_batchers[pmax_p] = partial(_names_in_param, 'axes')
pmin_p = core.Primitive('pmin')
pmin_p.multiple_results = True
pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax._reduce_min))
pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax.reduce_min))
pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
mlir.register_lowering(
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
pmin_p, partial(_allreduce_lowering, lax.min_p, lax.reduce_min))
batching.fancy_primitive_batchers[pmin_p] = \
partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v)
batching.skippable_batchers[pmin_p] = partial(_names_in_param, 'axes')

View File

@ -1910,7 +1910,7 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
mask = lax.bitwise_and(
lax.ge(indices, np.int64(0)),
lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
mask = lax._reduce_and(mask, [num_batch_dims])
mask = lax.reduce_and(mask, [num_batch_dims])
# Computes the output shape and the positions of the batch dimensions in the
# output

View File

@ -644,18 +644,18 @@ def _gen_reduce_choose_taylor_rule(chooser_fun):
location_indicators = lax.convert_element_type(
lax_internal._eq_meet(operand, lax.reshape(primal_out, shape)),
primal_dtype)
counts = lax_internal._reduce_sum(location_indicators, axes)
counts = lax.reduce_sum(location_indicators, axes)
def _reduce_chooser_taylor_rule(g):
return lax.div(
lax_internal._reduce_sum(lax.mul(g, location_indicators), axes),
lax.reduce_sum(lax.mul(g, location_indicators), axes),
counts)
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
return primal_out, series_out
return chooser_taylor_rule
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(
lax_internal._reduce_max)
lax.reduce_max)
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(
lax_internal._reduce_min)
lax.reduce_min)
def _abs_taylor_rule(x, series_in, **params):
x, = x

View File

@ -162,15 +162,22 @@ from jax._src.lax.lax import (
real_p as real_p,
reciprocal as reciprocal,
reduce as reduce,
reduce_and as reduce_and,
reduce_and_p as reduce_and_p,
reduce_max as reduce_max,
reduce_max_p as reduce_max_p,
reduce_min as reduce_min,
reduce_min_p as reduce_min_p,
reduce_or as reduce_or,
reduce_or_p as reduce_or_p,
reduce_p as reduce_p,
reduce_precision as reduce_precision,
reduce_precision_p as reduce_precision_p,
reduce_prod as reduce_prod,
reduce_prod_p as reduce_prod_p,
reduce_sum as reduce_sum,
reduce_sum_p as reduce_sum_p,
reduce_xor as reduce_xor,
reduce_xor_p as reduce_xor_p,
rem as rem,
rem_p as rem_p,

View File

@ -34,7 +34,6 @@ from jax._src import util
from jax._src import test_util as jtu
from jax._src.core import ShapedArray, DBIdx
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
config.parse_flags_with_absl()
@ -651,7 +650,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
def f(x, y):
z = lax.mul(x, y)
w = lax.sin(z)
u = lax_internal._reduce_sum(w, [0])
u = lax.reduce_sum(w, [0])
return (u,)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(

View File

@ -23,8 +23,6 @@ import jax
from jax import lax
import numpy as np
from jax._src.lax import lax as lax_internal
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
jax.config.update("jax_enable_x64", True)
@ -56,12 +54,12 @@ def main(_):
partial(lax.pad, padding_value=np.int32(7),
padding_config=((2, 3, 4), (4, 5, 6))))
# CHECK-LABEL: TEST: _reduce_sum int32[2,3,7]
# CHECK-LABEL: TEST: reduce_sum int32[2,3,7]
# CHECK: hlo.reduce
# CHECK: hlo.add
# CHECK: tensor<3xi32>
print_ir(np.empty([2, 3, 7], np.int32))(
partial(lax_internal._reduce_sum, axes=(0, 2)))
partial(lax.reduce_sum, axes=(0, 2)))
# CHECK-LABEL: TEST: reshape int32[2,3,7]
# CHECK: hlo.reshape

View File

@ -1941,6 +1941,42 @@ class LaxTest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(fun)(rng(shape, dtype))
self.assertEqual(jaxpr.eqns[0].primitive, primitive)
@jtu.sample_product(
[
dict(
op=rec.op,
reference_op=rec.reference_op,
dtype=dtype,
)
for rec in lax_test_util.lax_named_reduce_ops()
for dtype in rec.dtypes
],
[
dict(shape=shape, dims=dims)
for shape, dims in [
[(3, 4, 5), (0,)],
[(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)],
[(3, 4, 5), (0, 1, 2)],
]
],
)
def testNamedReduceOperators(self, op, reference_op, dtype, shape, dims):
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small)
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
def lax_fun(operand):
return op(operand, dims)
def reference_fun(operand):
return reference_op(operand, dims).astype(dtype)
self._CompileAndCheck(lax_fun, args_maker)
self._CheckAgainstNumpy(reference_fun, lax_fun, args_maker)
@jtu.sample_product(
op=["add", "mul"],
op_namespace=[lax, operator],