Add support for binding axis_name in gmap

This allows executing collectives over the gmapped axes. This requires
some extra manipulation of the gmapped jaxpr, since gmap exposes a
single logical axis name, but evaluates the program using multiple
"physical" axes.

This also fixes some bugs around handling `multiple_returns` in
vmap collective implementation.
This commit is contained in:
Adam Paszke 2020-08-18 09:14:38 +00:00 committed by Adam Paszke
parent e95d5701e3
commit 7210d6f5d0
6 changed files with 142 additions and 55 deletions

View File

@ -1579,7 +1579,8 @@ def omnistaging_enabler() -> None:
try:
yield
finally:
thread_local_state.trace_state.axis_env.pop()
frame_ = thread_local_state.trace_state.axis_env.pop()
assert frame is frame_ # Only runs if there was was no exception
def axis_frame(axis_name):
frames = thread_local_state.trace_state.axis_env

View File

@ -30,10 +30,7 @@ from ..interpreters import partial_eval as pe
def gmap(fun: Callable, schedule, axis_name = None) -> Callable:
warn("gmap is an experimental feature and probably has bugs!")
_check_callable(fun)
if axis_name is not None:
raise ValueError("gmap doesn't support binding axis names yet")
binds_axis_name = axis_name is not None
axis_name = _TempAxisName(fun) if axis_name is None else axis_name
@wraps(fun)
@ -42,6 +39,7 @@ def gmap(fun: Callable, schedule, axis_name = None) -> Callable:
args_flat, in_tree = tree_flatten((args, kwargs))
mapped_invars = (True,) * len(args_flat)
axis_size = _mapped_axis_size(in_tree, args_flat, (0,) * len(args_flat), "gmap")
parsed_schedule = _normalize_schedule(schedule, axis_size, binds_axis_name)
for arg in args_flat: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
outs = gmap_p.bind(
@ -49,7 +47,8 @@ def gmap(fun: Callable, schedule, axis_name = None) -> Callable:
axis_name=axis_name,
axis_size=axis_size,
mapped_invars=mapped_invars,
schedule=tuple(schedule))
schedule=parsed_schedule,
binds_axis_name=binds_axis_name)
return tree_unflatten(out_tree(), outs)
return f_gmapped
@ -62,20 +61,10 @@ class LoopType(enum.Enum):
Loop = namedtuple('Loop', ['type', 'size'])
def gmap_impl(fun: lu.WrappedFun, *args, axis_size, axis_name, mapped_invars, schedule):
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
scheduled_fun = _apply_schedule(fun, axis_size, mapped_invars, schedule, *avals)
return scheduled_fun(*args)
def _normalize_schedule(schedule, axis_size, binds_axis_name):
if not schedule:
raise ValueError("gmap expects a non-empty schedule")
def _parse_name(name):
if isinstance(name, LoopType):
return name
try:
return LoopType[name]
except KeyError as err:
raise ValueError(f"Unrecognized loop type: {name}") from err
def _normalize_schedule(schedule, axis_size):
scheduled = 1
seen_none = False
for loop in schedule:
@ -92,31 +81,69 @@ def _normalize_schedule(schedule, axis_size):
loop_type = _parse_name(loop[0])
if loop_type is LoopType.vectorized and i < len(schedule) - 1:
raise ValueError("vectorized loops can only appear as the last component of the schedule")
if loop_type is LoopType.sequential and binds_axis_name:
raise ValueError("gmaps that bind a new axis name cannot have sequential components in the schedule")
new_schedule.append(Loop(loop_type, loop[1] or unscheduled))
return new_schedule
return tuple(new_schedule)
def _parse_name(name):
if isinstance(name, LoopType):
return name
try:
return LoopType[name]
except KeyError as err:
raise ValueError(f"Unrecognized loop type: {name}") from err
def gmap_impl(fun: lu.WrappedFun, *args, axis_size, axis_name, binds_axis_name, mapped_invars, schedule):
avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
scheduled_fun = _apply_schedule(fun, axis_size, axis_name, binds_axis_name,
mapped_invars, schedule, *avals)
return scheduled_fun(*args)
class _GMapSubaxis:
def __init__(self, axis_name, index):
self.axis_name = axis_name
self.index = index
def __repr__(self):
return f'<subaxis {self.index} of {self.axis_name}>'
def __hash__(self):
return hash((self.axis_name, self.index))
def __eq__(self, other):
return (isinstance(other, _GMapSubaxis) and
self.axis_name == other.axis_name and
self.index == other.index)
@lu.cache
def _apply_schedule(fun: lu.WrappedFun, axis_size, mapped_invars, schedule, *avals):
def _apply_schedule(fun: lu.WrappedFun,
axis_size, full_axis_name, binds_axis_name,
mapped_invars,
schedule,
*avals):
assert all(mapped_invars)
mapped_avals = [core.mapped_aval(axis_size, aval) if mapped else aval
for mapped, aval in zip(mapped_invars, avals)]
with core.extend_axis_env(full_axis_name, axis_size, None):
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals)
schedule = _normalize_schedule(schedule, axis_size)
dim_sizes = tuple(loop.size for loop in schedule)
axis_names = tuple(_GMapSubaxis(full_axis_name, i) for i in range(len(schedule)))
if binds_axis_name:
jaxpr = subst_axis_names(jaxpr, full_axis_name, axis_names)
sched_fun = lambda *args: core.eval_jaxpr(jaxpr, consts, *args)
if schedule[-1].type is LoopType.vectorized:
sched_fun = jax.vmap(sched_fun)
nonvector_schedule = schedule[:-1]
else:
nonvector_schedule = schedule
for (ltype, size) in nonvector_schedule[::-1]:
if ltype is LoopType.parallel:
sched_fun = jax.pmap(sched_fun)
for (ltype, size), axis_name in list(zip(schedule, axis_names))[::-1]:
if ltype is LoopType.vectorized:
sched_fun = jax.vmap(sched_fun, axis_name=axis_name)
elif ltype is LoopType.parallel:
sched_fun = jax.pmap(sched_fun, axis_name=axis_name)
elif ltype is LoopType.sequential:
if binds_axis_name:
raise NotImplementedError("gmaps with sequential components of the schedule don't support "
"collectives yet. Please open a feature request!")
assert not binds_axis_name
sched_fun = lambda *args, sched_fun=sched_fun: jax.lax.map(lambda xs: sched_fun(*xs), args)
dim_sizes = tuple(loop.size for loop in schedule)
def sched_fun_wrapper(*args):
split_args = [arg.reshape(dim_sizes + arg.shape[1:]) for arg in args]
results = sched_fun(*split_args)
@ -125,3 +152,19 @@ def _apply_schedule(fun: lu.WrappedFun, axis_size, mapped_invars, schedule, *ava
gmap_p = core.MapPrimitive('gmap')
gmap_p.def_impl(gmap_impl)
def subst_axis_names(jaxpr, replaced_name, axis_names):
eqns = [subst_eqn_axis_names(eqn, replaced_name, axis_names) for eqn in jaxpr.eqns]
return core.Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, eqns)
def subst_eqn_axis_names(eqn, replaced_name, axis_names):
if isinstance(eqn.primitive, (core.CallPrimitive, core.MapPrimitive)):
if eqn.params.get('axis_name', None) == replaced_name: # Check for shadowing
return eqn
new_call_jaxpr = subst_axis_names(eqn.params['call_jaxpr'], replaced_name, axis_names)
return eqn._replace(params=dict(eqn.params, call_jaxpr=new_call_jaxpr))
elif eqn.params.get('axis_name', None) == replaced_name:
return eqn._replace(params=dict(eqn.params, axis_name=axis_names))
else:
return eqn

