mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
move jax.interpreters.batching
to jax._src.interpreters.batching
Re-export roughly all of the same symbols via `jax.interpreters.batching` for now. PiperOrigin-RevId: 508107044
This commit is contained in:
parent
85b0c88490
commit
55c2b6dad6
945
jax/_src/interpreters/batching.py
Normal file
945
jax/_src/interpreters/batching.py
Normal file
@ -0,0 +1,945 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Hashable, Iterable, Optional, Sequence,
|
||||
Set, Tuple, Type, Union)
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src.core import raise_to_shaped, Trace, Tracer
|
||||
from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
register_pytree_node)
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, Zero)
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function,
|
||||
curry, memoize, weakref_lru_cache)
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
Array = Any
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
# Piles
|
||||
|
||||
# i:(Fin 3) => f32[[3, 1, 4].i]
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PileTy:
|
||||
binder: core.Var
|
||||
length: Union[int, Tracer, core.Var]
|
||||
elt_ty: core.DShapedArray
|
||||
def __repr__(self) -> str:
|
||||
return f'Var{id(self.binder)}:{self.length} => {self.elt_ty}'
|
||||
replace = dataclasses.replace
|
||||
|
||||
# [3, 1, 4].i
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class IndexedAxisSize:
|
||||
idx: core.Var
|
||||
lengths: Union[Array, core.Var, Tracer]
|
||||
def __repr__(self) -> str:
|
||||
return f'{str(self.lengths)}.Var{id(self.idx)}'
|
||||
replace = dataclasses.replace
|
||||
|
||||
# Pile(aval=a:3 => f32[[3 1 4].a],
|
||||
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pile:
|
||||
aval: PileTy
|
||||
data: Array
|
||||
|
||||
# To vmap over a pile, one must specify the axis as PileAxis.
|
||||
class PileAxis: pass
|
||||
pile_axis = PileAxis()
|
||||
|
||||
# As a temporary measure before we have more general JITable / ADable interfaces
|
||||
# (analogues to vmappable), to enable Piles to be used with other
|
||||
# transformations and higher-order primitives (primarily jit, though also grad
|
||||
# with allow_int=True) we register them as pytrees.
|
||||
# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration
|
||||
def _pile_flatten(pile):
|
||||
lengths = []
|
||||
new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
|
||||
if type(d) is IndexedAxisSize else d
|
||||
for d in pile.aval.elt_ty.shape]
|
||||
elt_ty = pile.aval.elt_ty.update(shape=tuple(new_shape))
|
||||
aval = pile.aval.replace(elt_ty=elt_ty)
|
||||
return (lengths, pile.data), aval
|
||||
def _pile_unflatten(aval, x):
|
||||
lengths, data = x
|
||||
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
|
||||
if type(d) is IndexedAxisSize else d
|
||||
for d in aval.elt_ty.shape]
|
||||
elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
|
||||
aval = aval.replace(elt_ty=elt_ty)
|
||||
return Pile(aval, data)
|
||||
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
|
||||
|
||||
def _pile_result(axis_size, axis, segment_lens, x):
|
||||
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
|
||||
shape = list(x.shape)
|
||||
shape[axis] = IndexedAxisSize(binder, segment_lens)
|
||||
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
|
||||
return Pile(PileTy(binder, axis_size, elt_ty), x)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ConcatAxis:
|
||||
axis: int
|
||||
segment_lengths: Array
|
||||
|
||||
|
||||
def _update_annotation(
|
||||
f: lu.WrappedFun, orig_type: Optional[core.InputType],
|
||||
axis_size: core.AxisSize, axis_name: core.AxisName,
|
||||
explicit_in_dims: Sequence[Optional[Union[int, ConcatAxis]]],
|
||||
segment_lens: Sequence[Array],
|
||||
) -> lu.WrappedFun:
|
||||
if orig_type is None: return f
|
||||
# By convention, `explicit_in_dims` only accounts for explicit arguments.
|
||||
assert len(explicit_in_dims) == sum(explicit for _, explicit in orig_type)
|
||||
# We need to:
|
||||
# * if `axis_size` is dynamic, add a new implicit binder (type) for it;
|
||||
# * for each element of `segment_lengths`, add a new explicit binder for it;
|
||||
# * drop other implicit binders, replacing DBIdx which refer to them with
|
||||
# Name objects;
|
||||
# * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int
|
||||
# size if `axis_size` is int, otherwise Name); if ConcatAxis-valued in_dim,
|
||||
# add batch axis (int if corresponding segment_lengths is concrete, Name if
|
||||
# not);
|
||||
# * generate full in_type with implicit args too.
|
||||
|
||||
class Name:
|
||||
def __init__(self, a): self.a = a
|
||||
names = [Name(a) for a, _ in orig_type]
|
||||
avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d # type: ignore
|
||||
for d in a.shape))
|
||||
if type(a) is core.DShapedArray else a for a, e in orig_type if e]
|
||||
|
||||
new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
|
||||
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
|
||||
for a, d in zip(avals, explicit_in_dims):
|
||||
if isinstance(d, ConcatAxis):
|
||||
s = segment_lens[d.segment_lengths.val]
|
||||
if isinstance(core.get_aval(s), core.ConcreteArray):
|
||||
shape = list(a.shape) # type: ignore
|
||||
shape[d.axis] = int(s.sum()) # specialize on shape if we can
|
||||
new_avals.append(a.update(shape=tuple(shape)))
|
||||
else:
|
||||
new_avals.append(a)
|
||||
else:
|
||||
new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore
|
||||
|
||||
mentioned = {d for a in new_avals if type(a) is core.DShapedArray
|
||||
for d in a.shape if type(d) is Name}
|
||||
expl_names = set(map(Name, new_avals))
|
||||
impl_names = mentioned - expl_names # type: ignore
|
||||
impl_part = [(n.a, False) for n in impl_names] # type: ignore
|
||||
name_map = {n: pe.DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))}
|
||||
expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape))
|
||||
if type(a) is core.DShapedArray else a, True) for a in new_avals]
|
||||
return lu.annotate(f, (*impl_part, *expl_part))
|
||||
|
||||
### vmappable typeclass
|
||||
|
||||
Vmappable = Any
|
||||
Elt = Any
|
||||
MapSpec = Any
|
||||
AxisSize = Any
|
||||
GetIdx = Callable[[], Tracer] # TODO(mattjj): revise this laziness
|
||||
ToEltHandler = Callable[[Callable, GetIdx, Vmappable, MapSpec], Elt]
|
||||
FromEltHandler = Callable[[Callable, AxisSize, Elt, MapSpec], Vmappable]
|
||||
MakeIotaHandler = Callable[[AxisSize], Array]
|
||||
|
||||
def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
||||
handler = to_elt_handlers.get(type(x))
|
||||
if handler:
|
||||
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
|
||||
elif type(x) is Pile:
|
||||
if spec is not pile_axis:
|
||||
raise TypeError("pile input without using pile_axis in_axes spec")
|
||||
(d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape)
|
||||
if type(sz) is IndexedAxisSize)
|
||||
return BatchTracer(trace, x.data, ConcatAxis(d, ias.lengths)) # type: ignore
|
||||
elif isinstance(spec, int) or spec is None:
|
||||
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
|
||||
return (BatchTracer(trace, x, spec, source_info_util.current())
|
||||
if spec is not None else x)
|
||||
else:
|
||||
assert False
|
||||
to_elt_handlers: Dict[Type, ToEltHandler] = {}
|
||||
|
||||
def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
|
||||
) -> Vmappable:
|
||||
handler = from_elt_handlers.get(type(x))
|
||||
if handler:
|
||||
return handler(partial(from_elt, trace), axis_size, x, spec)
|
||||
x_ = trace.full_raise(x)
|
||||
val, bdim = x_.val, x_.batch_dim
|
||||
if type(bdim) is ConcatAxis:
|
||||
if spec is not pile_axis:
|
||||
# TODO(mattjj): improve this error message
|
||||
raise TypeError("ragged output without using pile_axis out_axes spec")
|
||||
return _pile_result(axis_size, bdim.axis, bdim.segment_lengths, val)
|
||||
else:
|
||||
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
|
||||
from_elt_handlers: Dict[Type, FromEltHandler] = {}
|
||||
|
||||
def make_iota(axis_size: AxisSize) -> Array:
|
||||
handler = make_iota_handlers.get(type(axis_size))
|
||||
if handler:
|
||||
return handler(axis_size)
|
||||
else:
|
||||
return jax.lax.iota('int32', int(axis_size))
|
||||
make_iota_handlers: Dict[Type, MakeIotaHandler] = {}
|
||||
|
||||
def register_vmappable(data_type: Type, spec_type: Type, axis_size_type: Type,
|
||||
to_elt: Callable, from_elt: Callable,
|
||||
make_iota: Optional[Callable]):
|
||||
vmappables[data_type] = (spec_type, axis_size_type)
|
||||
spec_types.add(spec_type)
|
||||
to_elt_handlers[data_type] = to_elt
|
||||
from_elt_handlers[data_type] = from_elt
|
||||
if make_iota: make_iota_handlers[axis_size_type] = make_iota
|
||||
vmappables: Dict[Type, Tuple[Type, Type]] = {}
|
||||
spec_types: Set[Type] = {PileAxis}
|
||||
|
||||
def unregister_vmappable(data_type: Type) -> None:
|
||||
spec_type, axis_size_type = vmappables.pop(data_type)
|
||||
spec_types.remove(spec_type)
|
||||
del to_elt_handlers[data_type]
|
||||
del from_elt_handlers[data_type]
|
||||
if axis_size_type in make_iota_handlers:
|
||||
del make_iota_handlers[axis_size_type]
|
||||
|
||||
def is_vmappable(x: Any) -> bool:
|
||||
return type(x) is Pile or type(x) in vmappables
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_for_vmap(in_tree, *args_flat):
|
||||
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
||||
ans = yield py_args, py_kwargs
|
||||
yield tree_flatten(ans, is_leaf=is_vmappable)
|
||||
|
||||
### tracer
|
||||
|
||||
# TODO(mattjj): use a special sentinel type rather than None
|
||||
NotMapped = type(None)
|
||||
not_mapped = None
|
||||
|
||||
|
||||
class BatchTracer(Tracer):
|
||||
__slots__ = ['val', 'batch_dim', 'source_info']
|
||||
|
||||
def __init__(self, trace, val, batch_dim: Union[NotMapped, int, ConcatAxis],
|
||||
source_info: Optional[source_info_util.SourceInfo] = None):
|
||||
if config.jax_enable_checks:
|
||||
assert type(batch_dim) in (NotMapped, int, ConcatAxis)
|
||||
if type(batch_dim) is int:
|
||||
aval = raise_to_shaped(core.get_aval(val))
|
||||
assert 0 <= batch_dim < len(aval.shape) # type: ignore
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
self.batch_dim = batch_dim
|
||||
self.source_info = source_info
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
aval = raise_to_shaped(core.get_aval(self.val))
|
||||
if self.batch_dim is not_mapped:
|
||||
return aval
|
||||
elif type(self.batch_dim) is int:
|
||||
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
|
||||
elif type(self.batch_dim) is ConcatAxis:
|
||||
shape = list(aval.shape)
|
||||
size_tracer = BatchTracer(self._trace, self.batch_dim.segment_lengths, 0)
|
||||
shape[self.batch_dim.axis] = size_tracer
|
||||
return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype,
|
||||
weak_type=aval.weak_type)
|
||||
|
||||
def full_lower(self):
|
||||
if self.batch_dim is not_mapped:
|
||||
return core.full_lower(self.val)
|
||||
else:
|
||||
return self
|
||||
|
||||
def _origin_msg(self):
|
||||
if self.source_info is None:
|
||||
return ""
|
||||
return (f"\nThis BatchTracer with object id {id(self)} was created on line:"
|
||||
f"\n {source_info_util.summarize(self.source_info)}")
|
||||
|
||||
def _contents(self):
|
||||
return [('val', self.val), ('batch_dim', self.batch_dim)]
|
||||
|
||||
def get_referent(self):
|
||||
if self.batch_dim is None or type(self.batch_dim) is int:
|
||||
return core.get_referent(self.val)
|
||||
else: # TODO(mattjj): could handle the ConcatAxis case?
|
||||
return self
|
||||
|
||||
class BatchTrace(Trace):
|
||||
|
||||
def __init__(self, *args, axis_name, spmd_axis_name = None):
|
||||
super().__init__(*args)
|
||||
self.axis_name = axis_name
|
||||
self.spmd_axis_name = spmd_axis_name
|
||||
|
||||
def pure(self, val):
|
||||
return BatchTracer(self, val, not_mapped, source_info_util.current())
|
||||
|
||||
def lift(self, val):
|
||||
return BatchTracer(self, val, not_mapped, source_info_util.current())
|
||||
|
||||
def sublift(self, val):
|
||||
return BatchTracer(self, val.val, val.batch_dim, source_info_util.current())
|
||||
|
||||
def get_primitive_batcher(self, primitive, frame):
|
||||
if primitive in primitive_batchers:
|
||||
return primitive_batchers[primitive]
|
||||
elif self.spmd_axis_name is not None and primitive in spmd_axis_primitive_batchers:
|
||||
return partial(spmd_axis_primitive_batchers[primitive],
|
||||
self.spmd_axis_name, frame.size, frame.name,
|
||||
frame.main_trace.trace_type)
|
||||
elif primitive in axis_primitive_batchers:
|
||||
return self.get_axis_primitive_batcher(primitive, frame)
|
||||
msg = "Batching rule for '{}' not implemented"
|
||||
raise NotImplementedError(msg.format(primitive))
|
||||
|
||||
def get_axis_primitive_batcher(self, primitive, frame):
|
||||
return partial(axis_primitive_batchers[primitive],
|
||||
frame.size, frame.name, frame.main_trace.trace_type)
|
||||
|
||||
def get_frame(self, vals, dims) -> core.AxisEnvFrame:
|
||||
if self.axis_name is core.no_axis_name:
|
||||
# If axis name is `no_axis_name` we can't find it via `core.axis_name` so
|
||||
# we reconstruct it from the information we have available
|
||||
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
|
||||
for x, d in zip(vals, dims) if d is not not_mapped)
|
||||
axis_size, = core.dedup_referents(sizes)
|
||||
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
|
||||
return core.axis_frame(self.axis_name)
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
is_axis_primitive = primitive in axis_primitive_batchers
|
||||
used_names = core.used_axis_names(primitive, params)
|
||||
if is_axis_primitive and _main_trace_for_axis_names(self.main, used_names):
|
||||
frame = self.get_frame(vals_in, dims_in)
|
||||
batcher_primitive = self.get_axis_primitive_batcher(primitive, frame)
|
||||
val_out, dim_out = batcher_primitive(vals_in, dims_in, **params)
|
||||
elif all(bdim is not_mapped for bdim in dims_in):
|
||||
return primitive.bind(*vals_in, **params)
|
||||
else:
|
||||
frame = self.get_frame(vals_in, dims_in)
|
||||
batched_primitive = self.get_primitive_batcher(primitive, frame)
|
||||
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
|
||||
src = source_info_util.current()
|
||||
if primitive.multiple_results:
|
||||
return [BatchTracer(self, x, d, src) for x, d in zip(val_out, dim_out)]
|
||||
else:
|
||||
return BatchTracer(self, val_out, dim_out, src)
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
assert call_primitive.multiple_results
|
||||
params = dict(params, name=params.get('name', f.__name__))
|
||||
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
if all(bdim is not_mapped for bdim in dims):
|
||||
return call_primitive.bind(f, *vals, **params)
|
||||
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
|
||||
for x, d in zip(vals, dims) if d is not not_mapped)
|
||||
axis_size, = core.dedup_referents(sizes)
|
||||
segment_lens, dims = unpack_concat_axes(dims)
|
||||
f_, dims_out = batch_subtrace(f, self.main, tuple(dims))
|
||||
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims,
|
||||
segment_lens)
|
||||
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
|
||||
vals_out, dims_out = reassemble_concat_axes(vals_out, dims_out())
|
||||
src = source_info_util.current()
|
||||
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracers, params):
|
||||
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
||||
for t in out_tracers)
|
||||
main = self.main
|
||||
def todo(vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
return map(partial(BatchTracer, trace), vals, dims, srcs)
|
||||
return vals, todo
|
||||
|
||||
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
||||
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
if all(dim is not_mapped for dim in dims):
|
||||
return map_primitive.bind(f, *vals, **params)
|
||||
else:
|
||||
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
|
||||
# The logic for the dimension math below is as follows:
|
||||
# ╔═════════════╦════════════════════════════════════════╦═══════════╗
|
||||
# ║ d / in_axis ║ None ║ int ║
|
||||
# ╠═════════════╬════════════════════════════════════════╩═══════════╣
|
||||
# ║ None ║ No extra axis, so in_axis unaffected ║
|
||||
# ╠═════════════╬════════════════════════════════════════╦═══════════╣
|
||||
# ║ int ║ Not mapped, so batching dim unaffected ║ See below ║
|
||||
# ╚═════════════╩════════════════════════════════════════╩═══════════╝
|
||||
# When both d and in_axis are defined then:
|
||||
# - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
|
||||
# - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed).
|
||||
def both_mapped(in_out_axis, d):
|
||||
return in_out_axis is not None and d is not not_mapped
|
||||
new_in_axes = tuple(
|
||||
in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
|
||||
for d, in_axis in zip(dims, params['in_axes']))
|
||||
new_dims = tuple(
|
||||
d - 1 if both_mapped(in_axis, d) and in_axis < d else d
|
||||
for d, in_axis in zip(dims, params['in_axes']))
|
||||
f, dims_out = batch_subtrace(f, self.main, new_dims)
|
||||
out_axes_thunk = params['out_axes_thunk']
|
||||
# NOTE: This assumes that the choice of the dimensions over which outputs
|
||||
# are batched is entirely dependent on the function and not e.g. on the
|
||||
# data or its shapes.
|
||||
@as_hashable_function(closure=out_axes_thunk)
|
||||
def new_out_axes_thunk():
|
||||
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
|
||||
for out_axis, d in zip(out_axes_thunk(), dims_out()))
|
||||
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
|
||||
vals_out = map_primitive.bind(f, *vals, **new_params)
|
||||
dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
|
||||
for d, out_axis in zip(dims_out(), out_axes_thunk())]
|
||||
src = source_info_util.current()
|
||||
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)]
|
||||
|
||||
def post_process_map(self, call_primitive, out_tracers, params):
|
||||
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
||||
for t in out_tracers)
|
||||
main = self.main
|
||||
def both_mapped(in_out_axis, d):
|
||||
return in_out_axis is not None and d is not not_mapped
|
||||
def todo(vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
return [BatchTracer(trace, v, d + 1 if both_mapped(oa, d) and oa <= d else d, s)
|
||||
for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](), srcs)]
|
||||
if call_primitive.map_primitive:
|
||||
def out_axes_transform(out_axes):
|
||||
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
|
||||
for out_axis, d in zip(out_axes, dims))
|
||||
todo = (todo, out_axes_transform)
|
||||
return vals, todo
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
||||
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
|
||||
out_vals = prim.bind(fun, jvp, *in_vals)
|
||||
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
||||
if not fst:
|
||||
assert out_dims == out_dims[:len(out_dims) // 2] * 2
|
||||
out_dims = out_dims[:len(out_dims) // 2]
|
||||
src = source_info_util.current()
|
||||
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
|
||||
|
||||
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
|
||||
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
||||
for t in out_tracers)
|
||||
main = self.main
|
||||
def todo(vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
if jvp_was_run:
|
||||
primal_dims, tangent_dims = dims[:len(vals)], dims[len(vals):]
|
||||
assert primal_dims == tangent_dims
|
||||
primal_srcs = srcs[:len(vals)]
|
||||
return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
|
||||
else:
|
||||
return map(partial(BatchTracer, trace), vals, dims, srcs)
|
||||
return vals, todo
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): # pytype: disable=signature-mismatch
|
||||
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
|
||||
if d is not not_mapped}
|
||||
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
||||
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
|
||||
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
|
||||
out_dims2, in_dims, self.main.trace_type, self.spmd_axis_name)
|
||||
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
|
||||
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
||||
if not fst:
|
||||
_, res_tree = out_trees()
|
||||
_, out_dims = split_list(out_dims, [res_tree.num_leaves])
|
||||
src = source_info_util.current()
|
||||
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
|
||||
|
||||
def post_process_custom_vjp_call(self, out_tracers, _):
|
||||
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
||||
for t in out_tracers)
|
||||
main = self.main
|
||||
def todo(vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
return map(partial(BatchTracer, trace), vals, dims, srcs)
|
||||
return vals, todo
|
||||
|
||||
def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees):
|
||||
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
||||
for t in out_tracers)
|
||||
axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}
|
||||
main, trace_type = self.main, self.main.trace_type
|
||||
axis_name = self.axis_name
|
||||
_, res_tree = out_trees()
|
||||
num_res = res_tree.num_leaves
|
||||
res_dims, primal_dims = split_list(dims, [num_res])
|
||||
_, primal_srcs = split_list(srcs, [num_res])
|
||||
def todo(vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
|
||||
def bwd_transform(bwd):
|
||||
return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,),
|
||||
trace_type, self.spmd_axis_name)
|
||||
return vals, todo, bwd_transform
|
||||
|
||||
def _main_trace_for_axis_names(main_trace: core.MainTrace,
|
||||
axis_name: Iterable[core.AxisName],
|
||||
) -> bool:
|
||||
# This function exists to identify whether a main trace corresponds to any of
|
||||
# the axis names used by a primitive. Axis names alone aren't enough because
|
||||
# axis names can shadow, so we use the main trace as a tag.
|
||||
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
|
||||
|
||||
### API for batching callables with vmappable inputs and outputs
|
||||
|
||||
def batch(fun: lu.WrappedFun, axis_name: core.AxisName, axis_size,
|
||||
in_dims, out_dim_dests, main_type: Type[BatchTrace] = BatchTrace,
|
||||
spmd_axis_name: Optional[Hashable] = None) -> lu.WrappedFun:
|
||||
# we split up _batch_inner and _batch_outer for the leak checker
|
||||
f = _batch_inner(fun, axis_size, out_dim_dests)
|
||||
return _batch_outer(f, axis_name, axis_size, in_dims, main_type,
|
||||
spmd_axis_name)
|
||||
|
||||
@lu.transformation
|
||||
def _batch_outer(axis_name, axis_size, in_dims, main_type, spmd_axis_name,
|
||||
*in_vals):
|
||||
with core.new_main(
|
||||
main_type, axis_name=axis_name, spmd_axis_name=spmd_axis_name) as main:
|
||||
with core.extend_axis_env(axis_name, axis_size, main):
|
||||
with source_info_util.transform_name_stack('vmap'):
|
||||
outs = yield (main, in_dims, *in_vals), {}
|
||||
del main
|
||||
yield outs
|
||||
|
||||
@lu.transformation
|
||||
def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals):
|
||||
in_dims = in_dims() if callable(in_dims) else in_dims
|
||||
trace = main.with_cur_sublevel()
|
||||
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_size), 0,
|
||||
source_info_util.current()))
|
||||
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
|
||||
outs = yield in_tracers, {}
|
||||
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
||||
out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)
|
||||
yield out_vals
|
||||
|
||||
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
||||
def vtile(f_flat: lu.WrappedFun,
|
||||
in_axes_flat: Tuple[Optional[int], ...],
|
||||
out_axes_flat: Tuple[Optional[int], ...],
|
||||
tile_size: Optional[int],
|
||||
axis_name: core.AxisName,
|
||||
main_type: Type[BatchTrace] = BatchTrace):
|
||||
@curry
|
||||
def tile_axis(arg, axis: Optional[int], tile_size):
|
||||
if axis is None:
|
||||
return arg
|
||||
shape = list(arg.shape)
|
||||
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
|
||||
return arg.reshape(shape)
|
||||
|
||||
def untile_axis(out, axis: Optional[int]):
|
||||
if axis is None:
|
||||
return out
|
||||
shape = list(out.shape)
|
||||
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
|
||||
return out.reshape(shape)
|
||||
|
||||
@lu.transformation
|
||||
def _map_to_tile(*args_flat):
|
||||
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
|
||||
tile_size_ = tile_size or next(sizes, None)
|
||||
assert tile_size_ is not None, "No mapped arguments?"
|
||||
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
|
||||
yield map(untile_axis, outputs_flat, out_axes_flat)
|
||||
|
||||
return _map_to_tile(batch(
|
||||
f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type))
|
||||
|
||||
### API for batching functions with jaxpr type inputs and outputs
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def batch_subtrace(main, in_dims, *in_vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
in_dims = in_dims() if callable(in_dims) else in_dims
|
||||
in_vals, in_dims = reassemble_concat_axes(in_vals, in_dims)
|
||||
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
|
||||
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, outs)
|
||||
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
segment_lens, out_dims = unpack_concat_axes(out_dims)
|
||||
yield (*segment_lens, *out_vals), out_dims
|
||||
|
||||
def unpack_concat_axes(dims):
|
||||
if not any(type(d) is ConcatAxis for d in dims):
|
||||
return [], dims
|
||||
concat_axis_map = collections.OrderedDict()
|
||||
def convert(d: ConcatAxis) -> ConcatAxis:
|
||||
_, dbidx = concat_axis_map.setdefault(
|
||||
id(core.get_referent(d.segment_lengths)),
|
||||
(d.segment_lengths, pe.DBIdx(len(concat_axis_map))))
|
||||
return ConcatAxis(d.axis, dbidx)
|
||||
new_dims = [convert(d) if isinstance(d, ConcatAxis) else d for d in dims]
|
||||
segment_lens = [s for s, _ in concat_axis_map.values()]
|
||||
return segment_lens, new_dims
|
||||
|
||||
def reassemble_concat_axes(vals, dims):
|
||||
idxs = {d.segment_lengths.val for d in dims if isinstance(d, ConcatAxis)}
|
||||
dims = [ConcatAxis(d.axis, vals[d.segment_lengths.val])
|
||||
if isinstance(d, ConcatAxis) else d for d in dims]
|
||||
vals = [x for i, x in enumerate(vals) if i not in idxs]
|
||||
return vals, dims
|
||||
|
||||
|
||||
### API for batching jaxprs
|
||||
|
||||
def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
||||
axis_size: core.AxisSize,
|
||||
in_axes: Tuple[Union[int, NotMapped], ...],
|
||||
axis_name: core.AxisName,
|
||||
main_type: Type[BatchTrace],
|
||||
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
||||
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name,
|
||||
main_type)
|
||||
|
||||
@weakref_lru_cache
|
||||
def _batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
||||
axis_size: core.AxisSize,
|
||||
in_axes: Tuple[Union[int, NotMapped], ...],
|
||||
axis_name: core.AxisName,
|
||||
main_type: Type[BatchTrace],
|
||||
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_size)
|
||||
f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
|
||||
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval)
|
||||
if b is not not_mapped else aval
|
||||
for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
||||
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
|
||||
|
||||
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name,
|
||||
main_type):
|
||||
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
|
||||
return _batch_jaxpr(closed_jaxpr, axis_size, tuple(in_batched), inst,
|
||||
axis_name, main_type)
|
||||
|
||||
def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name,
|
||||
main_type):
|
||||
assert (isinstance(instantiate, bool) or
|
||||
isinstance(instantiate, (list, tuple)) and
|
||||
all(isinstance(b, bool) for b in instantiate))
|
||||
if isinstance(instantiate, bool):
|
||||
instantiate = [instantiate] * len(closed_jaxpr.out_avals)
|
||||
in_axes = [0 if b else not_mapped for b in in_batched]
|
||||
out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate]
|
||||
return batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
|
||||
axis_name, main_type)
|
||||
|
||||
def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name,
|
||||
main_type):
|
||||
return _batch_jaxpr_axes(closed_jaxpr, axis_size, tuple(in_axes),
|
||||
tuple(out_axes_dest), axis_name, main_type)
|
||||
|
||||
@weakref_lru_cache
|
||||
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
|
||||
axis_name, main_type):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_size)
|
||||
f, out_batched = _match_axes_jaxpr(f, axis_size, out_axes_dest, out_axes)
|
||||
f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
|
||||
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
|
||||
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
||||
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
||||
for val, dim in zip(in_vals, in_axes)]
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, outs)
|
||||
out_vals, out_axes = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
yield out_vals, out_axes
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes,
|
||||
*in_vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
out_vals = yield (main, in_axes, *in_vals), {}
|
||||
out_axes = out_axes()
|
||||
out_axes_dest = [(None if src is not_mapped else 0)
|
||||
if dst is zero_if_mapped else dst
|
||||
for src, dst in unsafe_zip(out_axes, out_axes_dest)]
|
||||
if len(out_axes_dest) != len(out_axes):
|
||||
out_axis_dest, = out_axes_dest
|
||||
out_axes_dest = [out_axis_dest] * len(out_axes)
|
||||
out_vals = map(partial(matchaxis, trace.axis_name, axis_size),
|
||||
out_axes, out_axes_dest, out_vals)
|
||||
out_batched = [dst is not None for dst in out_axes_dest]
|
||||
yield out_vals, out_batched
|
||||
|
||||
@lu.transformation
|
||||
def _batch_jaxpr_outer(axis_name, axis_size, in_dims, main_type, *in_vals):
|
||||
if axis_size is None:
|
||||
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
||||
in_dims = in_dims() if callable(in_dims) else in_dims
|
||||
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
|
||||
else ax for x, ax in unsafe_zip(in_vals, in_dims)]
|
||||
with core.new_main(main_type, axis_name=axis_name) as main:
|
||||
with core.extend_axis_env(axis_name, axis_size, main):
|
||||
out_vals = yield (main, in_dims, *in_vals), {}
|
||||
del main
|
||||
yield out_vals
|
||||
|
||||
def _merge_bdims(x, y):
|
||||
if x == y:
|
||||
return x
|
||||
elif x is not_mapped:
|
||||
return y
|
||||
elif y is not_mapped:
|
||||
return x
|
||||
else:
|
||||
return x # arbitrary
|
||||
|
||||
class ZeroIfMapped: pass
|
||||
zero_if_mapped = ZeroIfMapped()
|
||||
|
||||
### functions for handling custom_vjp
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
|
||||
size, = {x.shape[d] for x, d in zip(in_vals, in_dims * 2)
|
||||
if d is not not_mapped}
|
||||
trace = main.with_cur_sublevel()
|
||||
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
||||
for val, dim in zip(in_vals, in_dims * 2)]
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, outs)
|
||||
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
|
||||
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
|
||||
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
|
||||
out_primals = map(partial(matchaxis, trace.axis_name, size),
|
||||
out_primal_bds, out_dims, out_primals)
|
||||
out_tangents = map(partial(matchaxis, trace.axis_name, size),
|
||||
out_tangent_bds, out_dims, out_tangents)
|
||||
yield out_primals + out_tangents, out_dims * 2
|
||||
|
||||
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name):
|
||||
bwd, out_dims_thunk = batch_subtrace(bwd)
|
||||
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type, spmd_axis_name)
|
||||
return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests)
|
||||
|
||||
@lu.transformation
|
||||
def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
# this is like _match_axes, but we do reduce-sums as needed
|
||||
out_vals = yield in_vals, {}
|
||||
yield map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name,
|
||||
sum_match=True), out_dims_thunk(), out_dim_dests, out_vals)
|
||||
|
||||
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
|
||||
# Just like `matchaxis`, but handles symbolic zeros using ad_util.py
|
||||
# TODO(mattjj): dedup with matchaxis
|
||||
if isinstance(x, Zero):
|
||||
if src == dst:
|
||||
return x
|
||||
elif type(src) == type(dst) == int:
|
||||
aval = core.mapped_aval(sz, src, x.aval)
|
||||
return Zero(core.unmapped_aval(sz, name, dst, aval))
|
||||
elif src is not_mapped and dst is not not_mapped:
|
||||
return Zero(core.unmapped_aval(sz, name, dst, x.aval))
|
||||
elif dst is not_mapped and sum_match:
|
||||
return Zero(core.mapped_aval(sz, src, x.aval))
|
||||
else:
|
||||
raise ValueError((axis_name, x, src, dst))
|
||||
else:
|
||||
return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
|
||||
|
||||
|
||||
### utilities for defining primitives' batching rules
|
||||
|
||||
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
|
||||
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
|
||||
axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
||||
spmd_axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
def defvectorized(prim):
|
||||
primitive_batchers[prim] = partial(vectorized_batcher, prim)
|
||||
|
||||
def vectorized_batcher(prim, batched_args, batch_dims, **params):
|
||||
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
|
||||
return prim.bind(*batched_args, **params), batch_dims[0]
|
||||
|
||||
def defbroadcasting(prim):
|
||||
primitive_batchers[prim] = partial(broadcast_batcher, prim)
|
||||
|
||||
def broadcast_batcher(prim, args, dims, **params):
|
||||
"""Process a primitive with built-in broadcasting.
|
||||
|
||||
Args:
|
||||
args: the possibly-batched arguments
|
||||
dims: list or tuple of the same length as `args`, where each
|
||||
entry indicates the batching state of the corresponding entry to `args`:
|
||||
either an int indicating the batch dimension, or else `not_mapped`
|
||||
indicating no batching.
|
||||
"""
|
||||
assert len(args) > 1
|
||||
shape, dim = next((x.shape, d) for x, d in zip(args, dims)
|
||||
if d is not not_mapped)
|
||||
if all(core.symbolic_equal_shape(shape, x.shape) and d == dim
|
||||
for x, d in zip(args, dims) if np.ndim(x)):
|
||||
# if there's only agreeing batch dims and scalars, just call the primitive
|
||||
out = prim.bind(*args, **params)
|
||||
return (out, (dim,) * len(out)) if prim.multiple_results else (out, dim)
|
||||
else:
|
||||
# We pass size of 1 here because (1) at least one argument has a real batch
|
||||
# dimension and (2) all unmapped axes can have a singleton axis inserted and
|
||||
# then rely on the primitive's built-in broadcasting.
|
||||
args = [bdim_at_front(x, d, 1) if np.ndim(x) else x
|
||||
for x, d in zip(args, dims)]
|
||||
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
|
||||
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
|
||||
out = prim.bind(*args, **params)
|
||||
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
|
||||
|
||||
def _handle_scalar_broadcasting(nd, x, d):
|
||||
if d is not_mapped or nd == np.ndim(x):
|
||||
return x
|
||||
else:
|
||||
return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
|
||||
|
||||
def defreducer(prim):
|
||||
primitive_batchers[prim] = partial(reducer_batcher, prim)
|
||||
|
||||
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
if isinstance(bdim, int):
|
||||
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
|
||||
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
|
||||
if 'input_shape' in params:
|
||||
params = dict(params, input_shape=operand.shape)
|
||||
return prim.bind(operand, axes=axes, **params), bdim_out
|
||||
elif isinstance(bdim, ConcatAxis):
|
||||
if bdim.axis in axes:
|
||||
other_axes = [i for i in axes if i != bdim.axis]
|
||||
if other_axes:
|
||||
operand = prim.bind(operand, axes=other_axes, **params)
|
||||
c_axis = bdim.axis - sum(d < bdim.axis for d in other_axes)
|
||||
operand = bdim_at_front(operand, c_axis, operand.shape[c_axis])
|
||||
return segment_sum(operand, bdim.segment_lengths), 0
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# TODO(mattjj): replace with jax.lax.ops.segment_sum (once it's easier to trace
|
||||
# under dynamic shapes)
|
||||
def segment_sum(operand, segment_lens):
|
||||
scat_idx = jax.numpy.cumsum(segment_lens) - segment_lens
|
||||
segment_ids = jax.numpy.cumsum(
|
||||
jax.numpy.zeros(operand.shape[0], 'int32').at[scat_idx].set(1)) - 1
|
||||
out = jax.numpy.zeros((len(segment_lens), *operand.shape[1:]),
|
||||
operand.dtype).at[segment_ids].add(operand)
|
||||
return out
|
||||
|
||||
### general utilities for manipulating axes on jaxpr types (not vmappables)
|
||||
|
||||
def broadcast(x, sz, axis):
|
||||
shape = list(np.shape(x))
|
||||
shape.insert(axis, sz)
|
||||
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
|
||||
|
||||
def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
|
||||
if dst == pile_axis:
|
||||
x = bdim_at_front(x, src, sz)
|
||||
elt_ty = x.aval.update(shape=x.shape[1:])
|
||||
aval = PileTy(core.Var(0, '', core.ShapedArray((), np.dtype('int32'))),
|
||||
x.shape[0], elt_ty)
|
||||
return Pile(aval, x)
|
||||
try:
|
||||
_ = core.get_aval(x)
|
||||
except TypeError as e:
|
||||
raise TypeError(f"Output from batched function {repr(x)} with type "
|
||||
f"{type(x)} is not a valid JAX type") from e
|
||||
if src == dst:
|
||||
return x
|
||||
elif type(src) == type(dst) == int:
|
||||
return moveaxis(x, src, dst)
|
||||
elif src is not_mapped and dst is not not_mapped:
|
||||
return broadcast(x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
|
||||
elif dst is not_mapped and sum_match:
|
||||
return x.sum(src)
|
||||
else:
|
||||
if (not isinstance(axis_name, core._TempAxisName) and
|
||||
axis_name is not core.no_axis_name):
|
||||
raise ValueError(f'vmap has mapped output ({axis_name=}) but out_axes is {dst}')
|
||||
else:
|
||||
raise ValueError(f'vmap has mapped output but out_axes is {dst}')
|
||||
|
||||
def bdim_at_front(x, bdim, size):
|
||||
if bdim is not_mapped:
|
||||
return broadcast(x, size, 0)
|
||||
else:
|
||||
return moveaxis(x, bdim, 0)
|
||||
|
||||
# sets up primitive batchers for ad_util and xla primitives
|
||||
|
||||
def add_batched(batched_args, batch_dims):
|
||||
bdx, bdy = batch_dims
|
||||
x, y = batched_args
|
||||
if bdx == bdy:
|
||||
return add_jaxvals(x, y), bdx
|
||||
elif bdx is not_mapped:
|
||||
x = broadcast(x, y.shape[bdy], bdy)
|
||||
return add_jaxvals(x, y), bdy
|
||||
elif bdy is not_mapped:
|
||||
y = broadcast(y, x.shape[bdx], bdx)
|
||||
return add_jaxvals(x, y), bdx
|
||||
else:
|
||||
x = moveaxis(x, bdx, bdy)
|
||||
return add_jaxvals(x, y), bdy
|
||||
primitive_batchers[add_jaxvals_p] = add_batched
|
||||
|
||||
def zeros_like_batched(batched_args, batch_dims):
|
||||
val, = batched_args
|
||||
bdim, = batch_dims
|
||||
return zeros_like_jaxval(val), bdim
|
||||
primitive_batchers[zeros_like_p] = zeros_like_batched
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user