Back to Chroma

Forking

examples/advanced/forking.ipynb

1.5.921.2 KB
Original Source

Forking

ChromaDB now supports forking. Below is an example using forking to chunk and embed a github repo, fork off of the collection for a new github branch, and apply diffs to the new branch.

python
! pip install chromadb --quiet
! pip install tree-sitter --quiet
! pip install numpy --quiet
! pip install tree-sitter-language-pack --quiet
from tree_sitter import Language, Parser, Tree
from tree_sitter_language_pack import get_language, get_parser
import requests
import base64
import os
import getpass
from tqdm import tqdm
import chromadb
from chromadb.utils.embedding_functions import JinaEmbeddingFunction
from chromadb.utils.results import query_result_to_dfs
from chromadb.api.models.Collection import Collection
from chromadb.api import ClientAPI
python

PY_LANGUAGE = get_language("python")
REPO_OWNER = "jairad26"
REPO_NAME = "Django-WebApp"
EXISTING_BRANCH = "main"
NEW_BRANCH = "test1"
os.environ["GITHUB_API_KEY"] = getpass.getpass("Github API Key:")
os.environ["CHROMA_JINA_API_KEY"] = getpass.getpass("Jina API Key:")
os.environ["CHROMA_API_KEY"] = getpass.getpass("Chroma API Key:")

Chunker and Github Helpers

Below are 2 helper classes CodeChunker and GithubRepoProcessor

CodeChunker is a custom tree-sitter implementation that allows you to chunk files into nodes and converts them into embeddable chunks. GithubRepoProcessor is a wrapper around the Github client to make it easier to pull file contents, and calculate diffs between branches

python

