Hands‑On Building a Transformer from Scratch with PyTorch

This tutorial walks you through implementing a full Transformer model in PyTorch, starting from basic linear‑regression code, adding attention mechanisms, multi‑head attention, encoder‑decoder architecture, training loops, and inference, all reinforced with practical debugging tips.

Alibaba Cloud Developer
Alibaba Cloud Developer
Alibaba Cloud Developer
Hands‑On Building a Transformer from Scratch with PyTorch

As an engineering learner, the author stresses that writing a Transformer yourself solidifies understanding, so the guide focuses on practical PyTorch implementation rather than theoretical algorithm explanations.

1. Prepare Knowledge

First, a simple linear‑regression example is built to illustrate data preparation and training basics.

# Sample count
num_examples = 1000
# Feature count: house size, age
num_input = 2
true_w = [2, -3.4]
true_b = 4.2

features = torch.randn(num_examples, num_input, dtype=torch.float32)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
_temp = torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)
labels = labels + _temp

2. Define Model (Manual Version)

# Model parameters w and b with initial values
w = torch.tensor(np.random.normal(0, 0.01, (num_input, 1)), dtype=torch.float32)
b = torch.zeros(1, dtype=torch.float32)

# Enable gradient tracking
w.requires_grad_(True)
b.requires_grad_(True)

# Linear regression function
def linreg(X, w, b):
    return torch.mm(X, w) + b

# Squared loss (MSE)
def squared_loss(y_pred, y):
    return (y_pred - y.view(y_pred.size())) ** 2 / 2

# Stochastic gradient descent
def sgd(params, lr, batch):
    for param in params:
        param.data -= lr * param.grad / batch

3. Train Model (Manual Version)

lr = 0.03
epoch = 5
batch_size = 10

for epoch in range(epoch):
    for X, y in data_iter(batch_size, features, labels):
        ls = loss(linreg(X, w, b), y).sum()
        ls.backward()
        sgd([w, b], lr, batch_size)
        w.grad.data.zero_()
        b.grad.data.zero_()
    train_l = loss(linreg(features, w, b), labels)
    print('epoch %d, loss %f' % (epoch + 1, train_l.mean().item()))

4. Auxiliary Function: Data Iterator

def data_iter(batch, features, labels):
    nums = len(features)
    indices = list(range(nums))
    random.shuffle(indices)
    for i in range(0, nums, batch):
        t_ind = indices[i: min(i + batch, nums)]
        j = torch.LongTensor(t_ind)
        yield features.index_select(0, j), labels.index_select(0, j)

5. PyTorch Version of Linear Model

class LinearNet(nn.Module):
    def __init__(self, n_feature):
        super(LinearNet, self).__init__()
        self.linear = nn.Linear(n_feature, 1)
    def forward(self, x):
        return self.linear(x)

net = LinearNet(num_input)
init.normal_(net[0].weight, mean=0, std=0.01)
init.constant_(net[0].bias, val=0)
loss = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.03)

6. Train PyTorch Model

num_epochs = 3
for epoch in range(1, num_epochs + 1):
    for X, y in data_iter:
        output = net(X)
        l_sum = loss(output, y.view(-1, 1))
        l_sum.backward()
        optimizer.step()
        optimizer.zero_grad()
    print('epoch %d, loss: %f' % (epoch, l_sum.item()))

7. Transformer Implementation

The following sections build the core components of a Transformer: Attention, Multi‑Head Attention, Encoder, Decoder, and the full model.

7.1 Attention

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()
    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
        scores.masked_fill_(attn_mask, -1e9)
        attn = nn.Softmax(dim=-1)(scores)
        return torch.matmul(attn, V)

7.2 Multi‑Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(d_v * n_heads, d_model, bias=False)
        self.layer_norm = nn.LayerNorm(d_model)
    def forward(self, input_Q, input_K, input_V, attn_mask):
        residual, batch = input_Q, input_Q.size(0)
        Q = self.W_Q(input_Q).view(batch, -1, n_heads, d_k).transpose(1, 2)
        K = self.W_K(input_K).view(batch, -1, n_heads, d_k).transpose(1, 2)
        V = self.W_V(input_V).view(batch, -1, n_heads, d_v).transpose(1, 2)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
        prob = ScaledDotProductAttention()(Q, K, V, attn_mask)
        prob = prob.transpose(1, 2).contiguous().view(batch, -1, n_heads * d_v)
        output = self.fc(prob)
        return self.layer_norm(residual + output)

7.3 Encoder

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.source_embedding = nn.Embedding(len(source_vocab), d_model)
        self.attention = MultiHeadAttention()
    def forward(self, encoder_input):
        embedded = self.source_embedding(encoder_input)
        mask = get_attn_pad_mask(encoder_input, encoder_input)
        return self.attention(embedded, embedded, embedded, mask)

7.4 Decoder

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.target_embedding = nn.Embedding(len(target_vocab), d_model)
        self.attention = MultiHeadAttention()
    def forward(self, decoder_input, encoder_input, encoder_output):
        decoder_embedded = self.target_embedding(decoder_input)
        decoder_self_attn_mask = get_attn_pad_mask(decoder_input, decoder_input)
        decoder_subsequent_mask = get_attn_subsequent_mask(decoder_input)
        decoder_self_mask = torch.gt(decoder_self_attn_mask + decoder_subsequent_mask, 0)
        decoder_output = self.attention(decoder_embedded, decoder_embedded, decoder_embedded, decoder_self_mask)
        decoder_encoder_attn_mask = get_attn_pad_mask(decoder_input, encoder_input)
        return self.attention(decoder_output, encoder_output, encoder_output, decoder_encoder_attn_mask)

7.5 Full Transformer Model

class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.fc = nn.Linear(d_model, len(target_vocab), bias=False)
    def forward(self, encoder_input, decoder_input):
        enc_out = self.encoder(encoder_input)
        dec_out = self.decoder(decoder_input, encoder_input, enc_out)
        return self.fc(dec_out).view(-1, self.fc.out_features)

8. Training the Transformer

model = Transformer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-1)
for epoch in range(10):
    output = model(encoder_input, decoder_input)
    loss = criterion(output, target.view(-1))
    print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

9. Inference with the Trained Model

target_len = len(target_vocab)
encoder_output = model.encoder(encoder_input)
decoder_input = torch.zeros(1, target_len).type_as(encoder_input.data)
next_symbol = 4  # start token
for i in range(target_len):
    decoder_input[0][i] = next_symbol
    decoder_output = model.decoder(decoder_input, encoder_input, encoder_output)
    logits = model.fc(decoder_output).squeeze(0)
    prob = logits.max(dim=1)[1]
    next_symbol = prob.data[i].item()
    for k, v in target_vocab.items():
        if v == next_symbol:
            print('Step', i, ':', k)
            break
    if next_symbol == 0:
        break

10. Reference Materials

https://jalammar.github.io/illustrated-transformer/

http://nlp.seas.harvard.edu/annotated-transformer/

ChatGPT assistance during learning

Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

Deep LearningTransformerattentionNLPPyTorch
Alibaba Cloud Developer
Written by

Alibaba Cloud Developer

Alibaba's official tech channel, featuring all of its technology innovations.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.