13 Commits

Author SHA1 Message Date
George Necula
3df602882c [shape_poly] Small improvement in the documentation
Added an example for equality constraints.
2024-07-06 08:10:55 +03:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
George Necula
47f1b3de2c [export] Add documentation for debugging and for ensuring compatibility.
The rendered documentation is at https://jax--21976.org.readthedocs.build/en/21976/export/export.html#developer-documentation (for the export developer documentation, including compatibility) and https://jax--21976.org.readthedocs.build/en/21976/export/shape_poly.html#debugging (for the shape polymorphism debugging documentation)

While testing the compatibility mechanism I discovered that it can be circumvented by caches.
To fix this, I added export_ignore_forward_compatibility to mlir.LoweringParameters.
2024-06-28 08:36:55 +03:00
Peng Wang
99d90b23ee Fixes an error in jax.export.Exported's docstring.
PiperOrigin-RevId: 647115474
2024-06-26 16:27:07 -07:00
Lukas Geiger
8df518f7f8
Fix typo in jax.export migration guide 2024-06-18 16:27:18 +01:00
George Necula
b1a8c65883 [shape_poly] Add documentation for workaround with dimension parameters. 2024-06-17 20:14:20 +03:00
George Necula
b58ff2ba20 [shape_poly] Add documentation for shape polymorphism
This involved writing some new content and also moving and adapting
the documentation that existed as part of the jax2tf
README file:

https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion
2024-06-15 18:20:54 +03:00
Jake VanderPlas
a92fa547a0 Re-land https://github.com/google/jax/pull/21847
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b

PiperOrigin-RevId: 643202512
2024-06-13 19:53:53 -07:00
jax authors
0bcc81ceb3 Reverts 5aedafc214cf930f5b196b1eb130fd7ec866bc5e
PiperOrigin-RevId: 643131144
2024-06-13 14:58:54 -07:00
George Necula
7af03a8fd1 [export] Deprecate jax.experimental.export
And announce jax.export.

While turning on the DeprecationWarning I discovered a couple
of tests that needed adjustment.
2024-06-13 21:46:18 +03:00
George Necula
7c3a4db3e4 [export] Rename some API entry points
We take the opportunity of a new jax.export package to rename some
of the API entry points:

  * `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
    because this is more accurate. The dimension variables are global
    constants, but so is the platform index. And we need to run
    global constant propagation and shape refinement for all of these.
  * We rename "serialization version" with "calling convention version".
    Hence we now have `Exported.calling_convention_version`,
    and the configuration flag is renamed from `--jax-serialization-version`
    to `--jax-export-calling-convention-version`. Also,
    `jax.export.minimum_supported_serialization_version` is now
    `jax.export.minimum_supported_calling_convention_version`.
   * We rename `lowering_platforms` to `platforms` both as a field
    of `Exported` and as the kwarg to `export.export`.
   * We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
2024-06-13 06:44:13 +02:00
Jake VanderPlas
6e837da326 Document jax.export serialization version numbers 2024-06-12 12:44:42 -07:00
George Necula
105cc9a103 [export] Add documentation for jax.export 2024-06-12 19:44:47 +02:00