class CodeChunker:
    def __init__(self, language: str = "python", max_chunk_size=500):
        """Initialize the Python code chunker.
        
        Args:
            language: The programming language of the code
            max_chunk_size: Maximum chunk size in bytes
        """
        # Create a parser
        self.parser = get_parser(language)
        
        # Set maximum chunk size (in bytes)
        self.max_chunk_size = max_chunk_size
        
        # Define what node types we consider as "chunkable"
        self.chunkable_types = {
            "function_definition", 
            "class_definition",
            "if_statement",
            "for_statement",
            "while_statement",
            "try_statement",
            "with_statement",
            "match_statement"
        }
    
    def parse_code(self, code) -> Tree:
        """Parse the code into an AST."""
        tree = self.parser.parse(bytes(code, "utf8"))
        return tree
    
    def get_node_code(self, node, code):
        """Extract the code text for a given node."""
        start_byte = node.start_byte
        end_byte = node.end_byte
        return code[start_byte:end_byte]
    
    def get_node_name(self, node, code):
        """Try to extract the name of a function or class."""
        if node.type in ("function_definition", "class_definition"):
            # Find the identifier child
            for child in node.children:
                if child.type == "identifier":
                    return self.get_node_code(child, code)
        return None
    
    def get_node_info(self, node, code, parent_type=None):
        """Get information about a node."""
        return {
            "type": node.type,
            "name": self.get_node_name(node, code),
            "parent_type": parent_type,
            "code": self.get_node_code(node, code),
            "start_line": node.start_point[0],
            "end_line": node.end_point[0],
            "start_byte": node.start_byte,
            "end_byte": node.end_byte,
            "size": node.end_byte - node.start_byte,
        }
    
    def find_module_imports(self, tree: Tree, code: str) -> list[str]:
        """Find all import statements in the module."""
        imports = []
        
        # Define import node types
        import_types = {"import_statement", "import_from_statement"}
        
        # Walk through the tree
        cursor = tree.walk()
        
        def visit_node():
            node = cursor.node
            
            assert node is not None
            
            # If this is an import node, add it to our list
            if node.type in import_types:
                imports.append(self.get_node_code(node, code))
            
            # Continue traversal
            if cursor.goto_first_child():
                visit_node()
                while cursor.goto_next_sibling():
                    visit_node()
                cursor.goto_parent()
        
        visit_node()
        return imports
    
    def find_chunkable_nodes(self, tree, code):
        """Find nodes that can be treated as independent chunks."""
        chunkable_nodes = []
        
        # Walk through the tree
        cursor = tree.walk()
        
        def visit_node():
            node = cursor.node
            
            # If this is a chunkable node, add it to our list
            if node.type in self.chunkable_types:
                chunkable_nodes.append(
                    self.get_node_info(node, code, parent_type=cursor.node.parent.type if cursor.node.parent else None)
                )
            
            # Continue traversal
            if cursor.goto_first_child():
                visit_node()
                while cursor.goto_next_sibling():
                    visit_node()
                cursor.goto_parent()
        
        visit_node()
        return chunkable_nodes
    
    def break_large_node(self, node, max_size):
        """Break a large node into smaller chunks based on lines."""
        node_code = node["code"]
        lines = node_code.splitlines()
        
        chunks = []
        current_lines = []
        current_size = 0
        
        for line in lines:
            line_size = len(line) + 1  # +1 for newline
            
            # If adding this line would exceed max chunk size, finalize current chunk
            if current_size + line_size > max_size and current_lines:
                chunks.append({
                    "parent_node": node,
                    "lines": current_lines.copy(),
                    "size": current_size
                })
                current_lines = []
                current_size = 0
            
            # Add line to current chunk
            current_lines.append(line)
            current_size += line_size
        
        # Add any remaining lines as the final chunk
        if current_lines:
            chunks.append({
                "parent_node": node,
                "lines": current_lines,
                "size": current_size
            })
        
        return chunks
    
    def create_chunks(self, code):
        """Break the code into semantic chunks based on the AST."""
        # Parse the code to get the AST
        tree = self.parse_code(code)
        
        # Get imports
        imports = self.find_module_imports(tree, code)
        imports_text = "\n".join(imports)
        imports_size = len(imports_text) + 2 if imports_text else 0  # +2 for newlines if imports exist
        
        # Find chunkable nodes
        nodes = self.find_chunkable_nodes(tree, code)
        
        # First, identify oversized nodes that need special handling
        regular_nodes = []
        line_chunked_nodes = []
        
        for node in nodes:
            # Check if the node alone would exceed our limit
            if node["size"] > self.max_chunk_size:
                # This node needs to be broken down
                sub_chunks = self.break_large_node(node, self.max_chunk_size)
                line_chunked_nodes.extend(sub_chunks)
            else:
                regular_nodes.append(node)
        
        # Sort regular nodes by size (smallest first to maximize packing)
        regular_nodes.sort(key=lambda x: x["size"])
        
        # Group regular nodes into chunks based on size
        semantic_chunks = []
        current_chunk = []
        current_size = 0
        
        # Consider imports only for the first chunk
        first_chunk_imports_size = imports_size if imports_text else 0
        
        for node in regular_nodes:
            # For the first chunk only, account for imports size
            effective_max_size = self.max_chunk_size
            effective_current_size = current_size
            
            # If this would be the first chunk, account for imports
            if not semantic_chunks and not current_chunk:
                effective_current_size += first_chunk_imports_size
            
            # If adding this node would exceed max chunk size, finalize current chunk
            if effective_current_size + node["size"] > effective_max_size and current_chunk:
                # For the first chunk, include imports
                if not semantic_chunks and imports_text:
                    chunk_code = imports_text + "\n\n" + "\n".join([n["code"] for n in current_chunk])
                else:
                    chunk_code = "\n".join([n["code"] for n in current_chunk])
                
                semantic_chunks.append({
                    "nodes": current_chunk,
                    "size": len(chunk_code),
                    "code": chunk_code,
                    "has_imports": not semantic_chunks and imports_text  # Only first chunk has imports
                })
                current_chunk = []
                current_size = 0
            
            # Add node to current chunk
            current_chunk.append(node)
            current_size += node["size"]
        
        # Add any remaining nodes as the final chunk
        if current_chunk:
            if not semantic_chunks and imports_text:
                chunk_code = imports_text + "\n\n" + "\n".join([n["code"] for n in current_chunk])
            else:
                chunk_code = "\n".join([n["code"] for n in current_chunk])
            
            semantic_chunks.append({
                "nodes": current_chunk,
                "size": len(chunk_code),
                "code": chunk_code,
                "has_imports": not semantic_chunks and imports_text  # Only first chunk has imports
            })
        
        # Now handle line-chunked nodes
        line_based_chunks = []
        for chunked_node in line_chunked_nodes:
            node_code = "\n".join(chunked_node["lines"])
            
            # Only include imports if this would be the first chunk overall
            if not semantic_chunks and not line_based_chunks and imports_text:
                chunk_code = imports_text + "\n\n" + node_code
                has_imports = True
            else:
                chunk_code = node_code
                has_imports = False
            
            # Get the original node info but with just the subset of code
            parent_node = chunked_node["parent_node"]
            line_based_chunks.append({
                "nodes": [{
                    "type": parent_node["type"],
                    "name": parent_node["name"],
                    "parent_type": parent_node["parent_type"],
                    "code": node_code,
                    "start_line": parent_node["start_line"],
                    "end_line": parent_node["end_line"],
                    "start_byte": parent_node["start_byte"],
                    "end_byte": parent_node["end_byte"],
                    "size": len(node_code),
                    "is_partial": True,
                }],
                "size": len(chunk_code),
                "code": chunk_code,
                "has_imports": has_imports
            })
        
        # Combine both types of chunks
        all_chunks = []
        
        # Ensure the chunk with imports comes first if it exists
        import_chunks = [c for c in semantic_chunks + line_based_chunks if c.get("has_imports")]
        non_import_chunks = [c for c in semantic_chunks + line_based_chunks if not c.get("has_imports")]
        
        all_chunks = import_chunks + non_import_chunks
        
        # Final verification - ensure no chunk exceeds 5000 bytes
        for i, chunk in enumerate(all_chunks):
            if chunk["size"] > 5000:
                print(f"Warning: Chunk {i} is {chunk['size']} bytes, which exceeds the 5000 byte limit.")
        
        # Add a fallback for files with no chunks but valid code
        if not all_chunks and code.strip():
            # Create a single chunk with the entire file
            chunk_code = code.strip()
            all_chunks.append({
                "nodes": [],
                "size": len(chunk_code),
                "code": chunk_code,
                "has_imports": True 
            })
        
        return all_chunks
    
    def create_embeddings_ready_chunks(self, file_path, code, include_metadata=True):
        """Create chunks with metadata ready for embedding and storage in a vector DB."""
        chunks = self.create_chunks(code)
        
        result = []
        for i, chunk in enumerate(chunks):
            if include_metadata:
                # Check if any nodes are partial chunks
                partial_info = any(n.get("is_partial", False) for n in chunk["nodes"])
                
                result.append({
                    "chunk_id": str(f"{file_path}_{i}"),
                    "code": chunk["code"],
                    "metadata": {
                        "file_path": file_path,
                        "size": chunk["size"],
                        # "node_types": str([n["type"] for n in chunk["nodes"]]),
                        # "node_names": str([n["name"] for n in chunk["nodes"] if n["name"]]),
                        # "node_count": len(chunk["nodes"]),
                        # "has_imports": chunk.get("has_imports", False),
                        # "is_partial_chunk": partial_info,
                        # "line_range": str([
                        #     min([n["start_line"] for n in chunk["nodes"]]),
                        #     max([n["end_line"] for n in chunk["nodes"]])
                        # ])
                    }
                })
            else:
                result.append(chunk["code"])
        
        return result
