Skip to content

trie

AsyncTokenCharacterTrie

An asynchronous wrapper for TokenCharacterTrie implementations.

This class provides asynchronous access to mass sum calculations, with automatic batching of concurrent requests. It maintains a background task that processes queued requests.

Source code in genlm_backend/trie/async_impl.py
class AsyncTokenCharacterTrie:
    """An asynchronous wrapper for `TokenCharacterTrie` implementations.

    This class provides asynchronous access to mass sum calculations, with automatic batching of concurrent requests.
    It maintains a background task that processes queued requests.
    """

    def __init__(self, trie):
        """Initialize an `AsyncTokenCharacterTrie`.

        Args:
            trie (TokenCharacterTrie|ParallelTokenCharacterTrie): The underlying `TokenCharacterTrie` or `ParallelTokenCharacterTrie` instance
        """
        self.trie = trie
        self._queue = asyncio.Queue()
        self._task = None

    @classmethod
    def from_vocab(cls, byte_vocab, backend="parallel", **kwargs):
        """Creates an `AsyncTokenCharacterTrie` from a byte vocabulary.

        Args:
            byte_vocab (list[byte]): The byte vocabulary over which the trie will be defined.
            backend (str, optional): The trie implementation to use - either 'sequential' or 'parallel'.
                    Defaults to 'parallel' which uses GPU acceleration when available.
            **kwargs: Additional arguments passed to the trie constructor

        Returns:
            (AsyncTokenCharacterTrie): The initialized asynchronous trie instance.
        """
        if backend == "sequential":
            trie = TokenCharacterTrie(decode=byte_vocab, **kwargs)
        elif backend == "parallel":
            trie = ParallelTokenCharacterTrie(decode=byte_vocab, **kwargs)
        else:
            raise ValueError(
                f"Unknown backend: {backend}. Must be one of ['sequential', 'parallel']"
            )
        return cls(trie)

    async def mass_sum(self, p_llm):
        """Asynchronously computes the mass at each node of the trie.

        This method queues the mass calculation to be processed in a background task.
        Multiple concurrent requests are automatically batched together.

        Args:
            p_llm (torch.Tensor): Probability distribution over the trie's vocabulary of length `len(trie.decode)`.

        Returns:
            (float): The calculated mass sum for the given distribution.
        """
        if not self._task:
            self.start()

        future = asyncio.Future()
        await self._queue.put((p_llm, future))
        return await future

    def start(self):
        """Start the background processing task if not already running."""
        if not self._task:
            self._task = asyncio.create_task(self._background_loop())

    async def _do_mass_sums(self, p_llms):
        """Compute mass sums for a batch of distributions.

        Args:
            p_llms (list[torch.Tensor]): List of distributions over trie vocabulary.

        Returns:
            (torch.Tensor): Batch of computed mass sums
        """
        return self.trie.batch_mass_sum(torch.stack(p_llms))  # XXX handle device

    async def _background_loop(self):
        """Background task that processes queued mass sum requests.

        Continuously monitors the queue for new requests and processes them using the underlying trie implementation.

        Raises:
            Exception: If any error occurs during processing, it is propagated to all
                      pending futures in the current batch.
        """
        while True:
            try:
                requests = []
                futures = []

                request, future = await self._queue.get()
                requests.append(request)
                futures.append(future)

                while not self._queue.empty():
                    request, future = await self._queue.get()
                    requests.append(request)
                    futures.append(future)

                logger.debug(f"Processing batch of {len(requests)} requests.")
                results = await self._do_mass_sums(requests)

                for future, result in zip(futures, results):
                    future.set_result(result)

            except Exception as e:
                for future in futures:
                    if not future.done():
                        future.set_exception(e)
                raise

    def shutdown(self):
        """Stop the background processing task and cleanup resources."""
        if self._task:
            self._task.cancel()
            self._task = None

    def __del__(self):
        self.shutdown()

__init__(trie)

Initialize an AsyncTokenCharacterTrie.

Parameters:

Name Type Description Default
trie TokenCharacterTrie | ParallelTokenCharacterTrie

The underlying TokenCharacterTrie or ParallelTokenCharacterTrie instance

required
Source code in genlm_backend/trie/async_impl.py
def __init__(self, trie):
    """Initialize an `AsyncTokenCharacterTrie`.

    Args:
        trie (TokenCharacterTrie|ParallelTokenCharacterTrie): The underlying `TokenCharacterTrie` or `ParallelTokenCharacterTrie` instance
    """
    self.trie = trie
    self._queue = asyncio.Queue()
    self._task = None

