mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-11-29 18:38:39 +08:00
!1548 【mcore框架 新增 mg2mg & mg2hf & ep等特性】
Merge pull request !1548 from glhyy/master
This commit is contained in:
parent
a08bb1cd12
commit
c5bd9b2b58
@ -17,6 +17,8 @@ import os
|
||||
|
||||
from functools import wraps
|
||||
from megatron.training import get_args
|
||||
from megatron.training.utils import print_rank_0
|
||||
from megatron.training.checkpointing import _load_base_checkpoint
|
||||
from .tasks.finetune.lora.utils import is_enable_lora, merge_dicts, modify_keys_with_dict
|
||||
|
||||
|
||||
@ -54,3 +56,44 @@ def load_checkpoint_wrapper(fn):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def load_args_from_checkpoint_wrapper(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
res = fn(*args, **kwargs)
|
||||
if len(res) == 1:
|
||||
return res
|
||||
args, checkpoint_args = res
|
||||
|
||||
def _set_arg(arg_name, old_arg_name=None, force=False):
|
||||
if not force and getattr(args, arg_name, None) is not None:
|
||||
return
|
||||
if old_arg_name is not None:
|
||||
checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
|
||||
else:
|
||||
checkpoint_value = getattr(checkpoint_args, arg_name, None)
|
||||
if checkpoint_value is not None:
|
||||
print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
|
||||
setattr(args, arg_name, checkpoint_value)
|
||||
else:
|
||||
print_rank_0(f"Checkpoint did not provide arguments {arg_name}")
|
||||
|
||||
_set_arg('num_layer_list', force=True)
|
||||
_set_arg('post_norm', force=True)
|
||||
_set_arg('num_experts')
|
||||
_set_arg('sequence_parallel', force=True)
|
||||
|
||||
state_dict, checkpoint_name, release = _load_base_checkpoint(
|
||||
getattr(args, kwargs.get('load_arg', 'load')),
|
||||
rank0=True,
|
||||
exit_on_missing_checkpoint=kwargs.get('exit_on_missing_checkpoint', False),
|
||||
checkpoint_step=args.ckpt_step
|
||||
)
|
||||
checkpoint_version = state_dict.get('checkpoint_version', 0)
|
||||
if checkpoint_version >= 3.0:
|
||||
_set_arg('expert_model_parallel_size', force=True)
|
||||
|
||||
return args, checkpoint_args
|
||||
|
||||
return wrapper
|
||||
|
@ -45,7 +45,7 @@ from ..core.pipeline_parallel.p2p_communication import _batched_p2p_ops
|
||||
from ..data import build_pretraining_data_loader
|
||||
from ..tokenizer import build_tokenizer
|
||||
from ..arguments import parse_args_decorator
|
||||
from ..checkpointing import _load_base_checkpoint_wrapper, load_checkpoint_wrapper
|
||||
from ..checkpointing import _load_base_checkpoint_wrapper, load_checkpoint_wrapper, load_args_from_checkpoint_wrapper
|
||||
from ..initialize import initialize_megatron
|
||||
from ..utils import emit
|
||||
from ..arguments import process_args
|
||||
@ -290,6 +290,7 @@ def patch_model():
|
||||
# patch language model
|
||||
PatchManager.register_patch('megatron.legacy.model.language_model.TransformerLanguageModel.forward', transformer_language_model_forward_wrapper)
|
||||
PatchManager.register_patch('megatron.legacy.model.language_model.TransformerLanguageModel.__init__', transformer_language_model_init)
|
||||
PatchManager.register_patch('megatron.training.checkpointing.load_args_from_checkpoint', load_args_from_checkpoint_wrapper)
|
||||
|
||||
|
||||
def patch_initialize():
|
||||
|
@ -62,6 +62,8 @@ def main():
|
||||
parser.add_argument('--no-checking', action='store_false',
|
||||
help='Do not perform checking on the name and ordering of weights',
|
||||
dest='checking')
|
||||
parser.add_argument('--model-type-hf', type=str, default="llama2",
|
||||
choices=['llama2', 'mixtral', 'chatglm3'], help='model-type')
|
||||
known_args, _ = parser.parse_known_args()
|
||||
loader = load_plugin('loader', known_args.loader)
|
||||
saver = load_plugin('saver', known_args.saver)
|
||||
|
@ -16,11 +16,15 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import logging as logger
|
||||
import torch
|
||||
import transformers
|
||||
from models import get_megatron_model
|
||||
from models import get_huggingface_model
|
||||
|
||||
logger.basicConfig(format="")
|
||||
logger.getLogger().setLevel(logger.INFO)
|
||||
|
||||
|
||||
def add_arguments(parser):
|
||||
group = parser.add_argument_group(title='Llama-2 HF loader.')
|
||||
@ -51,8 +55,6 @@ def add_arguments(parser):
|
||||
'This is added for computational efficiency reasons.')
|
||||
group.add_argument('--use-mcore-models', action='store_true',
|
||||
help='Use the implementation from megatron core')
|
||||
group.add_argument('--model-type-hf', type=str,
|
||||
help='huggingface model type e.g., llama2, qwen, ...')
|
||||
|
||||
|
||||
def verify_transformers_version():
|
||||
@ -87,6 +89,7 @@ def build_metadata(args, margs):
|
||||
md.consumed_train_samples = 0
|
||||
md.consumed_valid_samples = 0
|
||||
md.embed_layernorm = margs.embed_layernorm
|
||||
md.disable_bias_linear = margs.disable_bias_linear
|
||||
|
||||
return md
|
||||
|
||||
@ -98,14 +101,14 @@ def get_message_preprocess(model, md):
|
||||
}
|
||||
|
||||
# bloom
|
||||
if model.has_embedding_word_embeddings_norm():
|
||||
if model.has_embedding_word_embeddings_norm_module():
|
||||
message["word embeddings norm_w"] = model.get_embedding_word_embeddings_norm_weight()
|
||||
message["word embeddings norm_b"] = model.get_embedding_word_embeddings_norm_bias()
|
||||
|
||||
if md.position_embedding_type == 'learned_absolute':
|
||||
message["position embeddings"] = model.get_embedding_position_embeddings_weight()
|
||||
else:
|
||||
if model.has_embedding_position_embeddings():
|
||||
if model.has_embedding_position_embeddings_module():
|
||||
raise ValueError("model should have position_embeddings")
|
||||
|
||||
return message
|
||||
@ -149,15 +152,15 @@ def get_message_layer_attn(message, model, layer_idx, md=None, args=None):
|
||||
return message
|
||||
|
||||
|
||||
def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
|
||||
def _get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1, **kwargs):
|
||||
# Grab all parallel tensors for this layer.
|
||||
mlp_l0_weight = []
|
||||
mlp_l0_bias = []
|
||||
mlp_l1_weight = []
|
||||
mlp_l0_weight.append(model.get_layers_mlp_linear_fc1_weight(layer_idx=layer_idx))
|
||||
mlp_l1_weight.append(model.get_layers_mlp_linear_fc2_weight(layer_idx=layer_idx))
|
||||
mlp_l0_weight.append(model.get_layers_mlp_linear_fc1_weight(layer_idx=layer_idx, **kwargs))
|
||||
mlp_l1_weight.append(model.get_layers_mlp_linear_fc2_weight(layer_idx=layer_idx, **kwargs))
|
||||
if md.linear_bias:
|
||||
mlp_l0_bias.append(model.get_layers_mlp_linear_fc1_bias(layer_idx=layer_idx))
|
||||
mlp_l0_bias.append(model.get_layers_mlp_linear_fc1_bias(layer_idx=layer_idx, **kwargs))
|
||||
# Handle gated linear units.
|
||||
if md.swiglu:
|
||||
# Concat all the first halves ('W's) and all the second halves ('V's).
|
||||
@ -171,7 +174,7 @@ def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
|
||||
# Simple concat of the rest.
|
||||
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
|
||||
if md.linear_bias:
|
||||
message["mlp l1 bias"] = model.get_layers_mlp_linear_fc2_bias(layer_idx=layer_idx)
|
||||
message["mlp l1 bias"] = model.get_layers_mlp_linear_fc2_bias(layer_idx=layer_idx, **kwargs)
|
||||
if md.swiglu:
|
||||
for tp_rank in range(tp_size):
|
||||
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
|
||||
@ -183,6 +186,22 @@ def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
|
||||
return message
|
||||
|
||||
|
||||
def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
|
||||
margs = model.get_args()
|
||||
if margs.num_experts:
|
||||
# return _get_message_layer_mlp(message, model, layer_idx, md=md, tp_size=tp_size)
|
||||
message["mlp_moe"] = {}
|
||||
mlp_router_weight = model.get_layers_mlp_router_weight(layer_idx=layer_idx)
|
||||
message["mlp_moe"]["mlp router weight"] = mlp_router_weight
|
||||
for expert_idx in range(margs.num_experts):
|
||||
kwargs = {'expert_idx': expert_idx}
|
||||
expert = _get_message_layer_mlp({}, model, layer_idx, md=md, tp_size=tp_size, **kwargs)
|
||||
message["mlp_moe"][f"expert {expert_idx}"] = expert
|
||||
return message
|
||||
else:
|
||||
return _get_message_layer_mlp(message, model, layer_idx, md=md, tp_size=tp_size)
|
||||
|
||||
|
||||
def get_message_postprocess(model, md):
|
||||
# Send final norm from tp_rank 0.
|
||||
message = {
|
||||
@ -224,6 +243,7 @@ def _load_checkpoint(queue, args):
|
||||
model_mg.initialize_megatron_args(args_hf, queue)
|
||||
|
||||
model_mg.set_tensor_model_parallel_world_size(model_mg.args.tensor_model_parallel_size)
|
||||
model_mg.set_expert_model_parallel_world_size(model_mg.args.expert_model_parallel_size)
|
||||
model_mg.set_pipeline_model_parallel_world_size(model_mg.args.pipeline_model_parallel_size)
|
||||
model_mg.set_virtual_pipeline_model_parallel_world_size(model_mg.args.virtual_pipeline_model_parallel_size)
|
||||
|
||||
@ -241,7 +261,7 @@ def _load_checkpoint(queue, args):
|
||||
model_mg.update_module(model_hf)
|
||||
|
||||
def queue_put(name, msg):
|
||||
print(f"sending {name}")
|
||||
logger.info(f"sending {name}")
|
||||
msg["name"] = name
|
||||
queue.put(msg)
|
||||
|
||||
|
299
tools/checkpoint/loader_mg_mcore.py
Normal file
299
tools/checkpoint/loader_mg_mcore.py
Normal file
@ -0,0 +1,299 @@
|
||||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import logging as logger
|
||||
import torch
|
||||
from models import get_megatron_model
|
||||
|
||||
logger.basicConfig(format="")
|
||||
logger.getLogger().setLevel(logger.INFO)
|
||||
|
||||
|
||||
def add_arguments(parser):
|
||||
group = parser.add_argument_group(title='Megatron loader')
|
||||
|
||||
group.add_argument('--true-vocab-size', type=int, default=None,
|
||||
help='original size of vocab, if specified will trim padding from embedding table.')
|
||||
group.add_argument('--vocab-file', type=str, default=None,
|
||||
help='Path to the vocab file. If specified will use this to get vocab size and '
|
||||
'trim padding from the embedding table.')
|
||||
group.add_argument('--megatron-path', type=str, default=None,
|
||||
help='Base directory of megatron repository')
|
||||
parser.add_argument('--add-qkv-bias', action='store_true',
|
||||
help='Add bias for attention qkv', default=False,
|
||||
)
|
||||
parser.add_argument('--add-dense-bias', action='store_true',
|
||||
help='Add bias for attention dense', default=False,
|
||||
)
|
||||
parser.add_argument('--embed-layernorm', action='store_true',
|
||||
help='Add embed layernorm for word embedding', default=False,
|
||||
)
|
||||
parser.add_argument('--params-dtype', type=str,
|
||||
help='Set weight dtype', default='fp16',
|
||||
)
|
||||
group.add_argument('--lora-target-modules', nargs='+', type=str, default=[],
|
||||
help='Lora target modules.')
|
||||
group.add_argument('--lora-load', type=str, default=None,
|
||||
help='Directory containing a lora model checkpoint.')
|
||||
group.add_argument('--lora-r', type=int, default=16,
|
||||
help='Lora r.')
|
||||
group.add_argument('--lora-alpha', type=int, default=32,
|
||||
help='Lora alpha.')
|
||||
|
||||
|
||||
def build_metadata(args, margs):
|
||||
# Metadata.
|
||||
|
||||
# Layernorm has bias; RMSNorm does not.
|
||||
if hasattr(margs, 'normalization'):
|
||||
norm_has_bias = margs.normalization == "LayerNorm"
|
||||
else:
|
||||
# older models only supported LayerNorm
|
||||
norm_has_bias = True
|
||||
|
||||
md = types.SimpleNamespace()
|
||||
md.model_type = args.model_type
|
||||
md.num_layers = margs.num_layers
|
||||
md.hidden_size = margs.hidden_size
|
||||
md.seq_length = margs.seq_length
|
||||
md.num_attention_heads = margs.num_attention_heads
|
||||
md.max_position_embeddings = margs.max_position_embeddings
|
||||
md.tokenizer_type = margs.tokenizer_type
|
||||
md.iteration = margs.iteration
|
||||
md.params_dtype = margs.params_dtype
|
||||
md.bert_binary_head = margs.bert_binary_head
|
||||
md.output_layer = margs.untie_embeddings_and_output_weights
|
||||
md.position_embedding_type = margs.position_embedding_type
|
||||
md.linear_bias = margs.add_bias_linear
|
||||
md.norm_has_bias = norm_has_bias
|
||||
md.swiglu = margs.swiglu
|
||||
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
|
||||
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
|
||||
md.true_vocab_size = None
|
||||
md.checkpoint_args = margs
|
||||
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
|
||||
md.embed_layernorm = margs.embed_layernorm
|
||||
|
||||
md.consumed_train_samples = 0
|
||||
md.consumed_valid_samples = 0
|
||||
|
||||
return md
|
||||
|
||||
|
||||
def get_message_preprocess(model, args):
|
||||
# Send embeddings.
|
||||
tp_size = args.tensor_model_parallel_size
|
||||
message = {
|
||||
"word embeddings": torch.cat(
|
||||
[model.get_embedding_word_embeddings_weight(tp_rank=tp_rank) for tp_rank in range(tp_size)], dim=0
|
||||
)
|
||||
}
|
||||
if args.position_embedding_type == 'learned_absolute':
|
||||
message["position embeddings"] = model.get_embedding_position_embeddings_weight()
|
||||
if args.embed_layernorm:
|
||||
message["word embeddings norm_w"] = model.get_embedding_word_embeddings_norm_weight()
|
||||
message["word embeddings norm_b"] = model.get_embedding_word_embeddings_norm_bias()
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_message_layer_norm(message, model, md, **kwargs):
|
||||
# Get non-parallel tensors from tp_rank 0.
|
||||
message["input norm weight"] = model.get_layers_input_layernorm_weight(**kwargs)
|
||||
if md.norm_has_bias:
|
||||
message["input norm bias"] = model.get_layers_input_layernorm_bias(**kwargs)
|
||||
|
||||
message["post norm weight"] = model.get_layers_self_attention_post_attention_layernorm_weight(**kwargs)
|
||||
if md.norm_has_bias:
|
||||
message["post norm bias"] = model.get_layers_self_attention_post_attention_layernorm_bias(**kwargs)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_message_layer_attn(message, model, md=None, **kwargs):
|
||||
# Grab all parallel tensors for this layer
|
||||
qkv_weight = []
|
||||
qkv_bias = []
|
||||
dense_weight = []
|
||||
margs = model.get_args()
|
||||
|
||||
for tp_rank in range(margs.tensor_model_parallel_size):
|
||||
kwargs["tp_rank"] = tp_rank
|
||||
qkv_weight.append(model.get_layers_self_attention_linear_qkv_weight(**kwargs))
|
||||
dense_weight.append(model.get_layers_self_attention_linear_proj_weight(**kwargs))
|
||||
|
||||
if md.linear_bias or margs.add_qkv_bias:
|
||||
qkv_bias.append(model.get_layers_self_attention_linear_qkv_bias(**kwargs))
|
||||
|
||||
# Handle gated linear units
|
||||
# simple concat of the rest
|
||||
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
|
||||
message["dense weight"] = torch.cat(dense_weight, dim=1)
|
||||
if md.linear_bias or margs.add_qkv_bias:
|
||||
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
|
||||
|
||||
if md.linear_bias:
|
||||
message["dense bias"] = model.get_layers_self_attention_linear_proj_bias(**kwargs)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def _get_message_layer_mlp(message, model, md=None, **kwargs):
|
||||
margs = model.get_args()
|
||||
mlp_l0_weight = []
|
||||
mlp_l1_weight = []
|
||||
mlp_l0_bias = []
|
||||
for tp_rank in range(margs.tensor_model_parallel_size):
|
||||
kwargs["tp_rank"] = tp_rank
|
||||
mlp_l0_weight.append(model.get_layers_mlp_linear_fc1_weight(**kwargs))
|
||||
mlp_l1_weight.append(model.get_layers_mlp_linear_fc2_weight(**kwargs))
|
||||
if md.linear_bias:
|
||||
mlp_l0_bias.append(model.get_layers_mlp_linear_fc1_bias(**kwargs))
|
||||
|
||||
# Handle gated linear units
|
||||
if md.swiglu:
|
||||
# concat all the first halves ('W's) and all the second halves ('V's)
|
||||
for tp_rank in range(margs.tensor_model_parallel_size):
|
||||
mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
|
||||
message[f"mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
|
||||
message[f"mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
|
||||
else:
|
||||
message[f"mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
|
||||
|
||||
# simple concat of the rest
|
||||
message[f"mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
|
||||
if md.linear_bias:
|
||||
if md.swiglu:
|
||||
for tp_rank in range(margs.tensor_model_parallel_size):
|
||||
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
|
||||
message[f"mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias], dim=0)
|
||||
message[f"mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias], dim=0)
|
||||
else:
|
||||
message[f"mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
|
||||
|
||||
|
||||
def get_message_layer_mlp(message, model, md=None, **kwargs):
|
||||
# Grab all parallel tensors for this layer
|
||||
margs = model.get_args()
|
||||
if margs.num_experts:
|
||||
message["mlp_moe"] = {}
|
||||
num_experts_local = margs.num_experts // margs.expert_model_parallel_size
|
||||
mlp_router_weight = model.get_layers_mlp_router_weight(**kwargs)
|
||||
message["mlp_moe"]["mlp router weight"] = mlp_router_weight
|
||||
for ep_rank in range(margs.expert_model_parallel_size):
|
||||
kwargs["ep_rank"] = ep_rank
|
||||
for expert_idx in range(num_experts_local):
|
||||
kwargs["expert_idx"] = expert_idx
|
||||
global_expert_idx = expert_idx + ep_rank * num_experts_local
|
||||
message["mlp_moe"][f"expert {global_expert_idx}"] = {}
|
||||
expert = message["mlp_moe"][f"expert {global_expert_idx}"]
|
||||
_get_message_layer_mlp(expert, model, md, **kwargs)
|
||||
else:
|
||||
_get_message_layer_mlp(message, model, md, **kwargs)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_message_postprocess(model, md, **kwargs):
|
||||
# Send final norm from tp_rank 0.
|
||||
message = {}
|
||||
message[f"weight"] = model.get_final_layernorm_weight(**kwargs)
|
||||
if md.norm_has_bias:
|
||||
message[f"bias"] = model.get_final_layernorm_bias(**kwargs)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_message_output_layer(model, md, **kwargs):
|
||||
# Send final norm from tp_rank 0.
|
||||
margs = model.get_args()
|
||||
tp_size = margs.tensor_model_parallel_size
|
||||
message = {}
|
||||
if md.output_layer:
|
||||
get_output_layer_weight_list = []
|
||||
for tp_rank in range(tp_size):
|
||||
kwargs["tp_rank"] = tp_rank
|
||||
get_output_layer_weight_list.append(
|
||||
model.get_output_layer_weight(**kwargs)
|
||||
)
|
||||
message[f"weight"] = torch.cat(get_output_layer_weight_list, dim=0)
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def _load_checkpoint(queue, args):
|
||||
|
||||
# Search in directory above this
|
||||
sys.path.append(os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__),
|
||||
os.path.pardir)))
|
||||
if args.megatron_path is not None:
|
||||
sys.path.insert(0, args.megatron_path)
|
||||
|
||||
model_mg = get_megatron_model(args_cmd=args)
|
||||
model_mg.initialize_megatron_args(queue=queue, loader_megatron=True)
|
||||
|
||||
model_mg.set_tensor_model_parallel_world_size(model_mg.args.tensor_model_parallel_size)
|
||||
model_mg.set_expert_model_parallel_world_size(model_mg.args.expert_model_parallel_size)
|
||||
model_mg.set_pipeline_model_parallel_world_size(model_mg.args.pipeline_model_parallel_size)
|
||||
model_mg.set_virtual_pipeline_model_parallel_world_size(model_mg.args.virtual_pipeline_model_parallel_size)
|
||||
|
||||
# Get first pipe stage.
|
||||
model_mg.set_tensor_model_parallel_rank(0)
|
||||
model_mg.set_pipeline_model_parallel_rank(0)
|
||||
|
||||
margs = model_mg.get_args()
|
||||
|
||||
md = build_metadata(args, margs)
|
||||
queue.put(md)
|
||||
model_mg.get_modules_from_pretrained(pp_stage_cache_flag=True)
|
||||
|
||||
def queue_put(name, msg):
|
||||
logger.info(f"sending {name}")
|
||||
msg["name"] = name
|
||||
queue.put(msg)
|
||||
|
||||
# Send embeddings
|
||||
message = get_message_preprocess(model_mg, margs)
|
||||
queue_put("embeddings", message)
|
||||
|
||||
pp_size = margs.pipeline_model_parallel_size
|
||||
vp_size = margs.virtual_pipeline_model_parallel_size
|
||||
if vp_size is None:
|
||||
vp_size = 1
|
||||
|
||||
total_layer_num = 0
|
||||
for vp_rank in range(vp_size):
|
||||
for pp_rank in range(pp_size):
|
||||
model_mg.set_pipeline_model_parallel_rank(pp_rank)
|
||||
model_mg.get_modules_from_pretrained(pp_stage_cache_flag=True)
|
||||
kwargs = {"vp_rank": vp_rank, 'pp_rank': pp_rank}
|
||||
for layer_idx in range(len(model_mg.get_layers_module(**kwargs))):
|
||||
kwargs["layer_idx"] = layer_idx
|
||||
message = {}
|
||||
message = get_message_layer_norm(message, model_mg, md, **kwargs)
|
||||
message = get_message_layer_attn(message, model_mg, md, **kwargs)
|
||||
message = get_message_layer_mlp(message, model_mg, md, **kwargs)
|
||||
queue_put(f"transformer layer {total_layer_num}", message)
|
||||
total_layer_num = total_layer_num + 1
|
||||
|
||||
kwargs = {"pp_rank": pp_size - 1, "vp_rank": vp_size - 1}
|
||||
message = get_message_postprocess(model_mg, md, **kwargs)
|
||||
queue_put("final norm", message)
|
||||
|
||||
message = get_message_output_layer(model_mg, md, **kwargs)
|
||||
if message:
|
||||
queue_put("output layer", message)
|
||||
|
||||
queue.put("done")
|
||||
|
||||
|
||||
def load_checkpoint(queue, args):
|
||||
try:
|
||||
_load_checkpoint(queue, args)
|
||||
except:
|
||||
queue.put("exit")
|
||||
raise
|
@ -25,24 +25,25 @@
|
||||
"vocab_size": "vocab_size",
|
||||
"intermediate_size": "intermediate_size",
|
||||
"norm_epsilon": "rms_norm_eps",
|
||||
"tie_word_embeddings": "tie_word_embeddings"
|
||||
"tie_word_embeddings": "tie_word_embeddings",
|
||||
"torch_dtype": "torch_dtype"
|
||||
},
|
||||
"model_hf_key_mapping": {
|
||||
"model": "module[tp_rank]",
|
||||
"embedding_word_embeddings": "module[tp_rank].model.embed_tokens",
|
||||
"embedding_word_embeddings_norm": "module[tp_rank].model.embedding.word_embeddings.norm",
|
||||
"layers": "module[tp_rank].model.layers",
|
||||
"layers_input_layernorm": "module[tp_rank].model.layers[layer_idx].input_layernorm",
|
||||
"layers_self_attention_linear_proj": "module[tp_rank].model.layers[layer_idx].self_attn.o_proj",
|
||||
"layers_self_attention_linear_q_proj": "module[tp_rank].model.layers[layer_idx].self_attn.q_proj",
|
||||
"layers_self_attention_linear_k_proj": "module[tp_rank].model.layers[layer_idx].self_attn.k_proj",
|
||||
"layers_self_attention_linear_v_proj": "module[tp_rank].model.layers[layer_idx].self_attn.v_proj",
|
||||
"layers_self_attention_pre_mlp_layernorm": "module[tp_rank].model.layers[layer_idx].post_attention_layernorm",
|
||||
"layers_mlp_gate_proj": "module[tp_rank].model.layers[layer_idx].mlp.gate_proj",
|
||||
"layers_mlp_up_proj": "module[tp_rank].model.layers[layer_idx].mlp.up_proj",
|
||||
"layers_mlp_linear_fc2": "module[tp_rank].model.layers[layer_idx].mlp.down_proj",
|
||||
"final_layernorm": "module[tp_rank].model.norm",
|
||||
"output_layer": "module[tp_rank].lm_head"
|
||||
"model": "module[0]",
|
||||
"embedding_word_embeddings": "model.embed_tokens",
|
||||
"embedding_word_embeddings_norm": "model.embedding.word_embeddings.norm",
|
||||
"layers": "model.layers",
|
||||
"layers_input_layernorm": "model.layers[layer_idx].input_layernorm",
|
||||
"layers_self_attention_linear_proj": "model.layers[layer_idx].self_attn.o_proj",
|
||||
"layers_self_attention_linear_q_proj": "model.layers[layer_idx].self_attn.q_proj",
|
||||
"layers_self_attention_linear_k_proj": "model.layers[layer_idx].self_attn.k_proj",
|
||||
"layers_self_attention_linear_v_proj": "model.layers[layer_idx].self_attn.v_proj",
|
||||
"layers_self_attention_pre_mlp_layernorm": "model.layers[layer_idx].post_attention_layernorm",
|
||||
"layers_mlp_gate_proj": "model.layers[layer_idx].mlp.gate_proj",
|
||||
"layers_mlp_up_proj": "model.layers[layer_idx].mlp.up_proj",
|
||||
"layers_mlp_linear_fc2": "model.layers[layer_idx].mlp.down_proj",
|
||||
"final_layernorm": "model.norm",
|
||||
"output_layer": "lm_head"
|
||||
}
|
||||
},
|
||||
"llama2": {
|
||||
@ -66,18 +67,29 @@
|
||||
"norm_epsilon": "layernorm_epsilon"
|
||||
},
|
||||
"model_hf_key_mapping": {
|
||||
"model": "module[tp_rank]",
|
||||
"embedding_word_embeddings": "module[tp_rank].transformer.embedding.word_embeddings",
|
||||
"layers": "module[tp_rank].transformer.encoder.layers",
|
||||
"layers_input_layernorm": "module[tp_rank].transformer.encoder.layers[layer_idx].input_layernorm",
|
||||
"layers_self_attention_linear_qkv_pack": "module[tp_rank].transformer.encoder.layers[layer_idx].self_attention.query_key_value",
|
||||
"layers_self_attention_linear_proj": "module[tp_rank].transformer.encoder.layers[layer_idx].self_attention.dense",
|
||||
"layers_self_attention_pre_mlp_layernorm": "module[tp_rank].transformer.encoder.layers[layer_idx].post_attention_layernorm",
|
||||
"layers_mlp_linear_fc1": "module[tp_rank].transformer.encoder.layers[layer_idx].mlp.dense_h_to_4h",
|
||||
"layers_mlp_linear_fc2": "module[tp_rank].transformer.encoder.layers[layer_idx].mlp.dense_4h_to_h",
|
||||
"final_layernorm": "module[tp_rank].transformer.encoder.final_layernorm",
|
||||
"output_layer": "module[tp_rank].transformer.output_layer"
|
||||
"embedding_word_embeddings": "transformer.embedding.word_embeddings",
|
||||
"layers": "transformer.encoder.layers",
|
||||
"layers_input_layernorm": "transformer.encoder.layers[layer_idx].input_layernorm",
|
||||
"layers_self_attention_linear_qkv_pack": "transformer.encoder.layers[layer_idx].self_attention.query_key_value",
|
||||
"layers_self_attention_linear_proj": "transformer.encoder.layers[layer_idx].self_attention.dense",
|
||||
"layers_self_attention_pre_mlp_layernorm": "transformer.encoder.layers[layer_idx].post_attention_layernorm",
|
||||
"layers_mlp_linear_fc1": "transformer.encoder.layers[layer_idx].mlp.dense_h_to_4h",
|
||||
"layers_mlp_linear_fc2": "transformer.encoder.layers[layer_idx].mlp.dense_4h_to_h",
|
||||
"final_layernorm": "transformer.encoder.final_layernorm",
|
||||
"output_layer": "transformer.output_layer"
|
||||
}
|
||||
},
|
||||
"mixtral": {
|
||||
"__base__": "base",
|
||||
"config_set_value": {
|
||||
"moe_flag": true
|
||||
},
|
||||
"model_hf_key_mapping": {
|
||||
"layers_mlp_router": "model.layers[layer_idx].block_sparse_moe.gate",
|
||||
"layers_mlp_gate_proj": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w1",
|
||||
"layers_mlp_up_proj": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w3",
|
||||
"layers_mlp_linear_fc2": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w2"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,80 +4,128 @@ import sys
|
||||
import re
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
import logging as logger
|
||||
from collections import OrderedDict
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM
|
||||
from megatron.core import mpu
|
||||
from megatron.training.arguments import validate_args
|
||||
from megatron.training.global_vars import set_global_variables
|
||||
from megatron.legacy.model import module
|
||||
from megatron.core.enums import ModelType
|
||||
from megatron.training.checkpointing import load_args_from_checkpoint
|
||||
from megatron.training.global_vars import set_args
|
||||
from megatron.training.checkpointing import load_checkpoint
|
||||
from megatron.core import tensor_parallel
|
||||
from pretrain_gpt import model_provider
|
||||
from modellink.utils import parse_args
|
||||
from modellink.training import model_provider_func_wrapper
|
||||
from modellink.checkpointing import load_checkpoint_wrapper
|
||||
|
||||
logger.basicConfig(format="")
|
||||
logger.getLogger().setLevel(logger.INFO)
|
||||
|
||||
model_provider = model_provider_func_wrapper(model_provider)
|
||||
load_checkpoint = load_checkpoint_wrapper(load_checkpoint)
|
||||
|
||||
|
||||
def tensor_info(tensor):
|
||||
shape = tensor.shape
|
||||
mean_val = tensor.mean().item()
|
||||
min_val = tensor.min().item()
|
||||
max_val = tensor.max().item()
|
||||
return f"shape: {shape} mean_val: {mean_val} min_val: {min_val} max_val: {max_val}"
|
||||
|
||||
|
||||
class ModelBase(abc.ABC):
|
||||
def __init__(self, args_cmd=None):
|
||||
self.args_cmd = args_cmd
|
||||
self.args = None
|
||||
self.args_megatron_checkpoint = None
|
||||
self.module = None
|
||||
self.module_mapping = None
|
||||
self.model_cfg = self.read_model_cfg()
|
||||
self.__register_functions()
|
||||
self.kwargs_idx = OrderedDict({
|
||||
"vp_rank": 0,
|
||||
"ep_rank": 0,
|
||||
"tp_rank": 0,
|
||||
"layer_idx": 0,
|
||||
"expert_idx": 0
|
||||
})
|
||||
|
||||
def update_kwargs_idx(self, **kwargs):
|
||||
for key in self.kwargs_idx:
|
||||
if key in kwargs:
|
||||
self.kwargs_idx[key] = kwargs[key]
|
||||
elif self.kwargs_idx[key] > 0:
|
||||
self.kwargs_idx[key] = 0
|
||||
|
||||
def __register_functions(self):
|
||||
self.get_module_mapping()
|
||||
kwargs_idx = dict({"tp_rank": 0, "layer_idx": 0})
|
||||
|
||||
def get_obj(self, value, **kwargs):
|
||||
def _get_obj(self, value, **kwargs):
|
||||
pattern = r'(\w+)(?:\[(\w+)\])?'
|
||||
matches = re.findall(pattern, value)
|
||||
obj = self
|
||||
for key in kwargs_idx:
|
||||
if key in kwargs:
|
||||
kwargs_idx[key] = kwargs[key]
|
||||
self.update_kwargs_idx(**kwargs)
|
||||
obj = self.get_model_item(**kwargs)
|
||||
for attr, attr_ident in matches:
|
||||
obj = getattr(obj, attr)
|
||||
if hasattr(obj, attr):
|
||||
obj = getattr(obj, attr)
|
||||
else:
|
||||
return None
|
||||
if attr_ident:
|
||||
if attr_ident in kwargs_idx:
|
||||
attr_idx = kwargs_idx[attr_ident]
|
||||
if attr_ident in self.kwargs_idx:
|
||||
attr_idx = self.kwargs_idx[attr_ident]
|
||||
obj = obj[attr_idx]
|
||||
else:
|
||||
raise AssertionError(f"check {self.__class__.__name__}.module_mapping **{attr_ident}**.")
|
||||
return obj
|
||||
|
||||
def func_generator_get_module(value):
|
||||
def _func_generator_get_module(value):
|
||||
def func(self, **kwargs):
|
||||
return get_obj(self, value, **kwargs)
|
||||
return _get_obj(self, value, **kwargs)
|
||||
return func
|
||||
|
||||
def func_generator_get_weight(value):
|
||||
def _func_generator_get_weight(value):
|
||||
def func(self, **kwargs):
|
||||
return get_obj(self, value, **kwargs).weight.data
|
||||
return _get_obj(self, value, **kwargs).weight.data
|
||||
return func
|
||||
|
||||
def func_generator_get_bias(value):
|
||||
def _func_generator_get_bias(value):
|
||||
def func(self, **kwargs):
|
||||
return get_obj(self, value, **kwargs).bias.data
|
||||
return _get_obj(self, value, **kwargs).bias.data
|
||||
return func
|
||||
|
||||
def func_generator_set_weight(value):
|
||||
def _func_generator_set_weight(value):
|
||||
def func(self, **kwargs):
|
||||
return get_obj(self, value, **kwargs).weight.data.copy_(kwargs.get('data'))
|
||||
return _get_obj(self, value, **kwargs).weight.data.copy_(kwargs.get('data'))
|
||||
return func
|
||||
|
||||
def func_generator_set_bias(value):
|
||||
def _func_generator_set_bias(value):
|
||||
def func(self, **kwargs):
|
||||
return get_obj(self, value, **kwargs).bias.data.copy_(kwargs.get('data'))
|
||||
return _get_obj(self, value, **kwargs).bias.data.copy_(kwargs.get('data'))
|
||||
return func
|
||||
|
||||
def _func_generator_has_module(value):
|
||||
def func(self, **kwargs):
|
||||
obj = _get_obj(self, value, **kwargs)
|
||||
return True if obj else False
|
||||
return func
|
||||
|
||||
def _setattr(self, func_name, value):
|
||||
if not hasattr(self, func_name):
|
||||
setattr(self, func_name, value)
|
||||
|
||||
if self.module_mapping:
|
||||
for key, value in self.module_mapping.items():
|
||||
setattr(self, "get_" + key + "_module", func_generator_get_module(value).__get__(self, ModelBase))
|
||||
setattr(self, "get_" + key + "_weight", func_generator_get_weight(value).__get__(self, ModelBase))
|
||||
setattr(self, "get_" + key + "_bias", func_generator_get_bias(value).__get__(self, ModelBase))
|
||||
setattr(self, "set_" + key + "_weight", func_generator_set_weight(value).__get__(self, ModelBase))
|
||||
setattr(self, "set_" + key + "_bias", func_generator_set_bias(value).__get__(self, ModelBase))
|
||||
_setattr(self, "get_" + key + "_module", _func_generator_get_module(value).__get__(self, ModelBase))
|
||||
_setattr(self, "get_" + key + "_weight", _func_generator_get_weight(value).__get__(self, ModelBase))
|
||||
_setattr(self, "get_" + key + "_bias", _func_generator_get_bias(value).__get__(self, ModelBase))
|
||||
_setattr(self, "set_" + key + "_weight", _func_generator_set_weight(value).__get__(self, ModelBase))
|
||||
_setattr(self, "set_" + key + "_bias", _func_generator_set_bias(value).__get__(self, ModelBase))
|
||||
_setattr(self, "has_" + key + "_module", _func_generator_has_module(value).__get__(self, ModelBase))
|
||||
|
||||
def update_module(self, src_model):
|
||||
self.set_preprocess_state(src_model)
|
||||
@ -87,51 +135,58 @@ class ModelBase(abc.ABC):
|
||||
|
||||
def set_preprocess_state(self, src_model):
|
||||
'''Set embedding params.'''
|
||||
self.set_embedding_word_embeddings_weight(data=src_model.get_embedding_word_embeddings_weight())
|
||||
embeddings_weight = src_model.get_embedding_word_embeddings_weight()
|
||||
self.set_embedding_word_embeddings_weight(data=embeddings_weight)
|
||||
|
||||
def set_postprocess_state(self, src_model):
|
||||
self.set_final_layernorm_weight(data=src_model.get_final_layernorm_weight())
|
||||
final_layernorm_weight = src_model.get_final_layernorm_weight()
|
||||
output_layer_weight = src_model.get_output_layer_weight()
|
||||
self.set_final_layernorm_weight(data=final_layernorm_weight)
|
||||
if self.args.untie_embeddings_and_output_weights:
|
||||
self.set_output_layer_weight(data=src_model.get_output_layer_weight())
|
||||
self.set_output_layer_weight(data=output_layer_weight)
|
||||
|
||||
def set_layer_state(self, src_model, layer_idx):
|
||||
self.set_attn_state(layer_idx, src_model)
|
||||
self.set_mlp_state(layer_idx, src_model)
|
||||
self.set_layers_input_layernorm_weight(
|
||||
layer_idx=layer_idx,
|
||||
data=src_model.get_layers_input_layernorm_weight(layer_idx=layer_idx))
|
||||
self.set_layers_self_attention_pre_mlp_layernorm_weight(
|
||||
layer_idx=layer_idx,
|
||||
data=src_model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx))
|
||||
input_layernorm_weight = src_model.get_layers_input_layernorm_weight(layer_idx=layer_idx)
|
||||
pre_mlp_layernorm_weight = src_model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx)
|
||||
self.set_layers_input_layernorm_weight(layer_idx=layer_idx, data=input_layernorm_weight)
|
||||
self.set_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx, data=pre_mlp_layernorm_weight)
|
||||
|
||||
def set_attn_state(self, layer_idx, src_model):
|
||||
'''Set self-attention params.'''
|
||||
# Get attention layer & state.
|
||||
self.set_layers_self_attention_linear_qkv_weight(
|
||||
layer_idx=layer_idx,
|
||||
data=src_model.get_layers_self_attention_linear_qkv_weight(layer_idx=layer_idx))
|
||||
|
||||
self.set_layers_self_attention_linear_proj_weight(
|
||||
layer_idx=layer_idx,
|
||||
data=src_model.get_layers_self_attention_linear_proj_weight(layer_idx=layer_idx))
|
||||
qkv_weight = src_model.get_layers_self_attention_linear_qkv_weight(layer_idx=layer_idx)
|
||||
proj_weight = src_model.get_layers_self_attention_linear_proj_weight(layer_idx=layer_idx)
|
||||
self.set_layers_self_attention_linear_qkv_weight(layer_idx=layer_idx, data=qkv_weight)
|
||||
self.set_layers_self_attention_linear_proj_weight(layer_idx=layer_idx, data=proj_weight)
|
||||
if self.args.add_qkv_bias:
|
||||
self.set_layers_self_attention_linear_qkv_bias(
|
||||
layer_idx=layer_idx,
|
||||
data=src_model.get_layers_self_attention_linear_qkv_bias(layer_idx=layer_idx))
|
||||
qkv_bias = src_model.get_layers_self_attention_linear_qkv_bias(layer_idx=layer_idx)
|
||||
self.set_layers_self_attention_linear_qkv_bias(layer_idx=layer_idx, data=qkv_bias)
|
||||
if self.args.add_dense_bias:
|
||||
self.set_layers_self_attention_linear_proj_bias(
|
||||
layer_idx=layer_idx,
|
||||
data=src_model.get_layers_self_attention_linear_proj_bias(layer_idx=layer_idx))
|
||||
proj_bias = src_model.get_layers_self_attention_linear_proj_bias(layer_idx=layer_idx)
|
||||
self.set_layers_self_attention_linear_proj_bias(layer_idx=layer_idx, data=proj_bias)
|
||||
|
||||
def _set_mlp_state(self, src_model, **kwargs):
|
||||
'''Set MLP params.'''
|
||||
fc1_weight = src_model.get_layers_mlp_linear_fc1_weight(**kwargs)
|
||||
fc2_weight = src_model.get_layers_mlp_linear_fc2_weight(**kwargs)
|
||||
self.set_layers_mlp_linear_fc1_weight(data=fc1_weight, **kwargs)
|
||||
self.set_layers_mlp_linear_fc2_weight(data=fc2_weight, **kwargs)
|
||||
|
||||
def set_mlp_state(self, layer_idx, src_model):
|
||||
'''Set MLP params.'''
|
||||
self.set_layers_mlp_linear_fc1_weight(
|
||||
layer_idx=layer_idx,
|
||||
data=src_model.get_layers_mlp_linear_fc1_weight(layer_idx=layer_idx))
|
||||
args = src_model.get_args()
|
||||
kwargs = {'layer_idx': layer_idx}
|
||||
num_experts = getattr(args, 'num_experts', None) or getattr(args, 'num_local_experts', None)
|
||||
if num_experts:
|
||||
router_weight = src_model.get_layers_mlp_router_weight(**kwargs)
|
||||
self.set_layers_mlp_router_weight(**kwargs, data=router_weight)
|
||||
for expert_idx in range(num_experts):
|
||||
kwargs['expert_idx'] = expert_idx
|
||||
self._set_mlp_state(src_model, **kwargs)
|
||||
else:
|
||||
self._set_mlp_state(src_model, **kwargs)
|
||||
|
||||
self.set_layers_mlp_linear_fc2_weight(
|
||||
layer_idx=layer_idx,
|
||||
data=src_model.get_layers_mlp_linear_fc2_weight(layer_idx=layer_idx))
|
||||
|
||||
def get_args(self):
|
||||
return self.args
|
||||
@ -145,21 +200,57 @@ class ModelBase(abc.ABC):
|
||||
def get_modules_count(self):
|
||||
return len(self.module)
|
||||
|
||||
@staticmethod
|
||||
def read_model_cfg():
|
||||
def merge_configs(base_config, specific_config):
|
||||
merged_config = base_config.copy()
|
||||
for key, value in specific_config.items():
|
||||
if isinstance(value, dict) and key in merged_config:
|
||||
merged_config[key] = merge_configs(merged_config[key], value)
|
||||
else:
|
||||
merged_config[key] = value
|
||||
return merged_config
|
||||
|
||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(os.path.join(current_directory, 'model_cfg.json'), 'r') as file:
|
||||
config = json.load(file)
|
||||
final_configs = {}
|
||||
|
||||
for model_name, model_config in config["model_mappings"].items():
|
||||
if "__base__" in model_config:
|
||||
base_model_name = model_config["__base__"]
|
||||
base_config = config["model_mappings"][base_model_name]
|
||||
specific_config = model_config.copy()
|
||||
specific_config.pop("__base__", None)
|
||||
final_config = merge_configs(base_config, specific_config)
|
||||
else:
|
||||
final_config = model_config
|
||||
final_configs[model_name] = final_config
|
||||
|
||||
return final_configs
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_module_mapping(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_model_item(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class HuggingfaceModel(ModelBase):
|
||||
def __init__(self, args_cmd):
|
||||
self.model_cfg = self.read_model_cfg()
|
||||
super(HuggingfaceModel, self).__init__(args_cmd)
|
||||
self.initialize_args()
|
||||
self.layers_self_attention_linear_qkv_caches = {"layer_idx": -1, "weight": None, "bias": None}
|
||||
|
||||
def initialize_args(self):
|
||||
# Read huggingface args.
|
||||
llama_args_path = os.path.join(self.args_cmd.load_dir, "config.json")
|
||||
if self.args_cmd.save_model_type == 'huggingface':
|
||||
cfg_dir = self.args_cmd.save_dir
|
||||
else:
|
||||
cfg_dir = self.args_cmd.load_dir
|
||||
llama_args_path = os.path.join(cfg_dir, "config.json")
|
||||
with open(llama_args_path) as f:
|
||||
self.args = json.load(f)
|
||||
|
||||
@ -189,21 +280,20 @@ class HuggingfaceModel(ModelBase):
|
||||
load_dir = self.args_cmd.save_dir
|
||||
else:
|
||||
load_dir = self.args_cmd.load_dir
|
||||
self.module = [
|
||||
AutoModelForCausalLM.from_pretrained(load_dir, device_map=device_map, trust_remote_code=trust_remote_code)
|
||||
]
|
||||
self.module = [AutoModelForCausalLM.from_pretrained(load_dir, device_map=device_map, trust_remote_code=trust_remote_code)]
|
||||
if self.args.torch_dtype in ["float16", "bfloat16"]:
|
||||
self.module[0] = self.module[0].to(eval(f'torch.{self.args.torch_dtype}'))
|
||||
|
||||
def get_module_mapping(self):
|
||||
self.module_mapping = self.model_cfg.get(self.args_cmd.model_type_hf).get('model_hf_key_mapping')
|
||||
|
||||
def _get_layers_self_attention_linear_qkv_module(self, layer_idx=0):
|
||||
def __get_layers_self_attention_linear_qkv_module(self, layer_idx=0):
|
||||
if self.layers_self_attention_linear_qkv_caches["layer_idx"] == layer_idx:
|
||||
return
|
||||
self.layers_self_attention_linear_qkv_caches["layer_idx"] = layer_idx
|
||||
# Reshape loaded weights.
|
||||
nh = self.args.num_attention_heads
|
||||
ng = (self.args.num_key_value_heads if self.args.group_query_attention else self.args.num_attention_heads)
|
||||
# dim = self.args['kv_channels']
|
||||
dim = self.args.hidden_size // self.args.num_attention_heads
|
||||
if not nh % ng == 0:
|
||||
raise ValueError("nh % ng should equal 0")
|
||||
@ -237,7 +327,6 @@ class HuggingfaceModel(ModelBase):
|
||||
qkv_pack_weight = qkv_pack.weight
|
||||
full_q = dim * nh
|
||||
end_k = full_q + ng * dim
|
||||
hs = self.args.hidden_size
|
||||
q_weight = qkv_pack_weight[:full_q, :]
|
||||
k_weight = qkv_pack_weight[full_q:end_k, :]
|
||||
v_weight = qkv_pack_weight[end_k:, :]
|
||||
@ -253,82 +342,115 @@ class HuggingfaceModel(ModelBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported types. {qkv_type}")
|
||||
|
||||
def get_layers_mlp_linear_fc1_weight(self, layer_idx=0):
|
||||
def get_layers_mlp_linear_fc1_weight(self, **kwargs):
|
||||
fc_type = self.args.fc_type
|
||||
if fc_type == "h_to_4h":
|
||||
return self.get_layers_mlp_linear_fc1_module(layer_idx=layer_idx).weight
|
||||
return self.get_layers_mlp_linear_fc1_module(**kwargs).weight
|
||||
elif fc_type == "gate_up_down":
|
||||
gate_proj = self.get_layers_mlp_gate_proj_weight(layer_idx=layer_idx)
|
||||
up_proj = self.get_layers_mlp_up_proj_weight(layer_idx=layer_idx)
|
||||
gate_proj = self.get_layers_mlp_gate_proj_weight(**kwargs)
|
||||
up_proj = self.get_layers_mlp_up_proj_weight(**kwargs)
|
||||
return torch.cat([gate_proj, up_proj], dim=0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported fc_type {fc_type}")
|
||||
|
||||
def get_layers_self_attention_linear_qkv_weight(self, layer_idx):
|
||||
self._get_layers_self_attention_linear_qkv_module(layer_idx=layer_idx)
|
||||
self.__get_layers_self_attention_linear_qkv_module(layer_idx=layer_idx)
|
||||
return self.layers_self_attention_linear_qkv_caches["weight"]
|
||||
|
||||
def get_layers_self_attention_linear_qkv_bias(self, layer_idx):
|
||||
self._get_layers_self_attention_linear_qkv_module(layer_idx=layer_idx)
|
||||
self.__get_layers_self_attention_linear_qkv_module(layer_idx=layer_idx)
|
||||
return self.layers_self_attention_linear_qkv_caches["bias"]
|
||||
|
||||
@staticmethod
|
||||
def read_model_cfg():
|
||||
def merge_configs(base_config, specific_config):
|
||||
merged_config = base_config.copy()
|
||||
for key, value in specific_config.items():
|
||||
if isinstance(value, dict) and key in merged_config:
|
||||
merged_config[key] = merge_configs(merged_config[key], value)
|
||||
else:
|
||||
merged_config[key] = value
|
||||
return merged_config
|
||||
def set_layers_mlp_linear_fc1_weight(self, data=None, **kwargs):
|
||||
gate_proj, up_proj = torch.chunk(data, 2, dim=0)
|
||||
self.set_layers_mlp_gate_proj_weight(data=gate_proj, **kwargs)
|
||||
self.set_layers_mlp_up_proj_weight(data=up_proj, **kwargs)
|
||||
|
||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(os.path.join(current_directory, 'model_cfg.json'), 'r') as file:
|
||||
config = json.load(file)
|
||||
# 存储最终配置的字典
|
||||
final_configs = {}
|
||||
def set_layers_self_attention_linear_qkv_weight(self, layer_idx=0, data=None):
|
||||
def qkv_split_weight(query_key_value):
|
||||
qkv_weight = query_key_value.reshape(
|
||||
ng,
|
||||
repeats + 2,
|
||||
query_key_value.shape[0] // ng // (repeats + 2),
|
||||
query_key_value.shape[1],
|
||||
)
|
||||
hidden_size = qkv_weight.shape[-1]
|
||||
qw = qkv_weight[:, :repeats, ...].reshape(-1, hidden_size)
|
||||
kw = qkv_weight[:, repeats: repeats + 1, ...].reshape(-1, hidden_size)
|
||||
vw = qkv_weight[:, repeats + 1:, ...].reshape(-1, hidden_size)
|
||||
return qw, kw, vw
|
||||
|
||||
# 遍历所有模型配置
|
||||
for model_name, model_config in config["model_mappings"].items():
|
||||
if "__base__" in model_config:
|
||||
base_model_name = model_config["__base__"]
|
||||
base_config = config["model_mappings"][base_model_name]
|
||||
specific_config = model_config.copy()
|
||||
specific_config.pop("__base__", None)
|
||||
final_config = merge_configs(base_config, specific_config)
|
||||
else:
|
||||
final_config = model_config
|
||||
final_configs[model_name] = final_config
|
||||
nh = self.args.num_attention_heads
|
||||
ng = (self.args.num_key_value_heads if self.args.group_query_attention else self.args.num_attention_heads)
|
||||
if not nh % ng == 0:
|
||||
raise ValueError("nh % ng should equal 0")
|
||||
repeats = nh // ng
|
||||
|
||||
return final_configs
|
||||
qkv_type = self.args.qkv_type
|
||||
if qkv_type == "unpack":
|
||||
q_weight, k_weight, v_weight = qkv_split_weight(data)
|
||||
self.set_layers_self_attention_linear_q_proj_weight(layer_idx=layer_idx, data=q_weight)
|
||||
self.set_layers_self_attention_linear_k_proj_weight(layer_idx=layer_idx, data=k_weight)
|
||||
self.set_layers_self_attention_linear_v_proj_weight(layer_idx=layer_idx, data=v_weight)
|
||||
else:
|
||||
raise ValueError(f"Unsupported types. {qkv_type}")
|
||||
|
||||
def set_layers_self_attention_linear_qkv_bias(self, layer_idx, data=None):
|
||||
def qkv_split_bias(query_key_value):
|
||||
bias_weight = query_key_value.reshape(
|
||||
ng, repeats + 2, query_key_value.shape[0] // ng // (repeats + 2)
|
||||
)
|
||||
qw = bias_weight[:, :repeats, ...].reshape(-1)
|
||||
kw = bias_weight[:, repeats: repeats + 1, ...].reshape(-1)
|
||||
vw = bias_weight[:, repeats + 1:, ...].reshape(-1)
|
||||
return qw, kw, vw
|
||||
|
||||
nh = self.args.num_attention_heads
|
||||
ng = (self.args.num_key_value_heads if self.args.group_query_attention else self.args.num_attention_heads)
|
||||
if not nh % ng == 0:
|
||||
raise ValueError("nh % ng should equal 0")
|
||||
repeats = nh // ng
|
||||
|
||||
qkv_type = self.args.qkv_type
|
||||
if qkv_type == "unpack":
|
||||
if self.args_cmd.add_qkv_bias:
|
||||
q_bias, k_bias, v_bias = qkv_split_bias(data)
|
||||
self.set_layers_self_attention_linear_q_proj_bias(layer_idx=layer_idx, data=q_bias)
|
||||
self.set_layers_self_attention_linear_k_proj_bias(layer_idx=layer_idx, data=k_bias)
|
||||
self.set_layers_self_attention_linear_v_proj_bias(layer_idx=layer_idx, data=v_bias)
|
||||
else:
|
||||
raise ValueError(f"Unsupported types. {qkv_type}")
|
||||
|
||||
def get_model_item(self, **kwargs):
|
||||
return self.module[0]
|
||||
|
||||
|
||||
class MegatronModel(ModelBase):
|
||||
def __init__(self, args_cmd, md=None):
|
||||
super(MegatronModel, self).__init__(args_cmd)
|
||||
self.md = md
|
||||
self.pp_stage_cache = []
|
||||
|
||||
def initialize_megatron_args(self, hf_args=None, queue=None):
|
||||
def initialize_megatron_args(self, hf_args=None, queue=None, loader_megatron=False, saver_megatron=False):
|
||||
sys.argv = self.get_sys_argv()
|
||||
self.args = parse_args()
|
||||
|
||||
self.update_megatron_args_from_cmd_config() # saver里面是否都需要这个,要验证
|
||||
self.update_megatron_args_from_huggingface_config(hf_args) # loader走, saver不走
|
||||
self.update_megatron_args_from_megatron_checkpoint(loader_megatron)
|
||||
self.update_megatron_args_from_cmd_config(loader_megatron)
|
||||
self.update_megatron_args_from_huggingface_config(hf_args)
|
||||
|
||||
# Arguments do sanity checks on the world size, but we don't care,
|
||||
# so trick it into thinking we are plenty of processes.
|
||||
self.args.world_size = self.args.tensor_model_parallel_size * self.args.pipeline_model_parallel_size
|
||||
self.update_megatron_args_from_loader_margs()
|
||||
self.args = validate_args(self.args)
|
||||
self.check_for_args(queue)
|
||||
self.check_for_args(queue, saver_megatron)
|
||||
|
||||
self.args.model_type = ModelType.encoder_or_decoder
|
||||
# Suppress warning about torch.distributed not being initialized.
|
||||
module.MegatronModule.embedding_warning_printed = True
|
||||
|
||||
set_global_variables(self.args, build_tokenizer=False)
|
||||
self.set_megatron_parallel_state()
|
||||
set_args(self.args)
|
||||
self.set_megatron_parallel_state(saver_megatron)
|
||||
|
||||
def update_megatron_args_from_loader_margs(self):
|
||||
if self.md and hasattr(self.md, 'checkpoint_args'):
|
||||
@ -342,27 +464,28 @@ class MegatronModel(ModelBase):
|
||||
'recompute_num_layers', 'recompute_method', 'encoder_num_layers', 'encoder_seq_length',
|
||||
'distribute_saved_activations', 'train_iters', 'lr_decay_iters', 'lr_warmup_iters',
|
||||
'lr_warmup_fraction', 'start_weight_decay', 'end_weight_decay', 'make_vocab_size_divisible_by',
|
||||
'masked_softmax_fusion', 'num_layer_list',
|
||||
'masked_softmax_fusion', 'num_layer_list', 'lora_target_modules', 'expert_model_parallel_size'
|
||||
]
|
||||
|
||||
for arg, value in vars(self.md.checkpoint_args).items():
|
||||
if arg in args_to_keep:
|
||||
continue
|
||||
if not hasattr(self.args, arg):
|
||||
print(f"Checkpoint had argument {arg} but new arguments does not have this.")
|
||||
logger.warning(f"Checkpoint had argument {arg} but new arguments does not have this.")
|
||||
continue
|
||||
if getattr(self.args, arg) != value:
|
||||
print(
|
||||
f"Overwriting default {arg} value {getattr(self.args, arg)} with value from checkpoint {value}.")
|
||||
logger.warning(
|
||||
f"Overwriting default {arg} value {getattr(self.args, arg)} with value from checkpoint {value}."
|
||||
)
|
||||
setattr(self.args, arg, value)
|
||||
|
||||
if hasattr(self.md, 'consumed_train_samples'):
|
||||
self.args.consumed_train_samples = self.md.consumed_train_samples
|
||||
self.args.consumed_valid_samples = self.md.consumed_valid_samples
|
||||
print(f"Setting consumed_train_samples to {self.args.consumed_train_samples}"
|
||||
f" and consumed_valid_samples to {self.args.consumed_valid_samples}")
|
||||
logger.info(f"Setting consumed_train_samples to {self.args.consumed_train_samples} "
|
||||
f"and consumed_valid_samples to {self.args.consumed_valid_samples}")
|
||||
else:
|
||||
print("consumed_train_samples not provided.")
|
||||
logger.warning("consumed_train_samples not provided.")
|
||||
|
||||
def update_megatron_args_from_huggingface_config(self, hf_args):
|
||||
if hf_args is None:
|
||||
@ -397,18 +520,33 @@ class MegatronModel(ModelBase):
|
||||
):
|
||||
self.args.group_query_attention = True
|
||||
self.args.num_query_groups = hf_args.num_key_value_heads
|
||||
if hasattr(hf_args, 'num_local_experts'):
|
||||
self.args.num_experts = hf_args.num_local_experts
|
||||
|
||||
def update_megatron_args_from_cmd_config(self):
|
||||
def update_megatron_args_from_megatron_checkpoint(self, loader_megatron):
|
||||
if not loader_megatron:
|
||||
return
|
||||
set_args(self.args)
|
||||
self.args, self.args_megatron_checkpoint = load_args_from_checkpoint(self.args)
|
||||
|
||||
def update_megatron_args_from_cmd_config(self, loader_megatron):
|
||||
self.args.w_pack = self.args_cmd.w_pack
|
||||
self.args.add_qkv_bias = self.args_cmd.add_qkv_bias
|
||||
self.args.add_dense_bias = self.args_cmd.add_dense_bias
|
||||
self.args.tokenizer_model = self.args_cmd.tokenizer_model
|
||||
self.args.make_vocab_size_divisible_by = self.args_cmd.make_vocab_size_divisible_by
|
||||
self.args.tokenizer_model = getattr(self.args_cmd, 'tokenizer_model', None)
|
||||
self.args.make_vocab_size_divisible_by = getattr(self.args_cmd, 'make_vocab_size_divisible_by', None)
|
||||
if self.args_cmd.params_dtype == 'bf16':
|
||||
self.args.bf16 = True
|
||||
elif self.args_cmd.params_dtype == 'fp16':
|
||||
self.args.fp16 = True
|
||||
if self.args_cmd.add_dense_bias:
|
||||
self.args.skip_bias_add = False
|
||||
|
||||
if loader_megatron:
|
||||
self.args.lora_target_modules = self.args_cmd.lora_target_modules
|
||||
self.args.lora_load = self.args_cmd.lora_load
|
||||
self.args.lora_r = self.args_cmd.lora_r
|
||||
self.args.lora_alpha = self.args_cmd.lora_alpha
|
||||
# Determine how to make our models.
|
||||
if not self.args_cmd.model_type == 'GPT':
|
||||
raise ValueError("Llama-2 is a GPT model.")
|
||||
@ -416,34 +554,113 @@ class MegatronModel(ModelBase):
|
||||
if self.md and self.args_cmd.num_layer_list:
|
||||
self.args.num_layer_list = self.args_cmd.num_layer_list
|
||||
|
||||
def set_megatron_parallel_state(self):
|
||||
self.set_tensor_model_parallel_world_size(self.args.tensor_model_parallel_size)
|
||||
self.set_pipeline_model_parallel_world_size(self.args.pipeline_model_parallel_size)
|
||||
self.set_virtual_pipeline_model_parallel_world_size(self.args.virtual_pipeline_model_parallel_size)
|
||||
def set_megatron_parallel_state(self, saver_megatron):
|
||||
if saver_megatron:
|
||||
self.set_tensor_model_parallel_world_size(self.args_cmd.target_tensor_parallel_size)
|
||||
self.set_expert_model_parallel_world_size(self.args_cmd.target_expert_parallel_size)
|
||||
self.set_pipeline_model_parallel_world_size(self.args_cmd.target_pipeline_parallel_size)
|
||||
if self.args_cmd.num_layers_per_virtual_pipeline_stage:
|
||||
vp_size = (self.args.num_layers //
|
||||
self.args_cmd.target_pipeline_parallel_size //
|
||||
self.args_cmd.num_layers_per_virtual_pipeline_stage)
|
||||
self.set_virtual_pipeline_model_parallel_world_size(vp_size)
|
||||
else:
|
||||
self.set_tensor_model_parallel_world_size(self.args.tensor_model_parallel_size)
|
||||
self.set_pipeline_model_parallel_world_size(self.args.pipeline_model_parallel_size)
|
||||
self.set_virtual_pipeline_model_parallel_world_size(self.args.virtual_pipeline_model_parallel_size)
|
||||
|
||||
# Get first pipe stage.
|
||||
self.set_tensor_model_parallel_rank(0)
|
||||
self.set_pipeline_model_parallel_rank(0)
|
||||
|
||||
def get_modules_from_config(self, count=1, pre_process=True, post_process=True):
|
||||
self.args.model_type = ModelType.encoder_or_decoder
|
||||
self.module = [model_provider(pre_process, post_process).to(self.args.params_dtype) for _ in range(count)]
|
||||
def get_modules_from_config(self, pp_stage_cache_flag=False):
|
||||
self.__get_modules(pp_stage_cache_flag=pp_stage_cache_flag)
|
||||
|
||||
def check_for_args(self, queue):
|
||||
check_args_list = {'tensor_model_parallel_size': None, 'pipeline_model_parallel_size': None,
|
||||
'num_layers': None, 'hidden_size': None, 'seq_length': None,
|
||||
'num_attention_heads': None, 'max_position_embeddings': None,
|
||||
'position_embedding_type': None, 'tokenizer_type': None, 'iteration': 1,
|
||||
'bert_binary_head': None, 'disable_bias_linear': False, 'params_dtype': None,
|
||||
'swiglu': False}
|
||||
def get_modules_from_pretrained(self, pp_stage_cache_flag=False):
|
||||
self.__get_modules(from_pretrained=True, pp_stage_cache_flag=pp_stage_cache_flag)
|
||||
|
||||
def __get_modules(self, from_pretrained=False, pp_stage_cache_flag=False):
|
||||
if self.args.num_experts:
|
||||
tensor_parallel.model_parallel_cuda_manual_seed(123)
|
||||
# Initialize the dictionary for the parallel mode of the model
|
||||
pp_rank = self.get_pipeline_model_parallel_rank()
|
||||
if pp_stage_cache_flag and pp_rank < len(self.pp_stage_cache):
|
||||
self.module = self.pp_stage_cache[pp_rank]
|
||||
return
|
||||
|
||||
virtual_pipeline_model_parallel_size = self.args.virtual_pipeline_model_parallel_size
|
||||
if virtual_pipeline_model_parallel_size is None:
|
||||
virtual_pipeline_model_parallel_size = 1
|
||||
|
||||
models = [
|
||||
[
|
||||
[
|
||||
None for _ in range(self.args.tensor_model_parallel_size)
|
||||
]
|
||||
for _ in range(self.args.expert_model_parallel_size)
|
||||
]
|
||||
for _ in range(virtual_pipeline_model_parallel_size)
|
||||
]
|
||||
|
||||
for ep_rank in range(self.args.expert_model_parallel_size):
|
||||
if self.args.expert_model_parallel_size > 1:
|
||||
self.set_expert_model_parallel_rank(ep_rank)
|
||||
for tp_rank in range(self.args.tensor_model_parallel_size):
|
||||
self.set_tensor_model_parallel_rank(tp_rank)
|
||||
if self.args.virtual_pipeline_model_parallel_size is not None:
|
||||
model_ = []
|
||||
for vp_rank in range(self.args.virtual_pipeline_model_parallel_size):
|
||||
self.set_virtual_pipeline_model_parallel_rank(vp_rank)
|
||||
# Set pre_process and post_process only after virtual rank is set.
|
||||
pre_process = mpu.is_pipeline_first_stage()
|
||||
post_process = mpu.is_pipeline_last_stage()
|
||||
expert_parallel_size = mpu.get_expert_model_parallel_world_size()
|
||||
this_model = model_provider(
|
||||
pre_process=pre_process,
|
||||
post_process=post_process
|
||||
).to(self.args.params_dtype)
|
||||
model_.append(this_model)
|
||||
else:
|
||||
pre_process = mpu.is_pipeline_first_stage()
|
||||
post_process = mpu.is_pipeline_last_stage()
|
||||
model_ = [model_provider(pre_process, post_process).to(self.args.params_dtype)]
|
||||
self.args.consumed_train_samples = 0
|
||||
self.args.consumed_valid_samples = 0
|
||||
if from_pretrained:
|
||||
load_checkpoint(model_, None, None)
|
||||
for vp_rank in range(virtual_pipeline_model_parallel_size):
|
||||
models[vp_rank][ep_rank][tp_rank] = model_[vp_rank]
|
||||
if self.args.lora_target_modules and from_pretrained:
|
||||
if virtual_pipeline_model_parallel_size > 1:
|
||||
raise AssertionError("Virtual pipeline and LoRA weight merging "
|
||||
"are not supported simultaneously")
|
||||
models[vp_rank][ep_rank][tp_rank].merge_and_unload()
|
||||
|
||||
self.module = models
|
||||
|
||||
if pp_stage_cache_flag:
|
||||
self.pp_stage_cache.append(models)
|
||||
|
||||
|
||||
def check_for_args(self, queue, saver_megatron):
|
||||
if saver_megatron:
|
||||
return
|
||||
check_args_list = {
|
||||
'tensor_model_parallel_size': None, 'pipeline_model_parallel_size': None, 'num_layers': None,
|
||||
'hidden_size': None, 'seq_length': None, 'num_attention_heads': None, 'max_position_embeddings': None,
|
||||
'position_embedding_type': None, 'tokenizer_type': None, 'iteration': 1, 'bert_binary_head': None,
|
||||
'disable_bias_linear': False, 'params_dtype': None, 'swiglu': False
|
||||
}
|
||||
# if hasattr(self.args, 'add_bias_linear'):
|
||||
# check_args_list['disable_bias_linear'] = self.args.add_bias_linear
|
||||
|
||||
def check_for_arg(arg_name, default=None):
|
||||
if getattr(self.args, arg_name, None) is None:
|
||||
if default is not None:
|
||||
setattr(self.args, arg_name, default)
|
||||
elif queue is not None:
|
||||
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
|
||||
print(f"Arguments: {self.args}")
|
||||
logger.error(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
|
||||
logger.info(f"Arguments: {self.args}")
|
||||
queue.put("exit")
|
||||
exit(1)
|
||||
|
||||
@ -465,8 +682,14 @@ class MegatronModel(ModelBase):
|
||||
'--no-save-rng',
|
||||
'--no-initialization',
|
||||
'--save-interval', '1',
|
||||
'--load', self.args_cmd.load_dir
|
||||
'--mock-data', # To pass the "blend data checks" in arguments.py
|
||||
'--load', self.args_cmd.load_dir,
|
||||
'--finetune',
|
||||
# '--disable-bias-linear'
|
||||
]
|
||||
|
||||
if hasattr(self.args_cmd, 'add_bias_linear') and not self.args_cmd.add_bias_linear:
|
||||
sys_argv.append('--disable-bias-linear')
|
||||
|
||||
if self.args_cmd.use_mcore_models:
|
||||
sys_argv.append('--use-mcore-models')
|
||||
@ -484,9 +707,18 @@ class MegatronModel(ModelBase):
|
||||
'--tokenizer-type', str(self.md.tokenizer_type),
|
||||
'--tensor-model-parallel-size', str(self.args_cmd.target_tensor_parallel_size),
|
||||
'--pipeline-model-parallel-size', str(self.args_cmd.target_pipeline_parallel_size),
|
||||
'--expert-model-parallel-size', str(self.args_cmd.target_expert_parallel_size),
|
||||
'--save', self.args_cmd.save_dir
|
||||
])
|
||||
|
||||
if self.args_cmd.num_layers_per_virtual_pipeline_stage:
|
||||
sys_argv.extend(['--num-layers-per-virtual-pipeline-stage',
|
||||
str(self.args_cmd.num_layers_per_virtual_pipeline_stage)])
|
||||
|
||||
num_experts = getattr(self.md.checkpoint_args, 'num_experts', None)
|
||||
if self.args_cmd.target_tensor_parallel_size > 1 and num_experts is not None and num_experts > 1:
|
||||
sys_argv.append('--sequence-parallel')
|
||||
|
||||
if self.md.make_vocab_size_divisible_by is not None:
|
||||
sys_argv.extend(['--make-vocab-size-divisible-by', str(self.md.make_vocab_size_divisible_by)])
|
||||
if self.md.params_dtype == torch.float16:
|
||||
@ -504,10 +736,22 @@ class MegatronModel(ModelBase):
|
||||
|
||||
return sys_argv
|
||||
|
||||
def get_model_item(self, **kwargs):
|
||||
self.update_kwargs_idx(**kwargs)
|
||||
_module = self.module
|
||||
for key in self.kwargs_idx:
|
||||
if "rank" in key:
|
||||
_module = _module[self.kwargs_idx[key]]
|
||||
return _module
|
||||
|
||||
@staticmethod
|
||||
def set_tensor_model_parallel_world_size(tensor_model_parallel_size):
|
||||
mpu.set_tensor_model_parallel_world_size(tensor_model_parallel_size)
|
||||
|
||||
@staticmethod
|
||||
def set_expert_model_parallel_world_size(expert_model_parallel_size):
|
||||
mpu.set_expert_model_parallel_world_size(expert_model_parallel_size)
|
||||
|
||||
@staticmethod
|
||||
def set_pipeline_model_parallel_world_size(pipeline_model_parallel_size):
|
||||
mpu.set_pipeline_model_parallel_world_size(pipeline_model_parallel_size)
|
||||
@ -524,6 +768,18 @@ class MegatronModel(ModelBase):
|
||||
def set_pipeline_model_parallel_rank(pipeline_model_parallel_rank):
|
||||
mpu.set_pipeline_model_parallel_rank(pipeline_model_parallel_rank)
|
||||
|
||||
@staticmethod
|
||||
def set_expert_model_parallel_rank(pipeline_model_parallel_rank):
|
||||
mpu.set_expert_model_parallel_rank(pipeline_model_parallel_rank)
|
||||
|
||||
@staticmethod
|
||||
def set_virtual_pipeline_model_parallel_rank(pipeline_model_parallel_rank):
|
||||
mpu.set_virtual_pipeline_model_parallel_rank(pipeline_model_parallel_rank)
|
||||
|
||||
@staticmethod
|
||||
def get_pipeline_model_parallel_rank():
|
||||
return mpu.get_pipeline_model_parallel_rank()
|
||||
|
||||
|
||||
class MegatronLegacyModel(MegatronModel):
|
||||
def __init__(self, args_cmd, md=None):
|
||||
@ -535,31 +791,33 @@ class MegatronMCoreModel(MegatronModel):
|
||||
super(MegatronMCoreModel, self).__init__(args_cmd, md)
|
||||
|
||||
def get_module_mapping(self):
|
||||
module_tp_rank = "module[tp_rank]."
|
||||
module_layer = module_tp_rank + "decoder.layers[layer_idx]."
|
||||
module_layer = "decoder.layers[layer_idx]."
|
||||
self.module_mapping = {
|
||||
"embedding": module_tp_rank + "embedding",
|
||||
"embedding_word_embeddings": module_tp_rank + "embedding.word_embeddings",
|
||||
"embedding_word_embeddings_norm": module_tp_rank + "embedding.word_embeddings.norm",
|
||||
"model": "module[tp_rank]",
|
||||
"embedding": "embedding",
|
||||
"embedding_word_embeddings": "embedding.word_embeddings",
|
||||
"embedding_word_embeddings_norm": "embedding.word_embeddings.norm",
|
||||
"embedding_position_embeddings": "embedding.position_embeddings",
|
||||
"model": "module",
|
||||
"layers_input_layernorm": module_layer + "input_layernorm",
|
||||
"layers": module_tp_rank + "decoder.layers",
|
||||
"layers": "decoder.layers",
|
||||
"layers_self_attention_linear_proj": module_layer + "self_attention.linear_proj",
|
||||
"layers_self_attention_linear_qkv": module_layer + "self_attention.linear_qkv",
|
||||
"layers_self_attention_q_layernorm": module_layer + "self_attention.q_layernorm",
|
||||
"layers_self_attention_k_layernorm": module_layer + "self_attention.k_layernorm",
|
||||
"layers_self_attention_post_attention_layernorm": module_layer + "pre_mlp_layernorm",
|
||||
"layers_self_attention_pre_mlp_layernorm": module_layer + "pre_mlp_layernorm",
|
||||
"layers_mlp_linear_fc1": module_layer + "mlp.linear_fc1",
|
||||
"layers_mlp_linear_fc2": module_layer + "mlp.linear_fc2",
|
||||
"final_layernorm": module_tp_rank + "decoder.final_layernorm",
|
||||
"output_layer": module_tp_rank + "output_layer"
|
||||
"layers_self_attention_post_mlp_layernorm": module_layer + "post_mlp_layernorm",
|
||||
"final_layernorm": "decoder.final_layernorm",
|
||||
"output_layer": "output_layer"
|
||||
}
|
||||
|
||||
def has_embedding_word_embeddings_norm(self):
|
||||
return hasattr(self.get_embedding_word_embeddings_module(), 'norm')
|
||||
|
||||
def has_embedding_position_embeddings(self):
|
||||
return hasattr(self.get_embedding_module(), 'position_embeddings')
|
||||
config_value = self.model_cfg.get(self.args_cmd.model_type_hf).get('config_set_value')
|
||||
if config_value.get('moe_flag', False):
|
||||
self.module_mapping["layers_mlp_router"] = module_layer + "mlp.router"
|
||||
self.module_mapping["layers_mlp_linear_fc1"] = module_layer + "mlp.experts.local_experts[expert_idx].linear_fc1"
|
||||
self.module_mapping["layers_mlp_linear_fc2"] = module_layer + "mlp.experts.local_experts[expert_idx].linear_fc2"
|
||||
|
||||
|
||||
def get_megatron_model(args_cmd, md=None):
|
||||
@ -571,5 +829,3 @@ def get_megatron_model(args_cmd, md=None):
|
||||
|
||||
def get_huggingface_model(args_cmd):
|
||||
return HuggingfaceModel(args_cmd)
|
||||
|
||||
|
||||
|
@ -16,13 +16,15 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import copy
|
||||
import logging as logger
|
||||
import torch
|
||||
from megatron.training.checkpointing import save_checkpoint
|
||||
from megatron.core import mpu
|
||||
from models import get_megatron_model
|
||||
|
||||
logger.basicConfig(format="")
|
||||
logger.getLogger().setLevel(logger.INFO)
|
||||
|
||||
|
||||
def add_arguments(parser):
|
||||
group = parser.add_argument_group(title='Megatron saver')
|
||||
@ -30,21 +32,25 @@ def add_arguments(parser):
|
||||
group.add_argument('--megatron-path', type=str, default=None,
|
||||
help='Base directory of Megatron repository')
|
||||
|
||||
group.add_argument('--target-tensor-parallel-size', type=int,
|
||||
group.add_argument('--target-tensor-parallel-size', type=int, default=1,
|
||||
help='Target tensor model parallel size, defaults to the tensor parallel size '
|
||||
'in the input checkpoint if provided by the loader, otherwise to 1')
|
||||
group.add_argument('--target-pipeline-parallel-size', type=int,
|
||||
group.add_argument('--target-pipeline-parallel-size', type=int, default=1,
|
||||
help='Target tensor model parallel size, default to the pipeline parall size '
|
||||
'in the input checkpoint if provided by the loader, otherwise to 1')
|
||||
group.add_argument('--save-model-type', type=str, default='megatron',
|
||||
choices=['megatron'], help='Save model type')
|
||||
choices=['megatron', 'huggingface'], help='Save model type')
|
||||
group.add_argument("--w-pack", type=bool,
|
||||
help='True is w_pack weight for llm',
|
||||
default=False)
|
||||
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
|
||||
help='Number of layers per virtual pipeline stage')
|
||||
group.add_argument('--target-expert-parallel-size', type=int, default=1,
|
||||
help='Number of layers per virtual pipeline stage')
|
||||
group.add_argument('--num-layer-list',
|
||||
type=str, help='a list of number of layers, seperated by comma; e.g., 4,4,4,4')
|
||||
group.add_argument('--use-mcore-models', action='store_true',
|
||||
help='Use the implementation from megatron core')
|
||||
|
||||
|
||||
def update_padded_vocab_size(md, model_mg, orig_tensor, orig_word_embed):
|
||||
@ -55,7 +61,7 @@ def update_padded_vocab_size(md, model_mg, orig_tensor, orig_word_embed):
|
||||
padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
|
||||
model_mg.set_padded_vocab_size(padded_vocab_size)
|
||||
else:
|
||||
print("Original vocab size not specified, leaving embedding table as-is. "
|
||||
logger.warning("Original vocab size not specified, leaving embedding table as-is. "
|
||||
"If you've changed the tensor parallel size this could cause problems.")
|
||||
model_mg.set_padded_vocab_size(orig_word_embed.shape[0])
|
||||
margs = model_mg.get_args()
|
||||
@ -90,56 +96,254 @@ def reset_cmd_args_from_md(args, md):
|
||||
if hasattr(md, 'previous_tensor_parallel_size'):
|
||||
args.target_tensor_parallel_size = md.previous_tensor_parallel_size
|
||||
else:
|
||||
print(
|
||||
"loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
|
||||
"Default to 1.")
|
||||
logger.warning("loader did not provide a tensor parallel size and "
|
||||
"--target-tensor-parallel-size not provided on command line. Default to 1.")
|
||||
args.target_tensor_parallel_size = 1
|
||||
|
||||
if args.target_pipeline_parallel_size is None:
|
||||
if hasattr(md, 'previous_pipeline_parallel_size'):
|
||||
args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
|
||||
else:
|
||||
print(
|
||||
"loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
|
||||
"Default to 1.")
|
||||
logger.warning(
|
||||
"loader did not provide a pipeline parallel size and "
|
||||
"--target-pipeline-parallel-size not provided on command line. Default to 1.")
|
||||
args.target_pipeline_parallel_size = 1
|
||||
|
||||
|
||||
def set_model_preprocess(model, embeddings_msg, check_message):
|
||||
def set_model_preprocess(model, embeddings_msg):
|
||||
md = model.get_metadata()
|
||||
margs = model.get_args()
|
||||
pos_embed = None
|
||||
tp_size = margs.tensor_model_parallel_size
|
||||
ep_size = margs.expert_model_parallel_size
|
||||
if md.position_embedding_type == 'learned_absolute':
|
||||
pos_embed = embeddings_msg.pop("position embeddings")
|
||||
orig_word_embed = embeddings_msg.pop("word embeddings")
|
||||
pos_embed = embeddings_msg.pop(f"position embeddings")
|
||||
orig_word_embed = embeddings_msg.pop(f"word embeddings")
|
||||
orig_word_embed_n_w, orig_word_embed_n_b = None, None
|
||||
if "word embeddings norm_w" in embeddings_msg and "word embeddings norm_b" in embeddings_msg:
|
||||
orig_word_embed_n_w = embeddings_msg.pop("word embeddings norm_w")
|
||||
orig_word_embed_n_b = embeddings_msg.pop("word embeddings norm_b")
|
||||
check_message(embeddings_msg)
|
||||
|
||||
if md.true_vocab_size is not None:
|
||||
orig_vocab_size = orig_word_embed.shape[0]
|
||||
full_word_embed = vocab_padding(orig_vocab_size, margs.padded_vocab_size, orig_word_embed)
|
||||
else:
|
||||
full_word_embed = orig_word_embed
|
||||
|
||||
# Split into new tensor model parallel sizes tensor_model_parallel_size
|
||||
out_word_embed = torch.chunk(full_word_embed, margs.tensor_model_parallel_size, dim=0)
|
||||
|
||||
modules_count = model.get_modules_count()
|
||||
for tp_rank in range(modules_count):
|
||||
model.set_embedding_word_embeddings_weight(tp_rank=tp_rank, data=out_word_embed[tp_rank])
|
||||
if orig_word_embed_n_w is not None:
|
||||
model.set_embedding_word_embeddings_norm_weight(tp_rank=tp_rank, data=orig_word_embed_n_w)
|
||||
model.set_embedding_word_embeddings_norm_bias(tp_rank=tp_rank, data=orig_word_embed_n_b)
|
||||
if pos_embed is not None:
|
||||
model.set_embedding_position_embeddings_weight(tp_rank=tp_rank, data=pos_embed)
|
||||
if "word embeddings norm_w" in embeddings_msg:
|
||||
orig_word_embed_n_w = embeddings_msg.pop(f"word embeddings norm_w")
|
||||
if "word embeddings norm_b" in embeddings_msg:
|
||||
orig_word_embed_n_b = embeddings_msg.pop(f"word embeddings norm_b")
|
||||
out_word_embed_list = []
|
||||
for ep_rank in range(ep_size):
|
||||
if md.true_vocab_size is not None:
|
||||
orig_vocab_size = orig_word_embed.shape[0]
|
||||
full_word_embed = vocab_padding(orig_vocab_size, margs.padded_vocab_size, orig_word_embed)
|
||||
else:
|
||||
if hasattr(model.get_embedding_module(), 'position_embeddings'):
|
||||
raise ValueError("model should have position_embeddings")
|
||||
full_word_embed = orig_word_embed
|
||||
|
||||
return out_word_embed
|
||||
# Split into new tensor model parallel sizes tensor_model_parallel_size
|
||||
out_word_embed = torch.chunk(full_word_embed, margs.tensor_model_parallel_size, dim=0)
|
||||
for tp_rank in range(tp_size):
|
||||
model.set_embedding_word_embeddings_weight(ep_rank=ep_rank, tp_rank=tp_rank, data=out_word_embed[tp_rank])
|
||||
if orig_word_embed_n_w is not None:
|
||||
model.set_embedding_word_embeddings_norm_weight(ep_rank=ep_rank, tp_rank=tp_rank, data=orig_word_embed_n_w)
|
||||
if orig_word_embed_n_b is not None:
|
||||
model.set_embedding_word_embeddings_norm_bias(ep_rank=ep_rank, tp_rank=tp_rank, data=orig_word_embed_n_b)
|
||||
if pos_embed is not None:
|
||||
model.set_embedding_position_embeddings_weight(ep_rank=ep_rank, tp_rank=tp_rank, data=pos_embed)
|
||||
else:
|
||||
if hasattr(model.get_embedding_module(), 'position_embeddings'):
|
||||
raise ValueError("model should have position_embeddings")
|
||||
|
||||
out_word_embed_list.append(out_word_embed)
|
||||
|
||||
return out_word_embed_list
|
||||
|
||||
|
||||
def set_model_layer_norm(model_mg, msg, md, **kwargs):
|
||||
# duplicated tensors
|
||||
input_norm_weight = msg.pop("input norm weight")
|
||||
post_norm_weight = msg.pop("post norm weight")
|
||||
input_norm_bias = None
|
||||
post_norm_bias = None
|
||||
if md.norm_has_bias:
|
||||
input_norm_bias = msg.pop("input norm bias")
|
||||
if md.norm_has_bias:
|
||||
post_norm_bias = msg.pop("post norm bias")
|
||||
|
||||
margs = model_mg.get_args()
|
||||
|
||||
# Save them to the model
|
||||
for ep_rank in range(margs.expert_model_parallel_size):
|
||||
kwargs["ep_rank"] = ep_rank
|
||||
for tp_rank in range(margs.tensor_model_parallel_size):
|
||||
kwargs["tp_rank"] = tp_rank
|
||||
|
||||
model_mg.set_layers_input_layernorm_weight(**kwargs, data=input_norm_weight)
|
||||
if input_norm_bias:
|
||||
model_mg.set_layers_input_layernorm_bias(**kwargs, data=input_norm_bias)
|
||||
model_mg.set_layers_self_attention_pre_mlp_layernorm_weight(**kwargs, data=post_norm_weight)
|
||||
if post_norm_bias:
|
||||
model_mg.set_layers_self_attention_pre_mlp_layernorm_bias(**kwargs, data=post_norm_bias)
|
||||
|
||||
|
||||
def set_model_layer_attn(model_mg, msg, md, **kwargs):
|
||||
# duplicated tensors
|
||||
margs = model_mg.get_args()
|
||||
if md.linear_bias or margs.add_dense_bias:
|
||||
dense_bias = msg.pop("dense bias")
|
||||
if md.linear_bias or margs.add_qkv_bias:
|
||||
qkv_bias = torch.chunk(msg.pop("qkv bias"), margs.tensor_model_parallel_size, dim=0)
|
||||
|
||||
qkv_org = msg.pop("qkv weight")
|
||||
qkv_weight = torch.chunk(qkv_org, margs.tensor_model_parallel_size, dim=0)
|
||||
|
||||
# Split up the parallel tensors
|
||||
dense_weight = torch.chunk(msg.pop("dense weight"), margs.tensor_model_parallel_size, dim=1)
|
||||
|
||||
# Save them to the model
|
||||
for ep_rank in range(margs.expert_model_parallel_size):
|
||||
kwargs["ep_rank"] = ep_rank
|
||||
for tp_rank in range(margs.tensor_model_parallel_size):
|
||||
kwargs["tp_rank"] = tp_rank
|
||||
model_mg.set_layers_self_attention_linear_qkv_weight(**kwargs, data=qkv_weight[tp_rank])
|
||||
model_mg.set_layers_self_attention_linear_proj_weight(**kwargs, data=dense_weight[tp_rank])
|
||||
|
||||
if md.linear_bias:
|
||||
model_mg.set_layers_self_attention_linear_qkv_bias(**kwargs, data=qkv_bias[tp_rank])
|
||||
model_mg.set_layers_self_attention_linear_proj_bias(**kwargs, data=dense_bias)
|
||||
|
||||
if margs.add_qkv_bias:
|
||||
model_mg.set_layers_self_attention_linear_qkv_bias(**kwargs, data=qkv_bias[tp_rank])
|
||||
if margs.add_dense_bias:
|
||||
model_mg.set_layers_self_attention_linear_proj_bias(**kwargs, data=dense_bias)
|
||||
|
||||
|
||||
def _set_set_model_layer_mlp(model_mg, msg, md, **kwargs):
|
||||
margs = model_mg.get_args()
|
||||
num_experts_local = 1
|
||||
if margs.num_experts:
|
||||
num_experts_local = margs.num_experts // margs.expert_model_parallel_size
|
||||
# Save them to the model
|
||||
|
||||
if md.linear_bias:
|
||||
mlp_l1_bias = msg.pop(f"mlp l1 bias")
|
||||
# Split up the parallel tensors
|
||||
mlp_l1_weight = torch.chunk(msg.pop(f"mlp l1 weight"), margs.tensor_model_parallel_size, dim=1)
|
||||
|
||||
# Special handling for swiglu
|
||||
if md.swiglu:
|
||||
mlp_l0_weight_W = torch.chunk(msg.pop(f"mlp l0 weight W"), margs.tensor_model_parallel_size, dim=0)
|
||||
mlp_l0_weight_V = torch.chunk(msg.pop(f"mlp l0 weight V"), margs.tensor_model_parallel_size, dim=0)
|
||||
mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)]
|
||||
else:
|
||||
mlp_l0_weight = torch.chunk(msg.pop(f"mlp l0 weight"), margs.tensor_model_parallel_size, dim=0)
|
||||
if md.linear_bias:
|
||||
if md.swiglu:
|
||||
mlp_l0_bias_W = torch.chunk(msg.pop(f"mlp l0 bias W"), margs.tensor_model_parallel_size, dim=0)
|
||||
mlp_l0_bias_V = torch.chunk(msg.pop(f"mlp l0 bias V"), margs.tensor_model_parallel_size, dim=0)
|
||||
mlp_l0_bias = [torch.cat(bias, dim=0) for bias in zip(mlp_l0_bias_W, mlp_l0_bias_V)]
|
||||
else:
|
||||
mlp_l0_bias = torch.chunk(msg.pop(f"mlp l0 bias"), margs.tensor_model_parallel_size, dim=0)
|
||||
|
||||
# duplicated tensors
|
||||
for tp_rank in range(margs.tensor_model_parallel_size):
|
||||
kwargs["tp_rank"] = tp_rank
|
||||
model_mg.set_layers_mlp_linear_fc1_weight(**kwargs, data=mlp_l0_weight[tp_rank])
|
||||
model_mg.set_layers_mlp_linear_fc2_weight(**kwargs, data=mlp_l1_weight[tp_rank])
|
||||
|
||||
if md.linear_bias:
|
||||
model_mg.set_layers_mlp_linear_fc1_bias(**kwargs, data=mlp_l0_bias[tp_rank])
|
||||
model_mg.set_layers_mlp_linear_fc2_bias(**kwargs, data=mlp_l1_bias)
|
||||
|
||||
|
||||
def set_model_layer_mlp(model_mg, msg, md, **kwargs):
|
||||
margs = model_mg.get_args()
|
||||
if margs.num_experts:
|
||||
num_experts_local = margs.num_experts // margs.expert_model_parallel_size
|
||||
mlp_moe = msg.pop("mlp_moe")
|
||||
mlp_router_weight = mlp_moe.pop("mlp router weight")
|
||||
for ep_rank in range(margs.expert_model_parallel_size):
|
||||
kwargs["ep_rank"] = ep_rank
|
||||
for tp_rank in range(margs.tensor_model_parallel_size):
|
||||
kwargs['tp_rank'] = tp_rank
|
||||
model_mg.set_layers_mlp_router_weight(**kwargs, data=mlp_router_weight)
|
||||
for expert_idx in range(num_experts_local):
|
||||
kwargs["expert_idx"] = expert_idx
|
||||
global_expert_idx = expert_idx + ep_rank * num_experts_local
|
||||
expert = mlp_moe.pop(f"expert {global_expert_idx}")
|
||||
_set_set_model_layer_mlp(model_mg, expert, md, **kwargs)
|
||||
else:
|
||||
_set_set_model_layer_mlp(model_mg, msg, md, **kwargs)
|
||||
|
||||
|
||||
def set_model_postprocess(model_mg, msg, md, out_word_embed_list, **kwargs):
|
||||
margs = model_mg.get_args()
|
||||
tp_size = margs.tensor_model_parallel_size
|
||||
ep_size = margs.expert_model_parallel_size
|
||||
final_norm_weight = msg.pop(f"weight")
|
||||
final_norm_bias = None
|
||||
if md.norm_has_bias:
|
||||
final_norm_bias = msg.pop(f"bias")
|
||||
for ep_rank in range(ep_size):
|
||||
kwargs["ep_rank"] = ep_rank
|
||||
for tp_rank in range(tp_size):
|
||||
kwargs["tp_rank"] = tp_rank
|
||||
model_mg.set_final_layernorm_weight(**kwargs, data=final_norm_weight)
|
||||
if final_norm_bias:
|
||||
model_mg.set_final_layernorm_bias(**kwargs, data=final_norm_bias)
|
||||
if kwargs.get("pp_rank", 0) and not md.output_layer:
|
||||
# Copy word embeddings to final pipeline rank
|
||||
model_mg.set_output_layer_weight(**kwargs, data=out_word_embed_list[ep_rank][tp_rank])
|
||||
del final_norm_weight
|
||||
if final_norm_bias:
|
||||
del final_norm_bias
|
||||
|
||||
|
||||
def set_model_output_layer(model_mg, msg, md, **kwargs):
|
||||
margs = model_mg.get_args()
|
||||
tp_size = margs.tensor_model_parallel_size
|
||||
ep_size = margs.expert_model_parallel_size
|
||||
output_layer = msg.pop(f"weight")
|
||||
for ep_rank in range(ep_size):
|
||||
kwargs["ep_rank"] = ep_rank
|
||||
if md.true_vocab_size is not None:
|
||||
orig_vocab_size = output_layer.shape[0]
|
||||
full_word_embed = vocab_padding(orig_vocab_size, margs.padded_vocab_size, output_layer)
|
||||
else:
|
||||
full_word_embed = output_layer
|
||||
output_layer_weight = torch.chunk(full_word_embed, margs.tensor_model_parallel_size, dim=0)
|
||||
for tp_rank in range(tp_size):
|
||||
kwargs["tp_rank"] = tp_rank
|
||||
model_mg.set_output_layer_weight(**kwargs, data=output_layer_weight[tp_rank])
|
||||
|
||||
|
||||
def save_model(model_mg, md, **kwargs):
|
||||
margs = model_mg.get_args()
|
||||
args_cmd = model_mg.get_args_cmd()
|
||||
virtual_pipeline_model_parallel_size = margs.virtual_pipeline_model_parallel_size
|
||||
if virtual_pipeline_model_parallel_size is None:
|
||||
virtual_pipeline_model_parallel_size = 1
|
||||
for ep_rank in range(margs.expert_model_parallel_size):
|
||||
model_mg.set_expert_model_parallel_rank(ep_rank)
|
||||
kwargs["ep_rank"] = ep_rank
|
||||
for tp_rank in range(margs.tensor_model_parallel_size):
|
||||
model_mg.set_tensor_model_parallel_rank(tp_rank)
|
||||
kwargs["tp_rank"] = tp_rank
|
||||
vp_models = []
|
||||
for vp_rank in range(virtual_pipeline_model_parallel_size):
|
||||
kwargs["vp_rank"] = vp_rank
|
||||
vp_models.append(model_mg.get_model_item(**kwargs))
|
||||
if args_cmd.save_model_type == 'megatron':
|
||||
# Split the PP into multiple VPPs and select the corresponding layers for each VPP by copying and deleting
|
||||
save_checkpoint(md.iteration, vp_models, None, None, 0)
|
||||
elif args_cmd.save_model_type == "huggingface":
|
||||
save_huggingface(args_cmd, model_mg)
|
||||
|
||||
|
||||
def save_huggingface(args, model):
|
||||
'''Set model params.'''
|
||||
from models import get_huggingface_model
|
||||
model_hf = get_huggingface_model(args)
|
||||
model_hf.get_modules_from_pretrained()
|
||||
args_cmd = model_hf.get_args_cmd()
|
||||
|
||||
model_hf.update_module(model)
|
||||
|
||||
save_dir = os.path.join(args_cmd.save_dir, 'mg2hf')
|
||||
logger.info(f'save weight to {save_dir}')
|
||||
model_hf.get_model_item().save_pretrained(save_dir)
|
||||
|
||||
|
||||
def save_model_checkpoint(queue, args):
|
||||
@ -154,14 +358,14 @@ def save_model_checkpoint(queue, args):
|
||||
def queue_get(name=None):
|
||||
val = queue.get()
|
||||
if val == "exit":
|
||||
print("Loader exited, exiting saver")
|
||||
logger.error("Loader exited, exiting saver")
|
||||
exit(1)
|
||||
if name is not None and args.checking and val["name"] != name:
|
||||
val_name = val["name"]
|
||||
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
|
||||
logger.error(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
|
||||
exit(1)
|
||||
if name is not None:
|
||||
print(f"received {name}")
|
||||
logger.info(f"received {name}")
|
||||
return val
|
||||
|
||||
def check_message(msg):
|
||||
@ -169,10 +373,10 @@ def save_model_checkpoint(queue, args):
|
||||
return
|
||||
msg_name = msg.pop("name")
|
||||
if len(msg.keys()) > 0:
|
||||
print(f"Unexpected values in {msg_name}:")
|
||||
logger.error(f"Unexpected values in {msg_name}:")
|
||||
for key in msg.keys():
|
||||
print(f" {key}")
|
||||
print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
|
||||
logger.error(f" {key}")
|
||||
logger.error(f"Exiting. If you want to ignore this, use the argument --no-checking.")
|
||||
exit(1)
|
||||
|
||||
md = queue_get()
|
||||
@ -185,182 +389,59 @@ def save_model_checkpoint(queue, args):
|
||||
|
||||
# We want all arguments to come from us
|
||||
model_mg = get_megatron_model(args_cmd=args, md=md)
|
||||
model_mg.initialize_megatron_args(queue=queue)
|
||||
model_mg.initialize_megatron_args(queue=queue, saver_megatron=True)
|
||||
|
||||
# Make models for first pipeline stage and fill in embeddings
|
||||
mpu.set_pipeline_model_parallel_rank(0)
|
||||
post_process = args.target_pipeline_parallel_size == 1
|
||||
model_mg.get_modules_from_config(args.target_tensor_parallel_size, pre_process=True, post_process=post_process)
|
||||
model_mg.get_modules_from_config(pp_stage_cache_flag=True)
|
||||
|
||||
# Embeddings
|
||||
embeddings_msg = queue_get("embeddings")
|
||||
out_word_embed = set_model_preprocess(model_mg, embeddings_msg, check_message)
|
||||
|
||||
out_word_embed_list = set_model_preprocess(model_mg, embeddings_msg)
|
||||
check_message(embeddings_msg)
|
||||
margs = model_mg.get_args()
|
||||
|
||||
# Transformer layers
|
||||
# -------------------
|
||||
total_layer_num = 0
|
||||
lst = []
|
||||
if args.num_layers_per_virtual_pipeline_stage and args.save_model_type == 'megatron':
|
||||
times = 3
|
||||
while queue.qsize() > 3 or times >= 0:
|
||||
if times >= 0:
|
||||
time.sleep(1)
|
||||
times -= 1
|
||||
continue
|
||||
lst.append(queue.get())
|
||||
for pp_rank in range(args.target_pipeline_parallel_size):
|
||||
# For later pipeline parallel ranks, make the new models
|
||||
if pp_rank > 0:
|
||||
|
||||
virtual_pipeline_model_parallel_size = margs.virtual_pipeline_model_parallel_size
|
||||
if virtual_pipeline_model_parallel_size is None:
|
||||
virtual_pipeline_model_parallel_size = 1
|
||||
|
||||
for vp_rank in range(virtual_pipeline_model_parallel_size):
|
||||
model_mg.set_virtual_pipeline_model_parallel_rank(vp_rank)
|
||||
kwargs = {"vp_rank": vp_rank}
|
||||
for pp_rank in range(args.target_pipeline_parallel_size):
|
||||
# For later pipeline parallel ranks, make the new models
|
||||
mpu.set_pipeline_model_parallel_rank(pp_rank)
|
||||
post_process = pp_rank == args.target_pipeline_parallel_size - 1
|
||||
model_mg.get_modules_from_config(args.target_tensor_parallel_size, False, post_process)
|
||||
|
||||
if args.num_layers_per_virtual_pipeline_stage and args.save_model_type == 'megatron':
|
||||
vp_size = margs.num_layers // args.target_pipeline_parallel_size // args.num_layers_per_virtual_pipeline_stage
|
||||
else:
|
||||
vp_size = 1
|
||||
for vpp_rank in range(vp_size):
|
||||
for layer in range(len(model_mg.get_layers_module()) // vp_size):
|
||||
if args.num_layers_per_virtual_pipeline_stage and args.save_model_type == 'megatron':
|
||||
# The execution order between layers in the VPP model is different from that in the PP model. Here,
|
||||
# it is necessary to calculate the index and arrange the layers in the actual execution order
|
||||
total_layer_num = args.target_pipeline_parallel_size * vpp_rank * args.num_layers_per_virtual_pipeline_stage + pp_rank * args.num_layers_per_virtual_pipeline_stage + layer
|
||||
msg = lst[total_layer_num]
|
||||
else:
|
||||
msg = queue_get(f"transformer layer {total_layer_num}")
|
||||
|
||||
# duplicated tensors
|
||||
input_norm_weight = msg.pop("input norm weight")
|
||||
if md.norm_has_bias:
|
||||
input_norm_bias = msg.pop("input norm bias")
|
||||
post_norm_weight = msg.pop("post norm weight")
|
||||
if md.norm_has_bias:
|
||||
post_norm_bias = msg.pop("post norm bias")
|
||||
if md.linear_bias:
|
||||
dense_bias = msg.pop("dense bias")
|
||||
mlp_l1_bias = msg.pop("mlp l1 bias")
|
||||
|
||||
if args.add_qkv_bias:
|
||||
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
|
||||
if args.add_dense_bias:
|
||||
dense_bias = msg.pop("dense bias")
|
||||
|
||||
qkv_org = msg.pop("qkv weight")
|
||||
qkv_weight = torch.chunk(qkv_org, args.target_tensor_parallel_size, dim=0)
|
||||
|
||||
# Split up the parallel tensors
|
||||
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
|
||||
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
|
||||
|
||||
# Special handling for swiglu
|
||||
if md.swiglu:
|
||||
mlp_l0_weight_W = torch.chunk(msg.pop("mlp l0 weight W"), args.target_tensor_parallel_size, dim=0)
|
||||
mlp_l0_weight_V = torch.chunk(msg.pop("mlp l0 weight V"), args.target_tensor_parallel_size, dim=0)
|
||||
mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)]
|
||||
else:
|
||||
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
|
||||
|
||||
if md.linear_bias:
|
||||
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
|
||||
if md.swiglu:
|
||||
mlp_l0_bias_W = torch.chunk(msg.pop("mlp l0 bias W"), args.target_tensor_parallel_size, dim=0)
|
||||
mlp_l0_bias_V = torch.chunk(msg.pop("mlp l0 bias V"), args.target_tensor_parallel_size, dim=0)
|
||||
mlp_l0_bias = [torch.cat(bias, dim=0) for bias in zip(mlp_l0_bias_W, mlp_l0_bias_V)]
|
||||
else:
|
||||
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
|
||||
|
||||
# Save them to the model
|
||||
for tp_rank in range(args.target_tensor_parallel_size):
|
||||
if args.num_layers_per_virtual_pipeline_stage and args.save_model_type == 'megatron':
|
||||
l_idx = vpp_rank * args.num_layers_per_virtual_pipeline_stage + layer
|
||||
else:
|
||||
l_idx = layer
|
||||
|
||||
model_mg.set_layers_input_layernorm_weight(tp_rank=tp_rank, layer_idx=l_idx, data=input_norm_weight)
|
||||
if md.norm_has_bias:
|
||||
model_mg.set_layers_input_layernorm_bias(tp_rank=tp_rank, layer_idx=l_idx, data=input_norm_bias)
|
||||
model_mg.set_layers_self_attention_linear_qkv_weight(tp_rank=tp_rank, layer_idx=l_idx, data=qkv_weight[tp_rank])
|
||||
model_mg.set_layers_self_attention_linear_proj_weight(tp_rank=tp_rank, layer_idx=l_idx, data=dense_weight[tp_rank])
|
||||
model_mg.set_layers_self_attention_pre_mlp_layernorm_weight(tp_rank=tp_rank, layer_idx=l_idx, data=post_norm_weight)
|
||||
if md.norm_has_bias:
|
||||
model_mg.set_layers_self_attention_pre_mlp_layernorm_bias(tp_rank=tp_rank, layer_idx=l_idx, data=post_norm_bias)
|
||||
|
||||
model_mg.set_layers_mlp_linear_fc1_weight(tp_rank=tp_rank, layer_idx=l_idx, data=mlp_l0_weight[tp_rank])
|
||||
model_mg.set_layers_mlp_linear_fc2_weight(tp_rank=tp_rank, layer_idx=l_idx, data=mlp_l1_weight[tp_rank])
|
||||
|
||||
if md.linear_bias:
|
||||
model_mg.set_layers_self_attention_linear_qkv_bias(tp_rank=tp_rank, layer_idx=l_idx, data=qkv_bias[tp_rank])
|
||||
model_mg.set_layers_self_attention_linear_proj_bias(tp_rank=tp_rank, layer_idx=l_idx, data=dense_bias)
|
||||
model_mg.set_layers_mlp_linear_fc1_bias(tp_rank=tp_rank, layer_idx=l_idx, data=mlp_l0_bias[tp_rank])
|
||||
model_mg.set_layers_mlp_linear_fc2_bias(tp_rank=tp_rank, layer_idx=l_idx, data=mlp_l1_bias)
|
||||
|
||||
if args.add_qkv_bias:
|
||||
model_mg.set_layers_self_attention_linear_qkv_bias(tp_rank=tp_rank, layer_idx=l_idx, data=qkv_bias[tp_rank])
|
||||
if args.add_dense_bias:
|
||||
model_mg.set_layers_self_attention_linear_proj_bias(tp_rank=tp_rank, layer_idx=l_idx, data=dense_bias)
|
||||
model_mg.get_modules_from_config(pp_stage_cache_flag=True)
|
||||
kwargs["pp_rank"] = pp_rank
|
||||
for layer in range(len(model_mg.get_layers_module())):
|
||||
kwargs["layer_idx"] = layer
|
||||
msg = queue_get(f"transformer layer {total_layer_num}")
|
||||
set_model_layer_norm(model_mg, msg, md, **kwargs)
|
||||
set_model_layer_attn(model_mg, msg, md, **kwargs)
|
||||
set_model_layer_mlp(model_mg, msg, md, **kwargs)
|
||||
|
||||
total_layer_num = total_layer_num + 1
|
||||
check_message(msg)
|
||||
|
||||
if post_process:
|
||||
msg = queue_get("final norm")
|
||||
final_norm_weight = msg.pop("weight")
|
||||
if md.norm_has_bias:
|
||||
final_norm_bias = msg.pop("bias")
|
||||
for tp_rank in range(args.target_tensor_parallel_size):
|
||||
model_mg.set_final_layernorm_weight(tp_rank=tp_rank, data=final_norm_weight)
|
||||
if md.norm_has_bias:
|
||||
model_mg.set_final_layernorm_bias(tp_rank=tp_rank, data=final_norm_bias)
|
||||
if pp_rank != 0 and not md.output_layer:
|
||||
# Copy word embeddings to final pipeline rank
|
||||
model_mg.set_output_layer_weight(tp_rank=tp_rank, data=out_word_embed[tp_rank])
|
||||
del final_norm_weight
|
||||
if md.norm_has_bias:
|
||||
del final_norm_bias
|
||||
check_message(msg)
|
||||
|
||||
if md.output_layer:
|
||||
msg = queue_get("output layer")
|
||||
output_layer = msg.pop("weight")
|
||||
if md.true_vocab_size is not None:
|
||||
orig_vocab_size = output_layer.shape[0]
|
||||
full_word_embed = vocab_padding(orig_vocab_size, margs.padded_vocab_size, output_layer)
|
||||
else:
|
||||
full_word_embed = output_layer
|
||||
|
||||
output_layer_weight = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
|
||||
|
||||
for tp_rank in range(args.target_tensor_parallel_size):
|
||||
model_mg.set_output_layer_weight(tp_rank=tp_rank, data=output_layer_weight[tp_rank])
|
||||
del output_layer_weight
|
||||
post_process = (
|
||||
(pp_rank == args.target_pipeline_parallel_size - 1) &
|
||||
(vp_rank == virtual_pipeline_model_parallel_size - 1)
|
||||
)
|
||||
if post_process:
|
||||
msg = queue_get("final norm")
|
||||
set_model_postprocess(model_mg, msg, md, out_word_embed_list, **kwargs)
|
||||
check_message(msg)
|
||||
msg = queue_get()
|
||||
|
||||
for tp_rank in range(args.target_tensor_parallel_size):
|
||||
mpu.set_tensor_model_parallel_rank(tp_rank)
|
||||
# Split the PP into multiple VPPs and select the corresponding layers for each VPP by copying and deleting
|
||||
if args.num_layers_per_virtual_pipeline_stage:
|
||||
vp_models = []
|
||||
layers = margs.num_layers // args.target_pipeline_parallel_size
|
||||
for vp_rank in range(vp_size):
|
||||
model = copy.deepcopy(model_mg.get_model_module(tp_rank=tp_rank))
|
||||
left = vp_rank * args.num_layers_per_virtual_pipeline_stage
|
||||
right = (vp_rank + 1) * args.num_layers_per_virtual_pipeline_stage
|
||||
for i in range(layers - 1, -1, -1):
|
||||
if i >= right or i < left:
|
||||
del model.decoder.layers[i]
|
||||
if right < layers and pp_rank == args.target_pipeline_parallel_size - 1:
|
||||
del model.decoder.final_layernorm
|
||||
if getattr(model, "output_layer", None):
|
||||
model.post_process = False
|
||||
del model.output_layer
|
||||
if pp_rank == 0 and vp_rank > 0:
|
||||
model.pre_process = False
|
||||
del model.embedding
|
||||
vp_models.append(model)
|
||||
save_checkpoint(md.iteration, vp_models, None, None, 0)
|
||||
else:
|
||||
save_checkpoint(md.iteration, [model_mg.get_model_module(tp_rank=tp_rank)], None, None, 0)
|
||||
print("Done!")
|
||||
if md.output_layer:
|
||||
msg = queue_get("output layer")
|
||||
set_model_output_layer(model_mg, msg, md, **kwargs)
|
||||
check_message(msg)
|
||||
|
||||
if vp_rank == virtual_pipeline_model_parallel_size - 1:
|
||||
save_model(model_mg, md, **kwargs)
|
||||
logger.info("Done!")
|
||||
|
Loading…
Reference in New Issue
Block a user