Hands‑On Building a Transformer from Scratch with PyTorch
This tutorial walks you through implementing a full Transformer model in PyTorch, starting from basic linear‑regression code, adding attention mechanisms, multi‑head attention, encoder‑decoder architecture, training loops, and inference, all reinforced with practical debugging tips.
As an engineering learner, the author stresses that writing a Transformer yourself solidifies understanding, so the guide focuses on practical PyTorch implementation rather than theoretical algorithm explanations.
1. Prepare Knowledge
First, a simple linear‑regression example is built to illustrate data preparation and training basics.
# Sample count
num_examples = 1000
# Feature count: house size, age
num_input = 2
true_w = [2, -3.4]
true_b = 4.2
features = torch.randn(num_examples, num_input, dtype=torch.float32)
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
_temp = torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)
labels = labels + _temp2. Define Model (Manual Version)
# Model parameters w and b with initial values
w = torch.tensor(np.random.normal(0, 0.01, (num_input, 1)), dtype=torch.float32)
b = torch.zeros(1, dtype=torch.float32)
# Enable gradient tracking
w.requires_grad_(True)
b.requires_grad_(True)
# Linear regression function
def linreg(X, w, b):
return torch.mm(X, w) + b
# Squared loss (MSE)
def squared_loss(y_pred, y):
return (y_pred - y.view(y_pred.size())) ** 2 / 2
# Stochastic gradient descent
def sgd(params, lr, batch):
for param in params:
param.data -= lr * param.grad / batch3. Train Model (Manual Version)
lr = 0.03
epoch = 5
batch_size = 10
for epoch in range(epoch):
for X, y in data_iter(batch_size, features, labels):
ls = loss(linreg(X, w, b), y).sum()
ls.backward()
sgd([w, b], lr, batch_size)
w.grad.data.zero_()
b.grad.data.zero_()
train_l = loss(linreg(features, w, b), labels)
print('epoch %d, loss %f' % (epoch + 1, train_l.mean().item()))4. Auxiliary Function: Data Iterator
def data_iter(batch, features, labels):
nums = len(features)
indices = list(range(nums))
random.shuffle(indices)
for i in range(0, nums, batch):
t_ind = indices[i: min(i + batch, nums)]
j = torch.LongTensor(t_ind)
yield features.index_select(0, j), labels.index_select(0, j)5. PyTorch Version of Linear Model
class LinearNet(nn.Module):
def __init__(self, n_feature):
super(LinearNet, self).__init__()
self.linear = nn.Linear(n_feature, 1)
def forward(self, x):
return self.linear(x)
net = LinearNet(num_input)
init.normal_(net[0].weight, mean=0, std=0.01)
init.constant_(net[0].bias, val=0)
loss = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.03)6. Train PyTorch Model
num_epochs = 3
for epoch in range(1, num_epochs + 1):
for X, y in data_iter:
output = net(X)
l_sum = loss(output, y.view(-1, 1))
l_sum.backward()
optimizer.step()
optimizer.zero_grad()
print('epoch %d, loss: %f' % (epoch, l_sum.item()))7. Transformer Implementation
The following sections build the core components of a Transformer: Attention, Multi‑Head Attention, Encoder, Decoder, and the full model.
7.1 Attention
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
scores.masked_fill_(attn_mask, -1e9)
attn = nn.Softmax(dim=-1)(scores)
return torch.matmul(attn, V)7.2 Multi‑Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
self.fc = nn.Linear(d_v * n_heads, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, input_Q, input_K, input_V, attn_mask):
residual, batch = input_Q, input_Q.size(0)
Q = self.W_Q(input_Q).view(batch, -1, n_heads, d_k).transpose(1, 2)
K = self.W_K(input_K).view(batch, -1, n_heads, d_k).transpose(1, 2)
V = self.W_V(input_V).view(batch, -1, n_heads, d_v).transpose(1, 2)
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
prob = ScaledDotProductAttention()(Q, K, V, attn_mask)
prob = prob.transpose(1, 2).contiguous().view(batch, -1, n_heads * d_v)
output = self.fc(prob)
return self.layer_norm(residual + output)7.3 Encoder
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.source_embedding = nn.Embedding(len(source_vocab), d_model)
self.attention = MultiHeadAttention()
def forward(self, encoder_input):
embedded = self.source_embedding(encoder_input)
mask = get_attn_pad_mask(encoder_input, encoder_input)
return self.attention(embedded, embedded, embedded, mask)7.4 Decoder
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.target_embedding = nn.Embedding(len(target_vocab), d_model)
self.attention = MultiHeadAttention()
def forward(self, decoder_input, encoder_input, encoder_output):
decoder_embedded = self.target_embedding(decoder_input)
decoder_self_attn_mask = get_attn_pad_mask(decoder_input, decoder_input)
decoder_subsequent_mask = get_attn_subsequent_mask(decoder_input)
decoder_self_mask = torch.gt(decoder_self_attn_mask + decoder_subsequent_mask, 0)
decoder_output = self.attention(decoder_embedded, decoder_embedded, decoder_embedded, decoder_self_mask)
decoder_encoder_attn_mask = get_attn_pad_mask(decoder_input, encoder_input)
return self.attention(decoder_output, encoder_output, encoder_output, decoder_encoder_attn_mask)7.5 Full Transformer Model
class Transformer(nn.Module):
def __init__(self):
super(Transformer, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.fc = nn.Linear(d_model, len(target_vocab), bias=False)
def forward(self, encoder_input, decoder_input):
enc_out = self.encoder(encoder_input)
dec_out = self.decoder(decoder_input, encoder_input, enc_out)
return self.fc(dec_out).view(-1, self.fc.out_features)8. Training the Transformer
model = Transformer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-1)
for epoch in range(10):
output = model(encoder_input, decoder_input)
loss = criterion(output, target.view(-1))
print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()9. Inference with the Trained Model
target_len = len(target_vocab)
encoder_output = model.encoder(encoder_input)
decoder_input = torch.zeros(1, target_len).type_as(encoder_input.data)
next_symbol = 4 # start token
for i in range(target_len):
decoder_input[0][i] = next_symbol
decoder_output = model.decoder(decoder_input, encoder_input, encoder_output)
logits = model.fc(decoder_output).squeeze(0)
prob = logits.max(dim=1)[1]
next_symbol = prob.data[i].item()
for k, v in target_vocab.items():
if v == next_symbol:
print('Step', i, ':', k)
break
if next_symbol == 0:
break10. Reference Materials
https://jalammar.github.io/illustrated-transformer/
http://nlp.seas.harvard.edu/annotated-transformer/
ChatGPT assistance during learning
Signed-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
Alibaba Cloud Developer
Alibaba's official tech channel, featuring all of its technology innovations.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
