2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2018-11-17 18:03:33 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2022-10-17 11:15:14 -07:00
|
|
|
from __future__ import annotations
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-17 11:15:14 -07:00
|
|
|
import collections
|
|
|
|
import dataclasses
|
2021-09-13 17:24:44 -04:00
|
|
|
from functools import partial
|
2022-08-08 18:33:26 -07:00
|
|
|
from typing import (Any, Callable, Dict, Hashable, Iterable, Optional, Sequence,
|
|
|
|
Set, Tuple, Type, Union)
|
2021-09-13 17:24:44 -04:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-08 13:46:10 -07:00
|
|
|
import jax
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax.config import config
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src import core
|
2021-10-28 11:06:58 -07:00
|
|
|
from jax._src import source_info_util
|
2022-12-16 20:59:41 -08:00
|
|
|
from jax._src.core import raise_to_shaped, Trace, Tracer
|
2022-10-17 11:15:14 -07:00
|
|
|
from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
|
|
|
register_pytree_node)
|
2021-08-17 17:06:19 -07:00
|
|
|
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
|
|
|
zeros_like_p, Zero)
|
2022-12-20 14:49:27 -08:00
|
|
|
from jax._src import linear_util as lu
|
2022-11-10 11:59:16 -08:00
|
|
|
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
|
|
|
|
canonicalize_axis, moveaxis, as_hashable_function,
|
|
|
|
curry, memoize, weakref_lru_cache)
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax.interpreters import partial_eval as pe
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-17 11:15:14 -07:00
|
|
|
Array = Any
|
2022-10-17 11:15:14 -07:00
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-17 11:15:14 -07:00
|
|
|
|
|
|
|
# Piles
|
|
|
|
|
|
|
|
# i:(Fin 3) => f32[[3, 1, 4].i]
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class PileTy:
|
|
|
|
binder: core.Var
|
|
|
|
length: Union[int, Tracer, core.Var]
|
|
|
|
elt_ty: core.DShapedArray
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return f'Var{id(self.binder)}:{self.length} => {self.elt_ty}'
|
|
|
|
replace = dataclasses.replace
|
|
|
|
|
|
|
|
# [3, 1, 4].i
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class IndexedAxisSize:
|
|
|
|
idx: core.Var
|
|
|
|
lengths: Union[Array, core.Var, Tracer]
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
return f'{str(self.lengths)}.Var{id(self.idx)}'
|
|
|
|
replace = dataclasses.replace
|
|
|
|
|
|
|
|
# Pile(aval=a:3 => f32[[3 1 4].a],
|
|
|
|
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class Pile:
|
|
|
|
aval: PileTy
|
|
|
|
data: Array
|
|
|
|
|
2022-11-06 22:56:51 -08:00
|
|
|
# To vmap over a pile, one must specify the axis as PileAxis.
|
|
|
|
class PileAxis: pass
|
|
|
|
pile_axis = PileAxis()
|
|
|
|
|
|
|
|
# As a temporary measure before we have more general JITable / ADable interfaces
|
|
|
|
# (analogues to vmappable), to enable Piles to be used with other
|
|
|
|
# transformations and higher-order primitives (primarily jit, though also grad
|
|
|
|
# with allow_int=True) we register them as pytrees.
|
|
|
|
# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration
|
2022-10-17 11:15:14 -07:00
|
|
|
def _pile_flatten(pile):
|
|
|
|
lengths = []
|
|
|
|
new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
|
|
|
|
if type(d) is IndexedAxisSize else d
|
|
|
|
for d in pile.aval.elt_ty.shape]
|
|
|
|
elt_ty = pile.aval.elt_ty.update(shape=tuple(new_shape))
|
|
|
|
aval = pile.aval.replace(elt_ty=elt_ty)
|
|
|
|
return (lengths, pile.data), aval
|
|
|
|
def _pile_unflatten(aval, x):
|
|
|
|
lengths, data = x
|
|
|
|
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
|
|
|
|
if type(d) is IndexedAxisSize else d
|
|
|
|
for d in aval.elt_ty.shape]
|
|
|
|
elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
|
|
|
|
aval = aval.replace(elt_ty=elt_ty)
|
|
|
|
return Pile(aval, data)
|
|
|
|
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
|
|
|
|
|
|
|
|
def _pile_result(axis_size, axis, segment_lens, x):
|
|
|
|
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
|
|
|
|
shape = list(x.shape)
|
|
|
|
shape[axis] = IndexedAxisSize(binder, segment_lens)
|
|
|
|
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
|
|
|
|
return Pile(PileTy(binder, axis_size, elt_ty), x)
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
class ConcatAxis:
|
|
|
|
axis: int
|
|
|
|
segment_lengths: Array
|
|
|
|
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
def _update_annotation(
|
2022-07-07 16:44:00 -07:00
|
|
|
f: lu.WrappedFun, orig_type: Optional[core.InputType],
|
|
|
|
axis_size: core.AxisSize, axis_name: core.AxisName,
|
2022-10-17 11:15:14 -07:00
|
|
|
explicit_in_dims: Sequence[Optional[Union[int, ConcatAxis]]],
|
|
|
|
segment_lens: Sequence[Array],
|
|
|
|
) -> lu.WrappedFun:
|
2022-07-07 16:44:00 -07:00
|
|
|
if orig_type is None: return f
|
2022-10-17 11:15:14 -07:00
|
|
|
# By convention, `explicit_in_dims` only accounts for explicit arguments.
|
|
|
|
assert len(explicit_in_dims) == sum(explicit for _, explicit in orig_type)
|
2022-10-17 11:15:14 -07:00
|
|
|
# We need to:
|
|
|
|
# * if `axis_size` is dynamic, add a new implicit binder (type) for it;
|
|
|
|
# * for each element of `segment_lengths`, add a new explicit binder for it;
|
|
|
|
# * drop other implicit binders, replacing DBIdx which refer to them with
|
|
|
|
# Name objects;
|
|
|
|
# * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int
|
|
|
|
# size if `axis_size` is int, otherwise Name); if ConcatAxis-valued in_dim,
|
|
|
|
# add batch axis (int if corresponding segment_lengths is concrete, Name if
|
|
|
|
# not);
|
|
|
|
# * generate full in_type with implicit args too.
|
|
|
|
|
|
|
|
class Name:
|
|
|
|
def __init__(self, a): self.a = a
|
|
|
|
names = [Name(a) for a, _ in orig_type]
|
|
|
|
avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d # type: ignore
|
|
|
|
for d in a.shape))
|
|
|
|
if type(a) is core.DShapedArray else a for a, e in orig_type if e]
|
|
|
|
|
|
|
|
new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
|
|
|
|
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
|
|
|
|
for a, d in zip(avals, explicit_in_dims):
|
|
|
|
if isinstance(d, ConcatAxis):
|
|
|
|
s = segment_lens[d.segment_lengths.val]
|
|
|
|
if isinstance(core.get_aval(s), core.ConcreteArray):
|
|
|
|
shape = list(a.shape) # type: ignore
|
|
|
|
shape[d.axis] = int(s.sum()) # specialize on shape if we can
|
|
|
|
new_avals.append(a.update(shape=tuple(shape)))
|
|
|
|
else:
|
|
|
|
new_avals.append(a)
|
|
|
|
else:
|
|
|
|
new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore
|
|
|
|
|
|
|
|
mentioned = {d for a in new_avals if type(a) is core.DShapedArray
|
|
|
|
for d in a.shape if type(d) is Name}
|
|
|
|
expl_names = set(map(Name, new_avals))
|
|
|
|
impl_names = mentioned - expl_names # type: ignore
|
|
|
|
impl_part = [(n.a, False) for n in impl_names] # type: ignore
|
|
|
|
name_map = {n: pe.DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))}
|
|
|
|
expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape))
|
|
|
|
if type(a) is core.DShapedArray else a, True) for a in new_avals]
|
|
|
|
return lu.annotate(f, (*impl_part, *expl_part))
|
2022-03-30 17:52:55 -07:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
### vmappable typeclass
|
|
|
|
|
|
|
|
Vmappable = Any
|
|
|
|
Elt = Any
|
|
|
|
MapSpec = Any
|
|
|
|
AxisSize = Any
|
|
|
|
GetIdx = Callable[[], Tracer] # TODO(mattjj): revise this laziness
|
|
|
|
ToEltHandler = Callable[[Callable, GetIdx, Vmappable, MapSpec], Elt]
|
|
|
|
FromEltHandler = Callable[[Callable, AxisSize, Elt, MapSpec], Vmappable]
|
|
|
|
MakeIotaHandler = Callable[[AxisSize], Array]
|
|
|
|
|
|
|
|
def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
|
|
|
handler = to_elt_handlers.get(type(x))
|
|
|
|
if handler:
|
|
|
|
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
|
2022-10-17 11:15:14 -07:00
|
|
|
elif type(x) is Pile:
|
2022-11-06 22:56:51 -08:00
|
|
|
if spec is not pile_axis:
|
|
|
|
raise TypeError("pile input without using pile_axis in_axes spec")
|
2022-10-17 11:15:14 -07:00
|
|
|
(d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape)
|
|
|
|
if type(sz) is IndexedAxisSize)
|
|
|
|
return BatchTracer(trace, x.data, ConcatAxis(d, ias.lengths)) # type: ignore
|
|
|
|
elif isinstance(spec, int) or spec is None:
|
2021-10-06 14:18:07 -07:00
|
|
|
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
|
2022-01-19 11:43:02 -08:00
|
|
|
return (BatchTracer(trace, x, spec, source_info_util.current())
|
|
|
|
if spec is not None else x)
|
2022-10-17 11:15:14 -07:00
|
|
|
else:
|
|
|
|
assert False
|
2021-10-06 14:18:07 -07:00
|
|
|
to_elt_handlers: Dict[Type, ToEltHandler] = {}
|
|
|
|
|
|
|
|
def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
|
|
|
|
) -> Vmappable:
|
|
|
|
handler = from_elt_handlers.get(type(x))
|
|
|
|
if handler:
|
|
|
|
return handler(partial(from_elt, trace), axis_size, x, spec)
|
2022-10-17 11:15:14 -07:00
|
|
|
x_ = trace.full_raise(x)
|
|
|
|
val, bdim = x_.val, x_.batch_dim
|
|
|
|
if type(bdim) is ConcatAxis:
|
2022-11-06 22:56:51 -08:00
|
|
|
if spec is not pile_axis:
|
|
|
|
# TODO(mattjj): improve this error message
|
|
|
|
raise TypeError("ragged output without using pile_axis out_axes spec")
|
2022-10-17 11:15:14 -07:00
|
|
|
return _pile_result(axis_size, bdim.axis, bdim.segment_lengths, val)
|
2021-10-06 14:18:07 -07:00
|
|
|
else:
|
|
|
|
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
|
|
|
|
from_elt_handlers: Dict[Type, FromEltHandler] = {}
|
|
|
|
|
|
|
|
def make_iota(axis_size: AxisSize) -> Array:
|
|
|
|
handler = make_iota_handlers.get(type(axis_size))
|
|
|
|
if handler:
|
|
|
|
return handler(axis_size)
|
|
|
|
else:
|
|
|
|
return jax.lax.iota('int32', int(axis_size))
|
|
|
|
make_iota_handlers: Dict[Type, MakeIotaHandler] = {}
|
|
|
|
|
|
|
|
def register_vmappable(data_type: Type, spec_type: Type, axis_size_type: Type,
|
|
|
|
to_elt: Callable, from_elt: Callable,
|
|
|
|
make_iota: Optional[Callable]):
|
|
|
|
vmappables[data_type] = (spec_type, axis_size_type)
|
|
|
|
spec_types.add(spec_type)
|
|
|
|
to_elt_handlers[data_type] = to_elt
|
|
|
|
from_elt_handlers[data_type] = from_elt
|
|
|
|
if make_iota: make_iota_handlers[axis_size_type] = make_iota
|
|
|
|
vmappables: Dict[Type, Tuple[Type, Type]] = {}
|
2022-11-06 22:56:51 -08:00
|
|
|
spec_types: Set[Type] = {PileAxis}
|
2021-10-06 14:18:07 -07:00
|
|
|
|
|
|
|
def unregister_vmappable(data_type: Type) -> None:
|
|
|
|
spec_type, axis_size_type = vmappables.pop(data_type)
|
|
|
|
spec_types.remove(spec_type)
|
|
|
|
del to_elt_handlers[data_type]
|
|
|
|
del from_elt_handlers[data_type]
|
|
|
|
if axis_size_type in make_iota_handlers:
|
|
|
|
del make_iota_handlers[axis_size_type]
|
|
|
|
|
|
|
|
def is_vmappable(x: Any) -> bool:
|
2022-10-17 11:15:14 -07:00
|
|
|
return type(x) is Pile or type(x) in vmappables
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2021-10-06 14:18:07 -07:00
|
|
|
def flatten_fun_for_vmap(in_tree, *args_flat):
|
|
|
|
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
|
|
|
ans = yield py_args, py_kwargs
|
|
|
|
yield tree_flatten(ans, is_leaf=is_vmappable)
|
2020-01-15 15:00:38 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
### tracer
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
# TODO(mattjj): use a special sentinel type rather than None
|
|
|
|
NotMapped = type(None)
|
|
|
|
not_mapped = None
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-10-17 11:15:14 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class BatchTracer(Tracer):
|
2022-01-19 11:43:02 -08:00
|
|
|
__slots__ = ['val', 'batch_dim', 'source_info']
|
2019-01-16 16:51:54 +00:00
|
|
|
|
2022-10-17 11:15:14 -07:00
|
|
|
def __init__(self, trace, val, batch_dim: Union[NotMapped, int, ConcatAxis],
|
2022-01-19 11:43:02 -08:00
|
|
|
source_info: Optional[source_info_util.SourceInfo] = None):
|
2021-04-15 15:16:29 -07:00
|
|
|
if config.jax_enable_checks:
|
2022-10-17 11:15:14 -07:00
|
|
|
assert type(batch_dim) in (NotMapped, int, ConcatAxis)
|
2021-04-15 15:16:29 -07:00
|
|
|
if type(batch_dim) is int:
|
|
|
|
aval = raise_to_shaped(core.get_aval(val))
|
2022-10-17 11:15:14 -07:00
|
|
|
assert 0 <= batch_dim < len(aval.shape) # type: ignore
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.val = val
|
|
|
|
self.batch_dim = batch_dim
|
2022-01-19 11:43:02 -08:00
|
|
|
self.source_info = source_info
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
2019-07-27 15:46:14 -07:00
|
|
|
aval = raise_to_shaped(core.get_aval(self.val))
|
2022-06-23 17:23:43 +01:00
|
|
|
if self.batch_dim is not_mapped:
|
|
|
|
return aval
|
2022-10-17 11:15:14 -07:00
|
|
|
elif type(self.batch_dim) is int:
|
|
|
|
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
|
|
|
|
elif type(self.batch_dim) is ConcatAxis:
|
|
|
|
shape = list(aval.shape)
|
|
|
|
size_tracer = BatchTracer(self._trace, self.batch_dim.segment_lengths, 0)
|
|
|
|
shape[self.batch_dim.axis] = size_tracer
|
|
|
|
return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype,
|
|
|
|
weak_type=aval.weak_type)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def full_lower(self):
|
2019-07-27 15:46:14 -07:00
|
|
|
if self.batch_dim is not_mapped:
|
2018-11-17 18:03:33 -08:00
|
|
|
return core.full_lower(self.val)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
2022-01-19 11:43:02 -08:00
|
|
|
def _origin_msg(self):
|
|
|
|
if self.source_info is None:
|
|
|
|
return ""
|
2022-10-26 14:14:58 -07:00
|
|
|
return (f"\nThis BatchTracer with object id {id(self)} was created on line:"
|
|
|
|
f"\n {source_info_util.summarize(self.source_info)}")
|
2022-01-19 11:43:02 -08:00
|
|
|
|
|
|
|
def _contents(self):
|
|
|
|
return [('val', self.val), ('batch_dim', self.batch_dim)]
|
|
|
|
|
2022-10-17 11:15:14 -07:00
|
|
|
def get_referent(self):
|
|
|
|
if self.batch_dim is None or type(self.batch_dim) is int:
|
|
|
|
return core.get_referent(self.val)
|
|
|
|
else: # TODO(mattjj): could handle the ConcatAxis case?
|
|
|
|
return self
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class BatchTrace(Trace):
|
2022-08-08 18:33:26 -07:00
|
|
|
|
|
|
|
def __init__(self, *args, axis_name, spmd_axis_name = None):
|
2020-10-26 10:11:13 +00:00
|
|
|
super().__init__(*args)
|
|
|
|
self.axis_name = axis_name
|
2022-08-08 18:33:26 -07:00
|
|
|
self.spmd_axis_name = spmd_axis_name
|
2020-10-26 10:11:13 +00:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def pure(self, val):
|
2022-01-19 11:43:02 -08:00
|
|
|
return BatchTracer(self, val, not_mapped, source_info_util.current())
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def lift(self, val):
|
2022-01-19 11:43:02 -08:00
|
|
|
return BatchTracer(self, val, not_mapped, source_info_util.current())
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def sublift(self, val):
|
2022-01-19 11:43:02 -08:00
|
|
|
return BatchTracer(self, val.val, val.batch_dim, source_info_util.current())
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-08-26 13:34:01 -07:00
|
|
|
def get_primitive_batcher(self, primitive, frame):
|
|
|
|
if primitive in primitive_batchers:
|
2021-06-03 04:13:02 -07:00
|
|
|
return primitive_batchers[primitive]
|
2022-08-08 18:33:26 -07:00
|
|
|
elif self.spmd_axis_name is not None and primitive in spmd_axis_primitive_batchers:
|
|
|
|
return partial(spmd_axis_primitive_batchers[primitive],
|
|
|
|
self.spmd_axis_name, frame.size, frame.name,
|
|
|
|
frame.main_trace.trace_type)
|
2021-08-26 13:34:01 -07:00
|
|
|
elif primitive in axis_primitive_batchers:
|
|
|
|
return self.get_axis_primitive_batcher(primitive, frame)
|
|
|
|
msg = "Batching rule for '{}' not implemented"
|
|
|
|
raise NotImplementedError(msg.format(primitive))
|
|
|
|
|
|
|
|
def get_axis_primitive_batcher(self, primitive, frame):
|
|
|
|
return partial(axis_primitive_batchers[primitive],
|
|
|
|
frame.size, frame.name, frame.main_trace.trace_type)
|
|
|
|
|
|
|
|
def get_frame(self, vals, dims) -> core.AxisEnvFrame:
|
|
|
|
if self.axis_name is core.no_axis_name:
|
2022-03-30 17:52:55 -07:00
|
|
|
# If axis name is `no_axis_name` we can't find it via `core.axis_name` so
|
|
|
|
# we reconstruct it from the information we have available
|
2022-10-17 11:15:14 -07:00
|
|
|
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
|
|
|
|
for x, d in zip(vals, dims) if d is not not_mapped)
|
|
|
|
axis_size, = core.dedup_referents(sizes)
|
2021-08-26 13:34:01 -07:00
|
|
|
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
|
|
|
|
return core.axis_frame(self.axis_name)
|
2021-06-03 04:13:02 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
|
2021-08-26 13:34:01 -07:00
|
|
|
is_axis_primitive = primitive in axis_primitive_batchers
|
2021-10-06 14:18:07 -07:00
|
|
|
used_names = core.used_axis_names(primitive, params)
|
|
|
|
if is_axis_primitive and _main_trace_for_axis_names(self.main, used_names):
|
2021-08-26 13:34:01 -07:00
|
|
|
frame = self.get_frame(vals_in, dims_in)
|
|
|
|
batcher_primitive = self.get_axis_primitive_batcher(primitive, frame)
|
|
|
|
val_out, dim_out = batcher_primitive(vals_in, dims_in, **params)
|
2021-04-09 12:43:40 +00:00
|
|
|
elif all(bdim is not_mapped for bdim in dims_in):
|
|
|
|
return primitive.bind(*vals_in, **params)
|
2020-11-24 09:58:44 -08:00
|
|
|
else:
|
2021-08-26 13:34:01 -07:00
|
|
|
frame = self.get_frame(vals_in, dims_in)
|
|
|
|
batched_primitive = self.get_primitive_batcher(primitive, frame)
|
2020-11-24 09:58:44 -08:00
|
|
|
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
|
2022-01-19 11:43:02 -08:00
|
|
|
src = source_info_util.current()
|
2020-08-14 18:22:04 +02:00
|
|
|
if primitive.multiple_results:
|
2022-01-19 11:43:02 -08:00
|
|
|
return [BatchTracer(self, x, d, src) for x, d in zip(val_out, dim_out)]
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2022-01-19 11:43:02 -08:00
|
|
|
return BatchTracer(self, val_out, dim_out, src)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
def process_call(self, call_primitive, f, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert call_primitive.multiple_results
|
2022-11-10 11:59:16 -08:00
|
|
|
params = dict(params, name=params.get('name', f.__name__))
|
2020-04-21 18:12:02 -07:00
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
|
|
if all(bdim is not_mapped for bdim in dims):
|
|
|
|
return call_primitive.bind(f, *vals, **params)
|
2022-10-17 11:15:14 -07:00
|
|
|
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
|
|
|
|
for x, d in zip(vals, dims) if d is not not_mapped)
|
|
|
|
axis_size, = core.dedup_referents(sizes)
|
|
|
|
segment_lens, dims = unpack_concat_axes(dims)
|
|
|
|
f_, dims_out = batch_subtrace(f, self.main, tuple(dims))
|
|
|
|
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims,
|
|
|
|
segment_lens)
|
|
|
|
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
|
|
|
|
vals_out, dims_out = reassemble_concat_axes(vals_out, dims_out())
|
|
|
|
src = source_info_util.current()
|
|
|
|
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
|
2020-04-21 18:12:02 -07:00
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
2022-01-19 11:43:02 -08:00
|
|
|
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
|
|
|
for t in out_tracers)
|
2020-08-30 12:38:14 +03:00
|
|
|
main = self.main
|
2020-04-21 18:12:02 -07:00
|
|
|
def todo(vals):
|
2020-10-26 10:11:13 +00:00
|
|
|
trace = main.with_cur_sublevel()
|
2022-01-19 11:43:02 -08:00
|
|
|
return map(partial(BatchTracer, trace), vals, dims, srcs)
|
2020-04-21 18:12:02 -07:00
|
|
|
return vals, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
2019-06-04 18:33:52 -07:00
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
2019-07-27 15:46:14 -07:00
|
|
|
if all(dim is not_mapped for dim in dims):
|
2019-06-04 18:33:52 -07:00
|
|
|
return map_primitive.bind(f, *vals, **params)
|
|
|
|
else:
|
2020-11-05 11:54:05 +00:00
|
|
|
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
|
|
|
|
# The logic for the dimension math below is as follows:
|
|
|
|
# ╔═════════════╦════════════════════════════════════════╦═══════════╗
|
|
|
|
# ║ d / in_axis ║ None ║ int ║
|
|
|
|
# ╠═════════════╬════════════════════════════════════════╩═══════════╣
|
|
|
|
# ║ None ║ No extra axis, so in_axis unaffected ║
|
|
|
|
# ╠═════════════╬════════════════════════════════════════╦═══════════╣
|
|
|
|
# ║ int ║ Not mapped, so batching dim unaffected ║ See below ║
|
|
|
|
# ╚═════════════╩════════════════════════════════════════╩═══════════╝
|
|
|
|
# When both d and in_axis are defined then:
|
|
|
|
# - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
|
|
|
|
# - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed).
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def both_mapped(in_out_axis, d):
|
|
|
|
return in_out_axis is not None and d is not not_mapped
|
|
|
|
new_in_axes = tuple(
|
|
|
|
in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
|
|
|
|
for d, in_axis in zip(dims, params['in_axes']))
|
|
|
|
new_dims = tuple(
|
|
|
|
d - 1 if both_mapped(in_axis, d) and in_axis < d else d
|
|
|
|
for d, in_axis in zip(dims, params['in_axes']))
|
2020-11-05 11:54:05 +00:00
|
|
|
f, dims_out = batch_subtrace(f, self.main, new_dims)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
out_axes_thunk = params['out_axes_thunk']
|
2020-12-02 14:13:05 +00:00
|
|
|
# NOTE: This assumes that the choice of the dimensions over which outputs
|
|
|
|
# are batched is entirely dependent on the function and not e.g. on the
|
|
|
|
# data or its shapes.
|
|
|
|
@as_hashable_function(closure=out_axes_thunk)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def new_out_axes_thunk():
|
|
|
|
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
|
|
|
|
for out_axis, d in zip(out_axes_thunk(), dims_out()))
|
|
|
|
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
|
|
|
|
vals_out = map_primitive.bind(f, *vals, **new_params)
|
2022-10-20 22:23:29 -07:00
|
|
|
dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
|
|
|
|
for d, out_axis in zip(dims_out(), out_axes_thunk())]
|
2022-01-19 11:43:02 -08:00
|
|
|
src = source_info_util.current()
|
2022-10-20 22:23:29 -07:00
|
|
|
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)]
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2020-04-21 18:12:02 -07:00
|
|
|
def post_process_map(self, call_primitive, out_tracers, params):
|
2022-01-19 11:43:02 -08:00
|
|
|
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
|
|
|
for t in out_tracers)
|
2020-08-30 12:38:14 +03:00
|
|
|
main = self.main
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def both_mapped(in_out_axis, d):
|
|
|
|
return in_out_axis is not None and d is not not_mapped
|
2020-04-21 18:12:02 -07:00
|
|
|
def todo(vals):
|
2020-10-26 10:11:13 +00:00
|
|
|
trace = main.with_cur_sublevel()
|
2022-01-19 11:43:02 -08:00
|
|
|
return [BatchTracer(trace, v, d + 1 if both_mapped(oa, d) and oa <= d else d, s)
|
|
|
|
for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](), srcs)]
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
if call_primitive.map_primitive:
|
|
|
|
def out_axes_transform(out_axes):
|
|
|
|
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
|
|
|
|
for out_axis, d in zip(out_axes, dims))
|
|
|
|
todo = (todo, out_axes_transform)
|
2019-07-27 15:46:14 -07:00
|
|
|
return vals, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
2020-08-30 01:16:51 -07:00
|
|
|
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
|
|
|
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
|
2020-03-28 14:15:46 -07:00
|
|
|
out_vals = prim.bind(fun, jvp, *in_vals)
|
|
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
|
|
if not fst:
|
|
|
|
assert out_dims == out_dims[:len(out_dims) // 2] * 2
|
|
|
|
out_dims = out_dims[:len(out_dims) // 2]
|
2022-01-19 11:43:02 -08:00
|
|
|
src = source_info_util.current()
|
|
|
|
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
|
2020-03-28 14:15:46 -07:00
|
|
|
|
2021-12-11 14:07:30 -08:00
|
|
|
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
|
2022-01-19 11:43:02 -08:00
|
|
|
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
|
|
|
for t in out_tracers)
|
2020-10-16 00:21:04 -07:00
|
|
|
main = self.main
|
|
|
|
def todo(vals):
|
2020-10-26 10:11:13 +00:00
|
|
|
trace = main.with_cur_sublevel()
|
2021-12-11 14:07:30 -08:00
|
|
|
if jvp_was_run:
|
|
|
|
primal_dims, tangent_dims = dims[:len(vals)], dims[len(vals):]
|
|
|
|
assert primal_dims == tangent_dims
|
2022-01-19 11:43:02 -08:00
|
|
|
primal_srcs = srcs[:len(vals)]
|
|
|
|
return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
|
2021-12-11 14:07:30 -08:00
|
|
|
else:
|
2022-01-19 11:43:02 -08:00
|
|
|
return map(partial(BatchTracer, trace), vals, dims, srcs)
|
2020-10-16 00:21:04 -07:00
|
|
|
return vals, todo
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
|
|
|
|
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
2021-01-21 21:29:09 -08:00
|
|
|
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
|
|
|
|
if d is not not_mapped}
|
2020-08-30 01:16:51 -07:00
|
|
|
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
|
|
|
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
|
2021-01-21 21:29:09 -08:00
|
|
|
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
|
2022-08-08 18:33:26 -07:00
|
|
|
out_dims2, in_dims, self.main.trace_type, self.spmd_axis_name)
|
2020-03-28 14:15:46 -07:00
|
|
|
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
|
|
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
|
|
if not fst:
|
2022-10-17 11:15:14 -07:00
|
|
|
_, res_tree = out_trees()
|
|
|
|
_, out_dims = split_list(out_dims, [res_tree.num_leaves])
|
2022-01-19 11:43:02 -08:00
|
|
|
src = source_info_util.current()
|
|
|
|
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
|
2020-03-28 14:15:46 -07:00
|
|
|
|
2021-12-11 14:07:30 -08:00
|
|
|
def post_process_custom_vjp_call(self, out_tracers, _):
|
2022-01-19 11:43:02 -08:00
|
|
|
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
|
|
|
for t in out_tracers)
|
2021-12-11 14:07:30 -08:00
|
|
|
main = self.main
|
|
|
|
def todo(vals):
|
|
|
|
trace = main.with_cur_sublevel()
|
2022-01-19 11:43:02 -08:00
|
|
|
return map(partial(BatchTracer, trace), vals, dims, srcs)
|
2021-12-11 14:07:30 -08:00
|
|
|
return vals, todo
|
|
|
|
|
|
|
|
def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees):
|
2022-01-19 11:43:02 -08:00
|
|
|
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
|
|
|
for t in out_tracers)
|
2021-12-11 14:07:30 -08:00
|
|
|
axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
|
|
|
|
main, trace_type = self.main, self.main.trace_type
|
|
|
|
axis_name = self.axis_name
|
|
|
|
_, res_tree = out_trees()
|
|
|
|
num_res = res_tree.num_leaves
|
|
|
|
res_dims, primal_dims = split_list(dims, [num_res])
|
2022-01-19 11:43:02 -08:00
|
|
|
_, primal_srcs = split_list(srcs, [num_res])
|
2021-12-11 14:07:30 -08:00
|
|
|
def todo(vals):
|
|
|
|
trace = main.with_cur_sublevel()
|
2022-01-19 11:43:02 -08:00
|
|
|
return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
|
2021-12-11 14:07:30 -08:00
|
|
|
def bwd_transform(bwd):
|
2022-01-19 11:43:02 -08:00
|
|
|
return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,),
|
2022-08-08 18:33:26 -07:00
|
|
|
trace_type, self.spmd_axis_name)
|
2021-12-11 14:07:30 -08:00
|
|
|
return vals, todo, bwd_transform
|
2021-01-19 19:08:23 -08:00
|
|
|
|
2020-11-24 09:58:44 -08:00
|
|
|
def _main_trace_for_axis_names(main_trace: core.MainTrace,
|
2021-01-26 18:54:00 +00:00
|
|
|
axis_name: Iterable[core.AxisName],
|
2020-11-24 09:58:44 -08:00
|
|
|
) -> bool:
|
|
|
|
# This function exists to identify whether a main trace corresponds to any of
|
|
|
|
# the axis names used by a primitive. Axis names alone aren't enough because
|
|
|
|
# axis names can shadow, so we use the main trace as a tag.
|
|
|
|
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
|
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
### API for batching callables with vmappable inputs and outputs
|
2021-01-21 21:29:09 -08:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
def batch(fun: lu.WrappedFun, axis_name: core.AxisName, axis_size,
|
|
|
|
in_dims, out_dim_dests, main_type: Type[BatchTrace] = BatchTrace,
|
2022-08-08 18:33:26 -07:00
|
|
|
spmd_axis_name: Optional[Hashable] = None) -> lu.WrappedFun:
|
2021-10-06 14:18:07 -07:00
|
|
|
# we split up _batch_inner and _batch_outer for the leak checker
|
|
|
|
f = _batch_inner(fun, axis_size, out_dim_dests)
|
2022-08-08 18:33:26 -07:00
|
|
|
return _batch_outer(f, axis_name, axis_size, in_dims, main_type,
|
|
|
|
spmd_axis_name)
|
2021-04-20 08:32:41 -07:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
@lu.transformation
|
2022-08-08 18:33:26 -07:00
|
|
|
def _batch_outer(axis_name, axis_size, in_dims, main_type, spmd_axis_name,
|
|
|
|
*in_vals):
|
|
|
|
with core.new_main(
|
|
|
|
main_type, axis_name=axis_name, spmd_axis_name=spmd_axis_name) as main:
|
2021-10-06 14:18:07 -07:00
|
|
|
with core.extend_axis_env(axis_name, axis_size, main):
|
2021-10-28 11:06:58 -07:00
|
|
|
with source_info_util.transform_name_stack('vmap'):
|
|
|
|
outs = yield (main, in_dims, *in_vals), {}
|
2021-10-06 14:18:07 -07:00
|
|
|
del main
|
|
|
|
yield outs
|
2021-04-20 08:32:41 -07:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
@lu.transformation
|
|
|
|
def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals):
|
|
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
|
|
|
trace = main.with_cur_sublevel()
|
2022-01-19 11:43:02 -08:00
|
|
|
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_size), 0,
|
|
|
|
source_info_util.current()))
|
2021-10-06 14:18:07 -07:00
|
|
|
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
|
|
|
|
outs = yield in_tracers, {}
|
|
|
|
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
|
|
|
out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)
|
|
|
|
yield out_vals
|
2021-04-20 08:32:41 -07:00
|
|
|
|
|
|
|
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
|
|
|
def vtile(f_flat: lu.WrappedFun,
|
|
|
|
in_axes_flat: Tuple[Optional[int], ...],
|
|
|
|
out_axes_flat: Tuple[Optional[int], ...],
|
|
|
|
tile_size: Optional[int],
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
main_type: Type[BatchTrace] = BatchTrace):
|
|
|
|
@curry
|
|
|
|
def tile_axis(arg, axis: Optional[int], tile_size):
|
|
|
|
if axis is None:
|
|
|
|
return arg
|
|
|
|
shape = list(arg.shape)
|
|
|
|
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
|
|
|
|
return arg.reshape(shape)
|
|
|
|
|
|
|
|
def untile_axis(out, axis: Optional[int]):
|
|
|
|
if axis is None:
|
|
|
|
return out
|
|
|
|
shape = list(out.shape)
|
|
|
|
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
|
|
|
|
return out.reshape(shape)
|
|
|
|
|
|
|
|
@lu.transformation
|
|
|
|
def _map_to_tile(*args_flat):
|
|
|
|
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
|
|
|
|
tile_size_ = tile_size or next(sizes, None)
|
|
|
|
assert tile_size_ is not None, "No mapped arguments?"
|
|
|
|
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
|
|
|
|
yield map(untile_axis, outputs_flat, out_axes_flat)
|
|
|
|
|
|
|
|
return _map_to_tile(batch(
|
|
|
|
f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
### API for batching functions with jaxpr type inputs and outputs
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
|
|
|
def batch_subtrace(main, in_dims, *in_vals):
|
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
2022-10-17 11:15:14 -07:00
|
|
|
in_vals, in_dims = reassemble_concat_axes(in_vals, in_dims)
|
2022-01-19 11:43:02 -08:00
|
|
|
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
|
|
|
|
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
|
2021-10-06 14:18:07 -07:00
|
|
|
outs = yield in_tracers, {}
|
|
|
|
out_tracers = map(trace.full_raise, outs)
|
|
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
2022-10-17 11:15:14 -07:00
|
|
|
segment_lens, out_dims = unpack_concat_axes(out_dims)
|
|
|
|
yield (*segment_lens, *out_vals), out_dims
|
|
|
|
|
|
|
|
def unpack_concat_axes(dims):
|
|
|
|
if not any(type(d) is ConcatAxis for d in dims):
|
|
|
|
return [], dims
|
|
|
|
concat_axis_map = collections.OrderedDict()
|
|
|
|
def convert(d: ConcatAxis) -> ConcatAxis:
|
|
|
|
_, dbidx = concat_axis_map.setdefault(
|
|
|
|
id(core.get_referent(d.segment_lengths)),
|
|
|
|
(d.segment_lengths, pe.DBIdx(len(concat_axis_map))))
|
|
|
|
return ConcatAxis(d.axis, dbidx)
|
|
|
|
new_dims = [convert(d) if isinstance(d, ConcatAxis) else d for d in dims]
|
|
|
|
segment_lens = [s for s, _ in concat_axis_map.values()]
|
|
|
|
return segment_lens, new_dims
|
|
|
|
|
|
|
|
def reassemble_concat_axes(vals, dims):
|
|
|
|
idxs = {d.segment_lengths.val for d in dims if isinstance(d, ConcatAxis)}
|
|
|
|
dims = [ConcatAxis(d.axis, vals[d.segment_lengths.val])
|
|
|
|
if isinstance(d, ConcatAxis) else d for d in dims]
|
|
|
|
vals = [x for i, x in enumerate(vals) if i not in idxs]
|
|
|
|
return vals, dims
|
2021-10-06 14:18:07 -07:00
|
|
|
|
|
|
|
|
|
|
|
### API for batching jaxprs
|
|
|
|
|
2023-01-12 21:16:18 -08:00
|
|
|
def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
|
|
|
axis_size: core.AxisSize,
|
|
|
|
in_axes: Tuple[Union[int, NotMapped], ...],
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
main_type: Type[BatchTrace],
|
|
|
|
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
|
|
|
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name,
|
|
|
|
main_type)
|
|
|
|
|
|
|
|
@weakref_lru_cache
|
|
|
|
def _batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
|
|
|
axis_size: core.AxisSize,
|
|
|
|
in_axes: Tuple[Union[int, NotMapped], ...],
|
|
|
|
axis_name: core.AxisName,
|
|
|
|
main_type: Type[BatchTrace],
|
|
|
|
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
|
|
|
f, out_axes = _batch_jaxpr_inner(f, axis_size)
|
|
|
|
f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
|
|
|
|
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval)
|
|
|
|
if b is not not_mapped else aval
|
|
|
|
for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
|
|
|
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
|
|
|
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
|
|
|
|
|
2022-01-18 11:24:44 -08:00
|
|
|
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name,
|
|
|
|
main_type):
|
|
|
|
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
|
|
|
|
return _batch_jaxpr(closed_jaxpr, axis_size, tuple(in_batched), inst,
|
|
|
|
axis_name, main_type)
|
|
|
|
|
|
|
|
def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name,
|
|
|
|
main_type):
|
2021-10-06 14:18:07 -07:00
|
|
|
assert (isinstance(instantiate, bool) or
|
|
|
|
isinstance(instantiate, (list, tuple)) and
|
|
|
|
all(isinstance(b, bool) for b in instantiate))
|
|
|
|
if isinstance(instantiate, bool):
|
|
|
|
instantiate = [instantiate] * len(closed_jaxpr.out_avals)
|
|
|
|
in_axes = [0 if b else not_mapped for b in in_batched]
|
|
|
|
out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate]
|
|
|
|
return batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
|
|
|
|
axis_name, main_type)
|
|
|
|
|
|
|
|
def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name,
|
|
|
|
main_type):
|
2022-03-07 12:44:50 -08:00
|
|
|
return _batch_jaxpr_axes(closed_jaxpr, axis_size, tuple(in_axes),
|
|
|
|
tuple(out_axes_dest), axis_name, main_type)
|
|
|
|
|
2022-05-05 13:16:44 -07:00
|
|
|
@weakref_lru_cache
|
|
|
|
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
|
|
|
|
axis_name, main_type):
|
2021-10-06 14:18:07 -07:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
2023-01-12 21:16:18 -08:00
|
|
|
f, out_axes = _batch_jaxpr_inner(f, axis_size)
|
|
|
|
f, out_batched = _match_axes_jaxpr(f, axis_size, out_axes_dest, out_axes)
|
2021-10-06 14:18:07 -07:00
|
|
|
f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
|
|
|
|
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
|
2022-10-20 21:56:00 -07:00
|
|
|
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
2021-10-06 14:18:07 -07:00
|
|
|
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
|
|
|
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
2023-01-12 21:16:18 -08:00
|
|
|
def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals):
|
2021-10-06 14:18:07 -07:00
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
|
|
for val, dim in zip(in_vals, in_axes)]
|
|
|
|
outs = yield in_tracers, {}
|
|
|
|
out_tracers = map(trace.full_raise, outs)
|
|
|
|
out_vals, out_axes = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
2023-01-12 21:16:18 -08:00
|
|
|
yield out_vals, out_axes
|
2021-10-06 14:18:07 -07:00
|
|
|
|
2023-01-12 21:16:18 -08:00
|
|
|
@lu.transformation_with_aux
|
|
|
|
def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes,
|
|
|
|
*in_vals):
|
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
out_vals = yield (main, in_axes, *in_vals), {}
|
|
|
|
out_axes = out_axes()
|
2021-10-06 14:18:07 -07:00
|
|
|
out_axes_dest = [(None if src is not_mapped else 0)
|
|
|
|
if dst is zero_if_mapped else dst
|
2022-10-20 21:56:00 -07:00
|
|
|
for src, dst in unsafe_zip(out_axes, out_axes_dest)]
|
2021-10-06 14:18:07 -07:00
|
|
|
if len(out_axes_dest) != len(out_axes):
|
|
|
|
out_axis_dest, = out_axes_dest
|
|
|
|
out_axes_dest = [out_axis_dest] * len(out_axes)
|
|
|
|
out_vals = map(partial(matchaxis, trace.axis_name, axis_size),
|
|
|
|
out_axes, out_axes_dest, out_vals)
|
|
|
|
out_batched = [dst is not None for dst in out_axes_dest]
|
|
|
|
yield out_vals, out_batched
|
|
|
|
|
|
|
|
@lu.transformation
|
|
|
|
def _batch_jaxpr_outer(axis_name, axis_size, in_dims, main_type, *in_vals):
|
|
|
|
if axis_size is None:
|
|
|
|
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
|
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
|
|
|
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
|
2022-10-20 21:56:00 -07:00
|
|
|
else ax for x, ax in unsafe_zip(in_vals, in_dims)]
|
2021-10-06 14:18:07 -07:00
|
|
|
with core.new_main(main_type, axis_name=axis_name) as main:
|
|
|
|
with core.extend_axis_env(axis_name, axis_size, main):
|
|
|
|
out_vals = yield (main, in_dims, *in_vals), {}
|
|
|
|
del main
|
|
|
|
yield out_vals
|
|
|
|
|
|
|
|
def _merge_bdims(x, y):
|
|
|
|
if x == y:
|
|
|
|
return x
|
|
|
|
elif x is not_mapped:
|
|
|
|
return y
|
|
|
|
elif y is not_mapped:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return x # arbitrary
|
|
|
|
|
2022-05-05 13:16:44 -07:00
|
|
|
class ZeroIfMapped: pass
|
|
|
|
zero_if_mapped = ZeroIfMapped()
|
2021-10-06 14:18:07 -07:00
|
|
|
|
|
|
|
### functions for handling custom_vjp
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
|
|
|
def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
|
2022-10-17 11:15:14 -07:00
|
|
|
size, = {x.shape[d] for x, d in zip(in_vals, in_dims * 2)
|
|
|
|
if d is not not_mapped}
|
2021-10-06 14:18:07 -07:00
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
|
|
for val, dim in zip(in_vals, in_dims * 2)]
|
|
|
|
outs = yield in_tracers, {}
|
|
|
|
out_tracers = map(trace.full_raise, outs)
|
|
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
|
|
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
|
|
|
|
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
|
|
|
|
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
|
|
|
|
out_primals = map(partial(matchaxis, trace.axis_name, size),
|
|
|
|
out_primal_bds, out_dims, out_primals)
|
|
|
|
out_tangents = map(partial(matchaxis, trace.axis_name, size),
|
|
|
|
out_tangent_bds, out_dims, out_tangents)
|
|
|
|
yield out_primals + out_tangents, out_dims * 2
|
|
|
|
|
2022-08-08 18:33:26 -07:00
|
|
|
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name):
|
2021-10-06 14:18:07 -07:00
|
|
|
bwd, out_dims_thunk = batch_subtrace(bwd)
|
2022-08-08 18:33:26 -07:00
|
|
|
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type, spmd_axis_name)
|
2021-10-06 14:18:07 -07:00
|
|
|
return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests)
|
|
|
|
|
|
|
|
@lu.transformation
|
|
|
|
def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
|
|
|
|
# this is like _match_axes, but we do reduce-sums as needed
|
|
|
|
out_vals = yield in_vals, {}
|
|
|
|
yield map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name,
|
|
|
|
sum_match=True), out_dims_thunk(), out_dim_dests, out_vals)
|
|
|
|
|
|
|
|
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
|
|
|
|
# Just like `matchaxis`, but handles symbolic zeros using ad_util.py
|
|
|
|
# TODO(mattjj): dedup with matchaxis
|
|
|
|
if isinstance(x, Zero):
|
|
|
|
if src == dst:
|
|
|
|
return x
|
|
|
|
elif type(src) == type(dst) == int:
|
|
|
|
aval = core.mapped_aval(sz, src, x.aval)
|
|
|
|
return Zero(core.unmapped_aval(sz, name, dst, aval))
|
|
|
|
elif src is not_mapped and dst is not not_mapped:
|
|
|
|
return Zero(core.unmapped_aval(sz, name, dst, x.aval))
|
|
|
|
elif dst is not_mapped and sum_match:
|
|
|
|
return Zero(core.mapped_aval(sz, src, x.aval))
|
|
|
|
else:
|
|
|
|
raise ValueError((axis_name, x, src, dst))
|
|
|
|
else:
|
|
|
|
return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
|
|
|
|
|
|
|
|
|
|
|
|
### utilities for defining primitives' batching rules
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
|
2020-01-15 15:00:38 -08:00
|
|
|
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
|
2021-10-06 14:18:07 -07:00
|
|
|
axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
2022-08-08 18:33:26 -07:00
|
|
|
spmd_axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defvectorized(prim):
|
|
|
|
primitive_batchers[prim] = partial(vectorized_batcher, prim)
|
|
|
|
|
|
|
|
def vectorized_batcher(prim, batched_args, batch_dims, **params):
|
2019-07-02 12:18:47 -04:00
|
|
|
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
|
2018-11-17 18:03:33 -08:00
|
|
|
return prim.bind(*batched_args, **params), batch_dims[0]
|
|
|
|
|
|
|
|
def defbroadcasting(prim):
|
|
|
|
primitive_batchers[prim] = partial(broadcast_batcher, prim)
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def broadcast_batcher(prim, args, dims, **params):
|
2020-02-10 11:40:05 +01:00
|
|
|
"""Process a primitive with built-in broadcasting.
|
|
|
|
|
|
|
|
Args:
|
2020-02-13 09:28:01 +01:00
|
|
|
args: the possibly-batched arguments
|
2020-02-14 07:23:09 +01:00
|
|
|
dims: list or tuple of the same length as `args`, where each
|
|
|
|
entry indicates the batching state of the corresponding entry to `args`:
|
|
|
|
either an int indicating the batch dimension, or else `not_mapped`
|
|
|
|
indicating no batching.
|
2020-02-10 11:40:05 +01:00
|
|
|
"""
|
2022-07-07 16:44:00 -07:00
|
|
|
assert len(args) > 1
|
|
|
|
shape, dim = next((x.shape, d) for x, d in zip(args, dims)
|
|
|
|
if d is not not_mapped)
|
|
|
|
if all(core.symbolic_equal_shape(shape, x.shape) and d == dim
|
|
|
|
for x, d in zip(args, dims) if np.ndim(x)):
|
2019-07-27 15:46:14 -07:00
|
|
|
# if there's only agreeing batch dims and scalars, just call the primitive
|
2019-11-24 13:06:23 -05:00
|
|
|
out = prim.bind(*args, **params)
|
2022-07-07 16:44:00 -07:00
|
|
|
return (out, (dim,) * len(out)) if prim.multiple_results else (out, dim)
|
2019-07-27 15:46:14 -07:00
|
|
|
else:
|
2022-02-16 23:11:22 -08:00
|
|
|
# We pass size of 1 here because (1) at least one argument has a real batch
|
|
|
|
# dimension and (2) all unmapped axes can have a singleton axis inserted and
|
|
|
|
# then rely on the primitive's built-in broadcasting.
|
|
|
|
args = [bdim_at_front(x, d, 1) if np.ndim(x) else x
|
2021-03-19 21:01:00 -07:00
|
|
|
for x, d in zip(args, dims)]
|
2020-07-14 13:05:31 -07:00
|
|
|
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
|
2019-07-27 15:46:14 -07:00
|
|
|
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
|
2019-11-24 13:06:23 -05:00
|
|
|
out = prim.bind(*args, **params)
|
|
|
|
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
def _handle_scalar_broadcasting(nd, x, d):
|
2020-07-14 13:05:31 -07:00
|
|
|
if d is not_mapped or nd == np.ndim(x):
|
2019-07-27 15:46:14 -07:00
|
|
|
return x
|
|
|
|
else:
|
2022-03-31 00:17:57 +01:00
|
|
|
return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defreducer(prim):
|
|
|
|
primitive_batchers[prim] = partial(reducer_batcher, prim)
|
|
|
|
|
2019-01-10 15:35:15 -08:00
|
|
|
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
|
2018-11-17 18:03:33 -08:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
2022-10-17 11:15:14 -07:00
|
|
|
if isinstance(bdim, int):
|
|
|
|
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
|
|
|
|
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
|
|
|
|
if 'input_shape' in params:
|
|
|
|
params = dict(params, input_shape=operand.shape)
|
|
|
|
return prim.bind(operand, axes=axes, **params), bdim_out
|
|
|
|
elif isinstance(bdim, ConcatAxis):
|
|
|
|
if bdim.axis in axes:
|
|
|
|
other_axes = [i for i in axes if i != bdim.axis]
|
|
|
|
if other_axes:
|
|
|
|
operand = prim.bind(operand, axes=other_axes, **params)
|
|
|
|
c_axis = bdim.axis - sum(d < bdim.axis for d in other_axes)
|
|
|
|
operand = bdim_at_front(operand, c_axis, operand.shape[c_axis])
|
|
|
|
return segment_sum(operand, bdim.segment_lengths), 0
|
|
|
|
else:
|
|
|
|
raise NotImplementedError # TODO(mattjj)
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
|
|
|
# TODO(mattjj): replace with jax.lax.ops.segment_sum (once it's easier to trace
|
|
|
|
# under dynamic shapes)
|
|
|
|
def segment_sum(operand, segment_lens):
|
|
|
|
scat_idx = jax.numpy.cumsum(segment_lens) - segment_lens
|
|
|
|
segment_ids = jax.numpy.cumsum(
|
|
|
|
jax.numpy.zeros(operand.shape[0], 'int32').at[scat_idx].set(1)) - 1
|
|
|
|
out = jax.numpy.zeros((len(segment_lens), *operand.shape[1:]),
|
|
|
|
operand.dtype).at[segment_ids].add(operand)
|
|
|
|
return out
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
### general utilities for manipulating axes on jaxpr types (not vmappables)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def broadcast(x, sz, axis):
|
2020-07-14 13:05:31 -07:00
|
|
|
shape = list(np.shape(x))
|
2019-07-27 15:46:14 -07:00
|
|
|
shape.insert(axis, sz)
|
2020-07-14 13:05:31 -07:00
|
|
|
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
|
2020-06-08 13:46:10 -07:00
|
|
|
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
|
2022-11-06 22:56:51 -08:00
|
|
|
if dst == pile_axis:
|
|
|
|
x = bdim_at_front(x, src, sz)
|
|
|
|
elt_ty = x.aval.update(shape=x.shape[1:])
|
|
|
|
aval = PileTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))),
|
|
|
|
x.shape[0], elt_ty)
|
|
|
|
return Pile(aval, x)
|
2021-02-16 16:46:19 -05:00
|
|
|
try:
|
2022-05-02 17:11:44 -07:00
|
|
|
_ = core.get_aval(x)
|
2021-02-16 16:46:19 -05:00
|
|
|
except TypeError as e:
|
|
|
|
raise TypeError(f"Output from batched function {repr(x)} with type "
|
|
|
|
f"{type(x)} is not a valid JAX type") from e
|
2019-05-15 07:25:03 -07:00
|
|
|
if src == dst:
|
|
|
|
return x
|
2019-07-27 15:46:14 -07:00
|
|
|
elif type(src) == type(dst) == int:
|
|
|
|
return moveaxis(x, src, dst)
|
|
|
|
elif src is not_mapped and dst is not not_mapped:
|
2021-08-17 17:06:19 -07:00
|
|
|
return broadcast(x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
|
|
|
|
elif dst is not_mapped and sum_match:
|
2020-03-29 20:51:51 -07:00
|
|
|
return x.sum(src)
|
2019-05-15 07:25:03 -07:00
|
|
|
else:
|
2021-10-06 14:18:07 -07:00
|
|
|
if (not isinstance(axis_name, core._TempAxisName) and
|
|
|
|
axis_name is not core.no_axis_name):
|
2022-12-01 09:12:01 -08:00
|
|
|
raise ValueError(f'vmap has mapped output ({axis_name=}) but out_axes is {dst}')
|
2021-10-06 14:18:07 -07:00
|
|
|
else:
|
|
|
|
raise ValueError(f'vmap has mapped output but out_axes is {dst}')
|
2019-05-15 07:25:03 -07:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def bdim_at_front(x, bdim, size):
|
|
|
|
if bdim is not_mapped:
|
|
|
|
return broadcast(x, size, 0)
|
2019-05-15 07:25:03 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return moveaxis(x, bdim, 0)
|
2019-05-15 07:25:03 -07:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
# sets up primitive batchers for ad_util and xla primitives
|
2019-05-15 07:25:03 -07:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
def add_batched(batched_args, batch_dims):
|
|
|
|
bdx, bdy = batch_dims
|
|
|
|
x, y = batched_args
|
2022-05-02 17:11:44 -07:00
|
|
|
if bdx == bdy:
|
2021-10-06 14:18:07 -07:00
|
|
|
return add_jaxvals(x, y), bdx
|
|
|
|
elif bdx is not_mapped:
|
|
|
|
x = broadcast(x, y.shape[bdy], bdy)
|
|
|
|
return add_jaxvals(x, y), bdy
|
|
|
|
elif bdy is not_mapped:
|
|
|
|
y = broadcast(y, x.shape[bdx], bdx)
|
|
|
|
return add_jaxvals(x, y), bdx
|
2020-03-28 14:15:46 -07:00
|
|
|
else:
|
2021-10-06 14:18:07 -07:00
|
|
|
x = moveaxis(x, bdx, bdy)
|
|
|
|
return add_jaxvals(x, y), bdy
|
|
|
|
primitive_batchers[add_jaxvals_p] = add_batched
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2021-10-06 14:18:07 -07:00
|
|
|
def zeros_like_batched(batched_args, batch_dims):
|
|
|
|
val, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
return zeros_like_jaxval(val), bdim
|
|
|
|
primitive_batchers[zeros_like_p] = zeros_like_batched
|