!1658 【mcore】添加baichuan2-13B和deepseek2-coder适配

Merge pull request !1658 from xiongliangcheng/deepseek-coder
This commit is contained in:
xiongliangcheng 2024-09-12 01:04:21 +00:00 committed by i-robot
parent b3c029ddd9
commit 2b84f8f52a
14 changed files with 545 additions and 21 deletions

View File

@ -186,10 +186,10 @@ ModelLink 通过模型并行与数据并行来训练大语言模型,为了演
<tr>
<td><a href="https://huggingface.co/baichuan-inc/Baichuan2-13B-Base/tree/main">13B</a></td>
<td>4K</td>
<th>Legacy</th>
<th>Mcore</th>
<td> 1x8</td>
<td> BF16 </td>
<td> 1668 </td>
<td> 1754 </td>
<td> -- </td>
<td> 2062 </td>
<td><center>【昇腾】</td>

View File

@ -72,7 +72,7 @@ def main():
help='Do not perform checking on the name and ordering of weights',
dest='checking')
parser.add_argument('--model-type-hf', type=str, default="llama2",
choices=['llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'bloom', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm-moe'],
choices=['baichuan2', 'llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'bloom', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm-moe'],
help='model type of huggingface')
known_args, _ = parser.parse_known_args()

View File

@ -0,0 +1,16 @@
# 修改 ascend-toolkit 路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 权重格式转换
python convert_ckpt.py \
--use-mcore-models \
--model-type-hf baichuan2 \
--model-type GPT \
--load-model-type hf \
--save-model-type mg \
--params-dtype bf16 \
--target-tensor-parallel-size 8 \
--target-pipeline-parallel-size 1 \
--load-dir ./model_from_hf/Baichuan2-13B_hf/ \
--save-dir ./model_weights/Baichuan2-13B_mcore/ \
--tokenizer-model ./model_from_hf/Baichuan2-13B_hf/tokenizer.model

View File

@ -0,0 +1,15 @@
# 修改 ascend-toolkit 路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 权重格式转换
python convert_ckpt.py \
--use-mcore-models \
--model-type-hf baichuan2 \
--model-type GPT \
--load-model-type mg \
--save-model-type hf \
--params-dtype bf16 \
--target-tensor-parallel-size 1 \
--target-pipeline-parallel-size 1 \
--load-dir ./model_weights/Baichuan2-13B_mcore/ \
--save-dir ./model_from_hf/Baichuan2-13B_hf/

View File

@ -0,0 +1,12 @@
# 请按照您的真实环境修改 set_env.sh 路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
mkdir ./dataset
# 数据集下载地址 https://huggingface.co/datasets/tatsu-lab/alpaca/resolve/main/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet
python ./preprocess_data.py \
--input ./dataset/train-00000-of-00001-a09b74b3ef9c3b56.parquet \
--tokenizer-name-or-path ./model_from_hf/Baichuan-hf/ \
--output-prefix ./dataset/alpaca \
--workers 4 \
--log-interval 1000 \
--tokenizer-type PretrainedFromHF

View File

@ -0,0 +1,53 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6001
NNODES=1
NODE_RANK=0
NPUS_PER_NODE=8
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
CHECKPOINT="Your ckpt file path"
TOKENIZER_PATH="Your tokenizer path"
DATA_PATH="./boolq/data/test/"
TASK="boolq"
# Different task needs different max_new_tokens value, please follow the instruction in readme.
torchrun $DISTRIBUTED_ARGS evaluation.py \
--task-data-path ${DATA_PATH} \
--task ${TASK} \
--seq-length 4096 \
--max-new-tokens 1 \
--max-position-embeddings 4096 \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 \
--num-layers 40 \
--hidden-size 5120 \
--ffn-hidden-size 13696 \
--num-attention-heads 40 \
--disable-bias-linear \
--swiglu \
--position-embedding-type alibi \
--square-alibi-mask \
--fill-neg-inf \
--load $CHECKPOINT \
--normalization RMSNorm \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_PATH} \
--tokenizer-not-use-fast \
--fp16 \
--micro-batch-size 1 \
--use-fused-rmsnorm \
--exit-on-missing-checkpoint \
--no-load-rng \
--no-load-optim \
--untie-embeddings-and-output-weights \
--no-masked-softmax-fusion \
--make-vocab-size-divisible-by 32 \
--use-mcore-models \
--seed 42 | tee logs/eval_baichuan2_13b_mcore_${TASK}.log

View File

