!1548 【mcore框架 新增 mg2mg & mg2hf & ep等特性】

Merge pull request !1548 from glhyy/master
This commit is contained in:
glhyy 2024-08-23 07:44:29 +00:00 committed by i-robot
parent a08bb1cd12
commit c5bd9b2b58
8 changed files with 1103 additions and 389 deletions

View File

@ -17,6 +17,8 @@ import os
from functools import wraps
from megatron.training import get_args
from megatron.training.utils import print_rank_0
from megatron.training.checkpointing import _load_base_checkpoint
from .tasks.finetune.lora.utils import is_enable_lora, merge_dicts, modify_keys_with_dict
@ -54,3 +56,44 @@ def load_checkpoint_wrapper(fn):
return fn(*args, **kwargs)
return wrapper
def load_args_from_checkpoint_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
res = fn(*args, **kwargs)
if len(res) == 1:
return res
args, checkpoint_args = res
def _set_arg(arg_name, old_arg_name=None, force=False):
if not force and getattr(args, arg_name, None) is not None:
return
if old_arg_name is not None:
checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
else:
checkpoint_value = getattr(checkpoint_args, arg_name, None)
if checkpoint_value is not None:
print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
setattr(args, arg_name, checkpoint_value)
else:
print_rank_0(f"Checkpoint did not provide arguments {arg_name}")
_set_arg('num_layer_list', force=True)
_set_arg('post_norm', force=True)
_set_arg('num_experts')
_set_arg('sequence_parallel', force=True)
state_dict, checkpoint_name, release = _load_base_checkpoint(
getattr(args, kwargs.get('load_arg', 'load')),
rank0=True,
exit_on_missing_checkpoint=kwargs.get('exit_on_missing_checkpoint', False),
checkpoint_step=args.ckpt_step
)
checkpoint_version = state_dict.get('checkpoint_version', 0)
if checkpoint_version >= 3.0:
_set_arg('expert_model_parallel_size', force=True)
return args, checkpoint_args
return wrapper

View File

@ -45,7 +45,7 @@ from ..core.pipeline_parallel.p2p_communication import _batched_p2p_ops
from ..data import build_pretraining_data_loader
from ..tokenizer import build_tokenizer
from ..arguments import parse_args_decorator
from ..checkpointing import _load_base_checkpoint_wrapper, load_checkpoint_wrapper
from ..checkpointing import _load_base_checkpoint_wrapper, load_checkpoint_wrapper, load_args_from_checkpoint_wrapper
from ..initialize import initialize_megatron
from ..utils import emit
from ..arguments import process_args
@ -290,6 +290,7 @@ def patch_model():
# patch language model
PatchManager.register_patch('megatron.legacy.model.language_model.TransformerLanguageModel.forward', transformer_language_model_forward_wrapper)
PatchManager.register_patch('megatron.legacy.model.language_model.TransformerLanguageModel.__init__', transformer_language_model_init)
PatchManager.register_patch('megatron.training.checkpointing.load_args_from_checkpoint', load_args_from_checkpoint_wrapper)
def patch_initialize():

View File

@ -62,6 +62,8 @@ def main():
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
parser.add_argument('--model-type-hf', type=str, default="llama2",
choices=['llama2', 'mixtral', 'chatglm3'], help='model-type')
known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)

View File

