Qihong Ruan / CS336 Notes / Assignments

🛠️ Assignments — Do the Work

Study guides for the 5 implementation-heavy assignments — where the real learning is.

Assignment 1: Basics

Official starter code & handout ↗ · due Apr 15, 2026 (2026 offering) · No solutions here — by design.

You will implement the fundamental components of a modern language model from scratch and assemble them into a working system. The deliverables are:

  1. A Byte-Pair Encoding (BPE) Tokenizer: You'll implement the BPE training algorithm to learn a vocabulary and merge rules from a text corpus. Your tokenizer will have train, encode, and decode methods. The specific style is based on GPT-2, which involves a pre-tokenization step before running BPE on word-level chunks.
  2. A Transformer Model: You will build the decoder-only Transformer architecture as a torch.nn.Module. This involves implementing its constituent parts: multi-head self-attention (MHA) with Rotary Position Embeddings (RoPE), a SwiGLU feed-forward network, and RMSNorm for normalization. You will assemble these into Transformer blocks and stack them to form the full model.
  3. An AdamW Optimizer: You will implement the AdamW optimizer by creating a custom class that inherits from torch.optim.Optimizer. This requires managing the optimizer's state (first and second moment estimates) for each model parameter and implementing the step() logic, including bias correction and decoupled weight decay.
  4. A Training Harness: You will write a script that uses these three components to train a minimal language model on a real dataset (like TinyStories). This involves setting up the data loading, the training loop, calculating the cross-entropy loss, and calling your custom optimizer.

These pieces fit together to form a complete training pipeline: raw text is fed to your tokenizer to create integer sequences, which are batched and passed to your Transformer model. The model's output is used to calculate a loss, and your AdamW optimizer updates the model's weights to minimize this loss.

Why this is where the learning happens

Watching lectures gives you a map, but implementing this assignment is the journey where you actually learn the terrain. Building these components from first principles forces you to move beyond "leaky abstractions" and confront the details that make these systems work. You will internalize how architectural choices (like RMSNorm vs. LayerNorm) are not just theoretical but have real consequences for speed and memory, and how an algorithmic choice in your tokenizer directly impacts the shape of your model's embedding matrix. This struggle is what solidifies the concepts and builds the engineering mindset required to innovate.

What makes it hard

This assignment combines conceptual hurdles with non-trivial engineering. Don't underestimate it.

  • Algorithmic Complexity (BPE): A naive BPE training implementation that re-scans the corpus for every merge is computationally infeasible ($O(\text{CorpusSize} \times \text{NumMerges})$). You must devise an efficient method to track and update pair frequencies without repeated scans, which requires careful data structure design. The encoding process also has a subtle greedy logic that must be implemented deterministically.
  • Tensor Gymnastics (Transformer): The logic of multi-head attention involves a flurry of tensor reshaping, transposing, and batch matrix multiplications. Keeping track of dimensions (batch, seq_len, num_heads, head_dim) is a notorious challenge. A single incorrect transpose or view can introduce silent bugs that are hard to trace. Implementing the causal mask and RoPE correctly requires precise control over tensor indices and dimensions.
  • Numerical Stability: You are building a deep neural network where numbers can easily explode or vanish. The attention softmax, $softmax(\frac{QK^T}{\sqrt{d_k}})$, can overflow if the logits in the numerator are too large. The $\sqrt{d_k}$ scaling factor is not just a suggestion; it's critical for stable training. Your choice of normalization (pre-norm vs. post-norm) will dramatically affect gradient flow and training stability.
  • Memory Management: A language model is a memory hog. For every parameter in your model, you must store not only its value (4 bytes for float32), but also its gradient (4 bytes), and the AdamW optimizer's state (8 bytes for the two moments). That's 16 bytes per parameter, before even considering the memory for activations, which scales with batch size and sequence length. You will quickly learn why techniques like mixed-precision training are essential.
  • Silent Bugs: The most frustrating challenge is that your code might be "wrong" but still run without errors. An incorrect attention mask, a miscalculated gradient in your optimizer, or a bug in your BPE encoding might not cause a crash, but will simply lead to a model that refuses to learn. The loss will stagnate, and you won't know why.

How to approach it

