llama 13b增加自定义选择重计算,兼容之前所有逻辑。

This commit is contained in:
l00619700 2023-10-20 11:05:38 +08:00
parent 5e27f2e9b6
commit a511221b5c
5 changed files with 32 additions and 91 deletions

View File

@ -9,10 +9,10 @@ AscendSpeed provides end-to-end test scripts for large language models on [Ascen
Currently, the pre-training of the following models has been supported:
* <a href="https://github.com/baichuan-inc" style="color:green">Baichuan</a>-[[README: 7B/13B]](examples/baichuan/README.md)
* <a href="https://arxiv.org/pdf/2108.12409.pdf" style="color:green">Bloom</a>-[[README: 7B1/176B]](examples/bloom/README.md)
* <a href="https://internlm.intern-ai.org.cn/" style="color:green">InternLM</a>-7B
* <a href="https://huggingface.co/docs/transformers/main/model_doc/llama" style="color:green">LLaMA</a>-7B/13B/[[README: 65B]](examples/llama/README.md)
* <a href="https://huggingface.co/docs/transformers/main/model_doc/llama2" style="color:green">LLaMA2</a>-7B
* <a href="https://arxiv.org/pdf/2108.12409.pdf" style="color:green">Bloom</a>-[[README: 7B/176B]](examples/bloom/README.md)
* <a href="https://internlm.intern-ai.org.cn/" style="color:green">InternLM</a>-[[README: 7B]](examples/intern/README.md)
* <a href="https://huggingface.co/docs/transformers/main/model_doc/llama" style="color:green">LLaMA</a>-[[README: 7B/13B/65B]](examples/llama/README.md)
* <a href="https://huggingface.co/docs/transformers/main/model_doc/llama2" style="color:green">LLaMA2</a>-[[README: 7B]](examples/llama2/README.md)
LLaMA-33B, LLaMA2-13B/70B, Aquila-7B are coming soon ...
@ -103,7 +103,7 @@ The plan for more tasks, like RLHF and RM, is under way ...
<td> <a href="examples/bloom/pretrain_bloom_176b.sh">Train</a> </td>
</tr>
<tr>
<td>InternLM</td>
<td><a href="examples/intern/README.md">InternLM</td>
<td>7B</td>
<td>1x8</td>
<td>FP16</td>
@ -113,7 +113,7 @@ The plan for more tasks, like RLHF and RM, is under way ...
<td> <a href="examples/intern/pretrain_internlm_7b_zero.sh">Train</a> </td>
</tr>
<tr>
<td rowspan="4">LLaMA</td>
<td rowspan="4"><a href="examples/llama/README.md">LLaMA</td>
<td>7B</td>
<td>1x8</td>
<td>FP16</td>
@ -143,7 +143,7 @@ The plan for more tasks, like RLHF and RM, is under way ...
<td> <a href="examples/llama/pretrain_llama_65B_ptd_32p.sh">Train</a> </td>
</tr>
<tr>
<td>LLaMA2</td>
<td><a href="examples/llama2/README.md">LLaMA2</td>
<td>7B</td>
<td>1x8</td>
<td>FP16 </td>

View File

@ -557,6 +557,11 @@ def _add_training_args(parser):
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--recomputation-layer-num', nargs='+',
type=int, help='Represents the number of layers to be recomputed at each stage of the pp. '
'The default is None. If pp=4, each stage has 8 layers, '
'if this parameter is equal to 4 4 4 4, '
'it means that each stage only needs to recompute 4 layers.')
group.add_argument('--distribute-checkpointed-activations',
action='store_true',
help='If set, distribute checkpointed activations '
@ -1110,7 +1115,7 @@ def _add_activation_checkpoint_args(parser):
help='does a synchronize at the beginning and end of each checkpointed layer.')
group.add_argument('--profile-backward', action='store_true',
help='Enables backward pass profiling for checkpointed layers.')
group.add_argument('--checkpoint_policy', type=str, default='full', choices=['full', 'block'],
group.add_argument('--checkpoint_policy', type=str, default='full', choices=['full', 'block', 'custom'],
help="activation checkpoint policy")
group.add_argument('--checkpoint_block_layer', type=int, default=25,
help="activation checkpoint block layer number")

View File

@ -779,6 +779,19 @@ class LlamaParallelTransformer(MegatronModule):
check_divisible(args.num_layers, parallel_state.get_pipeline_model_parallel_world_size(), error_info)
self.num_layers = args.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
if self.checkpoint_policy == "block":
self.recomputation_layer_num = [self.checkpoint_block_layer] * \
parallel_state.get_pipeline_model_parallel_world_size()
elif self.checkpoint_policy == "custom":
if len(args.recomputation_layer_num) == \
parallel_state.get_pipeline_model_parallel_world_size():
self.recomputation_layer_num = args.recomputation_layer_num
else:
raise ValueError(f"`recomputation_layer_num` length must equal to PP stage number.")
else:
self.recomputation_layer_num = [self.num_layers] * \
parallel_state.get_pipeline_model_parallel_world_size()
# Transformer layers.
def build_layer(layer_number):
return LlamaParallelTransformerLayer(
@ -834,7 +847,6 @@ class LlamaParallelTransformer(MegatronModule):
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
@ -848,32 +860,11 @@ class LlamaParallelTransformer(MegatronModule):
# Make sure memory is freed.
tensor_parallel.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.checkpoint_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask)
l += self.checkpoint_num_layers
return hidden_states
def _checkpointed_forward_block(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask=attention_mask)
return x_
return custom_forward
# Make sure memory is freed.
for idx in range(self.num_layers):
if idx < self.checkpoint_block_layer:
if idx < self.recomputation_layer_num[
parallel_state.get_pipeline_model_parallel_rank()]:
hidden_states = tensor_parallel.checkpoint(
custom(idx, idx + 1),
self.distribute_saved_activations,
@ -921,10 +912,8 @@ class LlamaParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states = self.input_tensor
if self.checkpoint_activations and self.checkpoint_policy == 'full':
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, attention_mask)
elif self.checkpoint_activations and self.checkpoint_policy == 'block':
hidden_states = self._checkpointed_forward_block(hidden_states, attention_mask)
else:
if get_key_value:
presents = []

View File

@ -50,4 +50,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--eval-iters 10 \
--initial-loss-scale 4096.0 \
--checkpoint-activations \
--checkpoint-policy custom \
--recomputation-layer-num 3 2 1 0 0 0 0 0 \
--triangle-attn \
--fp16 | tee logs/train_13B.log

View File

@ -1,56 +0,0 @@
# This is an example: train llama using PTD.
# The number of parameters is not aligned
export LD_LIBRARY_PATH=/usr/local/lib:/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
export HCCL_CONNECT_TIMEOUT=1200
export COMBINED_ENABLE=1
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6001
NNODES=1
NODE_RANK=0
NPUS_PER_NODE=8
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
DATA_PATH=./dataset/llama_text_document
CHECKPOINT_PATH=./ckpt
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
# Main script
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_llama.py \
--DDP-impl local \
--use-distributed-optimizer \
--tensor-model-parallel-size 4 \
--pipeline-model-parallel-size 2 \
--num-layers 30 \
--hidden-size 6656 \
--ffn-hidden-size 17920 \
--num-attention-heads 52 \
--micro-batch-size 1 \
--global-batch-size 32 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--train-iters 500000 \
--lr-decay-iters 320000 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--tokenizer-name-or-path $TOKENIZER_PATH \
--tokenizer-not-use-fast \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.00015 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction .01 \
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--initial-loss-scale 4096.0 \
--fp16 | tee logs/train.log