Mastering Large‑Model Interview Questions: MHA, KV‑Cache, Scaled Dot‑Product, and Speculative Decoding
This guide walks through common large‑model interview challenges, including a hands‑on implementation of multi‑head attention with KV‑cache, the mathematical reason for scaling by sqrt(dₖ), a concise speculative decoding algorithm, and systematic debugging steps for NaN loss during training.
The article begins with a practical interview question: implement multi‑head attention (MHA) with KV‑cache. It explains that MHA is essentially multiple copies of single‑head attention whose outputs are concatenated. Each head has its own Q, K, V matrices, analogous to different focus areas in a meeting.
Core Idea Decomposition
Head 1 focuses on technical details.
Head 2 focuses on business logic.
Head 3 focuses on timing.
Head 4 focuses on resource allocation.
Each head maintains its own Q, K, V matrices.
Simplified Code Implementation
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
assert d_model % num_heads == 0
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.cache = {}
def forward(self, query, key, value, mask=None, use_cache=False, cache_key="default"):
batch_size = query.size(0)
seq_len = query.size(1)
# 1. Linear projections
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# 2. Reshape for multiple heads
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. KV‑cache logic (interview bonus)
if use_cache and cache_key in self.cache:
cached_K, cached_V = self.cache[cache_key]
K = torch.cat([cached_K, K], dim=2)
V = torch.cat([cached_V, V], dim=2)
if use_cache:
self.cache[cache_key] = (K, V)
# 4. Scaled dot‑product attention
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
attention_scores.masked_fill_(mask == 0, -1e9)
attention_weights = torch.softmax(attention_scores, dim=-1)
attended_values = torch.matmul(attention_weights, V)
# 5. Concatenate heads
attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
# 6. Final linear projection
output = self.W_o(attended_values)
return output, attention_weightsWhy Scale by sqrt(dₖ) in Transformer Attention?
The dot‑product of Q and K has variance proportional to dₖ. Without scaling, large dₖ makes the variance huge, causing the softmax to produce extremely peaked distributions (most probability mass on a single token) and leading to gradient vanishing. Dividing by sqrt(dₖ) normalizes the variance to 1, preserving softmax diversity, stabilizing training, and preventing gradient collapse. The article demonstrates this with a Python script that prints score ranges and softmax extremes for different dₖ values, showing how unscaled scores become extreme as dₖ grows.
Speculative Decoding Explained
Speculative decoding speeds up generation by letting a small draft model predict a batch of candidate tokens, which the large target model then verifies in parallel. The workflow consists of:
Draft stage – the small model quickly generates k tokens.
Verify stage – the large model evaluates all k tokens in a single forward pass.
Accept/Reject – tokens whose probability under the large model meets or exceeds the draft model’s probability are accepted.
This parallel verification reduces the effective generation time from O(k) to O(1), often yielding 1.5×‑3× speedups depending on model gap, task difficulty, and chosen k.
Simplified Speculative Decoding Code
def speculative_decoding(draft_model, target_model, input_ids, k=4):
"""Simplified speculative decoding implementation"""
accepted_tokens = []
current_input = input_ids
while len(accepted_tokens) < max_length:
# Draft stage
draft_tokens = []
draft_input = current_input
for _ in range(k):
with torch.no_grad():
draft_logits = draft_model(draft_input)
next_token = sample_token(draft_logits)
draft_tokens.append(next_token)
draft_input = torch.cat([draft_input, next_token.unsqueeze(0)], dim=-1)
# Verify stage
verify_input = torch.cat([current_input] + draft_tokens, dim=-1)
with torch.no_grad():
target_logits = target_model(verify_input)
# Accept/Reject
accepted_count = 0
for i in range(k):
draft_prob = get_prob(draft_logits[i], draft_tokens[i])
target_prob = get_prob(target_logits[i], draft_tokens[i])
if target_prob >= draft_prob:
accepted_tokens.append(draft_tokens[i])
accepted_count += 1
else:
accept_prob = target_prob / draft_prob
if random.random() < accept_prob:
accepted_tokens.append(draft_tokens[i])
accepted_count += 1
break
current_input = torch.cat([current_input] + accepted_tokens[-accepted_count:], dim=-1)
return accepted_tokensDebugging NaN Loss in Large‑Model Training
The article lists four common causes of NaN loss and provides concrete diagnostic code:
Gradient explosion – monitor gradient norms and clip them with torch.nn.utils.clip_grad_norm_().
Learning rate too high – start with a small LR and use schedulers like torch.optim.lr_scheduler.LinearLR.
Numerical overflow/underflow – employ mixed‑precision training with torch.cuda.amp.GradScaler and check for NaNs after each forward pass.
Data issues – validate inputs and labels for NaN/Inf values before feeding them to the model.
It then presents a full “NaN detector” class that logs loss, checks gradients periodically, and inspects model parameters for NaNs or Infs, offering a systematic troubleshooting pipeline.
Interview Preparation Advice
The concluding section emphasizes four pillars for success in large‑model interviews: solid grasp of core concepts (Transformer, attention), ability to explain underlying mathematics (e.g., scaling by sqrt(dₖ)), hands‑on experience with real‑world issues (NaN loss debugging), and awareness of cutting‑edge techniques (KV‑cache, speculative decoding). It advises candidates to articulate principles and intuition rather than merely reciting memorized answers.
Key Takeaways
Implement MHA with KV‑cache to reduce redundant computation during autoregressive generation.
Scale dot‑product attention by sqrt(dₖ) to keep softmax variance stable.
Speculative decoding leverages a fast draft model and parallel verification by a large model for significant speed gains.
Systematically debug NaN loss by checking gradients, learning rates, precision settings, and data quality.
Wu Shixiong's Large Model Academy
We continuously share large‑model know‑how, helping you master core skills—LLM, RAG, fine‑tuning, deployment—from zero to job offer, tailored for career‑switchers, autumn recruiters, and those seeking stable large‑model positions.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
