fix a bug with eager pmap + vmap + custom_jvp interaction

I used the same implementation technique in shard_map.py, e.g. in ShardMapTrace.process_custom_jvp_call, and it's sound, whereas I can't remember why we implementd the eager pmap stuff the way we did.

This fixes an internal test, but unfortunately I wasn't able to figure out a simple repro :/
This commit is contained in:
Matthew Johnson 2023-12-14 12:41:56 -08:00
parent cf5a49584d
commit b70ac9047d

View File

@ -441,10 +441,11 @@ class MapTrace(core.Trace):
return self.process_primitive(fake_primitive, tracers, params)
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
vals, shard_axes = unzip2((t.val, t.shard_axes) for t in tracers)
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
if ax is not None else s
for v, ax, s in zip(vals, in_axes, shard_axes)]
# TODO(mattjj): use _emap_subtrace here?
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
t = self.main.with_cur_sublevel()
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
@ -457,22 +458,29 @@ class MapTrace(core.Trace):
return map(partial(MapTracer, self), out, outaxes)
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
bind = HashableFunction(
lambda *args, **kwargs: prim.bind(
fun, jvp, *args, symbolic_zeros=symbolic_zeros, **kwargs),
(prim, fun, jvp, symbolic_zeros))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})
if symbolic_zeros:
msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. "
"Please open an issue at https://github.com/google/jax/issues !")
raise NotImplementedError(msg)
del prim, jvp, symbolic_zeros # always base main, can drop jvp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
fun, out_axes = _emap_subtrace(fun, self.main, in_axes)
with core.new_sublevel():
out_vals = fun.call_wrapped(*in_vals)
return map(partial(MapTracer, self), out_vals, out_axes())
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
bind = HashableFunction(
lambda *args, **kwargs: primitive.bind(
fun, fwd, bwd, *args, out_trees=out_trees,
symbolic_zeros=symbolic_zeros, **kwargs),
(primitive, fun, fwd, bwd))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})
if symbolic_zeros:
msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. "
"Please open an issue at https://github.com/google/jax/issues !")
raise NotImplementedError(msg)
del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
fun, out_axes = _emap_subtrace(fun, self.main, in_axes)
with core.new_sublevel():
out_vals = fun.call_wrapped(*in_vals)
return map(partial(MapTracer, self), out_vals, out_axes())
def process_axis_index(self, frame):
bind = HashableFunction(
@ -484,6 +492,16 @@ class MapTrace(core.Trace):
dummy_tracer = MapTracer(self, range, {frame.name: 0})
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
@lu.transformation_with_aux
def _emap_subtrace(main, in_axes, *in_vals):
t = main.with_cur_sublevel()
in_tracers = map(partial(MapTracer, t), in_vals, in_axes)
ans = yield in_tracers, {}
out_tracers = map(t.full_raise, ans)
out_vals, out_axes = unzip2((t.val, t.shard_axes) for t in out_tracers)
del t, in_tracers, ans, out_tracers
yield out_vals, out_axes
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
annotation: int | None) -> int | None:
if annotation is None: return None