At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.
Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.
Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).
In the process of implementing this we have done some small cleanup of the Exported structure:
* renamed serialization_version to mlir_module_serialization_version
* renamed disabled_checks to disabled_safety_checks
This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.
There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.
PiperOrigin-RevId: 590078785
This fixes our RTD failures, which were caused by RTD installing an older version of pygments:
```
jupyterlab-pygments 0.1.1 requires pygments<3,>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
nbconvert 6.0.1 requires pygments>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
```
* Refactored host_callback to use the C++ runtime.
* The new runtime makes it unnecessary to start the outfeed_receiver
in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
I am trying to solve this by (a) making host_callback_test stop the
outfeed receiver on finish and infeed_test on start, and (b)
telling pytest-xdist to run all the tests from one file into
a single worker.
* Moved all notebooks to docs/notebooks.
Now all notebooks are in the same place, thus all are subject
to auto-doc generation at readthedocs.io and to automated testing
with travis.
Some notebooks are too slow, exclude them at docs/conf.py:exclude_patterns.
Cleanup a bit the section headings in notebooks so that they show
up well in readtehdocs.io.
* Increase the cell timeout for executing notebooks
* Exclude also the neural network notebook from auto-generation (timing out)
* Disable the score_matching notebook from auto-doc (travis does not have sklearn)
Had to extend the docs/requirements.txt file to install
matplotlb (needed by the Gotchas notebook) and ".",
needed by everything. This results in a reduction
of the sphinx warnings from 3300 to 1200!