Back to Claude Scientific Skills

Tasks

scientific-skills/pyhealth/references/tasks.md

2.38.05.1 KB
Original Source

Tasks

A task turns a BaseDataset (raw patients) into a SampleDataset (supervised samples). Tasks define input_schema (which fields go to the model) and output_schema (the label).

python
samples = base.set_task(MortalityPredictionMIMIC3())

Tasks are dataset-specific. Picking the wrong combo (e.g., MortalityPredictionMIMIC3 on a MIMIC-IV dataset) will fail. Match the suffix.

Task → Dataset compatibility matrix

Mortality prediction (binary)

Task classDataset
MortalityPredictionMIMIC3MIMIC-III
MortalityPredictionMIMIC4MIMIC-IV
InHospitalMortalityMIMIC4MIMIC-IV (in-hospital, narrower than next-visit)
MortalityPredictionEICU, MortalityPredictionEICU2eICU
MortalityPredictionOMOPOMOP
MortalityPredictionStageNetMIMIC4MIMIC-IV (paired with StageNet model)

Readmission prediction (binary)

Task classDataset
ReadmissionPredictionMIMIC3MIMIC-III
ReadmissionPredictionMIMIC4MIMIC-IV
ReadmissionPredictionEICUeICU
ReadmissionPredictionOMOPOMOP

Length-of-stay prediction (multiclass)

Task classDataset
LengthOfStayPredictionMIMIC3MIMIC-III
LengthOfStayPredictionMIMIC4MIMIC-IV
LengthOfStayPredictioneICUeICU
LengthOfStayPredictionOMOPOMOP

LOS is bucketed into discrete classes (e.g., <1 day, 1-2 days, …, >14 days). Treat as multiclass classification.

Drug recommendation (multilabel)

Task classDataset
DrugRecommendationMIMIC3MIMIC-III
DrugRecommendationMIMIC4MIMIC-IV
DrugRecommendationEICUeICU

Multilabel = each visit has a set of drugs prescribed; predict the set. Use models with drug-aware structure (GAMENet, SafeDrug, MICRON, MoleRec) or fall back to Transformer / RNN.

Specialized clinical

Task classWhat it predicts
DKAPredictionMIMIC4Diabetic ketoacidosis risk
MIMIC3ICD9CodingICD-9 codes for a discharge note (multilabel)

Sleep & EEG

Task classDatasetPredicts
SleepStagingSleepEDFSleepEDFSleep stage (multiclass)
EEGEventsTUEVTUEVEEG events
EEGAbnormalTUABTUABEEG abnormality (binary)

Imaging

Task classDatasetPredicts
COVID19CXRClassificationCOVID19-CXRCOVID-19 (multiclass)
ChestXray14BinaryClassificationChestX-ray14Single-disease binary
ChestXray14MultilabelClassificationChestX-ray14Multi-disease multilabel
cardiology_isAR_fn, _isBBBFB_fn, _isAD_fn, _isCD_fn, _isWA_fnCardiologyVarious ECG abnormalities

Text / NLP

Task classDatasetPredicts
MedicalTranscriptionsClassificationMedical TranscriptionsSpecialty/category
DeIDNERTaskPhysioNet DeIDDe-identification NER

Genomics

Task classDatasetPredicts
VariantClassificationClinVarClinVarVariant pathogenicity
MutationPathogenicityPredictionCOSMICMutation pathogenicity
CancerSurvivalPredictionTCGA-PRADCancer survival
CancerMutationBurdenTCGA-PRADTumor mutation burden

Benchmarks

Task classUse
BenchmarkEHRShotMulti-task EHR few-shot benchmark on EHRShot

Picking the right monitor metric

The Trainer.train(monitor=...) argument decides which checkpoint gets saved. Match it to the task type:

Task typeGood monitor choices
Binary (mortality, readmission, EEG abnormal)"pr_auc", "roc_auc", "f1"
Multiclass (LOS, sleep staging, COVID CXR)"accuracy", "f1_macro", "cohen_kappa"
Multilabel (drug rec, ICD coding, ChestXray14)"pr_auc_samples", "jaccard_samples", "f1_samples"

Mismatched monitor (e.g., "pr_auc" on a multiclass task) silently saves the wrong epoch.

Custom tasks

When no built-in task fits, subclass BaseTask:

python
from pyhealth.tasks import BaseTask

class MyTask(BaseTask):
    task_name = "MyTask"
    input_schema = {"diagnoses": "sequence", "procedures": "sequence"}
    output_schema = {"label": "binary"}

    def __call__(self, patient):
        # Iterate the patient's visits, decide which become samples,
        # extract features, compute the label, and return a list of dicts.
        samples = []
        for i, visit in enumerate(patient.visits):
            if i == len(patient.visits) - 1:
                continue  # need at least one future visit for the label
            samples.append({
                "patient_id": patient.patient_id,
                "visit_id": visit.visit_id,
                "diagnoses": visit.get_code_list("DIAGNOSES_ICD"),
                "procedures": visit.get_code_list("PROCEDURES_ICD"),
                "label": int(self._compute_label(patient, visit)),
            })
        return samples

    def _compute_label(self, patient, visit): ...

The __call__ is invoked once per patient. Returning [] for a patient excludes them from the SampleDataset. The schema strings ("sequence", "binary", "multilabel", "multiclass", "regression") tell PyHealth's processors how to handle each field.