George Necula 2feea414ac [export] Add support for serialization for some custom PyTree nodes
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.

Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).

When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.

The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.

Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.

Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
2024-10-21 11:38:13 +02:00

74 lines
3.3 KiB
Python

# Copyright 2023 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
_deprecation_message = (
"The jax.experimental.export module is deprecated. "
"Use jax.export instead. "
"See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export."
)
from jax._src.export import _export as _src_export
from jax._src.export import shape_poly as _src_shape_poly
from jax._src.export import serialization as _src_serialization
# Import only to set the shape poly decision procedure
from jax._src.export import shape_poly_decision
del shape_poly_decision
# All deprecations added Jun 14, 2024
_deprecations = {
# Added Jun 13, 2024
"Exported": (_deprecation_message, _src_export.Exported),
"DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck),
"export": (_deprecation_message, _src_export.export_back_compat),
"call": (_deprecation_message, _src_export.call),
"call_exported": (_deprecation_message, _src_export.call_exported),
"default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform),
"minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
"maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
"serialize": (_deprecation_message, _src_serialization.serialize),
"deserialize": (_deprecation_message, _src_serialization.deserialize),
"SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope),
"is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim),
"symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape),
"symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs),
}
import typing
if typing.TYPE_CHECKING:
Exported = _src_export.Exported
DisabledSafetyCheck = _src_export.DisabledSafetyCheck
export = _src_export.export_back_compat
call = _src_export.call
call_exported = _src_export.call_exported
default_lowering_platform = _src_export.default_lowering_platform
serialize = _src_serialization.serialize
deserialize = _src_serialization.deserialize
SymbolicScope = _src_shape_poly.SymbolicScope
is_symbolic_dim = _src_shape_poly.is_symbolic_dim
symbolic_shape = _src_shape_poly.symbolic_shape
symbolic_args_specs = _src_shape_poly.symbolic_args_specs
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _src_export
del _src_serialization
del _src_shape_poly