54 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
b6306e3953 Remove synchronization from GPU LU decomposition kernel by adding an async batch pointers builder.
In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO.

PiperOrigin-RevId: 663686539
2024-08-16 04:37:09 -07:00
Dan Foreman-Mackey
ad1bd38790 Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel.
This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.

PiperOrigin-RevId: 662945116
2024-08-14 09:20:40 -07:00
Dan Foreman-Mackey
8df0c3a9cc Port Getrf GPU kernel from custom call to FFI.
PiperOrigin-RevId: 658550170
2024-08-01 15:02:25 -07:00
Dan Foreman-Mackey
f20efc630f Move jaxlib GPU handlers to separate build target.
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.

PiperOrigin-RevId: 658497960
2024-08-01 12:30:04 -07:00
Dan Foreman-Mackey
33a9db3943 Move FFI helper macros from jaxlib/cpu/lapack_kernels.cc to a jaxlib/ffi_helpers.h.
Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to
be useful for other FFI calls. This change consolidates these macros in the
ffi_helper header.

PiperOrigin-RevId: 651166306
2024-07-10 15:09:45 -07:00
Dan Foreman-Mackey
4f394828e1 Fix C++ registration of FFI handlers and consolidate gpu/linalg kernel implementation.
This change does a few things (arguably too many):

1. The key change here is that it fixes the handler registration in `jaxlib/gpu/gpu_kernels.cc` for the two handlers that use the XLA FFI API. A previous attempt at this change caused downstream issues because of duplicate registrations, but we were able to fix that directly in XLA.

2. A second related change is to declare and define the XLA FFI handlers consistently using the `XLA_FFI_DECLARE_HANDLER_SYMBOL` and `XLA_FFI_DEFINE_HANDLER_SYMBOL` macros. We need to use these macros instead of the `XLA_FFI_DEFINE_HANDLER` version which produces a lambda, so that when XLA checks the address of the handler during registration it is consistent. Without this change, the downstream tests would continue to fail.

3. The final change is to consolidate the `cholesky_update_kernel` and `lu_pivot_kernels` implementations into a common `linalg_kernels` target. This makes the implementation of the `_linalg` nanobind module consistent with the other targets within `jaxlib/gpu`, and (I think!) makes the details easier to follow. This last change is less urgent, but it was what I set out to do so that's why I'm suggesting them all together, but I can split this in two if that would be preferred.

PiperOrigin-RevId: 651107659
2024-07-10 12:09:12 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
Dan Foreman-Mackey
9ae1c56c44 Update lu_pivots_to_permutation to use FFI dimensions on GPU.
The XLA FFI interface provides metadata about buffer dimensions, so quantities
like batch dimensions can be evaluated on the backend, instead of passed as
attributes. This change has the added benefit of allowing this FFI call to
support "vectorized" vmap and dynamic shapes.

PiperOrigin-RevId: 647343656
2024-06-27 09:27:15 -07:00
Thomas Köppe
cd93b46df4 Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions.
PiperOrigin-RevId: 641879836
2024-06-10 06:21:16 -07:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Sergei Lebedev
51fc4f85ad Ported LuPivotsToPermutation to the typed XLA FFI
The typed FFI

* allows passing custom call attributes directly to backend_config= instead
  of serializing them into a C++ struct.
* It also handles validation and deserialization of custom call operands.

