Back to Llama Index

Retrieval-Augmented Image Captioning

docs/examples/multi_modal/llava_multi_modal_tesla_10q.ipynb

0.14.216.5 KB
Original Source

<a href="https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/multi_modal/llava_multi_modal_tesla_10q.ipynb" target="_parent"></a>

Retrieval-Augmented Image Captioning

In this example, we show how to leverage LLaVa + Replicate for image understanding/captioning and retrieve relevant unstructured text and embedded tables from Tesla 10K file according to the image understanding.

  1. LlaVa can provide image understanding based on user prompt.
  2. We use Unstructured to parse out the tables, and use LlamaIndex recursive retrieval to index/retrieve tables and texts.
  3. We can leverage the image understanding from Step 1 to retrieve relevant information from knowledge base generated by Step 2 (which is indexed by LlamaIndex)

Context for LLaVA: Large Language and Vision Assistant

For LlamaIndex: LlaVa+Replicate enables us to run image understanding locally and combine the multi-modal knowledge with our RAG knowledge base system.

TODO: Waiting for llama-cpp-python supporting LlaVa model in python wrapper. So LlamaIndex can leverage LlamaCPP class for serving LlaVa model directly/locally.

Using Replicate serving LLaVa model through LlamaIndex

Build and Run LLaVa models locally through Llama.cpp (Deprecated)

  1. git clone https://github.com/ggerganov/llama.cpp.git
  2. cd llama.cpp. Checkout llama.cpp repo for more details.
  3. make
  4. Download Llava models including ggml-model-* and mmproj-model-* from this Hugging Face repo. Please select one model based on your own local configuration
  5. ./llava for checking whether llava is running locally
python
%pip install llama-index-readers-file
%pip install llama-index-multi-modal-llms-replicate
python
%load_ext autoreload
% autoreload 2
python
!pip install unstructured
python
from unstructured.partition.html import partition_html
import pandas as pd

pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", None)
pd.set_option("display.max_colwidth", None)

Perform Data Extraction from Tesla 10K file

In these sections we use Unstructured to parse out the table and non-table elements.

Extract Elements

We use Unstructured to extract table and non-table elements from the 10-K filing.

python
!wget "https://www.dropbox.com/scl/fi/mlaymdy1ni1ovyeykhhuk/tesla_2021_10k.htm?rlkey=qf9k4zn0ejrbm716j0gg7r802&dl=1" -O tesla_2021_10k.htm
!wget "https://docs.google.com/uc?export=download&id=1THe1qqM61lretr9N3BmINc_NWDvuthYf" -O shanghai.jpg
!wget "https://docs.google.com/uc?export=download&id=1PDVCf_CzLWXNnNoRV8CFgoJxv6U0sHAO" -O tesla_supercharger.jpg
python
from llama_index.readers.file import FlatReader
from pathlib import Path

reader = FlatReader()
docs_2021 = reader.load_data(Path("tesla_2021_10k.htm"))
python
from llama_index.core.node_parser import UnstructuredElementNodeParser

node_parser = UnstructuredElementNodeParser()
python
import os

REPLICATE_API_TOKEN = "..."  # Your Relicate API token here
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
python
import openai

OPENAI_API_KEY = "sk-..."
openai.api_key = OPENAI_API_KEY  # add your openai api key here
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
python
import os
import pickle

if not os.path.exists("2021_nodes.pkl"):
    raw_nodes_2021 = node_parser.get_nodes_from_documents(docs_2021)
    pickle.dump(raw_nodes_2021, open("2021_nodes.pkl", "wb"))
else:
    raw_nodes_2021 = pickle.load(open("2021_nodes.pkl", "rb"))
python
nodes_2021, objects_2021 = node_parser.get_nodes_and_objects(raw_nodes_2021)

Setup Composable Retriever

Now that we've extracted tables and their summaries, we can setup a composable retriever in LlamaIndex to query these tables.

Construct Retrievers

python
from llama_index.core import VectorStoreIndex

# construct top-level vector index + query engine
vector_index = VectorStoreIndex(nodes=nodes_2021, objects=objects_2021)
query_engine = vector_index.as_query_engine(similarity_top_k=2, verbose=True)
python
from PIL import Image
import matplotlib.pyplot as plt

imageUrl = "./tesla_supercharger.jpg"
image = Image.open(imageUrl).convert("RGB")

plt.figure(figsize=(16, 5))
plt.imshow(image)

Running LLaVa model using Replicate through LlamaIndex for image understanding

python
from llama_index.multi_modal_llms.replicate import ReplicateMultiModal
from llama_index.core.schema import ImageDocument
from llama_index.multi_modal_llms.replicate.base import (
    REPLICATE_MULTI_MODAL_LLM_MODELS,
)

multi_modal_llm = ReplicateMultiModal(
    model=REPLICATE_MULTI_MODAL_LLM_MODELS["llava-13b"],
    max_new_tokens=200,
    temperature=0.1,
)

prompt = "what is the main object for tesla in the image?"

llava_response = multi_modal_llm.complete(
    prompt=prompt,
    image_documents=[ImageDocument(image_path=imageUrl)],
)

Retrieve relevant information from LlamaIndex knowledge base according to LLaVa image understanding

python
prompt_template = "please provide relevant information about: "
rag_response = query_engine.query(prompt_template + llava_response.text)

Showing final RAG image caption results from LlamaIndex

python
print(str(rag_response))
python
from PIL import Image
import matplotlib.pyplot as plt

imageUrl = "./shanghai.jpg"
image = Image.open(imageUrl).convert("RGB")

plt.figure(figsize=(16, 5))
plt.imshow(image)

Retrieve relevant information from LlamaIndex for a new image

python
prompt = "which Tesla factory is shown in the image?"

llava_response = multi_modal_llm.complete(
    prompt=prompt,
    image_documents=[ImageDocument(image_path=imageUrl)],
)
python
prompt_template = "please provide relevant information about: "
rag_response = query_engine.query(prompt_template + llava_response.text)

Showing final RAG image caption results from LlamaIndex

python
print(rag_response)