Mastering Large‑Model Interview Questions: MHA, KV‑Cache, Scaled Dot‑Product, and Speculative Decoding

This guide walks through common large‑model interview challenges, including a hands‑on implementation of multi‑head attention with KV‑cache, the mathematical reason for scaling by sqrt(dₖ), a concise speculative decoding algorithm, and systematic debugging steps for NaN loss during training.

Wu Shixiong's Large Model Academy
Wu Shixiong's Large Model Academy
Wu Shixiong's Large Model Academy
Mastering Large‑Model Interview Questions: MHA, KV‑Cache, Scaled Dot‑Product, and Speculative Decoding

The article begins with a practical interview question: implement multi‑head attention (MHA) with KV‑cache. It explains that MHA is essentially multiple copies of single‑head attention whose outputs are concatenated. Each head has its own Q, K, V matrices, analogous to different focus areas in a meeting.

Core Idea Decomposition

Head 1 focuses on technical details.

Head 2 focuses on business logic.

Head 3 focuses on timing.

Head 4 focuses on resource allocation.

Each head maintains its own Q, K, V matrices.

Simplified Code Implementation

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        assert d_model % num_heads == 0
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.cache = {}

    def forward(self, query, key, value, mask=None, use_cache=False, cache_key="default"):
        batch_size = query.size(0)
        seq_len = query.size(1)
        # 1. Linear projections
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        # 2. Reshape for multiple heads
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        # 3. KV‑cache logic (interview bonus)
        if use_cache and cache_key in self.cache:
            cached_K, cached_V = self.cache[cache_key]
            K = torch.cat([cached_K, K], dim=2)
            V = torch.cat([cached_V, V], dim=2)
        if use_cache:
            self.cache[cache_key] = (K, V)
        # 4. Scaled dot‑product attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_weights = torch.softmax(attention_scores, dim=-1)
        attended_values = torch.matmul(attention_weights, V)
        # 5. Concatenate heads
        attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        # 6. Final linear projection
        output = self.W_o(attended_values)
        return output, attention_weights

Why Scale by sqrt(dₖ) in Transformer Attention?

The dot‑product of Q and K has variance proportional to dₖ. Without scaling, large dₖ makes the variance huge, causing the softmax to produce extremely peaked distributions (most probability mass on a single token) and leading to gradient vanishing. Dividing by sqrt(dₖ) normalizes the variance to 1, preserving softmax diversity, stabilizing training, and preventing gradient collapse. The article demonstrates this with a Python script that prints score ranges and softmax extremes for different dₖ values, showing how unscaled scores become extreme as dₖ grows.

Speculative Decoding Explained

Speculative decoding speeds up generation by letting a small draft model predict a batch of candidate tokens, which the large target model then verifies in parallel. The workflow consists of:

Draft stage – the small model quickly generates k tokens.

Verify stage – the large model evaluates all k tokens in a single forward pass.

Accept/Reject – tokens whose probability under the large model meets or exceeds the draft model’s probability are accepted.

This parallel verification reduces the effective generation time from O(k) to O(1), often yielding 1.5×‑3× speedups depending on model gap, task difficulty, and chosen k.

Simplified Speculative Decoding Code

def speculative_decoding(draft_model, target_model, input_ids, k=4):
    """Simplified speculative decoding implementation"""
    accepted_tokens = []
    current_input = input_ids
    while len(accepted_tokens) < max_length:
        # Draft stage
        draft_tokens = []
        draft_input = current_input
        for _ in range(k):
            with torch.no_grad():
                draft_logits = draft_model(draft_input)
                next_token = sample_token(draft_logits)
                draft_tokens.append(next_token)
                draft_input = torch.cat([draft_input, next_token.unsqueeze(0)], dim=-1)
        # Verify stage
        verify_input = torch.cat([current_input] + draft_tokens, dim=-1)
        with torch.no_grad():
            target_logits = target_model(verify_input)
        # Accept/Reject
        accepted_count = 0
        for i in range(k):
            draft_prob = get_prob(draft_logits[i], draft_tokens[i])
            target_prob = get_prob(target_logits[i], draft_tokens[i])
            if target_prob >= draft_prob:
                accepted_tokens.append(draft_tokens[i])
                accepted_count += 1
            else:
                accept_prob = target_prob / draft_prob
                if random.random() < accept_prob:
                    accepted_tokens.append(draft_tokens[i])
                    accepted_count += 1
                break
        current_input = torch.cat([current_input] + accepted_tokens[-accepted_count:], dim=-1)
    return accepted_tokens

Debugging NaN Loss in Large‑Model Training

The article lists four common causes of NaN loss and provides concrete diagnostic code:

Gradient explosion – monitor gradient norms and clip them with torch.nn.utils.clip_grad_norm_().

Learning rate too high – start with a small LR and use schedulers like torch.optim.lr_scheduler.LinearLR.

Numerical overflow/underflow – employ mixed‑precision training with torch.cuda.amp.GradScaler and check for NaNs after each forward pass.

Data issues – validate inputs and labels for NaN/Inf values before feeding them to the model.

It then presents a full “NaN detector” class that logs loss, checks gradients periodically, and inspects model parameters for NaNs or Infs, offering a systematic troubleshooting pipeline.

Interview Preparation Advice

The concluding section emphasizes four pillars for success in large‑model interviews: solid grasp of core concepts (Transformer, attention), ability to explain underlying mathematics (e.g., scaling by sqrt(dₖ)), hands‑on experience with real‑world issues (NaN loss debugging), and awareness of cutting‑edge techniques (KV‑cache, speculative decoding). It advises candidates to articulate principles and intuition rather than merely reciting memorized answers.

Key Takeaways

Implement MHA with KV‑cache to reduce redundant computation during autoregressive generation.

Scale dot‑product attention by sqrt(dₖ) to keep softmax variance stable.

Speculative decoding leverages a fast draft model and parallel verification by a large model for significant speed gains.

Systematically debug NaN loss by checking gradients, learning rates, precision settings, and data quality.

TransformerSpeculative DecodingKV cacheLarge Model InterviewNaN DebuggingScaled Dot‑ProductMulti‑Head Attention
Wu Shixiong's Large Model Academy
Written by

Wu Shixiong's Large Model Academy

We continuously share large‑model know‑how, helping you master core skills—LLM, RAG, fine‑tuning, deployment—from zero to job offer, tailored for career‑switchers, autumn recruiters, and those seeking stable large‑model positions.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.