49 Commits

Author SHA1 Message Date
Peter Hawkins
41f0b336e3 Add minimum version checks for cublas and cusparse.
Split code to determine CUDA library versions out of py_extension() module and into a cc_library(), because it fixes a linking problem in Google's build. (Long story, not worth it.)

Fixes https://github.com/google/jax/issues/8289

PiperOrigin-RevId: 583544218
2023-11-17 19:30:41 -08:00
jax authors
88fe0da6d1 Merge pull request #18078 from ROCmSoftwarePlatform:rocm-jax-triton
PiperOrigin-RevId: 574546618
2023-10-18 11:56:01 -07:00
Jieying Luo
7478fbcfd5 [PJRT C API] Add "cuda_plugin_extension" to "gpu_only_test_deps" to support bazel test for GPU plugin.
PiperOrigin-RevId: 573251982
2023-10-13 10:12:16 -07:00
Peter Hawkins
2eca5b34b3 Add a compile-time version test that verifies CUDA is version 11.8 or newer.
Issue https://github.com/google/jax/issues/17829

PiperOrigin-RevId: 569302585
2023-09-28 15:14:04 -07:00
Peter Hawkins
53845615ff Disable nanobind leak checker in cuda/versions module.
The leak checker appears to be sensitive to the destruction order during Python shutdown.

PiperOrigin-RevId: 568962933
2023-09-27 14:43:20 -07:00
Peter Hawkins
9404518201 [CUDA] Add code to jax initialization that verifies that the CUDA libraries that are found are at least as new as the versions against which JAX was built.
This is intended to flag cases where the wrong CUDA libraries are used, either because:
* the user self-installed CUDA and that installation is too old, or
* the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version.

PiperOrigin-RevId: 568910422
2023-09-27 11:28:40 -07:00
Peter Hawkins
8c70288b83 Refer to CUDA stubs directly from TSL, rather than using an alias defined in xla/stream_executor.
Remove the aliases in xla/stream_executor.

PiperOrigin-RevId: 567025507
2023-09-20 11:21:54 -07:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Richard Levasseur
f891cbf64b Load Python rules from rules_python
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
Chris Jones
f70f1f8006 Internal change.
PiperOrigin-RevId: 559053761
2023-08-22 03:05:17 -07:00
Chris Jones
4ac2bdc2b1 [jax_triton] Add user-specified name field to serialized format.
PiperOrigin-RevId: 557415723
2023-08-16 02:53:51 -07:00
Chris Jones
9935445d57 [jax_triton] Simplify auto-tuning code.
PiperOrigin-RevId: 545733541
2023-07-05 11:18:18 -07:00
Chris Jones
31b862dd56 [jax_triton] Split C++ only parts of Triton custom callback from Python parts.
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).

PiperOrigin-RevId: 545211452
2023-07-03 06:52:32 -07:00
Chris Jones
d4e2464340 [jax_triton] Expose Triton custom call callback in header file.
This allows users to register the callback from C++ when not using the default call target name.

PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00
Chris Jones
b3527f3975 Zlib compress kernel proto.
PiperOrigin-RevId: 542529065
2023-06-22 05:22:53 -07:00
Chris Jones
f238667492 Make JAX-Triton calls serializable.
PiperOrigin-RevId: 542524794
2023-06-22 04:57:14 -07:00
Chris Jones
64e73270ff Use EncapsulateFunction utility.
PiperOrigin-RevId: 542299099
2023-06-21 10:37:52 -07:00
Sharad Vikram
1279418ce5 Link in CUDA runtime for triton in jaxlib
PiperOrigin-RevId: 535708416
2023-05-26 14:02:16 -07:00
Chris Jones
ea37043577 Switch to STATUS_RETURNING callback API.
PiperOrigin-RevId: 535568707
2023-05-26 03:15:44 -07:00
Chris Jones
2155b9181f Switch to using JAX status macros in jax-triton kernel call lib.
PiperOrigin-RevId: 535300412
2023-05-25 10:26:06 -07:00
Chris Jones
6b13d4eb86 Add branch prediction to JAX status macros.
PiperOrigin-RevId: 535233546
2023-05-25 06:23:23 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
b62f114524 Add support for using pip-installed CUDA wheels.
Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels.
Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency.
2023-03-26 12:35:00 +00:00
Peter Hawkins
ab45383038 Fix build breakage from OpenXLA switch.
PiperOrigin-RevId: 516325478
2023-03-13 14:37:35 -07:00
jax authors
42ef649e65 Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Yash Katariya
3e5a5053f4 Run GPU presubmits via bazel test on the RBE cluster. This speeds up the build + testing significantly (upto 10x).
But run the continuous builds by building on RBE and testing locally so as to run the multiaccelerator tests too. Locally we have 4 GPUs available.

Also make GPU presubmits blocking for JAX (re-enabled it).

PiperOrigin-RevId: 491647775
2022-11-29 08:45:58 -08:00
Qiao Zhang
c54bc90bf4 Fix cudnn_header OSS BUILD dep.
PiperOrigin-RevId: 491465703
2022-11-28 15:58:55 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Tianjian Lu
e219d55c36 Roll-back #12892 because CUSPARSE_SPMV_COO_ALG2 is not available in CUDA 11.1
PiperOrigin-RevId: 482897448
2022-10-21 15:06:17 -07:00
Tianjian Lu
7093142f61 [sparse] Update the default cuSparse matvec algorithm in jaxlib.
PiperOrigin-RevId: 482553550
2022-10-20 11:49:09 -07:00
Peter Hawkins
5617a02fa4 Remove JAX custom call implementation of batched triangular solve.
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.

PiperOrigin-RevId: 482031830
2022-10-18 15:04:14 -07:00
Peter Hawkins
2693afa263 Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
2022-10-07 14:36:24 -07:00
Peter Hawkins
93b839ace4 Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

PiperOrigin-RevId: 479642543
2022-10-07 12:22:04 -07:00
Artem Belevich
2de91d26b7 Handle FP8 types.
PiperOrigin-RevId: 479148993
2022-10-05 14:48:30 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
David Dunleavy
a8aa774a57 Use tensorflow/compiler/xla/stream_executor instead of tensorflow/stream_executor
PiperOrigin-RevId: 470804752
2022-08-29 13:46:20 -07:00
Tianjian Lu
d37b711dd4 [sparse] Add batch count and batch stride to matrix descriptors.
PiperOrigin-RevId: 468760351
2022-08-19 12:26:17 -07:00
Deniz Oktay
d5de596d17 Sparse direct solver via QR factorization CUDA implementation.
PiperOrigin-RevId: 468467698
2022-08-18 08:46:25 -07:00
Peter Hawkins
3bb0030014 Revert: Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
Reason: Breaks JAX tests.
PiperOrigin-RevId: 468346430
2022-08-17 18:54:29 -07:00
Deniz Oktay
2bc3e39cd9 Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
PiperOrigin-RevId: 468303019
2022-08-17 15:10:27 -07:00
Tianjian Lu
07da502323 [sparse] Enable batch mode of COO matmat from cusparse kernels.
PiperOrigin-RevId: 465405490
2022-08-04 14:30:02 -07:00
Peter Hawkins
aa7d291767 Replace references to absl::string_view with std::string_view.
PiperOrigin-RevId: 450768333
2022-05-24 14:21:32 -07:00
Peter Hawkins
bb0816227d Add a batched QR decomposition implementation on GPU.
PiperOrigin-RevId: 449583027
2022-05-18 14:50:18 -07:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00