Tackle this assignment by building and testing components incrementally. Do not try to write everything at once.

  1. Start with the Optimizer: Implement your AdamW optimizer first. You can test it in isolation on a simple problem, like fitting a linear regression model. You can even write a unit test that performs a single step with your optimizer and asserts that the resulting parameter values match those produced by torch.optim.AdamW on the same inputs.
  2. Build the Tokenizer:
    • First, focus on the encode and decode methods. Assume you are given a vocabulary and a list of learned merges. Can you correctly apply the merges to encode a string and then decode it back? Test your round-trip capability.
    • Next, implement the train method. Start with a tiny, one-sentence corpus. Manually trace the first few merges to ensure your pair counting and update logic is correct. Then, scale up to the full training corpus.
  3. Construct the Transformer, Piece by Piece:
    • Start with the innermost components. Implement RMSNorm and the SwiGLU feed-forward block. Write small tests to verify their output shapes.
    • Tackle self-attention. First, implement a single attention head without RoPE. Get the masking and scaled dot-product right.
    • Then, wrap this in a MultiHeadAttention module. This is mostly about managing shapes and combining the outputs of the individual heads.
    • Now, add RoPE. This is a modification to the queries and keys before they are used in the attention calculation.
    • Finally, assemble a single TransformerBlock using your MHA and FFN modules, along with the residual connections and normalization.
  4. Assemble and Test the Full Model: Stack your TransformerBlocks to create the full model. Your goal is to achieve "loss-goes-down" on a single, tiny batch of data. Set up a minimal training loop with your model, a standard CrossEntropyLoss, and your custom AdamW optimizer. If you can overfit a single batch and drive the loss to near-zero, your entire pipeline is likely correct.
  5. Scale Up: Once you've successfully overfit a single batch, you can move on to training on the full dataset and trying to achieve a competitive perplexity score.

Common pitfalls & debugging tips

  • Incorrect Attention Mask: A classic bug. Your causal mask must prevent a position from attending to future positions. A common error is an off-by-one in the mask's dimensions or applying it incorrectly (e.g., adding 0 instead of -inf before the softmax). Tip: After computing your attention scores but before the softmax, extract a single attention matrix from your batch and print it. The upper triangle (above the diagonal) should be -inf.
  • Mismatched Tensor Shapes: Use assert statements liberally to check the shapes of your tensors at every step, especially within the attention mechanism. Comments like # shape: (batch, heads, seq_len, head_dim) are good, but assert x.shape == (b, h, s, d) is better.
  • BPE Merge Ambiguity: The encoding process must be deterministic. When multiple merges are possible (e.g., in "abc", you could merge 'ab' or 'bc'), the ambiguity is resolved by picking the merge that was learned earliest (has the lowest rank). If there's a tie, the leftmost occurrence in the sequence wins.
  • Weight Decay vs. L2 Regularization: The 'W' in AdamW stands for decoupled weight decay. It is not the same as adding an L2 regularization term to the loss. Weight decay is applied directly to the weights during the optimizer step. Make sure you implement this distinction correctly.
  • Forgetting .to(device): A simple but common error. Your model, your optimizer's state, and your data must all live on the same device (CPU or GPU).
  • The Single-Batch Test is Your Best Friend: If your model can't learn to perfectly memorize a single batch of data, it has no hope of learning from the entire dataset. If your loss isn't plummeting towards zero in this test, there is a bug in your model, loss calculation, or optimizer step function.
  • Check Gradients: If you suspect your backward pass is wrong, you can inspect the .grad attribute of your model's parameters after loss.backward(). Are they None? Are they all zero? This can point to issues like detached parts of your computation graph.

Assignment 2: Systems

Official starter code & handout ↗ · due Apr 29, 2026 (2026 offering) · No solutions here — by design.

You will implement three key systems components for large-scale language modeling. First, you'll use profiling tools to benchmark a standard Transformer block and identify performance bottlenecks, particularly in the attention layer. Second, you will write a custom FlashAttention-2 forward and backward pass from scratch in Triton, a Python-based language for GPU programming. This involves implementing tiling, online softmax, and recomputation to minimize slow memory access. Finally, you will use PyTorch's Fully Sharded Data Parallelism (FSDP) to build a memory-efficient, multi-GPU training loop, allowing you to train models that wouldn't fit on a single device.

Why this is where the learning happens

