docs/docs/classic-ml/traditional-ml/sparkml/index.mdx
import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import FeatureHighlights from "@site/src/components/FeatureHighlights"; import TilesGrid from "@site/src/components/TilesGrid"; import TileCard from "@site/src/components/TileCard"; import { Database, GitBranch, Boxes, Rocket, Package, FileSignature, BookOpen } from "lucide-react";
Apache Spark MLlib provides distributed machine learning algorithms for processing large-scale datasets across clusters. MLflow integrates with Spark MLlib to track distributed ML pipelines, manage models, and enable flexible deployment from cluster training to standalone inference.
<FeatureHighlights features={[ { icon: Database, title: "Pipeline Tracking", description: "Automatically log Spark ML pipelines with all stages, transformers, and estimators. Track parameters from each pipeline component and maintain complete lineage." }, { icon: GitBranch, title: "Format Flexibility", description: "Save models in native Spark format for distributed batch processing or PyFunc format for inference outside a Spark cluster with automatic DataFrame conversion." }, { icon: Boxes, title: "Datasource Autologging", description: "Track data sources automatically with paths, formats, and versions. Maintain complete data lineage for distributed ML workflows." }, { icon: Rocket, title: "Cross-Platform Deployment", description: "Deploy Spark models with PyFunc wrappers for REST APIs and edge computing, or convert to ONNX for platform-independent inference." } ]} />
Log Spark MLlib models with mlflow.spark.log_model():
import mlflow
import mlflow.spark
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import Tokenizer, HashingTF
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder.appName("MLflowSparkExample").getOrCreate()
# Prepare training data
training = spark.createDataFrame(
[
(0, "a b c d e spark", 1.0),
(1, "b d", 0.0),
(2, "spark f g h", 1.0),
(3, "hadoop mapreduce", 0.0),
],
["id", "text", "label"],
)
# Create ML Pipeline
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=10, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
# Train and log the model
with mlflow.start_run():
model = pipeline.fit(training)
# Log the entire pipeline
model_info = mlflow.spark.log_model(spark_model=model, artifact_path="spark-pipeline")
# Log parameters manually
mlflow.log_params({
"max_iter": lr.getMaxIter(),
"reg_param": lr.getRegParam(),
"num_features": hashingTF.getNumFeatures(),
})
print(f"Model logged with URI: {model_info.model_uri}")
Automatically logs the complete pipeline with all stages, parameters, and model in both Spark native and PyFunc formats.
Preserves full Spark ML functionality for distributed processing:
# Load as native Spark model (requires Spark session)
spark_model = mlflow.spark.load_model(model_info.model_uri)
# Use for distributed batch scoring
test_data = spark.createDataFrame(
[(4, "spark i j k"), (5, "l m n"), (6, "spark hadoop spark"), (7, "apache hadoop")],
["id", "text"],
)
predictions = spark_model.transform(test_data)
predictions.show()
Enables inference outside a Spark cluster:
import pandas as pd
# Load as PyFunc model
pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
# Use with pandas DataFrame
test_data = pd.DataFrame({"text": ["spark machine learning", "hadoop distributed computing"]})
predictions = pyfunc_model.predict(test_data)
print(predictions)
PyFunc automatically converts pandas DataFrames to Spark format and creates a local Spark session for inference. Note that the Apache Spark library is still required as a dependency.
</TabItem> </Tabs>Track data sources automatically during model training:
import mlflow.spark
mlflow.spark.autolog()
with mlflow.start_run():
raw_data = spark.read.parquet("s3://my-bucket/training-data/")
model = pipeline.fit(raw_data)
mlflow.spark.log_model(model, artifact_path="model")
Requires Spark 3.0+, MLflow-Spark JAR configuration, and is not supported on Databricks shared/serverless clusters. Logs paths, formats, and versions for all datasource reads.
Infer signatures automatically for Spark ML models:
from mlflow.models import infer_signature
from pyspark.ml.functions import array_to_vector
vector_data = spark.createDataFrame(
[([3.0, 4.0], 0.0), ([5.0, 6.0], 1.0)], ["features_array", "label"]
).select(array_to_vector("features_array").alias("features"), "label")
lr = LogisticRegression(featuresCol="features", labelCol="label")
model = lr.fit(vector_data)
predictions = model.transform(vector_data)
# Infer signature from pandas DataFrames
signature = infer_signature(
vector_data.limit(2).toPandas(),
predictions.select("prediction").limit(2).toPandas(),
)
with mlflow.start_run():
mlflow.spark.log_model(
spark_model=model,
artifact_path="vector_model",
signature=signature,
)
Convert Spark models to ONNX (experimental):
import onnxmltools
with mlflow.start_run():
model = pipeline.fit(training_data)
mlflow.spark.log_model(spark_model=model, artifact_path="spark_model")
onnx_model = onnxmltools.convert_sparkml(model, name="SparkMLPipeline")
onnxmltools.utils.save_model(onnx_model, "model.onnx")
mlflow.log_artifact("model.onnx")
Register and promote Spark models:
from mlflow import MlflowClient
client = MlflowClient()
with mlflow.start_run():
model = pipeline.fit(train_data)
mlflow.spark.log_model(
spark_model=model,
artifact_path="production_candidate",
registered_model_name="CustomerSegmentationModel",
)
mlflow.set_tags({
"validation_passed": "true",
"deployment_target": "batch_scoring",
})
model_version = client.get_latest_versions("CustomerSegmentationModel", stages=["None"])[0]
client.transition_model_version_stage(
name="CustomerSegmentationModel", version=model_version.version, stage="Staging"
)