@home:~$

Data Structures and Networking - Prefix Trie

My background is more on the operations side of house but having an understanding of data structures and algorithms is really helpful regardless or whether or not you actually do a lot of development. The further I go in my software development journey learning these topics, the better I find myself understanding networking concepts I had previously thought I had already known a lot about. Aside from that, it’s just going to make your code better, period. So let’s spend some time looking at why that’s the case.

Let’s look at a very real world example. Imagine a scenario where you’re given a 3-tuple value that represents an ARP entry, something like (DEVICE, IP_ADDR, MAC) and you need to match this entry to an existing subnet somehow since you need that info to update each inviduals IPs info in some system of record and you need to do this on some set interval. Let’s also imagine this is a large dataset measured with 6 figures. This isn’t hard to imagine in a multi-tenant environment, especially if you’re running MP-BGP. This also introduces the fact that could have duplicates across different route tables as well.

A Test Data Set

Without having the actual data set handy, we can still gather a decently large data set to work with as an example: the global routing table. That gives us ~730K subnets at the time of writing this. The best way to get this is hit up a looking glass that handles SSH and is preferably Juniper. You can dump the active-path table(because we don’t care to get non-active ones) to a text file like so:

ssh {user}@{looking-glass-server} 'show route protocol bgp active-path' > fulltable.txt

Bam. Full routing table in a text file. It’s not a small one though, and we’ll want to clean it up into something more useful. All we really care to have for an example is a tuple of (ROUTE, NEXT_HOP). The next hop from a looking glass will always be the same though, for what it’s worth. A simple script to clean it up and dump a CSV of what we need:

import re

cidr_regex = ('^(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)'
                  '\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d'
                  ')\/\d+')
ip_regex   = ('(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)'
                  '\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)\.(25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)')

output = []
with open('routetable.txt', 'r') as rfile:
    to_parse = rfile.readlines()
    current_pos = 0
    end_pos = len(to_parse)
    while current_pos < end_pos:
        ip_matched = re.search(cidr_regex, to_parse[current_pos])
        if ip_matched:
            # Next hop will be 2 lines down in JunOS CLI output for BGP. This wont catch
            # every single route but it will catch any route that's sourced from BGP.
            current_pos += 2
            next_hop_matched = re.search(ip_regex, to_parse[current_pos])
            if next_hop_matched:
                output.append(ip_matched.group(0)+ ',' + next_hop_matched.group(0) + '\n')
                # We can skip 2 ahead if we matched both
                current_pos += 1
        else:
            current_pos += 1

with open('routes-clean.txt', 'w') as out_file:
    out_file.writelines(output)

That gives us subnets to query in a flat file. Now we just need some IP addresses to test with. We can generate test IPs to use in the queries using the random module to give us 32-bit integers and then convert them to an IP.

generated_ips = []
for i in range(0, 500):
    r = random.randint(1, 4294967295)
    t = struct.pack('!I', r)
    ips = socket.inet_ntoa(t)
    generated_ips.append(ips)

The Naive Approach

First let’s take a look at a simple naive approach using loops and the standard library ipaddress module. We also need to ensure that the route we’re choosing is the longest match and not just the first match. For all 500 of our randomly generated IPs, check every entry in the global routing table to see if the IP belongs to the network for that route. If our IP is part of that route, check to see if we have a current prefix match already. If we do and it’s a longer prefix match than our current value, use the value we just found, else set the current prefix to the one we just found.

import time

with open('routes-clean.txt', 'r') as in_file:
    input = in_file.readlines()

input = [(e.split(',')[0], e.split(',')[1].strip('\n')) for e in input]

naive_start = time.time()
for i, ip in enumerate(generated_ips):
    test_ip = IPv4Address(ip)
    current_longest = (None, 0)
    current_prefix_len = 0

    for each in input:
        test_net = IPv4Network(each[0])
        
        if test_ip in test_net and test_net.prefixlen > current_prefix_len:
            current_longest = each
            current_prefix_len = int(each[0].split('/')[1])

    naive_answers.append(current_longest[1])

naive_end = time.time()
print('naive time: {}'.format(naive_end - naive_start))

And let’s run it! If you’re actually following along, you should take this time to go do something else for a while. Go cut your grass, or catch up on a missed episode of something.

drew@Drews-MacBook-Pro[~/Documents/lpm-trie] (master) ×  ❯ python3 test.py
naive time: 3178.157989025116

