rocm_jax/jax/export.py
George Necula 2feea414ac [export] Add support for serialization for some custom PyTree nodes
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.

Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).

When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.

The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.

Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.

Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
2024-10-21 11:38:13 +02:00

41 lines
1.5 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
"register_pytree_node_serialization",
"register_namedtuple_serialization",
"maximum_supported_calling_convention_version",
"minimum_supported_calling_convention_version",
"default_export_platform",
"SymbolicScope", "is_symbolic_dim",
"symbolic_shape", "symbolic_args_specs"]
from jax._src.export._export import (
DisabledSafetyCheck,
Exported,
export,
deserialize,
register_pytree_node_serialization,
register_namedtuple_serialization,
maximum_supported_calling_convention_version,
minimum_supported_calling_convention_version,
default_export_platform)
from jax._src.export import shape_poly_decision # Import only to set the decision procedure
del shape_poly_decision
from jax._src.export.shape_poly import (
SymbolicScope,
is_symbolic_dim,
symbolic_shape,
symbolic_args_specs)