from_vocab(byte_vocab, backend='parallel', **kwargs) classmethod

Creates an AsyncTokenCharacterTrie from a byte vocabulary.

Parameters:

Name Type Description Default
byte_vocab list[byte]

The byte vocabulary over which the trie will be defined.

required
backend str

The trie implementation to use - either 'sequential' or 'parallel'. Defaults to 'parallel' which uses GPU acceleration when available.

'parallel'
**kwargs

Additional arguments passed to the trie constructor

{}

Returns:

Type Description
AsyncTokenCharacterTrie

The initialized asynchronous trie instance.

Source code in genlm_backend/trie/async_impl.py
@classmethod
def from_vocab(cls, byte_vocab, backend="parallel", **kwargs):
    """Creates an `AsyncTokenCharacterTrie` from a byte vocabulary.

    Args:
        byte_vocab (list[byte]): The byte vocabulary over which the trie will be defined.
        backend (str, optional): The trie implementation to use - either 'sequential' or 'parallel'.
                Defaults to 'parallel' which uses GPU acceleration when available.
        **kwargs: Additional arguments passed to the trie constructor

    Returns:
        (AsyncTokenCharacterTrie): The initialized asynchronous trie instance.
    """
    if backend == "sequential":
        trie = TokenCharacterTrie(decode=byte_vocab, **kwargs)
    elif backend == "parallel":
        trie = ParallelTokenCharacterTrie(decode=byte_vocab, **kwargs)
    else:
        raise ValueError(
            f"Unknown backend: {backend}. Must be one of ['sequential', 'parallel']"
        )
    return cls(trie)

mass_sum(p_llm) async

Asynchronously computes the mass at each node of the trie.

This method queues the mass calculation to be processed in a background task. Multiple concurrent requests are automatically batched together.

Parameters:

Name Type Description Default
p_llm Tensor

Probability distribution over the trie's vocabulary of length len(trie.decode).

required

Returns:

Type Description
float

The calculated mass sum for the given distribution.

Source code in genlm_backend/trie/async_impl.py
async def mass_sum(self, p_llm):
    """Asynchronously computes the mass at each node of the trie.

    This method queues the mass calculation to be processed in a background task.
    Multiple concurrent requests are automatically batched together.

    Args:
        p_llm (torch.Tensor): Probability distribution over the trie's vocabulary of length `len(trie.decode)`.

    Returns:
        (float): The calculated mass sum for the given distribution.
    """
    if not self._task:
        self.start()

    future = asyncio.Future()
    await self._queue.put((p_llm, future))
    return await future

shutdown()

Stop the background processing task and cleanup resources.

Source code in genlm_backend/trie/async_impl.py
def shutdown(self):
    """Stop the background processing task and cleanup resources."""
    if self._task:
        self._task.cancel()
        self._task = None

start()

Start the background processing task if not already running.

Source code in genlm_backend/trie/async_impl.py
def start(self):
    """Start the background processing task if not already running."""
    if not self._task:
        self._task = asyncio.create_task(self._background_loop())

ParallelTokenCharacterTrie

Bases: TokenCharacterTrie

A GPU-optimized version of TokenCharacterTrie that performs mass_sum in parallel.

Inherits from TokenCharacterTrie.

The mass at leaf nodes is propagated to their ancestors through sparse matrix multiplication with a reachability matrix. The reachability matrix is constructed at initialization.

Implementation details:

The reachability matrix M is a num_leafs × num_nodes matrix
where M[i,j] = 1 if:

    - leaf_indices[i] == j (self connection) or
    - j is an ancestor of leaf_indices[i] in the trie

Example:

    Trie:          M:
         0           [[1, 1, 0, 1],
        / \           [1, 0, 1, 0]]
       1   2 (leaf index = 1)
       |
       3 (leaf index = 0)

The matrix is stored as a sparse tensor in CSR (Compressed Sparse Row) format,
built from COO (Coordinate) format. For example,

    rows = [1, 0, 1, 0, 0] (index of leaf node)
    cols = [2, 3, 0, 1, 0] (connections)
    vals = [1, 1, 1, 1, 1] (connection weights)

