mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 18:06:07 +00:00

Refactoring only, no functional changes intended. Previously the MLIR lowering rule signature was ``` def rule(ctx, avals_in, avals_out, *args, **jaxpr_params): ``` where `ctx` was a module-wide context. Change it to ``` def rule(ctx, *args, **jaxpr_params) ``` where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`. This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context. PiperOrigin-RevId: 416698663