98 Commits

Author SHA1 Message Date
Sebastian Bodenstein
e3b8177af3 Internal change.
PiperOrigin-RevId: 671583042
2024-09-05 18:42:22 -07:00
Peter Hawkins
45b871950e Fix a number of minor problems in the ROCM build.
Change in preparation for adding more presubmits for AMD ROCM.

PiperOrigin-RevId: 667766343
2024-08-26 17:04:01 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Dan Foreman-Mackey
b56ed8eedd Port GPU kernel for Householder transformation to FFI.
PiperOrigin-RevId: 666305682
2024-08-22 05:23:09 -07:00
Krishna Haridasan
3713b966c2 Fix a potential segfault in triton kernel call caching
It is possible that a null pointer is inserted into the cache and not updated with a valid kernel call
in case there is an error later during initialization. This change updates the cache to store either
an error or a valid kernel call.

PiperOrigin-RevId: 666161091
2024-08-21 20:45:35 -07:00
Dan Foreman-Mackey
bd90968a25 Port the GPU Cholesky update custom call to the FFI.
PiperOrigin-RevId: 665319689
2024-08-20 05:46:03 -07:00
Dan Foreman-Mackey
71a93d0c87 Port QR factorization GPU kernel to FFI.
The biggest change here is that we now ignore the `info` parameter that is returned by `getrf`. In the previous implementation, we would return an error in the batched implementation, or set the relevant matrix entries to NaN in the non-batched version if `info != 0`. But, since info is only used for shape checking (see LAPACK, cuBLAS and cuSolver docs), I argue that we will never see `info != 0`, because we're including all the shape checks in the kernel already.

PiperOrigin-RevId: 665307128
2024-08-20 05:07:04 -07:00
Dan Foreman-Mackey
30d54ec6ff Refactor FFI shape inference functions to include dimension check.
Previously we always had two steps when extracting the batch size: (1) check the buffer has enough dimensions, (2) get the shape. And, in a few cases, this first check was missing. Now these steps are combined into one function that returns a StatusOr.

As part of this, I needed to fix our implementation of the `ASSIGN_OR_RETURN` macro to properly handle parentheses.

PiperOrigin-RevId: 664803225
2024-08-19 07:41:28 -07:00
Dan Foreman-Mackey
b6306e3953 Remove synchronization from GPU LU decomposition kernel by adding an async batch pointers builder.
In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO.

PiperOrigin-RevId: 663686539
2024-08-16 04:37:09 -07:00
Dan Foreman-Mackey
ad1bd38790 Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel.
This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.

PiperOrigin-RevId: 662945116
2024-08-14 09:20:40 -07:00
jax authors
be4d52b814 Merge pull request #22667 from ROCm:rocm-jax-triton-add-get_arch_detail
PiperOrigin-RevId: 662007143
2024-08-12 02:30:49 -07:00
Rahul Batra
4b7c198a1c [ROCm]: Add get_arch_details for triton kernel call 2024-08-12 06:16:27 +00:00
Dan Foreman-Mackey
11d9c2de2c Update GPU implementation of lu_pivots_to_permutation to infer the permutation size directly from the input dimensions, instead of using an input parameter.
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.

In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.

PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Ruturaj4
a2d79936df [ROCM] Fix BUILD.bazel library source paths 2024-08-07 09:18:20 -05:00
Dan Foreman-Mackey
8df0c3a9cc Port Getrf GPU kernel from custom call to FFI.
PiperOrigin-RevId: 658550170
2024-08-01 15:02:25 -07:00
Dan Foreman-Mackey
f20efc630f Move jaxlib GPU handlers to separate build target.
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.

PiperOrigin-RevId: 658497960
2024-08-01 12:30:04 -07:00
George Necula
65450d165e Remove forward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.

PiperOrigin-RevId: 658011228
2024-07-31 08:10:17 -07:00
Dan Foreman-Mackey
33a9db3943 Move FFI helper macros from jaxlib/cpu/lapack_kernels.cc to a jaxlib/ffi_helpers.h.
Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to
be useful for other FFI calls. This change consolidates these macros in the
ffi_helper header.

PiperOrigin-RevId: 651166306
2024-07-10 15:09:45 -07:00
Dan Foreman-Mackey
4f394828e1 Fix C++ registration of FFI handlers and consolidate gpu/linalg kernel implementation.
This change does a few things (arguably too many):

