Back to Mlflow

MLflow Spark MLlib Integration

docs/docs/classic-ml/traditional-ml/sparkml/index.mdx

3.13.07.7 KB
Original Source

import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import FeatureHighlights from "@site/src/components/FeatureHighlights"; import TilesGrid from "@site/src/components/TilesGrid"; import TileCard from "@site/src/components/TileCard"; import { Database, GitBranch, Boxes, Rocket, Package, FileSignature, BookOpen } from "lucide-react";

MLflow Spark MLlib Integration

Apache Spark MLlib provides distributed machine learning algorithms for processing large-scale datasets across clusters. MLflow integrates with Spark MLlib to track distributed ML pipelines, manage models, and enable flexible deployment from cluster training to standalone inference.

Why MLflow + Spark MLlib?

<FeatureHighlights features={[ { icon: Database, title: "Pipeline Tracking", description: "Automatically log Spark ML pipelines with all stages, transformers, and estimators. Track parameters from each pipeline component and maintain complete lineage." }, { icon: GitBranch, title: "Format Flexibility", description: "Save models in native Spark format for distributed batch processing or PyFunc format for inference outside a Spark cluster with automatic DataFrame conversion." }, { icon: Boxes, title: "Datasource Autologging", description: "Track data sources automatically with paths, formats, and versions. Maintain complete data lineage for distributed ML workflows." }, { icon: Rocket, title: "Cross-Platform Deployment", description: "Deploy Spark models with PyFunc wrappers for REST APIs and edge computing, or convert to ONNX for platform-independent inference." } ]} />

Basic Model Logging

Log Spark MLlib models with mlflow.spark.log_model():

python
import mlflow
import mlflow.spark
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import Tokenizer, HashingTF
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession

# Initialize Spark session
spark = SparkSession.builder.appName("MLflowSparkExample").getOrCreate()

# Prepare training data
training = spark.createDataFrame(
    [
        (0, "a b c d e spark", 1.0),
        (1, "b d", 0.0),
        (2, "spark f g h", 1.0),
        (3, "hadoop mapreduce", 0.0),
    ],
    ["id", "text", "label"],
)

# Create ML Pipeline
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=10, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])

# Train and log the model
with mlflow.start_run():
    model = pipeline.fit(training)

    # Log the entire pipeline
    model_info = mlflow.spark.log_model(spark_model=model, artifact_path="spark-pipeline")

    # Log parameters manually
    mlflow.log_params({
        "max_iter": lr.getMaxIter(),
        "reg_param": lr.getRegParam(),
        "num_features": hashingTF.getNumFeatures(),
    })

print(f"Model logged with URI: {model_info.model_uri}")

Automatically logs the complete pipeline with all stages, parameters, and model in both Spark native and PyFunc formats.

Model Formats and Loading

<Tabs> <TabItem value="native" label="Native Spark Format">

Preserves full Spark ML functionality for distributed processing:

python
# Load as native Spark model (requires Spark session)
spark_model = mlflow.spark.load_model(model_info.model_uri)

# Use for distributed batch scoring
test_data = spark.createDataFrame(
    [(4, "spark i j k"), (5, "l m n"), (6, "spark hadoop spark"), (7, "apache hadoop")],
    ["id", "text"],
)

predictions = spark_model.transform(test_data)
predictions.show()
</TabItem> <TabItem value="pyfunc" label="PyFunc Format">

Enables inference outside a Spark cluster:

python
import pandas as pd

# Load as PyFunc model
pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)

# Use with pandas DataFrame
test_data = pd.DataFrame({"text": ["spark machine learning", "hadoop distributed computing"]})

predictions = pyfunc_model.predict(test_data)
print(predictions)

PyFunc automatically converts pandas DataFrames to Spark format and creates a local Spark session for inference. Note that the Apache Spark library is still required as a dependency.

</TabItem> </Tabs>

Datasource Autologging

Track data sources automatically during model training:

python
import mlflow.spark

mlflow.spark.autolog()

with mlflow.start_run():
    raw_data = spark.read.parquet("s3://my-bucket/training-data/")
    model = pipeline.fit(raw_data)
    mlflow.spark.log_model(model, artifact_path="model")

Requires Spark 3.0+, MLflow-Spark JAR configuration, and is not supported on Databricks shared/serverless clusters. Logs paths, formats, and versions for all datasource reads.

Model Signatures

Infer signatures automatically for Spark ML models:

python
from mlflow.models import infer_signature
from pyspark.ml.functions import array_to_vector

vector_data = spark.createDataFrame(
    [([3.0, 4.0], 0.0), ([5.0, 6.0], 1.0)], ["features_array", "label"]
).select(array_to_vector("features_array").alias("features"), "label")

lr = LogisticRegression(featuresCol="features", labelCol="label")
model = lr.fit(vector_data)

predictions = model.transform(vector_data)

# Infer signature from pandas DataFrames
signature = infer_signature(
    vector_data.limit(2).toPandas(),
    predictions.select("prediction").limit(2).toPandas(),
)

with mlflow.start_run():
    mlflow.spark.log_model(
        spark_model=model,
        artifact_path="vector_model",
        signature=signature,
    )

ONNX Conversion

Convert Spark models to ONNX (experimental):

python
import onnxmltools

with mlflow.start_run():
    model = pipeline.fit(training_data)
    mlflow.spark.log_model(spark_model=model, artifact_path="spark_model")

    onnx_model = onnxmltools.convert_sparkml(model, name="SparkMLPipeline")
    onnxmltools.utils.save_model(onnx_model, "model.onnx")
    mlflow.log_artifact("model.onnx")

Model Registry

Register and promote Spark models:

python
from mlflow import MlflowClient

client = MlflowClient()

with mlflow.start_run():
    model = pipeline.fit(train_data)

    mlflow.spark.log_model(
        spark_model=model,
        artifact_path="production_candidate",
        registered_model_name="CustomerSegmentationModel",
    )

    mlflow.set_tags({
        "validation_passed": "true",
        "deployment_target": "batch_scoring",
    })

model_version = client.get_latest_versions("CustomerSegmentationModel", stages=["None"])[0]

client.transition_model_version_stage(
    name="CustomerSegmentationModel", version=model_version.version, stage="Staging"
)

Learn More

<TilesGrid> <TileCard icon={Package} iconSize={48} title="Model Registry" description="Manage model versions, aliases, and lifecycle stages for production deployment workflows." href="/ml/model-registry" linkText="View registry docs →" containerHeight={64} /> <TileCard icon={FileSignature} iconSize={48} title="Model Signatures" description="Define input and output schemas for model validation and type checking." href="/ml/model/signatures" linkText="Learn about signatures →" containerHeight={64} /> <TileCard icon={Rocket} iconSize={48} title="Model Deployment" description="Deploy Spark models with MLflow serving, batch inference, and cloud platforms." href="/ml/deployment" linkText="Deploy models →" containerHeight={64} /> <TileCard icon={BookOpen} iconSize={48} title="MLflow Tracking" description="Track experiments, parameters, metrics, and artifacts across ML workflows." href="/ml/tracking" linkText="View tracking docs →" containerHeight={64} /> </TilesGrid>