Back to Transformers

Audio classification examples

examples/pytorch/audio-classification/README.md

5.8.05.9 KB
Original Source
<!--- Copyright 2021 The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -->

Audio classification examples

The following examples showcase how to fine-tune Wav2Vec2 for audio classification using PyTorch.

Speech recognition models that have been pretrained in unsupervised fashion on audio data alone, e.g. Wav2Vec2, HuBERT, XLSR-Wav2Vec2, have shown to require only very little annotated data to yield good performance on speech classification datasets.

Single-GPU

The following command shows how to fine-tune wav2vec2-base on the šŸ—£ļø Keyword Spotting subset of the SUPERB dataset.

bash
python run_audio_classification.py \
    --model_name_or_path facebook/wav2vec2-base \
    --dataset_name superb \
    --dataset_config_name ks \
    --output_dir wav2vec2-base-ft-keyword-spotting \
    --remove_unused_columns False \
    --do_train \
    --do_eval \
    --fp16 \
    --learning_rate 3e-5 \
    --max_length_seconds 1 \
    --attention_mask False \
    --warmup_steps 0.1 \
    --num_train_epochs 5 \
    --per_device_train_batch_size 32 \
    --gradient_accumulation_steps 4 \
    --per_device_eval_batch_size 32 \
    --dataloader_num_workers 4 \
    --logging_strategy steps \
    --logging_steps 10 \
    --eval_strategy epoch \
    --save_strategy epoch \
    --load_best_model_at_end True \
    --metric_for_best_model accuracy \
    --save_total_limit 3 \
    --seed 0 \
    --push_to_hub

On a single V100 GPU (16GB), this script should run in ~14 minutes and yield accuracy of 98.26%.

šŸ‘€ See the results here: anton-l/wav2vec2-base-ft-keyword-spotting

If your model classification head dimensions do not fit the number of labels in the dataset, you can specify --ignore_mismatched_sizes to adapt it.

Multi-GPU

The following command shows how to fine-tune wav2vec2-base for šŸŒŽ Language Identification on the CommonLanguage dataset.

bash
python run_audio_classification.py \
    --model_name_or_path facebook/wav2vec2-base \
    --dataset_name common_language \
    --audio_column_name audio \
    --label_column_name language \
    --output_dir wav2vec2-base-lang-id \
    --remove_unused_columns False \
    --do_train \
    --do_eval \
    --fp16 \
    --learning_rate 3e-4 \
    --max_length_seconds 16 \
    --attention_mask False \
    --warmup_steps 0.1 \
    --num_train_epochs 10 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 4 \
    --per_device_eval_batch_size 1 \
    --dataloader_num_workers 8 \
    --logging_strategy steps \
    --logging_steps 10 \
    --eval_strategy epoch \
    --save_strategy epoch \
    --load_best_model_at_end True \
    --metric_for_best_model accuracy \
    --save_total_limit 3 \
    --seed 0 \
    --push_to_hub

On 4 V100 GPUs (16GB), this script should run in ~1 hour and yield accuracy of 79.45%.

šŸ‘€ See the results here: anton-l/wav2vec2-base-lang-id

Sharing your model on šŸ¤— Hub

  1. If you haven't already, sign up for a šŸ¤— account

  2. Make sure you have git-lfs installed and git set up.

bash
$ apt install git-lfs
  1. Log in with your HuggingFace account credentials using hf
bash
$ hf auth login
# ...follow the prompts
  1. When running the script, pass the following arguments:
bash
python run_audio_classification.py \
    --push_to_hub \
    --hub_model_id <username/model_id> \
    ...

Examples

The following table shows a couple of demonstration fine-tuning runs. It has been verified that the script works for the following datasets:

DatasetPretrained Model# transformer layersAccuracy on evalGPU setupTraining timeFine-tuned Model & Logs
Keyword Spottingntu-spml/distilhubert20.97061 V100 GPU11minhere
Keyword Spottingfacebook/wav2vec2-base120.98261 V100 GPU14minhere
Keyword Spottingfacebook/hubert-base-ls960120.98191 V100 GPU14minhere
Keyword Spottingasapp/sew-mid-100k240.97571 V100 GPU15minhere
Common Languagefacebook/wav2vec2-base120.79454 V100 GPUs1h10mhere