Persistent Hash Array Mapped Tries

Published on Saturday, August 23, 2025 - 8:30 PM CDT

Bagwell's Hash Array Mapped Trie is one of a few foundational data-structures which enable efficient immutable collections. It was originally developed as a fast, space-efficient alternative to traditional hash tables. However, Rich Hickey (I believe) was the first to observe they could be made persistent. They have been fully embraced by the Clojure language as the data-structure of choice for implementing immutable collections.

What make's HAMT's so great as a persistent data-structure is the same property that makes them space-efficient. Namely the way they encode sparse collections within each intermediary node in the tree. Immutable data-structures are expensive to modify. You must copy the data structure in order to modify it. The less data we copy the faster an immutable data-structure is.

Each node in an HAMT is one of three types. A value node, a collision node, or an array node. Value nodes contain the key, value pair you inserted. Collision nodes contain a list of key, value pairs whose keys share the same 32-bit hash. Array nodes, which are the most interesting of the three, contain a sparse array of children. Those children are themselves one of the three node types.

To make it persistent you need to handle a few rules. When modifying a value node you will shallow-copy its parent's sparse-array and re-write a pointer in that array to point somewhere else. Collisions are rare but when they happen you'll shallow-copy the children and update or append the new key, value pair. When adding or deleting a key-value pair your only operation is a shallow-copy of the sparse array plus whatever operation you were asked to do.

It's not my intention to write a tutorial for implementing HAMT's but the original paper was a bit light on implementation details. I've included some below. Namely how the sparse-array is constructed and maintained.

def insert(self, hash_bits, level, key, value):
    # Shallow copy of the array. All interior values are assumed
    # immutable.
    node = ArrayNode()
    node.bitmap = self.bitmap
    node.nodes = self.nodes.copy()

    index = index_at_level(hash_bits, level)

    if index_is_occupied(node.bitmap, index):
        sparse_index = count_after(node.bitmap, index)
        new_node = node.nodes[sparse_index].insert(hash_bits, level + 1, key, value)
        node.nodes[sparse_index] = new_node
        return node
    else:
        node.bitmap = mark_index_position(node.bitmap, index)
        sparse_index = count_after(node.bitmap, index)
        node.nodes.insert(sparse_index, ValueNode(key, value))
        return node

The code might be overwhelming at first so let's throw out all the things that don't matter. The key, value, and self parameters can be ignored. Let's also ignore the recursive call to "insert". This implementation is incomplete so dwelling on it is not productive. Finally, let's ignore the first five lines of the method. This is what makes the HAMT persistent but its not important for understanding what the HAMT is doing.

Let's also add some context to our function paramters. The "hash_bits" parameter is the 32-bit hash of the key. Its cached to prevent re-computation at every level. The "level" parameter is our depth in the tree; HAMTs have a limited depth which varies based on how fat your array nodes are (its common for array nodes to be 32 wide).

There are four critical operations being performed. First we compute the 'index at our level'. A 32-bit hash is a collection of 32 0's and 1's. We can subdivide that hash 2 ways, 4 ways, 8 ways, 16 ways, or in our case 5 ways.

01 11001 00111 00100 11111 11010 11000

Our index bits are extracted by multiplying the level by the bit-length of each level, shifting the bits, and masking off the lowest bits that fill the range.

(bits >> (level * 5)) & 31  # Get the 5 relevant bits for our level

Each segment represents an index in our sparse array at a given level. The indices range from 0 to 31. Note our index position does not consider the sparsity of the child array. This means we need to allocate space for this index position to be marked as occupied. We can do this space-efficiently by keeping a 32-bit bitmap in each array node which will record occupied index positions.

bitmap | (1 << index)  # Mark the index as occupied.

To determine our node's index in the child array we count the number of "1"'s which appear in the bitmap prior to our entry and then insert our node into that position. The CPU instruction is commonly called POPCNT. In Python, it can be called as a method of on the "bin" class.

bin(bitmap & ((1 << index) - 1)).count("1")

The children in the array node are ordered based on their hashes position in the bitmap. When inserting a new node the node must fill the index position corresponding to the computation list prior. All previously inserted nodes are implicitly incremented by one when a new index position is inserted in a lower position than they occupy.

If you're interested in learning more I've written a full Python implementation below.

def hash_(a: int):
    # clamp to 32-bit
    return hash(a) & 0xFFFFFFFF