@ -0,0 +1,53 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
# please fill these path configurations
CHECKPOINT="your model directory path"
TOKENIZER_PATH="your tokenizer directory path"
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6001
NNODES=1
NODE_RANK=0
NPUS_PER_NODE=8
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
torchrun $DISTRIBUTED_ARGS inference.py \
--use-mcore-models \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 \
--num-layers 40 \
--hidden-size 5120 \
--ffn-hidden-size 13696 \
--seq-length 1024 \
--max-new-tokens 256 \
--micro-batch-size 1 \
--global-batch-size 16 \
--num-attention-heads 40 \
--max-position-embeddings 2048 \
--position-embedding-type alibi \
--square-alibi-mask \
--fill-neg-inf \
--swiglu \
--load ${CHECKPOINT} \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_PATH} \
--tokenizer-not-use-fast \
--fp16 \
--normalization RMSNorm \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--attention-softmax-in-fp32 \
--no-load-optim \
--no-load-rng \
--no-masked-softmax-fusion \
--no-gradient-accumulation-fusion \
--exit-on-missing-checkpoint \
--make-vocab-size-divisible-by 32 \
| tee logs/generate_baichuan2_13b_mcore.log

View File

@ -0,0 +1,100 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
GPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
CKPT_SAVE_DIR="your model save ckpt path"
DATA_PATH="your data path"
TOKENIZER_MODEL="your tokenizer model path"
CKPT_LOAD_DIR="your model load ckpt path"
TP=8
PP=1
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
GPT_ARGS="
--use-mcore-models \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--sequence-parallel \
--use-mc2 \
--num-layers 40 \
--hidden-size 5120 \
--ffn-hidden-size 13696 \
--num-attention-heads 40 \
--tokenizer-type Llama2Tokenizer \
--tokenizer-model ${TOKENIZER_MODEL} \
--seq-length 4096 \
--disable-bias-linear \
--max-position-embeddings 4096 \
--micro-batch-size 2 \
--global-batch-size 128 \
--untie-embeddings-and-output-weights \
--no-gradient-accumulation-fusion \
--make-vocab-size-divisible-by 32 \
--lr 1e-5 \
--load ${CKPT_LOAD_DIR} \
--train-iters 2000 \
--lr-decay-style cosine \
--attention-dropout 0.0 \
--init-method-std 0.01 \
--position-embedding-type alibi \
--hidden-dropout 0.0 \
--norm-epsilon 1e-6 \
--normalization RMSNorm \
--use-fused-rmsnorm \
--use-flash-attn \
--use-fused-swiglu \
--use-mc2 \
--swiglu \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--square-alibi-mask \
--fill-neg-inf \
--min-lr 1e-8 \
--weight-decay 1e-4 \
--clip-grad 1.0 \
--seed 1234 \
--adam-beta1 0.9 \
--initial-loss-scale 8188.0 \
--adam-beta2 0.98 \
--adam-eps 1.0e-8 \
--no-load-optim \
--no-load-rng \
--bf16
"
DATA_ARGS="
--data-path ${DATA_PATH} \
--split 100,0,0
"
OUTPUT_ARGS="
--log-interval 1 \
--save-interval 2000 \
--eval-interval 2000 \
--eval-iters 0
"
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--save ${CKPT_SAVE_DIR} \
| tee logs/train_baichuan2_13b_mcore.log

View File

@ -0,0 +1,18 @@
# 修改 ascend-toolkit 路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 权重格式转换,设置需要的并行策略
python convert_ckpt.py \
--use-mcore-models \
--moe-grouped-gemm \
--model-type-hf deepseek2 \
--model-type GPT \
--load-model-type hf \
--save-model-type mg \
--params-dtype bf16 \
--target-tensor-parallel-size 1 \
--target-pipeline-parallel-size 1 \
--target-expert-parallel-size 8 \
--load-dir ./model_from_hf/deepseek2-coder-hf/ \
--save-dir ./model_weights/deepseek2-coder-mcore/ \
--tokenizer-model ./model_from_hf/deepseek2-coder-hf/

View File

@ -0,0 +1,17 @@
# 修改 ascend-toolkit 路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 权重格式转换
python convert_ckpt.py \
--use-mcore-models \
--moe-grouped-gemm \
--model-type-hf deepseek2 \
--model-type GPT \
--load-model-type mg \
--save-model-type hf \
--params-dtype bf16 \
--target-tensor-parallel-size 1 \
--target-pipeline-parallel-size 1 \
--target-expert-parallel-size 1 \
--load-dir ./model_weights/deepseek2-coder-mcore/ \
--save-dir ./model_from_hf/deepseek2-coder-hf/

View File

@ -0,0 +1,14 @@
# 请按照您的真实环境修改 set_env.sh 路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
mkdir ./dataset
# 数据集下载地址 https://huggingface.co/datasets/lsb/enwiki20230101/blob/main/data/train-00000-of-00042-d964455e17e96d5a.parquet
python ./preprocess_data.py \
--input ./dataset/train-00000-of-00042-d964455e17e96d5a.parquet \
--tokenizer-name-or-path ./model_from_hf/deepseek2-coder-hf/ \
--tokenizer-type PretrainedFromHF \
--handler-name GeneralPretrainHandler \
--output-prefix ./dataset/enwiki \
--json-keys text \
--workers 4 \
--log-interval 1000

