Move vtile to batching.py, make it possible to add new BatchTraces

No substantial behavior change right now, but the ability to use
subclasses of BatchTrace comes in handy when adding support for
nesting xmaps in the SPMD lowering.

PiperOrigin-RevId: 369445693
This commit is contained in:
Adam Paszke 2021-04-20 08:32:41 -07:00 committed by jax authors
parent 93c63d0341
commit c09037bd14
5 changed files with 101 additions and 83 deletions

View File

@ -324,7 +324,7 @@ def _custom_jvp_call_jaxpr_jvp(
ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp
def _custom_jvp_call_jaxpr_vmap( def _custom_jvp_call_jaxpr_vmap(
args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr, args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int): num_consts: int):
size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
@ -333,7 +333,8 @@ def _custom_jvp_call_jaxpr_vmap(
num_out = len(fun_jaxpr.out_avals) num_out = len(fun_jaxpr.out_avals)
in_batched = [d is not not_mapped for d in in_dims] in_batched = [d is not not_mapped for d in in_dims]
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(fun_jaxpr, size, in_batched, False, axis_name) batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, size, in_batched, False, axis_name, main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk
@ -341,12 +342,14 @@ def _custom_jvp_call_jaxpr_vmap(
def batched_jvp_jaxpr_thunk(): def batched_jvp_jaxpr_thunk():
jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers
_, args_batched = split_list(in_batched, [num_consts]) _, args_batched = split_list(in_batched, [num_consts])
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False, axis_name) _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False,
axis_name, main_type)
primals_batched, tangents_batched = split_list(all_batched, [num_out]) primals_batched, tangents_batched = split_list(all_batched, [num_out])
out_batched = map(op.or_, primals_batched, tangents_batched) out_batched = map(op.or_, primals_batched, tangents_batched)
out_dims2.append([0 if b else not_mapped for b in out_batched]) out_dims2.append([0 if b else not_mapped for b in out_batched])
batched_jvp_jaxpr, _ = batching.batch_jaxpr( batched_jvp_jaxpr, _ = batching.batch_jaxpr(
jvp_jaxpr, size, args_batched * 2, out_batched * 2, axis_name) jvp_jaxpr, size, args_batched * 2, out_batched * 2,
axis_name, main_type)
return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts
batched_outs = custom_jvp_call_jaxpr_p.bind( batched_outs = custom_jvp_call_jaxpr_p.bind(
@ -610,7 +613,7 @@ def _custom_vjp_call_jaxpr_jvp(
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
def _custom_vjp_call_jaxpr_vmap( def _custom_vjp_call_jaxpr_vmap(
args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr, args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int): bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
@ -620,7 +623,7 @@ def _custom_vjp_call_jaxpr_vmap(
in_batched = [d is not not_mapped for d in in_dims] in_batched = [d is not not_mapped for d in in_dims]
_, args_batched = split_list(in_batched, [num_consts]) _, args_batched = split_list(in_batched, [num_consts])
batched_fun_jaxpr, out_batched = batching.batch_jaxpr( batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, in_batched, False, axis_name) fun_jaxpr, axis_size, in_batched, False, axis_name, main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = [] out_dims2 = []
@ -628,14 +631,14 @@ def _custom_vjp_call_jaxpr_vmap(
def batched_fwd_jaxpr_thunk(): def batched_fwd_jaxpr_thunk():
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, args_batched, False, axis_name) fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type)
out_dims2.append([0 if b else not_mapped for b in out_batched]) out_dims2.append([0 if b else not_mapped for b in out_batched])
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_args_batched = [0 if b else not_mapped for b in args_batched]
fwd_out_dims = lambda: out_dims2[0] fwd_out_dims = lambda: out_dims2[0]
batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims, batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims,
fwd_args_batched) fwd_args_batched, main_type)
batched_outs = custom_vjp_call_jaxpr_p.bind( batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr, *args, fun_jaxpr=batched_fun_jaxpr,

View File

@ -367,7 +367,7 @@ def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue):
bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape)))) bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape))))
return xops.Select(bcast_pred, x, y) return xops.Select(bcast_pred, x, y)
def _while_loop_batching_rule(args, dims, axis_name, def _while_loop_batching_rule(args, dims, axis_name, main_type,
cond_nconsts, cond_jaxpr, cond_nconsts, cond_jaxpr,
body_nconsts, body_jaxpr): body_nconsts, body_jaxpr):
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
@ -383,11 +383,12 @@ def _while_loop_batching_rule(args, dims, axis_name,
for _ in range(1 + len(carry_bat)): for _ in range(1 + len(carry_bat)):
batched = bconst_bat + carry_bat batched = bconst_bat + carry_bat
body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr( body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
body_jaxpr, size, batched, instantiate=carry_bat, axis_name=axis_name) body_jaxpr, size, batched, instantiate=carry_bat,
axis_name=axis_name, main_type=main_type)
cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr( cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
cond_jaxpr, size, cconst_bat + carry_bat, cond_jaxpr, size, cconst_bat + carry_bat,
instantiate=bool(cond_jaxpr.out_avals[0].shape), instantiate=bool(cond_jaxpr.out_avals[0].shape),
axis_name=axis_name) axis_name=axis_name, main_type=main_type)
carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out) carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
if carry_bat_out == carry_bat: if carry_bat_out == carry_bat:
break break
@ -772,7 +773,7 @@ def _bcast_select(pred, on_true, on_false):
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx) pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
return lax.select(pred, on_true, on_false) return lax.select(pred, on_true, on_false)
def _cond_batching_rule(args, dims, axis_name, branches, linear): def _cond_batching_rule(args, dims, axis_name, main_type, branches, linear):
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
index, *ops = args index, *ops = args
index_dim, *op_dims = dims index_dim, *op_dims = dims
@ -786,7 +787,7 @@ def _cond_batching_rule(args, dims, axis_name, branches, linear):
index, *ops = [batching.bdim_at_front(x, d, size) for x, d in zip(args, dims)] index, *ops = [batching.bdim_at_front(x, d, size) for x, d in zip(args, dims)]
branches_batched = [ branches_batched = [
batching.batch_jaxpr(jaxpr, size, [True] * len(ops), True, axis_name)[0] batching.batch_jaxpr(jaxpr, size, [True] * len(ops), True, axis_name, main_type)[0]
for jaxpr in branches] for jaxpr in branches]
branch_outs = [] branch_outs = []
@ -806,11 +807,11 @@ def _cond_batching_rule(args, dims, axis_name, branches, linear):
for b, x, d in zip(ops_bat, ops, op_dims)] for b, x, d in zip(ops_bat, ops, op_dims)]
branches_out_bat = [ branches_out_bat = [
batching.batch_jaxpr(jaxpr, size, ops_bat, False, axis_name)[1] batching.batch_jaxpr(jaxpr, size, ops_bat, False, axis_name, main_type)[1]
for jaxpr in branches] for jaxpr in branches]
out_bat = [any(bat) for bat in zip(*branches_out_bat)] out_bat = [any(bat) for bat in zip(*branches_out_bat)]
branches_batched = tuple( branches_batched = tuple(
batching.batch_jaxpr(jaxpr, size, ops_bat, out_bat, axis_name)[0] batching.batch_jaxpr(jaxpr, size, ops_bat, out_bat, axis_name, main_type)[0]
for jaxpr in branches) for jaxpr in branches)
out_dims = [0 if b else batching.not_mapped for b in out_bat] out_dims = [0 if b else batching.not_mapped for b in out_bat]
@ -1731,7 +1732,7 @@ def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.Abstrac
return core.ClosedJaxpr(jaxpr, consts) return core.ClosedJaxpr(jaxpr, consts)
def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_consts, def _scan_batching_rule(args, dims, axis_name, main_type, reverse, length, jaxpr, num_consts,
num_carry, linear, unroll): num_carry, linear, unroll):
num_ys = len(jaxpr.out_avals) - num_carry num_ys = len(jaxpr.out_avals) - num_carry
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
@ -1749,7 +1750,8 @@ def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_const
jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr_batched, batched_out = batching.batch_jaxpr(
jaxpr, size, batched, jaxpr, size, batched,
instantiate=carry_batched + [False] * num_ys, instantiate=carry_batched + [False] * num_ys,
axis_name=axis_name) axis_name=axis_name,
main_type=main_type)
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
if carry_batched_out == carry_batched: if carry_batched_out == carry_batched:
break break
@ -2275,7 +2277,7 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
return [None] * sum(const_lengths) + cotangent_b return [None] * sum(const_lengths) + cotangent_b
def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs): def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths, jaxprs):
orig_bat = [d is not batching.not_mapped for d in dims] orig_bat = [d is not batching.not_mapped for d in dims]
size, = { size, = {
a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped
@ -2295,23 +2297,27 @@ def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs):
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
# Apply vecmat and solve -> new batched parts of x # Apply vecmat and solve -> new batched parts of x
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
solve, size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name) solve, size, solve_bat + b_bat, instantiate=x_bat,
axis_name=axis_name, main_type=main_type)
if vecmat is None: if vecmat is None:
vecmat_jaxpr_batched = None vecmat_jaxpr_batched = None
x_bat_out = solve_x_bat x_bat_out = solve_x_bat
else: else:
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name) vecmat, size, vecmat_bat + b_bat, instantiate=x_bat,
axis_name=axis_name, main_type=main_type)
x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat) x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat)
# Apply matvec and solve_t -> new batched parts of b # Apply matvec and solve_t -> new batched parts of b
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name) matvec, size, matvec_bat + x_bat_out, instantiate=b_bat,
axis_name=axis_name, main_type=main_type)
if solve_t is None: if solve_t is None:
solve_t_jaxpr_batched = None solve_t_jaxpr_batched = None
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
else: else:
solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr( solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr(
solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name) solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat,
axis_name=axis_name, main_type=main_type)
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
orig_b_bat) orig_b_bat)
if x_bat_out == x_bat and b_bat_out == b_bat: if x_bat_out == x_bat and b_bat_out == b_bat:

