mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move vtile to batching.py, make it possible to add new BatchTraces
No substantial behavior change right now, but the ability to use subclasses of BatchTrace comes in handy when adding support for nesting xmaps in the SPMD lowering. PiperOrigin-RevId: 369445693
This commit is contained in:
parent
93c63d0341
commit
c09037bd14
@ -324,7 +324,7 @@ def _custom_jvp_call_jaxpr_jvp(
|
||||
ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp
|
||||
|
||||
def _custom_jvp_call_jaxpr_vmap(
|
||||
args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
num_consts: int):
|
||||
size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
|
||||
@ -333,7 +333,8 @@ def _custom_jvp_call_jaxpr_vmap(
|
||||
num_out = len(fun_jaxpr.out_avals)
|
||||
|
||||
in_batched = [d is not not_mapped for d in in_dims]
|
||||
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(fun_jaxpr, size, in_batched, False, axis_name)
|
||||
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
|
||||
fun_jaxpr, size, in_batched, False, axis_name, main_type)
|
||||
out_dims1 = [0 if b else not_mapped for b in out_batched]
|
||||
out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk
|
||||
|
||||
@ -341,12 +342,14 @@ def _custom_jvp_call_jaxpr_vmap(
|
||||
def batched_jvp_jaxpr_thunk():
|
||||
jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers
|
||||
_, args_batched = split_list(in_batched, [num_consts])
|
||||
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False, axis_name)
|
||||
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False,
|
||||
axis_name, main_type)
|
||||
primals_batched, tangents_batched = split_list(all_batched, [num_out])
|
||||
out_batched = map(op.or_, primals_batched, tangents_batched)
|
||||
out_dims2.append([0 if b else not_mapped for b in out_batched])
|
||||
batched_jvp_jaxpr, _ = batching.batch_jaxpr(
|
||||
jvp_jaxpr, size, args_batched * 2, out_batched * 2, axis_name)
|
||||
jvp_jaxpr, size, args_batched * 2, out_batched * 2,
|
||||
axis_name, main_type)
|
||||
return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts
|
||||
|
||||
batched_outs = custom_jvp_call_jaxpr_p.bind(
|
||||
@ -610,7 +613,7 @@ def _custom_vjp_call_jaxpr_jvp(
|
||||
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
|
||||
|
||||
def _custom_vjp_call_jaxpr_vmap(
|
||||
args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
|
||||
axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
|
||||
@ -620,7 +623,7 @@ def _custom_vjp_call_jaxpr_vmap(
|
||||
in_batched = [d is not not_mapped for d in in_dims]
|
||||
_, args_batched = split_list(in_batched, [num_consts])
|
||||
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
|
||||
fun_jaxpr, axis_size, in_batched, False, axis_name)
|
||||
fun_jaxpr, axis_size, in_batched, False, axis_name, main_type)
|
||||
out_dims1 = [0 if b else not_mapped for b in out_batched]
|
||||
out_dims2 = []
|
||||
|
||||
@ -628,14 +631,14 @@ def _custom_vjp_call_jaxpr_vmap(
|
||||
def batched_fwd_jaxpr_thunk():
|
||||
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
|
||||
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
|
||||
fwd_jaxpr, axis_size, args_batched, False, axis_name)
|
||||
fwd_jaxpr, axis_size, args_batched, False, axis_name, main_type)
|
||||
out_dims2.append([0 if b else not_mapped for b in out_batched])
|
||||
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
|
||||
|
||||
fwd_args_batched = [0 if b else not_mapped for b in args_batched]
|
||||
fwd_out_dims = lambda: out_dims2[0]
|
||||
batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims,
|
||||
fwd_args_batched)
|
||||
fwd_args_batched, main_type)
|
||||
|
||||
batched_outs = custom_vjp_call_jaxpr_p.bind(
|
||||
*args, fun_jaxpr=batched_fun_jaxpr,
|
||||
|
@ -367,7 +367,7 @@ def _pred_bcast_select(c, pred, x, y, x_y_aval: core.AbstractValue):
|
||||
bcast_pred = xops.BroadcastInDim(pred, x_shape, list(range(len(pred_shape))))
|
||||
return xops.Select(bcast_pred, x, y)
|
||||
|
||||
def _while_loop_batching_rule(args, dims, axis_name,
|
||||
def _while_loop_batching_rule(args, dims, axis_name, main_type,
|
||||
cond_nconsts, cond_jaxpr,
|
||||
body_nconsts, body_jaxpr):
|
||||
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
||||
@ -383,11 +383,12 @@ def _while_loop_batching_rule(args, dims, axis_name,
|
||||
for _ in range(1 + len(carry_bat)):
|
||||
batched = bconst_bat + carry_bat
|
||||
body_jaxpr_batched, carry_bat_out = batching.batch_jaxpr(
|
||||
body_jaxpr, size, batched, instantiate=carry_bat, axis_name=axis_name)
|
||||
body_jaxpr, size, batched, instantiate=carry_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
cond_jaxpr_batched, (pred_bat,) = batching.batch_jaxpr(
|
||||
cond_jaxpr, size, cconst_bat + carry_bat,
|
||||
instantiate=bool(cond_jaxpr.out_avals[0].shape),
|
||||
axis_name=axis_name)
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
carry_bat_out = _map(partial(operator.or_, pred_bat), carry_bat_out)
|
||||
if carry_bat_out == carry_bat:
|
||||
break
|
||||
@ -772,7 +773,7 @@ def _bcast_select(pred, on_true, on_false):
|
||||
pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
|
||||
return lax.select(pred, on_true, on_false)
|
||||
|
||||
def _cond_batching_rule(args, dims, axis_name, branches, linear):
|
||||
def _cond_batching_rule(args, dims, axis_name, main_type, branches, linear):
|
||||
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
||||
index, *ops = args
|
||||
index_dim, *op_dims = dims
|
||||
@ -786,7 +787,7 @@ def _cond_batching_rule(args, dims, axis_name, branches, linear):
|
||||
index, *ops = [batching.bdim_at_front(x, d, size) for x, d in zip(args, dims)]
|
||||
|
||||
branches_batched = [
|
||||
batching.batch_jaxpr(jaxpr, size, [True] * len(ops), True, axis_name)[0]
|
||||
batching.batch_jaxpr(jaxpr, size, [True] * len(ops), True, axis_name, main_type)[0]
|
||||
for jaxpr in branches]
|
||||
|
||||
branch_outs = []
|
||||
@ -806,11 +807,11 @@ def _cond_batching_rule(args, dims, axis_name, branches, linear):
|
||||
for b, x, d in zip(ops_bat, ops, op_dims)]
|
||||
|
||||
branches_out_bat = [
|
||||
batching.batch_jaxpr(jaxpr, size, ops_bat, False, axis_name)[1]
|
||||
batching.batch_jaxpr(jaxpr, size, ops_bat, False, axis_name, main_type)[1]
|
||||
for jaxpr in branches]
|
||||
out_bat = [any(bat) for bat in zip(*branches_out_bat)]
|
||||
branches_batched = tuple(
|
||||
batching.batch_jaxpr(jaxpr, size, ops_bat, out_bat, axis_name)[0]
|
||||
batching.batch_jaxpr(jaxpr, size, ops_bat, out_bat, axis_name, main_type)[0]
|
||||
for jaxpr in branches)
|
||||
|
||||
out_dims = [0 if b else batching.not_mapped for b in out_bat]
|
||||
@ -1731,7 +1732,7 @@ def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.Abstrac
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
|
||||
def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_consts,
|
||||
def _scan_batching_rule(args, dims, axis_name, main_type, reverse, length, jaxpr, num_consts,
|
||||
num_carry, linear, unroll):
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
||||
@ -1749,7 +1750,8 @@ def _scan_batching_rule(args, dims, axis_name, reverse, length, jaxpr, num_const
|
||||
jaxpr_batched, batched_out = batching.batch_jaxpr(
|
||||
jaxpr, size, batched,
|
||||
instantiate=carry_batched + [False] * num_ys,
|
||||
axis_name=axis_name)
|
||||
axis_name=axis_name,
|
||||
main_type=main_type)
|
||||
carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:]
|
||||
if carry_batched_out == carry_batched:
|
||||
break
|
||||
@ -2275,7 +2277,7 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs):
|
||||
return [None] * sum(const_lengths) + cotangent_b
|
||||
|
||||
|
||||
def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs):
|
||||
def _linear_solve_batching_rule(args, dims, axis_name, main_type, const_lengths, jaxprs):
|
||||
orig_bat = [d is not batching.not_mapped for d in dims]
|
||||
size, = {
|
||||
a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped
|
||||
@ -2295,23 +2297,27 @@ def _linear_solve_batching_rule(args, dims, axis_name, const_lengths, jaxprs):
|
||||
for i in range(1 + len(orig_b_bat) + len(solve.out_avals)):
|
||||
# Apply vecmat and solve -> new batched parts of x
|
||||
solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr(
|
||||
solve, size, solve_bat + b_bat, instantiate=x_bat, axis_name=axis_name)
|
||||
solve, size, solve_bat + b_bat, instantiate=x_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
if vecmat is None:
|
||||
vecmat_jaxpr_batched = None
|
||||
x_bat_out = solve_x_bat
|
||||
else:
|
||||
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
|
||||
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat, axis_name=axis_name)
|
||||
vecmat, size, vecmat_bat + b_bat, instantiate=x_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
x_bat_out = _map(operator.or_, vecmat_x_bat, solve_x_bat)
|
||||
# Apply matvec and solve_t -> new batched parts of b
|
||||
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
|
||||
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name)
|
||||
matvec, size, matvec_bat + x_bat_out, instantiate=b_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
if solve_t is None:
|
||||
solve_t_jaxpr_batched = None
|
||||
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
|
||||
else:
|
||||
solve_t_jaxpr_batched, solve_t_b_bat = batching.batch_jaxpr(
|
||||
solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat, axis_name=axis_name)
|
||||
solve_t, size, solve_t_bat + x_bat_out, instantiate=b_bat,
|
||||
axis_name=axis_name, main_type=main_type)
|
||||
b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat,
|
||||
orig_b_bat)
|
||||
if x_bat_out == x_bat and b_bat_out == b_bat:
|
||||
|
@ -652,7 +652,7 @@ class EvaluationPlan(NamedTuple):
|
||||
vaxis = raxes[-1]
|
||||
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
|
||||
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
|
||||
f = pxla.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis)
|
||||
f = batching.vtile(f, map_in_axes, map_out_axes, tile_size=local_tile_size, axis_name=vaxis)
|
||||
return f
|
||||
|
||||
def to_mesh_axes(self, in_axes, out_axes):
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence, Iterable
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence, Iterable, Type
|
||||
|
||||
import jax
|
||||
from ..config import config
|
||||
@ -21,8 +21,8 @@ from .. import core
|
||||
from ..core import raise_to_shaped, Trace, Tracer
|
||||
from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
|
||||
from .. import linear_util as lu
|
||||
from .._src.util import (unzip2, partial, safe_map, wrap_name, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function)
|
||||
from .._src.util import (unzip2, partial, safe_map, safe_zip, wrap_name, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function, curry)
|
||||
from . import xla
|
||||
from . import partial_eval as pe
|
||||
|
||||
@ -30,19 +30,8 @@ map = safe_map
|
||||
|
||||
BatchDim = Optional[int]
|
||||
BatchDims = Sequence[BatchDim]
|
||||
AxesSpec = Union[Callable[[], BatchDims], BatchDims]
|
||||
|
||||
def batch(fun: lu.WrappedFun, axis_name: core.AxisName,
|
||||
axis_size: Optional[int], in_dims: AxesSpec, out_dim_dests: AxesSpec,
|
||||
) -> lu.WrappedFun:
|
||||
# anlogue of `jvp` in ad.py
|
||||
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
|
||||
fun, out_dims_thunk = batch_subtrace(fun)
|
||||
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims),
|
||||
axis_size, in_dims, out_dims_thunk, out_dim_dests)
|
||||
|
||||
@lu.transformation
|
||||
def batchfun(axis_name, axis_size, in_dims, *in_vals):
|
||||
def batchfun(axis_name, axis_size, in_dims, main_type, *in_vals):
|
||||
# analogue of `jvpfun` in ad.py
|
||||
if axis_size is None:
|
||||
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
||||
@ -50,7 +39,7 @@ def batchfun(axis_name, axis_size, in_dims, *in_vals):
|
||||
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
|
||||
and not isinstance(core.get_aval(x), core.AbstractUnit)
|
||||
else ax for x, ax in zip(in_vals, in_dims)]
|
||||
with core.new_main(BatchTrace, axis_name=axis_name) as main:
|
||||
with core.new_main(main_type, axis_name=axis_name) as main:
|
||||
with core.extend_axis_env(axis_name, axis_size, main):
|
||||
out_vals = yield (main, in_dims, *in_vals), {}
|
||||
del main
|
||||
@ -137,7 +126,7 @@ class BatchTrace(Trace):
|
||||
frame = core.axis_frame(self.axis_name)
|
||||
val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params)
|
||||
else:
|
||||
batched_primitive = get_primitive_batcher(primitive, self.axis_name)
|
||||
batched_primitive = get_primitive_batcher(primitive, self)
|
||||
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
|
||||
if primitive.multiple_results:
|
||||
return map(partial(BatchTracer, self), val_out, dim_out)
|
||||
@ -245,7 +234,7 @@ class BatchTrace(Trace):
|
||||
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
||||
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
|
||||
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
|
||||
out_dims2, in_dims)
|
||||
out_dims2, in_dims, self.main.trace_type)
|
||||
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
|
||||
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
||||
if not fst:
|
||||
@ -262,9 +251,9 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace,
|
||||
# axis names can shadow, so we use the main trace as a tag.
|
||||
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
|
||||
|
||||
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests):
|
||||
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type):
|
||||
bwd, out_dims_thunk = batch_subtrace(bwd)
|
||||
return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims),
|
||||
return _match_axes_and_sum(batchfun(bwd, axis_name, axis_size, in_dims, main_type),
|
||||
axis_size, out_dims_thunk, out_dim_dests)
|
||||
|
||||
@lu.transformation
|
||||
@ -274,6 +263,55 @@ def _match_axes_and_sum(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
yield map(partial(matchaxis, axis_size, sum_match=True),
|
||||
out_dims_thunk(), out_dim_dests, out_vals)
|
||||
|
||||
### API
|
||||
|
||||
AxesSpec = Union[Callable[[], BatchDims], BatchDims]
|
||||
|
||||
def batch(fun: lu.WrappedFun,
|
||||
axis_name: core.AxisName,
|
||||
axis_size: Optional[int],
|
||||
in_dims: AxesSpec,
|
||||
out_dim_dests: AxesSpec,
|
||||
main_type: Type[BatchTrace] = BatchTrace,
|
||||
) -> lu.WrappedFun:
|
||||
# anlogue of `jvp` in ad.py
|
||||
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
|
||||
fun, out_dims_thunk = batch_subtrace(fun)
|
||||
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims, main_type),
|
||||
axis_size, in_dims, out_dims_thunk, out_dim_dests)
|
||||
|
||||
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
||||
def vtile(f_flat: lu.WrappedFun,
|
||||
in_axes_flat: Tuple[Optional[int], ...],
|
||||
out_axes_flat: Tuple[Optional[int], ...],
|
||||
tile_size: Optional[int],
|
||||
axis_name: core.AxisName,
|
||||
main_type: Type[BatchTrace] = BatchTrace):
|
||||
@curry
|
||||
def tile_axis(arg, axis: Optional[int], tile_size):
|
||||
if axis is None:
|
||||
return arg
|
||||
shape = list(arg.shape)
|
||||
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
|
||||
return arg.reshape(shape)
|
||||
|
||||
def untile_axis(out, axis: Optional[int]):
|
||||
if axis is None:
|
||||
return out
|
||||
shape = list(out.shape)
|
||||
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
|
||||
return out.reshape(shape)
|
||||
|
||||
@lu.transformation
|
||||
def _map_to_tile(*args_flat):
|
||||
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
|
||||
tile_size_ = tile_size or next(sizes, None)
|
||||
assert tile_size_ is not None, "No mapped arguments?"
|
||||
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
|
||||
yield map(untile_axis, outputs_flat, out_axes_flat)
|
||||
|
||||
return _map_to_tile(batch(
|
||||
f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type))
|
||||
|
||||
### primitives
|
||||
|
||||
@ -281,9 +319,11 @@ BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
|
||||
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
|
||||
initial_style_batchers : Dict[core.Primitive, Any] = {}
|
||||
|
||||
def get_primitive_batcher(p, axis_name):
|
||||
def get_primitive_batcher(p, trace):
|
||||
if p in initial_style_batchers:
|
||||
return partial(initial_style_batchers[p], axis_name=axis_name)
|
||||
return partial(initial_style_batchers[p],
|
||||
axis_name=trace.axis_name,
|
||||
main_type=trace.main.trace_type)
|
||||
try:
|
||||
return primitive_batchers[p]
|
||||
except KeyError as err:
|
||||
@ -408,10 +448,10 @@ def bdim_at_front(x, bdim, size):
|
||||
return moveaxis(x, bdim, 0)
|
||||
|
||||
|
||||
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name):
|
||||
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, main_type):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
|
||||
f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched])
|
||||
f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched], main_type)
|
||||
avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
|
||||
for aval, b in zip(closed_jaxpr.in_avals, in_batched)]
|
||||
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
|
@ -47,7 +47,7 @@ from ..abstract_arrays import array_types
|
||||
from ..core import ConcreteArray, ShapedArray
|
||||
from .._src.util import (partial, unzip3, prod, safe_map, safe_zip,
|
||||
extend_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, curry)
|
||||
tuple_insert, tuple_delete)
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from ..lib import pmap_lib
|
||||
@ -1381,10 +1381,10 @@ def mesh_callable(fun: lu.WrappedFun,
|
||||
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
|
||||
for name, size in reversed(mesh.shape.items()):
|
||||
if tile_by_mesh_axes:
|
||||
fun = vtile(fun,
|
||||
tuple(a.get(name, None) for a in in_axes),
|
||||
tuple(a.get(name, None) for a in out_axes_thunk()),
|
||||
tile_size=size, axis_name=name)
|
||||
fun = batching.vtile(fun,
|
||||
tuple(a.get(name, None) for a in in_axes),
|
||||
tuple(a.get(name, None) for a in out_axes_thunk()),
|
||||
tile_size=size, axis_name=name)
|
||||
global_in_untiled_avals = [untile_aval_nd(global_axis_sizes, aval_in_axes, aval)
|
||||
for aval, aval_in_axes in safe_zip(in_tiled_avals, in_axes)]
|
||||
in_jaxpr_avals = global_in_untiled_avals
|
||||
@ -1500,37 +1500,6 @@ def compile_and_wrap_mesh_hlo(computation: xc.XlaComputation, backend,
|
||||
handle_outs)
|
||||
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
||||
|
||||
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
||||
def vtile(f_flat,
|
||||
in_axes_flat: Tuple[Optional[int], ...],
|
||||
out_axes_flat: Tuple[Optional[int], ...],
|
||||
tile_size: Optional[int], axis_name):
|
||||
@curry
|
||||
def tile_axis(arg, axis: Optional[int], tile_size):
|
||||
if axis is None:
|
||||
return arg
|
||||
shape = list(arg.shape)
|
||||
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
|
||||
return arg.reshape(shape)
|
||||
|
||||
def untile_axis(out, axis: Optional[int]):
|
||||
if axis is None:
|
||||
return out
|
||||
shape = list(out.shape)
|
||||
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
|
||||
return out.reshape(shape)
|
||||
|
||||
@lu.transformation
|
||||
def _map_to_tile(*args_flat):
|
||||
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
|
||||
tile_size_ = tile_size or next(sizes, None)
|
||||
assert tile_size_ is not None, "No mapped arguments?"
|
||||
outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {}
|
||||
yield map(untile_axis, outputs_flat, out_axes_flat)
|
||||
|
||||
return _map_to_tile(
|
||||
batching.batch(f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat))
|
||||
|
||||
_forbidden_primitives = {
|
||||
'xla_pmap': 'pmap',
|
||||
'sharded_call': 'sharded_jit',
|
||||
|
Loading…
x
Reference in New Issue
Block a user