API Reference

Predictor

class epitopegen.inference.EpitopeGenPredictor(checkpoint_path: str | None = None, model_path: str = 'gpt2-small', tokenizer_path: str | None = None, device: str | None = None, special_token_id: int = 2, batch_size: int = 32, cache_dir: str | None = None)[source]

Bases: object

A predictor class for generating epitopes from TCR sequences using a GPT-2 based model.

This class handles model initialization, checkpoint management, and prediction generation for TCR-epitope pairs. It supports multiple checkpoints and automatic downloading of model weights from Zenodo.

ZENODO_URL

URL for downloading model checkpoints.

DEFAULT_CHECKPOINT

Name of the default checkpoint to use.

AVAILABLE_CHECKPOINTS

Dictionary mapping checkpoint names to their file paths.

AVAILABLE_CHECKPOINTS = {'ckpt1': 'checkpoints/epitopegen_weight_1/epoch_28/pytorch_model.bin', 'ckpt10': 'checkpoints/epitopegen_weight_10/epoch_20/pytorch_model.bin', 'ckpt11': 'checkpoints/epitopegen_weight_11/epoch_21/pytorch_model.bin', 'ckpt2': 'checkpoints/epitopegen_weight_2/epoch_26/pytorch_model.bin', 'ckpt3': 'checkpoints/epitopegen_weight_3/epoch_19/pytorch_model.bin', 'ckpt4': 'checkpoints/epitopegen_weight_4/epoch_21/pytorch_model.bin', 'ckpt5': 'checkpoints/epitopegen_weight_5/epoch_28/pytorch_model.bin', 'ckpt6': 'checkpoints/epitopegen_weight_6/epoch_28/pytorch_model.bin', 'ckpt7': 'checkpoints/epitopegen_weight_7/epoch_24/pytorch_model.bin', 'ckpt8': 'checkpoints/epitopegen_weight_8/epoch_22/pytorch_model.bin', 'ckpt9': 'checkpoints/epitopegen_weight_9/epoch_24/pytorch_model.bin'}
DEFAULT_CHECKPOINT = 'checkpoints/epitopegen_weight_1/epoch_28/pytorch_model.bin'
ZENODO_URL = 'https://zenodo.org/records/14897624/files/checkpoints.zip'
predict(tcr_sequences: list, output_path: str | None = None, top_k: int = 50, temperature: float = 0.7, top_p: float = 0.95, use_attention_mask=False) DataFrame[source]

Generates epitope predictions for a list of TCR sequences.

A convenience wrapper around predict_from_df that accepts a list of TCR sequences instead of a DataFrame.

Parameters:
  • tcr_sequences – List of TCR amino acid sequences to generate predictions for.

  • output_path – Path to save the prediction results CSV. If None, results are only returned as DataFrame.

  • top_k – Number of epitope predictions to generate per TCR sequence (default: 50). Note: This is not the top-k parameter used in top-k top-p sampling.

  • temperature – Sampling temperature for generation. Higher values increase diversity (default: 0.7).

  • top_p – Nucleus sampling probability threshold (default: 0.95).

  • use_attention_mask – Whether to use attention masking during generation (default: False).

Returns:

DataFrame containing TCR sequences and their predicted epitopes.

Columns are [‘tcr’, ‘pred_0’, ‘pred_1’, …, ‘pred_{top_k-1}’].

Return type:

pd.DataFrame

predict_all(tcr_sequences: list, output_dir: str, models: List[str] | None = None, top_k: int = 50, temperature: float = 0.7, top_p: float = 0.95, use_attention_mask=False) Dict[str, DataFrame][source]

Runs predictions using multiple model checkpoints.

Parameters:
  • tcr_sequences – List of TCR amino acid sequences to generate predictions for.

  • output_dir – Directory path where prediction results will be saved.

  • models – List of checkpoint names to use. If None, uses all available checkpoints.

  • top_k – Number of most likely tokens to consider for sampling (default: 50).

  • temperature – Sampling temperature, higher values increase diversity (default: 0.7).

  • top_p – Nucleus sampling probability threshold (default: 0.95).

  • use_attention_mask – Whether to use attention masking during generation (default: False).

Returns:

Dictionary mapping checkpoint names to prediction DataFrames.

Each DataFrame contains the input TCR sequences and their predicted epitopes.

Return type:

Dict[str, pd.DataFrame]

predict_from_df(df: DataFrame, output_path: str | None = None, top_k: int = 50, temperature: float = 0.7, top_p: float = 0.95, use_attention_mask=False) DataFrame[source]

Generates epitope predictions from a DataFrame containing TCR sequences.

Main prediction method that processes TCR sequences in batches and generates multiple epitope predictions for each sequence using the loaded model.

Parameters:
  • df – DataFrame containing TCR sequences in a ‘text’ column.

  • output_path – Path to save the prediction results CSV. If None, results are only returned as DataFrame.

  • top_k – Number of epitope predictions to generate per TCR sequence (default: 50).

  • temperature – Sampling temperature for text generation. Higher values increase diversity (default: 0.7).

  • top_p – Nucleus sampling probability threshold (default: 0.95).

  • use_attention_mask – Whether to use attention masking during generation. Defaults to False to match training conditions.

Returns:

DataFrame containing TCR sequences and their predicted epitopes.

Columns are [‘tcr’, ‘pred_0’, ‘pred_1’, …, ‘pred_{top_k-1}’].

Return type:

pd.DataFrame

Note

The method processes sequences in batches defined by self.batch_size and prints a detailed summary of the predictions including statistics about TCR lengths, epitope lengths, and most common predictions.

Annotator

Analysis Tools