2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2019 The JAX Authors.
|
2019-07-09 11:38:23 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2019-07-29 10:57:27 -04:00
|
|
|
import collections
|
2021-10-05 15:25:28 -04:00
|
|
|
import functools
|
2022-08-04 07:13:01 -07:00
|
|
|
import pickle
|
2021-04-20 14:10:30 -07:00
|
|
|
import re
|
2019-07-29 10:57:27 -04:00
|
|
|
|
2019-07-09 11:38:23 -07:00
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2022-06-14 15:14:44 -07:00
|
|
|
import jax
|
2019-07-29 10:57:27 -04:00
|
|
|
from jax import tree_util
|
2021-03-19 10:09:14 -07:00
|
|
|
from jax import flatten_util
|
2022-01-10 14:35:02 -08:00
|
|
|
from jax._src import test_util as jtu
|
2023-03-04 00:48:29 +00:00
|
|
|
from jax._src.tree_util import prefix_errors, flatten_one_level
|
2021-03-19 10:09:14 -07:00
|
|
|
import jax.numpy as jnp
|
2021-11-18 14:55:19 -05:00
|
|
|
|
2019-07-09 11:38:23 -07:00
|
|
|
|
|
|
|
def _dummy_func(*args, **kwargs):
|
|
|
|
return
|
|
|
|
|
|
|
|
|
2019-07-29 10:57:27 -04:00
|
|
|
ATuple = collections.namedtuple("ATuple", ("foo", "bar"))
|
|
|
|
|
2019-10-29 10:19:41 -04:00
|
|
|
class ANamedTupleSubclass(ATuple):
|
|
|
|
pass
|
2019-07-29 10:57:27 -04:00
|
|
|
|
2023-03-17 19:08:53 -07:00
|
|
|
ATuple2 = collections.namedtuple("ATuple2", ("foo", "bar"))
|
|
|
|
tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar),
|
|
|
|
lambda bar, foo: ATuple2(foo[0], bar))
|
|
|
|
|
2022-05-12 19:13:00 +01:00
|
|
|
class AnObject:
|
2019-07-29 10:57:27 -04:00
|
|
|
|
|
|
|
def __init__(self, x, y, z):
|
|
|
|
self.x = x
|
|
|
|
self.y = y
|
|
|
|
self.z = z
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return self.x == other.x and self.y == other.y and self.z == other.z
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash((self.x, self.y, self.z))
|
|
|
|
|
|
|
|
def __repr__(self):
|
2022-05-12 19:13:00 +01:00
|
|
|
return f"AnObject({self.x},{self.y},{self.z})"
|
2019-07-29 10:57:27 -04:00
|
|
|
|
2023-03-18 14:32:19 -07:00
|
|
|
tree_util.register_pytree_node(AnObject, lambda o: ((o.x, o.y), o.z),
|
|
|
|
lambda z, xy: AnObject(xy[0], xy[1], z))
|
|
|
|
|
|
|
|
class AnObject2(AnObject): pass
|
|
|
|
|
2023-03-04 00:48:29 +00:00
|
|
|
tree_util.register_pytree_with_keys(
|
2023-03-18 14:32:19 -07:00
|
|
|
AnObject2,
|
2023-03-04 00:48:29 +00:00
|
|
|
lambda o: ((("x", o.x), ("y", o.y)), o.z), # flatten_with_keys
|
2023-03-18 14:32:19 -07:00
|
|
|
lambda z, xy: AnObject2(xy[0], xy[1], z), # unflatten (no key involved)
|
2023-03-04 00:48:29 +00:00
|
|
|
)
|
2019-07-29 10:57:27 -04:00
|
|
|
|
2020-03-10 15:01:18 -07:00
|
|
|
@tree_util.register_pytree_node_class
|
|
|
|
class Special:
|
|
|
|
def __init__(self, x, y):
|
|
|
|
self.x = x
|
|
|
|
self.y = y
|
|
|
|
|
|
|
|
def __repr__(self):
|
2022-05-12 19:13:00 +01:00
|
|
|
return f"Special(x={self.x}, y={self.y})"
|
2020-03-10 15:01:18 -07:00
|
|
|
|
|
|
|
def tree_flatten(self):
|
|
|
|
return ((self.x, self.y), None)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tree_unflatten(cls, aux_data, children):
|
|
|
|
return cls(*children)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)
|
|
|
|
|
2023-03-04 00:48:29 +00:00
|
|
|
|
|
|
|
@tree_util.register_pytree_with_keys_class
|
|
|
|
class SpecialWithKeys(Special):
|
|
|
|
def tree_flatten_with_keys(self):
|
|
|
|
return (((tree_util.GetAttrKey('x'), self.x),
|
|
|
|
(tree_util.GetAttrKey('y'), self.y)), None)
|
|
|
|
|
|
|
|
|
2020-07-30 18:31:17 +01:00
|
|
|
@tree_util.register_pytree_node_class
|
|
|
|
class FlatCache:
|
|
|
|
def __init__(self, structured, *, leaves=None, treedef=None):
|
|
|
|
if treedef is None:
|
|
|
|
leaves, treedef = tree_util.tree_flatten(structured)
|
|
|
|
self._structured = structured
|
|
|
|
self.treedef = treedef
|
|
|
|
self.leaves = leaves
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.structured)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return self.structured == other.structured
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f"FlatCache({self.structured!r})"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def structured(self):
|
|
|
|
if self._structured is None:
|
|
|
|
self._structured = tree_util.tree_unflatten(self.treedef, self.leaves)
|
|
|
|
return self._structured
|
|
|
|
|
|
|
|
def tree_flatten(self):
|
|
|
|
return self.leaves, self.treedef
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tree_unflatten(cls, meta, data):
|
|
|
|
if not tree_util.all_leaves(data):
|
|
|
|
data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
|
|
|
|
return FlatCache(None, leaves=data, treedef=meta)
|
|
|
|
|
2020-03-28 13:14:40 +00:00
|
|
|
TREES = (
|
|
|
|
(None,),
|
xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.
The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.
Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.
Examples:
```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})
OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))
OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:11 -07:00
|
|
|
((None,),),
|
2019-08-01 16:48:18 -04:00
|
|
|
((),),
|
|
|
|
(([()]),),
|
2019-07-29 10:57:27 -04:00
|
|
|
((1, 2),),
|
2019-08-23 15:47:42 -07:00
|
|
|
(((1, "foo"), ["bar", (3, None, 7)]),),
|
2019-07-29 10:57:27 -04:00
|
|
|
([3],),
|
2019-08-23 15:47:42 -07:00
|
|
|
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
|
|
|
|
([AnObject(3, None, [4, "foo"])],),
|
2023-03-18 14:32:19 -07:00
|
|
|
([AnObject2(3, None, [4, "foo"])],),
|
2020-03-10 15:01:18 -07:00
|
|
|
(Special(2, 3.),),
|
2019-07-29 10:57:27 -04:00
|
|
|
({"a": 1, "b": 2},),
|
2019-10-29 10:19:41 -04:00
|
|
|
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
|
2019-12-21 23:38:33 +00:00
|
|
|
(collections.defaultdict(dict,
|
|
|
|
[("foo", 34), ("baz", 101), ("something", -42)]),),
|
2019-10-29 10:19:41 -04:00
|
|
|
(ANamedTupleSubclass(foo="hello", bar=3.5),),
|
2020-07-30 18:31:17 +01:00
|
|
|
(FlatCache(None),),
|
|
|
|
(FlatCache(1),),
|
|
|
|
(FlatCache({"a": [1, 2]}),),
|
2020-03-28 13:14:40 +00:00
|
|
|
)
|
|
|
|
|
xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.
The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.
Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.
Examples:
```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})
OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))
OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:11 -07:00
|
|
|
|
|
|
|
TREE_STRINGS = (
|
|
|
|
"PyTreeDef(None)",
|
|
|
|
"PyTreeDef((None,))",
|
|
|
|
"PyTreeDef(())",
|
|
|
|
"PyTreeDef([()])",
|
|
|
|
"PyTreeDef((*, *))",
|
|
|
|
"PyTreeDef(((*, *), [*, (*, None, *)]))",
|
|
|
|
"PyTreeDef([*])",
|
2022-07-15 07:14:00 -07:00
|
|
|
("PyTreeDef([*, CustomNode(namedtuple[ATuple], [(*, "
|
|
|
|
"CustomNode(namedtuple[ATuple], [*, None])), {'baz': *}])])"),
|
|
|
|
"PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])",
|
2023-03-18 14:32:19 -07:00
|
|
|
"PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])",
|
2022-07-15 07:14:00 -07:00
|
|
|
"PyTreeDef(CustomNode(Special[None], [*, *]))",
|
xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.
The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.
Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.
Examples:
```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})
OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))
OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:11 -07:00
|
|
|
"PyTreeDef({'a': *, 'b': *})",
|
|
|
|
)
|
|
|
|
|
2021-04-20 14:10:30 -07:00
|
|
|
# pytest expects "tree_util_test.ATuple"
|
|
|
|
STRS = []
|
|
|
|
for tree_str in TREE_STRINGS:
|
2021-12-22 02:56:25 -08:00
|
|
|
tree_str = re.escape(tree_str)
|
|
|
|
tree_str = tree_str.replace("__main__", ".*")
|
|
|
|
STRS.append(tree_str)
|
2021-04-20 14:10:30 -07:00
|
|
|
TREE_STRINGS = STRS
|
xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.
The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.
Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.
Examples:
```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})
OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))
OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:11 -07:00
|
|
|
|
2020-03-28 13:14:40 +00:00
|
|
|
LEAVES = (
|
|
|
|
("foo",),
|
|
|
|
(0.1,),
|
|
|
|
(1,),
|
|
|
|
(object(),),
|
|
|
|
)
|
2019-07-29 10:57:27 -04:00
|
|
|
|
2023-03-04 00:48:29 +00:00
|
|
|
# All except those decorated by register_pytree_node_class
|
|
|
|
TREES_WITH_KEYPATH = (
|
|
|
|
(None,),
|
|
|
|
((None,),),
|
|
|
|
((),),
|
|
|
|
(([()]),),
|
|
|
|
((1, 0),),
|
|
|
|
(((1, "foo"), ["bar", (3, None, 7)]),),
|
|
|
|
([3],),
|
|
|
|
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
|
2023-03-18 14:32:19 -07:00
|
|
|
([AnObject2(3, None, [4, "foo"])],),
|
2023-03-04 00:48:29 +00:00
|
|
|
(SpecialWithKeys(2, 3.),),
|
|
|
|
({"a": 1, "b": 0},),
|
|
|
|
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
|
|
|
|
(collections.defaultdict(dict,
|
|
|
|
[("foo", 34), ("baz", 101), ("something", -42)]),),
|
|
|
|
(ANamedTupleSubclass(foo="hello", bar=3.5),),
|
|
|
|
)
|
|
|
|
|
2019-07-29 10:57:27 -04:00
|
|
|
|
2019-07-09 11:38:23 -07:00
|
|
|
class TreeTest(jtu.JaxTestCase):
|
|
|
|
|
2020-03-28 13:14:40 +00:00
|
|
|
@parameterized.parameters(*(TREES + LEAVES))
|
2019-07-09 11:38:23 -07:00
|
|
|
def testRoundtrip(self, inputs):
|
|
|
|
xs, tree = tree_util.tree_flatten(inputs)
|
|
|
|
actual = tree_util.tree_unflatten(tree, xs)
|
|
|
|
self.assertEqual(actual, inputs)
|
|
|
|
|
2020-03-28 13:14:40 +00:00
|
|
|
@parameterized.parameters(*(TREES + LEAVES))
|
2019-08-01 12:15:08 -04:00
|
|
|
def testRoundtripWithFlattenUpTo(self, inputs):
|
|
|
|
_, tree = tree_util.tree_flatten(inputs)
|
|
|
|
xs = tree.flatten_up_to(inputs)
|
|
|
|
actual = tree_util.tree_unflatten(tree, xs)
|
|
|
|
self.assertEqual(actual, inputs)
|
|
|
|
|
2019-07-09 11:38:23 -07:00
|
|
|
@parameterized.parameters(
|
2019-07-29 10:57:27 -04:00
|
|
|
(tree_util.Partial(_dummy_func),),
|
|
|
|
(tree_util.Partial(_dummy_func, 1, 2),),
|
|
|
|
(tree_util.Partial(_dummy_func, x="a"),),
|
|
|
|
(tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5),),
|
2019-07-09 11:38:23 -07:00
|
|
|
)
|
2019-07-29 10:57:27 -04:00
|
|
|
def testRoundtripPartial(self, inputs):
|
2019-07-09 11:38:23 -07:00
|
|
|
xs, tree = tree_util.tree_flatten(inputs)
|
|
|
|
actual = tree_util.tree_unflatten(tree, xs)
|
|
|
|
# functools.partial does not support equality comparisons:
|
|
|
|
# https://stackoverflow.com/a/32786109/809705
|
|
|
|
self.assertEqual(actual.func, inputs.func)
|
|
|
|
self.assertEqual(actual.args, inputs.args)
|
|
|
|
self.assertEqual(actual.keywords, inputs.keywords)
|
|
|
|
|
2021-10-05 15:25:28 -04:00
|
|
|
def testPartialDoesNotMergeWithOtherPartials(self):
|
|
|
|
def f(a, b, c): pass
|
|
|
|
g = functools.partial(f, 2)
|
|
|
|
h = tree_util.Partial(g, 3)
|
|
|
|
self.assertEqual(h.args, (3,))
|
|
|
|
|
2022-02-03 10:23:29 -05:00
|
|
|
def testPartialFuncAttributeHasStableHash(self):
|
|
|
|
# https://github.com/google/jax/issues/9429
|
|
|
|
fun = functools.partial(print, 1)
|
|
|
|
p1 = tree_util.Partial(fun, 2)
|
|
|
|
p2 = tree_util.Partial(fun, 2)
|
|
|
|
self.assertEqual(fun, p1.func)
|
|
|
|
self.assertEqual(p1.func, fun)
|
|
|
|
self.assertEqual(p1.func, p2.func)
|
|
|
|
self.assertEqual(hash(p1.func), hash(p2.func))
|
|
|
|
|
2019-07-29 10:57:27 -04:00
|
|
|
def testChildren(self):
|
|
|
|
_, tree = tree_util.tree_flatten(((1, 2, 3), (4,)))
|
|
|
|
_, c0 = tree_util.tree_flatten((0, 0, 0))
|
|
|
|
_, c1 = tree_util.tree_flatten((7,))
|
|
|
|
self.assertEqual([c0, c1], tree.children())
|
|
|
|
|
2021-08-03 09:53:53 -07:00
|
|
|
def testTreedefTupleFromChildren(self):
|
|
|
|
# https://github.com/google/jax/issues/7377
|
|
|
|
tree = ((1, 2, (3, 4)), (5,))
|
|
|
|
leaves, treedef1 = tree_util.tree_flatten(tree)
|
|
|
|
treedef2 = tree_util.treedef_tuple(treedef1.children())
|
|
|
|
self.assertEqual(treedef1.num_leaves, len(leaves))
|
|
|
|
self.assertEqual(treedef1.num_leaves, treedef2.num_leaves)
|
|
|
|
self.assertEqual(treedef1.num_nodes, treedef2.num_nodes)
|
|
|
|
|
2022-01-10 14:35:02 -08:00
|
|
|
def testTreedefTupleComparesEqual(self):
|
|
|
|
# https://github.com/google/jax/issues/9066
|
|
|
|
self.assertEqual(tree_util.tree_structure((3,)),
|
|
|
|
tree_util.treedef_tuple((tree_util.tree_structure(3),)))
|
|
|
|
|
2022-07-28 19:22:55 -07:00
|
|
|
def testFlattenOrder(self):
|
|
|
|
flat1, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
|
|
|
|
flat2, _ = tree_util.tree_flatten([0, ((1, 2), 3, (4, (5, 6, 7))), 8, 9])
|
|
|
|
flat3, _ = tree_util.tree_flatten([0, ((1, (2, 3)), (4, (5, 6, 7))), 8, 9])
|
|
|
|
self.assertEqual(flat1, list(range(10)))
|
|
|
|
self.assertEqual(flat2, list(range(10)))
|
|
|
|
self.assertEqual(flat3, list(range(10)))
|
|
|
|
|
2019-08-01 12:15:08 -04:00
|
|
|
def testFlattenUpTo(self):
|
2019-08-23 15:47:42 -07:00
|
|
|
_, tree = tree_util.tree_flatten([(1, 2), None, ATuple(foo=3, bar=7)])
|
2019-08-01 12:15:08 -04:00
|
|
|
out = tree.flatten_up_to([({
|
|
|
|
"foo": 7
|
2019-08-23 15:47:42 -07:00
|
|
|
}, (3, 4)), None, ATuple(foo=(11, 9), bar=None)])
|
|
|
|
self.assertEqual(out, [{"foo": 7}, (3, 4), (11, 9), None])
|
2019-08-01 12:15:08 -04:00
|
|
|
|
2022-04-01 14:51:54 -07:00
|
|
|
def testTreeMap(self):
|
2019-08-01 12:15:08 -04:00
|
|
|
x = ((1, 2), (3, 4, 5))
|
2019-08-23 15:47:42 -07:00
|
|
|
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
2022-04-01 14:51:54 -07:00
|
|
|
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y)
|
2019-08-23 15:47:42 -07:00
|
|
|
self.assertEqual(out, (((1, [3]), (2, None)),
|
2019-08-01 12:15:08 -04:00
|
|
|
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
|
|
|
|
|
2022-04-01 14:51:54 -07:00
|
|
|
def testTreeMapWithIsLeafArgument(self):
|
2021-01-05 11:21:21 +00:00
|
|
|
x = ((1, 2), [3, 4, 5])
|
|
|
|
y = (([3], None), ({"foo": "bar"}, 7, [5, 6]))
|
2022-04-01 14:51:54 -07:00
|
|
|
out = tree_util.tree_map(lambda *xs: tuple(xs), x, y,
|
|
|
|
is_leaf=lambda n: isinstance(n, list))
|
2021-01-05 11:21:21 +00:00
|
|
|
self.assertEqual(out, (((1, [3]), (2, None)),
|
|
|
|
(([3, 4, 5], ({"foo": "bar"}, 7, [5, 6])))))
|
|
|
|
|
2021-12-22 02:56:25 -08:00
|
|
|
@parameterized.parameters(
|
|
|
|
tree_util.tree_leaves,
|
|
|
|
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[0])
|
|
|
|
def testFlattenIsLeaf(self, leaf_fn):
|
2020-12-09 12:38:39 +00:00
|
|
|
x = [(1, 2), (3, 4), (5, 6)]
|
2021-12-22 02:56:25 -08:00
|
|
|
leaves = leaf_fn(x, is_leaf=lambda t: False)
|
2020-12-09 12:38:39 +00:00
|
|
|
self.assertEqual(leaves, [1, 2, 3, 4, 5, 6])
|
2021-12-22 02:56:25 -08:00
|
|
|
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, tuple))
|
2020-12-09 12:38:39 +00:00
|
|
|
self.assertEqual(leaves, x)
|
2021-12-22 02:56:25 -08:00
|
|
|
leaves = leaf_fn(x, is_leaf=lambda t: isinstance(t, list))
|
2020-12-09 12:38:39 +00:00
|
|
|
self.assertEqual(leaves, [x])
|
2021-12-22 02:56:25 -08:00
|
|
|
leaves = leaf_fn(x, is_leaf=lambda t: True)
|
2020-12-09 12:38:39 +00:00
|
|
|
self.assertEqual(leaves, [x])
|
|
|
|
|
|
|
|
y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
|
2021-12-22 02:56:25 -08:00
|
|
|
leaves = leaf_fn(y, is_leaf=lambda t: isinstance(t, tuple))
|
2020-12-09 12:38:39 +00:00
|
|
|
self.assertEqual(leaves, [(1,), (2,), (3,)])
|
|
|
|
|
2021-12-22 02:56:25 -08:00
|
|
|
@parameterized.parameters(
|
|
|
|
tree_util.tree_structure,
|
|
|
|
lambda tree, is_leaf: tree_util.tree_flatten(tree, is_leaf)[1])
|
|
|
|
def testStructureIsLeaf(self, structure_fn):
|
|
|
|
x = [(1, 2), (3, 4), (5, 6)]
|
|
|
|
treedef = structure_fn(x, is_leaf=lambda t: False)
|
|
|
|
self.assertEqual(treedef.num_leaves, 6)
|
|
|
|
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, tuple))
|
|
|
|
self.assertEqual(treedef.num_leaves, 3)
|
|
|
|
treedef = structure_fn(x, is_leaf=lambda t: isinstance(t, list))
|
|
|
|
self.assertEqual(treedef.num_leaves, 1)
|
|
|
|
treedef = structure_fn(x, is_leaf=lambda t: True)
|
|
|
|
self.assertEqual(treedef.num_leaves, 1)
|
|
|
|
|
|
|
|
y = [[[(1,)], [[(2,)], {"a": (3,)}]]]
|
|
|
|
treedef = structure_fn(y, is_leaf=lambda t: isinstance(t, tuple))
|
|
|
|
self.assertEqual(treedef.num_leaves, 3)
|
|
|
|
|
2020-12-09 12:38:39 +00:00
|
|
|
@parameterized.parameters(*TREES)
|
|
|
|
def testRoundtripIsLeaf(self, tree):
|
|
|
|
xs, treedef = tree_util.tree_flatten(
|
|
|
|
tree, is_leaf=lambda t: isinstance(t, tuple))
|
|
|
|
recon_tree = tree_util.tree_unflatten(treedef, xs)
|
|
|
|
self.assertEqual(recon_tree, tree)
|
|
|
|
|
2020-03-28 13:14:40 +00:00
|
|
|
@parameterized.parameters(*TREES)
|
|
|
|
def testAllLeavesWithTrees(self, tree):
|
|
|
|
leaves = tree_util.tree_leaves(tree)
|
|
|
|
self.assertTrue(tree_util.all_leaves(leaves))
|
|
|
|
self.assertFalse(tree_util.all_leaves([tree]))
|
|
|
|
|
|
|
|
@parameterized.parameters(*LEAVES)
|
|
|
|
def testAllLeavesWithLeaves(self, leaf):
|
|
|
|
self.assertTrue(tree_util.all_leaves([leaf]))
|
2019-07-09 11:38:23 -07:00
|
|
|
|
2020-07-30 18:31:17 +01:00
|
|
|
@parameterized.parameters(*TREES)
|
|
|
|
def testCompose(self, tree):
|
|
|
|
treedef = tree_util.tree_structure(tree)
|
|
|
|
inner_treedef = tree_util.tree_structure(["*", "*", "*"])
|
|
|
|
composed_treedef = treedef.compose(inner_treedef)
|
|
|
|
expected_leaves = treedef.num_leaves * inner_treedef.num_leaves
|
|
|
|
self.assertEqual(composed_treedef.num_leaves, expected_leaves)
|
|
|
|
expected_nodes = ((treedef.num_nodes - treedef.num_leaves) +
|
|
|
|
(inner_treedef.num_nodes * treedef.num_leaves))
|
|
|
|
self.assertEqual(composed_treedef.num_nodes, expected_nodes)
|
|
|
|
leaves = [1] * expected_leaves
|
|
|
|
composed = tree_util.tree_unflatten(composed_treedef, leaves)
|
|
|
|
self.assertEqual(leaves, tree_util.tree_leaves(composed))
|
|
|
|
|
|
|
|
@parameterized.parameters(*TREES)
|
|
|
|
def testTranspose(self, tree):
|
|
|
|
outer_treedef = tree_util.tree_structure(tree)
|
|
|
|
if not outer_treedef.num_leaves:
|
|
|
|
self.skipTest("Skipping empty tree")
|
|
|
|
inner_treedef = tree_util.tree_structure([1, 1, 1])
|
|
|
|
nested = tree_util.tree_map(lambda x: [x, x, x], tree)
|
|
|
|
actual = tree_util.tree_transpose(outer_treedef, inner_treedef, nested)
|
|
|
|
self.assertEqual(actual, [tree, tree, tree])
|
|
|
|
|
|
|
|
def testTransposeMismatchOuter(self):
|
|
|
|
tree = {"a": [1, 2], "b": [3, 4]}
|
|
|
|
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3})
|
|
|
|
inner_treedef = tree_util.tree_structure([1, 2])
|
|
|
|
with self.assertRaisesRegex(TypeError, "Mismatch"):
|
|
|
|
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
|
|
|
|
|
|
|
|
def testTransposeMismatchInner(self):
|
|
|
|
tree = {"a": [1, 2], "b": [3, 4]}
|
|
|
|
outer_treedef = tree_util.tree_structure({"a": 1, "b": 2})
|
|
|
|
inner_treedef = tree_util.tree_structure([1, 2, 3])
|
|
|
|
with self.assertRaisesRegex(TypeError, "Mismatch"):
|
|
|
|
tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
|
|
|
|
|
|
|
|
def testTransposeWithCustomObject(self):
|
|
|
|
outer_treedef = tree_util.tree_structure(FlatCache({"a": 1, "b": 2}))
|
|
|
|
inner_treedef = tree_util.tree_structure([1, 2])
|
|
|
|
expected = [FlatCache({"a": 3, "b": 5}), FlatCache({"a": 4, "b": 6})]
|
|
|
|
actual = tree_util.tree_transpose(outer_treedef, inner_treedef,
|
|
|
|
FlatCache({"a": [3, 4], "b": [5, 6]}))
|
|
|
|
self.assertEqual(expected, actual)
|
|
|
|
|
xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.
The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.
Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.
Examples:
```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})
OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))
OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:11 -07:00
|
|
|
@parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)])
|
|
|
|
def testStringRepresentation(self, tree, correct_string):
|
|
|
|
"""Checks that the string representation of a tree works."""
|
|
|
|
treedef = tree_util.tree_structure(tree)
|
2021-04-20 14:10:30 -07:00
|
|
|
self.assertRegex(str(treedef), correct_string)
|
xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.
The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.
Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.
Examples:
```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})
OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))
OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:11 -07:00
|
|
|
|
2021-08-16 10:43:00 -07:00
|
|
|
def testTreeDefWithEmptyDictStringRepresentation(self):
|
|
|
|
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
|
|
|
|
|
2022-08-04 07:13:01 -07:00
|
|
|
@parameterized.parameters(*TREES)
|
|
|
|
def testPickleRoundTrip(self, tree):
|
|
|
|
treedef = tree_util.tree_structure(tree)
|
|
|
|
treedef_restored = pickle.loads(pickle.dumps(treedef))
|
|
|
|
self.assertEqual(treedef, treedef_restored)
|
|
|
|
|
2023-01-18 13:39:58 -08:00
|
|
|
def testDictKeysSortable(self):
|
|
|
|
d = {"a": 1, 2: "b"}
|
|
|
|
with self.assertRaisesRegex(TypeError, "'<' not supported"):
|
|
|
|
_, _ = tree_util.tree_flatten(d)
|
|
|
|
|
|
|
|
def testFlattenDictKeyOrder(self):
|
|
|
|
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
|
|
|
|
leaves, treedef = tree_util.tree_flatten(d)
|
|
|
|
self.assertEqual(leaves, [1, 2, 1, 2])
|
|
|
|
self.assertEqual(
|
|
|
|
str(treedef), "PyTreeDef({'a': *, 'b': *, 'c': {'a': *, 'b': *}})"
|
|
|
|
)
|
|
|
|
restored_d = tree_util.tree_unflatten(treedef, leaves)
|
|
|
|
self.assertEqual(list(restored_d.keys()), ["a", "b", "c"])
|
|
|
|
|
|
|
|
def testWalk(self):
|
|
|
|
d = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}}
|
|
|
|
leaves, treedef = tree_util.tree_flatten(d)
|
|
|
|
|
|
|
|
nodes_visited = []
|
|
|
|
node_data_visited = []
|
|
|
|
leaves_visited = []
|
|
|
|
|
|
|
|
def f_node(node, node_data):
|
|
|
|
nodes_visited.append(node)
|
|
|
|
node_data_visited.append(node_data)
|
|
|
|
|
|
|
|
def f_leaf(leaf):
|
|
|
|
leaves_visited.append(leaf)
|
|
|
|
|
|
|
|
treedef.walk(f_node, f_leaf, leaves)
|
|
|
|
self.assertEqual(leaves_visited, [1, 2, 1, 2])
|
|
|
|
self.assertEqual(nodes_visited, [(None, None), (None, None, None)])
|
|
|
|
self.assertEqual(node_data_visited, [["a", "b"], ["a", "b", "c"]])
|
|
|
|
|
2023-03-04 00:48:29 +00:00
|
|
|
@parameterized.parameters(*(TREES_WITH_KEYPATH + LEAVES))
|
|
|
|
def testRoundtripWithPath(self, inputs):
|
|
|
|
key_leaves, treedef = tree_util.tree_flatten_with_path(inputs)
|
|
|
|
actual = tree_util.tree_unflatten(treedef, [leaf for _, leaf in key_leaves])
|
|
|
|
self.assertEqual(actual, inputs)
|
|
|
|
|
|
|
|
def testTreeMapWithPath(self):
|
|
|
|
tree = [{i: i for i in range(10)}]
|
|
|
|
all_zeros = tree_util.tree_map_with_path(
|
|
|
|
lambda kp, val: val - kp[1].key + kp[0].idx, tree
|
|
|
|
)
|
|
|
|
self.assertEqual(all_zeros, [{i: 0 for i in range(10)}])
|
|
|
|
|
|
|
|
def testTreeMapWithPathMultipleTrees(self):
|
2023-03-18 14:32:19 -07:00
|
|
|
tree1 = [AnObject2(x=12,
|
|
|
|
y={'cin': [1, 4, 10], 'bar': None},
|
|
|
|
z='constantdef'),
|
2023-03-04 00:48:29 +00:00
|
|
|
5]
|
2023-03-18 14:32:19 -07:00
|
|
|
tree2 = [AnObject2(x=2,
|
|
|
|
y={'cin': [2, 2, 2], 'bar': None},
|
|
|
|
z='constantdef'),
|
2023-03-04 00:48:29 +00:00
|
|
|
2]
|
|
|
|
from_two_trees = tree_util.tree_map_with_path(
|
|
|
|
lambda kp, a, b: a + b, tree1, tree2
|
|
|
|
)
|
|
|
|
from_one_tree = tree_util.tree_map(lambda a: a + 2, tree1)
|
|
|
|
self.assertEqual(from_two_trees, from_one_tree)
|
|
|
|
|
|
|
|
def testKeyStr(self):
|
|
|
|
tree1 = [ATuple(12, {'cin': [1, 4, 10], 'bar': None}), jnp.arange(5)]
|
|
|
|
flattened, _ = tree_util.tree_flatten_with_path(tree1)
|
|
|
|
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
|
|
|
|
self.assertEqual(
|
|
|
|
strs,
|
|
|
|
[
|
|
|
|
"[0].foo: 12",
|
|
|
|
"[0].bar['cin'][0]: 1",
|
|
|
|
"[0].bar['cin'][1]: 4",
|
|
|
|
"[0].bar['cin'][2]: 10",
|
|
|
|
"[1]: [0 1 2 3 4]",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
def testTreeMapWithPathWithIsLeafArgument(self):
|
|
|
|
x = ((1, 2), [3, 4, 5])
|
|
|
|
y = (([3], jnp.array((0))), ([0], 7, [5, 6]))
|
|
|
|
out = tree_util.tree_map_with_path(
|
|
|
|
lambda kp, *xs: tuple((kp[0].idx, *xs)), x, y,
|
|
|
|
is_leaf=lambda n: isinstance(n, list))
|
|
|
|
self.assertEqual(out, (((0, 1, [3]),
|
|
|
|
(0, 2, jnp.array((0)))),
|
|
|
|
(1, [3, 4, 5], ([0], 7, [5, 6]))))
|
|
|
|
|
|
|
|
def testFlattenWithPathWithIsLeafArgument(self):
|
|
|
|
def is_empty(x):
|
|
|
|
try:
|
|
|
|
children, _ = flatten_one_level(x)
|
|
|
|
except ValueError:
|
|
|
|
return True # Cannot flatten x; means it must be a leaf
|
|
|
|
return len(children) == 0
|
|
|
|
|
|
|
|
EmptyTuple = collections.namedtuple("EmptyTuple", ())
|
|
|
|
tree1 = {'a': 1,
|
|
|
|
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
|
2023-03-18 14:32:19 -07:00
|
|
|
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
|
2023-03-04 00:48:29 +00:00
|
|
|
flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty)
|
|
|
|
strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened]
|
|
|
|
self.assertEqual(
|
|
|
|
strs,
|
|
|
|
[
|
|
|
|
"['a']: 1",
|
|
|
|
"['obj']x: EmptyTuple()",
|
|
|
|
"['obj']y: 0",
|
|
|
|
"['sub'][0]: [1 2]",
|
|
|
|
"['sub'][1].foo: ()",
|
|
|
|
"['sub'][1].bar[0]: None",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
def testFlattenOneLevel(self):
|
|
|
|
EmptyTuple = collections.namedtuple("EmptyTuple", ())
|
|
|
|
tree1 = {'a': 1,
|
|
|
|
'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])],
|
2023-03-18 14:32:19 -07:00
|
|
|
'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')}
|
2023-03-04 00:48:29 +00:00
|
|
|
self.assertEqual(flatten_one_level(tree1["sub"])[0],
|
|
|
|
tree1["sub"])
|
|
|
|
self.assertEqual(flatten_one_level(tree1["sub"][1])[0],
|
|
|
|
[(), [None]])
|
|
|
|
self.assertEqual(flatten_one_level(tree1["obj"])[0],
|
|
|
|
[EmptyTuple(), 0])
|
|
|
|
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
|
|
|
|
flatten_one_level(1)
|
|
|
|
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
|
|
|
|
flatten_one_level(jnp.array((1, 2)))
|
|
|
|
|
2023-03-16 21:34:29 -07:00
|
|
|
def testOptionalFlatten(self):
|
|
|
|
@tree_util.register_pytree_with_keys_class
|
|
|
|
class FooClass:
|
|
|
|
def __init__(self, x, y):
|
|
|
|
self.x = x
|
|
|
|
self.y = y
|
|
|
|
def tree_flatten(self):
|
|
|
|
return ((self.x, self.y), 'treedef')
|
|
|
|
def tree_flatten_with_keys(self):
|
|
|
|
return (((tree_util.GetAttrKey('x'), self.x),
|
|
|
|
(tree_util.GetAttrKey('x'), self.y)), 'treedef')
|
|
|
|
@classmethod
|
|
|
|
def tree_unflatten(cls, _, children):
|
|
|
|
return cls(*children)
|
|
|
|
|
|
|
|
tree = FooClass(x=1, y=2)
|
|
|
|
self.assertEqual(
|
|
|
|
str(tree_util.tree_flatten(tree)[1]),
|
|
|
|
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
|
|
|
|
)
|
|
|
|
self.assertEqual(
|
|
|
|
str(tree_util.tree_flatten_with_path(tree)[1]),
|
|
|
|
"PyTreeDef(CustomNode(FooClass[treedef], [*, *]))",
|
|
|
|
)
|
|
|
|
self.assertEqual(tree_util.tree_flatten(tree)[0],
|
|
|
|
[l for _, l in tree_util.tree_flatten_with_path(tree)[0]])
|
|
|
|
|
2023-03-17 19:08:53 -07:00
|
|
|
def testPyTreeWithoutKeysIsntTreatedAsLeaf(self):
|
|
|
|
leaves, _ = tree_util.tree_flatten_with_path(Special([1, 2], [3, 4]))
|
|
|
|
self.assertLen(leaves, 4)
|
|
|
|
|
|
|
|
def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self):
|
|
|
|
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
|
|
|
|
self.assertLen(leaves, 1)
|
|
|
|
|
2021-03-19 10:09:14 -07:00
|
|
|
|
|
|
|
class RavelUtilTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def testFloats(self):
|
2023-01-18 13:39:58 -08:00
|
|
|
tree = [
|
|
|
|
jnp.array([3.0], jnp.float32),
|
|
|
|
jnp.array([[1.0, 2.0], [3.0, 4.0]], jnp.float32),
|
|
|
|
]
|
2021-03-19 10:09:14 -07:00
|
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
|
|
self.assertEqual(raveled.dtype, jnp.float32)
|
|
|
|
tree_ = unravel(raveled)
|
|
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
|
|
|
|
def testInts(self):
|
|
|
|
tree = [jnp.array([3], jnp.int32),
|
|
|
|
jnp.array([[1, 2], [3, 4]], jnp.int32)]
|
|
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
|
|
self.assertEqual(raveled.dtype, jnp.int32)
|
|
|
|
tree_ = unravel(raveled)
|
|
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
|
2022-06-14 15:14:44 -07:00
|
|
|
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
|
2021-03-19 10:09:14 -07:00
|
|
|
def testMixedFloatInt(self):
|
|
|
|
tree = [jnp.array([3], jnp.int32),
|
|
|
|
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
|
|
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
2021-04-07 19:35:17 -07:00
|
|
|
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.int32))
|
2021-03-19 10:09:14 -07:00
|
|
|
tree_ = unravel(raveled)
|
|
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
|
2022-06-14 15:14:44 -07:00
|
|
|
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
|
2021-03-19 10:09:14 -07:00
|
|
|
def testMixedIntBool(self):
|
|
|
|
tree = [jnp.array([0], jnp.bool_),
|
|
|
|
jnp.array([[1, 2], [3, 4]], jnp.int32)]
|
|
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
2021-04-07 19:35:17 -07:00
|
|
|
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.bool_, jnp.int32))
|
2021-03-19 10:09:14 -07:00
|
|
|
tree_ = unravel(raveled)
|
|
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
|
2022-06-14 15:14:44 -07:00
|
|
|
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
|
2021-03-19 10:09:14 -07:00
|
|
|
def testMixedFloatComplex(self):
|
|
|
|
tree = [jnp.array([1.], jnp.float32),
|
|
|
|
jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
|
|
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
2021-04-07 19:35:17 -07:00
|
|
|
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.complex64))
|
2021-03-19 10:09:14 -07:00
|
|
|
tree_ = unravel(raveled)
|
|
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
|
|
|
|
def testEmpty(self):
|
|
|
|
tree = []
|
|
|
|
raveled, unravel = flatten_util.ravel_pytree(tree)
|
|
|
|
self.assertEqual(raveled.dtype, jnp.float32) # convention
|
|
|
|
tree_ = unravel(raveled)
|
|
|
|
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
|
|
|
|
2021-12-14 14:40:13 -08:00
|
|
|
def testDtypePolymorphicUnravel(self):
|
|
|
|
# https://github.com/google/jax/issues/7809
|
|
|
|
x = jnp.arange(10, dtype=jnp.float32)
|
|
|
|
x_flat, unravel = flatten_util.ravel_pytree(x)
|
|
|
|
y = x_flat < 5.3
|
|
|
|
x_ = unravel(y)
|
|
|
|
self.assertEqual(x_.dtype, y.dtype)
|
|
|
|
|
2022-06-14 15:14:44 -07:00
|
|
|
@jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion.
|
2021-12-14 14:40:13 -08:00
|
|
|
def testDtypeMonomorphicUnravel(self):
|
|
|
|
# https://github.com/google/jax/issues/7809
|
|
|
|
x1 = jnp.arange(10, dtype=jnp.float32)
|
|
|
|
x2 = jnp.arange(10, dtype=jnp.int32)
|
|
|
|
x_flat, unravel = flatten_util.ravel_pytree((x1, x2))
|
|
|
|
y = x_flat < 5.3
|
|
|
|
with self.assertRaisesRegex(TypeError, 'but expected dtype'):
|
|
|
|
_ = unravel(y)
|
|
|
|
|
2023-03-13 10:47:45 -07:00
|
|
|
def test_no_recompile(self):
|
|
|
|
x1 = jnp.array([1, 2])
|
|
|
|
x2 = jnp.array([3, 4])
|
|
|
|
x_flat1, unravel1 = flatten_util.ravel_pytree((x1, x2))
|
|
|
|
x_flat2, unravel2 = flatten_util.ravel_pytree((x1, x2))
|
|
|
|
num_traces = 0
|
|
|
|
|
|
|
|
def run(flat, unravel):
|
|
|
|
nonlocal num_traces
|
|
|
|
num_traces += 1
|
|
|
|
flat = flat + 1
|
|
|
|
return unravel(flat)
|
|
|
|
|
|
|
|
run = jax.jit(run, static_argnums=1)
|
|
|
|
|
|
|
|
run(x_flat1, unravel1)
|
|
|
|
run(x_flat2, unravel2)
|
|
|
|
self.assertEqual(num_traces, 1)
|
|
|
|
|
2021-03-19 10:09:14 -07:00
|
|
|
|
2022-01-28 15:54:19 -08:00
|
|
|
class TreePrefixErrorsTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_different_types(self):
|
|
|
|
e, = prefix_errors((1, 2), [1, 2])
|
2022-02-08 12:45:38 -08:00
|
|
|
expected = ("pytree structure error: different types at key path\n"
|
2023-03-04 00:48:29 +00:00
|
|
|
" in_axes")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
|
|
|
def test_different_types_nested(self):
|
|
|
|
e, = prefix_errors(((1,), (2,)), ([3], (4,)))
|
2022-02-08 12:45:38 -08:00
|
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
|
|
r" in_axes\[0\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
|
|
|
def test_different_types_multiple(self):
|
|
|
|
e1, e2 = prefix_errors(((1,), (2,)), ([3], [4]))
|
2022-02-08 12:45:38 -08:00
|
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
|
|
r" in_axes\[0\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e1('in_axes')
|
2022-02-08 12:45:38 -08:00
|
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
|
|
r" in_axes\[1\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e2('in_axes')
|
|
|
|
|
2023-01-20 10:51:02 -08:00
|
|
|
def test_different_num_children_tuple(self):
|
2022-01-28 15:54:19 -08:00
|
|
|
e, = prefix_errors((1,), (2, 3))
|
2023-01-20 10:51:02 -08:00
|
|
|
expected = ("pytree structure error: different lengths of tuple "
|
|
|
|
"at key path\n"
|
2023-03-04 00:48:29 +00:00
|
|
|
" in_axes")
|
2023-01-20 10:51:02 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
|
|
|
def test_different_num_children_list(self):
|
|
|
|
e, = prefix_errors([1], [2, 3])
|
|
|
|
expected = ("pytree structure error: different lengths of list "
|
|
|
|
"at key path\n"
|
2023-03-04 00:48:29 +00:00
|
|
|
" in_axes")
|
2023-01-20 10:51:02 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
|
|
|
|
|
|
|
def test_different_num_children_generic(self):
|
|
|
|
e, = prefix_errors({'hi': 1}, {'hi': 2, 'bye': 3})
|
2022-01-28 15:54:19 -08:00
|
|
|
expected = ("pytree structure error: different numbers of pytree children "
|
2022-02-08 12:45:38 -08:00
|
|
|
"at key path\n"
|
2023-03-04 00:48:29 +00:00
|
|
|
" in_axes")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
|
|
|
def test_different_num_children_nested(self):
|
|
|
|
e, = prefix_errors([[1]], [[2, 3]])
|
2023-01-20 10:51:02 -08:00
|
|
|
expected = ("pytree structure error: different lengths of list "
|
2022-02-08 12:45:38 -08:00
|
|
|
"at key path\n"
|
|
|
|
r" in_axes\[0\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
|
|
|
def test_different_num_children_multiple(self):
|
|
|
|
e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]])
|
2023-01-20 10:51:02 -08:00
|
|
|
expected = ("pytree structure error: different lengths of list "
|
2022-02-08 12:45:38 -08:00
|
|
|
"at key path\n"
|
|
|
|
r" in_axes\[0\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e1('in_axes')
|
2023-01-20 10:51:02 -08:00
|
|
|
expected = ("pytree structure error: different lengths of list "
|
2022-02-08 12:45:38 -08:00
|
|
|
"at key path\n"
|
|
|
|
r" in_axes\[1\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e2('in_axes')
|
|
|
|
|
2022-10-10 17:47:18 -07:00
|
|
|
def test_different_num_children_print_key_diff(self):
|
|
|
|
e, = prefix_errors({'a': 1}, {'a': 2, 'b': 3})
|
|
|
|
expected = ("so the symmetric difference on key sets is\n"
|
|
|
|
" b")
|
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
2022-01-28 15:54:19 -08:00
|
|
|
def test_different_metadata(self):
|
|
|
|
e, = prefix_errors({1: 2}, {3: 4})
|
|
|
|
expected = ("pytree structure error: different pytree metadata "
|
2022-02-08 12:45:38 -08:00
|
|
|
"at key path\n"
|
2023-03-04 00:48:29 +00:00
|
|
|
" in_axes")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
|
|
|
def test_different_metadata_nested(self):
|
|
|
|
e, = prefix_errors([{1: 2}], [{3: 4}])
|
|
|
|
expected = ("pytree structure error: different pytree metadata "
|
2022-02-08 12:45:38 -08:00
|
|
|
"at key path\n"
|
|
|
|
r" in_axes\[0\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
|
|
|
def test_different_metadata_multiple(self):
|
|
|
|
e1, e2 = prefix_errors([{1: 2}, {3: 4}], [{3: 4}, {5: 6}])
|
|
|
|
expected = ("pytree structure error: different pytree metadata "
|
2022-02-08 12:45:38 -08:00
|
|
|
"at key path\n"
|
|
|
|
r" in_axes\[0\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e1('in_axes')
|
|
|
|
expected = ("pytree structure error: different pytree metadata "
|
2022-02-08 12:45:38 -08:00
|
|
|
"at key path\n"
|
|
|
|
r" in_axes\[1\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e2('in_axes')
|
|
|
|
|
|
|
|
def test_fallback_keypath(self):
|
|
|
|
e, = prefix_errors(Special(1, [2]), Special(3, 4))
|
2022-02-08 12:45:38 -08:00
|
|
|
expected = ("pytree structure error: different types at key path\n"
|
|
|
|
r" in_axes\[<flat index 1>\]")
|
2022-01-28 15:54:19 -08:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
|
|
|
def test_no_errors(self):
|
|
|
|
() = prefix_errors((1, 2), ((11, 12, 13), 2))
|
|
|
|
|
2022-07-18 14:14:57 -07:00
|
|
|
def test_different_structure_no_children(self):
|
|
|
|
e, = prefix_errors({}, {'a': []})
|
|
|
|
expected = ("pytree structure error: different numbers of pytree children "
|
|
|
|
"at key path\n"
|
2023-03-04 00:48:29 +00:00
|
|
|
" in_axes")
|
2022-07-18 14:14:57 -07:00
|
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
|
|
raise e('in_axes')
|
|
|
|
|
2022-01-28 15:54:19 -08:00
|
|
|
|
2019-07-09 11:38:23 -07:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|