Add batch_jaxpr2 which tells the caller where batch dims are.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
This commit is contained in:
Yash Katariya 2023-01-12 21:16:18 -08:00 committed by jax authors
parent 94f0ccc54a
commit e21c29476d
4 changed files with 50 additions and 18 deletions

View File

@ -1321,13 +1321,8 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics, keep_unused, inline):
# batch_jaxpr expects all batching dimensions to be equal to 0
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(vals_in, dims_in)]
is_mapped_in = [d is not batching.not_mapped for d in dims_in]
new_jaxpr, is_mapped_out = batching.batch_jaxpr(
jaxpr, axis_size, is_mapped_in,
instantiate=False, axis_name=axis_name, main_type=main_type)
new_jaxpr, axes_out = batching.batch_jaxpr2(
jaxpr, axis_size, dims_in, axis_name=axis_name, main_type=main_type)
# `insert_axis` is set to True only for some `xmap` uses.
new_parts = (axis_name,) if insert_axis else (
@ -1339,11 +1334,13 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
mesh = None
in_shardings = tuple(
_pjit_batcher_for_sharding(i, 0, new_parts, mesh, aval.ndim) if is_mapped else i
for is_mapped, i, aval in zip(is_mapped_in, in_shardings, new_jaxpr.in_avals))
_pjit_batcher_for_sharding(i, axis_in, new_parts, mesh, aval.ndim)
if axis_in is not None else i
for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals))
out_shardings = tuple(
_pjit_batcher_for_sharding(o, 0, new_parts, mesh, aval.ndim) if is_mapped else o
for is_mapped, o, aval in zip(is_mapped_out, out_shardings, new_jaxpr.out_avals))
_pjit_batcher_for_sharding(o, axis_out, new_parts, mesh, aval.ndim)
if axis_out is not None else o
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
vals_out = pjit_p.bind(
*vals_in,
jaxpr=new_jaxpr,
@ -1356,15 +1353,15 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
out_positional_semantics=out_positional_semantics,
keep_unused=keep_unused,
inline=inline)
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
return vals_out, dims_out
return vals_out, axes_out
batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
def _pjit_batcher_for_sharding(
s: Union[OpShardingSharding, _UnspecifiedValue], dim: int,
val: Tuple[str, ...], mesh, ndim: int):
s: Union[OpShardingSharding, _UnspecifiedValue],
dim: int, val: Tuple[str, ...], mesh, ndim: int):
if _is_unspecified(s):
return s
if not val:

View File

@ -627,6 +627,31 @@ def reassemble_concat_axes(vals, dims):
### API for batching jaxprs
def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize,
in_axes: Tuple[Union[int, NotMapped], ...],
axis_name: core.AxisName,
main_type: Type[BatchTrace],
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name,
main_type)
@weakref_lru_cache
def _batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize,
in_axes: Tuple[Union[int, NotMapped], ...],
axis_name: core.AxisName,
main_type: Type[BatchTrace],
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_axes = _batch_jaxpr_inner(f, axis_size)
f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval)
if b is not not_mapped else aval
for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name,
main_type):
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
@ -654,7 +679,8 @@ def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name,
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
axis_name, main_type):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest)
f, out_axes = _batch_jaxpr_inner(f, axis_size)
f, out_batched = _match_axes_jaxpr(f, axis_size, out_axes_dest, out_axes)
f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
@ -662,14 +688,21 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
@lu.transformation_with_aux
def _batch_jaxpr_inner(axis_size, out_axes_dest, main, in_axes, *in_vals):
def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals):
trace = main.with_cur_sublevel()
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_axes)]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals, out_axes = unzip2((t.val, t.batch_dim) for t in out_tracers)
yield out_vals, out_axes
@lu.transformation_with_aux
def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes,
*in_vals):
trace = main.with_cur_sublevel()
out_vals = yield (main, in_axes, *in_vals), {}
out_axes = out_axes()
out_axes_dest = [(None if src is not_mapped else 0)
if dst is zero_if_mapped else dst
for src, dst in unsafe_zip(out_axes, out_axes_dest)]

View File

@ -60,6 +60,7 @@ jax_test(
jax_test(
name = "batching_test",
srcs = ["batching_test.py"],
enable_configs = ["cpu_jit_pjit_api_merge"],
shard_count = {
"gpu": 5,
},

View File

@ -1309,7 +1309,8 @@ class NamedArray:
data: Array
def __init__(self, names, data):
assert len(names) == data.ndim
# TODO(mattjj): Enable it back after NamedArray is not a pytree.
# assert len(names) == data.ndim
self.names = names
self.data = data