docs/examples/cookbooks/GraphRAG_v1.ipynb
<a href="https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/cookbooks/GraphRAG_v1.ipynb" target="_parent"></a>
GraphRAG (Graphs + Retrieval Augmented Generation) combines the strengths of Retrieval Augmented Generation (RAG) and Query-Focused Summarization (QFS) to effectively handle complex queries over large text datasets. While RAG excels in fetching precise information, it struggles with broader queries that require thematic understanding, a challenge that QFS addresses but cannot scale well. GraphRAG integrates these approaches to offer responsive and thorough querying capabilities across extensive, diverse text corpora.
This notebook provides guidance on constructing the GraphRAG pipeline using the LlamaIndex PropertyGraph abstractions.
NOTE: This is an approximate implementation of GraphRAG. We are currently developing a series of cookbooks that will detail the exact implementation of GraphRAG.
The GraphRAG involves two steps:
Graph Generation:
Source Documents to Text Chunks: Source documents are divided into smaller text chunks for easier processing.
Text Chunks to Element Instances: Each text chunk is analyzed to identify and extract entities and relationships, resulting in a list of tuples that represent these elements.
Element Instances to Element Summaries: The extracted entities and relationships are summarized into descriptive text blocks for each element using the LLM.
Element Summaries to Graph Communities: These entities, relationships and summaries form a graph, which is subsequently partitioned into communities using algorithms using Heirarchical Leiden to establish a hierarchical structure.
Graph Communities to Community Summaries: The LLM generates summaries for each community, providing insights into the dataset’s overall topical structure and semantics.
Answering the Query:
Community Summaries to Global Answers: The summaries of the communities are utilized to respond to user queries. This involves generating intermediate answers, which are then consolidated into a comprehensive global answer.
Here are the different components we implemented to build all of the processes mentioned above.
Source Documents to Text Chunks: Implemented using SentenceSplitter with a chunk size of 1024 and chunk overlap of 20 tokens.
Text Chunks to Element Instances AND Element Instances to Element Summaries: Implemented using GraphRAGExtractor.
Element Summaries to Graph Communities AND Graph Communities to Community Summaries: Implemented using GraphRAGStore.
Community Summaries to Global Answers: Implemented using GraphQueryEngine.
Let's check into each of these components and build GraphRAG pipeline.
graspologic is used to use hierarchical_leiden for building communities.
!pip install llama-index graspologic numpy==1.24.4 scipy==1.12.0
We will use a sample news article dataset retrieved from Diffbot, which Tomaz has conveniently made available on GitHub for easy access.
The dataset contains 2,500 samples; for ease of experimentation, we will use 50 of these samples, which include the title and text of news articles.
import pandas as pd
from llama_index.core import Document
news = pd.read_csv(
"https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv"
)[:50]
news.head()
Prepare documents as required by LlamaIndex
documents = [
Document(text=f"{row['title']}: {row['text']}")
for i, row in news.iterrows()
]
import os
os.environ["OPENAI_API_KEY"] = "sk-..."
from llama_index.llms.openai import OpenAI
llm = OpenAI(model="gpt-4")
The GraphRAGExtractor class is designed to extract triples (subject-relation-object) from text and enrich them by adding descriptions for entities and relationships to their properties using an LLM.
This functionality is similar to that of the SimpleLLMPathExtractor, but includes additional enhancements to handle entity, relationship descriptions. For guidance on implementation, you may look at similar existing extractors.
Here's a breakdown of its functionality:
Key Components:
llm: The language model used for extraction.extract_prompt: A prompt template used to guide the LLM in extracting information.parse_fn: A function to parse the LLM's output into structured data.max_paths_per_chunk: Limits the number of triples extracted per text chunk.num_workers: For parallel processing of multiple text nodes.Main Methods:
__call__: The entry point for processing a list of text nodes.acall: An asynchronous version of call for improved performance._aextract: The core method that processes each individual node.Extraction Process:
For each input node (chunk of text):
NOTE: In the current implementation, we are using only relationship descriptions. In the next implementation, we will utilize entity descriptions during the retrieval stage.
import asyncio
import nest_asyncio
nest_asyncio.apply()
from typing import Any, List, Callable, Optional, Union, Dict
from IPython.display import Markdown, display
from llama_index.core.async_utils import run_jobs
from llama_index.core.indices.property_graph.utils import (
default_parse_triplets_fn,
)
from llama_index.core.graph_stores.types import (
EntityNode,
KG_NODES_KEY,
KG_RELATIONS_KEY,
Relation,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.prompts import PromptTemplate
from llama_index.core.prompts.default_prompts import (
DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
)
from llama_index.core.schema import TransformComponent, BaseNode
from llama_index.core.bridge.pydantic import BaseModel, Field
class GraphRAGExtractor(TransformComponent):
"""Extract triples from a graph.
Uses an LLM and a simple prompt + output parsing to extract paths (i.e. triples) and entity, relation descriptions from text.
Args:
llm (LLM):
The language model to use.
extract_prompt (Union[str, PromptTemplate]):
The prompt to use for extracting triples.
parse_fn (callable):
A function to parse the output of the language model.
num_workers (int):
The number of workers to use for parallel processing.
max_paths_per_chunk (int):
The maximum number of paths to extract per chunk.
"""
llm: LLM
extract_prompt: PromptTemplate
parse_fn: Callable
num_workers: int
max_paths_per_chunk: int
def __init__(
self,
llm: Optional[LLM] = None,
extract_prompt: Optional[Union[str, PromptTemplate]] = None,
parse_fn: Callable = default_parse_triplets_fn,
max_paths_per_chunk: int = 10,
num_workers: int = 4,
) -> None:
"""Init params."""
from llama_index.core import Settings
if isinstance(extract_prompt, str):
extract_prompt = PromptTemplate(extract_prompt)
super().__init__(
llm=llm or Settings.llm,
extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
parse_fn=parse_fn,
num_workers=num_workers,
max_paths_per_chunk=max_paths_per_chunk,
)
@classmethod
def class_name(cls) -> str:
return "GraphExtractor"
def __call__(
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
"""Extract triples from nodes."""
return asyncio.run(
self.acall(nodes, show_progress=show_progress, **kwargs)
)
async def _aextract(self, node: BaseNode) -> BaseNode:
"""Extract triples from a node."""
assert hasattr(node, "text")
text = node.get_content(metadata_mode="llm")
try:
llm_response = await self.llm.apredict(
self.extract_prompt,
text=text,
max_knowledge_triplets=self.max_paths_per_chunk,
)
entities, entities_relationship = self.parse_fn(llm_response)
except ValueError:
entities = []
entities_relationship = []
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
metadata = node.metadata.copy()
for entity, entity_type, description in entities:
metadata[
"entity_description"
] = description # Not used in the current implementation. But will be useful in future work.
entity_node = EntityNode(
name=entity, label=entity_type, properties=metadata
)
existing_nodes.append(entity_node)
metadata = node.metadata.copy()
for triple in entities_relationship:
subj, obj, rel, description = triple
subj_node = EntityNode(name=subj, properties=metadata)
obj_node = EntityNode(name=obj, properties=metadata)
metadata["relationship_description"] = description
rel_node = Relation(
label=rel,
source_id=subj_node.id,
target_id=obj_node.id,
properties=metadata,
)
existing_nodes.extend([subj_node, obj_node])
existing_relations.append(rel_node)
node.metadata[KG_NODES_KEY] = existing_nodes
node.metadata[KG_RELATIONS_KEY] = existing_relations
return node
async def acall(
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
"""Extract triples from nodes async."""
jobs = []
for node in nodes:
jobs.append(self._aextract(node))
return await run_jobs(
jobs,
workers=self.num_workers,
show_progress=show_progress,
desc="Extracting paths from text",
)
The GraphRAGStore class is an extension of the SimplePropertyGraphStore class, designed to implement GraphRAG pipeline. Here's a breakdown of its key components and functions:
The class uses community detection algorithms to group related nodes in the graph and then it generates summaries for each community using an LLM.
Key Methods:
build_communities():
Converts the internal graph representation to a NetworkX graph.
Applies the hierarchical Leiden algorithm for community detection.
Collects detailed information about each community.
Generates summaries for each community.
generate_community_summary(text):
_create_nx_graph():
_collect_community_info(nx_graph, clusters):
_summarize_communities(community_info):
get_community_summaries():
import re
from llama_index.core.graph_stores import SimplePropertyGraphStore
import networkx as nx
from graspologic.partition import hierarchical_leiden
from llama_index.core.llms import ChatMessage
class GraphRAGStore(SimplePropertyGraphStore):
community_summary = {}
max_cluster_size = 5
def generate_community_summary(self, text):
"""Generate summary for a given text using an LLM."""
messages = [
ChatMessage(
role="system",
content=(
"You are provided with a set of relationships from a knowledge graph, each represented as "
"entity1->entity2->relation->relationship_description. Your task is to create a summary of these "
"relationships. The summary should include the names of the entities involved and a concise synthesis "
"of the relationship descriptions. The goal is to capture the most critical and relevant details that "
"highlight the nature and significance of each relationship. Ensure that the summary is coherent and "
"integrates the information in a way that emphasizes the key aspects of the relationships."
),
),
ChatMessage(role="user", content=text),
]
response = OpenAI().chat(messages)
clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
return clean_response
def build_communities(self):
"""Builds communities from the graph and summarizes them."""
nx_graph = self._create_nx_graph()
community_hierarchical_clusters = hierarchical_leiden(
nx_graph, max_cluster_size=self.max_cluster_size
)
community_info = self._collect_community_info(
nx_graph, community_hierarchical_clusters
)
self._summarize_communities(community_info)
def _create_nx_graph(self):
"""Converts internal graph representation to NetworkX graph."""
nx_graph = nx.Graph()
for node in self.graph.nodes.values():
nx_graph.add_node(str(node))
for relation in self.graph.relations.values():
nx_graph.add_edge(
relation.source_id,
relation.target_id,
relationship=relation.label,
description=relation.properties["relationship_description"],
)
return nx_graph
def _collect_community_info(self, nx_graph, clusters):
"""Collect detailed information for each node based on their community."""
community_mapping = {item.node: item.cluster for item in clusters}
community_info = {}
for item in clusters:
cluster_id = item.cluster
node = item.node
if cluster_id not in community_info:
community_info[cluster_id] = []
for neighbor in nx_graph.neighbors(node):
if community_mapping[neighbor] == cluster_id:
edge_data = nx_graph.get_edge_data(node, neighbor)
if edge_data:
detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
community_info[cluster_id].append(detail)
return community_info
def _summarize_communities(self, community_info):
"""Generate and store summaries for each community."""
for community_id, details in community_info.items():
details_text = (
"\n".join(details) + "."
) # Ensure it ends with a period
self.community_summary[
community_id
] = self.generate_community_summary(details_text)
def get_community_summaries(self):
"""Returns the community summaries, building them if not already done."""
if not self.community_summary:
self.build_communities()
return self.community_summary
The GraphRAGQueryEngine class is a custom query engine designed to process queries using the GraphRAG approach. It leverages the community summaries generated by the GraphRAGStore to answer user queries. Here's a breakdown of its functionality:
Main Components:
graph_store: An instance of GraphRAGStore, which contains the community summaries.
llm: A Language Model (LLM) used for generating and aggregating answers.
Key Methods:
custom_query(query_str: str)
generate_answer_from_summary(community_summary, query):
aggregate_answers(community_answers):
Query Processing Flow:
Example usage:
query_engine = GraphRAGQueryEngine(graph_store=graph_store, llm=llm)
response = query_engine.query("query")
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.llms import LLM
class GraphRAGQueryEngine(CustomQueryEngine):
graph_store: GraphRAGStore
llm: LLM
def custom_query(self, query_str: str) -> str:
"""Process all community summaries to generate answers to a specific query."""
community_summaries = self.graph_store.get_community_summaries()
community_answers = [
self.generate_answer_from_summary(community_summary, query_str)
for _, community_summary in community_summaries.items()
]
final_answer = self.aggregate_answers(community_answers)
return final_answer
def generate_answer_from_summary(self, community_summary, query):
"""Generate an answer from a community summary based on a given query using LLM."""
prompt = (
f"Given the community summary: {community_summary}, "
f"how would you answer the following query? Query: {query}"
)
messages = [
ChatMessage(role="system", content=prompt),
ChatMessage(
role="user",
content="I need an answer based on the above information.",
),
]
response = self.llm.chat(messages)
cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
return cleaned_response
def aggregate_answers(self, community_answers):
"""Aggregate individual community answers into a final, coherent response."""
# intermediate_text = " ".join(community_answers)
prompt = "Combine the following intermediate answers into a final, concise response."
messages = [
ChatMessage(role="system", content=prompt),
ChatMessage(
role="user",
content=f"Intermediate answers: {community_answers}",
),
]
final_response = self.llm.chat(messages)
cleaned_final_response = re.sub(
r"^assistant:\s*", "", str(final_response)
).strip()
return cleaned_final_response
Now that we have defined all the necessary components, let’s construct the GraphRAG pipeline:
GraphRAGExtractor and GraphRAGStore.GraphRAGQueryEngine and begin querying.from llama_index.core.node_parser import SentenceSplitter
splitter = SentenceSplitter(
chunk_size=1024,
chunk_overlap=20,
)
nodes = splitter.get_nodes_from_documents(documents)
len(nodes)
GraphRAGExtractor and GraphRAGStoreKG_TRIPLET_EXTRACT_TMPL = """
-Goal-
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: Type of the entity
- entity_description: Comprehensive description of the entity's attributes and activities
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relation: relationship between source_entity and target_entity
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
3. Output Formatting:
- Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects).
- Exclude any text outside the JSON structure (e.g., no explanations or comments).
- If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }.
-An Output Example-
{
"entities": [
{
"entity_name": "Albert Einstein",
"entity_type": "Person",
"entity_description": "Albert Einstein was a theoretical physicist who developed the theory of relativity and made significant contributions to physics."
},
{
"entity_name": "Theory of Relativity",
"entity_type": "Scientific Theory",
"entity_description": "A scientific theory developed by Albert Einstein, describing the laws of physics in relation to observers in different frames of reference."
},
{
"entity_name": "Nobel Prize in Physics",
"entity_type": "Award",
"entity_description": "A prestigious international award in the field of physics, awarded annually by the Royal Swedish Academy of Sciences."
}
],
"relationships": [
{
"source_entity": "Albert Einstein",
"target_entity": "Theory of Relativity",
"relation": "developed",
"relationship_description": "Albert Einstein is the developer of the theory of relativity."
},
{
"source_entity": "Albert Einstein",
"target_entity": "Nobel Prize in Physics",
"relation": "won",
"relationship_description": "Albert Einstein won the Nobel Prize in Physics in 1921."
}
]
}
-Real Data-
######################
text: {text}
######################
output:"""
import json
def parse_fn(response_str: str) -> Any:
json_pattern = r"\{.*\}"
match = re.search(json_pattern, response_str, re.DOTALL)
entities = []
relationships = []
if not match:
return entities, relationships
json_str = match.group(0)
try:
data = json.loads(json_str)
entities = [
(
entity["entity_name"],
entity["entity_type"],
entity["entity_description"],
)
for entity in data.get("entities", [])
]
relationships = [
(
relation["source_entity"],
relation["target_entity"],
relation["relation"],
relation["relationship_description"],
)
for relation in data.get("relationships", [])
]
return entities, relationships
except json.JSONDecodeError as e:
print("Error parsing JSON:", e)
return entities, relationships
kg_extractor = GraphRAGExtractor(
llm=llm,
extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
max_paths_per_chunk=2,
parse_fn=parse_fn,
)
from llama_index.core import PropertyGraphIndex
index = PropertyGraphIndex(
nodes=nodes,
property_graph_store=GraphRAGStore(),
kg_extractors=[kg_extractor],
show_progress=True,
)
list(index.property_graph_store.graph.nodes.values())[-1]
list(index.property_graph_store.graph.relations.values())[0]
list(index.property_graph_store.graph.relations.values())[0].properties[
"relationship_description"
]
This will create communities and summary for each community.
index.property_graph_store.build_communities()
query_engine = GraphRAGQueryEngine(
graph_store=index.property_graph_store, llm=llm
)
response = query_engine.query(
"What are the main news discussed in the document?"
)
display(Markdown(f"{response.response}"))
response = query_engine.query("What are news related to financial sector?")
display(Markdown(f"{response.response}"))
This cookbook is an approximate implementation of GraphRAG. In future cookbooks, we plan to extend it as follows: