mirror of
https://github.com/ROCm/jax.git
synced 2025-04-29 02:46:06 +00:00
513 lines
22 KiB
Python
513 lines
22 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import numpy as np
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence, Iterable, Type
|
|
|
|
import jax
|
|
from ..config import config
|
|
from .. import core
|
|
from ..core import raise_to_shaped, Trace, Tracer
|
|
from jax._src.ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
|
|
from .. import linear_util as lu
|
|
from .._src.util import (unzip2, partial, safe_map, safe_zip, wrap_name, split_list,
|
|
canonicalize_axis, moveaxis, as_hashable_function, curry)
|
|
from . import xla
|
|
from . import partial_eval as pe
|
|
|
|
map = safe_map
|
|
|
|
BatchDim = Optional[int]
|
|
BatchDims = Sequence[BatchDim]
|
|
@lu.transformation
|
|
def batchfun(axis_name, axis_size, in_dims, main_type, *in_vals):
|
|
# analogue of `jvpfun` in ad.py
|
|
if axis_size is None:
|
|
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
|
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
|
|
and not isinstance(core.get_aval(x), core.AbstractUnit)
|
|
else ax for x, ax in zip(in_vals, in_dims)]
|
|
with core.new_main(main_type, axis_name=axis_name) as main:
|
|
with core.extend_axis_env(axis_name, axis_size, main):
|
|
out_vals = yield (main, in_dims, *in_vals), {}
|
|
del main
|
|
yield out_vals
|
|
|
|
@lu.transformation_with_aux
|
|
def batch_subtrace(main, in_dims, *in_vals):
|
|
# analogue of `jvp_subtrace` in ad.py
|
|
trace = main.with_cur_sublevel()
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
for val, dim in zip(in_vals, in_dims)]
|
|
outs = yield in_tracers, {}
|
|
out_tracers = map(trace.full_raise, outs)
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
yield out_vals, out_dims
|
|
|
|
@lu.transformation
|
|
def _match_axes(axis_size, axis_name, in_dims, out_dims_thunk, out_dim_dests,
|
|
*in_vals):
|
|
if axis_size is None:
|
|
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
|
out_vals = yield in_vals, {}
|
|
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
|
out_dims = out_dims_thunk()
|
|
for od, od_dest in zip(out_dims, out_dim_dests):
|
|
if od is not None and not isinstance(od_dest, int):
|
|
if not isinstance(axis_name, core._TempAxisName) and axis_name is not None:
|
|
msg = f"vmap has mapped output (axis_name={axis_name}) but out_axes is {od_dest}"
|
|
else:
|
|
msg = f"vmap has mapped output but out_axes is {od_dest}"
|
|
raise ValueError(msg)
|
|
yield map(partial(matchaxis, axis_size), out_dims, out_dim_dests, out_vals)
|
|
|
|
|
|
### tracer
|
|
|
|
# TODO(mattjj): use a special sentinel type rather than None
|
|
NotMapped = type(None)
|
|
not_mapped = None
|
|
|
|
class BatchTracer(Tracer):
|
|
__slots__ = ['val', 'batch_dim']
|
|
|
|
def __init__(self, trace, val, batch_dim: Optional[int]):
|
|
if config.jax_enable_checks:
|
|
assert type(batch_dim) in (int, NotMapped)
|
|
if type(batch_dim) is int:
|
|
aval = raise_to_shaped(core.get_aval(val))
|
|
assert aval is core.abstract_unit or 0 <= batch_dim < len(aval.shape) # type: ignore
|
|
self._trace = trace
|
|
self.val = val
|
|
self.batch_dim = batch_dim
|
|
|
|
@property
|
|
def aval(self):
|
|
aval = raise_to_shaped(core.get_aval(self.val))
|
|
if self.batch_dim is not_mapped or aval is core.abstract_unit:
|
|
return aval
|
|
else:
|
|
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
|
|
|
|
def full_lower(self):
|
|
if self.batch_dim is not_mapped:
|
|
return core.full_lower(self.val)
|
|
else:
|
|
return self
|
|
|
|
class BatchTrace(Trace):
|
|
def __init__(self, *args, axis_name):
|
|
super().__init__(*args)
|
|
self.axis_name = axis_name
|
|
|
|
def pure(self, val):
|
|
return BatchTracer(self, val, not_mapped)
|
|
|
|
def lift(self, val):
|
|
return BatchTracer(self, val, not_mapped)
|
|
|
|
def sublift(self, val):
|
|
return BatchTracer(self, val.val, val.batch_dim)
|
|
|
|
def get_primitive_batcher(self, primitive):
|
|
if primitive in initial_style_batchers:
|
|
return partial(initial_style_batchers[primitive],
|
|
axis_name=self.axis_name,
|
|
main_type=self.main.trace_type)
|
|
try:
|
|
return primitive_batchers[primitive]
|
|
except KeyError as err:
|
|
msg = "Batching rule for '{}' not implemented"
|
|
raise NotImplementedError(msg.format(primitive)) from err
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
if (primitive in collective_rules and
|
|
_main_trace_for_axis_names(self.main, core.used_axis_names(primitive, params))):
|
|
frame = core.axis_frame(self.axis_name)
|
|
val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params)
|
|
elif all(bdim is not_mapped for bdim in dims_in):
|
|
return primitive.bind(*vals_in, **params)
|
|
else:
|
|
batched_primitive = self.get_primitive_batcher(primitive)
|
|
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
|
|
if primitive.multiple_results:
|
|
return map(partial(BatchTracer, self), val_out, dim_out)
|
|
else:
|
|
return BatchTracer(self, val_out, dim_out)
|
|
|
|
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
|
assert call_primitive.multiple_results
|
|
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
if all(bdim is not_mapped for bdim in dims):
|
|
return call_primitive.bind(f, *vals, **params)
|
|
else:
|
|
f, dims_out = batch_subtrace(f, self.main, dims)
|
|
vals_out = call_primitive.bind(f, *vals, **params)
|
|
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
main = self.main
|
|
def todo(vals):
|
|
trace = main.with_cur_sublevel()
|
|
return map(partial(BatchTracer, trace), vals, dims)
|
|
return vals, todo
|
|
|
|
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
if all(dim is not_mapped for dim in dims):
|
|
return map_primitive.bind(f, *vals, **params)
|
|
else:
|
|
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
|
|
# The logic for the dimension math below is as follows:
|
|
# ╔═════════════╦════════════════════════════════════════╦═══════════╗
|
|
# ║ d / in_axis ║ None ║ int ║
|
|
# ╠═════════════╬════════════════════════════════════════╩═══════════╣
|
|
# ║ None ║ No extra axis, so in_axis unaffected ║
|
|
# ╠═════════════╬════════════════════════════════════════╦═══════════╣
|
|
# ║ int ║ Not mapped, so batching dim unaffected ║ See below ║
|
|
# ╚═════════════╩════════════════════════════════════════╩═══════════╝
|
|
# When both d and in_axis are defined then:
|
|
# - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
|
|
# - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed).
|
|
def both_mapped(in_out_axis, d):
|
|
return in_out_axis is not None and d is not not_mapped
|
|
new_in_axes = tuple(
|
|
in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
|
|
for d, in_axis in zip(dims, params['in_axes']))
|
|
new_dims = tuple(
|
|
d - 1 if both_mapped(in_axis, d) and in_axis < d else d
|
|
for d, in_axis in zip(dims, params['in_axes']))
|
|
f, dims_out = batch_subtrace(f, self.main, new_dims)
|
|
out_axes_thunk = params['out_axes_thunk']
|
|
# NOTE: This assumes that the choice of the dimensions over which outputs
|
|
# are batched is entirely dependent on the function and not e.g. on the
|
|
# data or its shapes.
|
|
@as_hashable_function(closure=out_axes_thunk)
|
|
def new_out_axes_thunk():
|
|
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
|
|
for out_axis, d in zip(out_axes_thunk(), dims_out()))
|
|
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
|
|
vals_out = map_primitive.bind(f, *vals, **new_params)
|
|
dims_out = (d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
|
|
for d, out_axis in zip(dims_out(), out_axes_thunk()))
|
|
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
|
|
|
|
def post_process_map(self, call_primitive, out_tracers, params):
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
main = self.main
|
|
def both_mapped(in_out_axis, d):
|
|
return in_out_axis is not None and d is not not_mapped
|
|
def todo(vals):
|
|
trace = main.with_cur_sublevel()
|
|
return [BatchTracer(trace, v, d + 1 if both_mapped(out_axis, d) and out_axis <= d else d)
|
|
for v, d, out_axis in zip(vals, dims, params['out_axes_thunk']())]
|
|
if call_primitive.map_primitive:
|
|
def out_axes_transform(out_axes):
|
|
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
|
|
for out_axis, d in zip(out_axes, dims))
|
|
todo = (todo, out_axes_transform)
|
|
return vals, todo
|
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
|
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
|
|
out_vals = prim.bind(fun, jvp, *in_vals)
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
if not fst:
|
|
assert out_dims == out_dims[:len(out_dims) // 2] * 2
|
|
out_dims = out_dims[:len(out_dims) // 2]
|
|
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
|
|
|
|
def post_process_custom_jvp_call(self, out_tracers, params):
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
main = self.main
|
|
def todo(vals):
|
|
trace = main.with_cur_sublevel()
|
|
return map(partial(BatchTracer, trace), vals, dims)
|
|
return vals, todo
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
|
|
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
|
|
if d is not not_mapped}
|
|
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
|
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
|
|
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
|
|
out_dims2, in_dims, self.main.trace_type)
|
|
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
if not fst:
|
|
out_dims = out_dims[-len(out_vals) % len(out_dims):]
|
|
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
|
|
|
|
post_process_custom_vjp_call = post_process_custom_jvp_call
|
|
|
|
def _main_trace_for_axis_names(main_trace: core.MainTrace,
|
|
axis_name: Iterable[core.AxisName],
|
|
) -> bool:
|
|
# This function exists to identify whether a main trace corresponds to any of
|
|
# the axis names used by a primitive. Axis names alone aren't enough because
|
|
# axis names can shadow, so we use the main trace as a tag.
|
|
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
|
|
|
|
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type):
|
|
bwd, out_dims_thunk = batch_subtrace(bwd)
|
|
return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims, main_type),
|
|
axis_size, out_dims_thunk, out_dim_dests)
|
|
|
|
@lu.transformation
|
|
def _match_axes_and_sum(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
|
|
# this is like _match_axes, but we do reduce-sums as needed
|
|
out_vals = yield in_vals, {}
|
|
yield map(partial(matchaxis, axis_size, sum_match=True),
|
|
out_dims_thunk(), out_dim_dests, out_vals)
|
|
|
|
### API
|
|
|
|
AxesSpec = Union[Callable[[], BatchDims], BatchDims]
|
|
|
|
def batch(fun: lu.WrappedFun,
|
|
axis_name: core.AxisName,
|
|
axis_size: Optional[int],
|
|
in_dims: AxesSpec,
|
|
out_dim_dests: AxesSpec,
|
|
main_type: Type[BatchTrace] = BatchTrace,
|
|
) -> lu.WrappedFun:
|
|
# anlogue of `jvp` in ad.py
|
|
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
|
|
fun, out_dims_thunk = batch_subtrace(fun)
|
|
return _match_axes(
|
|
batchfun(fun, axis_name, axis_size, in_dims, main_type), axis_size,
|
|
axis_name, in_dims, out_dims_thunk, out_dim_dests)
|
|
|
|
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
|
def vtile(f_flat: lu.WrappedFun,
|
|
in_axes_flat: Tuple[Optional[int], ...],
|
|
out_axes_flat: Tuple[Optional[int], ...],
|
|
tile_size: Optional[int],
|
|
axis_name: core.AxisName,
|
|
main_type: Type[BatchTrace] = BatchTrace):
|
|
@curry
|
|
def tile_axis(arg, axis: Optional[int], tile_size):
|
|
if axis is None:
|
|
return arg
|
|
shape = list(arg.shape)
|
|
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
|
|
return arg.reshape(shape)
|
|
|
|
def untile_axis(out, axis: Optional[int]):
|
|
if axis is None:
|
|
return out
|
|
shape = list(out.shape)
|
|
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
|
|
return out.reshape(shape)
|
|
|
|
@lu.transformation
|
|
def _map_to_tile(*args_flat):
|
|
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
|
|
tile_size_ = tile_size or next(sizes, None)
|
|
assert tile_size_ is not None, "No mapped arguments?"
|
|
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
|
|
yield map(untile_axis, outputs_flat, out_axes_flat)
|
|
|
|
return _map_to_tile(batch(
|
|
f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type))
|
|
|
|
### primitives
|
|
|
|
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
|
|
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
|
|
initial_style_batchers : Dict[core.Primitive, Any] = {}
|
|
|
|
def defvectorized(prim):
|
|
primitive_batchers[prim] = partial(vectorized_batcher, prim)
|
|
|
|
def vectorized_batcher(prim, batched_args, batch_dims, **params):
|
|
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
|
|
return prim.bind(*batched_args, **params), batch_dims[0]
|
|
|
|
def defbroadcasting(prim):
|
|
primitive_batchers[prim] = partial(broadcast_batcher, prim)
|
|
|
|
def broadcast_batcher(prim, args, dims, **params):
|
|
"""Process a primitive with built-in broadcasting.
|
|
|
|
Args:
|
|
args: the possibly-batched arguments
|
|
dims: list or tuple of the same length as `args`, where each
|
|
entry indicates the batching state of the corresponding entry to `args`:
|
|
either an int indicating the batch dimension, or else `not_mapped`
|
|
indicating no batching.
|
|
"""
|
|
shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)}
|
|
if len(shapes) == 1:
|
|
# if there's only agreeing batch dims and scalars, just call the primitive
|
|
d = next(d for d in dims if d is not not_mapped)
|
|
out = prim.bind(*args, **params)
|
|
return (out, (d,) * len(out)) if prim.multiple_results else (out, d)
|
|
else:
|
|
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
|
|
args = [bdim_at_front(x, d, size) if np.ndim(x) else x
|
|
for x, d in zip(args, dims)]
|
|
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
|
|
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
|
|
out = prim.bind(*args, **params)
|
|
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
|
|
|
|
def _handle_scalar_broadcasting(nd, x, d):
|
|
if d is not_mapped or nd == np.ndim(x):
|
|
return x
|
|
else:
|
|
return x.reshape(x.shape + (1,) * (nd - np.ndim(x)))
|
|
|
|
def defreducer(prim):
|
|
primitive_batchers[prim] = partial(reducer_batcher, prim)
|
|
|
|
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
|
|
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
|
|
if 'input_shape' in params:
|
|
params = dict(params, input_shape=operand.shape)
|
|
return prim.bind(operand, axes=axes, **params), bdim_out
|
|
|
|
# sets up primitive batchers for ad_util and xla primitives
|
|
|
|
def add_batched(batched_args, batch_dims):
|
|
bdx, bdy = batch_dims
|
|
x, y = batched_args
|
|
if bdx == bdy or core.get_aval(x) == core.abstract_unit:
|
|
return add_jaxvals(x, y), bdx
|
|
elif bdx is not_mapped:
|
|
x = broadcast(x, y.shape[bdy], bdy)
|
|
return add_jaxvals(x, y), bdy
|
|
elif bdy is not_mapped:
|
|
y = broadcast(y, x.shape[bdx], bdx)
|
|
return add_jaxvals(x, y), bdx
|
|
else:
|
|
x = moveaxis(x, bdx, bdy)
|
|
return add_jaxvals(x, y), bdy
|
|
primitive_batchers[add_jaxvals_p] = add_batched
|
|
|
|
def zeros_like_batched(batched_args, batch_dims):
|
|
val, = batched_args
|
|
bdim, = batch_dims
|
|
return zeros_like_jaxval(val), bdim
|
|
primitive_batchers[zeros_like_p] = zeros_like_batched
|
|
|
|
defvectorized(xla.device_put_p)
|
|
|
|
### util
|
|
|
|
def broadcast(x, sz, axis):
|
|
if core.get_aval(x) is core.abstract_unit:
|
|
return core.unit
|
|
shape = list(np.shape(x))
|
|
shape.insert(axis, sz)
|
|
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
|
|
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
|
|
|
|
def matchaxis(sz, src, dst, x, sum_match=False):
|
|
try:
|
|
aval = core.get_aval(x)
|
|
except TypeError as e:
|
|
raise TypeError(f"Output from batched function {repr(x)} with type "
|
|
f"{type(x)} is not a valid JAX type") from e
|
|
if aval is core.abstract_unit:
|
|
return core.unit
|
|
if src == dst:
|
|
return x
|
|
elif type(src) == type(dst) == int:
|
|
return moveaxis(x, src, dst)
|
|
elif src is not_mapped and dst is not not_mapped:
|
|
return broadcast(
|
|
x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
|
|
elif dst is None and sum_match:
|
|
return x.sum(src)
|
|
else:
|
|
raise ValueError((src, dst))
|
|
|
|
def bdim_at_front(x, bdim, size):
|
|
if core.get_aval(x) is core.abstract_unit:
|
|
return core.unit
|
|
if bdim is not_mapped:
|
|
return broadcast(x, size, 0)
|
|
else:
|
|
return moveaxis(x, bdim, 0)
|
|
|
|
|
|
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, main_type):
|
|
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
|
f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
|
|
f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched], main_type)
|
|
avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
|
|
for aval, b in zip(closed_jaxpr.in_avals, in_batched)]
|
|
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
|
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
|
|
|
|
@lu.transformation_with_aux
|
|
def batch_subtrace_instantiate(instantiate, axis_size, main, in_dims, *in_vals):
|
|
# this is like `batch_subtrace` but we take an extra `instantiate` arg
|
|
# analogue of `jvp_subtrace` in ad.py
|
|
trace = main.with_cur_sublevel()
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
for val, dim in zip(in_vals, in_dims)]
|
|
outs = yield in_tracers, {}
|
|
out_tracers = map(trace.full_raise, outs)
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
|
|
if type(instantiate) is bool:
|
|
instantiate = [instantiate] * len(out_vals)
|
|
out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
|
else broadcast(x, axis_size, 0) if d is not_mapped and inst else x
|
|
for x, d, inst in zip(out_vals, out_dims, instantiate)]
|
|
out_batched = [d is not not_mapped or inst
|
|
for d, inst in zip(out_dims, instantiate)]
|
|
yield out_vals, out_batched
|
|
|
|
@lu.transformation_with_aux
|
|
def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
|
|
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
|
trace = main.with_cur_sublevel()
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
for val, dim in zip(in_vals, in_dims * 2)]
|
|
outs = yield in_tracers, {}
|
|
out_tracers = map(trace.full_raise, outs)
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
|
|
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
|
|
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
|
|
out_primals = map(partial(matchaxis, size), out_primal_bds, out_dims, out_primals)
|
|
out_tangents = map(partial(matchaxis, size), out_tangent_bds, out_dims, out_tangents)
|
|
yield out_primals + out_tangents, out_dims * 2
|
|
|
|
def _merge_bdims(x, y):
|
|
if x == y:
|
|
return x
|
|
elif x is not_mapped:
|
|
return y
|
|
elif y is not_mapped:
|
|
return x
|
|
else:
|
|
return x # arbitrary
|
|
|
|
|
|
collective_rules: Dict[core.Primitive, Callable] = {}
|