mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add public APIs for jax.lax monoidal reductions
This commit is contained in:
parent
d0b6c677b0
commit
e389b707ba
@ -21,6 +21,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
decorator to support customizing the behavior of opaque functions under
|
||||
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
|
||||
details.
|
||||
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
|
||||
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
|
||||
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
|
||||
|
||||
* Changes
|
||||
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
|
||||
|
@ -130,8 +130,15 @@ Operators
|
||||
real
|
||||
reciprocal
|
||||
reduce
|
||||
reduce_and
|
||||
reduce_max
|
||||
reduce_min
|
||||
reduce_or
|
||||
reduce_precision
|
||||
reduce_prod
|
||||
reduce_sum
|
||||
reduce_window
|
||||
reduce_xor
|
||||
rem
|
||||
reshape
|
||||
rev
|
||||
|
@ -49,8 +49,11 @@ uint_dtypes = test_util.dtypes.all_unsigned
|
||||
bool_dtypes = test_util.dtypes.boolean
|
||||
|
||||
default_dtypes = float_dtypes + int_dtypes
|
||||
number_dtypes = (
|
||||
float_dtypes + complex_dtypes + int_dtypes + uint_dtypes
|
||||
)
|
||||
all_dtypes = (
|
||||
float_dtypes + complex_dtypes + int_dtypes + uint_dtypes + bool_dtypes
|
||||
number_dtypes + bool_dtypes
|
||||
)
|
||||
python_scalar_types = [bool, int, float, complex]
|
||||
|
||||
@ -151,6 +154,25 @@ def lax_reduce_ops():
|
||||
]
|
||||
|
||||
|
||||
NamedReducerOpRecord = collections.namedtuple(
|
||||
"NamedReducerOpRecord", ["op", "reference_op", "dtypes"]
|
||||
)
|
||||
|
||||
def lax_named_reduce_ops():
|
||||
return [
|
||||
NamedReducerOpRecord(lax.reduce_sum, np.sum, number_dtypes),
|
||||
NamedReducerOpRecord(lax.reduce_prod, np.prod, number_dtypes),
|
||||
NamedReducerOpRecord(lax.reduce_max, np.max, all_dtypes),
|
||||
NamedReducerOpRecord(lax.reduce_min, np.min, all_dtypes),
|
||||
NamedReducerOpRecord(lax.reduce_and, np.bitwise_and.reduce,
|
||||
bool_dtypes + int_dtypes + uint_dtypes),
|
||||
NamedReducerOpRecord(lax.reduce_or, np.bitwise_or.reduce,
|
||||
bool_dtypes + int_dtypes + uint_dtypes),
|
||||
NamedReducerOpRecord(lax.reduce_xor, np.bitwise_xor.reduce,
|
||||
bool_dtypes + int_dtypes + uint_dtypes),
|
||||
]
|
||||
|
||||
|
||||
def lax_ops():
|
||||
return [
|
||||
op_record(
|
||||
|
@ -1752,7 +1752,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
# Pred can be batched
|
||||
pred = core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
|
||||
if batched:
|
||||
pred = lax._reduce_or(pred, tuple(range(len(pred_aval.shape))))
|
||||
pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape))))
|
||||
return pred
|
||||
def body(args):
|
||||
return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args))
|
||||
|
@ -2177,19 +2177,19 @@ def _get_monoid_reducer(monoid_op: Callable,
|
||||
# allow bitwise reductions for boolean and integer types
|
||||
_is_intlike = dtype == np.bool_ or dtypes.issubdtype(dtype, np.integer)
|
||||
if monoid_op is add:
|
||||
return _reduce_sum if np.equal(val, 0) else None
|
||||
return reduce_sum if np.equal(val, 0) else None
|
||||
elif monoid_op is mul:
|
||||
return _reduce_prod if np.equal(val, 1) else None
|
||||
return reduce_prod if np.equal(val, 1) else None
|
||||
elif monoid_op is bitwise_or and _is_intlike:
|
||||
return _reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None
|
||||
return reduce_or if np.equal(val, _get_bitwise_or_identity(dtype)) else None
|
||||
elif monoid_op is bitwise_and and _is_intlike:
|
||||
return _reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None
|
||||
return reduce_and if np.equal(val, _get_bitwise_and_identity(dtype)) else None
|
||||
elif monoid_op is bitwise_xor and _is_intlike:
|
||||
return _reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None
|
||||
return reduce_xor if np.equal(val, _get_bitwise_or_identity(dtype)) else None
|
||||
elif monoid_op is max:
|
||||
return _reduce_max if np.equal(val, _get_max_identity(dtype)) else None
|
||||
return reduce_max if np.equal(val, _get_max_identity(dtype)) else None
|
||||
elif monoid_op is min:
|
||||
return _reduce_min if np.equal(val, _get_min_identity(dtype)) else None
|
||||
return reduce_min if np.equal(val, _get_min_identity(dtype)) else None
|
||||
return None
|
||||
|
||||
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:
|
||||
@ -2226,25 +2226,164 @@ def _get_min_identity(dtype: DTypeLike) -> np.ndarray:
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype for min: {dtype}")
|
||||
|
||||
def _reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
def reduce_sum(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
"""Compute the sum of elements over one or more array axes.
|
||||
|
||||
Args:
|
||||
operand: array over which to sum. Must have numerical dtype.
|
||||
axes: sequence of zero or more unique integers specifying the axes over
|
||||
which to sum. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``operand``, with shape corresponding
|
||||
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
||||
|
||||
Notes:
|
||||
Unlike :func:`jax.numpy.sum`, :func:`jax.lax.reduce_sum` does not upcast
|
||||
narrow-width types for accumulation, so sums of 8-bit or 16-bit types
|
||||
may be subject to rounding errors.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.sum`: more flexible NumPy-style summation API, built
|
||||
around :func:`jax.lax.reduce_sum`.
|
||||
- Other low-level :mod:`jax.lax` reduction operators:
|
||||
:func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`,
|
||||
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
||||
"""
|
||||
return reduce_sum_p.bind(operand, axes=tuple(axes))
|
||||
|
||||
def _reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
def reduce_prod(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
"""Compute the product of elements over one or more array axes.
|
||||
|
||||
Args:
|
||||
operand: array over which to sum. Must have numerical dtype.
|
||||
axes: sequence of zero or more unique integers specifying the axes over
|
||||
which to sum. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``operand``, with shape corresponding
|
||||
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
||||
|
||||
Notes:
|
||||
Unlike :func:`jax.numpy.prod`, :func:`jax.lax.reduce_prod` does not upcast
|
||||
narrow-width types for accumulation, so products of 8-bit or 16-bit types
|
||||
may be subject to rounding errors.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.prod`: more flexible NumPy-style product API, built
|
||||
around :func:`jax.lax.reduce_prod`.
|
||||
- Other low-level :mod:`jax.lax` reduction operators:
|
||||
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_max`, :func:`jax.lax.reduce_min`,
|
||||
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
||||
"""
|
||||
return reduce_prod_p.bind(operand, axes=tuple(axes))
|
||||
|
||||
def _reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
def reduce_max(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
"""Compute the maximum of elements over one or more array axes.
|
||||
|
||||
Args:
|
||||
operand: array over which to compute maximum.
|
||||
axes: sequence of zero or more unique integers specifying the axes over
|
||||
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``operand``, with shape corresponding
|
||||
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.max`: more flexible NumPy-style max-reduction API, built
|
||||
around :func:`jax.lax.reduce_max`.
|
||||
- Other low-level :mod:`jax.lax` reduction operators:
|
||||
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_min`,
|
||||
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
||||
"""
|
||||
return reduce_max_p.bind(operand, axes=tuple(axes))
|
||||
|
||||
def _reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
def reduce_min(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
"""Compute the minimum of elements over one or more array axes.
|
||||
|
||||
Args:
|
||||
operand: array over which to compute minimum.
|
||||
axes: sequence of zero or more unique integers specifying the axes over
|
||||
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``operand``, with shape corresponding
|
||||
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.min`: more flexible NumPy-style min-reduction API, built
|
||||
around :func:`jax.lax.reduce_min`.
|
||||
- Other low-level :mod:`jax.lax` reduction operators:
|
||||
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
|
||||
:func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
||||
"""
|
||||
return reduce_min_p.bind(operand, axes=tuple(axes))
|
||||
|
||||
def _reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
def reduce_or(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
"""Compute the bitwise OR of elements over one or more array axes.
|
||||
|
||||
Args:
|
||||
operand: array over which to compute the reduction. Must have boolean
|
||||
or integer dtype.
|
||||
axes: sequence of zero or more unique integers specifying the axes over
|
||||
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``operand``, with shape corresponding
|
||||
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.bitwise_or.reduce`: more flexible NumPy-style logical
|
||||
reduction API, built around :func:`jax.lax.reduce_or`.
|
||||
- Other low-level :mod:`jax.lax` reduction operators:
|
||||
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
|
||||
:func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_xor`.
|
||||
"""
|
||||
return reduce_or_p.bind(operand, axes=tuple(axes))
|
||||
|
||||
def _reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
def reduce_and(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
"""Compute the bitwise AND of elements over one or more array axes.
|
||||
|
||||
Args:
|
||||
operand: array over which to compute the reduction. Must have boolean
|
||||
or integer dtype.
|
||||
axes: sequence of zero or more unique integers specifying the axes over
|
||||
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``operand``, with shape corresponding
|
||||
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.bitwise_and.reduce`: more flexible NumPy-style logical
|
||||
reduction API, built around :func:`jax.lax.reduce_and`.
|
||||
- Other low-level :mod:`jax.lax` reduction operators:
|
||||
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
|
||||
:func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_or`, :func:`jax.lax.reduce_xor`.
|
||||
"""
|
||||
return reduce_and_p.bind(operand, axes=tuple(axes))
|
||||
|
||||
def _reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
def reduce_xor(operand: ArrayLike, axes: Sequence[int]) -> Array:
|
||||
"""Compute the bitwise XOR of elements over one or more array axes.
|
||||
|
||||
Args:
|
||||
operand: array over which to compute the reduction. Must have boolean
|
||||
or integer dtype.
|
||||
axes: sequence of zero or more unique integers specifying the axes over
|
||||
which to reduce. Each entry must satisfy ``0 <= axis < operand.ndim``.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``operand``, with shape corresponding
|
||||
to the dimensions of ``operand.shape`` with ``axes`` removed.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.bitwise_xor.reduce`: more flexible NumPy-style logical
|
||||
reduction API, built around :func:`jax.lax.reduce_xor`.
|
||||
- Other low-level :mod:`jax.lax` reduction operators:
|
||||
:func:`jax.lax.reduce_sum`, :func:`jax.lax.reduce_prod`, :func:`jax.lax.reduce_max`,
|
||||
:func:`jax.lax.reduce_min`, :func:`jax.lax.reduce_and`, :func:`jax.lax.reduce_or`.
|
||||
"""
|
||||
return reduce_xor_p.bind(operand, axes=tuple(axes))
|
||||
|
||||
@overload
|
||||
@ -3014,11 +3153,11 @@ def _unbroadcast(aval, x):
|
||||
return x
|
||||
assert not aval.shape or len(x_shape) == len(aval.shape)
|
||||
if not aval.shape:
|
||||
return _reduce_sum(x, list(range(len(x_shape))))
|
||||
return reduce_sum(x, list(range(len(x_shape))))
|
||||
else:
|
||||
dims = [i for i, (a, b) in enumerate(zip(x_shape, aval.shape)) if not core.definitely_equal(a, b)]
|
||||
if config.enable_checks.value: assert all(aval.shape[i] == 1 for i in dims)
|
||||
return reshape(_reduce_sum(x, dims), aval.shape)
|
||||
return reshape(reduce_sum(x, dims), aval.shape)
|
||||
|
||||
def _maybe_broadcast(target_shape, x):
|
||||
x_shape = np.shape(x)
|
||||
@ -4959,7 +5098,7 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
|
||||
if core.definitely_equal(s, 1)]
|
||||
bdims = tuple(np.delete(broadcast_dimensions, unit_dims))
|
||||
axes = tuple(np.delete(range(len(shape)), bdims))
|
||||
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
|
||||
return ([expand_dims(reduce_sum(ct, axes), unit_dims)] +
|
||||
[None] * len(dyn_shape))
|
||||
|
||||
def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape,
|
||||
@ -5405,7 +5544,7 @@ def _pad_transpose(t, operand, padding_value, *, padding_config):
|
||||
t_padv = ad_util.Zero(padding_value.aval) if ad.is_undefined_primal(padding_value) else None
|
||||
else:
|
||||
lo, hi, interior = util.unzip3(padding_config)
|
||||
total = lambda x: _reduce_sum(x, list(range(t.ndim)))
|
||||
total = lambda x: reduce_sum(x, list(range(t.ndim)))
|
||||
|
||||
def t_op():
|
||||
unpad_config = safe_zip(np.negative(lo), np.negative(hi),
|
||||
@ -6177,7 +6316,7 @@ reduce_sum_p = standard_primitive(
|
||||
'reduce_sum', sharding_rule=_reduce_op_sharding_rule)
|
||||
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
||||
batching.defreducer(reduce_sum_p, _get_sum_identity)
|
||||
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
|
||||
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum,
|
||||
_get_sum_identity)
|
||||
batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
@ -6192,7 +6331,7 @@ reduce_prod_p = standard_primitive(
|
||||
'reduce_prod', sharding_rule=_reduce_op_sharding_rule)
|
||||
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
|
||||
batching.defreducer(reduce_prod_p, _get_prod_identity)
|
||||
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod,
|
||||
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod,
|
||||
_get_prod_identity)
|
||||
|
||||
|
||||
@ -6203,8 +6342,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
|
||||
shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
|
||||
location_indicators = convert_element_type(
|
||||
_eq_meet(operand, reshape(ans, shape)), g.dtype)
|
||||
counts = _reduce_sum(location_indicators, axes)
|
||||
return div(_reduce_sum(mul(g, location_indicators), axes), counts)
|
||||
counts = reduce_sum(location_indicators, axes)
|
||||
return div(reduce_sum(mul(g, location_indicators), axes), counts)
|
||||
|
||||
|
||||
reduce_max_p = standard_primitive(
|
||||
@ -6212,7 +6351,7 @@ reduce_max_p = standard_primitive(
|
||||
sharding_rule=_reduce_op_sharding_rule)
|
||||
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
|
||||
batching.defreducer(reduce_max_p, _get_max_identity)
|
||||
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
|
||||
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max,
|
||||
_get_max_identity)
|
||||
batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
@ -6222,7 +6361,7 @@ reduce_min_p = standard_primitive(
|
||||
sharding_rule=_reduce_op_sharding_rule)
|
||||
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
|
||||
batching.defreducer(reduce_min_p, _get_min_identity)
|
||||
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, _reduce_min,
|
||||
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min,
|
||||
_get_min_identity)
|
||||
|
||||
|
||||
|
@ -130,8 +130,8 @@ def psum(x, axis_name, *, axis_index_groups=None):
|
||||
def pos_reduce(x):
|
||||
if not pos_axes:
|
||||
return x
|
||||
return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0))
|
||||
for axis in pos_axes])
|
||||
return lax.reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0))
|
||||
for axis in pos_axes])
|
||||
if axis_index_groups is not None:
|
||||
assert not pos_axes
|
||||
size = len(axis_index_groups[0])
|
||||
@ -834,10 +834,10 @@ def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
|
||||
|
||||
psum_p = core.Primitive('psum')
|
||||
psum_p.multiple_results = True
|
||||
psum_p.def_impl(partial(_allreduce_impl, psum_p, lax._reduce_sum))
|
||||
psum_p.def_impl(partial(_allreduce_impl, psum_p, lax.reduce_sum))
|
||||
psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
|
||||
mlir.register_lowering(
|
||||
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
|
||||
psum_p, partial(_allreduce_lowering, lax.add_p, lax.reduce_sum))
|
||||
ad.deflinear2(psum_p, _psum_transpose_rule)
|
||||
batching.fancy_primitive_batchers[psum_p] = \
|
||||
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
|
||||
@ -845,10 +845,10 @@ batching.skippable_batchers[psum_p] = partial(_names_in_param, 'axes')
|
||||
|
||||
pmax_p = core.Primitive('pmax')
|
||||
pmax_p.multiple_results = True
|
||||
pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax._reduce_max))
|
||||
pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax.reduce_max))
|
||||
pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
|
||||
mlir.register_lowering(
|
||||
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
|
||||
pmax_p, partial(_allreduce_lowering, lax.max_p, lax.reduce_max))
|
||||
batching.fancy_primitive_batchers[pmax_p] = \
|
||||
partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v)
|
||||
batching.skippable_batchers[pmax_p] = partial(_names_in_param, 'axes')
|
||||
@ -856,10 +856,10 @@ batching.skippable_batchers[pmax_p] = partial(_names_in_param, 'axes')
|
||||
|
||||
pmin_p = core.Primitive('pmin')
|
||||
pmin_p.multiple_results = True
|
||||
pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax._reduce_min))
|
||||
pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax.reduce_min))
|
||||
pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
|
||||
mlir.register_lowering(
|
||||
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
|
||||
pmin_p, partial(_allreduce_lowering, lax.min_p, lax.reduce_min))
|
||||
batching.fancy_primitive_batchers[pmin_p] = \
|
||||
partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v)
|
||||
batching.skippable_batchers[pmin_p] = partial(_names_in_param, 'axes')
|
||||
|
@ -1910,7 +1910,7 @@ def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes,
|
||||
mask = lax.bitwise_and(
|
||||
lax.ge(indices, np.int64(0)),
|
||||
lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims)))))
|
||||
mask = lax._reduce_and(mask, [num_batch_dims])
|
||||
mask = lax.reduce_and(mask, [num_batch_dims])
|
||||
|
||||
# Computes the output shape and the positions of the batch dimensions in the
|
||||
# output
|
||||
|
@ -644,18 +644,18 @@ def _gen_reduce_choose_taylor_rule(chooser_fun):
|
||||
location_indicators = lax.convert_element_type(
|
||||
lax_internal._eq_meet(operand, lax.reshape(primal_out, shape)),
|
||||
primal_dtype)
|
||||
counts = lax_internal._reduce_sum(location_indicators, axes)
|
||||
counts = lax.reduce_sum(location_indicators, axes)
|
||||
def _reduce_chooser_taylor_rule(g):
|
||||
return lax.div(
|
||||
lax_internal._reduce_sum(lax.mul(g, location_indicators), axes),
|
||||
lax.reduce_sum(lax.mul(g, location_indicators), axes),
|
||||
counts)
|
||||
series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
|
||||
return primal_out, series_out
|
||||
return chooser_taylor_rule
|
||||
jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(
|
||||
lax_internal._reduce_max)
|
||||
lax.reduce_max)
|
||||
jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(
|
||||
lax_internal._reduce_min)
|
||||
lax.reduce_min)
|
||||
|
||||
def _abs_taylor_rule(x, series_in, **params):
|
||||
x, = x
|
||||
|
@ -162,15 +162,22 @@ from jax._src.lax.lax import (
|
||||
real_p as real_p,
|
||||
reciprocal as reciprocal,
|
||||
reduce as reduce,
|
||||
reduce_and as reduce_and,
|
||||
reduce_and_p as reduce_and_p,
|
||||
reduce_max as reduce_max,
|
||||
reduce_max_p as reduce_max_p,
|
||||
reduce_min as reduce_min,
|
||||
reduce_min_p as reduce_min_p,
|
||||
reduce_or as reduce_or,
|
||||
reduce_or_p as reduce_or_p,
|
||||
reduce_p as reduce_p,
|
||||
reduce_precision as reduce_precision,
|
||||
reduce_precision_p as reduce_precision_p,
|
||||
reduce_prod as reduce_prod,
|
||||
reduce_prod_p as reduce_prod_p,
|
||||
reduce_sum as reduce_sum,
|
||||
reduce_sum_p as reduce_sum_p,
|
||||
reduce_xor as reduce_xor,
|
||||
reduce_xor_p as reduce_xor_p,
|
||||
rem as rem,
|
||||
rem_p as rem_p,
|
||||
|
@ -34,7 +34,6 @@ from jax._src import util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.core import ShapedArray, DBIdx
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -651,7 +650,7 @@ class DynamicShapesTest(jtu.JaxTestCase):
|
||||
def f(x, y):
|
||||
z = lax.mul(x, y)
|
||||
w = lax.sin(z)
|
||||
u = lax_internal._reduce_sum(w, [0])
|
||||
u = lax.reduce_sum(w, [0])
|
||||
return (u,)
|
||||
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
|
@ -23,8 +23,6 @@ import jax
|
||||
from jax import lax
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
||||
from jax.tests.filecheck.jax_filecheck_helpers import print_ir
|
||||
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
@ -56,12 +54,12 @@ def main(_):
|
||||
partial(lax.pad, padding_value=np.int32(7),
|
||||
padding_config=((2, 3, 4), (4, 5, 6))))
|
||||
|
||||
# CHECK-LABEL: TEST: _reduce_sum int32[2,3,7]
|
||||
# CHECK-LABEL: TEST: reduce_sum int32[2,3,7]
|
||||
# CHECK: hlo.reduce
|
||||
# CHECK: hlo.add
|
||||
# CHECK: tensor<3xi32>
|
||||
print_ir(np.empty([2, 3, 7], np.int32))(
|
||||
partial(lax_internal._reduce_sum, axes=(0, 2)))
|
||||
partial(lax.reduce_sum, axes=(0, 2)))
|
||||
|
||||
# CHECK-LABEL: TEST: reshape int32[2,3,7]
|
||||
# CHECK: hlo.reshape
|
||||
|
@ -1941,6 +1941,42 @@ class LaxTest(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(fun)(rng(shape, dtype))
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, primitive)
|
||||
|
||||
@jtu.sample_product(
|
||||
[
|
||||
dict(
|
||||
op=rec.op,
|
||||
reference_op=rec.reference_op,
|
||||
dtype=dtype,
|
||||
)
|
||||
for rec in lax_test_util.lax_named_reduce_ops()
|
||||
for dtype in rec.dtypes
|
||||
],
|
||||
[
|
||||
dict(shape=shape, dims=dims)
|
||||
for shape, dims in [
|
||||
[(3, 4, 5), (0,)],
|
||||
[(3, 4, 5), (1, 2)],
|
||||
[(3, 4, 5), (0, 2)],
|
||||
[(3, 4, 5), (0, 1, 2)],
|
||||
]
|
||||
],
|
||||
)
|
||||
def testNamedReduceOperators(self, op, reference_op, dtype, shape, dims):
|
||||
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small)
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
def lax_fun(operand):
|
||||
return op(operand, dims)
|
||||
|
||||
def reference_fun(operand):
|
||||
return reference_op(operand, dims).astype(dtype)
|
||||
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
self._CheckAgainstNumpy(reference_fun, lax_fun, args_maker)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
op=["add", "mul"],
|
||||
op_namespace=[lax, operator],
|
||||
|
Loading…
x
Reference in New Issue
Block a user