Skip to content

llm

AsyncLM

Bases: ABC

Abstract base class for asynchronous language models.

This class provides an interface for language models that can generate token probabilities asynchronously. It handles tokenization and vocabulary management.

Parameters:

Name Type Description Default
tokenizer

A Hugging Face tokenizer instance compatible with the language model

required
Source code in genlm_backend/llm/base.py
class AsyncLM(ABC):
    """Abstract base class for asynchronous language models.

    This class provides an interface for language models that can generate token probabilities
    asynchronously. It handles tokenization and vocabulary management.

    Args:
        tokenizer: A Hugging Face tokenizer instance compatible with the language model
    """

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.byte_vocab, self.str_vocab = decode_vocab(self.tokenizer)

    @abstractmethod
    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously.

        Args:
            token_ids (list[int]): A list of token IDs representing the prompt.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        pass

    @abstractmethod
    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Args:
            token_ids (list[int]): A list of token IDs representing the prompt.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        pass

    async def batch_next_token_logprobs(self, token_ids_list):
        """Batch request log probabilities for multiple token sequences asynchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists.

        Returns:
            (torch.Tensor): A tensor of log probability tensors.
        """
        logprobs = await asyncio.gather(
            *[self.next_token_logprobs(token_ids) for token_ids in token_ids_list]
        )

        return torch.stack(logprobs)

    def batch_next_token_logprobs_sync(self, token_ids_list):
        """Batch request log probabilities for multiple token sequences synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists.

        Returns:
            (torch.Tensor): A tensor of log probability tensors.
        """
        return torch.stack(
            [self.next_token_logprobs_sync(token_ids) for token_ids in token_ids_list]
        )

    def clear_cache(self):
        """Clear any caches used by the language model. No-op in base class."""
        pass

batch_next_token_logprobs(token_ids_list) async

Batch request log probabilities for multiple token sequences asynchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists.

required

Returns:

Type Description
Tensor

A tensor of log probability tensors.

Source code in genlm_backend/llm/base.py
async def batch_next_token_logprobs(self, token_ids_list):
    """Batch request log probabilities for multiple token sequences asynchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists.

    Returns:
        (torch.Tensor): A tensor of log probability tensors.
    """
    logprobs = await asyncio.gather(
        *[self.next_token_logprobs(token_ids) for token_ids in token_ids_list]
    )

    return torch.stack(logprobs)

batch_next_token_logprobs_sync(token_ids_list)

Batch request log probabilities for multiple token sequences synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists.

required

Returns:

Type Description
Tensor

A tensor of log probability tensors.

Source code in genlm_backend/llm/base.py
def batch_next_token_logprobs_sync(self, token_ids_list):
    """Batch request log probabilities for multiple token sequences synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists.

    Returns:
        (torch.Tensor): A tensor of log probability tensors.
    """
    return torch.stack(
        [self.next_token_logprobs_sync(token_ids) for token_ids in token_ids_list]
    )

clear_cache()

Clear any caches used by the language model. No-op in base class.

Source code in genlm_backend/llm/base.py
def clear_cache(self):
    """Clear any caches used by the language model. No-op in base class."""
    pass

next_token_logprobs(token_ids) abstractmethod async

Request log probabilities of next token asynchronously.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs representing the prompt.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm_backend/llm/base.py
@abstractmethod
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token asynchronously.

    Args:
        token_ids (list[int]): A list of token IDs representing the prompt.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    pass

next_token_logprobs_sync(token_ids) abstractmethod

Request log probabilities of next token synchronously.

Parameters:

Name Type Description Default
token_ids list[int]

A list of token IDs representing the prompt.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm_backend/llm/base.py
@abstractmethod
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Args:
        token_ids (list[int]): A list of token IDs representing the prompt.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    pass

AsyncTransformer

Bases: AsyncLM

Asynchronous wrapper around a HuggingFace causal language model with caching support.

This class provides an asynchronous interface to HuggingFace language models with automatic batching and caching (output and KV) for improved efficiency.

