mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
fix a bug with eager pmap + vmap + custom_jvp interaction
I used the same implementation technique in shard_map.py, e.g. in ShardMapTrace.process_custom_jvp_call, and it's sound, whereas I can't remember why we implementd the eager pmap stuff the way we did. This fixes an internal test, but unfortunately I wasn't able to figure out a simple repro :/
This commit is contained in:
parent
cf5a49584d
commit
b70ac9047d
@ -441,10 +441,11 @@ class MapTrace(core.Trace):
|
||||
return self.process_primitive(fake_primitive, tracers, params)
|
||||
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
|
||||
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
|
||||
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
|
||||
vals, shard_axes = unzip2((t.val, t.shard_axes) for t in tracers)
|
||||
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
|
||||
if ax is not None else s
|
||||
for v, ax, s in zip(vals, in_axes, shard_axes)]
|
||||
# TODO(mattjj): use _emap_subtrace here?
|
||||
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
|
||||
t = self.main.with_cur_sublevel()
|
||||
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
|
||||
@ -457,22 +458,29 @@ class MapTrace(core.Trace):
|
||||
return map(partial(MapTracer, self), out, outaxes)
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
||||
bind = HashableFunction(
|
||||
lambda *args, **kwargs: prim.bind(
|
||||
fun, jvp, *args, symbolic_zeros=symbolic_zeros, **kwargs),
|
||||
(prim, fun, jvp, symbolic_zeros))
|
||||
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
||||
return self.process_primitive(fake_primitive, tracers, {})
|
||||
if symbolic_zeros:
|
||||
msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. "
|
||||
"Please open an issue at https://github.com/google/jax/issues !")
|
||||
raise NotImplementedError(msg)
|
||||
del prim, jvp, symbolic_zeros # always base main, can drop jvp
|
||||
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
|
||||
fun, out_axes = _emap_subtrace(fun, self.main, in_axes)
|
||||
with core.new_sublevel():
|
||||
out_vals = fun.call_wrapped(*in_vals)
|
||||
return map(partial(MapTracer, self), out_vals, out_axes())
|
||||
|
||||
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
|
||||
out_trees, symbolic_zeros):
|
||||
bind = HashableFunction(
|
||||
lambda *args, **kwargs: primitive.bind(
|
||||
fun, fwd, bwd, *args, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros, **kwargs),
|
||||
(primitive, fun, fwd, bwd))
|
||||
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
||||
return self.process_primitive(fake_primitive, tracers, {})
|
||||
if symbolic_zeros:
|
||||
msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. "
|
||||
"Please open an issue at https://github.com/google/jax/issues !")
|
||||
raise NotImplementedError(msg)
|
||||
del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp
|
||||
in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers)
|
||||
fun, out_axes = _emap_subtrace(fun, self.main, in_axes)
|
||||
with core.new_sublevel():
|
||||
out_vals = fun.call_wrapped(*in_vals)
|
||||
return map(partial(MapTracer, self), out_vals, out_axes())
|
||||
|
||||
def process_axis_index(self, frame):
|
||||
bind = HashableFunction(
|
||||
@ -484,6 +492,16 @@ class MapTrace(core.Trace):
|
||||
dummy_tracer = MapTracer(self, range, {frame.name: 0})
|
||||
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _emap_subtrace(main, in_axes, *in_vals):
|
||||
t = main.with_cur_sublevel()
|
||||
in_tracers = map(partial(MapTracer, t), in_vals, in_axes)
|
||||
ans = yield in_tracers, {}
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
out_vals, out_axes = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
||||
del t, in_tracers, ans, out_tracers
|
||||
yield out_vals, out_axes
|
||||
|
||||
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
|
||||
annotation: int | None) -> int | None:
|
||||
if annotation is None: return None
|
||||
|
Loading…
x
Reference in New Issue
Block a user