Move vtile to batching.py, make it possible to add new BatchTraces

No substantial behavior change right now, but the ability to use
subclasses of BatchTrace comes in handy when adding support for
nesting xmaps in the SPMD lowering.

PiperOrigin-RevId: 369445693
This commit is contained in:
Adam Paszke 2021-04-20 08:32:41 -07:00 committed by jax authors
parent 93c63d0341
commit c09037bd14
5 changed files with 101 additions and 83 deletions

View File

@ -324,7 +324,7 @@ def _custom_jvp_call_jaxpr_jvp(
ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp
def _custom_jvp_call_jaxpr_vmap(
args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr,
args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int):
size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
@ -333,7 +333,8 @@ def _custom_jvp_call_jaxpr_vmap(
num_out = len(fun_jaxpr.out_avals)
in_batched = [d is not not_mapped for d in in_dims]
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(fun_jaxpr, size, in_batched, False, axis_name)
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, size, in_batched, False, axis_name, main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk
@ -341,12 +342,14 @@ def _custom_jvp_call_jaxpr_vmap(
def batched_jvp_jaxpr_thunk():
jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers
_, args_batched = split_list(in_batched, [num_consts])
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False, axis_name)
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False,
axis_name, main_type)
primals_batched, tangents_batched = split_list(all_batched, [num_out])
out_batched = map(op.or_, primals_batched, tangents_batched)
out_dims2.append([0 if b else not_mapped for b in out_batched])
batched_jvp_jaxpr, _ = batching.batch_jaxpr(
jvp_jaxpr, size, args_batched * 2, out_batched * 2, axis_name)
jvp_jaxpr, size, args_batched * 2, out_batched * 2,
axis_name, main_type)
return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts
batched_outs = custom_jvp_call_jaxpr_p.bind(
@ -610,7 +613,7 @@ def _custom_vjp_call_jaxpr_jvp(
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
def _custom_vjp_call_jaxpr_vmap(
args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr,
args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
@ -620,7 +623,7 @@ def _custom_vjp_call_jaxpr_vmap(
in_batched = [d is not not_mapped for d in in_dims]
_, args_batched = split_list(in_batched, [num_consts])
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, in_batched, False, axis_name)
fun_jaxpr, axis_size, in_batched, False, axis_name, main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = []
@ -628,14 +631,14 @@ def _custom_vjp_call_jaxpr_vmap(
def batched_fwd_jaxpr_thunk():
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, args_batched, False, axis_name)
fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type)
out_dims2.append([0 if b else not_mapped for b in out_batched])
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
fwd_args_batched = [0 if b else not_mapped for b in args_batched]
fwd_out_dims = lambda: out_dims2[0]
batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims,
fwd_args_batched)
fwd_args_batched, main_type)
batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr,

View File

