From deb8b2ebceb704606b2f71ddc0bff295d450a5be Mon Sep 17 00:00:00 2001 From: guoyiwei111 Date: Fri, 31 May 2024 07:51:23 +0000 Subject: [PATCH] =?UTF-8?q?!1235=20=E5=AE=9E=E7=8E=B0Lora=E6=9D=83?= =?UTF-8?q?=E9=87=8D=E5=90=88=E5=85=A5HuggingFace=E6=9D=83=E9=87=8D=20Merg?= =?UTF-8?q?e=20pull=20request=20!1235=20from=20guoyiwei111/master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/llama/README.md | 3 +- examples/llama/README_en.md | 2 + examples/llama2/README.md | 4 +- examples/llama2/README_en.md | 2 + tests/ut/test_convert_ckpt_to_huggingface.py | 76 +++++++++++++++++++ tools/checkpoint/convert_ckpt.py | 2 + tools/checkpoint/loader_megatron.py | 77 ++++++++++++++++---- 7 files changed, 147 insertions(+), 19 deletions(-) create mode 100644 tests/ut/test_convert_ckpt_to_huggingface.py diff --git a/examples/llama/README.md b/examples/llama/README.md index 85e38ff7..497ad258 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -170,6 +170,7 @@ LLaMA-7B/13B 训练的硬件配置如下: --save-dir ./model_from_hf/llama-13b-hf/ # <-- 需要填入原始HF模型路径,新权重会存于./model_from_hf/llama-13b-hf/mg2hg/ ``` + 若需将Lora微调权重一并合并到HuggingFace权重,需添加 --lora-dir {lora微调权重路径} 参数进行转换。\ 权重转换适用于预训练、微调、推理和评估,根据任务不同调整参数`target-tensor-parallel-size`和`target-pipeline-parallel-size`。 5. 预训练 @@ -605,7 +606,7 @@ LLaMA-33B/65B 训练的硬件配置: --target-pipeline-parallel-size 1 \ --save-dir ./model_from_hf/llama-65b-hf/ # <-- 需要填入原始HF模型路径,新权重会存于./model_from_hf/llama-65b-hf/mg2hg/ ``` - + 若需将Lora微调权重一并合并到HuggingFace权重,需添加 --lora-dir {lora微调权重路径} 参数进行转换。\ 权重转换适用于预训练、微调、推理和评估,根据任务不同调整参数`target-tensor-parallel-size`和`target-pipeline-parallel-size`。 5. 预训练 diff --git a/examples/llama/README_en.md b/examples/llama/README_en.md index 3d9f3433..c4134a76 100644 --- a/examples/llama/README_en.md +++ b/examples/llama/README_en.md @@ -165,6 +165,7 @@ Here's a hardware summary of pre-training LLaMA-7B/13B: --save-dir ./model_from_hf/llama-13b-hf/ # <-- Fill in the original HF model path here, new weights will be saved in ./model_from_hf/llama-13b-hf/mg2hg/ ``` + If you need combine Lora weight to huggingface weight, please add --lora-dir {lora weight path} \ Weight conversion is suitable for pre-training, fine-tuning, inference and evaluation. Adjust the parameters `target-tensor-parallel-size` and `target-pipeline-parallel-size` according to different tasks. 5. Pretrain @@ -588,6 +589,7 @@ The model was trained using alpaca datasets. --save-dir ./model_from_hf/llama-65b-hf/ # <-- Fill in the original HF model path here, new weights will be saved in ./model_from_hf/llama-65b-hf/mg2hg/ ``` + If you need combine Lora weight to huggingface weight, please add --lora-dir {lora weight path} \ Weight conversion is suitable for pre-training, fine-tuning, inference and evaluation. Adjust the parameters `target-tensor-parallel-size` and `target-pipeline-parallel-size` according to different tasks. 5. Pretrain diff --git a/examples/llama2/README.md b/examples/llama2/README.md index 1074568b..a6073f25 100755 --- a/examples/llama2/README.md +++ b/examples/llama2/README.md @@ -138,7 +138,7 @@ LLAMA2-7B 训练的硬件配置: --target-pipeline-parallel-size 1 \ --save-dir ./model_from_hf/llama-2-7b-hf/ # <-- 需要填入原始HF模型路径,新权重会存于./model_from_hf/llama-2-7b-hf/mg2hg/ ``` - + 若需将Lora微调权重一并合并到HuggingFace权重,需添加 --lora-dir {lora微调权重路径} 参数进行转换。\ 权重转换适用于预训练、微调、推理和评估,根据任务不同调整参数 `target-tensor-parallel-size`和 `target-pipeline-parallel-size`。 5. 预训练 @@ -871,7 +871,7 @@ LLaMA2-34B/70B 训练的硬件配置: --target-pipeline-parallel-size 1 \ --save-dir ./model_from_hf/llama-2-34b-hf/ # <-- 需要填入原始HF模型路径,新权重会存于./model_from_hf/llama-2-34b-hf/mg2hg/ ``` - + 若需将Lora微调权重一并合并到HuggingFace权重,需添加 --lora-dir {lora微调权重路径} 参数进行转换。\ 权重转换适用于预训练、微调、推理和评估,根据任务不同调整参数`target-tensor-parallel-size`和 `target-pipeline-parallel-size`。 5. 预训练 diff --git a/examples/llama2/README_en.md b/examples/llama2/README_en.md index 85c6708c..c9078508 100644 --- a/examples/llama2/README_en.md +++ b/examples/llama2/README_en.md @@ -151,6 +151,7 @@ Here's a hardware summary of pre-training LLAMA2-7B: --save-dir ./model_from_hf/llama-2-7b-hf/ # <-- Fill in the original HF model path here, new weights will be saved in ./model_from_hf/llama-2-7b-hf/mg2hg/ ``` + If you need combine Lora weight to huggingface weight, please add --lora-dir {lora weight path} \ Weight conversion is suitable for pre-training, fine-tuning, inference and evaluation. Adjust the parameters `target-tensor-parallel-size` and `target-pipeline-parallel-size` according to different tasks. 5. pre-training @@ -886,6 +887,7 @@ Here's a hardware summary of pre-training LLaMA2-34B/70B: --save-dir ./model_from_hf/llama-2-34b-hf/ # <-- Fill in the original HF model path here, new weights will be saved in ./model_from_hf/llama-2-34b-hf/mg2hg/ ``` + If you need combine Lora weight to huggingface weight, please add --lora-dir {lora weight path} \ Weight conversion is suitable for pre-training, fine-tuning, inference and evaluation. Adjust the parameters `target-tensor-parallel-size` and `target-pipeline-parallel-size` according to different tasks. 5. pre-training diff --git a/tests/ut/test_convert_ckpt_to_huggingface.py b/tests/ut/test_convert_ckpt_to_huggingface.py new file mode 100644 index 00000000..e103fd2a --- /dev/null +++ b/tests/ut/test_convert_ckpt_to_huggingface.py @@ -0,0 +1,76 @@ +import unittest +import sys +import os +import subprocess +from pathlib import Path +import numpy as np +import torch +from ascendspeed import megatron_adaptor +from transformers import AutoModelForCausalLM +from utils import judge_expression +import modellink +from tests.pipeline.common import DistributedTest + + +class CovertCkptToHuggingfaceArgs: + model_type = "GPT" + loader = "megatron" + saver = "megatron" + save_dir = "/home/dataset/ci_engineering/llama-2-7b-hf" + lora_dir = "/home/dataset/ci_engineering/llama2-7B-lora-ckpt" + load_dir = "/home/dataset/ci_engineering/llama2-7B-tp8-pp1" + + +class TestConvertCkptFromHuggingface(DistributedTest): + + def test_combine_lora_weights_to_huggingface(self): + """ + Test whether the weight to be converted as we want in `--lora-dir`. We will check the combine weight + in huggingface equals loraB @ loraA * rate + base in megatron. + """ + args = CovertCkptToHuggingfaceArgs() + rate = 2 + hidden_layer = 4096 + num_head = 32 + tp = 8 + dk = 128 + + base_dir = Path(__file__).absolute().parent.parent.parent + file_path = os.path.join(base_dir, "tools/checkpoint/convert_ckpt.py") + arguments = [ + "--model-type", args.model_type, + "--loader", args.loader, + "--saver", args.saver, + "--save-model-type", "save_huggingface_llama", + "--load-dir", args.load_dir, + "--lora-dir", args.lora_dir, + "--target-tensor-parallel-size", "1", + "--target-pipeline-parallel-size", "1", + "--save-dir", args.save_dir + ] + + subprocess.run(["python3", file_path] + arguments) + + output_dir = os.path.join(args.save_dir, "mg2hg") + + model = AutoModelForCausalLM.from_pretrained(output_dir, trust_remote_code=True, low_cpu_mem_usage=True) + q_hf = model.state_dict()["model.layers.0.self_attn.q_proj.weight"] + + judge_expression(q_hf.size() == torch.Size([4096, 4096])) + + base_dir = os.path.join(args.load_dir, "iter_0000001") + weight_base = torch.load(os.path.join(base_dir, "mp_rank_00/model_optim_rng.pt")) + weight_base_content = weight_base['model']['language_model']['encoder'] # extract commmon content + base_qkv = weight_base_content['layers.0.self_attention.query_key_value.weight'] + + lora_dir = os.path.join(args.lora_dir, "iter_0000010") + weight_lora = torch.load(os.path.join(lora_dir, "mp_rank_00/model_optim_rng.pt")) + weight_lora_content = weight_lora['model']['language_model']['encoder'] # extract commmon content + loraB_qkv = weight_lora_content['layers.0.self_attention.query_key_value.lora_B.default.weight'] + loraA_qkv = weight_lora_content['layers.0.self_attention.query_key_value.lora_A.default.weight'] + + res_qkv = loraB_qkv.cpu().float() @ loraA_qkv.cpu().float() * rate + base_qkv + + gp1_q_mg = res_qkv.reshape(num_head // tp, 3, dk, hidden_layer)[:1, :1, :, :].reshape(dk, hidden_layer) + gp1_q_hf = q_hf.reshape(num_head, dk, hidden_layer)[:1, :, :].reshape(dk, hidden_layer) + judge_expression(np.allclose(gp1_q_mg.cpu(), gp1_q_hf.cpu(), rtol=0.001, atol=0.001)) \ No newline at end of file diff --git a/tools/checkpoint/convert_ckpt.py b/tools/checkpoint/convert_ckpt.py index 7a1086b0..5c716453 100644 --- a/tools/checkpoint/convert_ckpt.py +++ b/tools/checkpoint/convert_ckpt.py @@ -63,6 +63,8 @@ def main(): help='Module name to save checkpoint, shdoul be on python path') parser.add_argument('--load-dir', type=str, required=True, help='Directory to load model checkpoint from') + parser.add_argument('--lora-dir', type=str, + help='Directory to lora model checkpoint from') parser.add_argument('--save-dir', type=str, required=True, help='Directory to save model checkpoint to') parser.add_argument('--max-queue-size', type=int, default=50, diff --git a/tools/checkpoint/loader_megatron.py b/tools/checkpoint/loader_megatron.py index c2400bd0..32c08fc2 100644 --- a/tools/checkpoint/loader_megatron.py +++ b/tools/checkpoint/loader_megatron.py @@ -46,7 +46,7 @@ def _load_checkpoint(queue, args): from modellink.utils import parse_args from megatron.arguments import validate_args from megatron.global_vars import set_args, set_global_variables - from megatron.checkpointing import load_args_from_checkpoint + from megatron.checkpointing import load_args_from_checkpoint, _load_base_checkpoint from megatron.checkpointing import load_checkpoint as load_checkpoint_mg from megatron.model import module from megatron.core import mpu @@ -236,6 +236,34 @@ def _load_checkpoint(queue, args): # Get first pipe stage mpu.set_pipeline_model_parallel_rank(0) all_models = [get_models(tp_size, md.params_dtype)] + lora_target_modules = [] + lora_rate = None + if args.lora_dir is not None: + sys.argv = ['script.py', + '--no-masked-softmax-fusion', + '--no-bias-gelu-fusion', + '--no-bias-dropout-fusion', + '--no-async-tensor-model-parallel-allreduce', + '--use-cpu-initialization', + '--micro-batch-size', '1', + '--no-load-optim', + '--no-load-rng', + '--no-save-optim', + '--no-save-rng', + '--no-initialization', + '--load', args.lora_dir + ] + margs_lora = parse_args() + + margs_lora, checkpoint_args_lora = load_args_from_checkpoint(margs_lora) + lora_target_modules = checkpoint_args_lora.lora_target_modules + if checkpoint_args_lora.lora_r is not None and checkpoint_args_lora.lora_r != 0: + lora_rate = checkpoint_args_lora.lora_alpha / checkpoint_args_lora.lora_r + else: + raise ZeroDivisionError + print("lora_rate ", lora_rate) + print("lora_target_modules ", lora_target_modules) + models = all_models[0][0] md.consumed_train_samples = consumed_train_samples @@ -261,6 +289,7 @@ def _load_checkpoint(queue, args): queue_put("embeddings", message) total_layer_num = 0 + state_dict_lora_list = [] for vp_rank in range(vp_size): mpu.set_virtual_pipeline_model_parallel_rank(vp_rank) for pp_rank in range(pp_size): @@ -287,38 +316,54 @@ def _load_checkpoint(queue, args): message["dense bias"] = layer.self_attention.dense.bias.data # Grab all parallel tensors for this layer - qkv_weight = [] + qkv_bias = [] - dense_weight = [] - mlp_l0_weight = [] mlp_l0_bias = [] - mlp_l1_weight = [] + support_target = ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'] + weight_support_lora_list = [[], [], [], []] + support_target_key = ['self_attention.query_key_value', 'self_attention.dense', 'mlp.dense_h_to_4h', 'mlp.dense_4h_to_h'] + for tp_rank, model in enumerate(models): + if len(state_dict_lora_list) <= tp_rank: + # lora + if args.lora_dir is not None: + mpu.set_tensor_model_parallel_rank(tp_rank) + state_dict_lora, _, _ = _load_base_checkpoint(args.lora_dir, rank0=False) + state_dict_lora_list.append(state_dict_lora["model"]["language_model"]["encoder"]) + layer = model.language_model.encoder.layers[layer_num] - qkv_weight.append(layer.self_attention.query_key_value.weight.data) - dense_weight.append(layer.self_attention.dense.weight.data) - mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data) - mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data) + support_target_base = [layer.self_attention.query_key_value.weight.data, layer.self_attention.dense.weight.data, + layer.mlp.dense_h_to_4h.weight.data, layer.mlp.dense_4h_to_h.weight.data] if md.linear_bias: qkv_bias.append(layer.self_attention.query_key_value.bias.data) mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data) if args.add_qkv_bias: qkv_bias.append(layer.self_attention.query_key_value.bias.data) + + + for ind, item in enumerate(support_target): + if item in lora_target_modules: + loraB = state_dict_lora_list[tp_rank]["layers." + str(layer_num) + "." + support_target_key[ind] + ".lora_B.default.weight"] + loraA = state_dict_lora_list[tp_rank]["layers." + str(layer_num) + "." + support_target_key[ind] + ".lora_A.default.weight"] + tmp = loraB.float() @ loraA.float() * lora_rate + support_target_base[ind] + weight_support_lora_list[ind].append(tmp.to(dtype=torch.float16)) + else: + weight_support_lora_list[ind].append(support_target_base[ind]) # Handle gated linear units if md.swiglu: # concat all the first halves ('W's) and all the second halves ('V's) for tp_rank in range(tp_size): - mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) - message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) - message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) + weight_support_lora_list[2][tp_rank] = torch.chunk(weight_support_lora_list[2][tp_rank], 2, dim=0) + message["mlp l0 weight W"] = torch.cat([w[0] for w in weight_support_lora_list[2]], dim=0) + message["mlp l0 weight V"] = torch.cat([w[1] for w in weight_support_lora_list[2]], dim=0) else: - message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) + message["mlp l0 weight"] = torch.cat(weight_support_lora_list[2], dim=0) # simple concat of the rest - message["qkv weight"] = torch.cat(qkv_weight, dim=0) - message["dense weight"] = torch.cat(dense_weight, dim=1) - message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) + message["qkv weight"] = torch.cat(weight_support_lora_list[0], dim=0) + message["dense weight"] = torch.cat(weight_support_lora_list[1], dim=1) + message["mlp l1 weight"] = torch.cat(weight_support_lora_list[3], dim=1) if md.linear_bias: message["qkv bias"] = torch.cat(qkv_bias, dim=0) if md.swiglu: