mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 05:17:40 +08:00
!210 LLama2 model添加GQA和GQA FA功能
Merge pull request !210 from Liuchang/master
This commit is contained in:
commit
a7ef12d257
@ -313,7 +313,14 @@ def parse_args(extra_args_provider=None, defaults={},
|
||||
args.swiglu = False
|
||||
args.fp8_e4m3 = False
|
||||
args.fp8_hybrid = False
|
||||
args.group_query_attention = False
|
||||
|
||||
if args.group_query_attention and args.position_embedding_type != PositionEmbeddingType.rotary:
|
||||
raise NotImplementedError(
|
||||
'Currently the group query attention only '
|
||||
'support rotary position embedding.')
|
||||
|
||||
error_message = 'Triangle attn and flash attention should not be used at the same time.'
|
||||
ensure_valid(not (args.triangle_attn and args.use_flash_attn), error_message)
|
||||
# AML
|
||||
if args.aml_data_download_path is not None:
|
||||
data_paths = []
|
||||
@ -438,6 +445,9 @@ def _add_network_size_args(parser):
|
||||
group.add_argument('--mlp-layer-fusion', action='store_true',
|
||||
help='Fuse gate and upprojection in MLP for llama families, '
|
||||
'e.g. llama or internlm')
|
||||
group.add_argument('--use-flash-attn', action='store_true',
|
||||
default=False,
|
||||
help='Use flash attention')
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -42,6 +42,10 @@ from ascendspeed.model.fused_softmax import NPUFusedScaleMaskSoftmax
|
||||
from ascendspeed.model.language_model import Pooler
|
||||
from ascendspeed.model.triangle_attention import TriangleAttention
|
||||
from ascendspeed.error_utils import check_equal, check_divisible, ensure_valid
|
||||
try:
|
||||
from einops import rearrange
|
||||
except ImportError:
|
||||
rearrange = None
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
@ -301,6 +305,46 @@ class LlamaParallelMLP(MegatronModule):
|
||||
return output
|
||||
|
||||
|
||||
class FlashSelfAttention(torch.nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=1., attention_dropout=0.):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(self, q, k, v, n, attention_mask):
|
||||
if self.causal:
|
||||
output = torch_npu.npu_fusion_attention(
|
||||
q, k, v, n, "SBH",
|
||||
pse=None,
|
||||
padding_mask=None,
|
||||
atten_mask=attention_mask,
|
||||
scale=self.softmax_scale,
|
||||
pre_tockens=k.shape[0], # seq_len
|
||||
next_tockens=0, # 0
|
||||
keep_prob=1 - self.dropout_p,
|
||||
)[0]
|
||||
return output
|
||||
raise Exception("the attention type {} is not support!".format(self.attention_type))
|
||||
|
||||
|
||||
def repeat_interleave(inputs, repeats, dim):
|
||||
shape = inputs.shape
|
||||
new_shape = shape[:dim + 1] + (repeats, ) + shape[dim + 1:]
|
||||
out_shape = shape[:dim] + (shape[dim] * repeats, ) + shape[dim + 1:]
|
||||
return inputs.unsqueeze(dim + 1).expand(new_shape).reshape(out_shape)
|
||||
|
||||
|
||||
class LlamaParallelAttention(MegatronModule):
|
||||
"""Parallel self-attention layer abstract class.
|
||||
|
||||
@ -342,13 +386,30 @@ class LlamaParallelAttention(MegatronModule):
|
||||
self.hidden_size_per_attention_head = utils.divide(projection_size, args.num_attention_heads)
|
||||
self.num_attention_heads_per_partition = utils.divide(args.num_attention_heads, world_size)
|
||||
|
||||
self.group_query_attention = args.group_query_attention
|
||||
self.num_query_groups = args.num_query_groups
|
||||
|
||||
if self.group_query_attention:
|
||||
if args.num_query_groups % world_size != 0:
|
||||
raise NotImplementedError(
|
||||
'Currently the num_query_groups should be '
|
||||
'a multiple of the tensor parallel size')
|
||||
self.num_query_groups_per_partition = utils.divide(
|
||||
args.num_query_groups, world_size)
|
||||
kv_projection_size = args.kv_channels * args.num_query_groups
|
||||
else:
|
||||
kv_projection_size = args.kv_channels * args.num_attention_heads
|
||||
self.num_query_groups_per_partition = self.num_attention_heads_per_partition
|
||||
|
||||
self.num_repeat = (self.num_attention_heads_per_partition //
|
||||
self.num_query_groups_per_partition)
|
||||
# Strided linear layer.
|
||||
if attention_type == AttnType.self_attn:
|
||||
# 适配internlm
|
||||
bias = getattr(config, "column_parallel_linear_bias", False)
|
||||
self.query_key_value = tensor_parallel.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
3 * projection_size,
|
||||
projection_size + 2 * kv_projection_size,
|
||||
config=config,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
@ -377,6 +438,14 @@ class LlamaParallelAttention(MegatronModule):
|
||||
if self.use_triangle_attn:
|
||||
self.triangle_attn = TriangleAttention(block_size=1024,
|
||||
masked_softmax_func=self.scale_mask_softmax)
|
||||
|
||||
self.use_flash_attention = args.use_flash_attn
|
||||
if self.use_flash_attention:
|
||||
self.core_attention_flash = FlashSelfAttention(
|
||||
causal=True,
|
||||
softmax_scale=(1.0 / self.norm_factor),
|
||||
attention_dropout=0)
|
||||
|
||||
# 适配internlm模型
|
||||
bias = getattr(config, "row_parallel_linear_bias", False)
|
||||
skip_bias_add = getattr(config, "row_parallel_linear_skip_bias_add", True)
|
||||
@ -403,19 +472,38 @@ class LlamaParallelAttention(MegatronModule):
|
||||
# =====================
|
||||
|
||||
if self.attention_type == AttnType.self_attn:
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np + 2ng) * hn)]
|
||||
mixed_x_layer, _ = self.query_key_value(hidden_states)
|
||||
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head)
|
||||
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (
|
||||
self.num_query_groups_per_partition,
|
||||
(
|
||||
(self.num_repeat + 2)
|
||||
* self.hidden_size_per_attention_head
|
||||
),
|
||||
)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
|
||||
# [sq, b, 3 * h] --> 3 [sq, b, h]
|
||||
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
|
||||
(query_layer,
|
||||
key_layer,
|
||||
value_layer) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
value_layer) = torch.split(
|
||||
mixed_x_layer,
|
||||
[
|
||||
(
|
||||
self.num_repeat
|
||||
* self.hidden_size_per_attention_head
|
||||
),
|
||||
self.hidden_size_per_attention_head,
|
||||
self.hidden_size_per_attention_head
|
||||
],
|
||||
dim=3)
|
||||
|
||||
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
|
||||
query_layer = query_layer.view(query_layer.size(0),
|
||||
query_layer.size(1), -1,
|
||||
self.hidden_size_per_attention_head)
|
||||
# ==================================
|
||||
# Rotary Position Embedding
|
||||
# ==================================
|
||||
@ -439,8 +527,25 @@ class LlamaParallelAttention(MegatronModule):
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
# expand the key_layer and value_layer [b, ng, sk, hn] -> [b, np, sk, hn]
|
||||
if self.num_repeat > 1 and not self.use_flash_attention:
|
||||
key_layer = repeat_interleave(
|
||||
key_layer, self.num_repeat, dim=1)
|
||||
value_layer = repeat_interleave(
|
||||
value_layer, self.num_repeat, dim=1)
|
||||
|
||||
if self.use_flash_attention and layer_past is None:
|
||||
query_layer, key_layer, value_layer = [
|
||||
rearrange(x, 'b n s d -> s b (n d)').contiguous()
|
||||
for x in (query_layer, key_layer, value_layer)]
|
||||
context_layer = self.core_attention_flash(
|
||||
query_layer, key_layer, value_layer,
|
||||
self.num_attention_heads_per_partition,
|
||||
attention_mask)
|
||||
output, _ = self.dense(context_layer)
|
||||
return output
|
||||
# use triangle attention
|
||||
if self.use_triangle_attn and layer_past is None:
|
||||
elif self.use_triangle_attn and layer_past is None:
|
||||
context_layer = self.triangle_attn(query_layer, key_layer, value_layer, attention_mask)
|
||||
output, _ = self.dense(context_layer)
|
||||
return output
|
||||
|
62
examples/llama2/pretrain_llama2_70B_ptd.sh
Normal file
62
examples/llama2/pretrain_llama2_70B_ptd.sh
Normal file
@ -0,0 +1,62 @@
|
||||
# This is an example: training llama using PTD/
|
||||
|
||||
# The number of parameters is not aligned
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
export COMBINED_ENABLE=1
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=8
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
DATA_PATH=./dataset/llama_text_document
|
||||
CHECKPOINT_LOAD_PATH=./load_ckpt
|
||||
CHECKPOINT_SAVE_PATH=./save_ckpt
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
# Main script
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--use-flash-attn \
|
||||
--sequence-parallel \
|
||||
--mlp-layer-fusion \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 8 \
|
||||
--num-layers 80 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 28672 \
|
||||
--num-attention-heads 64 \
|
||||
--group-query-attention \
|
||||
--num-query-groups 8 \
|
||||
--micro-batch-size 2 \
|
||||
--global-batch-size 1024 \
|
||||
--seq-length 4096 \
|
||||
--max-position-embeddings 4096 \
|
||||
--train-iters 5000 \
|
||||
--lr-decay-iters 320000 \
|
||||
--save $CHECKPOINT_SAVE_PATH \
|
||||
--load $CHECKPOINT_LOAD_PATH \
|
||||
--data-path $DATA_PATH \
|
||||
--tokenizer-name-or-path ./dataset/llama/ \
|
||||
--tokenizer-not-use-fast \
|
||||
--pad-vocab-size-to 32000 \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 0.00015 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 1.0e-5 \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .01 \
|
||||
--log-interval 1 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
--bf16 | tee logs/train.log
|
Loading…
Reference in New Issue
Block a user