106 Commits

Author SHA1 Message Date
Yash Katariya
6efcf44b1a Deprecate PositionalSharding and GSPMDSharding
PiperOrigin-RevId: 746564071
2025-04-11 13:06:43 -07:00
Dan Foreman-Mackey
81722201fd Remove legacy CPU custom call kernels that have been unused since v0.4.34.
As of today it has been 180 days since the release of 0.4.34 where the following legacy LAPACK kernels were no longer used when lowering:

* getrf
* geqrf / orgqr
* potrf
* gesdd
* syevd
* geev
* gehrd

Following our compatibility policy, these are now safe to remove.

PiperOrigin-RevId: 746388529
2025-04-11 03:17:19 -07:00
Kostiantyn Liepieshov
c730bbda74 fix bug in export_module when no mesh axes are empty for shardy.
If mesh axes are empty, we are setting mesh as None, resulting in an error in
this test.

This fix provides an empty mesh, when no mesh axes in dumped module are empty.

PiperOrigin-RevId: 746058506
2025-04-10 09:21:58 -07:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
jax authors
c2eaedfe94 Merge pull request #27776 from gnecula:export_keys
PiperOrigin-RevId: 745038060
2025-04-08 01:53:11 -07:00
George Necula
51dbcd4dad [export] Add backwards compatibility test for annotate_device_placement.
This enables exporting functions that use memory kinds to place
data in different memories.

jax-fixit

PiperOrigin-RevId: 745008959
2025-04-08 00:10:23 -07:00
George Necula
ce7dc85104 [export] Add support for serializing functions with PRNG keys as inputs/outputs
This introduces version 4 of serialization, fully backwards compatible
with versions 2 and 3.

Fixes: #24143
2025-04-07 11:53:20 +02:00
George Necula
1941714d26 [export] Add support for override_lowering_rules to jax.export.
This parameter is already part of the internal API for the
AOT lowering function, here we just expose it to `jax.export`.
2025-04-03 16:13:16 +01:00
Jake VanderPlas
431c2c0807 cleanup now that we depend on ml_dtypes>=0.5 2025-03-28 07:44:38 -07:00
Peter Hawkins
362fb7ae9d Remove code to support jaxlib < 0.5.3.
The new xla_extension_version is 320.

PiperOrigin-RevId: 738522486
2025-03-19 13:40:04 -07:00
Yash Katariya
a4ca0dbc6c Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
shuw
c099e8081d support e2m1fn 2025-03-05 17:44:34 +00:00
Bart Chrzaszcz
ac493655bf #sdy support JAX export tests when Shardy is enabled.
This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.

Note that we will be introducing some restrictions under Shardy for JAX export:

- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)

We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.

