9 Commits

Author SHA1 Message Date
Peter Hawkins
3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
Dan Foreman-Mackey
a7f384cc6e Add a register_custom_type_id function to the GPU plugins.
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
Ruturaj4
e8934b95eb [ROCm] Add rocm version information 2024-11-25 10:21:48 -06:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Michael Hudgins
7f3a90c63b Change references in setup.py and utilities to reference the JAX repo move to the JAX-ML org
PiperOrigin-RevId: 676838502
2024-09-20 07:32:15 -07:00
jax authors
e5eaff84bd Replace pjrt_c_api_gpu_plugin.so symlink with XLA dependency.
The runfiles of the original targets were lost when the symlinked files were used.

This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When pjrt_c_api_gpu_plugin.so is simlinked, the content of the runfiles is lost. With proper XLA target dependency the runfiles are preserved.

PiperOrigin-RevId: 662197057
2024-08-12 13:01:18 -07:00
Ruturaj4
a00d030248 [ROCM] nits and fixes 2024-06-18 20:21:23 +00:00
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00