mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add batch_jaxpr2 which tells the caller where batch dims are.
Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 501746795
This commit is contained in:
parent
94f0ccc54a
commit
e21c29476d
@ -1321,13 +1321,8 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics, keep_unused, inline):
|
||||
# batch_jaxpr expects all batching dimensions to be equal to 0
|
||||
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
||||
else x for x, d in zip(vals_in, dims_in)]
|
||||
is_mapped_in = [d is not batching.not_mapped for d in dims_in]
|
||||
new_jaxpr, is_mapped_out = batching.batch_jaxpr(
|
||||
jaxpr, axis_size, is_mapped_in,
|
||||
instantiate=False, axis_name=axis_name, main_type=main_type)
|
||||
new_jaxpr, axes_out = batching.batch_jaxpr2(
|
||||
jaxpr, axis_size, dims_in, axis_name=axis_name, main_type=main_type)
|
||||
|
||||
# `insert_axis` is set to True only for some `xmap` uses.
|
||||
new_parts = (axis_name,) if insert_axis else (
|
||||
@ -1339,11 +1334,13 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
mesh = None
|
||||
|
||||
in_shardings = tuple(
|
||||
_pjit_batcher_for_sharding(i, 0, new_parts, mesh, aval.ndim) if is_mapped else i
|
||||
for is_mapped, i, aval in zip(is_mapped_in, in_shardings, new_jaxpr.in_avals))
|
||||
_pjit_batcher_for_sharding(i, axis_in, new_parts, mesh, aval.ndim)
|
||||
if axis_in is not None else i
|
||||
for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals))
|
||||
out_shardings = tuple(
|
||||
_pjit_batcher_for_sharding(o, 0, new_parts, mesh, aval.ndim) if is_mapped else o
|
||||
for is_mapped, o, aval in zip(is_mapped_out, out_shardings, new_jaxpr.out_avals))
|
||||
_pjit_batcher_for_sharding(o, axis_out, new_parts, mesh, aval.ndim)
|
||||
if axis_out is not None else o
|
||||
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
|
||||
vals_out = pjit_p.bind(
|
||||
*vals_in,
|
||||
jaxpr=new_jaxpr,
|
||||
@ -1356,15 +1353,15 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
out_positional_semantics=out_positional_semantics,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
|
||||
return vals_out, dims_out
|
||||
return vals_out, axes_out
|
||||
|
||||
batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
|
||||
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
|
||||
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: Union[OpShardingSharding, _UnspecifiedValue], dim: int,
|
||||
val: Tuple[str, ...], mesh, ndim: int):
|
||||
s: Union[OpShardingSharding, _UnspecifiedValue],
|
||||
dim: int, val: Tuple[str, ...], mesh, ndim: int):
|
||||
if _is_unspecified(s):
|
||||
return s
|
||||
if not val:
|
||||
|
@ -627,6 +627,31 @@ def reassemble_concat_axes(vals, dims):
|
||||
|
||||
### API for batching jaxprs
|
||||
|
||||
def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
||||
axis_size: core.AxisSize,
|
||||
in_axes: Tuple[Union[int, NotMapped], ...],
|
||||
axis_name: core.AxisName,
|
||||
main_type: Type[BatchTrace],
|
||||
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
||||
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name,
|
||||
main_type)
|
||||
|
||||
@weakref_lru_cache
|
||||
def _batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
||||
axis_size: core.AxisSize,
|
||||
in_axes: Tuple[Union[int, NotMapped], ...],
|
||||
axis_name: core.AxisName,
|
||||
main_type: Type[BatchTrace],
|
||||
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_size)
|
||||
f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
|
||||
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval)
|
||||
if b is not not_mapped else aval
|
||||
for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
||||
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
|
||||
|
||||
def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name,
|
||||
main_type):
|
||||
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
|
||||
@ -654,7 +679,8 @@ def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name,
|
||||
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
|
||||
axis_name, main_type):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest)
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_size)
|
||||
f, out_batched = _match_axes_jaxpr(f, axis_size, out_axes_dest, out_axes)
|
||||
f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
|
||||
avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
|
||||
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
||||
@ -662,14 +688,21 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _batch_jaxpr_inner(axis_size, out_axes_dest, main, in_axes, *in_vals):
|
||||
def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
||||
for val, dim in zip(in_vals, in_axes)]
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, outs)
|
||||
out_vals, out_axes = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
yield out_vals, out_axes
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes,
|
||||
*in_vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
out_vals = yield (main, in_axes, *in_vals), {}
|
||||
out_axes = out_axes()
|
||||
out_axes_dest = [(None if src is not_mapped else 0)
|
||||
if dst is zero_if_mapped else dst
|
||||
for src, dst in unsafe_zip(out_axes, out_axes_dest)]
|
||||
|
@ -60,6 +60,7 @@ jax_test(
|
||||
jax_test(
|
||||
name = "batching_test",
|
||||
srcs = ["batching_test.py"],
|
||||
enable_configs = ["cpu_jit_pjit_api_merge"],
|
||||
shard_count = {
|
||||
"gpu": 5,
|
||||
},
|
||||
|
@ -1309,7 +1309,8 @@ class NamedArray:
|
||||
data: Array
|
||||
|
||||
def __init__(self, names, data):
|
||||
assert len(names) == data.ndim
|
||||
# TODO(mattjj): Enable it back after NamedArray is not a pytree.
|
||||
# assert len(names) == data.ndim
|
||||
self.names = names
|
||||
self.data = data
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user