Source code in genlm_backend/llm/hf.py
class AsyncTransformer(AsyncLM):
    """Asynchronous wrapper around a HuggingFace causal language model with caching support.

    This class provides an asynchronous interface to HuggingFace language models with automatic batching
    and caching (output and KV) for improved efficiency.
    """

    @classmethod
    def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
        """Create an AsyncTransformer instance from a pretrained HuggingFace model.

        Args:
            model_id (str): Model identifier in HuggingFace's model hub.
            bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
                Defaults to None.
            hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
                Defaults to None.
            **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

        Returns:
            (AsyncTransformer): An initialized `AsyncTransformer` instance.
        """
        if bitsandbytes_opts:
            bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
        else:
            bnb_config = None

        _hf_opts = {
            "device_map": "auto",
            "torch_dtype": "auto",
        }
        if hf_opts:
            _hf_opts.update(hf_opts)

        tok = AutoTokenizer.from_pretrained(model_id)
        mod = AutoModelForCausalLM.from_pretrained(
            model_id, quantization_config=bnb_config, **_hf_opts
        )

        return cls(mod, tok, **kwargs)

    @torch.no_grad()
    def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
        """Initialize an AsyncTransformer instance.

        Args:
            hf_model: A HuggingFace CausalLM model instance.
            hf_tokenizer: A HuggingFace Tokenizer.
            batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
                Defaults to 20.
            timeout (float, optional): Seconds to wait since last query before processing current batch.
                Defaults to 0.02.
        """
        self.model = hf_model
        self.tokenizer = hf_tokenizer
        self.device = hf_model.device
        self.cache = TokenTrie()

        # Queries to be batched. Each query is a sequence of tokens,
        # and a Future to be called when the query is resolved.
        self.queries = []
        self.batch_size = batch_size
        self.timeout = timeout
        self.timer = None

        self.model.eval()

        super().__init__(tokenizer=self.tokenizer)

    def clear_cache(self):
        """Clear the cache of log probabilities and key/value pairs."""
        self.cache = TokenTrie(None, self.cache.logprobs)

    def clear_kv_cache(self):
        """Clear any key and value vectors from the cache."""
        self.cache.clear_kv_cache()

    def reset_async_queries(self):
        """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
        to completion."""
        self.queries = []

    @torch.no_grad()
    def cache_kv(self, prompt_tokens):
        """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

        Args:
            prompt_tokens (list[int]): token ids for the prompt to cache.
        """
        result = self.model(torch.tensor([prompt_tokens]).to(self.device))
        node = self.cache.extend_cache(1, prompt_tokens, result.logits[0], 0)
        node.past_key_values = result.past_key_values

    @torch.no_grad()
    def batch_evaluate_queries(self):
        """
        Process a batch of queued language model queries.

        This method is called internally when the `batch_size` has been met or the `timeout` has expired.
        """

        queries, self.queries = self.queries, []
        if len(queries) == 0:
            return

        query_groups = defaultdict(list)
        for query in queries:
            key = tuple(query.prompt)  # XXX: cache based on past_len too?
            query_groups[key].append(query)

        # Use one representative query from each group
        unique_queries = [group[0] for group in query_groups.values()]

        past_example = next((q.past for q in unique_queries if q.past), False)
        max_past_length = max(q.past_len for q in unique_queries)
        max_query_length = max(len(q.prompt) for q in unique_queries)

        padding_token_id = (
            self.tokenizer.pad_token_id
            if self.tokenizer.pad_token_id is not None
            else 0
        )

        input_ids = torch.tensor(
            [
                q.prompt_padded(padding_token_id, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        attn_masks = torch.tensor(
            [
                q.attention_mask(max_past_length, max_query_length)
                for q in unique_queries
            ]
        ).to(self.device)
        posn_ids = torch.tensor(
            [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
        ).to(self.device)
        if past_example:
            pasts = [
                [
                    torch.cat(
                        (
                            *(
                                q.past_padded(
                                    layer,
                                    j,
                                    max_past_length,
                                    past_example[0][0].dtype,
                                    self.device,
                                    past_example[0][0].shape,
                                )
                                for q in unique_queries
                            ),
                        ),
                        dim=0,
                    )
                    for j in range(2)
                ]
                for layer in range(len(past_example))
            ]
        else:
            pasts = None

        results = self.model(
            input_ids,
            attention_mask=attn_masks,
            position_ids=posn_ids,
            past_key_values=pasts,
            use_cache=pasts is not None,
        )

        assert len(results.logits) == len(unique_queries)

        for i, q in enumerate(unique_queries):
            result = results.logits[i]
            for dup_query in query_groups[tuple(q.prompt)]:
                dup_query.future.set_result(result)

    @torch.no_grad()
    def add_query(self, query, future, past):
        """Add a query to be evaluated in the next batch.

        This method is called internally when a `next_token_logprobs` request is made.

        Args:
            query (list[int]): Token IDs representing the query prompt
            future (asyncio.Future): Future to store the result in
            past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
                or None if this is a new query
        """
        self.queries.append(Query(query, future, past))

        if self.timer:
            self.timer.cancel()
            self.timer = None
        if len(self.queries) >= self.batch_size:
            self.batch_evaluate_queries()
        else:
            self.timer = asyncio.get_running_loop().call_later(
                self.timeout, lambda: self.batch_evaluate_queries()
            )

    def walk_cache(self, token_ids):
        """Walk the cache tree to find the deepest node matching a sequence of tokens.

        Args:
            token_ids (list[int]): Sequence of token IDs to follow in the cache tree

        Returns:
            tuple:
                - CacheNode: The deepest node in the cache tree that matches the token sequence
                - int: Number of tokens matched from the start of token_ids
                - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                    or None if no cached states were found
                - int: Base index indicating where the past states start in token_ids
        """
        # Walk while tokens can be found
        node = self.cache
        next_token_index = 0

        past = None
        base = 0
        while next_token_index < len(token_ids):
            if node.past_key_values is not None:
                past = node.past_key_values
                base = next_token_index
            if node.has_token(token_ids[next_token_index]):
                node = node.get_token(token_ids[next_token_index])
                next_token_index += 1
            else:
                break

        return node, next_token_index, past, base

    @torch.no_grad()
    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        node, next_token_index, past, base = self.walk_cache(token_ids)

        # If we processed all tokens, then we're done.
        if next_token_index == len(token_ids):
            return node.logprobs

        # Create a future with the prompt
        future = asyncio.get_running_loop().create_future()
        self.add_query(token_ids[base:], future, past)
        logits = await future

        # Create new nodes
        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    @torch.no_grad()
    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        # Walk while tokens can be found
        node, next_token_index, past, base = self.walk_cache(token_ids)

        if next_token_index == len(token_ids):
            return node.logprobs

        logits = self.model(
            torch.tensor([token_ids[base:]]).to(self.device),
            past_key_values=node.past_key_values,
            use_cache=node.past_key_values is not None,
        ).logits[0]

        node = node.extend_cache(next_token_index, token_ids, logits, base)

        return node.logprobs

    def next_token_logprobs_uncached(self, token_ids):
        """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

        Args:
            token_ids (list[int]): a list of token ids, representing a prompt to the language model.

        Returns:
            logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
        """
        if not token_ids:
            raise ValueError("Token ids must not be empty")

        with torch.no_grad():
            logits = self.model(
                torch.tensor([token_ids]).to(self.device),
                past_key_values=None,
                use_cache=False,
            ).logits[0]
            return torch.log_softmax(logits[-1], dim=0)

__init__(hf_model, hf_tokenizer, batch_size=20, timeout=0.02)

Initialize an AsyncTransformer instance.

Parameters:

Name Type Description Default
hf_model

A HuggingFace CausalLM model instance.

required
hf_tokenizer

A HuggingFace Tokenizer.

required
batch_size int

Maximum queries to process in one batch during auto-batching. Defaults to 20.

20
timeout float

Seconds to wait since last query before processing current batch. Defaults to 0.02.

0.02
Source code in genlm_backend/llm/hf.py
@torch.no_grad()
def __init__(self, hf_model, hf_tokenizer, batch_size=20, timeout=0.02):
    """Initialize an AsyncTransformer instance.

    Args:
        hf_model: A HuggingFace CausalLM model instance.
        hf_tokenizer: A HuggingFace Tokenizer.
        batch_size (int, optional): Maximum queries to process in one batch during auto-batching.
            Defaults to 20.
        timeout (float, optional): Seconds to wait since last query before processing current batch.
            Defaults to 0.02.
    """
    self.model = hf_model
    self.tokenizer = hf_tokenizer
    self.device = hf_model.device
    self.cache = TokenTrie()

    # Queries to be batched. Each query is a sequence of tokens,
    # and a Future to be called when the query is resolved.
    self.queries = []
    self.batch_size = batch_size
    self.timeout = timeout
    self.timer = None

    self.model.eval()

    super().__init__(tokenizer=self.tokenizer)

add_query(query, future, past)

Add a query to be evaluated in the next batch.

This method is called internally when a next_token_logprobs request is made.

Parameters:

Name Type Description Default
query list[int]

Token IDs representing the query prompt

required
future Future

Future to store the result in

required
past list[tuple[Tensor]] | None

Past key/value states from previous evaluation, or None if this is a new query

required
Source code in genlm_backend/llm/hf.py
@torch.no_grad()
def add_query(self, query, future, past):
    """Add a query to be evaluated in the next batch.

    This method is called internally when a `next_token_logprobs` request is made.

    Args:
        query (list[int]): Token IDs representing the query prompt
        future (asyncio.Future): Future to store the result in
        past (list[tuple[torch.Tensor]]|None): Past key/value states from previous evaluation,
            or None if this is a new query
    """
    self.queries.append(Query(query, future, past))

    if self.timer:
        self.timer.cancel()
        self.timer = None
    if len(self.queries) >= self.batch_size:
        self.batch_evaluate_queries()
    else:
        self.timer = asyncio.get_running_loop().call_later(
            self.timeout, lambda: self.batch_evaluate_queries()
        )

batch_evaluate_queries()

Process a batch of queued language model queries.

This method is called internally when the batch_size has been met or the timeout has expired.

Source code in genlm_backend/llm/hf.py
@torch.no_grad()
def batch_evaluate_queries(self):
    """
    Process a batch of queued language model queries.

    This method is called internally when the `batch_size` has been met or the `timeout` has expired.
    """

    queries, self.queries = self.queries, []
    if len(queries) == 0:
        return

    query_groups = defaultdict(list)
    for query in queries:
        key = tuple(query.prompt)  # XXX: cache based on past_len too?
        query_groups[key].append(query)

    # Use one representative query from each group
    unique_queries = [group[0] for group in query_groups.values()]

    past_example = next((q.past for q in unique_queries if q.past), False)
    max_past_length = max(q.past_len for q in unique_queries)
    max_query_length = max(len(q.prompt) for q in unique_queries)

    padding_token_id = (
        self.tokenizer.pad_token_id
        if self.tokenizer.pad_token_id is not None
        else 0
    )

    input_ids = torch.tensor(
        [
            q.prompt_padded(padding_token_id, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    attn_masks = torch.tensor(
        [
            q.attention_mask(max_past_length, max_query_length)
            for q in unique_queries
        ]
    ).to(self.device)
    posn_ids = torch.tensor(
        [q.position_ids(max_past_length, max_query_length) for q in unique_queries]
    ).to(self.device)
    if past_example:
        pasts = [
            [
                torch.cat(
                    (
                        *(
                            q.past_padded(
                                layer,
                                j,
                                max_past_length,
                                past_example[0][0].dtype,
                                self.device,
                                past_example[0][0].shape,
                            )
                            for q in unique_queries
                        ),
                    ),
                    dim=0,
                )
                for j in range(2)
            ]
            for layer in range(len(past_example))
        ]
    else:
        pasts = None

    results = self.model(
        input_ids,
        attention_mask=attn_masks,
        position_ids=posn_ids,
        past_key_values=pasts,
        use_cache=pasts is not None,
    )

    assert len(results.logits) == len(unique_queries)

    for i, q in enumerate(unique_queries):
        result = results.logits[i]
        for dup_query in query_groups[tuple(q.prompt)]:
            dup_query.future.set_result(result)

cache_kv(prompt_tokens)

Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

Parameters:

Name Type Description Default
prompt_tokens list[int]

token ids for the prompt to cache.

required
Source code in genlm_backend/llm/hf.py
@torch.no_grad()
def cache_kv(self, prompt_tokens):
    """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens.

    Args:
        prompt_tokens (list[int]): token ids for the prompt to cache.
    """
    result = self.model(torch.tensor([prompt_tokens]).to(self.device))
    node = self.cache.extend_cache(1, prompt_tokens, result.logits[0], 0)
    node.past_key_values = result.past_key_values

clear_cache()

Clear the cache of log probabilities and key/value pairs.

Source code in genlm_backend/llm/hf.py
def clear_cache(self):
    """Clear the cache of log probabilities and key/value pairs."""
    self.cache = TokenTrie(None, self.cache.logprobs)

clear_kv_cache()

Clear any key and value vectors from the cache.

Source code in genlm_backend/llm/hf.py
def clear_kv_cache(self):
    """Clear any key and value vectors from the cache."""
    self.cache.clear_kv_cache()

from_name(model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs) classmethod

Create an AsyncTransformer instance from a pretrained HuggingFace model.

Parameters:

Name Type Description Default
model_id str

Model identifier in HuggingFace's model hub.

required
bitsandbytes_opts dict

Additional configuration options for bitsandbytes quantization. Defaults to None.

None
hf_opts dict

Additional configuration options for loading the HuggingFace model. Defaults to None.

None
**kwargs

Additional arguments passed to the AsyncTransformer constructor

{}

Returns:

Type Description
AsyncTransformer

An initialized AsyncTransformer instance.

Source code in genlm_backend/llm/hf.py
@classmethod
def from_name(cls, model_id, bitsandbytes_opts=None, hf_opts=None, **kwargs):
    """Create an AsyncTransformer instance from a pretrained HuggingFace model.

    Args:
        model_id (str): Model identifier in HuggingFace's model hub.
        bitsandbytes_opts (dict, optional): Additional configuration options for bitsandbytes quantization.
            Defaults to None.
        hf_opts (dict, optional): Additional configuration options for loading the HuggingFace model.
            Defaults to None.
        **kwargs: Additional arguments passed to the `AsyncTransformer` constructor

    Returns:
        (AsyncTransformer): An initialized `AsyncTransformer` instance.
    """
    if bitsandbytes_opts:
        bnb_config = BitsAndBytesConfig(**bitsandbytes_opts)
    else:
        bnb_config = None

    _hf_opts = {
        "device_map": "auto",
        "torch_dtype": "auto",
    }
    if hf_opts:
        _hf_opts.update(hf_opts)

    tok = AutoTokenizer.from_pretrained(model_id)
    mod = AutoModelForCausalLM.from_pretrained(
        model_id, quantization_config=bnb_config, **_hf_opts
    )

    return cls(mod, tok, **kwargs)

next_token_logprobs(token_ids) async

Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with await.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm_backend/llm/hf.py
@torch.no_grad()
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor of with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    node, next_token_index, past, base = self.walk_cache(token_ids)

    # If we processed all tokens, then we're done.
    if next_token_index == len(token_ids):
        return node.logprobs

    # Create a future with the prompt
    future = asyncio.get_running_loop().create_future()
    self.add_query(token_ids[base:], future, past)
    logits = await future

    # Create new nodes
    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_sync(token_ids)

Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm_backend/llm/hf.py
@torch.no_grad()
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token. Not asynchronous, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    # Walk while tokens can be found
    node, next_token_index, past, base = self.walk_cache(token_ids)

    if next_token_index == len(token_ids):
        return node.logprobs

    logits = self.model(
        torch.tensor([token_ids[base:]]).to(self.device),
        past_key_values=node.past_key_values,
        use_cache=node.past_key_values is not None,
    ).logits[0]

    node = node.extend_cache(next_token_index, token_ids, logits, base)

    return node.logprobs

next_token_logprobs_uncached(token_ids)

Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

Parameters:

Name Type Description Default
token_ids list[int]

a list of token ids, representing a prompt to the language model.

required

Returns:

Name Type Description
logprobs Tensor

a tensor with the language model's log (normalized) probabilities for the next token following the prompt.

Source code in genlm_backend/llm/hf.py
def next_token_logprobs_uncached(self, token_ids):
    """Request log probabilities of next token. No KV or output caching, and does not support auto-batching.

    Args:
        token_ids (list[int]): a list of token ids, representing a prompt to the language model.

    Returns:
        logprobs (torch.Tensor): a tensor with the language model's log (normalized) probabilities for the next token following the prompt.
    """
    if not token_ids:
        raise ValueError("Token ids must not be empty")

    with torch.no_grad():
        logits = self.model(
            torch.tensor([token_ids]).to(self.device),
            past_key_values=None,
            use_cache=False,
        ).logits[0]
        return torch.log_softmax(logits[-1], dim=0)

reset_async_queries()

Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.

Source code in genlm_backend/llm/hf.py
def reset_async_queries(self):
    """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing
    to completion."""
    self.queries = []

walk_cache(token_ids)

Walk the cache tree to find the deepest node matching a sequence of tokens.

Parameters:

Name Type Description Default
token_ids list[int]

Sequence of token IDs to follow in the cache tree

required

Returns:

Name Type Description
tuple
  • CacheNode: The deepest node in the cache tree that matches the token sequence
  • int: Number of tokens matched from the start of token_ids
  • list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node, or None if no cached states were found
  • int: Base index indicating where the past states start in token_ids
Source code in genlm_backend/llm/hf.py
def walk_cache(self, token_ids):
    """Walk the cache tree to find the deepest node matching a sequence of tokens.

    Args:
        token_ids (list[int]): Sequence of token IDs to follow in the cache tree

    Returns:
        tuple:
            - CacheNode: The deepest node in the cache tree that matches the token sequence
            - int: Number of tokens matched from the start of token_ids
            - list[tuple[torch.Tensor]]|None: Past key/value states from the deepest cached node,
                or None if no cached states were found
            - int: Base index indicating where the past states start in token_ids
    """
    # Walk while tokens can be found
    node = self.cache
    next_token_index = 0

    past = None
    base = 0
    while next_token_index < len(token_ids):
        if node.past_key_values is not None:
            past = node.past_key_values
            base = next_token_index
        if node.has_token(token_ids[next_token_index]):
            node = node.get_token(token_ids[next_token_index])
            next_token_index += 1
        else:
            break

    return node, next_token_index, past, base

AsyncVirtualLM

Bases: AsyncLM

A wrapper around vLLM's AsyncLLMEngine for asynchronous next token log probability computations.

This class provides an asynchronous interface for computing log probabilities using vLLM's engine. It is optimized for next token log probability computations and supports caching of results (outputs and KV).

Source code in genlm_backend/llm/vllm.py
class AsyncVirtualLM(AsyncLM):
    """A wrapper around vLLM's `AsyncLLMEngine` for asynchronous next token log probability computations.

    This class provides an asynchronous interface for computing log probabilities using vLLM's engine.
    It is optimized for next token log probability computations and supports caching of results (outputs and KV).
    """

    default_params = SamplingParams(
        max_tokens=1, n=1, logprobs=1, detokenize=False, stop=None, ignore_eos=True
    )

    def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
        """Initialize an `AsyncVirtualLM` instance.

        Args:
            async_llm_engine (AsyncLLMEngine): The async vLLM engine instance.
            cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
            cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm_backend.cache.OutputCache] constructor. Defaults to {}.

        Note:
            The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
        """
        self.async_llm_engine = async_llm_engine
        self.tokenizer = async_llm_engine.engine.get_tokenizer()
        self.request_counter = Counter()
        self.custom_sampler = DeferredSampler()
        self.cache = (
            OutputCache(maxsize=cache_size, **cache_opts)
            if cache_size > 0
            else None
        )

        async_llm_engine.engine.log_stats = False

        super().__init__(tokenizer=self.tokenizer)

    @classmethod
    def from_name(cls, model_name, engine_opts=None, **kwargs):
        """Create a `AsyncVirtualLM` instance from a model name.

        Args:
            model_name (str): Name of the model to load.
            engine_opts (dict): Additional options to pass to the `AsyncLLMEngine`. The engine will be
                configured with prefix caching enabled and async output processing disabled by default.
            **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

        Returns:
            (AsyncVirtualLM): An `AsyncVirtualLM` instance.
        """
        if not HAS_VLLM:
            raise ImportError(
                "vLLM not available. Install vLLM or use AsyncTransformer instead."
            )

        engine_opts = {
            "enable_prefix_caching": True,
            "disable_log_requests": True,
            "disable_async_output_proc": True,
            **(engine_opts or {}),
        }

        engine = AsyncLLMEngine.from_engine_args(
            AsyncEngineArgs(model=model_name, tokenizer=model_name, **engine_opts)
        )

        return cls(engine, **kwargs)

    async def next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously with output caching.

        Args:
            token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            result (torch.Tensor): Normalized log probability tensor.

        Warning:
            Do not use `asyncio.run(next_token_logprobs())` as it may interfere with vLLM's background loop.
            For synchronous usage, use the `next_token_logprobs_sync()` method instead.
        """
        key = tuple(token_ids)

        if self.cache is not None and key in self.cache:
            return self.cache[key]

        result = await self._next_token_logprobs(key)

        if self.cache is not None:
            self.cache[key] = result

        return result

    async def _next_token_logprobs(self, token_ids):
        """Request log probabilities of next token asynchronously.

        Args:
            token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        req_id = str(next(self.request_counter))
        prompt = TokensPrompt(prompt_token_ids=token_ids)

        outputs = []
        with self._optimized_sampling_context():
            async for output in self.async_llm_engine.generate(
                prompt=prompt,
                sampling_params=self.default_params,
                request_id=req_id,
            ):
                if output.finished:
                    outputs.append(output)

        return self._validate_outputs(outputs)

    def next_token_logprobs_sync(self, token_ids):
        """Request log probabilities of next token synchronously.

        Args:
            token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self.batch_next_token_logprobs_sync([token_ids])[0]

    def batch_next_token_logprobs_sync(self, token_ids_list):
        """
        Request log probabilities of next tokens in a batch synchronously.

        Args:
            token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

        Returns:
            (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
        """
        req_ids = []
        for token_ids in token_ids_list:
            req_id = str(next(self.request_counter))
            req_ids.append(req_id)
            self.async_llm_engine.engine.add_request(
                prompt=TokensPrompt(prompt_token_ids=token_ids),
                params=self.default_params,
                request_id=req_id,
            )

        req_id2outputs = {}
        with self._optimized_sampling_context():
            while self.async_llm_engine.engine.has_unfinished_requests():
                output = self.async_llm_engine.engine.step()
                for out in output:
                    if out.finished:
                        assert (
                            out.request_id not in req_id2outputs
                        ), f"Duplicate outputs for request {out.request_id}"
                        assert (
                            out.request_id in req_ids
                        ), f"{out.request_id} not in requested IDs"
                        req_id2outputs[out.request_id] = out

        logprobs = [
            self._validate_outputs([req_id2outputs[req_id]]) for req_id in req_ids
        ]

        return torch.stack(logprobs)

    @contextmanager
    def _optimized_sampling_context(self):
        """Context manager for optimized sampling configuration."""
        model = self.async_llm_engine.engine.model_executor.driver_worker.model_runner.model
        original_sampler = model.sampler
        try:
            model.sampler = self.custom_sampler
            yield
        finally:
            model.sampler = original_sampler

    def _validate_outputs(self, outputs):
        """Validate and extract logprobs from a vLLM output.

        Args:
            outputs: List of sequence group outputs from vLLM generation

        Returns:
            Tensor of log probabilities for the next token

        Raises:
            AssertionError: If output structure doesn't match expected format
        """
        assert len(outputs) == 1, "Expected exactly one sequence group"
        seq_group = outputs[0]

        assert (
            len(seq_group.outputs) == 1
        ), "Expected exactly one sequence in output"
        sequence = seq_group.outputs[0]

        assert len(sequence.logprobs) == 1, "Expected exactly one set of logprobs"
        token_logprobs = sequence.logprobs[0].logprobs

        return token_logprobs

    def clear_cache(self):
        """Clear output cache."""
        if self.cache:
            self.cache.clear()

    def __del__(self):
        """Clean up resources on deletion."""
        self._cleanup_engine()

    def _cleanup_engine(self):
        """Clean up the vLLM engine and associated resources."""
        if async_engine := getattr(self, "async_llm_engine", None):
            async_engine.shutdown_background_loop()

__del__()

Clean up resources on deletion.

Source code in genlm_backend/llm/vllm.py
def __del__(self):
    """Clean up resources on deletion."""
    self._cleanup_engine()

__init__(async_llm_engine, cache_size=0, cache_opts={})

Initialize an AsyncVirtualLM instance.

Parameters:

Name Type Description Default
async_llm_engine AsyncLLMEngine

The async vLLM engine instance.

required
cache_size int

Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.

0
cache_opts dict

Additional options to pass to the OutputCache constructor. Defaults to {}.

{}
Note

The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.

Source code in genlm_backend/llm/vllm.py
def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
    """Initialize an `AsyncVirtualLM` instance.

    Args:
        async_llm_engine (AsyncLLMEngine): The async vLLM engine instance.
        cache_size (int, optional): Maximum size of the output cache. If 0, caching is disabled. Defaults to 0.
        cache_opts (dict, optional): Additional options to pass to the [`OutputCache`][genlm_backend.cache.OutputCache] constructor. Defaults to {}.

    Note:
        The cache stores the log probabilities for previously seen token sequences to avoid redundant requests. KV caching is handled internally by the vLLM engine.
    """
    self.async_llm_engine = async_llm_engine
    self.tokenizer = async_llm_engine.engine.get_tokenizer()
    self.request_counter = Counter()
    self.custom_sampler = DeferredSampler()
    self.cache = (
        OutputCache(maxsize=cache_size, **cache_opts)
        if cache_size > 0
        else None
    )

    async_llm_engine.engine.log_stats = False

    super().__init__(tokenizer=self.tokenizer)

batch_next_token_logprobs_sync(token_ids_list)

Request log probabilities of next tokens in a batch synchronously.

Parameters:

Name Type Description Default
token_ids_list list[list[int]]

A list of token ID lists, each representing a prompt to the language model.

required

Returns:

Type Description
Tensor

A tensor of normalized log probability tensors, one for each prompt in the input list.

Source code in genlm_backend/llm/vllm.py
def batch_next_token_logprobs_sync(self, token_ids_list):
    """
    Request log probabilities of next tokens in a batch synchronously.

    Args:
        token_ids_list (list[list[int]]): A list of token ID lists, each representing a prompt to the language model.

    Returns:
        (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
    """
    req_ids = []
    for token_ids in token_ids_list:
        req_id = str(next(self.request_counter))
        req_ids.append(req_id)
        self.async_llm_engine.engine.add_request(
            prompt=TokensPrompt(prompt_token_ids=token_ids),
            params=self.default_params,
            request_id=req_id,
        )

    req_id2outputs = {}
    with self._optimized_sampling_context():
        while self.async_llm_engine.engine.has_unfinished_requests():
            output = self.async_llm_engine.engine.step()
            for out in output:
                if out.finished:
                    assert (
                        out.request_id not in req_id2outputs
                    ), f"Duplicate outputs for request {out.request_id}"
                    assert (
                        out.request_id in req_ids
                    ), f"{out.request_id} not in requested IDs"
                    req_id2outputs[out.request_id] = out

    logprobs = [
        self._validate_outputs([req_id2outputs[req_id]]) for req_id in req_ids
    ]

    return torch.stack(logprobs)

clear_cache()

Clear output cache.

Source code in genlm_backend/llm/vllm.py
def clear_cache(self):
    """Clear output cache."""
    if self.cache:
        self.cache.clear()

from_name(model_name, engine_opts=None, **kwargs) classmethod

Create a AsyncVirtualLM instance from a model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load.

required
engine_opts dict

Additional options to pass to the AsyncLLMEngine. The engine will be configured with prefix caching enabled and async output processing disabled by default.

None
**kwargs

Additional arguments passed to AsyncVirtualLM constructor.

{}

Returns:

Type Description
AsyncVirtualLM

An AsyncVirtualLM instance.

Source code in genlm_backend/llm/vllm.py
@classmethod
def from_name(cls, model_name, engine_opts=None, **kwargs):
    """Create a `AsyncVirtualLM` instance from a model name.

    Args:
        model_name (str): Name of the model to load.
        engine_opts (dict): Additional options to pass to the `AsyncLLMEngine`. The engine will be
            configured with prefix caching enabled and async output processing disabled by default.
        **kwargs: Additional arguments passed to `AsyncVirtualLM` constructor.

    Returns:
        (AsyncVirtualLM): An `AsyncVirtualLM` instance.
    """
    if not HAS_VLLM:
        raise ImportError(
            "vLLM not available. Install vLLM or use AsyncTransformer instead."
        )

    engine_opts = {
        "enable_prefix_caching": True,
        "disable_log_requests": True,
        "disable_async_output_proc": True,
        **(engine_opts or {}),
    }

    engine = AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(model=model_name, tokenizer=model_name, **engine_opts)
    )

    return cls(engine, **kwargs)

next_token_logprobs(token_ids) async

Request log probabilities of next token asynchronously with output caching.

Parameters:

Name Type Description Default
token_ids_list list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Name Type Description
result Tensor

Normalized log probability tensor.

Warning

Do not use asyncio.run(next_token_logprobs()) as it may interfere with vLLM's background loop. For synchronous usage, use the next_token_logprobs_sync() method instead.

Source code in genlm_backend/llm/vllm.py
async def next_token_logprobs(self, token_ids):
    """Request log probabilities of next token asynchronously with output caching.

    Args:
        token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        result (torch.Tensor): Normalized log probability tensor.

    Warning:
        Do not use `asyncio.run(next_token_logprobs())` as it may interfere with vLLM's background loop.
        For synchronous usage, use the `next_token_logprobs_sync()` method instead.
    """
    key = tuple(token_ids)

    if self.cache is not None and key in self.cache:
        return self.cache[key]

    result = await self._next_token_logprobs(key)

    if self.cache is not None:
        self.cache[key] = result

    return result

next_token_logprobs_sync(token_ids)

Request log probabilities of next token synchronously.

Parameters:

Name Type Description Default
token_ids_list list[int]

A list of token IDs, representing a prompt to the language model.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm_backend/llm/vllm.py
def next_token_logprobs_sync(self, token_ids):
    """Request log probabilities of next token synchronously.

    Args:
        token_ids_list (list[int]): A list of token IDs, representing a prompt to the language model.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self.batch_next_token_logprobs_sync([token_ids])[0]

MockAsyncLM

Bases: AsyncLM

Mock implementation of AsyncLM used for testing.

Source code in genlm_backend/llm/base.py
class MockAsyncLM(AsyncLM):
    """Mock implementation of AsyncLM used for testing."""

    def __init__(self, tokenizer):
        """Initialize a `MockAsyncLM` instance.

        Args:
            tokenizer: Hugging Face tokenizer instance
        """
        super().__init__(tokenizer)
        self._rng = np.random.RandomState(42)

    @classmethod
    def from_name(cls, model_name, **kwargs):
        """Create a MockAsyncLM instance over the vocabulary of the model's tokenizer.

        Args:
            model_name (str): Name of pretrained model to load tokenizer from
            **kwargs: Additional arguments passed to `MockAsyncLM` constructor

        Returns:
            (MockAsyncLM): `MockAsyncLM` instance initialized with tokenizer from `model_name`
        """
        from transformers import AutoTokenizer

        return cls(AutoTokenizer.from_pretrained(model_name), **kwargs)

    async def next_token_logprobs(self, token_ids):
        """Get next token log probabilities asynchronously.

        Args:
            token_ids (list[int]): Input token IDs.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self._get_logprobs(token_ids)

    def next_token_logprobs_sync(self, token_ids):
        """Get next token log probabilities synchronously.

        Args:
            token_ids (list[int]): Input token IDs.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        return self._get_logprobs(token_ids)

    def _get_logprobs(self, token_ids):
        """Generate random but deterministic log probabilities for given tokens.

        Uses token_ids to seed the random generator, ensuring same inputs produce same outputs.

        Args:
            token_ids (list[int]): Input token IDs.

        Returns:
            (torch.Tensor): Normalized log probability tensor.
        """
        seed = sum([(i + 1) * t for i, t in enumerate(token_ids)])
        self._rng.seed(seed)
        logits = torch.from_numpy(
            self._rng.rand(len(self.tokenizer)).astype(np.float32)
        )
        return torch.log_softmax(logits, dim=-1)

__init__(tokenizer)

Initialize a MockAsyncLM instance.

Parameters:

Name Type Description Default
tokenizer

Hugging Face tokenizer instance

required
Source code in genlm_backend/llm/base.py
def __init__(self, tokenizer):
    """Initialize a `MockAsyncLM` instance.

    Args:
        tokenizer: Hugging Face tokenizer instance
    """
    super().__init__(tokenizer)
    self._rng = np.random.RandomState(42)

from_name(model_name, **kwargs) classmethod

Create a MockAsyncLM instance over the vocabulary of the model's tokenizer.

Parameters:

Name Type Description Default
model_name str

Name of pretrained model to load tokenizer from

required
**kwargs

Additional arguments passed to MockAsyncLM constructor

{}

Returns:

Type Description
MockAsyncLM

MockAsyncLM instance initialized with tokenizer from model_name

Source code in genlm_backend/llm/base.py
@classmethod
def from_name(cls, model_name, **kwargs):
    """Create a MockAsyncLM instance over the vocabulary of the model's tokenizer.

    Args:
        model_name (str): Name of pretrained model to load tokenizer from
        **kwargs: Additional arguments passed to `MockAsyncLM` constructor

    Returns:
        (MockAsyncLM): `MockAsyncLM` instance initialized with tokenizer from `model_name`
    """
    from transformers import AutoTokenizer

    return cls(AutoTokenizer.from_pretrained(model_name), **kwargs)

next_token_logprobs(token_ids) async

Get next token log probabilities asynchronously.

Parameters:

Name Type Description Default
token_ids list[int]

Input token IDs.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm_backend/llm/base.py
async def next_token_logprobs(self, token_ids):
    """Get next token log probabilities asynchronously.

    Args:
        token_ids (list[int]): Input token IDs.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self._get_logprobs(token_ids)

next_token_logprobs_sync(token_ids)

Get next token log probabilities synchronously.

Parameters:

Name Type Description Default
token_ids list[int]

Input token IDs.

required

Returns:

Type Description
Tensor

Normalized log probability tensor.

Source code in genlm_backend/llm/base.py
def next_token_logprobs_sync(self, token_ids):
    """Get next token log probabilities synchronously.

    Args:
        token_ids (list[int]): Input token IDs.

    Returns:
        (torch.Tensor): Normalized log probability tensor.
    """
    return self._get_logprobs(token_ids)