docs/ddd/training-pipeline-domain-model.md
The Training & ML Pipeline is the subsystem of WiFi-DensePose that turns raw public CSI datasets into a trained pose estimation model and its downstream derivatives: contrastive embeddings, domain-generalized weights, and deterministic proof bundles. It is the bridge between research data and deployable inference.
This document defines the system using Domain-Driven Design (DDD): bounded contexts that own their data and rules, aggregate roots that enforce invariants, value objects that carry meaning, and domain events that connect everything. The goal is to make the pipeline's structure match the physics and mathematics it implements -- so that anyone reading the code (or an AI agent modifying it) understands why each piece exists, not just what it does.
Bounded Contexts:
| # | Context | Responsibility | Key ADRs | Code |
|---|---|---|---|---|
| 1 | Dataset Management | Load, validate, normalize, and preprocess training data from MM-Fi and Wi-Pose | ADR-015 | train/src/dataset.rs, train/src/subcarrier.rs |
| 2 | Model Architecture | Define the neural network, forward pass, attention mechanisms, and spatial decoding | ADR-016, ADR-020 | train/src/model.rs, train/src/graph_transformer.rs |
| 3 | Training Orchestration | Run the training loop, compute composite loss, checkpoint, and verify deterministic proofs | ADR-015, ADR-016 | train/src/trainer.rs, train/src/losses.rs, train/src/metrics.rs, train/src/proof.rs |
| 4 | Embedding & Transfer | Produce AETHER contrastive embeddings, MERIDIAN domain-generalized features, and LoRA adapters | ADR-024, ADR-027 | train/src/embedding.rs, train/src/domain.rs, train/src/sona.rs |
All code paths shown are relative to rust-port/wifi-densepose-rs/crates/wifi-densepose- unless otherwise noted.
| Term | Definition |
|---|---|
| Training Run | A complete training session: configuration, epoch loop, checkpoint history, and final model weights |
| Epoch | One full pass through the training dataset; produces train loss and validation metrics |
| Checkpoint | A snapshot of model weights at a given epoch, identified by SHA-256 hash and validation PCK |
| CSI Sample | A single observation: amplitude + phase tensors, ground-truth keypoints, and visibility flags |
| Subcarrier Interpolation | Resampling CSI from source subcarrier count to the canonical 56 (114->56 for MM-Fi, 30->56 for Wi-Pose) |
| Teacher-Student | Training regime where a camera-based RGB model generates pseudo-labels; at inference the camera is removed |
| Pseudo-Label | DensePose UV surface coordinates generated by Detectron2 from paired RGB frames |
| [email protected] | Percentage of Correct Keypoints within 20% of torso diameter; primary accuracy metric |
| OKS | Object Keypoint Similarity; per-keypoint Gaussian-weighted distance used in COCO evaluation |
| MPJPE | Mean Per Joint Position Error in millimeters; 3D accuracy metric |
| Hungarian Assignment | Bipartite matching of predicted persons to ground-truth using min-cost assignment |
| Dynamic Min-Cut | Subpolynomial O(n^1.5 log n) person-to-GT assignment maintained across frames |
| Compressed CSI Buffer | Tiered-quantization temporal window: hot frames at 8-bit, warm at 5/7-bit, cold at 3-bit |
| Proof Verification | Deterministic check: fixed seed -> N training steps -> loss decreases AND SHA-256 hash matches |
| AETHER Embedding | 128-dim L2-normalized contrastive vector from the CsiToPoseTransformer backbone |
| InfoNCE Loss | Contrastive loss that pushes same-identity embeddings together and different-identity apart |
| HNSW Index | Hierarchical Navigable Small World graph for approximate nearest-neighbor embedding search |
| Domain Factorizer | Splits latent features into pose-invariant (h_pose) and environment-specific (h_env) components |
| Gradient Reversal Layer | Identity in forward pass; multiplies gradient by -lambda in backward pass to force domain invariance |
| GRL Lambda | Adversarial weight annealed from 0.0 to 1.0 over the first 20 epochs |
| FiLM Conditioning | Feature-wise Linear Modulation: gamma * features + beta, conditioned on geometry encoding |
| Hardware Normalizer | Resamples CSI from any chipset to canonical 56 subcarriers with z-score amplitude normalization |
| LoRA Adapter | Low-Rank Adaptation weights (rank r, alpha) for few-shot environment-specific fine-tuning |
| Rapid Adaptation | 10-second unlabeled calibration producing a per-room LoRA adapter via contrastive test-time training |
Responsibility: Load raw CSI data from public datasets (MM-Fi, Wi-Pose), validate structural invariants, resample subcarriers to the canonical 56, apply phase sanitization, and present typed samples to the training loop. Memory efficiency via tiered temporal compression.
+----------------------------------------------------------+
| Dataset Management Context |
+----------------------------------------------------------+
| |
| +---------------+ +---------------+ |
| | MM-Fi Loader | | Wi-Pose | |
| | (.npy files, | | Loader | |
| | 114 sub, | | (.mat files, | |
| | 40 subjects)| | 30 sub, | |
| +-------+-------+ | 12 subjects)| |
| | +-------+-------+ |
| | | |
| +--------+-----------+ |
| v |
| +----------------+ |
| | Subcarrier | |
| | Interpolator | |
| | (114->56 or | |
| | 30->56) | |
| +--------+-------+ |
| v |
| +----------------+ |
| | Phase | |
| | Sanitizer | |
| | (SOTA algs | |
| | from signal) | |
| +--------+-------+ |
| v |
| +----------------+ |
| | Compressed CSI |--> CsiSample |
| | Buffer | |
| | (tiered quant) | |
| +----------------+ |
| |
+----------------------------------------------------------+
Aggregates:
MmFiDataset (Aggregate Root) -- Manages the MM-Fi data lifecycleWiPoseDataset (Aggregate Root) -- Manages the Wi-Pose data lifecycleValue Objects:
CsiSample -- Single observation with amplitude, phase, keypoints, visibilitySubcarrierConfig -- Source count, target count, interpolation methodDatasetSplit -- Train / Validation / Test subject partitioningCompressedCsiBuffer -- Tiered temporal window backed by TemporalTensorCompressorDomain Services:
SubcarrierInterpolationService -- Resamples subcarriers via sparse least-squares or linear fallbackPhaseSanitizationService -- Applies SpotFi / MUSIC phase correction from wifi-densepose-signalTeacherLabelService -- Runs Detectron2 on paired RGB frames to produce DensePose UV pseudo-labelsHardwareNormalizerService -- Z-score normalization + chipset-invariant phase sanitizationRuVector Integration:
ruvector-solver -> NeumannSolver for sparse O(sqrt(n)) subcarrier interpolation (114->56)ruvector-temporal-tensor -> TemporalTensorCompressor for 50-75% memory reduction in CSI windowsResponsibility: Define the WiFiDensePoseModel: CSI embedding, cross-attention between keypoint queries and CSI features, GNN message passing, attention-gated modality fusion, and spatial decoding heads for keypoints and DensePose UV.
+----------------------------------------------------------+
| Model Architecture Context |
+----------------------------------------------------------+
| |
| +---------------+ +---------------+ |
| | CSI Embed | | Keypoint | |
| | (Linear | | Queries | |
| | 56 -> d) | | (17 learned | |
| +-------+-------+ | embeddings) | |
| | +-------+-------+ |
| | | |
| +--------+-----------+ |
| v |
| +----------------+ |
| | Cross-Attention| |
| | (Q=queries, | |
| | K,V=csi) | |
| +--------+-------+ |
| v |
| +----------------+ |
| | GNN Stack | |
| | (2-layer GCN | |
| | skeleton | |
| | adjacency) | |
| +--------+-------+ |
| v |
| body_part_features [17 x d_model] |
| | |
| +-------+--------+--------+ |
| v v v v |
| +----------+ +------+ +-----+ +-------+ |
| | Modality | | xyz | | UV | |Spatial| |
| | Transl. | | Head | | Head| |Attn | |
| | (attn | | | | | |Decoder| |
| | mincut) | | | | | | | |
| +----------+ +------+ +-----+ +-------+ |
| |
+----------------------------------------------------------+
Aggregates:
WiFiDensePoseModel (Aggregate Root) -- The complete model graphEntities:
ModalityTranslator -- Attention-gated CSI fusion using min-cutCsiToPoseTransformer -- Cross-attention + GNN backboneKeypointHead -- Regresses 17 x (x, y, z, confidence) from body_part_featuresDensePoseHead -- Predicts body part labels and UV surface coordinatesValue Objects:
ModelConfig -- Architecture hyperparameters (d_model, n_heads, n_gnn_layers)AttentionOutput -- Attended values + gating result from min-cut attentionBodyPartFeatures -- [17 x d_model] intermediate representationDomain Services:
AttentionGatingService -- Applies attn_mincut to prune irrelevant antenna pathsSpatialDecodingService -- Graph-based spatial attention among feature map locationsRuVector Integration:
ruvector-attn-mincut -> attn_mincut for antenna-path gating in ModalityTranslatorruvector-attention -> ScaledDotProductAttention for spatial decoder long-range dependenciesResponsibility: Run the training loop across epochs, compute the composite loss (keypoint MSE + DensePose part CE + UV Smooth L1 + transfer MSE), evaluate validation metrics ([email protected], OKS, MPJPE), manage checkpoints, and verify deterministic proof correctness.
+----------------------------------------------------------+
| Training Orchestration Context |
+----------------------------------------------------------+
| |
| +---------------+ +---------------+ |
| | Training Loop | | Loss Computer | |
| | (epoch iter, | | (composite: | |
| | batch fwd/ | | kp_mse + | |
| | bwd, optim) | | part_ce + | |
| +-------+-------+ | uv_l1 + | |
| | | transfer) | |
| | +-------+-------+ |
| +--------+-----------+ |
| v |
| +----------------+ |
| | Metric | |
| | Evaluator | |
| | (PCK, OKS, | |
| | MPJPE, | |
| | Hungarian) | |
| +--------+-------+ |
| v |
| +-------------+-------------+ |
| v v |
| +----------------+ +----------------+ |
| | Checkpoint | | Proof Verifier | |
| | Manager | | (fixed seed, | |
| | (best-by-PCK, | | 50 steps, | |
| | SHA-256 hash) | | loss + hash) | |
| +----------------+ +----------------+ |
| |
+----------------------------------------------------------+
Aggregates:
TrainingRun (Aggregate Root) -- The complete training sessionEntities:
CheckpointManager -- Persists and selects model snapshotsProofVerifier -- Deterministic verification against stored hashesValue Objects:
TrainingConfig -- Epochs, batch_size, learning_rate, loss_weights, optimizer paramsCheckpoint -- Epoch number, model weights SHA-256, validation PCK at that epochLossWeights -- Relative weights for each loss componentCompositeTrainingLoss -- Combined scalar loss with per-component breakdownOksScore -- Per-keypoint Object Keypoint Similarity with sigma valuesPckScore -- Percentage of Correct Keypoints at threshold 0.2MpjpeScore -- Mean Per Joint Position Error in millimetersProofResult -- Seed, steps, loss_decreased flag, hash_matches flagDomain Services:
LossComputationService -- Computes composite loss from model outputs and ground truthMetricEvaluationService -- Computes PCK, OKS, MPJPE over validation setHungarianAssignmentService -- Bipartite matching for multi-person evaluationDynamicPersonMatcherService -- Frame-persistent assignment via ruvector-mincutProofVerificationService -- Fixed-seed training + SHA-256 verificationRuVector Integration:
ruvector-mincut -> DynamicMinCut for O(n^1.5 log n) multi-person assignment in metricshungarian_assignment kept for single-frame static matching in proof verificationResponsibility: Produce AETHER contrastive embeddings from the model backbone, train domain-adversarial features via MERIDIAN, manage the HNSW embedding index for re-ID and fingerprinting, and generate LoRA adapters for few-shot environment adaptation.
+----------------------------------------------------------+
| Embedding & Transfer Context |
+----------------------------------------------------------+
| |
| body_part_features [17 x d_model] |
| | |
| +--------+-----------+ |
| v v |
| +---------------+ +---------------+ |
| | AETHER | | MERIDIAN | |
| | Projection | | Domain | |
| | Head | | Factorizer | |
| | (MeanPool -> | | (PoseEncoder | |
| | fc -> 128d) | | + EnvEncoder)| |
| +-------+-------+ +-------+-------+ |
| | | |
| v v |
| +---------------+ +---------------+ |
| | InfoNCE Loss | | Gradient | |
| | + Hard Neg | | Reversal | |
| | Mining (HNSW) | | Layer (GRL) | |
| +-------+-------+ +-------+-------+ |
| | | |
| v v |
| +---------------+ +---------------+ |
| | Embedding | | Geometry | |
| | Index (HNSW) | | Encoder + | |
| | (fingerprint | | FiLM Cond. | |
| | store) | | (zero-shot) | |
| +---------------+ +-------+-------+ |
| | |
| v |
| +---------------+ |
| | Rapid Adapt. | |
| | (LoRA + TTT, | |
| | 10-sec cal.) | |
| +---------------+ |
| |
+----------------------------------------------------------+
Aggregates:
EmbeddingIndex (Aggregate Root) -- HNSW-indexed store of AETHER fingerprintsDomainAdaptationState (Aggregate Root) -- Tracks GRL lambda, domain classifier accuracy, factorization qualityEntities:
ProjectionHead -- MLP mapping body_part_features to 128-dim embedding spaceDomainFactorizer -- Splits features into h_pose and h_envDomainClassifier -- Classifies domain from h_pose (trained adversarially via GRL)GeometryEncoder -- Fourier positional encoding + DeepSets for AP positionsLoraAdapter -- Low-rank adaptation weights for environment-specific fine-tuningValue Objects:
AetherEmbedding -- 128-dim L2-normalized contrastive vectorFingerprintType -- ReIdentification / RoomFingerprint / PersonFingerprintDomainLabel -- Environment identifier for adversarial trainingGrlSchedule -- Lambda annealing parameters (max_lambda, warmup_epochs)GeometryInput -- AP positions in meters relative to room originFilmParameters -- Gamma (scale) and beta (shift) vectors from geometry conditioningLoraConfig -- Rank, alpha, target layersAdaptationLoss -- ContrastiveTTT / EntropyMin / CombinedDomain Services:
ContrastiveLossService -- Computes InfoNCE loss with temperature scalingHardNegativeMiningService -- HNSW k-NN search for difficult negative pairsDomainAdversarialService -- Manages GRL annealing and domain classificationGeometryConditioningService -- Encodes AP layout and produces FiLM parametersVirtualDomainAugmentationService -- Generates synthetic environment shifts for training diversityRapidAdaptationService -- Produces LoRA adapter from 10-second unlabeled calibrationpub struct TrainingRun {
/// Unique run identifier
pub id: TrainingRunId,
/// Full training configuration
pub config: TrainingConfig,
/// Datasets loaded for this run
pub datasets: Vec<DatasetHandle>,
/// Ordered history of per-epoch metrics
pub epoch_history: Vec<EpochRecord>,
/// Best checkpoint by validation PCK
pub best_checkpoint: Option<Checkpoint>,
/// Current epoch (0-indexed)
pub current_epoch: usize,
/// Run status
pub status: RunStatus,
/// Proof verification result (if run)
pub proof_result: Option<ProofResult>,
}
pub enum RunStatus {
Initializing,
Training,
Completed,
Failed { reason: String },
ProofVerified,
}
Invariants:
Trainingbest_checkpoint is updated only when a new epoch's validation PCK exceeds all prior epochsproof_result can only be set once and is immutable after verificationpub struct MmFiDataset {
/// Root directory containing .npy files
pub data_root: PathBuf,
/// Subject IDs in this split
pub subject_ids: Vec<u32>,
/// Number of action classes
pub n_actions: usize, // 27
/// Source subcarrier count
pub source_subcarriers: usize, // 114
/// Target subcarrier count after interpolation
pub target_subcarriers: usize, // 56
/// Antenna configuration: 1 TX x 3 RX
pub antenna_pairs: usize, // 3
/// Sampling rate in Hz
pub sample_rate_hz: f32, // 100.0
/// Temporal window size (frames per sample)
pub window_frames: usize, // 10
/// Compressed buffer for memory-efficient storage
pub buffer: CompressedCsiBuffer,
/// Total loaded samples
pub n_samples: usize,
}
pub struct WiPoseDataset {
/// Root directory containing .mat files
pub data_root: PathBuf,
/// Subject IDs in this split
pub subject_ids: Vec<u32>,
/// Source subcarrier count
pub source_subcarriers: usize, // 30
/// Target subcarrier count after zero-padding
pub target_subcarriers: usize, // 56
/// Antenna configuration: 3 TX x 3 RX
pub antenna_pairs: usize, // 9
/// Keypoint count (18 AlphaPose, mapped to 17 COCO)
pub source_keypoints: usize, // 18
/// Compressed buffer
pub buffer: CompressedCsiBuffer,
/// Total loaded samples
pub n_samples: usize,
}
pub struct WiFiDensePoseModel {
/// CSI embedding layer: Linear(56, d_model)
pub csi_embed: Linear,
/// Learned keypoint query embeddings [17 x d_model]
pub keypoint_queries: Tensor,
/// Cross-attention: Q=queries, K,V=csi_embed
pub cross_attention: MultiHeadAttention,
/// GNN message passing on skeleton graph
pub gnn_stack: GnnStack,
/// Modality translator with attention-gated fusion
pub modality_translator: ModalityTranslator,
/// Keypoint regression head
pub keypoint_head: KeypointHead,
/// DensePose UV prediction head
pub densepose_head: DensePoseHead,
/// Spatial attention decoder
pub spatial_decoder: SpatialAttentionDecoder,
/// Model dimensionality
pub d_model: usize, // 64
}
pub struct EmbeddingIndex {
/// HNSW graph for approximate nearest-neighbor search
pub hnsw: HnswIndex,
/// Stored embeddings with metadata
pub entries: Vec<EmbeddingEntry>,
/// Embedding dimensionality
pub dim: usize, // 128
/// Number of indexed embeddings
pub count: usize,
/// HNSW construction parameters
pub ef_construction: usize, // 200
pub m_connections: usize, // 16
}
pub struct EmbeddingEntry {
pub id: EmbeddingId,
pub embedding: Vec<f32>, // [128], L2-normalized
pub fingerprint_type: FingerprintType,
pub source_domain: Option<DomainLabel>,
pub created_at: u64,
}
pub enum FingerprintType {
ReIdentification,
RoomFingerprint,
PersonFingerprint,
}
pub struct CsiSample {
/// Amplitude tensor [n_antenna_pairs x n_subcarriers x n_time_frames]
pub amplitude: Vec<f32>,
/// Phase tensor [n_antenna_pairs x n_subcarriers x n_time_frames]
pub phase: Vec<f32>,
/// Ground-truth 3D keypoints [17 x 3] (x, y, z in meters)
pub keypoints: [[f32; 3]; 17],
/// Per-keypoint visibility flags
pub visibility: [f32; 17],
/// DensePose UV pseudo-labels (optional, from teacher model)
pub densepose_uv: Option<DensePoseLabels>,
/// Domain label for adversarial training
pub domain_label: Option<DomainLabel>,
/// Hardware source type
pub hardware_type: HardwareType,
}
pub struct TrainingConfig {
/// Number of training epochs
pub epochs: usize,
/// Mini-batch size
pub batch_size: usize,
/// Initial learning rate
pub learning_rate: f64, // 1e-3
/// Learning rate schedule: step decay at these epochs
pub lr_decay_epochs: Vec<usize>, // [40, 80]
/// Learning rate decay factor
pub lr_decay_factor: f64, // 0.1
/// Loss component weights
pub loss_weights: LossWeights,
/// Optimizer (Adam)
pub optimizer: OptimizerConfig,
/// Validation subject IDs (MM-Fi: 33-40)
pub val_subjects: Vec<u32>,
/// Random seed for reproducibility
pub seed: u64,
/// Enable MERIDIAN domain-adversarial training
pub meridian_enabled: bool,
/// Enable AETHER contrastive learning
pub aether_enabled: bool,
}
pub struct LossWeights {
/// Keypoint heatmap MSE weight
pub keypoint_mse: f32, // 1.0
/// DensePose body part cross-entropy weight
pub densepose_part_ce: f32, // 0.5
/// DensePose UV Smooth L1 weight
pub uv_smooth_l1: f32, // 0.5
/// Teacher-student transfer MSE weight
pub transfer_mse: f32, // 0.2
/// AETHER contrastive loss weight (ADR-024)
pub contrastive: f32, // 0.1
/// MERIDIAN domain adversarial weight (ADR-027)
pub domain_adversarial: f32, // annealed 0.0 -> 1.0
}
pub struct Checkpoint {
/// Epoch at which this checkpoint was saved
pub epoch: usize,
/// SHA-256 hash of serialized model weights
pub weights_hash: String,
/// Validation [email protected] at this epoch
pub validation_pck: f64,
/// Validation OKS at this epoch
pub validation_oks: f64,
/// File path to saved weights
pub path: PathBuf,
/// Timestamp
pub created_at: u64,
}
pub struct ProofResult {
/// Seed used for model initialization
pub model_seed: u64, // MODEL_SEED = 0
/// Seed used for proof data generation
pub proof_seed: u64, // PROOF_SEED = 42
/// Number of training steps in proof
pub steps: usize, // 50
/// Whether loss decreased monotonically
pub loss_decreased: bool,
/// Whether final weights hash matches stored expected hash
pub hash_matches: bool,
/// The computed SHA-256 hash
pub computed_hash: String,
/// The expected SHA-256 hash (from file)
pub expected_hash: String,
}
pub struct LoraAdapter {
/// Low-rank decomposition rank
pub rank: usize, // 4
/// LoRA alpha scaling factor
pub alpha: f32, // 1.0
/// Per-layer weight matrices (A and B for each adapted layer)
pub weights: Vec<LoraLayerWeights>,
/// Source domain this adapter was calibrated for
pub source_domain: DomainLabel,
/// Calibration duration in seconds
pub calibration_duration_secs: f32,
/// Number of calibration frames used
pub calibration_frames: usize,
}
pub struct LoraLayerWeights {
/// Layer name in the model
pub layer_name: String,
/// Down-projection: [d_model x rank]
pub a: Vec<f32>,
/// Up-projection: [rank x d_model]
pub b: Vec<f32>,
}
pub enum DatasetEvent {
/// Dataset loaded and validated
DatasetLoaded {
dataset_type: DatasetType,
n_samples: usize,
n_subjects: u32,
source_subcarriers: usize,
timestamp: u64,
},
/// Subcarrier interpolation completed for a dataset
SubcarrierInterpolationComplete {
dataset_type: DatasetType,
source_count: usize,
target_count: usize,
method: InterpolationMethod,
timestamp: u64,
},
/// Teacher pseudo-labels generated for a batch
PseudoLabelsGenerated {
n_samples: usize,
n_with_uv: usize,
timestamp: u64,
},
}
pub enum DatasetType {
MmFi,
WiPose,
Synthetic,
}
pub enum InterpolationMethod {
/// ruvector-solver NeumannSolver sparse least-squares
SparseNeumannSolver,
/// Fallback linear interpolation
LinearInterpolation,
/// Wi-Pose zero-padding
ZeroPad,
}
pub enum TrainingEvent {
/// One epoch of training completed
EpochCompleted {
epoch: usize,
train_loss: f64,
val_pck: f64,
val_oks: f64,
val_mpjpe_mm: f64,
learning_rate: f64,
grl_lambda: f32,
timestamp: u64,
},
/// New best checkpoint saved
CheckpointSaved {
epoch: usize,
weights_hash: String,
validation_pck: f64,
path: String,
timestamp: u64,
},
/// Deterministic proof verification completed
ProofVerified {
model_seed: u64,
proof_seed: u64,
steps: usize,
loss_decreased: bool,
hash_matches: bool,
timestamp: u64,
},
/// Training run completed or failed
TrainingRunFinished {
run_id: String,
status: RunStatus,
total_epochs: usize,
best_pck: f64,
best_oks: f64,
timestamp: u64,
},
}
pub enum EmbeddingEvent {
/// New AETHER embedding indexed
EmbeddingIndexed {
embedding_id: String,
fingerprint_type: FingerprintType,
nearest_neighbor_distance: f32,
index_size: usize,
timestamp: u64,
},
/// Hard negative pair discovered during mining
HardNegativeFound {
anchor_id: String,
negative_id: String,
similarity: f32,
timestamp: u64,
},
/// Domain adaptation completed for a target environment
DomainAdaptationComplete {
source_domain: String,
target_domain: String,
pck_before: f64,
pck_after: f64,
adaptation_method: String,
timestamp: u64,
},
/// LoRA adapter generated via rapid calibration
LoraAdapterGenerated {
domain: String,
rank: usize,
calibration_frames: usize,
calibration_seconds: f32,
timestamp: u64,
},
}
CompressedCsiBuffer must preserve signal fidelity within quantization error bounds (hot: <1% error)csi_embed input dimension must equal the canonical 56 subcarrierskeypoint_queries must have exactly 17 entries (one per COCO keypoint)attn_mincut seq_len must equal n_antenna_pairs * n_time_framesstart() is calledbest_checkpoint is updated if and only if current val_pck > all previous val_pck valuesDomainFactorizer output dimensions: h_pose = [17 x 64], h_env = [32]GeometryEncoder must be permutation-invariant with respect to AP ordering (DeepSets guarantee)Resamples CSI subcarriers from source to target count using physically-motivated sparse interpolation.
pub trait SubcarrierInterpolationService {
/// Sparse interpolation via NeumannSolver (O(sqrt(n)), preferred)
fn interpolate_sparse(
&self,
source: &[f32],
source_count: usize,
target_count: usize,
tolerance: f64,
) -> Result<Vec<f32>, InterpolationError>;
/// Linear interpolation fallback (O(n))
fn interpolate_linear(
&self,
source: &[f32],
source_count: usize,
target_count: usize,
) -> Vec<f32>;
/// Zero-pad for Wi-Pose (30 -> 56)
fn zero_pad(
&self,
source: &[f32],
target_count: usize,
) -> Vec<f32>;
}
Computes the composite training loss from model outputs and ground truth.
pub trait LossComputationService {
/// Compute composite loss with per-component breakdown
fn compute(
&self,
predictions: &ModelOutput,
targets: &GroundTruth,
weights: &LossWeights,
) -> CompositeTrainingLoss;
}
pub struct CompositeTrainingLoss {
/// Total weighted loss (scalar for backprop)
pub total: f64,
/// Keypoint heatmap MSE component
pub keypoint_mse: f64,
/// DensePose body part cross-entropy component
pub densepose_part_ce: f64,
/// DensePose UV Smooth L1 component
pub uv_smooth_l1: f64,
/// Teacher-student transfer MSE component
pub transfer_mse: f64,
/// AETHER contrastive loss (if enabled)
pub contrastive: Option<f64>,
/// MERIDIAN domain adversarial loss (if enabled)
pub domain_adversarial: Option<f64>,
}
Evaluates model accuracy on the validation set using standard pose estimation metrics.
pub trait MetricEvaluationService {
/// [email protected]: fraction of keypoints within 20% of torso diameter
fn compute_pck(&self, predictions: &[PosePrediction], targets: &[PoseTarget], threshold: f64) -> PckScore;
/// OKS: Object Keypoint Similarity with per-keypoint sigmas
fn compute_oks(&self, predictions: &[PosePrediction], targets: &[PoseTarget]) -> OksScore;
/// MPJPE: Mean Per Joint Position Error in millimeters
fn compute_mpjpe(&self, predictions: &[PosePrediction], targets: &[PoseTarget]) -> MpjpeScore;
/// Multi-person assignment via Hungarian (static, deterministic)
fn assign_hungarian(&self, pred: &[PosePrediction], gt: &[PoseTarget]) -> Vec<(usize, usize)>;
/// Multi-person assignment via DynamicMinCut (persistent, O(n^1.5 log n))
fn assign_dynamic(&mut self, pred: &[PosePrediction], gt: &[PoseTarget]) -> Vec<(usize, usize)>;
}
Manages the MERIDIAN gradient reversal training regime.
pub trait DomainAdversarialService {
/// Compute GRL lambda for the current epoch
fn grl_lambda(&self, epoch: usize, max_warmup_epochs: usize) -> f32;
/// Forward pass through domain classifier with gradient reversal
fn classify_domain(
&self,
h_pose: &Tensor,
lambda: f32,
) -> Tensor;
/// Compute domain adversarial loss (cross-entropy on domain logits)
fn domain_loss(
&self,
domain_logits: &Tensor,
domain_labels: &Tensor,
) -> f64;
}
+------------------------------------------------------------------+
| Training Pipeline System |
+------------------------------------------------------------------+
| |
| +------------------+ CsiSample +------------------+ |
| | Dataset |-------------->| Training | |
| | Management | | Orchestration | |
| | Context | | Context | |
| +--------+---------+ +--------+-----------+ |
| | | |
| | Publishes | Publishes |
| | DatasetEvent | TrainingEvent |
| v v |
| +------------------------------------------------------+ |
| | Event Bus (Domain Events) | |
| +------------------------------------------------------+ |
| | | |
| v v |
| +------------------+ +------------------+ |
| | Model |<-------------| Embedding & | |
| | Architecture | body_part_ | Transfer | |
| | Context | features | Context | |
| +------------------+ +------------------+ |
| |
+------------------------------------------------------------------+
| UPSTREAM (Conformist) |
| +--------------+ +--------------+ +--------------+ |
| |wifi-densepose| |wifi-densepose| |wifi-densepose| |
| | -signal | | -nn | | -core | |
| | (phase algs,| | (ONNX, | | (CsiFrame, | |
| | SpotFi) | | Candle) | | error) | |
| +--------------+ +--------------+ +--------------+ |
| |
+------------------------------------------------------------------+
| SIBLING (Partnership) |
| +--------------+ +--------------+ +--------------+ |
| | RuvSense | | MAT | | Sensing | |
| | (pose | | (triage, | | Server | |
| | tracker, | | survivor) | | (inference | |
| | field | | | | deployment) | |
| | model) | | | | | |
| +--------------+ +--------------+ +--------------+ |
| |
+------------------------------------------------------------------+
| EXTERNAL (Published Language) |
| +--------------+ +--------------+ +--------------+ |
| | MM-Fi | | Wi-Pose | | Detectron2 | |
| | (NeurIPS | | (NjtechCV | | (teacher | |
| | dataset) | | dataset) | | labels) | |
| +--------------+ +--------------+ +--------------+ |
+------------------------------------------------------------------+
Relationship Types:
/// Translates raw MM-Fi numpy files into domain CsiSample values.
/// Handles the 114->56 subcarrier interpolation and 1TX/3RX antenna layout.
pub struct MmFiAdapter {
/// Subcarrier interpolation service
interpolator: Box<dyn SubcarrierInterpolationService>,
/// Phase sanitizer from wifi-densepose-signal
phase_sanitizer: PhaseSanitizer,
/// Hardware normalizer for z-score normalization
normalizer: HardwareNormalizer,
}
impl MmFiAdapter {
/// Load a single MM-Fi sample from .npy tensors and produce a CsiSample.
/// Steps:
/// 1. Read amplitude [3, 114, 10] and phase [3, 114, 10]
/// 2. Interpolate 114 -> 56 subcarriers per antenna pair
/// 3. Sanitize phase (remove linear offset, unwrap)
/// 4. Z-score normalize amplitude per frame
/// 5. Read 17-keypoint COCO annotations
pub fn adapt(&self, raw: &MmFiRawSample) -> Result<CsiSample, AdapterError>;
}
/// Translates Wi-Pose .mat files into domain CsiSample values.
/// Handles 30->56 zero-padding and 18->17 keypoint mapping.
pub struct WiPoseAdapter {
/// Zero-padding service
interpolator: Box<dyn SubcarrierInterpolationService>,
/// Phase sanitizer
phase_sanitizer: PhaseSanitizer,
}
impl WiPoseAdapter {
/// Load a Wi-Pose sample from .mat format and produce a CsiSample.
/// Steps:
/// 1. Read CSI [9, 30] (3x3 antenna pairs, 30 subcarriers)
/// 2. Zero-pad 30 -> 56 subcarriers (high-frequency padding)
/// 3. Sanitize phase
/// 4. Map 18 AlphaPose keypoints -> 17 COCO (drop neck, index 1)
pub fn adapt(&self, raw: &WiPoseRawSample) -> Result<CsiSample, AdapterError>;
}
/// Adapts Detectron2 DensePose outputs into domain DensePoseLabels.
/// Used during teacher-student pseudo-label generation.
pub struct TeacherModelAdapter;
impl TeacherModelAdapter {
/// Run Detectron2 DensePose on an RGB frame and produce pseudo-labels.
/// Output: (part_labels [H x W], u_coords [H x W], v_coords [H x W])
pub fn generate_pseudo_labels(
&self,
rgb_frame: &RgbFrame,
) -> Result<DensePoseLabels, AdapterError>;
}
/// Adapts ruvector-attn-mincut API to the model's tensor format.
/// Handles the Tensor <-> Vec<f32> conversion overhead per batch element.
pub struct AttnMinCutAdapter;
impl AttnMinCutAdapter {
/// Apply min-cut gated attention to antenna-path features.
/// Converts [B, n_ant, n_sc] tensor to flat Vec<f32> per batch element,
/// calls attn_mincut, and reshapes output back to tensor.
pub fn apply(
&self,
features: &Tensor,
n_antenna_paths: usize,
n_subcarriers: usize,
lambda: f32,
) -> Result<Tensor, AdapterError>;
}
/// Persists and retrieves training run state
pub trait TrainingRunRepository {
fn save(&self, run: &TrainingRun) -> Result<(), RepositoryError>;
fn find_by_id(&self, id: &TrainingRunId) -> Result<Option<TrainingRun>, RepositoryError>;
fn find_latest(&self) -> Result<Option<TrainingRun>, RepositoryError>;
fn list_completed(&self) -> Result<Vec<TrainingRun>, RepositoryError>;
}
/// Persists model checkpoints
pub trait CheckpointRepository {
fn save(&self, checkpoint: &Checkpoint) -> Result<(), RepositoryError>;
fn find_best(&self, run_id: &TrainingRunId) -> Result<Option<Checkpoint>, RepositoryError>;
fn find_by_epoch(&self, run_id: &TrainingRunId, epoch: usize) -> Result<Option<Checkpoint>, RepositoryError>;
fn list_all(&self, run_id: &TrainingRunId) -> Result<Vec<Checkpoint>, RepositoryError>;
}
/// Persists AETHER embedding index
pub trait EmbeddingRepository {
fn save_index(&self, index: &EmbeddingIndex) -> Result<(), RepositoryError>;
fn load_index(&self) -> Result<Option<EmbeddingIndex>, RepositoryError>;
fn add_entry(&self, entry: &EmbeddingEntry) -> Result<(), RepositoryError>;
fn search_knn(&self, query: &[f32], k: usize) -> Result<Vec<(EmbeddingEntry, f32)>, RepositoryError>;
}
/// Persists LoRA adapters for environment-specific fine-tuning
pub trait LoraRepository {
fn save(&self, adapter: &LoraAdapter) -> Result<(), RepositoryError>;
fn find_by_domain(&self, domain: &DomainLabel) -> Result<Option<LoraAdapter>, RepositoryError>;
fn list_all(&self) -> Result<Vec<LoraAdapter>, RepositoryError>;
}