Watching a lecture on GPU architecture is like reading a map; implementing a custom kernel is like navigating the terrain yourself. You will directly confront the "memory wall"—the massive gap between compute speed and memory bandwidth—that lectures can only describe. By manually managing data movement between slow HBM and fast on-chip SRAM, you will gain a physical intuition for why algorithms like FlashAttention are not just clever, but essential for modern AI. This struggle transforms abstract concepts like tiling and kernel fusion into concrete engineering skills.

What makes it hard

The primary challenge is bridging the gap between a high-level algorithm and a low-level, hardware-aware implementation.

  • Numerical Stability: The "online softmax" trick in FlashAttention requires maintaining a running maximum and a sum of exponentials. A naive implementation can easily lead to inf or NaN values. The update rule when the running maximum changes, $l_{new} = e^{m_{old} - m_{new}} l_{old} + e^{s_i - m_{new}}$, involves careful scaling to prevent overflow/underflow.
  • Memory Management & Correctness: In Triton, you are responsible for calculating memory offsets for loading "tiles" of your Q, K, and V matrices. An off-by-one error in your pointer arithmetic can lead to incorrect data being loaded, silent correctness failures, or hard crashes. You must also use masks correctly to handle sequences whose lengths are not perfect multiples of your tile size.
  • Conceptual Complexity of the Backward Pass: The backward pass of attention is more complex than the forward pass. To save memory, FlashAttention recomputes attention scores from the forward pass instead of storing them. You must correctly derive and implement the gradients with respect to Q, K, and V, while re-materializing the necessary values on-the-fly within your tiled execution, all without access to the full $N \times N$ attention matrix.
  • Distributed Systems Complexity: Setting up a distributed training job with FSDP involves managing process groups and understanding how model parameters, gradients, and optimizer states are sharded across devices. A common trap is a deadlock, where one GPU waits for another in a collective operation that never completes, causing the program to hang indefinitely.

How to approach it

Follow this roadmap to build your components incrementally, testing at each stage.

  1. Profile First: Before optimizing, you must measure.

    • Set up a simple benchmark for a single Transformer block using a standard PyTorch implementation.
    • Use torch.profiler to generate a trace. Identify which operations (e.g., aten::matmul, aten::softmax) consume the most GPU time. This will confirm that the attention calculation is your bottleneck and motivate the need for a custom kernel.
  2. Implement the FlashAttention Forward Pass in Triton:

    • Start with the simplest case: a "naive" attention kernel in Triton that computes attention for a single row of Q against all of K and V, without tiling the sequence length dimension. Verify its correctness against torch.nn.functional.scaled_dot_product_attention.
    • Now, implement the full tiled forward pass. Your outer loop will iterate over blocks of the K/V sequence. Your inner logic will process a block of Q against a block of K/V.
    • Focus on the core "online softmax" state: what running statistics (the max value m and the normalization factor l) do you need to carry through the outer loop?
    • Think carefully about how to update your output block O when you encounter a new maximum value. You'll need to rescale the existing O before adding the contribution from the new block.
    • Test your tiled implementation for numerical correctness against the PyTorch reference on various matrix sizes, especially those that are not perfectly divisible by your tile size.
  3. Implement the FlashAttention Backward Pass:

    • First, ensure you are saving the correct, minimal set of values from the forward pass to global memory for use in the backward pass (hint: you need the output O and the final softmax statistics L).
    • Structure your backward pass kernel to mirror the tiled structure of the forward pass. You will iterate through blocks of Q, K, and V to recompute the attention scores needed for the gradient calculation.
    • Derive or look up the gradients dO, dQ, dK, dV. Implement the calculation for dQ, dK, and dV inside your tiled loop. Remember that dK and dV are accumulated across the blocks of Q.
    • Again, test for numerical correctness at every step by comparing your kernel's output gradients to those produced by PyTorch's autograd on the reference attention function. Use torch.allclose with a reasonable tolerance.
  4. Build the Distributed Training System:

    • Start with a standard model (like the one from your previous assignments) and a training loop that runs on a single GPU. Profile its peak memory usage.
    • Wrap your model with torch.distributed.fsdp.FullyShardedDataParallel. You will need to set up a torch.distributed process group.
    • Run the FSDP-wrapped training loop on multiple GPUs. Profile the peak memory usage per GPU. You should see a significant reduction compared to the single-GPU baseline, demonstrating that you've successfully sharded the model's state.

