See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.
Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).
When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.
The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.
Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.
Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
We take the opportunity of a new jax.export package to rename some
of the API entry points:
* `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
because this is more accurate. The dimension variables are global
constants, but so is the platform index. And we need to run
global constant propagation and shape refinement for all of these.
* We rename "serialization version" with "calling convention version".
Hence we now have `Exported.calling_convention_version`,
and the configuration flag is renamed from `--jax-serialization-version`
to `--jax-export-calling-convention-version`. Also,
`jax.export.minimum_supported_serialization_version` is now
`jax.export.minimum_supported_calling_convention_version`.
* We rename `lowering_platforms` to `platforms` both as a field
of `Exported` and as the kwarg to `export.export`.
* We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
This is part of the work to move the export APIs out
of jax.experimental. For now, the way to use this
implementation is still through `jax.experimental.export`.
Had to add a few "#type ignore" to the _export.py because
previously the file was exempt from internal pytype.
Will try to fix these in a later PR.
PiperOrigin-RevId: 641688200