!1568 新增gemma2系列模型适配

Merge pull request !1568 from yuhui/gemma2
This commit is contained in:
yuhui 2024-09-03 01:57:47 +00:00 committed by i-robot
parent 87624bbb05
commit 705158ab41
21 changed files with 741 additions and 13 deletions

View File

@ -744,6 +744,29 @@ ModelLink 通过模型并行与数据并行来训练大语言模型,为了演
<td><center>【GTS】</td>
<td>【Test】</td>
</tr>
<tr>
<td rowspan="2"><a href="https://huggingface.co/google">Gemma2</a></td>
<td><a href="https://huggingface.co/google/gemma-2-9b/tree/main">9B</a></td>
<td> 8K </td>
<th>Mcore</th>
<td>1x8</td>
<td>BF16 </td>
<td> 1713 </td>
<td> 1595 </td>
<td><center>【GTS】</td>
<td>【Test】</td>
</tr>
<tr>
<td><a href="https://huggingface.co/google/gemma-2-27b/tree/main">27B</a></td>
<td> 8K </td>
<th>Mcore</th>
<td>2x8</td>
<td>BF16 </td>
<td> 827 </td>
<td> 800 </td>
<td><center>【GTS】</td>
<td>【Test】</td>
</tr>
<tr>
<td rowspan="2">GPT3</td>
<td>175B</td>

View File

