mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Skip unneccessary unflattening of avals in pjit lowering path.
The avals get flattened again when calling `from_flat_info` (here:
1641c8f141/jax/_src/stages.py (L347)
),
so skip unflattening here.
PiperOrigin-RevId: 504260643
This commit is contained in:
parent
1641c8f141
commit
7064be1a76
@ -266,13 +266,12 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
|
||||
if kwargs:
|
||||
args_kwargs_in_tree = in_tree
|
||||
local_in_avals = in_tree.unflatten(flat_local_in_avals)
|
||||
else:
|
||||
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
|
||||
|
||||
return stages.Lowered.from_flat_info(
|
||||
lowering, args_kwargs_in_tree, local_in_avals, donate_argnums, out_tree)
|
||||
lowering, args_kwargs_in_tree, flat_local_in_avals, donate_argnums,
|
||||
out_tree)
|
||||
|
||||
wrapped.lower = lower
|
||||
return wrapped
|
||||
|
Loading…
x
Reference in New Issue
Block a user