Model Card for SSAE Checkpoints
This is the official model repository for the paper "Step-Level Sparse Autoencoder for Reasoning Process Interpretation".
This repository contains the trained Step-Level Sparse Autoencoder (SSAE) checkpoints.
- Paper: Arxiv Link Here
- Code: GitHub Link Here
- Collection: HuggingFace
Model Overview
The checkpoints are provided as PyTorch state dictionaries (.pt files). Each file represents an SSAE trained on a specific Base Model using a specific Dataset.
Naming Convention
The filenames follow this structure:
{Dataset}_{BaseModel}_{SparsityConfig}.pt
- Dataset: Source data used for training (e.g.,
gsm8k,numina,opencodeinstruct). - Base Model: The LLM whose activations were encoded (e.g.,
Llama3.2-1b,Qwen2.5-0.5b). - SparsityConfig: Target sparsity (e.g.,
spar-10indicates target sparisty (tau_{spar}) equals 10.)
Checkpoints List
Below is the list of available checkpoints in this repository:
| Filename | Base Model | Training Dataset | Description |
|---|---|---|---|
gsm8k-385k_Llama3.2-1b_spar-10.pt |
Llama-3.2-1B | GSM8K | SSAE trained on Llama-3.2-1B using GSM8K-385K. |
gsm8k-385k_Qwen2.5-0.5b_spar-10.pt |
Qwen-2.5-0.5B | GSM8K | SSAE trained on Qwen-2.5-0.5B using GSM8K-385K. |
numina-859k_Qwen2.5-0.5b_spar-10.pt |
Qwen-2.5-0.5B | Numina | SSAE trained on Qwen-2.5-0.5B using Numina-859K. |
opencodeinstruct-36k_Llama3.2-1b_spar-10.pt |
Llama-3.2-1B | OpenCodeInstruct | SSAE trained on Llama-3.2-1B using OpenCodeInstruct-36K. |
opencodeinstruct-36k_Qwen2.5-0.5b_spar-10.pt |
Qwen-2.5-0.5B | OpenCodeInstruct | SSAE trained on Qwen-2.5-0.5B using OpenCodeInstruct-36K. |
Usage
The provided .pt files contain not only the model weights but also the training configuration and metadata.
Structure of the checkpoint dictionary:
model: The model state dictionary (weights).config: Configuration dictionary (sparsity factor, etc.).encoder_name/decoder_name: Names of the base models used.global_step: Training step count.
Loading Code Example
import torch
from huggingface_hub import hf_hub_download
# 1. Download the checkpoint
repo_id = "Miaow-Lab/SSAE-Models"
filename = "gsm8k-385k_Llama3.2-1b_spar-10.pt" # Example filename
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
# 2. Load the full checkpoint dictionary
# Note: map_location="cpu" is recommended for initial loading
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"Loaded checkpoint (Step: {checkpoint.get('global_step', 'Unknown')})")
print(f"Config: {checkpoint.get('config')}")
# 3. Initialize your model
# Use the metadata from the checkpoint to ensure correct initialization arguments
# model = MyModel(
# tokenizer=...,
# sparsity_factor=checkpoint['config'].get('sparsity_factor'), # Adjust key based on your config structure
# init_from=(checkpoint['encoder_name'], checkpoint['decoder_name'])
# )
# 4. Load the weights
# CRITICAL: The weights are stored under the "model" key
model.load_state_dict(checkpoint["model"], strict=True)
model.to("cuda") # Move to GPU if needed
model.eval()
Citation
If you use these models or the associated code in your research, please cite our paper:
Model tree for Miaow-Lab/SSAE-Checkpoints
Base model
Qwen/Qwen2.5-0.5BCollection including Miaow-Lab/SSAE-Checkpoints
Collection
Training and evaluation dataset, model checkpoints in 'Step-Level Sparse Autoencoder for Reasoning Process Interpretation' • 2 items • Updated