@ -785,6 +785,7 @@ ModelLink已支持模型评估分数如下
| Bloom-176B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 64.5% | -- | ChatGLM3-6B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 61.5% | 61.4% |
| CodeLLaMA-34B | <a href="https://huggingface.co/datasets/openai_humaneval">Human Eval</a> | 48.78% | [48.8%](https://paperswithcode.com/sota/code-generation-on-humaneval) | Gemma-2B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 39.6% | 39.7% |
| Gemma-7B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 52.2% | 52.2% | InternLM-7B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 48.7% | [51.0](https://huggingface.co/internlm/internlm-7b) |
| Gemma2-9B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 70.7% | [71.3%](https://huggingface.co/google/gemma-2-9b) | Gemma2-27B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 75.5% | [75.2%](https://huggingface.co/google/gemma-2-27b) |
| LLaMA-7B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 74.6% | [75.4](https://hub.opencompass.org.cn/dataset-detail/BoolQ) | LLaMA-13B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 79.6% | [78.7](https://hub.opencompass.org.cn/dataset-detail/BoolQ) |
| LLaMA-33B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 83.2% | [83.1](https://paperswithcode.com/sota/question-answering-on-boolq) | LLaMA-65B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 85.7% | [86.6](https://paperswithcode.com/sota/question-answering-on-boolq) |
| LLaMA2-7B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 45.7% | 45.3% | LLaMA2-13B | [BoolQ](https://paperswithcode.com/dataset/boolq) | 82.2% | [81.7](https://paperswithcode.com/sota/question-answering-on-boolq) |

View File

@ -0,0 +1,17 @@
# 修改 ascend-toolkit 路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 权重格式转换
python tools/checkpoint/convert_ckpt.py \
--use-mcore-models \
--model-type-hf gemma2 \
--model-type GPT \
--loader hf_mcore \
--saver mg_mcore \
--params-dtype bf16 \
--post-norm \
--target-tensor-parallel-size 8 \
--target-pipeline-parallel-size 1 \
--load-dir ./model_from_hf/gemma2_hf/ \
--save-dir ./model_weights/gemma2_mcore/ \
--tokenizer-model ./model_from_hf/gemma2_hf/tokenizer.json

View File

@ -0,0 +1,11 @@
# 请按照您的真实环境修改 set_env.sh 路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
mkdir ./dataset
python ./preprocess_data.py \
--input ./dataset/train-00000-of-00001-a09b74b3ef9c3b56.parquet \
--tokenizer-name-or-path ./model_from_hf/gemma2_hf/ \
--tokenizer-type PretrainedFromHF
--output-prefix ./dataset/enwiki \
--workers 4 \
--log-interval 1000 \

View File

@ -0,0 +1,69 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
# distributed config
NPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
# modify script model path and tokenizer path
TOKENIZER_PATH="your tokenizer directory path"
CHECKPOINT="your model directory path"
# configure task and data path
DATA_PATH="./mmlu/test/"
TASK="mmlu"
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
# configure generation parameters
torchrun $DISTRIBUTED_ARGS evaluation.py \
--task-data-path ${DATA_PATH} \
--task ${TASK}\
--load ${CHECKPOINT} \
--use-mcore-models \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 \
--gelu-tanh \
--post-norm \
--query-pre-attn-scalar 144 \
--output-logit-softcapping 30.0 \
--attn-logit-softcapping 50.0 \
--interleave-sliding-window 4096 \
--group-query-attention \
--num-query-groups 16 \
--num-layers 46 \
--hidden-size 4608 \
--ffn-hidden-size 36864 \
--num-attention-heads 32 \
--kv-channels 128 \
--max-position-embeddings 8192 \
--seq-length 8192 \
--max-new-tokens 1 \
--position-embedding-type rope \
--disable-bias-linear \
--normalization RMSNorm \
--add-rmsnorm-offset \
--input-embeds-norm \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_PATH} \
--tokenizer-not-use-fast \
--norm-epsilon 1e-06 \
--evaluation-batch-size 1 \
--micro-batch-size 1 \
--use-fused-rmsnorm \
--no-masked-softmax-fusion \
--exit-on-missing-checkpoint \
--no-load-rng \
--no-load-optim \
--vocab-size 256000 \
--make-vocab-size-divisible-by 1 \
--bf16 \
--seed 42 | tee logs/evaluation_gemma2_27b_mcore_${TASK}.log

View File

@ -0,0 +1,69 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
# distributed config
NPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
# modify script model path and tokenizer path
TOKENIZER_PATH="your tokenizer directory path"
CHECKPOINT="your model directory path"
# configure task and data path
DATA_PATH="./mmlu/test/"
TASK="mmlu"
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
# configure generation parameters
torchrun $DISTRIBUTED_ARGS evaluation.py \
--task-data-path ${DATA_PATH} \
--task ${TASK}\
--load ${CHECKPOINT} \
--use-mcore-models \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 \
--gelu-tanh \
--post-norm \
--query-pre-attn-scalar 256 \
--output-logit-softcapping 30.0 \
--attn-logit-softcapping 50.0 \
--interleave-sliding-window 4096 \
--group-query-attention \
--num-query-groups 8 \
--num-layers 42 \
--hidden-size 3584 \
--ffn-hidden-size 14336 \
--num-attention-heads 16 \
--kv-channels 256 \
--max-position-embeddings 8192 \
--seq-length 8192 \
--max-new-tokens 1 \
--position-embedding-type rope \
--disable-bias-linear \
--normalization RMSNorm \
--add-rmsnorm-offset \
--input-embeds-norm \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_PATH} \
--tokenizer-not-use-fast \
--norm-epsilon 1e-06 \
--evaluation-batch-size 1 \
--micro-batch-size 1 \
--use-fused-rmsnorm \
--no-masked-softmax-fusion \
--exit-on-missing-checkpoint \
--no-load-rng \
--no-load-optim \
--vocab-size 256000 \
--make-vocab-size-divisible-by 1 \
--bf16 \
--seed 42 | tee logs/evaluation_gemma2_9b_mcore_${TASK}.log

View File

@ -0,0 +1,66 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
# Change for multinode config
NPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6001
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
# please fill these path configurations
TOKENIZER_PATH="your tokenizer directory path"
CHECKPOINT="your model directory path"
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
torchrun $DISTRIBUTED_ARGS inference.py \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 \
--use-mcore-models \
--gelu-tanh \
--post-norm \
--query-pre-attn-scalar 144 \
--output-logit-softcapping 30.0 \
--attn-logit-softcapping 50.0 \
--interleave-sliding-window 4096 \
--group-query-attention \
--num-query-groups 16 \
--load ${CHECKPOINT} \
--num-layers 46 \
--hidden-size 4608 \
--kv-channels 128 \
--ffn-hidden-size 36864 \
--num-attention-heads 32 \
--position-embedding-type rope \
--seq-length 8192 \
--max-position-embeddings 8192 \
--max-new-tokens 256 \
--micro-batch-size 1 \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_PATH} \
--tokenizer-not-use-fast \
--normalization RMSNorm \
--add-rmsnorm-offset \
--norm-epsilon 1e-06 \
--input-embeds-norm \
--disable-bias-linear \
--hidden-dropout 0 \
--attention-dropout 0 \
--attention-softmax-in-fp32 \
--no-load-optim \
--no-load-rng \
--no-masked-softmax-fusion \
--no-gradient-accumulation-fusion \
--exit-on-missing-checkpoint \
--make-vocab-size-divisible-by 1 \
--vocab-size 256000 \
--bf16 \
--seed 42 \
| tee logs/generate_gemma2_27b_mcore.log

View File

@ -0,0 +1,66 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
# Change for multinode config
NPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6001
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
# please fill these path configurations
TOKENIZER_PATH="your tokenizer directory path"
CHECKPOINT="your model directory path"
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
torchrun $DISTRIBUTED_ARGS inference.py \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 \
--use-mcore-models \
--gelu-tanh \
--post-norm \
--query-pre-attn-scalar 256 \
--output-logit-softcapping 30.0 \
--attn-logit-softcapping 50.0 \
--interleave-sliding-window 4096 \
--group-query-attention \
--num-query-groups 8 \
--load ${CHECKPOINT} \
--num-layers 42 \
--hidden-size 3584 \
--kv-channels 256 \
--ffn-hidden-size 14436 \
--num-attention-heads 16 \
--position-embedding-type rope \
--seq-length 8192 \
--max-position-embeddings 8192 \
--max-new-tokens 256 \
--micro-batch-size 1 \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_PATH} \
--tokenizer-not-use-fast \
--normalization RMSNorm \
--add-rmsnorm-offset \
--norm-epsilon 1e-06 \
--input-embeds-norm \
--disable-bias-linear \
--hidden-dropout 0 \
--attention-dropout 0 \
--attention-softmax-in-fp32 \
--no-load-optim \
--no-load-rng \
--no-masked-softmax-fusion \
--no-gradient-accumulation-fusion \
--exit-on-missing-checkpoint \
--make-vocab-size-divisible-by 1 \
--vocab-size 256000 \
--bf16 \
--seed 42 \
| tee logs/generate_gemma2_9b_mcore.log

View File

@ -0,0 +1,105 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
NPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=2
NODE_RANK=0
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
CKPT_SAVE_DIR="your model save ckpt path"
DATA_PATH="your data path"
TOKENIZER_MODEL="your tokenizer path"
CKPT_LOAD_DIR="your model ckpt path"
TP=8
PP=2
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
GPT_ARGS="
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--sequence-parallel \
--use-mcore-models \
--use-mc2 \
--use-fused-rmsnorm \
--use-fused-rotary-pos-emb \
--gelu-tanh \
--post-norm \
--query-pre-attn-scalar 144 \
--output-logit-softcapping 30.0 \
--interleave-sliding-window 4096 \
--num-layers 46 \
--num-layer-list 20,26 \
--hidden-size 4608 \
--ffn-hidden-size 36864 \
--num-attention-heads 32 \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_MODEL} \
--seq-length 8192 \
--max-position-embeddings 8192 \
--micro-batch-size 1 \
--global-batch-size 64 \
--kv-channels 128 \
--group-query-attention \
--num-query-groups 16 \
--make-vocab-size-divisible-by 1 \
--lr 1.25e-6 \
--train-iters 2000 \
--lr-decay-style cosine \
--disable-bias-linear \
--attention-dropout 0.0 \
--init-method-std 0.01 \
--hidden-dropout 0.0 \
--position-embedding-type rope \
--normalization RMSNorm \
--add-rmsnorm-offset \
--norm-epsilon 1e-06 \
--input-embeds-norm \
--use-flash-attn \
--use-distributed-optimizer \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--min-lr 1.25e-7 \
--weight-decay 1e-1 \
--lr-warmup-fraction 0.01 \
--clip-grad 1.0 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--tokenizer-padding-side left \
--initial-loss-scale 4096 \
--no-gradient-accumulation-fusion \
--no-load-optim \
--no-load-rng \
--vocab-size 256000 \
--bf16
"
DATA_ARGS="
--data-path $DATA_PATH \
--split 100,0,0
"
OUTPUT_ARGS="
--log-interval 1 \
--save-interval 2000 \
--eval-interval 1000 \
--eval-iters 0 \
"
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--load ${CKPT_LOAD_DIR} \
--save ${CKPT_SAVE_DIR} \
| tee logs/train_gemma2_27b_mcore.log

View File

@ -0,0 +1,104 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
NPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
CKPT_SAVE_DIR="your model save ckpt path"
DATA_PATH="your data path"
TOKENIZER_MODEL="your tokenizer path"
CKPT_LOAD_DIR="your model ckpt path"
TP=8
PP=1
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
GPT_ARGS="
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--sequence-parallel \
--use-mcore-models \
--use-mc2 \
--use-fused-rmsnorm \
--use-fused-rotary-pos-emb \
--gelu-tanh \
--post-norm \
--query-pre-attn-scalar 256 \
--output-logit-softcapping 30.0 \
--interleave-sliding-window 4096 \
--num-layers 42 \
--hidden-size 3584 \
--ffn-hidden-size 14336 \
--num-attention-heads 16 \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_MODEL} \
--seq-length 8192 \
--max-position-embeddings 8192 \
--micro-batch-size 1 \
--global-batch-size 64 \
--kv-channels 256 \
--group-query-attention \
--num-query-groups 8 \
--make-vocab-size-divisible-by 1 \
--lr 1.25e-6 \
--train-iters 2000 \
--lr-decay-style cosine \
--disable-bias-linear \
--attention-dropout 0.0 \
--init-method-std 0.01 \
--hidden-dropout 0.0 \
--position-embedding-type rope \
--normalization RMSNorm \
--add-rmsnorm-offset \
--norm-epsilon 1e-06 \
--input-embeds-norm \
--use-flash-attn \
--use-distributed-optimizer \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--min-lr 1.25e-7 \
--weight-decay 1e-1 \
--lr-warmup-fraction 0.01 \
--clip-grad 1.0 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--tokenizer-padding-side left \
--initial-loss-scale 4096 \
--no-gradient-accumulation-fusion \
--no-load-optim \
--no-load-rng \
--vocab-size 256000 \
--bf16
"
DATA_ARGS="
--data-path $DATA_PATH \
--split 100,0,0
"
OUTPUT_ARGS="
--log-interval 1 \
--save-interval 2000 \
--eval-interval 1000 \
--eval-iters 0 \
"
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--load ${CKPT_LOAD_DIR} \
--save ${CKPT_SAVE_DIR} \
| tee logs/train_gemma2_9b_mcore.log

