mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00

The one bundled with the default MLIR runtime was convenient, but it is also impractical. It allocates memory (which can deadlock due to NCCL), does a synchronous host-to-device copy and then leaks the descriptor after the kernel... With this change, we use our own runtime function to create all the descriptors. What's more, we pack them all into a single buffer so that a single asynchronous copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer, letting us lean on XLA:GPU for memory management. PiperOrigin-RevId: 628430358
jaxlib: support library for JAX
jaxlib is the support library for JAX. While JAX itself is a pure Python package, jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main JAX README: https://github.com/google/jax/.