From b70ac9047d5eebcac1b7bd94ef27541d9b86e7fa Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 14 Dec 2023 12:41:56 -0800 Subject: [PATCH] fix a bug with eager pmap + vmap + custom_jvp interaction I used the same implementation technique in shard_map.py, e.g. in ShardMapTrace.process_custom_jvp_call, and it's sound, whereas I can't remember why we implementd the eager pmap stuff the way we did. This fixes an internal test, but unfortunately I wasn't able to figure out a simple repro :/ --- jax/_src/interpreters/pxla.py | 46 ++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index fdbe4edb8..aada85d1e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -441,10 +441,11 @@ class MapTrace(core.Trace): return self.process_primitive(fake_primitive, tracers, params) axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"], params["in_axes"], params["out_axes_thunk"], params["axis_size"]) - vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers]) + vals, shard_axes = unzip2((t.val, t.shard_axes) for t in tracers) shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s} if ax is not None else s for v, ax, s in zip(vals, in_axes, shard_axes)] + # TODO(mattjj): use _emap_subtrace here? with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main): t = self.main.with_cur_sublevel() in_tracers = map(partial(MapTracer, t), vals, shard_axes) @@ -457,22 +458,29 @@ class MapTrace(core.Trace): return map(partial(MapTracer, self), out, outaxes) def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - bind = HashableFunction( - lambda *args, **kwargs: prim.bind( - fun, jvp, *args, symbolic_zeros=symbolic_zeros, **kwargs), - (prim, fun, jvp, symbolic_zeros)) - fake_primitive = FakePrimitive(multiple_results=True, bind=bind) - return self.process_primitive(fake_primitive, tracers, {}) + if symbolic_zeros: + msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. " + "Please open an issue at https://github.com/google/jax/issues !") + raise NotImplementedError(msg) + del prim, jvp, symbolic_zeros # always base main, can drop jvp + in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) + fun, out_axes = _emap_subtrace(fun, self.main, in_axes) + with core.new_sublevel(): + out_vals = fun.call_wrapped(*in_vals) + return map(partial(MapTracer, self), out_vals, out_axes()) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): - bind = HashableFunction( - lambda *args, **kwargs: primitive.bind( - fun, fwd, bwd, *args, out_trees=out_trees, - symbolic_zeros=symbolic_zeros, **kwargs), - (primitive, fun, fwd, bwd)) - fake_primitive = FakePrimitive(multiple_results=True, bind=bind) - return self.process_primitive(fake_primitive, tracers, {}) + if symbolic_zeros: + msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. " + "Please open an issue at https://github.com/google/jax/issues !") + raise NotImplementedError(msg) + del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp + in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) + fun, out_axes = _emap_subtrace(fun, self.main, in_axes) + with core.new_sublevel(): + out_vals = fun.call_wrapped(*in_vals) + return map(partial(MapTracer, self), out_vals, out_axes()) def process_axis_index(self, frame): bind = HashableFunction( @@ -484,6 +492,16 @@ class MapTrace(core.Trace): dummy_tracer = MapTracer(self, range, {frame.name: 0}) return self.process_primitive(fake_primitive, (dummy_tracer,), {}) +@lu.transformation_with_aux +def _emap_subtrace(main, in_axes, *in_vals): + t = main.with_cur_sublevel() + in_tracers = map(partial(MapTracer, t), in_vals, in_axes) + ans = yield in_tracers, {} + out_tracers = map(t.full_raise, ans) + out_vals, out_axes = unzip2((t.val, t.shard_axes) for t in out_tracers) + del t, in_tracers, ans, out_tracers + yield out_vals, out_axes + def _annot_to_flat(ndim: int, mapped_axes: Iterable[int], annotation: int | None) -> int | None: if annotation is None: return None