View File

@ -150,7 +150,9 @@ class BatchTrace(Trace):
if len(axis_names) > 1:
return split_axis(primitive, axis_name, tracers, params)
vals_out, dims_out = collective_rules[primitive](vals_in, dims_in, frame.size, **params)
return map(partial(BatchTracer, self), vals_out, dims_out)
results = map(partial(BatchTracer, self), vals_out, dims_out)
print(results)
return results if primitive.multiple_results else results[0]
# TODO(mattjj,phawkins): if no rule implemented, could vmap-via-map here
batched_primitive = get_primitive_batcher(primitive)
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)

View File

@ -326,12 +326,14 @@ def _split_axis_comm_assoc(primitive, split_name, args, params):
split_params = dict(params, axis_name=split_name)
remain_params = dict(params, axis_name=remaining_axes)
split_result = primitive.bind(*args, **split_params)
if not primitive.multiple_results:
split_result = (split_result,)
return primitive.bind(*split_result, **remain_params)
# NB: This is only used for collectives that do not include the vmapped axis name,
# which is why the rule is so simple. All other collectives go through split_axis.
def _collective_batcher(prim, args, dims, **params):
return prim.bind(*args, **params), dims
return prim.bind(*args, **params), dims if prim.multiple_results else dims[0]
def _batched_reduction_collective(prim, if_mapped, if_unmapped,
vals_in, dims_in, axis_size,

View File

@ -984,12 +984,12 @@ class BatchingTest(jtu.JaxTestCase):
"seq": seq}
for collective, seq in [(lax.psum, jnp.sum),
(lax.pmean, jnp.mean),
(lambda x, n: lax.pmax(x, n)[0], jnp.max),
(lambda x, n: lax.pmin(x, n)[0], jnp.min)])
(lambda x, n: lax.pmax(x, n), jnp.max),
(lambda x, n: lax.pmin(x, n), jnp.min)])
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testCollective(self, collective, seq):
x = jnp.arange(1000).reshape((10, 10, 10))
x = jnp.arange(64).reshape((4, 4, 4))
self.assertAllClose(
vmap(lambda x: x - collective(x, 'i'), axis_name='i')(x),
x - seq(x, axis=0))
@ -1015,7 +1015,7 @@ class BatchingTest(jtu.JaxTestCase):
perm_pairs = np.stack([np.arange(nelem), perm], axis=-1)
rng.shuffle(perm_pairs)
self.assertAllClose(
vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs)[0], axis_name='i')(x),
vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs), axis_name='i')(x),
x - x[perm])

