Skip to content

Grokking in Neural Networks: Understanding Modular Arithmetic Through Mechanistic Interpretability

Introduction: The Mystery of Modern AI

No one fully understands modern AI systems. Each token produced by models like ChatGPT results from hundreds of billions of calculations using parameters learned from data. These models are trained simply to predict the next piece of text, yet somehow real intelligence appears to emerge from this process.

Key questions remain unanswered: - What pathways through billions of computations are responsible for specific knowledge? - Why do certain skills only emerge at specific model sizes or training durations? - Are these models memorizing or truly learning?

This document explores one of the most well-understood AI phenomena: grokking - where models suddenly generalize after appearing to merely memorize training data.

The Grokking Discovery

Accidental Discovery at OpenAI (2021)

In 2021, an OpenAI research team was training small transformer models on modular arithmetic. The initial results were disappointing - models quickly memorized training data but failed to generalize to test sets.

Then serendipity struck: a researcher accidentally left a model training during vacation. Upon return, they discovered the model had suddenly achieved perfect generalization after thousands of additional training steps.

The Modular Arithmetic Task

The team studied modular arithmetic using addition operations:

Mathematical Setup: - Operation: X + Y mod p (where p is the modulus) - Example with modulus 5: - 1 + 4 = 5 mod 5 = 0 - 4 + 2 = 6 mod 5 = 1

Data Representation:

# Example: representing 1 + 2 = with one-hot encoding
# For modulus 5 (tokens: 0,1,2,3,4,=)

input_sequence = [
    [1,0,0,0,0,0],  # token "1"
    [0,0,1,0,0,0],  # token "2" 
    [0,0,0,0,0,1]   # token "="
]
# Expected output: [0,0,0,1,0,0] for token "3"

The Grokking Phenomenon

The term "grokking" comes from Robert Heinlein's 1961 novel Stranger in a Strange Land, meaning to understand something so thoroughly that you merge with it. The training dynamics showed:

  1. Memorization Phase (~140 steps): Perfect training accuracy, poor test performance
  2. Grokking Phase (~7,000 steps): Sudden perfect generalization

Deep Dive: Mechanistic Analysis by Nanda et al. (2023)

Model Architecture

The researchers studied a single-layer transformer with: - Input: One-hot encoded vectors (modulus 113 → 114 dimensions) - Embedding: 114×3 matrix → 128-dimensional vectors
- Attention block: Combines information across positions - MLP block: 512 neurons with nonlinear activations - Unembedding: Maps back to vocabulary space

Input (114×3) → Embedding (128×3) → Attention → MLP (512) → Unembedding (114)

Discovering Trigonometric Structure

Key Finding: The model learns to compute trigonometric functions of its inputs!

Layer 1: Learning Sine and Cosine

Early in the network, researchers found the model computes: - cos(x * 8π/113) and sin(x * 8π/113) for input x - cos(y * 8π/113) and sin(y * 8π/113) for input y

Visualization Method:

# Sparse linear probe to extract trigonometric signals
def extract_trig_signal(embeddings, probe_weights):
    """Extract sine/cosine signals from embeddings"""
    signal = np.dot(embeddings, probe_weights)
    return signal

# Results show clean sine/cosine waves when plotting against input values

Layer 2: Computing Products

The MLP learns to compute products: - cos(x) * cos(y) - sin(x) * sin(y)

Layer 3: The Trigonometric Identity

The crucial insight: the model learns to use the trigonometric identity:

cos(x) * cos(y) - sin(x) * sin(y) = cos(x + y)

This allows the network to convert products of trigonometric functions into the sum x + y, solving the modular arithmetic problem!

The Clock Analogy

Modular arithmetic is naturally circular - like a clock: - 11 AM + 2 hours = 1 PM (11 + 2 mod 12 = 1) - The model learns circular representations in activation space - Trigonometric functions naturally handle this circular structure

Understanding Why Grokking Occurs

Training Dynamics Visualization

