Llama3从零实现指南通过逐行代码解析,展示了如何基于Meta开源的Llama3模型权重实现完整的推理流程。
1、模型权重加载与配置解析
从Meta官方下载的模型文件(如consolidated.00.pth
和params.json
)中加载张量,解析关键参数:
model = torch.load("Meta-Llama-3-8B/consolidated.00.pth")
config = json.load(open("Meta-Llama-3-8B/params.json"))
dim = config["dim"] # 4096
n_heads = config["n_heads"] # 32
模型包含32层Transformer,每层含32个注意力头。
2、分词器实现
采用Andrej Karpathy的minbpe
库进行字节对编码(BPE),处理特殊标记如<|begin_of_text|>
,支持文本与Token的双向转换。
tokenizer = tiktoken.Encoding(
name=Path(tokenizer_path).name,
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
mergeable_ranks=mergeable_ranks
)
3、嵌入与归一化
Token通过嵌入层转换为4096维向量,使用均方根归一化(RMSNorm)优化数值稳定性:
def rms_norm(tensor, norm_weights):
return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights
归一化后形状保持为[17, 4096]
(17为Token数量)。
1、多头注意力拆解
加载查询(Query)、键(Key)、值(Value)权重矩阵,解包为32个独立头:
q_layer0 = model["layers.0.attention.wq.weight"]
q_layer0 = q_layer0.view(n_heads, head_dim, dim) # [32, 128, 4096]
单个头的Query计算:
q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T) # [17, 128]
2、旋转位置编码(RoPE)
将128维Query向量拆分为64对,每对应用复数旋转操作,引入位置信息:
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) # [17, 64, 2]
通过旋转角度偏移解决“相同词不同位置”的语义区分问题。
3、评分矩阵与掩码处理
计算Query-Key相关性评分矩阵(形状[17, 17]
),掩码未来Token以防止信息泄漏:
scores = torch.matmul(q_per_token, k_per_token.T) / (head_dim ** 0.5)
scores = scores.masked_fill(mask == 0, float("-inf"))
1、环境部署
pip install --pre mlscraper # 安装预发布版本
# 或从源码安装
pip install git+https://github.com/naklecha/llama3-from-scratch
2、推理验证示例
prompt = "the answer to the ultimate question of life, the universe, and everything is"
tokens = [128000] + tokenizer.encode(prompt)
embeddings = embedding_layer(torch.tensor(tokens))
# 逐层处理Transformer
final_embeddings = process_transformer_layers(embeddings)
predicted_token = decode(final_embeddings[-1]) # 预期输出为42(《银河系漫游指南》彩蛋)
适用场景
教育领域:理解Transformer层间交互与参数作用机制。
研究验证:快速测试模型修改对推理结果的影响。
局限性
仅支持推理流程,未包含训练代码。
动态渲染内容(如JavaScript生成页面)需额外适配。