@ -367,7 +367,7 @@ def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue):
bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape))))
return xops.Select(bcast_pred, x, y)
def _while_loop_batching_rule(args, dims, axis_name,
def _while_loop_batching_rule(args, dims, axis_name, main_type,
cond_nconsts, cond_jaxpr,
body_nconsts, body_jaxpr):
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
@ -383,11 +383,12 @@ def _while_loop_batching_rule(args, dims, axis_name,
for _ in range(1 + len(carry_bat)):
batched = bconst_bat + carry_bat
body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
body_jaxpr, size, batched, instantiate=carry_bat, axis_name=axis_name)
body_jaxpr, size, batched, instantiate=carry_bat,
axis_name=axis_name, main_type=main_type)
cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
cond_jaxpr, size, cconst_bat + carry_bat,
instantiate=bool(cond_jaxpr.out_avals[0].shape),
axis_name=axis_name)
axis_name=axis_name, main_type=main_type)
carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
if carry_bat_out == carry_bat:
break
@ -772,7 +773,7 @@ def _bcast_select(pred, on_true, on_false):
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
return lax.select(pred, on_true, on_false)
def _cond_batching_rule(args, dims, axis_name, branches, linear):
def _cond_batching_rule(args, dims, axis_name, main_type, branches, linear):
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
index, *ops = args
index_dim, *op_dims = dims
@ -786,7 +787,7 @@ def _cond_batching_rule(args, dims, axis_name, branches, linear):
index, *ops = [batching.bdim_at_front(x, d, size) for x, d in zip(args, dims)]
branches_batched = [
batching.batch_jaxpr(jaxpr, size, [True] * len(ops), True, axis_name)[0]
batching.batch_jaxpr(jaxpr, size, [True] * len(ops), True, axis_name, main_type)[0]
for jaxpr in branches]
branch_outs = []
@ -806,11 +807,11 @@ def _cond_batching_rule(args, dims, axis_name, branches, linear):
for b, x, d in zip(ops_bat, ops, op_dims)]
branches_out_bat = [
batching.batch_jaxpr(jaxpr, size, ops_bat, False, axis_name)[1]
batching.batch_jaxpr(jaxpr, size, ops_bat, False, axis_name, main_type)[1]
for jaxpr in branches]
out_bat = [any(bat) for bat in zip(*branches_out_bat)]
branches_batched = tuple(
batching.batch_jaxpr(jaxpr, size, ops_bat, out_bat, axis_name)[0]
batching.batch_jaxpr(jaxpr, size, ops_bat, out_bat, axis_name, main_type)[0]
for jaxpr in branches)
out_dims = [0 if b else batching.not_mapped for b in out_bat]
@ -1731,7 +1732,7 @@ def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.Abstrac
return core.ClosedJaxpr(jaxpr, consts)
def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_consts,
def _scan_batching_rule(args, dims, axis_name, main_type, reverse, length, jaxpr, num_consts,
num_carry, linear, unroll):
num_ys = len(jaxpr.out_avals) - num_carry
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
@ -1749,7 +1750,8 @@ def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_const
jaxpr_batched, batched_out = batching.batch_jaxpr(
jaxpr, size, batched,
instantiate=carry_batched + [False] * num_ys,
axis_name=axis_name)
axis_name=axis_name,
main_type=main_type)
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
if carry_batched_out == carry_batched:
break
@ -2275,7 +2277,7 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
return [None] * sum(const_lengths) + cotangent_b
def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs):
def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths, jaxprs):
orig_bat = [d is not batching.not_mapped for d in dims]
size, = {
a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped
@ -2295,23 +2297,27 @@ def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs):
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
# Apply vecmat and solve -> new batched parts of x
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
solve, size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name)
solve, size, solve_bat + b_bat, instantiate=x_bat,
axis_name=axis_name, main_type=main_type)
if vecmat is None:
vecmat_jaxpr_batched = None
x_bat_out = solve_x_bat
else:
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name)
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat,
axis_name=axis_name, main_type=main_type)
x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat)
# Apply matvec and solve_t -> new batched parts of b
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name)
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat,
axis_name=axis_name, main_type=main_type)
if solve_t is None:
solve_t_jaxpr_batched = None
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
else:
solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr(
solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name)
solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat,
axis_name=axis_name, main_type=main_type)
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
orig_b_bat)
if x_bat_out == x_bat and b_bat_out == b_bat:

View File

@ -652,7 +652,7 @@ class EvaluationPlan(NamedTuple):
vaxis = raxes[-1]
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis)
f = batching.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis)
return f
def to_mesh_axes(self, in_axes, out_axes):

View File

