Add support for axis names in jnp.{sum,min,max}

Similarly to `jnp.einsum`, whenever we encounter an extension to the
positional NumPy API (in the case of reductions, the extension is
whenever a non-integer axis is specified), we reroute the call to a
parallel primitive instead of the standard lax reductions.

Note that this makes the parallel primitives implement a strict subset
of functionality of the lax reductions so in the future (when we decide
that we want axes to be truly first class) we can always swap out the
implementation for the parallel version. But, it makes sense to keep
them separate for the ease of prototyping in the near future.
This commit is contained in:
Adam Paszke 2021-01-27 16:48:51 +00:00
parent baf6ed11cf
commit f86bf12b5a
6 changed files with 214 additions and 82 deletions

View File

@ -104,9 +104,9 @@ inline.
The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``, in
addition to the operand ``e``.
Note that even though execution of a program that calls into JAX builds a jaxpr,
Note that even though execution of a program that calls into JAX builds a jaxpr,
Python-level control-flow and Python-level functions execute normally.
This means that just because a Python program contains functions and control-flow,
This means that just because a Python program contains functions and control-flow,
the resulting jaxpr does not have to contain control-flow or higher-order features.
For example, when tracing the function ``func3`` JAX will inline the call to
@ -445,8 +445,8 @@ captured using the ``xla_pmap`` primitive. Consider this example
d = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
e = add c d
f = psum[ axis_index_groups=None
axis_name=('rows',) ] b
f = psum[ axes=('rows',)
axis_index_groups=None ] b
g = div e f
in (g,) }
devices=None

View File

