I recently run into a really nice performance problem while trying to transverse a graph in Django, for the
Toaster project at work. I don't want to bore you with the particular details, so I'm simplifying the description a bit and go through the original solution and the optimizations I've made to get it to display faster.
I have a directed acyclic graph with black and red nodes, i.e. any node is either black or red. The graph is stored in a SQL database as Django models, and I query it using the QuerySet API. The problem I need to solve is - if I start from a red node, I want to return a list of all black nodes where the transversal ends; if I start from a black node, the transversal ends.
The simplified classes I have look like this:
class Node(models.Model):
COLOR_RED = 0
COLOR_BLACK = 1
COLOR_SET = (
(COLOR_RED, 'red'),
(COLOR_BLACK, 'black'),
)
color = models.IntegerField(choices=COLOR_SET)
class Edge(models.Model):
from_node = models.ForeignKey(Node, related_name='from_set')
to_node = models.ForeignKey(Node, related_name='to_set')
How big are the tables ? In the practical example I have, Node table size ranges in the thousands of entries, and for each Node I have 1-10 edges originating in that node (obviously 0 edges for leaf nodes).
First try was a basic recursive depth-first transversal. This has the advantage of being very fast to write, easy to follow code, and easy to debug.
def get_all_black_children(node):
child_list = [t.to_node for t in node.to_set.all()]
retlist = []
for n in [x for x in child_list if x.color == Node.COLOR_RED]:
retlist += get_all_black_children(n)
return retlist
This approach works, but it is very slow. For each node there are at least two queries to the database - we need to get the node properties (the Node table) and the children (the Edges table). Django models API abstracts all that, but the queries happen.
It gets worse - we may retrieve the same child node twice, e.g. if I have the A, B and C nodes, and A -> B and C -> B edges, data about node B will be fetched twice.
The problem is compounded if you consider that the B node might be RED and have a giant sub-tree underneath it, and each node in that sub-tree will be re-fetched for each possible red ancestry path.
In testing, I found malignant sets for which the transversal would take 600 seconds on a localhost MySQL, database which replied to each query in 20ms.
My the first try at improving the situation was to keep a cache of all nodes already visited, figuring out that local memory is thousands of time faster than database lookups, even if the database lives on the same machine and all data is cached in memory on the SQL side.
_cache = {}
def get_all_black_children2(node):
child_list = [t.to_node for t in node.to_set.all()]
retlist = []
for n in [x for x in child_list if x.color == Node.COLOR_RED]:
if not n in _cache:
_cache[n] = get_all_black_children(n)
retlist += _cache[n]
return retlist
There are several issues here, starting with the global cache that obscures updates to the database, but this code works as proof-of-concept. I've reduced the computation time for the malignant examples from 600 seconds to 50 seconds using the cache approach, and I was happy.
But before long this approach stroke a bell in my head - the cache I'm keeping is sort of a list of already visited nodes, and this sounds a lot like a breadth-first transversal. So I've re-written the code to move from depth-first to breadth-first:
def get_all_black_children3(node):
child_list = [t.to_node for t in node.to_set.all()]
black_list = [n for n in child_list if n.color == Node.COLOR_BLACK]
red_list = [n for n in child_list if n.color == Node.COLOR_RED]
visited_list = [node]
while len(red_list):
next_node = red_list.pop()
if next_node in visited_list:
continue
visited_list.push(next_node)
child_list = [t.to_node for t in next_node.to_set.all()]
black_list = list(set(black_list + [n for n in child_list if n.color == Node.COLOR_BLACK]))
red_list = list(set(red_list + [n for n in child_list if n.color == Node.COLOR_RED]))
return black_list
This was a WOW-moment for me. Display time dropped from 10 seconds to 1 second, and the interface was getting usable. Getting rid of recursion made debugging far easier too.
I had now a two-order of magnitude performance improvement over the naive recursive implementation. Any chance to make it faster ?
Well, we still make two distinct queries for each and every node - one to bring the node data, other to bring the node edges; this is abstracted by the QuerySet API, but you can see it happening if you use
django-debug-toolbar.
Look at what we do - we fetch the same type of information for each entry in a list, so maybe we should batch the SQL queries and get the whole data for an outstanding list of RED nodes in one go.
This is easier to read once we have a helper function that does the batching.
# we fetch all children for a node list in one go by selecting by edges
def get_chidren_for_list(node_list):
return [e.to_node for e in Edge.objects.filter(from_node__in = node_list).select_related("to_node")]
def get_all_black_children4(node):
child_list = get_children_for_list([node_list])
black_list = [n for n in child_list if n.color == Node.COLOR_BLACK]
red_list = [n for n in child_list if n.color == Node.COLOR_RED]
visited_list = node_list
while len(red_list):
child_list = get_children_for_list(red_list)
visited_list = list(set(visited_list + child_list))
black_list = list(set(black_list + [n for n in child_list if n.color == Node.COLOR_BLACK]))
red_list = [n for n in child_list if n.color == Node.COLOR_RED]
return black_list
What happened here ? Well, by selecting by Edge, we moved the complexity of queries from O(N) to O(logN); we do a single query for each level of depth in the tree.
Notice the
select_related call that makes sure the node information is fetched by Django in the same query as the Edge information so we save the round-trips for each Node. This is a powerful tool in Django to eliminate SQL latency.
With these saving I managed to display the testing data set in 200 ms, a 3000-times improvement over the initial code !
Lessons learned: latency to the database really matters, so:
Memory caches and array operations beat hands down trips to the database. You'll say "DOOHHH", but this leads to:
- Breadth-first beats hands down Depth-first when data lives in a database !
And specific optimizations on Django,
- Batch the queries by using the __in operator to bring in more
- Use select_related to fetch all additional data in one go