@ -13,7 +13,7 @@
# limitations under the License.
import numpy as np
from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence, Iterable
from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence, Iterable, Type
import jax
from ..config import config
@ -21,8 +21,8 @@ from .. import core
from ..core import raise_to_shaped, Trace, Tracer
from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
from .. import linear_util as lu
from .._src.util import (unzip2, partial, safe_map, wrap_name, split_list,
canonicalize_axis, moveaxis, as_hashable_function)
from .._src.util import (unzip2, partial, safe_map, safe_zip, wrap_name, split_list,
canonicalize_axis, moveaxis, as_hashable_function, curry)
from . import xla
from . import partial_eval as pe
@ -30,19 +30,8 @@ map = safe_map
BatchDim = Optional[int]
BatchDims = Sequence[BatchDim]
AxesSpec = Union[Callable[[], BatchDims], BatchDims]
def batch(fun: lu.WrappedFun, axis_name: core.AxisName,
axis_size: Optional[int], in_dims: AxesSpec, out_dim_dests: AxesSpec,
) -> lu.WrappedFun:
# anlogue of `jvp` in ad.py
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
fun, out_dims_thunk = batch_subtrace(fun)
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims),
axis_size, in_dims, out_dims_thunk, out_dim_dests)
@lu.transformation
def batchfun(axis_name, axis_size, in_dims, *in_vals):
def batchfun(axis_name, axis_size, in_dims, main_type, *in_vals):
# analogue of `jvpfun` in ad.py
if axis_size is None:
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
@ -50,7 +39,7 @@ def batchfun(axis_name, axis_size, in_dims, *in_vals):
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
and not isinstance(core.get_aval(x), core.AbstractUnit)
else ax for x, ax in zip(in_vals, in_dims)]
with core.new_main(BatchTrace, axis_name=axis_name) as main:
with core.new_main(main_type, axis_name=axis_name) as main:
with core.extend_axis_env(axis_name, axis_size, main):
out_vals = yield (main, in_dims, *in_vals), {}
del main
@ -137,7 +126,7 @@ class BatchTrace(Trace):
frame = core.axis_frame(self.axis_name)
val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params)
else:
batched_primitive = get_primitive_batcher(primitive, self.axis_name)
batched_primitive = get_primitive_batcher(primitive, self)
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
if primitive.multiple_results:
return map(partial(BatchTracer, self), val_out, dim_out)
@ -245,7 +234,7 @@ class BatchTrace(Trace):
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
out_dims2, in_dims)
out_dims2, in_dims, self.main.trace_type)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
@ -262,9 +251,9 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace,
# axis names can shadow, so we use the main trace as a tag.
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests):
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type):
bwd, out_dims_thunk = batch_subtrace(bwd)
return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims),
return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims, main_type),
axis_size, out_dims_thunk, out_dim_dests)
@lu.transformation
@ -274,6 +263,55 @@ def _match_axes_and_sum(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
yield map(partial(matchaxis, axis_size, sum_match=True),
out_dims_thunk(), out_dim_dests, out_vals)
### API
AxesSpec = Union[Callable[[], BatchDims], BatchDims]
def batch(fun: lu.WrappedFun,
axis_name: core.AxisName,
axis_size: Optional[int],
in_dims: AxesSpec,
out_dim_dests: AxesSpec,
main_type: Type[BatchTrace] = BatchTrace,
) -> lu.WrappedFun:
# anlogue of `jvp` in ad.py
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
fun, out_dims_thunk = batch_subtrace(fun)
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims, main_type),
axis_size, in_dims, out_dims_thunk, out_dim_dests)
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat: lu.WrappedFun,
in_axes_flat: Tuple[Optional[int], ...],
out_axes_flat: Tuple[Optional[int], ...],
tile_size: Optional[int],
axis_name: core.AxisName,
main_type: Type[BatchTrace] = BatchTrace):
@curry
def tile_axis(arg, axis: Optional[int], tile_size):
if axis is None:
return arg
shape = list(arg.shape)
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
return arg.reshape(shape)
def untile_axis(out, axis: Optional[int]):
if axis is None:
return out
shape = list(out.shape)
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
return out.reshape(shape)
@lu.transformation
def _map_to_tile(*args_flat):
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
tile_size_ = tile_size or next(sizes, None)
assert tile_size_ is not None, "No mapped arguments?"
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
yield map(untile_axis, outputs_flat, out_axes_flat)
return _map_to_tile(batch(
f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type))
### primitives
@ -281,9 +319,11 @@ BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
initial_style_batchers : Dict[core.Primitive, Any] = {}
def get_primitive_batcher(p, axis_name):
def get_primitive_batcher(p, trace):
if p in initial_style_batchers:
return partial(initial_style_batchers[p], axis_name=axis_name)
return partial(initial_style_batchers[p],
axis_name=trace.axis_name,
main_type=trace.main.trace_type)
try:
return primitive_batchers[p]
except KeyError as err:
@ -408,10 +448,10 @@ def bdim_at_front(x, bdim, size):
return moveaxis(x, bdim, 0)
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name):
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, main_type):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched])
f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched], main_type)
avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
for aval, b in zip(closed_jaxpr.in_avals, in_batched)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)

View File

@ -47,7 +47,7 @@ from ..abstract_arrays import array_types
from ..core import ConcreteArray, ShapedArray
from .._src.util import (partial, unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, curry)
tuple_insert, tuple_delete)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from ..lib import pmap_lib
@ -1381,10 +1381,10 @@ def mesh_callable(fun: lu.WrappedFun,
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
for name, size in reversed(mesh.shape.items()):
if tile_by_mesh_axes:
fun = vtile(fun,
tuple(a.get(name, None) for a in in_axes),
tuple(a.get(name, None) for a in out_axes_thunk()),
tile_size=size, axis_name=name)
fun = batching.vtile(fun,
tuple(a.get(name, None) for a in in_axes),
tuple(a.get(name, None) for a in out_axes_thunk()),
tile_size=size, axis_name=name)
global_in_untiled_avals = [untile_aval_nd(global_axis_sizes, aval_in_axes, aval)
for aval, aval_in_axes in safe_zip(in_tiled_avals, in_axes)]
in_jaxpr_avals = global_in_untiled_avals
@ -1500,37 +1500,6 @@ def compile_and_wrap_mesh_hlo(computation: xc.XlaComputation, backend,
handle_outs)
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat,
in_axes_flat: Tuple[Optional[int], ...],
out_axes_flat: Tuple[Optional[int], ...],
tile_size: Optional[int], axis_name):
@curry
def tile_axis(arg, axis: Optional[int], tile_size):
if axis is None:
return arg
shape = list(arg.shape)
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
return arg.reshape(shape)
def untile_axis(out, axis: Optional[int]):
if axis is None:
return out
shape = list(out.shape)
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
return out.reshape(shape)
@lu.transformation
def _map_to_tile(*args_flat):
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
tile_size_ = tile_size or next(sizes, None)
assert tile_size_ is not None, "No mapped arguments?"
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
yield map(untile_axis, outputs_flat, out_axes_flat)
return _map_to_tile(
batching.batch(f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat))
_forbidden_primitives = {
'xla_pmap': 'pmap',
'sharded_call': 'sharded_jit',