Back to Claude Scientific Skills

Seaborn Common Use Cases and Examples

scientific-skills/seaborn/references/examples.md

2.38.017.2 KB
Original Source

Seaborn Common Use Cases and Examples

This document provides practical examples for common data visualization scenarios using seaborn.

Exploratory Data Analysis

Quick Dataset Overview

python
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Load data
df = pd.read_csv('data.csv')

# Pairwise relationships for all numeric variables
sns.pairplot(df, hue='target_variable', corner=True, diag_kind='kde')
plt.suptitle('Dataset Overview', y=1.01)
plt.savefig('overview.png', dpi=300, bbox_inches='tight')

Distribution Exploration

python
# Multiple distributions across categories
g = sns.displot(
    data=df,
    x='measurement',
    hue='condition',
    col='timepoint',
    kind='kde',
    fill=True,
    height=3,
    aspect=1.5,
    col_wrap=3,
    common_norm=False
)
g.set_axis_labels('Measurement Value', 'Density')
g.set_titles('{col_name}')

Correlation Analysis

python
# Compute correlation matrix
corr = df.select_dtypes(include='number').corr()

# Create mask for upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))

# Plot heatmap
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
    corr,
    mask=mask,
    annot=True,
    fmt='.2f',
    cmap='coolwarm',
    center=0,
    square=True,
    linewidths=1,
    cbar_kws={'shrink': 0.8}
)
plt.title('Correlation Matrix')
plt.tight_layout()

Scientific Publications

Multi-Panel Figure with Different Plot Types

python
# Set publication style
sns.set_theme(style='ticks', context='paper', font_scale=1.1)
sns.set_palette('colorblind')

# Create figure with custom layout
fig = plt.figure(figsize=(12, 8))
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

# Panel A: Time series
ax1 = fig.add_subplot(gs[0, :2])
sns.lineplot(
    data=timeseries_df,
    x='time',
    y='expression',
    hue='gene',
    style='treatment',
    markers=True,
    dashes=False,
    ax=ax1
)
ax1.set_title('A. Gene Expression Over Time', loc='left', fontweight='bold')
ax1.set_xlabel('Time (hours)')
ax1.set_ylabel('Expression Level (AU)')

# Panel B: Distribution comparison
ax2 = fig.add_subplot(gs[0, 2])
sns.violinplot(
    data=expression_df,
    x='treatment',
    y='expression',
    inner='box',
    ax=ax2
)
ax2.set_title('B. Expression Distribution', loc='left', fontweight='bold')
ax2.set_xlabel('Treatment')
ax2.set_ylabel('')

# Panel C: Correlation
ax3 = fig.add_subplot(gs[1, 0])
sns.scatterplot(
    data=correlation_df,
    x='gene1',
    y='gene2',
    hue='cell_type',
    alpha=0.6,
    ax=ax3
)
sns.regplot(
    data=correlation_df,
    x='gene1',
    y='gene2',
    scatter=False,
    color='black',
    ax=ax3
)
ax3.set_title('C. Gene Correlation', loc='left', fontweight='bold')
ax3.set_xlabel('Gene 1 Expression')
ax3.set_ylabel('Gene 2 Expression')

# Panel D: Heatmap
ax4 = fig.add_subplot(gs[1, 1:])
sns.heatmap(
    sample_matrix,
    cmap='RdBu_r',
    center=0,
    annot=True,
    fmt='.1f',
    cbar_kws={'label': 'Log2 Fold Change'},
    ax=ax4
)
ax4.set_title('D. Treatment Effects', loc='left', fontweight='bold')
ax4.set_xlabel('Sample')
ax4.set_ylabel('Gene')

# Clean up
sns.despine()
plt.savefig('figure.pdf', dpi=300, bbox_inches='tight')
plt.savefig('figure.png', dpi=300, bbox_inches='tight')

Box Plot with Significance Annotations

python
import numpy as np
from scipy import stats

