!1832 支持recompute-in-advance特性

Merge pull request !1832 from yanzhixiao/recompute-in-advance
This commit is contained in:
yanzhixiao 2024-11-01 09:33:05 +00:00 committed by i-robot
parent 1f9c98ad15
commit 4a88d95473
8 changed files with 112 additions and 38 deletions

1
OWNERS
View File

@ -26,3 +26,4 @@ reviewers:
- wenjiang2357
- ningbenzhe1
- matrixssy
- wx_4e191bb7af

View File

@ -223,12 +223,18 @@ MindSpeed-LLM作为昇腾大模型训练框架旨在为华为 [昇腾芯片](
<td>【昇腾】</td>
</tr>
<tr>
<td rowspan="3">通信掩盖</td>
<td rowspan="4">通信掩盖</td>
<td><a href="https://gitee.com/ascend/MindSpeed/blob/master/docs/features/async-ddp-param-gather.md">梯度reduce通算掩盖</a></td>
<td></td>
<td></td>
<td>【昇腾】</td>
</tr>
<tr>
<td><a href="https://gitee.com/ascend/MindSpeed/blob/master/docs/features/recompute_independent_pipelining.md">Recompute in advance</a></td>
<td></td>
<td></td>
<td>【昇腾】</td>
</tr>
<tr>
<td><a href="https://gitee.com/ascend/MindSpeed/blob/master/docs/features/async-ddp-param-gather.md">权重all-gather通算掩盖</a></td>
<td></td>

View File

@ -0,0 +1,33 @@
# coding=utf-8
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reversed.
# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
import contextlib
from functools import wraps
import torch
from megatron.training import get_args
from mindspeed.core.pipeline_parallel.ripipe_schedules import forward_backward_ripipe_pipelining
def get_forward_backward_func_wrapper(get_forward_backward_func):
@wraps(get_forward_backward_func)
def wrapper(*args, **kwargs):
arguments = get_args()
if arguments.recompute_in_advance and torch.is_grad_enabled():
return forward_backward_ripipe_pipelining
return get_forward_backward_func(*args, **kwargs)
return wrapper

View File

@ -314,7 +314,7 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init_wrapper)
# For recomputation
from ..core.transformer.transformer_block import transformer_block_checkpointed_forward_wrapper
from mindspeed.core.transformer.transformer_block import transformer_block_checkpointed_forward_wrapper
MegatronAdaptation.register(
'megatron.core.transformer.transformer_block.TransformerBlock._checkpointed_forward',
transformer_block_checkpointed_forward_wrapper)
@ -416,12 +416,14 @@ class CoreAdaptation(MegatronAdaptationABC):
_batched_p2p_ops)
# dpo relative, we need to change the recv/send shape when using PP, then deal with it by ourselves.
from modellink.tasks.rl.utils import get_tensor_shapes_decorator
MegatronAdaptation.register(
'megatron.core.pipeline_parallel.schedules.get_tensor_shapes',
get_tensor_shapes_decorator
)
from ..tasks.rl.utils import get_tensor_shapes_decorator
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.get_tensor_shapes', get_tensor_shapes_decorator)
# For recompute-in-advance
from ..core.pipeline_parallel.schedules import get_forward_backward_func_wrapper
MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.get_forward_backward_func', get_forward_backward_func_wrapper)
def patch_tensor_parallel(self):
from mindspeed.core.tensor_parallel.layers import vocab_parallel_embedding_forward
from mindspeed.core.tensor_parallel.random import _set_cuda_rng_state
@ -444,6 +446,10 @@ class CoreAdaptation(MegatronAdaptationABC):
checkpoint_forward_wrapper)
MegatronAdaptation.register('megatron.core.tensor_parallel.random.CheckpointFunction.backward',
checkpoint_backward_wrapper)
# For recompute-in-advance
from mindspeed.core.tensor_parallel.random import checkpoint_wrapper
MegatronAdaptation.register('megatron.core.tensor_parallel.random.checkpoint', checkpoint_wrapper)
def patch_parallel_state(self):
import megatron

View File

