Back to Synapseml

Text Explainers

docs/Explore Algorithms/Responsible AI/Text Explainers.ipynb

1.1.33.7 KB
Original Source

Interpretability - Text Explainers

In this example, we use LIME and Kernel SHAP explainers to explain a text classification model.

First we import the packages and define some UDFs and a plotting function we will need later.

python
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.functions import vector_to_array
from synapse.ml.explainers import *
from synapse.ml.featurize.text import TextFeaturizer
from synapse.ml.core.platform import *

vec_access = udf(lambda v, i: float(v[i]), FloatType())

Load training data, and convert rating to binary label.

python
data = (
    spark.read.parquet("wasbs://[email protected]/BookReviewsFromAmazon10K.parquet")
    .withColumn("label", (col("rating") > 3).cast(LongType()))
    .select("label", "text")
    .cache()
)

display(data)

We train a text classification model, and randomly sample 10 rows to explain.

python
train, test = data.randomSplit([0.60, 0.40])

pipeline = Pipeline(
    stages=[
        TextFeaturizer(
            inputCol="text",
            outputCol="features",
            useStopWordsRemover=True,
            useIDF=True,
            minDocFreq=20,
            numFeatures=1 << 16,
        ),
        LogisticRegression(maxIter=100, regParam=0.005, labelCol="label", featuresCol="features"),
    ]
)

model = pipeline.fit(train)

prediction = model.transform(test)

explain_instances = prediction.orderBy(rand()).limit(10)
python
def plotConfusionMatrix(df, label, prediction, classLabels):
    from synapse.ml.plot import confusionMatrix
    import matplotlib.pyplot as plt

    fig = plt.figure(figsize=(4.5, 4.5))
    confusionMatrix(df, label, prediction, classLabels)
    if running_on_synapse():
        plt.show()
    else:
        display(fig)


plotConfusionMatrix(model.transform(test), "label", "prediction", [0, 1])

First we use the LIME text explainer to explain the model's predicted probability for a given observation.

python
lime = TextLIME(
    model=model,
    outputCol="weights",
    inputCol="text",
    targetCol="probability",
    targetClasses=[1],
    tokensCol="tokens",
    samplingFraction=0.7,
    numSamples=2000,
)

lime_results = (
    lime.transform(explain_instances)
    .select("tokens", "weights", "r2", "probability", "text")
    .withColumn("probability", vec_access("probability", lit(1)))
    .withColumn("weights", vector_to_array(col("weights").getItem(0)))
    .withColumn("r2", vec_access("r2", lit(0)))
    .withColumn("tokens_weights", arrays_zip("tokens", "weights"))
)

display(lime_results.select("probability", "r2", "tokens_weights", "text").orderBy(col("probability").desc()))

Then we use the Kernel SHAP text explainer to explain the model's predicted probability for a given observation.

Notice that we drop the base value from the SHAP output before displaying the SHAP values. The base value is the model output for an empty string.

python
shap = TextSHAP(
    model=model,
    outputCol="shaps",
    inputCol="text",
    targetCol="probability",
    targetClasses=[1],
    tokensCol="tokens",
    numSamples=5000,
)

shap_results = (
    shap.transform(explain_instances)
    .select("tokens", "shaps", "r2", "probability", "text")
    .withColumn("probability", vec_access("probability", lit(1)))
    .withColumn("shaps", vector_to_array(col("shaps").getItem(0)))
    .withColumn("shaps", slice(col("shaps"), lit(2), size(col("shaps"))))
    .withColumn("r2", vec_access("r2", lit(0)))
    .withColumn("tokens_shaps", arrays_zip("tokens", "shaps"))
)

display(shap_results.select("probability", "r2", "tokens_shaps", "text").orderBy(col("probability").desc()))