mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14418 from mattjj:vmap-spmd-axis-name-tuples
PiperOrigin-RevId: 508777043
This commit is contained in:
commit
fc507f2ebe
@ -101,7 +101,7 @@ traceback_util.register_exclusion(__file__)
|
||||
|
||||
_dtype = partial(dtypes.dtype, canonicalize=True)
|
||||
|
||||
AxisName = Any
|
||||
AxisName = Hashable
|
||||
|
||||
Device = xc.Device
|
||||
|
||||
@ -1228,7 +1228,7 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
|
||||
check_callable(fun)
|
||||
argnums = core.concrete_or_error(_ensure_index, argnums)
|
||||
reduce_axes = _ensure_str_tuple(reduce_axes)
|
||||
reduce_axes = _ensure_str_tuple(reduce_axes) # type: ignore
|
||||
|
||||
@wraps(fun, docstr=docstr, argnums=argnums)
|
||||
@api_boundary
|
||||
@ -1593,9 +1593,10 @@ def _split(x, indices, axis):
|
||||
def vmap(fun: F,
|
||||
in_axes: Union[int, Sequence[Any]] = 0,
|
||||
out_axes: Any = 0,
|
||||
axis_name: Optional[Hashable] = None,
|
||||
axis_name: Optional[AxisName] = None,
|
||||
axis_size: Optional[int] = None,
|
||||
spmd_axis_name: Optional[Hashable] = None) -> F:
|
||||
spmd_axis_name: Optional[Union[AxisName, Tuple[AxisName, ...]]] = None
|
||||
) -> F:
|
||||
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
|
||||
|
||||
Args:
|
||||
@ -1733,6 +1734,8 @@ def vmap(fun: F,
|
||||
docstr += fun.__doc__
|
||||
|
||||
axis_name = core.no_axis_name if axis_name is None else axis_name
|
||||
if spmd_axis_name is not None and type(spmd_axis_name) is not tuple:
|
||||
spmd_axis_name = (spmd_axis_name,)
|
||||
|
||||
if isinstance(in_axes, list):
|
||||
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
||||
@ -2773,7 +2776,7 @@ def vjp( # type: ignore
|
||||
-0.2524413
|
||||
"""
|
||||
check_callable(fun)
|
||||
reduce_axes = _ensure_str_tuple(reduce_axes)
|
||||
reduce_axes = _ensure_str_tuple(reduce_axes) # type: ignore
|
||||
return _vjp(
|
||||
lu.wrap_init(fun), *primals, has_aux=has_aux, reduce_axes=reduce_axes)
|
||||
|
||||
|
@ -16,8 +16,8 @@ from __future__ import annotations
|
||||
import collections
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Hashable, Iterable, Optional, Sequence,
|
||||
Set, Tuple, Type, Union)
|
||||
from typing import (Any, Callable, Dict, Iterable, Optional, Sequence, Set,
|
||||
Tuple, Type, Union)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -25,7 +25,7 @@ import jax
|
||||
from jax.config import config
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src.core import raise_to_shaped, Trace, Tracer
|
||||
from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
|
||||
from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
register_pytree_node)
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
@ -111,7 +111,7 @@ class ConcatAxis:
|
||||
|
||||
def _update_annotation(
|
||||
f: lu.WrappedFun, orig_type: Optional[core.InputType],
|
||||
axis_size: core.AxisSize, axis_name: core.AxisName,
|
||||
axis_size: core.AxisSize, axis_name: AxisName,
|
||||
explicit_in_dims: Sequence[Optional[Union[int, ConcatAxis]]],
|
||||
segment_lens: Sequence[Array],
|
||||
) -> lu.WrappedFun:
|
||||
@ -479,7 +479,8 @@ class BatchTrace(Trace):
|
||||
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
||||
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
|
||||
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
|
||||
out_dims2, in_dims, self.main.trace_type, self.spmd_axis_name)
|
||||
out_dims2, in_dims, self.main.trace_type,
|
||||
self.spmd_axis_name)
|
||||
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
|
||||
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
||||
if not fst:
|
||||
@ -516,7 +517,7 @@ class BatchTrace(Trace):
|
||||
return vals, todo, bwd_transform
|
||||
|
||||
def _main_trace_for_axis_names(main_trace: core.MainTrace,
|
||||
axis_name: Iterable[core.AxisName],
|
||||
axis_name: Iterable[AxisName],
|
||||
) -> bool:
|
||||
# This function exists to identify whether a main trace corresponds to any of
|
||||
# the axis names used by a primitive. Axis names alone aren't enough because
|
||||
@ -525,9 +526,10 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace,
|
||||
|
||||
### API for batching callables with vmappable inputs and outputs
|
||||
|
||||
def batch(fun: lu.WrappedFun, axis_name: core.AxisName, axis_size,
|
||||
def batch(fun: lu.WrappedFun, axis_name: AxisName, axis_size,
|
||||
in_dims, out_dim_dests, main_type: Type[BatchTrace] = BatchTrace,
|
||||
spmd_axis_name: Optional[Hashable] = None) -> lu.WrappedFun:
|
||||
spmd_axis_name: Optional[Tuple[AxisName, ...]] = None
|
||||
) -> lu.WrappedFun:
|
||||
# we split up _batch_inner and _batch_outer for the leak checker
|
||||
f = _batch_inner(fun, axis_size, out_dim_dests)
|
||||
return _batch_outer(f, axis_name, axis_size, in_dims, main_type,
|
||||
@ -561,7 +563,7 @@ def vtile(f_flat: lu.WrappedFun,
|
||||
in_axes_flat: Tuple[Optional[int], ...],
|
||||
out_axes_flat: Tuple[Optional[int], ...],
|
||||
tile_size: Optional[int],
|
||||
axis_name: core.AxisName,
|
||||
axis_name: AxisName,
|
||||
main_type: Type[BatchTrace] = BatchTrace):
|
||||
@curry
|
||||
def tile_axis(arg, axis: Optional[int], tile_size):
|
||||
@ -630,7 +632,7 @@ def reassemble_concat_axes(vals, dims):
|
||||
def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
||||
axis_size: core.AxisSize,
|
||||
in_axes: Tuple[Union[int, NotMapped], ...],
|
||||
axis_name: core.AxisName,
|
||||
axis_name: AxisName,
|
||||
main_type: Type[BatchTrace],
|
||||
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
||||
return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name,
|
||||
@ -640,7 +642,7 @@ def batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
||||
def _batch_jaxpr2(closed_jaxpr: core.ClosedJaxpr,
|
||||
axis_size: core.AxisSize,
|
||||
in_axes: Tuple[Union[int, NotMapped], ...],
|
||||
axis_name: core.AxisName,
|
||||
axis_name: AxisName,
|
||||
main_type: Type[BatchTrace],
|
||||
) -> Tuple[core.ClosedJaxpr, Tuple[Union[int, NotMapped], ...]]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
@ -763,7 +765,8 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
|
||||
|
||||
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name):
|
||||
bwd, out_dims_thunk = batch_subtrace(bwd)
|
||||
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type, spmd_axis_name)
|
||||
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type,
|
||||
spmd_axis_name)
|
||||
return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests)
|
||||
|
||||
@lu.transformation
|
||||
|
@ -1466,7 +1466,7 @@ def _pjit_batcher(insert_axis, spmd_axis_name,
|
||||
|
||||
# `insert_axis` is set to True only for some `xmap` uses.
|
||||
new_parts = (axis_name,) if insert_axis else (
|
||||
() if spmd_axis_name is None else (spmd_axis_name,))
|
||||
() if spmd_axis_name is None else spmd_axis_name)
|
||||
|
||||
if resource_env is not None:
|
||||
mesh = resource_env.physical_mesh
|
||||
@ -2017,7 +2017,7 @@ def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
|
||||
d, = dims_in
|
||||
# None means unconstrained in ParsedPartitionSpec
|
||||
new_parts = (axis_name,) if insert_axis else (
|
||||
None if spmd_axis_name is None else (spmd_axis_name,))
|
||||
None if spmd_axis_name is None else spmd_axis_name)
|
||||
unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims}
|
||||
if new_parts is None:
|
||||
unconstrained_dims.add(d)
|
||||
|
@ -709,7 +709,7 @@ def _shard_map_batch(
|
||||
for ax in names} for names, d in zip(in_names, in_dims)]
|
||||
spmd_axis_name = trace.spmd_axis_name
|
||||
if spmd_axis_name is not None:
|
||||
new_in_names = [{**ns, d:(spmd_axis_name,)} if d is not batching.not_mapped # type: ignore
|
||||
new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped # type: ignore
|
||||
else ns for ns, d in zip(new_in_names, in_dims)]
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
def new_out_names_thunk():
|
||||
@ -717,7 +717,7 @@ def _shard_map_batch(
|
||||
out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax]
|
||||
for ax in names} for names, d in zip(out_names, out_dims())]
|
||||
if spmd_axis_name is not None:
|
||||
out_names_ = [{**ns, d:(spmd_axis_name,)} if d is not batching.not_mapped
|
||||
out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
|
||||
else ns for ns, d in zip(out_names_, out_dims())]
|
||||
return out_names_
|
||||
|
||||
|
@ -457,6 +457,21 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertIn('out_names', e.params)
|
||||
self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},))
|
||||
|
||||
def test_vmap_spmd_axis_name_pair(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
x = jnp.arange(4 * 4).reshape(4, 4)
|
||||
jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name=('x', 'y')))(x).jaxpr
|
||||
e, = jaxpr.eqns
|
||||
self.assertIn('in_names', e.params)
|
||||
self.assertEqual(e.params['in_names'], ({0: ('x', 'y',)},))
|
||||
self.assertIn('out_names', e.params)
|
||||
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user