diff --git a/CHANGELOG.md b/CHANGELOG.md index e4f8bee0b..69c88f00f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. decorator to support customizing the behavior of opaque functions under JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more details. + * Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`, + {func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`, + {func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`. * Changes * `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 34459d614..67c54ad46 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -130,8 +130,15 @@ Operators real reciprocal reduce + reduce_and + reduce_max + reduce_min + reduce_or reduce_precision + reduce_prod + reduce_sum reduce_window + reduce_xor rem reshape rev diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index b57b7d085..4e28791e9 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -49,8 +49,11 @@ uint_dtypes = test_util.dtypes.all_unsigned bool_dtypes = test_util.dtypes.boolean default_dtypes = float_dtypes + int_dtypes +number_dtypes = ( + float_dtypes + complex_dtypes + int_dtypes + uint_dtypes +) all_dtypes = ( - float_dtypes + complex_dtypes + int_dtypes + uint_dtypes + bool_dtypes + number_dtypes + bool_dtypes ) python_scalar_types = [bool, int, float, complex] @@ -151,6 +154,25 @@ def lax_reduce_ops(): ] +NamedReducerOpRecord = collections.namedtuple( + "NamedReducerOpRecord", ["op", "reference_op", "dtypes"] +) + +def lax_named_reduce_ops(): + return [ + NamedReducerOpRecord(lax.reduce_sum, np.sum, number_dtypes), + NamedReducerOpRecord(lax.reduce_prod, np.prod, number_dtypes), + NamedReducerOpRecord(lax.reduce_max, np.max, all_dtypes), + NamedReducerOpRecord(lax.reduce_min, np.min, all_dtypes), + NamedReducerOpRecord(lax.reduce_and, np.bitwise_and.reduce, + bool_dtypes + int_dtypes + uint_dtypes), + NamedReducerOpRecord(lax.reduce_or, np.bitwise_or.reduce, + bool_dtypes + int_dtypes + uint_dtypes), + NamedReducerOpRecord(lax.reduce_xor, np.bitwise_xor.reduce, + bool_dtypes + int_dtypes + uint_dtypes), + ] + + def lax_ops(): return [ op_record( diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 114a37b0c..ab7451052 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1752,7 +1752,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, # Pred can be batched pred = core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0] if batched: - pred = lax._reduce_or(pred, tuple(range(len(pred_aval.shape)))) + pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape)))) return pred def body(args): return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ac99ff886..aab0e459e 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2177,19 +2177,19 @@ def _get_monoid_reducer(monoid_op: Callable, # allow bitwise reductions for boolean and integer types _is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer) if monoid_op is add: - return _reduce_sum if np.equal(val, 0) else None + return reduce_sum if np.equal(val, 0) else None elif monoid_op is mul: - return _reduce_prod if np.equal(val, 1) else None + return reduce_prod if np.equal(val, 1) else None elif monoid_op is bitwise_or and _is_intlike: - return _reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None + return reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None elif monoid_op is bitwise_and and _is_intlike: - return _reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None + return reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None elif monoid_op is bitwise_xor and _is_intlike: - return _reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None + return reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None elif monoid_op is max: - return _reduce_max if np.equal(val, _get_max_identity(dtype)) else None + return reduce_max if np.equal(val, _get_max_identity(dtype)) else None elif monoid_op is min: - return _reduce_min if np.equal(val, _get_min_identity(dtype)) else None + return reduce_min if np.equal(val, _get_min_identity(dtype)) else None return None def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray: @@ -2226,25 +2226,164 @@ def _get_min_identity(dtype: DTypeLike) -> np.ndarray: else: raise ValueError(f"Unsupported dtype for min: {dtype}") -def _reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array: +def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array: + """Compute the sum of elements over one or more array axes. + + Args: + operand: array over which to sum. Must have numerical dtype. + axes: sequence of zero or more unique integers specifying the axes over + which to sum. Each entry must satisfy ``0 <= axis < operand.ndim``. + + Returns: + An array of the same dtype as ``operand``, with shape corresponding + to the dimensions of ``operand.shape`` with ``axes`` removed. + + Notes: + Unlike :func:`jax.numpy.sum`, :func:`jax.lax.reduce_sum` does not upcast + narrow-width types for accumulation, so sums of 8-bit or 16-bit types + may be subject to rounding errors. + + See also: + - :func:`jax.numpy.sum`: more flexible NumPy-style summation API, built + around :func:`jax.lax.reduce_sum`. + - Other low-level :mod:`jax.lax` reduction operators: + :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, + :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. + """ return reduce_sum_p.bind(operand, axes=tuple(axes)) -def _reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: +def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array: + """Compute the product of elements over one or more array axes. + + Args: + operand: array over which to sum. Must have numerical dtype. + axes: sequence of zero or more unique integers specifying the axes over + which to sum. Each entry must satisfy ``0 <= axis < operand.ndim``. + + Returns: + An array of the same dtype as ``operand``, with shape corresponding + to the dimensions of ``operand.shape`` with ``axes`` removed. + + Notes: + Unlike :func:`jax.numpy.prod`, :func:`jax.lax.reduce_prod` does not upcast + narrow-width types for accumulation, so products of 8-bit or 16-bit types + may be subject to rounding errors. + + See also: + - :func:`jax.numpy.prod`: more flexible NumPy-style product API, built + around :func:`jax.lax.reduce_prod`. + - Other low-level :mod:`jax.lax` reduction operators: + :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`, + :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. + """ return reduce_prod_p.bind(operand, axes=tuple(axes)) -def _reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: +def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array: + """Compute the maximum of elements over one or more array axes. + + Args: + operand: array over which to compute maximum. + axes: sequence of zero or more unique integers specifying the axes over + which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``. + + Returns: + An array of the same dtype as ``operand``, with shape corresponding + to the dimensions of ``operand.shape`` with ``axes`` removed. + + See also: + - :func:`jax.numpy.max`: more flexible NumPy-style max-reduction API, built + around :func:`jax.lax.reduce_max`. + - Other low-level :mod:`jax.lax` reduction operators: + :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_min`, + :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. + """ return reduce_max_p.bind(operand, axes=tuple(axes)) -def _reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: +def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array: + """Compute the minimum of elements over one or more array axes. + + Args: + operand: array over which to compute minimum. + axes: sequence of zero or more unique integers specifying the axes over + which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``. + + Returns: + An array of the same dtype as ``operand``, with shape corresponding + to the dimensions of ``operand.shape`` with ``axes`` removed. + + See also: + - :func:`jax.numpy.min`: more flexible NumPy-style min-reduction API, built + around :func:`jax.lax.reduce_min`. + - Other low-level :mod:`jax.lax` reduction operators: + :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, + :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. + """ return reduce_min_p.bind(operand, axes=tuple(axes)) -def _reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: +def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array: + """Compute the bitwise OR of elements over one or more array axes. + + Args: + operand: array over which to compute the reduction. Must have boolean + or integer dtype. + axes: sequence of zero or more unique integers specifying the axes over + which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``. + + Returns: + An array of the same dtype as ``operand``, with shape corresponding + to the dimensions of ``operand.shape`` with ``axes`` removed. + + See also: + - :func:`jax.numpy.bitwise_or.reduce`: more flexible NumPy-style logical + reduction API, built around :func:`jax.lax.reduce_or`. + - Other low-level :mod:`jax.lax` reduction operators: + :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, + :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_xor`. + """ return reduce_or_p.bind(operand, axes=tuple(axes)) -def _reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: +def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array: + """Compute the bitwise AND of elements over one or more array axes. + + Args: + operand: array over which to compute the reduction. Must have boolean + or integer dtype. + axes: sequence of zero or more unique integers specifying the axes over + which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``. + + Returns: + An array of the same dtype as ``operand``, with shape corresponding + to the dimensions of ``operand.shape`` with ``axes`` removed. + + See also: + - :func:`jax.numpy.bitwise_and.reduce`: more flexible NumPy-style logical + reduction API, built around :func:`jax.lax.reduce_and`. + - Other low-level :mod:`jax.lax` reduction operators: + :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, + :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`. + """ return reduce_and_p.bind(operand, axes=tuple(axes)) -def _reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: +def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array: + """Compute the bitwise XOR of elements over one or more array axes. + + Args: + operand: array over which to compute the reduction. Must have boolean + or integer dtype. + axes: sequence of zero or more unique integers specifying the axes over + which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``. + + Returns: + An array of the same dtype as ``operand``, with shape corresponding + to the dimensions of ``operand.shape`` with ``axes`` removed. + + See also: + - :func:`jax.numpy.bitwise_xor.reduce`: more flexible NumPy-style logical + reduction API, built around :func:`jax.lax.reduce_xor`. + - Other low-level :mod:`jax.lax` reduction operators: + :func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, + :func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`. + """ return reduce_xor_p.bind(operand, axes=tuple(axes)) @overload @@ -3014,11 +3153,11 @@ def _unbroadcast(aval, x): return x assert not aval.shape or len(x_shape) == len(aval.shape) if not aval.shape: - return _reduce_sum(x, list(range(len(x_shape)))) + return reduce_sum(x, list(range(len(x_shape)))) else: dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)] if config.enable_checks.value: assert all(aval.shape[i] == 1 for i in dims) - return reshape(_reduce_sum(x, dims), aval.shape) + return reshape(reduce_sum(x, dims), aval.shape) def _maybe_broadcast(target_shape, x): x_shape = np.shape(x) @@ -4959,7 +5098,7 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape, if core.definitely_equal(s, 1)] bdims = tuple(np.delete(broadcast_dimensions, unit_dims)) axes = tuple(np.delete(range(len(shape)), bdims)) - return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] + + return ([expand_dims(reduce_sum(ct, axes), unit_dims)] + [None] * len(dyn_shape)) def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape, @@ -5405,7 +5544,7 @@ def _pad_transpose(t, operand, padding_value, *, padding_config): t_padv = ad_util.Zero(padding_value.aval) if ad.is_undefined_primal(padding_value) else None else: lo, hi, interior = util.unzip3(padding_config) - total = lambda x: _reduce_sum(x, list(range(t.ndim))) + total = lambda x: reduce_sum(x, list(range(t.ndim))) def t_op(): unpad_config = safe_zip(np.negative(lo), np.negative(hi), @@ -6177,7 +6316,7 @@ reduce_sum_p = standard_primitive( 'reduce_sum', sharding_rule=_reduce_op_sharding_rule) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) -pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum, +pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum, _get_sum_identity) batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule @@ -6192,7 +6331,7 @@ reduce_prod_p = standard_primitive( 'reduce_prod', sharding_rule=_reduce_op_sharding_rule) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) -pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod, +pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod, _get_prod_identity) @@ -6203,8 +6342,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): shape = [1 if i in axes else d for i, d in enumerate(operand.shape)] location_indicators = convert_element_type( _eq_meet(operand, reshape(ans, shape)), g.dtype) - counts = _reduce_sum(location_indicators, axes) - return div(_reduce_sum(mul(g, location_indicators), axes), counts) + counts = reduce_sum(location_indicators, axes) + return div(reduce_sum(mul(g, location_indicators), axes), counts) reduce_max_p = standard_primitive( @@ -6212,7 +6351,7 @@ reduce_max_p = standard_primitive( sharding_rule=_reduce_op_sharding_rule) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) -pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max, +pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max, _get_max_identity) batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule @@ -6222,7 +6361,7 @@ reduce_min_p = standard_primitive( sharding_rule=_reduce_op_sharding_rule) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) -pe.padding_rules[reduce_min_p] = partial(_reducer_padding, _reduce_min, +pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min, _get_min_identity) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index d8e7431e9..5abd55036 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -130,8 +130,8 @@ def psum(x, axis_name, *, axis_index_groups=None): def pos_reduce(x): if not pos_axes: return x - return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0)) - for axis in pos_axes]) + return lax.reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0)) + for axis in pos_axes]) if axis_index_groups is not None: assert not pos_axes size = len(axis_index_groups[0]) @@ -834,10 +834,10 @@ def _psum_transpose_rule(cts, *args, axes, axis_index_groups): psum_p = core.Primitive('psum') psum_p.multiple_results = True -psum_p.def_impl(partial(_allreduce_impl, psum_p, lax._reduce_sum)) +psum_p.def_impl(partial(_allreduce_impl, psum_p, lax.reduce_sum)) psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( - psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum)) + psum_p, partial(_allreduce_lowering, lax.add_p, lax.reduce_sum)) ad.deflinear2(psum_p, _psum_transpose_rule) batching.fancy_primitive_batchers[psum_p] = \ partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v) @@ -845,10 +845,10 @@ batching.skippable_batchers[psum_p] = partial(_names_in_param, 'axes') pmax_p = core.Primitive('pmax') pmax_p.multiple_results = True -pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax._reduce_max)) +pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax.reduce_max)) pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( - pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max)) + pmax_p, partial(_allreduce_lowering, lax.max_p, lax.reduce_max)) batching.fancy_primitive_batchers[pmax_p] = \ partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v) batching.skippable_batchers[pmax_p] = partial(_names_in_param, 'axes') @@ -856,10 +856,10 @@ batching.skippable_batchers[pmax_p] = partial(_names_in_param, 'axes') pmin_p = core.Primitive('pmin') pmin_p.multiple_results = True -pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax._reduce_min)) +pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax.reduce_min)) pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( - pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min)) + pmin_p, partial(_allreduce_lowering, lax.min_p, lax.reduce_min)) batching.fancy_primitive_batchers[pmin_p] = \ partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v) batching.skippable_batchers[pmin_p] = partial(_names_in_param, 'axes') diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 3497fd134..9037bcab6 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1910,7 +1910,7 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes, mask = lax.bitwise_and( lax.ge(indices, np.int64(0)), lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims))))) - mask = lax._reduce_and(mask, [num_batch_dims]) + mask = lax.reduce_and(mask, [num_batch_dims]) # Computes the output shape and the positions of the batch dimensions in the # output diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 31515d4e1..0dc30f845 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -644,18 +644,18 @@ def _gen_reduce_choose_taylor_rule(chooser_fun): location_indicators = lax.convert_element_type( lax_internal._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype) - counts = lax_internal._reduce_sum(location_indicators, axes) + counts = lax.reduce_sum(location_indicators, axes) def _reduce_chooser_taylor_rule(g): return lax.div( - lax_internal._reduce_sum(lax.mul(g, location_indicators), axes), + lax.reduce_sum(lax.mul(g, location_indicators), axes), counts) series_out = [_reduce_chooser_taylor_rule(g) for g in gs] return primal_out, series_out return chooser_taylor_rule jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule( - lax_internal._reduce_max) + lax.reduce_max) jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule( - lax_internal._reduce_min) + lax.reduce_min) def _abs_taylor_rule(x, series_in, **params): x, = x diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index de3f293b2..a26d15c14 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -162,15 +162,22 @@ from jax._src.lax.lax import ( real_p as real_p, reciprocal as reciprocal, reduce as reduce, + reduce_and as reduce_and, reduce_and_p as reduce_and_p, + reduce_max as reduce_max, reduce_max_p as reduce_max_p, + reduce_min as reduce_min, reduce_min_p as reduce_min_p, + reduce_or as reduce_or, reduce_or_p as reduce_or_p, reduce_p as reduce_p, reduce_precision as reduce_precision, reduce_precision_p as reduce_precision_p, + reduce_prod as reduce_prod, reduce_prod_p as reduce_prod_p, + reduce_sum as reduce_sum, reduce_sum_p as reduce_sum_p, + reduce_xor as reduce_xor, reduce_xor_p as reduce_xor_p, rem as rem, rem_p as rem_p, diff --git a/tests/core_test.py b/tests/core_test.py index da7fdc056..5fc906bd3 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -34,7 +34,6 @@ from jax._src import util from jax._src import test_util as jtu from jax._src.core import ShapedArray, DBIdx from jax._src.interpreters import partial_eval as pe -from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow config.parse_flags_with_absl() @@ -651,7 +650,7 @@ class DynamicShapesTest(jtu.JaxTestCase): def f(x, y): z = lax.mul(x, y) w = lax.sin(z) - u = lax_internal._reduce_sum(w, [0]) + u = lax.reduce_sum(w, [0]) return (u,) jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( diff --git a/tests/filecheck/array.filecheck.py b/tests/filecheck/array.filecheck.py index 4305a450d..0ee237ed6 100644 --- a/tests/filecheck/array.filecheck.py +++ b/tests/filecheck/array.filecheck.py @@ -23,8 +23,6 @@ import jax from jax import lax import numpy as np -from jax._src.lax import lax as lax_internal - from jax.tests.filecheck.jax_filecheck_helpers import print_ir jax.config.update("jax_enable_x64", True) @@ -56,12 +54,12 @@ def main(_): partial(lax.pad, padding_value=np.int32(7), padding_config=((2, 3, 4), (4, 5, 6)))) - # CHECK-LABEL: TEST: _reduce_sum int32[2,3,7] + # CHECK-LABEL: TEST: reduce_sum int32[2,3,7] # CHECK: hlo.reduce # CHECK: hlo.add # CHECK: tensor<3xi32> print_ir(np.empty([2, 3, 7], np.int32))( - partial(lax_internal._reduce_sum, axes=(0, 2))) + partial(lax.reduce_sum, axes=(0, 2))) # CHECK-LABEL: TEST: reshape int32[2,3,7] # CHECK: hlo.reshape diff --git a/tests/lax_test.py b/tests/lax_test.py index 291832388..ef73082ab 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -1941,6 +1941,42 @@ class LaxTest(jtu.JaxTestCase): jaxpr = jax.make_jaxpr(fun)(rng(shape, dtype)) self.assertEqual(jaxpr.eqns[0].primitive, primitive) + @jtu.sample_product( + [ + dict( + op=rec.op, + reference_op=rec.reference_op, + dtype=dtype, + ) + for rec in lax_test_util.lax_named_reduce_ops() + for dtype in rec.dtypes + ], + [ + dict(shape=shape, dims=dims) + for shape, dims in [ + [(3, 4, 5), (0,)], + [(3, 4, 5), (1, 2)], + [(3, 4, 5), (0, 2)], + [(3, 4, 5), (0, 1, 2)], + ] + ], + ) + def testNamedReduceOperators(self, op, reference_op, dtype, shape, dims): + rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer) + else jtu.rand_small) + rng = rng_factory(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + def lax_fun(operand): + return op(operand, dims) + + def reference_fun(operand): + return reference_op(operand, dims).astype(dtype) + + self._CompileAndCheck(lax_fun, args_maker) + self._CheckAgainstNumpy(reference_fun, lax_fun, args_maker) + + @jtu.sample_product( op=["add", "mul"], op_namespace=[lax, operator],