Back to Hermes Agent

SAELens API Reference

optional-skills/mlops/saelens/references/api.md

2026.6.56.8 KB
Original Source

SAELens API Reference

SAE Class

The core class representing a Sparse Autoencoder.

Loading Pre-trained SAEs

python
from sae_lens import SAE

# From official releases
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.8.hook_resid_pre",
    device="cuda"
)

# From HuggingFace
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="username/repo-name",
    sae_id="path/to/sae",
    device="cuda"
)

# From local disk
sae = SAE.load_from_disk("/path/to/sae", device="cuda")

SAE Attributes

AttributeShapeDescription
W_enc[d_in, d_sae]Encoder weights
W_dec[d_sae, d_in]Decoder weights
b_enc[d_sae]Encoder bias
b_dec[d_in]Decoder bias
cfgSAEConfigConfiguration object

Core Methods

encode()

python
# Encode activations to sparse features
features = sae.encode(activations)
# Input: [batch, pos, d_in]
# Output: [batch, pos, d_sae]

decode()

python
# Reconstruct activations from features
reconstructed = sae.decode(features)
# Input: [batch, pos, d_sae]
# Output: [batch, pos, d_in]

forward()

python
# Full forward pass (encode + decode)
reconstructed = sae(activations)
# Returns reconstructed activations

save_model()

python
sae.save_model("/path/to/save")

SAEConfig

Configuration class for SAE architecture and training context.

Key Parameters

ParameterTypeDescription
d_inintInput dimension (model's d_model)
d_saeintSAE hidden dimension
architecturestr"standard", "gated", "jumprelu", "topk"
activation_fn_strstrActivation function name
model_namestrSource model name
hook_namestrHook point in model
normalize_activationsstrNormalization method
dtypestrData type
devicestrDevice

Accessing Config

python
print(sae.cfg.d_in)      # 768 for GPT-2 small
print(sae.cfg.d_sae)     # e.g., 24576 (32x expansion)
print(sae.cfg.hook_name) # e.g., "blocks.8.hook_resid_pre"

LanguageModelSAERunnerConfig

Comprehensive configuration for training SAEs.

Example Configuration

python
from sae_lens import LanguageModelSAERunnerConfig

cfg = LanguageModelSAERunnerConfig(
    # Model and hook
    model_name="gpt2-small",
    hook_name="blocks.8.hook_resid_pre",
    hook_layer=8,
    d_in=768,

    # SAE architecture
    architecture="standard",  # "standard", "gated", "jumprelu", "topk"
    d_sae=768 * 8,           # Expansion factor
    activation_fn="relu",

    # Training hyperparameters
    lr=4e-4,
    l1_coefficient=8e-5,
    lp_norm=1.0,
    lr_scheduler_name="constant",
    lr_warm_up_steps=500,

    # Sparsity control
    l1_warm_up_steps=1000,
    use_ghost_grads=True,
    feature_sampling_window=1000,
    dead_feature_window=5000,
    dead_feature_threshold=1e-8,

    # Data
    dataset_path="monology/pile-uncopyrighted",
    streaming=True,
    context_size=128,

    # Batch sizes
    train_batch_size_tokens=4096,
    store_batch_size_prompts=16,
    n_batches_in_buffer=64,

    # Training duration
    training_tokens=100_000_000,

    # Logging
    log_to_wandb=True,
    wandb_project="sae-training",
    wandb_log_frequency=100,

    # Checkpointing
    checkpoint_path="checkpoints",
    n_checkpoints=5,

    # Hardware
    device="cuda",
    dtype="float32",
)

Key Parameters Explained

Architecture Parameters

ParameterDescription
architectureSAE type: "standard", "gated", "jumprelu", "topk"
d_saeHidden dimension (or use expansion_factor)
expansion_factorAlternative to d_sae: d_sae = d_in × expansion_factor
activation_fn"relu", "topk", etc.
activation_fn_kwargsDict for activation params (e.g., {"k": 50} for topk)

Sparsity Parameters

ParameterDescription
l1_coefficientL1 penalty weight (higher = sparser)
l1_warm_up_stepsSteps to ramp up L1 penalty
use_ghost_gradsApply gradients to dead features
dead_feature_thresholdActivation threshold for "dead"
dead_feature_windowSteps to check for dead features

Learning Rate Parameters

ParameterDescription
lrBase learning rate
lr_scheduler_name"constant", "cosineannealing", etc.
lr_warm_up_stepsLR warmup steps
lr_decay_stepsSteps for LR decay

SAETrainingRunner

Main class for executing training.

Basic Training

python
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig

cfg = LanguageModelSAERunnerConfig(...)
runner = SAETrainingRunner(cfg)
sae = runner.run()

Accessing Training Metrics

python
# During training, metrics logged to W&B include:
# - l0: Average active features
# - ce_loss_score: Cross-entropy recovery
# - mse_loss: Reconstruction loss
# - l1_loss: Sparsity loss
# - dead_features: Count of dead features

ActivationsStore

Manages activation collection and batching.

Basic Usage

python
from sae_lens import ActivationsStore

store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device="cuda",
)

# Get batch of activations
activations = store.get_batch_tokens()

HookedSAETransformer

Integration of SAEs with TransformerLens models.

Basic Usage

python
from sae_lens import HookedSAETransformer

# Load model with SAE
model = HookedSAETransformer.from_pretrained("gpt2-small")
model.add_sae(sae)

# Run with SAE in the loop
output = model.run_with_saes(tokens, saes=[sae])

# Cache with SAE activations
output, cache = model.run_with_cache_with_saes(tokens, saes=[sae])

SAE Architectures

Standard (ReLU + L1)

python
cfg = LanguageModelSAERunnerConfig(
    architecture="standard",
    activation_fn="relu",
    l1_coefficient=8e-5,
)

Gated

python
cfg = LanguageModelSAERunnerConfig(
    architecture="gated",
)

TopK

python
cfg = LanguageModelSAERunnerConfig(
    architecture="topk",
    activation_fn="topk",
    activation_fn_kwargs={"k": 50},  # Exactly 50 active features
)

JumpReLU (State-of-the-art)

python
cfg = LanguageModelSAERunnerConfig(
    architecture="jumprelu",
)

Utility Functions

Upload to HuggingFace

python
from sae_lens import upload_saes_to_huggingface

upload_saes_to_huggingface(
    saes=[sae],
    repo_id="username/my-saes",
    token="hf_token",
)

Neuronpedia Integration

python
# Features can be viewed on Neuronpedia
# URL format: neuronpedia.org/{model}/{layer}-{sae_type}/{feature_id}
# Example: neuronpedia.org/gpt2-small/8-res-jb/1234