docs/source/en/model_doc/vibevoice_acoustic_tokenizer.md
This model was released on 2025-08-26 and added to Hugging Face Transformers on 2026-02-06.
VibeVoice is a novel framework for synthesizing high-fidelity, long-form speech with multiple speakers by employing a next-token diffusion approach within a Large Language Model (LLM) structure. It's designed to capture the authentic conversational "vibe" and is particularly suited for generating audio content like podcasts and multi-participant audiobooks.
One key feature of VibeVoice is the use of two continuous audio tokenizers, one for extracting acoustic features and another for semantic features.
A model checkpoint is available at microsoft/VibeVoice-AcousticTokenizer
This model was contributed by Eric Bezzam.
The architecture is a mirror-symmetric encoder-decoder structure. The encoder employs a hierarchical design with 7 stages of ConvNeXt-like blocks, which use 1D depth-wise causal convolutionsfor efficient streaming processing. Six downsampling layers achieve a cumulative 3200X downsampling rate from a 24kHz input, yielding 7.5 tokens/frames per second. Each encoder/decoder component has approximately 340M parameters, for a total of around 680M parameters The training objective follows that of DAC, including its discriminator and loss designs.
Acoustic Tokenizer adopts the principles of a Variational Autoencoder (VAE). The encoder maps the input audio to the parameters of a latent distribution, namely the mean. Along with a fixed standard deviation, a latent vector is then sampled using the reparameterization trick. Please refer to the technical report for further details.
Below is example usage to encode and decode audio:
import torch
from scipy.io import wavfile
from transformers import AutoFeatureExtractor, VibeVoiceAcousticTokenizerModel
from transformers.audio_utils import load_audio_librosa
model_id = "microsoft/VibeVoice-AcousticTokenizer"
# load model
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = VibeVoiceAcousticTokenizerModel.from_pretrained(model_id, device_map="auto")
print("Model loaded on device:", model.device)
print("Model dtype:", model.dtype)
# load audio
audio = load_audio_librosa(
"https://huggingface.co/datasets/bezzam/vibevoice_samples/resolve/main/voices/en-Alice_woman.wav",
sampling_rate=feature_extractor.sampling_rate,
)
# preprocess audio
inputs = feature_extractor(
audio,
sampling_rate=feature_extractor.sampling_rate,
pad_to_multiple_of=3200,
).to(model.device, model.dtype)
print("Input audio shape:", inputs.input_values.shape)
# Input audio shape: torch.Size([1, 1, 224000])
with torch.no_grad():
# set VAE sampling to False for deterministic output
encoded_outputs = model.encode(inputs.input_values, sample=False)
print("Latent shape:", encoded_outputs.latents.shape)
# Latent shape: torch.Size([1, 70, 64])
decoded_outputs = model.decode(**encoded_outputs)
print("Reconstructed audio shape:", decoded_outputs.audio.shape)
# Reconstructed audio shape: torch.Size([1, 1, 224000])
# Save audio
output_fp = "vibevoice_acoustic_tokenizer_reconstructed.wav"
wavfile.write(output_fp, feature_extractor.sampling_rate, decoded_outputs.audio.squeeze().float().cpu().numpy())
print(f"Reconstructed audio saved to : {output_fp}")
For streaming ASR or TTS, where cached states need to be tracked, the use_cache parameter can be used when encoding or decoding audio:
import torch
from scipy.io import wavfile
from transformers import AutoFeatureExtractor, VibeVoiceAcousticTokenizerModel
from transformers.audio_utils import load_audio_librosa
model_id = "microsoft/VibeVoice-AcousticTokenizer"
# load model
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = VibeVoiceAcousticTokenizerModel.from_pretrained(model_id, device_map="auto")
print("Model loaded on device:", model.device)
print("Model dtype:", model.dtype)
# load audio
audio = load_audio_librosa(
"https://huggingface.co/datasets/bezzam/vibevoice_samples/resolve/main/voices/en-Alice_woman.wav",
sampling_rate=feature_extractor.sampling_rate,
)
# preprocess audio
inputs = feature_extractor(
audio,
sampling_rate=feature_extractor.sampling_rate,
pad_to_multiple_of=3200,
).to(model.device, model.dtype)
print("Input audio shape:", inputs.input_values.shape)
# Input audio shape: torch.Size([1, 1, 224000])
# chache will be initialized after a first pass
encoder_cache = None
decoder_cache = None
with torch.no_grad():
# set VAE sampling to False for deterministic output
encoded_outputs = model.encode(inputs.input_values, sample=False, padding_cache=encoder_cache, use_cache=True)
print("Latent shape:", encoded_outputs.latents.shape)
# Latent shape: torch.Size([1, 70, 64])
decoded_outputs = model.decode(encoded_outputs.latents, padding_cache=decoder_cache, use_cache=True)
print("Reconstructed audio shape:", decoded_outputs.audio.shape)
# Reconstructed audio shape: torch.Size([1, 1, 224000])
# `padding_cache` can be extracted from the outputs for subsequent passes
encoder_cache = encoded_outputs.padding_cache
print("Number of cached encoder layers:", len(encoder_cache.per_layer_in_channels))
# Number of cached encoder layers: 34
decoder_cache = decoded_outputs.padding_cache
print("Number of cached decoder layers:", len(decoder_cache.per_layer_in_channels))
# Number of cached decoder layers: 34
# Save audio
output_fp = "vibevoice_acoustic_tokenizer_reconstructed.wav"
wavfile.write(output_fp, feature_extractor.sampling_rate, decoded_outputs.audio.squeeze().float().cpu().numpy())
print(f"Reconstructed audio saved to : {output_fp}")
[[autodoc]] VibeVoiceAcousticTokenizerConfig
[[autodoc]] VibeVoiceAcousticTokenizerEncoderConfig
[[autodoc]] VibeVoiceAcousticTokenizerDecoderConfig
[[autodoc]] VibeVoiceAcousticTokenizerFeatureExtractor - call
[[autodoc]] VibeVoiceAcousticTokenizerModel - encode - decode - forward
[[autodoc]] VibeVoiceAcousticTokenizerEncoderModel - forward
[[autodoc]] VibeVoiceAcousticTokenizerDecoderModel - forward