Back to Claude Scientific Skills

Models

scientific-skills/pyhealth/references/models.md

2.38.04.4 KB
Original Source

Models

All PyHealth models are PyTorch modules with a unified constructor: they take a SampleDataset (the output of base.set_task(...)) as the first argument, plus model-specific hyperparameters. The model auto-configures input/output dimensions from the dataset's schema — you don't wire layers by hand.

python
model = Transformer(dataset=samples, hidden_dim=128)

If you pass a BaseDataset instead of a SampleDataset, the model can't introspect schemas and will error or misbehave.

Choosing a model

Pick by data shape and task type, not by recency. The "newest" model is rarely the right answer.

EHR sequential codes (diagnoses, procedures, prescriptions across visits)

ModelWhen to pick it
TransformerStrong default. Long visit histories, attention over codes.
RNN (LSTM/GRU)Smaller datasets; faster than Transformer; sensible baseline.
RETAINWhen interpretability matters — produces visit-level and code-level attention weights.
DeeprCNN-over-codes; readmission-style tasks.
TCNLong-range temporal patterns where causality matters.
AdaCareAdaptive feature extraction across irregular time intervals.
ConCareContextualized representations across visits.
StageNetDisease-progression staging from irregular vitals.
EHRMambaState-space alternative to Transformer for long sequences.

Drug recommendation (multilabel)

ModelWhen to pick it
GAMENetDrug-rec baseline with memory networks; pairs with DrugRecommendation* tasks.
SafeDrugModels drug-drug interactions / safety constraints via molecular structure.
MICRONPredicts medication change between visits, not the full set.
MoleRecSubstructure-aware molecular drug recommendation.

Static / tabular features

ModelWhen to pick it
LogisticRegressionStrong, fast baseline. Always run this first.
MLPStatic numeric vectors, no sequence order.

Imaging / signals

ModelWhen to pick it
CNNGeneric convolutional baseline for images and 1D signals.
ContraWRContrastive learning for biosignals.
SparcNetSparse signal prediction (seizure, sleep staging).
BIOTBiosignal transformer.

Graph-structured data

ModelWhen to pick it
GNNGeneric graph neural net baseline.
GraphCareEHR codes augmented with external medical knowledge graphs (UMLS/SNOMED).
GRASPPatient-similarity graph representations.

Text

ModelWhen to pick it
TransformersModelPretrained HuggingFace transformer (BERT-family) — clinical notes, transcripts.
TransformerDeIDDe-identification NER head on top of a transformer.
MedLinkMedical entity linking.

Generative / representation

ModelWhen to pick it
VAESynthetic EHR generation, anomaly detection.
GANSynthetic data with adversarial training.

Reinforcement learning

ModelWhen to pick it
AgentTreatment recommendation framed as RL.

Multimodal

ModelWhen to pick it
MultimodalRNNMix of sequential codes and static tensors in one sample.

Common arguments

Most clinical models accept:

  • dataset — the SampleDataset (required, positional)
  • hidden_dim — embedding/hidden width (default ≈128)
  • embedding_dim — separate embedding width if exposed
  • dropout — dropout rate
  • num_layers — for RNN/Transformer/TCN

Refer to the docstring (help(Transformer)) for model-specific knobs (e.g., rnn_type for RNN, num_filters for CNN, latent_dim for VAE).

When starting on a new task, work up the model ladder rather than jumping to the most exotic option:

  1. LogisticRegression — sanity check + floor.
  2. MLP if features are static, RNN if sequential.
  3. Transformer — strong general default.
  4. Specialized model (RETAIN, GAMENet, StageNet, etc.) — only if the task has a property that motivates it (interpretability, drug structure, irregular time, etc.).

Stop as soon as a model does the job. A working Transformer beats a half-debugged MoleRec.

Custom models

Subclass BaseModel if nothing fits. The dataset object provides feature extractors via dataset.input_processors — use them to keep tokenization consistent with the rest of the pipeline rather than rolling custom encoders.