mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-11-29 18:38:39 +08:00
!1658 【mcore】添加baichuan2-13B和deepseek2-coder适配
Merge pull request !1658 from xiongliangcheng/deepseek-coder
This commit is contained in:
parent
b3c029ddd9
commit
2b84f8f52a
@ -186,10 +186,10 @@ ModelLink 通过模型并行与数据并行来训练大语言模型,为了演
|
|||||||
<tr>
|
<tr>
|
||||||
<td><a href="https://huggingface.co/baichuan-inc/Baichuan2-13B-Base/tree/main">13B</a></td>
|
<td><a href="https://huggingface.co/baichuan-inc/Baichuan2-13B-Base/tree/main">13B</a></td>
|
||||||
<td>4K</td>
|
<td>4K</td>
|
||||||
<th>Legacy</th>
|
<th>Mcore</th>
|
||||||
<td> 1x8</td>
|
<td> 1x8</td>
|
||||||
<td> BF16 </td>
|
<td> BF16 </td>
|
||||||
<td> 1668 </td>
|
<td> 1754 </td>
|
||||||
<td> -- </td>
|
<td> -- </td>
|
||||||
<td> 2062 </td>
|
<td> 2062 </td>
|
||||||
<td><center>【昇腾】</td>
|
<td><center>【昇腾】</td>
|
||||||
|
@ -72,7 +72,7 @@ def main():
|
|||||||
help='Do not perform checking on the name and ordering of weights',
|
help='Do not perform checking on the name and ordering of weights',
|
||||||
dest='checking')
|
dest='checking')
|
||||||
parser.add_argument('--model-type-hf', type=str, default="llama2",
|
parser.add_argument('--model-type-hf', type=str, default="llama2",
|
||||||
choices=['llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'bloom', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm-moe'],
|
choices=['baichuan2', 'llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'bloom', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm-moe'],
|
||||||
help='model type of huggingface')
|
help='model type of huggingface')
|
||||||
known_args, _ = parser.parse_known_args()
|
known_args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
# 修改 ascend-toolkit 路径
|
||||||
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
|
|
||||||
|
# 权重格式转换
|
||||||
|
python convert_ckpt.py \
|
||||||
|
--use-mcore-models \
|
||||||
|
--model-type-hf baichuan2 \
|
||||||
|
--model-type GPT \
|
||||||
|
--load-model-type hf \
|
||||||
|
--save-model-type mg \
|
||||||
|
--params-dtype bf16 \
|
||||||
|
--target-tensor-parallel-size 8 \
|
||||||
|
--target-pipeline-parallel-size 1 \
|
||||||
|
--load-dir ./model_from_hf/Baichuan2-13B_hf/ \
|
||||||
|
--save-dir ./model_weights/Baichuan2-13B_mcore/ \
|
||||||
|
--tokenizer-model ./model_from_hf/Baichuan2-13B_hf/tokenizer.model
|
@ -0,0 +1,15 @@
|
|||||||
|
# 修改 ascend-toolkit 路径
|
||||||
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
|
|
||||||
|
# 权重格式转换
|
||||||
|
python convert_ckpt.py \
|
||||||
|
--use-mcore-models \
|
||||||
|
--model-type-hf baichuan2 \
|
||||||
|
--model-type GPT \
|
||||||
|
--load-model-type mg \
|
||||||
|
--save-model-type hf \
|
||||||
|
--params-dtype bf16 \
|
||||||
|
--target-tensor-parallel-size 1 \
|
||||||
|
--target-pipeline-parallel-size 1 \
|
||||||
|
--load-dir ./model_weights/Baichuan2-13B_mcore/ \
|
||||||
|
--save-dir ./model_from_hf/Baichuan2-13B_hf/
|
12
examples/mcore/baichuan2/data_convert_baichuan2_pretrain.sh
Normal file
12
examples/mcore/baichuan2/data_convert_baichuan2_pretrain.sh
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# 请按照您的真实环境修改 set_env.sh 路径
|
||||||
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
|
mkdir ./dataset
|
||||||
|
|
||||||
|
# 数据集下载地址 https://huggingface.co/datasets/tatsu-lab/alpaca/resolve/main/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet
|
||||||
|
python ./preprocess_data.py \
|
||||||
|
--input ./dataset/train-00000-of-00001-a09b74b3ef9c3b56.parquet \
|
||||||
|
--tokenizer-name-or-path ./model_from_hf/Baichuan-hf/ \
|
||||||
|
--output-prefix ./dataset/alpaca \
|
||||||
|
--workers 4 \
|
||||||
|
--log-interval 1000 \
|
||||||
|
--tokenizer-type PretrainedFromHF
|
53
examples/mcore/baichuan2/evaluate_baichuan2_13b_ptd.sh
Normal file
53
examples/mcore/baichuan2/evaluate_baichuan2_13b_ptd.sh
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||||
|
|
||||||
|
# Change for multinode config
|
||||||
|
MASTER_ADDR=localhost
|
||||||
|
MASTER_PORT=6001
|
||||||
|
NNODES=1
|
||||||
|
NODE_RANK=0
|
||||||
|
NPUS_PER_NODE=8
|
||||||
|
|
||||||
|
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||||
|
|
||||||
|
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||||
|
|
||||||
|
CHECKPOINT="Your ckpt file path"
|
||||||
|
TOKENIZER_PATH="Your tokenizer path"
|
||||||
|
DATA_PATH="./boolq/data/test/"
|
||||||
|
TASK="boolq"
|
||||||
|
|
||||||
|
# Different task needs different max_new_tokens value, please follow the instruction in readme.
|
||||||
|
torchrun $DISTRIBUTED_ARGS evaluation.py \
|
||||||
|
--task-data-path ${DATA_PATH} \
|
||||||
|
--task ${TASK} \
|
||||||
|
--seq-length 4096 \
|
||||||
|
--max-new-tokens 1 \
|
||||||
|
--max-position-embeddings 4096 \
|
||||||
|
--tensor-model-parallel-size 8 \
|
||||||
|
--pipeline-model-parallel-size 1 \
|
||||||
|
--num-layers 40 \
|
||||||
|
--hidden-size 5120 \
|
||||||
|
--ffn-hidden-size 13696 \
|
||||||
|
--num-attention-heads 40 \
|
||||||
|
--disable-bias-linear \
|
||||||
|
--swiglu \
|
||||||
|
--position-embedding-type alibi \
|
||||||
|
--square-alibi-mask \
|
||||||
|
--fill-neg-inf \
|
||||||
|
--load $CHECKPOINT \
|
||||||
|
--normalization RMSNorm \
|
||||||
|
--tokenizer-type PretrainedFromHF \
|
||||||
|
--tokenizer-name-or-path ${TOKENIZER_PATH} \
|
||||||
|
--tokenizer-not-use-fast \
|
||||||
|
--fp16 \
|
||||||
|
--micro-batch-size 1 \
|
||||||
|
--use-fused-rmsnorm \
|
||||||
|
--exit-on-missing-checkpoint \
|
||||||
|
--no-load-rng \
|
||||||
|
--no-load-optim \
|
||||||
|
--untie-embeddings-and-output-weights \
|
||||||
|
--no-masked-softmax-fusion \
|
||||||
|
--make-vocab-size-divisible-by 32 \
|
||||||
|
--use-mcore-models \
|
||||||
|
--seed 42 | tee logs/eval_baichuan2_13b_mcore_${TASK}.log
|
53
examples/mcore/baichuan2/generate_baichuan2_13b_ptd.sh
Normal file
53
examples/mcore/baichuan2/generate_baichuan2_13b_ptd.sh
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||||
|
|
||||||
|
# please fill these path configurations
|
||||||
|
CHECKPOINT="your model directory path"
|
||||||
|
TOKENIZER_PATH="your tokenizer directory path"
|
||||||
|
|
||||||
|
# Change for multinode config
|
||||||
|
MASTER_ADDR=localhost
|
||||||
|
MASTER_PORT=6001
|
||||||
|
NNODES=1
|
||||||
|
NODE_RANK=0
|
||||||
|
NPUS_PER_NODE=8
|
||||||
|
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||||
|
|
||||||
|
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||||
|
|
||||||
|
torchrun $DISTRIBUTED_ARGS inference.py \
|
||||||
|
--use-mcore-models \
|
||||||
|
--tensor-model-parallel-size 8 \
|
||||||
|
--pipeline-model-parallel-size 1 \
|
||||||
|
--num-layers 40 \
|
||||||
|
--hidden-size 5120 \
|
||||||
|
--ffn-hidden-size 13696 \
|
||||||
|
--seq-length 1024 \
|
||||||
|
--max-new-tokens 256 \
|
||||||
|
--micro-batch-size 1 \
|
||||||
|
--global-batch-size 16 \
|
||||||
|
--num-attention-heads 40 \
|
||||||
|
--max-position-embeddings 2048 \
|
||||||
|
--position-embedding-type alibi \
|
||||||
|
--square-alibi-mask \
|
||||||
|
--fill-neg-inf \
|
||||||
|
--swiglu \
|
||||||
|
--load ${CHECKPOINT} \
|
||||||
|
--tokenizer-type PretrainedFromHF \
|
||||||
|
--tokenizer-name-or-path ${TOKENIZER_PATH} \
|
||||||
|
--tokenizer-not-use-fast \
|
||||||
|
--fp16 \
|
||||||
|
--normalization RMSNorm \
|
||||||
|
--untie-embeddings-and-output-weights \
|
||||||
|
--disable-bias-linear \
|
||||||
|
--attention-softmax-in-fp32 \
|
||||||
|
--no-load-optim \
|
||||||
|
--no-load-rng \
|
||||||
|
--no-masked-softmax-fusion \
|
||||||
|
--no-gradient-accumulation-fusion \
|
||||||
|
--exit-on-missing-checkpoint \
|
||||||
|
--make-vocab-size-divisible-by 32 \
|
||||||
|
| tee logs/generate_baichuan2_13b_mcore.log
|
||||||
|
|
||||||
|
|
100
examples/mcore/baichuan2/pretrain_baichuan2_13b_ptd.sh
Normal file
100
examples/mcore/baichuan2/pretrain_baichuan2_13b_ptd.sh
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||||
|
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||||
|
|
||||||
|
GPUS_PER_NODE=8
|
||||||
|
MASTER_ADDR=localhost
|
||||||
|
MASTER_PORT=6000
|
||||||
|
NNODES=1
|
||||||
|
NODE_RANK=0
|
||||||
|
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||||
|
|
||||||
|
CKPT_SAVE_DIR="your model save ckpt path"
|
||||||
|
DATA_PATH="your data path"
|
||||||
|
TOKENIZER_MODEL="your tokenizer model path"
|
||||||
|
CKPT_LOAD_DIR="your model load ckpt path"
|
||||||
|
|
||||||
|
TP=8
|
||||||
|
PP=1
|
||||||
|
|
||||||
|
|
||||||
|
DISTRIBUTED_ARGS="
|
||||||
|
--nproc_per_node $GPUS_PER_NODE \
|
||||||
|
--nnodes $NNODES \
|
||||||
|
--node_rank $NODE_RANK \
|
||||||
|
--master_addr $MASTER_ADDR \
|
||||||
|
--master_port $MASTER_PORT
|
||||||
|
"
|
||||||
|
|
||||||
|
GPT_ARGS="
|
||||||
|
--use-mcore-models \
|
||||||
|
--tensor-model-parallel-size ${TP} \
|
||||||
|
--pipeline-model-parallel-size ${PP} \
|
||||||
|
--sequence-parallel \
|
||||||
|
--use-mc2 \
|
||||||
|
--num-layers 40 \
|
||||||
|
--hidden-size 5120 \
|
||||||
|
--ffn-hidden-size 13696 \
|
||||||
|
--num-attention-heads 40 \
|
||||||
|
--tokenizer-type Llama2Tokenizer \
|
||||||
|
--tokenizer-model ${TOKENIZER_MODEL} \
|
||||||
|
--seq-length 4096 \
|
||||||
|
--disable-bias-linear \
|
||||||
|
--max-position-embeddings 4096 \
|
||||||
|
--micro-batch-size 2 \
|
||||||
|
--global-batch-size 128 \
|
||||||
|
--untie-embeddings-and-output-weights \
|
||||||
|
--no-gradient-accumulation-fusion \
|
||||||
|
--make-vocab-size-divisible-by 32 \
|
||||||
|
--lr 1e-5 \
|
||||||
|
--load ${CKPT_LOAD_DIR} \
|
||||||
|
--train-iters 2000 \
|
||||||
|
--lr-decay-style cosine \
|
||||||
|
--attention-dropout 0.0 \
|
||||||
|
--init-method-std 0.01 \
|
||||||
|
--position-embedding-type alibi \
|
||||||
|
--hidden-dropout 0.0 \
|
||||||
|
--norm-epsilon 1e-6 \
|
||||||
|
--normalization RMSNorm \
|
||||||
|
--use-fused-rmsnorm \
|
||||||
|
--use-flash-attn \
|
||||||
|
--use-fused-swiglu \
|
||||||
|
--use-mc2 \
|
||||||
|
--swiglu \
|
||||||
|
--no-masked-softmax-fusion \
|
||||||
|
--attention-softmax-in-fp32 \
|
||||||
|
--square-alibi-mask \
|
||||||
|
--fill-neg-inf \
|
||||||
|
--min-lr 1e-8 \
|
||||||
|
--weight-decay 1e-4 \
|
||||||
|
--clip-grad 1.0 \
|
||||||
|
--seed 1234 \
|
||||||
|
--adam-beta1 0.9 \
|
||||||
|
--initial-loss-scale 8188.0 \
|
||||||
|
--adam-beta2 0.98 \
|
||||||
|
--adam-eps 1.0e-8 \
|
||||||
|
--no-load-optim \
|
||||||
|
--no-load-rng \
|
||||||
|
--bf16
|
||||||
|
"
|
||||||
|
|
||||||
|
DATA_ARGS="
|
||||||
|
--data-path ${DATA_PATH} \
|
||||||
|
--split 100,0,0
|
||||||
|
"
|
||||||
|
|
||||||
|
OUTPUT_ARGS="
|
||||||
|
--log-interval 1 \
|
||||||
|
--save-interval 2000 \
|
||||||
|
--eval-interval 2000 \
|
||||||
|
--eval-iters 0
|
||||||
|
"
|
||||||
|
|
||||||
|
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
|
||||||
|
$GPT_ARGS \
|
||||||
|
$DATA_ARGS \
|
||||||
|
$OUTPUT_ARGS \
|
||||||
|
--distributed-backend nccl \
|
||||||
|
--save ${CKPT_SAVE_DIR} \
|
||||||
|
| tee logs/train_baichuan2_13b_mcore.log
|
@ -0,0 +1,18 @@
|
|||||||
|
# 修改 ascend-toolkit 路径
|
||||||
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
|
|
||||||
|
# 权重格式转换,设置需要的并行策略
|
||||||
|
python convert_ckpt.py \
|
||||||
|
--use-mcore-models \
|
||||||
|
--moe-grouped-gemm \
|
||||||
|
--model-type-hf deepseek2 \
|
||||||
|
--model-type GPT \
|
||||||
|
--load-model-type hf \
|
||||||
|
--save-model-type mg \
|
||||||
|
--params-dtype bf16 \
|
||||||
|
--target-tensor-parallel-size 1 \
|
||||||
|
--target-pipeline-parallel-size 1 \
|
||||||
|
--target-expert-parallel-size 8 \
|
||||||
|
--load-dir ./model_from_hf/deepseek2-coder-hf/ \
|
||||||
|
--save-dir ./model_weights/deepseek2-coder-mcore/ \
|
||||||
|
--tokenizer-model ./model_from_hf/deepseek2-coder-hf/
|
@ -0,0 +1,17 @@
|
|||||||
|
# 修改 ascend-toolkit 路径
|
||||||
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
|
|
||||||
|
# 权重格式转换
|
||||||
|
python convert_ckpt.py \
|
||||||
|
--use-mcore-models \
|
||||||
|
--moe-grouped-gemm \
|
||||||
|
--model-type-hf deepseek2 \
|
||||||
|
--model-type GPT \
|
||||||
|
--load-model-type mg \
|
||||||
|
--save-model-type hf \
|
||||||
|
--params-dtype bf16 \
|
||||||
|
--target-tensor-parallel-size 1 \
|
||||||
|
--target-pipeline-parallel-size 1 \
|
||||||
|
--target-expert-parallel-size 1 \
|
||||||
|
--load-dir ./model_weights/deepseek2-coder-mcore/ \
|
||||||
|
--save-dir ./model_from_hf/deepseek2-coder-hf/
|
@ -0,0 +1,14 @@
|
|||||||
|
# 请按照您的真实环境修改 set_env.sh 路径
|
||||||
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
|
mkdir ./dataset
|
||||||
|
|
||||||
|
# 数据集下载地址 https://huggingface.co/datasets/lsb/enwiki20230101/blob/main/data/train-00000-of-00042-d964455e17e96d5a.parquet
|
||||||
|
python ./preprocess_data.py \
|
||||||
|
--input ./dataset/train-00000-of-00042-d964455e17e96d5a.parquet \
|
||||||
|
--tokenizer-name-or-path ./model_from_hf/deepseek2-coder-hf/ \
|
||||||
|
--tokenizer-type PretrainedFromHF \
|
||||||
|
--handler-name GeneralPretrainHandler \
|
||||||
|
--output-prefix ./dataset/enwiki \
|
||||||
|
--json-keys text \
|
||||||
|
--workers 4 \
|
||||||
|
--log-interval 1000
|
147
examples/mcore/deepseek2_coder/pretrain_deepseek2_ptd_8p.sh
Normal file
147
examples/mcore/deepseek2_coder/pretrain_deepseek2_ptd_8p.sh
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||||
|
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||||
|
|
||||||
|
GPUS_PER_NODE=8
|
||||||
|
MASTER_ADDR=localhost
|
||||||
|
MASTER_PORT=6000
|
||||||
|
NNODES=1
|
||||||
|
NODE_RANK=0
|
||||||
|
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||||
|
|
||||||
|
CKPT_SAVE_DIR="your model save ckpt path"
|
||||||
|
DATA_PATH="your data path"
|
||||||
|
TOKENIZER_MODEL="your tokenizer path"
|
||||||
|
CKPT_LOAD_DIR="your model ckpt path"
|
||||||
|
|
||||||
|
TP=1
|
||||||
|
PP=1
|
||||||
|
EP=8
|
||||||
|
|
||||||
|
DISTRIBUTED_ARGS="
|
||||||
|
--nproc_per_node $GPUS_PER_NODE \
|
||||||
|
--nnodes $NNODES \
|
||||||
|
--node_rank $NODE_RANK \
|
||||||
|
--master_addr $MASTER_ADDR \
|
||||||
|
--master_port $MASTER_PORT
|
||||||
|
"
|
||||||
|
|
||||||
|
MLA_ARGS="
|
||||||
|
--multi-head-latent-attention \
|
||||||
|
--qk-rope-head-dim 64 \
|
||||||
|
--qk-nope-head-dim 128 \
|
||||||
|
--q-lora-rank 1536 \
|
||||||
|
--kv-lora-rank 512 \
|
||||||
|
--v-head-dim 128 \
|
||||||
|
--qk-layernorm \
|
||||||
|
"
|
||||||
|
|
||||||
|
MOE_ARGS="
|
||||||
|
--moe-grouped-gemm \
|
||||||
|
--moe-permutation-async-comm \
|
||||||
|
--moe-token-dispatcher-type allgather \
|
||||||
|
--first-k-dense-replace 1 \
|
||||||
|
--moe-layer-freq 1 \
|
||||||
|
--n-shared-experts 2 \
|
||||||
|
--num-experts 160 \
|
||||||
|
--moe-router-topk 6 \
|
||||||
|
--moe-intermediate-size 1536 \
|
||||||
|
--moe-router-load-balancing-type group_limited_greedy \
|
||||||
|
--topk-group 3 \
|
||||||
|
--moe-aux-loss-coeff 0.003 \
|
||||||
|
--moe-device-level-aux-loss-coeff 0.05 \
|
||||||
|
--moe-comm-aux-loss-coeff 0.02 \
|
||||||
|
--routed-scaling-factor 16.0 \
|
||||||
|
--seq-aux
|
||||||
|
"
|
||||||
|
|
||||||
|
ROPE_ARGS="
|
||||||
|
--rope-scaling-beta-fast 32 \
|
||||||
|
--rope-scaling-beta-slow 1 \
|
||||||
|
--rope-scaling-factor 40 \
|
||||||
|
--rope-scaling-mscale 0.707 \
|
||||||
|
--rope-scaling-mscale-all-dim 0.707 \
|
||||||
|
--rope-scaling-original-max-position-embeddings 4096 \
|
||||||
|
--rope-scaling-type yarn
|
||||||
|
"
|
||||||
|
|
||||||
|
GPT_ARGS="
|
||||||
|
--load $CKPT_LOAD_DIR \
|
||||||
|
--use-distributed-optimizer \
|
||||||
|
--use-flash-attn \
|
||||||
|
--shape-order BNSD \
|
||||||
|
--use-mcore-models \
|
||||||
|
--tensor-model-parallel-size ${TP} \
|
||||||
|
--pipeline-model-parallel-size ${PP} \
|
||||||
|
--expert-model-parallel-size ${EP} \
|
||||||
|
--sequence-parallel \
|
||||||
|
--output-layer-slice-num 10 \
|
||||||
|
--num-layers 2 \
|
||||||
|
--hidden-size 5120 \
|
||||||
|
--ffn-hidden-size 12288 \
|
||||||
|
--num-attention-heads 128 \
|
||||||
|
--tokenizer-type PretrainedFromHF \
|
||||||
|
--tokenizer-name-or-path ${TOKENIZER_MODEL} \
|
||||||
|
--seq-length 8192 \
|
||||||
|
--max-position-embeddings 163840 \
|
||||||
|
--micro-batch-size 1 \
|
||||||
|
--global-batch-size 64 \
|
||||||
|
--make-vocab-size-divisible-by 1 \
|
||||||
|
--lr 1.0e-5 \
|
||||||
|
--train-iters 2000 \
|
||||||
|
--lr-decay-style cosine \
|
||||||
|
--untie-embeddings-and-output-weights \
|
||||||
|
--disable-bias-linear \
|
||||||
|
--attention-dropout 0.0 \
|
||||||
|
--init-method-std 0.02 \
|
||||||
|
--hidden-dropout 0.0 \
|
||||||
|
--position-embedding-type rope \
|
||||||
|
--normalization RMSNorm \
|
||||||
|
--use-fused-rotary-pos-emb \
|
||||||
|
--use-rotary-position-embeddings \
|
||||||
|
--use-fused-swiglu \
|
||||||
|
--use-fused-rmsnorm \
|
||||||
|
--swiglu \
|
||||||
|
--no-masked-softmax-fusion \
|
||||||
|
--attention-softmax-in-fp32 \
|
||||||
|
--min-lr 1.0e-7 \
|
||||||
|
--weight-decay 1e-2 \
|
||||||
|
--lr-warmup-iters 500 \
|
||||||
|
--clip-grad 1.0 \
|
||||||
|
--adam-beta1 0.9 \
|
||||||
|
--adam-beta2 0.999 \
|
||||||
|
--initial-loss-scale 65536 \
|
||||||
|
--vocab-size 102400 \
|
||||||
|
--padded-vocab-size 102400 \
|
||||||
|
--rotary-base 10000 \
|
||||||
|
--no-gradient-accumulation-fusion \
|
||||||
|
--norm-epsilon 1e-6 \
|
||||||
|
--no-load-optim \
|
||||||
|
--no-load-rng \
|
||||||
|
--bf16
|
||||||
|
"
|
||||||
|
|
||||||
|
DATA_ARGS="
|
||||||
|
--data-path $DATA_PATH \
|
||||||
|
--split 100,0,0
|
||||||
|
"
|
||||||
|
|
||||||
|
OUTPUT_ARGS="
|
||||||
|
--log-interval 1 \
|
||||||
|
--save-interval 2000 \
|
||||||
|
--eval-interval 2000 \
|
||||||
|
--eval-iters 0 \
|
||||||
|
--no-save-optim \
|
||||||
|
--no-save-rng
|
||||||
|
"
|
||||||
|
|
||||||
|
python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
|
||||||
|
$GPT_ARGS \
|
||||||
|
$DATA_ARGS \
|
||||||
|
$OUTPUT_ARGS \
|
||||||
|
$MLA_ARGS \
|
||||||
|
$ROPE_ARGS \
|
||||||
|
$MOE_ARGS \
|
||||||
|
--distributed-backend nccl \
|
||||||
|
--save $CKPT_SAVE_DIR \
|
||||||
|
| tee logs/pretrain_deepseek2_ptd_8p.log
|
@ -17,6 +17,7 @@ from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid
|
|||||||
from modellink.model.transformer import get_attention_mask
|
from modellink.model.transformer import get_attention_mask
|
||||||
from modellink.core.models.common.embeddings.rotary_pos_embedding import yarn_get_mscale
|
from modellink.core.models.common.embeddings.rotary_pos_embedding import yarn_get_mscale
|
||||||
from modellink.utils import get_actual_seq_len
|
from modellink.utils import get_actual_seq_len
|
||||||
|
from modellink.model.alibi import Alibi
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@ -63,6 +64,19 @@ def dot_product_attention_init_wrapper(fn):
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
self.attn_logit_softcapping = args.attn_logit_softcapping
|
self.attn_logit_softcapping = args.attn_logit_softcapping
|
||||||
|
self.square_alibi_mask = args.square_alibi_mask
|
||||||
|
self.fill_neg_inf = args.fill_neg_inf
|
||||||
|
self.beta = 1.0
|
||||||
|
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
|
||||||
|
if self.apply_query_key_layer_scaling:
|
||||||
|
self.beta = 1.0 / self.layer_number
|
||||||
|
|
||||||
|
if args.position_embedding_type == 'alibi':
|
||||||
|
get_alibi(self, args.seq_length)
|
||||||
|
self.alibi_output_size = None
|
||||||
|
else:
|
||||||
|
self.alibi = None
|
||||||
|
|
||||||
if args.query_pre_attn_scalar:
|
if args.query_pre_attn_scalar:
|
||||||
self.norm_factor = args.query_pre_attn_scalar ** 0.5
|
self.norm_factor = args.query_pre_attn_scalar ** 0.5
|
||||||
self.scale_mask_softmax.scale = 1.0
|
self.scale_mask_softmax.scale = 1.0
|
||||||
@ -87,6 +101,21 @@ def dot_product_attention_init_wrapper(fn):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def get_alibi(self, seq_length):
|
||||||
|
args = get_args()
|
||||||
|
self.alibi = Alibi()
|
||||||
|
alibi = self.alibi._build_alibi_tensor(seq_length,
|
||||||
|
args.num_attention_heads,
|
||||||
|
args.square_alibi_mask,
|
||||||
|
args.fill_neg_inf,
|
||||||
|
).to(torch.cuda.current_device())
|
||||||
|
if args.params_dtype == torch.float16:
|
||||||
|
alibi = alibi.to(torch.float16)
|
||||||
|
elif args.params_dtype == torch.bfloat16:
|
||||||
|
alibi = alibi.to(torch.bfloat16)
|
||||||
|
self.alibi.alibi = alibi
|
||||||
|
|
||||||
|
|
||||||
def dot_product_attention_forward_wrapper(fn):
|
def dot_product_attention_forward_wrapper(fn):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params):
|
def wrapper(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params):
|
||||||
@ -134,18 +163,27 @@ def dot_product_attention_forward_wrapper(fn):
|
|||||||
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
|
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
|
||||||
|
|
||||||
# preallocting input tensor: [b * np, sq, sk]
|
# preallocting input tensor: [b * np, sq, sk]
|
||||||
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
|
if self.alibi is None:
|
||||||
(output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
|
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
|
||||||
)
|
(output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
|
||||||
|
)
|
||||||
|
|
||||||
# Raw attention scores. [b * np, sq, sk]
|
# Raw attention scores. [b * np, sq, sk]
|
||||||
matmul_result = torch.baddbmm(
|
matmul_result = torch.baddbmm(
|
||||||
matmul_input_buffer,
|
matmul_input_buffer,
|
||||||
query.transpose(0, 1), # [b * np, sq, hn]
|
query.transpose(0, 1), # [b * np, sq, hn]
|
||||||
key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
||||||
beta=0.0,
|
beta=0.0,
|
||||||
alpha=(1.0 / self.norm_factor),
|
alpha=(1.0 / self.norm_factor),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if self.alibi.alibi_pse is None or self.alibi.output_size != output_size:
|
||||||
|
self.alibi.output_size = output_size
|
||||||
|
self.alibi.get_alibi_pse(attention_mask, output_size[0], output_size[2], output_size[3])
|
||||||
|
|
||||||
|
q_trans = query.transpose(0, 1).contiguous()
|
||||||
|
k_trans = key.transpose(0, 1).transpose(1, 2).contiguous()
|
||||||
|
matmul_result = self.beta * self.alibi.alibi_pse + torch.bmm(q_trans, k_trans) * (1.0 / self.norm_factor)
|
||||||
|
|
||||||
if self.attn_logit_softcapping is not None:
|
if self.attn_logit_softcapping is not None:
|
||||||
matmul_result = matmul_result / self.attn_logit_softcapping
|
matmul_result = matmul_result / self.attn_logit_softcapping
|
||||||
@ -160,7 +198,13 @@ def dot_product_attention_forward_wrapper(fn):
|
|||||||
# ===========================
|
# ===========================
|
||||||
|
|
||||||
# attention scores and attention mask [b, np, sq, sk]
|
# attention scores and attention mask [b, np, sq, sk]
|
||||||
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
|
if self.square_alibi_mask:
|
||||||
|
attention_scores = torch.max(
|
||||||
|
attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min)
|
||||||
|
)
|
||||||
|
attention_probs = torch.nn.functional.softmax(attention_scores, -1)
|
||||||
|
else:
|
||||||
|
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
|
||||||
|
|
||||||
# This is actually dropping out entire tokens to attend to, which might
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
@ -224,7 +268,7 @@ def dot_product_attention_forward(
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
seq_length, _, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
|
seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
|
||||||
actual_seq_len = None
|
actual_seq_len = None
|
||||||
if args.reset_attention_mask or args.reset_position_ids:
|
if args.reset_attention_mask or args.reset_position_ids:
|
||||||
query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]]
|
query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]]
|
||||||
@ -241,9 +285,18 @@ def dot_product_attention_forward(
|
|||||||
raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.")
|
raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.")
|
||||||
scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \
|
scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \
|
||||||
if self.scale_mask_softmax.scale is None else self.softmax_scale
|
if self.scale_mask_softmax.scale is None else self.softmax_scale
|
||||||
if attention_mask is None or args.reset_attention_mask or args.reset_position_ids:
|
if not hasattr(self, 'attention_mask') or \
|
||||||
attention_mask = get_attention_mask()
|
self.attention_mask is None or \
|
||||||
|
self.attention_mask.shape[0] != seq_length or \
|
||||||
|
args.reset_attention_mask or args.reset_position_ids:
|
||||||
|
if self.alibi is not None:
|
||||||
|
self.attention_mask = torch.triu(
|
||||||
|
torch.ones(seq_length, seq_length),
|
||||||
|
1).bool().npu()
|
||||||
|
elif attention_mask is None:
|
||||||
|
self.attention_mask = get_attention_mask()
|
||||||
|
else:
|
||||||
|
self.attention_mask = attention_mask
|
||||||
if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']:
|
if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']:
|
||||||
return do_ring_context_parallel(
|
return do_ring_context_parallel(
|
||||||
query, key, value, head_num=n_head, softmax_scale=scale, attn_mask=attention_mask)
|
query, key, value, head_num=n_head, softmax_scale=scale, attn_mask=attention_mask)
|
||||||
@ -253,11 +306,27 @@ def dot_product_attention_forward(
|
|||||||
if use_sliding_windows:
|
if use_sliding_windows:
|
||||||
args.pre_tockens = args.sliding_window
|
args.pre_tockens = args.sliding_window
|
||||||
|
|
||||||
|
pse = None
|
||||||
|
size_record = key.shape
|
||||||
|
if self.alibi is not None and (self.alibi.output_size != size_record) and pse is None:
|
||||||
|
if args.shape_order != 'SBH':
|
||||||
|
raise ValueError(
|
||||||
|
'FlashAttention with Alibi requires for SBH shape_order, but is {}.'.format(args.shape_order))
|
||||||
|
|
||||||
|
self.alibi.output_size = size_record
|
||||||
|
self.alibi.get_alibi_pse(self.attention_mask, batch_size, query.shape[0], key.shape[0])
|
||||||
|
|
||||||
|
if self.alibi and pse is None:
|
||||||
|
pse = self.alibi.alibi_pse.reshape(
|
||||||
|
batch_size, n_head, self.alibi.alibi_pse.size(1), -1) * self.beta * self.norm_factor
|
||||||
|
args.pre_tockens = seq_length
|
||||||
|
args.sparse_mode = 0
|
||||||
|
|
||||||
output = torch_npu.npu_fusion_attention(
|
output = torch_npu.npu_fusion_attention(
|
||||||
query, key, value, n_head, args.shape_order,
|
query, key, value, n_head, args.shape_order,
|
||||||
pse=None,
|
pse=pse,
|
||||||
padding_mask=None,
|
padding_mask=None,
|
||||||
atten_mask=attention_mask,
|
atten_mask=self.attention_mask,
|
||||||
actual_seq_qlen=actual_seq_len,
|
actual_seq_qlen=actual_seq_len,
|
||||||
actual_seq_kvlen=actual_seq_len,
|
actual_seq_kvlen=actual_seq_len,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
@ -259,6 +259,16 @@
|
|||||||
"layers_mlp_linear_fc2": "model.layers[layer_idx].mlp.experts[expert_idx].w2",
|
"layers_mlp_linear_fc2": "model.layers[layer_idx].mlp.experts[expert_idx].w2",
|
||||||
"final_layernorm": "model.norm"
|
"final_layernorm": "model.norm"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"baichuan2": {
|
||||||
|
"__base__": "base",
|
||||||
|
"config_set_value": {
|
||||||
|
"qkv_type": "pack_gqa",
|
||||||
|
"max_position_embeddings": 4096
|
||||||
|
},
|
||||||
|
"model_hf_key_mapping": {
|
||||||
|
"layers_self_attention_linear_qkv_pack": "model.layers[layer_idx].self_attn.W_pack"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user