plugins/plugin-local-inference/native/kokoro_training/README.md
Training implementation for English text-to-speech using the Kokoro Transformer architecture with LJSpeech dataset support.
This is a simplified training implementation based on the Kokoro architecture. The official Kokoro-82M uses a decoder-only architecture based on StyleTTS 2 and iSTFTNet, employing a phoneme-level BERT text encoder, style encoder for prosody control, WavLM-based discriminator (12 layers, pre-trained on 94k hours), and iSTFTNet vocoder generating magnitude and phase for inverse STFT conversion. Training uses two stages: acoustic modules for mel-spectrogram reconstruction, then TTS prediction modules with style diffusion and adversarial training. This implementation uses explicit MFA-derived durations with a duration predictor, teacher forcing with standard multi-head attention, no style encoder or multi-speaker embeddings, a simple encoder-decoder transformer (~22M parameters vs 82M), and external HiFi-GAN vocoder, prioritizing training clarity and educational value over production architecture.
| Component | Kokoro-82M (Official) | This Implementation |
|---|---|---|
| Architecture | Decoder-only (StyleTTS 2 + iSTFTNet) | Encoder-decoder transformer |
| Parameters | 82M | ~22M |
| Text Encoder | Phoneme-level BERT (pre-trained) | Standard transformer (6 layers) |
| Style Encoder | Yes (prosody/speaker control) | No |
| Alignment | Learned via diffusion | Explicit MFA durations |
| Discriminator | WavLM (12 layers, 94k hours) | None |
| Training | Two-stage + adversarial | Single-stage supervised |
| Vocoder | Integrated iSTFTNet | External HiFi-GAN |
| Multi-speaker | Yes (zero-shot) | No |
| Training Data | Few hundred hours | LJSpeech (24 hours) |
Full encoder-decoder transformer with multi-head attention, phoneme-level duration alignment using Montreal Forced Aligner (MFA), CUDA support with optional mixed precision training, experiment tracking via Weights & Biases, checkpoint management for resuming training, and adaptive memory management with gradient checkpointing.
Install dependencies:
pip install -r requirements.txt
The system uses g2p_en for grapheme-to-phoneme conversion (ARPA phonemes), which perfectly matches MFA's english_us_arpa alignment model.
Download LJSpeech dataset with pre-aligned MFA annotations (3.8GB, recommended - saves 1-3 hours):
python setup_ljspeech.py --zenodo
# Download dataset only
python setup_ljspeech.py
# Then run alignment with standard dictionary (matches g2p_en)
python setup_ljspeech.py --align --no-custom-dict
Note: The --no-custom-dict flag uses MFA's standard english_us_arpa dictionary, which perfectly matches our g2p_en phoneme processor. This ensures 100% alignment compatibility.
Start training:
python training_english.py --corpus LJSpeech-1.1 --wandb
Resume from checkpoint:
python training_english.py --corpus LJSpeech-1.1 --resume auto --wandb
Generate speech:
python inference_english.py \
--model kokoro_english_model/kokoro_english_final.pth \
--text "Hello world, this is a test." \
--output output.wav
| Argument | Default | Description |
|---|---|---|
--corpus | LJSpeech-1.1 | Path to LJSpeech dataset |
--output | ./kokoro_english_model | Output directory for checkpoints |
--batch-size | from config (32) | Training batch size |
--epochs | from config (300) | Number of training epochs |
--learning-rate | from config (1e-3) | Learning rate |
--model-size | default | Model size: small (6M), medium (25M, recommended for LJSpeech), default (62M), large (120M) |
--device | auto | Device: auto, cuda, mps, cpu |
--resume | None | Resume from checkpoint (auto for latest, or path to .pth) |
--wandb | False | Enable Weights & Biases logging |
--no-gradient-checkpointing | False | Disable gradient checkpointing (uses more memory) |
--no-mixed-precision | False | Disable mixed precision training |
--test-mode | False | Quick test with 100 samples, 5 epochs |
The model consists of a text encoder (6-layer transformer with 8 attention heads), duration predictor (MLP for phoneme durations), length regulator (expands encoder outputs), mel decoder (6-layer transformer with masked attention), PostNet (5-layer convolutional refinement network from Tacotron 2), and stop token predictor.
Configuration: 512 hidden dimensions, 6 encoder/decoder layers, 8 attention heads, 2048 feed-forward dimensions, 80 mel channels, 22,050 Hz sample rate. Gradient checkpointing enabled for memory efficiency.
Scheduled Sampling (CRITICAL for inference quality):
PostNet Architecture:
mel_final = mel_coarse + 0.5 * mel_residualLoss Configuration:
Training Optimizations:
use_gt_durations_until_epoch)Input: Phoneme indices
↓
Text Encoder (Transformer)
↓
Duration Predictor → predicted durations
↓
Length Regulator (expand by durations)
↓
Decoder (Transformer with scheduled sampling)
├─ Teacher forcing (0-500 batches)
├─ Mixed sampling (500-2000 batches)
└─ Full exposure (2000+ batches)
↓
Mel Projection Coarse → coarse mels
↓
PostNet (5-layer Conv1D) → residual
↓
mel_final = coarse + 0.5 * residual
↓
Clamp to [-11.5, 0.0]
↓
Output: Mel spectrogram → HiFi-GAN vocoder → Audio
Scheduled Sampling prevents exposure bias - the most common failure mode in autoregressive TTS:
PostNet on Complete Sequences is required for proper temporal context:
Loss Weight Balance prevents gradient imbalance:
Based on overfit test success (mel loss 0.016 after 3000 iterations):
Note: Full dataset mel loss will be higher than overfit (0.2-0.3 vs 0.016) because it generalizes across 13,000 diverse samples, not just 1. Audio quality should be excellent at ~0.3.
Training crashes with "unexpected keyword argument":
forward() wrapper now accepts use_gt_durations and decoder_input_melsgit pull or check kokoro/model.py lines 777-803Model produces garbage audio at inference despite low training loss:
enable_scheduled_sampling: True in training/config_english.pyMel loss stuck above 1.0 after epoch 20:
duration_loss_weight = 0.01 (not 0.25 or 1.0) in configgrep "PostNet" kokoro/model.py should find itAudio quality not improving after epoch 50:
Out of memory errors:
Run inference tests periodically to catch issues early:
# Every 10 epochs, generate test audio
python inference_english.py \
--model kokoro_english_model/checkpoint_epoch_20.pth \
--text "The quick brown fox jumps over the lazy dog." \
--output test_epoch_20.wav
Compare audio quality across epochs:
If any checkpoint produces garbage audio, scheduled sampling may have been disabled. Check training logs.
The LJSpeech dataset should be organized as:
LJSpeech-1.1/
├── metadata.csv # Transcriptions
├── wavs/ # Audio files (13,100 samples)
│ ├── LJ001-0001.wav
│ └── ...
└── TextGrid/ # MFA alignments (if using Zenodo)
├── LJ001-0001.TextGrid
└── ...
Basic usage:
from inference_english import EnglishTTSInference
tts = EnglishTTSInference(
model_path="kokoro_english_model/kokoro_english_final.pth",
device="cuda"
)
tts.synthesize_to_file(
text="Hello, how are you today?",
output_path="output.wav"
)
Advanced options:
python inference_english.py \
--model kokoro_english_model/checkpoint_epoch_50.pth \
--text "Your text here" \
--output output.wav \
--device cuda \
--vocoder hifigan
ImportError: cannot import name 'TypeIs' from 'typing_extensions'
Run pip install --upgrade typing-extensions
Mixed precision errors on CUDA
Add --no-mixed-precision flag
Out of memory
Reduce --batch-size (try 8, 4, or 2)
W&B not showing loss charts Fixed in latest version (losses log every 10 batches)
Performance tips: Use batch size 16-32 for CUDA GPUs. CPU training not recommended, but if needed use batch size 2-4. Gradient checkpointing is enabled by default. Pre-aligned Zenodo dataset saves 1-3 hours of setup.
kokoro-english-tts/
├── README.md
├── requirements.txt
├── setup_ljspeech.py # Dataset setup
├── training_english.py # Main training script
├── inference_english.py # Main inference script
├── test_english_implementation.py # Test suite
│
├── kokoro/ # Core model architecture
│ ├── __init__.py
│ ├── model.py # Kokoro TTS model
│ ├── model_transformers.py # Transformer encoder/decoder
│ └── positional_encoding.py # Sinusoidal encoding
│
├── data/ # Dataset and preprocessing
│ ├── __init__.py
│ ├── ljspeech_dataset.py # LJSpeech data loader
│ └── english_phoneme_processor.py # English G2P (g2p_en - ARPA)
│
├── audio/ # Audio processing and vocoder
│ ├── __init__.py
│ ├── audio_utils.py # Audio utilities
│ ├── vocoder_manager.py # Vocoder interface
│ └── hifigan_vocoder.py # HiFi-GAN implementation
│
└── training/ # Training infrastructure
├── __init__.py
├── config_english.py # Training configuration
├── trainer.py # Base trainer
├── english_trainer.py # English trainer with W&B
├── checkpoint_manager.py # Checkpoint utilities
├── adaptive_memory_manager.py # Memory optimization
├── interbatch_profiler.py # Performance profiling
├── mps_grad_scaler.py # MPS mixed precision
└── device_type.py # Device enumeration
See requirements.txt for full list.
Run implementation tests:
python test_english_implementation.py
Quick training test (100 samples):
python training_english.py --test-mode
Training generates checkpoints every 5 epochs, a phoneme processor file, and a final model. Each checkpoint contains model state dict, optimizer state, learning rate scheduler state, training configuration, current epoch and loss, and mixed precision scaler state.
This implementation is for educational and research purposes.
Based on the original Kokoro TTS model. LJSpeech dataset by Keith Ito. Montreal Forced Aligner for phoneme-level alignments. g2p_en for English grapheme-to-phoneme conversion (ARPA phonemes). Original implementation based on kokoro-ruslan.