mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00

The one bundled with the default MLIR runtime was convenient, but it is also impractical. It allocates memory (which can deadlock due to NCCL), does a synchronous host-to-device copy and then leaks the descriptor after the kernel... With this change, we use our own runtime function to create all the descriptors. What's more, we pack them all into a single buffer so that a single asynchronous copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer, letting us lean on XLA:GPU for memory management. PiperOrigin-RevId: 628430358