# Create plot
fig, ax = plt.subplots(figsize=(8, 6))
sns.boxplot(
    data=df,
    x='treatment',
    y='response',
    order=['Control', 'Low', 'Medium', 'High'],
    palette='Set2',
    ax=ax
)

# Add individual points
sns.stripplot(
    data=df,
    x='treatment',
    y='response',
    order=['Control', 'Low', 'Medium', 'High'],
    color='black',
    alpha=0.3,
    size=3,
    ax=ax
)

# Add significance bars
def add_significance_bar(ax, x1, x2, y, h, text):
    ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], 'k-', lw=1.5)
    ax.text((x1+x2)/2, y+h, text, ha='center', va='bottom')

y_max = df['response'].max()
add_significance_bar(ax, 0, 3, y_max + 1, 0.5, '***')
add_significance_bar(ax, 0, 1, y_max + 3, 0.5, 'ns')

ax.set_ylabel('Response (μM)')
ax.set_xlabel('Treatment Condition')
ax.set_title('Treatment Response Analysis')
sns.despine()

Time Series Analysis

Multiple Time Series with Confidence Bands

python
# Plot with automatic aggregation
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(
    data=timeseries_df,
    x='timestamp',
    y='value',
    hue='sensor',
    style='location',
    markers=True,
    dashes=False,
    errorbar=('ci', 95),
    ax=ax
)

# Customize
ax.set_xlabel('Date')
ax.set_ylabel('Measurement (units)')
ax.set_title('Sensor Measurements Over Time')
ax.legend(title='Sensor & Location', bbox_to_anchor=(1.05, 1), loc='upper left')

# Format x-axis for dates
import matplotlib.dates as mdates
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=7))
plt.xticks(rotation=45, ha='right')

plt.tight_layout()

Faceted Time Series

python
# Create faceted time series
g = sns.relplot(
    data=long_timeseries,
    x='date',
    y='measurement',
    hue='device',
    col='location',
    row='metric',
    kind='line',
    height=3,
    aspect=2,
    errorbar='sd',
    facet_kws={'sharex': True, 'sharey': False}
)

# Customize facet titles
g.set_titles('{row_name} - {col_name}')
g.set_axis_labels('Date', 'Value')

# Rotate x-axis labels
for ax in g.axes.flat:
    ax.tick_params(axis='x', rotation=45)

g.tight_layout()

Categorical Comparisons

Nested Categorical Variables

python
# Create figure
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left panel: Grouped bar plot
sns.barplot(
    data=df,
    x='category',
    y='value',
    hue='subcategory',
    errorbar=('ci', 95),
    capsize=0.1,
    ax=axes[0]
)
axes[0].set_title('Mean Values with 95% CI')
axes[0].set_ylabel('Value (units)')
axes[0].legend(title='Subcategory')

# Right panel: Strip + violin plot
sns.violinplot(
    data=df,
    x='category',
    y='value',
    hue='subcategory',
    inner=None,
    alpha=0.3,
    ax=axes[1]
)
sns.stripplot(
    data=df,
    x='category',
    y='value',
    hue='subcategory',
    dodge=True,
    size=3,
    alpha=0.6,
    ax=axes[1]
)
axes[1].set_title('Distribution of Individual Values')
axes[1].set_ylabel('')
axes[1].get_legend().remove()

plt.tight_layout()
python
# Show how values change across categories
sns.pointplot(
    data=df,
    x='timepoint',
    y='score',
    hue='treatment',
    markers=['o', 's', '^'],
    linestyles=['-', '--', '-.'],
    dodge=0.3,
    capsize=0.1,
    errorbar=('ci', 95)
)

