This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.
Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.
Testing: test workload.
PiperOrigin-RevId: 590378036
When sorting by granule key is disabled, the granules are used to create the mesh in the order in which they appear in the sequence of devices.
PiperOrigin-RevId: 590228169
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.
Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.
Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).
In the process of implementing this we have done some small cleanup of the Exported structure:
* renamed serialization_version to mlir_module_serialization_version
* renamed disabled_checks to disabled_safety_checks
This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.
There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.
PiperOrigin-RevId: 590078785
A while ago we moved the native exporting code out of jax2tf, to
jax.experimental.export, but we left behind shape_poly.py shim for
backwards compatibility. Now we remove this shim.
Replace previous
from jax.experimental.jax2tf import shape_poly
with
from jax.experimental.export import shape_poly
PiperOrigin-RevId: 589850402