Add a C++ implementation of a toplogical sort.

This is an exact port of the current Python implementation to C++ for speed.

I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.

PiperOrigin-RevId: 737014989
This commit is contained in:
Peter Hawkins 2025-03-14 16:03:45 -07:00 committed by jax authors
parent b00a3a1986
commit 14cb7453f0
5 changed files with 171 additions and 41 deletions

View File

@ -797,7 +797,7 @@ def tracers_to_jaxpr(
processed_eqn_ids = set()
eqns: list[core.JaxprEqn] = []
for t in toposort([*in_tracers, *out_tracers]):
for t in toposort((*in_tracers, *out_tracers)):
r = t.recipe
if isinstance(r, JaxprEqnRecipe):
# TODO broadcast_in_dim can create a new tracer, not present in parents

View File

@ -244,52 +244,62 @@ def curry(f):
"""
return wraps(f)(partial(partial, f))
def toposort(end_nodes):
if not end_nodes: return []
end_nodes = _remove_duplicates(end_nodes)
# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum.
toposort: Callable[[Iterable[Any]], list[Any]]
if hasattr(jaxlib_utils, "topological_sort"):
toposort = partial(jaxlib_utils.topological_sort, "parents")
else:
child_counts = {}
stack = list(end_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(node)] = 1
stack.extend(node.parents)
for node in end_nodes:
child_counts[id(node)] -= 1
def toposort(end_nodes):
if not end_nodes:
return []
end_nodes = _remove_duplicates(end_nodes)
sorted_nodes = []
childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
assert childless_nodes
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in node.parents:
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
child_counts = {}
stack = list(end_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
child_counts[id(node)] = 1
stack.extend(node.parents)
for node in end_nodes:
child_counts[id(node)] -= 1
check_toposort(sorted_nodes)
return sorted_nodes
sorted_nodes = []
childless_nodes = [
node for node in end_nodes if child_counts[id(node)] == 0
]
assert childless_nodes
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in node.parents:
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
def check_toposort(nodes):
visited = set()
for node in nodes:
assert all(id(parent) in visited for parent in node.parents)
visited.add(id(node))
check_toposort(sorted_nodes)
return sorted_nodes
def check_toposort(nodes):
visited = set()
for node in nodes:
assert all(id(parent) in visited for parent in node.parents)
visited.add(id(node))
def _remove_duplicates(node_list):
seen = set()
out = []
for n in node_list:
if id(n) not in seen:
seen.add(id(n))
out.append(n)
return out
def _remove_duplicates(node_list):
seen = set()
out = []
for n in node_list:
if id(n) not in seen:
seen.add(id(n))
out.append(n)
return out
def split_merge(predicate, xs):
sides = list(map(predicate, xs))

View File

@ -214,6 +214,8 @@ nanobind_extension(
module_name = "utils",
deps = [
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/synchronization",
"@nanobind",

View File

@ -16,9 +16,13 @@ limitations under the License.
#include <Python.h>
#include <cstddef>
#include <utility>
#include <vector>
#include "nanobind/nanobind.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/synchronization/mutex.h"
@ -293,6 +297,69 @@ PyMethodDef safe_zip_def = {
METH_FASTCALL,
};
nb::list TopologicalSort(nb::str parents_attr,
nb::iterable end_nodes_iterable) {
// This is a direct conversion of the original Python implementation.
// More efficient implementations of a topological sort are possible (and
// indeed, easier to write), but changing the choice of topological order
// would break existing tests.
std::vector<nb::object> end_nodes;
absl::flat_hash_set<PyObject*> seen;
for (nb::handle n : end_nodes_iterable) {
nb::object node = nb::borrow(n);
if (seen.insert(node.ptr()).second) {
end_nodes.push_back(node);
}
}
nb::list sorted_nodes;
if (end_nodes.empty()) {
return sorted_nodes;
}
std::vector<nb::object> stack = end_nodes;
absl::flat_hash_map<PyObject*, int> child_counts;
while (!stack.empty()) {
nb::object node = std::move(stack.back());
stack.pop_back();
auto& count = child_counts[node.ptr()];
if (count == 0) {
for (nb::handle parent : node.attr(parents_attr)) {
stack.push_back(nb::borrow(parent));
}
}
++count;
}
for (nb::handle n : end_nodes) {
child_counts[n.ptr()] -= 1;
}
std::vector<nb::object> childless_nodes;
childless_nodes.reserve(end_nodes.size());
for (nb::handle n : end_nodes) {
if (child_counts[n.ptr()] == 0) {
childless_nodes.push_back(nb::borrow(n));
}
}
while (!childless_nodes.empty()) {
nb::object node = std::move(childless_nodes.back());
childless_nodes.pop_back();
sorted_nodes.append(node);
for (nb::handle parent : node.attr(parents_attr)) {
auto& count = child_counts[parent.ptr()];
if (count == 1) {
childless_nodes.push_back(nb::borrow(parent));
} else {
--count;
}
}
}
sorted_nodes.reverse();
return sorted_nodes;
}
} // namespace
NB_MODULE(utils, m) {
@ -304,6 +371,13 @@ NB_MODULE(utils, m) {
m.attr("safe_zip") = nb::steal<nb::object>(
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"),
nb::arg("end_nodes"),
"Computes a topological sort of a graph of objects. parents_attr is "
"the name of the attribute on each object that contains the list of "
"parent objects. end_nodes is an iterable of objects from which we "
"should start a backwards search.");
// Python has no reader-writer lock in its standard library, so we expose
// bindings around absl::Mutex.
nb::class_<absl::Mutex>(m, "Mutex")

View File

@ -201,5 +201,49 @@ class SafeZipTest(jtu.JaxTestCase):
util.safe_zip((), range(3))
class Node:
def __init__(self, parents):
self.parents = parents
class TopologicalSortTest(jtu.JaxTestCase):
def _check_topological_sort(self, nodes, order):
self.assertEqual(sorted(nodes, key=id), sorted(order, key=id))
visited = set()
for node in nodes:
self.assertTrue(all(id(parent) in visited for parent in node.parents))
visited.add(id(node))
def test_basic(self):
a = Node([])
b = Node([a])
c = Node([a])
d = Node([a, c])
e = Node([b, c])
out = util.toposort([a, d, e])
self._check_topological_sort([a, b, c, d, e], out)
def test_stick(self):
a = Node([])
b = Node([a])
c = Node([b])
d = Node([c])
e = Node([d])
out = util.toposort([e])
self._check_topological_sort([a, b, c, d, e], out)
def test_diamonds(self):
a = Node([])
b = Node([a])
c = Node([a])
d = Node([b, c])
e = Node([d])
f = Node([d])
g = Node([e, f])
out = util.toposort([g])
self._check_topological_sort([a, b, c, d, e, f, g], out)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())