mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support for axis names in jnp.{sum,min,max}
Similarly to `jnp.einsum`, whenever we encounter an extension to the positional NumPy API (in the case of reductions, the extension is whenever a non-integer axis is specified), we reroute the call to a parallel primitive instead of the standard lax reductions. Note that this makes the parallel primitives implement a strict subset of functionality of the lax reductions so in the future (when we decide that we want axes to be truly first class) we can always swap out the implementation for the parallel version. But, it makes sense to keep them separate for the ease of prototyping in the near future.
This commit is contained in:
parent
baf6ed11cf
commit
f86bf12b5a
@ -104,9 +104,9 @@ inline.
|
||||
The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``, in
|
||||
addition to the operand ``e``.
|
||||
|
||||
Note that even though execution of a program that calls into JAX builds a jaxpr,
|
||||
Note that even though execution of a program that calls into JAX builds a jaxpr,
|
||||
Python-level control-flow and Python-level functions execute normally.
|
||||
This means that just because a Python program contains functions and control-flow,
|
||||
This means that just because a Python program contains functions and control-flow,
|
||||
the resulting jaxpr does not have to contain control-flow or higher-order features.
|
||||
|
||||
For example, when tracing the function ``func3`` JAX will inline the call to
|
||||
@ -445,8 +445,8 @@ captured using the ``xla_pmap`` primitive. Consider this example
|
||||
d = broadcast_in_dim[ broadcast_dimensions=( )
|
||||
shape=(1,) ] 1.0
|
||||
e = add c d
|
||||
f = psum[ axis_index_groups=None
|
||||
axis_name=('rows',) ] b
|
||||
f = psum[ axes=('rows',)
|
||||
axis_index_groups=None ] b
|
||||
g = div e f
|
||||
in (g,) }
|
||||
devices=None
|
||||
|
@ -32,7 +32,7 @@ from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.util import partial, unzip2, prod
|
||||
from jax._src.util import partial, unzip2, prod, canonicalize_axis, safe_map
|
||||
from jax.lib import xla_client as xc
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.config import config
|
||||
@ -40,6 +40,8 @@ from jax._src.numpy import lax_numpy
|
||||
|
||||
xops = xc.ops
|
||||
|
||||
unsafe_map, map = map, safe_map # type: ignore
|
||||
|
||||
|
||||
### parallel traceables
|
||||
|
||||
@ -77,11 +79,13 @@ def psum(x, axis_name, *, axis_index_groups=None):
|
||||
"""
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
|
||||
raise ValueError("axis_index_groups only supported for sums over just named axes")
|
||||
_validate_axis_index_groups(axis_index_groups)
|
||||
leaves, treedef = tree_util.tree_flatten(x)
|
||||
leaves = [lax.convert_element_type(l, np.int32)
|
||||
if dtypes.dtype(l) == np.bool_ else l for l in leaves]
|
||||
out_flat = psum_p.bind(*leaves, axis_name=axis_name,
|
||||
out_flat = psum_p.bind(*leaves, axes=axis_name,
|
||||
axis_index_groups=axis_index_groups)
|
||||
return tree_util.tree_unflatten(treedef, out_flat)
|
||||
|
||||
@ -139,9 +143,11 @@ def pmax(x, axis_name, *, axis_index_groups=None):
|
||||
"""
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
|
||||
raise ValueError("axis_index_groups only supported for sums over just named axes")
|
||||
_validate_axis_index_groups(axis_index_groups)
|
||||
leaves, treedef = tree_util.tree_flatten(x)
|
||||
out_flat = pmax_p.bind(*leaves, axis_name=axis_name,
|
||||
out_flat = pmax_p.bind(*leaves, axes=axis_name,
|
||||
axis_index_groups=axis_index_groups)
|
||||
return tree_util.tree_unflatten(treedef, out_flat)
|
||||
|
||||
@ -166,9 +172,11 @@ def pmin(x, axis_name, *, axis_index_groups=None):
|
||||
"""
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
|
||||
raise ValueError("axis_index_groups only supported for sums over just named axes")
|
||||
_validate_axis_index_groups(axis_index_groups)
|
||||
leaves, treedef = tree_util.tree_flatten(x)
|
||||
out_flat = pmin_p.bind(*leaves, axis_name=axis_name,
|
||||
out_flat = pmin_p.bind(*leaves, axes=axis_name,
|
||||
axis_index_groups=axis_index_groups)
|
||||
return tree_util.tree_unflatten(treedef, out_flat)
|
||||
|
||||
@ -484,31 +492,71 @@ class XeinsumSpecParser:
|
||||
|
||||
### parallel primitives
|
||||
|
||||
def _subst_all_names_in_axis_name(params: core.ParamDict, subst: core.AxisSubst) -> core.ParamDict:
|
||||
axis_name = params['axis_name']
|
||||
def _subst_all_names_in_param(
|
||||
pname: str, params: core.ParamDict, subst: core.AxisSubst) -> core.ParamDict:
|
||||
axis_name = params[pname]
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
return dict(params, axis_name=sum((subst(name) for name in axis_name), ()))
|
||||
result = dict(params)
|
||||
result[pname] = sum(((name,) if isinstance(name, int) else subst(name)
|
||||
for name in axis_name),
|
||||
())
|
||||
return result
|
||||
|
||||
# This is only used for collectives that do not include the vmapped axis name,
|
||||
# which is why the rule is so simple.
|
||||
def _collective_batcher(prim, args, dims, **params):
|
||||
return prim.bind(*args, **params), dims if prim.multiple_results else dims[0]
|
||||
|
||||
def _batched_reduction_collective(
|
||||
prim, if_mapped, if_unmapped, frame, vals_in, dims_in, axis_name,
|
||||
axis_index_groups):
|
||||
assert prim.multiple_results
|
||||
assert frame.name in axis_name
|
||||
def _reduction_with_positional_batcher(prim, vals_in, dims_in, axis_index_groups,
|
||||
transform_unmapped, transform_mapped):
|
||||
if axis_index_groups is not None:
|
||||
raise NotImplementedError("axis_index_groups not supported in vmap collectives. "
|
||||
"Please open a feature request!")
|
||||
vals_out = [if_mapped(v, d) if d is not batching.not_mapped
|
||||
else if_unmapped(v, frame.size) for v, d in zip(vals_in, dims_in)]
|
||||
if len(axis_name) > 1:
|
||||
remaining_axis_names = tuple(n for n in axis_name if n != frame.name)
|
||||
vals_out = prim.bind(*vals_out, axis_name=remaining_axis_names,
|
||||
axis_index_groups=None)
|
||||
# TODO: Transpose all dims to 0, increment all axes
|
||||
vals_in = [val if d is batching.not_mapped or d == 0 else _moveaxis(d, 0, val)
|
||||
for val, d in zip(vals_in, dims_in)]
|
||||
mapped_vals_in, unmapped_vals_in = partitioned_vals_in = [], []
|
||||
mapped_idxs, unmapped_idxs = partitioned_idxs = [], []
|
||||
for i, (val, d) in enumerate(zip(vals_in, dims_in)):
|
||||
partitioned_vals_in[d is batching.not_mapped].append(val)
|
||||
partitioned_idxs[d is batching.not_mapped].append(i)
|
||||
vals_out = [None] * len(vals_in)
|
||||
if unmapped_vals_in:
|
||||
unmapped_axes, unmapped_vals_in = transform_unmapped(0, unmapped_vals_in)
|
||||
unmapped_vals_out = prim.bind(*unmapped_vals_in, axes=unmapped_axes, axis_index_groups=None)
|
||||
for i, val in zip(unmapped_idxs, unmapped_vals_out):
|
||||
vals_out[i] = val
|
||||
if mapped_vals_in:
|
||||
mapped_axes, mapped_vals_in = transform_mapped(0, mapped_vals_in)
|
||||
mapped_vals_out = prim.bind(*mapped_vals_in, axes=mapped_axes, axis_index_groups=None)
|
||||
for i, val in zip(mapped_idxs, mapped_vals_out):
|
||||
vals_out[i] = val
|
||||
assert all(v is not None for v in vals_out)
|
||||
return vals_out
|
||||
|
||||
# This is only used for collectives that do not include the vmapped axis name,
|
||||
# which is why the rule is so simple.
|
||||
def _reduction_batcher(prim, vals_in, dims_in, *, axes, axis_index_groups):
|
||||
if not any(isinstance(axis, int) for axis in axes):
|
||||
return prim.bind(*vals_in, axes=axes, axis_index_groups=axis_index_groups), dims_in
|
||||
vals_out = _reduction_with_positional_batcher(
|
||||
prim, vals_in, dims_in, axis_index_groups,
|
||||
lambda d, d_vals_in: (axes, d_vals_in),
|
||||
lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else axis
|
||||
for axis in axes),
|
||||
d_vals_in))
|
||||
return vals_out, dims_in
|
||||
|
||||
def _batched_reduction_collective(
|
||||
prim, if_unmapped, frame, vals_in, dims_in, axes,
|
||||
axis_index_groups):
|
||||
assert prim.multiple_results
|
||||
assert frame.name in axes
|
||||
vals_out = _reduction_with_positional_batcher(
|
||||
prim, vals_in, dims_in, axis_index_groups,
|
||||
lambda d, d_vals_in: (tuple(axis for axis in axes if axis != frame.name),
|
||||
[if_unmapped(v, frame.size) for v in d_vals_in]),
|
||||
lambda d, d_vals_in: (tuple(axis + (axis >= d) if isinstance(axis, int) else
|
||||
axis if axis != frame.name else
|
||||
d
|
||||
for axis in axes),
|
||||
d_vals_in))
|
||||
return vals_out, [batching.not_mapped] * len(vals_out)
|
||||
|
||||
def _replica_groups(axis_env, axis_name, axis_index_groups):
|
||||
@ -519,11 +567,31 @@ def _replica_groups(axis_env, axis_name, axis_index_groups):
|
||||
for axis_index_group in axis_index_groups]
|
||||
return replica_groups
|
||||
|
||||
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
|
||||
def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups):
|
||||
assert axis_index_groups is None
|
||||
assert all(isinstance(axis, int) for axis in axes)
|
||||
return [pos_reducer(arg, axes) for arg in args]
|
||||
|
||||
def _allreduce_abstract_eval(*args, axes, axis_index_groups):
|
||||
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
|
||||
return [ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
|
||||
arg.dtype)
|
||||
for arg in args]
|
||||
|
||||
def _allreduce_translation_rule(prim, pos_prim, c, *args, axes, axis_index_groups,
|
||||
axis_env, platform):
|
||||
named_axes, positional_axes = axes_partition = [], []
|
||||
for axis in axes:
|
||||
axes_partition[isinstance(axis, int)].append(axis)
|
||||
|
||||
if positional_axes:
|
||||
args = map(partial(xla.translations[pos_prim], c, axes=tuple(positional_axes)), args)
|
||||
if not named_axes:
|
||||
return xops.Tuple(c, args)
|
||||
|
||||
if platform in ("cpu", "tpu"):
|
||||
return _notuple_allreduce_translation_rule(
|
||||
prim, c, *args, axis_name=axis_name,
|
||||
prim, c, *args, named_axes=named_axes,
|
||||
axis_index_groups=axis_index_groups, axis_env=axis_env,
|
||||
platform=platform)
|
||||
|
||||
@ -537,7 +605,7 @@ def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
|
||||
|
||||
# The outputs, in the original argument order.
|
||||
out = [None] * len(args)
|
||||
replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
|
||||
replica_groups = _replica_groups(axis_env, named_axes, axis_index_groups)
|
||||
replica_groups_protos = xc.make_replica_groups(replica_groups)
|
||||
for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
|
||||
is_complex = dtypes.issubdtype(dtype, np.complexfloating)
|
||||
@ -562,11 +630,11 @@ def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
|
||||
|
||||
# TODO(b/155446630): An XLA:TPU optimization pass also doesn't support
|
||||
# tuple all-reduce yet. Meanwhile, rely on deterministic compiler behavior.
|
||||
def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env,
|
||||
def _notuple_allreduce_translation_rule(prim, c, *args, named_axes, axis_env,
|
||||
axis_index_groups, platform):
|
||||
def all_reduce(x):
|
||||
replica_groups_protos = xc.make_replica_groups(
|
||||
_replica_groups(axis_env, axis_name, axis_index_groups))
|
||||
_replica_groups(axis_env, named_axes, axis_index_groups))
|
||||
scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
|
||||
computation = xla.primitive_subcomputation(prim, scalar, scalar)
|
||||
return xops.AllReduce(x, computation, replica_groups_protos, None, None)
|
||||
@ -581,64 +649,74 @@ def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env,
|
||||
else all_reduce(x) for x in args]
|
||||
return xops.Tuple(c, outs)
|
||||
|
||||
def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups):
|
||||
def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
|
||||
if any(isinstance(axis, int) for axis in axes):
|
||||
raise NotImplementedError
|
||||
nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
|
||||
nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name,
|
||||
nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axes=axes,
|
||||
axis_index_groups=axis_index_groups)
|
||||
return tree_util.tree_unflatten(treedef, nonzero_in_cts)
|
||||
|
||||
psum_p = core.Primitive('psum')
|
||||
psum_p.multiple_results = True
|
||||
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
|
||||
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p) # type: ignore
|
||||
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
|
||||
psum_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule,
|
||||
lax.add_p, lax.reduce_sum_p) # type: ignore
|
||||
ad.deflinear2(psum_p, _psum_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(psum_p)
|
||||
batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
|
||||
batching.primitive_batchers[psum_p] = partial(_reduction_batcher, psum_p)
|
||||
batching.collective_rules[psum_p] = \
|
||||
partial(_batched_reduction_collective,
|
||||
psum_p,
|
||||
lambda v, d: v.sum(d),
|
||||
lambda v, axis_size: axis_size * v)
|
||||
core.axis_substitution_rules[psum_p] = _subst_all_names_in_axis_name
|
||||
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
|
||||
core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes')
|
||||
|
||||
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
|
||||
# tracing time.
|
||||
@psum_p.def_custom_bind
|
||||
def psum_bind(*args, axis_name, axis_index_groups):
|
||||
def psum_bind(*args, axes, axis_index_groups):
|
||||
if all(not isinstance(x, core.Tracer) for x in args):
|
||||
named_axes, pos_axes = axes_partition = [], []
|
||||
for axis in axes:
|
||||
axes_partition[isinstance(axis, int)].append(axis)
|
||||
def pos_reduce(x):
|
||||
if not pos_axes:
|
||||
return x
|
||||
return lax._reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0))
|
||||
for axis in pos_axes])
|
||||
if axis_index_groups is not None:
|
||||
assert not pos_axes
|
||||
size = len(axis_index_groups[0])
|
||||
elif isinstance(axis_name, (list, tuple)):
|
||||
size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore
|
||||
else:
|
||||
size = core.axis_frame(axis_name).size # type: ignore
|
||||
return tuple(size * x for x in args)
|
||||
size = prod([core.axis_frame(name).size for name in named_axes]) # type: ignore
|
||||
return tuple(size * pos_reduce(x) for x in args)
|
||||
return core.Primitive.bind(
|
||||
psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
|
||||
psum_p, *args, axes=axes, axis_index_groups=axis_index_groups)
|
||||
|
||||
|
||||
pmax_p = core.Primitive('pmax')
|
||||
pmax_p.multiple_results = True
|
||||
pmax_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
|
||||
xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule, lax.max_p)
|
||||
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
|
||||
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
xla.parallel_translations[pmax_p] = partial(_allreduce_translation_rule,
|
||||
lax.max_p, lax.reduce_max_p) # type: ignore
|
||||
pxla.multi_host_supported_collectives.add(pmax_p)
|
||||
batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p)
|
||||
batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p)
|
||||
batching.collective_rules[pmax_p] = \
|
||||
partial(_batched_reduction_collective, pmax_p,
|
||||
lambda v, d: v.max(d), lambda v, axis_size: v)
|
||||
core.axis_substitution_rules[pmax_p] = _subst_all_names_in_axis_name
|
||||
partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v)
|
||||
core.axis_substitution_rules[pmax_p] = partial(_subst_all_names_in_param, 'axes')
|
||||
|
||||
|
||||
pmin_p = core.Primitive('pmin')
|
||||
pmin_p.multiple_results = True
|
||||
pmin_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
|
||||
xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule, lax.min_p)
|
||||
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
|
||||
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
xla.parallel_translations[pmin_p] = partial(_allreduce_translation_rule,
|
||||
lax.min_p, lax.reduce_min_p) # type: ignore
|
||||
pxla.multi_host_supported_collectives.add(pmin_p)
|
||||
batching.primitive_batchers[pmin_p] = partial(_collective_batcher, pmin_p)
|
||||
batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p)
|
||||
batching.collective_rules[pmin_p] = \
|
||||
partial(_batched_reduction_collective, pmin_p,
|
||||
lambda v, d: v.min(d), lambda v, axis_size: v)
|
||||
core.axis_substitution_rules[pmin_p] = _subst_all_names_in_axis_name
|
||||
partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v)
|
||||
core.axis_substitution_rules[pmin_p] = partial(_subst_all_names_in_param, 'axes')
|
||||
|
||||
|
||||
def _ppermute_translation_rule(c, x, *, axis_name, axis_env, perm, platform):
|
||||
@ -670,6 +748,9 @@ def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm):
|
||||
perm_indices[src] = dst
|
||||
return lax_numpy.take(v, perm_indices, d), d
|
||||
|
||||
def _collective_batcher(prim, args, dims, **params):
|
||||
return prim.bind(*args, **params), dims if prim.multiple_results else dims[0]
|
||||
|
||||
ppermute_p = core.Primitive('ppermute')
|
||||
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
|
||||
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
|
||||
@ -677,7 +758,7 @@ xla.parallel_translations[ppermute_p] = _ppermute_translation_rule
|
||||
pxla.multi_host_supported_collectives.add(ppermute_p)
|
||||
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
|
||||
batching.collective_rules[ppermute_p] = _ppermute_batcher
|
||||
core.axis_substitution_rules[ppermute_p] = _subst_all_names_in_axis_name
|
||||
core.axis_substitution_rules[ppermute_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
|
||||
def _moveaxis(src, dst, x):
|
||||
@ -795,7 +876,7 @@ ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(all_to_all_p)
|
||||
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
|
||||
batching.collective_rules[all_to_all_p] = _all_to_all_batched_collective
|
||||
core.axis_substitution_rules[all_to_all_p] = _subst_all_names_in_axis_name
|
||||
core.axis_substitution_rules[all_to_all_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
|
||||
def _expand(dim, size, index, x):
|
||||
@ -947,7 +1028,7 @@ ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(all_gather_p)
|
||||
batching.primitive_batchers[all_gather_p] = _all_gather_batcher
|
||||
batching.collective_rules[all_gather_p] = _all_gather_batched_collective
|
||||
core.axis_substitution_rules[all_gather_p] = _subst_all_names_in_axis_name
|
||||
core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
|
||||
axis_pos = list(axis_env.names).index(axis_name)
|
||||
@ -963,7 +1044,7 @@ xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
|
||||
axis_index_p.def_abstract_eval(
|
||||
lambda *args, **params: ShapedArray((), np.int32))
|
||||
pxla.multi_host_supported_collectives.add(axis_index_p)
|
||||
core.axis_substitution_rules[axis_index_p] = _subst_all_names_in_axis_name
|
||||
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
# Axis index doesn't get any arguments, so that the default bind would have no
|
||||
# way to call into a data-dependency based trace such as vmap. Each trace that
|
||||
@ -993,7 +1074,7 @@ batching.BatchTrace.process_axis_index = _process_axis_index # type: ignore
|
||||
|
||||
|
||||
pdot_p = core.Primitive('pdot')
|
||||
core.axis_substitution_rules[pdot_p] = _subst_all_names_in_axis_name
|
||||
core.axis_substitution_rules[pdot_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
@pdot_p.def_impl
|
||||
def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch):
|
||||
@ -1042,7 +1123,7 @@ def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch,
|
||||
preferred_element_type=None)
|
||||
if axis_name:
|
||||
out_tup = xla.parallel_translations[psum_p](
|
||||
c, local_out, axis_name=axis_name, axis_index_groups=None,
|
||||
c, local_out, axes=axis_name, axis_index_groups=None,
|
||||
axis_env=axis_env, platform=platform)
|
||||
out, = xla.xla_destructure(c, out_tup)
|
||||
else:
|
||||
|
@ -1852,7 +1852,8 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
|
||||
|
||||
def _reduction(a, name, np_fun, op, init_val, has_identity=True,
|
||||
preproc=None, bool_op=None, upcast_f16_for_computation=False,
|
||||
axis=None, dtype=None, out=None, keepdims=False, initial=None, where_=None):
|
||||
axis=None, dtype=None, out=None, keepdims=False, initial=None,
|
||||
where_=None, parallel_reduce=None):
|
||||
bool_op = bool_op or op
|
||||
if out is not None:
|
||||
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
|
||||
@ -1869,7 +1870,7 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
|
||||
|
||||
a = a if isinstance(a, ndarray) else asarray(a)
|
||||
a = preproc(a) if preproc else a
|
||||
dims = _reduction_dims(a, axis)
|
||||
pos_dims, dims = _reduction_dims(a, axis)
|
||||
result_dtype = dtypes.canonicalize_dtype(dtype or _dtype(np_fun(np.ones((), dtype=_dtype(a)))))
|
||||
if upcast_f16_for_computation and issubdtype(result_dtype, inexact):
|
||||
computation_dtype = promote_types(result_dtype, float32)
|
||||
@ -1882,23 +1883,38 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
|
||||
init_val = _reduction_init_val(a, init_val)
|
||||
if where_ is not None:
|
||||
a = where(where_, a, init_val)
|
||||
result = lax.reduce(a, init_val, op, dims)
|
||||
if pos_dims is not dims:
|
||||
if parallel_reduce is None:
|
||||
raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
|
||||
result = parallel_reduce(a, dims)
|
||||
else:
|
||||
result = lax.reduce(a, init_val, op, dims)
|
||||
if initial is not None:
|
||||
result = op(_reduction_init_val(a, initial), result)
|
||||
if keepdims:
|
||||
result = expand_dims(result, dims)
|
||||
result = expand_dims(result, pos_dims)
|
||||
return lax.convert_element_type(result, dtype or result_dtype)
|
||||
|
||||
def _canonicalize_axis_allow_named(x, rank):
|
||||
try:
|
||||
return _canonicalize_axis(x, rank)
|
||||
except TypeError:
|
||||
return x
|
||||
|
||||
def _reduction_dims(a, axis):
|
||||
if axis is None:
|
||||
return tuple(range(ndim(a)))
|
||||
elif isinstance(axis, (np.ndarray, tuple, list)):
|
||||
axis = tuple(_canonicalize_axis(x, ndim(a)) for x in axis)
|
||||
if len(axis) != len(set(axis)):
|
||||
raise ValueError(f"duplicate value in 'axis': {axis}")
|
||||
return axis
|
||||
return (tuple(range(ndim(a))),) * 2
|
||||
elif not isinstance(axis, (np.ndarray, tuple, list)):
|
||||
axis = (axis,)
|
||||
canon_axis = tuple(_canonicalize_axis_allow_named(x, ndim(a))
|
||||
for x in axis)
|
||||
if len(canon_axis) != len(set(canon_axis)):
|
||||
raise ValueError(f"duplicate value in 'axis': {axis}")
|
||||
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
|
||||
if len(canon_pos_axis) != len(canon_axis):
|
||||
return canon_pos_axis, canon_axis
|
||||
else:
|
||||
return (_canonicalize_axis(axis, ndim(a)),)
|
||||
return canon_axis, canon_axis
|
||||
|
||||
def _reduction_init_val(a, init_val):
|
||||
a_dtype = dtypes.canonicalize_dtype(_dtype(a))
|
||||
@ -1918,7 +1934,8 @@ def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
out=None, keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "sum", np.sum, lax.add, 0,
|
||||
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where)
|
||||
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.psum)
|
||||
|
||||
@_wraps(np.prod)
|
||||
def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
@ -1931,13 +1948,15 @@ def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where)
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmax)
|
||||
|
||||
@_wraps(np.min)
|
||||
def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
keepdims=None, initial=None, where=None):
|
||||
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
|
||||
axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where)
|
||||
axis=axis, out=out, keepdims=keepdims,
|
||||
initial=initial, where_=where, parallel_reduce=lax.pmin)
|
||||
|
||||
@_wraps(np.all)
|
||||
def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
|
||||
|
@ -102,7 +102,7 @@ expit.defjvps(lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans))
|
||||
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
if b is not None:
|
||||
a, b = jnp.broadcast_arrays(a, b)
|
||||
dims = _reduction_dims(a, axis)
|
||||
_, dims = _reduction_dims(a, axis)
|
||||
dimadd = lambda x: lax.expand_dims(x, dims)
|
||||
amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
|
||||
amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
|
||||
|
@ -1688,7 +1688,7 @@ def omnistaging_disabler() -> None:
|
||||
# particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We
|
||||
# look up information in the dynamic axis env.
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
axis_name = params.pop('axis_name')
|
||||
axis_name = params.pop('axes')
|
||||
axis_index_groups = params.pop('axis_index_groups')
|
||||
if axis_index_groups is not None:
|
||||
shape = (len(axis_index_groups[0]),)
|
||||
|
@ -338,6 +338,38 @@ class XMapTestSPMD(XMapTest):
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
|
||||
|
||||
|
||||
class NamedNumPyTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
if jax.lib.version < (0, 1, 58):
|
||||
raise SkipTest("xmap requires jaxlib version >= 0.1.58")
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{reduction.__name__}_axes={axes}_i={mapped_axis}",
|
||||
"reduction": reduction, "axes": axes, "mapped_axis": mapped_axis}
|
||||
for reduction in (jnp.sum, jnp.max, jnp.min)
|
||||
for axes in (0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0))
|
||||
for mapped_axis in range(2))
|
||||
@ignore_xmap_warning()
|
||||
def testReductions(self, reduction, axes, mapped_axis):
|
||||
axes_t = axes if isinstance(axes, tuple) else (axes,)
|
||||
reduces_i = 'i' in axes_t
|
||||
ref_red = partial(reduction,
|
||||
axis=tuple(mapped_axis if a == 'i' else a + (a >= mapped_axis)
|
||||
for a in axes_t))
|
||||
mapped_axis_after_red = mapped_axis - sum(axis < mapped_axis if axis != 'i' else 0
|
||||
for axis in axes_t)
|
||||
xmap_red = xmap(lambda x: reduction(x, axes),
|
||||
in_axes={mapped_axis: 'i'},
|
||||
out_axes=({} if 'i' in axes_t else {mapped_axis_after_red: 'i'}))
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(2, 5, 6)
|
||||
self.assertAllClose(ref_red(x), xmap_red(x))
|
||||
|
||||
|
||||
AxisIndices = Tuple[int, ...]
|
||||
MatchedAxisIndices = Tuple[AxisIndices, AxisIndices]
|
||||
AxisNames = Tuple[str, ...]
|
||||
|
Loading…
x
Reference in New Issue
Block a user