Merge pull request #14418 from mattjj:vmap-spmd-axis-name-tuples

PiperOrigin-RevId: 508777043
This commit is contained in:
jax authors 2023-02-10 16:08:32 -08:00
commit fc507f2ebe
5 changed files with 42 additions and 21 deletions

View File

@ -101,7 +101,7 @@ traceback_util.register_exclusion(__file__)
_dtype = partial(dtypes.dtype, canonicalize=True)
AxisName = Any
AxisName = Hashable
Device = xc.Device
@ -1228,7 +1228,7 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
check_callable(fun)
argnums = core.concrete_or_error(_ensure_index, argnums)
reduce_axes = _ensure_str_tuple(reduce_axes)
reduce_axes = _ensure_str_tuple(reduce_axes) # type: ignore
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
@ -1593,9 +1593,10 @@ def _split(x, indices, axis):
def vmap(fun: F,
in_axes: Union[int, Sequence[Any]] = 0,
out_axes: Any = 0,
axis_name: Optional[Hashable] = None,
axis_name: Optional[AxisName] = None,
axis_size: Optional[int] = None,
spmd_axis_name: Optional[Hashable] = None) -> F:
spmd_axis_name: Optional[Union[AxisName, Tuple[AxisName, ...]]] = None
) -> F:
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
Args:
@ -1733,6 +1734,8 @@ def vmap(fun: F,
docstr += fun.__doc__
axis_name = core.no_axis_name if axis_name is None else axis_name
if spmd_axis_name is not None and type(spmd_axis_name) is not tuple:
spmd_axis_name = (spmd_axis_name,)
if isinstance(in_axes, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
@ -2773,7 +2776,7 @@ def vjp( # type: ignore
-0.2524413
"""
check_callable(fun)
reduce_axes = _ensure_str_tuple(reduce_axes)
reduce_axes = _ensure_str_tuple(reduce_axes) # type: ignore
return _vjp(
lu.wrap_init(fun), *primals, has_aux=has_aux, reduce_axes=reduce_axes)

View File

@ -16,8 +16,8 @@ from __future__ import annotations
import collections
import dataclasses
from functools import partial
from typing import (Any, Callable, Dict, Hashable, Iterable, Optional, Sequence,
Set, Tuple, Type, Union)
from typing import (Any, Callable, Dict, Iterable, Optional, Sequence, Set,
Tuple, Type, Union)
import numpy as np
@ -25,7 +25,7 @@ import jax
from jax.config import config
from jax._src import core
from jax._src import source_info_util
from jax._src.core import raise_to_shaped, Trace, Tracer
from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
@ -111,7 +111,7 @@ class ConcatAxis:
def _update_annotation(
f: lu.WrappedFun, orig_type: Optional[core.InputType],
axis_size: core.AxisSize, axis_name: core.AxisName,
axis_size: core.AxisSize, axis_name: AxisName,
explicit_in_dims: Sequence[Optional[Union[int, ConcatAxis]]],
segment_lens: Sequence[Array],
) -> lu.WrappedFun:
@ -479,7 +479,8 @@ class BatchTrace(Trace):
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
out_dims2, in_dims, self.main.trace_type, self.spmd_axis_name)
out_dims2, in_dims, self.main.trace_type,
self.spmd_axis_name)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
@ -516,7 +517,7 @@ class BatchTrace(Trace):
return vals, todo, bwd_transform
def _main_trace_for_axis_names(main_trace: core.MainTrace,
axis_name: Iterable[core.AxisName],
axis_name: Iterable[AxisName],
) -> bool:
# This function exists to identify whether a main trace corresponds to any of
# the axis names used by a primitive. Axis names alone aren't enough because
@ -525,9 +526,10 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace,
### API for batching callables with vmappable inputs and outputs
def batch(fun: lu.WrappedFun, axis_name: core.AxisName, axis_size,
def batch(fun: lu.WrappedFun, axis_name: AxisName, axis_size,
in_dims, out_dim_dests, main_type: Type[BatchTrace] = BatchTrace,
spmd_axis_name: Optional[Hashable] = None) -> lu.WrappedFun:
spmd_axis_name: Optional[Tuple[AxisName, ...]] = None
) -> lu.WrappedFun:
# we split up _batch_inner and _batch_outer for the leak checker
f = _batch_inner(fun, axis_size, out_dim_dests)
return _batch_outer(f, axis_name, axis_size, in_dims, main_type,
@ -561,7 +563,7 @@ def vtile(f_flat: lu.WrappedFun,
in_axes_flat: Tuple[Optional[int], ...],
out_axes_flat: Tuple[Optional[int], ...],
tile_size: Optional[int],
axis_name: core.AxisName,
axis_name: AxisName,
main_type: Type[BatchTrace] = BatchTrace):
@curry
def tile_axis(arg, axis: Optional[int], tile_size):
@ -630,7 +632,7 @@ def reassemble_concat_axes(vals, dims):
def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize,
in_axes: Tuple[Union[int, NotMapped], ...],
axis_name: core.AxisName,
axis_name: AxisName,
main_type: Type[BatchTrace],
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name,
@ -640,7 +642,7 @@ def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
def _batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
axis_size: core.AxisSize,
in_axes: Tuple[Union[int, NotMapped], ...],
axis_name: core.AxisName,
axis_name: AxisName,
main_type: Type[BatchTrace],
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
@ -763,7 +765,8 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name):
bwd, out_dims_thunk = batch_subtrace(bwd)
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type, spmd_axis_name)
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type,
spmd_axis_name)
return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests)
@lu.transformation

View File

@ -1466,7 +1466,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
# `insert_axis` is set to True only for some `xmap` uses.
new_parts = (axis_name,) if insert_axis else (
() if spmd_axis_name is None else (spmd_axis_name,))
() if spmd_axis_name is None else spmd_axis_name)
if resource_env is not None:
mesh = resource_env.physical_mesh
@ -2017,7 +2017,7 @@ def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
d, = dims_in
# None means unconstrained in ParsedPartitionSpec
new_parts = (axis_name,) if insert_axis else (
None if spmd_axis_name is None else (spmd_axis_name,))
None if spmd_axis_name is None else spmd_axis_name)
unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims}
if new_parts is None:
unconstrained_dims.add(d)

View File

@ -709,7 +709,7 @@ def _shard_map_batch(
for ax in names} for names, d in zip(in_names, in_dims)]
spmd_axis_name = trace.spmd_axis_name
if spmd_axis_name is not None:
new_in_names = [{**ns, d:(spmd_axis_name,)} if d is not batching.not_mapped # type: ignore
new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore
else ns for ns, d in zip(new_in_names, in_dims)]
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
@ -717,7 +717,7 @@ def _shard_map_batch(
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
for ax in names} for names, d in zip(out_names, out_dims())]
if spmd_axis_name is not None:
out_names_ = [{**ns, d:(spmd_axis_name,)} if d is not batching.not_mapped
out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
else ns for ns, d in zip(out_names_, out_dims())]
return out_names_

View File

@ -457,6 +457,21 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertIn('out_names', e.params)
self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},))
def test_vmap_spmd_axis_name_pair(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
def f(x):
return x
x = jnp.arange(4 * 4).reshape(4, 4)
jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name=('x', 'y')))(x).jaxpr
e, = jaxpr.eqns
self.assertIn('in_names', e.params)
self.assertEqual(e.params['in_names'], ({0: ('x', 'y',)},))
self.assertIn('out_names', e.params)
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())