88 Commits

Author SHA1 Message Date
jax authors
be4d52b814 Merge pull request #22667 from ROCm:rocm-jax-triton-add-get_arch_detail
PiperOrigin-RevId: 662007143
2024-08-12 02:30:49 -07:00
Rahul Batra
4b7c198a1c [ROCm]: Add get_arch_details for triton kernel call 2024-08-12 06:16:27 +00:00
Dan Foreman-Mackey
11d9c2de2c Update GPU implementation of lu_pivots_to_permutation to infer the permutation size directly from the input dimensions, instead of using an input parameter.
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.

In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.

PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Ruturaj4
a2d79936df [ROCM] Fix BUILD.bazel library source paths 2024-08-07 09:18:20 -05:00
Dan Foreman-Mackey
8df0c3a9cc Port Getrf GPU kernel from custom call to FFI.
PiperOrigin-RevId: 658550170
2024-08-01 15:02:25 -07:00
Dan Foreman-Mackey
f20efc630f Move jaxlib GPU handlers to separate build target.
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.

PiperOrigin-RevId: 658497960
2024-08-01 12:30:04 -07:00
George Necula
65450d165e Remove forward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.

PiperOrigin-RevId: 658011228
2024-07-31 08:10:17 -07:00
Dan Foreman-Mackey
33a9db3943 Move FFI helper macros from jaxlib/cpu/lapack_kernels.cc to a jaxlib/ffi_helpers.h.
Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to
be useful for other FFI calls. This change consolidates these macros in the
ffi_helper header.

PiperOrigin-RevId: 651166306
2024-07-10 15:09:45 -07:00
Dan Foreman-Mackey
4f394828e1 Fix C++ registration of FFI handlers and consolidate gpu/linalg kernel implementation.
This change does a few things (arguably too many):

1. The key change here is that it fixes the handler registration in `jaxlib/gpu/gpu_kernels.cc` for the two handlers that use the XLA FFI API. A previous attempt at this change caused downstream issues because of duplicate registrations, but we were able to fix that directly in XLA.

2. A second related change is to declare and define the XLA FFI handlers consistently using the `XLA_FFI_DECLARE_HANDLER_SYMBOL` and `XLA_FFI_DEFINE_HANDLER_SYMBOL` macros. We need to use these macros instead of the `XLA_FFI_DEFINE_HANDLER` version which produces a lambda, so that when XLA checks the address of the handler during registration it is consistent. Without this change, the downstream tests would continue to fail.

3. The final change is to consolidate the `cholesky_update_kernel` and `lu_pivot_kernels` implementations into a common `linalg_kernels` target. This makes the implementation of the `_linalg` nanobind module consistent with the other targets within `jaxlib/gpu`, and (I think!) makes the details easier to follow. This last change is less urgent, but it was what I set out to do so that's why I'm suggesting them all together, but I can split this in two if that would be preferred.

PiperOrigin-RevId: 651107659
2024-07-10 12:09:12 -07:00
Eugene Zhulenev
d49a0c5a63 [jax] Remove dead code from JAX custom calls defined as FFI handlers
PiperOrigin-RevId: 651025363
2024-07-10 08:11:12 -07:00
Eugene Zhulenev
1e03917c43 [xla:ffi] Use lazy decoding for Buffer<dtype,rank>
name                old cpu/op   new cpu/op   delta
BM_AnyBufferArgX1   11.0ns ± 3%  11.2ns ±10%   +1.76%  (p=0.000 n=67+69)
BM_AnyBufferArgX4   12.4ns ± 3%  12.4ns ± 4%   -0.31%  (p=0.006 n=69+69)
BM_BufferArgX1      12.5ns ± 1%  11.1ns ± 4%  -11.20%  (p=0.000 n=62+76)
BM_BufferArgX4      19.1ns ± 1%  14.4ns ± 4%  -24.84%  (p=0.000 n=64+73)
BM_BufferArgX8      36.0ns ± 5%  20.3ns ± 4%  -43.59%  (p=0.000 n=79+75)
BM_TupleOfI32Attrs  66.4ns ± 1%  66.4ns ± 2%   -0.03%  (p=0.000 n=66+72)

