mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 15:26:24 +00:00

Some primitives have very large lowerings. This is particularly true for lowerings that use `mlir.lower_fun` (e.g., the threefry PRNG kernel) or some XLA fallback lowerings. In this case it makes sense to lower such computations once for each signature as an out of line function that we can call multiple times. XLA will inline these functions early in compilation at the moment, but this avoids the need to repeatedly trace, e.g., the threefry kernel when emitting MHLO. PiperOrigin-RevId: 416818325
This directory contains LLVM FileCheck tests that verify that JAX primitives can be lowered to MHLO.
These tests are intended to be a quick and easy-to-understand way to catch regressions from changes due the MLIR Python bindings and from changes to the various MLIR dialects used by JAX, without needing to run the full JAX test suite.