94 Commits

Author SHA1 Message Date
Bart Chrzaszcz
ac493655bf #sdy support JAX export tests when Shardy is enabled.
This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.

Note that we will be introducing some restrictions under Shardy for JAX export:

- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)

We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.

PiperOrigin-RevId: 732878916
2025-03-03 04:57:06 -08:00
Jake VanderPlas
7ab7b214ac refactor: move jnp.einsum impl into its own submodule 2025-02-12 09:05:30 -08:00
George Necula
9f797990b5 Remove old backward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th, 2024 (#20997).

PiperOrigin-RevId: 723077990
2025-02-04 07:34:52 -08:00
Jake VanderPlas
955e7c4793 Internal: avoid adding _DimExpr to dtypes._weak_types
This causes problems because internal code assumes it will not be modified. We replace this with an internal registration mechanism.

PiperOrigin-RevId: 721000907
2025-01-29 09:11:02 -08:00
wenscarl
638c6ae046 Add e8m0fnu support by conditional dtype. 2025-01-22 21:57:43 +00:00
Dan Foreman-Mackey
39ce7916f1 Activate FFI implementation of tridiagonal reduction on GPU.
PiperOrigin-RevId: 714078036
2025-01-10 09:28:15 -08:00
Dan Foreman-Mackey
c1de7c733d Add LAPACK lowering for lax.linalg.tridiagonal_solve on CPU.
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.

PiperOrigin-RevId: 714069225
2025-01-10 08:56:46 -08:00
George Necula
fdb6af82d2 Clean up backend_or_name vs. platforms in lowering code.
It turns out that the backend is rarely needed when lowering, e.g.,
for lowering callbacks. Whenever we need the backend for lowering,
we must be in single-platform lowering mode (`len(platforms) == 1`)
and we can look up the backend from `platforms[0]`.

However, in some rare cases we can have a custom `XlaBackend` whose
platform matches `platforms[0]`. We rename `backend_or_name` to just `backend`
and we restrict its type to be an optional `XlaBackend` (not a platform
string).

PiperOrigin-RevId: 712926140
2025-01-07 08:42:57 -08:00
George Necula
bc3306c8bc [shape_poly] Improve threefry with symbolic shapes
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.

Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.

We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
2025-01-07 09:10:04 +02:00
George Necula
e87a2a5929 [shape_poly] Remove old non_negative support.
This was deprecated in January 2024, replaced by
`core_max_dim(..., 0)`.

PiperOrigin-RevId: 712523579
2025-01-06 07:36:11 -08:00
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
John QiangZhang
e560c6a45c Change the namespace name to avoid using export c++ keyword on namespace.
PiperOrigin-RevId: 708450293
2024-12-20 16:02:15 -08:00
Jake VanderPlas
c560f8e06c Unify abstractify & shaped_abstractify rules 2024-12-20 04:28:19 -08:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
Jake VanderPlas
2c722d9b13 Cleanup: toward merging core.concrete_aval & xla.abstractify 2024-12-17 09:27:00 -08:00
George Necula
afcb62ea20 [export] Expand exporting to work with AbstractMesh.
This is a follow up from #25640 that enabled lowering with
AbstractMesh.

This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
2024-12-16 10:30:46 +02:00
jax authors
ea63aeab01 Merge pull request #25442 from jakevdp:raise-to-shaped
PiperOrigin-RevId: 705556199
2024-12-12 10:43:17 -08:00
Jake VanderPlas
40367a9eaf Cleanup: remove uses of no-op raise_to_shaped 2024-12-12 09:49:06 -08:00
George Necula
27b024b240 [shape_poly] Improve handling of mod(e, k) == 0 constraints.
These constraints turn out to be quite useful, e.g., when
we want to say that certain dimensions are a multiple of
a device axis.

Previously, the constraint `mod(e, k) == 0` was being useful
only to normalize away `mod(e, k)`. In particular it was not
useful for proving `k * floordiv(e, k)`. Now we add that
features.
2024-12-12 10:31:02 +01:00
jax authors
01206f839b Merge pull request #25395 from gnecula:poly_better_eq
PiperOrigin-RevId: 705105803
2024-12-11 07:51:40 -08:00
Paweł Paruzel
1256153200 Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
2024-12-11 02:22:37 -08:00
George Necula
60f9da5d58 [shape_poly] Improve reasoning for >= in presence of == constraints.
Previously, an equality constraint was used only as a normalization
rule. This created a problem for constraints of the form "4*b=c",
because it would not allow proving that "b <= c" (since the
normalization of "4*b" kicks in only if "b" is multiplied by a
multiple of 4.

Now we add the equality constraints also in the inequality
reasoning state.
2024-12-11 10:51:49 +01:00
jax authors
90de28cd63 Merge pull request #25335 from gnecula:export_doc_call
PiperOrigin-RevId: 704589764
2024-12-10 00:45:20 -08:00
Sergei Lebedev
1ac6b762dd Ensured that JAX type checks under pytype on Python 3.12
Some errors uncovered by pytype look genuine and need to be revisited in
the in the future.

PiperOrigin-RevId: 704268742
2024-12-09 06:53:08 -08:00
Paweł Paruzel
d474feda9e Activate Tridiagonal Reduction to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Tridiagonal Reduction.

PiperOrigin-RevId: 704234350
2024-12-09 04:36:59 -08:00
George Necula
cc73c50c41 [export] Improved the documentation.
In particular added the docstring for `Exported.call` method,
and fixed the formatting for `Exported.in_shardings_jax`.
2024-12-08 17:39:24 +01:00
Paweł Paruzel
9081e85d68 Activate Schur Decomposition to XLA's FFI
PiperOrigin-RevId: 703484916
2024-12-06 06:49:53 -08:00
jax authors
9fc077a50b Merge pull request #25252 from gnecula:poly_power
PiperOrigin-RevId: 703471979
2024-12-06 05:54:18 -08:00
George Necula
3f5f3e1c47 [export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility.
This is because the underlying Triton IR does not guarantee compatibility.

PiperOrigin-RevId: 703127711
2024-12-05 08:42:41 -08:00
George Necula
5fe5206b6a [shape_poly] Remove some deprecated kwargs
PiperOrigin-RevId: 703116755
2024-12-05 08:02:38 -08:00
George Necula
4e17bea91a [shape_poly] Fix the handling of __pow__ for symbolic dimensions
The code for handling exponentiation was wrong, and there were
no tests.
2024-12-05 11:11:02 +01:00
jax authors
91891cb600 Merge pull request #23585 from apivovarov:float8_e4m3
PiperOrigin-RevId: 697760985
2024-11-18 14:34:59 -08:00
George Necula
45ae4dfb9e [shape_poly] Remove caching for the symbolic shape evaluator
The caching used for the shape_poly.CachingShapeEvaluator leads to
leaked tracer errors. This is because the `lru_cache` is attached
to the `CachingShapeEvaluator.evaluate` and persists for the
duration of the program. It is possible to reimplement the caching,
but in this case caching does not help much so we just remove it.
2024-11-09 11:13:48 +02:00
Sergei Lebedev
78da9fa432 Add float8_e4m3 and float8_e3m4 types support 2024-11-08 18:58:31 +00:00
George Necula
292a00b35a [export] Cleanup in the export module.
With jax.experimental.export gone we can now do some cleanup in the export module.

In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.

PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
Jake VanderPlas
8948e6de58 sharding cleanup: use inline checks for unimplemented and auto 2024-10-25 04:22:40 -07:00
jax authors
f4b84e1c97 Merge pull request #24342 from gnecula:export_custom_types
PiperOrigin-RevId: 688093192
2024-10-21 05:08:04 -07:00
George Necula
2feea414ac [export] Add support for serialization for some custom PyTree nodes
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.

Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).

When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.

The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.

Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.

Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
2024-10-21 11:38:13 +02:00
Dan Foreman-Mackey
8361eb58e1 Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
2024-10-17 17:57:06 -07:00
George Necula
9aa79bffba [export] Fix github links in the export documentation
Reflects the repo change google/jax -> jax-ml/jax.
Also changes the error message to put the link to the documentation
in a more visible place.
2024-10-17 08:30:28 +01:00
Gunhyun Park
af50c21225 Remove deprecated API and migrate to new API.
Context: https://github.com/jax-ml/jax/pull/21716
P.S. minor formatting fixes.
PiperOrigin-RevId: 684896546
2024-10-11 11:16:09 -07:00
Dan Foreman-Mackey
67f24df740 Activate FFI implementation of symmetric Eigendecomposition.
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.

PiperOrigin-RevId: 682415506
2024-10-04 12:38:26 -07:00
Dan Foreman-Mackey
c0240764bc Activate FFI implementation of the QR decomposition.
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.

The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.

PiperOrigin-RevId: 682312752
2024-10-04 07:27:11 -07:00
Paweł Paruzel
6e9a53690c Activate Hessenberg Decomposition to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Hessenberg Decomposition.

PiperOrigin-RevId: 681047625
2024-10-01 09:20:06 -07:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
jax authors
d776f1da76 Merge pull request #23470 from gnecula:poly_fix_eq_constraints
PiperOrigin-RevId: 671727351
2024-09-06 05:53:53 -07:00
George Necula
0d8ffd33ab [shape_polyO] Improve handling of equality shape constraints
This fixes several bugs in presence of equality constraints where
the left-hand side is just a dimension variable.

First, such constraints were not applied when parsing variables.
Now, with a constraint `a == b` when we parse "a" we obtain `b`.

Second, when we evaluate symbolic dimensions that contain
dimension variables that are constrained to be equal to something
else, we may fail to find the dimension variable in the environment
because the environment construction has applied the constraints.
We fix this by looking up the unknown dimension variable in
the equality constraints.

Fixes: #23437
Fixes: #23456
2024-09-06 13:55:38 +03:00