[Pallas/Fuser] DCE fusion jaxprs before pulling (to avoid unnecessary computations being staged out in block functions)

PiperOrigin-RevId: 738218113
This commit is contained in:
Sharad Vikram 2025-03-18 18:59:50 -07:00 committed by jax authors
parent 4d715753c4
commit e949effcda

View File

@ -244,6 +244,13 @@ def pull_block_spec(
_unwrap_block_spec_scalar_prefetch, out_block_specs
)
flat_block_specs, out_tree = jax.tree.flatten(block_specs_)
jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
jaxpr,
used_outputs=[True] * len(jaxpr.outvars),
instantiate=True,
)
assert all(used_invars)
assert all(used_consts)
in_block_specs, env, read_usage_env = _pull_block_spec(
jaxpr,
tuple(flat_block_specs),