docs/source/en/model_doc/mbart.md
This model was released on 2020-01-22 and added to Hugging Face Transformers on 2020-11-16.
<div style="float: right;"> <div class="flex flex-wrap space-x-1"> </div> </div>mBART is a multilingual machine translation model that pretrains the entire translation model (encoder-decoder) unlike previous methods that only focused on parts of the model. The model is trained on a denoising objective which reconstructs the corrupted text. This allows mBART to handle the source language and the target text to translate to.
mBART-50 is pretrained on an additional 25 languages.
You can find all the original mBART checkpoints under the AI at Meta organization.
[!TIP] Click on the mBART models in the right sidebar for more examples of applying mBART to different language tasks.
The example below demonstrates how to translate text with [Pipeline] or the [AutoModel] class.
from transformers import pipeline
pipeline = pipeline(
task="translation",
model="facebook/mbart-large-50-many-to-many-mmt",
src_lang="en_XX",
tgt_lang="fr_XX",
device=0,
)
print(pipeline("UN Chief Says There Is No Military Solution in Syria"))
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
article_en = "UN Chief Says There Is No Military Solution in Syria"
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt", attn_implementation="sdpa", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer.src_lang = "en_XX"
encoded_hi = tokenizer(article_en, return_tensors="pt").to(model.device)
generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"], cache_implementation="static")
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
You can check the full list of language codes via tokenizer.lang_code_to_id.keys().
mBART requires a special language id token in the source and target text during training. The source text format is X [eos, src_lang_code] where X is the source text. The target text format is [tgt_lang_code] X [eos]. The bos token is never used. The [~PreTrainedTokenizerBase._call_] encodes the source text format passed as the first argument or with the text keyword. The target text format is passed with the text_label keyword.
Set the decoder_start_token_id to the target language id for mBART.
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro", attn_implementation="sdpa", device_map="auto")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX")
article = "UN Chief Says There Is No Military Solution in Syria"
inputs = tokenizer(article, return_tensors="pt").to(model.device)
translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
mBART-50 has a different text format. The language id token is used as the prefix for the source and target text. The text format is [lang_code] X [eos] where lang_code is the source language id for the source text and target language id for the target text. X is the source or target text respectively.
Set the eos_token_id as the decoder_start_token_id for mBART-50. The target language id is used as the first generated token by passing forced_bos_token_id to [~GenerationMixin.generate].
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt", attn_implementation="sdpa", device_map="auto")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."
tokenizer.src_lang = "ar_AR"
encoded_ar = tokenizer(article_ar, return_tensors="pt").to(model.device)
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
[[autodoc]] MBartConfig
[[autodoc]] MBartTokenizer
[[autodoc]] MBartTokenizerFast
[[autodoc]] MBart50Tokenizer
[[autodoc]] MBart50TokenizerFast
[[autodoc]] MBartModel
[[autodoc]] MBartForConditionalGeneration
[[autodoc]] MBartForQuestionAnswering
[[autodoc]] MBartForSequenceClassification
[[autodoc]] MBartForCausalLM - forward