PiperOrigin-RevId: 650691450
2024-07-09 11:07:25 -07:00
George Necula
2f808e9da9 Fix error in custom call registration for some FFI functions
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```

It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.

The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.

PiperOrigin-RevId: 647993991
2024-06-29 12:18:34 -07:00
Dan Foreman-Mackey
9b33df6438 Update C++ registration of cu_lu_pivots_to_permutation to use XLA_FFI_REGISTER_HANDLER
PiperOrigin-RevId: 647734115
2024-06-28 10:53:33 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
Dan Foreman-Mackey
9ae1c56c44 Update lu_pivots_to_permutation to use FFI dimensions on GPU.
The XLA FFI interface provides metadata about buffer dimensions, so quantities
like batch dimensions can be evaluated on the backend, instead of passed as
attributes. This change has the added benefit of allowing this FFI call to
support "vectorized" vmap and dynamic shapes.

PiperOrigin-RevId: 647343656
2024-06-27 09:27:15 -07:00
Ruturaj4
79fccf6c82 add cholesky changes in bazel 2024-05-18 00:37:09 +00:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Sergei Lebedev
51fc4f85ad Ported LuPivotsToPermutation to the typed XLA FFI
The typed FFI

* allows passing custom call attributes directly to backend_config= instead
  of serializing them into a C++ struct.
* It also handles validation and deserialization of custom call operands.

PiperOrigin-RevId: 630067005
2024-05-02 08:12:05 -07:00
Marvin Kim
90e9e47a55 [Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched.
For configs that cannot be launched, we should not launch them via benchmark.

PiperOrigin-RevId: 626153377
2024-04-18 14:35:51 -07:00
Jieying Luo
44e83d4e0a Add a few custom call registrations to gpu_kernel to keep in-sync with callers of xla_client.register_custom_call_target.
PiperOrigin-RevId: 624275186
2024-04-12 13:30:18 -07:00
Henning Becker
9809aa1929 Move CUDA specific functions from asm_compiler to cuda_asm_compiler target
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target

The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.

Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.

PiperOrigin-RevId: 623285804
2024-04-09 14:43:41 -07:00
Marvin Kim
722708052c [JAX] Fix typo in comment.
PiperOrigin-RevId: 621827985
2024-04-04 05:35:28 -07:00
David Dunleavy
aade591fdf Move tsl/python to xla/tsl/python
PiperOrigin-RevId: 620320903
2024-03-29 13:15:21 -07:00
Rahul Batra
8575055571 [ROCm]: Add missing hipStreamWaitEvent API call 2024-03-20 16:58:21 +00:00
Peter Hawkins
c2bbf9c577 Remove some code to support older CUDA and CUSPARSE versions.
The minimum CUDA version supported by JAX is CUDA 11.8, which ships with CUSPARSE 11.7.5.

PiperOrigin-RevId: 616892230
2024-03-18 11:25:03 -07:00
Andrey Portnoy
dcb58bb540 Include <cstdint> in files where it is used 2024-03-06 11:58:15 -05:00
jax authors
7514d5c7aa [triton] Add clustering support and test
PiperOrigin-RevId: 612417957
2024-03-04 05:51:10 -08:00
Eugene Zhulenev
1ae2022918 [jax-triton] Do not capture jax-triton calls that require autotuning
PiperOrigin-RevId: 611823473
2024-03-01 10:28:47 -08:00
Eugene Zhulenev
3a69b80774 [jax-triton] Synchronize autotuning stream with a main one
PiperOrigin-RevId: 609792049
2024-02-23 11:42:30 -08:00
Chris Jones
fcc8b54789 [jax_triton] Use ReaderLock on fast path to reduce lock contention in multi-GPU settings.
PiperOrigin-RevId: 606648981
2024-02-13 09:31:50 -08:00
Anlun Xu
d62071066e [jax:triton] Add a workaround for calling cuStreamGetCtx inside graph capture
A bug in CUDA prevents us from calling gpuStreamGetCtx inside graph capture. We use cuCtxGetCurrent as workaround for now.

PiperOrigin-RevId: 605417225
2024-02-08 13:49:45 -08:00
Rahul Batra
f01c27f65a [ROCm]: Add ROCm command buffer support for triton kernel 2024-02-05 19:34:12 +00:00
Anlun Xu
16636f9c97 [jax_triton] Only use side stream to do autotuning when doing graph capture
When graph capture is not enabled, autotuning and kernel launch should be on the same stream to avoid race condition.

PiperOrigin-RevId: 603728867
2024-02-02 10:48:26 -08:00
Anlun Xu
5e009f9ff1 Make triton kernels compatible with command buffers
Autotuning is not compatible with graph capture because it requires synchronizing.

We use cuThreadExchangeStreamCaptureMode to execute a sequence of commands that are not recorded to graphs, similar to what NCCL does here: b6d7438d31/src/include/alloc.h (L171)

PiperOrigin-RevId: 602436960
2024-01-29 11:00:29 -08:00
Anlun Xu
88f5eaca3e [xla:gpu] Make cu_threefry2x32 custom call compatible with command buffers
PiperOrigin-RevId: 600937786
2024-01-23 16:14:21 -08:00
jax authors
5761e393fa The Triton autotuner ignores configs that use too much shmem
The autotuner runs a series of benchmarks to determine the best configuration
for a Triton kernel. However, if it encounters a config that does not fit in
shared memory it throws an error and stops. I this eventuality it should just
continue.

PiperOrigin-RevId: 600730507
2024-01-23 03:08:57 -08:00
Rahul Batra
f997609e76 [ROCm]: Updates hip headers path for ROCm 6.0 2024-01-22 16:08:37 +00:00
jax authors
ab3c1b5146 [triton] Pass cluster_dims to TritonKernel and use cuLaunchKernel if size <= 1
PiperOrigin-RevId: 599809560
2024-01-19 05:55:41 -08:00
jax authors
59ea9f3fde [triton] Use cuLaunchKernelEx instead of cuLaunchKernel
PiperOrigin-RevId: 597555083
2024-01-11 07:52:07 -08:00
Peter Hawkins
95e2d3fc2b [JAX:GPU] Generalize gesvdj kernel to iterate over the unbatched Jacobi kernel in cases that we cannot use the batched kernel.
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().

PiperOrigin-RevId: 582279549
2023-11-14 04:52:15 -08:00
jax authors
88fe0da6d1 Merge pull request #18078 from ROCmSoftwarePlatform:rocm-jax-triton
PiperOrigin-RevId: 574546618
2023-10-18 11:56:01 -07:00
Rahul Batra
b4b97cd8e8 [ROCm]: Add jax-triton support for ROCm 2023-10-18 07:09:20 +00:00
Peter Hawkins
07fa9dc3db Fix cupti-related build failure under CUDA 11.
cuptiGetErrorMessage was added in CUDA 12.2.

PiperOrigin-RevId: 568962562
2023-09-27 14:33:30 -07:00
Peter Hawkins
9404518201 [CUDA] Add code to jax initialization that verifies that the CUDA libraries that are found are at least as new as the versions against which JAX was built.
This is intended to flag cases where the wrong CUDA libraries are used, either because:
* the user self-installed CUDA and that installation is too old, or
* the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version.

PiperOrigin-RevId: 568910422
2023-09-27 11:28:40 -07:00
Andrey Portnoy
fc1c31d958 Run LSTM test using FP32 math (as opposed to TF32)
1. Add (limited) precision specifier handling to LSTM

Enables differentiating between TF32 and FP32 math. TF32 math had insufficient
precision to reliably pass LSTM correctness tests on A100 and H100.

2. Run the test using FP32

TF32 precision is not sufficient for the test to pass reliably on Ampere+ GPUs
such as A100 and H100.
2023-09-19 14:45:14 -04:00
Chris Jones
9f7a19ad50 [jax_triton] Improve error message when shared memory exceeds that available on the GPU.
PiperOrigin-RevId: 561792406
2023-08-31 16:32:18 -07:00
Peter Hawkins
46ac9e2170 Use the default CSR matmul algorithm.
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.

PiperOrigin-RevId: 560867049
2023-08-28 17:49:01 -07:00
Peter Hawkins
34010a9e4a Align dummy pointers passed to cusparse to 16 bytes
Fixes alignment errors from Cusparse 12.2.

PiperOrigin-RevId: 560793586
2023-08-28 12:56:27 -07:00
Peter Hawkins
ac8ea86103 Fix accidental signature change to get_serialized_metadata() from nanobind PR.
pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead.

PiperOrigin-RevId: 560086072
2023-08-25 07:31:31 -07:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00