mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00

This lets us break a dependency on standard MLIR dialects while serializing the program into HLO. The scheme is simple: we make a lightweight lazy fork of existing dialects by mangling the dialect name and otherwise keeping the structure of the ops identical. This keeps serialization and deserialization simple, for as long as the upstream dialects don't change much. If they do, we have to increment our version counter and write rules that update the IR structure. Note that this scheme only protects us from changes such as changing the attributes annotating the ops (renaming, etc.). However, it doesn't protect us from the attributes defined by a dialect from changing. Still, as far as I can tell, the only attributes we depend on are enums (which are simply plain integer attributes, so we can remap their values) and affine maps (that are unlikely to change much, I hope). This does not actually wire up the pass yet, as we are currently reorganizing the Python/C++ boundary significantly. The integration should be completed once that works is done. PiperOrigin-RevId: 595128374
jaxlib: support library for JAX
jaxlib is the support library for JAX. While JAX itself is a pure Python package, jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main JAX README: https://github.com/google/jax/.