How to Add Special Tokens to LLMs Without Losing Performance
This guide explains why naïvely adding special tokens during supervised fine‑tuning can destabilize a large language model, and provides step‑by‑step strategies—including tokenizer updates, embedding resizing, smart initialization, and LoRA‑based PEFT—to integrate new tokens while preserving the model's original capabilities.
Background
During the supervised fine‑tuning (SFT) stage of a large language model (LLM), adding special tokens such as <|user|> or <|assistant|> is common, but if the new tokens are not handled properly they can severely degrade the model’s original capabilities.
Core Issue
The newly added tokens have no pretrained vectors in the embedding matrix or the LM head, so they are initialized from scratch, causing large gradient spikes, catastrophic forgetting, and noisy predictions.
Key Steps
Step 1 – Add tokens to the tokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
special_tokens_to_add = ["<|user|>", "<|assistant|>", "<|endoftext|>"]
tokenizer.add_special_tokens({
"additional_special_tokens": special_tokens_to_add,
"pad_token": "<|endoftext|>"
})
print(f"Tokenizer vocabulary size: {len(tokenizer)}")Step 2 – Resize model embeddings and LM head
model.resize_token_embeddings(len(tokenizer))
print(f"Model input embedding size: {model.get_input_embeddings().weight.shape[0]}")
print(f"Model output head size: {model.get_output_embeddings().weight.shape[0]}")Step 3 – Initialise new token embeddings
Strategy A – Average of existing embeddings (recommended)
import torch
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
old_vocab_size = input_embeddings.weight.shape[0] - len(special_tokens_to_add)
avg_embedding = input_embeddings.weight.data[:old_vocab_size].mean(dim=0, keepdim=True)
with torch.no_grad():
input_embeddings.weight.data[old_vocab_size:] = avg_embedding.clone()
if output_embeddings is not None and output_embeddings.weight.shape[0] == len(tokenizer):
output_embeddings.weight.data[old_vocab_size:] = avg_embedding.clone()
print("New special token embeddings have been initialized with the average of old embeddings.")Strategy B – Use semantically similar tokens
# Example: initialise with "user" and "assistant" embeddings
user_token_id = tokenizer.encode("user", add_special_tokens=False)[0]
assistant_token_id = tokenizer.encode("assistant", add_special_tokens=False)[0]
user_embedding = input_embeddings.weight.data[user_token_id].clone()
assistant_embedding = input_embeddings.weight.data[assistant_token_id].clone()
new_user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
new_assistant_token_id = tokenizer.convert_tokens_to_ids("<|assistant|>")
with torch.no_grad():
input_embeddings.weight.data[new_user_token_id] = user_embedding
input_embeddings.weight.data[new_assistant_token_id] = assistant_embeddingStep 4 – Choose fine‑tuning method
Strategy A – Parameter‑efficient fine‑tuning (PEFT) with LoRA (recommended)
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj","v_proj","k_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
# modules_to_save=["embed_tokens","lm_head"] # optional for full‑parameter tuning of new tokens
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()Using modules_to_save forces the embedding and LM head to be fully trainable, which is necessary for the new tokens, while LoRA keeps the rest of the model frozen.
Strategy B – Staged full fine‑tuning (if LoRA is not used)
First freeze most layers and train only embed_tokens and lm_head for a few steps, then unfreeze all layers for complete SFT.
Summary and Best Practices
Add special tokens with tokenizer.add_special_tokens.
Resize model embeddings via model.resize_token_embeddings.
Initialise new embeddings with the average of existing embeddings (or semantically similar tokens).
Prefer LoRA‑based PEFT to keep most pretrained weights frozen while learning new token usage.
Ensure SFT data consistently uses the new tokens and evaluate both task performance and general benchmarks (e.g., MMLU, C‑Eval) to detect any degradation.
Baobao Algorithm Notes
Author of the BaiMian large model, offering technology and industry insights.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