plt.xlabel('Timepoint')
plt.ylabel('Performance Score')
plt.title('Treatment Effects Over Time')
plt.legend(title='Treatment', bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()

Regression and Relationships

Linear Regression with Facets

python
# Fit separate regressions for each category
g = sns.lmplot(
    data=df,
    x='predictor',
    y='response',
    hue='treatment',
    col='cell_line',
    height=4,
    aspect=1.2,
    scatter_kws={'alpha': 0.5, 's': 50},
    ci=95,
    palette='Set2'
)

g.set_axis_labels('Predictor Variable', 'Response Variable')
g.set_titles('{col_name}')
g.tight_layout()

Polynomial Regression

python
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, order in enumerate([1, 2, 3]):
    sns.regplot(
        data=df,
        x='x',
        y='y',
        order=order,
        scatter_kws={'alpha': 0.5},
        line_kws={'color': 'red'},
        ci=95,
        ax=axes[idx]
    )
    axes[idx].set_title(f'Order {order} Polynomial Fit')
    axes[idx].set_xlabel('X Variable')
    axes[idx].set_ylabel('Y Variable')

plt.tight_layout()

Residual Analysis

python
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Main regression
sns.regplot(data=df, x='x', y='y', ax=axes[0, 0])
axes[0, 0].set_title('Regression Fit')

# Residuals vs fitted
sns.residplot(data=df, x='x', y='y', lowess=True,
              scatter_kws={'alpha': 0.5},
              line_kws={'color': 'red', 'lw': 2},
              ax=axes[0, 1])
axes[0, 1].set_title('Residuals vs Fitted')
axes[0, 1].axhline(0, ls='--', color='gray')

# Q-Q plot (using scipy)
from scipy import stats as sp_stats
residuals = df['y'] - np.poly1d(np.polyfit(df['x'], df['y'], 1))(df['x'])
sp_stats.probplot(residuals, dist="norm", plot=axes[1, 0])
axes[1, 0].set_title('Q-Q Plot')

# Histogram of residuals
sns.histplot(residuals, kde=True, ax=axes[1, 1])
axes[1, 1].set_title('Residual Distribution')
axes[1, 1].set_xlabel('Residuals')

plt.tight_layout()

Bivariate and Joint Distributions

Joint Plot with Multiple Representations

python
# Scatter with marginals
g = sns.jointplot(
    data=df,
    x='var1',
    y='var2',
    hue='category',
    kind='scatter',
    height=8,
    ratio=4,
    space=0.1,
    joint_kws={'alpha': 0.5, 's': 50},
    marginal_kws={'kde': True, 'bins': 30}
)

# Add reference lines
g.ax_joint.axline((0, 0), slope=1, color='r', ls='--', alpha=0.5, label='y=x')
g.ax_joint.legend()

g.set_axis_labels('Variable 1', 'Variable 2', fontsize=12)

KDE Contour Plot

python
fig, ax = plt.subplots(figsize=(8, 8))

# Bivariate KDE with filled contours
sns.kdeplot(
    data=df,
    x='x',
    y='y',
    fill=True,
    levels=10,
    cmap='viridis',
    thresh=0.05,
    ax=ax
)

# Overlay scatter
sns.scatterplot(
    data=df,
    x='x',
    y='y',
    color='white',
    edgecolor='black',
    s=50,
    alpha=0.6,
    ax=ax
)

ax.set_xlabel('X Variable')
ax.set_ylabel('Y Variable')
ax.set_title('Bivariate Distribution')

Hexbin with Marginals

python
# For large datasets
g = sns.jointplot(
    data=large_df,
    x='x',
    y='y',
    kind='hex',
    height=8,
    ratio=5,
    space=0.1,
    joint_kws={'gridsize': 30, 'cmap': 'viridis'},
    marginal_kws={'bins': 50, 'color': 'skyblue'}
)

g.set_axis_labels('X Variable', 'Y Variable')

Matrix and Heatmap Visualizations

Hierarchical Clustering Heatmap

python
# Prepare data (samples x features)
data_matrix = df.set_index('sample_id')[feature_columns]

# Create color annotations
row_colors = df.set_index('sample_id')['condition'].map({
    'control': '#1f77b4',
    'treatment': '#ff7f0e'
})

col_colors = pd.Series(['#2ca02c' if 'gene' in col else '#d62728'
                        for col in data_matrix.columns])

# Plot
g = sns.clustermap(
    data_matrix,
    method='ward',
    metric='euclidean',
    z_score=0,  # Normalize rows
    cmap='RdBu_r',
    center=0,
    row_colors=row_colors,
    col_colors=col_colors,
    figsize=(12, 10),
    dendrogram_ratio=(0.1, 0.1),
    cbar_pos=(0.02, 0.8, 0.03, 0.15),
    linewidths=0.5
)

g.ax_heatmap.set_xlabel('Features')
g.ax_heatmap.set_ylabel('Samples')
plt.savefig('clustermap.png', dpi=300, bbox_inches='tight')

Annotated Heatmap with Custom Colorbar

python
# Pivot data for heatmap
pivot_data = df.pivot(index='row_var', columns='col_var', values='value')

# Create heatmap
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
    pivot_data,
    annot=True,
    fmt='.1f',
    cmap='RdYlGn',
    center=pivot_data.mean().mean(),
    vmin=pivot_data.min().min(),
    vmax=pivot_data.max().max(),
    linewidths=0.5,
    linecolor='gray',
    cbar_kws={
        'label': 'Value (units)',
        'orientation': 'vertical',
        'shrink': 0.8,
        'aspect': 20
    },
    ax=ax
)

