benchmarking/ai/audio_transcription/spark.ipynb
import io
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T
import pandas as pd
from pyspark.sql.functions import pandas_udf, col, udf
from pyspark.sql.types import ArrayType, FloatType, StringType, IntegerType
%%configure -f
{
"executorCores": 1,
"conf": {
"spark.sql.execution.arrow.maxRecordsPerBatch": "64",
"spark.executorEnv.HF_HOME": "/tmp/huggingface"
}
}
TRANSCRIPTION_MODEL = "openai/whisper-tiny"
NEW_SAMPLING_RATE = 16000
_processor_cache = {"processor": None}
def get_processor():
if _processor_cache["processor"] is None:
from transformers import AutoProcessor
_processor_cache["processor"] = AutoProcessor.from_pretrained(
TRANSCRIPTION_MODEL
)
return _processor_cache["processor"]
_model_cache = {"model": None, "device": None, "dtype": None}
def get_model():
if _model_cache["model"] is None:
from transformers import AutoModelForSpeechSeq2Seq
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(
TRANSCRIPTION_MODEL,
torch_dtype=dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
).to(device)
_model_cache["model"] = model
_model_cache["device"] = device
_model_cache["dtype"] = dtype
return _model_cache["model"], _model_cache["device"], _model_cache["dtype"]
@pandas_udf(ArrayType(FloatType()))
def resample_udf(audio_bytes: pd.Series) -> pd.Series:
results = []
for bytes_arr in audio_bytes:
waveform, sampling_rate = torchaudio.load(io.BytesIO(bytes_arr))
waveform = T.Resample(sampling_rate, NEW_SAMPLING_RATE)(waveform).squeeze()
results.append(waveform.numpy().astype(np.float32).tolist())
return pd.Series(results)
@pandas_udf(ArrayType(ArrayType(FloatType())))
def whisper_preprocess_udf(resampled: pd.Series) -> pd.Series:
processor = get_processor()
features = processor(
resampled.tolist(), sampling_rate=NEW_SAMPLING_RATE, return_tensors="np"
).input_features
return pd.Series([f.astype(np.float32).tolist() for f in features])
@pandas_udf(ArrayType(IntegerType()))
def transcriber_udf(extracted_features: pd.Series) -> pd.Series:
model, device, dtype = get_model()
batch = [np.array(feat, dtype=np.float32) for feat in extracted_features]
spectrograms = torch.tensor(batch, dtype=dtype, device=device)
with torch.no_grad():
token_ids = model.generate(spectrograms)
return pd.Series([toks.cpu().numpy().tolist() for toks in token_ids])
@pandas_udf(StringType())
def decode_udf(token_ids: pd.Series) -> pd.Series:
processor = get_processor()
return pd.Series(
processor.batch_decode(token_ids.tolist(), skip_special_tokens=True)
)
df = spark.read.parquet("s3://daft-public-datasets/common_voice_17")
df = df.withColumn("resampled", resample_udf(col("audio.bytes")))
df = df.withColumn("extracted_features", whisper_preprocess_udf(col("resampled")))
df = df.withColumn("token_ids", transcriber_udf(col("extracted_features")))
df = df.withColumn("transcription", decode_udf(col("token_ids")))
df = df.withColumn(
"transcription_length", udf(lambda x: len(x), IntegerType())(col("transcription"))
)
final_df = df.drop("token_ids", "extracted_features", "resampled")
final_df.write.mode("append").parquet(
"s3://eventual-dev-benchmarking-results/ai-benchmark-results/audio-transcription"
)