@ -16,11 +16,15 @@
import os
import sys
import types
import logging as logger
import torch
import transformers
from models import get_megatron_model
from models import get_huggingface_model
logger.basicConfig(format="")
logger.getLogger().setLevel(logger.INFO)
def add_arguments(parser):
group = parser.add_argument_group(title='Llama-2 HF loader.')
@ -51,8 +55,6 @@ def add_arguments(parser):
'This is added for computational efficiency reasons.')
group.add_argument('--use-mcore-models', action='store_true',
help='Use the implementation from megatron core')
group.add_argument('--model-type-hf', type=str,
help='huggingface model type e.g., llama2, qwen, ...')
def verify_transformers_version():
@ -87,6 +89,7 @@ def build_metadata(args, margs):
md.consumed_train_samples = 0
md.consumed_valid_samples = 0
md.embed_layernorm = margs.embed_layernorm
md.disable_bias_linear = margs.disable_bias_linear
return md
@ -98,14 +101,14 @@ def get_message_preprocess(model, md):
}
# bloom
if model.has_embedding_word_embeddings_norm():
if model.has_embedding_word_embeddings_norm_module():
message["word embeddings norm_w"] = model.get_embedding_word_embeddings_norm_weight()
message["word embeddings norm_b"] = model.get_embedding_word_embeddings_norm_bias()
if md.position_embedding_type == 'learned_absolute':
message["position embeddings"] = model.get_embedding_position_embeddings_weight()
else:
if model.has_embedding_position_embeddings():
if model.has_embedding_position_embeddings_module():
raise ValueError("model should have position_embeddings")
return message
@ -149,15 +152,15 @@ def get_message_layer_attn(message, model, layer_idx, md=None, args=None):
return message
def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
def _get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1, **kwargs):
# Grab all parallel tensors for this layer.
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
mlp_l0_weight.append(model.get_layers_mlp_linear_fc1_weight(layer_idx=layer_idx))
mlp_l1_weight.append(model.get_layers_mlp_linear_fc2_weight(layer_idx=layer_idx))
mlp_l0_weight.append(model.get_layers_mlp_linear_fc1_weight(layer_idx=layer_idx, **kwargs))
mlp_l1_weight.append(model.get_layers_mlp_linear_fc2_weight(layer_idx=layer_idx, **kwargs))
if md.linear_bias:
mlp_l0_bias.append(model.get_layers_mlp_linear_fc1_bias(layer_idx=layer_idx))
mlp_l0_bias.append(model.get_layers_mlp_linear_fc1_bias(layer_idx=layer_idx, **kwargs))
# Handle gated linear units.
if md.swiglu:
# Concat all the first halves ('W's) and all the second halves ('V's).
@ -171,7 +174,7 @@ def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
# Simple concat of the rest.
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
if md.linear_bias:
message["mlp l1 bias"] = model.get_layers_mlp_linear_fc2_bias(layer_idx=layer_idx)
message["mlp l1 bias"] = model.get_layers_mlp_linear_fc2_bias(layer_idx=layer_idx, **kwargs)
if md.swiglu:
for tp_rank in range(tp_size):
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
@ -183,6 +186,22 @@ def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
return message
def get_message_layer_mlp(message, model, layer_idx, md=None, tp_size=1):
margs = model.get_args()
if margs.num_experts:
# return _get_message_layer_mlp(message, model, layer_idx, md=md, tp_size=tp_size)
message["mlp_moe"] = {}
mlp_router_weight = model.get_layers_mlp_router_weight(layer_idx=layer_idx)
message["mlp_moe"]["mlp router weight"] = mlp_router_weight
for expert_idx in range(margs.num_experts):
kwargs = {'expert_idx': expert_idx}
expert = _get_message_layer_mlp({}, model, layer_idx, md=md, tp_size=tp_size, **kwargs)
message["mlp_moe"][f"expert {expert_idx}"] = expert
return message
else:
return _get_message_layer_mlp(message, model, layer_idx, md=md, tp_size=tp_size)
def get_message_postprocess(model, md):
# Send final norm from tp_rank 0.
message = {
@ -224,6 +243,7 @@ def _load_checkpoint(queue, args):
model_mg.initialize_megatron_args(args_hf, queue)
model_mg.set_tensor_model_parallel_world_size(model_mg.args.tensor_model_parallel_size)
model_mg.set_expert_model_parallel_world_size(model_mg.args.expert_model_parallel_size)
model_mg.set_pipeline_model_parallel_world_size(model_mg.args.pipeline_model_parallel_size)
model_mg.set_virtual_pipeline_model_parallel_world_size(model_mg.args.virtual_pipeline_model_parallel_size)
@ -241,7 +261,7 @@ def _load_checkpoint(queue, args):
model_mg.update_module(model_hf)
def queue_put(name, msg):
print(f"sending {name}")
logger.info(f"sending {name}")
msg["name"] = name
queue.put(msg)

View File

@ -0,0 +1,299 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import os
import sys
import types
import logging as logger
import torch
from models import get_megatron_model
logger.basicConfig(format="")
logger.getLogger().setLevel(logger.INFO)
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron loader')
group.add_argument('--true-vocab-size', type=int, default=None,
help='original size of vocab, if specified will trim padding from embedding table.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of megatron repository')
parser.add_argument('--add-qkv-bias', action='store_true',
help='Add bias for attention qkv', default=False,
)
parser.add_argument('--add-dense-bias', action='store_true',
help='Add bias for attention dense', default=False,
)
parser.add_argument('--embed-layernorm', action='store_true',
help='Add embed layernorm for word embedding', default=False,
)
parser.add_argument('--params-dtype', type=str,
help='Set weight dtype', default='fp16',
)
group.add_argument('--lora-target-modules', nargs='+', type=str, default=[],
help='Lora target modules.')
group.add_argument('--lora-load', type=str, default=None,
help='Directory containing a lora model checkpoint.')
group.add_argument('--lora-r', type=int, default=16,
help='Lora r.')
group.add_argument('--lora-alpha', type=int, default=32,
help='Lora alpha.')
def build_metadata(args, margs):
# Metadata.
# Layernorm has bias; RMSNorm does not.
if hasattr(margs, 'normalization'):
norm_has_bias = margs.normalization == "LayerNorm"
else:
# older models only supported LayerNorm
norm_has_bias = True
md = types.SimpleNamespace()
md.model_type = args.model_type
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
md.num_attention_heads = margs.num_attention_heads
md.max_position_embeddings = margs.max_position_embeddings
md.tokenizer_type = margs.tokenizer_type
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.output_layer = margs.untie_embeddings_and_output_weights
md.position_embedding_type = margs.position_embedding_type
md.linear_bias = margs.add_bias_linear
md.norm_has_bias = norm_has_bias
md.swiglu = margs.swiglu
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = None
md.checkpoint_args = margs
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
md.embed_layernorm = margs.embed_layernorm
md.consumed_train_samples = 0
md.consumed_valid_samples = 0
return md
def get_message_preprocess(model, args):
# Send embeddings.
tp_size = args.tensor_model_parallel_size
message = {
"word embeddings": torch.cat(
[model.get_embedding_word_embeddings_weight(tp_rank=tp_rank) for tp_rank in range(tp_size)], dim=0
)
}
if args.position_embedding_type == 'learned_absolute':
message["position embeddings"] = model.get_embedding_position_embeddings_weight()
if args.embed_layernorm:
message["word embeddings norm_w"] = model.get_embedding_word_embeddings_norm_weight()
message["word embeddings norm_b"] = model.get_embedding_word_embeddings_norm_bias()
return message
def get_message_layer_norm(message, model, md, **kwargs):
# Get non-parallel tensors from tp_rank 0.
message["input norm weight"] = model.get_layers_input_layernorm_weight(**kwargs)
if md.norm_has_bias:
message["input norm bias"] = model.get_layers_input_layernorm_bias(**kwargs)
message["post norm weight"] = model.get_layers_self_attention_post_attention_layernorm_weight(**kwargs)
if md.norm_has_bias:
message["post norm bias"] = model.get_layers_self_attention_post_attention_layernorm_bias(**kwargs)
return message
def get_message_layer_attn(message, model, md=None, **kwargs):
# Grab all parallel tensors for this layer
qkv_weight = []
qkv_bias = []
dense_weight = []
margs = model.get_args()
for tp_rank in range(margs.tensor_model_parallel_size):
kwargs["tp_rank"] = tp_rank
qkv_weight.append(model.get_layers_self_attention_linear_qkv_weight(**kwargs))
dense_weight.append(model.get_layers_self_attention_linear_proj_weight(**kwargs))
if md.linear_bias or margs.add_qkv_bias:
qkv_bias.append(model.get_layers_self_attention_linear_qkv_bias(**kwargs))
# Handle gated linear units
# simple concat of the rest
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
if md.linear_bias or margs.add_qkv_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
if md.linear_bias:
message["dense bias"] = model.get_layers_self_attention_linear_proj_bias(**kwargs)
return message
def _get_message_layer_mlp(message, model, md=None, **kwargs):
margs = model.get_args()
mlp_l0_weight = []
mlp_l1_weight = []
mlp_l0_bias = []
for tp_rank in range(margs.tensor_model_parallel_size):
kwargs["tp_rank"] = tp_rank
mlp_l0_weight.append(model.get_layers_mlp_linear_fc1_weight(**kwargs))
mlp_l1_weight.append(model.get_layers_mlp_linear_fc2_weight(**kwargs))
if md.linear_bias:
mlp_l0_bias.append(model.get_layers_mlp_linear_fc1_bias(**kwargs))
# Handle gated linear units
if md.swiglu:
# concat all the first halves ('W's) and all the second halves ('V's)
for tp_rank in range(margs.tensor_model_parallel_size):
mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
message[f"mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
message[f"mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
else:
message[f"mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
# simple concat of the rest
message[f"mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
if md.linear_bias:
if md.swiglu:
for tp_rank in range(margs.tensor_model_parallel_size):
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
message[f"mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias], dim=0)
message[f"mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias], dim=0)
else:
message[f"mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
def get_message_layer_mlp(message, model, md=None, **kwargs):
# Grab all parallel tensors for this layer
margs = model.get_args()
if margs.num_experts:
message["mlp_moe"] = {}
num_experts_local = margs.num_experts // margs.expert_model_parallel_size
mlp_router_weight = model.get_layers_mlp_router_weight(**kwargs)
message["mlp_moe"]["mlp router weight"] = mlp_router_weight
for ep_rank in range(margs.expert_model_parallel_size):
kwargs["ep_rank"] = ep_rank
for expert_idx in range(num_experts_local):
kwargs["expert_idx"] = expert_idx
global_expert_idx = expert_idx + ep_rank * num_experts_local
message["mlp_moe"][f"expert {global_expert_idx}"] = {}
expert = message["mlp_moe"][f"expert {global_expert_idx}"]
_get_message_layer_mlp(expert, model, md, **kwargs)
else:
_get_message_layer_mlp(message, model, md, **kwargs)
return message
def get_message_postprocess(model, md, **kwargs):
# Send final norm from tp_rank 0.
message = {}
message[f"weight"] = model.get_final_layernorm_weight(**kwargs)
if md.norm_has_bias:
message[f"bias"] = model.get_final_layernorm_bias(**kwargs)
return message
def get_message_output_layer(model, md, **kwargs):
# Send final norm from tp_rank 0.
margs = model.get_args()
tp_size = margs.tensor_model_parallel_size
message = {}
if md.output_layer:
get_output_layer_weight_list = []
for tp_rank in range(tp_size):
kwargs["tp_rank"] = tp_rank
get_output_layer_weight_list.append(
model.get_output_layer_weight(**kwargs)
)
message[f"weight"] = torch.cat(get_output_layer_weight_list, dim=0)
return message
def _load_checkpoint(queue, args):
# Search in directory above this
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
model_mg = get_megatron_model(args_cmd=args)
model_mg.initialize_megatron_args(queue=queue, loader_megatron=True)
model_mg.set_tensor_model_parallel_world_size(model_mg.args.tensor_model_parallel_size)
model_mg.set_expert_model_parallel_world_size(model_mg.args.expert_model_parallel_size)
model_mg.set_pipeline_model_parallel_world_size(model_mg.args.pipeline_model_parallel_size)
model_mg.set_virtual_pipeline_model_parallel_world_size(model_mg.args.virtual_pipeline_model_parallel_size)
# Get first pipe stage.
model_mg.set_tensor_model_parallel_rank(0)
model_mg.set_pipeline_model_parallel_rank(0)
margs = model_mg.get_args()
md = build_metadata(args, margs)
queue.put(md)
model_mg.get_modules_from_pretrained(pp_stage_cache_flag=True)
def queue_put(name, msg):
logger.info(f"sending {name}")
msg["name"] = name
queue.put(msg)
# Send embeddings
message = get_message_preprocess(model_mg, margs)
queue_put("embeddings", message)
pp_size = margs.pipeline_model_parallel_size
vp_size = margs.virtual_pipeline_model_parallel_size
if vp_size is None:
vp_size = 1
total_layer_num = 0
for vp_rank in range(vp_size):
for pp_rank in range(pp_size):
model_mg.set_pipeline_model_parallel_rank(pp_rank)
model_mg.get_modules_from_pretrained(pp_stage_cache_flag=True)
kwargs = {"vp_rank": vp_rank, 'pp_rank': pp_rank}
for layer_idx in range(len(model_mg.get_layers_module(**kwargs))):
kwargs["layer_idx"] = layer_idx
message = {}
message = get_message_layer_norm(message, model_mg, md, **kwargs)
message = get_message_layer_attn(message, model_mg, md, **kwargs)
message = get_message_layer_mlp(message, model_mg, md, **kwargs)
queue_put(f"transformer layer {total_layer_num}", message)
total_layer_num = total_layer_num + 1
kwargs = {"pp_rank": pp_size - 1, "vp_rank": vp_size - 1}
message = get_message_postprocess(model_mg, md, **kwargs)
queue_put("final norm", message)
message = get_message_output_layer(model_mg, md, **kwargs)
if message:
queue_put("output layer", message)
queue.put("done")
def load_checkpoint(queue, args):
try:
_load_checkpoint(queue, args)
except:
queue.put("exit")
raise

View File

@ -25,24 +25,25 @@
"vocab_size": "vocab_size",
"intermediate_size": "intermediate_size",
"norm_epsilon": "rms_norm_eps",
"tie_word_embeddings": "tie_word_embeddings"
"tie_word_embeddings": "tie_word_embeddings",
"torch_dtype": "torch_dtype"
},
"model_hf_key_mapping": {
"model": "module[tp_rank]",
"embedding_word_embeddings": "module[tp_rank].model.embed_tokens",
"embedding_word_embeddings_norm": "module[tp_rank].model.embedding.word_embeddings.norm",
"layers": "module[tp_rank].model.layers",
"layers_input_layernorm": "module[tp_rank].model.layers[layer_idx].input_layernorm",
"layers_self_attention_linear_proj": "module[tp_rank].model.layers[layer_idx].self_attn.o_proj",
"layers_self_attention_linear_q_proj": "module[tp_rank].model.layers[layer_idx].self_attn.q_proj",
"layers_self_attention_linear_k_proj": "module[tp_rank].model.layers[layer_idx].self_attn.k_proj",
"layers_self_attention_linear_v_proj": "module[tp_rank].model.layers[layer_idx].self_attn.v_proj",
"layers_self_attention_pre_mlp_layernorm": "module[tp_rank].model.layers[layer_idx].post_attention_layernorm",
"layers_mlp_gate_proj": "module[tp_rank].model.layers[layer_idx].mlp.gate_proj",
"layers_mlp_up_proj": "module[tp_rank].model.layers[layer_idx].mlp.up_proj",
"layers_mlp_linear_fc2": "module[tp_rank].model.layers[layer_idx].mlp.down_proj",
"final_layernorm": "module[tp_rank].model.norm",
"output_layer": "module[tp_rank].lm_head"
"model": "module[0]",
"embedding_word_embeddings": "model.embed_tokens",
"embedding_word_embeddings_norm": "model.embedding.word_embeddings.norm",
"layers": "model.layers",
"layers_input_layernorm": "model.layers[layer_idx].input_layernorm",
"layers_self_attention_linear_proj": "model.layers[layer_idx].self_attn.o_proj",
"layers_self_attention_linear_q_proj": "model.layers[layer_idx].self_attn.q_proj",
"layers_self_attention_linear_k_proj": "model.layers[layer_idx].self_attn.k_proj",
"layers_self_attention_linear_v_proj": "model.layers[layer_idx].self_attn.v_proj",
"layers_self_attention_pre_mlp_layernorm": "model.layers[layer_idx].post_attention_layernorm",
"layers_mlp_gate_proj": "model.layers[layer_idx].mlp.gate_proj",
"layers_mlp_up_proj": "model.layers[layer_idx].mlp.up_proj",
"layers_mlp_linear_fc2": "model.layers[layer_idx].mlp.down_proj",
"final_layernorm": "model.norm",
"output_layer": "lm_head"
}
},
"llama2": {
@ -66,18 +67,29 @@
"norm_epsilon": "layernorm_epsilon"
},
"model_hf_key_mapping": {
"model": "module[tp_rank]",
"embedding_word_embeddings": "module[tp_rank].transformer.embedding.word_embeddings",
"layers": "module[tp_rank].transformer.encoder.layers",
"layers_input_layernorm": "module[tp_rank].transformer.encoder.layers[layer_idx].input_layernorm",
"layers_self_attention_linear_qkv_pack": "module[tp_rank].transformer.encoder.layers[layer_idx].self_attention.query_key_value",
"layers_self_attention_linear_proj": "module[tp_rank].transformer.encoder.layers[layer_idx].self_attention.dense",
"layers_self_attention_pre_mlp_layernorm": "module[tp_rank].transformer.encoder.layers[layer_idx].post_attention_layernorm",
"layers_mlp_linear_fc1": "module[tp_rank].transformer.encoder.layers[layer_idx].mlp.dense_h_to_4h",
"layers_mlp_linear_fc2": "module[tp_rank].transformer.encoder.layers[layer_idx].mlp.dense_4h_to_h",
"final_layernorm": "module[tp_rank].transformer.encoder.final_layernorm",
"output_layer": "module[tp_rank].transformer.output_layer"
"embedding_word_embeddings": "transformer.embedding.word_embeddings",
"layers": "transformer.encoder.layers",
"layers_input_layernorm": "transformer.encoder.layers[layer_idx].input_layernorm",
"layers_self_attention_linear_qkv_pack": "transformer.encoder.layers[layer_idx].self_attention.query_key_value",
"layers_self_attention_linear_proj": "transformer.encoder.layers[layer_idx].self_attention.dense",
"layers_self_attention_pre_mlp_layernorm": "transformer.encoder.layers[layer_idx].post_attention_layernorm",
"layers_mlp_linear_fc1": "transformer.encoder.layers[layer_idx].mlp.dense_h_to_4h",
"layers_mlp_linear_fc2": "transformer.encoder.layers[layer_idx].mlp.dense_4h_to_h",
"final_layernorm": "transformer.encoder.final_layernorm",
"output_layer": "transformer.output_layer"
}
},
"mixtral": {
"__base__": "base",
"config_set_value": {
"moe_flag": true
},
"model_hf_key_mapping": {
"layers_mlp_router": "model.layers[layer_idx].block_sparse_moe.gate",
"layers_mlp_gate_proj": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w1",
"layers_mlp_up_proj": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w3",
"layers_mlp_linear_fc2": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w2"
}
}
}
}

View File

@ -4,80 +4,128 @@ import sys
import re
import json
from types import SimpleNamespace
import logging as logger
from collections import OrderedDict
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from megatron.core import mpu
from megatron.training.arguments import validate_args
from megatron.training.global_vars import set_global_variables
from megatron.legacy.model import module
from megatron.core.enums import ModelType
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_args
from megatron.training.checkpointing import load_checkpoint
from megatron.core import tensor_parallel
from pretrain_gpt import model_provider
from modellink.utils import parse_args
from modellink.training import model_provider_func_wrapper
from modellink.checkpointing import load_checkpoint_wrapper
logger.basicConfig(format="")
logger.getLogger().setLevel(logger.INFO)
model_provider = model_provider_func_wrapper(model_provider)
load_checkpoint = load_checkpoint_wrapper(load_checkpoint)
def tensor_info(tensor):
shape = tensor.shape
mean_val = tensor.mean().item()
min_val = tensor.min().item()
max_val = tensor.max().item()
return f"shape: {shape} mean_val: {mean_val} min_val: {min_val} max_val: {max_val}"
class ModelBase(abc.ABC):
def __init__(self, args_cmd=None):
self.args_cmd = args_cmd
self.args = None
self.args_megatron_checkpoint = None
self.module = None
self.module_mapping = None
self.model_cfg = self.read_model_cfg()
self.__register_functions()
self.kwargs_idx = OrderedDict({
"vp_rank": 0,
"ep_rank": 0,
"tp_rank": 0,
"layer_idx": 0,
"expert_idx": 0
})
def update_kwargs_idx(self, **kwargs):
for key in self.kwargs_idx:
if key in kwargs:
self.kwargs_idx[key] = kwargs[key]
elif self.kwargs_idx[key] > 0:
self.kwargs_idx[key] = 0
def __register_functions(self):
self.get_module_mapping()
kwargs_idx = dict({"tp_rank": 0, "layer_idx": 0})
def get_obj(self, value, **kwargs):
def _get_obj(self, value, **kwargs):
pattern = r'(\w+)(?:\[(\w+)\])?'
matches = re.findall(pattern, value)
obj = self
for key in kwargs_idx:
if key in kwargs:
kwargs_idx[key] = kwargs[key]
self.update_kwargs_idx(**kwargs)
obj = self.get_model_item(**kwargs)
for attr, attr_ident in matches:
obj = getattr(obj, attr)
if hasattr(obj, attr):
obj = getattr(obj, attr)
else:
return None
if attr_ident:
if attr_ident in kwargs_idx:
attr_idx = kwargs_idx[attr_ident]
if attr_ident in self.kwargs_idx:
attr_idx = self.kwargs_idx[attr_ident]
obj = obj[attr_idx]
else:
raise AssertionError(f"check {self.__class__.__name__}.module_mapping **{attr_ident}**.")
return obj
def func_generator_get_module(value):
def _func_generator_get_module(value):
def func(self, **kwargs):
return get_obj(self, value, **kwargs)
return _get_obj(self, value, **kwargs)
return func
def func_generator_get_weight(value):
def _func_generator_get_weight(value):
def func(self, **kwargs):
return get_obj(self, value, **kwargs).weight.data
return _get_obj(self, value, **kwargs).weight.data
return func
def func_generator_get_bias(value):
def _func_generator_get_bias(value):
def func(self, **kwargs):
return get_obj(self, value, **kwargs).bias.data
return _get_obj(self, value, **kwargs).bias.data
return func
def func_generator_set_weight(value):
def _func_generator_set_weight(value):
def func(self, **kwargs):
return get_obj(self, value, **kwargs).weight.data.copy_(kwargs.get('data'))
return _get_obj(self, value, **kwargs).weight.data.copy_(kwargs.get('data'))
return func
def func_generator_set_bias(value):
def _func_generator_set_bias(value):
def func(self, **kwargs):
return get_obj(self, value, **kwargs).bias.data.copy_(kwargs.get('data'))
return _get_obj(self, value, **kwargs).bias.data.copy_(kwargs.get('data'))
return func
def _func_generator_has_module(value):
def func(self, **kwargs):
obj = _get_obj(self, value, **kwargs)
return True if obj else False
return func
def _setattr(self, func_name, value):
if not hasattr(self, func_name):
setattr(self, func_name, value)
if self.module_mapping:
for key, value in self.module_mapping.items():
setattr(self, "get_" + key + "_module", func_generator_get_module(value).__get__(self, ModelBase))
setattr(self, "get_" + key + "_weight", func_generator_get_weight(value).__get__(self, ModelBase))
setattr(self, "get_" + key + "_bias", func_generator_get_bias(value).__get__(self, ModelBase))
setattr(self, "set_" + key + "_weight", func_generator_set_weight(value).__get__(self, ModelBase))
setattr(self, "set_" + key + "_bias", func_generator_set_bias(value).__get__(self, ModelBase))
_setattr(self, "get_" + key + "_module", _func_generator_get_module(value).__get__(self, ModelBase))
_setattr(self, "get_" + key + "_weight", _func_generator_get_weight(value).__get__(self, ModelBase))
_setattr(self, "get_" + key + "_bias", _func_generator_get_bias(value).__get__(self, ModelBase))
_setattr(self, "set_" + key + "_weight", _func_generator_set_weight(value).__get__(self, ModelBase))
_setattr(self, "set_" + key + "_bias", _func_generator_set_bias(value).__get__(self, ModelBase))
_setattr(self, "has_" + key + "_module", _func_generator_has_module(value).__get__(self, ModelBase))
def update_module(self, src_model):
self.set_preprocess_state(src_model)
@ -87,51 +135,58 @@ class ModelBase(abc.ABC):
def set_preprocess_state(self, src_model):
'''Set embedding params.'''
self.set_embedding_word_embeddings_weight(data=src_model.get_embedding_word_embeddings_weight())
embeddings_weight = src_model.get_embedding_word_embeddings_weight()
self.set_embedding_word_embeddings_weight(data=embeddings_weight)
def set_postprocess_state(self, src_model):
self.set_final_layernorm_weight(data=src_model.get_final_layernorm_weight())
final_layernorm_weight = src_model.get_final_layernorm_weight()
output_layer_weight = src_model.get_output_layer_weight()
self.set_final_layernorm_weight(data=final_layernorm_weight)
if self.args.untie_embeddings_and_output_weights:
self.set_output_layer_weight(data=src_model.get_output_layer_weight())
self.set_output_layer_weight(data=output_layer_weight)
def set_layer_state(self, src_model, layer_idx):
self.set_attn_state(layer_idx, src_model)
self.set_mlp_state(layer_idx, src_model)
self.set_layers_input_layernorm_weight(
layer_idx=layer_idx,
data=src_model.get_layers_input_layernorm_weight(layer_idx=layer_idx))
self.set_layers_self_attention_pre_mlp_layernorm_weight(
layer_idx=layer_idx,
data=src_model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx))
input_layernorm_weight = src_model.get_layers_input_layernorm_weight(layer_idx=layer_idx)
pre_mlp_layernorm_weight = src_model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx)
self.set_layers_input_layernorm_weight(layer_idx=layer_idx, data=input_layernorm_weight)
self.set_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx, data=pre_mlp_layernorm_weight)
def set_attn_state(self, layer_idx, src_model):
'''Set self-attention params.'''
# Get attention layer & state.
self.set_layers_self_attention_linear_qkv_weight(
layer_idx=layer_idx,
data=src_model.get_layers_self_attention_linear_qkv_weight(layer_idx=layer_idx))
self.set_layers_self_attention_linear_proj_weight(
layer_idx=layer_idx,
data=src_model.get_layers_self_attention_linear_proj_weight(layer_idx=layer_idx))
qkv_weight = src_model.get_layers_self_attention_linear_qkv_weight(layer_idx=layer_idx)
proj_weight = src_model.get_layers_self_attention_linear_proj_weight(layer_idx=layer_idx)
self.set_layers_self_attention_linear_qkv_weight(layer_idx=layer_idx, data=qkv_weight)
self.set_layers_self_attention_linear_proj_weight(layer_idx=layer_idx, data=proj_weight)
if self.args.add_qkv_bias:
self.set_layers_self_attention_linear_qkv_bias(
layer_idx=layer_idx,
data=src_model.get_layers_self_attention_linear_qkv_bias(layer_idx=layer_idx))
qkv_bias = src_model.get_layers_self_attention_linear_qkv_bias(layer_idx=layer_idx)
self.set_layers_self_attention_linear_qkv_bias(layer_idx=layer_idx, data=qkv_bias)
if self.args.add_dense_bias:
self.set_layers_self_attention_linear_proj_bias(
layer_idx=layer_idx,
data=src_model.get_layers_self_attention_linear_proj_bias(layer_idx=layer_idx))
proj_bias = src_model.get_layers_self_attention_linear_proj_bias(layer_idx=layer_idx)
self.set_layers_self_attention_linear_proj_bias(layer_idx=layer_idx, data=proj_bias)
def _set_mlp_state(self, src_model, **kwargs):
'''Set MLP params.'''
fc1_weight = src_model.get_layers_mlp_linear_fc1_weight(**kwargs)
fc2_weight = src_model.get_layers_mlp_linear_fc2_weight(**kwargs)
self.set_layers_mlp_linear_fc1_weight(data=fc1_weight, **kwargs)
self.set_layers_mlp_linear_fc2_weight(data=fc2_weight, **kwargs)
def set_mlp_state(self, layer_idx, src_model):
'''Set MLP params.'''
self.set_layers_mlp_linear_fc1_weight(
layer_idx=layer_idx,
data=src_model.get_layers_mlp_linear_fc1_weight(layer_idx=layer_idx))
args = src_model.get_args()
kwargs = {'layer_idx': layer_idx}
num_experts = getattr(args, 'num_experts', None) or getattr(args, 'num_local_experts', None)
if num_experts:
router_weight = src_model.get_layers_mlp_router_weight(**kwargs)
self.set_layers_mlp_router_weight(**kwargs, data=router_weight)
for expert_idx in range(num_experts):
kwargs['expert_idx'] = expert_idx
self._set_mlp_state(src_model, **kwargs)
else:
self._set_mlp_state(src_model, **kwargs)
self.set_layers_mlp_linear_fc2_weight(
layer_idx=layer_idx,
data=src_model.get_layers_mlp_linear_fc2_weight(layer_idx=layer_idx))
def get_args(self):
return self.args
@ -145,21 +200,57 @@ class ModelBase(abc.ABC):
def get_modules_count(self):
return len(self.module)
@staticmethod
def read_model_cfg():
def merge_configs(base_config, specific_config):
merged_config = base_config.copy()
for key, value in specific_config.items():
if isinstance(value, dict) and key in merged_config:
merged_config[key] = merge_configs(merged_config[key], value)
else:
merged_config[key] = value
return merged_config
current_directory = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(current_directory, 'model_cfg.json'), 'r') as file:
config = json.load(file)
final_configs = {}
for model_name, model_config in config["model_mappings"].items():
if "__base__" in model_config:
base_model_name = model_config["__base__"]
base_config = config["model_mappings"][base_model_name]
specific_config = model_config.copy()
specific_config.pop("__base__", None)
final_config = merge_configs(base_config, specific_config)
else:
final_config = model_config
final_configs[model_name] = final_config
return final_configs
@abc.abstractmethod
def get_module_mapping(self):
pass
@abc.abstractmethod
def get_model_item(self, **kwargs):
pass
class HuggingfaceModel(ModelBase):
def __init__(self, args_cmd):
self.model_cfg = self.read_model_cfg()
super(HuggingfaceModel, self).__init__(args_cmd)
self.initialize_args()
self.layers_self_attention_linear_qkv_caches = {"layer_idx": -1, "weight": None, "bias": None}
def initialize_args(self):
# Read huggingface args.
llama_args_path = os.path.join(self.args_cmd.load_dir, "config.json")
if self.args_cmd.save_model_type == 'huggingface':
cfg_dir = self.args_cmd.save_dir
else:
cfg_dir = self.args_cmd.load_dir
llama_args_path = os.path.join(cfg_dir, "config.json")
with open(llama_args_path) as f:
self.args = json.load(f)
@ -189,21 +280,20 @@ class HuggingfaceModel(ModelBase):
load_dir = self.args_cmd.save_dir
else:
load_dir = self.args_cmd.load_dir
self.module = [
AutoModelForCausalLM.from_pretrained(load_dir, device_map=device_map, trust_remote_code=trust_remote_code)
]
self.module = [AutoModelForCausalLM.from_pretrained(load_dir, device_map=device_map, trust_remote_code=trust_remote_code)]
if self.args.torch_dtype in ["float16", "bfloat16"]:
self.module[0] = self.module[0].to(eval(f'torch.{self.args.torch_dtype}'))
def get_module_mapping(self):
self.module_mapping = self.model_cfg.get(self.args_cmd.model_type_hf).get('model_hf_key_mapping')
def _get_layers_self_attention_linear_qkv_module(self, layer_idx=0):
def __get_layers_self_attention_linear_qkv_module(self, layer_idx=0):
if self.layers_self_attention_linear_qkv_caches["layer_idx"] == layer_idx:
return
self.layers_self_attention_linear_qkv_caches["layer_idx"] = layer_idx
# Reshape loaded weights.
nh = self.args.num_attention_heads
ng = (self.args.num_key_value_heads if self.args.group_query_attention else self.args.num_attention_heads)
# dim = self.args['kv_channels']
dim = self.args.hidden_size // self.args.num_attention_heads
if not nh % ng == 0:
raise ValueError("nh % ng should equal 0")
@ -237,7 +327,6 @@ class HuggingfaceModel(ModelBase):
qkv_pack_weight = qkv_pack.weight
full_q = dim * nh
end_k = full_q + ng * dim
hs = self.args.hidden_size
q_weight = qkv_pack_weight[:full_q, :]
k_weight = qkv_pack_weight[full_q:end_k, :]
v_weight = qkv_pack_weight[end_k:, :]
@ -253,82 +342,115 @@ class HuggingfaceModel(ModelBase):
else:
raise ValueError(f"Unsupported types. {qkv_type}")
def get_layers_mlp_linear_fc1_weight(self, layer_idx=0):
def get_layers_mlp_linear_fc1_weight(self, **kwargs):
fc_type = self.args.fc_type
if fc_type == "h_to_4h":
return self.get_layers_mlp_linear_fc1_module(layer_idx=layer_idx).weight
return self.get_layers_mlp_linear_fc1_module(**kwargs).weight
elif fc_type == "gate_up_down":
gate_proj = self.get_layers_mlp_gate_proj_weight(layer_idx=layer_idx)
up_proj = self.get_layers_mlp_up_proj_weight(layer_idx=layer_idx)
gate_proj = self.get_layers_mlp_gate_proj_weight(**kwargs)
up_proj = self.get_layers_mlp_up_proj_weight(**kwargs)
return torch.cat([gate_proj, up_proj], dim=0)
else:
raise ValueError(f"Unsupported fc_type {fc_type}")
def get_layers_self_attention_linear_qkv_weight(self, layer_idx):
self._get_layers_self_attention_linear_qkv_module(layer_idx=layer_idx)
self.__get_layers_self_attention_linear_qkv_module(layer_idx=layer_idx)
return self.layers_self_attention_linear_qkv_caches["weight"]
def get_layers_self_attention_linear_qkv_bias(self, layer_idx):
self._get_layers_self_attention_linear_qkv_module(layer_idx=layer_idx)
self.__get_layers_self_attention_linear_qkv_module(layer_idx=layer_idx)
return self.layers_self_attention_linear_qkv_caches["bias"]
@staticmethod
def read_model_cfg():
def merge_configs(base_config, specific_config):
merged_config = base_config.copy()
for key, value in specific_config.items():
if isinstance(value, dict) and key in merged_config:
merged_config[key] = merge_configs(merged_config[key], value)
else:
merged_config[key] = value
return merged_config
def set_layers_mlp_linear_fc1_weight(self, data=None, **kwargs):
gate_proj, up_proj = torch.chunk(data, 2, dim=0)
self.set_layers_mlp_gate_proj_weight(data=gate_proj, **kwargs)
self.set_layers_mlp_up_proj_weight(data=up_proj, **kwargs)
current_directory = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(current_directory, 'model_cfg.json'), 'r') as file:
config = json.load(file)
# 存储最终配置的字典
final_configs = {}
def set_layers_self_attention_linear_qkv_weight(self, layer_idx=0, data=None):
def qkv_split_weight(query_key_value):
qkv_weight = query_key_value.reshape(
ng,
repeats + 2,
query_key_value.shape[0] // ng // (repeats + 2),
query_key_value.shape[1],
)
hidden_size = qkv_weight.shape[-1]
qw = qkv_weight[:, :repeats, ...].reshape(-1, hidden_size)
kw = qkv_weight[:, repeats: repeats + 1, ...].reshape(-1, hidden_size)
vw = qkv_weight[:, repeats + 1:, ...].reshape(-1, hidden_size)
return qw, kw, vw
# 遍历所有模型配置
for model_name, model_config in config["model_mappings"].items():
if "__base__" in model_config:
base_model_name = model_config["__base__"]
base_config = config["model_mappings"][base_model_name]
specific_config = model_config.copy()
specific_config.pop("__base__", None)
final_config = merge_configs(base_config, specific_config)
else:
final_config = model_config
final_configs[model_name] = final_config
nh = self.args.num_attention_heads
ng = (self.args.num_key_value_heads if self.args.group_query_attention else self.args.num_attention_heads)
if not nh % ng == 0:
raise ValueError("nh % ng should equal 0")
repeats = nh // ng
return final_configs
qkv_type = self.args.qkv_type
if qkv_type == "unpack":
q_weight, k_weight, v_weight = qkv_split_weight(data)
self.set_layers_self_attention_linear_q_proj_weight(layer_idx=layer_idx, data=q_weight)
self.set_layers_self_attention_linear_k_proj_weight(layer_idx=layer_idx, data=k_weight)
self.set_layers_self_attention_linear_v_proj_weight(layer_idx=layer_idx, data=v_weight)
else:
raise ValueError(f"Unsupported types. {qkv_type}")
def set_layers_self_attention_linear_qkv_bias(self, layer_idx, data=None):
def qkv_split_bias(query_key_value):
bias_weight = query_key_value.reshape(
ng, repeats + 2, query_key_value.shape[0] // ng // (repeats + 2)
)
qw = bias_weight[:, :repeats, ...].reshape(-1)
kw = bias_weight[:, repeats: repeats + 1, ...].reshape(-1)
vw = bias_weight[:, repeats + 1:, ...].reshape(-1)
return qw, kw, vw
nh = self.args.num_attention_heads
ng = (self.args.num_key_value_heads if self.args.group_query_attention else self.args.num_attention_heads)
if not nh % ng == 0:
raise ValueError("nh % ng should equal 0")
repeats = nh // ng
qkv_type = self.args.qkv_type
if qkv_type == "unpack":
if self.args_cmd.add_qkv_bias:
q_bias, k_bias, v_bias = qkv_split_bias(data)
self.set_layers_self_attention_linear_q_proj_bias(layer_idx=layer_idx, data=q_bias)
self.set_layers_self_attention_linear_k_proj_bias(layer_idx=layer_idx, data=k_bias)
self.set_layers_self_attention_linear_v_proj_bias(layer_idx=layer_idx, data=v_bias)
else:
raise ValueError(f"Unsupported types. {qkv_type}")
def get_model_item(self, **kwargs):
return self.module[0]
class MegatronModel(ModelBase):
def __init__(self, args_cmd, md=None):
super(MegatronModel, self).__init__(args_cmd)
self.md = md
self.pp_stage_cache = []
def initialize_megatron_args(self, hf_args=None, queue=None):
def initialize_megatron_args(self, hf_args=None, queue=None, loader_megatron=False, saver_megatron=False):
sys.argv = self.get_sys_argv()
self.args = parse_args()
self.update_megatron_args_from_cmd_config() # saver里面是否都需要这个要验证
self.update_megatron_args_from_huggingface_config(hf_args) # loader走, saver不走
self.update_megatron_args_from_megatron_checkpoint(loader_megatron)
self.update_megatron_args_from_cmd_config(loader_megatron)
self.update_megatron_args_from_huggingface_config(hf_args)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes.
self.args.world_size = self.args.tensor_model_parallel_size * self.args.pipeline_model_parallel_size
self.update_megatron_args_from_loader_margs()
self.args = validate_args(self.args)
self.check_for_args(queue)
self.check_for_args(queue, saver_megatron)
self.args.model_type = ModelType.encoder_or_decoder
# Suppress warning about torch.distributed not being initialized.
module.MegatronModule.embedding_warning_printed = True
set_global_variables(self.args, build_tokenizer=False)
self.set_megatron_parallel_state()
set_args(self.args)
self.set_megatron_parallel_state(saver_megatron)
def update_megatron_args_from_loader_margs(self):
if self.md and hasattr(self.md, 'checkpoint_args'):
@ -342,27 +464,28 @@ class MegatronModel(ModelBase):
'recompute_num_layers', 'recompute_method', 'encoder_num_layers', 'encoder_seq_length',
'distribute_saved_activations', 'train_iters', 'lr_decay_iters', 'lr_warmup_iters',
'lr_warmup_fraction', 'start_weight_decay', 'end_weight_decay', 'make_vocab_size_divisible_by',
'masked_softmax_fusion', 'num_layer_list',
'masked_softmax_fusion', 'num_layer_list', 'lora_target_modules', 'expert_model_parallel_size'
]
for arg, value in vars(self.md.checkpoint_args).items():
if arg in args_to_keep:
continue
if not hasattr(self.args, arg):
print(f"Checkpoint had argument {arg} but new arguments does not have this.")
logger.warning(f"Checkpoint had argument {arg} but new arguments does not have this.")
continue
if getattr(self.args, arg) != value:
print(
f"Overwriting default {arg} value {getattr(self.args, arg)} with value from checkpoint {value}.")
logger.warning(
f"Overwriting default {arg} value {getattr(self.args, arg)} with value from checkpoint {value}."
)
setattr(self.args, arg, value)
if hasattr(self.md, 'consumed_train_samples'):
self.args.consumed_train_samples = self.md.consumed_train_samples
self.args.consumed_valid_samples = self.md.consumed_valid_samples
print(f"Setting consumed_train_samples to {self.args.consumed_train_samples}"
f" and consumed_valid_samples to {self.args.consumed_valid_samples}")
logger.info(f"Setting consumed_train_samples to {self.args.consumed_train_samples} "
f"and consumed_valid_samples to {self.args.consumed_valid_samples}")
else:
print("consumed_train_samples not provided.")
logger.warning("consumed_train_samples not provided.")
def update_megatron_args_from_huggingface_config(self, hf_args):
if hf_args is None:
@ -397,18 +520,33 @@ class MegatronModel(ModelBase):
):
self.args.group_query_attention = True
self.args.num_query_groups = hf_args.num_key_value_heads
if hasattr(hf_args, 'num_local_experts'):
self.args.num_experts = hf_args.num_local_experts
def update_megatron_args_from_cmd_config(self):
def update_megatron_args_from_megatron_checkpoint(self, loader_megatron):
if not loader_megatron:
return
set_args(self.args)
self.args, self.args_megatron_checkpoint = load_args_from_checkpoint(self.args)
def update_megatron_args_from_cmd_config(self, loader_megatron):
self.args.w_pack = self.args_cmd.w_pack
self.args.add_qkv_bias = self.args_cmd.add_qkv_bias
self.args.add_dense_bias = self.args_cmd.add_dense_bias
self.args.tokenizer_model = self.args_cmd.tokenizer_model
self.args.make_vocab_size_divisible_by = self.args_cmd.make_vocab_size_divisible_by
self.args.tokenizer_model = getattr(self.args_cmd, 'tokenizer_model', None)
self.args.make_vocab_size_divisible_by = getattr(self.args_cmd, 'make_vocab_size_divisible_by', None)
if self.args_cmd.params_dtype == 'bf16':
self.args.bf16 = True
elif self.args_cmd.params_dtype == 'fp16':
self.args.fp16 = True
if self.args_cmd.add_dense_bias:
self.args.skip_bias_add = False
if loader_megatron:
self.args.lora_target_modules = self.args_cmd.lora_target_modules
self.args.lora_load = self.args_cmd.lora_load
self.args.lora_r = self.args_cmd.lora_r
self.args.lora_alpha = self.args_cmd.lora_alpha
# Determine how to make our models.
if not self.args_cmd.model_type == 'GPT':
raise ValueError("Llama-2 is a GPT model.")
@ -416,34 +554,113 @@ class MegatronModel(ModelBase):
if self.md and self.args_cmd.num_layer_list:
self.args.num_layer_list = self.args_cmd.num_layer_list
def set_megatron_parallel_state(self):
self.set_tensor_model_parallel_world_size(self.args.tensor_model_parallel_size)
self.set_pipeline_model_parallel_world_size(self.args.pipeline_model_parallel_size)
self.set_virtual_pipeline_model_parallel_world_size(self.args.virtual_pipeline_model_parallel_size)
def set_megatron_parallel_state(self, saver_megatron):
if saver_megatron:
self.set_tensor_model_parallel_world_size(self.args_cmd.target_tensor_parallel_size)
self.set_expert_model_parallel_world_size(self.args_cmd.target_expert_parallel_size)
self.set_pipeline_model_parallel_world_size(self.args_cmd.target_pipeline_parallel_size)
if self.args_cmd.num_layers_per_virtual_pipeline_stage:
vp_size = (self.args.num_layers //
self.args_cmd.target_pipeline_parallel_size //
self.args_cmd.num_layers_per_virtual_pipeline_stage)
self.set_virtual_pipeline_model_parallel_world_size(vp_size)
else:
self.set_tensor_model_parallel_world_size(self.args.tensor_model_parallel_size)
self.set_pipeline_model_parallel_world_size(self.args.pipeline_model_parallel_size)
self.set_virtual_pipeline_model_parallel_world_size(self.args.virtual_pipeline_model_parallel_size)
# Get first pipe stage.
self.set_tensor_model_parallel_rank(0)
self.set_pipeline_model_parallel_rank(0)
def get_modules_from_config(self, count=1, pre_process=True, post_process=True):
self.args.model_type = ModelType.encoder_or_decoder
self.module = [model_provider(pre_process, post_process).to(self.args.params_dtype) for _ in range(count)]
def get_modules_from_config(self, pp_stage_cache_flag=False):
self.__get_modules(pp_stage_cache_flag=pp_stage_cache_flag)
def check_for_args(self, queue):
check_args_list = {'tensor_model_parallel_size': None, 'pipeline_model_parallel_size': None,
'num_layers': None, 'hidden_size': None, 'seq_length': None,
'num_attention_heads': None, 'max_position_embeddings': None,
'position_embedding_type': None, 'tokenizer_type': None, 'iteration': 1,
'bert_binary_head': None, 'disable_bias_linear': False, 'params_dtype': None,
'swiglu': False}
def get_modules_from_pretrained(self, pp_stage_cache_flag=False):
self.__get_modules(from_pretrained=True, pp_stage_cache_flag=pp_stage_cache_flag)
def __get_modules(self, from_pretrained=False, pp_stage_cache_flag=False):
if self.args.num_experts:
tensor_parallel.model_parallel_cuda_manual_seed(123)
# Initialize the dictionary for the parallel mode of the model
pp_rank = self.get_pipeline_model_parallel_rank()
if pp_stage_cache_flag and pp_rank < len(self.pp_stage_cache):
self.module = self.pp_stage_cache[pp_rank]
return
virtual_pipeline_model_parallel_size = self.args.virtual_pipeline_model_parallel_size
if virtual_pipeline_model_parallel_size is None:
virtual_pipeline_model_parallel_size = 1
models = [
[
[
None for _ in range(self.args.tensor_model_parallel_size)
]
for _ in range(self.args.expert_model_parallel_size)
]
for _ in range(virtual_pipeline_model_parallel_size)
]
for ep_rank in range(self.args.expert_model_parallel_size):
if self.args.expert_model_parallel_size > 1:
self.set_expert_model_parallel_rank(ep_rank)
for tp_rank in range(self.args.tensor_model_parallel_size):
self.set_tensor_model_parallel_rank(tp_rank)
if self.args.virtual_pipeline_model_parallel_size is not None:
model_ = []
for vp_rank in range(self.args.virtual_pipeline_model_parallel_size):
self.set_virtual_pipeline_model_parallel_rank(vp_rank)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
expert_parallel_size = mpu.get_expert_model_parallel_world_size()
this_model = model_provider(
pre_process=pre_process,
post_process=post_process
).to(self.args.params_dtype)
model_.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model_ = [model_provider(pre_process, post_process).to(self.args.params_dtype)]
self.args.consumed_train_samples = 0
self.args.consumed_valid_samples = 0
if from_pretrained:
load_checkpoint(model_, None, None)
for vp_rank in range(virtual_pipeline_model_parallel_size):
models[vp_rank][ep_rank][tp_rank] = model_[vp_rank]
if self.args.lora_target_modules and from_pretrained:
if virtual_pipeline_model_parallel_size > 1:
raise AssertionError("Virtual pipeline and LoRA weight merging "
"are not supported simultaneously")
models[vp_rank][ep_rank][tp_rank].merge_and_unload()
self.module = models
if pp_stage_cache_flag:
self.pp_stage_cache.append(models)
def check_for_args(self, queue, saver_megatron):
if saver_megatron:
return
check_args_list = {
'tensor_model_parallel_size': None, 'pipeline_model_parallel_size': None, 'num_layers': None,
'hidden_size': None, 'seq_length': None, 'num_attention_heads': None, 'max_position_embeddings': None,
'position_embedding_type': None, 'tokenizer_type': None, 'iteration': 1, 'bert_binary_head': None,
'disable_bias_linear': False, 'params_dtype': None, 'swiglu': False
}
# if hasattr(self.args, 'add_bias_linear'):
# check_args_list['disable_bias_linear'] = self.args.add_bias_linear
def check_for_arg(arg_name, default=None):
if getattr(self.args, arg_name, None) is None:
if default is not None:
setattr(self.args, arg_name, default)
elif queue is not None:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {self.args}")
logger.error(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
logger.info(f"Arguments: {self.args}")
queue.put("exit")
exit(1)
@ -465,8 +682,14 @@ class MegatronModel(ModelBase):
'--no-save-rng',
'--no-initialization',
'--save-interval', '1',
'--load', self.args_cmd.load_dir
'--mock-data', # To pass the "blend data checks" in arguments.py
'--load', self.args_cmd.load_dir,
'--finetune',
# '--disable-bias-linear'
]
if hasattr(self.args_cmd, 'add_bias_linear') and not self.args_cmd.add_bias_linear:
sys_argv.append('--disable-bias-linear')
if self.args_cmd.use_mcore_models:
sys_argv.append('--use-mcore-models')
@ -484,9 +707,18 @@ class MegatronModel(ModelBase):
'--tokenizer-type', str(self.md.tokenizer_type),
'--tensor-model-parallel-size', str(self.args_cmd.target_tensor_parallel_size),
'--pipeline-model-parallel-size', str(self.args_cmd.target_pipeline_parallel_size),
'--expert-model-parallel-size', str(self.args_cmd.target_expert_parallel_size),
'--save', self.args_cmd.save_dir
])
if self.args_cmd.num_layers_per_virtual_pipeline_stage:
sys_argv.extend(['--num-layers-per-virtual-pipeline-stage',
str(self.args_cmd.num_layers_per_virtual_pipeline_stage)])
num_experts = getattr(self.md.checkpoint_args, 'num_experts', None)
if self.args_cmd.target_tensor_parallel_size > 1 and num_experts is not None and num_experts > 1:
sys_argv.append('--sequence-parallel')
if self.md.make_vocab_size_divisible_by is not None:
sys_argv.extend(['--make-vocab-size-divisible-by', str(self.md.make_vocab_size_divisible_by)])
if self.md.params_dtype == torch.float16:
@ -504,10 +736,22 @@ class MegatronModel(ModelBase):
return sys_argv
def get_model_item(self, **kwargs):
self.update_kwargs_idx(**kwargs)
_module = self.module
for key in self.kwargs_idx:
if "rank" in key:
_module = _module[self.kwargs_idx[key]]
return _module
@staticmethod
def set_tensor_model_parallel_world_size(tensor_model_parallel_size):
mpu.set_tensor_model_parallel_world_size(tensor_model_parallel_size)
@staticmethod
def set_expert_model_parallel_world_size(expert_model_parallel_size):
mpu.set_expert_model_parallel_world_size(expert_model_parallel_size)
@staticmethod
def set_pipeline_model_parallel_world_size(pipeline_model_parallel_size):
mpu.set_pipeline_model_parallel_world_size(pipeline_model_parallel_size)
@ -524,6 +768,18 @@ class MegatronModel(ModelBase):
def set_pipeline_model_parallel_rank(pipeline_model_parallel_rank):
mpu.set_pipeline_model_parallel_rank(pipeline_model_parallel_rank)
@staticmethod
def set_expert_model_parallel_rank(pipeline_model_parallel_rank):
mpu.set_expert_model_parallel_rank(pipeline_model_parallel_rank)
@staticmethod
def set_virtual_pipeline_model_parallel_rank(pipeline_model_parallel_rank):
mpu.set_virtual_pipeline_model_parallel_rank(pipeline_model_parallel_rank)
@staticmethod
def get_pipeline_model_parallel_rank():
return mpu.get_pipeline_model_parallel_rank()
class MegatronLegacyModel(MegatronModel):
def __init__(self, args_cmd, md=None):
@ -535,31 +791,33 @@ class MegatronMCoreModel(MegatronModel):
super(MegatronMCoreModel, self).__init__(args_cmd, md)
def get_module_mapping(self):
module_tp_rank = "module[tp_rank]."
module_layer = module_tp_rank + "decoder.layers[layer_idx]."
module_layer = "decoder.layers[layer_idx]."
self.module_mapping = {
"embedding": module_tp_rank + "embedding",
"embedding_word_embeddings": module_tp_rank + "embedding.word_embeddings",
"embedding_word_embeddings_norm": module_tp_rank + "embedding.word_embeddings.norm",
"model": "module[tp_rank]",
"embedding": "embedding",
"embedding_word_embeddings": "embedding.word_embeddings",
"embedding_word_embeddings_norm": "embedding.word_embeddings.norm",
"embedding_position_embeddings": "embedding.position_embeddings",
"model": "module",
"layers_input_layernorm": module_layer + "input_layernorm",
"layers": module_tp_rank + "decoder.layers",
"layers": "decoder.layers",
"layers_self_attention_linear_proj": module_layer + "self_attention.linear_proj",
"layers_self_attention_linear_qkv": module_layer + "self_attention.linear_qkv",
"layers_self_attention_q_layernorm": module_layer + "self_attention.q_layernorm",
"layers_self_attention_k_layernorm": module_layer + "self_attention.k_layernorm",
"layers_self_attention_post_attention_layernorm": module_layer + "pre_mlp_layernorm",
"layers_self_attention_pre_mlp_layernorm": module_layer + "pre_mlp_layernorm",
"layers_mlp_linear_fc1": module_layer + "mlp.linear_fc1",
"layers_mlp_linear_fc2": module_layer + "mlp.linear_fc2",
"final_layernorm": module_tp_rank + "decoder.final_layernorm",
"output_layer": module_tp_rank + "output_layer"
"layers_self_attention_post_mlp_layernorm": module_layer + "post_mlp_layernorm",
"final_layernorm": "decoder.final_layernorm",
"output_layer": "output_layer"
}
def has_embedding_word_embeddings_norm(self):
return hasattr(self.get_embedding_word_embeddings_module(), 'norm')
def has_embedding_position_embeddings(self):
return hasattr(self.get_embedding_module(), 'position_embeddings')
config_value = self.model_cfg.get(self.args_cmd.model_type_hf).get('config_set_value')
if config_value.get('moe_flag', False):
self.module_mapping["layers_mlp_router"] = module_layer + "mlp.router"
self.module_mapping["layers_mlp_linear_fc1"] = module_layer + "mlp.experts.local_experts[expert_idx].linear_fc1"
self.module_mapping["layers_mlp_linear_fc2"] = module_layer + "mlp.experts.local_experts[expert_idx].linear_fc2"
def get_megatron_model(args_cmd, md=None):
@ -571,5 +829,3 @@ def get_megatron_model(args_cmd, md=None):
def get_huggingface_model(args_cmd):
return HuggingfaceModel(args_cmd)

