scientific-skills/torchdrug/references/core_concepts.md
This reference covers TorchDrug's fundamental architecture, design principles, and technical implementation details.
TorchDrug separates concerns into distinct modules:
Benefits:
All components inherit from core.Configurable:
Base class for all TorchDrug components.
Key Methods:
config_dict(): Serialize to dictionaryload_config_dict(config): Load from dictionarysave(file): Save to fileload(file): Load from fileExample:
from torchdrug import core, models
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
# Save configuration
config = model.config_dict()
# {'class': 'GIN', 'input_dim': 10, 'hidden_dims': [256, 256], ...}
# Reconstruct model
model2 = core.Configurable.load_config_dict(config)
Decorator for registering models, tasks, and datasets.
Usage:
from torchdrug import core as core_td
@core_td.register("models.CustomModel")
class CustomModel(nn.Module, core_td.Configurable):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.linear = nn.Linear(input_dim, hidden_dim)
def forward(self, graph, input, all_loss, metric):
# Model implementation
pass
Benefits:
Core data structure representing molecular or protein graphs.
Attributes:
num_node: Number of nodesnum_edge: Number of edgesnode_feature: Node feature tensor [num_node, feature_dim]edge_feature: Edge feature tensor [num_edge, feature_dim]edge_list: Edge connectivity [num_edge, 2 or 3]num_relation: Number of edge types (for multi-relational)Methods:
node_mask(mask): Select subset of nodesedge_mask(mask): Select subset of edgesundirected(): Make graph undirecteddirected(): Make graph directedBatching:
Specialized graph for molecules.
Additional Attributes:
atom_type: Atomic numbersbond_type: Bond types (single, double, triple, aromatic)formal_charge: Atomic formal chargesexplicit_hs: Explicit hydrogen countsMethods:
from_smiles(smiles): Create from SMILES stringfrom_molecule(mol): Create from RDKit moleculeto_smiles(): Convert to SMILESto_molecule(): Convert to RDKit moleculeion_to_molecule(): Neutralize chargesExample:
from torchdrug import data
# From SMILES
mol = data.Molecule.from_smiles("CCO")
# Atom features
print(mol.atom_type) # [6, 6, 8] (C, C, O)
print(mol.bond_type) # [1, 1] (single bonds)
Specialized graph for proteins.
Additional Attributes:
residue_type: Amino acid typesatom_name: Atom names (CA, CB, etc.)atom_type: Atomic numbersresidue_number: Residue numberingchain_id: Chain identifiersMethods:
from_pdb(pdb_file): Load from PDB filefrom_sequence(sequence): Create from sequenceto_pdb(pdb_file): Save to PDB fileGraph Construction:
Example:
from torchdrug import data
# Load protein
protein = data.Protein.from_pdb("1a3x.pdb")
# Build graph with multiple edge types
graph = protein.residue_graph(
node_position="ca", # Use Cα positions
edge_types=["sequential", "radius"] # Sequential + spatial edges
)
Efficient batching structure for heterogeneous graphs.
Purpose:
Attributes:
num_nodes: List of node counts per graphnum_edges: List of edge counts per graphgraph_ind: Graph index for each nodeUse Cases:
All TorchDrug models follow a standardized interface:
def forward(self, graph, input, all_loss=None, metric=None):
"""
Args:
graph (Graph): Batch of graphs
input (Tensor): Node input features
all_loss (Tensor, optional): Accumulator for losses
metric (dict, optional): Dictionary for metrics
Returns:
dict: Output dictionary with representation keys
"""
# Model computation
output = self.layers(graph, input)
return {
"node_feature": output,
"graph_feature": graph_pooling(output)
}
Key Points:
graph: Batched graph structureinput: Node features [num_node, input_dim]all_loss: Accumulated loss (for multi-task)metric: Shared metric dictionaryAll models must define:
input_dim: Expected input feature dimensionoutput_dim: Output representation dimensionPurpose:
Example:
class CustomModel(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.input_dim = input_dim
self.output_dim = hidden_dim
# ... layers ...
All tasks implement these methods:
class CustomTask(tasks.Task):
def preprocess(self, train_set, valid_set, test_set):
"""Dataset-specific preprocessing (optional)"""
pass
def predict(self, batch):
"""Generate predictions for a batch"""
graph, label = batch
output = self.model(graph, graph.node_feature)
pred = self.mlp(output["graph_feature"])
return pred
def target(self, batch):
"""Extract ground truth labels"""
graph, label = batch
return label
def forward(self, batch):
"""Compute training loss"""
pred = self.predict(batch)
target = self.target(batch)
loss = self.criterion(pred, target)
return loss
def evaluate(self, pred, target):
"""Compute evaluation metrics"""
metrics = {}
metrics["auroc"] = compute_auroc(pred, target)
metrics["auprc"] = compute_auprc(pred, target)
return metrics
Typical Task Structure:
Example:
from torchdrug import tasks, models
# Representation model
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
# Task wraps model with prediction head
task = tasks.PropertyPrediction(
model=model,
task=["task1", "task2"], # Multi-task
criterion="bce",
metric=["auroc", "auprc"],
num_mlp_layer=2
)
import torch
from torch.utils.data import DataLoader
from torchdrug import core, models, tasks, datasets
# 1. Load dataset
dataset = datasets.BBBP("~/datasets/")
train_set, valid_set, test_set = dataset.split()
# 2. Create data loaders
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=32)
# 3. Define model and task
model = models.GIN(input_dim=dataset.node_feature_dim,
hidden_dims=[256, 256, 256])
task = tasks.PropertyPrediction(model, task=dataset.tasks,
criterion="bce", metric=["auroc", "auprc"])
# 4. Setup optimizer
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
# 5. Training loop
for epoch in range(100):
# Train
task.train()
for batch in train_loader:
loss = task(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Validate
task.eval()
preds, targets = [], []
for batch in valid_loader:
pred = task.predict(batch)
target = task.target(batch)
preds.append(pred)
targets.append(target)
preds = torch.cat(preds)
targets = torch.cat(targets)
metrics = task.evaluate(preds, targets)
print(f"Epoch {epoch}: {metrics}")
TorchDrug tasks are compatible with PyTorch Lightning:
import pytorch_lightning as pl
class LightningWrapper(pl.LightningModule):
def __init__(self, task):
super().__init__()
self.task = task
def training_step(self, batch, batch_idx):
loss = self.task(batch)
return loss
def validation_step(self, batch, batch_idx):
pred = self.task.predict(batch)
target = self.task.target(batch)
return {"pred": pred, "target": target}
def validation_epoch_end(self, outputs):
preds = torch.cat([o["pred"] for o in outputs])
targets = torch.cat([o["target"] for o in outputs])
metrics = self.task.evaluate(preds, targets)
self.log_dict(metrics)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
Classification:
"bce": Binary cross-entropy"ce": Cross-entropy (multi-class)Regression:
"mse": Mean squared error"mae": Mean absolute errorKnowledge Graph:
"bce": Binary classification of triples"ce": Cross-entropy ranking loss"margin": Margin-based rankingclass CustomTask(tasks.Task):
def forward(self, batch):
pred = self.predict(batch)
target = self.target(batch)
# Custom loss computation
loss = custom_loss_function(pred, target)
return loss
Classification:
Regression:
Ranking (Knowledge Graph):
For multi-label or multi-task:
from torchdrug import transforms
# Add virtual node connected to all atoms
transform1 = transforms.VirtualNode()
# Add virtual edges
transform2 = transforms.VirtualEdge()
# Compose transforms
transform = transforms.Compose([transform1, transform2])
dataset = datasets.BBBP("~/datasets/", transform=transform)
# Add edges based on spatial proximity
transform = transforms.TruncateProtein(max_length=500)
dataset = datasets.Fold("~/datasets/", transform=transform)
torch.use_deterministic_algorithms(True)core.Configurableinput_dim and output_dimnvidia-smiTrain single model on multiple related tasks:
task = tasks.PropertyPrediction(
model,
task=["task1", "task2", "task3"],
criterion="bce",
metric=["auroc"],
task_weight=[1.0, 1.0, 2.0] # Weight task 3 more
)
Use pre-training tasks:
AttributeMasking: Mask node featuresEdgePrediction: Predict edge existenceContextPrediction: Contrastive learningExtend TorchDrug with custom GNN layers:
from torchdrug import layers
class CustomConv(layers.MessagePassingBase):
def message(self, graph, input):
# Custom message function
pass
def aggregate(self, graph, message):
# Custom aggregation
pass
def combine(self, input, update):
# Custom combination
pass
input_dim and output_dim: Models won't compose