During the seemingly "flat" period between memorization and generalization:

  1. Memorization circuits provide correct answers for training data
  2. Generalization circuits (trigonometric structures) slowly develop
  3. Cleanup phase: Memorization circuits are pruned away
  4. Grokking: Only robust generalization circuits remain

The Excluded Loss Metric

Nanda et al. introduced a brilliant diagnostic:

def excluded_loss(predictions, targets, exclude_frequencies):
    """Compute loss after removing specific frequency components"""
    # Remove learned trigonometric frequencies from predictions
    filtered_predictions = remove_frequencies(predictions, exclude_frequencies)
    return cross_entropy(filtered_predictions, targets)

This metric reveals the model's growing reliance on trigonometric frequencies even during the flat performance period.

Modern Applications: Claude Haiku Line Breaks

Recent work from Anthropic shows similar mechanistic understanding in production models. Claude 3.5 Haiku uses a 6-dimensional manifold to track character counts and determine line breaks:

The QK-Twist Mechanism

  1. Character count and line length are represented on parallel helical structures
  2. Attention mechanism rotates these structures relative to each other
  3. Proximity detection identifies when ~5 characters remain before line end
  4. Line break decision triggered by high attention scores

This demonstrates that mechanistic interpretability can reveal precise mechanisms even in large production models.

Implications and Future Directions

What Grokking Teaches Us

  1. Training dynamics are complex: Flat metrics don't mean no learning
  2. Multiple solutions exist: Memorization vs. generalization circuits
  3. Emergent algorithms: Networks discover sophisticated mathematical structures
  4. Interpretability is possible: With the right tools, we can understand mechanisms

Limitations and Challenges

  • Toy problems: Modular arithmetic is far simpler than language modeling
  • Scale gap: Techniques don't yet scale to full LLMs
  • Cherry-picking: We may only understand the most interpretable behaviors

Research Directions

  1. Scaling mechanistic interpretability to larger models
  2. Automated discovery of interpretable structures
  3. Training dynamics understanding across different tasks
  4. Circuit-level interventions for model editing

Conclusion: Alien Intelligence

The grokking phenomenon reveals something profound about neural networks. While we interact with AI through human language, the underlying computations involve alien mathematical structures - trigonometric identities, high-dimensional manifolds, and emergent algorithms.

As AI researcher Andrej Karpathy noted, training large language models is "less like building animal intelligence and more like summoning ghosts." The patterns these models discover are fundamentally different from human reasoning, even when solving the same problems.

Understanding these alien computational strategies may be key to building more capable, reliable, and interpretable AI systems. Grokking represents a rare window into these mysterious processes - a transparent box in a world of black boxes.

Code Example: Implementing Grokking Detection

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

class ModularArithmeticDataset:
    def __init__(self, modulus=113, operation='add'):
        self.modulus = modulus
        self.operation = operation
        self.vocab_size = modulus + 1  # +1 for equals token

    def generate_data(self, train_fraction=0.7):
        """Generate all possible arithmetic problems"""
        data = []
        for x in range(self.modulus):
            for y in range(self.modulus):
                if self.operation == 'add':
                    result = (x + y) % self.modulus
                data.append([x, y, result])

        # Split train/test
        np.random.shuffle(data)
        split_idx = int(len(data) * train_fraction)
        return data[:split_idx], data[split_idx:]

def track_grokking(model, train_loader, test_loader, num_epochs=10000):
    """Track training dynamics to detect grokking"""
    train_losses = []
    test_losses = []
    train_accs = []
    test_accs = []

    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss, train_acc = evaluate_model(model, train_loader, criterion)

        # Testing phase  
        model.eval()
        test_loss, test_acc = evaluate_model(model, test_loader, criterion)

        # Log metrics
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        train_accs.append(train_acc)
        test_accs.append(test_acc)

        # Check for grokking (sudden test accuracy improvement)
        if epoch > 100 and test_acc > 0.95 and test_accs[epoch-10] < 0.5:
            print(f"Grokking detected at epoch {epoch}!")

    return train_losses, test_losses, train_accs, test_accs