PiperOrigin-RevId: 732878916
2025-03-03 04:57:06 -08:00
Jake VanderPlas
7ab7b214ac refactor: move jnp.einsum impl into its own submodule 2025-02-12 09:05:30 -08:00
George Necula
9f797990b5 Remove old backward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th, 2024 (#20997).

PiperOrigin-RevId: 723077990
2025-02-04 07:34:52 -08:00
Jake VanderPlas
955e7c4793 Internal: avoid adding _DimExpr to dtypes._weak_types
This causes problems because internal code assumes it will not be modified. We replace this with an internal registration mechanism.

PiperOrigin-RevId: 721000907
2025-01-29 09:11:02 -08:00
wenscarl
638c6ae046 Add e8m0fnu support by conditional dtype. 2025-01-22 21:57:43 +00:00
Dan Foreman-Mackey
39ce7916f1 Activate FFI implementation of tridiagonal reduction on GPU.
PiperOrigin-RevId: 714078036
2025-01-10 09:28:15 -08:00
Dan Foreman-Mackey
c1de7c733d Add LAPACK lowering for lax.linalg.tridiagonal_solve on CPU.
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.

PiperOrigin-RevId: 714069225
2025-01-10 08:56:46 -08:00
George Necula
fdb6af82d2 Clean up backend_or_name vs. platforms in lowering code.
It turns out that the backend is rarely needed when lowering, e.g.,
for lowering callbacks. Whenever we need the backend for lowering,
we must be in single-platform lowering mode (`len(platforms) == 1`)
and we can look up the backend from `platforms[0]`.

However, in some rare cases we can have a custom `XlaBackend` whose
platform matches `platforms[0]`. We rename `backend_or_name` to just `backend`
and we restrict its type to be an optional `XlaBackend` (not a platform
string).

PiperOrigin-RevId: 712926140
2025-01-07 08:42:57 -08:00
George Necula
bc3306c8bc [shape_poly] Improve threefry with symbolic shapes
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.

Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.

We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
2025-01-07 09:10:04 +02:00
George Necula
e87a2a5929 [shape_poly] Remove old non_negative support.
This was deprecated in January 2024, replaced by
`core_max_dim(..., 0)`.

PiperOrigin-RevId: 712523579
2025-01-06 07:36:11 -08:00
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
John QiangZhang
e560c6a45c Change the namespace name to avoid using export c++ keyword on namespace.
PiperOrigin-RevId: 708450293
2024-12-20 16:02:15 -08:00
Jake VanderPlas
c560f8e06c Unify abstractify & shaped_abstractify rules 2024-12-20 04:28:19 -08:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
Jake VanderPlas
2c722d9b13 Cleanup: toward merging core.concrete_aval & xla.abstractify 2024-12-17 09:27:00 -08:00
George Necula
afcb62ea20 [export] Expand exporting to work with AbstractMesh.
This is a follow up from #25640 that enabled lowering with
AbstractMesh.

This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
2024-12-16 10:30:46 +02:00
jax authors
ea63aeab01 Merge pull request #25442 from jakevdp:raise-to-shaped
PiperOrigin-RevId: 705556199
2024-12-12 10:43:17 -08:00
Jake VanderPlas
40367a9eaf Cleanup: remove uses of no-op raise_to_shaped 2024-12-12 09:49:06 -08:00
George Necula
27b024b240 [shape_poly] Improve handling of mod(e, k) == 0 constraints.
These constraints turn out to be quite useful, e.g., when
we want to say that certain dimensions are a multiple of
a device axis.

Previously, the constraint `mod(e, k) == 0` was being useful
only to normalize away `mod(e, k)`. In particular it was not
useful for proving `k * floordiv(e, k)`. Now we add that
features.
2024-12-12 10:31:02 +01:00
jax authors
01206f839b Merge pull request #25395 from gnecula:poly_better_eq
PiperOrigin-RevId: 705105803
2024-12-11 07:51:40 -08:00
Paweł Paruzel
1256153200 Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
2024-12-11 02:22:37 -08:00
George Necula
60f9da5d58 [shape_poly] Improve reasoning for >= in presence of == constraints.
Previously, an equality constraint was used only as a normalization
rule. This created a problem for constraints of the form "4*b=c",
because it would not allow proving that "b <= c" (since the
normalization of "4*b" kicks in only if "b" is multiplied by a
multiple of 4.

Now we add the equality constraints also in the inequality
reasoning state.
2024-12-11 10:51:49 +01:00
jax authors
90de28cd63 Merge pull request #25335 from gnecula:export_doc_call
PiperOrigin-RevId: 704589764
2024-12-10 00:45:20 -08:00
Sergei Lebedev
1ac6b762dd Ensured that JAX type checks under pytype on Python 3.12
Some errors uncovered by pytype look genuine and need to be revisited in
the in the future.

PiperOrigin-RevId: 704268742
2024-12-09 06:53:08 -08:00
Paweł Paruzel
d474feda9e Activate Tridiagonal Reduction to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Tridiagonal Reduction.

PiperOrigin-RevId: 704234350
2024-12-09 04:36:59 -08:00
George Necula
cc73c50c41 [export] Improved the documentation.
In particular added the docstring for `Exported.call` method,
and fixed the formatting for `Exported.in_shardings_jax`.
2024-12-08 17:39:24 +01:00
Paweł Paruzel
9081e85d68 Activate Schur Decomposition to XLA's FFI
PiperOrigin-RevId: 703484916
2024-12-06 06:49:53 -08:00
jax authors
9fc077a50b Merge pull request #25252 from gnecula:poly_power
PiperOrigin-RevId: 703471979
2024-12-06 05:54:18 -08:00
George Necula
3f5f3e1c47 [export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility.
This is because the underlying Triton IR does not guarantee compatibility.

PiperOrigin-RevId: 703127711
2024-12-05 08:42:41 -08:00
George Necula
5fe5206b6a [shape_poly] Remove some deprecated kwargs
PiperOrigin-RevId: 703116755
2024-12-05 08:02:38 -08:00
George Necula
4e17bea91a [shape_poly] Fix the handling of __pow__ for symbolic dimensions
The code for handling exponentiation was wrong, and there were
no tests.
2024-12-05 11:11:02 +01:00
jax authors
91891cb600 Merge pull request #23585 from apivovarov:float8_e4m3
PiperOrigin-RevId: 697760985
2024-11-18 14:34:59 -08:00
George Necula
45ae4dfb9e [shape_poly] Remove caching for the symbolic shape evaluator
The caching used for the shape_poly.CachingShapeEvaluator leads to
leaked tracer errors. This is because the `lru_cache` is attached
to the `CachingShapeEvaluator.evaluate` and persists for the
duration of the program. It is possible to reimplement the caching,
but in this case caching does not help much so we just remove it.
2024-11-09 11:13:48 +02:00
Sergei Lebedev
78da9fa432 Add float8_e4m3 and float8_e3m4 types support 2024-11-08 18:58:31 +00:00
George Necula
292a00b35a [export] Cleanup in the export module.
With jax.experimental.export gone we can now do some cleanup in the export module.

In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.

PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
Jake VanderPlas
8948e6de58 sharding cleanup: use inline checks for unimplemented and auto 2024-10-25 04:22:40 -07:00