Artificial Intelligence 14 min read

Analysis of LLaMA Model Architecture in the Transformers Library

This article walks through the core LLaMA implementation in HuggingFace’s Transformers library, detailing the inheritance hierarchy, configuration defaults, model initialization, embedding and stacked decoder layers, the RMSNorm‑based attention and MLP modules, and the forward pass that produces normalized hidden states.

Sohu Tech Products
Sohu Tech Products
Sohu Tech Products
Analysis of LLaMA Model Architecture in the Transformers Library

This document provides a detailed walkthrough of the LLaMA model implementation in the transformers repository, focusing on the core model structure while omitting tensor‑parallel and memory‑saving code.

Model inheritance : The main model class LlamaModel inherits from LlamaPreTrainedModel , which in turn inherits from PreTrainedModel , the base class for all HuggingFace models.

LlamaModel -> LlamaPreTrainedModel -> PreTrainedModel

LlamaConfig : The configuration class defines hyper‑parameters such as vocab_size , hidden_size , num_hidden_layers , and num_attention_heads . All parameters have sensible defaults, allowing a config object to be instantiated directly.

config = LlamaConfig()

LlamaModel initialization : The constructor sets the padding index and vocabulary size, creates the token embedding layer, a list of decoder layers, RMSNorm layers, and a flag for gradient checkpointing. It also calls post_init() to initialize weights.

def __init__(self, config: LlamaConfig):
    super().__init__(config)
    self.padding_idx = config.pad_token_id
    self.vocab_size = config.vocab_size
    self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
    self.layers = nn.ModuleList([LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)])
    self._use_sdpa = config._attn_implementation == "sdpa"
    self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
    self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    self.gradient_checkpointing = False
    self.post_init()

The post_init() method (defined in PreTrainedModel ) performs weight initialization and handles backward‑compatibility for gradient checkpointing.

def post_init(self):
    """Execute code after model initialization, e.g., weight init."""
    self.init_weights()
    self._backward_compatibility_gradient_checkpointing()

LlamaModel forward pass : Input token IDs are embedded, then passed through each decoder layer. Hidden states are optionally collected, and the final hidden states are normalized before being returned as a BaseModelOutputWithPast object.

inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for decoder_layer in self.layers:
    if output_hidden_states:
        all_hidden_states += (hidden_states,)
    layer_outputs = decoder_layer(
        hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_values,
        output_attentions=output_attentions,
        use_cache=use_cache,
    )
    hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)

LlamaDecoderLayer : Each layer contains a self‑attention module, an RMSNorm before and after attention, and an MLP. The forward method saves a residual, applies layer‑norm, runs attention, adds the residual, applies a second layer‑norm, runs the MLP, and adds another residual.

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, hidden_states, attention_mask=None, position_ids=None,
                past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            **kwargs,
        )
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)
        if use_cache:
            outputs += (present_key_value,)
        return outputs

Attention module (LlamaAttention) : Implements multi‑head attention with optional rotary position embeddings. It projects the hidden states into query, key, and value tensors, reshapes them, applies rotary embeddings, computes scaled dot‑product attention, adds the attention mask, applies softmax and dropout, multiplies by the value tensor, reshapes, and finally projects back with o_proj .

class LlamaAttention(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        self._init_rope()

    def forward(self, hidden_states, attention_mask=None, position_ids=None,
                past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
        bsz, q_len, _ = hidden_states.size()
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        attn_weights = attn_weights + attention_mask
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights, past_key_value

LlamaMLP : A simple feed‑forward network consisting of a gate projection, an up projection, and a down projection. The gate output is passed through the activation function before being multiplied with the up‑projected tensor.

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        return down_proj

LlamaRMSNorm : Implements Root Mean Square Layer Normalization, which normalizes inputs by their RMS value and scales them with a learnable weight.

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

The walkthrough concludes that the core components of the LLaMA model—embedding, stacked decoder layers, multi‑head attention, RMSNorm, and MLP—are now fully explained, providing a solid foundation for further experimentation or extension.

Artificial Intelligencedeep learningTransformerLlamaPyTorchModel Architecture
Sohu Tech Products
Written by

Sohu Tech Products

A knowledge-sharing platform for Sohu's technology products. As a leading Chinese internet brand with media, video, search, and gaming services and over 700 million users, Sohu continuously drives tech innovation and practice. We’ll share practical insights and tech news here.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.