From c09037bd144cc2e271da4d7f1077d281f9d0e7f4 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 20 Apr 2021 08:32:41 -0700 Subject: [PATCH] Move vtile to batching.py, make it possible to add new BatchTraces No substantial behavior change right now, but the ability to use subclasses of BatchTrace comes in handy when adding support for nesting xmaps in the SPMD lowering. PiperOrigin-RevId: 369445693 --- jax/_src/custom_derivatives.py | 19 ++++---- jax/_src/lax/control_flow.py | 34 +++++++------ jax/experimental/maps.py | 2 +- jax/interpreters/batching.py | 88 ++++++++++++++++++++++++---------- jax/interpreters/pxla.py | 41 ++-------------- 5 files changed, 101 insertions(+), 83 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 89bad0b96..46d0fd5ed 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -324,7 +324,7 @@ def _custom_jvp_call_jaxpr_jvp( ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp def _custom_jvp_call_jaxpr_vmap( - args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr, + args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr, jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], num_consts: int): size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} @@ -333,7 +333,8 @@ def _custom_jvp_call_jaxpr_vmap( num_out = len(fun_jaxpr.out_avals) in_batched = [d is not not_mapped for d in in_dims] - batched_fun_jaxpr, out_batched = batching.batch_jaxpr(fun_jaxpr, size, in_batched, False, axis_name) + batched_fun_jaxpr, out_batched = batching.batch_jaxpr( + fun_jaxpr, size, in_batched, False, axis_name, main_type) out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk @@ -341,12 +342,14 @@ def _custom_jvp_call_jaxpr_vmap( def batched_jvp_jaxpr_thunk(): jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers _, args_batched = split_list(in_batched, [num_consts]) - _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False, axis_name) + _, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False, + axis_name, main_type) primals_batched, tangents_batched = split_list(all_batched, [num_out]) out_batched = map(op.or_, primals_batched, tangents_batched) out_dims2.append([0 if b else not_mapped for b in out_batched]) batched_jvp_jaxpr, _ = batching.batch_jaxpr( - jvp_jaxpr, size, args_batched * 2, out_batched * 2, axis_name) + jvp_jaxpr, size, args_batched * 2, out_batched * 2, + axis_name, main_type) return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts batched_outs = custom_jvp_call_jaxpr_p.bind( @@ -610,7 +613,7 @@ def _custom_vjp_call_jaxpr_jvp( ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp def _custom_vjp_call_jaxpr_vmap( - args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr, + args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], bwd: lu.WrappedFun, out_trees: Callable, num_consts: int): axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} @@ -620,7 +623,7 @@ def _custom_vjp_call_jaxpr_vmap( in_batched = [d is not not_mapped for d in in_dims] _, args_batched = split_list(in_batched, [num_consts]) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_size, in_batched, False, axis_name) + fun_jaxpr, axis_size, in_batched, False, axis_name, main_type) out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims2 = [] @@ -628,14 +631,14 @@ def _custom_vjp_call_jaxpr_vmap( def batched_fwd_jaxpr_thunk(): fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_size, args_batched, False, axis_name) + fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type) out_dims2.append([0 if b else not_mapped for b in out_batched]) return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims, - fwd_args_batched) + fwd_args_batched, main_type) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 4c2cb07f2..f0dd09f0d 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -367,7 +367,7 @@ def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue): bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape)))) return xops.Select(bcast_pred, x, y) -def _while_loop_batching_rule(args, dims, axis_name, +def _while_loop_batching_rule(args, dims, axis_name, main_type, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} @@ -383,11 +383,12 @@ def _while_loop_batching_rule(args, dims, axis_name, for _ in range(1 + len(carry_bat)): batched = bconst_bat + carry_bat body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr( - body_jaxpr, size, batched, instantiate=carry_bat, axis_name=axis_name) + body_jaxpr, size, batched, instantiate=carry_bat, + axis_name=axis_name, main_type=main_type) cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr( cond_jaxpr, size, cconst_bat + carry_bat, instantiate=bool(cond_jaxpr.out_avals[0].shape), - axis_name=axis_name) + axis_name=axis_name, main_type=main_type) carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out) if carry_bat_out == carry_bat: break @@ -772,7 +773,7 @@ def _bcast_select(pred, on_true, on_false): pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx) return lax.select(pred, on_true, on_false) -def _cond_batching_rule(args, dims, axis_name, branches, linear): +def _cond_batching_rule(args, dims, axis_name, main_type, branches, linear): size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} index, *ops = args index_dim, *op_dims = dims @@ -786,7 +787,7 @@ def _cond_batching_rule(args, dims, axis_name, branches, linear): index, *ops = [batching.bdim_at_front(x, d, size) for x, d in zip(args, dims)] branches_batched = [ - batching.batch_jaxpr(jaxpr, size, [True] * len(ops), True, axis_name)[0] + batching.batch_jaxpr(jaxpr, size, [True] * len(ops), True, axis_name, main_type)[0] for jaxpr in branches] branch_outs = [] @@ -806,11 +807,11 @@ def _cond_batching_rule(args, dims, axis_name, branches, linear): for b, x, d in zip(ops_bat, ops, op_dims)] branches_out_bat = [ - batching.batch_jaxpr(jaxpr, size, ops_bat, False, axis_name)[1] + batching.batch_jaxpr(jaxpr, size, ops_bat, False, axis_name, main_type)[1] for jaxpr in branches] out_bat = [any(bat) for bat in zip(*branches_out_bat)] branches_batched = tuple( - batching.batch_jaxpr(jaxpr, size, ops_bat, out_bat, axis_name)[0] + batching.batch_jaxpr(jaxpr, size, ops_bat, out_bat, axis_name, main_type)[0] for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] @@ -1731,7 +1732,7 @@ def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.Abstrac return core.ClosedJaxpr(jaxpr, consts) -def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_consts, +def _scan_batching_rule(args, dims, axis_name, main_type, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): num_ys = len(jaxpr.out_avals) - num_carry size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} @@ -1749,7 +1750,8 @@ def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_const jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, size, batched, instantiate=carry_batched + [False] * num_ys, - axis_name=axis_name) + axis_name=axis_name, + main_type=main_type) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] if carry_batched_out == carry_batched: break @@ -2275,7 +2277,7 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs): return [None] * sum(const_lengths) + cotangent_b -def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs): +def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths, jaxprs): orig_bat = [d is not batching.not_mapped for d in dims] size, = { a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped @@ -2295,23 +2297,27 @@ def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs): for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): # Apply vecmat and solve -> new batched parts of x solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( - solve, size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name) + solve, size, solve_bat + b_bat, instantiate=x_bat, + axis_name=axis_name, main_type=main_type) if vecmat is None: vecmat_jaxpr_batched = None x_bat_out = solve_x_bat else: vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( - vecmat, size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name) + vecmat, size, vecmat_bat + b_bat, instantiate=x_bat, + axis_name=axis_name, main_type=main_type) x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat) # Apply matvec and solve_t -> new batched parts of b matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( - matvec, size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name) + matvec, size, matvec_bat + x_bat_out, instantiate=b_bat, + axis_name=axis_name, main_type=main_type) if solve_t is None: solve_t_jaxpr_batched = None b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) else: solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr( - solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name) + solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat, + axis_name=axis_name, main_type=main_type) b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, orig_b_bat) if x_bat_out == x_bat and b_bat_out == b_bat: diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index c449b39f7..adcbf2d2a 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -652,7 +652,7 @@ class EvaluationPlan(NamedTuple): vaxis = raxes[-1] map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes)) map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes)) - f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis) + f = batching.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis) return f def to_mesh_axes(self, in_axes, out_axes): diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index bd0bf2b20..463a81696 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -13,7 +13,7 @@ # limitations under the License. import numpy as np -from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence, Iterable +from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence, Iterable, Type import jax from ..config import config @@ -21,8 +21,8 @@ from .. import core from ..core import raise_to_shaped, Trace, Tracer from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p from .. import linear_util as lu -from .._src.util import (unzip2, partial, safe_map, wrap_name, split_list, - canonicalize_axis, moveaxis, as_hashable_function) +from .._src.util import (unzip2, partial, safe_map, safe_zip, wrap_name, split_list, + canonicalize_axis, moveaxis, as_hashable_function, curry) from . import xla from . import partial_eval as pe @@ -30,19 +30,8 @@ map = safe_map BatchDim = Optional[int] BatchDims = Sequence[BatchDim] -AxesSpec = Union[Callable[[], BatchDims], BatchDims] - -def batch(fun: lu.WrappedFun, axis_name: core.AxisName, - axis_size: Optional[int], in_dims: AxesSpec, out_dim_dests: AxesSpec, - ) -> lu.WrappedFun: - # anlogue of `jvp` in ad.py - # TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int] - fun, out_dims_thunk = batch_subtrace(fun) - return _match_axes(batchfun(fun, axis_name, axis_size, in_dims), - axis_size, in_dims, out_dims_thunk, out_dim_dests) - @lu.transformation -def batchfun(axis_name, axis_size, in_dims, *in_vals): +def batchfun(axis_name, axis_size, in_dims, main_type, *in_vals): # analogue of `jvpfun` in ad.py if axis_size is None: axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} @@ -50,7 +39,7 @@ def batchfun(axis_name, axis_size, in_dims, *in_vals): in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) and not isinstance(core.get_aval(x), core.AbstractUnit) else ax for x, ax in zip(in_vals, in_dims)] - with core.new_main(BatchTrace, axis_name=axis_name) as main: + with core.new_main(main_type, axis_name=axis_name) as main: with core.extend_axis_env(axis_name, axis_size, main): out_vals = yield (main, in_dims, *in_vals), {} del main @@ -137,7 +126,7 @@ class BatchTrace(Trace): frame = core.axis_frame(self.axis_name) val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params) else: - batched_primitive = get_primitive_batcher(primitive, self.axis_name) + batched_primitive = get_primitive_batcher(primitive, self) val_out, dim_out = batched_primitive(vals_in, dims_in, **params) if primitive.multiple_results: return map(partial(BatchTracer, self), val_out, dim_out) @@ -245,7 +234,7 @@ class BatchTrace(Trace): fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims) bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, - out_dims2, in_dims) + out_dims2, in_dims, self.main.trace_type) out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: @@ -262,9 +251,9 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace, # axis names can shadow, so we use the main trace as a tag. return any(main_trace is core.axis_frame(n).main_trace for n in axis_name) -def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests): +def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type): bwd, out_dims_thunk = batch_subtrace(bwd) - return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims), + return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims, main_type), axis_size, out_dims_thunk, out_dim_dests) @lu.transformation @@ -274,6 +263,55 @@ def _match_axes_and_sum(axis_size, out_dims_thunk, out_dim_dests, *in_vals): yield map(partial(matchaxis, axis_size, sum_match=True), out_dims_thunk(), out_dim_dests, out_vals) +### API + +AxesSpec = Union[Callable[[], BatchDims], BatchDims] + +def batch(fun: lu.WrappedFun, + axis_name: core.AxisName, + axis_size: Optional[int], + in_dims: AxesSpec, + out_dim_dests: AxesSpec, + main_type: Type[BatchTrace] = BatchTrace, + ) -> lu.WrappedFun: + # anlogue of `jvp` in ad.py + # TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int] + fun, out_dims_thunk = batch_subtrace(fun) + return _match_axes(batchfun(fun, axis_name, axis_size, in_dims, main_type), + axis_size, in_dims, out_dims_thunk, out_dim_dests) + +# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. +def vtile(f_flat: lu.WrappedFun, + in_axes_flat: Tuple[Optional[int], ...], + out_axes_flat: Tuple[Optional[int], ...], + tile_size: Optional[int], + axis_name: core.AxisName, + main_type: Type[BatchTrace] = BatchTrace): + @curry + def tile_axis(arg, axis: Optional[int], tile_size): + if axis is None: + return arg + shape = list(arg.shape) + shape[axis:axis+1] = [tile_size, shape[axis] // tile_size] + return arg.reshape(shape) + + def untile_axis(out, axis: Optional[int]): + if axis is None: + return out + shape = list(out.shape) + shape[axis:axis+2] = [shape[axis] * shape[axis+1]] + return out.reshape(shape) + + @lu.transformation + def _map_to_tile(*args_flat): + sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None) + tile_size_ = tile_size or next(sizes, None) + assert tile_size_ is not None, "No mapped arguments?" + outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} + yield map(untile_axis, outputs_flat, out_axes_flat) + + return _map_to_tile(batch( + f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type)) ### primitives @@ -281,9 +319,11 @@ BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]] primitive_batchers : Dict[core.Primitive, BatchingRule] = {} initial_style_batchers : Dict[core.Primitive, Any] = {} -def get_primitive_batcher(p, axis_name): +def get_primitive_batcher(p, trace): if p in initial_style_batchers: - return partial(initial_style_batchers[p], axis_name=axis_name) + return partial(initial_style_batchers[p], + axis_name=trace.axis_name, + main_type=trace.main.trace_type) try: return primitive_batchers[p] except KeyError as err: @@ -408,10 +448,10 @@ def bdim_at_front(x, bdim, size): return moveaxis(x, bdim, 0) -def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name): +def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, main_type): f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size) - f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched]) + f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched], main_type) avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval for aval, b in zip(closed_jaxpr.in_avals, in_batched)] jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 8e7bdc662..3209af88b 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -47,7 +47,7 @@ from ..abstract_arrays import array_types from ..core import ConcreteArray, ShapedArray from .._src.util import (partial, unzip3, prod, safe_map, safe_zip, extend_name_stack, wrap_name, assert_unreachable, - tuple_insert, tuple_delete, curry) + tuple_insert, tuple_delete) from ..lib import xla_bridge as xb from ..lib import xla_client as xc from ..lib import pmap_lib @@ -1381,10 +1381,10 @@ def mesh_callable(fun: lu.WrappedFun, # TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice! for name, size in reversed(mesh.shape.items()): if tile_by_mesh_axes: - fun = vtile(fun, - tuple(a.get(name, None) for a in in_axes), - tuple(a.get(name, None) for a in out_axes_thunk()), - tile_size=size, axis_name=name) + fun = batching.vtile(fun, + tuple(a.get(name, None) for a in in_axes), + tuple(a.get(name, None) for a in out_axes_thunk()), + tile_size=size, axis_name=name) global_in_untiled_avals = [untile_aval_nd(global_axis_sizes, aval_in_axes, aval) for aval, aval_in_axes in safe_zip(in_tiled_avals, in_axes)] in_jaxpr_avals = global_in_untiled_avals @@ -1500,37 +1500,6 @@ def compile_and_wrap_mesh_hlo(computation: xc.XlaComputation, backend, handle_outs) return partial(execute_replicated, compiled, backend, handle_args, handle_outs) -# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. -def vtile(f_flat, - in_axes_flat: Tuple[Optional[int], ...], - out_axes_flat: Tuple[Optional[int], ...], - tile_size: Optional[int], axis_name): - @curry - def tile_axis(arg, axis: Optional[int], tile_size): - if axis is None: - return arg - shape = list(arg.shape) - shape[axis:axis+1] = [tile_size, shape[axis] // tile_size] - return arg.reshape(shape) - - def untile_axis(out, axis: Optional[int]): - if axis is None: - return out - shape = list(out.shape) - shape[axis:axis+2] = [shape[axis] * shape[axis+1]] - return out.reshape(shape) - - @lu.transformation - def _map_to_tile(*args_flat): - sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None) - tile_size_ = tile_size or next(sizes, None) - assert tile_size_ is not None, "No mapped arguments?" - outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} - yield map(untile_axis, outputs_flat, out_axes_flat) - - return _map_to_tile( - batching.batch(f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat)) - _forbidden_primitives = { 'xla_pmap': 'pmap', 'sharded_call': 'sharded_jit',