Pai‑Megatron‑Patch: Design Principles, Key Features, and End‑to‑End Usage for Large Language Model Training
This article introduces the open‑source Pai‑Megatron‑Patch tool from Alibaba Cloud, explains its non‑intrusive patch architecture, enumerates supported models and features such as weight conversion, Flash‑Attention 2.0, FP8 training with Transformer Engine, and provides detailed command‑line examples for model conversion, pre‑training, supervised fine‑tuning, inference, and RLHF reinforcement learning pipelines.
With the rapid evolution of large language models (LLMs), developers need efficient training tools that hide the complexity of distributed training, model conversion, and performance optimization. Pai‑Megatron‑Patch, released by Alibaba Cloud's PAI algorithm team, is a patch‑based extension for Megatron‑LM that enables seamless integration with the PAI‑Lingjun platform without modifying Megatron's source code.
Main Features
Supports many popular LLMs (LLaMA, LLaMA‑2, CodeLLaMA, Baichuan, Qwen, Falcon, GLM, StarCoder, Bloom, ChatGLM, etc.).
Bidirectional weight conversion between HuggingFace, Megatron, and Transformer Engine formats.
Flash‑Attention 2.0 and Transformer Engine FP8 acceleration with guaranteed convergence.
Comprehensive examples covering pre‑training, supervised fine‑tuning, evaluation, inference, and reinforcement learning (RLHF).
Technical Architecture
Pai‑Megatron‑Patch adds functionality via external patches rather than invasive changes to Megatron‑LM. The patch builds the LLM training pipeline by depending on Megatron‑LM core libraries, ensuring that future Megatron upgrades do not break the user experience.
The toolset includes model libraries, tokenizers, model converters, RL modules, offline text generation utilities, and example scripts. Model libraries provide Megatron implementations for popular LLMs, and the converter maps HuggingFace checkpoints to an internal format before exporting to Megatron.
Model Weight Conversion Example (LLaMA‑2)
cd /mnt/workspace/
mkdir llama2-ckpts
cd llama2-ckpts
wget https://atp-modelzoo-wlcb-pai.oss-cn-wulanchabu.aliyuncs.com/release/models/pai-megatron-patch/llama2-ckpts/Llama-2-7b-hf.tgz
tar -zxf Llama-2-7b-hf.tgz
mv Llama-2-7b-hf llama2-7b-hf
cd /mnt/workspace/PAI-Megatron-Patch/toolkits/model_checkpoints_convertor/llama
sh model_convertor.sh \
/root/Megatron-LM-23.04 \
/mnt/workspace/llama2-ckpts/llama2-7b-hf \
/mnt/workspace/llama2-ckpts/llama2-7b-hf-to-megatron-tp1-pp1 \
1 \
1 \
llama-7b \
0 \
falseWhen converting weights for tensor‑parallelism (TP>1), special handling of MLP gate and up projections is required:
for i in range(tp_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder")
dense_h_to_4h_1_name = 'mlp.dense_h_to_4h_1.weight'
dense_h_to_4h_1_layer_name = f"layers.{layer}.{dense_h_to_4h_1_name}"
dense_h_to_4h_1_weight = params_dict[dense_h_to_4h_1_layer_name]
dense_h_to_4h_2_name = 'mlp.dense_h_to_4h_2.weight'
dense_h_to_4h_2_layer_name = f"layers.{layer}.{dense_h_to_4h_2_name}"
dense_h_to_4h_2_weight = params_dict[dense_h_to_4h_2_layer_name]
dense_h_to_4h_name = 'mlp.dense_h_to_4h.weight'
dense_h_to_4h_layer_name = f"layers.{layer}.{dense_h_to_4h_name}"
params_dict[dense_h_to_4h_layer_name] = torch.cat([dense_h_to_4h_1_weight, dense_h_to_4h_2_weight], dim=0)FP8 Training with Transformer Engine
Enabling FP8 on H100 GPUs is as simple as adding the following flags to the training script:
if [ $PR = fp8 ]; then
pr_options=" \
--bf16 \
--fp8-hybrid \
--fp8-amax-compute-algo max \
--fp8-amax-history-len 1024 \
--transformer-impl transformer_engine"
fiLoss curves for LLaMA‑7B and LLaMA‑2‑70B with and without FP8 overlap, confirming convergence.
End‑to‑End Training Workflow
1. Model Format Conversion – Convert HuggingFace checkpoints to Megatron format (see code above).
2. Pre‑training – Example command:
export WORK_DIR=/mnt/workspace
cd ${WORK_DIR}/PAI-Megatron-Patch/examples/llama2
bash run_pretrain_megatron_llama.sh \
dlc \
/root/Megatron-LM-23.04 \
${WORK_DIR}/PAI-Megatron-Patch \
7B 1 16 1e-5 1e-6 2048 80 0 fp16 1 1 sel true false false 100000 \
${WORK_DIR}/llama2-datasets/wudao/wudao_llamabpe_text_document \
${WORK_DIR}/llama2-ckpts/llama2-7b-hf-to-megatron-tp1-pp1 \
100000000 10000 ${WORK_DIR}/output_megatron_llama2/3. Supervised Fine‑tuning – Run run_finetune_megatron_llama.sh with appropriate arguments (environment, model size, batch size, learning rate, etc.).
4. Offline Inference – Use the Megatron inference script, specifying precision, tensor‑parallelism, and token limits.
5. Reinforcement Learning (RLHF) – Convert Megatron checkpoints back to HuggingFace if needed, then train a reward model and PPO using either DeepSpeed‑Chat or trlx frameworks. Example for DeepSpeed‑Chat reward‑model training:
cd PAI-Megatron-Patch/rlhf/deepspeed-chat
git clone https://github.com/microsoft/DeepSpeedExamples.git
cp -f rm_main.py DeepSpeedExamples/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py
pip install -r DeepSpeedExamples/applications/DeepSpeed-Chat/requirements.txt
cd DeepSpeedExamples/applications/DeepSpeed-Chat/
bash training_scripts/llama2/run_llama2_7b.shSimilar scripts are provided for BLOOM and GPT‑J models using the trlx library.
Open‑Source Ecosystem and Future Work
Pai‑Megatron‑Patch aims to provide lossless conversion between HuggingFace and Megatron/Transformer Engine formats, FP8 training on H800 clusters, and best‑practice pipelines for LLMs and RLHF on the PAI‑Lingjun platform. Future plans include expanding model coverage, adding LoRA support for Megatron, and further enhancements to the Transformer Engine integration. The community is invited to join the discussion via the DingTalk group (ID 29605038042).
References: Attention Is All You Need; Megatron‑LM; FP8 Formats for Deep Learning; ZeRO; LLaMA; LLaMA‑2; Benchmarking LLMs on NVIDIA H100 GPUs; etc.
Alibaba Cloud Infrastructure
For uninterrupted computing services
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.