As of today it has been 180 days since the release of 0.4.34 where the following legacy LAPACK kernels were no longer used when lowering:
* getrf
* geqrf / orgqr
* potrf
* gesdd
* syevd
* geev
* gehrd
Following our compatibility policy, these are now safe to remove.
PiperOrigin-RevId: 746388529
If mesh axes are empty, we are setting mesh as None, resulting in an error in
this test.
This fix provides an empty mesh, when no mesh axes in dumped module are empty.
PiperOrigin-RevId: 746058506
This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.
Note that we will be introducing some restrictions under Shardy for JAX export:
- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)
We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.
PiperOrigin-RevId: 732878916
This causes problems because internal code assumes it will not be modified. We replace this with an internal registration mechanism.
PiperOrigin-RevId: 721000907
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.
PiperOrigin-RevId: 714069225
It turns out that the backend is rarely needed when lowering, e.g.,
for lowering callbacks. Whenever we need the backend for lowering,
we must be in single-platform lowering mode (`len(platforms) == 1`)
and we can look up the backend from `platforms[0]`.
However, in some rare cases we can have a custom `XlaBackend` whose
platform matches `platforms[0]`. We rename `backend_or_name` to just `backend`
and we restrict its type to be an optional `XlaBackend` (not a platform
string).
PiperOrigin-RevId: 712926140
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.
Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.
We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
This is a follow up from #25640 that enabled lowering with
AbstractMesh.
This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
These constraints turn out to be quite useful, e.g., when
we want to say that certain dimensions are a multiple of
a device axis.
Previously, the constraint `mod(e, k) == 0` was being useful
only to normalize away `mod(e, k)`. In particular it was not
useful for proving `k * floordiv(e, k)`. Now we add that
features.
Previously, an equality constraint was used only as a normalization
rule. This created a problem for constraints of the form "4*b=c",
because it would not allow proving that "b <= c" (since the
normalization of "4*b" kicks in only if "b" is multiplied by a
multiple of 4.
Now we add the equality constraints also in the inequality
reasoning state.
The caching used for the shape_poly.CachingShapeEvaluator leads to
leaked tracer errors. This is because the `lru_cache` is attached
to the `CachingShapeEvaluator.evaluate` and persists for the
duration of the program. It is possible to reimplement the caching,
but in this case caching does not help much so we just remove it.
With jax.experimental.export gone we can now do some cleanup in the export module.
In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.
PiperOrigin-RevId: 692398132