Implementing Greedy and Beam Decoding for Large Language Models from Scratch
This article walks through the mechanics of greedy search and beam search in large language models, demonstrates both methods with GPT‑2 on the prompt "I have a dream", visualizes the decoding trees, compares their scores, and discusses the trade‑offs between efficiency and output quality.
In the realm of large language models (LLMs) such as GPT‑2, the model architecture and training are often highlighted, while the decoding strategies that turn logits into text are less discussed. This tutorial focuses on two core strategies—greedy search and beam search—explaining their operation, hyper‑parameters (e.g., temperature, num_beams), and practical implications.
We start with a concrete example: given the input "I have a dream", we ask GPT‑2 to generate the next five tokens. The following Python code loads the model and tokenizer from the transformers library and runs model.generate:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()
text = "I have a dream"
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
outputs = model.generate(input_ids, max_length=len(input_ids.squeeze())+5)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")The output is:
Generated text: I have a dream of being a doctor.
Although the sentence appears complete, the underlying process is more complex. The tokenizer first converts each word piece into an ID via byte‑pair encoding (BPE). The model then computes logits for the entire vocabulary; applying softmax yields a probability distribution. For example, the token "of" receives a 17 % probability after the prompt, which is expressed as P(of | I have a dream)=17%. The joint probability of a token sequence w=(w₁,…,wₜ) factorises as the product of conditional probabilities P(wₜ|w₁,…,wₜ₋₁), requiring the model to evaluate 50,257 probabilities at each step.
Greedy Search selects the highest‑probability token at every step, discarding all alternatives. Applied to the example, the step‑by‑step choices are:
Step 1: "of" (most likely after the prompt)
Step 2: "being"
Step 3: "a"
Step 4: "doctor"
Step 5: "."
This method is fast because it tracks only a single sequence, but it is short‑sighted: early choices that are only marginally better can lead to low‑probability continuations later.
To visualise the greedy process we build a decision tree with networkx and graphviz. The code below constructs a balanced tree of height 5, records each token’s log‑probability, and plots the tree:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch
# ... (functions get_log_prob, greedy_search) ...
length = 5
beams = 1
graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())
for node in graph.nodes:
graph.nodes[node]['tokenscore'] = 100
graph.nodes[node]['token'] = text
output_ids = greedy_search(input_ids, 0, length=length)
output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)
print(f"Generated text: {output}")The resulting tree (shown in the following image) highlights that while each node represents the locally optimal token, later tokens such as "being" (9.68 %) and "doctor" (2.86 %) receive low probabilities, indicating that the initial choice "of" was sub‑optimal.
Beam Search overcomes this limitation by keeping the top n candidate tokens (the beam width) at each step and expanding each candidate. The algorithm maintains a cumulative log‑probability logP(w) for each partial sequence and optionally normalises by length to avoid bias toward longer outputs.
Key modifications compared with greedy search include:
Tracking cumulative sequence scores (sum of log‑probabilities).
Length normalisation to balance short and long sequences.
Expanding the tree with a branching factor equal to the beam width.
The implementation below adds a beam_search function that supports both greedy and top‑k sampling within the beam, updates node attributes ( tokenscore, cumscore, sequencescore), and recurses until the desired length is reached.
from tqdm import tqdm
def greedy_sampling(logits, beams):
return torch.topk(logits, beams).indices
def beam_search(input_ids, node, bar, length, beams, sampling, temperature=0.1):
if length == 0:
return None
outputs = model(input_ids)
logits = outputs.logits[0, -1, :]
if sampling == 'greedy':
top_token_ids = greedy_sampling(logits, beams)
elif sampling == 'top_k':
top_token_ids = top_k_sampling(logits, temperature, 20, beams)
for j, token_id in enumerate(top_token_ids):
bar.update(1)
token_score = get_log_prob(logits, token_id)
cumulative_score = graph.nodes[node]['cumscore'] + token_score
new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0).unsqueeze(0)], dim=-1)
token = tokenizer.decode(token_id, skip_special_tokens=True)
current_node = list(graph.successors(node))[j]
graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
graph.nodes[current_node]['cumscore'] = cumulative_score
graph.nodes[current_node]['sequencescore'] = len(new_input_ids.squeeze())**-1 * cumulative_score
graph.nodes[current_node]['token'] = f"{token}_{length}_{j}"
beam_search(new_input_ids, current_node, bar, length-1, beams, sampling, 1)
length = 5
beams = 2
graph = nx.balanced_tree(beams, length, create_using=nx.DiGraph())
bar = tqdm(total=len(graph.nodes))
for node in graph.nodes:
graph.nodes[node]['tokenscore'] = 100
graph.nodes[node]['cumscore'] = 0
graph.nodes[node]['sequencescore'] = 0
graph.nodes[node]['token'] = text
beam_search(input_ids, 0, bar, length, beams, 'greedy', 1)After the search finishes, we locate the leaf node with the highest sequencescore, backtrack to the root, and concatenate the tokens to obtain the best sequence:
def get_best_sequence(G):
leaf_nodes = [n for n in G.nodes() if G.out_degree(n) == 0]
max_node = max(leaf_nodes, key=lambda n: G.nodes[n]['sequencescore'])
path = nx.shortest_path(G, source=0, target=max_node)
sequence = "".join([G.nodes[n]['token'].split('_')[0] for n in path])
return sequence, G.nodes[max_node]['sequencescore']
sequence, max_score = get_best_sequence(graph)
print(f"Generated text: {sequence}")The beam‑search result is:
Generated text: I have a dream. I have a dream
Visualization of the beam tree (see image below) shows that the optimal leaf has a sequence score of –0.69, whereas the greedy path scores –1.16, confirming that beam search yields a higher‑probability output.
Conclusion – The tutorial examined two decoding strategies for LLMs, illustrating how greedy search is fast but can get trapped in local optima, while beam search balances quality and efficiency by exploring multiple candidate paths. Understanding these trade‑offs helps practitioners choose the appropriate method for tasks ranging from deterministic question answering to creative text generation.
AI Algorithm Path
A public account focused on deep learning, computer vision, and autonomous driving perception algorithms, covering visual CV, neural networks, pattern recognition, related hardware and software configurations, and open-source projects.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
