mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add jax.tree_util.all_leaves(iterable)
.
In Haiku (https://github.com/deepmind/dm-haiku) we have `FlatMapping` which is
an immutable Mapping subclass maintaining a flat internal representation. Our
goal is to allow very cheap flatten/unflatten since these objects are used to
represent parameters/state and are often passed in and out of JAX functions that
flatten their inputs (e.g. jit/pmap).
One challenge we have is that on unflatten we need a fast way of testing whether
the list of leaves provided are flat or not (since we want to cache both the
flat structure and the leaves). Consider the following case:
```python
d = FlatMapping.from_mapping({"a": 1}) # Caches the result of jax.tree_flatten.
l, t = jax.tree_flatten(d) # Fine, leaves are flat.
l = list(map(lambda x: (x, x), l)) # leaves are no longer flat.
d2 = jax.tree_unflatten(t, l) # Needs to recompute structure.
jax.tree_leaves(d2) # Should return [1, 1] not [(1, 1)]
```
Actual implementation here: d37b486e09/haiku/_src/data_structures.py (L204-L208)
This function allows an efficient way to do this using the JAX public API.
This commit is contained in:
parent
025d8741d5
commit
ca23be63fb
@ -87,6 +87,25 @@ def treedef_children(treedef):
|
||||
def treedef_is_leaf(treedef):
|
||||
return treedef.num_nodes == 1
|
||||
|
||||
def all_leaves(iterable):
|
||||
"""Tests whether all elements in the given iterable are all leaves.
|
||||
|
||||
>>> tree = {"a": [1, 2, 3]}
|
||||
>>> assert all_leaves(jax.tree_leaves(tree))
|
||||
>>> assert not all_leaves([tree])
|
||||
|
||||
This function is useful in advanced cases, for example if a library allows
|
||||
arbitrary map operations on a flat list of leaves it may want to check if
|
||||
the result is still a flat list of leaves.
|
||||
|
||||
Args:
|
||||
iterable: Iterable of leaves.
|
||||
|
||||
Returns:
|
||||
True if all elements in the input are leaves false if not.
|
||||
"""
|
||||
return pytree.all_leaves(iterable)
|
||||
|
||||
def register_pytree_node(nodetype, flatten_func, unflatten_func):
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
|
@ -111,6 +111,9 @@ class PyTreeDef {
|
||||
// Flattens a Pytree into a list of leaves and a PyTreeDef.
|
||||
static std::pair<py::list, std::unique_ptr<PyTreeDef>> Flatten(py::handle x);
|
||||
|
||||
// Tests whether the given list is a flat list of leaves.
|
||||
static bool AllLeaves(const py::iterable& x);
|
||||
|
||||
// Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of
|
||||
// the tree-structure of 'x'. For example, if we flatten a value
|
||||
// [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the
|
||||
@ -204,6 +207,10 @@ class PyTreeDef {
|
||||
py::handle xs,
|
||||
std::vector<PyTreeDef::Node>::const_reverse_iterator* it) const;
|
||||
|
||||
// Computes the node kind of a given Python object.
|
||||
static Kind GetKind(const py::handle& obj,
|
||||
CustomNodeRegistry::Registration const** custom);
|
||||
|
||||
// Nodes, in a post-order traversal. We use an ordered traversal to minimize
|
||||
// allocations, and post-order corresponds to the order we need to rebuild the
|
||||
// tree structure.
|
||||
@ -243,28 +250,47 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
/*static*/ PyTreeDef::Kind PyTreeDef::GetKind(
|
||||
const py::handle& obj,
|
||||
CustomNodeRegistry::Registration const** custom) {
|
||||
const PyObject* ptr = obj.ptr();
|
||||
if (PyTuple_CheckExact(ptr)) return Kind::kTuple;
|
||||
if (PyList_CheckExact(ptr)) return Kind::kList;
|
||||
if (PyDict_CheckExact(ptr)) return Kind::kDict;
|
||||
if ((*custom = CustomNodeRegistry::Lookup(obj.get_type()))) {
|
||||
return Kind::kCustom;
|
||||
} else if (py::isinstance<py::none>(obj)) {
|
||||
return Kind::kNone;
|
||||
} else if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
|
||||
// We can only identify namedtuples heuristically, here by the presence of
|
||||
// a _fields attribute.
|
||||
return Kind::kNamedTuple;
|
||||
} else {
|
||||
return Kind::kLeaf;
|
||||
}
|
||||
}
|
||||
|
||||
void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
|
||||
PyTreeDef* tree) {
|
||||
Node node;
|
||||
int start_num_nodes = tree->traversal_.size();
|
||||
int start_num_leaves = leaves->size();
|
||||
if (py::isinstance<py::none>(handle)) {
|
||||
node.kind = Kind::kNone;
|
||||
} else if (PyTuple_CheckExact(handle.ptr())) {
|
||||
node.kind = GetKind(handle, &node.custom);
|
||||
if (node.kind == Kind::kNone) {
|
||||
// Nothing to do.
|
||||
} else if (node.kind == Kind::kTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.kind = Kind::kTuple;
|
||||
node.arity = tuple.size();
|
||||
for (py::handle entry : tuple) {
|
||||
FlattenHelper(entry, leaves, tree);
|
||||
}
|
||||
} else if (PyList_CheckExact(handle.ptr())) {
|
||||
} else if (node.kind == Kind::kList) {
|
||||
py::list list = py::reinterpret_borrow<py::list>(handle);
|
||||
node.kind = Kind::kList;
|
||||
node.arity = list.size();
|
||||
for (py::handle entry : list) {
|
||||
FlattenHelper(entry, leaves, tree);
|
||||
}
|
||||
} else if (PyDict_CheckExact(handle.ptr())) {
|
||||
} else if (node.kind == Kind::kDict) {
|
||||
py::dict dict = py::reinterpret_borrow<py::dict>(handle);
|
||||
py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
|
||||
if (PyList_Sort(keys.ptr())) {
|
||||
@ -273,11 +299,9 @@ void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
|
||||
for (py::handle key : keys) {
|
||||
FlattenHelper(dict[key], leaves, tree);
|
||||
}
|
||||
node.kind = Kind::kDict;
|
||||
node.arity = dict.size();
|
||||
node.node_data = std::move(keys);
|
||||
} else if ((node.custom = CustomNodeRegistry::Lookup(handle.get_type()))) {
|
||||
node.kind = Kind::kCustom;
|
||||
} else if (node.kind == Kind::kCustom) {
|
||||
py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
|
||||
if (out.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
@ -289,19 +313,15 @@ void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
|
||||
++node.arity;
|
||||
FlattenHelper(entry, leaves, tree);
|
||||
}
|
||||
} else if (py::isinstance<py::tuple>(handle) &&
|
||||
py::hasattr(handle, "_fields")) {
|
||||
// We can only identify namedtuples heuristically, here by the presence of
|
||||
// a _fields attribute.
|
||||
} else if (node.kind == Kind::kNamedTuple) {
|
||||
py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
|
||||
node.kind = Kind::kNamedTuple;
|
||||
node.arity = tuple.size();
|
||||
node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
|
||||
for (py::handle entry : tuple) {
|
||||
FlattenHelper(entry, leaves, tree);
|
||||
}
|
||||
} else {
|
||||
node.kind = Kind::kLeaf;
|
||||
CHECK(node.kind == Kind::kLeaf);
|
||||
leaves->append(py::reinterpret_borrow<py::object>(handle));
|
||||
}
|
||||
node.num_nodes = tree->traversal_.size() - start_num_nodes + 1;
|
||||
@ -317,6 +337,14 @@ void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves,
|
||||
return std::make_pair(std::move(leaves), std::move(tree));
|
||||
}
|
||||
|
||||
/*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
|
||||
const CustomNodeRegistry::Registration* custom;
|
||||
for (const py::handle& h : x) {
|
||||
if (GetKind(h, &custom) != Kind::kLeaf) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
py::object PyTreeDef::Unflatten(py::iterable leaves) const {
|
||||
std::vector<py::object> agenda;
|
||||
auto it = leaves.begin();
|
||||
@ -749,6 +777,7 @@ std::string PyTreeDef::ToString() const {
|
||||
PYBIND11_MODULE(pytree, m) {
|
||||
m.def("flatten", &PyTreeDef::Flatten);
|
||||
m.def("tuple", &PyTreeDef::Tuple);
|
||||
m.def("all_leaves", &PyTreeDef::AllLeaves);
|
||||
|
||||
py::class_<PyTreeDef>(m, "PyTreeDef")
|
||||
.def("unflatten", &PyTreeDef::Unflatten)
|
||||
|
@ -69,8 +69,8 @@ class Special:
|
||||
def __eq__(self, other):
|
||||
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)
|
||||
|
||||
PYTREES = [
|
||||
("foo",),
|
||||
TREES = (
|
||||
(None,),
|
||||
((),),
|
||||
(([()]),),
|
||||
((1, 2),),
|
||||
@ -84,18 +84,25 @@ PYTREES = [
|
||||
(collections.defaultdict(dict,
|
||||
[("foo", 34), ("baz", 101), ("something", -42)]),),
|
||||
(ANamedTupleSubclass(foo="hello", bar=3.5),),
|
||||
]
|
||||
)
|
||||
|
||||
LEAVES = (
|
||||
("foo",),
|
||||
(0.1,),
|
||||
(1,),
|
||||
(object(),),
|
||||
)
|
||||
|
||||
|
||||
class TreeTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters(*PYTREES)
|
||||
@parameterized.parameters(*(TREES + LEAVES))
|
||||
def testRoundtrip(self, inputs):
|
||||
xs, tree = tree_util.tree_flatten(inputs)
|
||||
actual = tree_util.tree_unflatten(tree, xs)
|
||||
self.assertEqual(actual, inputs)
|
||||
|
||||
@parameterized.parameters(*PYTREES)
|
||||
@parameterized.parameters(*(TREES + LEAVES))
|
||||
def testRoundtripWithFlattenUpTo(self, inputs):
|
||||
_, tree = tree_util.tree_flatten(inputs)
|
||||
if not hasattr(tree, "flatten_up_to"):
|
||||
@ -119,7 +126,7 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(actual.args, inputs.args)
|
||||
self.assertEqual(actual.keywords, inputs.keywords)
|
||||
|
||||
@parameterized.parameters(*PYTREES)
|
||||
@parameterized.parameters(*(TREES + LEAVES))
|
||||
def testRoundtripViaBuild(self, inputs):
|
||||
xs, tree = tree_util._process_pytree(tuple, inputs)
|
||||
actual = tree_util.build_tree(tree, xs)
|
||||
@ -149,6 +156,15 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, (((1, [3]), (2, None)),
|
||||
((3, {"foo": "bar"}), (4, 7), (5, [5, 6]))))
|
||||
|
||||
@parameterized.parameters(*TREES)
|
||||
def testAllLeavesWithTrees(self, tree):
|
||||
leaves = tree_util.tree_leaves(tree)
|
||||
self.assertTrue(tree_util.all_leaves(leaves))
|
||||
self.assertFalse(tree_util.all_leaves([tree]))
|
||||
|
||||
@parameterized.parameters(*LEAVES)
|
||||
def testAllLeavesWithLeaves(self, leaf):
|
||||
self.assertTrue(tree_util.all_leaves([leaf]))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user