We introduced these shims when migrating from MHLO to StableHLO, and they helped accommodate the version skew between jaxlib and JAX across different environments. Now that a sufficient amount of time has passed, these shims are no longer used anywhere and can be deleted.
PiperOrigin-RevId: 510820007
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
2) Fallback lowerings now produce StableHLO ops as well.
3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).
From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.
PiperOrigin-RevId: 497978733
In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the
value n=-1, and the actual desired output length will be passed as
an additional operand. If the shape is static then the length will be
passed as part of the descriptor.
PiperOrigin-RevId: 497945778
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.
Unchanged occurrences:
1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
argument value in Lowering.as_text and Lowering.compiler_ir.
2) Documentation (changelog, JEPs, IR examples, etc).
3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
so both are necessary to disambiguate.
PiperOrigin-RevId: 495771153
But run the continuous builds by building on RBE and testing locally so as to run the multiaccelerator tests too. Locally we have 4 GPUs available.
Also make GPU presubmits blocking for JAX (re-enabled it).
PiperOrigin-RevId: 491647775
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.
PiperOrigin-RevId: 483666051
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.
PiperOrigin-RevId: 482031830
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.
PiperOrigin-RevId: 480134780
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.
PiperOrigin-RevId: 479670565
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.
PiperOrigin-RevId: 479642543