mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

The functionality comes from the jax.experimental.export module, which will be deprecated. The following APIs are introduced: ``` from jax import export def f(...): ... ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs) blob: bytearray = ex.serialize() rehydrated: export.Export = export.deserialize(blob) def caller(...): ... rehydrated.call(*args, **kwargs) ``` Module documentation will follow shortly. There are no changes for now in the jax.experimental.export APIs. Most of the changes in this PR are in tests due to some differences in the new jax.export APIs compared to jax.experimental.export: * Instead of `jax.experimental.export.call(exp)` we now write `exp.call` * The `jax.experimental.export.export` allowed the function argument to be any Python callable and it would wrap it with a `jax.jit`. This is not supported anymore by export, and instead the user must use `jax.jit`.