mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 05:17:40 +08:00
!248 LLama model权重转换支持自定义和GQA
Merge pull request !248 from Liuchang/master
This commit is contained in:
commit
ac98107b42
@ -262,7 +262,8 @@ def parse_args(extra_args_provider=None, defaults={},
|
||||
ensure_valid(args.is_instruction_dataset, 'Dynamic padding based on instruction dataset.')
|
||||
|
||||
if (args.position_embedding_type == PositionEmbeddingType.absolute or
|
||||
args.position_embedding_type == PositionEmbeddingType.alibi):
|
||||
args.position_embedding_type == PositionEmbeddingType.alibi or
|
||||
args.position_embedding_type == PositionEmbeddingType.rotary):
|
||||
ensure_var_is_not_none(args.max_position_embeddings)
|
||||
if not args.seq_length:
|
||||
ensure_valid(args.max_position_embeddings >= args.seq_length)
|
||||
|
@ -32,6 +32,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 28672 \
|
||||
--num-attention-heads 64 \
|
||||
--position-embedding-type rotary \
|
||||
--group-query-attention \
|
||||
--num-query-groups 8 \
|
||||
--micro-batch-size 2 \
|
||||
|
@ -68,20 +68,36 @@ def pad_embed(w, make_vocab_size_divisible_by, tp, added_token_num):
|
||||
return torch.cat([w, w[-(padded_size - w.shape[0]):, ...]], dim=0)
|
||||
|
||||
|
||||
def permute_qkv_weight(w, n_head, hidden_size, tp, split=False):
|
||||
""" adapt for ascendspeed llama qkv layer """
|
||||
def permute_qkv_weight(w, model_config, split=False):
|
||||
"""
|
||||
adapt for ascendspeed llama qkv layer
|
||||
Notation:
|
||||
n_head: Number of attention heads,
|
||||
kv_heads: Number of key and value heads,
|
||||
tp: Tensor model parallel size,
|
||||
np: Number of attention heads in per tensor partition,
|
||||
gp: Number of key and value heads in per tensor partition,
|
||||
"""
|
||||
n_head, hidden_size, tp, kv_heads = model_config
|
||||
if kv_heads is None:
|
||||
kv_heads = n_head
|
||||
|
||||
check_divisible(n_head, tp)
|
||||
check_divisible(hidden_size, n_head)
|
||||
check_divisible(kv_heads, tp)
|
||||
check_divisible(n_head, kv_heads)
|
||||
np = n_head // tp
|
||||
gp = kv_heads // tp
|
||||
repeats = np // gp
|
||||
hn = hidden_size // n_head
|
||||
w_s0, w_s1 = w.shape
|
||||
check_equal(w_s0, np * hn * 3)
|
||||
check_equal(w_s0, (repeats + 2) * gp * hn)
|
||||
if not split:
|
||||
return w.reshape(3, np, hn, w.shape[1]).contiguous().permute(1, 0, 2, 3).reshape(w_s0,
|
||||
w_s1).contiguous().clone()
|
||||
return w.reshape(repeats + 2, gp, hn, w.shape[1]).contiguous().permute(
|
||||
1, 0, 2, 3).reshape(w_s0, w_s1).contiguous().clone()
|
||||
else:
|
||||
return w.reshape(np, 3, hn, w.shape[1]).contiguous().permute(1, 0, 2, 3).reshape(w_s0,
|
||||
w_s1).contiguous().clone()
|
||||
return w.reshape(gp, repeats + 2, hn, w.shape[1]).contiguous().permute(
|
||||
1, 0, 2, 3).reshape(w_s0, w_s1).contiguous().clone()
|
||||
|
||||
|
||||
def permute_qkv_bias(bias, n_head, hidden_size, tp, split=False):
|
||||
|
@ -46,10 +46,19 @@ def get_args():
|
||||
parser.add_argument("--pipeline-model-parallel-size", type=int, default=1,
|
||||
help="degree of pipeline model parallel")
|
||||
parser.add_argument("--added-token-num", type=int, default=0, help="the number of added tokens")
|
||||
parser.add_argument("--type", type=str, choices=["7B", "13B", "30B", "65B"], default="7B")
|
||||
parser.add_argument("--type", type=str, default="7B", help="There are four predefined types: [7B, 13B, 30B, 65B]")
|
||||
parser.add_argument("--num_layers", type=int, default=1,
|
||||
help="num layers")
|
||||
parser.add_argument("--num_heads", type=int, default=1,
|
||||
help="num heads")
|
||||
parser.add_argument("--num_kv_heads", type=int, default=None,
|
||||
help="num kv heads")
|
||||
parser.add_argument("--hidden_size", type=int, default=1,
|
||||
help="hidden size")
|
||||
parser.add_argument("--bias", action="store_true", default=False)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False)
|
||||
|
||||
parser.add_argument("--merge-mlp", action="store_true", default=False,
|
||||
help="Merge gate and up mlp")
|
||||
parser.add_argument("--pse", action="store_true", default=False)
|
||||
parser.add_argument("--use_wpack_rotray", action="store_true", default=False)
|
||||
parser.add_argument("--load_weight_map", action="store_true", default=False)
|
||||
@ -65,6 +74,10 @@ model_config = {
|
||||
|
||||
|
||||
args = get_args()
|
||||
|
||||
if args.type not in model_config:
|
||||
model_config[args.type] = [args.num_layers, args.hidden_size, args.num_heads]
|
||||
|
||||
file = os.listdir(args.input_model_dir)
|
||||
model_files = [f for f in file if f[-4:] == ".bin"]
|
||||
input_models = {f: torch.load(os.path.join(args.input_model_dir, f), map_location="cpu") for f in model_files}
|
||||
@ -141,7 +154,7 @@ def generate_ascendspeed_weights_again(config):
|
||||
vw = row_split(get_weight_from_name(f"model.layers.{ori_i}.self_attn.v_proj.weight"), tp_size, tp_rank)
|
||||
|
||||
|
||||
permute_w = permute_qkv_weight(torch.cat([qw, kw, vw], dim=0), n_heads, hidden_size, tp_size)
|
||||
permute_w = permute_qkv_weight(torch.cat([qw, kw, vw], dim=0), (n_heads, hidden_size, tp_size, args.num_kv_heads))
|
||||
rank_model[f"language_model.layers.{pp_i}.attention.query_key_value.weight"] = permute_w
|
||||
|
||||
rank_model[f"language_model.layers.{pp_i}.attention.dense.weight"] = column_split(
|
||||
@ -158,11 +171,21 @@ def generate_ascendspeed_weights_again(config):
|
||||
rank_model[f"language_model.layers.{pp_i}.attention.query_key_value.bias"] = permute_bias
|
||||
rank_model[f"language_model.layers.{pp_i}.attention.dense.bias"] = \
|
||||
get_weight_from_name(f"model.layers.{ori_i}.self_attn.o_proj.bias")
|
||||
|
||||
rank_model[f"language_model.layers.{pp_i}.mlp.gate_proj.weight"] = row_split(
|
||||
|
||||
gate_proj = row_split(
|
||||
get_weight_from_name(f"model.layers.{ori_i}.mlp.gate_proj.weight"), tp_size, tp_rank)
|
||||
rank_model[f"language_model.layers.{pp_i}.mlp.up_proj.weight"] = row_split(
|
||||
up_proj = row_split(
|
||||
get_weight_from_name(f"model.layers.{ori_i}.mlp.up_proj.weight"), tp_size, tp_rank)
|
||||
if args.merge_mlp:
|
||||
rank_model[
|
||||
f"language_model.layers.{pp_i}.mlp.proj.weight"] = torch.cat(
|
||||
[gate_proj, up_proj], 0).contiguous().clone()
|
||||
else:
|
||||
rank_model[
|
||||
f"language_model.layers.{pp_i}.mlp.gate_proj.weight"] = gate_proj
|
||||
rank_model[
|
||||
f"language_model.layers.{pp_i}.mlp.up_proj.weight"] = up_proj
|
||||
|
||||
rank_model[f"language_model.layers.{pp_i}.mlp.down_proj.weight"] = column_split(
|
||||
get_weight_from_name(f"model.layers.{ori_i}.mlp.down_proj.weight"), tp_size, tp_rank)
|
||||
|
||||
|
@ -51,7 +51,20 @@ def get_args():
|
||||
parser.add_argument("--tgt-pipeline-model-parallel-size", type=int, default=1,
|
||||
help="degree of pipeline model parallel")
|
||||
parser.add_argument("--added-token-num", type=int, default=0, help="the number of added tokens")
|
||||
parser.add_argument("--type", type=str, choices=["7B", "13B", "30B", "65B"], default="7B")
|
||||
parser.add_argument("--type", type=str, default="7B",
|
||||
help="There are four predefined types: [7B, 13B, 30B, 65B]")
|
||||
parser.add_argument("--num_layers", type=int, default=1,
|
||||
help="num layers")
|
||||
parser.add_argument("--num_heads", type=int, default=1,
|
||||
help="num heads")
|
||||
parser.add_argument("--num_kv_heads", type=int, default=None,
|
||||
help="num kv heads")
|
||||
parser.add_argument("--hidden_size", type=int, default=1,
|
||||
help="hidden size")
|
||||
parser.add_argument("--bias", action="store_true", default=False)
|
||||
parser.add_argument("--deepspeed", action="store_true", default=False)
|
||||
parser.add_argument("--merge-mlp", action="store_true", default=False,
|
||||
help="Merge gate and up mlp")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -63,6 +76,7 @@ model_config = {
|
||||
"65B": [80, 8192, 64]
|
||||
}
|
||||
|
||||
args = get_args()
|
||||
entire_model = {}
|
||||
|
||||
|
||||
@ -97,8 +111,8 @@ def merge_pp_tp_models(config):
|
||||
pp_size, tp_size = config.pp_size, config.tp_size
|
||||
input_model_dir = config.input_model_dir
|
||||
orig_vocab_size = config.orig_vocab_size
|
||||
num_layer, num_heads, hid_size = config.num_heads, config.hid_size, config.hid_size
|
||||
|
||||
num_layer, num_heads, hid_size = config.num_layer, config.num_heads, config.hid_size
|
||||
repeats = num_heads // args.num_kv_heads
|
||||
global entire_model
|
||||
check_divisible(num_heads, tp_size)
|
||||
check_divisible(num_layer, pp_size)
|
||||
@ -137,14 +151,15 @@ def merge_pp_tp_models(config):
|
||||
# qkv split
|
||||
qkv_key = "language_model.layers.{}.attention.query_key_value.weight"
|
||||
qkv_len = tp_models[0][qkv_key.format(pp_i)].shape[0]
|
||||
check_divisible(qkv_len, 3)
|
||||
s1, s2 = qkv_len // 3, qkv_len // 3 * 2
|
||||
|
||||
qs = [permute_qkv_weight(tm[qkv_key.format(pp_i)], num_heads, hid_size, tp_size, split=True)[:s1,
|
||||
check_divisible(qkv_len, repeats + 2)
|
||||
s1, s2 = qkv_len // (repeats + 2) * repeats, qkv_len // (repeats + 2) * (repeats + 1)
|
||||
|
||||
qs = [permute_qkv_weight(tm[qkv_key.format(pp_i)], (num_heads, hid_size, tp_size, args.num_kv_heads), split=True)[:s1,
|
||||
...].clone() for tm in tp_models]
|
||||
ks = [permute_qkv_weight(tm[qkv_key.format(pp_i)], num_heads, hid_size, tp_size, split=True)[s1:s2,
|
||||
ks = [permute_qkv_weight(tm[qkv_key.format(pp_i)], (num_heads, hid_size, tp_size, args.num_kv_heads), split=True)[s1:s2,
|
||||
...].clone() for tm in tp_models]
|
||||
vs = [permute_qkv_weight(tm[qkv_key.format(pp_i)], num_heads, hid_size, tp_size, split=True)[s2:,
|
||||
vs = [permute_qkv_weight(tm[qkv_key.format(pp_i)], (num_heads, hid_size, tp_size, args.num_kv_heads), split=True)[s2:,
|
||||
...].clone() for tm in tp_models]
|
||||
|
||||
entire_model[qkv_key.format(g_i) + "_query"] = torch.cat(qs, dim=0)
|
||||
@ -155,6 +170,12 @@ def merge_pp_tp_models(config):
|
||||
"language_model.layers.{}.attention.dense.weight",
|
||||
pp_i, g_i, dim=1)
|
||||
merge_weight(merge_weight_config1)
|
||||
if args.merge_mlp:
|
||||
mlp_key = "language_model.layers.{}.mlp.".format(g_i)
|
||||
mlp_len = tp_models[0][mlp_key + "proj.weight"].shape[0] // 2
|
||||
for tm in tp_models:
|
||||
tm[mlp_key + "gate_proj.weight"] = tm[mlp_key + "proj.weight"][:mlp_len].clone()
|
||||
tm[mlp_key + "up_proj.weight"] = tm[mlp_key + "proj.weight"][mlp_len:].clone()
|
||||
merge_weight_config2 = MergeWeightConfig(entire_model, tp_models,
|
||||
"language_model.layers.{}.mlp.gate_proj.weight",
|
||||
pp_i, g_i, dim=0)
|
||||
@ -211,16 +232,26 @@ def generate_ascendspeed_weights(config):
|
||||
qw = row_split(get_weight_from_name(qkv_key + "_query"), tp_size, tp_rank)
|
||||
kw = row_split(get_weight_from_name(qkv_key + "_key"), tp_size, tp_rank)
|
||||
vw = row_split(get_weight_from_name(qkv_key + "_value"), tp_size, tp_rank)
|
||||
permute_w = permute_qkv_weight(torch.cat([qw, kw, vw], dim=0), num_heads, hid_size, tp_size)
|
||||
permute_w = permute_qkv_weight(torch.cat([qw, kw, vw], dim=0), (num_heads, hid_size, tp_size, args.num_kv_heads))
|
||||
rank_model[f"language_model.layers.{pp_i}.attention.query_key_value.weight"] = permute_w
|
||||
|
||||
rank_model[f"language_model.layers.{pp_i}.attention.dense.weight"] = column_split(
|
||||
get_weight_from_name(f"language_model.layers.{g_i}.attention.dense.weight"), tp_size, tp_rank)
|
||||
|
||||
rank_model[f"language_model.layers.{pp_i}.mlp.gate_proj.weight"] = row_split(
|
||||
gate_proj = row_split(
|
||||
get_weight_from_name(f"language_model.layers.{g_i}.mlp.gate_proj.weight"), tp_size, tp_rank)
|
||||
rank_model[f"language_model.layers.{pp_i}.mlp.up_proj.weight"] = row_split(
|
||||
up_proj = row_split(
|
||||
get_weight_from_name(f"language_model.layers.{g_i}.mlp.up_proj.weight"), tp_size, tp_rank)
|
||||
if args.merge_mlp:
|
||||
rank_model[
|
||||
f"language_model.layers.{pp_i}.mlp.proj.weight"] = torch.cat(
|
||||
[gate_proj, up_proj], 0).contiguous().clone()
|
||||
else:
|
||||
rank_model[
|
||||
f"language_model.layers.{pp_i}.mlp.gate_proj.weight"] = gate_proj
|
||||
rank_model[
|
||||
f"language_model.layers.{pp_i}.mlp.up_proj.weight"] = up_proj
|
||||
|
||||
rank_model[f"language_model.layers.{pp_i}.mlp.down_proj.weight"] = column_split(
|
||||
get_weight_from_name(f"language_model.layers.{g_i}.mlp.down_proj.weight"), tp_size, tp_rank)
|
||||
|
||||
@ -248,11 +279,13 @@ def print_result(arg):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
if args.type in model_config:
|
||||
n_layer, hidden_size, n_heads = model_config[args.type]
|
||||
else:
|
||||
raise KeyError(f"{args.type} is not in {model_config}")
|
||||
n_layer, hidden_size, n_heads = args.num_layers, args.hidden_size, args.num_heads
|
||||
|
||||
if args.num_kv_heads is None:
|
||||
args.num_kv_heads = n_heads
|
||||
|
||||
check_model_dir(args.input_model_dir, args.src_tensor_model_parallel_size, args.src_pipeline_model_parallel_size)
|
||||
make_ascendspeed_model_dirs(args.output_model_dir)
|
||||
|
Loading…
Reference in New Issue
Block a user