View File

@ -652,7 +652,7 @@ class EvaluationPlan(NamedTuple):
vaxis = raxes[-1] vaxis = raxes[-1]
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes)) map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes)) map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis) f = batching.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis)
return f return f
def to_mesh_axes(self, in_axes, out_axes): def to_mesh_axes(self, in_axes, out_axes):

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence, Iterable from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence, Iterable, Type
import jax import jax
from ..config import config from ..config import config
@ -21,8 +21,8 @@ from .. import core
from ..core import raise_to_shaped, Trace, Tracer from ..core import raise_to_shaped, Trace, Tracer
from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
from .. import linear_util as lu from .. import linear_util as lu
from .._src.util import (unzip2, partial, safe_map, wrap_name, split_list, from .._src.util import (unzip2, partial, safe_map, safe_zip, wrap_name, split_list,
canonicalize_axis, moveaxis, as_hashable_function) canonicalize_axis, moveaxis, as_hashable_function, curry)
from . import xla from . import xla
from . import partial_eval as pe from . import partial_eval as pe
@ -30,19 +30,8 @@ map = safe_map
BatchDim = Optional[int] BatchDim = Optional[int]
BatchDims = Sequence[BatchDim] BatchDims = Sequence[BatchDim]
AxesSpec = Union[Callable[[], BatchDims], BatchDims]
def batch(fun: lu.WrappedFun, axis_name: core.AxisName,
axis_size: Optional[int], in_dims: AxesSpec, out_dim_dests: AxesSpec,
) -> lu.WrappedFun:
# anlogue of `jvp` in ad.py
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
fun, out_dims_thunk = batch_subtrace(fun)
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims),
axis_size, in_dims, out_dims_thunk, out_dim_dests)
@lu.transformation @lu.transformation
def batchfun(axis_name, axis_size, in_dims, *in_vals): def batchfun(axis_name, axis_size, in_dims, main_type, *in_vals):
# analogue of `jvpfun` in ad.py # analogue of `jvpfun` in ad.py
if axis_size is None: if axis_size is None:
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
@ -50,7 +39,7 @@ def batchfun(axis_name, axis_size, in_dims, *in_vals):
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
and not isinstance(core.get_aval(x), core.AbstractUnit) and not isinstance(core.get_aval(x), core.AbstractUnit)
else ax for x, ax in zip(in_vals, in_dims)] else ax for x, ax in zip(in_vals, in_dims)]
with core.new_main(BatchTrace, axis_name=axis_name) as main: with core.new_main(main_type, axis_name=axis_name) as main:
with core.extend_axis_env(axis_name, axis_size, main): with core.extend_axis_env(axis_name, axis_size, main):
out_vals = yield (main, in_dims, *in_vals), {} out_vals = yield (main, in_dims, *in_vals), {}
del main del main
@ -137,7 +126,7 @@ class BatchTrace(Trace):
frame = core.axis_frame(self.axis_name) frame = core.axis_frame(self.axis_name)
val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params) val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params)
else: else:
batched_primitive = get_primitive_batcher(primitive, self.axis_name) batched_primitive = get_primitive_batcher(primitive, self)
val_out, dim_out = batched_primitive(vals_in, dims_in, **params) val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
if primitive.multiple_results: if primitive.multiple_results:
return map(partial(BatchTracer, self), val_out, dim_out) return map(partial(BatchTracer, self), val_out, dim_out)
@ -245,7 +234,7 @@ class BatchTrace(Trace):
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims) fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
out_dims2, in_dims) out_dims2, in_dims, self.main.trace_type)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees) out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst: if not fst:
@ -262,9 +251,9 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace,
# axis names can shadow, so we use the main trace as a tag. # axis names can shadow, so we use the main trace as a tag.
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name) return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests): def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type):
bwd, out_dims_thunk = batch_subtrace(bwd) bwd, out_dims_thunk = batch_subtrace(bwd)
return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims), return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims, main_type),
axis_size, out_dims_thunk, out_dim_dests) axis_size, out_dims_thunk, out_dim_dests)
@lu.transformation @lu.transformation
@ -274,6 +263,55 @@ def _match_axes_and_sum(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
yield map(partial(matchaxis, axis_size, sum_match=True), yield map(partial(matchaxis, axis_size, sum_match=True),
out_dims_thunk(), out_dim_dests, out_vals) out_dims_thunk(), out_dim_dests, out_vals)
### API
AxesSpec = Union[Callable[[], BatchDims], BatchDims]
def batch(fun: lu.WrappedFun,
axis_name: core.AxisName,
axis_size: Optional[int],
in_dims: AxesSpec,
out_dim_dests: AxesSpec,
main_type: Type[BatchTrace] = BatchTrace,
) -> lu.WrappedFun:
# anlogue of `jvp` in ad.py
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
fun, out_dims_thunk = batch_subtrace(fun)
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims, main_type),
axis_size, in_dims, out_dims_thunk, out_dim_dests)
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat: lu.WrappedFun,
in_axes_flat: Tuple[Optional[int], ...],
out_axes_flat: Tuple[Optional[int], ...],
tile_size: Optional[int],
axis_name: core.AxisName,
main_type: Type[BatchTrace] = BatchTrace):
@curry
def tile_axis(arg, axis: Optional[int], tile_size):
if axis is None:
return arg
shape = list(arg.shape)
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
return arg.reshape(shape)
def untile_axis(out, axis: Optional[int]):
if axis is None:
return out
shape = list(out.shape)
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
return out.reshape(shape)
@lu.transformation
def _map_to_tile(*args_flat):
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
tile_size_ = tile_size or next(sizes, None)
assert tile_size_ is not None, "No mapped arguments?"
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
yield map(untile_axis, outputs_flat, out_axes_flat)
return _map_to_tile(batch(
f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type))
### primitives ### primitives
@ -281,9 +319,11 @@ BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
primitive_batchers : Dict[core.Primitive, BatchingRule] = {} primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
initial_style_batchers : Dict[core.Primitive, Any] = {} initial_style_batchers : Dict[core.Primitive, Any] = {}
def get_primitive_batcher(p, axis_name): def get_primitive_batcher(p, trace):
if p in initial_style_batchers: if p in initial_style_batchers:
return partial(initial_style_batchers[p], axis_name=axis_name) return partial(initial_style_batchers[p],
axis_name=trace.axis_name,
main_type=trace.main.trace_type)
try: try:
return primitive_batchers[p] return primitive_batchers[p]
except KeyError as err: except KeyError as err:
@ -408,10 +448,10 @@ def bdim_at_front(x, bdim, size):
return moveaxis(x, bdim, 0) return moveaxis(x, bdim, 0)
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name): def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, main_type):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size) f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched]) f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched], main_type)
avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
for aval, b in zip(closed_jaxpr.in_avals, in_batched)] for aval, b in zip(closed_jaxpr.in_avals, in_batched)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in) jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)

