Peter Hawkins d0a6813ea2 Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call().
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.

Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.

This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).

Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.

PiperOrigin-RevId: 561042402
2023-08-29 08:50:07 -07:00
..

This directory contains LLVM FileCheck tests that verify that JAX primitives can be lowered to MLIR.

These tests are intended to be a quick and easy-to-understand way to catch regressions from changes due the MLIR Python bindings and from changes to the various MLIR dialects used by JAX, without needing to run the full JAX test suite.