View File

@ -27,8 +27,10 @@ from absl.testing import parameterized
import jax.numpy as jnp
from jax import test_util as jtu
from jax import vmap
from jax import lax
from jax.experimental.general_map import gmap
from jax.lib import xla_bridge
from jax.util import curry
from jax.config import config
config.parse_flags_with_absl()
@ -58,17 +60,36 @@ def tearDownModule():
xla_bridge.get_backend.cache_clear()
class GmapTest(jtu.JaxTestCase):
@curry
def skip_insufficient_devices(axis_size, fun):
@functools.wraps(fun)
def wrapper(*args, schedule, **kwargs):
for loop, n in schedule:
approx_n = axis_size if n is None else n
if loop == 'parallel' and approx_n > xla_bridge.device_count():
raise SkipTest("this test requires more XLA devices")
return fun(*args, schedule=schedule, **kwargs)
return wrapper
@parameterized.named_parameters(
{"testcase_name": "_" + name, "schedule": schedule}
for name, schedule in [
@curry
def check_default_schedules(cond, fun):
schedules = [
('seq', [('sequential', None)]),
('vec', [('vectorized', None)]),
('par', [('parallel', None)]),
('lim_vmap', [('sequential', None), ('vectorized', 2)]),
('soft_pmap', [('parallel', 2), ('vectorized', None)])
])
]
schedules = [s for s in schedules if cond(s[1])]
return parameterized.named_parameters(
{"testcase_name": "_" + name, "schedule": schedule}
for name, schedule in schedules)(fun)
class GmapTest(jtu.JaxTestCase):
@check_default_schedules(lambda _: True)
@skip_insufficient_devices(8)
@ignore_gmap_warning()
@skipIf(not config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
@ -78,12 +99,30 @@ class GmapTest(jtu.JaxTestCase):
x = jnp.arange(800).reshape((8, 10, 10))
for loop, n in schedule:
approx_n = x.shape[0] if n is None else n
if loop == 'parallel' and approx_n > xla_bridge.device_count():
raise SkipTest("this test requires more XLA devices")
self.assertAllClose(gmap(f, schedule)(x), vmap(f)(x))
self.assertAllClose(vmap(f)(x), gmap(f, schedule)(x))
@check_default_schedules(lambda s: not any(c[0] == 'sequential' for c in s))
@skip_insufficient_devices(8)
@ignore_gmap_warning()
@skipIf(not config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testAxisName(self, schedule):
def f(x):
return x - lax.psum(x, 'i')
x = jnp.arange(8)
self.assertAllClose(gmap(f, schedule, axis_name='i')(x),
vmap(f, axis_name='i')(x))
@ignore_gmap_warning()
@skipIf(not config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testAxisName2d(self):
def f(x):
return x - lax.psum(x, 'i') + lax.pmax(x, 'j')
x = jnp.arange(8 * 8).reshape((8, 8))
s = [('vectorized', None)]
self.assertAllClose(gmap(gmap(f, s, axis_name='i'), s, axis_name='j')(x),
vmap(vmap(f, axis_name='i'), axis_name='j')(x))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())