mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

Previously we used `from jax.experimental.export import export` and `export.export(fun)`. Now we want to add the public API directly to `jax.experimental.export`, for the following desired usage: ``` from jax.experimental import export exp: export.Exported = export.export(fun) ser: bytearray = export.serialize(exp) exp1: export.Exported = export.deserialized(ser) export.call(exp1) ``` This change requires changing the type of `jax.experimental.export.export` from a module to a function. This confuses pytype for the targets with strict type checking, which is why I attempt to make this change atomically throughout the internal code base. In order to support backwards compatibility with OSS packages, this change also includes explicit JAX version checks in several OSS packages, and also adds to the `export` function the attributes that the old export module had. PiperOrigin-RevId: 596563481