mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

To work around another buggy part of the PTX documentation. While PTX explicitly says that TMA descriptors can be in global memory, the C++ programming guide heavily discurages this, because it can lead to incorrrect results. Which is also what we've sometimes observed as a cache coherency issue unless a TMA fence is explicitly inserted at the beginning of the kernel. Note that this approach has a big downside of making the kernel unsafe for concurrent use. I don't think that XLA:GPU will ever dispatch it concurrently so I didn't insert any extra synchronization for now, but we should seriously consider it. My hope at the moment is that we'll be able to start passing in TMA descs as kernel args soon (pending upstreaming LLVM changes...) and we won't have to deal with this again. For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays PiperOrigin-RevId: 643972675
jaxlib: support library for JAX
jaxlib is the support library for JAX. While JAX itself is a pure Python package, jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main JAX README: https://github.com/google/jax/.