mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
Add support for binding axis_name in gmap
This allows executing collectives over the gmapped axes. This requires some extra manipulation of the gmapped jaxpr, since gmap exposes a single logical axis name, but evaluates the program using multiple "physical" axes. This also fixes some bugs around handling `multiple_returns` in vmap collective implementation.
This commit is contained in:
parent
e95d5701e3
commit
7210d6f5d0
@ -1579,7 +1579,8 @@ def omnistaging_enabler() -> None:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_local_state.trace_state.axis_env.pop()
|
||||
frame_ = thread_local_state.trace_state.axis_env.pop()
|
||||
assert frame is frame_ # Only runs if there was was no exception
|
||||
|
||||
def axis_frame(axis_name):
|
||||
frames = thread_local_state.trace_state.axis_env
|
||||
|
@ -30,10 +30,7 @@ from ..interpreters import partial_eval as pe
|
||||
def gmap(fun: Callable, schedule, axis_name = None) -> Callable:
|
||||
warn("gmap is an experimental feature and probably has bugs!")
|
||||
_check_callable(fun)
|
||||
|
||||
if axis_name is not None:
|
||||
raise ValueError("gmap doesn't support binding axis names yet")
|
||||
|
||||
binds_axis_name = axis_name is not None
|
||||
axis_name = _TempAxisName(fun) if axis_name is None else axis_name
|
||||
|
||||
@wraps(fun)
|
||||
@ -42,6 +39,7 @@ def gmap(fun: Callable, schedule, axis_name = None) -> Callable:
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
mapped_invars = (True,) * len(args_flat)
|
||||
axis_size = _mapped_axis_size(in_tree, args_flat, (0,) * len(args_flat), "gmap")
|
||||
parsed_schedule = _normalize_schedule(schedule, axis_size, binds_axis_name)
|
||||
for arg in args_flat: _check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
outs = gmap_p.bind(
|
||||
@ -49,7 +47,8 @@ def gmap(fun: Callable, schedule, axis_name = None) -> Callable:
|
||||
axis_name=axis_name,
|
||||
axis_size=axis_size,
|
||||
mapped_invars=mapped_invars,
|
||||
schedule=tuple(schedule))
|
||||
schedule=parsed_schedule,
|
||||
binds_axis_name=binds_axis_name)
|
||||
return tree_unflatten(out_tree(), outs)
|
||||
return f_gmapped
|
||||
|
||||
@ -62,20 +61,10 @@ class LoopType(enum.Enum):
|
||||
Loop = namedtuple('Loop', ['type', 'size'])
|
||||
|
||||
|
||||
def gmap_impl(fun: lu.WrappedFun, *args, axis_size, axis_name, mapped_invars, schedule):
|
||||
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
|
||||
scheduled_fun = _apply_schedule(fun, axis_size, mapped_invars, schedule, *avals)
|
||||
return scheduled_fun(*args)
|
||||
def _normalize_schedule(schedule, axis_size, binds_axis_name):
|
||||
if not schedule:
|
||||
raise ValueError("gmap expects a non-empty schedule")
|
||||
|
||||
def _parse_name(name):
|
||||
if isinstance(name, LoopType):
|
||||
return name
|
||||
try:
|
||||
return LoopType[name]
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unrecognized loop type: {name}") from err
|
||||
|
||||
def _normalize_schedule(schedule, axis_size):
|
||||
scheduled = 1
|
||||
seen_none = False
|
||||
for loop in schedule:
|
||||
@ -92,31 +81,69 @@ def _normalize_schedule(schedule, axis_size):
|
||||
loop_type = _parse_name(loop[0])
|
||||
if loop_type is LoopType.vectorized and i < len(schedule) - 1:
|
||||
raise ValueError("vectorized loops can only appear as the last component of the schedule")
|
||||
if loop_type is LoopType.sequential and binds_axis_name:
|
||||
raise ValueError("gmaps that bind a new axis name cannot have sequential components in the schedule")
|
||||
new_schedule.append(Loop(loop_type, loop[1] or unscheduled))
|
||||
return new_schedule
|
||||
return tuple(new_schedule)
|
||||
|
||||
def _parse_name(name):
|
||||
if isinstance(name, LoopType):
|
||||
return name
|
||||
try:
|
||||
return LoopType[name]
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unrecognized loop type: {name}") from err
|
||||
|
||||
|
||||
def gmap_impl(fun: lu.WrappedFun, *args, axis_size, axis_name, binds_axis_name, mapped_invars, schedule):
|
||||
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
|
||||
scheduled_fun = _apply_schedule(fun, axis_size, axis_name, binds_axis_name,
|
||||
mapped_invars, schedule, *avals)
|
||||
return scheduled_fun(*args)
|
||||
|
||||
class _GMapSubaxis:
|
||||
def __init__(self, axis_name, index):
|
||||
self.axis_name = axis_name
|
||||
self.index = index
|
||||
def __repr__(self):
|
||||
return f'<subaxis {self.index} of {self.axis_name}>'
|
||||
def __hash__(self):
|
||||
return hash((self.axis_name, self.index))
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, _GMapSubaxis) and
|
||||
self.axis_name == other.axis_name and
|
||||
self.index == other.index)
|
||||
|
||||
@lu.cache
|
||||
def _apply_schedule(fun: lu.WrappedFun, axis_size, mapped_invars, schedule, *avals):
|
||||
def _apply_schedule(fun: lu.WrappedFun,
|
||||
axis_size, full_axis_name, binds_axis_name,
|
||||
mapped_invars,
|
||||
schedule,
|
||||
*avals):
|
||||
assert all(mapped_invars)
|
||||
mapped_avals = [core.mapped_aval(axis_size, aval) if mapped else aval
|
||||
for mapped, aval in zip(mapped_invars, avals)]
|
||||
with core.extend_axis_env(full_axis_name, axis_size, None):
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals)
|
||||
|
||||
schedule = _normalize_schedule(schedule, axis_size)
|
||||
dim_sizes = tuple(loop.size for loop in schedule)
|
||||
axis_names = tuple(_GMapSubaxis(full_axis_name, i) for i in range(len(schedule)))
|
||||
if binds_axis_name:
|
||||
jaxpr = subst_axis_names(jaxpr, full_axis_name, axis_names)
|
||||
|
||||
sched_fun = lambda *args: core.eval_jaxpr(jaxpr, consts, *args)
|
||||
|
||||
if schedule[-1].type is LoopType.vectorized:
|
||||
sched_fun = jax.vmap(sched_fun)
|
||||
nonvector_schedule = schedule[:-1]
|
||||
else:
|
||||
nonvector_schedule = schedule
|
||||
for (ltype, size) in nonvector_schedule[::-1]:
|
||||
if ltype is LoopType.parallel:
|
||||
sched_fun = jax.pmap(sched_fun)
|
||||
for (ltype, size), axis_name in list(zip(schedule, axis_names))[::-1]:
|
||||
if ltype is LoopType.vectorized:
|
||||
sched_fun = jax.vmap(sched_fun, axis_name=axis_name)
|
||||
elif ltype is LoopType.parallel:
|
||||
sched_fun = jax.pmap(sched_fun, axis_name=axis_name)
|
||||
elif ltype is LoopType.sequential:
|
||||
if binds_axis_name:
|
||||
raise NotImplementedError("gmaps with sequential components of the schedule don't support "
|
||||
"collectives yet. Please open a feature request!")
|
||||
assert not binds_axis_name
|
||||
sched_fun = lambda *args, sched_fun=sched_fun: jax.lax.map(lambda xs: sched_fun(*xs), args)
|
||||
|
||||
dim_sizes = tuple(loop.size for loop in schedule)
|
||||
def sched_fun_wrapper(*args):
|
||||
split_args = [arg.reshape(dim_sizes + arg.shape[1:]) for arg in args]
|
||||
results = sched_fun(*split_args)
|
||||
@ -125,3 +152,19 @@ def _apply_schedule(fun: lu.WrappedFun, axis_size, mapped_invars, schedule, *ava
|
||||
|
||||
gmap_p = core.MapPrimitive('gmap')
|
||||
gmap_p.def_impl(gmap_impl)
|
||||
|
||||
|
||||
def subst_axis_names(jaxpr, replaced_name, axis_names):
|
||||
eqns = [subst_eqn_axis_names(eqn, replaced_name, axis_names) for eqn in jaxpr.eqns]
|
||||
return core.Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, eqns)
|
||||
|
||||
def subst_eqn_axis_names(eqn, replaced_name, axis_names):
|
||||
if isinstance(eqn.primitive, (core.CallPrimitive, core.MapPrimitive)):
|
||||
if eqn.params.get('axis_name', None) == replaced_name: # Check for shadowing
|
||||
return eqn
|
||||
new_call_jaxpr = subst_axis_names(eqn.params['call_jaxpr'], replaced_name, axis_names)
|
||||
return eqn._replace(params=dict(eqn.params, call_jaxpr=new_call_jaxpr))
|
||||
elif eqn.params.get('axis_name', None) == replaced_name:
|
||||
return eqn._replace(params=dict(eqn.params, axis_name=axis_names))
|
||||
else:
|
||||
return eqn
|
||||
|
@ -150,7 +150,9 @@ class BatchTrace(Trace):
|
||||
if len(axis_names) > 1:
|
||||
return split_axis(primitive, axis_name, tracers, params)
|
||||
vals_out, dims_out = collective_rules[primitive](vals_in, dims_in, frame.size, **params)
|
||||
return map(partial(BatchTracer, self), vals_out, dims_out)
|
||||
results = map(partial(BatchTracer, self), vals_out, dims_out)
|
||||
print(results)
|
||||
return results if primitive.multiple_results else results[0]
|
||||
# TODO(mattjj,phawkins): if no rule implemented, could vmap-via-map here
|
||||
batched_primitive = get_primitive_batcher(primitive)
|
||||
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
|
||||
|
@ -326,12 +326,14 @@ def _split_axis_comm_assoc(primitive, split_name, args, params):
|
||||
split_params = dict(params, axis_name=split_name)
|
||||
remain_params = dict(params, axis_name=remaining_axes)
|
||||
split_result = primitive.bind(*args, **split_params)
|
||||
if not primitive.multiple_results:
|
||||
split_result = (split_result,)
|
||||
return primitive.bind(*split_result, **remain_params)
|
||||
|
||||
# NB: This is only used for collectives that do not include the vmapped axis name,
|
||||
# which is why the rule is so simple. All other collectives go through split_axis.
|
||||
def _collective_batcher(prim, args, dims, **params):
|
||||
return prim.bind(*args, **params), dims
|
||||
return prim.bind(*args, **params), dims if prim.multiple_results else dims[0]
|
||||
|
||||
def _batched_reduction_collective(prim, if_mapped, if_unmapped,
|
||||
vals_in, dims_in, axis_size,
|
||||
|
@ -984,12 +984,12 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
"seq": seq}
|
||||
for collective, seq in [(lax.psum, jnp.sum),
|
||||
(lax.pmean, jnp.mean),
|
||||
(lambda x, n: lax.pmax(x, n)[0], jnp.max),
|
||||
(lambda x, n: lax.pmin(x, n)[0], jnp.min)])
|
||||
(lambda x, n: lax.pmax(x, n), jnp.max),
|
||||
(lambda x, n: lax.pmin(x, n), jnp.min)])
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testCollective(self, collective, seq):
|
||||
x = jnp.arange(1000).reshape((10, 10, 10))
|
||||
x = jnp.arange(64).reshape((4, 4, 4))
|
||||
self.assertAllClose(
|
||||
vmap(lambda x: x - collective(x, 'i'), axis_name='i')(x),
|
||||
x - seq(x, axis=0))
|
||||
@ -1015,7 +1015,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
perm_pairs = np.stack([np.arange(nelem), perm], axis=-1)
|
||||
rng.shuffle(perm_pairs)
|
||||
self.assertAllClose(
|
||||
vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs)[0], axis_name='i')(x),
|
||||
vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs), axis_name='i')(x),
|
||||
x - x[perm])
|
||||
|
||||
|
||||
|
@ -27,8 +27,10 @@ from absl.testing import parameterized
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax import vmap
|
||||
from jax import lax
|
||||
from jax.experimental.general_map import gmap
|
||||
from jax.lib import xla_bridge
|
||||
from jax.util import curry
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -58,17 +60,36 @@ def tearDownModule():
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
|
||||
class GmapTest(jtu.JaxTestCase):
|
||||
@curry
|
||||
def skip_insufficient_devices(axis_size, fun):
|
||||
@functools.wraps(fun)
|
||||
def wrapper(*args, schedule, **kwargs):
|
||||
for loop, n in schedule:
|
||||
approx_n = axis_size if n is None else n
|
||||
if loop == 'parallel' and approx_n > xla_bridge.device_count():
|
||||
raise SkipTest("this test requires more XLA devices")
|
||||
return fun(*args, schedule=schedule, **kwargs)
|
||||
return wrapper
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_" + name, "schedule": schedule}
|
||||
for name, schedule in [
|
||||
@curry
|
||||
def check_default_schedules(cond, fun):
|
||||
schedules = [
|
||||
('seq', [('sequential', None)]),
|
||||
('vec', [('vectorized', None)]),
|
||||
('par', [('parallel', None)]),
|
||||
('lim_vmap', [('sequential', None), ('vectorized', 2)]),
|
||||
('soft_pmap', [('parallel', 2), ('vectorized', None)])
|
||||
])
|
||||
]
|
||||
schedules = [s for s in schedules if cond(s[1])]
|
||||
return parameterized.named_parameters(
|
||||
{"testcase_name": "_" + name, "schedule": schedule}
|
||||
for name, schedule in schedules)(fun)
|
||||
|
||||
|
||||
class GmapTest(jtu.JaxTestCase):
|
||||
|
||||
@check_default_schedules(lambda _: True)
|
||||
@skip_insufficient_devices(8)
|
||||
@ignore_gmap_warning()
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
@ -78,12 +99,30 @@ class GmapTest(jtu.JaxTestCase):
|
||||
|
||||
x = jnp.arange(800).reshape((8, 10, 10))
|
||||
|
||||
for loop, n in schedule:
|
||||
approx_n = x.shape[0] if n is None else n
|
||||
if loop == 'parallel' and approx_n > xla_bridge.device_count():
|
||||
raise SkipTest("this test requires more XLA devices")
|
||||
self.assertAllClose(gmap(f, schedule)(x), vmap(f)(x))
|
||||
|
||||
self.assertAllClose(vmap(f)(x), gmap(f, schedule)(x))
|
||||
@check_default_schedules(lambda s: not any(c[0] == 'sequential' for c in s))
|
||||
@skip_insufficient_devices(8)
|
||||
@ignore_gmap_warning()
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testAxisName(self, schedule):
|
||||
def f(x):
|
||||
return x - lax.psum(x, 'i')
|
||||
x = jnp.arange(8)
|
||||
self.assertAllClose(gmap(f, schedule, axis_name='i')(x),
|
||||
vmap(f, axis_name='i')(x))
|
||||
|
||||
@ignore_gmap_warning()
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testAxisName2d(self):
|
||||
def f(x):
|
||||
return x - lax.psum(x, 'i') + lax.pmax(x, 'j')
|
||||
x = jnp.arange(8 * 8).reshape((8, 8))
|
||||
s = [('vectorized', None)]
|
||||
self.assertAllClose(gmap(gmap(f, s, axis_name='i'), s, axis_name='j')(x),
|
||||
vmap(vmap(f, axis_name='i'), axis_name='j')(x))
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user