The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.
PiperOrigin-RevId: 716596848
1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference.
2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp.
PiperOrigin-RevId: 716274818
It is still not efficiently implemented, this is mostly to clean up some logic. We may be able to fuse the creation of masks for different tiles into the creation of a single one. But this is also a problem for the later gens.
This also cleans up an unreachable return statement.
PiperOrigin-RevId: 714847066
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.
PiperOrigin-RevId: 714069225
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.
Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.
To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` - see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
This change does not yet do the work necessary to make any tests pass with threading enabled, which will come in future changes.
This approach is broadly inspired by a6d205dd4c/testtools/testsuite.py (L113) and by unittest-ft.
We add a custom TestResult class that batches up any test result actions and applies them under a lock. We also add a custom TestSuite class that runs individual test cases in parallel using a thread-pool.
We need a reader-writer lock to implement a `@jtu.thread_hostile_test` decorator, which we do by adding bindings around absl::Mutex to jaxlib.
PiperOrigin-RevId: 713312937
This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling,
by biasing the layout inference to prefer the latter.
PiperOrigin-RevId: 713270421
In our MSAN CI, the copy of LAPACK we use is not MSAN-instrumented, leading to false positives. Suppress those false-positives via annotations.
PiperOrigin-RevId: 712607044
Note that offsets outside of first tile are still disabled (for both infer and apply), and once we support it we will want to assign offsets differently, this is mostly to avoid assigning invalid layouts (that may not just be outside the first tile, but outside the vreg slice)
PiperOrigin-RevId: 709168368
We can enable them later but at least this way the support is available to build on
(e.g. in the new insert relayouts pass)
Reverts 05f3a701e769748ff1ec51d50324a3595c4aff0d
PiperOrigin-RevId: 708397219