Peter Hawkins 18baa6e93b [MLIR] Add a @mlir.cache_lowering decorator that lowers a primitive out-of-line as a reusable function.
Some primitives have very large lowerings. This is particularly true for lowerings that use `mlir.lower_fun` (e.g., the threefry PRNG kernel) or some XLA fallback lowerings. In this case it makes sense to lower such computations once for each signature as an out of line function that we can call multiple times.

XLA will inline these functions early in compilation at the moment, but this avoids the need to repeatedly trace, e.g., the threefry kernel when emitting MHLO.

PiperOrigin-RevId: 416818325
2021-12-16 08:34:52 -08:00
..

This directory contains LLVM FileCheck tests that verify that JAX primitives can be lowered to MHLO.

These tests are intended to be a quick and easy-to-understand way to catch regressions from changes due the MLIR Python bindings and from changes to the various MLIR dialects used by JAX, without needing to run the full JAX test suite.