View File

@ -47,7 +47,7 @@ from ..abstract_arrays import array_types
from ..core import ConcreteArray, ShapedArray from ..core import ConcreteArray, ShapedArray
from .._src.util import (partial, unzip3, prod, safe_map, safe_zip, from .._src.util import (partial, unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name, assert_unreachable, extend_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, curry) tuple_insert, tuple_delete)
from ..lib import xla_bridge as xb from ..lib import xla_bridge as xb
from ..lib import xla_client as xc from ..lib import xla_client as xc
from ..lib import pmap_lib from ..lib import pmap_lib
@ -1381,10 +1381,10 @@ def mesh_callable(fun: lu.WrappedFun,
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice! # TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
for name, size in reversed(mesh.shape.items()): for name, size in reversed(mesh.shape.items()):
if tile_by_mesh_axes: if tile_by_mesh_axes:
fun = vtile(fun, fun = batching.vtile(fun,
tuple(a.get(name, None) for a in in_axes), tuple(a.get(name, None) for a in in_axes),
tuple(a.get(name, None) for a in out_axes_thunk()), tuple(a.get(name, None) for a in out_axes_thunk()),
tile_size=size, axis_name=name) tile_size=size, axis_name=name)
global_in_untiled_avals = [untile_aval_nd(global_axis_sizes, aval_in_axes, aval) global_in_untiled_avals = [untile_aval_nd(global_axis_sizes, aval_in_axes, aval)
for aval, aval_in_axes in safe_zip(in_tiled_avals, in_axes)] for aval, aval_in_axes in safe_zip(in_tiled_avals, in_axes)]
in_jaxpr_avals = global_in_untiled_avals in_jaxpr_avals = global_in_untiled_avals
@ -1500,37 +1500,6 @@ def compile_and_wrap_mesh_hlo(computation: xc.XlaComputation, backend,
handle_outs) handle_outs)
return partial(execute_replicated, compiled, backend, handle_args, handle_outs) return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat,
in_axes_flat: Tuple[Optional[int], ...],
out_axes_flat: Tuple[Optional[int], ...],
tile_size: Optional[int], axis_name):
@curry
def tile_axis(arg, axis: Optional[int], tile_size):
if axis is None:
return arg
shape = list(arg.shape)
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
return arg.reshape(shape)
def untile_axis(out, axis: Optional[int]):
if axis is None:
return out
shape = list(out.shape)
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
return out.reshape(shape)
@lu.transformation
def _map_to_tile(*args_flat):
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
tile_size_ = tile_size or next(sizes, None)
assert tile_size_ is not None, "No mapped arguments?"
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
yield map(untile_axis, outputs_flat, out_axes_flat)
return _map_to_tile(
batching.batch(f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat))
_forbidden_primitives = { _forbidden_primitives = {
'xla_pmap': 'pmap', 'xla_pmap': 'pmap',
'sharded_call': 'sharded_jit', 'sharded_call': 'sharded_jit',