@ -109,9 +109,7 @@ def _add_deepseek_moe_args(parser):
help='set the coeff for devicie-level balance loss in deepseek moe')
group.add_argument('--moe-comm-aux-loss-coeff', type=float, default=0.,
help='set the coeff for communication balance loss in deepseek moe')
group.add_argument('--moe-without-activation', action='store_true', default=False,
help='save all the memory occupied by activations in moe layer.')
return parser
@ -430,6 +428,8 @@ def _add_algorithm_args(parser):
group.add_argument('--recompute-activation-function-num-layers', type=int, default=None,
help='Can be used together with "--recompute-method block." '
'and "--recompute-num-layers". ')
group.add_argument('--recompute-in-advance', action='store_true',
help='recompute early to reduce bubble and improve training.')
return parser
@ -875,7 +875,9 @@ def _store_variables(args):
# Bypass megatron validation when pp == 2 and vpp is enabled.
if args.pipeline_model_parallel_size == 2 and args.num_layers_per_virtual_pipeline_stage is not None:
variable_dict["num_layers_per_virtual_pipeline_stage"] = args.num_layers_per_virtual_pipeline_stage
variable_dict["overlap_p2p_comm"] = args.overlap_p2p_comm
args.num_layers_per_virtual_pipeline_stage = None
args.overlap_p2p_comm = None
return variable_dict
@ -884,8 +886,9 @@ def _restore_variables(args, variable_dict):
args.variable_seq_lengths = variable_dict["variable_seq_lengths"]
# Bypass megatron validation when pp == 2 and vpp is enabled.
if variable_dict.get("num_layers_per_virtual_pipeline_stage"):
if variable_dict.get("num_layers_per_virtual_pipeline_stage") and args.pipeline_model_parallel_size == 2:
args.num_layers_per_virtual_pipeline_stage = variable_dict["num_layers_per_virtual_pipeline_stage"]
args.overlap_p2p_comm = variable_dict["overlap_p2p_comm"]
# Moe models require `--sequence-parallel` to be turned on before Megatron core_v0.7.0,
# which conflicted with the behavior of turning it off by default during inference and evaluation.
@ -901,14 +904,12 @@ def _add_dummy_args(args):
For arguments in mindspeed-core which is currently unsupported in mindspeed-llm.
"""
# reduce_recompute_for_last_chunk would be registered if recompute-in-advance is supported.
args.reduce_recompute_for_last_chunk = False
args.adaptive_recompute_device_swap = False
args.adaptive_recompute_device_size = -1
args.adaptive_recompute_profiling_step = 10
args.moe_tp_extend_ep = False
args.recompute_in_bubble = False
args.recompute_in_advance = False
args.use_nanopipe = False
args.moe_alltoall_overlap_comm = False
args.moe_allgather_overlap_comm = False
args.moe_without_activation = False
@ -937,6 +938,20 @@ def _validate_vpp(args):
f'num_layers_per_virtual_pipeline_stage is {args.num_layers_per_virtual_pipeline_stage}')
def _validate_recompute_in_advance(args):
args.reduce_recompute_for_last_chunk = False
if args.recompute_in_advance:
args.reduce_recompute_for_last_chunk = True
if args.recompute_method == "uniform":
raise AssertionError('recompute_in_advance does not support uniform recompute_method')
if args.recompute_granularity == 'selective':
raise AssertionError('recompute_in_advance does not support vanilla recompute_activations.')
if not args.recompute_num_layers:
raise AssertionError('recompute_num_layers must be configured when using recompute_in_advance')
if args.pipeline_model_parallel_size <= 1 or args.num_layers_per_virtual_pipeline_stage != 1:
raise AssertionError('recompute_in_advance only support pipelining with interleaving and vpp stage should be 1.')
def validate_args_decorator(megatron_validate_args):
@wraps(megatron_validate_args)
def wrapper(args, defaults=None):
@ -953,6 +968,7 @@ def validate_args_decorator(megatron_validate_args):
_validate_cp_args(args)
_validate_vpp(args)
_validate_recompute_args(args)
_validate_recompute_in_advance(args)
_validate_create_attention_mask_in_dataloader(args)
_validate_instruction_finetune(args)
_validate_position_embedding(args)

View File

@ -13,8 +13,8 @@
<th>Mem.</th>
</tr>
<tr>
<td rowspan="21">ST</td>
<td rowspan="14">Pretrain</td>
<td rowspan="22">ST</td>
<td rowspan="15">Pretrain</td>
<td>Mcore</td>
<td>TPPPVPP重计算enable_recompute_layers_per_pp_rankFA_TND</td>
<td><a href="st/shell_scripts/llama2_tp2_pp4_vpp2_ptd.sh">llama2_tp2_pp4_vpp2.sh</a></td>
@ -38,6 +38,14 @@
<td>Y</td>
<td>Y</td>
</tr>
<tr>
<td>Mcore</td>
<td>recompute_in_advance, pp2vpp</td>
<td><a href="st/shell_scripts/llama3_tp2_pp2_vpp1.sh">llama3_tp2_pp2_vpp1.sh</a></td>
<td>Y</td>
<td>Y</td>
<td>Y</td>
</tr>
<tr>
<td>Mcore</td>
<td>cp_hybridgqa</td>

View File

@ -17,42 +17,42 @@
7.66312
],
"throughput": [
50.0,
121.1,
121.2,
121.2,
121.3,
121.1,
121.3,
121.2,
121.4,
121.4,
121.3,
121.3,
121.3,
121.4,
121.4
6.3,
97.5,
97.5,
97.3,
97.5,
97.4,
97.5,
97.5,
97.5,
97.3,
97.4,
97.4,
97.3,
97.4,
97.3
],
"memo info": [
{
"rank": 0,
"allocated memory": 20003.2373046875,
"max allocated memory": 22943.82177734375
"max allocated memory": 21004.72216796875
},
{
"rank": 1,
"allocated memory": 20003.2373046875,
"max allocated memory": 22943.82177734375
"max allocated memory": 21004.72216796875
},
{
"rank": 4,
"allocated memory": 20067.3115234375,
"max allocated memory": 25605.884765625
"allocated memory": 20067.31201171875,
"max allocated memory": 21069.3154296875
},
{
"rank": 5,
"allocated memory": 20067.3115234375,
"max allocated memory": 25605.7587890625
"allocated memory": 20067.31201171875,
"max allocated memory": 21069.3154296875
}
]
}

View File

@ -29,6 +29,10 @@ GPT_ARGS="
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--num-layers-per-virtual-pipeline-stage 1 \
--recompute-granularity full \
--recompute-method block \
--recompute-num-layers 1 \
--recompute-in-advance \
--use-mcore-models \
--micro-batch-size 1 \
--global-batch-size 16 \