When computing masses (batch_size × num_leafs) @ M, each leaf node's mass
flows up to all its ancestors.
Source code in genlm_backend/trie/parallel.py
class ParallelTokenCharacterTrie(TokenCharacterTrie):
    """A GPU-optimized version of `TokenCharacterTrie` that performs `mass_sum` in parallel.

    Inherits from `TokenCharacterTrie`.

    The mass at leaf nodes is propagated to their ancestors through sparse matrix
    multiplication with a reachability matrix. The reachability matrix is constructed at initialization.

    Implementation details:\n
        The reachability matrix M is a num_leafs × num_nodes matrix
        where M[i,j] = 1 if:\n
            - leaf_indices[i] == j (self connection) or
            - j is an ancestor of leaf_indices[i] in the trie

        Example:\n
            Trie:          M:
                 0           [[1, 1, 0, 1],
                / \\           [1, 0, 1, 0]]
               1   2 (leaf index = 1)
               |
               3 (leaf index = 0)

        The matrix is stored as a sparse tensor in CSR (Compressed Sparse Row) format,
        built from COO (Coordinate) format. For example,\n
            rows = [1, 0, 1, 0, 0] (index of leaf node)
            cols = [2, 3, 0, 1, 0] (connections)
            vals = [1, 1, 1, 1, 1] (connection weights)

        When computing masses (batch_size × num_leafs) @ M, each leaf node's mass
        flows up to all its ancestors.
    """

    def __init__(self, decode, device=None, **kwargs):
        super().__init__(decode, **kwargs)

        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        if self.device not in ["cpu", "cuda"]:
            raise ValueError(f"Invalid device: {device}. Must be 'cpu', 'cuda' or None")

        self.M = self._build_reachability_matrix()
        self.token_ids = torch.tensor(
            self.token_id_to_leaf[:, 0], dtype=torch.long, device=self.device
        )

    def _build_parent_map(self):
        """Builds a mapping from each node to its parent node in the trie.

        Returns:
            (dict): A dictionary where keys are child nodes and values are their parent nodes.
        """
        parent = {}
        for node in range(len(self.children)):
            for child in self.jump[node]:
                parent[child] = node
        return parent

    def _build_reachability_matrix(self):
        """Constructs a sparse reachability matrix for efficient mass propagation.

        The matrix M is constructed such that M[i,j] = 1 if node j is either:
        - The leaf node i itself (self-connection)
        - An ancestor of leaf node i in the trie

        Returns:
            (torch.Tensor): A sparse CSR matrix of shape (num_leafs × num_nodes)
        """
        rows, cols = [], []
        leaf_indices = self.token_id_to_leaf[:, 1]

        # add self connections
        for i, node in enumerate(leaf_indices):
            rows.append(i)
            cols.append(node)

        # add all ancestor connections
        parent = self._build_parent_map()
        for i, node in enumerate(leaf_indices):
            current = node
            while current in parent:  # Walk up to root
                ancestor = parent[current]
                rows.append(i)
                cols.append(ancestor)
                current = ancestor

        indices = torch.tensor([rows, cols], dtype=torch.long, device=self.device)
        values = torch.ones(len(rows), device=self.device)
        M = torch.sparse_coo_tensor(
            indices, values, (len(leaf_indices), len(self.children))
        ).to_sparse_csr()

        return M

    def mass_sum(self, p_llm):
        """Computes the sum of masses for a single probability distribution.

        Args:
            p_llm (torch.Tensor): Probability distribution over tokens from the LLM.

        Returns:
            (numpy.ndarray): Summed masses for each node in the trie.
        """
        return self.batch_mass_sum(p_llm.unsqueeze(0))

    def batch_mass_sum(self, p_llms):
        """Computes mass sums for a batch of probability distributions.

        Args:
            p_llms (torch.Tensor): Batch of probability distributions over tokens,
                shape (batch_size × vocab_size).

        Returns:
            (numpy.ndarray): Summed masses for each node in the trie,
                shape (batch_size × num_nodes).
        """
        if p_llms.device != self.device:
            p_llms = p_llms.to(self.device)
        masses = torch.sparse.mm(p_llms[:, self.token_ids], self.M)
        return masses.cpu().numpy()

batch_mass_sum(p_llms)

Computes mass sums for a batch of probability distributions.

Parameters:

Name Type Description Default
p_llms Tensor

Batch of probability distributions over tokens, shape (batch_size × vocab_size).

