[export] Add support for serialization for some custom PyTree nodes

See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.

Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).

When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.

The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.

Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.

Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
This commit is contained in:
George Necula 2024-10-16 18:08:25 +01:00
parent bb271aaff8
commit 2feea414ac
7 changed files with 423 additions and 30 deletions

View File

@ -17,13 +17,15 @@
from __future__ import annotations
import collections
from collections.abc import Callable, Sequence
import copy
import dataclasses
import functools
import itertools
import json
import re
from typing import Any, Union, cast
from typing import Any, Protocol, TypeVar, Union, cast
import warnings
from absl import logging
@ -344,6 +346,189 @@ def deserialize(blob: bytearray) -> Exported:
return deserialize(blob)
T = TypeVar("T")
PyTreeAuxData = Any # alias for tree_util._AuxData
class _SerializeAuxData(Protocol):
def __call__(self, aux_data: PyTreeAuxData) -> bytes:
"""Serializes the PyTree node AuxData.
The AuxData is returned by the `flatten_func` registered by
`tree_util.register_pytree_node`).
"""
class _DeserializeAuxData(Protocol):
def __call__(self, serialized_aux_data: bytes) -> PyTreeAuxData:
"""Deserializes the PyTree node AuxData.
The result will be passed to `_BuildFromChildren`.
"""
class _BuildFromChildren(Protocol):
def __call__(self, aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any:
"""Materializes a T given a deserialized AuxData and children.
This is similar in scope with the `unflatten_func`.
"""
serialization_registry: dict[type, tuple[str, _SerializeAuxData]] = {}
deserialization_registry: dict[
str,
tuple[type, _DeserializeAuxData, _BuildFromChildren]] = {}
def _is_namedtuple(nodetype: type) -> bool:
return (issubclass(nodetype, tuple) and
hasattr(nodetype, "_fields") and
isinstance(nodetype._fields, Sequence) and
all(isinstance(f, str) for f in nodetype._fields))
def register_pytree_node_serialization(
nodetype: type[T],
*,
serialized_name: str,
serialize_auxdata: _SerializeAuxData,
deserialize_auxdata: _DeserializeAuxData,
from_children: _BuildFromChildren | None = None
) -> type[T]:
"""Registers a custom PyTree node for serialization and deserialization.
You must use this function before you can serialize and deserialize PyTree
nodes for the types not supported natively. We serialize PyTree nodes for
the `in_tree` and `out_tree` fields of `Exported`, which are part of the
exported function's calling convention.
This function must be called after calling
`jax.tree_util.register_pytree_node` (except for `collections.namedtuple`,
which do not require a call to `register_pytree_node`).
Args:
nodetype: the type whose PyTree nodes we want to serialize. It is an
error to attempt to register multiple serializations for a `nodetype`.
serialized_name: a string that will be present in the serialization and
will be used to look up the registration during deserialization. It is an
error to attempt to register multiple serializations for a
`serialized_name`.
serialize_auxdata: serialize the PyTree auxdata (returned by the
`flatten_func` argument to `jax.tree_util.register_pytree_node`.).
deserialize_auxdata: deserialize the auxdata that was serialized by the
`serialize_auxdata`.
from_children: if present, this is a function that takes that result of
`deserialize_auxdata` along with some children and creates an instance
of `nodetype`. This is similar to the `unflatten_func` passed to
`jax.tree_util.register_pytree_node`. If not present, we look up
and use the `unflatten_func`. This is needed for `collections.namedtuple`,
which does not have a `register_pytree_node`, but it can be useful to
override that function. Note that the result of `from_children` is
only used with `jax.tree_util.tree_structure` to construct a proper
PyTree node, it is not used to construct the outputs of the serialized
function.
Returns:
the same type passed as `nodetype`, so that this function can
be used as a class decorator.
"""
if nodetype in serialization_registry:
raise ValueError(
f"Duplicate serialization registration for type `{nodetype}`. "
"Previous registration was with serialized_name "
f"`{serialization_registry[nodetype][0]}`.")
if serialized_name in deserialization_registry:
raise ValueError(
"Duplicate serialization registration for "
f"serialized_name `{serialized_name}`. "
"Previous registration was for type "
f"`{deserialization_registry[serialized_name][0]}`.")
if from_children is None:
if nodetype not in tree_util._registry:
raise ValueError(
f"If `from_children` is not present, you must call first"
f"`jax.tree_util.register_pytree_node` for `{nodetype}`")
from_children = tree_util._registry[nodetype].from_iter
serialization_registry[nodetype] = (
serialized_name, serialize_auxdata)
deserialization_registry[serialized_name] = (
nodetype, deserialize_auxdata, from_children)
return nodetype
def register_namedtuple_serialization(
nodetype: type[T],
*,
serialized_name: str) -> type[T]:
"""Registers a namedtuple for serialization and deserialization.
JAX has native PyTree support for `collections.namedtuple`, and does not
require a call to `jax.tree_util.register_pytree_node`. However, if you
want to serialize functions that have inputs of outputs of a
namedtuple type, you must register that type for serialization.
Args:
nodetype: the type whose PyTree nodes we want to serialize. It is an
error to attempt to register multiple serializations for a `nodetype`.
On deserialization, this type must have the same set of keys that
were present during serialization.
serialized_name: a string that will be present in the serialization and
will be used to look up the registration during deserialization. It is an
error to attempt to register multiple serializations for
a `serialized_name`.
Returns:
the same type passed as `nodetype`, so that this function can
be used as a class decorator.
"""
if not _is_namedtuple(nodetype):
raise ValueError("Use `jax.export.register_pytree_node_serialization` for "
"types other than `collections.namedtuple`.")
def serialize_auxdata(aux_data: PyTreeAuxData) -> bytes:
# Store the serialized keys in the serialized auxdata
del aux_data
return json.dumps(nodetype._fields).encode("utf-8")
def deserialize_auxdata(serialized_aux_data: bytes) -> PyTreeAuxData:
return json.loads(serialized_aux_data.decode("utf-8"))
def from_children(aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any:
# Use our own "from_children" because namedtuples do not have a pytree
# registration.
ser_keys = cast(Sequence[str], aux_data)
assert len(ser_keys) == len(children)
return nodetype(** dict(zip(ser_keys, children)))
return register_pytree_node_serialization(
nodetype,
serialized_name=serialized_name,
serialize_auxdata=serialize_auxdata,
deserialize_auxdata=deserialize_auxdata,
from_children=from_children)
# collections.OrderedDict is registered as a pytree node with auxdata being
# `tuple(x.keys())`.
def _serialize_ordereddict_keys(keys):
if isinstance(keys, Sequence) and all(isinstance(k, str) for k in keys):
return json.dumps(keys).encode("utf-8")
else:
raise NotImplementedError(
"Serialization of collections.OrderedDict is supported only when the "
f"keys are strings. Found keys: {keys}.")
register_pytree_node_serialization(
collections.OrderedDict,
serialized_name="collections.OrderedDict",
serialize_auxdata=_serialize_ordereddict_keys,
deserialize_auxdata=lambda b: json.loads(b.decode("utf-8")))
def default_export_platform() -> str:
"""Retrieves the default export platform.
@ -404,9 +589,10 @@ def export_back_compat(
disabled_checks: the safety checks to disable. See docstring
of `DisabledSafetyCheck`.
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
or values with `.shape` and `.dtype` attributes, and returns an
`Exported`.
Returns:
a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
or values with `.shape` and `.dtype` attributes, and returns an
`Exported`.
Usage:
@ -480,9 +666,10 @@ def export(
disabled_checks: the safety checks to disable. See documentation for
of `jax.export.DisabledSafetyCheck`.
Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
or values with `.shape` and `.dtype` attributes, and returns an
`Exported`.
Returns:
a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
or values with `.shape` and `.dtype` attributes, and returns an
`Exported`.
Usage:

View File

@ -28,12 +28,15 @@ enum PyTreeDefKind: byte {
tuple = 2,
list = 3,
dict = 4,
custom = 5,
}
table PyTreeDef {
kind: PyTreeDefKind;
children: [PyTreeDef];
children_names: [string]; // only for "dict"
children_names: [string]; // only for "kind==dict"
custom_name: string; // only for "kind==custom"
custom_auxdata: [byte]; // only for "kind==custom"
}
enum AbstractValueKind: byte {

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import types
from collections.abc import Callable, Sequence
from functools import partial
from typing import TypeVar
@ -45,6 +46,8 @@ SerT = TypeVar("SerT")
# even if the change is backwards compatible.
# Version 1, Nov 2023, first version.
# Version 2, Dec 16th, 2023, adds the f0 dtype.
# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types
# This version is backwards compatible with Version 2.
_SERIALIZATION_VERSION = 2
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
@ -152,13 +155,13 @@ def _serialize_array(
def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
serialization_version = exp.SerializationVersion()
if serialization_version != _SERIALIZATION_VERSION:
if serialization_version not in [2, 3]:
raise NotImplementedError(
f"deserialize unsupported version {serialization_version}"
)
fun_name = exp.FunctionName().decode("utf-8")
_, in_tree = tree_util.tree_flatten(
in_tree = tree_util.tree_structure(
_deserialize_pytreedef_to_pytree(exp.InTree())
)
scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints
@ -166,7 +169,7 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
in_avals = _deserialize_tuple(
exp.InAvalsLength, exp.InAvals, deser_aval
)
_, out_tree = tree_util.tree_flatten(
out_tree = tree_util.tree_structure(
_deserialize_pytreedef_to_pytree(exp.OutTree())
)
out_avals = _deserialize_tuple(
@ -246,23 +249,40 @@ def _serialize_pytreedef(
children_vector_offset = _serialize_array(
builder, _serialize_pytreedef, children
)
custom_name = None
custom_auxdata = None
node_type = node_data and node_data[0]
if node_data is None: # leaf
kind = ser_flatbuf.PyTreeDefKind.leaf
elif node_data[0] is type(None):
elif node_type is types.NoneType:
kind = ser_flatbuf.PyTreeDefKind.none
elif node_data[0] is tuple:
elif node_type is tuple:
kind = ser_flatbuf.PyTreeDefKind.tuple
elif node_data[0] is list:
elif node_type is list:
kind = ser_flatbuf.PyTreeDefKind.list
elif node_data[0] is dict:
elif node_type is dict:
kind = ser_flatbuf.PyTreeDefKind.dict
assert len(node_data[1]) == len(children)
children_names_vector_offset = _serialize_array(
builder, lambda b, s: b.CreateString(s), node_data[1]
)
elif node_type in _export.serialization_registry:
kind = ser_flatbuf.PyTreeDefKind.custom
serialized_name, serialize_auxdata = _export.serialization_registry[node_type]
custom_name = builder.CreateString(serialized_name)
serialized_auxdata = serialize_auxdata(node_data[1])
if not isinstance(serialized_auxdata, (bytes, bytearray)):
raise ValueError(
"The custom serialization function for `node_type` must "
f"return a `bytes` object. It returned a {type(serialized_auxdata)}.")
custom_auxdata = builder.CreateByteVector(serialized_auxdata)
else:
raise NotImplementedError(f"serializing PyTreeDef {node_data}")
raise ValueError(
"Cannot serialize PyTreeDef containing an "
f"unregistered type `{node_type}`. "
"Use `export.register_pytree_node_serialization` or "
"`export.register_namedtuple_serialization`.")
ser_flatbuf.PyTreeDefStart(builder)
ser_flatbuf.PyTreeDefAddKind(builder, kind)
@ -270,6 +290,10 @@ def _serialize_pytreedef(
ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset)
if children_names_vector_offset:
ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset)
if custom_name is not None:
ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name)
if custom_auxdata is not None:
ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata)
return ser_flatbuf.PyTreeDefEnd(builder)
@ -294,6 +318,17 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
assert p.ChildrenNamesLength() == nr_children
keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)]
return dict(zip(keys, children))
elif kind == ser_flatbuf.PyTreeDefKind.custom:
serialized_name = p.CustomName().decode("utf-8")
if serialized_name not in _export.deserialization_registry:
raise ValueError(
"Cannot deserialize a PyTreeDef containing an "
f"unregistered type `{serialized_name}`. "
"Use `export.register_pytree_node_serialization` or "
"`export.register_namedtuple_serialization`.")
nodetype, deserialize_auxdata, from_iter = _export.deserialization_registry[serialized_name]
auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes())
return from_iter(auxdata, children)
else:
assert False, kind

View File

@ -21,20 +21,21 @@ import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()
class PyTreeDefKind:
class PyTreeDefKind(object):
leaf = 0
none = 1
tuple = 2
list = 3
dict = 4
custom = 5
class AbstractValueKind:
class AbstractValueKind(object):
shapedArray = 0
abstractToken = 1
class DType:
class DType(object):
bool = 0
i8 = 1
i16 = 2
@ -60,18 +61,18 @@ class DType:
f0 = 22
class ShardingKind:
class ShardingKind(object):
unspecified = 0
hlo_sharding = 1
class DisabledSafetyCheckKind:
class DisabledSafetyCheckKind(object):
platform = 0
custom_call = 1
shape_assertions = 2
class PyTreeDef:
class PyTreeDef(object):
__slots__ = ['_tab']
@classmethod
@ -140,8 +141,42 @@ class PyTreeDef:
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
return o == 0
# PyTreeDef
def CustomName(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
# PyTreeDef
def CustomAuxdata(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Int8Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 1))
return 0
# PyTreeDef
def CustomAuxdataAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int8Flags, o)
return 0
# PyTreeDef
def CustomAuxdataLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
return self._tab.VectorLen(o)
return 0
# PyTreeDef
def CustomAuxdataIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
return o == 0
def PyTreeDefStart(builder):
builder.StartObject(3)
builder.StartObject(5)
def PyTreeDefAddKind(builder, kind):
builder.PrependInt8Slot(0, kind, 0)
@ -158,12 +193,21 @@ def PyTreeDefAddChildrenNames(builder, childrenNames):
def PyTreeDefStartChildrenNamesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def PyTreeDefAddCustomName(builder, customName):
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(customName), 0)
def PyTreeDefAddCustomAuxdata(builder, customAuxdata):
builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(customAuxdata), 0)
def PyTreeDefStartCustomAuxdataVector(builder, numElems):
return builder.StartVector(1, numElems, 1)
def PyTreeDefEnd(builder):
return builder.EndObject()
class AbstractValue:
class AbstractValue(object):
__slots__ = ['_tab']
@classmethod
@ -235,7 +279,7 @@ def AbstractValueEnd(builder):
class Sharding:
class Sharding(object):
__slots__ = ['_tab']
@classmethod
@ -304,7 +348,7 @@ def ShardingEnd(builder):
class Effect:
class Effect(object):
__slots__ = ['_tab']
@classmethod
@ -340,7 +384,7 @@ def EffectEnd(builder):
class DisabledSafetyCheck:
class DisabledSafetyCheck(object):
__slots__ = ['_tab']
@classmethod
@ -386,7 +430,7 @@ def DisabledSafetyCheckEnd(builder):
class Exported:
class Exported(object):
__slots__ = ['_tab']
@classmethod

View File

@ -35,8 +35,8 @@ _deprecations = {
"call": (_deprecation_message, _src_export.call),
"call_exported": (_deprecation_message, _src_export.call_exported),
"default_lowering_platform": (_deprecation_message, _src_export.default_lowering_platform),
"minimum_supported_serialization_version" : (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
"maximum_supported_serialization_version" : (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
"minimum_supported_serialization_version": (_deprecation_message, _src_export.minimum_supported_calling_convention_version),
"maximum_supported_serialization_version": (_deprecation_message, _src_export.maximum_supported_calling_convention_version),
"serialize": (_deprecation_message, _src_serialization.serialize),
"deserialize": (_deprecation_message, _src_serialization.deserialize),

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["DisabledSafetyCheck", "Exported", "export", "deserialize",
"register_pytree_node_serialization",
"register_namedtuple_serialization",
"maximum_supported_calling_convention_version",
"minimum_supported_calling_convention_version",
"default_export_platform",
@ -23,6 +25,8 @@ from jax._src.export._export import (
Exported,
export,
deserialize,
register_pytree_node_serialization,
register_namedtuple_serialization,
maximum_supported_calling_convention_version,
minimum_supported_calling_convention_version,
default_export_platform)

View File

@ -13,11 +13,13 @@
# limitations under the License.
from __future__ import annotations
import collections
from collections.abc import Callable, Sequence
import contextlib
import dataclasses
import functools
import logging
import json
import math
import re
import unittest
@ -32,6 +34,7 @@ from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax import tree_util
from jax._src import config
from jax._src import core
@ -311,6 +314,123 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertAllClose(f((a, b), a=a, b=b),
exp_f.call((a, b), a=a, b=b))
def test_pytree_namedtuple(self):
T = collections.namedtuple("SomeType", ("a", "b", "c"))
export.register_namedtuple_serialization(
T,
serialized_name="test_pytree_namedtuple.SomeType",
)
x = T(a=1, b=2, c=3)
def f(x):
return (x, x) # return 2 copies, to check that types are shared
exp = export.export(jax.jit(f))(x)
res = exp.call(x)
self.assertEqual(tree_util.tree_structure(res),
tree_util.tree_structure((x, x)))
self.assertEqual(type(res[0]), type(x))
self.assertEqual(type(res[1]), type(x))
ser = exp.serialize()
exp2 = export.deserialize(ser)
self.assertEqual(exp2.in_tree, exp.in_tree)
self.assertEqual(exp2.out_tree, exp.out_tree)
res2 = exp2.call(x)
self.assertEqual(tree_util.tree_structure(res2),
tree_util.tree_structure(res))
def test_pytree_namedtuple_error(self):
T = collections.namedtuple("SomeType", ("a", "b"))
x = T(a=1, b=2)
with self.assertRaisesRegex(
ValueError,
"Cannot serialize .* unregistered type .*SomeType"):
export.export(jax.jit(lambda x: x))(x).serialize()
with self.assertRaisesRegex(
ValueError,
"If `from_children` is not present.* must call.*register_pytree_node"
):
export.register_pytree_node_serialization(
T,
serialized_name="test_pytree_namedtuple.SomeType_V2",
serialize_auxdata=lambda x: b"",
deserialize_auxdata=lambda b: None
)
with self.assertRaisesRegex(ValueError,
"Use .*register_pytree_node_serialization"):
export.register_namedtuple_serialization(str, serialized_name="n/a")
export.register_namedtuple_serialization(
T,
serialized_name="test_pytree_namedtuple_error.SomeType",
)
with self.assertRaisesRegex(
ValueError,
"Duplicate serialization registration .*test_pytree_namedtuple_error.SomeType"
):
export.register_namedtuple_serialization(
T,
serialized_name="test_pytree_namedtuple_error.OtherType",
)
with self.assertRaisesRegex(
ValueError,
"Duplicate serialization registration for serialized_name.*test_pytree_namedtuple_error.SomeType"
):
export.register_namedtuple_serialization(
collections.namedtuple("SomeOtherType", ("a", "b")),
serialized_name="test_pytree_namedtuple_error.SomeType",
)
def test_pytree_custom_types(self):
x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)])
@tree_util.register_pytree_node_class
class CustomType:
def __init__(self, a: int, b: CustomType | None, string: str):
self.a = a
self.b = b
self.string = string
def tree_flatten(self):
return ((self.a, self.b), self.string)
@classmethod
def tree_unflatten(cls, aux_data, children):
string = aux_data
return cls(*children, string)
export.register_pytree_node_serialization(
CustomType,
serialized_name="test_pytree_custom_types.CustomType",
serialize_auxdata=lambda aux: aux.encode("utf-8"),
deserialize_auxdata=lambda b: b.decode("utf-8")
)
x2 = CustomType(4, 5, "foo")
def f(x1, x2):
return (x1, x2, x1, x2) # return 2 copies, to check that types are shared
exp = export.export(jax.jit(f))(x1, x2)
res = exp.call(x1, x2)
self.assertEqual(tree_util.tree_structure(res),
tree_util.tree_structure(((x1, x2, x1, x2))))
self.assertEqual(type(res[0]), type(x1))
self.assertEqual(type(res[1]), type(x2))
self.assertEqual(type(res[2]), type(x1))
self.assertEqual(type(res[3]), type(x2))
ser = exp.serialize()
exp2 = export.deserialize(ser)
self.assertEqual(exp2.in_tree, exp.in_tree)
self.assertEqual(exp2.out_tree, exp.out_tree)
res2 = exp2.call(x1, x2)
self.assertEqual(tree_util.tree_structure(res2),
tree_util.tree_structure(res))
def test_error_wrong_intree(self):
def f(a_b_pair, *, c):
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c