Merge pull request #23387 from mattjj:shmap-leak-checker

PiperOrigin-RevId: 670380518
This commit is contained in:
jax authors 2024-09-02 20:27:08 -07:00
commit 281bfcdc62

View File

@ -1928,18 +1928,25 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk):
fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps)
return _match_rep(fun, mesh, out_reps_src, out_reps_dst)
def _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps):
return _efficient_transpose_outer(_efficient_transpose_inner(fun), mesh, in_reps)
@lu.transformation_with_aux
def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args):
def _efficient_transpose_outer(mesh, in_reps, *args):
lvl = core.dynamic_level()
with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main:
t = main.with_cur_sublevel()
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
ans = yield in_tracers, {}
out_tracers = map(t.full_raise, ans)
out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers)
del main, t, in_tracers, out_tracers, ans
out_vals, out_reps = yield (main, mesh, in_reps, args), {}
del main
yield out_vals, out_reps
@lu.transformation
def _efficient_transpose_inner(main, mesh, in_reps, args):
t = main.with_cur_sublevel()
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
ans = yield in_tracers, {}
out_tracers = map(t.full_raise, ans)
yield unzip2((t.val, t.rep) for t in out_tracers)
@lu.transformation
def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args):
outs = yield args, {}