@ -32,7 +32,7 @@ from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax._src.util import partial, unzip2, prod
from jax._src.util import partial, unzip2, prod, canonicalize_axis, safe_map
from jax.lib import xla_client as xc
from jax.lib import xla_bridge as xb
from jax.config import config
@ -40,6 +40,8 @@ from jax._src.numpy import lax_numpy
xops = xc.ops
unsafe_map, map = map, safe_map # type: ignore
### parallel traceables
@ -77,11 +79,13 @@ def psum(x, axis_name, *, axis_index_groups=None):
"""
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
raise ValueError("axis_index_groups only supported for sums over just named axes")
_validate_axis_index_groups(axis_index_groups)
leaves, treedef = tree_util.tree_flatten(x)
leaves = [lax.convert_element_type(l, np.int32)
if dtypes.dtype(l) == np.bool_ else l for l in leaves]
out_flat = psum_p.bind(*leaves, axis_name=axis_name,
out_flat = psum_p.bind(*leaves, axes=axis_name,
axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, out_flat)
@ -139,9 +143,11 @@ def pmax(x, axis_name, *, axis_index_groups=None):
"""
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
raise ValueError("axis_index_groups only supported for sums over just named axes")
_validate_axis_index_groups(axis_index_groups)
leaves, treedef = tree_util.tree_flatten(x)
out_flat = pmax_p.bind(*leaves, axis_name=axis_name,
out_flat = pmax_p.bind(*leaves, axes=axis_name,
axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, out_flat)
@ -166,9 +172,11 @@ def pmin(x, axis_name, *, axis_index_groups=None):
"""
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
raise ValueError("axis_index_groups only supported for sums over just named axes")
_validate_axis_index_groups(axis_index_groups)
leaves, treedef = tree_util.tree_flatten(x)
out_flat = pmin_p.bind(*leaves, axis_name=axis_name,
out_flat = pmin_p.bind(*leaves, axes=axis_name,
axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, out_flat)
@ -484,31 +492,71 @@ class XeinsumSpecParser:
### parallel primitives
def _subst_all_names_in_axis_name(params: core.ParamDict, subst: core.AxisSubst) -> core.ParamDict:
axis_name = params['axis_name']
def _subst_all_names_in_param(
pname: str, params: core.ParamDict, subst: core.AxisSubst) -> core.ParamDict:
axis_name = params[pname]
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)
return dict(params, axis_name=sum((subst(name) for name in axis_name), ()))
result = dict(params)
result[pname] = sum(((name,) if isinstance(name, int) else subst(name)
for name in axis_name),
())
return result
# This is only used for collectives that do not include the vmapped axis name,
# which is why the rule is so simple.
def _collective_batcher(prim, args, dims, **params):
return prim.bind(*args, **params), dims if prim.multiple_results else dims[0]
def _batched_reduction_collective(
prim, if_mapped, if_unmapped, frame, vals_in, dims_in, axis_name,
axis_index_groups):
assert prim.multiple_results
assert frame.name in axis_name
def _reduction_with_positional_batcher(prim, vals_in, dims_in, axis_index_groups,
transform_unmapped, transform_mapped):
if axis_index_groups is not None:
raise NotImplementedError("axis_index_groups not supported in vmap collectives. "
"Please open a feature request!")
vals_out = [if_mapped(v, d) if d is not batching.not_mapped
else if_unmapped(v, frame.size) for v, d in zip(vals_in, dims_in)]
if len(axis_name) > 1:
remaining_axis_names = tuple(n for n in axis_name if n != frame.name)
vals_out = prim.bind(*vals_out, axis_name=remaining_axis_names,
axis_index_groups=None)
# TODO: Transpose all dims to 0, increment all axes
vals_in = [val if d is batching.not_mapped or d == 0 else _moveaxis(d, 0, val)
for val, d in zip(vals_in, dims_in)]
mapped_vals_in, unmapped_vals_in = partitioned_vals_in = [], []
mapped_idxs, unmapped_idxs = partitioned_idxs = [], []
for i, (val, d) in enumerate(zip(vals_in, dims_in)):
partitioned_vals_in[d is batching.not_mapped].append(val)
partitioned_idxs[d is batching.not_mapped].append(i)
vals_out = [None] * len(vals_in)
if unmapped_vals_in:
unmapped_axes, unmapped_vals_in = transform_unmapped(0, unmapped_vals_in)
unmapped_vals_out = prim.bind(*unmapped_vals_in, axes=unmapped_axes, axis_index_groups=None)
for i, val in zip(unmapped_idxs, unmapped_vals_out):
vals_out[i] = val
if mapped_vals_in:
mapped_axes, mapped_vals_in = transform_mapped(0, mapped_vals_in)
mapped_vals_out = prim.bind(*mapped_vals_in, axes=mapped_axes, axis_index_groups=None)
for i, val in zip(mapped_idxs, mapped_vals_out):
vals_out[i] = val
assert all(v is not None for v in vals_out)
return vals_out
# This is only used for collectives that do not include the vmapped axis name,
# which is why the rule is so simple.
def _reduction_batcher(prim, vals_in, dims_in, *, axes, axis_index_groups):
if not any(isinstance(axis, int) for axis in axes):
return prim.bind(*vals_in, axes=axes, axis_index_groups=axis_index_groups), dims_in
vals_out = _reduction_with_positional_batcher(
prim, vals_in, dims_in, axis_index_groups,
lambda d, d_vals_in: (axes, d_vals_in),
lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else axis
for axis in axes),
d_vals_in))
return vals_out, dims_in
def _batched_reduction_collective(
prim, if_unmapped, frame, vals_in, dims_in, axes,
axis_index_groups):
assert prim.multiple_results
assert frame.name in axes
vals_out = _reduction_with_positional_batcher(
prim, vals_in, dims_in, axis_index_groups,
lambda d, d_vals_in: (tuple(axis for axis in axes if axis != frame.name),
[if_unmapped(v, frame.size) for v in d_vals_in]),
lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else
axis if axis != frame.name else
d
for axis in axes),
d_vals_in))
return vals_out, [batching.not_mapped] * len(vals_out)
def _replica_groups(axis_env, axis_name, axis_index_groups):
@ -519,11 +567,31 @@ def _replica_groups(axis_env, axis_name, axis_index_groups):
for axis_index_group in axis_index_groups]
return replica_groups
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups):
assert axis_index_groups is None
assert all(isinstance(axis, int) for axis in axes)
return [pos_reducer(arg, axes) for arg in args]
def _allreduce_abstract_eval(*args, axes, axis_index_groups):
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
return [ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
arg.dtype)
for arg in args]
def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_groups,
axis_env, platform):
named_axes, positional_axes = axes_partition = [], []
for axis in axes:
axes_partition[isinstance(axis, int)].append(axis)
if positional_axes:
args = map(partial(xla.translations[pos_prim], c, axes=tuple(positional_axes)), args)
if not named_axes:
return xops.Tuple(c, args)
if platform in ("cpu", "tpu"):
return _notuple_allreduce_translation_rule(
prim, c, *args, axis_name=axis_name,
prim, c, *args, named_axes=named_axes,
axis_index_groups=axis_index_groups, axis_env=axis_env,
platform=platform)
@ -537,7 +605,7 @@ def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
# The outputs, in the original argument order.
out = [None] * len(args)
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
replica_groups = _replica_groups(axis_env, named_axes, axis_index_groups)
replica_groups_protos = xc.make_replica_groups(replica_groups)
for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
is_complex = dtypes.issubdtype(dtype, np.complexfloating)
@ -562,11 +630,11 @@ def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
# TODO(b/155446630): An XLA:TPU optimization pass also doesn't support
# tuple all-reduce yet. Meanwhile, rely on deterministic compiler behavior.
def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env,
def _notuple_allreduce_translation_rule(prim, c, *args, named_axes, axis_env,
axis_index_groups, platform):
def all_reduce(x):
replica_groups_protos = xc.make_replica_groups(
_replica_groups(axis_env, axis_name, axis_index_groups))
_replica_groups(axis_env, named_axes, axis_index_groups))
scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
computation = xla.primitive_subcomputation(prim, scalar, scalar)
return xops.AllReduce(x, computation, replica_groups_protos, None, None)
@ -581,64 +649,74 @@ def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env,
else all_reduce(x) for x in args]
return xops.Tuple(c, outs)
def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups):
def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
if any(isinstance(axis, int) for axis in axes):
raise NotImplementedError
nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name,
nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axes=axes,
axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, nonzero_in_cts)
psum_p = core.Primitive('psum')
psum_p.multiple_results = True
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p) # type: ignore
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
psum_p.def_abstract_eval(_allreduce_abstract_eval)
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule,
lax.add_p, lax.reduce_sum_p) # type: ignore
ad.deflinear2(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
batching.primitive_batchers[psum_p] = partial(_reduction_batcher, psum_p)
batching.collective_rules[psum_p] = \
partial(_batched_reduction_collective,
psum_p,
lambda v, d: v.sum(d),
lambda v, axis_size: axis_size * v)
core.axis_substitution_rules[psum_p] = _subst_all_names_in_axis_name
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes')
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@psum_p.def_custom_bind
def psum_bind(*args, axis_name, axis_index_groups):
def psum_bind(*args, axes, axis_index_groups):
if all(not isinstance(x, core.Tracer) for x in args):
named_axes, pos_axes = axes_partition = [], []
for axis in axes:
axes_partition[isinstance(axis, int)].append(axis)
def pos_reduce(x):
if not pos_axes:
return x
return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0))
for axis in pos_axes])
if axis_index_groups is not None:
assert not pos_axes
size = len(axis_index_groups[0])
elif isinstance(axis_name, (list, tuple)):
size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore
else:
size = core.axis_frame(axis_name).size # type: ignore
return tuple(size * x for x in args)
size = prod([core.axis_frame(name).size for name in named_axes]) # type: ignore
return tuple(size * pos_reduce(x) for x in args)
return core.Primitive.bind(
psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
psum_p, *args, axes=axes, axis_index_groups=axis_index_groups)
pmax_p = core.Primitive('pmax')
pmax_p.multiple_results = True
pmax_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule, lax.max_p)
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule,
lax.max_p, lax.reduce_max_p) # type: ignore
pxla.multi_host_supported_collectives.add(pmax_p)
batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p)
batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p)
batching.collective_rules[pmax_p] = \
partial(_batched_reduction_collective, pmax_p,
lambda v, d: v.max(d), lambda v, axis_size: v)
core.axis_substitution_rules[pmax_p] = _subst_all_names_in_axis_name
partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v)
core.axis_substitution_rules[pmax_p] = partial(_subst_all_names_in_param, 'axes')
pmin_p = core.Primitive('pmin')
pmin_p.multiple_results = True
pmin_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule, lax.min_p)
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule,
lax.min_p, lax.reduce_min_p) # type: ignore
pxla.multi_host_supported_collectives.add(pmin_p)
batching.primitive_batchers[pmin_p] = partial(_collective_batcher, pmin_p)
batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p)
batching.collective_rules[pmin_p] = \
partial(_batched_reduction_collective, pmin_p,
lambda v, d: v.min(d), lambda v, axis_size: v)
core.axis_substitution_rules[pmin_p] = _subst_all_names_in_axis_name
partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v)
core.axis_substitution_rules[pmin_p] = partial(_subst_all_names_in_param, 'axes')
def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform):
@ -670,6 +748,9 @@ def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm):
perm_indices[src] = dst
return lax_numpy.take(v, perm_indices, d), d
def _collective_batcher(prim, args, dims, **params):
return prim.bind(*args, **params), dims if prim.multiple_results else dims[0]
ppermute_p = core.Primitive('ppermute')
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
@ -677,7 +758,7 @@ xla.parallel_translations[ppermute_p] = _ppermute_translation_rule
pxla.multi_host_supported_collectives.add(ppermute_p)
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
batching.collective_rules[ppermute_p] = _ppermute_batcher
core.axis_substitution_rules[ppermute_p] = _subst_all_names_in_axis_name
core.axis_substitution_rules[ppermute_p] = partial(_subst_all_names_in_param, 'axis_name')
def _moveaxis(src, dst, x):
@ -795,7 +876,7 @@ ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
pxla.multi_host_supported_collectives.add(all_to_all_p)
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
batching.collective_rules[all_to_all_p] = _all_to_all_batched_collective
core.axis_substitution_rules[all_to_all_p] = _subst_all_names_in_axis_name
core.axis_substitution_rules[all_to_all_p] = partial(_subst_all_names_in_param, 'axis_name')
def _expand(dim, size, index, x):
@ -947,7 +1028,7 @@ ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
pxla.multi_host_supported_collectives.add(all_gather_p)
batching.primitive_batchers[all_gather_p] = _all_gather_batcher
batching.collective_rules[all_gather_p] = _all_gather_batched_collective
core.axis_substitution_rules[all_gather_p] = _subst_all_names_in_axis_name
core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name')
def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
axis_pos = list(axis_env.names).index(axis_name)
@ -963,7 +1044,7 @@ xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
axis_index_p.def_abstract_eval(
lambda *args, **params: ShapedArray((), np.int32))
pxla.multi_host_supported_collectives.add(axis_index_p)
core.axis_substitution_rules[axis_index_p] = _subst_all_names_in_axis_name
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')
# Axis index doesn't get any arguments, so that the default bind would have no
# way to call into a data-dependency based trace such as vmap. Each trace that
@ -993,7 +1074,7 @@ batching.BatchTrace.process_axis_index = _process_axis_index # type: ignore
pdot_p = core.Primitive('pdot')
core.axis_substitution_rules[pdot_p] = _subst_all_names_in_axis_name
core.axis_substitution_rules[pdot_p] = partial(_subst_all_names_in_param, 'axis_name')
@pdot_p.def_impl
def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch):
@ -1042,7 +1123,7 @@ def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch,
preferred_element_type=None)
if axis_name:
out_tup = xla.parallel_translations[psum_p](
c, local_out, axis_name=axis_name, axis_index_groups=None,
c, local_out, axes=axis_name, axis_index_groups=None,
axis_env=axis_env, platform=platform)
out, = xla.xla_destructure(c, out_tup)
else:

View File

@ -1852,7 +1852,8 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
def _reduction(a, name, np_fun, op, init_val, has_identity=True,
preproc=None, bool_op=None, upcast_f16_for_computation=False,
axis=None, dtype=None, out=None, keepdims=False, initial=None, where_=None):
axis=None, dtype=None, out=None, keepdims=False, initial=None,
where_=None, parallel_reduce=None):
bool_op = bool_op or op
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
@ -1869,7 +1870,7 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
a = a if isinstance(a, ndarray) else asarray(a)
a = preproc(a) if preproc else a
dims = _reduction_dims(a, axis)
pos_dims, dims = _reduction_dims(a, axis)
result_dtype = dtypes.canonicalize_dtype(dtype or _dtype(np_fun(np.ones((), dtype=_dtype(a)))))
if upcast_f16_for_computation and issubdtype(result_dtype, inexact):
computation_dtype = promote_types(result_dtype, float32)
@ -1882,23 +1883,38 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
init_val = _reduction_init_val(a, init_val)
if where_ is not None:
a = where(where_, a, init_val)
result = lax.reduce(a, init_val, op, dims)
if pos_dims is not dims:
if parallel_reduce is None:
raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
result = parallel_reduce(a, dims)
else:
result = lax.reduce(a, init_val, op, dims)
if initial is not None:
result = op(_reduction_init_val(a, initial), result)
if keepdims:
result = expand_dims(result, dims)
result = expand_dims(result, pos_dims)
return lax.convert_element_type(result, dtype or result_dtype)
def _canonicalize_axis_allow_named(x, rank):
try:
return _canonicalize_axis(x, rank)
except TypeError:
return x
def _reduction_dims(a, axis):
if axis is None:
return tuple(range(ndim(a)))
elif isinstance(axis, (np.ndarray, tuple, list)):
axis = tuple(_canonicalize_axis(x, ndim(a)) for x in axis)
if len(axis) != len(set(axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
return axis
return (tuple(range(ndim(a))),) * 2
elif not isinstance(axis, (np.ndarray, tuple, list)):
axis = (axis,)
canon_axis = tuple(_canonicalize_axis_allow_named(x, ndim(a))
for x in axis)
if len(canon_axis) != len(set(canon_axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
if len(canon_pos_axis) != len(canon_axis):
return canon_pos_axis, canon_axis
else:
return (_canonicalize_axis(axis, ndim(a)),)
return canon_axis, canon_axis
def _reduction_init_val(a, init_val):
a_dtype = dtypes.canonicalize_dtype(_dtype(a))
@ -1918,7 +1934,8 @@ def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
return _reduction(a, "sum", np.sum, lax.add, 0,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where)
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum)
@_wraps(np.prod)
def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
@ -1931,13 +1948,15 @@ def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None):
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where)
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmax)
@_wraps(np.min)
def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None):
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where)
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmin)
@_wraps(np.all)
def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,

View File

@ -102,7 +102,7 @@ expit.defjvps(lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans))
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
if b is not None:
a, b = jnp.broadcast_arrays(a, b)
dims = _reduction_dims(a, axis)
_, dims = _reduction_dims(a, axis)
dimadd = lambda x: lax.expand_dims(x, dims)
amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))

View File

@ -1688,7 +1688,7 @@ def omnistaging_disabler() -> None:
# particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We
# look up information in the dynamic axis env.
dynamic_axis_env = _thread_local_state.dynamic_axis_env
axis_name = params.pop('axis_name')
axis_name = params.pop('axes')
axis_index_groups = params.pop('axis_index_groups')
if axis_index_groups is not None:
shape = (len(axis_index_groups[0]),)

View File

@ -338,6 +338,38 @@ class XMapTestSPMD(XMapTest):
jax.experimental.maps.make_xmap_callable.cache_clear()
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
class NamedNumPyTest(jtu.JaxTestCase):
def setUp(self):
if jax.lib.version < (0, 1, 58):
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
@parameterized.named_parameters(
{"testcase_name": f"{reduction.__name__}_axes={axes}_i={mapped_axis}",
"reduction": reduction, "axes": axes, "mapped_axis": mapped_axis}
for reduction in (jnp.sum, jnp.max, jnp.min)
for axes in (0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0))
for mapped_axis in range(2))
@ignore_xmap_warning()
def testReductions(self, reduction, axes, mapped_axis):
axes_t = axes if isinstance(axes, tuple) else (axes,)
reduces_i = 'i' in axes_t
ref_red = partial(reduction,
axis=tuple(mapped_axis if a == 'i' else a + (a >= mapped_axis)
for a in axes_t))
mapped_axis_after_red = mapped_axis - sum(axis < mapped_axis if axis != 'i' else 0
for axis in axes_t)
xmap_red = xmap(lambda x: reduction(x, axes),
in_axes={mapped_axis: 'i'},
out_axes=({} if 'i' in axes_t else {mapped_axis_after_red: 'i'}))
rng = np.random.RandomState(0)
x = rng.randn(2, 5, 6)
self.assertAllClose(ref_red(x), xmap_red(x))
AxisIndices = Tuple[int, ...]
MatchedAxisIndices = Tuple[AxisIndices, AxisIndices]
AxisNames = Tuple[str, ...]