parallel
ParallelTokenCharacterTrie
Bases: TokenCharacterTrie
A GPU-optimized version of TokenCharacterTrie
that performs mass_sum
in parallel.
Inherits from TokenCharacterTrie
.
The mass at leaf nodes is propagated to their ancestors through sparse matrix multiplication with a reachability matrix. The reachability matrix is constructed at initialization.
Implementation details:
The reachability matrix M is a num_leafs × num_nodes matrix
where M[i,j] = 1 if:
- leaf_indices[i] == j (self connection) or
- j is an ancestor of leaf_indices[i] in the trie
Example:
Trie: M:
0 [[1, 1, 0, 1],
/ \ [1, 0, 1, 0]]
1 2 (leaf index = 1)
|
3 (leaf index = 0)
The matrix is stored as a sparse tensor in CSR (Compressed Sparse Row) format,
built from COO (Coordinate) format. For example,
rows = [1, 0, 1, 0, 0] (index of leaf node)
cols = [2, 3, 0, 1, 0] (connections)
vals = [1, 1, 1, 1, 1] (connection weights)
When computing masses (batch_size × num_leafs) @ M, each leaf node's mass
flows up to all its ancestors.
Source code in genlm_backend/trie/parallel.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
|
batch_mass_sum(p_llms)
Computes mass sums for a batch of probability distributions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
p_llms
|
Tensor
|
Batch of probability distributions over tokens, shape (batch_size × vocab_size). |
required |
Returns:
Type | Description |
---|---|
ndarray
|
Summed masses for each node in the trie, shape (batch_size × num_nodes). |
Source code in genlm_backend/trie/parallel.py
mass_sum(p_llm)
Computes the sum of masses for a single probability distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
p_llm
|
Tensor
|
Probability distribution over tokens from the LLM. |
required |
Returns:
Type | Description |
---|---|
ndarray
|
Summed masses for each node in the trie. |