What Is a Topological Sort?
A topological sort is a linear ordering of nodes in a directed acyclic graph (DAG) where every
directed edge from node u to node v, u appears
before v in the ordering. Or in other words, every node is sorted in such a way
that they only appear after all of their inputs.
One property of topological orderings is that they’re not necessarily unique. A graph can have multiple
valid orderings, and they all work for our purposes.
The Algorithm: Post-Order DFS
Typically, a topological sort is done through a post-order depth-first search. This is because in
a post-order DFS, nodes only get appended after all of their descendants have been visited. For
backpropagation, we then want to reverse the list so that we get the outputs before the inputs. We
do this because backpropagation calculates the gradients during its backward pass.
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._prev: # sort all dependencies before v
build_topo(child)
topo.append(v) # only add v after
build_topo(root)
Writing the Algorithm
This algorithm’s most easily written as a recursive function. Before we call it, we need to create
two variables. The first, topo, stores the topological ordering, and the second, visited, contains
all of the nodes that we’ve visited between each recursive call.
topo = []
visited = set()
Now, let’s look over build_topo line by line.
def build_topo(v):
if v not in visited: # (1)
visited.add(v) # (2)
for child in v._prev: # (3)
build_topo(child) # (4)
topo.append(v) # (5)
(1) — Before doing anything, we check whether this node has already been processed. If it had been, we return immediately. This prevents nodes from getting appended to topo more than once.
(2) — Nodes get marked as visited before recursing into their children. Combined with step 1, this prevents nodes from getting processed more than once (ex. the x in 3x and x2).
(3–4) — We recurse into every input node before doing anything with v itself. This is what makes it a DFS.
(5) — v is appended only after its descendants have been processed. This post-order property is what gives us the topological ordering.
Putting It in Context
Here’s the complete backward() method from lesson 2:
def backward(self):
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._prev:
build_topo(child)
topo.append(v)
build_topo(self)
for node in topo:
node.grad = 0.0
self.grad = 1.0
for node in reversed(topo):
node._backward()
build_topo(self) produces the topological ordering and reversed(topo) allows us to start with the outputs
and work our way back during the backward pass. We also initially seed self.grad to be 1.0 because the
derivative of the output with respect to itself is always 1. Each node._backward() call then propagates that
node’s accumulated gradient one step further toward the leaves.