!210 LLama2 model添加GQA和GQA FA功能

Merge pull request !210 from Liuchang/master
This commit is contained in:
i-robot 2023-11-02 02:13:17 +00:00 committed by Gitee
commit a7ef12d257
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 186 additions and 9 deletions

View File

@ -313,7 +313,14 @@ def parse_args(extra_args_provider=None, defaults={},
args.swiglu = False
args.fp8_e4m3 = False
args.fp8_hybrid = False
args.group_query_attention = False
if args.group_query_attention and args.position_embedding_type != PositionEmbeddingType.rotary:
raise NotImplementedError(
'Currently the group query attention only '
'support rotary position embedding.')
error_message = 'Triangle attn and flash attention should not be used at the same time.'
ensure_valid(not (args.triangle_attn and args.use_flash_attn), error_message)
# AML
if args.aml_data_download_path is not None:
data_paths = []
@ -438,6 +445,9 @@ def _add_network_size_args(parser):
group.add_argument('--mlp-layer-fusion', action='store_true',
help='Fuse gate and upprojection in MLP for llama families, '
'e.g. llama or internlm')
group.add_argument('--use-flash-attn', action='store_true',
default=False,
help='Use flash attention')
return parser

View File

@ -42,6 +42,10 @@ from ascendspeed.model.fused_softmax import NPUFusedScaleMaskSoftmax
from ascendspeed.model.language_model import Pooler
from ascendspeed.model.triangle_attention import TriangleAttention
from ascendspeed.error_utils import check_equal, check_divisible, ensure_valid
try:
from einops import rearrange
except ImportError:
rearrange = None
class RotaryEmbedding(torch.nn.Module):
@ -301,6 +305,46 @@ class LlamaParallelMLP(MegatronModule):
return output
class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=1., attention_dropout=0.):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, k, v, n, attention_mask):
if self.causal:
output = torch_npu.npu_fusion_attention(
q, k, v, n, "SBH",
pse=None,
padding_mask=None,
atten_mask=attention_mask,
scale=self.softmax_scale,
pre_tockens=k.shape[0], # seq_len
next_tockens=0, # 0
keep_prob=1 - self.dropout_p,
)[0]
return output
raise Exception("the attention type {} is not support!".format(self.attention_type))
def repeat_interleave(inputs, repeats, dim):
shape = inputs.shape
new_shape = shape[:dim + 1] + (repeats, ) + shape[dim + 1:]
out_shape = shape[:dim] + (shape[dim] * repeats, ) + shape[dim + 1:]
return inputs.unsqueeze(dim + 1).expand(new_shape).reshape(out_shape)
class LlamaParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
@ -342,13 +386,30 @@ class LlamaParallelAttention(MegatronModule):
self.hidden_size_per_attention_head = utils.divide(projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = utils.divide(args.num_attention_heads, world_size)
self.group_query_attention = args.group_query_attention
self.num_query_groups = args.num_query_groups
if self.group_query_attention:
if args.num_query_groups % world_size != 0:
raise NotImplementedError(
'Currently the num_query_groups should be '
'a multiple of the tensor parallel size')
self.num_query_groups_per_partition = utils.divide(
args.num_query_groups, world_size)
kv_projection_size = args.kv_channels * args.num_query_groups
else:
kv_projection_size = args.kv_channels * args.num_attention_heads
self.num_query_groups_per_partition = self.num_attention_heads_per_partition
self.num_repeat = (self.num_attention_heads_per_partition //
self.num_query_groups_per_partition)
# Strided linear layer.
if attention_type == AttnType.self_attn:
# 适配internlm
bias = getattr(config, "column_parallel_linear_bias", False)
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
projection_size + 2 * kv_projection_size,
config=config,
bias=bias,
gather_output=False,
@ -377,6 +438,14 @@ class LlamaParallelAttention(MegatronModule):
if self.use_triangle_attn:
self.triangle_attn = TriangleAttention(block_size=1024,
masked_softmax_func=self.scale_mask_softmax)
self.use_flash_attention = args.use_flash_attn
if self.use_flash_attention:
self.core_attention_flash = FlashSelfAttention(
causal=True,
softmax_scale=(1.0 / self.norm_factor),
attention_dropout=0)
# 适配internlm模型
bias = getattr(config, "row_parallel_linear_bias", False)
skip_bias_add = getattr(config, "row_parallel_linear_skip_bias_add", True)
@ -403,19 +472,38 @@ class LlamaParallelAttention(MegatronModule):
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
# Attention heads [sq, b, h] --> [sq, b, (np + 2ng) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_query_groups_per_partition,
(
(self.num_repeat + 2)
* self.hidden_size_per_attention_head
),
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, 3 * h] --> 3 [sq, b, h]
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query_layer,
key_layer,
value_layer) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_x_layer, 3)
value_layer) = torch.split(
mixed_x_layer,
[
(
self.num_repeat
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head
],
dim=3)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query_layer = query_layer.view(query_layer.size(0),
query_layer.size(1), -1,
self.hidden_size_per_attention_head)
# ==================================
# Rotary Position Embedding
# ==================================
@ -439,8 +527,25 @@ class LlamaParallelAttention(MegatronModule):
if get_key_value:
present = (key_layer, value_layer)
# expand the key_layer and value_layer [b, ng, sk, hn] -> [b, np, sk, hn]
if self.num_repeat > 1 and not self.use_flash_attention:
key_layer = repeat_interleave(
key_layer, self.num_repeat, dim=1)
value_layer = repeat_interleave(
value_layer, self.num_repeat, dim=1)
if self.use_flash_attention and layer_past is None:
query_layer, key_layer, value_layer = [
rearrange(x, 'b n s d -> s b (n d)').contiguous()
for x in (query_layer, key_layer, value_layer)]
context_layer = self.core_attention_flash(
query_layer, key_layer, value_layer,
self.num_attention_heads_per_partition,
attention_mask)
output, _ = self.dense(context_layer)
return output
# use triangle attention
if self.use_triangle_attn and layer_past is None:
elif self.use_triangle_attn and layer_past is None:
context_layer = self.triangle_attn(query_layer, key_layer, value_layer, attention_mask)
output, _ = self.dense(context_layer)
return output

View File

@ -0,0 +1,62 @@
# This is an example: training llama using PTD/
# The number of parameters is not aligned
export LD_LIBRARY_PATH=/usr/local/lib:/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
export HCCL_CONNECT_TIMEOUT=1200
source /usr/local/Ascend/ascend-toolkit/set_env.sh
export COMBINED_ENABLE=1
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6001
NNODES=8
NODE_RANK=0
NPUS_PER_NODE=8
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
DATA_PATH=./dataset/llama_text_document
CHECKPOINT_LOAD_PATH=./load_ckpt
CHECKPOINT_SAVE_PATH=./save_ckpt
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
# Main script
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_llama.py \
--DDP-impl local \
--use-flash-attn \
--sequence-parallel \
--mlp-layer-fusion \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 8 \
--num-layers 80 \
--hidden-size 8192 \
--ffn-hidden-size 28672 \
--num-attention-heads 64 \
--group-query-attention \
--num-query-groups 8 \
--micro-batch-size 2 \
--global-batch-size 1024 \
--seq-length 4096 \
--max-position-embeddings 4096 \
--train-iters 5000 \
--lr-decay-iters 320000 \
--save $CHECKPOINT_SAVE_PATH \
--load $CHECKPOINT_LOAD_PATH \
--data-path $DATA_PATH \
--tokenizer-name-or-path ./dataset/llama/ \
--tokenizer-not-use-fast \
--pad-vocab-size-to 32000 \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.00015 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction .01 \
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--bf16 | tee logs/train.log