Adam Paszke 4ea73bf787 Use constant memory to pass in TMA descriptors to the kernel
To work around another buggy part of the PTX documentation. While PTX
explicitly says that TMA descriptors can be in global memory, the C++
programming guide heavily discurages this, because it can lead to
incorrrect results. Which is also what we've sometimes observed as
a cache coherency issue unless a TMA fence is explicitly inserted at the
beginning of the kernel.

Note that this approach has a big downside of making the kernel unsafe
for concurrent use. I don't think that XLA:GPU will ever dispatch it
concurrently so I didn't insert any extra synchronization for now, but
we should seriously consider it. My hope at the moment is that we'll
be able to start passing in TMA descs as kernel args soon (pending
upstreaming LLVM changes...) and we won't have to deal with this again.

For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays

PiperOrigin-RevId: 643972675
2024-06-17 05:31:26 -07:00
..
2024-05-18 00:37:09 +00:00
2024-05-18 00:37:09 +00:00
2024-03-25 11:46:39 -07:00
2024-03-25 11:46:39 -07:00
2024-03-25 11:46:39 -07:00
2024-03-25 11:46:39 -07:00
2024-06-07 13:16:33 -07:00

jaxlib: support library for JAX

jaxlib is the support library for JAX. While JAX itself is a pure Python package, jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main JAX README: https://github.com/google/jax/.