47 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
154e4506c0 Some lax.linalg housekeeping.
The main aim here is to clean up lax.linalg to make it a bit easier to maintain and update with new features (e.g. batch partitioning - coming soon!). In this change, I removes some code duplication by consolidate most of the lowering logic into a helper function, and identifying some other common patterns. As part of this, I moved the remaining lowering rules from `jaxlib.lapack` into `lax.linalg`.

PiperOrigin-RevId: 725223882
2025-02-10 08:27:18 -08:00
George Necula
9f797990b5 Remove old backward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th, 2024 (#20997).

PiperOrigin-RevId: 723077990
2025-02-04 07:34:52 -08:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Roy Frostig
a60ead6fd1 enable partitionable threefry by default
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Dan Foreman-Mackey
39ce7916f1 Activate FFI implementation of tridiagonal reduction on GPU.
PiperOrigin-RevId: 714078036
2025-01-10 09:28:15 -08:00
Dan Foreman-Mackey
c1de7c733d Add LAPACK lowering for lax.linalg.tridiagonal_solve on CPU.
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.

PiperOrigin-RevId: 714069225
2025-01-10 08:56:46 -08:00
Paweł Paruzel
1256153200 Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
2024-12-11 02:22:37 -08:00
Paweł Paruzel
d474feda9e Activate Tridiagonal Reduction to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Tridiagonal Reduction.

PiperOrigin-RevId: 704234350
2024-12-09 04:36:59 -08:00
Paweł Paruzel
9081e85d68 Activate Schur Decomposition to XLA's FFI
PiperOrigin-RevId: 703484916
2024-12-06 06:49:53 -08:00
Dan Foreman-Mackey
8361eb58e1 Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
2024-10-17 17:57:06 -07:00
Dan Foreman-Mackey
67f24df740 Activate FFI implementation of symmetric Eigendecomposition.
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.

PiperOrigin-RevId: 682415506
2024-10-04 12:38:26 -07:00
Peter Hawkins
d3f63a66b8 Remove code to support jaxlib <= 0.4.33. 2024-10-04 11:39:05 -04:00
Dan Foreman-Mackey
c0240764bc Activate FFI implementation of the QR decomposition.
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.

The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.

PiperOrigin-RevId: 682312752
2024-10-04 07:27:11 -07:00
Paweł Paruzel
6e9a53690c Activate Hessenberg Decomposition to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Hessenberg Decomposition.

PiperOrigin-RevId: 681047625
2024-10-01 09:20:06 -07:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
Peter Hawkins
940860625e Remove code that existed to support jaxlib < 0.4.32.
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57

PiperOrigin-RevId: 675291231
2024-09-16 14:30:00 -07:00
Paweł Paruzel
3c6103f2df Activate Eigenvalue Decompositions to XLA's FFI
Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.

PiperOrigin-RevId: 668381949
2024-08-28 03:53:49 -07:00
Paweł Paruzel
c430b0c5e3 Activate QR Factorization to XLA's FFI
PiperOrigin-RevId: 666722604
2024-08-23 03:21:43 -07:00
Dan Foreman-Mackey
e51848ea3d Activate GPU kernel for LU decomposition.
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.

One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).

PiperOrigin-RevId: 665829252
2024-08-21 05:08:41 -07:00
Paweł Paruzel
354293da48 Activate Singular Value Decomposition to XLA's FFI
PiperOrigin-RevId: 662436635
2024-08-13 02:41:57 -07:00
Dan Foreman-Mackey
3c014a4c27 Add support for shape polymorphism with lu_pivots_to_permutation.
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.

PiperOrigin-RevId: 662024940
2024-08-12 03:39:54 -07:00
jax authors
aec6efb44b Merge pull request #22649 from ROCm:ci_jax_export_harness
PiperOrigin-RevId: 660096296
2024-08-06 14:27:13 -07:00
Ruturaj4
35c70fd3ec [ROCM] Fix export harness tests 2024-08-06 10:12:31 -05:00
Paweł Paruzel
6b0b222a38 Activate LU Decomposition to XLA's FFI
PiperOrigin-RevId: 658721697
2024-08-02 02:22:53 -07:00
George Necula
65450d165e Remove forward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.

