optional-skills/mlops/saelens/references/api.md
The core class representing a Sparse Autoencoder.
from sae_lens import SAE
# From official releases
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)
# From HuggingFace
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="username/repo-name",
sae_id="path/to/sae",
device="cuda"
)
# From local disk
sae = SAE.load_from_disk("/path/to/sae", device="cuda")
| Attribute | Shape | Description |
|---|---|---|
W_enc | [d_in, d_sae] | Encoder weights |
W_dec | [d_sae, d_in] | Decoder weights |
b_enc | [d_sae] | Encoder bias |
b_dec | [d_in] | Decoder bias |
cfg | SAEConfig | Configuration object |
# Encode activations to sparse features
features = sae.encode(activations)
# Input: [batch, pos, d_in]
# Output: [batch, pos, d_sae]
# Reconstruct activations from features
reconstructed = sae.decode(features)
# Input: [batch, pos, d_sae]
# Output: [batch, pos, d_in]
# Full forward pass (encode + decode)
reconstructed = sae(activations)
# Returns reconstructed activations
sae.save_model("/path/to/save")
Configuration class for SAE architecture and training context.
| Parameter | Type | Description |
|---|---|---|
d_in | int | Input dimension (model's d_model) |
d_sae | int | SAE hidden dimension |
architecture | str | "standard", "gated", "jumprelu", "topk" |
activation_fn_str | str | Activation function name |
model_name | str | Source model name |
hook_name | str | Hook point in model |
normalize_activations | str | Normalization method |
dtype | str | Data type |
device | str | Device |
print(sae.cfg.d_in) # 768 for GPT-2 small
print(sae.cfg.d_sae) # e.g., 24576 (32x expansion)
print(sae.cfg.hook_name) # e.g., "blocks.8.hook_resid_pre"
Comprehensive configuration for training SAEs.
from sae_lens import LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(
# Model and hook
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768,
# SAE architecture
architecture="standard", # "standard", "gated", "jumprelu", "topk"
d_sae=768 * 8, # Expansion factor
activation_fn="relu",
# Training hyperparameters
lr=4e-4,
l1_coefficient=8e-5,
lp_norm=1.0,
lr_scheduler_name="constant",
lr_warm_up_steps=500,
# Sparsity control
l1_warm_up_steps=1000,
use_ghost_grads=True,
feature_sampling_window=1000,
dead_feature_window=5000,
dead_feature_threshold=1e-8,
# Data
dataset_path="monology/pile-uncopyrighted",
streaming=True,
context_size=128,
# Batch sizes
train_batch_size_tokens=4096,
store_batch_size_prompts=16,
n_batches_in_buffer=64,
# Training duration
training_tokens=100_000_000,
# Logging
log_to_wandb=True,
wandb_project="sae-training",
wandb_log_frequency=100,
# Checkpointing
checkpoint_path="checkpoints",
n_checkpoints=5,
# Hardware
device="cuda",
dtype="float32",
)
| Parameter | Description |
|---|---|
architecture | SAE type: "standard", "gated", "jumprelu", "topk" |
d_sae | Hidden dimension (or use expansion_factor) |
expansion_factor | Alternative to d_sae: d_sae = d_in × expansion_factor |
activation_fn | "relu", "topk", etc. |
activation_fn_kwargs | Dict for activation params (e.g., {"k": 50} for topk) |
| Parameter | Description |
|---|---|
l1_coefficient | L1 penalty weight (higher = sparser) |
l1_warm_up_steps | Steps to ramp up L1 penalty |
use_ghost_grads | Apply gradients to dead features |
dead_feature_threshold | Activation threshold for "dead" |
dead_feature_window | Steps to check for dead features |
| Parameter | Description |
|---|---|
lr | Base learning rate |
lr_scheduler_name | "constant", "cosineannealing", etc. |
lr_warm_up_steps | LR warmup steps |
lr_decay_steps | Steps for LR decay |
Main class for executing training.
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig
cfg = LanguageModelSAERunnerConfig(...)
runner = SAETrainingRunner(cfg)
sae = runner.run()
# During training, metrics logged to W&B include:
# - l0: Average active features
# - ce_loss_score: Cross-entropy recovery
# - mse_loss: Reconstruction loss
# - l1_loss: Sparsity loss
# - dead_features: Count of dead features
Manages activation collection and batching.
from sae_lens import ActivationsStore
store = ActivationsStore.from_sae(
model=model,
sae=sae,
store_batch_size_prompts=8,
train_batch_size_tokens=4096,
n_batches_in_buffer=32,
device="cuda",
)
# Get batch of activations
activations = store.get_batch_tokens()
Integration of SAEs with TransformerLens models.
from sae_lens import HookedSAETransformer
# Load model with SAE
model = HookedSAETransformer.from_pretrained("gpt2-small")
model.add_sae(sae)
# Run with SAE in the loop
output = model.run_with_saes(tokens, saes=[sae])
# Cache with SAE activations
output, cache = model.run_with_cache_with_saes(tokens, saes=[sae])
cfg = LanguageModelSAERunnerConfig(
architecture="standard",
activation_fn="relu",
l1_coefficient=8e-5,
)
cfg = LanguageModelSAERunnerConfig(
architecture="gated",
)
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50}, # Exactly 50 active features
)
cfg = LanguageModelSAERunnerConfig(
architecture="jumprelu",
)
from sae_lens import upload_saes_to_huggingface
upload_saes_to_huggingface(
saes=[sae],
repo_id="username/my-saes",
token="hf_token",
)
# Features can be viewed on Neuronpedia
# URL format: neuronpedia.org/{model}/{layer}-{sae_type}/{feature_id}
# Example: neuronpedia.org/gpt2-small/8-res-jb/1234