mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas/Fuser] DCE fusion jaxprs before pulling (to avoid unnecessary computations being staged out in block functions)
PiperOrigin-RevId: 738218113
This commit is contained in:
parent
4d715753c4
commit
e949effcda
@ -244,6 +244,13 @@ def pull_block_spec(
|
|||||||
_unwrap_block_spec_scalar_prefetch, out_block_specs
|
_unwrap_block_spec_scalar_prefetch, out_block_specs
|
||||||
)
|
)
|
||||||
flat_block_specs, out_tree = jax.tree.flatten(block_specs_)
|
flat_block_specs, out_tree = jax.tree.flatten(block_specs_)
|
||||||
|
jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
|
||||||
|
jaxpr,
|
||||||
|
used_outputs=[True] * len(jaxpr.outvars),
|
||||||
|
instantiate=True,
|
||||||
|
)
|
||||||
|
assert all(used_invars)
|
||||||
|
assert all(used_consts)
|
||||||
in_block_specs, env, read_usage_env = _pull_block_spec(
|
in_block_specs, env, read_usage_env = _pull_block_spec(
|
||||||
jaxpr,
|
jaxpr,
|
||||||
tuple(flat_block_specs),
|
tuple(flat_block_specs),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user