mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

XLA:GPU custom call design is far from ideal, as there's apparently no way to figure out the CUDA context that will be used to run an HLO module before the custom call is first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid handle errors due to the load and launch happening in different CUDA contexts... Also fix up build_wheel.py to match the rename of the runtime lib. PiperOrigin-RevId: 629401858