def index_at_level(bits: int, level: int) -> int:
    """Returns the index position in the bitmap at a given level."""
    assert level < 6

    # We use six levels of a 32-bit hash. 32 // 6 is 5. This means each
    # level owns its own 5 bits. The mask returns the 5 least
    # significant bits.
    mask = 31  # 11111

    # Each level is 5-bits wide. As we increase in levels we need to
    # shift the hash-bits by a multiple of 5. This is so our mask will
    # return the 5 bits in the least significant position and give us
    # an integer value between 0 and 31.
    shift = level * 5

    # Shift the bits and mask.
    return (bits >> (level * 5)) & 31


def count_after(bitmap: int, index: int) -> int:
    """Returns the index position in the sparse array.

    Count the number of bits which appear after the bit position. For
    example, given the bitmap 0x00...01111 and a bit position of 1
    return 1. If given a bit position of 4 return 4.
    """
    assert 31 >= index >= 0

    # Mask that returns every bit that exists in a position less than
    # the provided bit index.
    mask = (1 << index) - 1

    # Count all the 1's that appear in the bitmap after being masked.
    return bin(bitmap & ((1 << index) - 1)).count("1")


def index_is_occupied(bitmap: int, index: int) -> bool:
    """Return "True" if the index was occupied."""
    assert 31 >= index >= 0

    # Shift the bitmap "index" times.  If the index were 0 we shift the
    # bitmap 0 times; the least significant bit remains in its current
    # position. If the index were 31, we shift teh bitmap 31 times and
    # the most significant bit is now in the least signifcant bits
    # position.
    shifted = bitmap >> index

    # Mask the shifted value by 1. This gives us the value of the least
    # significant bit. If the LSB is 1 then the masked result is 1. If
    # the LSB is 0 then the masked result is 0.
    masked = shifted & 1

    # Finally we take the masked result and ask if its equal to the
    # integer 1. If so the slot was occupied and we return true.
    return masked == 1


def mark_index_position(bitmap: int, index: int) -> int:
    """Sets the bit at the given index to 1."""
    return bitmap | (1 << index)


class HAMT:

    def __init__(self, root=None):
        self.root = ArrayNode() if root is None else root

    def get(self, key):
        key_hash = hash_(key)
        return self.root.search(key_hash, 0, key)

    def insert(self, key, value):
        return HAMT(root=self.root.insert(hash_(key), 0, key, value))


class ArrayNode:
    def __init__(self):
        self.bitmap = 0
        self.nodes = []

    def insert(self, hash_bits, level, key, value):
        # Shallow copy of the array. All interior values are assumed
        # immutable.
        node = ArrayNode()
        node.bitmap = self.bitmap
        node.nodes = self.nodes.copy()

        index = index_at_level(hash_bits, level)

        if index_is_occupied(node.bitmap, index):
            sparse_index = count_after(node.bitmap, index)
            new_node = node.nodes[sparse_index].insert(hash_bits, level + 1, key, value)
            node.nodes[sparse_index] = new_node
            return node
        else:
            node.bitmap = mark_index_position(node.bitmap, index)
            sparse_index = count_after(node.bitmap, index)
            node.nodes.insert(sparse_index, ValueNode(key, value))
            return node

    def search(self, hash_bits, level, key):
        index = index_at_level(hash_bits, level)
        if index_is_occupied(self.bitmap, index):
            node = self.nodes[count_after(self.bitmap, index)]
            return node.search(hash_bits, level + 1, key)
        raise KeyError(key)


class ValueNode:
    def __init__(self, key: int, value: int):
        self.key = key
        self.value = value

    def insert(self, hash_bits, level, key, value):
        if self.key == key:
            return ValueNode(key, value)
        elif level >= 6:
            return CollisionNode(kvs=[(self.key, self.value), (key, value)])
        else:
            node = ArrayNode()
            node = node.insert(hash_(self.key), level, self.key, self.value)
            node = node.insert(hash_bits, level, key, value)
            return node

    def search(self, hash_bits, level, key):
        if key == self.key:
            return self.value
        raise KeyError(key)


class CollisionNode:
    def __init__(self, kvs: list[tuple[int, int]]):
        self.kvs = kvs

    def insert(self, hash_bits, level, key, value):
        entries = []
        found_key = False
        for k, v in self.kvs:
            if k == key:
                entries.append((key, value))
                found_key = True
            else:
                entries.append((k, v))
        if not found_key:
            entries.append((key, value))
        return CollisionNode(entries)

    def search(self, hash_bits, level, key):
        for k, v in self.kvs:
            if k == key:
                return v
        raise KeyError(key)