52 minutes?! Why the hell did that take so long? Let’s take a look at the runtime complexity of our algorithm using Big (O) notation. If we say M is the number of IPs we need to do a lookup for, and N is the number of subnets we can query against, then the runtime complexity of this solution in it’s worst case is O(mn), which is not great at all. On top of the runtime/memory used by the ipaddress module(we’ll touch on that next), for all 500 IPs we do membership checks on ~730K subnets, equating to ~365,000,000 operations. No wonder it takes so long.

Speeding It Up A Tad

Instead of using the standard library ipaddress module, we can forego the runtime it takes to initialize all it’s values(it’s doing a lot under the hood, generating address records for each host in subnet, etc…) and deal simply with converting IPs to integers and doing bitwise operations on them, then comparing those integers. For more on using bitwise operations for IP address/network addresses, see this previous blog post.

naive_start = time.time()
naive_answers = []
ALL_ONES = (2**31) - 1
for i, ip in enumerate(generated_ips):
    current_longest = None
    current_prefix_len = 0

    octets = ip.split('.')
    oct1 = int(octets[0]) << 24
    oct2 = int(octets[1]) << 16
    oct3 = int(octets[2]) << 8
    oct4 = int(octets[3])
    test_ip = oct1 + oct2 + oct3 + oct4
    
    octets = each[0].split('/')[0].split('.')
    mask = int(each[0].split('/')[1])
    oct1 = int(octets[0]) << 24
    oct2 = int(octets[1]) << 16
    oct3 = int(octets[2]) << 8
    oct4 = int(octets[3])
    test_start = oct1 + oct2 + oct3 + oct4
    test_wildcard = ALL_ONES >> mask
    test_end = test_start | test_wildcard

    if test_start < test_ip and test_ip < test_end:
        if mask > current_prefix_len:
            current_longest = each
            current_prefix_len = mask

        if current_longest[0] != '0.0.0.0/0':
            naive_answers.append(current_longest[1])
        else:
            naive_answers.append([])

naive_end = time.time()
print('naive time: {}'.format(naive_end - naive_start))

And if we run it again…

drew@Drews-MacBook-Pro[~/Documents/lpm-trie] (master) ×  ❯ python3 test.py
naive time: 885.8704960346222

Down to ~15 minutes. Doing bitwise operations is definitely faster. But this is still nowhere as fast as it could be and we’re still burning too much time/CPU on something that needs to run on an interval.

Enter the Prefix Trie

In order for us to speed this up, we need to think about how an IP address/subnet is viewed by the computer instead of by humans. Humans represent an IP address as a string, 4 octets delimited by periods. A computer views an IP address as a 32 bit integer compose in binary form as 0s and 1s. We can use a data structure, a binary trie in this case, to help us do lookups more efficiently. A binary trie is a tree of nodes where a parent node has exactly 2 children nodes. In the case of a prefix trie, the left child will be 0 and the right child will be a 1. Therefore we can start at the root node, and traverse down the trie looking for prefix matches until we either find the longest prefix match or we hit the bottom of the trie.

For the sake of keeping things simple and a diagram easier to look at, let’s imagine we have 3 routes with a /4 mask (netmask of 240.0.0.0):

P1 - 48.0.0.0/4 - 00110000000000000000000000000000 P2 - 192.0.0.0/4 - 11000000000000000000000000000000 P3 - 128.0.0.0/24 10000000000000000000000000000000

simpletrie

In our simple example, all possible nodes are already populated with null values but in our example code below we only insert nodes into the trie as needed. Using P1, starting at the root of the trie, we traverse left for 0 since the first bit is a zero. Next we’ll check to see if a prefix exists on that node and update our current longest prefix match if it does. There’s no value there so the longest prefix match is still none, and we keep moving down, to the left again since the next bit is a 0. We repeat this until we reach the 4th level of the trie and there are no more children to traverse. Since this level DOES have a prefix at this node, that is our longest prefix match.

So we’ll need to create 2 objects: A node to be inserted into the trie and the prefix trie itself which when instianted creates the root node. We’ll also need methods for inserting into the trie which takes the CIDR notiation for the route and it’s next hop and another for searching the trie by a CIDR block.

from ipaddress import IPv4Network, IPv4Address

class PrefixTrieNode(object):
    
    def __init__(self, prefix=''):
        self.left = None
        self.right = None
        self.prefix = prefix
        # Since there WILL be collisions, make this an array, it's up to the 
        # caller to handle logic on collisions
        self.value = []