PiperOrigin-RevId: 658011228
2024-07-31 08:10:17 -07:00
Yash Katariya
30037547d7 Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Paweł Paruzel
ae40c87919 Activate Cholesky Factorization Kernel to XLA's FFI
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
George Necula
0d058ce86f Reverts 0e17d26b6d81a6b281f55bdd81a6f0ab45efeafe
PiperOrigin-RevId: 655552768
2024-07-24 07:09:16 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
George Necula
14d87d3bf7 [export] Move the export implementation to jax._src.export.
This is part of the work to move the export APIs out
of jax.experimental. For now, the way to use this
implementation is still through `jax.experimental.export`.

Had to add a few "#type ignore" to the _export.py because
previously the file was exempt from internal pytype.
Will try to fix these in a later PR.

PiperOrigin-RevId: 641688200
2024-06-09 08:59:50 -07:00
George Necula
dbad518d2b [shape_poly] Add limited support for lax.approx_top_k.
This relies on newly introduced support for dynamic `k`
for approx_top_k, using the `stablehlo.dynamic_approx_top_k`
custom call.

We also add a backwards compatibility test.

PiperOrigin-RevId: 640557581
2024-06-05 09:51:47 -07:00
George Necula
3bcb8d6831 Remove DUCC FFT from jaxlib
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.

PiperOrigin-RevId: 638132572
2024-05-28 21:12:23 -07:00
Olli Lupton
9ba77f8ecd Skip a test when run with cuSolver >= 11.6
This version is shipped with CUDA 12.4. The test assumes that a
workspace size baked in with an older version of cuSolver can be used
with a newer version of cuSolver. This is not safe, and leads to an
error when upgrading from 11.5 to 11.6.
2024-05-21 14:46:43 +00:00
George Necula
b40a31006c [export] Add backwards compatibility test for Pallas call on GPUs.
Note that this adds the minimum of safety net to protect against
non-backwards-compatible changes. We really should have more tests
that cover more of the Triton MLIR.

Also enable serialization of such calls.

PiperOrigin-RevId: 630033989
2024-05-02 05:38:33 -07:00
Roy Frostig
3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Roy Frostig
69878c4924 remove Threefry GPU kernel
Cursory timing of `jit(lambda key: random.bits(key, (8, 128 * 128)))` suggests that this is a slight compile-time efficiency loss, taking roughly ~1.25x the time to compile compared to the removed kernel-based lowering. This seems worth the memory improvement, and one kernel fewer to maintain.

PiperOrigin-RevId: 629282330
2024-04-29 21:29:38 -07:00
George Necula
d92f4ae157 Reverts 9db5e693ebb4ad786c6e52b562cf32aeaba2e7e1
PiperOrigin-RevId: 628362293
2024-04-26 04:14:34 -07:00
jax authors
9db5e693eb Reverts 6bfbb4593a42fced91ba50de47271af425c74c20
PiperOrigin-RevId: 628035616
2024-04-25 04:53:22 -07:00
George Necula
6bfbb4593a Remove old ducc_fft custom call.
Starting in June 2023 we have switched the CPU lowering for FFT to use
the new custom call dynamic_ducc_fft. We are now out of the backwards
compatibility window and we remove the old ducc_fft.

We need to keep dynamic_ducc_fft a little bit longer (May 2024).

PiperOrigin-RevId: 627981921
2024-04-25 00:29:11 -07:00
Sergei Lebedev
754fab91f7 Bumped the minimum jaxlib to 0.4.23
jaxlib 0.4.23 has xla_extension_version 223 and mlir_api_version 54.
2024-04-13 08:18:33 +01:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
George Necula
02acacd999 Fix unused import error, breaks ruff. 2024-01-08 17:55:38 +02:00
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00
George Necula
ed2a839884 Move export backwards compatibility tests out of jax2tf. Step 3.
The last part of moving the tests: move jax2tf/tests/back_compat_test.py to tests/export_back_compat_test.py.

PiperOrigin-RevId: 596555577
2024-01-08 04:48:10 -08:00