benchmarking/ai/document_embedding/spark.ipynb
from __future__ import annotations
import pymupdf
import pandas as pd
import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pyspark.sql.functions import udf, explode, col, pandas_udf
from pyspark.sql.types import (
ArrayType,
StructType,
StructField,
StringType,
IntegerType,
FloatType,
)
%%configure -f
{
"executorCores": 1,
"conf": {
"spark.sql.execution.arrow.maxRecordsPerBatch": "10"
}
}
def extract_text_from_parsed_pdf(pdf_bytes: bytes):
try:
doc = pymupdf.Document(stream=pdf_bytes, filetype="pdf")
if len(doc) > 100:
return None
return [{"text": page.get_text(), "page_number": page.number} for page in doc]
except Exception:
return None
extract_schema = ArrayType(
StructType(
[
StructField("text", StringType(), True),
StructField("page_number", IntegerType(), True),
]
)
)
extract_udf = udf(extract_text_from_parsed_pdf, extract_schema)
def chunk(text: str):
splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=200)
chunks = []
for idx, t in enumerate(splitter.split_text(text)):
chunks.append({"text": t, "chunk_id": idx})
return chunks
chunk_schema = ArrayType(
StructType(
[
StructField("text", StringType(), True),
StructField("chunk_id", IntegerType(), True),
]
)
)
chunk_udf = udf(chunk, chunk_schema)
_model_cache = {"model": None}
def get_model():
import os
os.environ["TORCH_HOME"] = "/tmp/torch"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
if _model_cache["model"] is None:
from sentence_transformers import SentenceTransformer
device = "cuda"
model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2", device=device
)
model.compile()
_model_cache["model"] = model
return _model_cache["model"]
@pandas_udf(ArrayType(FloatType()))
def embed_udf(texts: pd.Series) -> pd.Series:
model = get_model()
if texts.empty:
return pd.Series([[]] * len(texts))
embeddings = model.encode(
texts.tolist(),
convert_to_tensor=True,
torch_dtype=torch.bfloat16,
)
return pd.Series([row.tolist() for row in embeddings.cpu().numpy()])
paths_df = spark.read.parquet(
"s3://daft-public-datasets/digitalcorpora_metadata"
).filter(col("file_name").endswith(".pdf"))
paths = [row.uploaded_pdf_path for row in paths_df.collect()]
df = spark.read.format("binaryFile").load(paths)
df = df.withColumnRenamed("path", "uploaded_pdf_path")
df = df.withColumn("pages", extract_udf(col("content")))
df = df.withColumn("page", explode("pages"))
df = df.withColumn("page_text", col("page.text"))
df = df.withColumn("page_number", col("page.page_number"))
df = df.filter(col("page_text").isNotNull())
df = df.withColumn("chunks", chunk_udf(col("page_text")))
df = df.withColumn("chunk", explode("chunks"))
df = df.withColumn("chunk_text", col("chunk.text"))
df = df.withColumn("chunk_id", col("chunk.chunk_id"))
df = df.filter(col("chunk_text").isNotNull())
df = df.withColumn("embedding", embed_udf(col("chunk_text")))
df = df.select(
"uploaded_pdf_path", "page_number", "chunk_id", "chunk_text", "embedding"
)
df.write.mode("append").parquet(
"s3://eventual-dev-benchmarking-results/ai-benchmark-results/document-embedding-results"
)