mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

Some users of JAX want to use the MLIR dialects defined in jaxlib. In particular, these need to be used by custom lowering rules. Add a semi-public (jax.extend) API to access these, rather than having them use jax._src.lib.mlir. PiperOrigin-RevId: 588448489