required

Returns:

Type Description
ndarray

Summed masses for each node in the trie, shape (batch_size × num_nodes).

Source code in genlm_backend/trie/parallel.py
def batch_mass_sum(self, p_llms):
    """Computes mass sums for a batch of probability distributions.

    Args:
        p_llms (torch.Tensor): Batch of probability distributions over tokens,
            shape (batch_size × vocab_size).

    Returns:
        (numpy.ndarray): Summed masses for each node in the trie,
            shape (batch_size × num_nodes).
    """
    if p_llms.device != self.device:
        p_llms = p_llms.to(self.device)
    masses = torch.sparse.mm(p_llms[:, self.token_ids], self.M)
    return masses.cpu().numpy()

mass_sum(p_llm)

Computes the sum of masses for a single probability distribution.

Parameters:

Name Type Description Default
p_llm Tensor

Probability distribution over tokens from the LLM.

required

Returns:

Type Description
ndarray

Summed masses for each node in the trie.

Source code in genlm_backend/trie/parallel.py
def mass_sum(self, p_llm):
    """Computes the sum of masses for a single probability distribution.

    Args:
        p_llm (torch.Tensor): Probability distribution over tokens from the LLM.

    Returns:
        (numpy.ndarray): Summed masses for each node in the trie.
    """
    return self.batch_mass_sum(p_llm.unsqueeze(0))

TokenCharacterTrie

A trie data structure for efficient token-to-character mapping and probability mass computation.

Each node in the trie corresponds to a token prefix. The probability mass computation provides the marginal probability of each prefix under a given distribution over the token vocabulary.

