mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
tweak shmap implementation to work better with leak checker
This commit is contained in:
parent
f1e0741890
commit
f2bef6bb5c
@ -1928,18 +1928,25 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk):
|
||||
fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps)
|
||||
return _match_rep(fun, mesh, out_reps_src, out_reps_dst)
|
||||
|
||||
def _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps):
|
||||
return _efficient_transpose_outer(_efficient_transpose_inner(fun), mesh, in_reps)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args):
|
||||
def _efficient_transpose_outer(mesh, in_reps, *args):
|
||||
lvl = core.dynamic_level()
|
||||
with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main:
|
||||
t = main.with_cur_sublevel()
|
||||
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
|
||||
ans = yield in_tracers, {}
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers)
|
||||
del main, t, in_tracers, out_tracers, ans
|
||||
out_vals, out_reps = yield (main, mesh, in_reps, args), {}
|
||||
del main
|
||||
yield out_vals, out_reps
|
||||
|
||||
@lu.transformation
|
||||
def _efficient_transpose_inner(main, mesh, in_reps, args):
|
||||
t = main.with_cur_sublevel()
|
||||
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
|
||||
ans = yield in_tracers, {}
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
yield unzip2((t.val, t.rep) for t in out_tracers)
|
||||
|
||||
@lu.transformation
|
||||
def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args):
|
||||
outs = yield args, {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user