28 Commits

Author SHA1 Message Date
Peter Hawkins
3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
jax authors
1aca76fc13 Update :build_jaxlib flag to control whether we should add py_import dependencies to the test targets.
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.

There are three options for running the tests:

1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.

PiperOrigin-RevId: 735765819
2025-03-11 08:31:43 -07:00
Sergei Lebedev
7929cd8410 [pallas:triton] The lowering now uses PTX instead of Triton IR
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.

See #25196.

A few notes

* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
  compilation is done via a new PjRt extension, which uses its own compilation
  pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
  deprecated and will be removed after 6 months as per
  [compatibility guarantees] [*]

[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees

PiperOrigin-RevId: 722773884
2025-02-03 13:21:40 -08:00
Dan Foreman-Mackey
a7f384cc6e Add a register_custom_type_id function to the GPU plugins.
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
jax authors
6d4278d290 Merge pull request #25091 from gspschmid:gschmid/nvidia-cuda-nvcc-cu12_12-6-85
PiperOrigin-RevId: 701143135
2024-11-28 18:47:55 -08:00
Georg Stefan Schmid
9866372d31 [cuda] Bump nvidia-cuda-nvcc-cu12 dependency to 12.6.85 2024-11-25 17:42:36 +00:00
Ruturaj4
e8934b95eb [ROCm] Add rocm version information 2024-11-25 10:21:48 -06:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Peter Hawkins
e4790b634e Don't pass --nocheck_visibility to Bazel.
This no longer appears to be needed.
2024-09-27 11:19:42 -04:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Michael Hudgins
7f3a90c63b Change references in setup.py and utilities to reference the JAX repo move to the JAX-ML org
PiperOrigin-RevId: 676838502
2024-09-20 07:32:15 -07:00
jax authors
e5eaff84bd Replace pjrt_c_api_gpu_plugin.so symlink with XLA dependency.
The runfiles of the original targets were lost when the symlinked files were used.

This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When pjrt_c_api_gpu_plugin.so is simlinked, the content of the runfiles is lost. With proper XLA target dependency the runfiles are preserved.

PiperOrigin-RevId: 662197057
2024-08-12 13:01:18 -07:00
Peter Hawkins
d1c0d993fc Bump the minimum CUDNN version to v9.1.
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.

Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.

PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00
Vadym Matsishevskyi
f089ecc47a Fix gpu_jax_head_jaxlib_pypi_latest job after migrating to plugin structure for jaxlib dependency
PiperOrigin-RevId: 648863763
2024-07-02 15:35:32 -07:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
jax authors
96cf5d53c8 Merge pull request #21916 from ROCm:ci_pjrt
PiperOrigin-RevId: 646793145
2024-06-26 02:43:21 -07:00
Ruturaj4
a00d030248 [ROCM] nits and fixes 2024-06-18 20:21:23 +00:00
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Peter Hawkins
b13733c13f Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
jax authors
0be07e6aec Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5

PiperOrigin-RevId: 618911489
2024-03-25 11:46:39 -07:00
jax authors
910a31d7b7 Reverts bed4f65438a62777ed100ecec2b0eb3f7cf87a0e
PiperOrigin-RevId: 618249855
2024-03-22 12:10:53 -07:00
jax authors
bed4f65438 Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

PiperOrigin-RevId: 618195554
2024-03-22 09:05:39 -07:00
Jieying Luo
087f99a31c Support mocking number of GPUs in CUDA plugin.
Also move reading jax config value to be right before the client is created. Previously they were read before calling register_plugin, which happens during import and before any call of jax.config.update.

The decorator in mock_gpu_test was used wrongly. jtu.run_on_devices will create the client before jax.config.update is called, which is not desired. Remove the decorator will not fail CPU/TPU tests because the mesh will check the num_shard and the number of devices in the client and skip it if it does not match.

generate_pjrt_gpu_plugin_options is only used in places that do not require compatibility so do not need to update xla_client version.

PiperOrigin-RevId: 611610915
2024-02-29 15:15:06 -08:00
Jieying Luo
29f1d3b033 [PJRT C API] Use xla_client.generate_pjrt_gpu_plugin_options to generate options for CUDA plugin.
PiperOrigin-RevId: 603074180
2024-01-31 09:42:58 -08:00
Jieying Luo
1559d6495e Remove local version in jax-cuda-plugin and jax-cuda-pjrt package.
PiperOrigin-RevId: 591057013
2023-12-14 14:44:49 -08:00
Jieying Luo
d6c5910105 [PJRT C API] Move cuda_plugin_extension from jaxlib to jax-cuda-plugin (the package for cuda kernels).
PiperOrigin-RevId: 583406466
2023-11-17 09:11:46 -08:00
Jieying Luo
88685d8de0 Support bazel test without bazel build for CUDA PJRT plugin.
- Add build target for jax_plugins/ and jax_plugins/cuda for bazel test.
- Update jax_plugins/cuda/__init__.py to fallback to local `.so` file path.
- Add a flag --//jax:build_cuda_plugin to control whether to link in local cuda plugin.

The following command will test with cuda plugin:
```
bazel test tests:python_callback_test_gpu --test_output=all --test_filter=PythonCallbackTest.test_send_zero_dim_arrays_pure --config=tensorflow_testing_rbe_linux --config=rbe_linux_cuda12.2_nvcc_py3.9 --//jax:build_cuda_plugin=false
```

Default behavior (without `--//jax:build_cuda_plugin=false`) remains unchanged.

PiperOrigin-RevId: 582728477
2023-11-15 10:38:19 -08:00
Jieying Luo
ec21e04201 [PJRT C API] Rename the folder "plugins" to "jax_plugins".
With this change, existing plugin discovery mechanism can discover local plugins without pip install.

Update jax_plugins/cuda/__init__.py to return without registering the plugin if the .so file does not exist.

PiperOrigin-RevId: 582431300
2023-11-14 13:56:13 -08:00