View File

@ -0,0 +1,147 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
GPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
CKPT_SAVE_DIR="your model save ckpt path"
DATA_PATH="your data path"
TOKENIZER_MODEL="your tokenizer path"
CKPT_LOAD_DIR="your model ckpt path"
TP=1
PP=1
EP=8
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
MLA_ARGS="
--multi-head-latent-attention \
--qk-rope-head-dim 64 \
--qk-nope-head-dim 128 \
--q-lora-rank 1536 \
--kv-lora-rank 512 \
--v-head-dim 128 \
--qk-layernorm \
"
MOE_ARGS="
--moe-grouped-gemm \
--moe-permutation-async-comm \
--moe-token-dispatcher-type allgather \
--first-k-dense-replace 1 \
--moe-layer-freq 1 \
--n-shared-experts 2 \
--num-experts 160 \
--moe-router-topk 6 \
--moe-intermediate-size 1536 \
--moe-router-load-balancing-type group_limited_greedy \
--topk-group 3 \
--moe-aux-loss-coeff 0.003 \
--moe-device-level-aux-loss-coeff 0.05 \
--moe-comm-aux-loss-coeff 0.02 \
--routed-scaling-factor 16.0 \
--seq-aux
"
ROPE_ARGS="
--rope-scaling-beta-fast 32 \
--rope-scaling-beta-slow 1 \
--rope-scaling-factor 40 \
--rope-scaling-mscale 0.707 \
--rope-scaling-mscale-all-dim 0.707 \
--rope-scaling-original-max-position-embeddings 4096 \
--rope-scaling-type yarn
"
GPT_ARGS="
--load $CKPT_LOAD_DIR \
--use-distributed-optimizer \
--use-flash-attn \
--shape-order BNSD \
--use-mcore-models \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--expert-model-parallel-size ${EP} \
--sequence-parallel \
--output-layer-slice-num 10 \
--num-layers 2 \
--hidden-size 5120 \
--ffn-hidden-size 12288 \
--num-attention-heads 128 \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_MODEL} \
--seq-length 8192 \
--max-position-embeddings 163840 \
--micro-batch-size 1 \
--global-batch-size 64 \
--make-vocab-size-divisible-by 1 \
--lr 1.0e-5 \
--train-iters 2000 \
--lr-decay-style cosine \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--attention-dropout 0.0 \
--init-method-std 0.02 \
--hidden-dropout 0.0 \
--position-embedding-type rope \
--normalization RMSNorm \
--use-fused-rotary-pos-emb \
--use-rotary-position-embeddings \
--use-fused-swiglu \
--use-fused-rmsnorm \
--swiglu \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--min-lr 1.0e-7 \
--weight-decay 1e-2 \
--lr-warmup-iters 500 \
--clip-grad 1.0 \
--adam-beta1 0.9 \
--adam-beta2 0.999 \
--initial-loss-scale 65536 \
--vocab-size 102400 \
--padded-vocab-size 102400 \
--rotary-base 10000 \
--no-gradient-accumulation-fusion \
--norm-epsilon 1e-6 \
--no-load-optim \
--no-load-rng \
--bf16
"
DATA_ARGS="
--data-path $DATA_PATH \
--split 100,0,0
"
OUTPUT_ARGS="
--log-interval 1 \
--save-interval 2000 \
--eval-interval 2000 \
--eval-iters 0 \
--no-save-optim \
--no-save-rng
"
python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
$MLA_ARGS \
$ROPE_ARGS \
$MOE_ARGS \
--distributed-backend nccl \
--save $CKPT_SAVE_DIR \
| tee logs/pretrain_deepseek2_ptd_8p.log

View File

