!1235 实现Lora权重合入HuggingFace权重

Merge pull request !1235 from guoyiwei111/master
This commit is contained in:
guoyiwei111 2024-05-31 07:51:23 +00:00 committed by i-robot
parent 6a335fa04b
commit deb8b2ebce
7 changed files with 147 additions and 19 deletions

View File

@ -170,6 +170,7 @@ LLaMA-7B/13B 训练的硬件配置如下:
--save-dir ./model_from_hf/llama-13b-hf/ # <-- 需要填入原始HF模型路径新权重会存于./model_from_hf/llama-13b-hf/mg2hg/
```
若需将Lora微调权重一并合并到HuggingFace权重需添加 --lora-dir {lora微调权重路径} 参数进行转换。\
权重转换适用于预训练、微调、推理和评估,根据任务不同调整参数`target-tensor-parallel-size`和`target-pipeline-parallel-size`。
5. 预训练
@ -605,7 +606,7 @@ LLaMA-33B/65B 训练的硬件配置:
--target-pipeline-parallel-size 1 \
--save-dir ./model_from_hf/llama-65b-hf/ # <-- 需要填入原始HF模型路径新权重会存于./model_from_hf/llama-65b-hf/mg2hg/
```
若需将Lora微调权重一并合并到HuggingFace权重需添加 --lora-dir {lora微调权重路径} 参数进行转换。\
权重转换适用于预训练、微调、推理和评估,根据任务不同调整参数`target-tensor-parallel-size`和`target-pipeline-parallel-size`。
5. 预训练

View File

@ -165,6 +165,7 @@ Here's a hardware summary of pre-training LLaMA-7B/13B:
--save-dir ./model_from_hf/llama-13b-hf/ # <-- Fill in the original HF model path here, new weights will be saved in ./model_from_hf/llama-13b-hf/mg2hg/
```
If you need combine Lora weight to huggingface weight, please add --lora-dir {lora weight path} \
Weight conversion is suitable for pre-training, fine-tuning, inference and evaluation. Adjust the parameters `target-tensor-parallel-size` and `target-pipeline-parallel-size` according to different tasks.
5. Pretrain
@ -588,6 +589,7 @@ The model was trained using alpaca datasets.
--save-dir ./model_from_hf/llama-65b-hf/ # <-- Fill in the original HF model path here, new weights will be saved in ./model_from_hf/llama-65b-hf/mg2hg/
```
If you need combine Lora weight to huggingface weight, please add --lora-dir {lora weight path} \
Weight conversion is suitable for pre-training, fine-tuning, inference and evaluation. Adjust the parameters `target-tensor-parallel-size` and `target-pipeline-parallel-size` according to different tasks.
5. Pretrain

View File

@ -138,7 +138,7 @@ LLAMA2-7B 训练的硬件配置:
--target-pipeline-parallel-size 1 \
--save-dir ./model_from_hf/llama-2-7b-hf/ # <-- 需要填入原始HF模型路径新权重会存于./model_from_hf/llama-2-7b-hf/mg2hg/
```
若需将Lora微调权重一并合并到HuggingFace权重需添加 --lora-dir {lora微调权重路径} 参数进行转换。\
权重转换适用于预训练、微调、推理和评估,根据任务不同调整参数 `target-tensor-parallel-size``target-pipeline-parallel-size`
5. 预训练
@ -871,7 +871,7 @@ LLaMA2-34B/70B 训练的硬件配置:
--target-pipeline-parallel-size 1 \
--save-dir ./model_from_hf/llama-2-34b-hf/ # <-- 需要填入原始HF模型路径新权重会存于./model_from_hf/llama-2-34b-hf/mg2hg/
```
若需将Lora微调权重一并合并到HuggingFace权重需添加 --lora-dir {lora微调权重路径} 参数进行转换。\
权重转换适用于预训练、微调、推理和评估,根据任务不同调整参数`target-tensor-parallel-size`和 `target-pipeline-parallel-size`
5. 预训练

View File

