docs/source/en/model_doc/dit.md
This model was released on 2022-03-04 and added to Hugging Face Transformers on 2022-03-10.
<div style="float: right;"> <div class="flex flex-wrap space-x-1"></div>
DiT is an image transformer pretrained on large-scale unlabeled document images. It learns to predict the missing visual tokens from a corrupted input image. The pretrained DiT model can be used as a backbone in other models for visual document tasks like document image classification and table detection.
You can find all the original DiT checkpoints under the Microsoft organization.
[!TIP] Refer to the BEiT docs for more examples of how to apply DiT to different vision tasks.
The example below demonstrates how to classify an image with [Pipeline] or the [AutoModel] class.
from transformers import pipeline
pipeline = pipeline(
task="image-classification",
model="microsoft/dit-base-finetuned-rvlcdip",
device=0
)
pipeline("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dit-example.jpg")
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
image_processor = AutoImageProcessor.from_pretrained(
"microsoft/dit-base-finetuned-rvlcdip",
use_fast=True,
)
model = AutoModelForImageClassification.from_pretrained(
"microsoft/dit-base-finetuned-rvlcdip",
device_map="auto",
)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dit-example.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(image, return_tensors="pt").to(model.device)
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax(dim=-1).item()
class_labels = model.config.id2label
predicted_class_label = class_labels[predicted_class_id]
print(f"The predicted class label is: {predicted_class_label}")
The pretrained DiT weights can be loaded in a [BEiT] model with a modeling head to predict visual tokens.
from transformers import BeitForMaskedImageModeling
model = BeitForMaskedImageModeling.from_pretraining("microsoft/dit-base")