Source code in genlm_backend/trie/base.py
class TokenCharacterTrie:
    """A trie data structure for efficient token-to-character mapping and probability mass computation.

    Each node in the trie corresponds to a token prefix. The probability mass computation provides the marginal
    probability of each prefix under a given distribution over the token vocabulary.
    """

    def __init__(self, decode, old_eos=None, new_eos=None):
        """Initialize a `TokenCharacterTrie`.

        Args:
            decode (list[bytes]): List of byte strings representing the token vocabulary.
            old_eos (str|bytes|None): The current end-of-sequence token to be replaced. If provided as str,
                                    will be encoded to bytes.
            new_eos (str|bytes|None): The new end-of-sequence token to use. If provided as str, will be
                                    encoded to bytes.
        """
        if not all(isinstance(x, bytes) for x in decode):
            raise ValueError("All elements in decode must be byte strings")

        self.decode = decode
        self._convert_eos(old_eos, new_eos)
        self._build_trie()

    def _convert_eos(self, old_eos, new_eos):
        """Configure EOS token conversion settings.

        Args:
            old_eos (str|bytes|None): Original EOS token to be converted
            new_eos (str|bytes|None): New EOS token to convert to

        Raises:
            ValueError: If only one of old_eos or new_eos is provided
        """
        if (old_eos is None) != (new_eos is None):
            raise ValueError(
                "Both old_eos and new_eos must be provided together, or neither should be provided"
            )

        old_eos = (
            old_eos.encode("utf-8")
            if old_eos and not isinstance(old_eos, bytes)
            else old_eos
        )
        new_eos = (
            new_eos.encode("utf-8")
            if new_eos and not isinstance(new_eos, bytes)
            else new_eos
        )

        if (new_eos is not None) and (new_eos in self.decode):
            raise ValueError(f"new_eos token {new_eos!r} already exists in vocabulary")

        self.old_eos = old_eos
        self.new_eos = new_eos
        self.convert_eos = (old_eos is not None) and (new_eos is not None)
        self.old_eos_id = self.decode.index(self.old_eos) if self.convert_eos else None

    def _build_trie(self):
        """Construct the trie structure from the vocabulary."""
        self.word2leaf = {}
        self.children = [{}]  # First node is root
        self.root = 0
        self.token_id_to_leaf = []

        for token_id, word in enumerate(self.decode):
            if self.convert_eos and word == self.old_eos:
                word = self.new_eos  # coerce old eos to new eos

            curr = self.root
            for letter in word:
                if letter not in self.children[curr]:
                    self.children[curr][letter] = len(self.children)
                    self.children.append({})
                curr = self.children[curr][letter]

            self.children[curr][None] = last = len(self.children)
            self.children.append({})
            assert (
                word not in self.word2leaf
            ), "Can't have duplicate words in vocabulary"
            self.word2leaf[word] = last

            self.token_id_to_leaf.append((token_id, last))

        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))
        self.jump = List(
            [np.array(sorted(x.values()), dtype=np.int32) for x in self.children]
        )
        self.ordering = np.array(list(self._order(self.root)), np.int32)

        # Renumber the states of the trie so that they are named by a contiguous
        # range of integers and those integers respect the are topologically
        # ordering of the trie topology.  This improves the efficiency of the
        # updating the trie as it improves memory locality.
        ordering = {}
        for i, x in enumerate(self._order_full(self.root)):
            ordering[x] = i
        self._rename(f=lambda x: ordering[x])

        node2prefix = {self.root: b""}
        for x in reversed(range(len(self.children))):
            for letter, y in self.children[x].items():
                if isinstance(letter, int):
                    letter = bytes([letter])
                if letter is None:
                    node2prefix[y] = node2prefix[x]
                else:
                    node2prefix[y] = node2prefix[x] + letter
        self.node2prefix = node2prefix

    def _rename(self, f):
        """Rename all node indices in the trie using the provided mapping function.

        Args:
            f (callable): Function that maps old node indices to new node indices
        """
        N = len(self.children)

        new_children = [{} for _ in range(N)]
        nodes = range(N)

        for x in nodes:
            for letter, y in self.children[x].items():
                new_children[f(x)][letter] = f(y)

        self.root = f(self.root)
        self.children = new_children
        self.word2leaf = {w: f(x) for w, x in self.word2leaf.items()}
        self.leaf2word = dict(zip(self.word2leaf.values(), self.word2leaf.keys()))

        self.token_id_to_leaf = np.array(
            [(i, f(x)) for i, x in self.token_id_to_leaf], dtype=np.int32
        )

        self.ordering = np.array([f(x) for x in self.ordering])
        self.jump = List(
            [np.array(sorted(x.values()), dtype=np.int32) for x in new_children]
        )

    def _alloc_mass(self):
        """Allocate an array to store probability mass values for all nodes.

        Returns:
            np.ndarray: Zero-initialized array for storing probability mass values
        """
        return np.zeros(len(self.children), dtype=np.float64)

    def mass_sum(self, p_llm):
        """Compute probability mass for each node in the trie.

        Args:
            p_llm (torch.Tensor|np.ndarray): Token probabilities from language model

        Returns:
            (np.ndarray): Probability mass values for each node in the trie.
                The mass corresponds to the marginal probability under `p_llm` of the prefix represented by the node.
        """
        if isinstance(p_llm, torch.Tensor):
            if p_llm.device.type != "cpu":
                p_llm = p_llm.cpu()
            p_llm = p_llm.numpy()
        mass = self._alloc_mass()
        if self.convert_eos:
            mass[self.word2leaf[self.new_eos]] = p_llm[self.old_eos_id]
        _update_trie_numba(
            mass=mass,
            _p=p_llm,
            token_id_to_leaf=self.token_id_to_leaf,
            jump=self.jump,
            ordering=self.ordering,
        )
        return mass

    def batch_mass_sum(self, p_llms):
        """Compute probability mass for multiple distributions over tokens.

        Args:
            p_llms (list[torch.Tensor|np.ndarray]): Batch of token probability distributions

        Returns:
            (np.ndarray): Batch of probability mass values of `len(p_llms)` for each node in the trie
        """
        return np.array([self.mass_sum(p_llm) for p_llm in p_llms])

    def _order(self, node):
        """Generate a topological ordering of nodes beneath the given node.

        Args:
            node (int): Starting node index

        Yields:
            int: Node indices in topological order
        """
        for a in self.children[node]:
            if a is None:
                pass
            else:
                yield from self._order(self.children[node][a])
        yield node

    def _order_full(self, node):
        """Generate a complete topological ordering including all child nodes.

        Args:
            node (int): Starting node index

        Yields:
            (int): Node indices in complete topological order
        """
        for a in self.children[node]:
            yield from self._order_full(self.children[node][a])
        yield node

    def visualize(self, mass=None):
        """Visualize the trie structure using Graphviz.

        Args:
            mass (np.ndarray|None): Optional mass vector to display at each node.
                                Should be of length `len(self.children)`.

        Returns:
            (graphviz.Digraph): The generated graph object
        """
        try:
            import graphviz
        except ImportError:
            raise ImportError("Please install graphviz: pip install graphviz")

        if mass is not None and len(mass) != len(self.children):
            raise ValueError(
                f"Mass vector length ({len(mass)}) must match number of nodes ({len(self.children)})"
            )

        dot = graphviz.Digraph(comment="Token Character Trie")
        dot.attr(rankdir="LR")

        # Create a subgraph for the legend
        with dot.subgraph(name="cluster_legend") as legend:
            legend.attr(label="Legend", fontsize="10")
            legend.attr("node", fontsize="7", width="0.1", height="0.1")

            # Example internal node
            legend.node(
                "legend_internal",
                "Internal Node ID\n'Prefix'\nMass (if provided)",
                shape="circle",
            )

            # Example leaf node
            legend.node("legend_leaf", "Complete Token", shape="doublecircle")

            legend.edge(
                "legend_internal",
                "legend_leaf",
                label="Character (Byte value)",
                fontsize="10",
            )

            # Align legend horizontally
            legend.attr(rankdir="TB")
            legend.attr(rank="same")

        # Add the main trie nodes and edges
        for node_id in range(len(self.children)):
            prefix = self.node2prefix[node_id].decode("utf-8", errors="replace")

            if mass is not None:
                label = f"{node_id}\n'{prefix}'\n{mass[node_id]:.4f}"
            else:
                label = f"{node_id}\n'{prefix}'"

            # Color nodes based on mass if provided
            if mass is not None:
                max_mass = mass.max()
                if max_mass > 0:
                    intensity = int(255 * (1 - mass[node_id] / max_mass))
                    color = f"#{intensity:02x}{255:02x}{intensity:02x}"
                else:
                    color = "#ffffff"  # white for zero mass
            else:
                color = "#ffffff"  # default white

            if node_id in self.leaf2word:
                dot.node(
                    str(node_id),
                    label,
                    shape="doublecircle",
                    style="filled",
                    fillcolor=color,
                )
            else:
                dot.node(
                    str(node_id), label, shape="circle", style="filled", fillcolor=color
                )

        for node_id, children in enumerate(self.children):
            for char, child_id in children.items():
                if char is not None:
                    if isinstance(char, int):
                        s_char = bytes([char]).decode("utf-8", errors="replace")
                        edge_label = str(s_char) + f" ({char})"
                    else:
                        edge_label = str(char)
                else:
                    edge_label = "End-of-Token"

                dot.edge(str(node_id), str(child_id), label=edge_label)

        return dot