Common pitfalls & debugging tips

  • Benchmarking Errors: Always use torch.cuda.synchronize() before starting and stopping your timer. The CPU and GPU execute asynchronously, and failing to synchronize will result in you timing only the CPU's kernel launch time, not the actual GPU execution. Also, run several warm-up iterations before measuring.
  • Triton Pointer Math: Use print() statements inside your Triton kernel (with TRITON_INTERPRET=1) to debug the offsets and pointers you are calculating. It's slow but invaluable for catching indexing errors.
  • Handling Edges: Your code must handle sequence lengths that aren't a multiple of your BLOCK_SIZE. Use the mask argument in tl.load and tl.store to prevent out-of-bounds memory access. Forgetting this will lead to garbage results or crashes.
  • Numerical Stability in Softmax: When implementing the online softmax, always subtract the current running maximum from your scores before calling tl.exp(). This prevents the exponentiation from producing inf.
  • Check Against a Reference: The most powerful debugging tool you have is a correct reference implementation. At every stage of your Triton kernel development (forward and backward), compare your output tensor against the output of torch.nn.functional.scaled_dot_product_attention using torch.allclose. If they don't match, they should be very close.
  • Backward Pass State: For the backward pass, you only need to have stored the final O, L, and dO in global memory. Everything else (like the S matrix) must be recomputed on the fly from Q, K, and V. If you find yourself trying to save the S matrix, you've missed the point of recomputation.
  • FSDP Hangs: If your multi-GPU training script hangs, it's almost always a deadlock. This means one or more processes have entered a collective communication call (like all_reduce) while others have not. Ensure that all ranked processes are executing the exact same code path that involves collective calls.

Assignment 3: Scaling

Official starter code & handout ↗ · due May 6, 2026 (2026 offering) · No solutions here — by design.

You will write a suite of Python scripts to conduct a small-scale replication of the Chinchilla scaling law experiments. Your goal is to determine the compute-optimal way to train a Transformer model. You will not be building the model or trainer itself; instead, you'll use a provided API that lets you train a model of a given size (N parameters) on a given amount of data (D tokens) and returns the final validation loss.

Your deliverables will be: 1. An Experiment Runner: A script that systematically calls the training API for various combinations of model sizes (N) and dataset sizes (D) to gather loss data. 2. An IsoFLOP Analyzer: A tool that implements the "IsoFLOP" method. For a fixed compute budget C, this tool will train several (N, D) configurations that all cost C FLOPs and identify the optimal N that minimizes loss. 3. A Scaling Law Fitter: A script that takes the optimal points from multiple IsoFLOP analyses and fits power laws to them. Specifically, you will find the exponents a and b in the relationships $N_{opt}(C) \propto C^a$ and $D_{opt}(C) \propto C^b$. 4. A Predictor: A final function that uses your fitted scaling laws to predict the optimal model size and number of training tokens for a large, unseen training budget.

These components will work together to transform raw experimental data into a predictive engineering tool.

Why this is where the learning happens

Watching a lecture on scaling laws can make the concepts seem abstract and the results pre-ordained. This assignment forces you to confront the messy reality of generating the data yourself. You will gain a visceral understanding of the trade-off between model size and data by plotting the U-shaped IsoFLOP curves from your own experiments. Actually fitting the power-law functions to your noisy, hard-won data points will solidify why log-log plots are so essential and how a few small-scale runs can yield powerful predictions about massive models.

What makes it hard

