mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

See the added documentation for `jax._src.export.register_pytree_node_serialization` and `jax._src.export.register_namedtuple_serialization`. Serialization of PyTree nodes is needed to serialize the `in_tree` and `out_tree` fields of `Exported` functions (not to serialize actual instances of the custom types). When writing this I have looked at how TensorFlow handles namedtuple. It does so transparently, without requiring the user to register a serialization handler for the namedtuple type. But this has the disadvantage that on deserializaton a fresh distinct namedtuple type is created for each input and output type of the serialized function. This means that calling the deserialized function will return outputs of different types than then function that was serialized. This can be confusing. The Python pickle mode does a bit better: it attempts to look up the namedtuple type as a module attribute in the deserializing code, importing automatically the module whose name was saved during serialization. This is too much magic for my taste, as it can result in strange import errors. Hence I added an explicit step for the user to say how they want the namedtuple to be serialized and deserialized. Since I wanted to also add support for `collections.OrderedDict`, which users are asking for, I added more general support for PyTree custom nodes. Note that this registration mechanism works in conjunction with the PyTree custom node registration mechanism. The burden is on the user to decide how to serialize and deserialize the custom auxdata that the PyTree custom registration mechanism uses. Not all custom types will be serializable, but many commonly used ones, e.g., dataclasses, can now be inputs and outputs of the serialized functions.
41 lines
1.5 KiB
Python
41 lines
1.5 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
|
|
"register_pytree_node_serialization",
|
|
"register_namedtuple_serialization",
|
|
"maximum_supported_calling_convention_version",
|
|
"minimum_supported_calling_convention_version",
|
|
"default_export_platform",
|
|
"SymbolicScope", "is_symbolic_dim",
|
|
"symbolic_shape", "symbolic_args_specs"]
|
|
|
|
from jax._src.export._export import (
|
|
DisabledSafetyCheck,
|
|
Exported,
|
|
export,
|
|
deserialize,
|
|
register_pytree_node_serialization,
|
|
register_namedtuple_serialization,
|
|
maximum_supported_calling_convention_version,
|
|
minimum_supported_calling_convention_version,
|
|
default_export_platform)
|
|
|
|
from jax._src.export import shape_poly_decision # Import only to set the decision procedure
|
|
del shape_poly_decision
|
|
from jax._src.export.shape_poly import (
|
|
SymbolicScope,
|
|
is_symbolic_dim,
|
|
symbolic_shape,
|
|
symbolic_args_specs)
|