scientific-skills/shap/references/plots.md
This document provides comprehensive information about all SHAP plotting functions, their parameters, use cases, and best practices for visualizing model explanations.
SHAP provides diverse visualization tools for explaining model predictions at both individual and global levels. Each plot type serves specific purposes in understanding feature importance, interactions, and prediction mechanisms.
Purpose: Display explanations for individual predictions, showing how each feature moves the prediction from the baseline (expected value) toward the final prediction.
Function: shap.plots.waterfall(explanation, max_display=10, show=True)
Key Parameters:
explanation: Single row from an Explanation object (not multiple samples)max_display: Number of features to show (default: 10); less impactful features collapse into a single "other features" termshow: Whether to display the plot immediatelyVisual Elements:
When to Use:
Important Notes:
Example:
import shap
# Compute SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
# Plot waterfall for first prediction
shap.plots.waterfall(shap_values[0])
# Show more features
shap.plots.waterfall(shap_values[0], max_display=20)
Purpose: Information-dense summary of how top features impact model output across the entire dataset, combining feature importance with value distributions.
Function: shap.plots.beeswarm(shap_values, max_display=10, order=Explanation.abs.mean(0), color=None, show=True)
Key Parameters:
shap_values: Explanation object containing multiple samplesmax_display: Number of features to display (default: 10)order: How to rank features
Explanation.abs.mean(0): Mean absolute SHAP values (default)Explanation.abs.max(0): Maximum absolute values (highlights outlier impacts)color: matplotlib colormap; defaults to red-blue schemeshow: Whether to display the plot immediatelyVisual Elements:
When to Use:
Practical Variations:
# Standard beeswarm plot
shap.plots.beeswarm(shap_values)
# Show more features
shap.plots.beeswarm(shap_values, max_display=20)
# Order by maximum absolute values (highlight outliers)
shap.plots.beeswarm(shap_values, order=shap_values.abs.max(0))
# Plot absolute SHAP values with fixed coloring
shap.plots.beeswarm(shap_values.abs, color="shap_red")
# Custom matplotlib colormap
shap.plots.beeswarm(shap_values, color=plt.cm.viridis)
Purpose: Display feature importance as mean absolute SHAP values, providing clean, simple visualizations of global feature impact.
Function: shap.plots.bar(shap_values, max_display=10, clustering=None, clustering_cutoff=0.5, show=True)
Key Parameters:
shap_values: Explanation object (can be single instance, global, or cohorts)max_display: Maximum number of features/bars to showclustering: Optional hierarchical clustering object from shap.utils.hclustclustering_cutoff: Threshold for displaying clustering structure (0-1, default: 0.5)Plot Types:
Shows overall feature importance across all samples. Importance calculated as mean absolute SHAP value.
# Global feature importance
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)
shap.plots.bar(shap_values)
Displays SHAP values for a single instance with feature values shown in gray.
# Single prediction explanation
shap.plots.bar(shap_values[0])
Compares feature importance across subgroups by passing a dictionary of Explanation objects.
# Compare cohorts
cohorts = {
"Group A": shap_values[mask_A],
"Group B": shap_values[mask_B]
}
shap.plots.bar(cohorts)
Feature Clustering: Identifies redundant features using model-based clustering (more accurate than correlation-based methods).
# Add feature clustering
clustering = shap.utils.hclust(X_train, y_train)
shap.plots.bar(shap_values, clustering=clustering)
# Adjust clustering display threshold
shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=0.3)
When to Use:
Purpose: Additive force visualization showing how features push prediction higher (red) or lower (blue) from baseline.
Function: shap.plots.force(base_value, shap_values, features, feature_names=None, out_names=None, link="identity", matplotlib=False, show=True)
Key Parameters:
base_value: Expected value (baseline prediction)shap_values: SHAP values for sample(s)features: Feature values for sample(s)feature_names: Optional feature nameslink: Transform function ("identity" or "logit")matplotlib: Use matplotlib backend (default: interactive JavaScript)Visual Elements:
Interactive Features (JavaScript mode):
When to Use:
Example:
# Single prediction force plot
shap.plots.force(
shap_values.base_values[0],
shap_values.values[0],
X_test.iloc[0],
matplotlib=True
)
# Multiple predictions (interactive)
shap.plots.force(
shap_values.base_values,
shap_values.values,
X_test
)
Purpose: Show relationship between feature values and their SHAP values, revealing how feature values impact predictions.
Function: shap.plots.scatter(shap_values, color=None, hist=True, alpha=1, show=True)
Key Parameters:
shap_values: Explanation object, can specify feature with subscript (e.g., shap_values[:, "Age"])color: Feature to use for coloring points (string name or Explanation object)hist: Show histogram of feature values on y-axisalpha: Point transparency (useful for dense plots)Visual Elements:
When to Use:
Interaction Detection: Color points by another feature to reveal interactions.
# Basic dependence plot
shap.plots.scatter(shap_values[:, "Age"])
# Color by another feature to show interactions
shap.plots.scatter(shap_values[:, "Age"], color=shap_values[:, "Education"])
# Multiple features in one plot
shap.plots.scatter(shap_values[:, ["Age", "Education", "Hours-per-week"]])
# Increase transparency for dense data
shap.plots.scatter(shap_values[:, "Age"], alpha=0.5)
Purpose: Visualize SHAP values for multiple samples simultaneously, showing feature impacts across instances.
Function: shap.plots.heatmap(shap_values, instance_order=None, feature_values=None, max_display=10, show=True)
Key Parameters:
shap_values: Explanation objectinstance_order: How to order instances (can be Explanation object for custom ordering)feature_values: Display feature values on hovermax_display: Maximum features to displayVisual Elements:
When to Use:
Example:
# Basic heatmap
shap.plots.heatmap(shap_values)
# Order instances by model output
shap.plots.heatmap(shap_values, instance_order=shap_values.sum(1))
# Show specific subset
shap.plots.heatmap(shap_values[:100])
Purpose: Similar to beeswarm plots but uses violin (kernel density) visualization instead of individual dots.
Function: shap.plots.violin(shap_values, features=None, feature_names=None, max_display=10, show=True)
When to Use:
Example:
shap.plots.violin(shap_values)
Purpose: Show prediction paths through cumulative SHAP values, particularly useful for multiclass classification.
Function: shap.plots.decision(base_value, shap_values, features, feature_names=None, feature_order="importance", highlight=None, link="identity", show=True)
Key Parameters:
base_value: Expected valueshap_values: SHAP values for samplesfeatures: Feature valuesfeature_order: How to order features ("importance" or list)highlight: Indices of samples to highlightlink: Transform functionWhen to Use:
Example:
# Decision plot for multiple predictions
shap.plots.decision(
shap_values.base_values,
shap_values.values,
X_test,
feature_names=X_test.columns.tolist()
)
# Highlight specific instances
shap.plots.decision(
shap_values.base_values,
shap_values.values,
X_test,
highlight=[0, 5, 10]
)
For Individual Predictions:
For Global Understanding:
For Feature Relationships:
For Multiple Samples:
For Cohort Comparison:
1. Start Global, Then Go Local:
2. Use Multiple Plot Types:
3. Adjust max_display:
4. Color Meaningfully:
5. Consider Audience:
6. Save High-Quality Figures:
import matplotlib.pyplot as plt
# Create plot
shap.plots.beeswarm(shap_values, show=False)
# Save with high DPI
plt.savefig('shap_plot.png', dpi=300, bbox_inches='tight')
plt.close()
7. Handle Large Datasets:
shap_values[:1000])Pattern 1: Complete Model Explanation
# 1. Global importance
shap.plots.beeswarm(shap_values)
# 2. Top feature relationships
for feature in top_features:
shap.plots.scatter(shap_values[:, feature])
# 3. Example predictions
for i in interesting_indices:
shap.plots.waterfall(shap_values[i])
Pattern 2: Model Comparison
# Compute SHAP for multiple models
shap_model1 = explainer1(X_test)
shap_model2 = explainer2(X_test)
# Compare feature importance
shap.plots.bar({
"Model 1": shap_model1,
"Model 2": shap_model2
})
Pattern 3: Subgroup Analysis
# Define cohorts
male_mask = X_test['Sex'] == 'Male'
female_mask = X_test['Sex'] == 'Female'
# Compare cohorts
shap.plots.bar({
"Male": shap_values[male_mask],
"Female": shap_values[female_mask]
})
# Separate beeswarm plots
shap.plots.beeswarm(shap_values[male_mask])
shap.plots.beeswarm(shap_values[female_mask])
Pattern 4: Debugging Predictions
# Identify outliers or errors
errors = (model.predict(X_test) != y_test)
error_indices = np.where(errors)[0]
# Explain errors
for idx in error_indices[:5]:
print(f"Sample {idx}:")
shap.plots.waterfall(shap_values[idx])
# Explore key features
shap.plots.scatter(shap_values[:, "Key_Feature"])
Jupyter Notebooks:
show=True (default) for inline displayStatic Reports:
Web Applications:
Issue: Plots don't display
plt.show() if neededIssue: Too many features cluttering plot
max_display parameter or use feature clusteringIssue: Colors reversed or confusing
Issue: Slow plotting with large datasets
shap_values[:1000] for visualizationIssue: Feature names missing