16 Commits

Author SHA1 Message Date
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
George Necula
47f1b3de2c [export] Add documentation for debugging and for ensuring compatibility.
The rendered documentation is at https://jax--21976.org.readthedocs.build/en/21976/export/export.html#developer-documentation (for the export developer documentation, including compatibility) and https://jax--21976.org.readthedocs.build/en/21976/export/shape_poly.html#debugging (for the shape polymorphism debugging documentation)

While testing the compatibility mechanism I discovered that it can be circumvented by caches.
To fix this, I added export_ignore_forward_compatibility to mlir.LoweringParameters.
2024-06-28 08:36:55 +03:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
George Necula
d737abda48 [export] Fix multi-platform lowering for unknown platform, with donated_argnums
I had to ensure that the check for platforms supporting donation
only kicks in when we actually have donation.
2024-06-23 07:26:12 +03:00
Yash Katariya
6ba16e0348 Add lowering_platforms to traced.lower() to allow lowering to different backends and multi-backend lowering too. In other words, enable cross-lowering!
The motivation for doing this is 2-fold:

1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.

2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.

Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.

Designed with @froystig!

PiperOrigin-RevId: 644087787
2024-06-17 11:59:10 -07:00
George Necula
b1a8c65883 [shape_poly] Add documentation for workaround with dimension parameters. 2024-06-17 20:14:20 +03:00
George Necula
b58ff2ba20 [shape_poly] Add documentation for shape polymorphism
This involved writing some new content and also moving and adapting
the documentation that existed as part of the jax2tf
README file:

https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion
2024-06-15 18:20:54 +03:00
Yash Katariya
4ef33fa90e Add trace to stages.Wrapped and add docs for it.
PiperOrigin-RevId: 643128426
2024-06-13 14:49:54 -07:00
George Necula
7c3a4db3e4 [export] Rename some API entry points
We take the opportunity of a new jax.export package to rename some
of the API entry points:

  * `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
    because this is more accurate. The dimension variables are global
    constants, but so is the platform index. And we need to run
    global constant propagation and shape refinement for all of these.
  * We rename "serialization version" with "calling convention version".
    Hence we now have `Exported.calling_convention_version`,
    and the configuration flag is renamed from `--jax-serialization-version`
    to `--jax-export-calling-convention-version`. Also,
    `jax.export.minimum_supported_serialization_version` is now
    `jax.export.minimum_supported_calling_convention_version`.
   * We rename `lowering_platforms` to `platforms` both as a field
    of `Exported` and as the kwarg to `export.export`.
   * We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
2024-06-13 06:44:13 +02:00
Jake VanderPlas
6e837da326 Document jax.export serialization version numbers 2024-06-12 12:44:42 -07:00
George Necula
105cc9a103 [export] Add documentation for jax.export 2024-06-12 19:44:47 +02:00
George Necula
e3faf854b0 [export] Cleaned up types of [in|out]_shardings
Previously we declared Exported.in_shardings to be
a sequence of `core.AbstractValue`, but in reality we only
support `core.ShapedArray`. We change the type declaration and
this allowed us to clean up some `# type: ignore"
2024-06-11 13:46:44 +02:00
George Necula
b33aca6b08 [export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00
George Necula
14d87d3bf7 [export] Move the export implementation to jax._src.export.
This is part of the work to move the export APIs out
of jax.experimental. For now, the way to use this
implementation is still through `jax.experimental.export`.

Had to add a few "#type ignore" to the _export.py because
previously the file was exempt from internal pytype.
Will try to fix these in a later PR.

PiperOrigin-RevId: 641688200
2024-06-09 08:59:50 -07:00
George Necula
39ac584729 [shape_poly] Move to jax._src in preparation for adding to AOT APIs.
The shape polymorphism APIs are still private and are only exposed through `jax.experimental.export` as before.

PiperOrigin-RevId: 640393089
2024-06-04 22:03:24 -07:00