Add jax.tree_util.all_leaves(iterable).

In Haiku (https://github.com/deepmind/dm-haiku) we have `FlatMapping` which is
an immutable Mapping subclass maintaining a flat internal representation. Our
goal is to allow very cheap flatten/unflatten since these objects are used to
represent parameters/state and are often passed in and out of JAX functions that
flatten their inputs (e.g. jit/pmap).

One challenge we have is that on unflatten we need a fast way of testing whether
the list of leaves provided are flat or not (since we want to cache both the
flat structure and the leaves). Consider the following case:

```python
d = FlatMapping.from_mapping({"a": 1})  # Caches the result of jax.tree_flatten.
l, t = jax.tree_flatten(d)              # Fine, leaves are flat.
l = list(map(lambda x: (x, x), l))      # leaves are no longer flat.
d2 = jax.tree_unflatten(t, l)           # Needs to recompute structure.
jax.tree_leaves(d2)                     # Should return [1, 1] not [(1, 1)]
```

Actual implementation here: d37b486e09/haiku/_src/data_structures.py (L204-L208)

This function allows an efficient way to do this using the JAX public API.
This commit is contained in:
Tom Hennigan 2020-03-28 13:14:40 +00:00 committed by George Necula
parent 025d8741d5
commit ca23be63fb
3 changed files with 86 additions and 22 deletions

View File

@ -87,6 +87,25 @@ def treedef_children(treedef):
def treedef_is_leaf(treedef):
return treedef.num_nodes == 1
def all_leaves(iterable):
"""Tests whether all elements in the given iterable are all leaves.
>>> tree = {"a": [1, 2, 3]}
>>> assert all_leaves(jax.tree_leaves(tree))
>>> assert not all_leaves([tree])
This function is useful in advanced cases, for example if a library allows
arbitrary map operations on a flat list of leaves it may want to check if
the result is still a flat list of leaves.
Args:
iterable: Iterable of leaves.
Returns:
True if all elements in the input are leaves false if not.
"""
return pytree.all_leaves(iterable)
def register_pytree_node(nodetype, flatten_func, unflatten_func):
"""Extends the set of types that are considered internal nodes in pytrees.

View File

@ -111,6 +111,9 @@ class PyTreeDef {
// Flattens a Pytree into a list of leaves and a PyTreeDef.
static std::pair<py::list, std::unique_ptr<PyTreeDef>> Flatten(py::handle x);
// Tests whether the given list is a flat list of leaves.
static bool AllLeaves(const py::iterable& x);
// Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of
// the tree-structure of 'x'. For example, if we flatten a value
// [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the
@ -204,6 +207,10 @@ class PyTreeDef {
py::handle xs,
std::vector<PyTreeDef::Node>::const_reverse_iterator* it) const;
// Computes the node kind of a given Python object.
static Kind GetKind(const py::handle& obj,
CustomNodeRegistry::Registration const** custom);
// Nodes, in a post-order traversal. We use an ordered traversal to minimize
// allocations, and post-order corresponds to the order we need to rebuild the
// tree structure.
@ -243,28 +250,47 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
return true;
}
/*static*/ PyTreeDef::Kind PyTreeDef::GetKind(
const py::handle& obj,
CustomNodeRegistry::Registration const** custom) {
const PyObject* ptr = obj.ptr();
if (PyTuple_CheckExact(ptr)) return Kind::kTuple;
if (PyList_CheckExact(ptr)) return Kind::kList;
if (PyDict_CheckExact(ptr)) return Kind::kDict;
if ((*custom = CustomNodeRegistry::Lookup(obj.get_type()))) {
return Kind::kCustom;
} else if (py::isinstance<py::none>(obj)) {
return Kind::kNone;
} else if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
// We can only identify namedtuples heuristically, here by the presence of
// a _fields attribute.
return Kind::kNamedTuple;
} else {
return Kind::kLeaf;
}
}
void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
PyTreeDef* tree) {
Node node;
int start_num_nodes = tree->traversal_.size();
int start_num_leaves = leaves->size();
if (py::isinstance<py::none>(handle)) {
node.kind = Kind::kNone;
} else if (PyTuple_CheckExact(handle.ptr())) {
node.kind = GetKind(handle, &node.custom);
if (node.kind == Kind::kNone) {
// Nothing to do.
} else if (node.kind == Kind::kTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.kind = Kind::kTuple;
node.arity = tuple.size();
for (py::handle entry : tuple) {
FlattenHelper(entry, leaves, tree);
}
} else if (PyList_CheckExact(handle.ptr())) {
} else if (node.kind == Kind::kList) {
py::list list = py::reinterpret_borrow<py::list>(handle);
node.kind = Kind::kList;
node.arity = list.size();
for (py::handle entry : list) {
FlattenHelper(entry, leaves, tree);
}
} else if (PyDict_CheckExact(handle.ptr())) {
} else if (node.kind == Kind::kDict) {
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
if (PyList_Sort(keys.ptr())) {
@ -273,11 +299,9 @@ void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
for (py::handle key : keys) {
FlattenHelper(dict[key], leaves, tree);
}
node.kind = Kind::kDict;
node.arity = dict.size();
node.node_data = std::move(keys);
} else if ((node.custom = CustomNodeRegistry::Lookup(handle.get_type()))) {
node.kind = Kind::kCustom;
} else if (node.kind == Kind::kCustom) {
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
if (out.size() != 2) {
throw std::runtime_error(
@ -289,19 +313,15 @@ void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
++node.arity;
FlattenHelper(entry, leaves, tree);
}
} else if (py::isinstance<py::tuple>(handle) &&
py::hasattr(handle, "_fields")) {
// We can only identify namedtuples heuristically, here by the presence of
// a _fields attribute.
} else if (node.kind == Kind::kNamedTuple) {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
node.kind = Kind::kNamedTuple;
node.arity = tuple.size();
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
for (py::handle entry : tuple) {
FlattenHelper(entry, leaves, tree);
}
} else {
node.kind = Kind::kLeaf;
CHECK(node.kind == Kind::kLeaf);
leaves->append(py::reinterpret_borrow<py::object>(handle));
}
node.num_nodes = tree->traversal_.size() - start_num_nodes + 1;
@ -317,6 +337,14 @@ void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
return std::make_pair(std::move(leaves), std::move(tree));
}
/*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
const CustomNodeRegistry::Registration* custom;
for (const py::handle& h : x) {
if (GetKind(h, &custom) != Kind::kLeaf) return false;
}
return true;
}
py::object PyTreeDef::Unflatten(py::iterable leaves) const {
std::vector<py::object> agenda;
auto it = leaves.begin();
@ -749,6 +777,7 @@ std::string PyTreeDef::ToString() const {
PYBIND11_MODULE(pytree, m) {
m.def("flatten", &PyTreeDef::Flatten);
m.def("tuple", &PyTreeDef::Tuple);
m.def("all_leaves", &PyTreeDef::AllLeaves);
py::class_<PyTreeDef>(m, "PyTreeDef")
.def("unflatten", &PyTreeDef::Unflatten)

View File

@ -69,8 +69,8 @@ class Special:
def __eq__(self, other):
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)
PYTREES = [
("foo",),
TREES = (
(None,),
((),),
(([()]),),
((1, 2),),
@ -84,18 +84,25 @@ PYTREES = [
(collections.defaultdict(dict,
[("foo", 34), ("baz", 101), ("something", -42)]),),
(ANamedTupleSubclass(foo="hello", bar=3.5),),
]
)
LEAVES = (
("foo",),
(0.1,),
(1,),
(object(),),
)
class TreeTest(jtu.JaxTestCase):
@parameterized.parameters(*PYTREES)
@parameterized.parameters(*(TREES + LEAVES))
def testRoundtrip(self, inputs):
xs, tree = tree_util.tree_flatten(inputs)
actual = tree_util.tree_unflatten(tree, xs)
self.assertEqual(actual, inputs)
@parameterized.parameters(*PYTREES)
@parameterized.parameters(*(TREES + LEAVES))
def testRoundtripWithFlattenUpTo(self, inputs):
_, tree = tree_util.tree_flatten(inputs)
if not hasattr(tree, "flatten_up_to"):
@ -119,7 +126,7 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(actual.args, inputs.args)
self.assertEqual(actual.keywords, inputs.keywords)
@parameterized.parameters(*PYTREES)
@parameterized.parameters(*(TREES + LEAVES))
def testRoundtripViaBuild(self, inputs):
xs, tree = tree_util._process_pytree(tuple, inputs)
actual = tree_util.build_tree(tree, xs)
@ -149,6 +156,15 @@ class TreeTest(jtu.JaxTestCase):
self.assertEqual(out, (((1, [3]), (2, None)),
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
@parameterized.parameters(*TREES)
def testAllLeavesWithTrees(self, tree):
leaves = tree_util.tree_leaves(tree)
self.assertTrue(tree_util.all_leaves(leaves))
self.assertFalse(tree_util.all_leaves([tree]))
@parameterized.parameters(*LEAVES)
def testAllLeavesWithLeaves(self, leaf):
self.assertTrue(tree_util.all_leaves([leaf]))
if __name__ == "__main__":
absltest.main()