1. The key change here is that it fixes the handler registration in `jaxlib/gpu/gpu_kernels.cc` for the two handlers that use the XLA FFI API. A previous attempt at this change caused downstream issues because of duplicate registrations, but we were able to fix that directly in XLA.

2. A second related change is to declare and define the XLA FFI handlers consistently using the `XLA_FFI_DECLARE_HANDLER_SYMBOL` and `XLA_FFI_DEFINE_HANDLER_SYMBOL` macros. We need to use these macros instead of the `XLA_FFI_DEFINE_HANDLER` version which produces a lambda, so that when XLA checks the address of the handler during registration it is consistent. Without this change, the downstream tests would continue to fail.

3. The final change is to consolidate the `cholesky_update_kernel` and `lu_pivot_kernels` implementations into a common `linalg_kernels` target. This makes the implementation of the `_linalg` nanobind module consistent with the other targets within `jaxlib/gpu`, and (I think!) makes the details easier to follow. This last change is less urgent, but it was what I set out to do so that's why I'm suggesting them all together, but I can split this in two if that would be preferred.

PiperOrigin-RevId: 651107659
2024-07-10 12:09:12 -07:00
Eugene Zhulenev
d49a0c5a63 [jax] Remove dead code from JAX custom calls defined as FFI handlers
PiperOrigin-RevId: 651025363
2024-07-10 08:11:12 -07:00
Eugene Zhulenev
1e03917c43 [xla:ffi] Use lazy decoding for Buffer<dtype,rank>
name                old cpu/op   new cpu/op   delta
BM_AnyBufferArgX1   11.0ns ± 3%  11.2ns ±10%   +1.76%  (p=0.000 n=67+69)
BM_AnyBufferArgX4   12.4ns ± 3%  12.4ns ± 4%   -0.31%  (p=0.006 n=69+69)
BM_BufferArgX1      12.5ns ± 1%  11.1ns ± 4%  -11.20%  (p=0.000 n=62+76)
BM_BufferArgX4      19.1ns ± 1%  14.4ns ± 4%  -24.84%  (p=0.000 n=64+73)
BM_BufferArgX8      36.0ns ± 5%  20.3ns ± 4%  -43.59%  (p=0.000 n=79+75)
BM_TupleOfI32Attrs  66.4ns ± 1%  66.4ns ± 2%   -0.03%  (p=0.000 n=66+72)

