rocm_jax/jax/lax/lax_parallel.py

931 lines
35 KiB
Python

# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Parallelization primitives.
"""
import collections
import numpy as np
from jax import core
from jax import ad_util
from jax import dtypes
from jax import tree_util
from jax.lax import lax
from jax.abstract_arrays import ShapedArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import parallel
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.util import partial, unzip2, prod
from jax.lib import xla_client as xc
from jax.interpreters.pxla import axis_index
xops = xc.ops
### parallel traceables
def psum(x, axis_name, *, axis_index_groups=None):
"""Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
Inputs of boolean dtype are converted to integers before the reduction.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
two and last two replicas). Groups must cover all axis indices exactly
once, and all groups must be the same size.
Returns:
Array(s) with the same shape as ``x`` representing the result of an
all-reduce sum along the axis ``axis_name``.
For example, with 4 XLA devices available:
>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[6 6 6 6]
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[ 0. 0.16666667 0.33333334 0.5 ]
"""
_validate_axis_index_groups(axis_index_groups)
leaves, treedef = tree_util.tree_flatten(x)
leaves = [lax.convert_element_type(l, np.int32)
if dtypes.dtype(l) == np.bool_ else l for l in leaves]
out_flat = psum_p.bind(*leaves, axis_name=axis_name,
axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, out_flat)
def pmean(x, axis_name, *, axis_index_groups=None):
"""Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first
two and last two replicas). Groups must cover all axis indices exactly
once, and all groups must be the same size.
Returns:
Array(s) with the same shape as ``x`` representing the result of an
all-reduce mean along the axis ``axis_name``.
For example, with 4 XLA devices available:
>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.pmean(x, 'i'), axis_name='i')(x)
>>> print(y)
[ 1.5 1.5 1.5 1.5 ]
>>> y = jax.pmap(lambda x: x / jax.lax.pmean(x, 'i'), axis_name='i')(x)
>>> print(y)
[ 0. 0.66666667 1.33333334 2.0 ]
"""
x = psum(x, axis_name=axis_name, axis_index_groups=axis_index_groups)
n = psum(1, axis_name=axis_name, axis_index_groups=axis_index_groups)
return tree_util.tree_map(lambda v: v / n, x)
def pmax(x, axis_name, *, axis_index_groups=None):
"""Compute an all-reduce max on ``x`` over the pmapped axis ``axis_name``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first
two and last two replicas). Groups must cover all axis indices exactly
once, and all groups must be the same size.
Returns:
Array(s) with the same shape as ``x`` representing the result of an
all-reduce max along the axis ``axis_name``.
"""
_validate_axis_index_groups(axis_index_groups)
return tree_util.tree_map(partial(
pmax_p.bind, axis_name=axis_name, axis_index_groups=axis_index_groups), x)
def pmin(x, axis_name, *, axis_index_groups=None):
"""Compute an all-reduce min on ``x`` over the pmapped axis ``axis_name``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first
two and last two replicas). Groups must cover all axis indices exactly
once, and all groups must be the same size.
Returns:
Array(s) with the same shape as ``x`` representing the result of an
all-reduce min along the axis ``axis_name``.
"""
_validate_axis_index_groups(axis_index_groups)
return tree_util.tree_map(partial(
pmin_p.bind, axis_name=axis_name, axis_index_groups=axis_index_groups), x)
def _validate_axis_index_groups(axis_index_groups):
if axis_index_groups is None:
return
len_0 = len(axis_index_groups[0])
if any(len(g) != len_0 for g in axis_index_groups):
raise ValueError("axis_index_groups must all be the same size")
axis_space = range(len_0 * len(axis_index_groups))
if set(i for g in axis_index_groups for i in g) != set(axis_space):
raise ValueError("axis_index_groups must cover all indices exactly once")
def ppermute(x, axis_name, perm):
"""Perform a collective permutation according to the permutation ``perm``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
This function is an analog of the CollectivePermute XLA HLO.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
perm: list of pairs of ints, representing
``(source_index, destination_index)``
pairs that encode how the mapped axis named ``axis_name`` should be
shuffled. The integer values are treated as indices into the mapped axis
``axis_name``. Any two pairs should not have the same source index or the
same destination index. For each index of the axis ``axis_name`` that does
not correspond to a destination index in ``perm``, the corresponding
values in the result are filled with zeros of the appropriate type.
Returns:
Array(s) with the same shape as ``x`` with slices along the axis
``axis_name`` gathered from ``x`` according to the permutation ``perm``.
"""
return tree_util.tree_map(
partial(ppermute_p.bind, axis_name=axis_name, perm=tuple(perm)), x)
def pshuffle(x, axis_name, perm):
"""Convenience wrapper of jax.lax.ppermute with alternate permutation encoding
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
perm: list of of ints encoding sources for the permutation to be applied to
the axis named ``axis_name``, so that the output at axis index i
comes from the input at axis index perm[i]. Every integer in [0, N) should
be included exactly once for axis size N.
Returns:
Array(s) with the same shape as ``x`` with slices along the axis
``axis_name`` gathered from ``x`` according to the permutation ``perm``.
"""
if set(perm) != set(range(len(perm))):
raise ValueError(f"`perm` does not represent a permutation: {perm}")
return ppermute(x, axis_name, list(zip(perm, range(len(perm)))))
def pswapaxes(x, axis_name, axis):
"""Swap the pmapped axis ``axis_name`` with the unmapped axis ``axis``.
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
The mapped axis size must be equal to the size of the unmapped axis; that is,
we must have ``lax.psum(1, axis_name) == x.shape[axis]``.
This function is a special case of ``all_to_all`` where the pmapped axis of
the input is placed at the position ``axis`` in the output. That is, it is
equivalent to ``all_to_all(x, axis_name, axis, axis)``.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
axis: int indicating the unmapped axis of ``x`` to map with the name
``axis_name``.
Returns:
Array(s) with shape ``np.insert(np.delete(x.shape, axis), axis, axis_size)``
where ``axis_size`` is the size of the mapped axis named ``axis_name`` in
the input ``x``.
"""
return all_to_all(x, axis_name, axis, axis)
def all_to_all(x, axis_name, split_axis, concat_axis):
"""Materialize the mapped axis and map a different axis.
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
In the output, the input mapped axis ``axis_name`` is materialized at the
logical axis position ``concat_axis``, and the input unmapped axis at position
``split_axis`` is mapped with the name ``axis_name``.
The input mapped axis size must be equal to the size of the axis to be mapped;
that is, we must have ``lax.psum(1, axis_name) == x.shape[split_axis]``.
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
split_axis: int indicating the unmapped axis of ``x`` to map with the name
``axis_name``.
concat_axis: int indicating the position in the output to materialize the
mapped axis of the input with the name ``axis_name``.
Returns:
Array(s) with shape given by the expression::
np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)
where ``axis_size`` is the size of the mapped axis named ``axis_name`` in
the input ``x``, i.e. ``axis_size = lax.psum(1, axis_name)``.
"""
def bind(x):
if psum(1, axis_name) != x.shape[split_axis]:
msg = ("all_to_all requires the size of the mapped axis axis_name to "
"equal x.shape[split_axis], but they are {} and {} respectively.")
raise ValueError(msg.format(psum(1, axis_name), x.shape[split_axis]))
return all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis,
axis_name=axis_name)
return tree_util.tree_map(bind, x)
### parallel primitives
def standard_pmap_primitive(name, multiple_results=False):
prim = core.Primitive(name)
prim.multiple_results = multiple_results
prim.def_impl(partial(pxla.apply_parallel_primitive, prim))
prim.def_abstract_eval(lambda x, *args, **params: x)
return prim
def _allreduce_split_axis_rule(prim, reducer, vals, which_mapped, axis_name,
axis_index_groups):
assert tuple(which_mapped) == (True,)
if axis_index_groups is not None:
raise NotImplementedError("soft_pmap does not yet support axis_index_groups")
vals = (reducer(x, [0]) for x in vals)
return prim.bind(*vals, axis_name=axis_name), False
def _allreduce_translation_rule(prim, c, val, replica_groups, platform=None):
dtype = c.get_shape(val).numpy_dtype()
scalar = ShapedArray((), dtype)
computation = xla.primitive_subcomputation(prim, scalar, scalar)
replica_groups_protos = xc.make_replica_groups(replica_groups)
return xops.AllReduce(val, computation, replica_groups_protos, None, None)
# psum translation rule has special handling for complex dtypes
def _psum_translation_rule(c, *args, replica_groups=None, platform=None):
if platform in ("cpu", "tpu"):
return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups)
# XLA's tuple all-reduce doesn't support different dtypes in the same
# allreduce. Instead, we perform once all-reduce for each argument input type.
args_by_type = collections.defaultdict(lambda: ([], []))
for i, arg in enumerate(args):
indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
indices.append(i)
dtype_args.append(arg)
# The outputs, in the original argument order.
out = [None] * len(args)
replica_groups_protos = xc.make_replica_groups(replica_groups)
for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
is_complex = dtypes.issubdtype(dtype, np.complexfloating)
n = len(dtype_args)
if is_complex:
dtype_args = ([xops.Real(x) for x in dtype_args] +
[xops.Imag(x) for x in dtype_args])
scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar)
all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
replica_groups_protos, None, None)
if is_complex:
xs = [xops.Complex(xops.GetTupleElement(all_reduce, i),
xops.GetTupleElement(all_reduce, n + i)) for i in range(n)]
else:
xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
for i, x in zip(indices, xs):
out[i] = x
return xops.Tuple(c, out)
# TODO(b/150476027): CPU doesn't support tuple all-reduce correctly. But
# fortunately we don't really need it in that case because CPU doesn't support
# cross-task communication either.
# TODO(b/155446630): An XLA:TPU optimization pass also doesn't support
# tuple all-reduce yet. Meanwhile, rely on deterministic compiler behavior.
def _notuple_psum_translation_rule(c, *args, replica_groups):
def _translate(val):
psum = partial(_allreduce_translation_rule, lax.add_p, c,
replica_groups=replica_groups)
dtype = c.get_shape(val).numpy_dtype()
if dtypes.issubdtype(dtype, np.complexfloating):
return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val)))
else:
return psum(val)
return xops.Tuple(c, list(map(_translate, args)))
def _psum_transpose_rule(cts, axis_name, axis_index_groups):
nonzero_out_cts, treedef = tree_util.tree_flatten(cts)
nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name,
axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, nonzero_in_cts)
psum_p = standard_pmap_primitive('psum', multiple_results=True)
psum_p.def_abstract_eval(
lambda *args, **params: tuple(map(raise_to_shaped, args)))
pxla.split_axis_rules[psum_p] = \
partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args)
ad.deflinear(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
pmax_p = standard_pmap_primitive('pmax')
xla.parallel_translations[pmax_p] = \
partial(_allreduce_translation_rule, lax.max_p)
pxla.split_axis_rules[pmax_p] = \
partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)
pmin_p = standard_pmap_primitive('pmin')
xla.parallel_translations[pmin_p] = \
partial(_allreduce_translation_rule, lax.min_p)
pxla.split_axis_rules[pmin_p] = \
partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
def _ppermute_translation_rule(c, x, replica_groups, perm, platform=None):
group_size = len(replica_groups[0])
srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm)
if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))):
msg = "ppermute sources and destinations must be unique, got {}."
raise ValueError(msg.format(perm))
full_perm = []
for grp in replica_groups:
grp = list(sorted(grp))
full_perm.extend((grp[src], grp[dst]) for src, dst in perm)
return xops.CollectivePermute(x, full_perm)
def _ppermute_transpose_rule(t, perm, axis_name):
srcs, dsts = unzip2(perm)
inverse_perm = list(zip(dsts, srcs))
return [ppermute(t, axis_name=axis_name, perm=inverse_perm)]
ppermute_p = standard_pmap_primitive('ppermute')
ad.deflinear(ppermute_p, _ppermute_transpose_rule)
xla.parallel_translations[ppermute_p] = _ppermute_translation_rule
pxla.multi_host_supported_collectives.add(ppermute_p)
def _all_to_all_translation_rule(c, x, split_axis, concat_axis, replica_groups,
platform=None):
# Workaround for AllToAll not being implemented on CPU.
if len(replica_groups[0]) == 1:
return x
else:
split_count = len(replica_groups[0])
if not all(split_count == len(g) for g in replica_groups):
raise ValueError('Replica groups must be equally sized')
replica_groups_protos = xc.make_replica_groups(replica_groups)
return xops.AllToAll(x, split_axis, concat_axis, split_count,
replica_groups_protos)
def _all_to_all_split_axis_rule(vals, which_mapped, split_axis, concat_axis,
axis_name):
assert tuple(which_mapped) == (True,)
x, = vals
# perform the communication to swap the hardware-mapped axes
stacked = all_to_all_p.bind(x, split_axis=split_axis + 1, concat_axis=0,
axis_name=axis_name)
# transpose the newly mapped axis to the front, newly unmapped to concat_axis
out = _moveaxis(split_axis + 1, 0, stacked)
out = _moveaxis(1, concat_axis + 1, out)
return out, True
def _moveaxis(src, dst, x):
perm = [i for i in range(x.ndim) if i != src]
perm.insert(dst, src)
return lax.transpose(x, perm)
all_to_all_p = standard_pmap_primitive('all_to_all')
xla.parallel_translations[all_to_all_p] = _all_to_all_translation_rule
pxla.split_axis_rules[all_to_all_p] = _all_to_all_split_axis_rule
### papply rules
# TODO(skye): it would be nice if we could put these with their corresponding
# primitives, but that currently causes circular dependencies. More refactoring
# might fix this.
def _drop(x, dim, axis_name):
return lax.dynamic_index_in_dim(x, axis_index(axis_name), dim, False)
def _expand(dim, size, axis_name, x):
shape = list(x.shape)
shape.insert(dim, size)
out = lax.full(shape, lax._const(x, 0))
return lax.dynamic_update_index_in_dim(out, x, axis_index(axis_name), dim)
def _allgather(x, dim, size, axis_name):
outs = tree_util.tree_map(partial(_expand, dim, size, axis_name), x)
return psum(outs, axis_name)
def all_gather(x, axis_name):
"""Gather values of x across all replicas.
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
This is equivalent to, but faster than, all_to_all(broadcast(x)).
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
:func:`jax.pmap` documentation for more details).
Returns:
Array(s) representing the result of an all-gather along the axis
``axis_name``. Shapes are the same as ``x.shape``, but with a leading
dimension of the axis_size.
For example, with 2 XLA devices available:
>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x)
>>> print(y)
[[0 1 2 3]
[0 1 2 3]
[0 1 2 3]
[0 1 2 3]]
"""
return _allgather(x, 0, psum(1, axis_name), axis_name)
def _broadcasting_papply(prim, name, size, vals, axes, **params):
x, y = vals
xdim, ydim = axes
if xdim is None:
if x.shape:
if x.shape[ydim] == 1:
x = x.reshape(np.delete(x.shape, ydim))
else:
x = _drop(x, ydim, name)
return prim.bind(x, y, **params), ydim
elif ydim is None:
if y.shape:
if y.shape[xdim] == 1:
y = y.reshape(np.delete(y.shape, xdim))
else:
y = _drop(y, xdim, name)
return prim.bind(x, y, **params), xdim
elif xdim == ydim:
return prim.bind(x, y, **params), xdim
else:
x_tosplit = ydim - int(xdim <= ydim)
y_tosplit = xdim - int(ydim <= xdim)
if y.shape[y_tosplit] == 1:
y = _allgather(y, ydim, size, name)
y = y.reshape(np.delete(y.shape, xdim))
return prim.bind(x, y, **params), ydim
elif x.shape[x_tosplit] == 1:
x = _allgather(x, xdim, size, name)
x = x.reshape(np.delete(x.shape, ydim))
return prim.bind(x, y, **params), ydim
else:
x = all_to_all(x, name, x_tosplit, xdim)
return prim.bind(x, y, **params), ydim
def _defbroadcasting(prim):
parallel.papply_primitive_rules[prim] = partial(_broadcasting_papply, prim)
def _vectorized_papply(prim, name, size, vals, axes, **params):
assert all(axes[0] == a for a in axes[1:])
return prim.bind(*vals, **params), axes[0]
def _defvectorized(prim):
parallel.papply_primitive_rules[prim] = partial(_vectorized_papply, prim)
def _reducer_papply(prim, collective, name, size, vals, papply_axes, axes, **kwargs):
operand, = vals
papply_axis, = papply_axes
other_axes = [i for i in axes if i != papply_axis]
other_axes = [i - 1 if i > papply_axis else i for i in other_axes]
if other_axes:
if 'input_shape' in kwargs: # special to the reduce-sum family
s = kwargs['input_shape']
kwargs['input_shape'] = s[:papply_axis] + s[papply_axis + 1:]
result = prim.bind(operand, axes=tuple(other_axes), **kwargs)
else:
result = operand
if not axes or papply_axis in axes:
return collective(result, axis_name=name), None
else:
new_papply_axis = papply_axis - np.sum(np.less(other_axes, papply_axis))
return result, new_papply_axis
def _defreducer(prim, collective_prim):
parallel.papply_primitive_rules[prim] = partial(_reducer_papply, prim, collective_prim)
def _identity_papply(prim, argnum, name, size, vals, axes, **params):
return prim.bind(*vals, **params), axes[argnum]
def _defidentity(prim, argnum=0):
parallel.papply_primitive_rules[prim] = partial(_identity_papply, prim, argnum)
_defvectorized(lax.neg_p)
_defvectorized(lax.sign_p)
_defvectorized(lax.floor_p)
_defvectorized(lax.ceil_p)
_defvectorized(lax.round_p)
_defvectorized(lax.is_finite_p)
_defvectorized(lax.exp_p)
_defvectorized(lax.log_p)
_defvectorized(lax.expm1_p)
_defvectorized(lax.log1p_p)
_defvectorized(lax.tanh_p)
_defvectorized(lax.sin_p)
_defvectorized(lax.cos_p)
_defvectorized(lax.lgamma_p)
_defvectorized(lax.digamma_p)
_defvectorized(lax.erf_p)
_defvectorized(lax.erfc_p)
_defvectorized(lax.erf_inv_p)
_defvectorized(lax.real_p)
_defvectorized(lax.imag_p)
_defvectorized(lax.conj_p)
_defvectorized(lax.abs_p)
_defvectorized(lax.sqrt_p)
_defbroadcasting(lax.atan2_p)
_defbroadcasting(lax.complex_p)
_defbroadcasting(lax.pow_p)
_defbroadcasting(lax.and_p)
_defbroadcasting(lax.or_p)
_defbroadcasting(lax.xor_p)
_defbroadcasting(lax.add_p)
_defbroadcasting(lax.sub_p)
_defbroadcasting(lax.mul_p)
_defbroadcasting(lax.div_p)
_defbroadcasting(lax.rem_p)
_defbroadcasting(lax.max_p)
_defbroadcasting(lax.min_p)
_defbroadcasting(lax.shift_left_p)
_defbroadcasting(lax.shift_right_arithmetic_p)
_defbroadcasting(lax.shift_right_logical_p)
_defidentity(lax.tie_in_p)
_defreducer(lax.reduce_sum_p, psum)
_defreducer(lax.reduce_max_p, pmax)
_defreducer(lax.reduce_min_p, pmin)
def _dot_general_papply_rule(name, size, vals, dims, dimension_numbers,
precision):
x, y = vals
xdim, ydim = dims
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
if lhs_batch or rhs_batch:
raise NotImplementedError(
('papply of dot_general with batch dimensions: '
'xdim={}, ydim={}, dimension_numbers={}').format(
xdim, ydim, dimension_numbers))
def adjust_dims(dims, thresh):
return tuple(i - 1 if i > thresh else i for i in dims if i != thresh)
def sub_dims(xdim, ydim, xcontract, ycontract, xbatch, ybatch):
if xdim is not None:
xbatch = adjust_dims(xbatch, xdim)
xcontract = adjust_dims(xcontract, xdim)
if ydim is not None:
ybatch = adjust_dims(ybatch, ydim)
ycontract = adjust_dims(ycontract, ydim)
return ((xcontract, ycontract), (xbatch, ybatch))
def cases(x, y, xdim, ydim, xc, yc, xb, yb):
# Consider three states in which an operand may be
# 1: split, contracting
# 2: split, not contracting
# 3: not split
#
# We will handle the following cases, marked by corresponding letter
# symbols:
#
# |1 2 3|y
# -+-----+-
# 1|a b c
# 2|d e f
# 3|g h i
# -+
# x|
#
# Case i is already covered and we can assume that it is excluded at the
# outset, since a papply rule is not invoked when no operands are split.
if xdim in xc:
# cases a, b, c
if ydim in yc:
# case a: both operands are split and contracting
# TODO(frostig): Might the following work?
# z = lax.dot_general(
# x, y, sub_dims(xdim, ydim, xc, yc, xb, yb), precision)
# return True, (psum(z, name), None)
return False, 'both operands split and contracting'
elif ydim is not None:
# case b: x split and contracting, y split but not contracting
# TODO(frostig): Might the following work?
# new_ydim = yc[xc.index(xdim)]
# y = all_to_all(y, name, new_ydim, ydim)
# z = lax.dot_general(
# x, y, sub_dims(xdim, new_ydim, xc, yc, xb, yb), precision)
# return True, (psum(z, name), None)
return False, 'rhs split but not contracting, lhs split and contracting'
else:
# case c: x split and contracting, y not split
assert ydim is None
return False, 'one operand split and contracting, other is not split'
elif xdim is not None:
# cases d, e, f
if ydim in yc:
# case d: x split but not contracting, y split and contracting
# TODO(frostig): Might the following work?
# new_xdim = xc[yc.index(ydim)]
# x = all_to_all(x, name, new_xdim, xdim)
# z = lax.dot_general(
# x, y, sub_dims(new_xdim, ydim, xc, yc, xb, yb), precision)
# return True, (psum(z, name), None)
return False, 'lhs split but not contracting, rhs split and contracting'
elif ydim is not None:
# case e: both operands are split but not contracting
y = _allgather(y, ydim, size, name)
z = lax.dot_general(
x, y, sub_dims(xdim, None, xc, yc, xb, yb), precision)
zdim = xdim + len(xb) - len([d for d in range(xdim) if d in xc])
return True, (z, zdim)
else:
# case f: x split but not contracting, y not split
assert ydim is None
z = lax.dot_general(
x, y, sub_dims(xdim, None, xc, yc, xb, yb), precision)
zdim = xdim + len(xb) - len([d for d in range(xdim) if d in xc])
return True, (z, zdim)
else:
# cases g, h
assert xdim is None
if ydim in yc:
# case g: x not split, y split and contracting
return False, 'one operand split and contracting, other is not split'
else:
# case h: x not split, y split but not contracting
assert ydim is not None
# TODO(frostig): Might the following work?
# z = lax.dot_general(
# x, y, sub_dims(None, ydim, xc, yc, xb, yb), precision)
# zdim = (
# ydim + len(xb) + # batch dimensions
# x.ndim - len(xc) - # non-contracting x dimensions
# len([d for d in range(ydim) if d in yc]))
# return True, (z, zdim)
return False, 'lhs not split, rhs split but not contracting'
assert False, 'unreachable'
ok, out = cases(
x, y, xdim, ydim, lhs_contract, rhs_contract, lhs_batch, rhs_batch)
if ok:
return out
else:
raise NotImplementedError(
('papply of dot_general, {}: '
'xdim={}, ydim={}, dimension_numbers={}').format(
out, xdim, ydim, dimension_numbers))
def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions):
operand, = vals
axis, = axes
old_sizes = tuple(np.insert(operand.shape, axis, size))
def filter_ones(xs):
return filter(lambda x: x != 1, xs)
def find_new_axis(old_axis, old_sizes, new_sizes):
left = np.prod(old_sizes[:old_axis])
size = old_sizes[old_axis]
prod = 1
for i, cur_sz in enumerate(new_sizes):
if prod == left and cur_sz == size:
return i
prod = prod * cur_sz
return None
if dimensions is None:
new_axis = find_new_axis(axis, old_sizes, new_sizes)
if new_axis is not None:
new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:]
return lax.reshape(operand, new_sizes_, dimensions=dimensions), new_axis
else:
raise NotImplementedError(
'papply of reshape that would change hidden dimension size')
else:
raise NotImplementedError('papply of reshape with `dimensions`')
def _transpose_papply_rule(name, size, vals, dims, permutation):
x, = vals
xdim, = dims
local_perm = [i if i < xdim else i - 1 for i in permutation if i != xdim]
return lax.transpose(x, local_perm), permutation.index(xdim)
def _select_papply_rule(name, size, vals, dims):
dimset = {d for d in dims if d is not None}
if len(dimset) != 1:
raise NotImplementedError(
'papply of select with operands split along different dimensions')
dim, = dimset
def drop(x, d):
return _drop(x, dim, name) if d is None else x
return lax.select_p.bind(*map(drop, vals, dims)), dim
def _add_jaxvals_papply_rule(name, size, vals, dims):
x, y = vals
xdim, ydim = dims
if xdim == ydim:
out_dim = xdim
else:
raise NotImplementedError
# elif ydim is None:
# y = lax.psplit_like(y, x, name)
# out_dim = xdim
# else:
# x = lax.psplit_like(x, y, name)
# out_dim = ydim
return ad_util.add_jaxvals_p.bind(x, y), out_dim
def _convert_element_type_papply_rule(
name, size, vals, dims, new_dtype, **params):
operand, = vals
dim, = dims
return lax.convert_element_type(operand, new_dtype), dim
def _conv_general_dilated_papply_rule(
name, size, vals, dims, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, precision, **unused_kwargs):
lhs, rhs = vals
lhs_dim, rhs_dim = dims
lhs_spec_batch_dim = dimension_numbers.lhs_spec[0]
if rhs_dim is None and lhs_dim == lhs_spec_batch_dim:
lhs = lax.reshape(lhs, tuple(np.insert(lhs.shape, lhs_dim, 1)))
out = lax.conv_general_dilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, precision)
return out, lhs_dim
else:
raise NotImplementedError(
"splitting a convolution along anything but input batch dimension")
def _broadcast_in_dim_papply_rule(name, size, vals, dims, shape,
broadcast_dimensions):
operand, = vals
dim, = dims
out_dim = broadcast_dimensions[dim]
if shape[out_dim] != shape[dim]:
raise ValueError(
"broadcast_in_dim changes hidden dimension size: {} to {}".format(
shape[dim], shape[out_dim]))
sub_bdims = tuple(np.delete(broadcast_dimensions, dim))
sub_shape = tuple(np.delete(shape, out_dim))
return lax.broadcast_in_dim(operand, sub_shape, sub_bdims), out_dim
def _pad_papply_rule(name, size, vals, dims, padding_config):
operand, padding_value = vals
operand_dim, padding_value_dim = dims
assert padding_value_dim is None
padding_config = list(padding_config)
if padding_config[operand_dim] == (0, 0, 0):
padded = lax.pad(
operand,
padding_value,
padding_config[:operand_dim] + padding_config[operand_dim + 1:])
return padded, operand_dim
else:
raise NotImplementedError(
'pad changes size of hidden dimension {} with config {}'.format(
operand_dim, padding_config))
def _slice_papply_rule(name, size, vals, dims, start_indices, limit_indices,
strides, **kwargs):
operand, = vals
dim, = dims
start_indices = list(start_indices)
limit_indices = list(limit_indices)
if (start_indices[dim] != 0 or
limit_indices[dim] != size or
strides is not None and strides[dim] != 1):
raise NotImplementedError('slice changes side of hidden dimension')
out = lax.slice(
operand,
start_indices[:dim] + start_indices[dim + 1:],
limit_indices[:dim] + limit_indices[dim + 1:],
strides[:dim] + strides[dim + 1:] if strides is not None else None)
return out, dim
def _gather_papply_rule(
name, size, vals, dims, dimension_numbers, slice_sizes, operand_shape):
operand, start_indices = vals
operand_dim, start_indices_dim = dims
if (operand_dim is None and
start_indices_dim is not None and
start_indices_dim not in dimension_numbers.offset_dims and
dimension_numbers.collapsed_slice_dims == (0,)):
offset_dims = tuple(i - 1 if i > start_indices_dim else i
for i in dimension_numbers.offset_dims)
dnums = lax.GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
start_index_map=dimension_numbers.start_index_map)
out = lax.gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes)
out_dim = start_indices_dim + np.sum(
np.less_equal(offset_dims, start_indices_dim))
return out, out_dim
else:
raise NotImplementedError
parallel.papply_primitive_rules[lax.dot_general_p] = _dot_general_papply_rule
parallel.papply_primitive_rules[lax.reshape_p] = _reshape_papply_rule
parallel.papply_primitive_rules[lax.transpose_p] = _transpose_papply_rule
parallel.papply_primitive_rules[lax.select_p] = _select_papply_rule
parallel.papply_primitive_rules[ad_util.add_jaxvals_p] = \
_add_jaxvals_papply_rule
parallel.papply_primitive_rules[lax.convert_element_type_p] = \
_convert_element_type_papply_rule
parallel.papply_primitive_rules[lax.conv_general_dilated_p] = \
_conv_general_dilated_papply_rule
parallel.papply_primitive_rules[lax.broadcast_in_dim_p] = \
_broadcast_in_dim_papply_rule
parallel.papply_primitive_rules[lax.pad_p] = _pad_papply_rule
parallel.papply_primitive_rules[lax.slice_p] = _slice_papply_rule
parallel.papply_primitive_rules[lax.gather_p] = _gather_papply_rule