The primary challenge is not in writing complex code, but in correctly orchestrating the experiments and interpreting the results.

  • Conceptual Complexity: The core of the assignment is the IsoFLOP method (Chinchilla's "Method 2"). You must internalize the compute formula, $C \approx 6ND$, and understand that for a fixed C, N and D are inversely proportional. You are searching for the minimum of a U-shaped loss curve that exists along this hyperbola.
  • Managing Experimental Sweeps: You will be running dozens of training jobs. Keeping track of the parameters (N, D, C) and their resulting loss requires careful organization. A single mislabeled data point can corrupt your analysis.
  • Curve Fitting Instability: You will likely use a library function like scipy.optimize.curve_fit to fit your scaling laws. These numerical optimizers can be sensitive. If you provide a poor functional form or bad initial guesses for the parameters, the fit can fail to converge or produce nonsensical results. The challenge is in understanding the expected shape of the data to guide the fitting process.
  • Correctness of Measurement: As the notes explain, the way you measure loss is critical. A loss value for a given (N, D) pair is only valid if it comes from a model of size N trained from scratch for D tokens with an appropriate learning rate schedule. You cannot simply take intermediate checkpoints from one long run; this will produce biased data and lead to incorrect scaling exponents.

How to approach it

Follow this roadmap to build your solution incrementally and ensure each part is working before moving to the next.

  1. Master the Training API: Before anything else, figure out how to run a single experiment. Call the provided training function with a fixed N and D and verify that you can successfully retrieve the final validation loss.

  2. Run a Simple Scaling Sweep: To build intuition, start with a one-dimensional analysis.

    • Fix a model size N. Run experiments for several different data sizes D.
    • Plot log(loss) vs. log(D). Do you see a roughly linear trend? This is your first scaling law!
  3. Implement the IsoFLOP Analysis for a Single Budget: This is the heart of the assignment.

    • Pick one, medium-sized compute budget C (e.g., 1e21 FLOPs).
    • Write a helper function that generates 5-7 pairs of (N, D) that all satisfy the compute constraint $C = 6ND$. Choose your N values to be spaced logarithmically.
    • Run the training for each of these pairs and save the results.
    • Plot loss vs. N (with N on a log scale). You should see a U-shaped curve.
    • Find the optimal model size, N_opt, that corresponds to the minimum of this curve.
  4. Scale Up the Analysis: Now, repeat the process from step 3 for several different compute budgets C. Choose budgets that are an order of magnitude apart (e.g., 1e20, 1e21, 1e22 FLOPs). At the end of this step, you should have a list of optimal points: (C_1, N_opt(C_1), D_opt(C_1)), (C_2, N_opt(C_2), D_opt(C_2)), etc.

  5. Fit the Final Power Laws: With your list of optimal points, you can now determine the scaling exponents.

    • Plot log(N_opt) vs. log(C). The data should look like a line. Perform a linear regression to find the slope, which is your exponent a.
    • Do the same for log(D_opt) vs. log(C) to find the slope b.
  6. Make Your Prediction: Use the fitted laws ($N_{opt} \propto C^a$, $D_{opt} \propto C^b$) to calculate the optimal N and D for the final, large compute budget specified in the assignment prompt.

Common pitfalls & debugging tips

  • "My IsoFLOP curve isn't U-shaped!"

    • Check your range of N values for a fixed C. If they are too close together, you may only see one side of the "U". Your N values should span at least two orders of magnitude.
    • Verify your compute calculation. For each (N, D) pair, assert that 6 * N * D is very close to your target C. A simple mistake here is a common source of error.
  • Curve fitting is failing or giving weird results.

    • Don't fit the complex Chinchilla loss function directly unless you have to. For finding the exponents a and b, it's much more robust to work in log-space and fit a simple line (y = mx + c).
    • If you must fit a non-linear function, provide good initial guesses (p0) to the optimizer. For example, the irreducible error term E_ir must be lower than any loss you've observed.
  • "My predicted exponents are very different from Chinchilla's."

    • This is possible! Your small-scale experiment may have different characteristics. However, first rule out errors.
    • The most critical error to avoid: Do not use intermediate checkpoints from a single long run. Each (N, D) data point must come from its own independent training run to be valid.
    • To find the minimum of the U-shaped IsoFLOP curve, don't just take the lowest point you sampled. It's more accurate to fit a simple parabola to the (log(N), loss) points and calculate the true minimum of that parabola. This smooths out experimental noise.
  • Forgetting the constants.

    • The compute formula is $C \approx 6ND$. Don't forget the 6. Write a helper function for this calculation and use it everywhere to ensure consistency.
  • Plotting on the wrong scale.

    • Scaling laws appear as straight lines on log-log plots. If your data looks curved, double-check that both your x and y axes are logarithmic. Use plt.loglog() or np.log() on your data before passing it to plt.plot().

Assignment 4: Data

Official starter code & handout ↗ · due May 20, 2026 (2026 offering) · No solutions here — by design.

Before you start — lectures: Lec 13: Data 1 · Lec 14: Data 2

You will construct a multi-stage data processing pipeline to transform raw, messy web data from Common Crawl into a clean, deduplicated, and filtered dataset suitable for pretraining a language model. This is not a single program, but a series of connected components that act like a factory assembly line for data.

Your pipeline will consist of these core stages: 1. Extractor: A tool that reads raw web data in WARC (Web ARChive) format, parses the HTML, and extracts the main textual content, discarding boilerplate like navigation bars, ads, and scripts. 2. Filter: A module that applies a series of cleaning and filtering rules. This includes language identification to keep only English documents and quality heuristics to remove low-quality or nonsensical text (e.g., documents with too little content, repetitive phrases, or "lorem ipsum" placeholder text). 3. Deduplicator: A scalable system to identify and remove near-duplicate documents. This is the most complex component, as a naive document-to-document comparison is computationally impossible at this scale. You will implement an efficient, approximate algorithm to find these duplicates.

The final output will be a collection of clean text files, significantly smaller but much higher in quality than the original raw data, ready to be tokenized and fed into a model.

Why this is where the learning happens

Lectures can introduce you to concepts like "C4" or "RefinedWeb," but building your own data pipeline forces you to confront the messy reality they abstract away. You will learn firsthand that "web data" is not a clean corpus but a chaotic mix of high-quality prose, machine-generated spam, and everything in between. The true challenge—and the real learning—is in making principled, pragmatic trade-offs between processing speed, memory usage, and data quality at a scale where brute-force solutions fail.

What makes it hard

This assignment combines data engineering, algorithmic thinking, and heuristic design. The primary challenges are not in the complexity of any single line of code, but in making the entire system work correctly and efficiently at scale.

  • Scale and Memory: A single Common Crawl dump is many terabytes. Even a small fraction is too large to fit into a single machine's memory. You cannot load the whole dataset to process it; you must design a streaming pipeline where documents are processed one by one or in small batches, without assuming you can see all the data at once.
  • Algorithmic Complexity of Deduplication: The core of deduplication is finding pairs of similar documents. A naive approach would compare every document with every other document, an $O(N^2)$ operation that is computationally infeasible for millions of documents. The key challenge is to replace this with a near-linear time algorithm. This requires understanding and implementing techniques like:
    • Shingling: Representing a document as a set of n-grams (its "shingles").
    • MinHash: Creating a small, fixed-size "signature" from the infinite set of shingles, with the special property that the probability of two signatures matching is proportional to the Jaccard similarity of the original documents: $P(h_{min}(A) = h_{min}(B)) = \frac{|A \cap B|}{|A \cup B|}$.
    • Locality Sensitive Hashing (LSH): A technique to find documents with similar MinHash signatures without comparing all pairs. It works by hashing signatures into buckets such that similar signatures are likely to land in the same bucket.
  • Heuristics and "Correctness": There is no single "correct" way to clean web text. HTML-to-text extraction is inherently lossy. Quality filters (e.g., "must have 5 sentences" or "must not contain bad words") are brittle heuristics. You will constantly face design decisions with no perfect answer, forcing you to define what "quality" means for your dataset and accept the trade-offs. For example, a filter that removes pages with { might improve quality by removing code, but it also prevents you from training a model that can code.

How to approach it

Tackle this assignment by building and testing your pipeline incrementally. Do not attempt to run on a large dataset until you are confident in each component.

  1. Start Small: Begin with a single, small WARC file. Your first goal is to write a script that can read this one file, extract plain text from the HTML payloads, and print it. Don't worry about quality yet; just get text out. Manually inspect the output. Does it contain navigation links, cookie banners, or JavaScript?
  2. Build the Filter Stage:
    • First, integrate a language identification library (like a pre-trained fastText model) into your script. Process your single WARC file and verify that you can correctly label and filter out non-English documents.
    • Next, add a series of simple, rule-based quality filters inspired by datasets like C4. For example, filter out documents that are too short, have a low alphabet-to-character ratio, or contain placeholder text. Test each heuristic on your sample file and observe what it discards.
  3. Implement the Deduplication Engine (On a Small Scale): This is the most challenging part. Break it down.
    • Shingling & MinHashing: Write a function that takes a single document's text and produces its MinHash signature. To test this, create two nearly identical documents and one very different one. Verify that the MinHash signatures of the similar documents are themselves similar, while the third is different.
    • LSH for Candidate Pairs: Implement the LSH banding technique. Your goal is a function that takes a batch of MinHash signatures and outputs candidate pairs for full comparison. The key insight is that you only need to compare documents that hash to the same bucket in at least one LSH band. This is what lets you avoid the $O(N^2)$ nightmare.
    • End-to-End Deduplication: Connect the pieces. For a small collection of documents (e.g., from one WARC file), generate all MinHash signatures, use LSH to find candidate pairs, and then for those pairs only, compute the true Jaccard similarity to make a final duplication decision.
  4. Integrate and Scale: Combine your extractor, filter, and deduplicator into a single pipeline. A major engineering challenge here is managing the state needed for deduplication. You can't hold all MinHash signatures in memory. How can you process files sequentially while still comparing a new document to all previous documents? Think about how LSH bucketing can be implemented in a streaming or multi-stage fashion. Once your pipeline is robust on one file, test it on ten, then one hundred, refining your approach to handle the scale.

Common pitfalls & debugging tips

  • My deduplication is too slow. You are almost certainly falling back to a pairwise comparison. The entire point of LSH is to avoid this. Ask yourself: "Am I iterating through every document to compare it with the current one?" If so, you need to rethink your LSH implementation. Your goal is to only retrieve documents from a specific hash bucket.
  • My memory usage explodes. You are trying to keep too much data in memory. Are you storing the full text of every document? The set of shingles for every document? You should only need to keep the compact MinHash signatures (and their LSH hashes) in a central data structure. For exact-duplicate detection (a useful first pass), consider using a memory-efficient probabilistic data structure like a Bloom Filter.
  • Deduplication finds too many/too few duplicates. This is a tuning problem. The "sensitivity" of your LSH is controlled by the number of hash functions per signature ($k$), the number of bands ($B$), and the number of rows per band ($R$). What is the relationship between these parameters and the Jaccard similarity threshold at which two documents are likely to be flagged as a candidate pair? Try to reason about the probability $1 - (1 - s^R)^B$ without just plugging in numbers. If you're finding too few duplicates, you need to make it easier for documents to collide in a bucket (e.g., fewer rows per band).
  • The extracted text quality is poor. HTML parsing is messy. Don't try to write your own parser with regex. Use a robust library like trafilatura. Even then, inspect your "bad" outputs. Is the library failing to find the main content block? Is it grabbing comment sections? You may need to add post-processing heuristics on top of the library's output.
  • Sanity-check your components in isolation. Before building the full pipeline, test each part on a tiny, handcrafted dataset of 5-10 documents. For example, create doc1.txt, doc2.txt (a near-duplicate of 1), doc3.txt (in another language), doc4.txt ("asdf qwer"), and doc5.txt (a good, unique document). Run them through each stage and assert that the output is exactly what you expect. This will save you hours of debugging on a multi-gigabyte corpus.

Assignment 5: Alignment & Reasoning RL

Official starter code & handout ↗ · due Jun 3, 2026 (2026 offering) · No solutions here — by design.

You will implement a pipeline to improve a language model's mathematical reasoning ability. This involves two main stages, mirroring the modern alignment process:

  1. Supervised Fine-Tuning (SFT) Model: You'll start with a pre-trained base model and fine-tune it on a dataset of math problems and their correct solutions (expert demonstrations). This creates a specialized model, SFT_model, that is better at math than the base model. This SFT_model will also serve as the crucial reference policy ($\pi_{ref}$) in the next stage.
  2. Direct Preference Optimization (DPO) Model: You will take your SFT_model and further train it using a preference dataset. This dataset contains pairs of responses to a math problem: one "winning" response ($y_w$) and one "losing" response ($y_l$). You will implement the DPO algorithm to train the model to prefer generating winners over losers. The final output is a DPO_model that is aligned with the preference data, hopefully making it an even better math reasoner.

The core deliverable is the implementation of the SFT and DPO training loops and the resulting models, evaluated on their ability to solve math problems.

Why this is where the learning happens

Lectures explain that DPO is a clever way to do reinforcement learning with a simple classification loss, but only by implementing it do you truly grasp the mechanics. You will confront the engineering reality of managing multiple models (the policy and the frozen reference), calculating sequence-level log-probabilities, and debugging a process where the "right" answer isn't just a label, but a preference. This struggle forces you to connect the abstract math of the DPO derivation to concrete PyTorch operations, which is where deep understanding is forged.

What makes it hard

This assignment moves beyond standard supervised learning into the more complex world of policy optimization. The challenges are both conceptual and technical.

  • Managing Multiple Models: The DPO algorithm requires simultaneous access to the policy you are training ($\pi_\theta$) and a frozen reference policy ($\pi_{ref}$, your SFT model). Keeping these two models in memory, ensuring one is frozen while the other trains, and correctly routing data through them is a significant engineering hurdle.
  • Correctly Calculating Log-Probabilities: The DPO loss is a function of the log-probabilities of entire sequences. You need to write a function that, given a model and a sequence of tokens, correctly computes $\log P(y|x) = \sum_{i} \log P(y_i | y_{
  • The DPO Loss Itself: The loss function involves a difference of log-probability ratios. The full expression is: $$ \mathcal{L}_{DPO} = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma\left(\beta \log\frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log\frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)}\right) \right] $$ Translating this equation into efficient, stable code is the main challenge. You must perform four forward passes (policy on $y_w$, policy on $y_l$, reference on $y_w$, reference on $y_l$) and combine the resulting log-probabilities correctly.
  • Debugging is Non-Intuitive: When a standard classification model fails, you can inspect the labels. Here, the "label" is a preference. If your DPO loss isn't decreasing, is it a bug in your log-prob calculation, an issue with the reference model, or a poor choice of the hyperparameter $\beta$? The feedback loop is much less direct than in SFT.

How to approach it

Follow this sequence. Do not move to the next step until the current one is working and verified.

  1. Implement and Verify SFT: First, focus entirely on Supervised Fine-Tuning. Take the base model and fine-tune it on the provided math problems and their solutions. Your goal is to create a model that is demonstrably better at math than the base model. Before moving on, generate some samples from your SFT model and evaluate its accuracy on a held-out test set. This SFT model is your deliverable for the first part and becomes the essential $\pi_{ref}$ for the second.
  2. Build the DPO Data Pipeline: Create a Dataset and DataLoader for the preference data. Each item in your batch should contain a tokenized prompt, a tokenized winning response, and a tokenized losing response. Pay close attention to padding and creating the correct attention masks for each of the three components.
  3. Implement the Log-Probability Helper: Write and isolate a helper function that takes a model, a batch of prompts, and a batch of corresponding completions, and returns the sum of the log-probabilities for each completion in the batch. Test this function thoroughly. Can it handle variable-length sequences? Does it correctly ignore the prompt tokens in its calculation?
  4. Assemble the DPO Training Step: With your data pipeline and log-prob helper ready, you can now write the main DPO training logic.
    • Load your SFT model twice: once as the policy model to be trained, and once as the reference model (remember to freeze this one!).
    • In your training loop, for each batch, get the log-probabilities for both the winning and losing responses from both the policy and reference models.
    • Use these four sets of log-probabilities to compute the DPO loss according to the formula.
    • Perform backpropagation and update the policy model's weights.
  5. Train and Evaluate: Run the DPO training. Monitor not only the DPO loss but also the model's math accuracy on the validation set. A decreasing DPO loss means you are successfully learning the preferences, but only an increase in math accuracy tells you that this alignment is achieving the ultimate goal.

Common pitfalls & debugging tips

  • Reference model is not actually frozen. Your reference model's parameters should not be passed to the optimizer, and you should call .eval() on it. A simple check: verify that the weights of your reference model are not changing after a training step.
  • Incorrect log-probability calculation. This is the most common bug. The log-probability should be calculated only for the tokens in the completion (y), not the prompt (x). Ensure your attention mask correctly separates the two and that you are summing the log-probs of the correct tokens.
  • Forgetting the beta hyperparameter. The DPO loss is scaled by a temperature parameter beta. If you forget it, or set it incorrectly, your training may be unstable or ineffective. What does a very high or very low beta imply about the loss?
  • Sanity-check your initial loss. Before training, the policy model ($\pi_\theta$) is identical to the reference model ($\pi_{ref}$). Therefore, the log-probability ratios should be 1, and their logs should be 0. The argument to the sigmoid function will be 0. What should the value of $-\log(\sigma(0))$ be? If your initial loss is not close to this value, something is wrong in your loss calculation.
  • Overfit on a single batch. A classic debugging technique. Take a single batch of preference pairs and train on it for many iterations. Your DPO loss should go to zero (or very close). If it doesn't, you have a bug in your gradient flow or loss computation.
  • Monitor the components of the loss. Don't just watch the final loss value. Log the average log-prob ratios for winners and losers: $(\log \pi_\theta(y_w|x) - \log \pi_{ref}(y_w|x))$. As training progresses, you should see this term become positive for winners and negative for losers. This confirms the model is learning to increase the relative probability of preferred responses.