@ -151,6 +151,7 @@ Here's a hardware summary of pre-training LLAMA2-7B:
--save-dir ./model_from_hf/llama-2-7b-hf/ # <-- Fill in the original HF model path here, new weights will be saved in ./model_from_hf/llama-2-7b-hf/mg2hg/
```
If you need combine Lora weight to huggingface weight, please add --lora-dir {lora weight path} \
Weight conversion is suitable for pre-training, fine-tuning, inference and evaluation. Adjust the parameters `target-tensor-parallel-size` and `target-pipeline-parallel-size` according to different tasks.
5. pre-training
@ -886,6 +887,7 @@ Here's a hardware summary of pre-training LLaMA2-34B/70B:
--save-dir ./model_from_hf/llama-2-34b-hf/ # <-- Fill in the original HF model path here, new weights will be saved in ./model_from_hf/llama-2-34b-hf/mg2hg/
```
If you need combine Lora weight to huggingface weight, please add --lora-dir {lora weight path} \
Weight conversion is suitable for pre-training, fine-tuning, inference and evaluation. Adjust the parameters `target-tensor-parallel-size` and `target-pipeline-parallel-size` according to different tasks.
5. pre-training

View File

@ -0,0 +1,76 @@
import unittest
import sys
import os
import subprocess
from pathlib import Path
import numpy as np
import torch
from ascendspeed import megatron_adaptor
from transformers import AutoModelForCausalLM
from utils import judge_expression
import modellink
from tests.pipeline.common import DistributedTest
class CovertCkptToHuggingfaceArgs:
model_type = "GPT"
loader = "megatron"
saver = "megatron"
save_dir = "/home/dataset/ci_engineering/llama-2-7b-hf"
lora_dir = "/home/dataset/ci_engineering/llama2-7B-lora-ckpt"
load_dir = "/home/dataset/ci_engineering/llama2-7B-tp8-pp1"
class TestConvertCkptFromHuggingface(DistributedTest):
def test_combine_lora_weights_to_huggingface(self):
"""
Test whether the weight to be converted as we want in `--lora-dir`. We will check the combine weight
in huggingface equals loraB @ loraA * rate + base in megatron.
"""
args = CovertCkptToHuggingfaceArgs()
rate = 2
hidden_layer = 4096
num_head = 32
tp = 8
dk = 128
base_dir = Path(__file__).absolute().parent.parent.parent
file_path = os.path.join(base_dir, "tools/checkpoint/convert_ckpt.py")
arguments = [
"--model-type", args.model_type,
"--loader", args.loader,
"--saver", args.saver,
"--save-model-type", "save_huggingface_llama",
"--load-dir", args.load_dir,
"--lora-dir", args.lora_dir,
"--target-tensor-parallel-size", "1",
"--target-pipeline-parallel-size", "1",
"--save-dir", args.save_dir
]
subprocess.run(["python3", file_path] + arguments)
output_dir = os.path.join(args.save_dir, "mg2hg")
model = AutoModelForCausalLM.from_pretrained(output_dir, trust_remote_code=True, low_cpu_mem_usage=True)
q_hf = model.state_dict()["model.layers.0.self_attn.q_proj.weight"]
judge_expression(q_hf.size() == torch.Size([4096, 4096]))
base_dir = os.path.join(args.load_dir, "iter_0000001")
weight_base = torch.load(os.path.join(base_dir, "mp_rank_00/model_optim_rng.pt"))
weight_base_content = weight_base['model']['language_model']['encoder'] # extract commmon content
base_qkv = weight_base_content['layers.0.self_attention.query_key_value.weight']
lora_dir = os.path.join(args.lora_dir, "iter_0000010")
weight_lora = torch.load(os.path.join(lora_dir, "mp_rank_00/model_optim_rng.pt"))
weight_lora_content = weight_lora['model']['language_model']['encoder'] # extract commmon content
loraB_qkv = weight_lora_content['layers.0.self_attention.query_key_value.lora_B.default.weight']
loraA_qkv = weight_lora_content['layers.0.self_attention.query_key_value.lora_A.default.weight']
res_qkv = loraB_qkv.cpu().float() @ loraA_qkv.cpu().float() * rate + base_qkv
gp1_q_mg = res_qkv.reshape(num_head // tp, 3, dk, hidden_layer)[:1, :1, :, :].reshape(dk, hidden_layer)
gp1_q_hf = q_hf.reshape(num_head, dk, hidden_layer)[:1, :, :].reshape(dk, hidden_layer)
judge_expression(np.allclose(gp1_q_mg.cpu(), gp1_q_hf.cpu(), rtol=0.001, atol=0.001))

View File

@ -63,6 +63,8 @@ def main():
help='Module name to save checkpoint, shdoul be on python path')
parser.add_argument('--load-dir', type=str, required=True,
help='Directory to load model checkpoint from')
parser.add_argument('--lora-dir', type=str,
help='Directory to lora model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True,
help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50,

View File

@ -46,7 +46,7 @@ def _load_checkpoint(queue, args):
from modellink.utils import parse_args
from megatron.arguments import validate_args
from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint
from megatron.checkpointing import load_args_from_checkpoint, _load_base_checkpoint
from megatron.checkpointing import load_checkpoint as load_checkpoint_mg
from megatron.model import module
from megatron.core import mpu
@ -236,6 +236,34 @@ def _load_checkpoint(queue, args):
# Get first pipe stage
mpu.set_pipeline_model_parallel_rank(0)
all_models = [get_models(tp_size, md.params_dtype)]
lora_target_modules = []
lora_rate = None
if args.lora_dir is not None:
sys.argv = ['script.py',
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--load', args.lora_dir
]
margs_lora = parse_args()
margs_lora, checkpoint_args_lora = load_args_from_checkpoint(margs_lora)
lora_target_modules = checkpoint_args_lora.lora_target_modules
if checkpoint_args_lora.lora_r is not None and checkpoint_args_lora.lora_r != 0:
lora_rate = checkpoint_args_lora.lora_alpha / checkpoint_args_lora.lora_r
else:
raise ZeroDivisionError
print("lora_rate ", lora_rate)
print("lora_target_modules ", lora_target_modules)
models = all_models[0][0]
md.consumed_train_samples = consumed_train_samples
@ -261,6 +289,7 @@ def _load_checkpoint(queue, args):
queue_put("embeddings", message)
total_layer_num = 0
state_dict_lora_list = []
for vp_rank in range(vp_size):
mpu.set_virtual_pipeline_model_parallel_rank(vp_rank)
for pp_rank in range(pp_size):
@ -287,38 +316,54 @@ def _load_checkpoint(queue, args):
message["dense bias"] = layer.self_attention.dense.bias.data
# Grab all parallel tensors for this layer
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
support_target = ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h']
weight_support_lora_list = [[], [], [], []]
support_target_key = ['self_attention.query_key_value', 'self_attention.dense', 'mlp.dense_h_to_4h', 'mlp.dense_4h_to_h']
for tp_rank, model in enumerate(models):
if len(state_dict_lora_list) <= tp_rank:
# lora
if args.lora_dir is not None:
mpu.set_tensor_model_parallel_rank(tp_rank)
state_dict_lora, _, _ = _load_base_checkpoint(args.lora_dir, rank0=False)
state_dict_lora_list.append(state_dict_lora["model"]["language_model"]["encoder"])
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
support_target_base = [layer.self_attention.query_key_value.weight.data, layer.self_attention.dense.weight.data,
layer.mlp.dense_h_to_4h.weight.data, layer.mlp.dense_4h_to_h.weight.data]
if md.linear_bias:
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
if args.add_qkv_bias:
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
for ind, item in enumerate(support_target):
if item in lora_target_modules:
loraB = state_dict_lora_list[tp_rank]["layers." + str(layer_num) + "." + support_target_key[ind] + ".lora_B.default.weight"]
loraA = state_dict_lora_list[tp_rank]["layers." + str(layer_num) + "." + support_target_key[ind] + ".lora_A.default.weight"]
tmp = loraB.float() @ loraA.float() * lora_rate + support_target_base[ind]
weight_support_lora_list[ind].append(tmp.to(dtype=torch.float16))
else:
weight_support_lora_list[ind].append(support_target_base[ind])
# Handle gated linear units
if md.swiglu:
# concat all the first halves ('W's) and all the second halves ('V's)
for tp_rank in range(tp_size):
mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
weight_support_lora_list[2][tp_rank] = torch.chunk(weight_support_lora_list[2][tp_rank], 2, dim=0)
message["mlp l0 weight W"] = torch.cat([w[0] for w in weight_support_lora_list[2]], dim=0)
message["mlp l0 weight V"] = torch.cat([w[1] for w in weight_support_lora_list[2]], dim=0)
else:
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
message["mlp l0 weight"] = torch.cat(weight_support_lora_list[2], dim=0)
# simple concat of the rest
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
message["qkv weight"] = torch.cat(weight_support_lora_list[0], dim=0)
message["dense weight"] = torch.cat(weight_support_lora_list[1], dim=1)
message["mlp l1 weight"] = torch.cat(weight_support_lora_list[3], dim=1)
if md.linear_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
if md.swiglu: