FP8 训练实战:H100 Transformer Engine 性能调优指南

为什么需要 FP8?

训练大模型的成本主要是显存和算力。FP8(8位浮点数)可以:

  • 显存占用降低 40-50%
  • 训练速度提升 1.5-2 倍
  • 成本降低 30-40%

FP8 格式详解

格式指数位尾数位动态范围精度
E5M2521.4e-4 ~ 5.7e4
E4M3439.5e-5 ~ 4.4e3

推荐:前向用 E4M3(精度高),反向用 E5M2(范围大)

实战配置

1. 环境准备

pip install transformer-engine[pytorch]
nvidia-smi  # 需要 H100/H800/A100

2. 代码示例

import torch
import transformer_engine.pytorch as te

# 初始化 FP8 层
linear = te.Linear(4096, 4096)

# 自动处理 FP8 计算
output = linear(input)

3. 关键参数

fp8_recipe = te.recipe.DelayedScaling(
    fp8_format=te.recipe.Format.E4M3,
    amax_history_len=1024,
    amax_compute_algo=max
)

踩坑记录

问题 1:Loss 发散

原因:梯度下溢,缩放因子太小

解决

fp8_recipe = te.recipe.DelayedScaling(
    fp8_format=te.recipe.Format.E4M3,
    override_linear_precision={'wgrad': False}
)

问题 2:精度下降

解决:关键层保持 BF16

# Embedding 和输出层用 BF16
x = embedding(input_ids)  # BF16
with te.fp8_autocast(enabled=True):
    x = fp8_layer(x)      # FP8
x = output_head(x)        # BF16

性能对比

配置显存占用训练速度精度损失
BF16100%1.0x0%
FP8 (E4M3)58%1.8x< 0.5%
FP8 (E5M2)55%1.9x< 1%

最佳实践

  1. 渐进式迁移:先用 BF16 训练稳定,再切换到 FP8 微调
  2. 关键层保护:Embedding 和 Output 层保持 BF16
  3. 监控指标:缩放因子、梯度范数、验证集精度

总结

FP8 训练是降低成本的有效手段,但需要谨慎配置。

不适合:第一次训模型、对精度极度敏感的任务。