Back to Unilm

Joint Speech Text Training for the MuST-C English to German Speech Translation task

edgelm/examples/speech_text_joint_to_text/docs/ende-mustc.md

latest6.1 KB
Original Source

[Back]

Joint Speech Text Training for the MuST-C English to German Speech Translation task

Joint Training Baseline: it is based on paper "A general multi-task learning framework to leverage text data for speech to text tasks"

Enhanced Joint Training: the joint training is enhanced with pre-trained models, cross attentive regularization and online knowledge distillation based on paper "Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task"

Prepare Data

Download files

Prepare MuST-C data set

  • Please follow the data preparation in the S2T example
  • Convert source text under the "src_text" column in the tsv file into phoneme representation.
bash
    python examples/speech_text_joint_to_text/scripts/g2p_encode.py \
        --lower-case --do-filter --use-word-start --no-punc \
        --reserve-word examples/speech_text_joint_to_text/configs/mustc_noise.list \
        --data-path ${must_c_en_de_src_text} \
        --out-path ${must_c_en_de_src_text_pho}
  • Replace the source text under the "src_text" column in the tsv file with the corresponding phoneme reprentation generated in the step above.
  • Prepare phoneme dictionary and save to $MANIFEST_ROOT as src_dict.txt

Prepare WMT text data

  • Download wmt data
  • Convert source text (English) into phoneme representation as above
  • Generate binary parallel file for training (as translation example) and save data in $parallel_text_data

Training

The model is trained with 8 v100 GPUs.

Download pretrained models

Training scripts

  • Jointly trained model from scratch
bash
python train.py ${MANIFEST_ROOT} \
    --save-dir ${save_dir} \
    --num-workers 8 \
    --task speech_text_joint_to_text \
    --arch dualinputs2ttransformer_s \
    --user-dir examples/speech_text_joint_to_text \
    --max-epoch 100 --update-mix-data \
    --optimizer adam --lr-scheduler inverse_sqrt \
    --lr 0.001 --update-freq 4 --clip-norm 10.0 \
    --criterion guided_label_smoothed_cross_entropy_with_accuracy \
    --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
    --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
    --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
    --dropout 0.1 --warmup-updates 20000  \
    --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
    --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
    --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
    --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
    --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
    --keep-last-epochs 10
  • Jointly trained model with good initialization, cross attentive loss and online knowledge distillation
bash
python train.py ${MANIFEST_ROOT} \
    --save-dir ${save_dir} \
    --num-workers 8 \
    --task speech_text_joint_to_text \
    --arch dualinputs2ttransformer_m \
    --user-dir examples/speech_text_joint_to_text \
    --max-epoch 100 --update-mix-data \
    --optimizer adam --lr-scheduler inverse_sqrt \
    --lr 0.002 --update-freq 4 --clip-norm 10.0 \
    --criterion guided_label_smoothed_cross_entropy_with_accuracy \
    --guide-alpha 0.8 --disable-text-guide-update-num 5000 \
    --label-smoothing 0.1 --max-tokens 10000 --max-tokens-text 10000 \
    --max-positions-text 400 --seed 2 --speech-encoder-layers 12 \
    --text-encoder-layers 6 --encoder-shared-layers 6 --decoder-layers 6 \
    --dropout 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \
    --text-sample-ratio 0.25 --parallel-text-data ${parallel_text_data} \
    --text-input-cost-ratio 0.5 --enc-grad-mult 2.0 --add-speech-eos \
    --log-format json --langpairs en-de --noise-token '"'"'▁NOISE'"'"' \
    --mask-text-ratio 0.0 --max-tokens-valid 20000 --ddp-backend no_c10d \
    --log-interval 100 --data-buffer-size 50 --config-yaml config.yaml \
    --load-pretrain-speech-encoder ${pretrain_encoder} \
    --load-pretrain-decoder ${pretrain_nmt} \
    --load-pretrain-text-encoder-last ${pretrain_nmt} \
    --keep-last-epochs 10

Evaluation

bash
python ./fairseq_cli/generate.py \
        ${MANIFEST_ROOT} \
        --task speech_text_joint_to_text \
        --max-tokens 25000 \
        --nbest 1 \
        --results-path ${infer_results} \
        --batch-size 512 \
        --path ${model} \
        --gen-subset tst-COMMON \
        --config-yaml config_spm.yaml \
        --scoring sacrebleu \
        --beam 5 --lenpen 1.0 \
        --user-dir examples/speech_text_joint_to_text \
        --load-speech-only

Results (Joint training with initialization + CAR + online KD)

DirectionEn-DeEn-EsEn-Fr
BLEU27.431.237.6
checkpointlinklinklink