__init__(decode, old_eos=None, new_eos=None)

Initialize a TokenCharacterTrie.

Parameters:

Name Type Description Default
decode list[bytes]

List of byte strings representing the token vocabulary.

required
old_eos str | bytes | None

The current end-of-sequence token to be replaced. If provided as str, will be encoded to bytes.

None
new_eos str | bytes | None

The new end-of-sequence token to use. If provided as str, will be encoded to bytes.

None
Source code in genlm_backend/trie/base.py
def __init__(self, decode, old_eos=None, new_eos=None):
    """Initialize a `TokenCharacterTrie`.

    Args:
        decode (list[bytes]): List of byte strings representing the token vocabulary.
        old_eos (str|bytes|None): The current end-of-sequence token to be replaced. If provided as str,
                                will be encoded to bytes.
        new_eos (str|bytes|None): The new end-of-sequence token to use. If provided as str, will be
                                encoded to bytes.
    """
    if not all(isinstance(x, bytes) for x in decode):
        raise ValueError("All elements in decode must be byte strings")

    self.decode = decode
    self._convert_eos(old_eos, new_eos)
    self._build_trie()

batch_mass_sum(p_llms)

Compute probability mass for multiple distributions over tokens.

Parameters:

Name Type Description Default
p_llms list[Tensor | ndarray]

Batch of token probability distributions

required

Returns:

Type Description
ndarray

Batch of probability mass values of len(p_llms) for each node in the trie

Source code in genlm_backend/trie/base.py
def batch_mass_sum(self, p_llms):
    """Compute probability mass for multiple distributions over tokens.

    Args:
        p_llms (list[torch.Tensor|np.ndarray]): Batch of token probability distributions

    Returns:
        (np.ndarray): Batch of probability mass values of `len(p_llms)` for each node in the trie
    """
    return np.array([self.mass_sum(p_llm) for p_llm in p_llms])