ax.set_title('Variable Relationships', fontsize=14, pad=20)
ax.set_xlabel('Column Variable', fontsize=12)
ax.set_ylabel('Row Variable', fontsize=12)

plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()

Statistical Comparisons

Before/After Comparison

python
# Reshape data for paired comparison
df_paired = df.melt(
    id_vars='subject',
    value_vars=['before', 'after'],
    var_name='timepoint',
    value_name='measurement'
)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Left: Individual trajectories
for subject in df_paired['subject'].unique():
    subject_data = df_paired[df_paired['subject'] == subject]
    axes[0].plot(subject_data['timepoint'], subject_data['measurement'],
                 'o-', alpha=0.3, color='gray')

sns.pointplot(
    data=df_paired,
    x='timepoint',
    y='measurement',
    color='red',
    markers='D',
    scale=1.5,
    errorbar=('ci', 95),
    capsize=0.2,
    ax=axes[0]
)
axes[0].set_title('Individual Changes')
axes[0].set_ylabel('Measurement')

# Right: Distribution comparison
sns.violinplot(
    data=df_paired,
    x='timepoint',
    y='measurement',
    inner='box',
    ax=axes[1]
)
sns.swarmplot(
    data=df_paired,
    x='timepoint',
    y='measurement',
    color='black',
    alpha=0.5,
    size=3,
    ax=axes[1]
)
axes[1].set_title('Distribution Comparison')
axes[1].set_ylabel('')

plt.tight_layout()

Dose-Response Curve

python
# Create dose-response plot
fig, ax = plt.subplots(figsize=(8, 6))

# Plot individual points
sns.stripplot(
    data=dose_df,
    x='dose',
    y='response',
    order=sorted(dose_df['dose'].unique()),
    color='gray',
    alpha=0.3,
    jitter=0.2,
    ax=ax
)

# Overlay mean with CI
sns.pointplot(
    data=dose_df,
    x='dose',
    y='response',
    order=sorted(dose_df['dose'].unique()),
    color='blue',
    markers='o',
    scale=1.2,
    errorbar=('ci', 95),
    capsize=0.1,
    ax=ax
)

# Fit sigmoid curve
from scipy.optimize import curve_fit

def sigmoid(x, bottom, top, ec50, hill):
    return bottom + (top - bottom) / (1 + (ec50 / x) ** hill)

doses_numeric = dose_df['dose'].astype(float)
params, _ = curve_fit(sigmoid, doses_numeric, dose_df['response'])

x_smooth = np.logspace(np.log10(doses_numeric.min()),
                       np.log10(doses_numeric.max()), 100)
y_smooth = sigmoid(x_smooth, *params)

