diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 16e090dd4..95c09ae94 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -797,7 +797,7 @@ def tracers_to_jaxpr( processed_eqn_ids = set() eqns: list[core.JaxprEqn] = [] - for t in toposort([*in_tracers, *out_tracers]): + for t in toposort((*in_tracers, *out_tracers)): r = t.recipe if isinstance(r, JaxprEqnRecipe): # TODO broadcast_in_dim can create a new tracer, not present in parents diff --git a/jax/_src/util.py b/jax/_src/util.py index 408106c12..3e52a2584 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -244,52 +244,62 @@ def curry(f): """ return wraps(f)(partial(partial, f)) -def toposort(end_nodes): - if not end_nodes: return [] - end_nodes = _remove_duplicates(end_nodes) +# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum. +toposort: Callable[[Iterable[Any]], list[Any]] +if hasattr(jaxlib_utils, "topological_sort"): + toposort = partial(jaxlib_utils.topological_sort, "parents") +else: - child_counts = {} - stack = list(end_nodes) - while stack: - node = stack.pop() - if id(node) in child_counts: - child_counts[id(node)] += 1 - else: - child_counts[id(node)] = 1 - stack.extend(node.parents) - for node in end_nodes: - child_counts[id(node)] -= 1 + def toposort(end_nodes): + if not end_nodes: + return [] + end_nodes = _remove_duplicates(end_nodes) - sorted_nodes = [] - childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0] - assert childless_nodes - while childless_nodes: - node = childless_nodes.pop() - sorted_nodes.append(node) - for parent in node.parents: - if child_counts[id(parent)] == 1: - childless_nodes.append(parent) + child_counts = {} + stack = list(end_nodes) + while stack: + node = stack.pop() + if id(node) in child_counts: + child_counts[id(node)] += 1 else: - child_counts[id(parent)] -= 1 - sorted_nodes = sorted_nodes[::-1] + child_counts[id(node)] = 1 + stack.extend(node.parents) + for node in end_nodes: + child_counts[id(node)] -= 1 - check_toposort(sorted_nodes) - return sorted_nodes + sorted_nodes = [] + childless_nodes = [ + node for node in end_nodes if child_counts[id(node)] == 0 + ] + assert childless_nodes + while childless_nodes: + node = childless_nodes.pop() + sorted_nodes.append(node) + for parent in node.parents: + if child_counts[id(parent)] == 1: + childless_nodes.append(parent) + else: + child_counts[id(parent)] -= 1 + sorted_nodes = sorted_nodes[::-1] -def check_toposort(nodes): - visited = set() - for node in nodes: - assert all(id(parent) in visited for parent in node.parents) - visited.add(id(node)) + check_toposort(sorted_nodes) + return sorted_nodes + + def check_toposort(nodes): + visited = set() + for node in nodes: + assert all(id(parent) in visited for parent in node.parents) + visited.add(id(node)) + + def _remove_duplicates(node_list): + seen = set() + out = [] + for n in node_list: + if id(n) not in seen: + seen.add(id(n)) + out.append(n) + return out -def _remove_duplicates(node_list): - seen = set() - out = [] - for n in node_list: - if id(n) not in seen: - seen.add(id(n)) - out.append(n) - return out def split_merge(predicate, xs): sides = list(map(predicate, xs)) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 33ea0e07d..93c2b483c 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -214,6 +214,8 @@ nanobind_extension( module_name = "utils", deps = [ "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/synchronization", "@nanobind", diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index 8b673ef68..bf50b3a52 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -16,9 +16,13 @@ limitations under the License. #include #include +#include +#include #include "nanobind/nanobind.h" #include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/synchronization/mutex.h" @@ -293,6 +297,69 @@ PyMethodDef safe_zip_def = { METH_FASTCALL, }; +nb::list TopologicalSort(nb::str parents_attr, + nb::iterable end_nodes_iterable) { + // This is a direct conversion of the original Python implementation. + // More efficient implementations of a topological sort are possible (and + // indeed, easier to write), but changing the choice of topological order + // would break existing tests. + std::vector end_nodes; + absl::flat_hash_set seen; + for (nb::handle n : end_nodes_iterable) { + nb::object node = nb::borrow(n); + if (seen.insert(node.ptr()).second) { + end_nodes.push_back(node); + } + } + + nb::list sorted_nodes; + if (end_nodes.empty()) { + return sorted_nodes; + } + + std::vector stack = end_nodes; + absl::flat_hash_map child_counts; + while (!stack.empty()) { + nb::object node = std::move(stack.back()); + stack.pop_back(); + auto& count = child_counts[node.ptr()]; + if (count == 0) { + for (nb::handle parent : node.attr(parents_attr)) { + stack.push_back(nb::borrow(parent)); + } + } + ++count; + } + + for (nb::handle n : end_nodes) { + child_counts[n.ptr()] -= 1; + } + + std::vector childless_nodes; + childless_nodes.reserve(end_nodes.size()); + for (nb::handle n : end_nodes) { + if (child_counts[n.ptr()] == 0) { + childless_nodes.push_back(nb::borrow(n)); + } + } + + while (!childless_nodes.empty()) { + nb::object node = std::move(childless_nodes.back()); + childless_nodes.pop_back(); + sorted_nodes.append(node); + for (nb::handle parent : node.attr(parents_attr)) { + auto& count = child_counts[parent.ptr()]; + if (count == 1) { + childless_nodes.push_back(nb::borrow(parent)); + } else { + --count; + } + } + } + sorted_nodes.reverse(); + return sorted_nodes; +} + } // namespace NB_MODULE(utils, m) { @@ -304,6 +371,13 @@ NB_MODULE(utils, m) { m.attr("safe_zip") = nb::steal( PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr())); + m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"), + nb::arg("end_nodes"), + "Computes a topological sort of a graph of objects. parents_attr is " + "the name of the attribute on each object that contains the list of " + "parent objects. end_nodes is an iterable of objects from which we " + "should start a backwards search."); + // Python has no reader-writer lock in its standard library, so we expose // bindings around absl::Mutex. nb::class_(m, "Mutex") diff --git a/tests/util_test.py b/tests/util_test.py index cb803d66b..53414dae9 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -201,5 +201,49 @@ class SafeZipTest(jtu.JaxTestCase): util.safe_zip((), range(3)) +class Node: + def __init__(self, parents): + self.parents = parents + + +class TopologicalSortTest(jtu.JaxTestCase): + + def _check_topological_sort(self, nodes, order): + self.assertEqual(sorted(nodes, key=id), sorted(order, key=id)) + visited = set() + for node in nodes: + self.assertTrue(all(id(parent) in visited for parent in node.parents)) + visited.add(id(node)) + + def test_basic(self): + a = Node([]) + b = Node([a]) + c = Node([a]) + d = Node([a, c]) + e = Node([b, c]) + out = util.toposort([a, d, e]) + self._check_topological_sort([a, b, c, d, e], out) + + def test_stick(self): + a = Node([]) + b = Node([a]) + c = Node([b]) + d = Node([c]) + e = Node([d]) + out = util.toposort([e]) + self._check_topological_sort([a, b, c, d, e], out) + + def test_diamonds(self): + a = Node([]) + b = Node([a]) + c = Node([a]) + d = Node([b, c]) + e = Node([d]) + f = Node([d]) + g = Node([e, f]) + out = util.toposort([g]) + self._check_topological_sort([a, b, c, d, e, f, g], out) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())