python
class GitHubRepoProcessor:
    def __init__(self, owner, repo, token=None):
        """Initialize with GitHub repository details."""
        self.owner = owner
        self.repo = repo
        self.base_url = f"https://api.github.com/repos/{owner}/{repo}"
        self.headers = {}
        
        if token:
            self.headers["Authorization"] = f"token {token}"
        
        # Add required GitHub API headers
        self.headers["Accept"] = "application/vnd.github.v3+json"
        self.headers["X-GitHub-Api-Version"] = "2022-11-28"
    
    def get_file_list(self, branch="main", path=""):
        """Get a list of files in the repository, recursively."""
        all_files = []
        self._get_contents_recursive(branch, path, all_files)
        return all_files
    
    def _get_contents_recursive(self, branch, path, all_files):
        """Recursively fetch repository contents."""
        url = f"{self.base_url}/contents/{path}"
        if path == "":
            url = f"{self.base_url}/contents"
        
        response = requests.get(url, headers=self.headers, params={"ref": branch})
        
        if response.status_code != 200:
            print(f"Error fetching contents at {path}: {response.status_code}")
            print(response.json().get("message", ""))
            return
        
        contents = response.json()
        
        # Handle case where response is a file not a directory
        if not isinstance(contents, list):
            contents = [contents]
        
        for item in contents:
            if item["type"] == "file" and item["name"].endswith(".py"):
                all_files.append({
                    "path": item["path"],
                    "download_url": item["download_url"],
                    "sha": item["sha"],
                    "size": item["size"]
                })
            elif item["type"] == "dir":
                self._get_contents_recursive(branch, item["path"], all_files)
    
    def get_file_content(self, branch, file_path):
        """Get the content of a specific file."""
        url = f"{self.base_url}/contents/{file_path}"
        response = requests.get(url, headers=self.headers, params={"ref": branch})
        
        if response.status_code != 200:
            print(f"Error fetching file {file_path}: {response.status_code}")
            return None
        
        content_data = response.json()
        
        if "content" not in content_data:
            print(f"File {file_path} is too large for the GitHub API. Getting it directly...")
            # Get the raw file directly
            raw_response = requests.get(content_data.get("download_url", ""))
            if raw_response.status_code == 200:
                return raw_response.text
            else:
                print(f"Failed to get raw content for {file_path}")
                return None
        
        # Decode content from base64
        content = base64.b64decode(content_data["content"]).decode("utf-8")
        return content
    
    def get_diff_files(self, branch1, branch2) -> dict:
        """Get the diff between two branches."""
        url = f"{self.base_url}/compare/{branch1}...{branch2}"
        response = requests.get(url, headers=self.headers)
        return {file['filename']: file['status'] for file in response.json()['files']}
    
