examples/advanced/forking.ipynb
ChromaDB now supports forking. Below is an example using forking to chunk and embed a github repo, fork off of the collection for a new github branch, and apply diffs to the new branch.
! pip install chromadb --quiet
! pip install tree-sitter --quiet
! pip install numpy --quiet
! pip install tree-sitter-language-pack --quiet
from tree_sitter import Language, Parser, Tree
from tree_sitter_language_pack import get_language, get_parser
import requests
import base64
import os
import getpass
from tqdm import tqdm
import chromadb
from chromadb.utils.embedding_functions import JinaEmbeddingFunction
from chromadb.utils.results import query_result_to_dfs
from chromadb.api.models.Collection import Collection
from chromadb.api import ClientAPI
PY_LANGUAGE = get_language("python")
REPO_OWNER = "jairad26"
REPO_NAME = "Django-WebApp"
EXISTING_BRANCH = "main"
NEW_BRANCH = "test1"
os.environ["GITHUB_API_KEY"] = getpass.getpass("Github API Key:")
os.environ["CHROMA_JINA_API_KEY"] = getpass.getpass("Jina API Key:")
os.environ["CHROMA_API_KEY"] = getpass.getpass("Chroma API Key:")
Below are 2 helper classes CodeChunker and GithubRepoProcessor
CodeChunker is a custom tree-sitter implementation that allows you to chunk files into nodes and converts them into embeddable chunks.
GithubRepoProcessor is a wrapper around the Github client to make it easier to pull file contents, and calculate diffs between branches
class CodeChunker:
def __init__(self, language: str = "python", max_chunk_size=500):
"""Initialize the Python code chunker.
Args:
language: The programming language of the code
max_chunk_size: Maximum chunk size in bytes
"""
# Create a parser
self.parser = get_parser(language)
# Set maximum chunk size (in bytes)
self.max_chunk_size = max_chunk_size
# Define what node types we consider as "chunkable"
self.chunkable_types = {
"function_definition",
"class_definition",
"if_statement",
"for_statement",
"while_statement",
"try_statement",
"with_statement",
"match_statement"
}
def parse_code(self, code) -> Tree:
"""Parse the code into an AST."""
tree = self.parser.parse(bytes(code, "utf8"))
return tree
def get_node_code(self, node, code):
"""Extract the code text for a given node."""
start_byte = node.start_byte
end_byte = node.end_byte
return code[start_byte:end_byte]
def get_node_name(self, node, code):
"""Try to extract the name of a function or class."""
if node.type in ("function_definition", "class_definition"):
# Find the identifier child
for child in node.children:
if child.type == "identifier":
return self.get_node_code(child, code)
return None
def get_node_info(self, node, code, parent_type=None):
"""Get information about a node."""
return {
"type": node.type,
"name": self.get_node_name(node, code),
"parent_type": parent_type,
"code": self.get_node_code(node, code),
"start_line": node.start_point[0],
"end_line": node.end_point[0],
"start_byte": node.start_byte,
"end_byte": node.end_byte,
"size": node.end_byte - node.start_byte,
}
def find_module_imports(self, tree: Tree, code: str) -> list[str]:
"""Find all import statements in the module."""
imports = []
# Define import node types
import_types = {"import_statement", "import_from_statement"}
# Walk through the tree
cursor = tree.walk()
def visit_node():
node = cursor.node
assert node is not None
# If this is an import node, add it to our list
if node.type in import_types:
imports.append(self.get_node_code(node, code))
# Continue traversal
if cursor.goto_first_child():
visit_node()
while cursor.goto_next_sibling():
visit_node()
cursor.goto_parent()
visit_node()
return imports
def find_chunkable_nodes(self, tree, code):
"""Find nodes that can be treated as independent chunks."""
chunkable_nodes = []
# Walk through the tree
cursor = tree.walk()
def visit_node():
node = cursor.node
# If this is a chunkable node, add it to our list
if node.type in self.chunkable_types:
chunkable_nodes.append(
self.get_node_info(node, code, parent_type=cursor.node.parent.type if cursor.node.parent else None)
)
# Continue traversal
if cursor.goto_first_child():
visit_node()
while cursor.goto_next_sibling():
visit_node()
cursor.goto_parent()
visit_node()
return chunkable_nodes
def break_large_node(self, node, max_size):
"""Break a large node into smaller chunks based on lines."""
node_code = node["code"]
lines = node_code.splitlines()
chunks = []
current_lines = []
current_size = 0
for line in lines:
line_size = len(line) + 1 # +1 for newline
# If adding this line would exceed max chunk size, finalize current chunk
if current_size + line_size > max_size and current_lines:
chunks.append({
"parent_node": node,
"lines": current_lines.copy(),
"size": current_size
})
current_lines = []
current_size = 0
# Add line to current chunk
current_lines.append(line)
current_size += line_size
# Add any remaining lines as the final chunk
if current_lines:
chunks.append({
"parent_node": node,
"lines": current_lines,
"size": current_size
})
return chunks
def create_chunks(self, code):
"""Break the code into semantic chunks based on the AST."""
# Parse the code to get the AST
tree = self.parse_code(code)
# Get imports
imports = self.find_module_imports(tree, code)
imports_text = "\n".join(imports)
imports_size = len(imports_text) + 2 if imports_text else 0 # +2 for newlines if imports exist
# Find chunkable nodes
nodes = self.find_chunkable_nodes(tree, code)
# First, identify oversized nodes that need special handling
regular_nodes = []
line_chunked_nodes = []
for node in nodes:
# Check if the node alone would exceed our limit
if node["size"] > self.max_chunk_size:
# This node needs to be broken down
sub_chunks = self.break_large_node(node, self.max_chunk_size)
line_chunked_nodes.extend(sub_chunks)
else:
regular_nodes.append(node)
# Sort regular nodes by size (smallest first to maximize packing)
regular_nodes.sort(key=lambda x: x["size"])
# Group regular nodes into chunks based on size
semantic_chunks = []
current_chunk = []
current_size = 0
# Consider imports only for the first chunk
first_chunk_imports_size = imports_size if imports_text else 0
for node in regular_nodes:
# For the first chunk only, account for imports size
effective_max_size = self.max_chunk_size
effective_current_size = current_size
# If this would be the first chunk, account for imports
if not semantic_chunks and not current_chunk:
effective_current_size += first_chunk_imports_size
# If adding this node would exceed max chunk size, finalize current chunk
if effective_current_size + node["size"] > effective_max_size and current_chunk:
# For the first chunk, include imports
if not semantic_chunks and imports_text:
chunk_code = imports_text + "\n\n" + "\n".join([n["code"] for n in current_chunk])
else:
chunk_code = "\n".join([n["code"] for n in current_chunk])
semantic_chunks.append({
"nodes": current_chunk,
"size": len(chunk_code),
"code": chunk_code,
"has_imports": not semantic_chunks and imports_text # Only first chunk has imports
})
current_chunk = []
current_size = 0
# Add node to current chunk
current_chunk.append(node)
current_size += node["size"]
# Add any remaining nodes as the final chunk
if current_chunk:
if not semantic_chunks and imports_text:
chunk_code = imports_text + "\n\n" + "\n".join([n["code"] for n in current_chunk])
else:
chunk_code = "\n".join([n["code"] for n in current_chunk])
semantic_chunks.append({
"nodes": current_chunk,
"size": len(chunk_code),
"code": chunk_code,
"has_imports": not semantic_chunks and imports_text # Only first chunk has imports
})
# Now handle line-chunked nodes
line_based_chunks = []
for chunked_node in line_chunked_nodes:
node_code = "\n".join(chunked_node["lines"])
# Only include imports if this would be the first chunk overall
if not semantic_chunks and not line_based_chunks and imports_text:
chunk_code = imports_text + "\n\n" + node_code
has_imports = True
else:
chunk_code = node_code
has_imports = False
# Get the original node info but with just the subset of code
parent_node = chunked_node["parent_node"]
line_based_chunks.append({
"nodes": [{
"type": parent_node["type"],
"name": parent_node["name"],
"parent_type": parent_node["parent_type"],
"code": node_code,
"start_line": parent_node["start_line"],
"end_line": parent_node["end_line"],
"start_byte": parent_node["start_byte"],
"end_byte": parent_node["end_byte"],
"size": len(node_code),
"is_partial": True,
}],
"size": len(chunk_code),
"code": chunk_code,
"has_imports": has_imports
})
# Combine both types of chunks
all_chunks = []
# Ensure the chunk with imports comes first if it exists
import_chunks = [c for c in semantic_chunks + line_based_chunks if c.get("has_imports")]
non_import_chunks = [c for c in semantic_chunks + line_based_chunks if not c.get("has_imports")]
all_chunks = import_chunks + non_import_chunks
# Final verification - ensure no chunk exceeds 5000 bytes
for i, chunk in enumerate(all_chunks):
if chunk["size"] > 5000:
print(f"Warning: Chunk {i} is {chunk['size']} bytes, which exceeds the 5000 byte limit.")
# Add a fallback for files with no chunks but valid code
if not all_chunks and code.strip():
# Create a single chunk with the entire file
chunk_code = code.strip()
all_chunks.append({
"nodes": [],
"size": len(chunk_code),
"code": chunk_code,
"has_imports": True
})
return all_chunks
def create_embeddings_ready_chunks(self, file_path, code, include_metadata=True):
"""Create chunks with metadata ready for embedding and storage in a vector DB."""
chunks = self.create_chunks(code)
result = []
for i, chunk in enumerate(chunks):
if include_metadata:
# Check if any nodes are partial chunks
partial_info = any(n.get("is_partial", False) for n in chunk["nodes"])
result.append({
"chunk_id": str(f"{file_path}_{i}"),
"code": chunk["code"],
"metadata": {
"file_path": file_path,
"size": chunk["size"],
# "node_types": str([n["type"] for n in chunk["nodes"]]),
# "node_names": str([n["name"] for n in chunk["nodes"] if n["name"]]),
# "node_count": len(chunk["nodes"]),
# "has_imports": chunk.get("has_imports", False),
# "is_partial_chunk": partial_info,
# "line_range": str([
# min([n["start_line"] for n in chunk["nodes"]]),
# max([n["end_line"] for n in chunk["nodes"]])
# ])
}
})
else:
result.append(chunk["code"])
return result
class GitHubRepoProcessor:
def __init__(self, owner, repo, token=None):
"""Initialize with GitHub repository details."""
self.owner = owner
self.repo = repo
self.base_url = f"https://api.github.com/repos/{owner}/{repo}"
self.headers = {}
if token:
self.headers["Authorization"] = f"token {token}"
# Add required GitHub API headers
self.headers["Accept"] = "application/vnd.github.v3+json"
self.headers["X-GitHub-Api-Version"] = "2022-11-28"
def get_file_list(self, branch="main", path=""):
"""Get a list of files in the repository, recursively."""
all_files = []
self._get_contents_recursive(branch, path, all_files)
return all_files
def _get_contents_recursive(self, branch, path, all_files):
"""Recursively fetch repository contents."""
url = f"{self.base_url}/contents/{path}"
if path == "":
url = f"{self.base_url}/contents"
response = requests.get(url, headers=self.headers, params={"ref": branch})
if response.status_code != 200:
print(f"Error fetching contents at {path}: {response.status_code}")
print(response.json().get("message", ""))
return
contents = response.json()
# Handle case where response is a file not a directory
if not isinstance(contents, list):
contents = [contents]
for item in contents:
if item["type"] == "file" and item["name"].endswith(".py"):
all_files.append({
"path": item["path"],
"download_url": item["download_url"],
"sha": item["sha"],
"size": item["size"]
})
elif item["type"] == "dir":
self._get_contents_recursive(branch, item["path"], all_files)
def get_file_content(self, branch, file_path):
"""Get the content of a specific file."""
url = f"{self.base_url}/contents/{file_path}"
response = requests.get(url, headers=self.headers, params={"ref": branch})
if response.status_code != 200:
print(f"Error fetching file {file_path}: {response.status_code}")
return None
content_data = response.json()
if "content" not in content_data:
print(f"File {file_path} is too large for the GitHub API. Getting it directly...")
# Get the raw file directly
raw_response = requests.get(content_data.get("download_url", ""))
if raw_response.status_code == 200:
return raw_response.text
else:
print(f"Failed to get raw content for {file_path}")
return None
# Decode content from base64
content = base64.b64decode(content_data["content"]).decode("utf-8")
return content
def get_diff_files(self, branch1, branch2) -> dict:
"""Get the diff between two branches."""
url = f"{self.base_url}/compare/{branch1}...{branch2}"
response = requests.get(url, headers=self.headers)
return {file['filename']: file['status'] for file in response.json()['files']}
def process_repo_file(chunker: CodeChunker, github_processor: GitHubRepoProcessor, file_path: str, branch="main"):
code = github_processor.get_file_content(branch, file_path)
if not code or not code.strip():
return []
file_chunks = chunker.create_embeddings_ready_chunks(file_path, code)
return file_chunks
def process_repo_files(chunker: CodeChunker, github_processor: GitHubRepoProcessor, branch="main"):
"""Process all Python files in the repository and generate chunks."""
# Get list of Python files
python_files = github_processor.get_file_list()
print(f"Found {len(python_files)} Python files.")
all_chunks = []
for file_info in tqdm(python_files, desc="Processing files"):
file_path = file_info["path"]
try:
file_chunks = process_repo_file(chunker, github_processor, file_path, branch)
all_chunks.extend(file_chunks)
except Exception as e:
print(f"Error processing {file_path}: {str(e)}")
print(f"Generated {len(all_chunks)} chunks in total.")
return all_chunks
chunker = CodeChunker(language="python", max_chunk_size=500)
gh_processor = GitHubRepoProcessor(REPO_OWNER, REPO_NAME, os.environ["GITHUB_API_KEY"])
chunks = process_repo_files(chunker=chunker, github_processor=gh_processor, branch="main")
Below are helper functions to help fork a collection and add chunks to a collection
# Fork an existing collection
def get_or_create_new_fork(client: ClientAPI, existing_collection: Collection, new_name: str) -> Collection:
try:
return existing_collection.fork(new_name)
except Exception as e:
if client.get_collection(new_name) is not None:
return client.get_collection(new_name)
else:
raise e
# add chunks to a collection
def add_chunks(collection: Collection, chunks: list[dict]):
# only add 100 chunks at a time
for i in range(0, len(chunks), 100):
collection.add(
ids=[chunk["chunk_id"] for chunk in chunks[i:i+100]],
documents=[chunk["code"] for chunk in chunks[i:i+100]],
metadatas=[chunk["metadata"] for chunk in chunks[i:i+100]]
)
# populate a collection with the diff files by deleting the existing chunks, regenerating chunks & adding
# the diff dictionary looks like this:
# {'file_path': 'status'}
# status can be 'added', 'modified', 'removed'
def populate_branch_diff(collection: Collection, diff_files: dict):
collection.delete(where={"file_path": {"$in": list(diff_files.keys())}})
for file_path, status in diff_files.items():
if status == "added" or status == "modified":
chunks = process_repo_file(chunker, gh_processor, file_path, NEW_BRANCH)
add_chunks(collection, chunks)
client = chromadb.HttpClient(
ssl=True,
host='api.trychroma.com',
tenant='fc152910-6412-4b6b-b67a-4eb229ef50ce',
database='Example Demo',
headers={
'x-chroma-token': os.environ["CHROMA_API_KEY"]
}
)
main_collection = client.get_or_create_collection(
name=f"{REPO_OWNER}_{REPO_NAME}_{EXISTING_BRANCH}",
configuration={
"embedding_function": JinaEmbeddingFunction(
model_name="jina-embeddings-v2-base-code"
)
}
)
add_chunks(
collection=main_collection,
chunks=chunks
)
diff = gh_processor.get_diff_files(EXISTING_BRANCH, NEW_BRANCH)
new_branch_collection = get_or_create_new_fork(
client=client,
existing_collection=main_collection,
new_name=f"{REPO_OWNER}_{REPO_NAME}_{NEW_BRANCH}"
)
populate_branch_diff(new_branch_collection, diff)
# code search both collections
query = "print('Hello, forking!')"
main_results = main_collection.query(
query_texts=[query],
n_results=10
)
new_branch_results = new_branch_collection.query(
query_texts=[query],
n_results=10
)
for i, df in enumerate(query_result_to_dfs(main_results)):
print(df.to_string())
for i, df in enumerate(query_result_to_dfs(new_branch_results)):
print(df.to_string())