diff --git a/jax/core.py b/jax/core.py index 38a2af1fa..c0a3f93d2 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1579,7 +1579,8 @@ def omnistaging_enabler() -> None: try: yield finally: - thread_local_state.trace_state.axis_env.pop() + frame_ = thread_local_state.trace_state.axis_env.pop() + assert frame is frame_ # Only runs if there was was no exception def axis_frame(axis_name): frames = thread_local_state.trace_state.axis_env diff --git a/jax/experimental/general_map.py b/jax/experimental/general_map.py index 156abe79f..f810a55c9 100644 --- a/jax/experimental/general_map.py +++ b/jax/experimental/general_map.py @@ -30,10 +30,7 @@ from ..interpreters import partial_eval as pe def gmap(fun: Callable, schedule, axis_name = None) -> Callable: warn("gmap is an experimental feature and probably has bugs!") _check_callable(fun) - - if axis_name is not None: - raise ValueError("gmap doesn't support binding axis names yet") - + binds_axis_name = axis_name is not None axis_name = _TempAxisName(fun) if axis_name is None else axis_name @wraps(fun) @@ -42,6 +39,7 @@ def gmap(fun: Callable, schedule, axis_name = None) -> Callable: args_flat, in_tree = tree_flatten((args, kwargs)) mapped_invars = (True,) * len(args_flat) axis_size = _mapped_axis_size(in_tree, args_flat, (0,) * len(args_flat), "gmap") + parsed_schedule = _normalize_schedule(schedule, axis_size, binds_axis_name) for arg in args_flat: _check_arg(arg) flat_fun, out_tree = flatten_fun(f, in_tree) outs = gmap_p.bind( @@ -49,7 +47,8 @@ def gmap(fun: Callable, schedule, axis_name = None) -> Callable: axis_name=axis_name, axis_size=axis_size, mapped_invars=mapped_invars, - schedule=tuple(schedule)) + schedule=parsed_schedule, + binds_axis_name=binds_axis_name) return tree_unflatten(out_tree(), outs) return f_gmapped @@ -62,20 +61,10 @@ class LoopType(enum.Enum): Loop = namedtuple('Loop', ['type', 'size']) -def gmap_impl(fun: lu.WrappedFun, *args, axis_size, axis_name, mapped_invars, schedule): - avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args] - scheduled_fun = _apply_schedule(fun, axis_size, mapped_invars, schedule, *avals) - return scheduled_fun(*args) +def _normalize_schedule(schedule, axis_size, binds_axis_name): + if not schedule: + raise ValueError("gmap expects a non-empty schedule") -def _parse_name(name): - if isinstance(name, LoopType): - return name - try: - return LoopType[name] - except KeyError as err: - raise ValueError(f"Unrecognized loop type: {name}") from err - -def _normalize_schedule(schedule, axis_size): scheduled = 1 seen_none = False for loop in schedule: @@ -92,31 +81,69 @@ def _normalize_schedule(schedule, axis_size): loop_type = _parse_name(loop[0]) if loop_type is LoopType.vectorized and i < len(schedule) - 1: raise ValueError("vectorized loops can only appear as the last component of the schedule") + if loop_type is LoopType.sequential and binds_axis_name: + raise ValueError("gmaps that bind a new axis name cannot have sequential components in the schedule") new_schedule.append(Loop(loop_type, loop[1] or unscheduled)) - return new_schedule + return tuple(new_schedule) + +def _parse_name(name): + if isinstance(name, LoopType): + return name + try: + return LoopType[name] + except KeyError as err: + raise ValueError(f"Unrecognized loop type: {name}") from err + + +def gmap_impl(fun: lu.WrappedFun, *args, axis_size, axis_name, binds_axis_name, mapped_invars, schedule): + avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args] + scheduled_fun = _apply_schedule(fun, axis_size, axis_name, binds_axis_name, + mapped_invars, schedule, *avals) + return scheduled_fun(*args) + +class _GMapSubaxis: + def __init__(self, axis_name, index): + self.axis_name = axis_name + self.index = index + def __repr__(self): + return f'' + def __hash__(self): + return hash((self.axis_name, self.index)) + def __eq__(self, other): + return (isinstance(other, _GMapSubaxis) and + self.axis_name == other.axis_name and + self.index == other.index) @lu.cache -def _apply_schedule(fun: lu.WrappedFun, axis_size, mapped_invars, schedule, *avals): +def _apply_schedule(fun: lu.WrappedFun, + axis_size, full_axis_name, binds_axis_name, + mapped_invars, + schedule, + *avals): + assert all(mapped_invars) mapped_avals = [core.mapped_aval(axis_size, aval) if mapped else aval for mapped, aval in zip(mapped_invars, avals)] - jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals) + with core.extend_axis_env(full_axis_name, axis_size, None): + jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals) - schedule = _normalize_schedule(schedule, axis_size) - dim_sizes = tuple(loop.size for loop in schedule) + axis_names = tuple(_GMapSubaxis(full_axis_name, i) for i in range(len(schedule))) + if binds_axis_name: + jaxpr = subst_axis_names(jaxpr, full_axis_name, axis_names) sched_fun = lambda *args: core.eval_jaxpr(jaxpr, consts, *args) - - if schedule[-1].type is LoopType.vectorized: - sched_fun = jax.vmap(sched_fun) - nonvector_schedule = schedule[:-1] - else: - nonvector_schedule = schedule - for (ltype, size) in nonvector_schedule[::-1]: - if ltype is LoopType.parallel: - sched_fun = jax.pmap(sched_fun) + for (ltype, size), axis_name in list(zip(schedule, axis_names))[::-1]: + if ltype is LoopType.vectorized: + sched_fun = jax.vmap(sched_fun, axis_name=axis_name) + elif ltype is LoopType.parallel: + sched_fun = jax.pmap(sched_fun, axis_name=axis_name) elif ltype is LoopType.sequential: + if binds_axis_name: + raise NotImplementedError("gmaps with sequential components of the schedule don't support " + "collectives yet. Please open a feature request!") + assert not binds_axis_name sched_fun = lambda *args, sched_fun=sched_fun: jax.lax.map(lambda xs: sched_fun(*xs), args) + dim_sizes = tuple(loop.size for loop in schedule) def sched_fun_wrapper(*args): split_args = [arg.reshape(dim_sizes + arg.shape[1:]) for arg in args] results = sched_fun(*split_args) @@ -125,3 +152,19 @@ def _apply_schedule(fun: lu.WrappedFun, axis_size, mapped_invars, schedule, *ava gmap_p = core.MapPrimitive('gmap') gmap_p.def_impl(gmap_impl) + + +def subst_axis_names(jaxpr, replaced_name, axis_names): + eqns = [subst_eqn_axis_names(eqn, replaced_name, axis_names) for eqn in jaxpr.eqns] + return core.Jaxpr(jaxpr.constvars, jaxpr.invars, jaxpr.outvars, eqns) + +def subst_eqn_axis_names(eqn, replaced_name, axis_names): + if isinstance(eqn.primitive, (core.CallPrimitive, core.MapPrimitive)): + if eqn.params.get('axis_name', None) == replaced_name: # Check for shadowing + return eqn + new_call_jaxpr = subst_axis_names(eqn.params['call_jaxpr'], replaced_name, axis_names) + return eqn._replace(params=dict(eqn.params, call_jaxpr=new_call_jaxpr)) + elif eqn.params.get('axis_name', None) == replaced_name: + return eqn._replace(params=dict(eqn.params, axis_name=axis_names)) + else: + return eqn diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 66891d50f..ed25c88e4 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -150,7 +150,9 @@ class BatchTrace(Trace): if len(axis_names) > 1: return split_axis(primitive, axis_name, tracers, params) vals_out, dims_out = collective_rules[primitive](vals_in, dims_in, frame.size, **params) - return map(partial(BatchTracer, self), vals_out, dims_out) + results = map(partial(BatchTracer, self), vals_out, dims_out) + print(results) + return results if primitive.multiple_results else results[0] # TODO(mattjj,phawkins): if no rule implemented, could vmap-via-map here batched_primitive = get_primitive_batcher(primitive) val_out, dim_out = batched_primitive(vals_in, dims_in, **params) diff --git a/jax/lax/lax_parallel.py b/jax/lax/lax_parallel.py index 6cab76c22..2c13df9a5 100644 --- a/jax/lax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -326,16 +326,18 @@ def _split_axis_comm_assoc(primitive, split_name, args, params): split_params = dict(params, axis_name=split_name) remain_params = dict(params, axis_name=remaining_axes) split_result = primitive.bind(*args, **split_params) + if not primitive.multiple_results: + split_result = (split_result,) return primitive.bind(*split_result, **remain_params) # NB: This is only used for collectives that do not include the vmapped axis name, # which is why the rule is so simple. All other collectives go through split_axis. def _collective_batcher(prim, args, dims, **params): - return prim.bind(*args, **params), dims + return prim.bind(*args, **params), dims if prim.multiple_results else dims[0] def _batched_reduction_collective(prim, if_mapped, if_unmapped, - vals_in, dims_in, axis_size, - axis_name, axis_index_groups): + vals_in, dims_in, axis_size, + axis_name, axis_index_groups): if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not implemented in vmap collectives. " "Please open a feature request!") diff --git a/tests/batching_test.py b/tests/batching_test.py index 3210186ae..9203d2dec 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -984,12 +984,12 @@ class BatchingTest(jtu.JaxTestCase): "seq": seq} for collective, seq in [(lax.psum, jnp.sum), (lax.pmean, jnp.mean), - (lambda x, n: lax.pmax(x, n)[0], jnp.max), - (lambda x, n: lax.pmin(x, n)[0], jnp.min)]) + (lambda x, n: lax.pmax(x, n), jnp.max), + (lambda x, n: lax.pmin(x, n), jnp.min)]) @skipIf(not jax.config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") def testCollective(self, collective, seq): - x = jnp.arange(1000).reshape((10, 10, 10)) + x = jnp.arange(64).reshape((4, 4, 4)) self.assertAllClose( vmap(lambda x: x - collective(x, 'i'), axis_name='i')(x), x - seq(x, axis=0)) @@ -1015,7 +1015,7 @@ class BatchingTest(jtu.JaxTestCase): perm_pairs = np.stack([np.arange(nelem), perm], axis=-1) rng.shuffle(perm_pairs) self.assertAllClose( - vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs)[0], axis_name='i')(x), + vmap(lambda x: x - lax.ppermute(x, 'i', perm_pairs), axis_name='i')(x), x - x[perm]) diff --git a/tests/gmap_test.py b/tests/gmap_test.py index 2b10ba147..d0fbf27d7 100644 --- a/tests/gmap_test.py +++ b/tests/gmap_test.py @@ -27,8 +27,10 @@ from absl.testing import parameterized import jax.numpy as jnp from jax import test_util as jtu from jax import vmap +from jax import lax from jax.experimental.general_map import gmap from jax.lib import xla_bridge +from jax.util import curry from jax.config import config config.parse_flags_with_absl() @@ -58,17 +60,36 @@ def tearDownModule(): xla_bridge.get_backend.cache_clear() +@curry +def skip_insufficient_devices(axis_size, fun): + @functools.wraps(fun) + def wrapper(*args, schedule, **kwargs): + for loop, n in schedule: + approx_n = axis_size if n is None else n + if loop == 'parallel' and approx_n > xla_bridge.device_count(): + raise SkipTest("this test requires more XLA devices") + return fun(*args, schedule=schedule, **kwargs) + return wrapper + +@curry +def check_default_schedules(cond, fun): + schedules = [ + ('seq', [('sequential', None)]), + ('vec', [('vectorized', None)]), + ('par', [('parallel', None)]), + ('lim_vmap', [('sequential', None), ('vectorized', 2)]), + ('soft_pmap', [('parallel', 2), ('vectorized', None)]) + ] + schedules = [s for s in schedules if cond(s[1])] + return parameterized.named_parameters( + {"testcase_name": "_" + name, "schedule": schedule} + for name, schedule in schedules)(fun) + + class GmapTest(jtu.JaxTestCase): - @parameterized.named_parameters( - {"testcase_name": "_" + name, "schedule": schedule} - for name, schedule in [ - ('seq', [('sequential', None)]), - ('vec', [('vectorized', None)]), - ('par', [('parallel', None)]), - ('lim_vmap', [('sequential', None), ('vectorized', 2)]), - ('soft_pmap', [('parallel', 2), ('vectorized', None)]) - ]) + @check_default_schedules(lambda _: True) + @skip_insufficient_devices(8) @ignore_gmap_warning() @skipIf(not config.omnistaging_enabled, "vmap collectives only supported when omnistaging is enabled") @@ -78,12 +99,30 @@ class GmapTest(jtu.JaxTestCase): x = jnp.arange(800).reshape((8, 10, 10)) - for loop, n in schedule: - approx_n = x.shape[0] if n is None else n - if loop == 'parallel' and approx_n > xla_bridge.device_count(): - raise SkipTest("this test requires more XLA devices") + self.assertAllClose(gmap(f, schedule)(x), vmap(f)(x)) - self.assertAllClose(vmap(f)(x), gmap(f, schedule)(x)) + @check_default_schedules(lambda s: not any(c[0] == 'sequential' for c in s)) + @skip_insufficient_devices(8) + @ignore_gmap_warning() + @skipIf(not config.omnistaging_enabled, + "vmap collectives only supported when omnistaging is enabled") + def testAxisName(self, schedule): + def f(x): + return x - lax.psum(x, 'i') + x = jnp.arange(8) + self.assertAllClose(gmap(f, schedule, axis_name='i')(x), + vmap(f, axis_name='i')(x)) + + @ignore_gmap_warning() + @skipIf(not config.omnistaging_enabled, + "vmap collectives only supported when omnistaging is enabled") + def testAxisName2d(self): + def f(x): + return x - lax.psum(x, 'i') + lax.pmax(x, 'j') + x = jnp.arange(8 * 8).reshape((8, 8)) + s = [('vectorized', None)] + self.assertAllClose(gmap(gmap(f, s, axis_name='i'), s, axis_name='j')(x), + vmap(vmap(f, axis_name='i'), axis_name='j')(x)) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())