If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().
PiperOrigin-RevId: 582279549
This is intended to flag cases where the wrong CUDA libraries are used, either because:
* the user self-installed CUDA and that installation is too old, or
* the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version.
PiperOrigin-RevId: 568910422
1. Add (limited) precision specifier handling to LSTM
Enables differentiating between TF32 and FP32 math. TF32 math had insufficient
precision to reliably pass LSTM correctness tests on A100 and H100.
2. Run the test using FP32
TF32 precision is not sufficient for the test to pass reliably on Ampere+ GPUs
such as A100 and H100.
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.
PiperOrigin-RevId: 560867049
pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead.
PiperOrigin-RevId: 560086072
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.
PiperOrigin-RevId: 559898790
Register callback with default call target name from C++, enabling Triton calls with the default name to work in C++ only contexts (e.g. serving).
PiperOrigin-RevId: 545211452
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.
There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.
This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.
PiperOrigin-RevId: 525534901
In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the
value n=-1, and the actual desired output length will be passed as
an additional operand. If the shape is static then the length will be
passed as part of the descriptor.
PiperOrigin-RevId: 497945778
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.
PiperOrigin-RevId: 483666051