Implementing Greedy and Beam Decoding for Large Language Models from Scratch

This article walks through the mechanics of greedy search and beam search in large language models, demonstrates both methods with GPT‑2 on the prompt "I have a dream", visualizes the decoding trees, compares their scores, and discusses the trade‑offs between efficiency and output quality.

AI Algorithm Path
AI Algorithm Path
AI Algorithm Path
Implementing Greedy and Beam Decoding for Large Language Models from Scratch

In the realm of large language models (LLMs) such as GPT‑2, the model architecture and training are often highlighted, while the decoding strategies that turn logits into text are less discussed. This tutorial focuses on two core strategies—greedy search and beam search—explaining their operation, hyper‑parameters (e.g., temperature, num_beams), and practical implications.

We start with a concrete example: given the input "I have a dream", we ask GPT‑2 to generate the next five tokens. The following Python code loads the model and tokenizer from the transformers library and runs model.generate:

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()

text = "I have a dream"
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
outputs = model.generate(input_ids, max_length=len(input_ids.squeeze())+5)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated text: {generated_text}")

The output is:

Generated text: I have a dream of being a doctor.

Although the sentence appears complete, the underlying process is more complex. The tokenizer first converts each word piece into an ID via byte‑pair encoding (BPE). The model then computes logits for the entire vocabulary; applying softmax yields a probability distribution. For example, the token "of" receives a 17 % probability after the prompt, which is expressed as P(of | I have a dream)=17%. The joint probability of a token sequence w=(w₁,…,wₜ) factorises as the product of conditional probabilities P(wₜ|w₁,…,wₜ₋₁), requiring the model to evaluate 50,257 probabilities at each step.

Greedy Search selects the highest‑probability token at every step, discarding all alternatives. Applied to the example, the step‑by‑step choices are:

Step 1: "of" (most likely after the prompt)

Step 2: "being"

Step 3: "a"

Step 4: "doctor"

Step 5: "."

This method is fast because it tracks only a single sequence, but it is short‑sighted: early choices that are only marginally better can lead to low‑probability continuations later.

To visualise the greedy process we build a decision tree with networkx and graphviz. The code below constructs a balanced tree of height 5, records each token’s log‑probability, and plots the tree:

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch

# ... (functions get_log_prob, greedy_search) ...

length = 5
beams = 1
graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())
for node in graph.nodes:
    graph.nodes[node]['tokenscore'] = 100
    graph.nodes[node]['token'] = text

output_ids = greedy_search(input_ids, 0, length=length)
output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)
print(f"Generated text: {output}")

The resulting tree (shown in the following image) highlights that while each node represents the locally optimal token, later tokens such as "being" (9.68 %) and "doctor" (2.86 %) receive low probabilities, indicating that the initial choice "of" was sub‑optimal.

Greedy search decision tree
Greedy search decision tree

Beam Search overcomes this limitation by keeping the top n candidate tokens (the beam width) at each step and expanding each candidate. The algorithm maintains a cumulative log‑probability logP(w) for each partial sequence and optionally normalises by length to avoid bias toward longer outputs.

Key modifications compared with greedy search include:

Tracking cumulative sequence scores (sum of log‑probabilities).

Length normalisation to balance short and long sequences.

Expanding the tree with a branching factor equal to the beam width.

The implementation below adds a beam_search function that supports both greedy and top‑k sampling within the beam, updates node attributes ( tokenscore, cumscore, sequencescore), and recurses until the desired length is reached.

from tqdm import tqdm

def greedy_sampling(logits, beams):
    return torch.topk(logits, beams).indices

def beam_search(input_ids, node, bar, length, beams, sampling, temperature=0.1):
    if length == 0:
        return None
    outputs = model(input_ids)
    logits = outputs.logits[0, -1, :]
    if sampling == 'greedy':
        top_token_ids = greedy_sampling(logits, beams)
    elif sampling == 'top_k':
        top_token_ids = top_k_sampling(logits, temperature, 20, beams)
    for j, token_id in enumerate(top_token_ids):
        bar.update(1)
        token_score = get_log_prob(logits, token_id)
        cumulative_score = graph.nodes[node]['cumscore'] + token_score
        new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0).unsqueeze(0)], dim=-1)
        token = tokenizer.decode(token_id, skip_special_tokens=True)
        current_node = list(graph.successors(node))[j]
        graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
        graph.nodes[current_node]['cumscore'] = cumulative_score
        graph.nodes[current_node]['sequencescore'] = len(new_input_ids.squeeze())**-1 * cumulative_score
        graph.nodes[current_node]['token'] = f"{token}_{length}_{j}"
        beam_search(new_input_ids, current_node, bar, length-1, beams, sampling, 1)

length = 5
beams = 2
graph = nx.balanced_tree(beams, length, create_using=nx.DiGraph())
bar = tqdm(total=len(graph.nodes))
for node in graph.nodes:
    graph.nodes[node]['tokenscore'] = 100
    graph.nodes[node]['cumscore'] = 0
    graph.nodes[node]['sequencescore'] = 0
    graph.nodes[node]['token'] = text

beam_search(input_ids, 0, bar, length, beams, 'greedy', 1)

After the search finishes, we locate the leaf node with the highest sequencescore, backtrack to the root, and concatenate the tokens to obtain the best sequence:

def get_best_sequence(G):
    leaf_nodes = [n for n in G.nodes() if G.out_degree(n) == 0]
    max_node = max(leaf_nodes, key=lambda n: G.nodes[n]['sequencescore'])
    path = nx.shortest_path(G, source=0, target=max_node)
    sequence = "".join([G.nodes[n]['token'].split('_')[0] for n in path])
    return sequence, G.nodes[max_node]['sequencescore']

sequence, max_score = get_best_sequence(graph)
print(f"Generated text: {sequence}")

The beam‑search result is:

Generated text: I have a dream. I have a dream

Visualization of the beam tree (see image below) shows that the optimal leaf has a sequence score of –0.69, whereas the greedy path scores –1.16, confirming that beam search yields a higher‑probability output.

Beam search decision tree
Beam search decision tree

Conclusion – The tutorial examined two decoding strategies for LLMs, illustrating how greedy search is fast but can get trapped in local optima, while beam search balances quality and efficiency by exploring multiple candidate paths. Understanding these trade‑offs helps practitioners choose the appropriate method for tasks ranging from deterministic question answering to creative text generation.

PythonLLMTransformersBeam Searchdecoding strategiesGPT-2Greedy Search
AI Algorithm Path
Written by

AI Algorithm Path

A public account focused on deep learning, computer vision, and autonomous driving perception algorithms, covering visual CV, neural networks, pattern recognition, related hardware and software configurations, and open-source projects.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.