Skip unneccessary unflattening of avals in pjit lowering path.

The avals get flattened again when calling `from_flat_info` (here:
1641c8f141/jax/_src/stages.py (L347)),
so skip unflattening here.

PiperOrigin-RevId: 504260643
This commit is contained in:
Lena Martens 2023-01-24 06:44:51 -08:00 committed by jax authors
parent 1641c8f141
commit 7064be1a76

View File

@ -266,13 +266,12 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
if kwargs:
args_kwargs_in_tree = in_tree
local_in_avals = in_tree.unflatten(flat_local_in_avals)
else:
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
return stages.Lowered.from_flat_info(
lowering, args_kwargs_in_tree, local_in_avals, donate_argnums, out_tree)
lowering, args_kwargs_in_tree, flat_local_in_avals, donate_argnums,
out_tree)
wrapped.lower = lower
return wrapped