mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

There are currently two parameters that are used to configure lowering: lowering_platform (for cross-platform lowering), and override_lowering_rules. Each of them are passed as separate arguments through several layers of lowering internal functions. This is tedious, and error prone. In fact, override_lowering_rules was not plumbed in all places, and due to using default arguments in all places, this leads to silent errors. We foresee introducing other parameters for lowering: for multi-platform lowering, for controlling the lowering of effects. Here is pack all such parameters into a `mlir.LoweringParameters` dataclass and we plumb that through.