From e949effcda6ccc806b5ca00cd3d7bf27927a3447 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 18 Mar 2025 18:59:50 -0700 Subject: [PATCH] [Pallas/Fuser] DCE fusion jaxprs before pulling (to avoid unnecessary computations being staged out in block functions) PiperOrigin-RevId: 738218113 --- jax/_src/pallas/fuser/block_spec.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index d0767aeeb..de0cdd204 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -244,6 +244,13 @@ def pull_block_spec( _unwrap_block_spec_scalar_prefetch, out_block_specs ) flat_block_specs, out_tree = jax.tree.flatten(block_specs_) + jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts( + jaxpr, + used_outputs=[True] * len(jaxpr.outvars), + instantiate=True, + ) + assert all(used_invars) + assert all(used_consts) in_block_specs, env, read_usage_env = _pull_block_spec( jaxpr, tuple(flat_block_specs),