mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

See the added documentation for `jax._src.export.register_pytree_node_serialization` and `jax._src.export.register_namedtuple_serialization`. Serialization of PyTree nodes is needed to serialize the `in_tree` and `out_tree` fields of `Exported` functions (not to serialize actual instances of the custom types). When writing this I have looked at how TensorFlow handles namedtuple. It does so transparently, without requiring the user to register a serialization handler for the namedtuple type. But this has the disadvantage that on deserializaton a fresh distinct namedtuple type is created for each input and output type of the serialized function. This means that calling the deserialized function will return outputs of different types than then function that was serialized. This can be confusing. The Python pickle mode does a bit better: it attempts to look up the namedtuple type as a module attribute in the deserializing code, importing automatically the module whose name was saved during serialization. This is too much magic for my taste, as it can result in strange import errors. Hence I added an explicit step for the user to say how they want the namedtuple to be serialized and deserialized. Since I wanted to also add support for `collections.OrderedDict`, which users are asking for, I added more general support for PyTree custom nodes. Note that this registration mechanism works in conjunction with the PyTree custom node registration mechanism. The burden is on the user to decide how to serialize and deserialize the custom auxdata that the PyTree custom registration mechanism uses. Not all custom types will be serializable, but many commonly used ones, e.g., dataclasses, can now be inputs and outputs of the serialized functions.
74 lines
3.3 KiB
Python
74 lines
3.3 KiB
Python
# Copyright 2023 The JAX Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
_deprecation_message = (
|
|
"The jax.experimental.export module is deprecated. "
|
|
"Use jax.export instead. "
|
|
"See the migration guide at https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export."
|
|
)
|
|
|
|
from jax._src.export import _export as _src_export
|
|
from jax._src.export import shape_poly as _src_shape_poly
|
|
from jax._src.export import serialization as _src_serialization
|
|
# Import only to set the shape poly decision procedure
|
|
from jax._src.export import shape_poly_decision
|
|
del shape_poly_decision
|
|
|
|
# All deprecations added Jun 14, 2024
|
|
_deprecations = {
|
|
# Added Jun 13, 2024
|
|
"Exported": (_deprecation_message, _src_export.Exported),
|
|
"DisabledSafetyCheck": (_deprecation_message, _src_export.DisabledSafetyCheck),
|
|
"export": (_deprecation_message, _src_export.export_back_compat),
|
|
"call": (_deprecation_message, _src_export.call),
|
|
"call_exported": (_deprecation_message, _src_export.call_exported),
|
|
"default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform),
|
|
"minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
|
|
"maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
|
|
|
|
"serialize": (_deprecation_message, _src_serialization.serialize),
|
|
"deserialize": (_deprecation_message, _src_serialization.deserialize),
|
|
|
|
"SymbolicScope": (_deprecation_message, _src_shape_poly.SymbolicScope),
|
|
"is_symbolic_dim": (_deprecation_message, _src_shape_poly.is_symbolic_dim),
|
|
"symbolic_shape": (_deprecation_message, _src_shape_poly.symbolic_shape),
|
|
"symbolic_args_specs": (_deprecation_message, _src_shape_poly.symbolic_args_specs),
|
|
}
|
|
|
|
import typing
|
|
if typing.TYPE_CHECKING:
|
|
Exported = _src_export.Exported
|
|
DisabledSafetyCheck = _src_export.DisabledSafetyCheck
|
|
export = _src_export.export_back_compat
|
|
call = _src_export.call
|
|
call_exported = _src_export.call_exported
|
|
default_lowering_platform = _src_export.default_lowering_platform
|
|
|
|
serialize = _src_serialization.serialize
|
|
deserialize = _src_serialization.deserialize
|
|
|
|
SymbolicScope = _src_shape_poly.SymbolicScope
|
|
is_symbolic_dim = _src_shape_poly.is_symbolic_dim
|
|
symbolic_shape = _src_shape_poly.symbolic_shape
|
|
symbolic_args_specs = _src_shape_poly.symbolic_args_specs
|
|
else:
|
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
|
del _deprecation_getattr
|
|
del typing
|
|
del _src_export
|
|
del _src_serialization
|
|
del _src_shape_poly
|