added trick

This commit is contained in:
wiyr 2023-06-30 11:00:38 +08:00 committed by dilililiwhy
parent bc2a4a33d5
commit d87e921410
8 changed files with 21 additions and 17 deletions

View File

@ -533,7 +533,7 @@ def _add_training_args(parser):
help='Create separate groups for MoE params.'
'This is necessary for techniques like ZeRO.')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
choices=['adam', 'sgd', 'fused_adam'],
help='Optimizer function')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],

View File

@ -78,7 +78,6 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.weight = torch.nn.parameter.Parameter(torch.Tensor(*normalized_shape))
self.bias = torch.nn.parameter.Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
self.no_persist_layer_norm = True
self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters
@ -90,9 +89,4 @@ class MixedFusedLayerNorm(torch.nn.Module):
init.zeros_(self.bias)
def forward(self, input):
if self.no_persist_layer_norm:
return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
else:
output = FastLayerNormFN.apply(input, self.weight, self.bias, self.eps)
output = make_viewless_tensor(inp=output, requires_grad=input.requires_grad, keep_graph=True)
return output
return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)

View File

@ -18,6 +18,7 @@ import math
import deepspeed
import torch
import torch_npu
import torch.nn.functional as F
from torch import distributed as dist
from deepspeed.accelerator import get_accelerator
@ -104,7 +105,7 @@ class ParallelMLP(MegatronModule):
if self.bias_gelu_fusion:
intermediate_parallel = \
bias_gelu_impl(intermediate_parallel, bias_parallel)
torch_npu.fast_gelu(intermediate_parallel + bias_parallel)
else:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
@ -542,7 +543,7 @@ class ParallelTransformerLayer(MegatronModule):
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
attention_bias,
residual,
self.hidden_dropout)
@ -564,7 +565,7 @@ class ParallelTransformerLayer(MegatronModule):
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
attention_bias,
residual,
self.hidden_dropout)
@ -589,7 +590,7 @@ class ParallelTransformerLayer(MegatronModule):
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
mlp_bias,
residual,
self.hidden_dropout)

View File

@ -46,7 +46,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
masked_target *= ~target_mask
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
@ -58,7 +58,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d.long()]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
predicted_logits *= ~target_mask
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,

View File

@ -200,7 +200,8 @@ class VocabParallelEmbedding(torch.nn.Module):
self.sparse)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
compare = torch.zeros_like(output_parallel)
output_parallel = torch.lerp(output_parallel, compare, input_mask[..., None].half())
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
if hasattr(self, 'norm'):

View File

@ -78,6 +78,14 @@ def get_megatron_optimizer(model):
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum)
elif args.optimizer == 'fused_adam':
from deepspeed.ops.adam.fused_adam import FusedAdam
optimizer = FusedAdam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
else:
raise Exception('{} optimizer is not supported.'.format(
args.optimizer))

View File

@ -81,7 +81,7 @@ TRANSFORMERS_OFFLINE=1 \
--fp16 \
--seed 42 \
--position-embedding-type alibi \
--optimizer adam \
--optimizer fused_adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--adam-eps 1e-8 \

View File

@ -91,7 +91,7 @@ TRANSFORMERS_OFFLINE=1 \
--fp16 \
--seed 42 \
--position-embedding-type alibi \
--optimizer adam \
--optimizer fused_adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--adam-eps 1e-8 \