class PrefixTrie(object):
    
    def __init__(self):
        self.root = PrefixTrieNode()
    
    def _child_node_count(self, node):
        # These should cover base cases
        right_count = 0
        left_count = 0
        if node.left:
            left_count = 1 + self._child_node_count(node.left)
        if node.right:
            right_count = 1 + self._child_node_count(node.right)
        
        return left_count + right_count

    def __len__(self):
        return self._child_node_count(self.root)

    def _insert(self, bit_array, value):
        current_node = self.root
        prefix_len = len(bit_array)
        for i, bit in enumerate(bit_array):
            if int(bit) == 0:
                if not current_node.left:
                    current_node.left = PrefixTrieNode(current_node.prefix + bit)
                current_node = current_node.left
            elif int(bit) == 1:
                if not current_node.right:
                    current_node.right = PrefixTrieNode(current_node.prefix + bit)
                current_node = current_node.right

            if i == (prefix_len - 1):
                current_node.value.append(value)

    def insert(self, cidr, nexthop):
        '''
        Convert CIDR notation string into binary equivalent, calculate
        network boundary based on bitwise shifts on wildcard mask, and 
        insert into trie at proper spot.
        '''

        network, mask = cidr.split('/')

        octets = network.split('.')
        oct1 = int(octets[0]) << 24
        oct2 = int(octets[1]) << 16
        oct3 = int(octets[2]) << 8
        oct4 = int(octets[3])

        boundary_int = oct1 + oct2 + oct3 + oct4
        boundary = '{0:032b}'.format(boundary_int)[:int(mask)]
        self._insert(boundary, nexthop)

    def _search(self, bit_array):
        '''
        Given IPv4 address in binary representation as bit array, search
        trie and return longest prefix match or None if not found.
        '''

        current_node = self.root
        current_longest = None
        for i, bit in enumerate(bit_array):
            if int(bit) == 0:
                current_node = current_node.left
            elif int(bit) == 1:
                current_node = current_node.right

            if not current_node:
                return current_longest.value
            else:
                current_longest = current_node
        

    def search(self, ip_addr):
        '''
        Given an IPv4 address as a string, convert to binary representation
        then search trie.
        '''

        # Let's not reinvent the wheel, ipaddress module is good
        # enough right here. This will also validate input IP.
        ip_object = IPv4Address(ip_addr)
        ip_bits = '{0:032b}'.format(int(ip_object))

        return self._search(ip_bits)

Let’s modify our script a bit. We want to loop our global route table output and insert every prefix into our trie. Once we generate our random test IPs again then we want to loop those and search the trie for a longest prefix match on all them. We’ll also keep track of how long it takes to both load the trie and search it.

load_start = time.time()
test_trie = trie.PrefixTrie()
for entry in input:
    test_trie.insert(entry[0], entry[1])
load_end = time.time()

generated_ips = []
for i in range(0, 500):
    r = random.randint(1, 4294967295)
    t = struct.pack('!I', r)
    ips = socket.inet_ntoa(t)
    generated_ips.append(ips)

search_start = time.time()
answers = []
for i in generated_ips:
    result = test_trie.search(i)
    answers.append(result)
search_end = time.time()

print('load time: {}'.format(load_end - load_start))
print('search time: {}'.format(search_end - search_start))

And if we run this again…

drew@Drews-MacBook-Pro[~/Documents/lpm-trie] (master) ×  ❯ python3 test.py
load time: 12.402899980545044
search time: 0.0062520503997802734

We’ve cut our runtime from 52 minutes down to about 13 seconds! To understand why, let’s take another look at the runtime complexity using a trie. The difference here when using a trie is we don’t need to loop every route in the global routing table. In a worst case scenario, we need to do 32 memory lookups(if it’s a host route) to find our longest prefix, and maybe 8 memory lookups best case. For the sake of comparison, let’s assume the entire routing table consisted of /24s. Using 500 prefixes against with 24 lookups per prefix would equate to 120,000 operations. That’s much much less than our original 365,000,000.

But it could be even better?

Sure. The objects created for this use case are pretty limited to what we’re doing here. We also end up with a lot of nodes in the trie where there’s no value. We could take this a step futher and compress the trie where that is the scenario. Python has modules available on PyPI, like pytricia, to do just this, and in a real world scenario you’d be better off using existing modules like that instead of doing this yourself.

As you can see, using the optimal data structure can make your code much faster. It doesn’t matter what language you use and how “fast” it is compared to others, but more on that in a later post.