PiperOrigin-RevId: 630067005
2024-05-02 08:12:05 -07:00
Adam Paszke
9b0319512a [Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
2024-04-26 09:40:47 -07:00
Marvin Kim
90e9e47a55 [Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched.
For configs that cannot be launched, we should not launch them via benchmark.

PiperOrigin-RevId: 626153377
2024-04-18 14:35:51 -07:00
Jieying Luo
44e83d4e0a Add a few custom call registrations to gpu_kernel to keep in-sync with callers of xla_client.register_custom_call_target.
PiperOrigin-RevId: 624275186
2024-04-12 13:30:18 -07:00
Henning Becker
9809aa1929 Move CUDA specific functions from asm_compiler to cuda_asm_compiler target
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target

The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.

Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.

PiperOrigin-RevId: 623285804
2024-04-09 14:43:41 -07:00
David Dunleavy
aade591fdf Move tsl/python to xla/tsl/python
PiperOrigin-RevId: 620320903
2024-03-29 13:15:21 -07:00
Michael Hudgins
023930decf Fix some load orderings for buildifier
PiperOrigin-RevId: 619575196
2024-03-27 10:28:57 -07:00
David Dunleavy
be3e39ad3b Move tsl/cuda to xla/tsl/cuda
PiperOrigin-RevId: 610550833
2024-02-26 15:45:10 -08:00
Peter Hawkins
41f0b336e3 Add minimum version checks for cublas and cusparse.
Split code to determine CUDA library versions out of py_extension() module and into a cc_library(), because it fixes a linking problem in Google's build. (Long story, not worth it.)

Fixes https://github.com/google/jax/issues/8289

PiperOrigin-RevId: 583544218
2023-11-17 19:30:41 -08:00
jax authors
88fe0da6d1 Merge pull request #18078 from ROCmSoftwarePlatform:rocm-jax-triton
PiperOrigin-RevId: 574546618
2023-10-18 11:56:01 -07:00
Jieying Luo
7478fbcfd5 [PJRT C API] Add "cuda_plugin_extension" to "gpu_only_test_deps" to support bazel test for GPU plugin.
PiperOrigin-RevId: 573251982
2023-10-13 10:12:16 -07:00
Peter Hawkins
9404518201 [CUDA] Add code to jax initialization that verifies that the CUDA libraries that are found are at least as new as the versions against which JAX was built.
This is intended to flag cases where the wrong CUDA libraries are used, either because:
* the user self-installed CUDA and that installation is too old, or
* the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version.

PiperOrigin-RevId: 568910422
2023-09-27 11:28:40 -07:00
Peter Hawkins
8c70288b83 Refer to CUDA stubs directly from TSL, rather than using an alias defined in xla/stream_executor.
Remove the aliases in xla/stream_executor.

PiperOrigin-RevId: 567025507
2023-09-20 11:21:54 -07:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Richard Levasseur
f891cbf64b Load Python rules from rules_python
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
Chris Jones
f70f1f8006 Internal change.
PiperOrigin-RevId: 559053761
2023-08-22 03:05:17 -07:00
Chris Jones
4ac2bdc2b1 [jax_triton] Add user-specified name field to serialized format.
PiperOrigin-RevId: 557415723
2023-08-16 02:53:51 -07:00
Chris Jones
9935445d57 [jax_triton] Simplify auto-tuning code.
PiperOrigin-RevId: 545733541
2023-07-05 11:18:18 -07:00
Chris Jones
31b862dd56 [jax_triton] Split C++ only parts of Triton custom callback from Python parts.
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).

PiperOrigin-RevId: 545211452
2023-07-03 06:52:32 -07:00
Chris Jones
d4e2464340 [jax_triton] Expose Triton custom call callback in header file.
This allows users to register the callback from C++ when not using the default call target name.

PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00
Chris Jones
b3527f3975 Zlib compress kernel proto.
PiperOrigin-RevId: 542529065
2023-06-22 05:22:53 -07:00
Chris Jones
f238667492 Make JAX-Triton calls serializable.
PiperOrigin-RevId: 542524794
2023-06-22 04:57:14 -07:00
Chris Jones
64e73270ff Use EncapsulateFunction utility.
PiperOrigin-RevId: 542299099
2023-06-21 10:37:52 -07:00
Sharad Vikram
1279418ce5 Link in CUDA runtime for triton in jaxlib
PiperOrigin-RevId: 535708416
2023-05-26 14:02:16 -07:00
Chris Jones
ea37043577 Switch to STATUS_RETURNING callback API.
PiperOrigin-RevId: 535568707
2023-05-26 03:15:44 -07:00
Chris Jones
2155b9181f Switch to using JAX status macros in jax-triton kernel call lib.
PiperOrigin-RevId: 535300412
2023-05-25 10:26:06 -07:00
Chris Jones
6b13d4eb86 Add branch prediction to JAX status macros.
PiperOrigin-RevId: 535233546
2023-05-25 06:23:23 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
b62f114524 Add support for using pip-installed CUDA wheels.
Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels.
Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency.
2023-03-26 12:35:00 +00:00
Peter Hawkins
ab45383038 Fix build breakage from OpenXLA switch.
PiperOrigin-RevId: 516325478
2023-03-13 14:37:35 -07:00
jax authors
42ef649e65 Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Yash Katariya
3e5a5053f4 Run GPU presubmits via bazel test on the RBE cluster. This speeds up the build + testing significantly (upto 10x).
But run the continuous builds by building on RBE and testing locally so as to run the multiaccelerator tests too. Locally we have 4 GPUs available.

Also make GPU presubmits blocking for JAX (re-enabled it).

PiperOrigin-RevId: 491647775
2022-11-29 08:45:58 -08:00
Qiao Zhang
c54bc90bf4 Fix cudnn_header OSS BUILD dep.
PiperOrigin-RevId: 491465703
2022-11-28 15:58:55 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00