mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-11-30 02:48:33 +08:00
added trick
This commit is contained in:
parent
bc2a4a33d5
commit
d87e921410
@ -533,7 +533,7 @@ def _add_training_args(parser):
|
||||
help='Create separate groups for MoE params.'
|
||||
'This is necessary for techniques like ZeRO.')
|
||||
group.add_argument('--optimizer', type=str, default='adam',
|
||||
choices=['adam', 'sgd'],
|
||||
choices=['adam', 'sgd', 'fused_adam'],
|
||||
help='Optimizer function')
|
||||
group.add_argument('--dataloader-type', type=str, default=None,
|
||||
choices=['single', 'cyclic'],
|
||||
|
@ -78,7 +78,6 @@ class MixedFusedLayerNorm(torch.nn.Module):
|
||||
self.weight = torch.nn.parameter.Parameter(torch.Tensor(*normalized_shape))
|
||||
self.bias = torch.nn.parameter.Parameter(torch.Tensor(*normalized_shape))
|
||||
self.reset_parameters()
|
||||
self.no_persist_layer_norm = True
|
||||
self.sequence_parallel = sequence_parallel
|
||||
|
||||
# set sequence parallelism flag on weight and bias parameters
|
||||
@ -90,9 +89,4 @@ class MixedFusedLayerNorm(torch.nn.Module):
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input):
|
||||
if self.no_persist_layer_norm:
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
output = FastLayerNormFN.apply(input, self.weight, self.bias, self.eps)
|
||||
output = make_viewless_tensor(inp=output, requires_grad=input.requires_grad, keep_graph=True)
|
||||
return output
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
@ -18,6 +18,7 @@ import math
|
||||
|
||||
import deepspeed
|
||||
import torch
|
||||
import torch_npu
|
||||
import torch.nn.functional as F
|
||||
from torch import distributed as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
@ -104,7 +105,7 @@ class ParallelMLP(MegatronModule):
|
||||
|
||||
if self.bias_gelu_fusion:
|
||||
intermediate_parallel = \
|
||||
bias_gelu_impl(intermediate_parallel, bias_parallel)
|
||||
torch_npu.fast_gelu(intermediate_parallel + bias_parallel)
|
||||
else:
|
||||
intermediate_parallel = \
|
||||
self.activation_func(intermediate_parallel + bias_parallel)
|
||||
@ -542,7 +543,7 @@ class ParallelTransformerLayer(MegatronModule):
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output,
|
||||
attention_bias.expand_as(residual),
|
||||
attention_bias,
|
||||
residual,
|
||||
self.hidden_dropout)
|
||||
|
||||
@ -564,7 +565,7 @@ class ParallelTransformerLayer(MegatronModule):
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output,
|
||||
attention_bias.expand_as(residual),
|
||||
attention_bias,
|
||||
residual,
|
||||
self.hidden_dropout)
|
||||
|
||||
@ -589,7 +590,7 @@ class ParallelTransformerLayer(MegatronModule):
|
||||
with torch.enable_grad():
|
||||
output = bias_dropout_add_func(
|
||||
mlp_output,
|
||||
mlp_bias.expand_as(residual),
|
||||
mlp_bias,
|
||||
residual,
|
||||
self.hidden_dropout)
|
||||
|
||||
|
@ -46,7 +46,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
|
||||
masked_target = target.clone() - vocab_start_index
|
||||
masked_target[target_mask] = 0
|
||||
masked_target *= ~target_mask
|
||||
|
||||
# Get predicted-logits = logits[target].
|
||||
# For Simplicity, we convert logits to a 2-D tensor with size
|
||||
@ -58,7 +58,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d.long()]
|
||||
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
||||
predicted_logits = predicted_logits_1d.view_as(target)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
predicted_logits *= ~target_mask
|
||||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
|
@ -200,7 +200,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.sparse)
|
||||
# Mask the output embedding.
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
compare = torch.zeros_like(output_parallel)
|
||||
output_parallel = torch.lerp(output_parallel, compare, input_mask[..., None].half())
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
if hasattr(self, 'norm'):
|
||||
|
@ -78,6 +78,14 @@ def get_megatron_optimizer(model):
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay,
|
||||
momentum=args.sgd_momentum)
|
||||
elif args.optimizer == 'fused_adam':
|
||||
from deepspeed.ops.adam.fused_adam import FusedAdam
|
||||
optimizer = FusedAdam(param_groups,
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
eps=args.adam_eps)
|
||||
|
||||
else:
|
||||
raise Exception('{} optimizer is not supported.'.format(
|
||||
args.optimizer))
|
||||
|
@ -81,7 +81,7 @@ TRANSFORMERS_OFFLINE=1 \
|
||||
--fp16 \
|
||||
--seed 42 \
|
||||
--position-embedding-type alibi \
|
||||
--optimizer adam \
|
||||
--optimizer fused_adam \
|
||||
--adam-beta1 0.9 \
|
||||
--adam-beta2 0.95 \
|
||||
--adam-eps 1e-8 \
|
||||
|
@ -91,7 +91,7 @@ TRANSFORMERS_OFFLINE=1 \
|
||||
--fp16 \
|
||||
--seed 42 \
|
||||
--position-embedding-type alibi \
|
||||
--optimizer adam \
|
||||
--optimizer fused_adam \
|
||||
--adam-beta1 0.9 \
|
||||
--adam-beta2 0.95 \
|
||||
--adam-eps 1e-8 \
|
||||
|
Loading…
Reference in New Issue
Block a user