View File

@ -381,6 +381,13 @@ def _add_network_args(parser):
help='Geglu activate function.')
group.add_argument('--input-embeds-norm', action='store_true', default=False,
help='input normalization.')
group.add_argument('--gelu-tanh', action='store_true', default=False,
help='Tanh Geglu activate function.')
group.add_argument('--output-logit-softcapping', type=float, help='output logit softcapping.')
group.add_argument('--attn-logit-softcapping', type=float, help='attention logit softcapping.')
group.add_argument('--query-pre-attn-scalar', type=int, help='attention scalar.')
group.add_argument('--interleave-sliding-window', type=int,
help='Window size when use interleave sliding window attention.')
return parser

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from functools import wraps
@ -106,6 +106,10 @@ def gpt_model_forward(self, input_ids: Tensor,
if args.output_multiplier_scale:
logits = logits * args.output_multiplier_scale
if args.output_logit_softcapping:
logits = logits / args.output_logit_softcapping
logits = torch.tanh(logits)
logits = logits * args.output_logit_softcapping
if labels is None:
# [s b h] => [b s h]

View File

@ -4,10 +4,11 @@
import math
from functools import wraps
import torch
from torch import Tensor
import torch_npu
from megatron.training import get_args
from megatron.core import mpu
from megatron.core import mpu, parallel_state, tensor_parallel
from mindspeed.core.context_parallel.ring_context_parallel import ringattn_context_parallel
from mindspeed.core.parallel_state import (get_context_parallel_group_for_hybrid_ring,
get_context_parallel_for_hybrid_ring_world_size,
@ -60,6 +61,12 @@ def dot_product_attention_init_wrapper(fn):
config.context_parallel_size = cp_size
args = get_args()
self.attn_logit_softcapping = args.attn_logit_softcapping
if args.query_pre_attn_scalar:
self.norm_factor = args.query_pre_attn_scalar ** 0.5
self.scale_mask_softmax.scale = 1.0
self.softmax_scale = 1.0 / self.norm_factor
if args.multi_head_latent_attention:
self.scale_mask_softmax.scale = True
self.hidden_size_per_partition = args.num_attention_heads * args.v_head_dim
@ -85,7 +92,119 @@ def dot_product_attention_forward_wrapper(fn):
if get_args().use_flash_attn:
return dot_product_attention_forward(self, query, key, value, attention_mask, attn_mask_type,
packed_seq_params)
return fn(self, query, key, value, attention_mask, attn_mask_type, packed_seq_params)
assert packed_seq_params is None, (
"Packed sequence is not supported by DotProductAttention."
"Please use TEDotProductAttention instead."
)
# ===================================
# Raw attention scores. [b, n/p, s, s]
# ===================================
# expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
# This is a noop for normal attention where ng == np. When using group query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
# attn_mask_type is not used.
if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
key = key.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
value = value.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
# [b, np, sq, sk]
output_size = (
query.size(1),
query.size(2),
query.size(0),
key.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
# This will be a simple view when doing normal attention, but in group query attention
# the key and value tensors are repeated to match the queries so you can't use simple strides
# to extract the queries.
query = query.reshape(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query.transpose(0, 1), # [b * np, sq, hn]
key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
if self.attn_logit_softcapping is not None:
matmul_result = matmul_result / self.attn_logit_softcapping
matmul_result = torch.tanh(matmul_result)
matmul_result = matmul_result * self.attn_logit_softcapping
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.config.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value.size(1),
value.size(2),
query.size(0),
value.size(3),
)
# change view [sk, b * np, hn]
value = value.view(value.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context = torch.bmm(attention_probs, value.transpose(0, 1))
# change view [b, np, sq, hn]
context = context.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context = context.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,)
context = context.view(*new_context_shape)
return context
return wrapper

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import wraps
import torch
import torch.nn.functional as F
@ -85,6 +85,14 @@ def core_mlp_init(self, config, submodules, is_expert=False, input_size=None):
self.config.activation_func = F.gelu
self.config.bias_gelu_fusion = False
if _args.gelu_tanh:
def gelu_tanh_approximation(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
self.config.gated_linear_unit = True
self.config.activation_func = gelu_tanh_approximation
self.config.bias_gelu_fusion = False
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2

View File

@ -142,7 +142,8 @@ def transformer_block_forward(
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
if self.input_embeds_norm and self.pre_process:
hidden_states = hidden_states * (self.hidden_size ** 0.5)
normalizer = torch.tensor(self.hidden_size ** 0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True,

View File

@ -17,6 +17,7 @@ from functools import wraps
from dataclasses import dataclass, field
from typing import Dict, Union
import torch
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.moe.moe_layer import MoELayer
@ -77,6 +78,9 @@ def transformer_layer_init_wrapper(fn):
expert.layer_number = self.layer_number
else:
self.mlp.layer_number = self.layer_number
self.is_sliding = not bool((self.layer_number - 1) % 2)
self.interleave_sliding_window = args_pos_norm.interleave_sliding_window
return wrapper
@ -88,6 +92,15 @@ def transformer_layer_forward(self, hidden_states, attention_mask, context=None,
# hidden_states: [s, b, h]
args_pos_norm = get_args()
if self.interleave_sliding_window is not None and self.is_sliding and attention_mask is not None:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.interleave_sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask).bool()
# when decoding
if attention_mask.shape[-1] <= 1:
attention_mask = attention_mask[:, :, :, -self.interleave_sliding_window:]
# Residual connection.
residual = hidden_states

View File

@ -63,7 +63,7 @@ def main():
help='Do not perform checking on the name and ordering of weights',
dest='checking')
parser.add_argument('--model-type-hf', type=str, default="llama2",
choices=['llama2', 'mixtral', 'chatglm3', 'gemma', 'bloom', 'qwen'], help='model-type')
choices=['llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'bloom', 'qwen'], help='model-type')
known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)

View File

@ -55,6 +55,8 @@ def add_arguments(parser):
'This is added for computational efficiency reasons.')
group.add_argument('--use-mcore-models', action='store_true',
help='Use the implementation from megatron core')
group.add_argument('--post-norm', action='store_true',
help='post norm after attention or mlp.')
def verify_transformers_version():
@ -116,11 +118,17 @@ def get_message_preprocess(model, md):
return message
def get_message_layer_norm(message, model, layer_idx, md):
def get_message_layer_norm(message, model, layer_idx, md, args=None):
# Get non-parallel tensors from tp_rank 0.
message["input norm weight"] = model.get_layers_input_layernorm_weight(layer_idx=layer_idx)
message["post norm weight"] = model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx)
if args.post_norm:
message["post norm weight"] = model.get_layers_self_attention_post_attention_layernorm_weight(
layer_idx=layer_idx)
message["pre mlp norm weight"] = model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx)
message["post mlp norm weight"] = model.get_layers_self_attention_post_mlp_layernorm_weight(layer_idx=layer_idx)
if md.norm_has_bias:
message["input norm bias"] = model.get_layers_input_layernorm_bias(layer_idx=layer_idx)
message["post norm bias"] = model.get_layers_self_attention_pre_mlp_layernorm_bias(layer_idx=layer_idx)
@ -272,7 +280,7 @@ def _load_checkpoint(queue, args):
for layer_idx in range(margs.num_layers):
# Grab all parallel tensors for this layer.
message = {}
message = get_message_layer_norm(message, model_mg, layer_idx, md)
message = get_message_layer_norm(message, model_mg, layer_idx, md, args)
message = get_message_layer_attn(message, model_mg, layer_idx, md, args)
message = get_message_layer_mlp(message, model_mg, layer_idx, md)

View File

@ -94,11 +94,26 @@
"gemma": {
"__base__": "base",
"config_set_value": {
"seq_length": 4096,
"seq_length": 8192,
"tie_word_embeddings": true,
"kv_channels": 256
}
},
"gemma2": {
"__base__": "base",
"config_set_value": {
"seq_length": 8192,
"tie_word_embeddings": true
},
"config_hf_key_mapping": {
"kv_channels": "head_dim"
},
"model_hf_key_mapping": {
"layers_self_attention_post_attention_layernorm": "model.layers[layer_idx].post_attention_layernorm",
"layers_self_attention_pre_mlp_layernorm": "model.layers[layer_idx].pre_feedforward_layernorm",
"layers_self_attention_post_mlp_layernorm": "model.layers[layer_idx].post_feedforward_layernorm"
}
},
"bloom": {
"__base__": "base",
"config_set_value": {

View File

@ -160,10 +160,16 @@ class ModelBase(abc.ABC):
self.set_attn_state(layer_idx, src_model)
self.set_mlp_state(layer_idx, src_model)
input_layernorm_weight = src_model.get_layers_input_layernorm_weight(layer_idx=layer_idx)
pre_mlp_layernorm_weight = src_model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx)
self.set_layers_input_layernorm_weight(layer_idx=layer_idx, data=input_layernorm_weight)
self.set_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx, data=pre_mlp_layernorm_weight)
if self.args.post_norm:
post_attn_layernorm_weight = src_model.get_layers_self_attention_post_attention_layernorm_weight(
layer_idx=layer_idx)
self.set_layers_self_attention_post_attention_layernorm_weight(layer_idx=layer_idx,
data=post_attn_layernorm_weight)
else:
pre_mlp_layernorm_weight = src_model.get_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx)
self.set_layers_self_attention_pre_mlp_layernorm_weight(layer_idx=layer_idx, data=pre_mlp_layernorm_weight)
if self.has_layers_input_layernorm_bias(layer_idx=layer_idx):
input_layernorm_bias = src_model.get_layers_input_layernorm_bias(layer_idx=layer_idx)
self.set_layers_input_layernorm_bias(layer_idx=layer_idx, data=input_layernorm_bias)
@ -198,6 +204,12 @@ class ModelBase(abc.ABC):
fc2_bias = src_model.get_layers_mlp_linear_fc2_bias(**kwargs)
self.set_layers_mlp_linear_fc2_bias(data=fc2_bias, **kwargs)
if self.args.post_norm:
pre_mlp_layernorm_weight = src_model.get_layers_self_attention_pre_mlp_layernorm_weight(**kwargs)
post_mlp_layernorm_weight = src_model.get_layers_self_attention_post_mlp_layernorm_weight(**kwargs)
self.set_layers_self_attention_pre_mlp_layernorm_weight(data=pre_mlp_layernorm_weight, **kwargs)
self.set_layers_self_attention_post_mlp_layernorm_weight(data=post_mlp_layernorm_weight, **kwargs)
def set_mlp_state(self, layer_idx, src_model):
args = src_model.get_args()
kwargs = {'layer_idx': layer_idx}
@ -301,6 +313,7 @@ class HuggingfaceModel(ModelBase):
self.args = SimpleNamespace(**self.args)
self.args.add_qkv_bias = self.args_cmd.add_qkv_bias
self.args.add_dense_bias = self.args_cmd.add_dense_bias
self.args.post_norm = self.args_cmd.post_norm
def get_modules_from_pretrained(self, device_map="cpu", trust_remote_code=True):
# Load Huggingface model.
@ -575,6 +588,7 @@ class MegatronModel(ModelBase):
self.args.w_pack = self.args_cmd.w_pack
self.args.add_qkv_bias = self.args_cmd.add_qkv_bias
self.args.add_dense_bias = self.args_cmd.add_dense_bias
self.args.post_norm = self.args_cmd.post_norm
self.args.tokenizer_model = getattr(self.args_cmd, 'tokenizer_model', None)
self.args.make_vocab_size_divisible_by = getattr(self.args_cmd, 'make_vocab_size_divisible_by', None)
if self.args_cmd.params_dtype == 'bf16':
@ -870,7 +884,7 @@ class MegatronMCoreModel(MegatronModel):
"layers_self_attention_linear_qkv": module_layer + "self_attention.linear_qkv",
"layers_self_attention_q_layernorm": module_layer + "self_attention.q_layernorm",
"layers_self_attention_k_layernorm": module_layer + "self_attention.k_layernorm",
"layers_self_attention_post_attention_layernorm": module_layer + "pre_mlp_layernorm",
"layers_self_attention_post_attention_layernorm": module_layer + "post_attn_norm",
"layers_self_attention_pre_mlp_layernorm": module_layer + "pre_mlp_layernorm",
"layers_mlp_linear_fc1": module_layer + "mlp.linear_fc1",
"layers_mlp_linear_fc2": module_layer + "mlp.linear_fc2",

View File

@ -164,6 +164,10 @@ def set_model_layer_norm(model_mg, msg, md, **kwargs):
margs = model_mg.get_args()
post_norm = margs.post_norm
if post_norm:
pre_mlp_norm_weight = msg.pop("pre mlp norm weight")
post_mlp_norm_weight = msg.pop("post mlp norm weight")
# Save them to the model
for ep_rank in range(margs.expert_model_parallel_size):
kwargs["ep_rank"] = ep_rank
@ -174,6 +178,10 @@ def set_model_layer_norm(model_mg, msg, md, **kwargs):
if input_norm_bias is not None:
model_mg.set_layers_input_layernorm_bias(**kwargs, data=input_norm_bias)
model_mg.set_layers_self_attention_pre_mlp_layernorm_weight(**kwargs, data=post_norm_weight)
if post_norm:
model_mg.set_layers_self_attention_pre_mlp_layernorm_weight(**kwargs, data=pre_mlp_norm_weight)
model_mg.set_layers_self_attention_post_attention_layernorm_weight(**kwargs, data=post_norm_weight)
model_mg.set_layers_self_attention_post_mlp_layernorm_weight(**kwargs, data=post_mlp_norm_weight)
if post_norm_bias is not None:
model_mg.set_layers_self_attention_pre_mlp_layernorm_bias(**kwargs, data=post_norm_bias)