Grokking in Neural Networks: Understanding Modular Arithmetic Through Mechanistic Interpretability¶
Introduction: The Mystery of Modern AI¶
No one fully understands modern AI systems. Each token produced by models like ChatGPT results from hundreds of billions of calculations using parameters learned from data. These models are trained simply to predict the next piece of text, yet somehow real intelligence appears to emerge from this process.
Key questions remain unanswered: - What pathways through billions of computations are responsible for specific knowledge? - Why do certain skills only emerge at specific model sizes or training durations? - Are these models memorizing or truly learning?
This document explores one of the most well-understood AI phenomena: grokking - where models suddenly generalize after appearing to merely memorize training data.
The Grokking Discovery¶
Accidental Discovery at OpenAI (2021)¶
In 2021, an OpenAI research team was training small transformer models on modular arithmetic. The initial results were disappointing - models quickly memorized training data but failed to generalize to test sets.
Then serendipity struck: a researcher accidentally left a model training during vacation. Upon return, they discovered the model had suddenly achieved perfect generalization after thousands of additional training steps.
The Modular Arithmetic Task¶
The team studied modular arithmetic using addition operations:
Mathematical Setup: - Operation: X + Y mod p (where p is the modulus) - Example with modulus 5: - 1 + 4 = 5 mod 5 = 0 - 4 + 2 = 6 mod 5 = 1
Data Representation:
# Example: representing 1 + 2 = with one-hot encoding
# For modulus 5 (tokens: 0,1,2,3,4,=)
input_sequence = [
[1,0,0,0,0,0], # token "1"
[0,0,1,0,0,0], # token "2"
[0,0,0,0,0,1] # token "="
]
# Expected output: [0,0,0,1,0,0] for token "3"
The Grokking Phenomenon¶
The term "grokking" comes from Robert Heinlein's 1961 novel Stranger in a Strange Land, meaning to understand something so thoroughly that you merge with it. The training dynamics showed:
- Memorization Phase (~140 steps): Perfect training accuracy, poor test performance
- Grokking Phase (~7,000 steps): Sudden perfect generalization
Deep Dive: Mechanistic Analysis by Nanda et al. (2023)¶
Model Architecture¶
The researchers studied a single-layer transformer with:
- Input: One-hot encoded vectors (modulus 113 → 114 dimensions)
- Embedding: 114×3 matrix → 128-dimensional vectors
- Attention block: Combines information across positions
- MLP block: 512 neurons with nonlinear activations
- Unembedding: Maps back to vocabulary space
Discovering Trigonometric Structure¶
Key Finding: The model learns to compute trigonometric functions of its inputs!
Layer 1: Learning Sine and Cosine¶
Early in the network, researchers found the model computes:
- cos(x * 8π/113) and sin(x * 8π/113) for input x
- cos(y * 8π/113) and sin(y * 8π/113) for input y
Visualization Method:
# Sparse linear probe to extract trigonometric signals
def extract_trig_signal(embeddings, probe_weights):
"""Extract sine/cosine signals from embeddings"""
signal = np.dot(embeddings, probe_weights)
return signal
# Results show clean sine/cosine waves when plotting against input values
Layer 2: Computing Products¶
The MLP learns to compute products:
- cos(x) * cos(y)
- sin(x) * sin(y)
Layer 3: The Trigonometric Identity¶
The crucial insight: the model learns to use the trigonometric identity:
This allows the network to convert products of trigonometric functions into the sum x + y, solving the modular arithmetic problem!
The Clock Analogy¶
Modular arithmetic is naturally circular - like a clock: - 11 AM + 2 hours = 1 PM (11 + 2 mod 12 = 1) - The model learns circular representations in activation space - Trigonometric functions naturally handle this circular structure
Understanding Why Grokking Occurs¶
Training Dynamics Visualization¶
During the seemingly "flat" period between memorization and generalization:
- Memorization circuits provide correct answers for training data
- Generalization circuits (trigonometric structures) slowly develop
- Cleanup phase: Memorization circuits are pruned away
- Grokking: Only robust generalization circuits remain
The Excluded Loss Metric¶
Nanda et al. introduced a brilliant diagnostic:
def excluded_loss(predictions, targets, exclude_frequencies):
"""Compute loss after removing specific frequency components"""
# Remove learned trigonometric frequencies from predictions
filtered_predictions = remove_frequencies(predictions, exclude_frequencies)
return cross_entropy(filtered_predictions, targets)
This metric reveals the model's growing reliance on trigonometric frequencies even during the flat performance period.
Modern Applications: Claude Haiku Line Breaks¶
Recent work from Anthropic shows similar mechanistic understanding in production models. Claude 3.5 Haiku uses a 6-dimensional manifold to track character counts and determine line breaks:
The QK-Twist Mechanism¶
- Character count and line length are represented on parallel helical structures
- Attention mechanism rotates these structures relative to each other
- Proximity detection identifies when ~5 characters remain before line end
- Line break decision triggered by high attention scores
This demonstrates that mechanistic interpretability can reveal precise mechanisms even in large production models.
Implications and Future Directions¶
What Grokking Teaches Us¶
- Training dynamics are complex: Flat metrics don't mean no learning
- Multiple solutions exist: Memorization vs. generalization circuits
- Emergent algorithms: Networks discover sophisticated mathematical structures
- Interpretability is possible: With the right tools, we can understand mechanisms
Limitations and Challenges¶
- Toy problems: Modular arithmetic is far simpler than language modeling
- Scale gap: Techniques don't yet scale to full LLMs
- Cherry-picking: We may only understand the most interpretable behaviors
Research Directions¶
- Scaling mechanistic interpretability to larger models
- Automated discovery of interpretable structures
- Training dynamics understanding across different tasks
- Circuit-level interventions for model editing
Conclusion: Alien Intelligence¶
The grokking phenomenon reveals something profound about neural networks. While we interact with AI through human language, the underlying computations involve alien mathematical structures - trigonometric identities, high-dimensional manifolds, and emergent algorithms.
As AI researcher Andrej Karpathy noted, training large language models is "less like building animal intelligence and more like summoning ghosts." The patterns these models discover are fundamentally different from human reasoning, even when solving the same problems.
Understanding these alien computational strategies may be key to building more capable, reliable, and interpretable AI systems. Grokking represents a rare window into these mysterious processes - a transparent box in a world of black boxes.
Code Example: Implementing Grokking Detection¶
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
class ModularArithmeticDataset:
def __init__(self, modulus=113, operation='add'):
self.modulus = modulus
self.operation = operation
self.vocab_size = modulus + 1 # +1 for equals token
def generate_data(self, train_fraction=0.7):
"""Generate all possible arithmetic problems"""
data = []
for x in range(self.modulus):
for y in range(self.modulus):
if self.operation == 'add':
result = (x + y) % self.modulus
data.append([x, y, result])
# Split train/test
np.random.shuffle(data)
split_idx = int(len(data) * train_fraction)
return data[:split_idx], data[split_idx:]
def track_grokking(model, train_loader, test_loader, num_epochs=10000):
"""Track training dynamics to detect grokking"""
train_losses = []
test_losses = []
train_accs = []
test_accs = []
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss, train_acc = evaluate_model(model, train_loader, criterion)
# Testing phase
model.eval()
test_loss, test_acc = evaluate_model(model, test_loader, criterion)
# Log metrics
train_losses.append(train_loss)
test_losses.append(test_loss)
train_accs.append(train_acc)
test_accs.append(test_acc)
# Check for grokking (sudden test accuracy improvement)
if epoch > 100 and test_acc > 0.95 and test_accs[epoch-10] < 0.5:
print(f"Grokking detected at epoch {epoch}!")
return train_losses, test_losses, train_accs, test_accs