ModelLink2/pretrain_gpt.py
guhangsong 39d6fd7336 !1218 迁移megatron patch
Merge pull request !1218 from guhangsong/patch
2024-04-23 01:57:03 +00:00

267 lines
8.9 KiB
Python

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import os
from functools import partial
from typing import Union
import torch
from torch import Tensor
import modellink
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import GPTDataset
import megatron.model
from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain
from megatron.core.transformer.spec_utils import import_module
from megatron.utils import (
get_ltor_masks_and_position_ids,
get_batch_on_this_cp_rank,
average_losses_across_data_parallel_group
)
from megatron.arguments import core_transformer_config_from_args
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec,
gpt_layer_with_transformer_engine_spec_moe
)
from modellink.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets as build_instruction_dataset
from modellink.utils import get_tune_attention_mask
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.model.GPTModel]:
"""Builds the model.
If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
Returns:
Union[GPTModel, megatron.model.GPTModel]: The returned model
"""
args = get_args()
print_rank_0('building GPT model ...')
config = core_transformer_config_from_args(get_args())
if args.use_mcore_models:
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if args.num_experts is None:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec()
else:
transformer_layer_spec = gpt_layer_with_transformer_engine_spec_moe
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent
)
else:
if not args.context_parallel_size == 1:
raise ValueError("Context parallelism is only supported with Megatron Core!")
model = megatron.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def get_batch(data_iterator):
"""Generate a batch."""
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
args = get_args()
tokenizer = get_tokenizer()
if args.is_instruction_dataset:
# Items and their type.
keys = ['input_ids', 'attention_mask', 'labels']
data_type = torch.int64
# Broadcast data.
data_b = tensor_parallel.broadcast_data(keys, next(data_iterator), data_type)
# Unpack
labels = data_b.get('labels').long()
tokens = data_b.get('input_ids').long()
attention_mask_1d = data_b.get('attention_mask').long()
# ignored label -100
loss_mask = torch.where(labels == -100, 0, 1)
attention_mask = get_tune_attention_mask(attention_mask_1d, args.reset_attention_mask)
return tokens, labels, loss_mask, attention_mask, None
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)
return batch.values()
def loss_func(loss_mask: Tensor, output_tensor: Tensor):
"""Loss function.
Args:
loss_mask (Tensor): Used to mask out some portions of the loss
output_tensor (Tensor): The tensor with the losses
"""
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
if args.context_parallel_size > 1:
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
loss = loss[0] / loss[1]
else:
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Check individual rank losses are not NaN prior to DP all-reduce.
if args.check_for_nan_in_loss_and_grad:
global_rank = torch.distributed.get_rank()
if loss.isnan():
raise ValueError(f'Rank {global_rank}: found NaN in local forward loss calculation. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}')
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model: GPTModel):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def is_dataset_built_on_rank():
return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig(
is_built_on_rank=is_dataset_built_on_rank,
random_seed=args.seed,
sequence_length=args.seq_length,
blend=args.data_path,
blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path],
split=args.split,
path_to_cache=args.data_cache_path,
)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build the train test and validation datasets.
Args:
train_val_test_num_samples : A list containing the number of samples in train test and validation.
"""
args = get_args()
print_rank_0("> building train, validation, and test datasets for GPT ...")
if args.is_instruction_dataset:
train_ds, valid_ds, test_ds = build_instruction_dataset(
data_prefix=args.data_path,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed)
else:
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
GPTDataset,
train_val_test_num_samples,
core_gpt_dataset_config_from_args(args)
).build()
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
jit_compile = False if os.environ.get("WITHOUT_JIT_COMPILE") else True
torch.npu.set_compile_mode(jit_compile=jit_compile)
# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True
pretrain(train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})