Chris Jones
4ac2bdc2b1
[jax_triton] Add user-specified name
field to serialized format.
...
PiperOrigin-RevId: 557415723
2023-08-16 02:53:51 -07:00
Srinivas Vasudevan
7dfc8ff49d
Add batching rules to jax.lax.linalg.tridiagonal_solve.
...
PiperOrigin-RevId: 555700103
2023-08-10 16:25:59 -07:00
Chris Jones
714156df63
[jax_triton] Add support for float scalar inputs.
...
Python `float`s are inferred as "f64". Values can be passed as "f32" using `np.float32(value)`.
PiperOrigin-RevId: 552036612
2023-07-28 22:57:08 -07:00
Chris Jones
9935445d57
[jax_triton] Simplify auto-tuning code.
...
PiperOrigin-RevId: 545733541
2023-07-05 11:18:18 -07:00
Chris Jones
31b862dd56
[jax_triton] Split C++ only parts of Triton custom callback from Python parts.
...
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).
PiperOrigin-RevId: 545211452
2023-07-03 06:52:32 -07:00
Chris Jones
3f9da19c63
Add get_serialized_metadata
function to retrieve metadata from op's opaque data.
...
PiperOrigin-RevId: 544608895
2023-06-30 03:23:28 -07:00
Chris Jones
d4e2464340
[jax_triton] Expose Triton custom call callback in header file.
...
This allows users to register the callback from C++ when not using the default call target name.
PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00
Chris Jones
b3527f3975
Zlib compress kernel proto.
...
PiperOrigin-RevId: 542529065
2023-06-22 05:22:53 -07:00
Chris Jones
f238667492
Make JAX-Triton calls serializable.
...
PiperOrigin-RevId: 542524794
2023-06-22 04:57:14 -07:00
Chris Jones
64e73270ff
Use EncapsulateFunction
utility.
...
PiperOrigin-RevId: 542299099
2023-06-21 10:37:52 -07:00
Peter Zhizhin
01ed663163
Add a comment at the end of autotuning a Triton function
...
I'd like to see what config was chosen at the end of autotuning. Tracking `is the new best config` is a bit hard.
PiperOrigin-RevId: 538465212
2023-06-07 06:10:28 -07:00
jax authors
3ba308d4fc
[triton] Raise exception on too much shared memory requested
...
PiperOrigin-RevId: 538135346
2023-06-06 03:55:53 -07:00
Chris Jones
ea37043577
Switch to STATUS_RETURNING
callback API.
...
PiperOrigin-RevId: 535568707
2023-05-26 03:15:44 -07:00
Chris Jones
2155b9181f
Switch to using JAX status macros in jax-triton kernel call lib.
...
PiperOrigin-RevId: 535300412
2023-05-25 10:26:06 -07:00
Chris Jones
6b13d4eb86
Add branch prediction to JAX status macros.
...
PiperOrigin-RevId: 535233546
2023-05-25 06:23:23 -07:00
Sharad Vikram
bf8ed6a543
Move triton_kernel_call_lib to jaxlib
...
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
Peter Hawkins
a89c377762
[GPU] Fix another instance of missing stream synchronization in RNN kernels.
...
PiperOrigin-RevId: 530660502
2023-05-09 11:08:24 -07:00
Peter Hawkins
f168a1560c
[GPU] Add missing stream synchronization to tridiagonal_solve gtsv2 call.
...
May fix flaky failures in CI.
Make stream argument to Pool::Borrow() mandatory to minimize chance of forgetting it.
PiperOrigin-RevId: 530425766
2023-05-08 15:37:04 -07:00
Peter Hawkins
6b9a109939
Use stream-synchronized copy in rnn_kernels.cc.
...
May fix flaky wrong outputs sometimes seen in CI.
Also check for errors in another use of gpuStreamSynchronize().
PiperOrigin-RevId: 530391917
2023-05-08 13:28:08 -07:00
George Necula
a2ac510dc3
[shape_poly] Add support for dynamic shapes for eigh
...
We can only handle dynamic sizes for the batch dimensions for now.
PiperOrigin-RevId: 529001830
2023-05-02 23:27:59 -07:00
Matthew Johnson
56feaca7f9
update cuDNN RNN code not to save 'workspace' scratch between fwd and bwd
...
PiperOrigin-RevId: 528928263
2023-05-02 17:05:42 -07:00
Peter Hawkins
3bb7386149
[JAX] Improve handling of metadata in compilation cache.
...
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.
There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.
This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.
PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Sharad Vikram
3c3fa042e3
Copy seq_lengths before creating descriptor
...
PiperOrigin-RevId: 519771897
2023-03-27 10:59:44 -07:00
Peter Hawkins
172a831219
Switch JAX to use the OpenXLA repository.
2023-03-13 18:38:26 +00:00
George Necula
7d452adfd3
Add support for dynamic shapes to GPU threefry2x32 custom call.
...
In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the
value n=-1, and the actual desired output length will be passed as
an additional operand. If the shape is static then the length will be
passed as part of the descriptor.
PiperOrigin-RevId: 497945778
2022-12-27 04:48:26 -08:00
Peter Hawkins
e835739eda
Remove an unnecessary include/ from pybind11 include paths.
...
PiperOrigin-RevId: 492016679
2022-11-30 14:20:02 -08:00
Qiao Zhang
4d1c4bc761
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
...
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf
Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
...
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
...
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Parker Schuh
0324cac888
Remove unused potrf kernels.
...
PiperOrigin-RevId: 489322021
2022-11-17 15:22:13 -08:00
Rahul Batra
31d8f62826
Sytrd solver and SytrdDescriptor should NOT be CUDA only
2022-11-11 22:41:51 +00:00
Peter Hawkins
352b042fe9
Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
...
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Tianjian Lu
46368e4e73
[sparse] Update the guard of cusparse SpMM and SpMv algorithms to cusparse version 11.7.1 onwards.
...
PiperOrigin-RevId: 486051658
2022-11-03 21:39:52 -07:00
Tianjian Lu
ef0f64ec5c
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
...
PiperOrigin-RevId: 485441349
2022-11-01 16:01:50 -07:00
Jake VanderPlas
06c1d8efb5
Rollback of:
...
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
Still breaks CUDA 11.1
PiperOrigin-RevId: 485151807
2022-10-31 14:38:47 -07:00
Tianjian Lu
66e75edd0b
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
...
PiperOrigin-RevId: 484351696
2022-10-27 14:34:44 -07:00
Peter Hawkins
0814770601
Fix FP8 compilation failure in jaxlib stemming from the CUDA/ROCM merge.
...
PiperOrigin-RevId: 484026031
2022-10-26 11:40:14 -07:00
Peter Hawkins
a852710a09
Merge CUDA and ROCM kernel code in jaxlib.
...
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.
PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00