mass_sum(p_llm)

Compute probability mass for each node in the trie.

Parameters:

Name Type Description Default
p_llm Tensor | ndarray

Token probabilities from language model

required

Returns:

Type Description
ndarray

Probability mass values for each node in the trie. The mass corresponds to the marginal probability under p_llm of the prefix represented by the node.

Source code in genlm_backend/trie/base.py
def mass_sum(self, p_llm):
    """Compute probability mass for each node in the trie.

    Args:
        p_llm (torch.Tensor|np.ndarray): Token probabilities from language model

    Returns:
        (np.ndarray): Probability mass values for each node in the trie.
            The mass corresponds to the marginal probability under `p_llm` of the prefix represented by the node.
    """
    if isinstance(p_llm, torch.Tensor):
        if p_llm.device.type != "cpu":
            p_llm = p_llm.cpu()
        p_llm = p_llm.numpy()
    mass = self._alloc_mass()
    if self.convert_eos:
        mass[self.word2leaf[self.new_eos]] = p_llm[self.old_eos_id]
    _update_trie_numba(
        mass=mass,
        _p=p_llm,
        token_id_to_leaf=self.token_id_to_leaf,
        jump=self.jump,
        ordering=self.ordering,
    )
    return mass

visualize(mass=None)

Visualize the trie structure using Graphviz.

Parameters:

Name Type Description Default
mass ndarray | None

Optional mass vector to display at each node. Should be of length len(self.children).

None

Returns:

Type Description
Digraph

The generated graph object

Source code in genlm_backend/trie/base.py
def visualize(self, mass=None):
    """Visualize the trie structure using Graphviz.

    Args:
        mass (np.ndarray|None): Optional mass vector to display at each node.
                            Should be of length `len(self.children)`.

    Returns:
        (graphviz.Digraph): The generated graph object
    """
    try:
        import graphviz
    except ImportError:
        raise ImportError("Please install graphviz: pip install graphviz")

    if mass is not None and len(mass) != len(self.children):
        raise ValueError(
            f"Mass vector length ({len(mass)}) must match number of nodes ({len(self.children)})"
        )

    dot = graphviz.Digraph(comment="Token Character Trie")
    dot.attr(rankdir="LR")

    # Create a subgraph for the legend
    with dot.subgraph(name="cluster_legend") as legend:
        legend.attr(label="Legend", fontsize="10")
        legend.attr("node", fontsize="7", width="0.1", height="0.1")

        # Example internal node
        legend.node(
            "legend_internal",
            "Internal Node ID\n'Prefix'\nMass (if provided)",
            shape="circle",
        )

        # Example leaf node
        legend.node("legend_leaf", "Complete Token", shape="doublecircle")

        legend.edge(
            "legend_internal",
            "legend_leaf",
            label="Character (Byte value)",
            fontsize="10",
        )

        # Align legend horizontally
        legend.attr(rankdir="TB")
        legend.attr(rank="same")

    # Add the main trie nodes and edges
    for node_id in range(len(self.children)):
        prefix = self.node2prefix[node_id].decode("utf-8", errors="replace")

        if mass is not None:
            label = f"{node_id}\n'{prefix}'\n{mass[node_id]:.4f}"
        else:
            label = f"{node_id}\n'{prefix}'"

        # Color nodes based on mass if provided
        if mass is not None:
            max_mass = mass.max()
            if max_mass > 0:
                intensity = int(255 * (1 - mass[node_id] / max_mass))
                color = f"#{intensity:02x}{255:02x}{intensity:02x}"
            else:
                color = "#ffffff"  # white for zero mass
        else:
            color = "#ffffff"  # default white

        if node_id in self.leaf2word:
            dot.node(
                str(node_id),
                label,
                shape="doublecircle",
                style="filled",
                fillcolor=color,
            )
        else:
            dot.node(
                str(node_id), label, shape="circle", style="filled", fillcolor=color
            )

    for node_id, children in enumerate(self.children):
        for char, child_id in children.items():
            if char is not None:
                if isinstance(char, int):
                    s_char = bytes([char]).decode("utf-8", errors="replace")
                    edge_label = str(s_char) + f" ({char})"
                else:
                    edge_label = str(char)
            else:
                edge_label = "End-of-Token"

            dot.edge(str(node_id), str(child_id), label=edge_label)

    return dot