python
def process_repo_file(chunker: CodeChunker, github_processor: GitHubRepoProcessor, file_path: str, branch="main"):
    code = github_processor.get_file_content(branch, file_path)
    if not code or not code.strip():
        return []
    file_chunks = chunker.create_embeddings_ready_chunks(file_path, code)
    return file_chunks

def process_repo_files(chunker: CodeChunker, github_processor: GitHubRepoProcessor, branch="main"):
    """Process all Python files in the repository and generate chunks."""
    # Get list of Python files
    python_files = github_processor.get_file_list()
    print(f"Found {len(python_files)} Python files.")
    
    all_chunks = []
    
    for file_info in tqdm(python_files, desc="Processing files"):
        file_path = file_info["path"]
        
        try:
            file_chunks = process_repo_file(chunker, github_processor, file_path, branch)
            all_chunks.extend(file_chunks)
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
    
    print(f"Generated {len(all_chunks)} chunks in total.")
    return all_chunks

Generate Chunks

python
chunker = CodeChunker(language="python", max_chunk_size=500)
gh_processor = GitHubRepoProcessor(REPO_OWNER, REPO_NAME, os.environ["GITHUB_API_KEY"])
chunks = process_repo_files(chunker=chunker, github_processor=gh_processor, branch="main")

Below are helper functions to help fork a collection and add chunks to a collection

python
# Fork an existing collection
def get_or_create_new_fork(client: ClientAPI, existing_collection: Collection, new_name: str) -> Collection:
    try:
        return existing_collection.fork(new_name)
    except Exception as e:
        if client.get_collection(new_name) is not None:
            return client.get_collection(new_name)
        else:
            raise e

# add chunks to a collection
def add_chunks(collection: Collection, chunks: list[dict]):
    # only add 100 chunks at a time
    for i in range(0, len(chunks), 100):
        collection.add(
            ids=[chunk["chunk_id"] for chunk in chunks[i:i+100]],
            documents=[chunk["code"] for chunk in chunks[i:i+100]],
            metadatas=[chunk["metadata"] for chunk in chunks[i:i+100]]
        )

# populate a collection with the diff files by deleting the existing chunks, regenerating chunks & adding
# the diff dictionary looks like this:
# {'file_path': 'status'}
# status can be 'added', 'modified', 'removed'
def populate_branch_diff(collection: Collection, diff_files: dict):
    collection.delete(where={"file_path": {"$in": list(diff_files.keys())}})
    for file_path, status in diff_files.items():
        if status == "added" or status == "modified":
            chunks = process_repo_file(chunker, gh_processor, file_path, NEW_BRANCH)
            add_chunks(collection, chunks)            

ChromaDB Impl

python
client = chromadb.HttpClient(
  ssl=True,
  host='api.trychroma.com',
  tenant='fc152910-6412-4b6b-b67a-4eb229ef50ce',
  database='Example Demo',
  headers={
    'x-chroma-token': os.environ["CHROMA_API_KEY"]
  }
)
  

main_collection = client.get_or_create_collection(
  name=f"{REPO_OWNER}_{REPO_NAME}_{EXISTING_BRANCH}",
  configuration={
    "embedding_function": JinaEmbeddingFunction(
      model_name="jina-embeddings-v2-base-code"
    )
  }
)

add_chunks(
  collection=main_collection,
  chunks=chunks
)

Fork the new branch and make updates

python
diff = gh_processor.get_diff_files(EXISTING_BRANCH, NEW_BRANCH)

new_branch_collection = get_or_create_new_fork(
  client=client,
  existing_collection=main_collection,
  new_name=f"{REPO_OWNER}_{REPO_NAME}_{NEW_BRANCH}"
)

populate_branch_diff(new_branch_collection, diff)

Query both collections

python
# code search both collections 
query = "print('Hello, forking!')"

main_results = main_collection.query(
  query_texts=[query],
  n_results=10
)

new_branch_results = new_branch_collection.query(
  query_texts=[query],
  n_results=10
)

for i, df in enumerate(query_result_to_dfs(main_results)):
  print(df.to_string())

for i, df in enumerate(query_result_to_dfs(new_branch_results)):
  print(df.to_string())