Adam Paszke 9b0319512a [Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
2024-04-26 09:40:47 -07:00
..
2024-04-08 19:58:41 -05:00
2024-04-18 04:04:10 -07:00
2024-03-25 11:46:39 -07:00
2024-03-25 11:46:39 -07:00
2024-03-25 11:46:39 -07:00
2024-03-25 11:46:39 -07:00
2024-03-25 11:46:39 -07:00
2024-03-25 11:46:39 -07:00
2024-04-18 04:04:10 -07:00
2024-04-18 04:04:10 -07:00

jaxlib: support library for JAX

jaxlib is the support library for JAX. While JAX itself is a pure Python package, jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main JAX README: https://github.com/google/jax/.