10 Commits

Author SHA1 Message Date
Peter Hawkins
883cf2b1e9 Refactor custom call building code in jaxlib to use a helper function.
Refactoring only, no functional changes intended.

This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type.

Fixes https://github.com/google/jax/issues/10474

PiperOrigin-RevId: 447076192
2022-05-06 14:51:24 -07:00
Matthew Johnson
0c5864a220 add xla_client._version checks for mhlo.ConstOp signature
fix break from 0cf08d0c6841332240cae873e4b4cf9a9b313373
2022-05-04 09:54:06 -07:00
jax authors
0cf08d0c68 Integrate LLVM at llvm/llvm-project@46cc04de34
Updates LLVM usage to match
[46cc04de341b](https://github.com/llvm/llvm-project/commit/46cc04de341b)

PiperOrigin-RevId: 446430294
2022-05-04 05:31:41 -07:00
Jake VanderPlas
c6343ddf8e jax.scipy.linalg.schur: error on 16-bit floats
Fixes https://github.com/google/jax/issues/10530

PiperOrigin-RevId: 446279906
2022-05-03 13:47:44 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
6c1461b52b [MHLO] Add MHLO lowerings for triangular_solve, cholesky, and schur.
PiperOrigin-RevId: 441769591
2022-04-14 08:38:21 -07:00
Peter Hawkins
bc658e7456 [MHLO] Add direct MHLO lowerings for most linear algebra kernels.
PiperOrigin-RevId: 439927594
2022-04-06 13:59:09 -07:00
Aden Grue
8884ce5b98 Migrate 'jaxlib' CPU custom-calls to the status-returning API
PiperOrigin-RevId: 438165260
2022-03-29 17:14:14 -07:00
Leello Tadesse Dadi
f9a246ac19 schur lapack wrapper 2021-09-29 14:29:52 +02:00
Peter Hawkins
94f97b920f Refactor JAX CPU kernels to make them usable from C++.
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.

This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.

When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.

Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.

PiperOrigin-RevId: 394705605
2021-09-03 10:03:54 -07:00