View File

@ -16,13 +16,15 @@
import os
import sys
import time
import copy
import logging as logger
import torch
from megatron.training.checkpointing import save_checkpoint
from megatron.core import mpu
from models import get_megatron_model
logger.basicConfig(format="")
logger.getLogger().setLevel(logger.INFO)
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron saver')
@ -30,21 +32,25 @@ def add_arguments(parser):
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of Megatron repository')
group.add_argument('--target-tensor-parallel-size', type=int,
group.add_argument('--target-tensor-parallel-size', type=int, default=1,
help='Target tensor model parallel size, defaults to the tensor parallel size '
'in the input checkpoint if provided by the loader, otherwise to 1')
group.add_argument('--target-pipeline-parallel-size', type=int,
group.add_argument('--target-pipeline-parallel-size', type=int, default=1,
help='Target tensor model parallel size, default to the pipeline parall size '
'in the input checkpoint if provided by the loader, otherwise to 1')
group.add_argument('--save-model-type', type=str, default='megatron',
choices=['megatron'], help='Save model type')
choices=['megatron', 'huggingface'], help='Save model type')
group.add_argument("--w-pack", type=bool,
help='True is w_pack weight for llm',
default=False)
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage')
group.add_argument('--target-expert-parallel-size', type=int, default=1,
help='Number of layers per virtual pipeline stage')
group.add_argument('--num-layer-list',
type=str, help='a list of number of layers, seperated by comma; e.g., 4,4,4,4')
group.add_argument('--use-mcore-models', action='store_true',
help='Use the implementation from megatron core')
def update_padded_vocab_size(md, model_mg, orig_tensor, orig_word_embed):
@ -55,7 +61,7 @@ def update_padded_vocab_size(md, model_mg, orig_tensor, orig_word_embed):
padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
model_mg.set_padded_vocab_size(padded_vocab_size)
else:
print("Original vocab size not specified, leaving embedding table as-is. "
logger.warning("Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems.")
model_mg.set_padded_vocab_size(orig_word_embed.shape[0])
margs = model_mg.get_args()
@ -90,56 +96,254 @@ def reset_cmd_args_from_md(args, md):
if hasattr(md, 'previous_tensor_parallel_size'):
args.target_tensor_parallel_size = md.previous_tensor_parallel_size
else:
print(
"loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
"Default to 1.")
logger.warning("loader did not provide a tensor parallel size and "
"--target-tensor-parallel-size not provided on command line. Default to 1.")
args.target_tensor_parallel_size = 1
if args.target_pipeline_parallel_size is None:
if hasattr(md, 'previous_pipeline_parallel_size'):
args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
else:
print(
"loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
"Default to 1.")
logger.warning(
"loader did not provide a pipeline parallel size and "
"--target-pipeline-parallel-size not provided on command line. Default to 1.")
args.target_pipeline_parallel_size = 1
def set_model_preprocess(model, embeddings_msg, check_message):
def set_model_preprocess(model, embeddings_msg):
md = model.get_metadata()
margs = model.get_args()
pos_embed = None
tp_size = margs.tensor_model_parallel_size
ep_size = margs.expert_model_parallel_size
if md.position_embedding_type == 'learned_absolute':
pos_embed = embeddings_msg.pop("position embeddings")
orig_word_embed = embeddings_msg.pop("word embeddings")
pos_embed = embeddings_msg.pop(f"position embeddings")
orig_word_embed = embeddings_msg.pop(f"word embeddings")
orig_word_embed_n_w, orig_word_embed_n_b = None, None
if "word embeddings norm_w" in embeddings_msg and "word embeddings norm_b" in embeddings_msg:
orig_word_embed_n_w = embeddings_msg.pop("word embeddings norm_w")
orig_word_embed_n_b = embeddings_msg.pop("word embeddings norm_b")
check_message(embeddings_msg)
if md.true_vocab_size is not None:
orig_vocab_size = orig_word_embed.shape[0]
full_word_embed = vocab_padding(orig_vocab_size, margs.padded_vocab_size, orig_word_embed)
else:
full_word_embed = orig_word_embed
# Split into new tensor model parallel sizes tensor_model_parallel_size
out_word_embed = torch.chunk(full_word_embed, margs.tensor_model_parallel_size, dim=0)
modules_count = model.get_modules_count()
for tp_rank in range(modules_count):
model.set_embedding_word_embeddings_weight(tp_rank=tp_rank, data=out_word_embed[tp_rank])
if orig_word_embed_n_w is not None:
model.set_embedding_word_embeddings_norm_weight(tp_rank=tp_rank, data=orig_word_embed_n_w)
model.set_embedding_word_embeddings_norm_bias(tp_rank=tp_rank, data=orig_word_embed_n_b)
if pos_embed is not None:
model.set_embedding_position_embeddings_weight(tp_rank=tp_rank, data=pos_embed)
if "word embeddings norm_w" in embeddings_msg:
orig_word_embed_n_w = embeddings_msg.pop(f"word embeddings norm_w")
if "word embeddings norm_b" in embeddings_msg:
orig_word_embed_n_b = embeddings_msg.pop(f"word embeddings norm_b")
out_word_embed_list = []
for ep_rank in range(ep_size):
if md.true_vocab_size is not None:
orig_vocab_size = orig_word_embed.shape[0]
full_word_embed = vocab_padding(orig_vocab_size, margs.padded_vocab_size, orig_word_embed)
else:
if hasattr(model.get_embedding_module(), 'position_embeddings'):
raise ValueError("model should have position_embeddings")
full_word_embed = orig_word_embed
return out_word_embed
# Split into new tensor model parallel sizes tensor_model_parallel_size
out_word_embed = torch.chunk(full_word_embed, margs.tensor_model_parallel_size, dim=0)
for tp_rank in range(tp_size):
model.set_embedding_word_embeddings_weight(ep_rank=ep_rank, tp_rank=tp_rank, data=out_word_embed[tp_rank])
if orig_word_embed_n_w is not None:
model.set_embedding_word_embeddings_norm_weight(ep_rank=ep_rank, tp_rank=tp_rank, data=orig_word_embed_n_w)
if orig_word_embed_n_b is not None:
model.set_embedding_word_embeddings_norm_bias(ep_rank=ep_rank, tp_rank=tp_rank, data=orig_word_embed_n_b)
if pos_embed is not None:
model.set_embedding_position_embeddings_weight(ep_rank=ep_rank, tp_rank=tp_rank, data=pos_embed)
else:
if hasattr(model.get_embedding_module(), 'position_embeddings'):
raise ValueError("model should have position_embeddings")
out_word_embed_list.append(out_word_embed)
return out_word_embed_list
def set_model_layer_norm(model_mg, msg, md, **kwargs):
# duplicated tensors
input_norm_weight = msg.pop("input norm weight")
post_norm_weight = msg.pop("post norm weight")
input_norm_bias = None
post_norm_bias = None
if md.norm_has_bias:
input_norm_bias = msg.pop("input norm bias")
if md.norm_has_bias:
post_norm_bias = msg.pop("post norm bias")
margs = model_mg.get_args()
# Save them to the model
for ep_rank in range(margs.expert_model_parallel_size):
kwargs["ep_rank"] = ep_rank
for tp_rank in range(margs.tensor_model_parallel_size):
kwargs["tp_rank"] = tp_rank
model_mg.set_layers_input_layernorm_weight(**kwargs, data=input_norm_weight)
if input_norm_bias:
model_mg.set_layers_input_layernorm_bias(**kwargs, data=input_norm_bias)
model_mg.set_layers_self_attention_pre_mlp_layernorm_weight(**kwargs, data=post_norm_weight)
if post_norm_bias:
model_mg.set_layers_self_attention_pre_mlp_layernorm_bias(**kwargs, data=post_norm_bias)
def set_model_layer_attn(model_mg, msg, md, **kwargs):
# duplicated tensors
margs = model_mg.get_args()
if md.linear_bias or margs.add_dense_bias:
dense_bias = msg.pop("dense bias")
if md.linear_bias or margs.add_qkv_bias:
qkv_bias = torch.chunk(msg.pop("qkv bias"), margs.tensor_model_parallel_size, dim=0)
qkv_org = msg.pop("qkv weight")
qkv_weight = torch.chunk(qkv_org, margs.tensor_model_parallel_size, dim=0)
# Split up the parallel tensors
dense_weight = torch.chunk(msg.pop("dense weight"), margs.tensor_model_parallel_size, dim=1)
# Save them to the model
for ep_rank in range(margs.expert_model_parallel_size):
kwargs["ep_rank"] = ep_rank
for tp_rank in range(margs.tensor_model_parallel_size):
kwargs["tp_rank"] = tp_rank
model_mg.set_layers_self_attention_linear_qkv_weight(**kwargs, data=qkv_weight[tp_rank])
model_mg.set_layers_self_attention_linear_proj_weight(**kwargs, data=dense_weight[tp_rank])
if md.linear_bias:
model_mg.set_layers_self_attention_linear_qkv_bias(**kwargs, data=qkv_bias[tp_rank])
model_mg.set_layers_self_attention_linear_proj_bias(**kwargs, data=dense_bias)
if margs.add_qkv_bias:
model_mg.set_layers_self_attention_linear_qkv_bias(**kwargs, data=qkv_bias[tp_rank])
if margs.add_dense_bias:
model_mg.set_layers_self_attention_linear_proj_bias(**kwargs, data=dense_bias)
def _set_set_model_layer_mlp(model_mg, msg, md, **kwargs):
margs = model_mg.get_args()
num_experts_local = 1
if margs.num_experts:
num_experts_local = margs.num_experts // margs.expert_model_parallel_size
# Save them to the model
if md.linear_bias:
mlp_l1_bias = msg.pop(f"mlp l1 bias")
# Split up the parallel tensors
mlp_l1_weight = torch.chunk(msg.pop(f"mlp l1 weight"), margs.tensor_model_parallel_size, dim=1)
# Special handling for swiglu
if md.swiglu:
mlp_l0_weight_W = torch.chunk(msg.pop(f"mlp l0 weight W"), margs.tensor_model_parallel_size, dim=0)
mlp_l0_weight_V = torch.chunk(msg.pop(f"mlp l0 weight V"), margs.tensor_model_parallel_size, dim=0)
mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)]
else:
mlp_l0_weight = torch.chunk(msg.pop(f"mlp l0 weight"), margs.tensor_model_parallel_size, dim=0)
if md.linear_bias:
if md.swiglu:
mlp_l0_bias_W = torch.chunk(msg.pop(f"mlp l0 bias W"), margs.tensor_model_parallel_size, dim=0)
mlp_l0_bias_V = torch.chunk(msg.pop(f"mlp l0 bias V"), margs.tensor_model_parallel_size, dim=0)
mlp_l0_bias = [torch.cat(bias, dim=0) for bias in zip(mlp_l0_bias_W, mlp_l0_bias_V)]
else:
mlp_l0_bias = torch.chunk(msg.pop(f"mlp l0 bias"), margs.tensor_model_parallel_size, dim=0)
# duplicated tensors
for tp_rank in range(margs.tensor_model_parallel_size):
kwargs["tp_rank"] = tp_rank
model_mg.set_layers_mlp_linear_fc1_weight(**kwargs, data=mlp_l0_weight[tp_rank])
model_mg.set_layers_mlp_linear_fc2_weight(**kwargs, data=mlp_l1_weight[tp_rank])
if md.linear_bias:
model_mg.set_layers_mlp_linear_fc1_bias(**kwargs, data=mlp_l0_bias[tp_rank])
model_mg.set_layers_mlp_linear_fc2_bias(**kwargs, data=mlp_l1_bias)
def set_model_layer_mlp(model_mg, msg, md, **kwargs):
margs = model_mg.get_args()
if margs.num_experts:
num_experts_local = margs.num_experts // margs.expert_model_parallel_size
mlp_moe = msg.pop("mlp_moe")
mlp_router_weight = mlp_moe.pop("mlp router weight")
for ep_rank in range(margs.expert_model_parallel_size):
kwargs["ep_rank"] = ep_rank
for tp_rank in range(margs.tensor_model_parallel_size):
kwargs['tp_rank'] = tp_rank
model_mg.set_layers_mlp_router_weight(**kwargs, data=mlp_router_weight)
for expert_idx in range(num_experts_local):
kwargs["expert_idx"] = expert_idx
global_expert_idx = expert_idx + ep_rank * num_experts_local
expert = mlp_moe.pop(f"expert {global_expert_idx}")
_set_set_model_layer_mlp(model_mg, expert, md, **kwargs)
else:
_set_set_model_layer_mlp(model_mg, msg, md, **kwargs)
def set_model_postprocess(model_mg, msg, md, out_word_embed_list, **kwargs):
margs = model_mg.get_args()
tp_size = margs.tensor_model_parallel_size
ep_size = margs.expert_model_parallel_size
final_norm_weight = msg.pop(f"weight")
final_norm_bias = None
if md.norm_has_bias:
final_norm_bias = msg.pop(f"bias")
for ep_rank in range(ep_size):
kwargs["ep_rank"] = ep_rank
for tp_rank in range(tp_size):
kwargs["tp_rank"] = tp_rank
model_mg.set_final_layernorm_weight(**kwargs, data=final_norm_weight)
if final_norm_bias:
model_mg.set_final_layernorm_bias(**kwargs, data=final_norm_bias)
if kwargs.get("pp_rank", 0) and not md.output_layer:
# Copy word embeddings to final pipeline rank
model_mg.set_output_layer_weight(**kwargs, data=out_word_embed_list[ep_rank][tp_rank])
del final_norm_weight
if final_norm_bias:
del final_norm_bias
def set_model_output_layer(model_mg, msg, md, **kwargs):
margs = model_mg.get_args()
tp_size = margs.tensor_model_parallel_size
ep_size = margs.expert_model_parallel_size
output_layer = msg.pop(f"weight")
for ep_rank in range(ep_size):
kwargs["ep_rank"] = ep_rank
if md.true_vocab_size is not None:
orig_vocab_size = output_layer.shape[0]
full_word_embed = vocab_padding(orig_vocab_size, margs.padded_vocab_size, output_layer)
else:
full_word_embed = output_layer
output_layer_weight = torch.chunk(full_word_embed, margs.tensor_model_parallel_size, dim=0)
for tp_rank in range(tp_size):
kwargs["tp_rank"] = tp_rank
model_mg.set_output_layer_weight(**kwargs, data=output_layer_weight[tp_rank])
def save_model(model_mg, md, **kwargs):
margs = model_mg.get_args()
args_cmd = model_mg.get_args_cmd()
virtual_pipeline_model_parallel_size = margs.virtual_pipeline_model_parallel_size
if virtual_pipeline_model_parallel_size is None:
virtual_pipeline_model_parallel_size = 1
for ep_rank in range(margs.expert_model_parallel_size):
model_mg.set_expert_model_parallel_rank(ep_rank)
kwargs["ep_rank"] = ep_rank
for tp_rank in range(margs.tensor_model_parallel_size):
model_mg.set_tensor_model_parallel_rank(tp_rank)
kwargs["tp_rank"] = tp_rank
vp_models = []
for vp_rank in range(virtual_pipeline_model_parallel_size):
kwargs["vp_rank"] = vp_rank
vp_models.append(model_mg.get_model_item(**kwargs))
if args_cmd.save_model_type == 'megatron':
# Split the PP into multiple VPPs and select the corresponding layers for each VPP by copying and deleting
save_checkpoint(md.iteration, vp_models, None, None, 0)
elif args_cmd.save_model_type == "huggingface":
save_huggingface(args_cmd, model_mg)
def save_huggingface(args, model):
'''Set model params.'''
from models import get_huggingface_model
model_hf = get_huggingface_model(args)
model_hf.get_modules_from_pretrained()
args_cmd = model_hf.get_args_cmd()
model_hf.update_module(model)
save_dir = os.path.join(args_cmd.save_dir, 'mg2hf')
logger.info(f'save weight to {save_dir}')
model_hf.get_model_item().save_pretrained(save_dir)
def save_model_checkpoint(queue, args):
@ -154,14 +358,14 @@ def save_model_checkpoint(queue, args):
def queue_get(name=None):
val = queue.get()
if val == "exit":
print("Loader exited, exiting saver")
logger.error("Loader exited, exiting saver")
exit(1)
if name is not None and args.checking and val["name"] != name:
val_name = val["name"]
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
logger.error(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
exit(1)
if name is not None:
print(f"received {name}")
logger.info(f"received {name}")
return val
def check_message(msg):
@ -169,10 +373,10 @@ def save_model_checkpoint(queue, args):
return
msg_name = msg.pop("name")
if len(msg.keys()) > 0:
print(f"Unexpected values in {msg_name}:")
logger.error(f"Unexpected values in {msg_name}:")
for key in msg.keys():
print(f" {key}")
print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
logger.error(f" {key}")
logger.error(f"Exiting. If you want to ignore this, use the argument --no-checking.")
exit(1)
md = queue_get()
@ -185,182 +389,59 @@ def save_model_checkpoint(queue, args):
# We want all arguments to come from us
model_mg = get_megatron_model(args_cmd=args, md=md)
model_mg.initialize_megatron_args(queue=queue)
model_mg.initialize_megatron_args(queue=queue, saver_megatron=True)
# Make models for first pipeline stage and fill in embeddings
mpu.set_pipeline_model_parallel_rank(0)
post_process = args.target_pipeline_parallel_size == 1
model_mg.get_modules_from_config(args.target_tensor_parallel_size, pre_process=True, post_process=post_process)
model_mg.get_modules_from_config(pp_stage_cache_flag=True)
# Embeddings
embeddings_msg = queue_get("embeddings")
out_word_embed = set_model_preprocess(model_mg, embeddings_msg, check_message)
out_word_embed_list = set_model_preprocess(model_mg, embeddings_msg)
check_message(embeddings_msg)
margs = model_mg.get_args()
# Transformer layers
# -------------------
total_layer_num = 0
lst = []
if args.num_layers_per_virtual_pipeline_stage and args.save_model_type == 'megatron':
times = 3
while queue.qsize() > 3 or times >= 0:
if times >= 0:
time.sleep(1)
times -= 1
continue
lst.append(queue.get())
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
if pp_rank > 0:
virtual_pipeline_model_parallel_size = margs.virtual_pipeline_model_parallel_size
if virtual_pipeline_model_parallel_size is None:
virtual_pipeline_model_parallel_size = 1
for vp_rank in range(virtual_pipeline_model_parallel_size):
model_mg.set_virtual_pipeline_model_parallel_rank(vp_rank)
kwargs = {"vp_rank": vp_rank}
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
mpu.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == args.target_pipeline_parallel_size - 1
model_mg.get_modules_from_config(args.target_tensor_parallel_size, False, post_process)
if args.num_layers_per_virtual_pipeline_stage and args.save_model_type == 'megatron':
vp_size = margs.num_layers // args.target_pipeline_parallel_size // args.num_layers_per_virtual_pipeline_stage
else:
vp_size = 1
for vpp_rank in range(vp_size):
for layer in range(len(model_mg.get_layers_module()) // vp_size):
if args.num_layers_per_virtual_pipeline_stage and args.save_model_type == 'megatron':
# The execution order between layers in the VPP model is different from that in the PP model. Here,
# it is necessary to calculate the index and arrange the layers in the actual execution order
total_layer_num = args.target_pipeline_parallel_size * vpp_rank * args.num_layers_per_virtual_pipeline_stage + pp_rank * args.num_layers_per_virtual_pipeline_stage + layer
msg = lst[total_layer_num]
else:
msg = queue_get(f"transformer layer {total_layer_num}")
# duplicated tensors
input_norm_weight = msg.pop("input norm weight")
if md.norm_has_bias:
input_norm_bias = msg.pop("input norm bias")
post_norm_weight = msg.pop("post norm weight")
if md.norm_has_bias:
post_norm_bias = msg.pop("post norm bias")
if md.linear_bias:
dense_bias = msg.pop("dense bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
if args.add_qkv_bias:
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
if args.add_dense_bias:
dense_bias = msg.pop("dense bias")
qkv_org = msg.pop("qkv weight")
qkv_weight = torch.chunk(qkv_org, args.target_tensor_parallel_size, dim=0)
# Split up the parallel tensors
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
# Special handling for swiglu
if md.swiglu:
mlp_l0_weight_W = torch.chunk(msg.pop("mlp l0 weight W"), args.target_tensor_parallel_size, dim=0)
mlp_l0_weight_V = torch.chunk(msg.pop("mlp l0 weight V"), args.target_tensor_parallel_size, dim=0)
mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)]
else:
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
if md.linear_bias:
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
if md.swiglu:
mlp_l0_bias_W = torch.chunk(msg.pop("mlp l0 bias W"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias_V = torch.chunk(msg.pop("mlp l0 bias V"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = [torch.cat(bias, dim=0) for bias in zip(mlp_l0_bias_W, mlp_l0_bias_V)]
else:
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
# Save them to the model
for tp_rank in range(args.target_tensor_parallel_size):
if args.num_layers_per_virtual_pipeline_stage and args.save_model_type == 'megatron':
l_idx = vpp_rank * args.num_layers_per_virtual_pipeline_stage + layer
else:
l_idx = layer
model_mg.set_layers_input_layernorm_weight(tp_rank=tp_rank, layer_idx=l_idx, data=input_norm_weight)
if md.norm_has_bias:
model_mg.set_layers_input_layernorm_bias(tp_rank=tp_rank, layer_idx=l_idx, data=input_norm_bias)
model_mg.set_layers_self_attention_linear_qkv_weight(tp_rank=tp_rank, layer_idx=l_idx, data=qkv_weight[tp_rank])
model_mg.set_layers_self_attention_linear_proj_weight(tp_rank=tp_rank, layer_idx=l_idx, data=dense_weight[tp_rank])
model_mg.set_layers_self_attention_pre_mlp_layernorm_weight(tp_rank=tp_rank, layer_idx=l_idx, data=post_norm_weight)
if md.norm_has_bias:
model_mg.set_layers_self_attention_pre_mlp_layernorm_bias(tp_rank=tp_rank, layer_idx=l_idx, data=post_norm_bias)
model_mg.set_layers_mlp_linear_fc1_weight(tp_rank=tp_rank, layer_idx=l_idx, data=mlp_l0_weight[tp_rank])
model_mg.set_layers_mlp_linear_fc2_weight(tp_rank=tp_rank, layer_idx=l_idx, data=mlp_l1_weight[tp_rank])
if md.linear_bias:
model_mg.set_layers_self_attention_linear_qkv_bias(tp_rank=tp_rank, layer_idx=l_idx, data=qkv_bias[tp_rank])
model_mg.set_layers_self_attention_linear_proj_bias(tp_rank=tp_rank, layer_idx=l_idx, data=dense_bias)
model_mg.set_layers_mlp_linear_fc1_bias(tp_rank=tp_rank, layer_idx=l_idx, data=mlp_l0_bias[tp_rank])
model_mg.set_layers_mlp_linear_fc2_bias(tp_rank=tp_rank, layer_idx=l_idx, data=mlp_l1_bias)
if args.add_qkv_bias:
model_mg.set_layers_self_attention_linear_qkv_bias(tp_rank=tp_rank, layer_idx=l_idx, data=qkv_bias[tp_rank])
if args.add_dense_bias:
model_mg.set_layers_self_attention_linear_proj_bias(tp_rank=tp_rank, layer_idx=l_idx, data=dense_bias)
model_mg.get_modules_from_config(pp_stage_cache_flag=True)
kwargs["pp_rank"] = pp_rank
for layer in range(len(model_mg.get_layers_module())):
kwargs["layer_idx"] = layer
msg = queue_get(f"transformer layer {total_layer_num}")
set_model_layer_norm(model_mg, msg, md, **kwargs)
set_model_layer_attn(model_mg, msg, md, **kwargs)
set_model_layer_mlp(model_mg, msg, md, **kwargs)
total_layer_num = total_layer_num + 1
check_message(msg)
if post_process:
msg = queue_get("final norm")
final_norm_weight = msg.pop("weight")
if md.norm_has_bias:
final_norm_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
model_mg.set_final_layernorm_weight(tp_rank=tp_rank, data=final_norm_weight)
if md.norm_has_bias:
model_mg.set_final_layernorm_bias(tp_rank=tp_rank, data=final_norm_bias)
if pp_rank != 0 and not md.output_layer:
# Copy word embeddings to final pipeline rank
model_mg.set_output_layer_weight(tp_rank=tp_rank, data=out_word_embed[tp_rank])
del final_norm_weight
if md.norm_has_bias:
del final_norm_bias
check_message(msg)
if md.output_layer:
msg = queue_get("output layer")
output_layer = msg.pop("weight")
if md.true_vocab_size is not None:
orig_vocab_size = output_layer.shape[0]
full_word_embed = vocab_padding(orig_vocab_size, margs.padded_vocab_size, output_layer)
else:
full_word_embed = output_layer
output_layer_weight = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
for tp_rank in range(args.target_tensor_parallel_size):
model_mg.set_output_layer_weight(tp_rank=tp_rank, data=output_layer_weight[tp_rank])
del output_layer_weight
post_process = (
(pp_rank == args.target_pipeline_parallel_size - 1) &
(vp_rank == virtual_pipeline_model_parallel_size - 1)
)
if post_process:
msg = queue_get("final norm")
set_model_postprocess(model_mg, msg, md, out_word_embed_list, **kwargs)
check_message(msg)
msg = queue_get()
for tp_rank in range(args.target_tensor_parallel_size):
mpu.set_tensor_model_parallel_rank(tp_rank)
# Split the PP into multiple VPPs and select the corresponding layers for each VPP by copying and deleting
if args.num_layers_per_virtual_pipeline_stage:
vp_models = []
layers = margs.num_layers // args.target_pipeline_parallel_size
for vp_rank in range(vp_size):
model = copy.deepcopy(model_mg.get_model_module(tp_rank=tp_rank))
left = vp_rank * args.num_layers_per_virtual_pipeline_stage
right = (vp_rank + 1) * args.num_layers_per_virtual_pipeline_stage
for i in range(layers - 1, -1, -1):
if i >= right or i < left:
del model.decoder.layers[i]
if right < layers and pp_rank == args.target_pipeline_parallel_size - 1:
del model.decoder.final_layernorm
if getattr(model, "output_layer", None):
model.post_process = False
del model.output_layer
if pp_rank == 0 and vp_rank > 0:
model.pre_process = False
del model.embedding
vp_models.append(model)
save_checkpoint(md.iteration, vp_models, None, None, 0)
else:
save_checkpoint(md.iteration, [model_mg.get_model_module(tp_rank=tp_rank)], None, None, 0)
print("Done!")
if md.output_layer:
msg = queue_get("output layer")
set_model_output_layer(model_mg, msg, md, **kwargs)
check_message(msg)
if vp_rank == virtual_pipeline_model_parallel_size - 1:
save_model(model_mg, md, **kwargs)
logger.info("Done!")