@ -17,6 +17,7 @@ from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid
from modellink.model.transformer import get_attention_mask
from modellink.core.models.common.embeddings.rotary_pos_embedding import yarn_get_mscale
from modellink.utils import get_actual_seq_len
from modellink.model.alibi import Alibi
try:
from einops import rearrange
@ -63,6 +64,19 @@ def dot_product_attention_init_wrapper(fn):
args = get_args()
self.attn_logit_softcapping = args.attn_logit_softcapping
self.square_alibi_mask = args.square_alibi_mask
self.fill_neg_inf = args.fill_neg_inf
self.beta = 1.0
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
if self.apply_query_key_layer_scaling:
self.beta = 1.0 / self.layer_number
if args.position_embedding_type == 'alibi':
get_alibi(self, args.seq_length)
self.alibi_output_size = None
else:
self.alibi = None
if args.query_pre_attn_scalar:
self.norm_factor = args.query_pre_attn_scalar ** 0.5
self.scale_mask_softmax.scale = 1.0
@ -87,6 +101,21 @@ def dot_product_attention_init_wrapper(fn):
return wrapper
def get_alibi(self, seq_length):
args = get_args()
self.alibi = Alibi()
alibi = self.alibi._build_alibi_tensor(seq_length,
args.num_attention_heads,
args.square_alibi_mask,
args.fill_neg_inf,
).to(torch.cuda.current_device())
if args.params_dtype == torch.float16:
alibi = alibi.to(torch.float16)
elif args.params_dtype == torch.bfloat16:
alibi = alibi.to(torch.bfloat16)
self.alibi.alibi = alibi
def dot_product_attention_forward_wrapper(fn):
@wraps(fn)
def wrapper(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params):
@ -134,18 +163,27 @@ def dot_product_attention_forward_wrapper(fn):
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
)
if self.alibi is None:
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query.transpose(0, 1), # [b * np, sq, hn]
key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query.transpose(0, 1), # [b * np, sq, hn]
key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
else:
if self.alibi.alibi_pse is None or self.alibi.output_size != output_size:
self.alibi.output_size = output_size
self.alibi.get_alibi_pse(attention_mask, output_size[0], output_size[2], output_size[3])
q_trans = query.transpose(0, 1).contiguous()
k_trans = key.transpose(0, 1).transpose(1, 2).contiguous()
matmul_result = self.beta * self.alibi.alibi_pse + torch.bmm(q_trans, k_trans) * (1.0 / self.norm_factor)
if self.attn_logit_softcapping is not None:
matmul_result = matmul_result / self.attn_logit_softcapping
@ -160,7 +198,13 @@ def dot_product_attention_forward_wrapper(fn):
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
if self.square_alibi_mask:
attention_scores = torch.max(
attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min)
)
attention_probs = torch.nn.functional.softmax(attention_scores, -1)
else:
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@ -224,7 +268,7 @@ def dot_product_attention_forward(
args = get_args()
seq_length, _, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
actual_seq_len = None
if args.reset_attention_mask or args.reset_position_ids:
query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]]
@ -241,9 +285,18 @@ def dot_product_attention_forward(
raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.")
scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \
if self.scale_mask_softmax.scale is None else self.softmax_scale
if attention_mask is None or args.reset_attention_mask or args.reset_position_ids:
attention_mask = get_attention_mask()
if not hasattr(self, 'attention_mask') or \
self.attention_mask is None or \
self.attention_mask.shape[0] != seq_length or \
args.reset_attention_mask or args.reset_position_ids:
if self.alibi is not None:
self.attention_mask = torch.triu(
torch.ones(seq_length, seq_length),
1).bool().npu()
elif attention_mask is None:
self.attention_mask = get_attention_mask()
else:
self.attention_mask = attention_mask
if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']:
return do_ring_context_parallel(
query, key, value, head_num=n_head, softmax_scale=scale, attn_mask=attention_mask)
@ -253,11 +306,27 @@ def dot_product_attention_forward(
if use_sliding_windows:
args.pre_tockens = args.sliding_window
pse = None
size_record = key.shape
if self.alibi is not None and (self.alibi.output_size != size_record) and pse is None:
if args.shape_order != 'SBH':
raise ValueError(
'FlashAttention with Alibi requires for SBH shape_order, but is {}.'.format(args.shape_order))
self.alibi.output_size = size_record
self.alibi.get_alibi_pse(self.attention_mask, batch_size, query.shape[0], key.shape[0])
if self.alibi and pse is None:
pse = self.alibi.alibi_pse.reshape(
batch_size, n_head, self.alibi.alibi_pse.size(1), -1) * self.beta * self.norm_factor
args.pre_tockens = seq_length
args.sparse_mode = 0
output = torch_npu.npu_fusion_attention(
query, key, value, n_head, args.shape_order,
pse=None,
pse=pse,
padding_mask=None,
atten_mask=attention_mask,
atten_mask=self.attention_mask,
actual_seq_qlen=actual_seq_len,
actual_seq_kvlen=actual_seq_len,
scale=scale,

View File

@ -259,6 +259,16 @@
"layers_mlp_linear_fc2": "model.layers[layer_idx].mlp.experts[expert_idx].w2",
"final_layernorm": "model.norm"
}
},
"baichuan2": {
"__base__": "base",
"config_set_value": {
"qkv_type": "pack_gqa",
"max_position_embeddings": 4096
},
"model_hf_key_mapping": {
"layers_self_attention_linear_qkv_pack": "model.layers[layer_idx].self_attn.W_pack"
}
}
}
}