ax.plot(range(len(sorted(dose_df['dose'].unique()))),
        sigmoid(sorted(doses_numeric.unique()), *params),
        'r-', linewidth=2, label='Sigmoid Fit')

ax.set_xlabel('Dose')
ax.set_ylabel('Response')
ax.set_title('Dose-Response Analysis')
ax.legend()
sns.despine()

Custom Styling

Custom Color Palette from Hex Codes

python
# Define custom palette
custom_palette = ['#E64B35', '#4DBBD5', '#00A087', '#3C5488', '#F39B7F']
sns.set_palette(custom_palette)

# Or use for specific plot
sns.scatterplot(
    data=df,
    x='x',
    y='y',
    hue='category',
    palette=custom_palette
)

Publication-Ready Theme

python
# Set comprehensive theme
sns.set_theme(
    context='paper',
    style='ticks',
    palette='colorblind',
    font='Arial',
    font_scale=1.1,
    rc={
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'savefig.format': 'pdf',
        'axes.linewidth': 1.0,
        'axes.labelweight': 'bold',
        'xtick.major.width': 1.0,
        'ytick.major.width': 1.0,
        'xtick.direction': 'out',
        'ytick.direction': 'out',
        'legend.frameon': False,
        'pdf.fonttype': 42,  # True Type fonts for PDFs
    }
)

Diverging Colormap Centered on Zero

python
# For data with meaningful zero point (e.g., log fold change)
from matplotlib.colors import TwoSlopeNorm

# Find data range
vmin, vmax = df['value'].min(), df['value'].max()
vcenter = 0

# Create norm
norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

# Plot
sns.heatmap(
    pivot_data,
    cmap='RdBu_r',
    norm=norm,
    center=0,
    annot=True,
    fmt='.2f'
)

Large Datasets

Downsampling Strategy

python
# For very large datasets, sample intelligently
def smart_sample(df, target_size=10000, category_col=None):
    if len(df) <= target_size:
        return df

    if category_col:
        # Stratified sampling
        return df.groupby(category_col, group_keys=False).apply(
            lambda x: x.sample(min(len(x), target_size // df[category_col].nunique()))
        )
    else:
        # Simple random sampling
        return df.sample(target_size)

# Use sampled data for visualization
df_sampled = smart_sample(large_df, target_size=5000, category_col='category')

sns.scatterplot(data=df_sampled, x='x', y='y', hue='category', alpha=0.5)

Hexbin for Dense Scatter Plots

python
# For millions of points
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Regular scatter (slow)
axes[0].scatter(df['x'], df['y'], alpha=0.1, s=1)
axes[0].set_title('Scatter (all points)')

# Hexbin (fast)
hb = axes[1].hexbin(df['x'], df['y'], gridsize=50, cmap='viridis', mincnt=1)
axes[1].set_title('Hexbin Aggregation')
plt.colorbar(hb, ax=axes[1], label='Count')

plt.tight_layout()

Interactive Elements for Notebooks

Adjustable Parameters

python
from ipywidgets import interact, FloatSlider

@interact(bandwidth=FloatSlider(min=0.1, max=3.0, step=0.1, value=1.0))
def plot_kde(bandwidth):
    plt.figure(figsize=(10, 6))
    sns.kdeplot(data=df, x='value', hue='category',
                bw_adjust=bandwidth, fill=True)
    plt.title(f'KDE with bandwidth adjustment = {bandwidth}')
    plt.show()

Dynamic Filtering

python
from ipywidgets import interact, SelectMultiple

categories = df['category'].unique().tolist()

@interact(selected=SelectMultiple(options=categories, value=[categories[0]]))
def filtered_plot(selected):
    filtered_df = df[df['category'].isin(selected)]

    fig, ax = plt.subplots(figsize=(10, 6))
    sns.violinplot(data=filtered_df, x='category', y='value', ax=ax)
    ax.set_title(f'Showing {len(selected)} categories')
    plt.show()