PiperOrigin-RevId: 650691450
2024-07-09 11:07:25 -07:00
George Necula
2f808e9da9 Fix error in custom call registration for some FFI functions
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```

It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.

The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.

PiperOrigin-RevId: 647993991
2024-06-29 12:18:34 -07:00
Dan Foreman-Mackey
9b33df6438 Update C++ registration of cu_lu_pivots_to_permutation to use XLA_FFI_REGISTER_HANDLER
PiperOrigin-RevId: 647734115
2024-06-28 10:53:33 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
Dan Foreman-Mackey
9ae1c56c44 Update lu_pivots_to_permutation to use FFI dimensions on GPU.
The XLA FFI interface provides metadata about buffer dimensions, so quantities
like batch dimensions can be evaluated on the backend, instead of passed as
attributes. This change has the added benefit of allowing this FFI call to
support "vectorized" vmap and dynamic shapes.

PiperOrigin-RevId: 647343656
2024-06-27 09:27:15 -07:00
Ruturaj4
79fccf6c82 add cholesky changes in bazel 2024-05-18 00:37:09 +00:00
jax authors
e8b06ccf56 Cholesky rank-1 update kernel for JAX.
PiperOrigin-RevId: 633722940
2024-05-14 15:21:38 -07:00
Sergei Lebedev
51fc4f85ad Ported LuPivotsToPermutation to the typed XLA FFI
The typed FFI

* allows passing custom call attributes directly to backend_config= instead
  of serializing them into a C++ struct.
* It also handles validation and deserialization of custom call operands.

PiperOrigin-RevId: 630067005
2024-05-02 08:12:05 -07:00
Marvin Kim
90e9e47a55 [Jax/Triton] Skip benchmarking while autotuning for configs that cannot be launched.
For configs that cannot be launched, we should not launch them via benchmark.

PiperOrigin-RevId: 626153377
2024-04-18 14:35:51 -07:00
Jieying Luo
44e83d4e0a Add a few custom call registrations to gpu_kernel to keep in-sync with callers of xla_client.register_custom_call_target.
PiperOrigin-RevId: 624275186
2024-04-12 13:30:18 -07:00
Henning Becker
9809aa1929 Move CUDA specific functions from asm_compiler to cuda_asm_compiler target
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target

The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.

Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.

PiperOrigin-RevId: 623285804
2024-04-09 14:43:41 -07:00
Marvin Kim
722708052c [JAX] Fix typo in comment.
PiperOrigin-RevId: 621827985
2024-04-04 05:35:28 -07:00
David Dunleavy
aade591fdf Move tsl/python to xla/tsl/python
PiperOrigin-RevId: 620320903
2024-03-29 13:15:21 -07:00
Rahul Batra
8575055571 [ROCm]: Add missing hipStreamWaitEvent API call 2024-03-20 16:58:21 +00:00
Peter Hawkins
c2bbf9c577 Remove some code to support older CUDA and CUSPARSE versions.
The minimum CUDA version supported by JAX is CUDA 11.8, which ships with CUSPARSE 11.7.5.

PiperOrigin-RevId: 616892230
2024-03-18 11:25:03 -07:00
Andrey Portnoy
dcb58bb540 Include <cstdint> in files where it is used 2024-03-06 11:58:15 -05:00
jax authors
7514d5c7aa [triton] Add clustering support and test
PiperOrigin-RevId: 612417957
2024-03-04 05:51:10 -08:00
Eugene Zhulenev
1ae2022918 [jax-triton] Do not capture jax-triton calls that require autotuning
PiperOrigin-RevId: 611823473
2024-03-01 10:28:47 -08:00
Eugene Zhulenev
3a69b80774 [jax-triton] Synchronize autotuning stream with a main one
PiperOrigin-RevId: 609792049
2024-02-23 11:42:30 -08:00
Chris Jones
fcc8b54789 [jax_triton] Use ReaderLock on fast path to reduce lock contention in multi-GPU settings.
PiperOrigin-RevId: 606648981
2024-02-13 09:31:50 -08:00
Anlun Xu
d62071066e [jax:triton] Add a workaround for calling cuStreamGetCtx inside graph capture
A bug in CUDA prevents us from calling gpuStreamGetCtx inside graph capture. We use cuCtxGetCurrent as workaround for now.

PiperOrigin-RevId: 605417225
2024-02-08 13:49:45 -08:00
Rahul Batra
f01c27f65a [ROCm]: Add ROCm command buffer support for triton kernel 2024-02-05 19:34:12 +00:00
Anlun Xu
16636f9c97 [jax_triton] Only use side stream to do autotuning when doing graph capture
When graph capture is not enabled, autotuning and kernel launch should be on the same stream to avoid race condition.

PiperOrigin-RevId: 603728867
2024-02-02 10:48:26 -08:00
Anlun Xu
5e009f9ff1 Make triton kernels compatible with command buffers
Autotuning is not compatible with graph capture because it requires synchronizing.

We use cuThreadExchangeStreamCaptureMode to execute a sequence of commands that are not recorded to graphs, similar to what NCCL does here: b6d7438d31/src/include/alloc.h (L171)

PiperOrigin-RevId: 602436960
2024-01-29 11:00:29 -08:00
Anlun Xu
88f5eaca3e [xla:gpu] Make cu_threefry2x32 custom call compatible with command buffers
PiperOrigin-RevId: 600937786
2024-01-23 16:14:21 -08:00
jax authors
5761e393fa The Triton autotuner ignores configs that use too much shmem
The autotuner runs a series of benchmarks to determine the best configuration
for a Triton kernel. However, if it encounters a config that does not fit in
shared memory it throws an error and stops. I this eventuality it should just
continue.

PiperOrigin-RevId: 600730507
2024-01-23 03:08:57 -08:00
Rahul Batra
f997609e76 [ROCm]: Updates hip headers path for ROCm 6.0 2024-01-22 16:08:37 +00:00
jax authors
ab3c1b5146 [triton] Pass cluster_dims to TritonKernel and use cuLaunchKernel if size <= 1
PiperOrigin-RevId: 599809560
2024-01-19 05:55:41 -08:00
jax authors
59ea9f3fde [triton] Use cuLaunchKernelEx instead of cuLaunchKernel
PiperOrigin-RevId: 597555083
2024-01-11 07:52:07 -08:00
Peter Hawkins
95e2d3fc2b [JAX:GPU] Generalize gesvdj kernel to iterate over the unbatched Jacobi kernel in cases that we cannot use the batched kernel.
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().

PiperOrigin-RevId: 582279549
2023-11-14 04:52:15 -08:00