mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-02 11:58:26 +08:00
cd5816b627
Merge pull request !1598 from chenqianghw/master
268 lines
9.4 KiB
Python
268 lines
9.4 KiB
Python
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
|
"""Pretrain GPT."""
|
|
|
|
import os
|
|
from functools import partial
|
|
from typing import Union
|
|
|
|
import torch
|
|
import modellink
|
|
from megatron.training import get_args
|
|
from megatron.training import print_rank_0
|
|
from megatron.training import get_timers
|
|
from megatron.training import get_tokenizer
|
|
from megatron.core import mpu, tensor_parallel
|
|
from megatron.core.enums import ModelType
|
|
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
|
|
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
|
|
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
|
|
import megatron.legacy.model
|
|
from megatron.core.models.gpt import GPTModel
|
|
from modellink.training import pretrain
|
|
from megatron.core.transformer.spec_utils import import_module
|
|
from megatron.training.utils import (
|
|
get_batch_on_this_cp_rank,
|
|
get_batch_on_this_tp_rank,
|
|
average_losses_across_data_parallel_group
|
|
)
|
|
from megatron.training.arguments import core_transformer_config_from_args
|
|
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
|
|
from megatron.core.models.gpt.gpt_layer_specs import (
|
|
get_gpt_layer_local_spec,
|
|
get_gpt_layer_with_transformer_engine_spec,
|
|
)
|
|
from modellink.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets as build_instruction_dataset
|
|
from modellink.utils import get_tune_attention_mask, get_finetune_data_on_this_tp_rank, generate_actual_seq_len
|
|
|
|
|
|
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
|
|
"""Builds the model.
|
|
|
|
If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model.
|
|
|
|
Args:
|
|
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
|
|
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
|
|
|
|
|
|
Returns:
|
|
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
|
|
"""
|
|
args = get_args()
|
|
use_te = args.transformer_impl == "transformer_engine"
|
|
|
|
print_rank_0('building GPT model ...')
|
|
# Experimental loading arguments from yaml
|
|
if args.yaml_cfg is not None:
|
|
config = core_transformer_config_from_yaml(args, "language_model")
|
|
else:
|
|
config = core_transformer_config_from_args(args)
|
|
|
|
if args.use_mcore_models:
|
|
if args.spec is not None:
|
|
transformer_layer_spec = import_module(args.spec)
|
|
else:
|
|
if use_te:
|
|
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm)
|
|
else:
|
|
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm)
|
|
|
|
model = GPTModel(
|
|
config=config,
|
|
transformer_layer_spec=transformer_layer_spec,
|
|
vocab_size=args.padded_vocab_size,
|
|
max_sequence_length=args.max_position_embeddings,
|
|
pre_process=pre_process,
|
|
post_process=post_process,
|
|
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
|
|
parallel_output=True,
|
|
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
|
|
position_embedding_type=args.position_embedding_type,
|
|
rotary_percent=args.rotary_percent,
|
|
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor
|
|
)
|
|
else:
|
|
if not args.context_parallel_size == 1:
|
|
raise ValueError("Context parallelism is only supported with Megatron Core!")
|
|
|
|
model = megatron.legacy.model.GPTModel(
|
|
config,
|
|
num_tokentypes=0,
|
|
parallel_output=True,
|
|
pre_process=pre_process,
|
|
post_process=post_process
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def get_batch(data_iterator):
|
|
"""Generate a batch."""
|
|
|
|
args = get_args()
|
|
|
|
if args.is_instruction_dataset:
|
|
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
|
|
if args.variable_seq_lengths and args.pipeline_model_parallel_size > 2:
|
|
tokens, attention_mask = get_finetune_data_on_this_tp_rank(data_iterator)
|
|
|
|
return tokens, None, None, attention_mask, None
|
|
else:
|
|
return None, None, None, None, None
|
|
# Items and their type.
|
|
keys = ['input_ids', 'attention_mask', 'labels']
|
|
data_type = torch.int64
|
|
|
|
# Broadcast data.
|
|
data_b = tensor_parallel.broadcast_data(keys, next(data_iterator), data_type)
|
|
|
|
# Unpack
|
|
labels = data_b.get('labels').long()
|
|
tokens = data_b.get('input_ids').long()
|
|
attention_mask_1d = data_b.get('attention_mask').long()
|
|
# ignored label -100
|
|
loss_mask = torch.where(labels == -100, 0, 1)
|
|
|
|
attention_mask = get_tune_attention_mask(attention_mask_1d)
|
|
|
|
return tokens, labels, loss_mask, attention_mask, None
|
|
|
|
# get batches based on the TP rank you are on
|
|
batch = get_batch_on_this_tp_rank(data_iterator)
|
|
if args.reset_position_ids:
|
|
generate_actual_seq_len(batch)
|
|
# slice batch along sequence dimension for context parallelism
|
|
batch = get_batch_on_this_cp_rank(batch)
|
|
|
|
return batch.values()
|
|
|
|
|
|
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
|
|
"""Loss function.
|
|
|
|
Args:
|
|
loss_mask (torch.Tensor): Used to mask out some portions of the loss
|
|
output_tensor (torch.Tensor): The tensor with the losses
|
|
"""
|
|
args = get_args()
|
|
|
|
losses = output_tensor.float()
|
|
if args.is_instruction_dataset:
|
|
loss_mask = loss_mask[..., 1:].view(-1).float()
|
|
else:
|
|
loss_mask = loss_mask.view(-1).float()
|
|
if args.context_parallel_size > 1:
|
|
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])
|
|
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
|
|
loss = loss[0] / loss[1]
|
|
else:
|
|
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
|
|
|
# Check individual rank losses are not NaN prior to DP all-reduce.
|
|
if args.check_for_nan_in_loss_and_grad:
|
|
global_rank = torch.distributed.get_rank()
|
|
if loss.isnan():
|
|
raise ValueError(f'Rank {global_rank}: found NaN in local forward loss calculation. '
|
|
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}')
|
|
|
|
# Reduce loss for logging.
|
|
averaged_loss = average_losses_across_data_parallel_group([loss])
|
|
|
|
return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]}
|
|
|
|
|
|
def forward_step(data_iterator, model: GPTModel):
|
|
"""Forward training step.
|
|
|
|
Args:
|
|
data_iterator : Input data iterator
|
|
model (GPTModel): The GPT Model
|
|
"""
|
|
args = get_args()
|
|
timers = get_timers()
|
|
|
|
# Get the batch.
|
|
timers('batch-generator', log_level=2).start()
|
|
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
|
|
data_iterator)
|
|
timers('batch-generator').stop()
|
|
|
|
output_tensor = model(tokens, position_ids, attention_mask,
|
|
labels=labels)
|
|
|
|
return output_tensor, partial(loss_func, loss_mask)
|
|
|
|
|
|
def is_dataset_built_on_rank():
|
|
return mpu.get_tensor_model_parallel_rank() == 0
|
|
|
|
|
|
def core_gpt_dataset_config_from_args(args):
|
|
tokenizer = get_tokenizer()
|
|
|
|
return GPTDatasetConfig(
|
|
random_seed=args.seed,
|
|
sequence_length=args.seq_length,
|
|
blend=args.data_path,
|
|
blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path],
|
|
split=args.split,
|
|
path_to_cache=args.data_cache_path,
|
|
mock=args.mock_data,
|
|
mmap_bin_files=args.mmap_bin_files,
|
|
tokenizer=tokenizer,
|
|
reset_position_ids=args.reset_position_ids,
|
|
reset_attention_mask=args.reset_attention_mask,
|
|
eod_mask_loss=args.eod_mask_loss,
|
|
create_attention_mask=args.create_attention_mask_in_dataloader,
|
|
)
|
|
|
|
|
|
def train_valid_test_datasets_provider(train_val_test_num_samples):
|
|
"""Build the train test and validation datasets.
|
|
|
|
Args:
|
|
train_val_test_num_samples : A list containing the number of samples in train test and validation.
|
|
"""
|
|
args = get_args()
|
|
|
|
config = core_gpt_dataset_config_from_args(args)
|
|
|
|
if config.mock:
|
|
dataset_type = MockGPTDataset
|
|
else:
|
|
dataset_type = GPTDataset
|
|
print_rank_0("> building train, validation, and test datasets for GPT ...")
|
|
|
|
if args.is_instruction_dataset:
|
|
train_ds, valid_ds, test_ds = build_instruction_dataset(
|
|
data_prefix=args.data_path,
|
|
splits_string=args.split,
|
|
train_valid_test_num_samples=train_val_test_num_samples,
|
|
seq_length=args.seq_length,
|
|
seed=args.seed)
|
|
else:
|
|
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
|
|
dataset_type,
|
|
train_val_test_num_samples,
|
|
is_dataset_built_on_rank,
|
|
config
|
|
).build()
|
|
|
|
print_rank_0("> finished creating GPT datasets ...")
|
|
|
|
return train_ds, valid_ds, test_ds
|
|
|
|
|
|
def main():
|
|
# Temporary for transition to core datasets
|
|
train_valid_test_datasets_provider.is_distributed = True
|
|
|
|
pretrain(train_valid_test_datasets_provider,
|
|
model_provider,
|
|
ModelType.encoder_or_decoder,
|
|
forward_step)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|