mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Pallas/Fuser] DCE fusion jaxprs before pulling (to avoid unnecessary computations being staged out in block functions)
PiperOrigin-RevId: 738218113
This commit is contained in:
parent
4d715753c4
commit
e949effcda
@ -244,6 +244,13 @@ def pull_block_spec(
|
||||
_unwrap_block_spec_scalar_prefetch, out_block_specs
|
||||
)
|
||||
flat_block_specs, out_tree = jax.tree.flatten(block_specs_)
|
||||
jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
|
||||
jaxpr,
|
||||
used_outputs=[True] * len(jaxpr.outvars),
|
||||
instantiate=True,
|
||||
)
|
||||
assert all(used_invars)
|
||||
assert all(used_consts)
|
||||
in_block_specs, env, read_usage_env = _pull_block_spec(
|
||||
jaxpr,
|
||||
tuple(flat_block_specs),
|
||||
|
Loading…
x
Reference in New Issue
Block a user