Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.
PiperOrigin-RevId: 483666051
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.
PiperOrigin-RevId: 482031830
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.
PiperOrigin-RevId: 480134780
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.
PiperOrigin-RevId: 479670565
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
PiperOrigin-RevId: 479642543
This makes handling TensorFlow's dependency on StableHLO consistent with handling other TensorFlow's dependencies. For example, LLVM goes into //third_party/llvm, and so should StableHLO.
Users of tensorflow/tensorflow (e.g. JAX) need to change Bazel builds, replacing `@org_tensorflow//tensorflow/compiler/xla/mlir_hlo/stablehlo` with `@stablehlo//`. Nothing else changes, e.g. C++ includes, C++ usage, Python bindings and Python usage all stay the same. Example: https://github.com/google/jax/pull/12174.
Users of tensorflow/mlir-hlo are unaffected thanks to the awesome power of Copybara. There are minor changes in the StableHLO part of MLIR-HLO caused by the fact that the StableHLO repository and the vendored StableHLO inside tensorflow/tensorflow have diverged a little bit (e.g. Markdown formatting is slightly different between repositories because I didn't have the time to propagate these changes) and now they have been forced to converge, but these changes won't affect the behavior of neither CMake nor Bazel builds of MLIR-HLO.
Moving forward, contributions to StableHLO will only be possible through openxla/stablehlo. This is because tensorflow/tensorflow no longer vendors StableHLO. (tensorflow/mlir-hlo still does, but it's readonly).
PiperOrigin-RevId: 474360128
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.
This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.
C++:
1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
Python:
5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
6) Python imports don't change.
7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
We frequently use the pattern
try:
import m
except ImportError:
# do something else.
This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
Currently JAX wheels end up with names like:
jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl
This PR changes the wheel names to:
jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl
i.e., we include the CPython ABI tag. This simply reflects the status
quo in the wheel name, and does not change what jaxlib needs.