mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add a C++ implementation of a toplogical sort.
This is an exact port of the current Python implementation to C++ for speed. I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change. PiperOrigin-RevId: 737014989
This commit is contained in:
parent
b00a3a1986
commit
14cb7453f0
@ -797,7 +797,7 @@ def tracers_to_jaxpr(
|
||||
|
||||
processed_eqn_ids = set()
|
||||
eqns: list[core.JaxprEqn] = []
|
||||
for t in toposort([*in_tracers, *out_tracers]):
|
||||
for t in toposort((*in_tracers, *out_tracers)):
|
||||
r = t.recipe
|
||||
if isinstance(r, JaxprEqnRecipe):
|
||||
# TODO broadcast_in_dim can create a new tracer, not present in parents
|
||||
|
@ -244,52 +244,62 @@ def curry(f):
|
||||
"""
|
||||
return wraps(f)(partial(partial, f))
|
||||
|
||||
def toposort(end_nodes):
|
||||
if not end_nodes: return []
|
||||
end_nodes = _remove_duplicates(end_nodes)
|
||||
# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum.
|
||||
toposort: Callable[[Iterable[Any]], list[Any]]
|
||||
if hasattr(jaxlib_utils, "topological_sort"):
|
||||
toposort = partial(jaxlib_utils.topological_sort, "parents")
|
||||
else:
|
||||
|
||||
child_counts = {}
|
||||
stack = list(end_nodes)
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if id(node) in child_counts:
|
||||
child_counts[id(node)] += 1
|
||||
else:
|
||||
child_counts[id(node)] = 1
|
||||
stack.extend(node.parents)
|
||||
for node in end_nodes:
|
||||
child_counts[id(node)] -= 1
|
||||
def toposort(end_nodes):
|
||||
if not end_nodes:
|
||||
return []
|
||||
end_nodes = _remove_duplicates(end_nodes)
|
||||
|
||||
sorted_nodes = []
|
||||
childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
|
||||
assert childless_nodes
|
||||
while childless_nodes:
|
||||
node = childless_nodes.pop()
|
||||
sorted_nodes.append(node)
|
||||
for parent in node.parents:
|
||||
if child_counts[id(parent)] == 1:
|
||||
childless_nodes.append(parent)
|
||||
child_counts = {}
|
||||
stack = list(end_nodes)
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if id(node) in child_counts:
|
||||
child_counts[id(node)] += 1
|
||||
else:
|
||||
child_counts[id(parent)] -= 1
|
||||
sorted_nodes = sorted_nodes[::-1]
|
||||
child_counts[id(node)] = 1
|
||||
stack.extend(node.parents)
|
||||
for node in end_nodes:
|
||||
child_counts[id(node)] -= 1
|
||||
|
||||
check_toposort(sorted_nodes)
|
||||
return sorted_nodes
|
||||
sorted_nodes = []
|
||||
childless_nodes = [
|
||||
node for node in end_nodes if child_counts[id(node)] == 0
|
||||
]
|
||||
assert childless_nodes
|
||||
while childless_nodes:
|
||||
node = childless_nodes.pop()
|
||||
sorted_nodes.append(node)
|
||||
for parent in node.parents:
|
||||
if child_counts[id(parent)] == 1:
|
||||
childless_nodes.append(parent)
|
||||
else:
|
||||
child_counts[id(parent)] -= 1
|
||||
sorted_nodes = sorted_nodes[::-1]
|
||||
|
||||
def check_toposort(nodes):
|
||||
visited = set()
|
||||
for node in nodes:
|
||||
assert all(id(parent) in visited for parent in node.parents)
|
||||
visited.add(id(node))
|
||||
check_toposort(sorted_nodes)
|
||||
return sorted_nodes
|
||||
|
||||
def check_toposort(nodes):
|
||||
visited = set()
|
||||
for node in nodes:
|
||||
assert all(id(parent) in visited for parent in node.parents)
|
||||
visited.add(id(node))
|
||||
|
||||
def _remove_duplicates(node_list):
|
||||
seen = set()
|
||||
out = []
|
||||
for n in node_list:
|
||||
if id(n) not in seen:
|
||||
seen.add(id(n))
|
||||
out.append(n)
|
||||
return out
|
||||
|
||||
def _remove_duplicates(node_list):
|
||||
seen = set()
|
||||
out = []
|
||||
for n in node_list:
|
||||
if id(n) not in seen:
|
||||
seen.add(id(n))
|
||||
out.append(n)
|
||||
return out
|
||||
|
||||
def split_merge(predicate, xs):
|
||||
sides = list(map(predicate, xs))
|
||||
|
@ -214,6 +214,8 @@ nanobind_extension(
|
||||
module_name = "utils",
|
||||
deps = [
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@nanobind",
|
||||
|
@ -16,9 +16,13 @@ limitations under the License.
|
||||
#include <Python.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
@ -293,6 +297,69 @@ PyMethodDef safe_zip_def = {
|
||||
METH_FASTCALL,
|
||||
};
|
||||
|
||||
nb::list TopologicalSort(nb::str parents_attr,
|
||||
nb::iterable end_nodes_iterable) {
|
||||
// This is a direct conversion of the original Python implementation.
|
||||
// More efficient implementations of a topological sort are possible (and
|
||||
// indeed, easier to write), but changing the choice of topological order
|
||||
// would break existing tests.
|
||||
std::vector<nb::object> end_nodes;
|
||||
absl::flat_hash_set<PyObject*> seen;
|
||||
for (nb::handle n : end_nodes_iterable) {
|
||||
nb::object node = nb::borrow(n);
|
||||
if (seen.insert(node.ptr()).second) {
|
||||
end_nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
nb::list sorted_nodes;
|
||||
if (end_nodes.empty()) {
|
||||
return sorted_nodes;
|
||||
}
|
||||
|
||||
std::vector<nb::object> stack = end_nodes;
|
||||
absl::flat_hash_map<PyObject*, int> child_counts;
|
||||
while (!stack.empty()) {
|
||||
nb::object node = std::move(stack.back());
|
||||
stack.pop_back();
|
||||
auto& count = child_counts[node.ptr()];
|
||||
if (count == 0) {
|
||||
for (nb::handle parent : node.attr(parents_attr)) {
|
||||
stack.push_back(nb::borrow(parent));
|
||||
}
|
||||
}
|
||||
++count;
|
||||
}
|
||||
|
||||
for (nb::handle n : end_nodes) {
|
||||
child_counts[n.ptr()] -= 1;
|
||||
}
|
||||
|
||||
std::vector<nb::object> childless_nodes;
|
||||
childless_nodes.reserve(end_nodes.size());
|
||||
for (nb::handle n : end_nodes) {
|
||||
if (child_counts[n.ptr()] == 0) {
|
||||
childless_nodes.push_back(nb::borrow(n));
|
||||
}
|
||||
}
|
||||
|
||||
while (!childless_nodes.empty()) {
|
||||
nb::object node = std::move(childless_nodes.back());
|
||||
childless_nodes.pop_back();
|
||||
sorted_nodes.append(node);
|
||||
for (nb::handle parent : node.attr(parents_attr)) {
|
||||
auto& count = child_counts[parent.ptr()];
|
||||
if (count == 1) {
|
||||
childless_nodes.push_back(nb::borrow(parent));
|
||||
} else {
|
||||
--count;
|
||||
}
|
||||
}
|
||||
}
|
||||
sorted_nodes.reverse();
|
||||
return sorted_nodes;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
NB_MODULE(utils, m) {
|
||||
@ -304,6 +371,13 @@ NB_MODULE(utils, m) {
|
||||
m.attr("safe_zip") = nb::steal<nb::object>(
|
||||
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
|
||||
|
||||
m.def("topological_sort", &TopologicalSort, nb::arg("parents_attr"),
|
||||
nb::arg("end_nodes"),
|
||||
"Computes a topological sort of a graph of objects. parents_attr is "
|
||||
"the name of the attribute on each object that contains the list of "
|
||||
"parent objects. end_nodes is an iterable of objects from which we "
|
||||
"should start a backwards search.");
|
||||
|
||||
// Python has no reader-writer lock in its standard library, so we expose
|
||||
// bindings around absl::Mutex.
|
||||
nb::class_<absl::Mutex>(m, "Mutex")
|
||||
|
@ -201,5 +201,49 @@ class SafeZipTest(jtu.JaxTestCase):
|
||||
util.safe_zip((), range(3))
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, parents):
|
||||
self.parents = parents
|
||||
|
||||
|
||||
class TopologicalSortTest(jtu.JaxTestCase):
|
||||
|
||||
def _check_topological_sort(self, nodes, order):
|
||||
self.assertEqual(sorted(nodes, key=id), sorted(order, key=id))
|
||||
visited = set()
|
||||
for node in nodes:
|
||||
self.assertTrue(all(id(parent) in visited for parent in node.parents))
|
||||
visited.add(id(node))
|
||||
|
||||
def test_basic(self):
|
||||
a = Node([])
|
||||
b = Node([a])
|
||||
c = Node([a])
|
||||
d = Node([a, c])
|
||||
e = Node([b, c])
|
||||
out = util.toposort([a, d, e])
|
||||
self._check_topological_sort([a, b, c, d, e], out)
|
||||
|
||||
def test_stick(self):
|
||||
a = Node([])
|
||||
b = Node([a])
|
||||
c = Node([b])
|
||||
d = Node([c])
|
||||
e = Node([d])
|
||||
out = util.toposort([e])
|
||||
self._check_topological_sort([a, b, c, d, e], out)
|
||||
|
||||
def test_diamonds(self):
|
||||
a = Node([])
|
||||
b = Node([a])
|
||||
c = Node([a])
|
||||
d = Node([b, c])
|
||||
e = Node([d])
|
||||
f = Node([d])
|
||||
g = Node([e, f])
|
||||
out = util.toposort([g])
|
||||
self._check_topological_sort([a, b, c, d, e, f, g], out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user