mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[export] Add support for serialization for some custom PyTree nodes
See the added documentation for `jax._src.export.register_pytree_node_serialization` and `jax._src.export.register_namedtuple_serialization`. Serialization of PyTree nodes is needed to serialize the `in_tree` and `out_tree` fields of `Exported` functions (not to serialize actual instances of the custom types). When writing this I have looked at how TensorFlow handles namedtuple. It does so transparently, without requiring the user to register a serialization handler for the namedtuple type. But this has the disadvantage that on deserializaton a fresh distinct namedtuple type is created for each input and output type of the serialized function. This means that calling the deserialized function will return outputs of different types than then function that was serialized. This can be confusing. The Python pickle mode does a bit better: it attempts to look up the namedtuple type as a module attribute in the deserializing code, importing automatically the module whose name was saved during serialization. This is too much magic for my taste, as it can result in strange import errors. Hence I added an explicit step for the user to say how they want the namedtuple to be serialized and deserialized. Since I wanted to also add support for `collections.OrderedDict`, which users are asking for, I added more general support for PyTree custom nodes. Note that this registration mechanism works in conjunction with the PyTree custom node registration mechanism. The burden is on the user to decide how to serialize and deserialize the custom auxdata that the PyTree custom registration mechanism uses. Not all custom types will be serializable, but many commonly used ones, e.g., dataclasses, can now be inputs and outputs of the serialized functions.
This commit is contained in:
parent
bb271aaff8
commit
2feea414ac
@ -17,13 +17,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Callable, Sequence
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, Protocol, TypeVar, Union, cast
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
@ -344,6 +346,189 @@ def deserialize(blob: bytearray) -> Exported:
|
||||
return deserialize(blob)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
PyTreeAuxData = Any # alias for tree_util._AuxData
|
||||
|
||||
|
||||
class _SerializeAuxData(Protocol):
|
||||
def __call__(self, aux_data: PyTreeAuxData) -> bytes:
|
||||
"""Serializes the PyTree node AuxData.
|
||||
|
||||
The AuxData is returned by the `flatten_func` registered by
|
||||
`tree_util.register_pytree_node`).
|
||||
"""
|
||||
|
||||
|
||||
class _DeserializeAuxData(Protocol):
|
||||
def __call__(self, serialized_aux_data: bytes) -> PyTreeAuxData:
|
||||
"""Deserializes the PyTree node AuxData.
|
||||
|
||||
The result will be passed to `_BuildFromChildren`.
|
||||
"""
|
||||
|
||||
|
||||
class _BuildFromChildren(Protocol):
|
||||
def __call__(self, aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any:
|
||||
"""Materializes a T given a deserialized AuxData and children.
|
||||
|
||||
This is similar in scope with the `unflatten_func`.
|
||||
"""
|
||||
|
||||
|
||||
serialization_registry: dict[type, tuple[str, _SerializeAuxData]] = {}
|
||||
|
||||
|
||||
deserialization_registry: dict[
|
||||
str,
|
||||
tuple[type, _DeserializeAuxData, _BuildFromChildren]] = {}
|
||||
|
||||
|
||||
def _is_namedtuple(nodetype: type) -> bool:
|
||||
return (issubclass(nodetype, tuple) and
|
||||
hasattr(nodetype, "_fields") and
|
||||
isinstance(nodetype._fields, Sequence) and
|
||||
all(isinstance(f, str) for f in nodetype._fields))
|
||||
|
||||
def register_pytree_node_serialization(
|
||||
nodetype: type[T],
|
||||
*,
|
||||
serialized_name: str,
|
||||
serialize_auxdata: _SerializeAuxData,
|
||||
deserialize_auxdata: _DeserializeAuxData,
|
||||
from_children: _BuildFromChildren | None = None
|
||||
) -> type[T]:
|
||||
"""Registers a custom PyTree node for serialization and deserialization.
|
||||
|
||||
You must use this function before you can serialize and deserialize PyTree
|
||||
nodes for the types not supported natively. We serialize PyTree nodes for
|
||||
the `in_tree` and `out_tree` fields of `Exported`, which are part of the
|
||||
exported function's calling convention.
|
||||
|
||||
This function must be called after calling
|
||||
`jax.tree_util.register_pytree_node` (except for `collections.namedtuple`,
|
||||
which do not require a call to `register_pytree_node`).
|
||||
|
||||
Args:
|
||||
nodetype: the type whose PyTree nodes we want to serialize. It is an
|
||||
error to attempt to register multiple serializations for a `nodetype`.
|
||||
serialized_name: a string that will be present in the serialization and
|
||||
will be used to look up the registration during deserialization. It is an
|
||||
error to attempt to register multiple serializations for a
|
||||
`serialized_name`.
|
||||
serialize_auxdata: serialize the PyTree auxdata (returned by the
|
||||
`flatten_func` argument to `jax.tree_util.register_pytree_node`.).
|
||||
deserialize_auxdata: deserialize the auxdata that was serialized by the
|
||||
`serialize_auxdata`.
|
||||
from_children: if present, this is a function that takes that result of
|
||||
`deserialize_auxdata` along with some children and creates an instance
|
||||
of `nodetype`. This is similar to the `unflatten_func` passed to
|
||||
`jax.tree_util.register_pytree_node`. If not present, we look up
|
||||
and use the `unflatten_func`. This is needed for `collections.namedtuple`,
|
||||
which does not have a `register_pytree_node`, but it can be useful to
|
||||
override that function. Note that the result of `from_children` is
|
||||
only used with `jax.tree_util.tree_structure` to construct a proper
|
||||
PyTree node, it is not used to construct the outputs of the serialized
|
||||
function.
|
||||
|
||||
Returns:
|
||||
the same type passed as `nodetype`, so that this function can
|
||||
be used as a class decorator.
|
||||
"""
|
||||
if nodetype in serialization_registry:
|
||||
raise ValueError(
|
||||
f"Duplicate serialization registration for type `{nodetype}`. "
|
||||
"Previous registration was with serialized_name "
|
||||
f"`{serialization_registry[nodetype][0]}`.")
|
||||
if serialized_name in deserialization_registry:
|
||||
raise ValueError(
|
||||
"Duplicate serialization registration for "
|
||||
f"serialized_name `{serialized_name}`. "
|
||||
"Previous registration was for type "
|
||||
f"`{deserialization_registry[serialized_name][0]}`.")
|
||||
if from_children is None:
|
||||
if nodetype not in tree_util._registry:
|
||||
raise ValueError(
|
||||
f"If `from_children` is not present, you must call first"
|
||||
f"`jax.tree_util.register_pytree_node` for `{nodetype}`")
|
||||
from_children = tree_util._registry[nodetype].from_iter
|
||||
|
||||
serialization_registry[nodetype] = (
|
||||
serialized_name, serialize_auxdata)
|
||||
deserialization_registry[serialized_name] = (
|
||||
nodetype, deserialize_auxdata, from_children)
|
||||
return nodetype
|
||||
|
||||
|
||||
def register_namedtuple_serialization(
|
||||
nodetype: type[T],
|
||||
*,
|
||||
serialized_name: str) -> type[T]:
|
||||
"""Registers a namedtuple for serialization and deserialization.
|
||||
|
||||
JAX has native PyTree support for `collections.namedtuple`, and does not
|
||||
require a call to `jax.tree_util.register_pytree_node`. However, if you
|
||||
want to serialize functions that have inputs of outputs of a
|
||||
namedtuple type, you must register that type for serialization.
|
||||
|
||||
Args:
|
||||
nodetype: the type whose PyTree nodes we want to serialize. It is an
|
||||
error to attempt to register multiple serializations for a `nodetype`.
|
||||
On deserialization, this type must have the same set of keys that
|
||||
were present during serialization.
|
||||
serialized_name: a string that will be present in the serialization and
|
||||
will be used to look up the registration during deserialization. It is an
|
||||
error to attempt to register multiple serializations for
|
||||
a `serialized_name`.
|
||||
|
||||
Returns:
|
||||
the same type passed as `nodetype`, so that this function can
|
||||
be used as a class decorator.
|
||||
"""
|
||||
if not _is_namedtuple(nodetype):
|
||||
raise ValueError("Use `jax.export.register_pytree_node_serialization` for "
|
||||
"types other than `collections.namedtuple`.")
|
||||
|
||||
def serialize_auxdata(aux_data: PyTreeAuxData) -> bytes:
|
||||
# Store the serialized keys in the serialized auxdata
|
||||
del aux_data
|
||||
return json.dumps(nodetype._fields).encode("utf-8")
|
||||
|
||||
def deserialize_auxdata(serialized_aux_data: bytes) -> PyTreeAuxData:
|
||||
return json.loads(serialized_aux_data.decode("utf-8"))
|
||||
|
||||
def from_children(aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any:
|
||||
# Use our own "from_children" because namedtuples do not have a pytree
|
||||
# registration.
|
||||
ser_keys = cast(Sequence[str], aux_data)
|
||||
assert len(ser_keys) == len(children)
|
||||
return nodetype(** dict(zip(ser_keys, children)))
|
||||
|
||||
return register_pytree_node_serialization(
|
||||
nodetype,
|
||||
serialized_name=serialized_name,
|
||||
serialize_auxdata=serialize_auxdata,
|
||||
deserialize_auxdata=deserialize_auxdata,
|
||||
from_children=from_children)
|
||||
|
||||
|
||||
# collections.OrderedDict is registered as a pytree node with auxdata being
|
||||
# `tuple(x.keys())`.
|
||||
def _serialize_ordereddict_keys(keys):
|
||||
if isinstance(keys, Sequence) and all(isinstance(k, str) for k in keys):
|
||||
return json.dumps(keys).encode("utf-8")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Serialization of collections.OrderedDict is supported only when the "
|
||||
f"keys are strings. Found keys: {keys}.")
|
||||
|
||||
|
||||
register_pytree_node_serialization(
|
||||
collections.OrderedDict,
|
||||
serialized_name="collections.OrderedDict",
|
||||
serialize_auxdata=_serialize_ordereddict_keys,
|
||||
deserialize_auxdata=lambda b: json.loads(b.decode("utf-8")))
|
||||
|
||||
|
||||
def default_export_platform() -> str:
|
||||
"""Retrieves the default export platform.
|
||||
|
||||
@ -404,9 +589,10 @@ def export_back_compat(
|
||||
disabled_checks: the safety checks to disable. See docstring
|
||||
of `DisabledSafetyCheck`.
|
||||
|
||||
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
|
||||
or values with `.shape` and `.dtype` attributes, and returns an
|
||||
`Exported`.
|
||||
Returns:
|
||||
a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
|
||||
or values with `.shape` and `.dtype` attributes, and returns an
|
||||
`Exported`.
|
||||
|
||||
Usage:
|
||||
|
||||
@ -480,9 +666,10 @@ def export(
|
||||
disabled_checks: the safety checks to disable. See documentation for
|
||||
of `jax.export.DisabledSafetyCheck`.
|
||||
|
||||
Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
|
||||
or values with `.shape` and `.dtype` attributes, and returns an
|
||||
`Exported`.
|
||||
Returns:
|
||||
a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
|
||||
or values with `.shape` and `.dtype` attributes, and returns an
|
||||
`Exported`.
|
||||
|
||||
Usage:
|
||||
|
||||
|
@ -28,12 +28,15 @@ enum PyTreeDefKind: byte {
|
||||
tuple = 2,
|
||||
list = 3,
|
||||
dict = 4,
|
||||
custom = 5,
|
||||
}
|
||||
|
||||
table PyTreeDef {
|
||||
kind: PyTreeDefKind;
|
||||
children: [PyTreeDef];
|
||||
children_names: [string]; // only for "dict"
|
||||
children_names: [string]; // only for "kind==dict"
|
||||
custom_name: string; // only for "kind==custom"
|
||||
custom_auxdata: [byte]; // only for "kind==custom"
|
||||
}
|
||||
|
||||
enum AbstractValueKind: byte {
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import TypeVar
|
||||
@ -45,6 +46,8 @@ SerT = TypeVar("SerT")
|
||||
# even if the change is backwards compatible.
|
||||
# Version 1, Nov 2023, first version.
|
||||
# Version 2, Dec 16th, 2023, adds the f0 dtype.
|
||||
# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types
|
||||
# This version is backwards compatible with Version 2.
|
||||
_SERIALIZATION_VERSION = 2
|
||||
|
||||
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
|
||||
@ -152,13 +155,13 @@ def _serialize_array(
|
||||
|
||||
def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
|
||||
serialization_version = exp.SerializationVersion()
|
||||
if serialization_version != _SERIALIZATION_VERSION:
|
||||
if serialization_version not in [2, 3]:
|
||||
raise NotImplementedError(
|
||||
f"deserialize unsupported version {serialization_version}"
|
||||
)
|
||||
|
||||
fun_name = exp.FunctionName().decode("utf-8")
|
||||
_, in_tree = tree_util.tree_flatten(
|
||||
in_tree = tree_util.tree_structure(
|
||||
_deserialize_pytreedef_to_pytree(exp.InTree())
|
||||
)
|
||||
scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints
|
||||
@ -166,7 +169,7 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
|
||||
in_avals = _deserialize_tuple(
|
||||
exp.InAvalsLength, exp.InAvals, deser_aval
|
||||
)
|
||||
_, out_tree = tree_util.tree_flatten(
|
||||
out_tree = tree_util.tree_structure(
|
||||
_deserialize_pytreedef_to_pytree(exp.OutTree())
|
||||
)
|
||||
out_avals = _deserialize_tuple(
|
||||
@ -246,23 +249,40 @@ def _serialize_pytreedef(
|
||||
children_vector_offset = _serialize_array(
|
||||
builder, _serialize_pytreedef, children
|
||||
)
|
||||
custom_name = None
|
||||
custom_auxdata = None
|
||||
node_type = node_data and node_data[0]
|
||||
|
||||
if node_data is None: # leaf
|
||||
kind = ser_flatbuf.PyTreeDefKind.leaf
|
||||
elif node_data[0] is type(None):
|
||||
elif node_type is types.NoneType:
|
||||
kind = ser_flatbuf.PyTreeDefKind.none
|
||||
elif node_data[0] is tuple:
|
||||
elif node_type is tuple:
|
||||
kind = ser_flatbuf.PyTreeDefKind.tuple
|
||||
elif node_data[0] is list:
|
||||
elif node_type is list:
|
||||
kind = ser_flatbuf.PyTreeDefKind.list
|
||||
elif node_data[0] is dict:
|
||||
elif node_type is dict:
|
||||
kind = ser_flatbuf.PyTreeDefKind.dict
|
||||
assert len(node_data[1]) == len(children)
|
||||
children_names_vector_offset = _serialize_array(
|
||||
builder, lambda b, s: b.CreateString(s), node_data[1]
|
||||
)
|
||||
elif node_type in _export.serialization_registry:
|
||||
kind = ser_flatbuf.PyTreeDefKind.custom
|
||||
serialized_name, serialize_auxdata = _export.serialization_registry[node_type]
|
||||
custom_name = builder.CreateString(serialized_name)
|
||||
serialized_auxdata = serialize_auxdata(node_data[1])
|
||||
if not isinstance(serialized_auxdata, (bytes, bytearray)):
|
||||
raise ValueError(
|
||||
"The custom serialization function for `node_type` must "
|
||||
f"return a `bytes` object. It returned a {type(serialized_auxdata)}.")
|
||||
custom_auxdata = builder.CreateByteVector(serialized_auxdata)
|
||||
else:
|
||||
raise NotImplementedError(f"serializing PyTreeDef {node_data}")
|
||||
raise ValueError(
|
||||
"Cannot serialize PyTreeDef containing an "
|
||||
f"unregistered type `{node_type}`. "
|
||||
"Use `export.register_pytree_node_serialization` or "
|
||||
"`export.register_namedtuple_serialization`.")
|
||||
|
||||
ser_flatbuf.PyTreeDefStart(builder)
|
||||
ser_flatbuf.PyTreeDefAddKind(builder, kind)
|
||||
@ -270,6 +290,10 @@ def _serialize_pytreedef(
|
||||
ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset)
|
||||
if children_names_vector_offset:
|
||||
ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset)
|
||||
if custom_name is not None:
|
||||
ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name)
|
||||
if custom_auxdata is not None:
|
||||
ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata)
|
||||
return ser_flatbuf.PyTreeDefEnd(builder)
|
||||
|
||||
|
||||
@ -294,6 +318,17 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
|
||||
assert p.ChildrenNamesLength() == nr_children
|
||||
keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)]
|
||||
return dict(zip(keys, children))
|
||||
elif kind == ser_flatbuf.PyTreeDefKind.custom:
|
||||
serialized_name = p.CustomName().decode("utf-8")
|
||||
if serialized_name not in _export.deserialization_registry:
|
||||
raise ValueError(
|
||||
"Cannot deserialize a PyTreeDef containing an "
|
||||
f"unregistered type `{serialized_name}`. "
|
||||
"Use `export.register_pytree_node_serialization` or "
|
||||
"`export.register_namedtuple_serialization`.")
|
||||
nodetype, deserialize_auxdata, from_iter = _export.deserialization_registry[serialized_name]
|
||||
auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes())
|
||||
return from_iter(auxdata, children)
|
||||
else:
|
||||
assert False, kind
|
||||
|
||||
|
@ -21,20 +21,21 @@ import flatbuffers
|
||||
from flatbuffers.compat import import_numpy
|
||||
np = import_numpy()
|
||||
|
||||
class PyTreeDefKind:
|
||||
class PyTreeDefKind(object):
|
||||
leaf = 0
|
||||
none = 1
|
||||
tuple = 2
|
||||
list = 3
|
||||
dict = 4
|
||||
custom = 5
|
||||
|
||||
|
||||
class AbstractValueKind:
|
||||
class AbstractValueKind(object):
|
||||
shapedArray = 0
|
||||
abstractToken = 1
|
||||
|
||||
|
||||
class DType:
|
||||
class DType(object):
|
||||
bool = 0
|
||||
i8 = 1
|
||||
i16 = 2
|
||||
@ -60,18 +61,18 @@ class DType:
|
||||
f0 = 22
|
||||
|
||||
|
||||
class ShardingKind:
|
||||
class ShardingKind(object):
|
||||
unspecified = 0
|
||||
hlo_sharding = 1
|
||||
|
||||
|
||||
class DisabledSafetyCheckKind:
|
||||
class DisabledSafetyCheckKind(object):
|
||||
platform = 0
|
||||
custom_call = 1
|
||||
shape_assertions = 2
|
||||
|
||||
|
||||
class PyTreeDef:
|
||||
class PyTreeDef(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -140,8 +141,42 @@ class PyTreeDef:
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
||||
return o == 0
|
||||
|
||||
# PyTreeDef
|
||||
def CustomName(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
||||
if o != 0:
|
||||
return self._tab.String(o + self._tab.Pos)
|
||||
return None
|
||||
|
||||
# PyTreeDef
|
||||
def CustomAuxdata(self, j):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
a = self._tab.Vector(o)
|
||||
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
|
||||
return 0
|
||||
|
||||
# PyTreeDef
|
||||
def CustomAuxdataAsNumpy(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
|
||||
return 0
|
||||
|
||||
# PyTreeDef
|
||||
def CustomAuxdataLength(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
if o != 0:
|
||||
return self._tab.VectorLen(o)
|
||||
return 0
|
||||
|
||||
# PyTreeDef
|
||||
def CustomAuxdataIsNone(self):
|
||||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
||||
return o == 0
|
||||
|
||||
def PyTreeDefStart(builder):
|
||||
builder.StartObject(3)
|
||||
builder.StartObject(5)
|
||||
|
||||
def PyTreeDefAddKind(builder, kind):
|
||||
builder.PrependInt8Slot(0, kind, 0)
|
||||
@ -158,12 +193,21 @@ def PyTreeDefAddChildrenNames(builder, childrenNames):
|
||||
def PyTreeDefStartChildrenNamesVector(builder, numElems):
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
|
||||
def PyTreeDefAddCustomName(builder, customName):
|
||||
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(customName), 0)
|
||||
|
||||
def PyTreeDefAddCustomAuxdata(builder, customAuxdata):
|
||||
builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(customAuxdata), 0)
|
||||
|
||||
def PyTreeDefStartCustomAuxdataVector(builder, numElems):
|
||||
return builder.StartVector(1, numElems, 1)
|
||||
|
||||
def PyTreeDefEnd(builder):
|
||||
return builder.EndObject()
|
||||
|
||||
|
||||
|
||||
class AbstractValue:
|
||||
class AbstractValue(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -235,7 +279,7 @@ def AbstractValueEnd(builder):
|
||||
|
||||
|
||||
|
||||
class Sharding:
|
||||
class Sharding(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -304,7 +348,7 @@ def ShardingEnd(builder):
|
||||
|
||||
|
||||
|
||||
class Effect:
|
||||
class Effect(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -340,7 +384,7 @@ def EffectEnd(builder):
|
||||
|
||||
|
||||
|
||||
class DisabledSafetyCheck:
|
||||
class DisabledSafetyCheck(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
@ -386,7 +430,7 @@ def DisabledSafetyCheckEnd(builder):
|
||||
|
||||
|
||||
|
||||
class Exported:
|
||||
class Exported(object):
|
||||
__slots__ = ['_tab']
|
||||
|
||||
@classmethod
|
||||
|
@ -35,8 +35,8 @@ _deprecations = {
|
||||
"call": (_deprecation_message, _src_export.call),
|
||||
"call_exported": (_deprecation_message, _src_export.call_exported),
|
||||
"default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform),
|
||||
"minimum_supported_serialization_version" : (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
|
||||
"maximum_supported_serialization_version" : (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
|
||||
"minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
|
||||
"maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
|
||||
|
||||
"serialize": (_deprecation_message, _src_serialization.serialize),
|
||||
"deserialize": (_deprecation_message, _src_serialization.deserialize),
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
|
||||
"register_pytree_node_serialization",
|
||||
"register_namedtuple_serialization",
|
||||
"maximum_supported_calling_convention_version",
|
||||
"minimum_supported_calling_convention_version",
|
||||
"default_export_platform",
|
||||
@ -23,6 +25,8 @@ from jax._src.export._export import (
|
||||
Exported,
|
||||
export,
|
||||
deserialize,
|
||||
register_pytree_node_serialization,
|
||||
register_namedtuple_serialization,
|
||||
maximum_supported_calling_convention_version,
|
||||
minimum_supported_calling_convention_version,
|
||||
default_export_platform)
|
||||
|
@ -13,11 +13,13 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Callable, Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import unittest
|
||||
@ -32,6 +34,7 @@ from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
@ -311,6 +314,123 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f((a, b), a=a, b=b),
|
||||
exp_f.call((a, b), a=a, b=b))
|
||||
|
||||
def test_pytree_namedtuple(self):
|
||||
T = collections.namedtuple("SomeType", ("a", "b", "c"))
|
||||
export.register_namedtuple_serialization(
|
||||
T,
|
||||
serialized_name="test_pytree_namedtuple.SomeType",
|
||||
)
|
||||
x = T(a=1, b=2, c=3)
|
||||
|
||||
def f(x):
|
||||
return (x, x) # return 2 copies, to check that types are shared
|
||||
|
||||
exp = export.export(jax.jit(f))(x)
|
||||
res = exp.call(x)
|
||||
self.assertEqual(tree_util.tree_structure(res),
|
||||
tree_util.tree_structure((x, x)))
|
||||
self.assertEqual(type(res[0]), type(x))
|
||||
self.assertEqual(type(res[1]), type(x))
|
||||
ser = exp.serialize()
|
||||
exp2 = export.deserialize(ser)
|
||||
self.assertEqual(exp2.in_tree, exp.in_tree)
|
||||
self.assertEqual(exp2.out_tree, exp.out_tree)
|
||||
res2 = exp2.call(x)
|
||||
self.assertEqual(tree_util.tree_structure(res2),
|
||||
tree_util.tree_structure(res))
|
||||
|
||||
def test_pytree_namedtuple_error(self):
|
||||
T = collections.namedtuple("SomeType", ("a", "b"))
|
||||
x = T(a=1, b=2)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Cannot serialize .* unregistered type .*SomeType"):
|
||||
export.export(jax.jit(lambda x: x))(x).serialize()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"If `from_children` is not present.* must call.*register_pytree_node"
|
||||
):
|
||||
export.register_pytree_node_serialization(
|
||||
T,
|
||||
serialized_name="test_pytree_namedtuple.SomeType_V2",
|
||||
serialize_auxdata=lambda x: b"",
|
||||
deserialize_auxdata=lambda b: None
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Use .*register_pytree_node_serialization"):
|
||||
export.register_namedtuple_serialization(str, serialized_name="n/a")
|
||||
|
||||
export.register_namedtuple_serialization(
|
||||
T,
|
||||
serialized_name="test_pytree_namedtuple_error.SomeType",
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Duplicate serialization registration .*test_pytree_namedtuple_error.SomeType"
|
||||
):
|
||||
export.register_namedtuple_serialization(
|
||||
T,
|
||||
serialized_name="test_pytree_namedtuple_error.OtherType",
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Duplicate serialization registration for serialized_name.*test_pytree_namedtuple_error.SomeType"
|
||||
):
|
||||
export.register_namedtuple_serialization(
|
||||
collections.namedtuple("SomeOtherType", ("a", "b")),
|
||||
serialized_name="test_pytree_namedtuple_error.SomeType",
|
||||
)
|
||||
|
||||
def test_pytree_custom_types(self):
|
||||
x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)])
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class CustomType:
|
||||
def __init__(self, a: int, b: CustomType | None, string: str):
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.string = string
|
||||
|
||||
def tree_flatten(self):
|
||||
return ((self.a, self.b), self.string)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
string = aux_data
|
||||
return cls(*children, string)
|
||||
|
||||
export.register_pytree_node_serialization(
|
||||
CustomType,
|
||||
serialized_name="test_pytree_custom_types.CustomType",
|
||||
serialize_auxdata=lambda aux: aux.encode("utf-8"),
|
||||
deserialize_auxdata=lambda b: b.decode("utf-8")
|
||||
)
|
||||
x2 = CustomType(4, 5, "foo")
|
||||
|
||||
def f(x1, x2):
|
||||
return (x1, x2, x1, x2) # return 2 copies, to check that types are shared
|
||||
|
||||
exp = export.export(jax.jit(f))(x1, x2)
|
||||
res = exp.call(x1, x2)
|
||||
self.assertEqual(tree_util.tree_structure(res),
|
||||
tree_util.tree_structure(((x1, x2, x1, x2))))
|
||||
self.assertEqual(type(res[0]), type(x1))
|
||||
self.assertEqual(type(res[1]), type(x2))
|
||||
self.assertEqual(type(res[2]), type(x1))
|
||||
self.assertEqual(type(res[3]), type(x2))
|
||||
ser = exp.serialize()
|
||||
exp2 = export.deserialize(ser)
|
||||
self.assertEqual(exp2.in_tree, exp.in_tree)
|
||||
self.assertEqual(exp2.out_tree, exp.out_tree)
|
||||
res2 = exp2.call(x1, x2)
|
||||
self.assertEqual(tree_util.tree_structure(res2),
|
||||
tree_util.tree_structure(res))
|
||||
|
||||
|
||||
def test_error_wrong_intree(self):
|
||||
def f(a_b_pair, *, c):
|
||||
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c
|
||||
|
Loading…
x
Reference in New Issue
Block a user