The functionality comes from the jax.experimental.export
module, which will be deprecated.
The following APIs are introduced:
```
from jax import export
def f(...): ...
ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)
blob: bytearray = ex.serialize()
rehydrated: export.Export = export.deserialize(blob)
def caller(...):
... rehydrated.call(*args, **kwargs)
```
Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.
Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:
* Instead of `jax.experimental.export.call(exp)` we now write
`exp.call`
* The `jax.experimental.export.export` allowed the function
argument to be any Python callable and it would wrap it with
a `jax.jit`. This is not supported anymore by export, and instead
the user must use `jax.jit`.
In some of the cases when we use many symbolic expressions for shapes, the operations with
symbolic expressions are becoming somewhat costly. These benchmarks help with picking
good internal representations.
PiperOrigin-RevId: 604289958