/* Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Caution: this code uses exceptions. The exception use is local to the // binding code and the idiomatic way to emit Python exceptions. #include #include #include #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/hash/hash.h" #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "include/pybind11/pybind11.h" #include "include/pybind11/pytypes.h" #include "include/pybind11/stl.h" namespace jax { namespace py = pybind11; // Registry of custom node types. class CustomNodeRegistry { public: struct Registration { // The Python type object, used to identify the type. py::object type; // A function with signature: object -> (iterable, aux_data) py::function to_iterable; // A function with signature: (aux_data, iterable) -> object py::function from_iterable; }; // Registers a new custom type. Objects of `type` will be treated as container // node types in PyTrees. static void Register(py::object type, py::function to_iterable, py::function from_iterable); // Finds the custom type registration for `type`. Returns nullptr if none // exists. static const Registration* Lookup(py::handle type); private: static CustomNodeRegistry* Singleton(); struct TypeHash { size_t operator()(const py::object& t) const { return py::hash(t); } }; struct TypeEq { bool operator()(const py::object& a, const py::object& b) const { return a.equal(b); } }; absl::flat_hash_map, TypeHash, TypeEq> registrations_; }; /*static*/ CustomNodeRegistry* CustomNodeRegistry::Singleton() { static auto* registry = new CustomNodeRegistry; return registry; } /*static*/ void CustomNodeRegistry::Register(py::object type, py::function to_iterable, py::function from_iterable) { CustomNodeRegistry* registry = Singleton(); auto registration = absl::make_unique(); registration->type = type; registration->to_iterable = std::move(to_iterable); registration->from_iterable = std::move(from_iterable); auto it = registry->registrations_.emplace(type, std::move(registration)); if (!it.second) { throw std::invalid_argument( absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", py::repr(type))); } } /*static*/ const CustomNodeRegistry::Registration* CustomNodeRegistry::Lookup( py::handle type) { CustomNodeRegistry* registry = Singleton(); auto it = registry->registrations_.find(py::reinterpret_borrow(type)); return it == registry->registrations_.end() ? nullptr : it->second.get(); } // A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of // Python values, where the interior nodes are tuples, lists, dictionaries, or // user-defined containers, and the leaves are other objects. class PyTreeDef { public: PyTreeDef() = default; // Flattens a Pytree into a list of leaves and a PyTreeDef. static std::pair> Flatten(py::handle x); // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of // the tree-structure of 'x'. For example, if we flatten a value // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the // list of leaves [1, (2, 3), {"foo": 4}]. py::list FlattenUpTo(py::handle x) const; // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. py::object Unflatten(py::iterable leaves) const; // Composes two PyTreeDefs, replacing the leaves of this tree with copies of // `inner`. std::unique_ptr Compose(const PyTreeDef& inner) const; // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. static std::unique_ptr Tuple(const std::vector& defs); std::vector> Children() const; // Maps a function over a PyTree structure, applying f_leaf to each leaf, and // f_node to each container node. // TODO(phawkins): use flattening everywhere instead and delete this method. py::object Walk(const py::function& f_node, py::handle f_leaf, py::iterable leaves) const; // Given a tree of iterables with the same node/leaf structure as this PyTree, // build the corresponding PyTree. // TODO(phawkins): use flattening everywhere instead and delete this method. py::object FromIterableTree(py::handle xs) const; int num_leaves() const { if (traversal_.empty()) { return 0; } return traversal_.back().num_leaves; } int num_nodes() const { return traversal_.size(); } size_t Hash() const; bool operator==(const PyTreeDef& other) const; bool operator!=(const PyTreeDef& other) const { return !(*this == other); } std::string ToString() const; private: enum class Kind { kLeaf, // An opaque leaf node kNone, // None. kTuple, // A tuple kNamedTuple, // A collections.namedtuple kList, // A list kDict, // A dict kCustom, // A custom type. }; struct Node { Kind kind = Kind::kLeaf; // Arity for non-kLeaf types. int arity = 0; // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type // object. For a kDict, contains a sorted list of keys. For a kCustom type, // contains the auxiliary data returned by the `to_iterable` function. py::object node_data; const CustomNodeRegistry::Registration* custom = nullptr; // Number of leaf nodes in the subtree rooted at this node. int num_leaves = 0; // Number of leaf and interior nodes in the subtree rooted at this node. int num_nodes = 0; }; template friend H AbslHashValue(H h, const Node& n); template friend H AbslHashValue(H h, const PyTreeDef& t); // Helper that manufactures an instance of a node given its children. static py::object MakeNode(const Node& node, absl::Span children); // Recursive helper used to implement Flatten() static void FlattenHelper(py::handle handle, py::list* leaves, PyTreeDef* tree); // Recursive helper used to implement FromIterableTree() py::object FromIterableTreeHelper( py::handle xs, std::vector::const_reverse_iterator* it) const; // Nodes, in a post-order traversal. We use an ordered traversal to minimize // allocations, and post-order corresponds to the order we need to rebuild the // tree structure. std::vector traversal_; }; template H AbslHashValue(H h, const PyTreeDef::Node& n) { h = H::combine(std::move(h), n.kind, n.arity, n.custom); return h; } template H AbslHashValue(H h, const PyTreeDef& t) { return H::combine_contiguous(std::move(h), t.traversal_.data(), t.traversal_.size()); } bool PyTreeDef::operator==(const PyTreeDef& other) const { if (traversal_.size() != other.traversal_.size()) { return false; } for (size_t i = 0; i < traversal_.size(); ++i) { const Node& a = traversal_[i]; const Node& b = other.traversal_[i]; if (a.kind != b.kind || a.arity != b.arity || (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) || a.custom != b.custom) { return false; } if (a.node_data && a.node_data.not_equal(b.node_data)) { return false; } // We don't need to test equality of num_leaves and num_nodes since they // are derivable from the other node data. } return true; } void PyTreeDef::FlattenHelper(py::handle handle, py::list* leaves, PyTreeDef* tree) { Node node; int start_num_nodes = tree->traversal_.size(); int start_num_leaves = leaves->size(); if (py::isinstance(handle)) { node.kind = Kind::kNone; } else if (PyTuple_CheckExact(handle.ptr())) { py::tuple tuple = py::reinterpret_borrow(handle); node.kind = Kind::kTuple; node.arity = tuple.size(); for (py::handle entry : tuple) { FlattenHelper(entry, leaves, tree); } } else if (PyList_CheckExact(handle.ptr())) { py::list list = py::reinterpret_borrow(handle); node.kind = Kind::kList; node.arity = list.size(); for (py::handle entry : list) { FlattenHelper(entry, leaves, tree); } } else if (PyDict_CheckExact(handle.ptr())) { py::dict dict = py::reinterpret_borrow(handle); py::list keys = py::reinterpret_steal(PyDict_Keys(dict.ptr())); if (PyList_Sort(keys.ptr())) { throw std::runtime_error("Dictionary key sort failed."); } for (py::handle key : keys) { FlattenHelper(dict[key], leaves, tree); } node.kind = Kind::kDict; node.arity = dict.size(); node.node_data = std::move(keys); } else if ((node.custom = CustomNodeRegistry::Lookup(handle.get_type()))) { node.kind = Kind::kCustom; py::tuple out = py::cast(node.custom->to_iterable(handle)); if (out.size() != 2) { throw std::runtime_error( "PyTree custom to_iterable function should return a pair"); } node.node_data = out[1]; node.arity = 0; for (py::handle entry : py::cast(out[0])) { ++node.arity; FlattenHelper(entry, leaves, tree); } } else if (py::isinstance(handle) && py::hasattr(handle, "_fields")) { // We can only identify namedtuples heuristically, here by the presence of // a _fields attribute. py::tuple tuple = py::reinterpret_borrow(handle); node.kind = Kind::kNamedTuple; node.arity = tuple.size(); node.node_data = py::reinterpret_borrow(tuple.get_type()); for (py::handle entry : tuple) { FlattenHelper(entry, leaves, tree); } } else { node.kind = Kind::kLeaf; leaves->append(py::reinterpret_borrow(handle)); } node.num_nodes = tree->traversal_.size() - start_num_nodes + 1; node.num_leaves = leaves->size() - start_num_leaves; tree->traversal_.push_back(std::move(node)); } /*static*/ std::pair> PyTreeDef::Flatten( py::handle x) { py::list leaves; auto tree = absl::make_unique(); FlattenHelper(x, &leaves, tree.get()); return std::make_pair(std::move(leaves), std::move(tree)); } py::object PyTreeDef::Unflatten(py::iterable leaves) const { std::vector agenda; auto it = leaves.begin(); int leaf_count = 0; for (const Node& node : traversal_) { if (agenda.size() < node.arity) { throw std::logic_error("Too few elements for TreeDef node."); } switch (node.kind) { case Kind::kLeaf: if (it == leaves.end()) { throw std::invalid_argument(absl::StrFormat( "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(), leaf_count)); } agenda.push_back(py::reinterpret_borrow(*it)); ++it; ++leaf_count; break; case Kind::kNone: case Kind::kTuple: case Kind::kNamedTuple: case Kind::kList: case Kind::kDict: case Kind::kCustom: { const int size = agenda.size(); absl::Span span; if (node.arity > 0) { span = absl::Span(&agenda[size - node.arity], node.arity); } py::object o = MakeNode(node, span); agenda.resize(size - node.arity); agenda.push_back(o); break; } } } if (it != leaves.end()) { throw std::invalid_argument(absl::StrFormat( "Too many leaves for PyTreeDef; expected %d.", num_leaves())); } if (agenda.size() != 1) { throw std::logic_error("PyTreeDef traversal did not yield a singleton."); } return std::move(agenda.back()); } /*static*/ py::object PyTreeDef::MakeNode(const PyTreeDef::Node& node, absl::Span children) { if (children.size() != node.arity) { throw std::logic_error("Node arity mismatch."); } switch (node.kind) { case Kind::kLeaf: throw std::logic_error("MakeNode not implemented for leaves."); case Kind::kNone: return py::none(); case Kind::kTuple: case Kind::kNamedTuple: { py::tuple tuple(node.arity); for (int i = 0; i < node.arity; ++i) { tuple[i] = std::move(children[i]); } if (node.kind == Kind::kNamedTuple) { return node.node_data(*tuple); } else { return std::move(tuple); } } case Kind::kList: { py::list list(node.arity); for (int i = 0; i < node.arity; ++i) { list[i] = std::move(children[i]); } return std::move(list); } case Kind::kDict: { py::dict dict; py::list keys = py::reinterpret_borrow(node.node_data); for (int i = 0; i < node.arity; ++i) { dict[keys[i]] = std::move(children[i]); } return std::move(dict); break; } case Kind::kCustom: { py::tuple tuple(node.arity); for (int i = 0; i < node.arity; ++i) { tuple[i] = std::move(children[i]); } return node.custom->from_iterable(node.node_data, tuple); } } } py::list PyTreeDef::FlattenUpTo(py::handle xs) const { py::list leaves(num_leaves()); std::vector agenda; agenda.push_back(py::reinterpret_borrow(xs)); auto it = traversal_.rbegin(); int leaf = num_leaves() - 1; while (!agenda.empty()) { if (it == traversal_.rend()) { throw std::invalid_argument( absl::StrFormat("Tree structures did not match: %s vs %s", py::repr(xs), ToString())); } const Node& node = *it; py::object object = agenda.back(); agenda.pop_back(); ++it; switch (node.kind) { case Kind::kLeaf: if (leaf < 0) { throw std::logic_error("Leaf count mismatch."); } leaves[leaf] = py::reinterpret_borrow(object); --leaf; break; case Kind::kNone: break; case Kind::kTuple: { if (!PyTuple_CheckExact(object.ptr())) { throw std::invalid_argument( absl::StrFormat("Expected tuple, got %s.", py::repr(object))); } py::tuple tuple = py::reinterpret_borrow(object); if (tuple.size() != node.arity) { throw std::invalid_argument( absl::StrFormat("Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), node.arity, py::repr(object))); } for (py::handle entry : tuple) { agenda.push_back(py::reinterpret_borrow(entry)); } break; } case Kind::kList: { if (!PyList_CheckExact(object.ptr())) { throw std::invalid_argument( absl::StrFormat("Expected list, got %s.", py::repr(object))); } py::list list = py::reinterpret_borrow(object); if (list.size() != node.arity) { throw std::invalid_argument( absl::StrFormat("List arity mismatch: %d != %d; list: %s.", list.size(), node.arity, py::repr(object))); } for (py::handle entry : list) { agenda.push_back(py::reinterpret_borrow(entry)); } break; } case Kind::kDict: { if (!PyDict_CheckExact(object.ptr())) { throw std::invalid_argument( absl::StrFormat("Expected dict, got %s.", py::repr(object))); } py::dict dict = py::reinterpret_borrow(object); py::list keys = py::reinterpret_steal(PyDict_Keys(dict.ptr())); if (PyList_Sort(keys.ptr())) { throw std::runtime_error("Dictionary key sort failed."); } if (keys.not_equal(node.node_data)) { throw std::invalid_argument( absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", py::repr(node.node_data), py::repr(object))); } for (py::handle key : keys) { agenda.push_back(dict[key]); } break; } case Kind::kNamedTuple: { if (!py::isinstance(object) || !py::hasattr(object, "_fields")) { throw std::invalid_argument(absl::StrFormat( "Expected named tuple, got %s.", py::repr(object))); } py::tuple tuple = py::reinterpret_borrow(object); if (tuple.size() != node.arity) { throw std::invalid_argument(absl::StrFormat( "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), node.arity, py::repr(object))); } if (tuple.get_type().not_equal(node.node_data)) { throw std::invalid_argument(absl::StrFormat( "Named tuple type mismatch: expected type: %s, tuple: %s.", py::repr(node.node_data), py::repr(object))); } for (py::handle entry : tuple) { agenda.push_back(py::reinterpret_borrow(entry)); } break; } case Kind::kCustom: { auto* registration = CustomNodeRegistry::Lookup(object.get_type()); if (registration != node.custom) { throw std::invalid_argument(absl::StrFormat( "Custom node type mismatch: expected type: %s, value: %s.", py::repr(node.custom->type), py::repr(object))); } py::tuple out = py::cast(node.custom->to_iterable(object)); if (out.size() != 2) { throw std::runtime_error( "PyTree custom to_iterable function should return a pair"); } if (node.node_data.not_equal(out[1])) { throw std::invalid_argument(absl::StrFormat( "Mismatch custom node data: %s != %s; value: %s.", py::repr(node.node_data), py::repr(out[1]), py::repr(object))); } int arity = 0; for (py::handle entry : py::cast(out[0])) { ++arity; agenda.push_back(py::reinterpret_borrow(entry)); } if (arity != node.arity) { throw std::invalid_argument(absl::StrFormat( "Custom type arity mismatch: %d != %d; value: %s.", arity, node.arity, py::repr(object))); } break; } } } if (it != traversal_.rend() || leaf != -1) { throw std::invalid_argument( absl::StrFormat("Tree structures did not match: %s vs %s", py::repr(xs), ToString())); } return leaves; } py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf, py::iterable leaves) const { std::vector agenda; auto it = leaves.begin(); for (const Node& node : traversal_) { switch (node.kind) { case Kind::kLeaf: { if (it == leaves.end()) { throw std::invalid_argument("Too few leaves for PyTreeDef"); } py::object leaf = py::reinterpret_borrow(*it); agenda.push_back(f_leaf.is_none() ? std::move(leaf) : f_leaf(std::move(leaf))); ++it; break; } case Kind::kNone: case Kind::kTuple: case Kind::kNamedTuple: case Kind::kList: case Kind::kDict: case Kind::kCustom: { if (agenda.size() < node.arity) { throw std::logic_error("Too few elements for custom type."); } py::tuple tuple(node.arity); for (int i = node.arity - 1; i >= 0; --i) { tuple[i] = agenda.back(); agenda.pop_back(); } agenda.push_back(f_node(tuple)); } } } if (it != leaves.end()) { throw std::invalid_argument("Too many leaves for PyTreeDef"); } if (agenda.size() != 1) { throw std::logic_error("PyTreeDef traversal did not yield a singleton."); } return std::move(agenda.back()); } py::object PyTreeDef::FromIterableTreeHelper( py::handle xs, std::vector::const_reverse_iterator* it) const { if (*it == traversal_.rend()) { throw std::invalid_argument("Tree structures did not match."); } const Node& node = **it; ++*it; if (node.kind == Kind::kLeaf) { return py::reinterpret_borrow(xs); } py::iterable iterable = py::reinterpret_borrow(xs); std::vector ys; ys.reserve(node.arity); for (py::handle x : iterable) { ys.push_back(py::reinterpret_borrow(x)); } if (ys.size() != node.arity) { throw std::invalid_argument("Arity mismatch between trees"); } for (int j = node.arity - 1; j >= 0; --j) { ys[j] = FromIterableTreeHelper(ys[j], it); } return MakeNode(node, absl::MakeSpan(ys)); } py::object PyTreeDef::FromIterableTree(py::handle xs) const { auto it = traversal_.rbegin(); py::object out = FromIterableTreeHelper(xs, &it); if (it != traversal_.rend()) { throw std::invalid_argument("Tree structures did not match."); } return out; } std::unique_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { auto out = absl::make_unique(); for (const Node& n : traversal_) { if (n.kind == Kind::kLeaf) { absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_)); } else { out->traversal_.push_back(n); } } return out; } /*static*/ std::unique_ptr PyTreeDef::Tuple( const std::vector& defs) { auto out = absl::make_unique(); for (const PyTreeDef& def : defs) { absl::c_copy(def.traversal_, std::back_inserter(out->traversal_)); } Node node; node.kind = Kind::kTuple; node.arity = defs.size(); out->traversal_.push_back(node); return out; } std::vector> PyTreeDef::Children() const { std::vector> children; if (traversal_.empty()) { return children; } Node const& root = traversal_.back(); children.resize(root.arity); int pos = traversal_.size() - 1; for (int i = root.arity - 1; i >= 0; --i) { children[i] = absl::make_unique(); const Node& node = traversal_.at(pos - 1); if (pos < node.num_nodes) { throw std::logic_error("children() walked off start of array"); } std::copy(traversal_.begin() + pos - node.num_nodes, traversal_.begin() + pos, std::back_inserter(children[i]->traversal_)); pos -= node.num_nodes; } if (pos != 0) { throw std::logic_error("pos != 0 at end of PyTreeDef::Children"); } return children; } std::string PyTreeDef::ToString() const { std::vector agenda; for (const Node& node : traversal_) { if (agenda.size() < node.arity) { throw std::logic_error("Too few elements for container."); } std::string kind; switch (node.kind) { case Kind::kLeaf: agenda.push_back("*"); continue; case Kind::kNone: kind = "None"; break; case Kind::kNamedTuple: kind = "namedtuple"; break; case Kind::kTuple: kind = "tuple"; break; case Kind::kList: kind = "list"; break; case Kind::kDict: kind = "dict"; break; case Kind::kCustom: kind = static_cast(py::str(node.custom->type)); break; } std::string children = absl::StrJoin(agenda.end() - node.arity, agenda.end(), ","); agenda.erase(agenda.end() - node.arity, agenda.end()); std::string data; if (node.node_data) { data = absl::StrFormat("[%s]", py::str(node.node_data)); } agenda.push_back( absl::StrFormat("PyTreeDef(%s%s, [%s])", kind, data, children)); } if (agenda.size() != 1) { throw std::logic_error("PyTreeDef traversal did not yield a singleton."); } return std::move(agenda.back()); } PYBIND11_MODULE(pytree, m) { m.def("flatten", &PyTreeDef::Flatten); m.def("tuple", &PyTreeDef::Tuple); py::class_(m, "PyTreeDef") .def("unflatten", &PyTreeDef::Unflatten) .def("flatten_up_to", &PyTreeDef::FlattenUpTo) .def("compose", &PyTreeDef::Compose) .def("walk", &PyTreeDef::Walk) .def("from_iterable_tree", &PyTreeDef::FromIterableTree) .def("children", &PyTreeDef::Children) .def_property_readonly("num_leaves", &PyTreeDef::num_leaves) .def_property_readonly("num_nodes", &PyTreeDef::num_nodes) .def("__repr__", &PyTreeDef::ToString) .def("__eq__", [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; }) .def("__ne__", [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; }) .def("__hash__", [](const PyTreeDef& t) { return absl::Hash()(t); }); m.def("register_node", [](py::object type, py::function to_iterable, py::function from_iterable) { return CustomNodeRegistry::Register(type, to_iterable, from_iterable); }); } } // namespace jax