mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-04 21:07:51 +08:00
!1578 [mcore-llm]类deepseekv2模型性能优化:无tp场景下的大词表mm本地切分,并支持mla场景的指定softmax_scale
Merge pull request !1578 from 丁子叉/master
This commit is contained in:
parent
143128b9e3
commit
cee50d023e
@ -47,7 +47,9 @@ MOE_ARGS="
|
||||
--moe-intermediate-size 1536 \
|
||||
--moe-router-load-balancing-type group_limited_greedy \
|
||||
--topk-group 3 \
|
||||
--moe-aux-loss-coeff 0.001 \
|
||||
--moe-aux-loss-coeff 0.003 \
|
||||
--moe-device-level-aux-loss-coeff 0.05 \
|
||||
--moe-comm-aux-loss-coeff 0.02 \
|
||||
--routed-scaling-factor 16.0 \
|
||||
--seq-aux
|
||||
"
|
||||
@ -71,13 +73,14 @@ GPT_ARGS="
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--expert-model-parallel-size ${EP} \
|
||||
--sequence-parallel \
|
||||
--output-layer-slice-num 8 \
|
||||
--num-layers 2 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 12288 \
|
||||
--num-attention-heads 128 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_MODEL} \
|
||||
--seq-length 4096 \
|
||||
--seq-length 8192 \
|
||||
--max-position-embeddings 163840 \
|
||||
--micro-batch-size 1 \
|
||||
--global-batch-size 64 \
|
||||
|
@ -338,6 +338,8 @@ def _add_network_size_args(parser):
|
||||
help="Use mc2 for compute-comm overlap in tp.")
|
||||
group.add_argument('--sliding-window', type=int, default=None,
|
||||
help='Window size when use sliding window attention.')
|
||||
group.add_argument('--output-layer-slice-num', type=int, default=1,
|
||||
help='Set the number of slices for the weight of the output_layer')
|
||||
return parser
|
||||
|
||||
|
||||
@ -612,6 +614,18 @@ def _validate_group_limited_greedy(args):
|
||||
args.expert_model_parallel_size))
|
||||
|
||||
|
||||
def _validate_output_layer_slice_num(args):
|
||||
if args.output_layer_slice_num < 1:
|
||||
raise AssertionError('Output_layer_slice_num must be greater than 0.')
|
||||
elif args.output_layer_slice_num > 1:
|
||||
if args.tensor_model_parallel_size > 1:
|
||||
raise AssertionError('When output_layer_slice_num is greater than 1, only support TP size is 1.')
|
||||
if (args.padded_vocab_size is not None) and (args.padded_vocab_size % args.output_layer_slice_num != 0):
|
||||
raise AssertionError('Output_layer_slice_num needs to be divisible by padded_vocab_size.')
|
||||
elif (args.vocab_size is not None) and (args.vocab_size % args.output_layer_slice_num != 0):
|
||||
raise AssertionError('Output_layer_slice_num needs to be divisible by vocab_size.')
|
||||
|
||||
|
||||
def core_transformer_config_from_args_wrapper(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(args):
|
||||
@ -659,6 +673,7 @@ def validate_args_decorator(megatron_validate_args):
|
||||
_validate_transformer_block_build_layers(args)
|
||||
_validate_group_limited_greedy(args)
|
||||
_validate_evaluation_args(args)
|
||||
_validate_output_layer_slice_num(args)
|
||||
|
||||
_validate_optimizer(args)
|
||||
from modellink.utils import print_args
|
||||
|
@ -14,12 +14,39 @@
|
||||
# limitations under the License.
|
||||
|
||||
from torch import Tensor
|
||||
from functools import wraps
|
||||
|
||||
from megatron.core import InferenceParams
|
||||
from megatron.core.packed_seq_params import PackedSeqParams
|
||||
from megatron.training import get_args
|
||||
|
||||
|
||||
from modellink.core.tensor_parallel.layers import SegmentedColumnParallelLinear
|
||||
|
||||
|
||||
def gpt_model_init_wrapper(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
fn(self, *args, **kwargs)
|
||||
config = args[1] if len(args) > 1 else kwargs['config']
|
||||
if get_args().output_layer_slice_num > 1:
|
||||
self.output_layer = SegmentedColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
self.vocab_size,
|
||||
config=config,
|
||||
init_method=config.init_method,
|
||||
bias=False,
|
||||
skip_bias_add=False,
|
||||
gather_output=not self.parallel_output,
|
||||
skip_weight_param_allocation=self.pre_process
|
||||
and self.share_embeddings_and_output_weights,
|
||||
embedding_activation_buffer=self.embedding_activation_buffer,
|
||||
grad_output_buffer=self.grad_output_buffer,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def gpt_model_forward(self, input_ids: Tensor,
|
||||
position_ids: Tensor, attention_mask: Tensor,
|
||||
decoder_input: Tensor = None,
|
||||
|
@ -13,9 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from megatron.training import get_args
|
||||
from megatron.core.tensor_parallel import copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region
|
||||
from megatron.core.tensor_parallel.layers import linear_with_frozen_weight, linear_with_grad_accumulation_and_async_allreduce, ColumnParallelLinear
|
||||
|
||||
|
||||
def vocab_embedding_wrapper(fn):
|
||||
@ -27,3 +32,92 @@ def vocab_embedding_wrapper(fn):
|
||||
output = self.norm(output)
|
||||
return output * args_.embedding_multiplier_scale if args_.embedding_multiplier_scale else output
|
||||
return wrapper
|
||||
|
||||
|
||||
class SegmentedColumnParallelLinear(ColumnParallelLinear):
|
||||
def __int__(self):
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
||||
|
||||
weight (optional): weight tensor to use, compulsory when
|
||||
skip_weight_param_allocation is True.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
|
||||
"""
|
||||
args_ = get_args()
|
||||
if weight is None:
|
||||
if self.weight is None:
|
||||
raise RuntimeError(
|
||||
"weight was not supplied to ColumnParallelLinear forward pass "
|
||||
"and skip_weight_param_allocation is True."
|
||||
)
|
||||
weight = self.weight
|
||||
else:
|
||||
# Check the weight passed in is the correct shape
|
||||
expected_shape = (self.output_size_per_partition, self.input_size)
|
||||
if weight.shape != expected_shape:
|
||||
raise RuntimeError(
|
||||
f"supplied weight's shape is {tuple(weight.shape)}, "
|
||||
f"not {expected_shape} as expected"
|
||||
)
|
||||
|
||||
if self.config._cpu_offloading_context is not None:
|
||||
if self.config._cpu_offloading_context.inside_context == True:
|
||||
assert (
|
||||
self.config.cpu_offloading == False
|
||||
), "CPU Offloading cannot be enabled while using non-TE modules"
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if (
|
||||
self.async_tensor_model_parallel_allreduce
|
||||
or self.sequence_parallel
|
||||
or self.explicit_expert_comm
|
||||
):
|
||||
input_parallel = input_
|
||||
else:
|
||||
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
||||
|
||||
if self.config.defer_embedding_wgrad_compute:
|
||||
self.embedding_activation_buffer.append(input_parallel)
|
||||
|
||||
# Matrix multiply.
|
||||
if not weight.requires_grad:
|
||||
self._forward_impl = linear_with_frozen_weight
|
||||
else:
|
||||
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
|
||||
|
||||
weight = torch.split(weight, weight.shape[0] // args_.output_layer_slice_num, dim=0)
|
||||
|
||||
output_parallel = []
|
||||
for i in range(args_.output_layer_slice_num):
|
||||
output_parallel.append(self._forward_impl(
|
||||
input=input_parallel,
|
||||
weight=weight[i],
|
||||
bias=bias,
|
||||
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
|
||||
async_grad_allreduce=False
|
||||
if self.explicit_expert_comm
|
||||
else self.async_tensor_model_parallel_allreduce,
|
||||
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
|
||||
grad_output_buffer=self.grad_output_buffer
|
||||
if self.config.defer_embedding_wgrad_compute
|
||||
else None,
|
||||
))
|
||||
output_parallel = torch.cat(output_parallel, dim=2)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
assert not self.sequence_parallel
|
||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
@ -61,6 +61,7 @@ def dot_product_attention_init_wrapper(fn):
|
||||
|
||||
args = get_args()
|
||||
if args.multi_head_latent_attention:
|
||||
self.scale_mask_softmax.scale = True
|
||||
self.hidden_size_per_partition = args.num_attention_heads * args.v_head_dim
|
||||
self.q_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
|
||||
self.softmax_scale = self.q_head_dim ** (-0.5)
|
||||
@ -73,6 +74,8 @@ def dot_product_attention_init_wrapper(fn):
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.softmax_scale = self.softmax_scale * mscale * mscale
|
||||
|
||||
self.norm_factor = 1.0 / self.softmax_scale
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
@ -135,6 +135,7 @@ def patch_core_models(args):
|
||||
from ..core.transformer.dot_product_attention import dot_product_attention_init_wrapper, \
|
||||
dot_product_attention_forward_wrapper
|
||||
from ..core.transformer.attention import attention_init_wrapper, attention_forward
|
||||
from ..core.models.gpt.gpt_model import gpt_model_init_wrapper
|
||||
|
||||
# Embedding
|
||||
PatchManager.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.get_pos_emb_on_this_cp_rank', get_pos_emb_on_this_cp_rank)
|
||||
@ -159,6 +160,7 @@ def patch_core_models(args):
|
||||
PatchManager.register_patch('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank)
|
||||
PatchManager.register_patch('megatron.training.dist_signal_handler.get_device', get_device_wrapper)
|
||||
PatchManager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward)
|
||||
PatchManager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init_wrapper)
|
||||
|
||||
# For recomputation
|
||||
from ..core.transformer.transformer_block import transformer_block_checkpointed_forward_wrapper
|
||||
|
Loading…
Reference in New Issue
Block a user