21 Commits

Author SHA1 Message Date
Paweł Paruzel
23fdb91252 Port Schur Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685689593
2024-10-14 06:46:42 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Paweł Paruzel
5fc992e5e1 Determine LAPACK workspaces during SVD kernel runtime
The SVD kernel implementation used to require workspace shapes to be determined prior to the custom call on the JAX's side. The new FFI kernels need not demand these shapes to be specified anymore. They are evaluated during kernel runtime.

PiperOrigin-RevId: 662413273
2024-08-13 01:17:44 -07:00
Paweł Paruzel
b2a469b361 Port Eigenvalue Decompositions to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 659492696
2024-08-05 03:18:13 -07:00
Dan Foreman-Mackey
ff4e0b1214 Rearrange the LAPACK handler definitions in jaxlib to avoid duplicate handler errors.
When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.

PiperOrigin-RevId: 657186057
2024-07-29 06:59:44 -07:00
Dan Foreman-Mackey
33a9db3943 Move FFI helper macros from jaxlib/cpu/lapack_kernels.cc to a jaxlib/ffi_helpers.h.
Some of the macros that were used in jaxlib's FFI calls to LAPACK turned out to
be useful for other FFI calls. This change consolidates these macros in the
ffi_helper header.

PiperOrigin-RevId: 651166306
2024-07-10 15:09:45 -07:00
Paweł Paruzel
532be68461 Port Singular Value Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 650212574
2024-07-08 05:19:53 -07:00
Dan Foreman-Mackey
98b87540a7 Avoid throwing exceptions in LAPACK CPU kernels.
When an FFI kernel is executed, there isn't any global try/except block (I think!) so it's probably a good idea to avoid throwing.
Instead, it should be safer to handle mapping failures to ffi::Error manually.

PiperOrigin-RevId: 647348889
2024-06-27 09:41:07 -07:00
Paweł Paruzel
3d39b6e752 Port Cholesky Factorization to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 642954763
2024-06-13 05:44:36 -07:00
George Necula
3bcb8d6831 Remove DUCC FFT from jaxlib
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.

PiperOrigin-RevId: 638132572
2024-05-28 21:12:23 -07:00
jax authors
16b29a6930 Merge pull request #19288 from pearu:pearu/int32-overflow
PiperOrigin-RevId: 608701959
2024-02-20 12:43:16 -08:00
Pearu Peterson
3fa1033ac1 Prevent silent overflow in lapack worker size calculations.
Add -fexceptions to building lapack_kernels
2024-02-20 11:04:06 +02:00
Shashank Viswanadha
350b7c56b8 Add python stub files for jaxlib/cpu C++ Python extensions.
PiperOrigin-RevId: 585990748
2023-11-28 08:45:24 -08:00
Antonio Sanchez
873cffc776 Use TSL's import of DUCC.
This is a necessary first step before adding DUCC support to XLA,
otherwise the JAX tests in the XLA repo pull from JAX's copy,
which has slightly different build rules.

PiperOrigin-RevId: 576880208
2023-10-26 08:26:56 -07:00
Peter Hawkins
dbf13252f0 Copybara import of the project:
--
3905d6123bdc22f505934242363fda426c99c4cf by Peter Hawkins <phawkins@google.com>:

Update flatbuffers.

Use upstream flatbuffer bazel scripts, with a couple of small patches to fix:
* https://github.com/google/flatbuffers/issues/8087 (remove npm references)
* https://github.com/google/flatbuffers/pull/8088 (fix flatc build failure due to main() removal by linker)

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17502 from hawkinsp:fb 3905d6123bdc22f505934242363fda426c99c4cf
PiperOrigin-RevId: 563543954
2023-09-07 14:27:25 -07:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
ab45383038 Fix build breakage from OpenXLA switch.
PiperOrigin-RevId: 516325478
2023-03-13 14:37:35 -07:00
jax authors
42ef649e65 Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00