# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Parallelization primitives. """ import collections import numpy as np from jax import core from jax import ad_util from jax import dtypes from jax import tree_util from jax.lax import lax from jax.abstract_arrays import ShapedArray, raise_to_shaped from jax.interpreters import ad from jax.interpreters import parallel from jax.interpreters import xla from jax.interpreters import pxla from jax.util import partial, unzip2, prod from jax.lib import xla_client as xc from jax.interpreters.pxla import axis_index xops = xc.ops ### parallel traceables def psum(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Inputs of boolean dtype are converted to integers before the reduction. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce sum along the axis ``axis_name``. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [6 6 6 6] >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [ 0. 0.16666667 0.33333334 0.5 ] """ _validate_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) leaves = [lax.convert_element_type(l, np.int32) if dtypes.dtype(l) == np.bool_ else l for l in leaves] out_flat = psum_p.bind(*leaves, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) def pmean(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce mean along the axis ``axis_name``. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.pmean(x, 'i'), axis_name='i')(x) >>> print(y) [ 1.5 1.5 1.5 1.5 ] >>> y = jax.pmap(lambda x: x / jax.lax.pmean(x, 'i'), axis_name='i')(x) >>> print(y) [ 0. 0.66666667 1.33333334 2.0 ] """ x = psum(x, axis_name=axis_name, axis_index_groups=axis_index_groups) n = psum(1, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_map(lambda v: v / n, x) def pmax(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce max on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce max along the axis ``axis_name``. """ _validate_axis_index_groups(axis_index_groups) return tree_util.tree_map(partial( pmax_p.bind, axis_name=axis_name, axis_index_groups=axis_index_groups), x) def pmin(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce min on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce min along the axis ``axis_name``. """ _validate_axis_index_groups(axis_index_groups) return tree_util.tree_map(partial( pmin_p.bind, axis_name=axis_name, axis_index_groups=axis_index_groups), x) def _validate_axis_index_groups(axis_index_groups): if axis_index_groups is None: return len_0 = len(axis_index_groups[0]) if any(len(g) != len_0 for g in axis_index_groups): raise ValueError("axis_index_groups must all be the same size") axis_space = range(len_0 * len(axis_index_groups)) if set(i for g in axis_index_groups for i in g) != set(axis_space): raise ValueError("axis_index_groups must cover all indices exactly once") def ppermute(x, axis_name, perm): """Perform a collective permutation according to the permutation ``perm``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. This function is an analog of the CollectivePermute XLA HLO. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). perm: list of pairs of ints, representing ``(source_index, destination_index)`` pairs that encode how the mapped axis named ``axis_name`` should be shuffled. The integer values are treated as indices into the mapped axis ``axis_name``. Any two pairs should not have the same source index or the same destination index. For each index of the axis ``axis_name`` that does not correspond to a destination index in ``perm``, the corresponding values in the result are filled with zeros of the appropriate type. Returns: Array(s) with the same shape as ``x`` with slices along the axis ``axis_name`` gathered from ``x`` according to the permutation ``perm``. """ return tree_util.tree_map( partial(ppermute_p.bind, axis_name=axis_name, perm=tuple(perm)), x) def pshuffle(x, axis_name, perm): """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). perm: list of of ints encoding sources for the permutation to be applied to the axis named ``axis_name``, so that the output at axis index i comes from the input at axis index perm[i]. Every integer in [0, N) should be included exactly once for axis size N. Returns: Array(s) with the same shape as ``x`` with slices along the axis ``axis_name`` gathered from ``x`` according to the permutation ``perm``. """ if set(perm) != set(range(len(perm))): raise ValueError(f"`perm` does not represent a permutation: {perm}") return ppermute(x, axis_name, list(zip(perm, range(len(perm))))) def pswapaxes(x, axis_name, axis): """Swap the pmapped axis ``axis_name`` with the unmapped axis ``axis``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. The mapped axis size must be equal to the size of the unmapped axis; that is, we must have ``lax.psum(1, axis_name) == x.shape[axis]``. This function is a special case of ``all_to_all`` where the pmapped axis of the input is placed at the position ``axis`` in the output. That is, it is equivalent to ``all_to_all(x, axis_name, axis, axis)``. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis: int indicating the unmapped axis of ``x`` to map with the name ``axis_name``. Returns: Array(s) with shape ``np.insert(np.delete(x.shape, axis), axis, axis_size)`` where ``axis_size`` is the size of the mapped axis named ``axis_name`` in the input ``x``. """ return all_to_all(x, axis_name, axis, axis) def all_to_all(x, axis_name, split_axis, concat_axis): """Materialize the mapped axis and map a different axis. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. In the output, the input mapped axis ``axis_name`` is materialized at the logical axis position ``concat_axis``, and the input unmapped axis at position ``split_axis`` is mapped with the name ``axis_name``. The input mapped axis size must be equal to the size of the axis to be mapped; that is, we must have ``lax.psum(1, axis_name) == x.shape[split_axis]``. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). split_axis: int indicating the unmapped axis of ``x`` to map with the name ``axis_name``. concat_axis: int indicating the position in the output to materialize the mapped axis of the input with the name ``axis_name``. Returns: Array(s) with shape given by the expression:: np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size) where ``axis_size`` is the size of the mapped axis named ``axis_name`` in the input ``x``, i.e. ``axis_size = lax.psum(1, axis_name)``. """ def bind(x): if psum(1, axis_name) != x.shape[split_axis]: msg = ("all_to_all requires the size of the mapped axis axis_name to " "equal x.shape[split_axis], but they are {} and {} respectively.") raise ValueError(msg.format(psum(1, axis_name), x.shape[split_axis])) return all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name) return tree_util.tree_map(bind, x) ### parallel primitives def standard_pmap_primitive(name, multiple_results=False): prim = core.Primitive(name) prim.multiple_results = multiple_results prim.def_impl(partial(pxla.apply_parallel_primitive, prim)) prim.def_abstract_eval(lambda x, *args, **params: x) return prim def _allreduce_split_axis_rule(prim, reducer, vals, which_mapped, axis_name, axis_index_groups): assert tuple(which_mapped) == (True,) if axis_index_groups is not None: raise NotImplementedError("soft_pmap does not yet support axis_index_groups") vals = (reducer(x, [0]) for x in vals) return prim.bind(*vals, axis_name=axis_name), False def _allreduce_translation_rule(prim, c, val, replica_groups, platform=None): dtype = c.get_shape(val).numpy_dtype() scalar = ShapedArray((), dtype) computation = xla.primitive_subcomputation(prim, scalar, scalar) replica_groups_protos = xc.make_replica_groups(replica_groups) return xops.AllReduce(val, computation, replica_groups_protos, None, None) # psum translation rule has special handling for complex dtypes def _psum_translation_rule(c, *args, replica_groups=None, platform=None): if platform in ("cpu", "tpu"): return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, np.complexfloating) n = len(dtype_args) if is_complex: dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex: xs = [xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n)] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out) # TODO(b/150476027): CPU doesn't support tuple all-reduce correctly. But # fortunately we don't really need it in that case because CPU doesn't support # cross-task communication either. # TODO(b/155446630): An XLA:TPU optimization pass also doesn't support # tuple all-reduce yet. Meanwhile, rely on deterministic compiler behavior. def _notuple_psum_translation_rule(c, *args, replica_groups): def _translate(val): psum = partial(_allreduce_translation_rule, lax.add_p, c, replica_groups=replica_groups) dtype = c.get_shape(val).numpy_dtype() if dtypes.issubdtype(dtype, np.complexfloating): return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val))) else: return psum(val) return xops.Tuple(c, list(map(_translate, args))) def _psum_transpose_rule(cts, axis_name, axis_index_groups): nonzero_out_cts, treedef = tree_util.tree_flatten(cts) nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, nonzero_in_cts) psum_p = standard_pmap_primitive('psum', multiple_results=True) psum_p.def_abstract_eval( lambda *args, **params: tuple(map(raise_to_shaped, args))) pxla.split_axis_rules[psum_p] = \ partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) ad.deflinear(psum_p, _psum_transpose_rule) pxla.multi_host_supported_collectives.add(psum_p) pmax_p = standard_pmap_primitive('pmax') xla.parallel_translations[pmax_p] = \ partial(_allreduce_translation_rule, lax.max_p) pxla.split_axis_rules[pmax_p] = \ partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max) pmin_p = standard_pmap_primitive('pmin') xla.parallel_translations[pmin_p] = \ partial(_allreduce_translation_rule, lax.min_p) pxla.split_axis_rules[pmin_p] = \ partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min) def _ppermute_translation_rule(c, x, replica_groups, perm, platform=None): group_size = len(replica_groups[0]) srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm) if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))): msg = "ppermute sources and destinations must be unique, got {}." raise ValueError(msg.format(perm)) full_perm = [] for grp in replica_groups: grp = list(sorted(grp)) full_perm.extend((grp[src], grp[dst]) for src, dst in perm) return xops.CollectivePermute(x, full_perm) def _ppermute_transpose_rule(t, perm, axis_name): srcs, dsts = unzip2(perm) inverse_perm = list(zip(dsts, srcs)) return [ppermute(t, axis_name=axis_name, perm=inverse_perm)] ppermute_p = standard_pmap_primitive('ppermute') ad.deflinear(ppermute_p, _ppermute_transpose_rule) xla.parallel_translations[ppermute_p] = _ppermute_translation_rule pxla.multi_host_supported_collectives.add(ppermute_p) def _all_to_all_translation_rule(c, x, split_axis, concat_axis, replica_groups, platform=None): # Workaround for AllToAll not being implemented on CPU. if len(replica_groups[0]) == 1: return x else: split_count = len(replica_groups[0]) if not all(split_count == len(g) for g in replica_groups): raise ValueError('Replica groups must be equally sized') replica_groups_protos = xc.make_replica_groups(replica_groups) return xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos) def _all_to_all_split_axis_rule(vals, which_mapped, split_axis, concat_axis, axis_name): assert tuple(which_mapped) == (True,) x, = vals # perform the communication to swap the hardware-mapped axes stacked = all_to_all_p.bind(x, split_axis=split_axis + 1, concat_axis=0, axis_name=axis_name) # transpose the newly mapped axis to the front, newly unmapped to concat_axis out = _moveaxis(split_axis + 1, 0, stacked) out = _moveaxis(1, concat_axis + 1, out) return out, True def _moveaxis(src, dst, x): perm = [i for i in range(x.ndim) if i != src] perm.insert(dst, src) return lax.transpose(x, perm) all_to_all_p = standard_pmap_primitive('all_to_all') xla.parallel_translations[all_to_all_p] = _all_to_all_translation_rule pxla.split_axis_rules[all_to_all_p] = _all_to_all_split_axis_rule ### papply rules # TODO(skye): it would be nice if we could put these with their corresponding # primitives, but that currently causes circular dependencies. More refactoring # might fix this. def _drop(x, dim, axis_name): return lax.dynamic_index_in_dim(x, axis_index(axis_name), dim, False) def _expand(dim, size, axis_name, x): shape = list(x.shape) shape.insert(dim, size) out = lax.full(shape, lax._const(x, 0)) return lax.dynamic_update_index_in_dim(out, x, axis_index(axis_name), dim) def _allgather(x, dim, size, axis_name): outs = tree_util.tree_map(partial(_expand, dim, size, axis_name), x) return psum(outs, axis_name) def all_gather(x, axis_name): """Gather values of x across all replicas. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. This is equivalent to, but faster than, all_to_all(broadcast(x)). Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). Returns: Array(s) representing the result of an all-gather along the axis ``axis_name``. Shapes are the same as ``x.shape``, but with a leading dimension of the axis_size. For example, with 2 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]] """ return _allgather(x, 0, psum(1, axis_name), axis_name) def _broadcasting_papply(prim, name, size, vals, axes, **params): x, y = vals xdim, ydim = axes if xdim is None: if x.shape: if x.shape[ydim] == 1: x = x.reshape(np.delete(x.shape, ydim)) else: x = _drop(x, ydim, name) return prim.bind(x, y, **params), ydim elif ydim is None: if y.shape: if y.shape[xdim] == 1: y = y.reshape(np.delete(y.shape, xdim)) else: y = _drop(y, xdim, name) return prim.bind(x, y, **params), xdim elif xdim == ydim: return prim.bind(x, y, **params), xdim else: x_tosplit = ydim - int(xdim <= ydim) y_tosplit = xdim - int(ydim <= xdim) if y.shape[y_tosplit] == 1: y = _allgather(y, ydim, size, name) y = y.reshape(np.delete(y.shape, xdim)) return prim.bind(x, y, **params), ydim elif x.shape[x_tosplit] == 1: x = _allgather(x, xdim, size, name) x = x.reshape(np.delete(x.shape, ydim)) return prim.bind(x, y, **params), ydim else: x = all_to_all(x, name, x_tosplit, xdim) return prim.bind(x, y, **params), ydim def _defbroadcasting(prim): parallel.papply_primitive_rules[prim] = partial(_broadcasting_papply, prim) def _vectorized_papply(prim, name, size, vals, axes, **params): assert all(axes[0] == a for a in axes[1:]) return prim.bind(*vals, **params), axes[0] def _defvectorized(prim): parallel.papply_primitive_rules[prim] = partial(_vectorized_papply, prim) def _reducer_papply(prim, collective, name, size, vals, papply_axes, axes, **kwargs): operand, = vals papply_axis, = papply_axes other_axes = [i for i in axes if i != papply_axis] other_axes = [i - 1 if i > papply_axis else i for i in other_axes] if other_axes: if 'input_shape' in kwargs: # special to the reduce-sum family s = kwargs['input_shape'] kwargs['input_shape'] = s[:papply_axis] + s[papply_axis + 1:] result = prim.bind(operand, axes=tuple(other_axes), **kwargs) else: result = operand if not axes or papply_axis in axes: return collective(result, axis_name=name), None else: new_papply_axis = papply_axis - np.sum(np.less(other_axes, papply_axis)) return result, new_papply_axis def _defreducer(prim, collective_prim): parallel.papply_primitive_rules[prim] = partial(_reducer_papply, prim, collective_prim) def _identity_papply(prim, argnum, name, size, vals, axes, **params): return prim.bind(*vals, **params), axes[argnum] def _defidentity(prim, argnum=0): parallel.papply_primitive_rules[prim] = partial(_identity_papply, prim, argnum) _defvectorized(lax.neg_p) _defvectorized(lax.sign_p) _defvectorized(lax.floor_p) _defvectorized(lax.ceil_p) _defvectorized(lax.round_p) _defvectorized(lax.is_finite_p) _defvectorized(lax.exp_p) _defvectorized(lax.log_p) _defvectorized(lax.expm1_p) _defvectorized(lax.log1p_p) _defvectorized(lax.tanh_p) _defvectorized(lax.sin_p) _defvectorized(lax.cos_p) _defvectorized(lax.lgamma_p) _defvectorized(lax.digamma_p) _defvectorized(lax.erf_p) _defvectorized(lax.erfc_p) _defvectorized(lax.erf_inv_p) _defvectorized(lax.real_p) _defvectorized(lax.imag_p) _defvectorized(lax.conj_p) _defvectorized(lax.abs_p) _defvectorized(lax.sqrt_p) _defbroadcasting(lax.atan2_p) _defbroadcasting(lax.complex_p) _defbroadcasting(lax.pow_p) _defbroadcasting(lax.and_p) _defbroadcasting(lax.or_p) _defbroadcasting(lax.xor_p) _defbroadcasting(lax.add_p) _defbroadcasting(lax.sub_p) _defbroadcasting(lax.mul_p) _defbroadcasting(lax.div_p) _defbroadcasting(lax.rem_p) _defbroadcasting(lax.max_p) _defbroadcasting(lax.min_p) _defbroadcasting(lax.shift_left_p) _defbroadcasting(lax.shift_right_arithmetic_p) _defbroadcasting(lax.shift_right_logical_p) _defidentity(lax.tie_in_p) _defreducer(lax.reduce_sum_p, psum) _defreducer(lax.reduce_max_p, pmax) _defreducer(lax.reduce_min_p, pmin) def _dot_general_papply_rule(name, size, vals, dims, dimension_numbers, precision): x, y = vals xdim, ydim = dims (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers if lhs_batch or rhs_batch: raise NotImplementedError( ('papply of dot_general with batch dimensions: ' 'xdim={}, ydim={}, dimension_numbers={}').format( xdim, ydim, dimension_numbers)) def adjust_dims(dims, thresh): return tuple(i - 1 if i > thresh else i for i in dims if i != thresh) def sub_dims(xdim, ydim, xcontract, ycontract, xbatch, ybatch): if xdim is not None: xbatch = adjust_dims(xbatch, xdim) xcontract = adjust_dims(xcontract, xdim) if ydim is not None: ybatch = adjust_dims(ybatch, ydim) ycontract = adjust_dims(ycontract, ydim) return ((xcontract, ycontract), (xbatch, ybatch)) def cases(x, y, xdim, ydim, xc, yc, xb, yb): # Consider three states in which an operand may be # 1: split, contracting # 2: split, not contracting # 3: not split # # We will handle the following cases, marked by corresponding letter # symbols: # # |1 2 3|y # -+-----+- # 1|a b c # 2|d e f # 3|g h i # -+ # x| # # Case i is already covered and we can assume that it is excluded at the # outset, since a papply rule is not invoked when no operands are split. if xdim in xc: # cases a, b, c if ydim in yc: # case a: both operands are split and contracting # TODO(frostig): Might the following work? # z = lax.dot_general( # x, y, sub_dims(xdim, ydim, xc, yc, xb, yb), precision) # return True, (psum(z, name), None) return False, 'both operands split and contracting' elif ydim is not None: # case b: x split and contracting, y split but not contracting # TODO(frostig): Might the following work? # new_ydim = yc[xc.index(xdim)] # y = all_to_all(y, name, new_ydim, ydim) # z = lax.dot_general( # x, y, sub_dims(xdim, new_ydim, xc, yc, xb, yb), precision) # return True, (psum(z, name), None) return False, 'rhs split but not contracting, lhs split and contracting' else: # case c: x split and contracting, y not split assert ydim is None return False, 'one operand split and contracting, other is not split' elif xdim is not None: # cases d, e, f if ydim in yc: # case d: x split but not contracting, y split and contracting # TODO(frostig): Might the following work? # new_xdim = xc[yc.index(ydim)] # x = all_to_all(x, name, new_xdim, xdim) # z = lax.dot_general( # x, y, sub_dims(new_xdim, ydim, xc, yc, xb, yb), precision) # return True, (psum(z, name), None) return False, 'lhs split but not contracting, rhs split and contracting' elif ydim is not None: # case e: both operands are split but not contracting y = _allgather(y, ydim, size, name) z = lax.dot_general( x, y, sub_dims(xdim, None, xc, yc, xb, yb), precision) zdim = xdim + len(xb) - len([d for d in range(xdim) if d in xc]) return True, (z, zdim) else: # case f: x split but not contracting, y not split assert ydim is None z = lax.dot_general( x, y, sub_dims(xdim, None, xc, yc, xb, yb), precision) zdim = xdim + len(xb) - len([d for d in range(xdim) if d in xc]) return True, (z, zdim) else: # cases g, h assert xdim is None if ydim in yc: # case g: x not split, y split and contracting return False, 'one operand split and contracting, other is not split' else: # case h: x not split, y split but not contracting assert ydim is not None # TODO(frostig): Might the following work? # z = lax.dot_general( # x, y, sub_dims(None, ydim, xc, yc, xb, yb), precision) # zdim = ( # ydim + len(xb) + # batch dimensions # x.ndim - len(xc) - # non-contracting x dimensions # len([d for d in range(ydim) if d in yc])) # return True, (z, zdim) return False, 'lhs not split, rhs split but not contracting' assert False, 'unreachable' ok, out = cases( x, y, xdim, ydim, lhs_contract, rhs_contract, lhs_batch, rhs_batch) if ok: return out else: raise NotImplementedError( ('papply of dot_general, {}: ' 'xdim={}, ydim={}, dimension_numbers={}').format( out, xdim, ydim, dimension_numbers)) def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions): operand, = vals axis, = axes old_sizes = tuple(np.insert(operand.shape, axis, size)) def filter_ones(xs): return filter(lambda x: x != 1, xs) def find_new_axis(old_axis, old_sizes, new_sizes): left = np.prod(old_sizes[:old_axis]) size = old_sizes[old_axis] prod = 1 for i, cur_sz in enumerate(new_sizes): if prod == left and cur_sz == size: return i prod = prod * cur_sz return None if dimensions is None: new_axis = find_new_axis(axis, old_sizes, new_sizes) if new_axis is not None: new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:] return lax.reshape(operand, new_sizes_, dimensions=dimensions), new_axis else: raise NotImplementedError( 'papply of reshape that would change hidden dimension size') else: raise NotImplementedError('papply of reshape with `dimensions`') def _transpose_papply_rule(name, size, vals, dims, permutation): x, = vals xdim, = dims local_perm = [i if i < xdim else i - 1 for i in permutation if i != xdim] return lax.transpose(x, local_perm), permutation.index(xdim) def _select_papply_rule(name, size, vals, dims): dimset = {d for d in dims if d is not None} if len(dimset) != 1: raise NotImplementedError( 'papply of select with operands split along different dimensions') dim, = dimset def drop(x, d): return _drop(x, dim, name) if d is None else x return lax.select_p.bind(*map(drop, vals, dims)), dim def _add_jaxvals_papply_rule(name, size, vals, dims): x, y = vals xdim, ydim = dims if xdim == ydim: out_dim = xdim else: raise NotImplementedError # elif ydim is None: # y = lax.psplit_like(y, x, name) # out_dim = xdim # else: # x = lax.psplit_like(x, y, name) # out_dim = ydim return ad_util.add_jaxvals_p.bind(x, y), out_dim def _convert_element_type_papply_rule( name, size, vals, dims, new_dtype, **params): operand, = vals dim, = dims return lax.convert_element_type(operand, new_dtype), dim def _conv_general_dilated_papply_rule( name, size, vals, dims, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision, **unused_kwargs): lhs, rhs = vals lhs_dim, rhs_dim = dims lhs_spec_batch_dim = dimension_numbers.lhs_spec[0] if rhs_dim is None and lhs_dim == lhs_spec_batch_dim: lhs = lax.reshape(lhs, tuple(np.insert(lhs.shape, lhs_dim, 1))) out = lax.conv_general_dilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, precision) return out, lhs_dim else: raise NotImplementedError( "splitting a convolution along anything but input batch dimension") def _broadcast_in_dim_papply_rule(name, size, vals, dims, shape, broadcast_dimensions): operand, = vals dim, = dims out_dim = broadcast_dimensions[dim] if shape[out_dim] != shape[dim]: raise ValueError( "broadcast_in_dim changes hidden dimension size: {} to {}".format( shape[dim], shape[out_dim])) sub_bdims = tuple(np.delete(broadcast_dimensions, dim)) sub_shape = tuple(np.delete(shape, out_dim)) return lax.broadcast_in_dim(operand, sub_shape, sub_bdims), out_dim def _pad_papply_rule(name, size, vals, dims, padding_config): operand, padding_value = vals operand_dim, padding_value_dim = dims assert padding_value_dim is None padding_config = list(padding_config) if padding_config[operand_dim] == (0, 0, 0): padded = lax.pad( operand, padding_value, padding_config[:operand_dim] + padding_config[operand_dim + 1:]) return padded, operand_dim else: raise NotImplementedError( 'pad changes size of hidden dimension {} with config {}'.format( operand_dim, padding_config)) def _slice_papply_rule(name, size, vals, dims, start_indices, limit_indices, strides, **kwargs): operand, = vals dim, = dims start_indices = list(start_indices) limit_indices = list(limit_indices) if (start_indices[dim] != 0 or limit_indices[dim] != size or strides is not None and strides[dim] != 1): raise NotImplementedError('slice changes side of hidden dimension') out = lax.slice( operand, start_indices[:dim] + start_indices[dim + 1:], limit_indices[:dim] + limit_indices[dim + 1:], strides[:dim] + strides[dim + 1:] if strides is not None else None) return out, dim def _gather_papply_rule( name, size, vals, dims, dimension_numbers, slice_sizes, operand_shape): operand, start_indices = vals operand_dim, start_indices_dim = dims if (operand_dim is None and start_indices_dim is not None and start_indices_dim not in dimension_numbers.offset_dims and dimension_numbers.collapsed_slice_dims == (0,)): offset_dims = tuple(i - 1 if i > start_indices_dim else i for i in dimension_numbers.offset_dims) dnums = lax.GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=dimension_numbers.collapsed_slice_dims, start_index_map=dimension_numbers.start_index_map) out = lax.gather(operand, start_indices, dimension_numbers=dnums, slice_sizes=slice_sizes) out_dim = start_indices_dim + np.sum( np.less_equal(offset_dims, start_indices_dim)) return out, out_dim else: raise NotImplementedError parallel.papply_primitive_rules[lax.dot_general_p] = _dot_general_papply_rule parallel.papply_primitive_rules[lax.reshape_p] = _reshape_papply_rule parallel.papply_primitive_rules[lax.transpose_p] = _transpose_papply_rule parallel.papply_primitive_rules[lax.select_p] = _select_papply_rule parallel.papply_primitive_rules[ad_util.add_jaxvals_p] = \ _add_jaxvals_papply_rule parallel.papply_primitive_rules[lax.convert_element_type_p] = \ _convert_element_type_papply_rule parallel.papply_primitive_rules[lax.conv_general_dilated_p] = \ _conv_general_dilated_papply_rule parallel.papply_primitive_rules[lax.broadcast_in_dim_p] = \ _broadcast_in_dim_papply_rule parallel.papply_primitive_rules[lax.pad_p] = _pad_papply_rule parallel.papply_primitive_rules[lax.slice_p] = _slice_papply_rule parallel.papply_primitive_rules[lax.gather_p] = _gather_papply_rule