diff --git a/README.md b/README.md
index 7b28b72d..7489d0fb 100644
--- a/README.md
+++ b/README.md
@@ -186,10 +186,10 @@ ModelLink 通过模型并行与数据并行来训练大语言模型,为了演
13B |
4K |
- Legacy |
+ Mcore |
1x8 |
BF16 |
- 1668 |
+ 1754 |
-- |
2062 |
【昇腾】 |
diff --git a/convert_ckpt.py b/convert_ckpt.py
index 30e97d43..0b7b1c09 100644
--- a/convert_ckpt.py
+++ b/convert_ckpt.py
@@ -72,7 +72,7 @@ def main():
help='Do not perform checking on the name and ordering of weights',
dest='checking')
parser.add_argument('--model-type-hf', type=str, default="llama2",
- choices=['llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'bloom', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm-moe'],
+ choices=['baichuan2', 'llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'bloom', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm-moe'],
help='model type of huggingface')
known_args, _ = parser.parse_known_args()
diff --git a/examples/mcore/baichuan2/ckpt_convert_baichuan2_13b_hf2mcore.sh b/examples/mcore/baichuan2/ckpt_convert_baichuan2_13b_hf2mcore.sh
new file mode 100644
index 00000000..a8b2dc1d
--- /dev/null
+++ b/examples/mcore/baichuan2/ckpt_convert_baichuan2_13b_hf2mcore.sh
@@ -0,0 +1,16 @@
+# 修改 ascend-toolkit 路径
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+
+# 权重格式转换
+python convert_ckpt.py \
+ --use-mcore-models \
+ --model-type-hf baichuan2 \
+ --model-type GPT \
+ --load-model-type hf \
+ --save-model-type mg \
+ --params-dtype bf16 \
+ --target-tensor-parallel-size 8 \
+ --target-pipeline-parallel-size 1 \
+ --load-dir ./model_from_hf/Baichuan2-13B_hf/ \
+ --save-dir ./model_weights/Baichuan2-13B_mcore/ \
+ --tokenizer-model ./model_from_hf/Baichuan2-13B_hf/tokenizer.model
diff --git a/examples/mcore/baichuan2/ckpt_convert_baichuan2_13b_mcore2hf.sh b/examples/mcore/baichuan2/ckpt_convert_baichuan2_13b_mcore2hf.sh
new file mode 100644
index 00000000..a691d514
--- /dev/null
+++ b/examples/mcore/baichuan2/ckpt_convert_baichuan2_13b_mcore2hf.sh
@@ -0,0 +1,15 @@
+# 修改 ascend-toolkit 路径
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+
+# 权重格式转换
+python convert_ckpt.py \
+ --use-mcore-models \
+ --model-type-hf baichuan2 \
+ --model-type GPT \
+ --load-model-type mg \
+ --save-model-type hf \
+ --params-dtype bf16 \
+ --target-tensor-parallel-size 1 \
+ --target-pipeline-parallel-size 1 \
+ --load-dir ./model_weights/Baichuan2-13B_mcore/ \
+ --save-dir ./model_from_hf/Baichuan2-13B_hf/
diff --git a/examples/mcore/baichuan2/data_convert_baichuan2_pretrain.sh b/examples/mcore/baichuan2/data_convert_baichuan2_pretrain.sh
new file mode 100644
index 00000000..456d68b0
--- /dev/null
+++ b/examples/mcore/baichuan2/data_convert_baichuan2_pretrain.sh
@@ -0,0 +1,12 @@
+# 请按照您的真实环境修改 set_env.sh 路径
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+mkdir ./dataset
+
+# 数据集下载地址 https://huggingface.co/datasets/tatsu-lab/alpaca/resolve/main/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet
+python ./preprocess_data.py \
+ --input ./dataset/train-00000-of-00001-a09b74b3ef9c3b56.parquet \
+ --tokenizer-name-or-path ./model_from_hf/Baichuan-hf/ \
+ --output-prefix ./dataset/alpaca \
+ --workers 4 \
+ --log-interval 1000 \
+ --tokenizer-type PretrainedFromHF
diff --git a/examples/mcore/baichuan2/evaluate_baichuan2_13b_ptd.sh b/examples/mcore/baichuan2/evaluate_baichuan2_13b_ptd.sh
new file mode 100644
index 00000000..c0942925
--- /dev/null
+++ b/examples/mcore/baichuan2/evaluate_baichuan2_13b_ptd.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+# Change for multinode config
+MASTER_ADDR=localhost
+MASTER_PORT=6001
+NNODES=1
+NODE_RANK=0
+NPUS_PER_NODE=8
+
+WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
+
+DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
+
+CHECKPOINT="Your ckpt file path"
+TOKENIZER_PATH="Your tokenizer path"
+DATA_PATH="./boolq/data/test/"
+TASK="boolq"
+
+# Different task needs different max_new_tokens value, please follow the instruction in readme.
+torchrun $DISTRIBUTED_ARGS evaluation.py \
+ --task-data-path ${DATA_PATH} \
+ --task ${TASK} \
+ --seq-length 4096 \
+ --max-new-tokens 1 \
+ --max-position-embeddings 4096 \
+ --tensor-model-parallel-size 8 \
+ --pipeline-model-parallel-size 1 \
+ --num-layers 40 \
+ --hidden-size 5120 \
+ --ffn-hidden-size 13696 \
+ --num-attention-heads 40 \
+ --disable-bias-linear \
+ --swiglu \
+ --position-embedding-type alibi \
+ --square-alibi-mask \
+ --fill-neg-inf \
+ --load $CHECKPOINT \
+ --normalization RMSNorm \
+ --tokenizer-type PretrainedFromHF \
+ --tokenizer-name-or-path ${TOKENIZER_PATH} \
+ --tokenizer-not-use-fast \
+ --fp16 \
+ --micro-batch-size 1 \
+ --use-fused-rmsnorm \
+ --exit-on-missing-checkpoint \
+ --no-load-rng \
+ --no-load-optim \
+ --untie-embeddings-and-output-weights \
+ --no-masked-softmax-fusion \
+ --make-vocab-size-divisible-by 32 \
+ --use-mcore-models \
+ --seed 42 | tee logs/eval_baichuan2_13b_mcore_${TASK}.log
diff --git a/examples/mcore/baichuan2/generate_baichuan2_13b_ptd.sh b/examples/mcore/baichuan2/generate_baichuan2_13b_ptd.sh
new file mode 100644
index 00000000..b474774b
--- /dev/null
+++ b/examples/mcore/baichuan2/generate_baichuan2_13b_ptd.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+
+# please fill these path configurations
+CHECKPOINT="your model directory path"
+TOKENIZER_PATH="your tokenizer directory path"
+
+# Change for multinode config
+MASTER_ADDR=localhost
+MASTER_PORT=6001
+NNODES=1
+NODE_RANK=0
+NPUS_PER_NODE=8
+WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
+
+DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
+
+torchrun $DISTRIBUTED_ARGS inference.py \
+ --use-mcore-models \
+ --tensor-model-parallel-size 8 \
+ --pipeline-model-parallel-size 1 \
+ --num-layers 40 \
+ --hidden-size 5120 \
+ --ffn-hidden-size 13696 \
+ --seq-length 1024 \
+ --max-new-tokens 256 \
+ --micro-batch-size 1 \
+ --global-batch-size 16 \
+ --num-attention-heads 40 \
+ --max-position-embeddings 2048 \
+ --position-embedding-type alibi \
+ --square-alibi-mask \
+ --fill-neg-inf \
+ --swiglu \
+ --load ${CHECKPOINT} \
+ --tokenizer-type PretrainedFromHF \
+ --tokenizer-name-or-path ${TOKENIZER_PATH} \
+ --tokenizer-not-use-fast \
+ --fp16 \
+ --normalization RMSNorm \
+ --untie-embeddings-and-output-weights \
+ --disable-bias-linear \
+ --attention-softmax-in-fp32 \
+ --no-load-optim \
+ --no-load-rng \
+ --no-masked-softmax-fusion \
+ --no-gradient-accumulation-fusion \
+ --exit-on-missing-checkpoint \
+ --make-vocab-size-divisible-by 32 \
+ | tee logs/generate_baichuan2_13b_mcore.log
+
+
diff --git a/examples/mcore/baichuan2/pretrain_baichuan2_13b_ptd.sh b/examples/mcore/baichuan2/pretrain_baichuan2_13b_ptd.sh
new file mode 100644
index 00000000..1a420c0b
--- /dev/null
+++ b/examples/mcore/baichuan2/pretrain_baichuan2_13b_ptd.sh
@@ -0,0 +1,100 @@
+#!/bin/bash
+
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
+
+GPUS_PER_NODE=8
+MASTER_ADDR=localhost
+MASTER_PORT=6000
+NNODES=1
+NODE_RANK=0
+WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
+
+CKPT_SAVE_DIR="your model save ckpt path"
+DATA_PATH="your data path"
+TOKENIZER_MODEL="your tokenizer model path"
+CKPT_LOAD_DIR="your model load ckpt path"
+
+TP=8
+PP=1
+
+
+DISTRIBUTED_ARGS="
+ --nproc_per_node $GPUS_PER_NODE \
+ --nnodes $NNODES \
+ --node_rank $NODE_RANK \
+ --master_addr $MASTER_ADDR \
+ --master_port $MASTER_PORT
+"
+
+GPT_ARGS="
+ --use-mcore-models \
+ --tensor-model-parallel-size ${TP} \
+ --pipeline-model-parallel-size ${PP} \
+ --sequence-parallel \
+ --use-mc2 \
+ --num-layers 40 \
+ --hidden-size 5120 \
+ --ffn-hidden-size 13696 \
+ --num-attention-heads 40 \
+ --tokenizer-type Llama2Tokenizer \
+ --tokenizer-model ${TOKENIZER_MODEL} \
+ --seq-length 4096 \
+ --disable-bias-linear \
+ --max-position-embeddings 4096 \
+ --micro-batch-size 2 \
+ --global-batch-size 128 \
+ --untie-embeddings-and-output-weights \
+ --no-gradient-accumulation-fusion \
+ --make-vocab-size-divisible-by 32 \
+ --lr 1e-5 \
+ --load ${CKPT_LOAD_DIR} \
+ --train-iters 2000 \
+ --lr-decay-style cosine \
+ --attention-dropout 0.0 \
+ --init-method-std 0.01 \
+ --position-embedding-type alibi \
+ --hidden-dropout 0.0 \
+ --norm-epsilon 1e-6 \
+ --normalization RMSNorm \
+ --use-fused-rmsnorm \
+ --use-flash-attn \
+ --use-fused-swiglu \
+ --use-mc2 \
+ --swiglu \
+ --no-masked-softmax-fusion \
+ --attention-softmax-in-fp32 \
+ --square-alibi-mask \
+ --fill-neg-inf \
+ --min-lr 1e-8 \
+ --weight-decay 1e-4 \
+ --clip-grad 1.0 \
+ --seed 1234 \
+ --adam-beta1 0.9 \
+ --initial-loss-scale 8188.0 \
+ --adam-beta2 0.98 \
+ --adam-eps 1.0e-8 \
+ --no-load-optim \
+ --no-load-rng \
+ --bf16
+"
+
+DATA_ARGS="
+ --data-path ${DATA_PATH} \
+ --split 100,0,0
+"
+
+OUTPUT_ARGS="
+ --log-interval 1 \
+ --save-interval 2000 \
+ --eval-interval 2000 \
+ --eval-iters 0
+"
+
+torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
+ $GPT_ARGS \
+ $DATA_ARGS \
+ $OUTPUT_ARGS \
+ --distributed-backend nccl \
+ --save ${CKPT_SAVE_DIR} \
+ | tee logs/train_baichuan2_13b_mcore.log
diff --git a/examples/mcore/deepseek2_coder/ckpt_convert_deepseek2_hf2mcore.sh b/examples/mcore/deepseek2_coder/ckpt_convert_deepseek2_hf2mcore.sh
new file mode 100644
index 00000000..79c770b5
--- /dev/null
+++ b/examples/mcore/deepseek2_coder/ckpt_convert_deepseek2_hf2mcore.sh
@@ -0,0 +1,18 @@
+# 修改 ascend-toolkit 路径
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+
+# 权重格式转换,设置需要的并行策略
+python convert_ckpt.py \
+ --use-mcore-models \
+ --moe-grouped-gemm \
+ --model-type-hf deepseek2 \
+ --model-type GPT \
+ --load-model-type hf \
+ --save-model-type mg \
+ --params-dtype bf16 \
+ --target-tensor-parallel-size 1 \
+ --target-pipeline-parallel-size 1 \
+ --target-expert-parallel-size 8 \
+ --load-dir ./model_from_hf/deepseek2-coder-hf/ \
+ --save-dir ./model_weights/deepseek2-coder-mcore/ \
+ --tokenizer-model ./model_from_hf/deepseek2-coder-hf/
\ No newline at end of file
diff --git a/examples/mcore/deepseek2_coder/ckpt_convert_deepseek2_mcore2hf.sh b/examples/mcore/deepseek2_coder/ckpt_convert_deepseek2_mcore2hf.sh
new file mode 100644
index 00000000..7decb6ef
--- /dev/null
+++ b/examples/mcore/deepseek2_coder/ckpt_convert_deepseek2_mcore2hf.sh
@@ -0,0 +1,17 @@
+# 修改 ascend-toolkit 路径
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+
+# 权重格式转换
+python convert_ckpt.py \
+ --use-mcore-models \
+ --moe-grouped-gemm \
+ --model-type-hf deepseek2 \
+ --model-type GPT \
+ --load-model-type mg \
+ --save-model-type hf \
+ --params-dtype bf16 \
+ --target-tensor-parallel-size 1 \
+ --target-pipeline-parallel-size 1 \
+ --target-expert-parallel-size 1 \
+ --load-dir ./model_weights/deepseek2-coder-mcore/ \
+ --save-dir ./model_from_hf/deepseek2-coder-hf/
\ No newline at end of file
diff --git a/examples/mcore/deepseek2_coder/data_convert_deepseek2_pretrain.sh b/examples/mcore/deepseek2_coder/data_convert_deepseek2_pretrain.sh
new file mode 100644
index 00000000..051054ee
--- /dev/null
+++ b/examples/mcore/deepseek2_coder/data_convert_deepseek2_pretrain.sh
@@ -0,0 +1,14 @@
+# 请按照您的真实环境修改 set_env.sh 路径
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+mkdir ./dataset
+
+# 数据集下载地址 https://huggingface.co/datasets/lsb/enwiki20230101/blob/main/data/train-00000-of-00042-d964455e17e96d5a.parquet
+python ./preprocess_data.py \
+ --input ./dataset/train-00000-of-00042-d964455e17e96d5a.parquet \
+ --tokenizer-name-or-path ./model_from_hf/deepseek2-coder-hf/ \
+ --tokenizer-type PretrainedFromHF \
+ --handler-name GeneralPretrainHandler \
+ --output-prefix ./dataset/enwiki \
+ --json-keys text \
+ --workers 4 \
+ --log-interval 1000
\ No newline at end of file
diff --git a/examples/mcore/deepseek2_coder/pretrain_deepseek2_ptd_8p.sh b/examples/mcore/deepseek2_coder/pretrain_deepseek2_ptd_8p.sh
new file mode 100644
index 00000000..c05e5927
--- /dev/null
+++ b/examples/mcore/deepseek2_coder/pretrain_deepseek2_ptd_8p.sh
@@ -0,0 +1,147 @@
+#!/bin/bash
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
+
+GPUS_PER_NODE=8
+MASTER_ADDR=localhost
+MASTER_PORT=6000
+NNODES=1
+NODE_RANK=0
+WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
+
+CKPT_SAVE_DIR="your model save ckpt path"
+DATA_PATH="your data path"
+TOKENIZER_MODEL="your tokenizer path"
+CKPT_LOAD_DIR="your model ckpt path"
+
+TP=1
+PP=1
+EP=8
+
+DISTRIBUTED_ARGS="
+ --nproc_per_node $GPUS_PER_NODE \
+ --nnodes $NNODES \
+ --node_rank $NODE_RANK \
+ --master_addr $MASTER_ADDR \
+ --master_port $MASTER_PORT
+"
+
+MLA_ARGS="
+ --multi-head-latent-attention \
+ --qk-rope-head-dim 64 \
+ --qk-nope-head-dim 128 \
+ --q-lora-rank 1536 \
+ --kv-lora-rank 512 \
+ --v-head-dim 128 \
+ --qk-layernorm \
+"
+
+MOE_ARGS="
+ --moe-grouped-gemm \
+ --moe-permutation-async-comm \
+ --moe-token-dispatcher-type allgather \
+ --first-k-dense-replace 1 \
+ --moe-layer-freq 1 \
+ --n-shared-experts 2 \
+ --num-experts 160 \
+ --moe-router-topk 6 \
+ --moe-intermediate-size 1536 \
+ --moe-router-load-balancing-type group_limited_greedy \
+ --topk-group 3 \
+ --moe-aux-loss-coeff 0.003 \
+ --moe-device-level-aux-loss-coeff 0.05 \
+ --moe-comm-aux-loss-coeff 0.02 \
+ --routed-scaling-factor 16.0 \
+ --seq-aux
+"
+
+ROPE_ARGS="
+ --rope-scaling-beta-fast 32 \
+ --rope-scaling-beta-slow 1 \
+ --rope-scaling-factor 40 \
+ --rope-scaling-mscale 0.707 \
+ --rope-scaling-mscale-all-dim 0.707 \
+ --rope-scaling-original-max-position-embeddings 4096 \
+ --rope-scaling-type yarn
+"
+
+GPT_ARGS="
+ --load $CKPT_LOAD_DIR \
+ --use-distributed-optimizer \
+ --use-flash-attn \
+ --shape-order BNSD \
+ --use-mcore-models \
+ --tensor-model-parallel-size ${TP} \
+ --pipeline-model-parallel-size ${PP} \
+ --expert-model-parallel-size ${EP} \
+ --sequence-parallel \
+ --output-layer-slice-num 10 \
+ --num-layers 2 \
+ --hidden-size 5120 \
+ --ffn-hidden-size 12288 \
+ --num-attention-heads 128 \
+ --tokenizer-type PretrainedFromHF \
+ --tokenizer-name-or-path ${TOKENIZER_MODEL} \
+ --seq-length 8192 \
+ --max-position-embeddings 163840 \
+ --micro-batch-size 1 \
+ --global-batch-size 64 \
+ --make-vocab-size-divisible-by 1 \
+ --lr 1.0e-5 \
+ --train-iters 2000 \
+ --lr-decay-style cosine \
+ --untie-embeddings-and-output-weights \
+ --disable-bias-linear \
+ --attention-dropout 0.0 \
+ --init-method-std 0.02 \
+ --hidden-dropout 0.0 \
+ --position-embedding-type rope \
+ --normalization RMSNorm \
+ --use-fused-rotary-pos-emb \
+ --use-rotary-position-embeddings \
+ --use-fused-swiglu \
+ --use-fused-rmsnorm \
+ --swiglu \
+ --no-masked-softmax-fusion \
+ --attention-softmax-in-fp32 \
+ --min-lr 1.0e-7 \
+ --weight-decay 1e-2 \
+ --lr-warmup-iters 500 \
+ --clip-grad 1.0 \
+ --adam-beta1 0.9 \
+ --adam-beta2 0.999 \
+ --initial-loss-scale 65536 \
+ --vocab-size 102400 \
+ --padded-vocab-size 102400 \
+ --rotary-base 10000 \
+ --no-gradient-accumulation-fusion \
+ --norm-epsilon 1e-6 \
+ --no-load-optim \
+ --no-load-rng \
+ --bf16
+"
+
+DATA_ARGS="
+ --data-path $DATA_PATH \
+ --split 100,0,0
+"
+
+OUTPUT_ARGS="
+ --log-interval 1 \
+ --save-interval 2000 \
+ --eval-interval 2000 \
+ --eval-iters 0 \
+ --no-save-optim \
+ --no-save-rng
+"
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
+ $GPT_ARGS \
+ $DATA_ARGS \
+ $OUTPUT_ARGS \
+ $MLA_ARGS \
+ $ROPE_ARGS \
+ $MOE_ARGS \
+ --distributed-backend nccl \
+ --save $CKPT_SAVE_DIR \
+ | tee logs/pretrain_deepseek2_ptd_8p.log
diff --git a/modellink/core/transformer/dot_product_attention.py b/modellink/core/transformer/dot_product_attention.py
index 8d10b77b..a4668064 100644
--- a/modellink/core/transformer/dot_product_attention.py
+++ b/modellink/core/transformer/dot_product_attention.py
@@ -17,6 +17,7 @@ from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid
from modellink.model.transformer import get_attention_mask
from modellink.core.models.common.embeddings.rotary_pos_embedding import yarn_get_mscale
from modellink.utils import get_actual_seq_len
+from modellink.model.alibi import Alibi
try:
from einops import rearrange
@@ -63,6 +64,19 @@ def dot_product_attention_init_wrapper(fn):
args = get_args()
self.attn_logit_softcapping = args.attn_logit_softcapping
+ self.square_alibi_mask = args.square_alibi_mask
+ self.fill_neg_inf = args.fill_neg_inf
+ self.beta = 1.0
+ self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
+ if self.apply_query_key_layer_scaling:
+ self.beta = 1.0 / self.layer_number
+
+ if args.position_embedding_type == 'alibi':
+ get_alibi(self, args.seq_length)
+ self.alibi_output_size = None
+ else:
+ self.alibi = None
+
if args.query_pre_attn_scalar:
self.norm_factor = args.query_pre_attn_scalar ** 0.5
self.scale_mask_softmax.scale = 1.0
@@ -87,6 +101,21 @@ def dot_product_attention_init_wrapper(fn):
return wrapper
+def get_alibi(self, seq_length):
+ args = get_args()
+ self.alibi = Alibi()
+ alibi = self.alibi._build_alibi_tensor(seq_length,
+ args.num_attention_heads,
+ args.square_alibi_mask,
+ args.fill_neg_inf,
+ ).to(torch.cuda.current_device())
+ if args.params_dtype == torch.float16:
+ alibi = alibi.to(torch.float16)
+ elif args.params_dtype == torch.bfloat16:
+ alibi = alibi.to(torch.bfloat16)
+ self.alibi.alibi = alibi
+
+
def dot_product_attention_forward_wrapper(fn):
@wraps(fn)
def wrapper(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params):
@@ -134,18 +163,27 @@ def dot_product_attention_forward_wrapper(fn):
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
- matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
- (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
- )
+ if self.alibi is None:
+ matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
+ (output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
+ )
- # Raw attention scores. [b * np, sq, sk]
- matmul_result = torch.baddbmm(
- matmul_input_buffer,
- query.transpose(0, 1), # [b * np, sq, hn]
- key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
- beta=0.0,
- alpha=(1.0 / self.norm_factor),
- )
+ # Raw attention scores. [b * np, sq, sk]
+ matmul_result = torch.baddbmm(
+ matmul_input_buffer,
+ query.transpose(0, 1), # [b * np, sq, hn]
+ key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
+ beta=0.0,
+ alpha=(1.0 / self.norm_factor),
+ )
+ else:
+ if self.alibi.alibi_pse is None or self.alibi.output_size != output_size:
+ self.alibi.output_size = output_size
+ self.alibi.get_alibi_pse(attention_mask, output_size[0], output_size[2], output_size[3])
+
+ q_trans = query.transpose(0, 1).contiguous()
+ k_trans = key.transpose(0, 1).transpose(1, 2).contiguous()
+ matmul_result = self.beta * self.alibi.alibi_pse + torch.bmm(q_trans, k_trans) * (1.0 / self.norm_factor)
if self.attn_logit_softcapping is not None:
matmul_result = matmul_result / self.attn_logit_softcapping
@@ -160,7 +198,13 @@ def dot_product_attention_forward_wrapper(fn):
# ===========================
# attention scores and attention mask [b, np, sq, sk]
- attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
+ if self.square_alibi_mask:
+ attention_scores = torch.max(
+ attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min)
+ )
+ attention_probs = torch.nn.functional.softmax(attention_scores, -1)
+ else:
+ attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
@@ -224,7 +268,7 @@ def dot_product_attention_forward(
args = get_args()
- seq_length, _, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
+ seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
actual_seq_len = None
if args.reset_attention_mask or args.reset_position_ids:
query, key, value = [rearrange(x, 's b h d -> (s b) h d') for x in [query, key, value]]
@@ -241,9 +285,18 @@ def dot_product_attention_forward(
raise AssertionError("self.hidden_size_per_attention_head should not be ZERO.")
scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \
if self.scale_mask_softmax.scale is None else self.softmax_scale
- if attention_mask is None or args.reset_attention_mask or args.reset_position_ids:
- attention_mask = get_attention_mask()
-
+ if not hasattr(self, 'attention_mask') or \
+ self.attention_mask is None or \
+ self.attention_mask.shape[0] != seq_length or \
+ args.reset_attention_mask or args.reset_position_ids:
+ if self.alibi is not None:
+ self.attention_mask = torch.triu(
+ torch.ones(seq_length, seq_length),
+ 1).bool().npu()
+ elif attention_mask is None:
+ self.attention_mask = get_attention_mask()
+ else:
+ self.attention_mask = attention_mask
if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']:
return do_ring_context_parallel(
query, key, value, head_num=n_head, softmax_scale=scale, attn_mask=attention_mask)
@@ -253,11 +306,27 @@ def dot_product_attention_forward(
if use_sliding_windows:
args.pre_tockens = args.sliding_window
+ pse = None
+ size_record = key.shape
+ if self.alibi is not None and (self.alibi.output_size != size_record) and pse is None:
+ if args.shape_order != 'SBH':
+ raise ValueError(
+ 'FlashAttention with Alibi requires for SBH shape_order, but is {}.'.format(args.shape_order))
+
+ self.alibi.output_size = size_record
+ self.alibi.get_alibi_pse(self.attention_mask, batch_size, query.shape[0], key.shape[0])
+
+ if self.alibi and pse is None:
+ pse = self.alibi.alibi_pse.reshape(
+ batch_size, n_head, self.alibi.alibi_pse.size(1), -1) * self.beta * self.norm_factor
+ args.pre_tockens = seq_length
+ args.sparse_mode = 0
+
output = torch_npu.npu_fusion_attention(
query, key, value, n_head, args.shape_order,
- pse=None,
+ pse=pse,
padding_mask=None,
- atten_mask=attention_mask,
+ atten_mask=self.attention_mask,
actual_seq_qlen=actual_seq_len,
actual_seq_kvlen=actual_seq_len,
scale=scale,
diff --git a/modellink/tasks/checkpoint/model_cfg.json b/modellink/tasks/checkpoint/model_cfg.json
index 68475619..663b68e6 100644
--- a/modellink/tasks/checkpoint/model_cfg.json
+++ b/modellink/tasks/checkpoint/model_cfg.json
@@ -259,6 +259,16 @@
"layers_mlp_linear_fc2": "model.layers[layer_idx].mlp.experts[expert_idx].w2",
"final_layernorm": "model.norm"
}
+ },
+ "baichuan2": {
+ "__base__": "base",
+ "config_set_value": {
+ "qkv_type": "pack_gqa",
+ "max_position_embeddings": 4096
+ },
+ "model_hf_key_mapping": {
+ "layers_self_attention_linear_qkv_pack": "model.layers[layer_idx].self_attn.W_pack"
+ }
}
}
}