Benjamin Chetioui a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
..
2025-02-12 16:12:42 +00:00
2024-06-26 16:10:18 -04:00
2025-02-28 03:13:39 -08:00

jaxlib: support library for JAX

jaxlib is the support library for JAX. While JAX itself is a pure Python package, jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main JAX README: https://github.com/jax-ml/jax/.