API Reference
Predictor
- class epitopegen.inference.EpitopeGenPredictor(checkpoint_path: str | None = None, model_path: str = 'gpt2-small', tokenizer_path: str | None = None, device: str | None = None, special_token_id: int = 2, batch_size: int = 32, cache_dir: str | None = None)[source]
Bases:
objectA predictor class for generating epitopes from TCR sequences using a GPT-2 based model.
This class handles model initialization, checkpoint management, and prediction generation for TCR-epitope pairs. It supports multiple checkpoints and automatic downloading of model weights from Zenodo.
- ZENODO_URL
URL for downloading model checkpoints.
- DEFAULT_CHECKPOINT
Name of the default checkpoint to use.
- AVAILABLE_CHECKPOINTS
Dictionary mapping checkpoint names to their file paths.
- AVAILABLE_CHECKPOINTS = {'ckpt1': 'checkpoints/epitopegen_weight_1/epoch_28/pytorch_model.bin', 'ckpt10': 'checkpoints/epitopegen_weight_10/epoch_20/pytorch_model.bin', 'ckpt11': 'checkpoints/epitopegen_weight_11/epoch_21/pytorch_model.bin', 'ckpt2': 'checkpoints/epitopegen_weight_2/epoch_26/pytorch_model.bin', 'ckpt3': 'checkpoints/epitopegen_weight_3/epoch_19/pytorch_model.bin', 'ckpt4': 'checkpoints/epitopegen_weight_4/epoch_21/pytorch_model.bin', 'ckpt5': 'checkpoints/epitopegen_weight_5/epoch_28/pytorch_model.bin', 'ckpt6': 'checkpoints/epitopegen_weight_6/epoch_28/pytorch_model.bin', 'ckpt7': 'checkpoints/epitopegen_weight_7/epoch_24/pytorch_model.bin', 'ckpt8': 'checkpoints/epitopegen_weight_8/epoch_22/pytorch_model.bin', 'ckpt9': 'checkpoints/epitopegen_weight_9/epoch_24/pytorch_model.bin'}
- DEFAULT_CHECKPOINT = 'checkpoints/epitopegen_weight_1/epoch_28/pytorch_model.bin'
- ZENODO_URL = 'https://zenodo.org/records/14897624/files/checkpoints.zip'
- predict(tcr_sequences: list, output_path: str | None = None, top_k: int = 50, temperature: float = 0.7, top_p: float = 0.95, use_attention_mask=False) DataFrame[source]
Generates epitope predictions for a list of TCR sequences.
A convenience wrapper around predict_from_df that accepts a list of TCR sequences instead of a DataFrame.
- Parameters:
tcr_sequences – List of TCR amino acid sequences to generate predictions for.
output_path – Path to save the prediction results CSV. If None, results are only returned as DataFrame.
top_k – Number of epitope predictions to generate per TCR sequence (default: 50). Note: This is not the top-k parameter used in top-k top-p sampling.
temperature – Sampling temperature for generation. Higher values increase diversity (default: 0.7).
top_p – Nucleus sampling probability threshold (default: 0.95).
use_attention_mask – Whether to use attention masking during generation (default: False).
- Returns:
- DataFrame containing TCR sequences and their predicted epitopes.
Columns are [‘tcr’, ‘pred_0’, ‘pred_1’, …, ‘pred_{top_k-1}’].
- Return type:
pd.DataFrame
- predict_all(tcr_sequences: list, output_dir: str, models: List[str] | None = None, top_k: int = 50, temperature: float = 0.7, top_p: float = 0.95, use_attention_mask=False) Dict[str, DataFrame][source]
Runs predictions using multiple model checkpoints.
- Parameters:
tcr_sequences – List of TCR amino acid sequences to generate predictions for.
output_dir – Directory path where prediction results will be saved.
models – List of checkpoint names to use. If None, uses all available checkpoints.
top_k – Number of most likely tokens to consider for sampling (default: 50).
temperature – Sampling temperature, higher values increase diversity (default: 0.7).
top_p – Nucleus sampling probability threshold (default: 0.95).
use_attention_mask – Whether to use attention masking during generation (default: False).
- Returns:
- Dictionary mapping checkpoint names to prediction DataFrames.
Each DataFrame contains the input TCR sequences and their predicted epitopes.
- Return type:
Dict[str, pd.DataFrame]
- predict_from_df(df: DataFrame, output_path: str | None = None, top_k: int = 50, temperature: float = 0.7, top_p: float = 0.95, use_attention_mask=False) DataFrame[source]
Generates epitope predictions from a DataFrame containing TCR sequences.
Main prediction method that processes TCR sequences in batches and generates multiple epitope predictions for each sequence using the loaded model.
- Parameters:
df – DataFrame containing TCR sequences in a ‘text’ column.
output_path – Path to save the prediction results CSV. If None, results are only returned as DataFrame.
top_k – Number of epitope predictions to generate per TCR sequence (default: 50).
temperature – Sampling temperature for text generation. Higher values increase diversity (default: 0.7).
top_p – Nucleus sampling probability threshold (default: 0.95).
use_attention_mask – Whether to use attention masking during generation. Defaults to False to match training conditions.
- Returns:
- DataFrame containing TCR sequences and their predicted epitopes.
Columns are [‘tcr’, ‘pred_0’, ‘pred_1’, …, ‘pred_{top_k-1}’].
- Return type:
pd.DataFrame
Note
The method processes sequences in batches defined by self.batch_size and prints a detailed summary of the predictions including statistics about TCR lengths, epitope lengths, and most common predictions.