This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.
Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.
This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).
Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.
PiperOrigin-RevId: 561042402
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.
Unchanged occurrences:
1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
argument value in Lowering.as_text and Lowering.compiler_ir.
2) Documentation (changelog, JEPs, IR examples, etc).
3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
so both are necessary to disambiguate.
PiperOrigin-RevId: 495771153
In passing refactor and fix some bugs in the MHLO helper code:
* mlir.ir_constant() failed to propagate its canonicalize_types argument to its callee.
* Refactor the code to convert an XLA computation to an MHLO module and to merge two MHLO modules from the XLA fallback translation rule path.
* Fix symbol (alpha) renaming of call operator callees when merging MHLO modules.
Attempt 2: In this iteration of the merge_mhlo_modules function, move all the operators into the target module first before doing any symbol table manipulation.
PiperOrigin-RevId: 442904129
In passing refactor and fix some bugs in the MHLO helper code:
* mlir.ir_constant() failed to propagate its canonicalize_types argument to its callee.
* Refactor the code to convert an XLA computation to an MHLO module and to merge two MHLO modules from the XLA fallback translation rule path.
* Fix symbol (alpha) renaming of call operator callees when merging MHLO modules.
PiperOrigin-RevId: 442803807
In passing refactor and fix some bugs in the MHLO helper code:
* mlir.ir_constant() failed to propagate its canonicalize_types argument to its callee.
* Refactor the code to convert an XLA computation to an MHLO module and to merge two MHLO modules from the XLA fallback translation rule path.
* Fix symbol (alpha) renaming of call operator callees when merging MHLO modules.
PiperOrigin-RevId: 442798170
This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.
PiperOrigin-RevId: 439324450
Some primitives have very large lowerings. This is particularly true for lowerings that use `mlir.lower_fun` (e.g., the threefry PRNG kernel) or some XLA fallback lowerings. In this case it makes sense to lower such computations once for each signature as an out of line function that we can call multiple times.
XLA will inline these functions early in compilation at the moment, but this avoids the need to repeatedly trace, e.g., the threefry kernel when emitting MHLO.
PiperOrigin-RevId: 416818325