mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-01 19:39:02 +08:00
up
This commit is contained in:
parent
be5d413ec6
commit
7a21f0bf58
6
.gitignore
vendored
6
.gitignore
vendored
@ -137,3 +137,9 @@ dmypy.json
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# Vscode stuff
|
||||
.vscode
|
||||
|
||||
# pycharm stuff
|
||||
.idea
|
||||
|
7
OWNERS
7
OWNERS
@ -1,7 +0,0 @@
|
||||
approvers:
|
||||
- dilililiwhy
|
||||
- wujianping996
|
||||
reviewers:
|
||||
- xqiangx1991
|
||||
- sanjaychan
|
||||
- fengliangjun66
|
@ -13,14 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import copy
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch_npu.contrib import transfer_to_npu
|
||||
import deepspeed
|
||||
try:
|
||||
import torch_npu
|
||||
from torch_npu.contrib import transfer_to_npu
|
||||
import deepspeed_npu
|
||||
except Exception as e:
|
||||
logging.warning("Warning: You did not install torch_npu or deepspeed_npu")
|
||||
from torch import distributed as dist
|
||||
import deepspeed_npu
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from .package_info import (
|
||||
__description__,
|
||||
@ -41,7 +44,7 @@ from .global_vars import get_tokenizer
|
||||
from .global_vars import get_tensorboard_writer
|
||||
from .global_vars import get_adlr_autoresume
|
||||
from .global_vars import get_timers
|
||||
from .initialize import initialize_megatron
|
||||
from .initialize import initialize_megatron, adaptor_deepspeed_initialize
|
||||
|
||||
def print_rank_0(message):
|
||||
"""If distributed is initialized, print only on rank 0."""
|
||||
@ -51,9 +54,10 @@ def print_rank_0(message):
|
||||
else:
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
def is_last_rank():
|
||||
return torch.distributed.get_rank() == (
|
||||
torch.distributed.get_world_size() - 1)
|
||||
return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1)
|
||||
|
||||
|
||||
def print_rank_last(message):
|
||||
"""If distributed is initialized, print only on last rank."""
|
||||
@ -63,18 +67,25 @@ def print_rank_last(message):
|
||||
else:
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
def is_aml():
|
||||
# Are we running inside an Azure Machine Learning (AML) environment?
|
||||
return 'AZUREML_EXPERIMENT_ID' in os.environ
|
||||
|
||||
|
||||
def is_rank_0():
|
||||
"""Check whether it is rank 0. For AML, check if it is rank 0 of a node"""
|
||||
if torch.distributed.is_initialized():
|
||||
if torch.distributed.get_rank() == 0 or (
|
||||
is_aml() and torch.distributed.get_rank() % get_accelerator().device_count() == 0
|
||||
):
|
||||
is_aml() and torch.distributed.get_rank() % get_accelerator().device_count() == 0
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
# Fixed an bug where deepspeed PipelineEngine is incompatible with Sequence Parallel.
|
||||
# When Sequence Parallel is used, PipelineEngine cannot be used, Otherwise, precision problems occur.
|
||||
deepspeed.initialize = adaptor_deepspeed_initialize
|
||||
|
@ -15,11 +15,14 @@
|
||||
|
||||
"""ascendspeed arguments."""
|
||||
|
||||
import dataclasses
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import deepspeed
|
||||
from ascendspeed.core.transformer import TransformerConfig
|
||||
from ascendspeed.enums import PositionEmbeddingType
|
||||
from ascendspeed.data.parse import ParseDataPaths, ParseDataPathsPath
|
||||
|
||||
@ -50,6 +53,8 @@ def parse_args(extra_args_provider=None, defaults={},
|
||||
parser = _add_activation_checkpoint_args(parser)
|
||||
parser = _add_distillation_args(parser)
|
||||
parser = _add_optimized_pipeline_args(parser)
|
||||
parser = _add_manual_layer_allocation(parser)
|
||||
parser = _add_lora_args(parser)
|
||||
|
||||
# Custom arguments.
|
||||
if extra_args_provider is not None:
|
||||
@ -175,11 +180,21 @@ def parse_args(extra_args_provider=None, defaults={},
|
||||
print('using {} for parameters ...'.format(args.params_dtype),
|
||||
flush=True)
|
||||
|
||||
# If we do accumulation and all-reduces in fp32, we need to have
|
||||
# local DDP and we should set the use-contiguous-buffers-in-ddp.
|
||||
# If we do accumulation and all-reduces in fp32, we need to have local DDP
|
||||
# and we should make sure use-contiguous-buffers-in-local-ddp is not off.
|
||||
if args.accumulate_allreduce_grads_in_fp32:
|
||||
assert args.DDP_impl == 'local'
|
||||
args.use_contiguous_buffers_in_ddp = True
|
||||
assert args.use_contiguous_buffers_in_local_ddp
|
||||
|
||||
# If we use the distributed optimizer, we need to have local DDP
|
||||
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
|
||||
if args.use_distributed_optimizer:
|
||||
assert args.DDP_impl == 'local'
|
||||
assert args.use_contiguous_buffers_in_local_ddp
|
||||
|
||||
# For torch DDP, we do not use contiguous buffer
|
||||
if args.DDP_impl == 'torch':
|
||||
args.use_contiguous_buffers_in_local_ddp = False
|
||||
|
||||
if args.dataloader_type is None:
|
||||
args.dataloader_type = 'single'
|
||||
@ -221,6 +236,7 @@ def parse_args(extra_args_provider=None, defaults={},
|
||||
'can only specify one of lr-warmup-fraction ' \
|
||||
'and lr-warmup-samples'
|
||||
|
||||
|
||||
# Check required arguments.
|
||||
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
|
||||
'max_position_embeddings']
|
||||
@ -274,10 +290,25 @@ def parse_args(extra_args_provider=None, defaults={},
|
||||
assert args.checkpoint_activations, \
|
||||
'for distribute-checkpointed-activations to work you '\
|
||||
'need to enable checkpoint-activations'
|
||||
|
||||
torch_major = int(torch.__version__.split('.')[0])
|
||||
torch_minor = int(torch.__version__.split('.')[1])
|
||||
# Persistent fused layer norm.
|
||||
if torch_major < 1 or (torch_major == 1 and torch_minor < 11):
|
||||
args.no_persist_layer_norm = True
|
||||
if args.rank == 0:
|
||||
print('Persistent fused layer norm kernel is supported from '
|
||||
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
|
||||
'Defaulting to no_persist_layer_norm=True')
|
||||
else:
|
||||
args.no_persist_layer_norm = False
|
||||
args.curriculum_learning_legacy = False
|
||||
args.compression_training = False
|
||||
|
||||
args.apply_layernorm_1p = False
|
||||
args.overlap_p2p_comm = False
|
||||
args.swiglu = False
|
||||
args.fp8_e4m3 = False
|
||||
args.fp8_hybrid = False
|
||||
args.group_query_attention = False
|
||||
# AML
|
||||
if args.aml_data_download_path is not None:
|
||||
data_paths = []
|
||||
@ -285,9 +316,13 @@ def parse_args(extra_args_provider=None, defaults={},
|
||||
data_paths.append(f"{args.aml_data_download_path}/{path}")
|
||||
args.data_path = data_paths
|
||||
|
||||
# manually layer distribute
|
||||
_get_manual_layer_allocation(args)
|
||||
|
||||
_print_args(args)
|
||||
return args
|
||||
|
||||
|
||||
def _print_args(args):
|
||||
"""Print arguments."""
|
||||
if args.rank == 0:
|
||||
@ -307,6 +342,35 @@ def _check_arg_is_not_none(args, arg):
|
||||
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
|
||||
|
||||
|
||||
def core_transformer_config_from_args(args):
|
||||
# Translate args to core transformer configuration
|
||||
kw_args = {}
|
||||
for f in dataclasses.fields(TransformerConfig):
|
||||
if hasattr(args, f.name):
|
||||
kw_args[f.name] = getattr(args, f.name)
|
||||
kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
|
||||
kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
|
||||
kw_args['deallocate_pipeline_outputs'] = False
|
||||
kw_args['pipeline_dtype'] = args.params_dtype
|
||||
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
|
||||
if args.swiglu:
|
||||
kw_args['activation_func'] = F.silu
|
||||
kw_args['gated_linear_unit'] = True
|
||||
kw_args['bias_gelu_fusion'] = False
|
||||
if args.init_method_xavier_uniform:
|
||||
kw_args['init_method'] = torch.nn.init.xavier_uniform_
|
||||
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
|
||||
kw_args['fp8'] = args.fp8_e4m3 or args.fp8_hybrid
|
||||
kw_args['fp8_e4m3'] = args.fp8_e4m3
|
||||
kw_args['fp8_margin'] = args.fp8_hybrid
|
||||
if args.group_query_attention:
|
||||
kw_args['num_query_groups'] = args.num_query_groups
|
||||
else:
|
||||
kw_args['num_query_groups'] = None
|
||||
|
||||
return TransformerConfig(**kw_args)
|
||||
|
||||
|
||||
def _add_network_size_args(parser):
|
||||
group = parser.add_argument_group(title='network size')
|
||||
|
||||
@ -563,6 +627,8 @@ def _add_training_args(parser):
|
||||
help='Use Tutel optimization for MoE')
|
||||
group.add_argument('--inference', action='store_true',
|
||||
help='Very basic inference mode: not allocating optim/lr - requires ZERO_STAGE=0')
|
||||
group.add_argument('--use-fused-rotary-pos-emb', action='store_true',
|
||||
help='use fused rotary pos emb')
|
||||
return parser
|
||||
|
||||
|
||||
@ -701,6 +767,9 @@ def _add_mixed_precision_args(parser):
|
||||
def _add_distributed_args(parser):
|
||||
group = parser.add_argument_group(title='distributed')
|
||||
|
||||
group.add_argument('--foldx-mode', default=None,
|
||||
choices=['aiao', 'fifo'],
|
||||
help='Choose fold-x pipeline parallelism.')
|
||||
group.add_argument('--tensor-model-parallel-size', type=int, default=1,
|
||||
help='Degree of tensor model parallelism.')
|
||||
group.add_argument('--enable-expert-tensor-parallelism', action='store_true',
|
||||
@ -711,6 +780,9 @@ def _add_distributed_args(parser):
|
||||
help="use sequence parallelism")
|
||||
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
|
||||
help='Degree of pipeline model parallelism.')
|
||||
group.add_argument('--pipeline-model-parallel-split-rank',
|
||||
type=int, default=None,
|
||||
help='Rank where encoder and decoder should be split.')
|
||||
group.add_argument('--moe-expert-parallel-size', type=int, default=1,
|
||||
help='Degree of the MoE expert parallelism.')
|
||||
group.add_argument('--model-parallel-size', type=int, default=None,
|
||||
@ -725,9 +797,10 @@ def _add_distributed_args(parser):
|
||||
choices=['local', 'torch'],
|
||||
help='which DistributedDataParallel implementation '
|
||||
'to use.')
|
||||
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
|
||||
help='If set, use contiguous buffer in DDP. Note that '
|
||||
'this option only works woth local DDP.' )
|
||||
group.add_argument('--no-contiguous-buffers-in-local-ddp',
|
||||
action='store_false', help='If set, dont use '
|
||||
'contiguous buffer in local DDP.',
|
||||
dest='use_contiguous_buffers_in_local_ddp')
|
||||
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
|
||||
help='Use scatter/gather to optimize communication of tensors in pipeline',
|
||||
dest='scatter_gather_tensors_in_pipeline')
|
||||
@ -742,6 +815,10 @@ def _add_distributed_args(parser):
|
||||
group.add_argument('--use-cpu-initialization', action='store_true',
|
||||
default=None, help='If set, affine parallel weights '
|
||||
'initialization uses CPU' )
|
||||
group.add_argument('--triangle-attn', action='store_true',
|
||||
help="use triangle attention instead self attention")
|
||||
group.add_argument('--use-distributed-optimizer', action='store_true',
|
||||
help='Use distributed optimizer.')
|
||||
return parser
|
||||
|
||||
|
||||
@ -768,6 +845,8 @@ def _add_data_args(parser):
|
||||
'1) a single data path, 2) multiple datasets in the'
|
||||
'form: dataset1-weight dataset1-path dataset2-weight '
|
||||
'dataset2-path ...')
|
||||
group.add_argument('--is-instruction-dataset', action='store_true',
|
||||
help='use instruction dataset or not')
|
||||
group.add_argument('--split', type=str, default='969, 30, 1',
|
||||
help='Comma-separated list of proportions for training,'
|
||||
' validation, and test split. For example the split '
|
||||
@ -1016,6 +1095,10 @@ def _add_activation_checkpoint_args(parser):
|
||||
help='does a synchronize at the beginning and end of each checkpointed layer.')
|
||||
group.add_argument('--profile-backward', action='store_true',
|
||||
help='Enables backward pass profiling for checkpointed layers.')
|
||||
group.add_argument('--checkpoint_policy', type=str, default='full', choices=['full', 'block'],
|
||||
help="activation checkpoint policy")
|
||||
group.add_argument('--checkpoint_block_layer', type=int, default=25,
|
||||
help="activation checkpoint block layer number")
|
||||
return parser
|
||||
|
||||
|
||||
@ -1047,6 +1130,7 @@ def _add_distillation_args(parser):
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _add_optimized_pipeline_args(parser):
|
||||
group = parser.add_argument_group(title='optimized_pipeline')
|
||||
|
||||
@ -1060,3 +1144,43 @@ def _add_optimized_pipeline_args(parser):
|
||||
'--manual-mbs example-config-2')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _add_manual_layer_allocation(parser):
|
||||
group = parser.add_argument_group(title='manual_layer_allocation')
|
||||
group.add_argument('--use-manual-layer-allocation', action='store_true',
|
||||
help='Enable manually allocated layers for pipeline model parallel.')
|
||||
group.add_argument('--manual-layers', type=str, help='a list of number of layers, '
|
||||
'seperated by comma; e.g., 4,4,4,4')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _get_manual_layer_allocation(args=None):
|
||||
if args is not None and args.use_manual_layer_allocation:
|
||||
assert args.manual_layers is not None
|
||||
manual_layers = list(map(int, args.manual_layers.split(',')))
|
||||
assert len(manual_layers) == args.pipeline_model_parallel_size
|
||||
args.manual_layers = manual_layers
|
||||
|
||||
|
||||
def _add_lora_args(parser):
|
||||
group = parser.add_argument_group(title='lora')
|
||||
|
||||
group.add_argument('--lora-target-modules', nargs='+', type=str, default=[],
|
||||
help='Lora target modules.')
|
||||
group.add_argument('--lora-load', type=str, default=None,
|
||||
help='Directory containing a lora model checkpoint.')
|
||||
group.add_argument('--lora-r', type=int, default=16,
|
||||
help='Lora r.')
|
||||
group.add_argument('--lora-alpha', type=int, default=32,
|
||||
help='Lora alpha.')
|
||||
group.add_argument('--lora-modules-to-save', nargs='+', type=str, default=None,
|
||||
help='Lora modules to save.')
|
||||
group.add_argument('--lora-register-forward-hook', nargs='+', type=str,
|
||||
default=['word_embeddings', 'input_layernorm'],
|
||||
help='Lora register forward hook.')
|
||||
group.add_argument('--lora-adapter-name', type=str, default='default',
|
||||
help='Lora adapter name.')
|
||||
|
||||
return parser
|
||||
|
@ -21,6 +21,8 @@ import sys
|
||||
import numpy as np
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
|
||||
from ascendspeed.enums import PositionEmbeddingType
|
||||
|
||||
|
||||
@ -31,6 +33,10 @@ from ascendspeed import (get_args,
|
||||
update_num_microbatches,
|
||||
utils)
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed.model import DistributedDataParallel as LocalDDP
|
||||
from ascendspeed.model.lora_utils import is_enable_lora, get_lora_state_dict, lora_custom_load_fn_for_deepspeed, \
|
||||
get_lora_model_classes, get_lora_state_dict_with_deepspeed, update_model_state_dict_with_megatron, \
|
||||
get_lora_load_fn_with_deepspeed, handle_lora_modules_to_save_key_with_megatron
|
||||
|
||||
_CHECKPOINT_VERSION = None
|
||||
|
||||
@ -90,7 +96,7 @@ def ensure_directory_exists(filename):
|
||||
|
||||
|
||||
def get_checkpoint_name(checkpoints_path, iteration,
|
||||
release=False):
|
||||
release=False, model_name='model_optim_rng.pt'):
|
||||
"""A unified checkpoint name."""
|
||||
if release:
|
||||
directory = 'release'
|
||||
@ -101,12 +107,12 @@ def get_checkpoint_name(checkpoints_path, iteration,
|
||||
return os.path.join(checkpoints_path, directory,
|
||||
'mp_rank_{:02d}'.format(
|
||||
parallel_state.get_tensor_model_parallel_rank()),
|
||||
'model_optim_rng.pt')
|
||||
model_name)
|
||||
return os.path.join(checkpoints_path, directory,
|
||||
'mp_rank_{:02d}_{:03d}'.format(
|
||||
parallel_state.get_tensor_model_parallel_rank(),
|
||||
parallel_state.get_pipeline_model_parallel_rank()),
|
||||
'model_optim_rng.pt')
|
||||
model_name)
|
||||
|
||||
|
||||
def get_checkpoint_tracker_filename(checkpoints_path):
|
||||
@ -121,7 +127,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
|
||||
|
||||
# Only rank zero of the data parallel writes to the disk.
|
||||
if not args.deepspeed:
|
||||
model = utils.unwrap_model(model)
|
||||
unwrap_model_classes = (torchDDP, LocalDDP)
|
||||
if is_enable_lora():
|
||||
unwrap_model_classes += get_lora_model_classes()
|
||||
model = utils.unwrap_model(model, unwrap_model_classes)
|
||||
|
||||
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
|
||||
iteration, args.save))
|
||||
@ -138,12 +147,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
|
||||
|
||||
# DeepSpeed saves the model/optimizer/scheduler
|
||||
if not args.deepspeed:
|
||||
if len(model) == 1:
|
||||
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
|
||||
else:
|
||||
for i in range(len(model)):
|
||||
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
|
||||
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
|
||||
get_model_state_dict(model, state_dict)
|
||||
|
||||
# Optimizer stuff.
|
||||
if not args.no_save_optim:
|
||||
@ -173,6 +177,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
|
||||
if args.no_pipeline_parallel:
|
||||
original_state_dict = model[0].module.state_dict
|
||||
model[0].module.state_dict = model[0].module.state_dict_for_save_checkpoint
|
||||
if is_enable_lora():
|
||||
model[0].module.state_dict = get_lora_state_dict_with_deepspeed(model=model[0])
|
||||
|
||||
# Saving is a collective communication
|
||||
checkpoint_name = get_checkpoint_name(args.save, iteration)
|
||||
@ -185,6 +191,25 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
|
||||
if args.no_pipeline_parallel:
|
||||
model[0].module.state_dict = original_state_dict
|
||||
|
||||
save_checkpoint_post_process(iteration)
|
||||
|
||||
|
||||
def get_model_state_dict(model, state_dict):
|
||||
if len(model) == 1:
|
||||
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
|
||||
if is_enable_lora():
|
||||
state_dict['model'] = get_lora_state_dict(state_dict['model'])
|
||||
else:
|
||||
for i in range(len(model)):
|
||||
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
|
||||
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
|
||||
if is_enable_lora():
|
||||
state_dict['model%d' % i] = get_lora_state_dict(state_dict['model%d' % i])
|
||||
|
||||
|
||||
def save_checkpoint_post_process(iteration):
|
||||
args = get_args()
|
||||
|
||||
# Wait so everyone is done (necessary)
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
@ -202,6 +227,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def _transpose_first_dim(t, num_splits, num_splits_first, model):
|
||||
input_shape = t.size()
|
||||
# We use a self_attention module but the values extracted aren't
|
||||
@ -271,6 +297,88 @@ def fix_query_key_value_ordering(model, checkpoint_version):
|
||||
print_rank_0(" succesfully fixed query-key-values ordering for"
|
||||
" checkpoint version {}".format(checkpoint_version))
|
||||
|
||||
|
||||
def read_tracker(load_dir):
|
||||
args = get_args()
|
||||
iteration = 0
|
||||
release = False
|
||||
# Read the tracker file and set the iteration.
|
||||
tracker_filename = get_checkpoint_tracker_filename(load_dir)
|
||||
|
||||
# If no tracker file, return iteration zero.
|
||||
if not os.path.isfile(tracker_filename):
|
||||
print_rank_0('WARNING: could not find the metadata file {} '.format(
|
||||
tracker_filename))
|
||||
print_rank_0(' will not load any checkpoints and will start from '
|
||||
'random')
|
||||
return False, iteration, release
|
||||
|
||||
# Otherwise, read the tracker file and either set the iteration or
|
||||
# mark it as a release checkpoint.
|
||||
with open(tracker_filename, 'r') as f:
|
||||
metastring = f.read().strip()
|
||||
try:
|
||||
iteration = int(metastring)
|
||||
except ValueError:
|
||||
release = metastring == 'release'
|
||||
if not release:
|
||||
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
|
||||
tracker_filename))
|
||||
sys.exit()
|
||||
|
||||
if not args.mos and not args.kd:
|
||||
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
|
||||
tracker_filename)
|
||||
|
||||
return True, iteration, release
|
||||
|
||||
|
||||
def get_state_dict_and_release(load_dir, lora_load_dir=None):
|
||||
args = get_args()
|
||||
|
||||
read_tracker_success, iteration, release = read_tracker(load_dir)
|
||||
if not read_tracker_success:
|
||||
raise ValueError(f"{load_dir} do not have tracker.")
|
||||
if lora_load_dir:
|
||||
read_tracker_success, lora_iteration, lora_release = read_tracker(lora_load_dir)
|
||||
if not read_tracker_success:
|
||||
raise ValueError(f"{lora_load_dir} do not have tracker.")
|
||||
|
||||
# Checkpoint.
|
||||
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
|
||||
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
|
||||
model_checkpoint_name = None
|
||||
if lora_load_dir: # 有lora目录时,其他参数都应从lora目录读取,load目录只提供原始模型权重
|
||||
model_checkpoint_name = checkpoint_name
|
||||
checkpoint_name = get_checkpoint_name(lora_load_dir, lora_iteration, lora_release)
|
||||
print_rank_0(
|
||||
f' loading lora checkpoint from {args.lora_load} at iteration {lora_iteration} release:{lora_release}')
|
||||
release = lora_release
|
||||
|
||||
# Load the checkpoint.
|
||||
try:
|
||||
state_dict = load_state_dict_from_checkpoint_with_megatron(checkpoint_name,
|
||||
model_checkpoint_name=model_checkpoint_name)
|
||||
except ModuleNotFoundError:
|
||||
from ascendspeed.fp16_deprecated import loss_scaler
|
||||
# For backward compatibility.
|
||||
print_rank_0(' > deserializing using the old code structure ...')
|
||||
sys.modules['fp16.loss_scaler'] = sys.modules[
|
||||
'ascendspeed.fp16_deprecated.loss_scaler']
|
||||
sys.modules['ascendspeed.fp16.loss_scaler'] = sys.modules[
|
||||
'ascendspeed.fp16_deprecated.loss_scaler']
|
||||
state_dict = load_state_dict_from_checkpoint_with_megatron(checkpoint_name,
|
||||
model_checkpoint_name=model_checkpoint_name)
|
||||
sys.modules.pop('fp16.loss_scaler', None)
|
||||
sys.modules.pop('ascendspeed.fp16.loss_scaler', None)
|
||||
except BaseException as e:
|
||||
print_rank_0('could not load the checkpoint')
|
||||
print_rank_0(e)
|
||||
sys.exit()
|
||||
|
||||
return state_dict, release, checkpoint_name
|
||||
|
||||
|
||||
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True, load_only_weights=False):
|
||||
"""Load a model checkpoint and return the iteration.
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
@ -279,70 +387,41 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
|
||||
"""
|
||||
args = get_args()
|
||||
load_dir = getattr(args, load_arg)
|
||||
lora_load_dir = getattr(args, 'lora_load')
|
||||
|
||||
if args.deepspeed:
|
||||
load_optimizer_states = False if args.no_load_optim else True
|
||||
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states)
|
||||
if not os.path.exists(load_dir):
|
||||
print_rank_0(f"WARNING: could not find the metadata file {load_dir}")
|
||||
print_rank_0(f" will not load any checkpoints and will start from random")
|
||||
return 0
|
||||
custom_load_fn, load_dir = get_custom_load_fn(model=model[0], load_dir=load_dir, lora_load_dir=lora_load_dir)
|
||||
load_zero_optim = sum(['zero' in file for file in os.listdir(load_dir)]) > 0
|
||||
release = not load_zero_optim
|
||||
loaded_dir, state_dict = model[0].load_checkpoint(
|
||||
load_dir,
|
||||
load_module_strict=strict,
|
||||
load_module_only=not load_zero_optim,
|
||||
load_optimizer_states=load_zero_optim,
|
||||
load_lr_scheduler_states=load_zero_optim,
|
||||
custom_load_fn=custom_load_fn
|
||||
)
|
||||
if loaded_dir is None:
|
||||
print_rank_0(f"WARNING: could not find the metadata file {load_dir}")
|
||||
print_rank_0(f" will not load any checkpoints and will start from random")
|
||||
return 0
|
||||
release = False
|
||||
checkpoint_name = loaded_dir # 开启lora时主要参数会从lora_load里读取,所以最后打印时用checkpoint_name传递
|
||||
else:
|
||||
model = utils.unwrap_model(model)
|
||||
unwrap_model_classes = (torchDDP, LocalDDP)
|
||||
if is_enable_lora():
|
||||
unwrap_model_classes += get_lora_model_classes()
|
||||
model = utils.unwrap_model(model, unwrap_model_classes)
|
||||
|
||||
# Read the tracker file and set the iteration.
|
||||
tracker_filename = get_checkpoint_tracker_filename(load_dir)
|
||||
|
||||
# If no tracker file, return iretation zero.
|
||||
if not os.path.isfile(tracker_filename):
|
||||
print_rank_0('WARNING: could not find the metadata file {} '.format(
|
||||
tracker_filename))
|
||||
print_rank_0(' will not load any checkpoints and will start from '
|
||||
'random')
|
||||
return 0
|
||||
|
||||
# Otherwise, read the tracker file and either set the iteration or
|
||||
# mark it as a release checkpoint.
|
||||
iteration = 0
|
||||
release = False
|
||||
with open(tracker_filename, 'r') as f:
|
||||
metastring = f.read().strip()
|
||||
try:
|
||||
iteration = int(metastring)
|
||||
except ValueError:
|
||||
release = metastring == 'release'
|
||||
if not release:
|
||||
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
|
||||
tracker_filename))
|
||||
sys.exit()
|
||||
|
||||
if not args.mos and not args.kd:
|
||||
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
|
||||
tracker_filename)
|
||||
|
||||
# Checkpoint.
|
||||
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
|
||||
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
|
||||
|
||||
# Load the checkpoint.
|
||||
try:
|
||||
state_dict = torch.load(checkpoint_name, map_location='cpu')
|
||||
except ModuleNotFoundError:
|
||||
from ascendspeed.fp16_deprecated import loss_scaler
|
||||
# For backward compatibility.
|
||||
print_rank_0(' > deserializing using the old code structure ...')
|
||||
sys.modules['fp16.loss_scaler'] = sys.modules[
|
||||
'ascendspeed.fp16_deprecated.loss_scaler']
|
||||
sys.modules['ascendspeed.fp16.loss_scaler'] = sys.modules[
|
||||
'ascendspeed.fp16_deprecated.loss_scaler']
|
||||
state_dict = torch.load(checkpoint_name, map_location='cpu')
|
||||
sys.modules.pop('fp16.loss_scaler', None)
|
||||
sys.modules.pop('ascendspeed.fp16.loss_scaler', None)
|
||||
except BaseException as e:
|
||||
print_rank_0('could not load the checkpoint')
|
||||
print_rank_0(e)
|
||||
sys.exit()
|
||||
state_dict, release, checkpoint_name = get_state_dict_and_release(load_dir=load_dir,
|
||||
lora_load_dir=lora_load_dir)
|
||||
except ValueError as e:
|
||||
print_rank_0(f"{e}")
|
||||
return 0
|
||||
|
||||
# set checkpoint version
|
||||
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
|
||||
@ -353,18 +432,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
|
||||
# Make DeepSpeed engine aware of this reset of iteration
|
||||
model[0].global_steps = 0
|
||||
else:
|
||||
try:
|
||||
iteration = state_dict['iteration']
|
||||
if 'tokens' in state_dict:
|
||||
args.consumed_train_tokens = state_dict['tokens']
|
||||
except KeyError:
|
||||
try: # Backward compatible with older checkpoints
|
||||
iteration = state_dict['total_iters']
|
||||
except KeyError:
|
||||
print_rank_0('A metadata file exists but unable to load '
|
||||
'iteration from checkpoint {}, exiting'.format(
|
||||
checkpoint_name))
|
||||
sys.exit()
|
||||
iteration = load_iteration_from_state_dict(state_dict, checkpoint_name)
|
||||
|
||||
# Check arguments.
|
||||
reset_train_valid_samples = args.reset_iteration
|
||||
@ -384,8 +452,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
|
||||
|
||||
# Model.
|
||||
if not args.deepspeed:
|
||||
if is_enable_lora() and iteration == 0:
|
||||
strict = False
|
||||
if len(model) == 1:
|
||||
model[0].load_state_dict(state_dict['model'], strict=strict)
|
||||
result = model[0].load_state_dict(state_dict['model'], strict=strict)
|
||||
if not strict and result:
|
||||
print_rank_0(f"load checkpoint result:{result}")
|
||||
else:
|
||||
for i in range(len(model)):
|
||||
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
|
||||
@ -399,17 +471,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
|
||||
# Optimizer.
|
||||
if not args.deepspeed:
|
||||
if not release and not args.finetune and not args.no_load_optim:
|
||||
try:
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(state_dict['optimizer'])
|
||||
if lr_scheduler is not None and not args.no_load_lr_state:
|
||||
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
|
||||
except KeyError:
|
||||
print_rank_0('Unable to load optimizer from checkpoint {}. '
|
||||
'Specify --no-load-optim or --finetune to prevent '
|
||||
'attempting to load the optimizer state, '
|
||||
'exiting ...'.format(checkpoint_name))
|
||||
sys.exit()
|
||||
load_optimizer_from_state_dict(optimizer, lr_scheduler, state_dict, checkpoint_name)
|
||||
|
||||
# rng states.
|
||||
if not release and not args.finetune and not args.no_load_rng:
|
||||
@ -434,12 +496,66 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
|
||||
print_rank_0(f' successfully loaded checkpoint from {args.load} '
|
||||
f'at iteration {iteration}')
|
||||
print_rank_0(f' successfully loaded checkpoint from {checkpoint_name} at iteration {iteration}')
|
||||
|
||||
return iteration
|
||||
|
||||
|
||||
def get_custom_load_fn(model, load_dir, lora_load_dir=None):
|
||||
custom_load_fn = None
|
||||
if is_enable_lora():
|
||||
if lora_load_dir:
|
||||
custom_load_fn = get_lora_load_fn_with_deepspeed(model=model, base_model_load_dir=load_dir)
|
||||
load_dir = lora_load_dir
|
||||
else:
|
||||
custom_load_fn = lora_custom_load_fn_for_deepspeed
|
||||
return custom_load_fn, load_dir
|
||||
|
||||
|
||||
def load_optimizer_from_state_dict(optimizer, lr_scheduler, state_dict, checkpoint_name):
|
||||
args = get_args()
|
||||
|
||||
try:
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(state_dict['optimizer'])
|
||||
if lr_scheduler is not None and not args.no_load_lr_state:
|
||||
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
|
||||
except KeyError:
|
||||
print_rank_0('Unable to load optimizer from checkpoint {}. '
|
||||
'Specify --no-load-optim or --finetune to prevent '
|
||||
'attempting to load the optimizer state, '
|
||||
'exiting ...'.format(checkpoint_name))
|
||||
sys.exit()
|
||||
|
||||
|
||||
def load_iteration_from_state_dict(state_dict, checkpoint_name):
|
||||
args = get_args()
|
||||
|
||||
try:
|
||||
iteration = state_dict['iteration']
|
||||
if 'tokens' in state_dict:
|
||||
args.consumed_train_tokens = state_dict['tokens']
|
||||
except KeyError:
|
||||
try: # Backward compatible with older checkpoints
|
||||
iteration = state_dict['total_iters']
|
||||
except KeyError:
|
||||
print_rank_0('A metadata file exists but unable to load '
|
||||
'iteration from checkpoint {}, exiting'.format(
|
||||
checkpoint_name))
|
||||
sys.exit()
|
||||
return iteration
|
||||
|
||||
|
||||
def load_state_dict_from_checkpoint_with_megatron(checkpoint_name, model_checkpoint_name=None):
|
||||
state_dict = torch.load(checkpoint_name, map_location='cpu')
|
||||
if model_checkpoint_name:
|
||||
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
|
||||
state_dict = update_model_state_dict_with_megatron(model_state_dict=model_state_dict,
|
||||
state_dict=state_dict)
|
||||
state_dict = handle_lora_modules_to_save_key_with_megatron(state_dict)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_biencoder_checkpoint(model, only_query_model=False,
|
||||
only_context_model=False, custom_load_path=None):
|
||||
"""
|
||||
|
@ -2,11 +2,10 @@ import ascendspeed.core.parallel_state
|
||||
import ascendspeed.core.tensor_parallel
|
||||
import ascendspeed.core.utils
|
||||
|
||||
from .inference_params import InferenceParams
|
||||
from .model_parallel_config import ModelParallelConfig
|
||||
|
||||
# Alias parallel_state as mpu, its legacy name
|
||||
mpu = parallel_state
|
||||
|
||||
__all__ = [
|
||||
"parallel_state",
|
||||
"tensor_parallel",
|
||||
"utils",
|
||||
]
|
||||
__all__ = ["parallel_state", "tensor_parallel", "utils", "InferenceParams", "ModelParallelConfig"]
|
||||
|
30
ascendspeed/core/enums.py
Normal file
30
ascendspeed/core/enums.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
encoder_or_decoder = 1
|
||||
encoder_and_decoder = 2
|
||||
retro_encoder = 3
|
||||
retro_decoder = 4
|
||||
|
||||
|
||||
class LayerType(enum.Enum):
|
||||
encoder = 1
|
||||
decoder = 2
|
||||
|
||||
|
||||
class AttnType(enum.Enum):
|
||||
self_attn = 1
|
||||
cross_attn = 2
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
causal = 2 # Overrides `attention_mask` to be a lower triangular matrix
|
||||
prefix = 3
|
||||
# Forces one to pass an `attention_mask` that's 1 if we need to mask.
|
||||
# Tensor that can be broadcast to [micro_batch_size, n_head, seq_length, seq_length]
|
||||
custom = 4
|
27
ascendspeed/core/inference_params.py
Normal file
27
ascendspeed/core/inference_params.py
Normal file
@ -0,0 +1,27 @@
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
|
||||
def __init__(self, max_batch_size, max_sequence_length):
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.max_batch_size = max_batch_size
|
||||
self.sequence_len_offset = 0
|
||||
self.batch_size_offset = 0
|
||||
self.key_value_memory_dict = {}
|
||||
|
||||
def swap_key_value_dict(self, batch_idx):
|
||||
"""swap between batches"""
|
||||
if len(self.key_value_memory_dict) == 0:
|
||||
raise ValueError("Should not swap when dict in empty.")
|
||||
|
||||
for layer_number in self.key_value_memory_dict.keys():
|
||||
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
|
||||
assert (
|
||||
len(batch_idx) == inference_key_memory.shape[1]
|
||||
) # make sure batch size is the same
|
||||
new_inference_key_memory = inference_key_memory[:, batch_idx]
|
||||
new_inference_value_memory = inference_value_memory[:, batch_idx]
|
||||
self.key_value_memory_dict[layer_number] = (
|
||||
new_inference_key_memory,
|
||||
new_inference_value_memory,
|
||||
)
|
167
ascendspeed/core/model_parallel_config.py
Normal file
167
ascendspeed/core/model_parallel_config.py
Normal file
@ -0,0 +1,167 @@
|
||||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParallelConfig:
|
||||
"""Base configuration for Megatron Core
|
||||
|
||||
Model Parallelism
|
||||
-----------------
|
||||
|
||||
tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. Defaults to 1.
|
||||
|
||||
pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers across GPU
|
||||
ranks. Defaults to 1.
|
||||
|
||||
virtual_pipeline_model_parallel_size (int): Interleaved pipeline parallelism is used to improve performance by
|
||||
reducing the pipeline bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
|
||||
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel size. See Efficient
|
||||
Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: https://arxiv.org/pdf/2104.04473.pdf for
|
||||
more details. Defaults to None.
|
||||
|
||||
sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by
|
||||
parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer
|
||||
Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False.
|
||||
|
||||
Initialization
|
||||
--------------
|
||||
|
||||
perform_initialization (bool, default=True): If true, weights are initialized. This option can be useful when you
|
||||
know you are going to load values from a checkpoint.
|
||||
|
||||
use_cpu_initialization: (bool, default=False): When set to False, we initialize the weights directly on the GPU.
|
||||
Transferring weights from CPU to GPU can take a significant amount of time for large models. Defaults to False.
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
fp16 (bool): If true, train with fp16 mixed precision training. Defaults to False.
|
||||
|
||||
bf16 (bool): If true, train with bf16 mixed precision training. Defaults to False.
|
||||
|
||||
params_dtype (torch.dtype): dtype used when intializing the weights. Defaults to torch.float32
|
||||
|
||||
timers (optional, default=None): TODO
|
||||
|
||||
Optimizations
|
||||
-------------
|
||||
|
||||
gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA
|
||||
extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with
|
||||
--cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\"
|
||||
". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.
|
||||
Defaults to False.
|
||||
|
||||
async_tensor_model_parallel_allreduce (bool, default=True): If true, enables asynchronous execution of
|
||||
tensor-model-parallel all-reduce with weight gradient compuation of a column-linear layer. Defaults to False.
|
||||
|
||||
Pipeline Parallelism
|
||||
--------------------
|
||||
|
||||
pipeline_dtype (required): dtype used in p2p communication, usually params_dtype
|
||||
|
||||
grad_scale_func (optional, default=None): If using loss scaling, this function should take the loss and return the
|
||||
scaled loss. If None, no function is called on the loss.
|
||||
|
||||
enable_autocast (bool): If true runs the forward step function inside torch.autocast context. Default is False.
|
||||
|
||||
autocast_dtype (torch.dtype): dtype to pass to torch.amp.autocast when enabled. Default is pipeline_dtype.
|
||||
|
||||
variable_seq_lengths (bool, default=False): Support for variable sequence lengths across microbatches. Setting this
|
||||
communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it
|
||||
should only be set if the sequence length varies by microbatch within a global batch.
|
||||
|
||||
num_microbatches_with_partial_activation_checkpoints (int, default=None): If int, set the number of microbatches
|
||||
where not all of the layers will be checkpointed and recomputed. The rest of the microbatches within the window
|
||||
of maximum outstanding microbatches will recompute all layers (either full recompute or selective recompute). If
|
||||
None, the checkpoint and recompute will be left up to the forward_step function.
|
||||
|
||||
overlap_p2p_comm (bool, optional, default=False): When True some of the peer to peer communication for pipeline
|
||||
parallelism will overlap with computation. Must be False if batch_p2p_comm is true.
|
||||
|
||||
batch_p2p_comm (bool, default=True): Use batch_isend_irecv instead of individual isend/irecv calls. Must be False
|
||||
if overlap_p2p_comm is True.
|
||||
|
||||
batch_p2p_sync (bool, default=True): When using batch_isend_irecv, do a cuda.device.synchronize afterward to work
|
||||
around a bug in older version of PyTorch.
|
||||
|
||||
use_ring_exchange_p2p (bool, default = False): Use custom ring_exchange kernel instead of
|
||||
torch.distributed.batch_isend_irecv(). Requires custom built torch with torch.distributed.ring_exchange.
|
||||
|
||||
deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent
|
||||
to the next pipeline stage. Helps with saving memory, does nothing when pipeline parallel is not used.
|
||||
|
||||
no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel
|
||||
communication. If the model is an instance of torch.nn.DistributedDataParallel, the default is to use
|
||||
torch.nn.DistributedDataParallel.no_sync.
|
||||
|
||||
grad_sync_func (optional): Function that launches asynchronous gradient reductions (e.g. distributed optimizer
|
||||
gradient reduce-scatters). The function should take one argument: an iterable of parameters whose gradients are
|
||||
to be synchronized.
|
||||
|
||||
param_sync_func (optional): Function that launches asynchronous parameter synchronizations (e.g. distributed
|
||||
optimizer parameter all-gathers). The function should take one argument: an iterable of parameters to be
|
||||
synchronized.
|
||||
|
||||
"""
|
||||
|
||||
# Model parallelism
|
||||
tensor_model_parallel_size: int = 1
|
||||
pipeline_model_parallel_size: int = 1
|
||||
virtual_pipeline_model_parallel_size: int = None
|
||||
sequence_parallel: bool = False
|
||||
|
||||
# Initialization
|
||||
perform_initialization: bool = True
|
||||
use_cpu_initialization: bool = False
|
||||
|
||||
# Training
|
||||
fp16: bool = False
|
||||
bf16: bool = False
|
||||
params_dtype: torch.dtype = torch.float32
|
||||
timers: Callable = None
|
||||
|
||||
# Optimizations
|
||||
gradient_accumulation_fusion: bool = False
|
||||
async_tensor_model_parallel_allreduce: bool = False
|
||||
|
||||
# Pipeline Parallel
|
||||
pipeline_dtype: torch.dtype = None
|
||||
grad_scale_func: Callable = None
|
||||
enable_autocast: bool = False
|
||||
autocast_dtype: torch.dtype = None
|
||||
variable_seq_lengths: bool = False
|
||||
num_microbatches_with_partial_activation_checkpoints: int = None
|
||||
overlap_p2p_comm: bool = False
|
||||
batch_p2p_comm: bool = True
|
||||
batch_p2p_sync: bool = True
|
||||
use_ring_exchange_p2p: bool = False
|
||||
deallocate_pipeline_outputs: bool = False
|
||||
no_sync_func: Callable = None
|
||||
grad_sync_func: Callable = None
|
||||
param_sync_func: Callable = None
|
||||
|
||||
def __post_init__(self):
|
||||
""" Python dataclass method that is used to modify attributes after initialization.
|
||||
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
|
||||
"""
|
||||
if self.sequence_parallel:
|
||||
if self.tensor_model_parallel_size <= 1:
|
||||
raise ValueError("Can not use sequence paralllelism without tensor parallelism")
|
||||
if self.async_tensor_model_parallel_allreduce:
|
||||
# sequence_parallelism already does this async
|
||||
self.async_tensor_model_parallel_allreduce = False
|
||||
|
||||
if self.pipeline_model_parallel_size > 1:
|
||||
if self.pipeline_dtype is None:
|
||||
raise ValueError(
|
||||
"When using pipeline parallelism, pipeline_dtype must be specified"
|
||||
)
|
||||
|
||||
if self.autocast_dtype is None:
|
||||
self.autocast_dtype = self.params_dtype
|
@ -1,8 +1,7 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
# pylint: disable=global-statement
|
||||
|
||||
"""Model and data parallel groups."""
|
||||
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
@ -21,11 +20,14 @@ _POSITION_EMBEDDING_GROUP = None
|
||||
# Data parallel group that the current rank belongs to.
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
_DATA_PARALLEL_GROUP_GLOO = None
|
||||
# FP8 amax reduction group.
|
||||
_AMAX_REDUCTION_GROUP = None
|
||||
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
|
||||
|
||||
_PIPELINE_PREV_GROUP = None
|
||||
_PIPELINE_NEXT_GROUP = None
|
||||
# These values enable us to change the mpu sizes on the fly.
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
@ -57,7 +59,8 @@ def initialize_model_parallel(
|
||||
pipeline_model_parallel_split_rank: Optional[int] = None,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
"""Initialize model data parallel groups.
|
||||
"""
|
||||
Initialize model data parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size (int, default = 1):
|
||||
@ -115,13 +118,13 @@ def initialize_model_parallel(
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
ranks 8 to 15 belong to the second box.
|
||||
|
||||
"""
|
||||
assert not use_fp8, "FP8 not supported by AscendSpeed"
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(f'> initializing tensor model parallel with size {tensor_model_parallel_size}')
|
||||
print(f'> initializing pipeline model parallel with size {pipeline_model_parallel_size}')
|
||||
print('> initializing tensor model parallel with size {}'.format(
|
||||
tensor_model_parallel_size))
|
||||
print('> initializing pipeline model parallel with size {}'.format(
|
||||
pipeline_model_parallel_size))
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
@ -202,6 +205,8 @@ def initialize_model_parallel(
|
||||
# (first and last rank in each pipeline model-parallel group).
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
global _PIPELINE_PREV_GROUP
|
||||
global _PIPELINE_NEXT_GROUP
|
||||
assert (
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP is None
|
||||
), 'pipeline model parallel group is already initialized'
|
||||
@ -217,6 +222,13 @@ def initialize_model_parallel(
|
||||
if rank in ranks:
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
for j in iter(range(len(ranks))):
|
||||
ranks_ = [ranks[j], ranks[(j + 1) % len(ranks)]] if world_size != 1 else [ranks[j]]
|
||||
group = torch.distributed.new_group(ranks_)
|
||||
if rank == ranks[j]:
|
||||
_PIPELINE_NEXT_GROUP = group
|
||||
if rank == ranks[(j + 1) % len(ranks)]:
|
||||
_PIPELINE_PREV_GROUP = group
|
||||
# Setup embedding group (to exchange gradients between
|
||||
# first and last stages).
|
||||
if len(ranks) > 1:
|
||||
@ -276,12 +288,10 @@ def get_model_parallel_group():
|
||||
return _MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_tensor_model_parallel_group(check_initialized=True):
|
||||
def get_tensor_model_parallel_group():
|
||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||
if check_initialized:
|
||||
assert (
|
||||
_TENSOR_MODEL_PARALLEL_GROUP is not None
|
||||
), 'tensor model parallel group is not initialized'
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
|
||||
'intra_layer_model parallel group is not initialized'
|
||||
return _TENSOR_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
@ -337,6 +347,12 @@ def get_tensor_model_parallel_world_size():
|
||||
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_model_parallel_world_size():
|
||||
assert get_pipeline_model_parallel_world_size() == 1, \
|
||||
"legacy get_model_parallel_world_size is only supported if PP is disabled"
|
||||
return get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_world_size():
|
||||
"""Return world size for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
@ -371,6 +387,12 @@ def get_tensor_model_parallel_rank():
|
||||
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_model_parallel_rank():
|
||||
assert get_pipeline_model_parallel_world_size() == 1, \
|
||||
"legacy get_model_parallel_rank is only supported if PP is disabled"
|
||||
return get_tensor_model_parallel_rank()
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_rank():
|
||||
"""Return my rank for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
@ -404,10 +426,11 @@ def is_pipeline_last_stage(ignore_virtual=False):
|
||||
)
|
||||
if virtual_pipeline_model_parallel_world_size is not None \
|
||||
and get_virtual_pipeline_model_parallel_rank() != (
|
||||
virtual_pipeline_model_parallel_world_size - 1
|
||||
):
|
||||
virtual_pipeline_model_parallel_world_size - 1):
|
||||
return False
|
||||
return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
|
||||
return get_pipeline_model_parallel_rank() == (
|
||||
get_pipeline_model_parallel_world_size() - 1)
|
||||
|
||||
|
||||
|
||||
def is_rank_in_embedding_group(ignore_virtual=False):
|
||||
@ -471,6 +494,7 @@ def is_pipeline_stage_at_split():
|
||||
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
|
||||
|
||||
|
||||
|
||||
def get_virtual_pipeline_model_parallel_rank():
|
||||
"""Return the virtual pipeline-parallel rank."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
@ -541,6 +565,16 @@ def get_pipeline_model_parallel_prev_rank():
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_prev_rank_group():
|
||||
assert _PIPELINE_PREV_GROUP is not None
|
||||
return _PIPELINE_PREV_GROUP
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_next_rank_group():
|
||||
assert _PIPELINE_NEXT_GROUP is not None
|
||||
return _PIPELINE_NEXT_GROUP
|
||||
|
||||
|
||||
def get_data_parallel_world_size():
|
||||
"""Return world size for the data parallel group."""
|
||||
return torch.distributed.get_world_size(group=get_data_parallel_group())
|
||||
@ -580,6 +614,10 @@ def destroy_model_parallel():
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _DATA_PARALLEL_GROUP
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
global _PIPELINE_NEXT_GROUP
|
||||
_PIPELINE_NEXT_GROUP = None
|
||||
global _PIPELINE_PREV_GROUP
|
||||
_PIPELINE_PREV_GROUP = None
|
||||
global _EMBEDDING_GROUP
|
||||
_EMBEDDING_GROUP = None
|
||||
global _POSITION_EMBEDDING_GROUP
|
||||
@ -598,3 +636,5 @@ def destroy_model_parallel():
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _GLOBAL_MEMORY_BUFFER
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
|
||||
|
2
ascendspeed/core/pipeline_parallel/__init__.py
Normal file
2
ascendspeed/core/pipeline_parallel/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
685
ascendspeed/core/pipeline_parallel/p2p_communication.py
Normal file
685
ascendspeed/core/pipeline_parallel/p2p_communication.py
Normal file
@ -0,0 +1,685 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
import torch.distributed as dist
|
||||
from ascendspeed import core, mpu
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed.core import ModelParallelConfig
|
||||
from ascendspeed.core.parallel_state import (
|
||||
get_pipeline_model_parallel_group,
|
||||
get_pipeline_model_parallel_next_rank,
|
||||
get_pipeline_model_parallel_prev_rank,
|
||||
get_pipeline_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_pipeline_model_parallel_prev_rank_group,
|
||||
get_pipeline_model_parallel_next_rank_group
|
||||
)
|
||||
|
||||
# Types
|
||||
Shape = Union[List[int], torch.Size]
|
||||
|
||||
|
||||
|
||||
def _communicate(
|
||||
*,
|
||||
tensor_send_next: Optional[torch.Tensor],
|
||||
tensor_send_prev: Optional[torch.Tensor],
|
||||
recv_prev: bool,
|
||||
recv_next: bool,
|
||||
tensor_shape: Shape,
|
||||
config: ModelParallelConfig,
|
||||
wait_on_reqs: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Communicate tensors between stages. Used as helper method in other
|
||||
communication methods that are used in megatron/schedules.py.
|
||||
|
||||
Arguments:
|
||||
tensor_send_next (torch.Tensor, optional):
|
||||
Tensor to send to next rank (no tensor sent if None)
|
||||
|
||||
tensor_send_prev (torch.Tensor, optional):
|
||||
Tensor to send to prev rank (no tensor sent if None)
|
||||
|
||||
recv_prev (boolean, required):
|
||||
whether tensor should be received from previous rank.
|
||||
|
||||
recv_next (boolean, required):
|
||||
whether tensor should be received from next rank.
|
||||
|
||||
tensor_shape (List[int] or torch.Size, required):
|
||||
shape of tensor to receive (this method assumes that all
|
||||
tensors sent and received in a single function call are
|
||||
the same shape).
|
||||
|
||||
wait_on_reqs (boolean, optional, default=False):
|
||||
For non-batched p2p communication, wait on each request
|
||||
before returning.
|
||||
|
||||
Returns:
|
||||
tuple containing
|
||||
|
||||
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
|
||||
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
|
||||
|
||||
"""
|
||||
|
||||
# Create placeholder tensors for receive in forward and backward directions
|
||||
# if needed.
|
||||
tensor_recv_prev = None
|
||||
tensor_recv_next = None
|
||||
args = get_args()
|
||||
|
||||
tensor_shape = tensor_shape if args.optimized_pipeline and (recv_prev or recv_next) \
|
||||
else (args.seq_length, args.micro_batch_size, args.hidden_size)
|
||||
|
||||
if args.sequence_parallel:
|
||||
seq_length = args.seq_length // get_tensor_model_parallel_world_size()
|
||||
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
|
||||
|
||||
if not config.variable_seq_lengths:
|
||||
recv_prev_shape = tensor_shape
|
||||
recv_next_shape = tensor_shape
|
||||
else:
|
||||
recv_prev_shape, recv_next_shape = _communicate_shapes(
|
||||
tensor_send_next, tensor_send_prev, recv_prev, recv_next, config
|
||||
)
|
||||
|
||||
if args.scatter_gather_tensors_in_pipeline and not config.sequence_parallel:
|
||||
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
|
||||
get_tensor_model_parallel_world_size()
|
||||
recv_prev_shape = tensor_chunk_shape
|
||||
recv_next_shape = tensor_chunk_shape
|
||||
|
||||
if recv_prev:
|
||||
if config.pipeline_dtype is None:
|
||||
raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
|
||||
if tensor_shape is None:
|
||||
raise RuntimeError(
|
||||
"tensor_shape must be specified if recv_prev is True. "
|
||||
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
|
||||
)
|
||||
tensor_recv_prev = torch.empty(
|
||||
recv_prev_shape,
|
||||
requires_grad=True,
|
||||
device=get_accelerator().current_device(),
|
||||
dtype=config.pipeline_dtype,
|
||||
)
|
||||
if recv_next:
|
||||
if config.pipeline_dtype is None:
|
||||
raise RuntimeError("dtype must be provided if recv_next is True")
|
||||
if tensor_shape is None:
|
||||
raise RuntimeError(
|
||||
"tensor_shape must be specified if recv_next is True. "
|
||||
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
|
||||
)
|
||||
tensor_recv_next = torch.empty(
|
||||
recv_next_shape,
|
||||
requires_grad=True,
|
||||
device=get_accelerator().current_device(),
|
||||
dtype=config.pipeline_dtype,
|
||||
)
|
||||
|
||||
# Split tensor into smaller chunks if using scatter-gather optimization.
|
||||
if args.scatter_gather_tensors_in_pipeline and not config.sequence_parallel:
|
||||
if tensor_send_next is not None:
|
||||
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
|
||||
|
||||
if tensor_send_prev is not None:
|
||||
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
|
||||
|
||||
# Send tensors in both the forward and backward directions as appropriate.
|
||||
if config.use_ring_exchange_p2p:
|
||||
def _ring_exchange_wrapper(**kwargs):
|
||||
torch.distributed.ring_exchange(**kwargs)
|
||||
return []
|
||||
|
||||
p2p_func = _ring_exchange_wrapper
|
||||
elif config.batch_p2p_comm:
|
||||
if not wait_on_reqs:
|
||||
raise Exception("Wait_on_reqs should be true")
|
||||
p2p_func = _batched_p2p_ops
|
||||
else:
|
||||
p2p_func = _p2p_ops
|
||||
|
||||
reqs = p2p_func(
|
||||
tensor_send_prev=tensor_send_prev,
|
||||
tensor_recv_prev=tensor_recv_prev,
|
||||
tensor_send_next=tensor_send_next,
|
||||
tensor_recv_next=tensor_recv_next,
|
||||
group=get_pipeline_model_parallel_group(),
|
||||
)
|
||||
|
||||
if wait_on_reqs and len(reqs) > 0:
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
reqs = None
|
||||
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
# User should assert that we have a modern enough PyTorch to not need this
|
||||
get_accelerator().synchronize()
|
||||
|
||||
# If using scatter-gather optimization, gather smaller chunks.
|
||||
if args.scatter_gather_tensors_in_pipeline and not config.sequence_parallel:
|
||||
if recv_prev:
|
||||
tensor_recv_prev = mpu.gather_split_1d_tensor(
|
||||
tensor_recv_prev).view(tensor_shape).requires_grad_()
|
||||
|
||||
if recv_next:
|
||||
tensor_recv_next = mpu.gather_split_1d_tensor(
|
||||
tensor_recv_next).view(tensor_shape).requires_grad_()
|
||||
|
||||
return tensor_recv_prev, tensor_recv_next, reqs
|
||||
|
||||
|
||||
def async_communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
|
||||
args = get_args()
|
||||
|
||||
# Create placeholder tensors for receive in forward and backward directions
|
||||
# if needed.
|
||||
tensor_recv_prev = None
|
||||
tensor_recv_next = None
|
||||
|
||||
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
|
||||
|
||||
if args.sequence_parallel:
|
||||
seq_length = args.seq_length // get_tensor_model_parallel_world_size()
|
||||
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
|
||||
|
||||
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
|
||||
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
|
||||
get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
tensor_chunk_shape = tensor_shape
|
||||
dtype = args.params_dtype
|
||||
if args.fp32_residual_connection:
|
||||
dtype = torch.float
|
||||
if recv_prev:
|
||||
tensor_recv_prev = torch.empty(tensor_chunk_shape,
|
||||
requires_grad=True,
|
||||
device=get_accelerator().current_device_name(),
|
||||
dtype=dtype)
|
||||
if recv_next:
|
||||
tensor_recv_next = torch.empty(tensor_chunk_shape,
|
||||
requires_grad=True,
|
||||
device=get_accelerator().current_device_name(),
|
||||
dtype=dtype)
|
||||
|
||||
# Split tensor into smaller chunks if using scatter-gather optimization.
|
||||
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
|
||||
if tensor_send_next is not None:
|
||||
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
|
||||
|
||||
if tensor_send_prev is not None:
|
||||
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
|
||||
|
||||
ops = []
|
||||
if tensor_send_prev is not None:
|
||||
torch.distributed.isend(tensor_send_prev,
|
||||
get_pipeline_model_parallel_prev_rank(),
|
||||
group=get_pipeline_model_parallel_prev_rank_group())
|
||||
if tensor_recv_prev is not None:
|
||||
ops.append(torch.distributed.irecv(tensor_recv_prev,
|
||||
get_pipeline_model_parallel_prev_rank(),
|
||||
group=get_pipeline_model_parallel_prev_rank_group()))
|
||||
if tensor_send_next is not None:
|
||||
torch.distributed.isend(tensor_send_next,
|
||||
get_pipeline_model_parallel_next_rank(),
|
||||
group=get_pipeline_model_parallel_next_rank_group())
|
||||
if tensor_recv_next is not None:
|
||||
ops.append(torch.distributed.irecv(tensor_recv_next,
|
||||
get_pipeline_model_parallel_next_rank(),
|
||||
group=get_pipeline_model_parallel_next_rank_group()))
|
||||
return tensor_recv_prev, tensor_recv_next, ops
|
||||
|
||||
|
||||
def recv_gather(tensor_recv):
|
||||
args = get_args()
|
||||
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
|
||||
|
||||
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
|
||||
tensor_recv = mpu.gather_split_1d_tensor(
|
||||
tensor_recv).view(tensor_shape).requires_grad_()
|
||||
|
||||
return tensor_recv
|
||||
|
||||
|
||||
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
|
||||
""" Receive tensor from previous rank in pipeline (forward receive).
|
||||
|
||||
|
||||
See _communicate for argument details.
|
||||
"""
|
||||
|
||||
if core.parallel_state.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
if config.timers is not None:
|
||||
config.timers('forward-recv', log_level=2).start()
|
||||
input_tensor, _, _ = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=True,
|
||||
recv_next=False,
|
||||
tensor_shape=tensor_shape,
|
||||
config=config,
|
||||
)
|
||||
if config.timers is not None:
|
||||
config.timers('forward-recv').stop()
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
|
||||
"""Receive tensor from next rank in pipeline (backward receive).
|
||||
|
||||
See _communicate for argument details.
|
||||
"""
|
||||
if core.parallel_state.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
if config.timers is not None:
|
||||
config.timers('backward-recv', log_level=2).start()
|
||||
_, output_tensor_grad, _ = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=True,
|
||||
tensor_shape=tensor_shape,
|
||||
config=config,
|
||||
)
|
||||
if config.timers is not None:
|
||||
config.timers('backward-recv').stop()
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_tensor: torch.Tensor, config: ModelParallelConfig) -> None:
|
||||
"""Send tensor to next rank in pipeline (forward send).
|
||||
|
||||
See _communicate for argument details.
|
||||
"""
|
||||
|
||||
if not core.parallel_state.is_pipeline_last_stage():
|
||||
if config.timers is not None:
|
||||
config.timers('forward-send', log_level=2).start()
|
||||
_communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=False,
|
||||
tensor_shape=None,
|
||||
config=config,
|
||||
)
|
||||
if config.timers is not None:
|
||||
config.timers('forward-send').stop()
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad: torch.Tensor, config: ModelParallelConfig) -> None:
|
||||
"""Send tensor to previous rank in pipeline (backward send).
|
||||
|
||||
See _communicate for argument details.
|
||||
"""
|
||||
if not core.parallel_state.is_pipeline_first_stage():
|
||||
if config.timers is not None:
|
||||
config.timers('backward-send', log_level=2).start()
|
||||
_communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=False,
|
||||
recv_next=False,
|
||||
tensor_shape=None,
|
||||
config=config,
|
||||
)
|
||||
if config.timers is not None:
|
||||
config.timers('backward-send').stop()
|
||||
|
||||
|
||||
def send_forward_recv_backward(
|
||||
output_tensor: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
|
||||
) -> torch.Tensor:
|
||||
"""Batched send and recv with next rank in pipeline.
|
||||
|
||||
See _communicate for argument details.
|
||||
"""
|
||||
if core.parallel_state.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
if config.timers is not None:
|
||||
config.timers('forward-send-backward-recv', log_level=2).start()
|
||||
_, output_tensor_grad, _ = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=True,
|
||||
tensor_shape=tensor_shape,
|
||||
config=config,
|
||||
)
|
||||
if config.timers is not None:
|
||||
config.timers('forward-send-backward-recv').stop()
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_backward_recv_forward(
|
||||
input_tensor_grad: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
|
||||
) -> torch.Tensor:
|
||||
"""Batched send and recv with previous rank in pipeline.
|
||||
|
||||
See _communicate for argument details.
|
||||
"""
|
||||
if core.parallel_state.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
if config.timers is not None:
|
||||
config.timers('backward-send-forward-recv', log_level=2).start()
|
||||
input_tensor, _, _ = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=True,
|
||||
recv_next=False,
|
||||
tensor_shape=tensor_shape,
|
||||
config=config,
|
||||
)
|
||||
if config.timers is not None:
|
||||
config.timers('backward-send-forward-recv').stop()
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_forward_recv_forward(
|
||||
output_tensor: torch.Tensor,
|
||||
recv_prev: bool,
|
||||
tensor_shape: Shape,
|
||||
config: ModelParallelConfig,
|
||||
overlap_p2p_comm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Batched recv from previous rank and send to next rank in pipeline.
|
||||
|
||||
See _communicate for argument details.
|
||||
"""
|
||||
if config.timers is not None:
|
||||
config.timers('forward-send-forward-recv', log_level=2).start()
|
||||
input_tensor, _, wait_handles = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=False,
|
||||
tensor_shape=tensor_shape,
|
||||
wait_on_reqs=(not overlap_p2p_comm),
|
||||
config=config,
|
||||
)
|
||||
if config.timers is not None:
|
||||
config.timers('forward-send-forward-recv').stop()
|
||||
if overlap_p2p_comm:
|
||||
return input_tensor, wait_handles
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_backward_recv_backward(
|
||||
input_tensor_grad: torch.Tensor,
|
||||
recv_next: bool,
|
||||
tensor_shape: Shape,
|
||||
config: ModelParallelConfig,
|
||||
overlap_p2p_comm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Batched recv from next rank and send to previous rank in pipeline.
|
||||
|
||||
See _communicate for argument details.
|
||||
"""
|
||||
if config.timers is not None:
|
||||
config.timers('backward-send-backward-recv', log_level=2).start()
|
||||
_, output_tensor_grad, wait_handles = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=False,
|
||||
recv_next=recv_next,
|
||||
tensor_shape=tensor_shape,
|
||||
wait_on_reqs=(not overlap_p2p_comm),
|
||||
config=config,
|
||||
)
|
||||
if config.timers is not None:
|
||||
config.timers('backward-send-backward-recv').stop()
|
||||
if overlap_p2p_comm:
|
||||
return output_tensor_grad, wait_handles
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward_backward_recv_forward_backward(
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor_grad: torch.Tensor,
|
||||
recv_prev: bool,
|
||||
recv_next: bool,
|
||||
tensor_shape: Shape,
|
||||
config: ModelParallelConfig,
|
||||
) -> torch.Tensor:
|
||||
"""Batched send and recv with previous and next ranks in pipeline.
|
||||
|
||||
See _communicate for argument details.
|
||||
"""
|
||||
if config.timers is not None:
|
||||
config.timers('forward-backward-send-forward-backward-recv', log_level=2).start()
|
||||
input_tensor, output_tensor_grad, _ = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next,
|
||||
tensor_shape=tensor_shape,
|
||||
config=config,
|
||||
)
|
||||
if config.timers is not None:
|
||||
config.timers('forward-backward-send-forward-backward-recv').stop()
|
||||
return input_tensor, output_tensor_grad
|
||||
|
||||
|
||||
def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config):
|
||||
"""Communicate tensor shapes between stages. Used to communicate
|
||||
tensor shapes before the actual tensor communication happens.
|
||||
This is required when the sequence lengths across micro batches
|
||||
are not uniform.
|
||||
|
||||
Takes the following arguments:
|
||||
tensor_send_next: tensor to send to next rank (no tensor sent if
|
||||
set to None).
|
||||
tensor_send_prev: tensor to send to prev rank (no tensor sent if
|
||||
set to None).
|
||||
recv_prev: boolean for whether tensor should be received from
|
||||
previous rank.
|
||||
recv_next: boolean for whether tensor should be received from
|
||||
next rank.
|
||||
Returns:
|
||||
(recv_prev_shape, recv_next_shape)
|
||||
"""
|
||||
|
||||
recv_prev_shape_tensor = None
|
||||
recv_next_shape_tensor = None
|
||||
send_prev_shape_tensor = None
|
||||
send_next_shape_tensor = None
|
||||
if recv_prev:
|
||||
recv_prev_shape_tensor = torch.empty((3),
|
||||
device=get_accelerator().current_device(),
|
||||
dtype=torch.int64)
|
||||
if recv_next:
|
||||
recv_next_shape_tensor = torch.empty((3),
|
||||
device=get_accelerator().current_device(),
|
||||
dtype=torch.int64)
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_shape_tensor = torch.tensor(tensor_send_prev.size(),
|
||||
device=get_accelerator().current_device(),
|
||||
dtype=torch.int64)
|
||||
if tensor_send_next is not None:
|
||||
send_next_shape_tensor = torch.tensor(tensor_send_next.size(),
|
||||
device=get_accelerator().current_device(),
|
||||
dtype=torch.int64)
|
||||
|
||||
if config.use_ring_exchange_p2p:
|
||||
torch.distributed.ring_exchange(
|
||||
tensor_send_prev=send_prev_shape_tensor,
|
||||
tensor_recv_prev=recv_prev_shape_tensor,
|
||||
tensor_send_next=send_next_shape_tensor,
|
||||
tensor_recv_next=recv_next_shape_tensor,
|
||||
group=get_pipeline_model_parallel_group(),
|
||||
)
|
||||
else:
|
||||
ops = []
|
||||
if send_prev_shape_tensor is not None:
|
||||
send_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend,
|
||||
send_prev_shape_tensor,
|
||||
get_pipeline_model_parallel_prev_rank(),
|
||||
)
|
||||
ops.append(send_prev_op)
|
||||
if recv_prev_shape_tensor is not None:
|
||||
recv_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv,
|
||||
recv_prev_shape_tensor,
|
||||
get_pipeline_model_parallel_prev_rank(),
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
if recv_next_shape_tensor is not None:
|
||||
recv_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv,
|
||||
recv_next_shape_tensor,
|
||||
get_pipeline_model_parallel_next_rank(),
|
||||
)
|
||||
ops.append(recv_next_op)
|
||||
if send_next_shape_tensor is not None:
|
||||
send_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend,
|
||||
send_next_shape_tensor,
|
||||
get_pipeline_model_parallel_next_rank(),
|
||||
)
|
||||
ops.append(send_next_op)
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = torch.distributed.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
# should take this out once the bug with batch_isend_irecv is resolved.
|
||||
get_accelerator().synchronize()
|
||||
|
||||
recv_prev_shape = [0, 0, 0]
|
||||
if recv_prev_shape_tensor is not None:
|
||||
recv_prev_shape = recv_prev_shape_tensor.tolist()
|
||||
|
||||
recv_next_shape = [0, 0, 0]
|
||||
if recv_next_shape_tensor is not None:
|
||||
recv_next_shape = recv_next_shape_tensor.tolist()
|
||||
|
||||
return recv_prev_shape, recv_next_shape
|
||||
|
||||
|
||||
def _batched_p2p_ops(
|
||||
*,
|
||||
tensor_send_prev: Optional[torch.Tensor],
|
||||
tensor_recv_prev: Optional[torch.Tensor],
|
||||
tensor_send_next: Optional[torch.Tensor],
|
||||
tensor_recv_next: Optional[torch.Tensor],
|
||||
group: torch.distributed.ProcessGroup
|
||||
):
|
||||
ops = []
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend,
|
||||
tensor_send_prev,
|
||||
get_pipeline_model_parallel_prev_rank(),
|
||||
group)
|
||||
ops.append(send_prev_op)
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv,
|
||||
tensor_recv_prev,
|
||||
get_pipeline_model_parallel_prev_rank(),
|
||||
group,
|
||||
)
|
||||
ops.append(recv_prev_op)
|
||||
if tensor_send_next is not None:
|
||||
send_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend,
|
||||
tensor_send_next,
|
||||
get_pipeline_model_parallel_next_rank(),
|
||||
group,
|
||||
)
|
||||
ops.append(send_next_op)
|
||||
if tensor_recv_next is not None:
|
||||
recv_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv,
|
||||
tensor_recv_next,
|
||||
get_pipeline_model_parallel_next_rank(),
|
||||
group,
|
||||
)
|
||||
ops.append(recv_next_op)
|
||||
|
||||
if get_pipeline_model_parallel_rank() % 2 == 1:
|
||||
ops.reverse()
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = torch.distributed.batch_isend_irecv(ops)
|
||||
else:
|
||||
reqs = []
|
||||
return reqs
|
||||
|
||||
|
||||
def _p2p_ops(
|
||||
*,
|
||||
tensor_send_prev: Optional[torch.Tensor],
|
||||
tensor_recv_prev: Optional[torch.Tensor],
|
||||
tensor_send_next: Optional[torch.Tensor],
|
||||
tensor_recv_next: Optional[torch.Tensor],
|
||||
group: torch.distributed.ProcessGroup
|
||||
):
|
||||
reqs = []
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
if get_pipeline_model_parallel_rank() % 2 == 0:
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_req = torch.distributed.isend(
|
||||
tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
|
||||
)
|
||||
reqs.append(send_prev_req)
|
||||
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_req = torch.distributed.irecv(
|
||||
tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
|
||||
)
|
||||
reqs.append(recv_prev_req)
|
||||
|
||||
if tensor_recv_next is not None:
|
||||
recv_next_req = torch.distributed.irecv(
|
||||
tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
|
||||
)
|
||||
reqs.append(recv_next_req)
|
||||
|
||||
if tensor_send_next is not None:
|
||||
send_next_req = torch.distributed.isend(
|
||||
tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
|
||||
)
|
||||
reqs.append(send_next_req)
|
||||
|
||||
|
||||
else:
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_req = torch.distributed.isend(
|
||||
tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group,
|
||||
)
|
||||
reqs.append(send_prev_req)
|
||||
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_req = torch.distributed.irecv(
|
||||
tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group,
|
||||
)
|
||||
reqs.append(recv_prev_req)
|
||||
|
||||
if tensor_recv_next is not None:
|
||||
recv_next_req = torch.distributed.irecv(
|
||||
tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group,
|
||||
)
|
||||
reqs.append(recv_next_req)
|
||||
|
||||
if tensor_send_next is not None:
|
||||
send_next_req = torch.distributed.isend(
|
||||
tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group,
|
||||
)
|
||||
reqs.append(send_next_req)
|
||||
|
||||
return reqs
|
1814
ascendspeed/core/pipeline_parallel/schedules.py
Normal file
1814
ascendspeed/core/pipeline_parallel/schedules.py
Normal file
File diff suppressed because it is too large
Load Diff
3
ascendspeed/core/transformer/__init__.py
Normal file
3
ascendspeed/core/transformer/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
from .transformer_config import TransformerConfig
|
242
ascendspeed/core/transformer/transformer_config.py
Normal file
242
ascendspeed/core/transformer/transformer_config.py
Normal file
@ -0,0 +1,242 @@
|
||||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ascendspeed.core import ModelParallelConfig
|
||||
from ascendspeed.core.utils import init_method_normal, scaled_init_method_normal
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig(ModelParallelConfig):
|
||||
"""Configuration object for megatron-core transformers.
|
||||
|
||||
Attributes:
|
||||
|
||||
# model architecture
|
||||
num_layers (int): Number of transformer layers in a transformer block.
|
||||
hidden_size (int): Transformer hidden size.
|
||||
ffn_hidden_size (int): Transformer Feed-Forward Network hidden size.
|
||||
This is set to 4*hidden_size if not provided. Defaults to None.')
|
||||
num_attention_heads (int): Number of transformer attention heads.
|
||||
kv_channels (int): Projection weights dimension in multi-head attention.
|
||||
This is set to hidden_size // num_attention_heads if not provided.
|
||||
Defaults to None.
|
||||
num_query_groups (int): Number of query groups for group query attention. If None, normal attention is used.
|
||||
|
||||
hidden_dropout (float): Dropout probability for transformer hidden state. Defaults to 0.1.
|
||||
attention_dropout (float): Post attention dropout probability. Defaults to 0.1.
|
||||
fp32_residual_connection (bool): If true, move residual connections to fp32.
|
||||
apply_residual_connection_post_layernorm (bool): If true, uses the original BERT residule connection ordering.
|
||||
Defaults to False.
|
||||
layernorm_epsilon (float): Layernorm epsilon. Defaults to 1e-5.
|
||||
|
||||
layernorm_zero_centered_gamma (bool): if set to 'True', the LayerNorm is adjusted to center the gamma values
|
||||
around 0. This improves numerical stability. Defaults to False.
|
||||
|
||||
add_bias_linear (bool): Include a bias term in all linear layers (QKV projections, after core attention, and two
|
||||
in MLP layer). Default is True.
|
||||
|
||||
gated_linear_unit (bool): Use a gated linear unit for the first linear layer in the MLP. Defaults to False.
|
||||
|
||||
activation_func (Callable): Activation function to use for the non-linearity in the MLP. Defaults to F.gelu.
|
||||
|
||||
# initialization
|
||||
init_method (Callable): Method to initialize weights. Note that bias is always set to
|
||||
zero. Should be a function that takes a single Tensor and
|
||||
initializes it. Defaults to
|
||||
megatron.core.utils.init_method_normal(init_method_std) which is
|
||||
torch.nn.init.normal_ with mean=0.0 and std=init_method_Std.
|
||||
|
||||
output_layer_init_method (Callable): Method to initialize weights of the output layer of
|
||||
both attention and MLP blocks. Defaults to
|
||||
megatron.core.utils.scaled_init_method_normal(init_method_std)
|
||||
which is torch.nn.init.normal_ with mean=0.0 and
|
||||
std=init_method_std / math.sqrt(2.0 * num_layers).
|
||||
|
||||
init_method_std (float): Standard deviation of the zero mean normal for the default
|
||||
initialization method, not used if init_method and
|
||||
output_layer_init_method are provided. Defaults to 0.02.
|
||||
|
||||
# mixed-precision
|
||||
apply_query_key_layer_scaling (bool): If true, scale Q * K^T by 1 / layer-number. Defaults to True.
|
||||
attention_softmax_in_fp32 (bool): If true, run attention masking and softmax in fp32.
|
||||
This should be true if apply_query_key_layer_scaling is true.
|
||||
|
||||
# fusion
|
||||
bias_gelu_fustion (bool): If true, fuses bias and gelu. Defaults to False.
|
||||
masked_softmax_fusion (bool): If true, uses softmax fusion.
|
||||
persist_layer_norm (bool): If true, uses the persistent fused layer norm kernel.
|
||||
This kernel only supports a fixed set of hidden sizes.
|
||||
Defaults to False.
|
||||
bias_dropout_fusion (bool): If true, uses bias dropout fusion.
|
||||
|
||||
# activation recomputation
|
||||
|
||||
recompute_granularity (str): megatron-core supports 'selective' activation checkpointing where only the memory
|
||||
intensive part of attention is checkpointed. These memory intensive activations
|
||||
are also less compute intensive which makes activation checkpointing more efficient
|
||||
for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer
|
||||
Models: https://arxiv.org/abs/2205.05198 for more details. 'full' will checkpoint
|
||||
the entire transformer layer. Must be 'selective' or 'full'. Defaults to None.
|
||||
|
||||
recompute_method (str): uniform will uniformly divide the total number of transformer layers in a transformer
|
||||
block and recompute the input activation of each divided chunk at the specified
|
||||
granularity. block will recompute the input activations for only a set number of
|
||||
transformer layers per pipeline stage. The rest of the layers in the pipeline stage
|
||||
will not have any activations recomputed. Must be 'uniform' or 'block'. Defaults to
|
||||
None.
|
||||
|
||||
recompute_num_layers (int): When recompute_method is uniform, recompute_num_layers is the number of transformer
|
||||
layers in each uniformly divided recompute unit. When recompute_method is block,
|
||||
recompute_num_layers is the number of transformer layers to recompute within each
|
||||
pipeline stage. Defaults to None.
|
||||
|
||||
distribute_saved_activations (bool): If true, distribute recomputed activations across the model parallel
|
||||
group. Defaults to None.
|
||||
|
||||
# fp8 related (via Transformer Engine). For detailed info, refer the the Transformer Engine docs at
|
||||
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html
|
||||
|
||||
fp8 (bool): Enables the use of FP8 precision through Transformer Engine.
|
||||
|
||||
fp8_e4m3 (bool): Enables the use of FP8 tensors in e4m3 format for both forward and backward passes.
|
||||
|
||||
fp8_margin (int): Enables the use of FP8 tensors in e4m3 format in the forward pass and e5m2 format in the
|
||||
backward pass.
|
||||
|
||||
fp8_interval (int): Controls how often the scaling factor is recomputed.
|
||||
|
||||
fp8_amax_history_len (int): The length of the amax history window used for scaling factor computation.
|
||||
|
||||
fp8_amax_compute_algo (str): Algorithm used for choosing the `amax` value for the scaling factor computation.
|
||||
There are 2 predefined choices: `max` chooses the largest `amax` in the history
|
||||
window, while `most_recent` always chooses the most recently seen value.
|
||||
|
||||
"""
|
||||
|
||||
# model architecture
|
||||
num_layers: int = 0
|
||||
hidden_size: int = 0
|
||||
num_attention_heads: int = 0
|
||||
num_query_groups: int = None
|
||||
|
||||
ffn_hidden_size: int = None
|
||||
kv_channels: int = None
|
||||
hidden_dropout: float = 0.1
|
||||
attention_dropout: float = 0.1
|
||||
fp32_residual_connection: bool = False
|
||||
apply_residual_connection_post_layernorm: bool = False
|
||||
layernorm_epsilon: float = 1e-5
|
||||
layernorm_zero_centered_gamma: bool = False
|
||||
add_bias_linear: bool = True
|
||||
gated_linear_unit: bool = False
|
||||
activation_func: Callable = F.gelu
|
||||
|
||||
# initialization
|
||||
init_method: Callable = None
|
||||
output_layer_init_method: Callable = None
|
||||
init_method_std: float = 0.02
|
||||
|
||||
# mixed-precision
|
||||
apply_query_key_layer_scaling: bool = True
|
||||
attention_softmax_in_fp32: bool = True
|
||||
|
||||
# communication
|
||||
|
||||
# fusion
|
||||
bias_gelu_fusion: bool = False # this should be bias_activation_fusion ?
|
||||
masked_softmax_fusion: bool = False
|
||||
persist_layer_norm: bool = False
|
||||
bias_dropout_fusion: bool = False # this should be bias_dropout_add_fusion?
|
||||
|
||||
# activation recomputation
|
||||
recompute_granularity: str = None
|
||||
recompute_method: str = None
|
||||
recompute_num_layers: int = None
|
||||
distribute_saved_activations: bool = None
|
||||
|
||||
# fp8 related
|
||||
fp8: bool = False
|
||||
fp8_e4m3: bool = False
|
||||
fp8_margin: int = 0
|
||||
fp8_interval: int = 1
|
||||
fp8_amax_history_len: int = 1
|
||||
fp8_amax_compute_algo: str = "most_recent"
|
||||
|
||||
def __post_init__(self):
|
||||
""" Python dataclass method that is used to modify attributes after initialization.
|
||||
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
|
||||
"""
|
||||
super().__post_init__()
|
||||
if self.fp16 and self.bf16:
|
||||
raise ValueError(f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.')
|
||||
|
||||
if self.num_attention_heads % self.tensor_model_parallel_size != 0:
|
||||
raise ValueError(f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
|
||||
f"tensor_model_parallel_size ({self.tensor_model_parallel_size}).")
|
||||
|
||||
if self.ffn_hidden_size is None:
|
||||
self.ffn_hidden_size = 4 * self.hidden_size
|
||||
|
||||
if self.kv_channels is None:
|
||||
self.kv_channels = self.hidden_size // self.num_attention_heads
|
||||
|
||||
if self.num_query_groups is None:
|
||||
self.num_query_groups = self.num_attention_heads
|
||||
|
||||
if self.num_query_groups % self.tensor_model_parallel_size != 0:
|
||||
raise ValueError(f"num_query_groups ({self.num_query_groups}) must be a multiple of "
|
||||
f"tensor_model_parallel_size ({self.tensor_model_parallel_size}).")
|
||||
|
||||
if self.apply_query_key_layer_scaling:
|
||||
self.attention_softmax_in_fp32 = True
|
||||
|
||||
if self.recompute_granularity is not None:
|
||||
self.__recompute_granularity_init()
|
||||
|
||||
if self.apply_query_key_layer_scaling:
|
||||
self.attention_softmax_in_fp32 = True
|
||||
|
||||
if self.bias_gelu_fusion:
|
||||
if not self.add_bias_linear:
|
||||
raise ValueError("When bias_gelu_fusion is True, add_bias_linear must also be True.")
|
||||
|
||||
if self.activation_func != F.gelu:
|
||||
raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.')
|
||||
|
||||
if self.init_method is None:
|
||||
self.init_method = init_method_normal(self.init_method_std)
|
||||
|
||||
if self.output_layer_init_method is None:
|
||||
self.output_layer_init_method = scaled_init_method_normal(self.init_method_std, self.num_layers)
|
||||
|
||||
def __recompute_granularity_init(self):
|
||||
if self.recompute_granularity not in ['full', 'selective']:
|
||||
raise ValueError(
|
||||
f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".')
|
||||
|
||||
if self.recompute_method is not None:
|
||||
if self.recompute_method not in ['block', 'uniform']:
|
||||
raise ValueError(f'recompute_method: {self.recompute_method} must be "block" or "uniform".')
|
||||
elif self.recompute_granularity != 'selective':
|
||||
raise ValueError(f'Using recompute_granularity:\
|
||||
{self.recompute_granularity} so recompute_method must be "block" or "uniform"')
|
||||
|
||||
if self.recompute_num_layers is None:
|
||||
raise ValueError(f'When using recompute_granularity: {self.recompute_granularity}'
|
||||
f' so recompute_num_layers must be between '
|
||||
f'1 and num_layers_per_pipeline_rank:'
|
||||
f' {self.num_layers // self.pipeline_model_parallel_size}')
|
||||
|
||||
if self.distribute_saved_activations and self.sequence_parallel_enabled:
|
||||
raise ValueError(f'distribute_saved_activations: {self.distribute_saved_activations} '
|
||||
f'must be false when sequence parallel is enabled: {self.sequence_parallel_enabled}')
|
||||
|
||||
if self.virtual_pipeline_model_parallel_size is not None:
|
||||
if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0:
|
||||
raise ValueError(f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size '
|
||||
f'{self.virtual_pipeline_model_parallel_size}')
|
@ -1,10 +1,12 @@
|
||||
"""Utility functions used throughout AscendSpeed core"""
|
||||
|
||||
import math
|
||||
from functools import reduce
|
||||
import operator
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
@ -19,6 +21,42 @@ def divide(numerator, denominator):
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def get_attr_wrapped_model(model, attr, allow_none=True):
|
||||
"""Get an attribute from a wrapped model"""
|
||||
if isinstance(model, list):
|
||||
raise RuntimeError("_get_attr_wrapped_model given a list of models")
|
||||
|
||||
if allow_none:
|
||||
|
||||
def condition(model, attr):
|
||||
return not hasattr(model, attr)
|
||||
|
||||
else:
|
||||
|
||||
def condition(model, attr):
|
||||
return getattr(model, attr, None) is None
|
||||
|
||||
while condition(model, attr):
|
||||
if not hasattr(model, "module"):
|
||||
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
|
||||
|
||||
model = model.module
|
||||
return getattr(model, attr)
|
||||
|
||||
|
||||
def get_model_type(model):
|
||||
return get_attr_wrapped_model(model, 'model_type')
|
||||
|
||||
|
||||
def get_model_config(model):
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed.arguments import core_transformer_config_from_args
|
||||
args = get_args()
|
||||
if args.deepspeed:
|
||||
return core_transformer_config_from_args(get_args())
|
||||
return get_attr_wrapped_model(model, 'config', allow_none=False)
|
||||
|
||||
|
||||
class GlobalMemoryBuffer:
|
||||
"""Global buffer to avoid dynamic memory allocations.
|
||||
Caller should ensure that buffers of the same name
|
||||
@ -51,9 +89,9 @@ def _kernel_make_viewless_tensor(inp, requires_grad):
|
||||
'''
|
||||
out = torch.empty(
|
||||
(1,),
|
||||
dtype = inp.dtype,
|
||||
device = inp.device,
|
||||
requires_grad = requires_grad,
|
||||
dtype=inp.dtype,
|
||||
device=inp.device,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
with torch.no_grad():
|
||||
out.set_(inp.data)
|
||||
@ -116,4 +154,23 @@ def split_tensor_along_last_dim(tensor, num_partitions,
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
return tensor_list
|
||||
|
||||
|
||||
def init_method_normal(sigma):
|
||||
"""Init method based on N(0, sigma)."""
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def scaled_init_method_normal(sigma, num_layers):
|
||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||
std = sigma / math.sqrt(2.0 * num_layers)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
|
||||
return init_
|
||||
|
@ -1,199 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""BERT Style dataset."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ascendspeed import (
|
||||
get_args,
|
||||
get_tokenizer,
|
||||
mpu,
|
||||
print_rank_0
|
||||
)
|
||||
from ascendspeed.data.dataset_utils import (
|
||||
get_samples_mapping,
|
||||
get_a_and_b_segments,
|
||||
truncate_segments,
|
||||
create_tokens_and_tokentypes,
|
||||
create_masked_lm_predictions
|
||||
)
|
||||
|
||||
|
||||
class BertDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, name, indexed_dataset, data_prefix,
|
||||
num_epochs, max_num_samples, masked_lm_prob,
|
||||
max_seq_length, short_seq_prob, seed, binary_head):
|
||||
|
||||
# Params to store.
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.masked_lm_prob = masked_lm_prob
|
||||
self.max_seq_length = max_seq_length
|
||||
self.binary_head = binary_head
|
||||
|
||||
# Dataset.
|
||||
self.indexed_dataset = indexed_dataset
|
||||
|
||||
# Build the samples mapping.
|
||||
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
self.max_seq_length - 3, # account for added tokens
|
||||
short_seq_prob,
|
||||
self.seed,
|
||||
self.name,
|
||||
self.binary_head)
|
||||
|
||||
# Vocab stuff.
|
||||
tokenizer = get_tokenizer()
|
||||
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
||||
self.cls_id = tokenizer.cls
|
||||
self.sep_id = tokenizer.sep
|
||||
self.mask_id = tokenizer.mask
|
||||
self.pad_id = tokenizer.pad
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_mapping.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
args = get_args()
|
||||
start_idx, end_idx, seq_length = self.samples_mapping[idx]
|
||||
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
|
||||
# Note that this rng state should be numpy and not python since
|
||||
# python randint is inclusive whereas the numpy one is exclusive.
|
||||
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
|
||||
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
|
||||
train_sample = build_training_sample(sample, seq_length,
|
||||
self.max_seq_length, # needed for padding
|
||||
self.vocab_id_list,
|
||||
self.vocab_id_to_token_dict,
|
||||
self.cls_id, self.sep_id,
|
||||
self.mask_id, self.pad_id,
|
||||
self.masked_lm_prob, np_rng,
|
||||
self.binary_head)
|
||||
if args.return_data_index:
|
||||
train_sample['index'] = np.array([idx], dtype=np.int64)
|
||||
return train_sample
|
||||
|
||||
|
||||
|
||||
|
||||
def build_training_sample(sample,
|
||||
target_seq_length, max_seq_length,
|
||||
vocab_id_list, vocab_id_to_token_dict,
|
||||
cls_id, sep_id, mask_id, pad_id,
|
||||
masked_lm_prob, np_rng, binary_head):
|
||||
"""Biuld training sample.
|
||||
|
||||
Arguments:
|
||||
sample: A list of sentences in which each sentence is a list token ids.
|
||||
target_seq_length: Desired sequence length.
|
||||
max_seq_length: Maximum length of the sequence. All values are padded to
|
||||
this length.
|
||||
vocab_id_list: List of vocabulary ids. Used to pick a random id.
|
||||
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
|
||||
cls_id: Start of example id.
|
||||
sep_id: Separator id.
|
||||
mask_id: Mask token id.
|
||||
pad_id: Padding token id.
|
||||
masked_lm_prob: Probability to mask tokens.
|
||||
np_rng: Random number genenrator. Note that this rng state should be
|
||||
numpy and not python since python randint is inclusive for
|
||||
the opper bound whereas the numpy one is exclusive.
|
||||
"""
|
||||
|
||||
if binary_head:
|
||||
# We assume that we have at least two sentences in the sample
|
||||
assert len(sample) > 1
|
||||
assert target_seq_length <= max_seq_length
|
||||
|
||||
# Divide sample into two segments (A and B).
|
||||
if binary_head:
|
||||
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
|
||||
np_rng)
|
||||
else:
|
||||
tokens_a = []
|
||||
for j in range(len(sample)):
|
||||
tokens_a.extend(sample[j])
|
||||
tokens_b = []
|
||||
is_next_random = False
|
||||
|
||||
# Truncate to `target_sequence_length`.
|
||||
max_num_tokens = target_seq_length
|
||||
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
|
||||
len(tokens_b), max_num_tokens, np_rng)
|
||||
|
||||
# Build tokens and toketypes.
|
||||
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
|
||||
cls_id, sep_id)
|
||||
|
||||
# Masking.
|
||||
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
||||
(tokens, masked_positions, masked_labels, _, _) = create_masked_lm_predictions(
|
||||
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
|
||||
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
|
||||
|
||||
# Padding.
|
||||
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
|
||||
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
masked_labels, pad_id, max_seq_length)
|
||||
|
||||
train_sample = {
|
||||
'text': tokens_np,
|
||||
'types': tokentypes_np,
|
||||
'labels': labels_np,
|
||||
'is_random': int(is_next_random),
|
||||
'loss_mask': loss_mask_np,
|
||||
'padding_mask': padding_mask_np,
|
||||
'truncated': int(truncated)}
|
||||
return train_sample
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
masked_labels, pad_id, max_seq_length):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
# Some checks.
|
||||
num_tokens = len(tokens)
|
||||
padding_length = max_seq_length - num_tokens
|
||||
assert padding_length >= 0
|
||||
assert len(tokentypes) == num_tokens
|
||||
assert len(masked_positions) == len(masked_labels)
|
||||
|
||||
# Tokens and token types.
|
||||
filler = [pad_id] * padding_length
|
||||
tokens_np = np.array(tokens + filler, dtype=np.int64)
|
||||
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
||||
|
||||
# Padding mask.
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
||||
dtype=np.int64)
|
||||
|
||||
# Lables and loss mask.
|
||||
labels = [-1] * max_seq_length
|
||||
loss_mask = [0] * max_seq_length
|
||||
for i in range(len(masked_positions)):
|
||||
assert masked_positions[i] < num_tokens
|
||||
labels[masked_positions[i]] = masked_labels[i]
|
||||
loss_mask[masked_positions[i]] = 1
|
||||
labels_np = np.array(labels, dtype=np.int64)
|
||||
loss_mask_np = np.array(loss_mask, dtype=np.int64)
|
||||
|
||||
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
@ -1,209 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ascendspeed import get_args, get_tokenizer, mpu, print_rank_0
|
||||
from ascendspeed.data.dataset_utils import create_masked_lm_predictions, \
|
||||
pad_and_convert_to_numpy
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed.data.data_samplers import MegatronPretrainingSampler
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
def make_attention_mask(source_block, target_block):
|
||||
"""
|
||||
Returns a 2-dimensional (2-D) attention mask
|
||||
:param source_block: 1-D array
|
||||
:param target_block: 1-D array
|
||||
"""
|
||||
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
|
||||
mask = mask.astype(np.int64)
|
||||
# (source_length, target_length)
|
||||
return mask
|
||||
|
||||
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
|
||||
"""Specifically one epoch to be used in an indexing job."""
|
||||
args = get_args()
|
||||
|
||||
if micro_batch_size is None:
|
||||
micro_batch_size = args.micro_batch_size
|
||||
num_workers = args.num_workers
|
||||
|
||||
# Use ascendspeed's sampler with consumed samples set to 0 as
|
||||
# this is only for evaluation and don't intend to resume half way.
|
||||
# Also, set the drop last to false as don't intend to remove
|
||||
# the last batch
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=0,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
data_parallel_rank=parallel_state.get_data_parallel_rank(),
|
||||
data_parallel_size=parallel_state.get_data_parallel_world_size(),
|
||||
drop_last=False)
|
||||
|
||||
return torch.utils.data.DataLoader(dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True)
|
||||
|
||||
|
||||
def get_ict_batch(data_iterator):
|
||||
# Items and their type.
|
||||
keys = ['query_tokens', 'query_mask',
|
||||
'context_tokens', 'context_mask', 'block_data']
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
if data_iterator is None:
|
||||
data = None
|
||||
else:
|
||||
data = next(data_iterator)
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
query_tokens = data_b['query_tokens'].long()
|
||||
query_mask = data_b['query_mask'] < 0.5
|
||||
context_tokens = data_b['context_tokens'].long()
|
||||
context_mask = data_b['context_mask'] < 0.5
|
||||
block_indices = data_b['block_data'].long()
|
||||
|
||||
return query_tokens, query_mask,\
|
||||
context_tokens, context_mask, block_indices
|
||||
|
||||
|
||||
def join_str_list(str_list):
|
||||
"""Join a list of strings, handling spaces appropriately"""
|
||||
result = ""
|
||||
for s in str_list:
|
||||
if s.startswith("##"):
|
||||
result += s[2:]
|
||||
else:
|
||||
result += " " + s
|
||||
return result
|
||||
|
||||
|
||||
class BlockSampleData(object):
|
||||
"""A struct for fully describing a fixed-size block of data as used in REALM
|
||||
|
||||
:param start_idx: for first sentence of the block
|
||||
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
|
||||
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
|
||||
:param block_idx: a unique integer identifier given to every block.
|
||||
"""
|
||||
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
|
||||
self.start_idx = start_idx
|
||||
self.end_idx = end_idx
|
||||
self.doc_idx = doc_idx
|
||||
self.block_idx = block_idx
|
||||
|
||||
def as_array(self):
|
||||
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
|
||||
|
||||
def as_tuple(self):
|
||||
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
|
||||
|
||||
|
||||
class BlockSamplesMapping(object):
|
||||
def __init__(self, mapping_array):
|
||||
# make sure that the array is compatible with BlockSampleData
|
||||
assert mapping_array.shape[1] == 4
|
||||
self.mapping_array = mapping_array
|
||||
|
||||
def __len__(self):
|
||||
return self.mapping_array.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get the data associated with an indexed sample."""
|
||||
sample_data = BlockSampleData(*self.mapping_array[idx])
|
||||
return sample_data
|
||||
|
||||
|
||||
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
|
||||
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
|
||||
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
|
||||
a dataset of the titles for the source documents since their lengths must be taken into account.
|
||||
|
||||
:return: samples_mapping (BlockSamplesMapping)
|
||||
"""
|
||||
|
||||
if not num_epochs:
|
||||
if not max_num_samples:
|
||||
raise ValueError("Need to specify either max_num_samples "
|
||||
"or num_epochs")
|
||||
num_epochs = np.iinfo(np.int32).max - 1
|
||||
if not max_num_samples:
|
||||
max_num_samples = np.iinfo(np.int64).max - 1
|
||||
|
||||
# Filename of the index mapping
|
||||
indexmap_filename = data_prefix
|
||||
indexmap_filename += '_{}_indexmap'.format(name)
|
||||
if num_epochs != (np.iinfo(np.int32).max - 1):
|
||||
indexmap_filename += '_{}ep'.format(num_epochs)
|
||||
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
||||
indexmap_filename += '_{}mns'.format(max_num_samples)
|
||||
indexmap_filename += '_{}msl'.format(max_seq_length)
|
||||
indexmap_filename += '_{}s'.format(seed)
|
||||
if use_one_sent_docs:
|
||||
indexmap_filename += '_1sentok'
|
||||
indexmap_filename += '.npy'
|
||||
|
||||
# Build the indexed mapping if not exist.
|
||||
if parallel_state.get_data_parallel_rank() == 0 and \
|
||||
not os.path.isfile(indexmap_filename):
|
||||
print(' > WARNING: could not find index map file {}, building '
|
||||
'the indices on rank 0 ...'.format(indexmap_filename))
|
||||
|
||||
# Make sure the types match the helpers input types.
|
||||
assert block_dataset.doc_idx.dtype == np.int64
|
||||
assert block_dataset.sizes.dtype == np.int32
|
||||
|
||||
# Build samples mapping
|
||||
verbose = torch.distributed.get_rank() == 0
|
||||
start_time = time.time()
|
||||
print_rank_0(' > building samples index mapping for {} ...'.format(
|
||||
name))
|
||||
|
||||
from ascendspeed.data import helpers
|
||||
mapping_array = helpers.build_blocks_mapping(
|
||||
block_dataset.doc_idx,
|
||||
block_dataset.sizes,
|
||||
title_dataset.sizes,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length - 3, # account for added tokens
|
||||
seed,
|
||||
verbose,
|
||||
use_one_sent_docs)
|
||||
|
||||
|
||||
print_rank_0(' > done building samples index mapping')
|
||||
np.save(indexmap_filename, mapping_array, allow_pickle=True)
|
||||
print_rank_0(' > saved the index mapping in {}'.format(
|
||||
indexmap_filename))
|
||||
# Make sure all the ranks have built the mapping
|
||||
print_rank_0(' > elapsed time to build and save samples mapping '
|
||||
'(seconds): {:4f}'.format(
|
||||
time.time() - start_time))
|
||||
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
counts = get_accelerator().LongTensor([1])
|
||||
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
|
||||
assert counts[0].item() == torch.distributed.get_world_size(
|
||||
group=parallel_state.get_data_parallel_group())
|
||||
|
||||
# Load indexed dataset.
|
||||
print_rank_0(' > loading indexed mapping from {}'.format(
|
||||
indexmap_filename))
|
||||
start_time = time.time()
|
||||
|
||||
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
|
||||
samples_mapping = BlockSamplesMapping(mapping_array)
|
||||
|
||||
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
|
||||
time.time() - start_time))
|
||||
print_rank_0(' total number of samples: {}'.format(
|
||||
mapping_array.shape[0]))
|
||||
|
||||
return samples_mapping
|
@ -21,7 +21,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ascendspeed import print_rank_0
|
||||
from ascendspeed import mpu
|
||||
|
||||
|
||||
class BlendableDataset(torch.utils.data.Dataset):
|
||||
|
438
ascendspeed/data/data_handler.py
Normal file
438
ascendspeed/data/data_handler.py
Normal file
@ -0,0 +1,438 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from ascendspeed.data import indexed_dataset
|
||||
from ascendspeed.data.prompter import Prompter, AlpacaTemplate
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["get_dataset_handler", "build_dataset"]
|
||||
|
||||
DEFAULT_CACHE_DIR = "~/tmp"
|
||||
|
||||
|
||||
class BaseDatasetHandler(object):
|
||||
"""
|
||||
a base handler to tokenize or/and prompt your own dataset
|
||||
"""
|
||||
|
||||
def __init__(self, args, raw_datasets, tokenizer, splitter):
|
||||
self.args = args
|
||||
self.tokenizer = tokenizer
|
||||
self.splitter = splitter
|
||||
self.raw_datasets = raw_datasets
|
||||
self.max_seq_len = args.seq_length
|
||||
self.tokenized_dataset = None
|
||||
|
||||
@property
|
||||
def _unwrapped_tokenizer(self):
|
||||
"""get huggingface tokenizer"""
|
||||
return self.tokenizer.tokenizer
|
||||
|
||||
def get_tokenized_data(self):
|
||||
"""get tokenized(and prompted) data"""
|
||||
columns = next(iter(self.raw_datasets)).keys()
|
||||
remove_columns = list(set(columns) - set(self.args.json_keys))
|
||||
proc_kwargs = {} if self.args.streaming else {"num_proc": self.args.workers}
|
||||
return self.raw_datasets.map(self._filter, remove_columns=remove_columns, **proc_kwargs)
|
||||
|
||||
def serialize_to_disk(self):
|
||||
"""save idx and bin to disk"""
|
||||
startup_start = time.time()
|
||||
if not self.tokenized_dataset:
|
||||
self.tokenized_dataset = self.get_tokenized_data()
|
||||
output_bin_files = {}
|
||||
output_idx_files = {}
|
||||
builders = {}
|
||||
level = "document"
|
||||
if self.args.split_sentences:
|
||||
level = "sentence"
|
||||
|
||||
logger.info("Vocab size: %s", self.tokenizer.vocab_size)
|
||||
logger.info("Output prefix: %s", self.args.output_prefix)
|
||||
for key in self.args.json_keys:
|
||||
output_bin_files[key] = f"{self.args.output_prefix}_{key}_{level}.bin"
|
||||
output_idx_files[key] = f"{self.args.output_prefix}_{key}_{level}.idx"
|
||||
# vocab_size=None : use int32 dtype for -100 will be used in labels
|
||||
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
|
||||
impl=self.args.dataset_impl,
|
||||
vocab_size=None)
|
||||
|
||||
startup_end = time.time()
|
||||
proc_start = time.time()
|
||||
total_bytes_processed = 0
|
||||
logger.info("Time to startup:%s", startup_end - startup_start)
|
||||
|
||||
for i, doc in enumerate(iter(self.tokenized_dataset), start=1):
|
||||
for key in self.args.json_keys:
|
||||
sentences = doc[key]
|
||||
if len(sentences) == 0:
|
||||
continue
|
||||
for sentence in sentences:
|
||||
total_bytes_processed += len(sentence) * np.int32().itemsize
|
||||
builders[key].add_item(torch.IntTensor(sentence))
|
||||
builders[key].end_document()
|
||||
if i % self.args.log_interval == 0:
|
||||
current = time.time()
|
||||
elapsed = current - proc_start
|
||||
mbs = total_bytes_processed / elapsed / 1024 / 1024
|
||||
logger.info("Processed %s documents (%s docs/s, %s MB/s).", i, i / elapsed, mbs)
|
||||
|
||||
for key in self.args.json_keys:
|
||||
builders[key].finalize(output_idx_files[key])
|
||||
|
||||
def _tokenize(self, prompt):
|
||||
result = self._unwrapped_tokenizer(text=prompt)
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
|
||||
return result
|
||||
|
||||
def _filter(self, sample):
|
||||
"""prompt and tokenize"""
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class GeneralPretrainHandler(BaseDatasetHandler):
|
||||
"""
|
||||
a general pretrain dataset handler
|
||||
"""
|
||||
def __init__(self, args, raw_datasets, tokenizer, splitter):
|
||||
super().__init__(args, raw_datasets, tokenizer, splitter)
|
||||
if self._text_keys:
|
||||
self.args.json_keys = self._text_keys
|
||||
|
||||
@property
|
||||
def _text_keys(self):
|
||||
return []
|
||||
|
||||
def _pre_process(self, sample):
|
||||
return sample
|
||||
|
||||
def _filter(self, sample):
|
||||
sample = self._pre_process(sample)
|
||||
for key in self.args.json_keys:
|
||||
text = sample[key]
|
||||
doc_ids = []
|
||||
for sentence in self.splitter.tokenize(text):
|
||||
if len(sentence) > 0:
|
||||
sentence_ids = self._tokenize(sentence)
|
||||
doc_ids.append(sentence_ids)
|
||||
if len(doc_ids) > 0 and self.args.append_eod:
|
||||
doc_ids[-1]['input_ids'].append(self.tokenizer.eod)
|
||||
doc_ids[-1]['attention_mask'].append(1)
|
||||
doc_ids[-1]['labels'].append(self.tokenizer.eod)
|
||||
sample[key] = doc_ids
|
||||
# for now, only input_ids are saved
|
||||
sample[key] = list(map(lambda x: x['input_ids'], sample[key]))
|
||||
return sample
|
||||
|
||||
|
||||
class GeneralInstructionHandler(BaseDatasetHandler):
|
||||
"""
|
||||
a general instruction dataset handler
|
||||
"""
|
||||
def __init__(self, args, raw_datasets, tokenizer, splitter):
|
||||
super().__init__(args, raw_datasets, tokenizer, splitter)
|
||||
self.prompter = Prompter(AlpacaTemplate())
|
||||
self.train_on_inputs = False
|
||||
self.args.json_keys = ["input_ids", "attention_mask", "labels"]
|
||||
# use 'packed' string to mark that this is a packed dataset
|
||||
self.args.output_prefix = self.args.output_prefix + "_packed"
|
||||
self.ignored_label = -100
|
||||
self.is_multi_turn = self._is_muti_turn()
|
||||
|
||||
@property
|
||||
def _instruction_key(self) -> str:
|
||||
return "instruction"
|
||||
|
||||
@property
|
||||
def _input_key(self) -> str:
|
||||
return "input"
|
||||
|
||||
@property
|
||||
def _output_key(self) -> str:
|
||||
return "output"
|
||||
|
||||
@property
|
||||
def _human_prefix(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _assistant_prefix(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _is_muti_turn(self) -> bool:
|
||||
try:
|
||||
is_multi_turn = True if isinstance(self._human_prefix, str) else False
|
||||
except NotImplementedError:
|
||||
is_multi_turn = False
|
||||
return is_multi_turn
|
||||
|
||||
def _format_msg(self, sample):
|
||||
"""format sample info"""
|
||||
if not self.is_multi_turn:
|
||||
messages = [
|
||||
dict(
|
||||
role=self.prompter.user_role,
|
||||
content=sample[self._instruction_key] + "\n" + sample[self._input_key]),
|
||||
dict(role=self.prompter.assistant_role, content=sample[self._output_key])
|
||||
]
|
||||
return messages
|
||||
|
||||
messages = []
|
||||
turns = sample[self._instruction_key].split(self._human_prefix)
|
||||
|
||||
for msg in turns:
|
||||
if not msg:
|
||||
continue
|
||||
tmp = msg.split(self._assistant_prefix)
|
||||
if len(tmp) > 1:
|
||||
messages.append(dict(role=self.prompter.user_role, content=tmp[0].strip()))
|
||||
messages.append(dict(role=self.prompter.assistant_role, content=tmp[1].strip()))
|
||||
else:
|
||||
messages.append(dict(role=self.prompter.assistant_role, content=tmp[0].strip()))
|
||||
messages.pop()
|
||||
messages.append(dict(role=self.prompter.assistant_role, content=sample[self._output_key].strip()))
|
||||
return messages
|
||||
|
||||
def _filter(self, sample):
|
||||
messages = self._format_msg(sample)
|
||||
full_prompt = self.prompter.generate_training_prompt(messages)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
|
||||
if self.args.append_eod:
|
||||
tokenized_full_prompt["input_ids"].append(self.tokenizer.eod)
|
||||
tokenized_full_prompt["attention_mask"].append(1)
|
||||
tokenized_full_prompt["labels"].append(self.tokenizer.eod)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = full_prompt.rsplit(self.prompter.template.assistant_token, maxsplit=1)[0] + \
|
||||
self.prompter.template.assistant_token + "\n"
|
||||
tokenized_user_prompt = self._tokenize(user_prompt)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
tokenized_full_prompt["labels"][:user_prompt_len] = [self.ignored_label] * user_prompt_len
|
||||
|
||||
for key in self.args.json_keys:
|
||||
tokenized_full_prompt[key] = [tokenized_full_prompt[key]]
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
|
||||
class BelleMultiTurnInstructionHandler(GeneralInstructionHandler):
|
||||
"""
|
||||
BelleMultiTurn dataset handler
|
||||
"""
|
||||
@property
|
||||
def _human_prefix(self) -> str:
|
||||
return "Human:"
|
||||
|
||||
@property
|
||||
def _assistant_prefix(self) -> str:
|
||||
return "Assistant:"
|
||||
|
||||
|
||||
class MOSSInstructionHandler(GeneralInstructionHandler):
|
||||
def _filter(self, sample):
|
||||
messages = []
|
||||
tokenized_chats = []
|
||||
|
||||
for turn in sample["chat"].values():
|
||||
if not turn:
|
||||
continue
|
||||
|
||||
user = turn["Human"].replace("<eoh>", "").replace("<|Human|>: ", "").strip()
|
||||
assistant = turn["MOSS"].replace("<|MOSS|>:", "").replace("<eom>", "").strip()
|
||||
|
||||
messages.append(dict(role=self.prompter.user_role, content=user))
|
||||
messages.append(dict(role=self.prompter.assistant_role, content=assistant))
|
||||
|
||||
full_prompt = self.prompter.generate_training_prompt(messages)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = full_prompt.rsplit(self.prompter.template.assistant_token, maxsplit=1)[0] + \
|
||||
self.prompter.template.assistant_token + "\n"
|
||||
tokenized_user_prompt = self._tokenize(user_prompt)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
tokenized_full_prompt["labels"] = [-100] * user_prompt_len + tokenized_full_prompt["labels"][
|
||||
user_prompt_len:]
|
||||
|
||||
tokenized_chats.append(tokenized_full_prompt)
|
||||
|
||||
for key in self.args.json_keys:
|
||||
sample[key] = [chat[key] for chat in tokenized_chats]
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class LeetcodePythonInstructionHandler(GeneralInstructionHandler):
|
||||
@property
|
||||
def _instruction_key(self) -> str:
|
||||
return "code_with_problem"
|
||||
|
||||
@property
|
||||
def _input_key(self) -> str:
|
||||
return "code_only"
|
||||
|
||||
@property
|
||||
def _output_key(self) -> str:
|
||||
return "explanation_only"
|
||||
|
||||
def _format_msg(self, sample):
|
||||
"""format sample info"""
|
||||
messages = [
|
||||
dict(
|
||||
role=self.prompter.user_role,
|
||||
content=sample[self._instruction_key].split("```", maxsplit=1)[0].strip()),
|
||||
dict(
|
||||
role=self.prompter.assistant_role,
|
||||
content=sample[self._input_key] + "\n" + sample[self._output_key])
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
class StackOverflowPythonPretrainHandler(GeneralPretrainHandler):
|
||||
@property
|
||||
def _text_keys(self):
|
||||
return ['text']
|
||||
|
||||
def _pre_process(self, sample):
|
||||
sample['text'] = f"In python, {sample['title']}\n### Question:\n{sample['question_body']}\n" \
|
||||
f"### Response:\n{sample['answer_body']}\n"
|
||||
|
||||
|
||||
def _get_handler_cls(handler_name=None):
|
||||
"""choose dataset class by dataset_name"""
|
||||
current_module = sys.modules.get(__name__)
|
||||
if not current_module:
|
||||
raise Exception("curent module not found")
|
||||
handler = getattr(current_module, handler_name, None)
|
||||
if handler is None:
|
||||
handler = GeneralPretrainHandler
|
||||
logger.info("dataset will use %s to handle dataset", handler.__name__)
|
||||
return handler
|
||||
|
||||
|
||||
def get_dataset_handler(args, raw_dataset, tokenizer, splitter):
|
||||
"""
|
||||
get a handler instance
|
||||
"""
|
||||
handler = _get_handler_cls(args.handler_name)
|
||||
|
||||
handler_instance = handler(args, raw_dataset, tokenizer, splitter)
|
||||
return handler_instance
|
||||
|
||||
|
||||
def _get_data_format(files):
|
||||
"""get format with largest number"""
|
||||
all_support_format = {
|
||||
'parquet': 'parquet',
|
||||
'arrow': 'arrow',
|
||||
'csv': 'csv',
|
||||
'json': 'json',
|
||||
'jsonl': 'json',
|
||||
'txt': 'text'
|
||||
}
|
||||
format_num = {}
|
||||
for file in files:
|
||||
ext = file.split('.')[-1]
|
||||
format_num[ext] = format_num.get(ext, 0) + 1
|
||||
exts_with_num = sorted(format_num.items(), key=lambda x: x[1], reverse=True)
|
||||
has_data_file = False
|
||||
for ext, _ in exts_with_num:
|
||||
if ext in all_support_format:
|
||||
has_data_file = True
|
||||
break
|
||||
return (ext, all_support_format.get(ext)) if has_data_file else (None, None)
|
||||
|
||||
|
||||
def _has_py_script(input_name):
|
||||
if os.path.isdir(input_name):
|
||||
dir_name = os.path.basename(input_name)
|
||||
if os.path.exists(os.path.join(input_name, dir_name + '.py')):
|
||||
has_py_script = True
|
||||
else:
|
||||
has_py_script = False
|
||||
else:
|
||||
if input_name.split('.')[-1] == 'py':
|
||||
has_py_script = True
|
||||
else:
|
||||
has_py_script = False
|
||||
return has_py_script
|
||||
|
||||
|
||||
def build_dataset(args):
|
||||
"""loading dataset by huggingface"""
|
||||
if args.handler_name == "MOSSInstructionHandler":
|
||||
# for MOSS, streaming is needed.
|
||||
args.streaming = True
|
||||
if args.hf_datasets_params:
|
||||
with open(args.hf_datasets_params, 'r') as fin:
|
||||
param_dict = json.load(fin)
|
||||
return load_dataset(**param_dict)
|
||||
cache_dir = DEFAULT_CACHE_DIR
|
||||
split_flag = "train"
|
||||
load_from_local = os.path.exists(args.input)
|
||||
if load_from_local:
|
||||
if _has_py_script(args.input):
|
||||
logger.info("loading data from a local python script")
|
||||
raw_datasets = load_dataset(
|
||||
args.input,
|
||||
split=split_flag,
|
||||
num_proc=None if args.streaming else args.workers,
|
||||
cache_dir=cache_dir,
|
||||
streaming=args.streaming
|
||||
)
|
||||
else:
|
||||
data_files = [args.input] if os.path.isfile(args.input) else \
|
||||
glob.glob(os.path.join(args.input, '*'))
|
||||
ext, data_format = _get_data_format(data_files)
|
||||
filtered_data_files = list(filter(lambda x: x.split('.')[-1] == ext, data_files))
|
||||
if filtered_data_files:
|
||||
logger.info("loading data from local file, format: %s,"
|
||||
" file num: %s", data_format, len(data_files))
|
||||
raw_datasets = load_dataset(
|
||||
data_format,
|
||||
split=split_flag,
|
||||
data_files=filtered_data_files,
|
||||
num_proc=None if args.streaming else args.workers,
|
||||
cache_dir=cache_dir,
|
||||
streaming=args.streaming
|
||||
)
|
||||
else:
|
||||
raise Exception("unknown local data!")
|
||||
else:
|
||||
logger.info("loading data from remote huggingface")
|
||||
raw_datasets = load_dataset(
|
||||
args.input,
|
||||
split=split_flag,
|
||||
num_proc=None if args.streaming else args.workers,
|
||||
cache_dir=cache_dir,
|
||||
streaming=args.streaming
|
||||
)
|
||||
return raw_datasets
|
@ -1,32 +1,39 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from ascendspeed import print_rank_0, logging
|
||||
from ascendspeed import is_rank_0, print_rank_0, get_args
|
||||
from ascendspeed.tokenizer.tokenizer import build_tokenizer
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed.data.blendable_dataset import BlendableDataset
|
||||
from ascendspeed.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_, \
|
||||
get_train_valid_test_split_
|
||||
from ascendspeed.data.mtf_dataset import MTFDataset
|
||||
from ascendspeed.data.mtf_dataset import MTFDataset, get_packed_indexed_dataset
|
||||
from ascendspeed.data.indexed_dataset import make_dataset as make_indexed_dataset
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
seq_length: int,
|
||||
pad_token: int,
|
||||
eos_token: int,
|
||||
train_valid_test_num_samples,
|
||||
seed,
|
||||
skip_warmup
|
||||
):
|
||||
"""Build train, valid, and test datasets."""
|
||||
|
||||
args = get_args()
|
||||
|
||||
tokenizer = build_tokenizer(args)
|
||||
pad_token = tokenizer.pad
|
||||
eos_token = tokenizer.eos
|
||||
|
||||
# Single dataset.
|
||||
if len(data_prefix) == 1:
|
||||
all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets(
|
||||
@ -173,14 +180,13 @@ def _build_single_datasets(
|
||||
index = ["train","valid","test"].index(train_valid_test)
|
||||
|
||||
# Target indexed dataset.
|
||||
target_indexed_dataset = get_indexed_dataset(
|
||||
packed_indexed_dataset = get_packed_indexed_dataset(
|
||||
data_prefix=data_prefix,
|
||||
is_input=False,
|
||||
data_impl=data_impl,
|
||||
skip_warmup=skip_warmup
|
||||
)
|
||||
|
||||
total_num_of_documents = target_indexed_dataset.sizes.shape[0]
|
||||
total_num_of_documents = list(packed_indexed_dataset.values())[0].sizes.shape[0]
|
||||
# this corresponds to option2 for data loading on the form
|
||||
# WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3
|
||||
# splits here is an array of size 2 [start_index, end_index]
|
||||
@ -232,9 +238,9 @@ def _build_train_valid_test_datasets(
|
||||
"""Build train, valid, and test datasets."""
|
||||
|
||||
# Target indexed dataset.
|
||||
target_indexed_dataset = get_indexed_dataset(data_prefix, is_input=False, data_impl=data_impl, skip_warmup=skip_warmup)
|
||||
packed_indexed_dataset = get_packed_indexed_dataset(data_prefix=data_prefix, data_impl=data_impl, skip_warmup=skip_warmup)
|
||||
|
||||
total_num_of_documents = target_indexed_dataset.sizes.shape[0]
|
||||
total_num_of_documents = list(packed_indexed_dataset.values())[0].sizes.shape[0]
|
||||
# splits here is an array of size 4 [train_start_index, valid_start_index, test_start_index, test_end_index]
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
# Print stats about the splits.
|
||||
@ -295,82 +301,28 @@ class DecoderPackedMTFDataset(torch.utils.data.Dataset):
|
||||
self.pad_token = pad_token
|
||||
self.seq_length = seq_length
|
||||
|
||||
self.sample_index, self.shuffle_index = _build_index_mappings(name=name, data_prefix=data_prefix, nb_documents=len(documents), mtf_dataset=self.mtf_dataset, num_samples=num_samples, seq_length=seq_length, seed=seed)
|
||||
|
||||
self.shuffle_index = _build_index_mappings(name=name, data_prefix=data_prefix, nb_documents=len(documents), mtf_dataset=self.mtf_dataset, num_samples=num_samples, seq_length=seq_length, seed=seed)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_index)
|
||||
return len(self.shuffle_index)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Get the shuffled index.
|
||||
start, end = self.sample_index[idx]
|
||||
mtf_samples_indices = self.shuffle_index[start: end]
|
||||
# TODO @thomasw21 build a dataset that generates an entire batch instead of a row (allows for more optimization)
|
||||
items = [self.mtf_dataset[sample_id] for sample_id in mtf_samples_indices]
|
||||
|
||||
return self.pack_samples(items)
|
||||
|
||||
def pack_samples(self, items):
|
||||
"""
|
||||
Greedily packs samples.
|
||||
|
||||
Items:
|
||||
[
|
||||
{
|
||||
'input_tokens': array([6, 7]),
|
||||
'target_tokens': array([8])
|
||||
},
|
||||
{
|
||||
'input_tokens': array([3, 4]),
|
||||
'target_tokens': array([5])
|
||||
}
|
||||
]
|
||||
|
||||
Output:
|
||||
decoder_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
|
||||
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
|
||||
decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]: `1` depicts inputs, `0` depicts target.
|
||||
"""
|
||||
|
||||
decoder_tokens = np.full((self.seq_length,), self.pad_token, dtype=np.int64)
|
||||
decoder_segment_ids = np.zeros((self.seq_length,), dtype=np.int64)
|
||||
decoder_is_inputs = np.full((self.seq_length,), False, dtype=bool)
|
||||
|
||||
# `0` is reserved for padding
|
||||
item_num = 1
|
||||
cur_len = 0
|
||||
|
||||
assert len(items) > 0
|
||||
|
||||
for token_dict in items:
|
||||
input_token_len = len(token_dict["input_tokens"])
|
||||
target_token_len = len(token_dict["target_tokens"])
|
||||
|
||||
total_len = input_token_len + target_token_len
|
||||
|
||||
if cur_len + total_len > self.seq_length:
|
||||
# This should not happen at the indexing should only allow the correct number of items
|
||||
raise ValueError(f"""Items to be packed do not fit inside a single sample.
|
||||
current length: {cur_len}
|
||||
input tokens length: {input_token_len}
|
||||
target token length: {target_token_len}
|
||||
expected sequence length: {self.seq_length}
|
||||
""")
|
||||
|
||||
decoder_tokens[cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
|
||||
decoder_tokens[cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
|
||||
decoder_segment_ids[cur_len: cur_len + total_len] = item_num
|
||||
decoder_is_inputs[cur_len: cur_len + input_token_len] = 1 # inputs
|
||||
# targets are already 0 at init, no need to update `decoder_is_inputs`
|
||||
|
||||
item_num += 1
|
||||
cur_len += total_len
|
||||
assert cur_len <= self.seq_length
|
||||
|
||||
doc_idx = self.shuffle_index[idx]
|
||||
item = self.mtf_dataset[doc_idx]
|
||||
return {
|
||||
"decoder_token_ids": decoder_tokens,
|
||||
"decoder_segment_ids": decoder_segment_ids,
|
||||
"decoder_is_inputs": decoder_is_inputs,
|
||||
"input_ids": self._pad_token(item["input_ids"][:-1], self.pad_token, np.int64),
|
||||
"attention_mask": self._pad_token(item["attention_mask"][:-1], 0, np.int64),
|
||||
"labels": self._pad_token(item["labels"][1:], -100, np.int64),
|
||||
}
|
||||
|
||||
def _pad_token(self, token, pad_value, dtype):
|
||||
padded_token = np.full((self.seq_length,), pad_value, dtype=dtype)
|
||||
token_length = len(token)
|
||||
if token_length <= self.seq_length:
|
||||
padded_token[:token_length] = token
|
||||
else:
|
||||
padded_token = token[:self.seq_length]
|
||||
return padded_token.astype(dtype)
|
||||
|
||||
|
||||
def _build_index_mappings(
|
||||
@ -396,52 +348,33 @@ def _build_index_mappings(
|
||||
_filename += '_{}_indexmap'.format(name)
|
||||
_filename += '_{}ns'.format(num_samples)
|
||||
_filename += '_{}s'.format(seed)
|
||||
sample_idx_filename = _filename + '_decoder_packed_batch_idx.npy'
|
||||
shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy'
|
||||
|
||||
# Build the indexed mapping if not exist.
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if (not os.path.isfile(sample_idx_filename)) or \
|
||||
(not os.path.isfile(shuffle_idx_filename)):
|
||||
if is_rank_0():
|
||||
if not os.path.isfile(shuffle_idx_filename):
|
||||
|
||||
print_rank_0(' > WARNING: could not find index map files, building '
|
||||
'the indices on rank 0 ...')
|
||||
|
||||
# iteratively add the entire dataset for every epoch and see if it's enough given current packing strategy
|
||||
start_time = time.time()
|
||||
row_offset = 0
|
||||
old_sample_start = 0
|
||||
epoch = 0
|
||||
shuffle_idx = []
|
||||
sample_idx = []
|
||||
while len(sample_idx) <= num_samples:
|
||||
while len(shuffle_idx) <= num_samples:
|
||||
new_document_ids = _build_shuffle_idx(nb_documents=nb_documents, np_rng=np_rng)
|
||||
# Generate a shuffling of the entire dataset
|
||||
shuffle_idx.append(new_document_ids)
|
||||
# Packs them into a single sample
|
||||
new_samples, row_offset, old_sample_start = _build_sample_idx(
|
||||
mtf_dataset=mtf_dataset,
|
||||
document_ids=new_document_ids,
|
||||
seq_length=seq_length,
|
||||
row_offset=row_offset,
|
||||
old_sample_start=old_sample_start,
|
||||
epoch=epoch
|
||||
)
|
||||
sample_idx.extend(new_samples)
|
||||
shuffle_idx.extend(new_document_ids.tolist())
|
||||
epoch += 1
|
||||
|
||||
shuffle_idx = np.concatenate(shuffle_idx, axis=0)
|
||||
sample_idx = np.stack(sample_idx, axis=0)
|
||||
|
||||
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
|
||||
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
|
||||
print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping'
|
||||
' (seconds): {:4f}'.format(time.time() - start_time))
|
||||
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
counts = torch.cuda.LongTensor([1])
|
||||
counts = get_accelerator().LongTensor([1])
|
||||
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
|
||||
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
|
||||
assert counts[0].item() == (
|
||||
@ -450,51 +383,13 @@ def _build_index_mappings(
|
||||
|
||||
# Load mappings.
|
||||
start_time = time.time()
|
||||
print_rank_0(' > loading doc-idx mapping from {}'.format(
|
||||
sample_idx_filename))
|
||||
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
|
||||
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
|
||||
shuffle_idx_filename))
|
||||
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
|
||||
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
|
||||
time.time() - start_time))
|
||||
|
||||
return sample_idx, shuffle_idx
|
||||
|
||||
def _build_sample_idx(mtf_dataset, document_ids, seq_length, row_offset, old_sample_start, epoch):
|
||||
"""Build start and off index of each `full` batch, return that list of batch + start of the unfinished batch"""
|
||||
row_length = row_offset
|
||||
|
||||
full_samples = []
|
||||
current_sample_start = old_sample_start
|
||||
epoch_offset = epoch * len(document_ids)
|
||||
|
||||
assert epoch_offset >= current_sample_start
|
||||
for current_sample_end, document_id in enumerate(document_ids):
|
||||
current_sample_end = epoch_offset + current_sample_end
|
||||
sample_sizes = mtf_dataset.size(document_id)
|
||||
|
||||
# TODO @thomasw21 figure out if we add <eos> tokens
|
||||
tok_len = sample_sizes["input_tokens"] + sample_sizes["target_tokens"]
|
||||
|
||||
row_length = row_length + tok_len
|
||||
if row_length > seq_length:
|
||||
# current sample can't be added and requires to be added in the next one
|
||||
if current_sample_end > current_sample_start:
|
||||
full_samples.append(np.asarray([current_sample_start, current_sample_end]))
|
||||
current_sample_start = current_sample_end
|
||||
row_length = tok_len
|
||||
|
||||
if tok_len > seq_length:
|
||||
# TODO @thomasw21 handle the case where a single sample cannot fit inside a row. We can
|
||||
# - silently skip that value [currently implemented]
|
||||
# - truncate to `seq_length`, and keep the right part
|
||||
logger.warning(f"Skipping sample id={document_id}. Maximum sequence length: {seq_length}, sample length: {tok_len}")
|
||||
current_sample_start = current_sample_end + 1 # skipping
|
||||
row_length = 0
|
||||
continue
|
||||
|
||||
return full_samples, row_length, current_sample_start
|
||||
return shuffle_idx
|
||||
|
||||
def _build_shuffle_idx(nb_documents: int, np_rng):
|
||||
"""Build the range [0, dataset_size) and shuffle."""
|
||||
@ -506,27 +401,3 @@ def _build_shuffle_idx(nb_documents: int, np_rng):
|
||||
np_rng.shuffle(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_indexed_dataset(data_prefix: str, is_input: bool, data_impl: str, skip_warmup: bool):
|
||||
if is_input:
|
||||
field = "inputs"
|
||||
else:
|
||||
field = "targets"
|
||||
|
||||
return get_indexed_dataset_(f"{data_prefix}_{field}_document", data_impl, skip_warmup)
|
||||
|
||||
|
||||
def get_indexed_dataset_(path, data_impl, skip_warmup):
|
||||
"""Build indexed dataset."""
|
||||
print_rank_0(' > building dataset index ...')
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(path,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
print_rank_0(' > finished creating indexed dataset in {:4f} '
|
||||
'seconds'.format(time.time() - start_time))
|
||||
print_rank_0(' number of documents: {}'.format(
|
||||
indexed_dataset.sizes.shape[0]))
|
||||
|
||||
return indexed_dataset
|
||||
|
Binary file not shown.
Binary file not shown.
@ -1,156 +0,0 @@
|
||||
import itertools
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ascendspeed import get_tokenizer
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed.data.dataset_utils import get_indexed_dataset_
|
||||
from ascendspeed.data.realm_dataset_utils import get_block_samples_mapping
|
||||
|
||||
def make_attention_mask(source_block, target_block):
|
||||
"""
|
||||
Returns a 2-dimensional (2-D) attention mask
|
||||
:param source_block: 1-D array
|
||||
:param target_block: 1-D array
|
||||
"""
|
||||
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
|
||||
mask = mask.astype(np.int64)
|
||||
# (source_length, target_length)
|
||||
return mask
|
||||
|
||||
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
|
||||
rather than for training, since it is only built with a single epoch sample mapping.
|
||||
"""
|
||||
args = get_args()
|
||||
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
|
||||
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
|
||||
|
||||
kwargs = dict(
|
||||
name='full',
|
||||
block_dataset=block_dataset,
|
||||
title_dataset=titles_dataset,
|
||||
data_prefix=args.data_path,
|
||||
num_epochs=1,
|
||||
max_num_samples=None,
|
||||
max_seq_length=args.seq_length,
|
||||
seed=1,
|
||||
query_in_block_prob=query_in_block_prob,
|
||||
use_titles=use_titles,
|
||||
use_one_sent_docs=args.use_one_sent_docs
|
||||
)
|
||||
dataset = ICTDataset(**kwargs)
|
||||
return dataset
|
||||
|
||||
|
||||
class ICTDataset(Dataset):
|
||||
"""Dataset containing sentences and their blocks for an inverse cloze task."""
|
||||
def __init__(self, name, block_dataset, title_dataset, data_prefix,
|
||||
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
|
||||
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.max_seq_length = max_seq_length
|
||||
self.query_in_block_prob = query_in_block_prob
|
||||
self.block_dataset = block_dataset
|
||||
self.title_dataset = title_dataset
|
||||
self.rng = random.Random(self.seed)
|
||||
self.use_titles = use_titles
|
||||
self.use_one_sent_docs = use_one_sent_docs
|
||||
|
||||
self.samples_mapping = get_block_samples_mapping(
|
||||
block_dataset, title_dataset, data_prefix, num_epochs,
|
||||
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
|
||||
self.tokenizer = get_tokenizer()
|
||||
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
|
||||
self.cls_id = self.tokenizer.cls
|
||||
self.sep_id = self.tokenizer.sep
|
||||
self.mask_id = self.tokenizer.mask
|
||||
self.pad_id = self.tokenizer.pad
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples_mapping)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
|
||||
sample_data = self.samples_mapping[idx]
|
||||
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
|
||||
|
||||
if self.use_titles:
|
||||
title = self.title_dataset[int(doc_idx)]
|
||||
title_pad_offset = 3 + len(title)
|
||||
else:
|
||||
title = None
|
||||
title_pad_offset = 2
|
||||
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
|
||||
assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
|
||||
|
||||
# randint() is inclusive for Python rng
|
||||
rand_sent_idx = self.rng.randint(0, len(block) - 1)
|
||||
|
||||
# keep the query in the context query_in_block_prob fraction of the time.
|
||||
if self.rng.random() < self.query_in_block_prob:
|
||||
query = block[rand_sent_idx].copy()
|
||||
else:
|
||||
query = block.pop(rand_sent_idx)
|
||||
|
||||
# still need to truncate because blocks are concluded when
|
||||
# the sentence lengths have exceeded max_seq_length.
|
||||
query = query[:self.max_seq_length - 2]
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
|
||||
|
||||
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
|
||||
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
|
||||
query_mask = make_attention_mask(query_tokens, query_tokens)
|
||||
context_mask = make_attention_mask(context_tokens, context_tokens)
|
||||
|
||||
block_data = sample_data.as_array()
|
||||
|
||||
sample = {
|
||||
'query_tokens': query_tokens,
|
||||
'query_mask': query_mask,
|
||||
'query_pad_mask': query_pad_mask,
|
||||
'context_tokens': context_tokens,
|
||||
'context_mask': context_mask,
|
||||
'context_pad_mask': context_pad_mask,
|
||||
'block_data': block_data,
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def get_block(self, start_idx, end_idx, doc_idx):
|
||||
"""Get the IDs for an evidence block plus the title of the corresponding document"""
|
||||
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
|
||||
title = self.title_dataset[int(doc_idx)]
|
||||
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
|
||||
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
|
||||
return block_tokens, block_pad_mask
|
||||
|
||||
def get_null_block(self):
|
||||
"""Get empty block and title - used in REALM pretraining"""
|
||||
block, title = [], []
|
||||
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
|
||||
return block_tokens, block_pad_mask
|
||||
|
||||
def concat_and_pad_tokens(self, tokens, title=None):
|
||||
"""Concat with special tokens and pad sequence to self.max_seq_length"""
|
||||
tokens = list(tokens)
|
||||
if title is None:
|
||||
tokens = [self.cls_id] + tokens + [self.sep_id]
|
||||
else:
|
||||
title = list(title)
|
||||
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
|
||||
assert len(tokens) <= self.max_seq_length
|
||||
|
||||
num_pad = self.max_seq_length - len(tokens)
|
||||
pad_mask = [1] * len(tokens) + [0] * num_pad
|
||||
tokens += [self.pad_id] * num_pad
|
||||
|
||||
return np.array(tokens), np.array(pad_mask)
|
@ -1,508 +0,0 @@
|
||||
"""Non-Causal Mask Language Model Finetune Style dataset."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ascendspeed import print_rank_0, get_tokenizer, get_args
|
||||
from ascendspeed.data.blendable_dataset import BlendableDataset
|
||||
from ascendspeed.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_
|
||||
from ascendspeed.data.dataset_utils import get_train_valid_test_split_, get_indexed_dataset_
|
||||
from ascendspeed.data.gpt_dataset import GPTDataset
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
sequence_length,
|
||||
noise_density,
|
||||
mean_noise_span_length,
|
||||
seed,
|
||||
skip_warmup
|
||||
):
|
||||
assert noise_density is not None
|
||||
assert mean_noise_span_length is not None
|
||||
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(
|
||||
data_prefix=data_prefix[0],
|
||||
data_impl=data_impl,
|
||||
splits_string=splits_string,
|
||||
train_valid_test_num_samples=train_valid_test_num_samples,
|
||||
sequence_length=sequence_length,
|
||||
noise_density=noise_density,
|
||||
mean_noise_span_length=mean_noise_span_length,
|
||||
seed=seed,
|
||||
skip_warmup=skip_warmup
|
||||
)
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
train_datasets = []
|
||||
valid_datasets = []
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
data_prefix=prefixes[i],
|
||||
data_impl=data_impl,
|
||||
splits_string=splits_string,
|
||||
train_valid_test_num_samples=datasets_train_valid_test_num_samples[i],
|
||||
sequence_length=sequence_length,
|
||||
noise_density=noise_density,
|
||||
mean_noise_span_length=mean_noise_span_length,
|
||||
seed=seed,
|
||||
skip_warmup=skip_warmup
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
valid_datasets.append(valid_ds)
|
||||
if test_ds:
|
||||
test_datasets.append(test_ds)
|
||||
|
||||
# Blend.
|
||||
blending_train_dataset = None
|
||||
if train_datasets:
|
||||
blending_train_dataset = BlendableDataset(train_datasets, weights)
|
||||
blending_valid_dataset = None
|
||||
if valid_datasets:
|
||||
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
|
||||
blending_test_dataset = None
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset,
|
||||
blending_test_dataset)
|
||||
|
||||
def build_dataset_group(
|
||||
dataset_group_name,
|
||||
paths,
|
||||
weights,
|
||||
splits,
|
||||
data_impl,
|
||||
train_valid_test_num_samples,
|
||||
seq_length,
|
||||
noise_density,
|
||||
mean_noise_span_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
train_valid_test
|
||||
):
|
||||
'''
|
||||
Build a single dataset group corresponding to Option 2 of data loading see arguments.py
|
||||
a dataset group is passed on the following form
|
||||
GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2
|
||||
or alternatively
|
||||
GIVEN_NAME PATH1 # for a single dataset to be used fully
|
||||
'''
|
||||
|
||||
assert train_valid_test in ["train","valid","test"]
|
||||
|
||||
# Single dataset.
|
||||
if len(paths) == 1:
|
||||
dataset = _build_single_datasets(
|
||||
data_prefix=paths[0],
|
||||
range_string=splits[0],
|
||||
data_impl=data_impl,
|
||||
train_valid_test_num_samples=train_valid_test_num_samples,
|
||||
sequence_length=seq_length,
|
||||
noise_density=noise_density,
|
||||
mean_noise_span_length=mean_noise_span_length,
|
||||
seed=seed,
|
||||
skip_warmup=skip_warmup,
|
||||
dataset_group_name=dataset_group_name,
|
||||
train_valid_test=train_valid_test)
|
||||
return dataset
|
||||
# Blending dataset.
|
||||
else:
|
||||
|
||||
data_prefix = []
|
||||
# data_prefix is on the shape:
|
||||
# ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"]
|
||||
for w,p in zip(weights, paths):
|
||||
data_prefix += [w,p]
|
||||
|
||||
output = get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
ds = _build_single_datasets(
|
||||
data_prefix=prefixes[i],
|
||||
range_string=splits[i],
|
||||
data_impl=data_impl,
|
||||
train_valid_test_num_samples=datasets_train_valid_test_num_samples[i],
|
||||
sequence_length=seq_length,
|
||||
noise_density=noise_density,
|
||||
mean_noise_span_length=mean_noise_span_length,
|
||||
seed=seed,
|
||||
skip_warmup=skip_warmup,
|
||||
dataset_group_name=dataset_group_name,
|
||||
train_valid_test=train_valid_test
|
||||
)
|
||||
datasets.append(ds)
|
||||
all_datasets = BlendableDataset(datasets, weights)
|
||||
|
||||
return all_datasets
|
||||
|
||||
def _build_single_datasets(
|
||||
data_prefix,
|
||||
range_string,
|
||||
data_impl,
|
||||
train_valid_test_num_samples,
|
||||
sequence_length,
|
||||
noise_density,
|
||||
mean_noise_span_length,
|
||||
seed,
|
||||
skip_warmup,
|
||||
dataset_group_name,
|
||||
train_valid_test):
|
||||
"""Build a single dataset"""
|
||||
|
||||
assert train_valid_test in ["train","valid","test"]
|
||||
index = ["train","valid","test"].index(train_valid_test)
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
|
||||
total_num_of_documents = indexed_dataset.sizes.shape[0]
|
||||
# this corresponds to option2 for data loading on the form
|
||||
# WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3
|
||||
# splits here is an array of size 2 [start_index, end_index]
|
||||
splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents)
|
||||
|
||||
# Print stats about the splits.
|
||||
print_rank_0(' > dataset split:')
|
||||
|
||||
print_rank_0(' {}:'.format(dataset_group_name))
|
||||
print_rank_0(' document indices in [{}, {}) total of {} '
|
||||
'documents'.format(splits[0], splits[1],
|
||||
splits[1] - splits[0]))
|
||||
|
||||
def build_dataset(name):
|
||||
dataset = None
|
||||
if splits[1] > splits[0]:
|
||||
documents = np.arange(start=splits[0], stop=splits[1],
|
||||
step=1, dtype=np.int32)
|
||||
dataset = MLMDataset(
|
||||
indexed_dataset=indexed_dataset,
|
||||
documents=documents,
|
||||
noise_density=noise_density,
|
||||
mean_noise_span_length=mean_noise_span_length,
|
||||
name=name,
|
||||
data_prefix=data_prefix,
|
||||
sequence_length=sequence_length,
|
||||
num_samples=train_valid_test_num_samples[index],
|
||||
seed=seed,
|
||||
)
|
||||
return dataset
|
||||
|
||||
dataset = build_dataset(dataset_group_name)
|
||||
|
||||
return dataset
|
||||
|
||||
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
sequence_length,
|
||||
noise_density,
|
||||
mean_noise_span_length,
|
||||
seed,
|
||||
skip_warmup):
|
||||
"""Build train, valid, and test datasets."""
|
||||
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
|
||||
total_num_of_documents = indexed_dataset.sizes.shape[0] - 1
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
# Print stats about the splits.
|
||||
print_rank_0(' > dataset split:')
|
||||
|
||||
def print_split_stats(name, index):
|
||||
print_rank_0(' {}:'.format(name))
|
||||
print_rank_0(' document indices in [{}, {}) total of {} '
|
||||
'documents'.format(splits[index], splits[index + 1],
|
||||
splits[index + 1] - splits[index]))
|
||||
start_index = indexed_dataset.doc_idx[splits[index]]
|
||||
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||||
print_rank_0(' sentence indices in [{}, {}) total of {} '
|
||||
'sentences'.format(start_index, end_index,
|
||||
end_index - start_index))
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
dataset = None
|
||||
if splits[index + 1] > splits[index]:
|
||||
# Build the dataset accordingly.
|
||||
documents = np.arange(start=splits[index], stop=splits[index + 1],
|
||||
step=1, dtype=np.int32)
|
||||
dataset = MLMDataset(
|
||||
indexed_dataset=indexed_dataset,
|
||||
documents=documents,
|
||||
noise_density=noise_density,
|
||||
mean_noise_span_length=mean_noise_span_length,
|
||||
name=name,
|
||||
data_prefix=data_prefix,
|
||||
sequence_length=sequence_length,
|
||||
num_samples=train_valid_test_num_samples[index],
|
||||
seed=seed,
|
||||
)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
class MLMDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
indexed_dataset,
|
||||
documents,
|
||||
data_prefix,
|
||||
sequence_length,
|
||||
num_samples,
|
||||
seed,
|
||||
noise_density=0.15,
|
||||
mean_noise_span_length=3
|
||||
):
|
||||
|
||||
# Params to store.
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.sequence_length = sequence_length
|
||||
|
||||
# Dataset.
|
||||
self.indexed_dataset = indexed_dataset
|
||||
|
||||
self.noise_density = noise_density
|
||||
self.mean_noise_span_length = mean_noise_span_length
|
||||
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
|
||||
# To ensure that the input length is `sequence_length`, we need to increase the maximum length
|
||||
# according to `noise_density` and `mean_noise_span_length`. We can also define the label length accordingly.
|
||||
number_of_raw_tokens, inputs_length, targets_length, num_noise_spans = compute_input_and_target_lengths(
|
||||
sequence_length=self.sequence_length,
|
||||
noise_density=self.noise_density,
|
||||
mean_noise_span_length=self.mean_noise_span_length
|
||||
)
|
||||
self.inputs_length = inputs_length
|
||||
# In order to compute loss, we need an extra token at the end.
|
||||
self.number_of_raw_tokens = number_of_raw_tokens + 1
|
||||
self.targets_length = targets_length + 1
|
||||
self.num_noise_spans = num_noise_spans
|
||||
|
||||
# Build the samples mapping.
|
||||
self._gpt_dataset = GPTDataset(
|
||||
name=self.name,
|
||||
data_prefix=data_prefix,
|
||||
documents=documents,
|
||||
indexed_dataset=self.indexed_dataset,
|
||||
num_samples=num_samples,
|
||||
# -1 because GPTDataset will return `seq_length + 1` sequences.
|
||||
seq_length=self.number_of_raw_tokens - 1,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
# Vocab stuff.
|
||||
tokenizer = get_tokenizer()
|
||||
self.sep_id = tokenizer.sep
|
||||
self.sentinel_token_ids = tokenizer.additional_special_tokens_ids
|
||||
assert self.sep_id is not None, "MLM dataset requires tokenizer to have a <sep> token"
|
||||
assert len(self.sentinel_token_ids) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
|
||||
assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more"
|
||||
|
||||
args = get_args()
|
||||
# TODO @thomasw21 check once we merge t5
|
||||
assert self.inputs_length + self.targets_length == args.seq_length + 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self._gpt_dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
raise NotImplementedError
|
||||
|
||||
sample = self._gpt_dataset[idx]["text"]
|
||||
|
||||
return build_training_sample(
|
||||
sample=sample,
|
||||
inputs_length=self.inputs_length,
|
||||
targets_length=self.targets_length,
|
||||
num_noise_spans=self.num_noise_spans,
|
||||
sep_id=self.sep_id,
|
||||
all_sentinel_token_ids=self.sentinel_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def build_training_sample(
|
||||
sample,
|
||||
inputs_length,
|
||||
targets_length,
|
||||
num_noise_spans,
|
||||
sep_id,
|
||||
all_sentinel_token_ids,
|
||||
):
|
||||
"""Build training sample.
|
||||
|
||||
Arguments:
|
||||
sample: int32 tensor
|
||||
inputs_length: integer
|
||||
targets_length: integer
|
||||
num_noise_spans: integer
|
||||
sep_id: integer
|
||||
all_sentinel_token_ids: List[int]
|
||||
Returns:
|
||||
Dict with following keys:
|
||||
- `input_tokens`: int32 tensor with as length input_length,
|
||||
- `target_tokens`: int32 tensor with as length targets_length + 1,
|
||||
"""
|
||||
|
||||
spans_start, mask_indices = random_spans_noise_mask(
|
||||
inputs_length=inputs_length,
|
||||
targets_length=targets_length,
|
||||
num_noise_spans=num_noise_spans,
|
||||
)
|
||||
spans_end = np.concatenate([
|
||||
spans_start[1:], np.full((1,), len(sample), dtype=np.int32)]
|
||||
)
|
||||
|
||||
sentinel_token_ids = all_sentinel_token_ids[:num_noise_spans]
|
||||
|
||||
input_token_ids = np.concatenate(
|
||||
[
|
||||
elt
|
||||
for start, end, sentinel_token in zip(spans_start[::2], spans_end[::2], sentinel_token_ids)
|
||||
for elt in [sample[start: end], np.full((1,), sentinel_token, dtype=np.int32)]
|
||||
] +
|
||||
[np.full((1,), sep_id, dtype=np.int32)]
|
||||
)
|
||||
target_token_ids = np.concatenate(
|
||||
[
|
||||
elt
|
||||
for start, end, sentinel_token in zip(spans_start[1::2], spans_end[1::2], sentinel_token_ids)
|
||||
for elt in [np.full((1,), sentinel_token, dtype=np.int32), sample[start: end]]
|
||||
] +
|
||||
[np.full((1,), sep_id, dtype=np.int32)]
|
||||
)
|
||||
|
||||
return {
|
||||
'input_tokens': input_token_ids,
|
||||
'target_tokens': target_token_ids
|
||||
}
|
||||
|
||||
|
||||
def compute_input_and_target_lengths(sequence_length, noise_density, mean_noise_span_length):
|
||||
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
|
||||
Training parameters to avoid padding with random_spans_noise_mask.
|
||||
When training a model with random_spans_noise_mask, we would like to set the other
|
||||
training hyperparmeters in a way that avoids padding.
|
||||
This function helps us compute these hyperparameters.
|
||||
The number of noise tokens and the number of noise spans and non-noise spans
|
||||
are determined deterministically as follows:
|
||||
num_noise_tokens = round(length * noise_density)
|
||||
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
|
||||
We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
|
||||
and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
|
||||
This function tells us the required number of tokens in the raw example (for split_tokens())
|
||||
as well as the length of the encoded targets. Note that this function assumes
|
||||
the inputs and targets will have SEP appended and includes that in the reported length.
|
||||
Args:
|
||||
inputs_length: an integer - desired length of the tokenized inputs sequence
|
||||
noise_density: a float
|
||||
mean_noise_span_length: a float
|
||||
Returns:
|
||||
tokens_length: length of original text in tokens
|
||||
targets_length: an integer - length in tokens of encoded targets sequence
|
||||
"""
|
||||
|
||||
def _tokens_length_to_inputs_length_targets_length(_tokens_length):
|
||||
num_noise_tokens = int(round(_tokens_length * noise_density))
|
||||
num_nonnoise_tokens = _tokens_length - num_noise_tokens
|
||||
_num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
|
||||
# inputs contain all nonnoise tokens, sentinels for all noise spans and one SEP token.
|
||||
_input_length = num_nonnoise_tokens + _num_noise_spans + 1
|
||||
_output_length = num_noise_tokens + _num_noise_spans + 1
|
||||
return _input_length, _output_length, _num_noise_spans
|
||||
|
||||
tokens_length = sequence_length
|
||||
inputs_length, targets_length, num_noise_spans = _tokens_length_to_inputs_length_targets_length(tokens_length)
|
||||
while inputs_length + targets_length > sequence_length:
|
||||
tokens_length -= 1
|
||||
inputs_length, targets_length, num_noise_spans = _tokens_length_to_inputs_length_targets_length(tokens_length)
|
||||
|
||||
# tokens_length is the number of raw tokens we need to get
|
||||
# inputs_length will be the input
|
||||
# targets_length will be the target
|
||||
# num_noise_spans is the number of spans we have to replace
|
||||
return tokens_length, inputs_length, targets_length, num_noise_spans
|
||||
|
||||
|
||||
def random_spans_noise_mask(
|
||||
inputs_length,
|
||||
targets_length,
|
||||
num_noise_spans,
|
||||
):
|
||||
|
||||
"""This function is inspired from `random_spans_noise_mask <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
||||
Noise mask consisting of random spans of noise tokens.
|
||||
Spans alternate between non-noise and noise, beginning with non-noise.
|
||||
Args:
|
||||
inputs_length: int32 scalar
|
||||
targets_length: int32 scalar
|
||||
num_noise_spans: int32 scalar
|
||||
Returns:
|
||||
a int8 tensor with shape [num_noise_spans]
|
||||
a boolean tensor with shape [length]
|
||||
"""
|
||||
# # pick the lengths of the noise spans and the non-noise spans
|
||||
num_noise_tokens = targets_length - num_noise_spans - 1
|
||||
num_nonnoise_tokens = inputs_length - num_noise_spans - 1
|
||||
number_of_raw_tokens = num_noise_tokens + num_nonnoise_tokens
|
||||
|
||||
def _random_segmentation(num_items, num_segments):
|
||||
"""Partition a sequence of items randomly into non-empty segments.
|
||||
Args:
|
||||
num_items: an integer scalar > 0
|
||||
num_segments: an integer scalar in [1, num_items]
|
||||
Returns:
|
||||
a Tensor with shape [num_segments] containing positive integers that add
|
||||
up to num_items
|
||||
"""
|
||||
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
|
||||
# TODO @thomasw21 handle random state correctly, ie synchronized across TP.
|
||||
# we might not care as get_batch_pipe broadcasts data to all devices.
|
||||
np.random.shuffle(mask_indices)
|
||||
first_in_segment = np.pad(mask_indices, [[1, 0]], constant_values=0)
|
||||
segment_id = np.cumsum(first_in_segment)
|
||||
# count length of sub segments assuming that list is sorted
|
||||
_, segment_length = np.unique(segment_id, return_counts=True)
|
||||
return segment_length
|
||||
|
||||
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
||||
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
|
||||
|
||||
interleaved_span_lengths = np.reshape(
|
||||
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
|
||||
)
|
||||
span_starts = np.concatenate([np.full((1,), 0, dtype=np.int32), np.cumsum(interleaved_span_lengths)[:-1]])
|
||||
span_start_indicator = np.zeros((number_of_raw_tokens,), dtype=np.int8)
|
||||
span_start_indicator[span_starts] = True
|
||||
span_num = np.cumsum(span_start_indicator)
|
||||
is_noise = np.equal(span_num % 2, 1)
|
||||
|
||||
return span_starts, is_noise
|
@ -16,6 +16,8 @@
|
||||
"""Multitask Finetune style dataset."""
|
||||
|
||||
import time
|
||||
import glob
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -37,43 +39,44 @@ class MTFDataset(torch.utils.data.Dataset):
|
||||
self.name = name
|
||||
|
||||
# Dataset.
|
||||
self.input_indexed_dataset = get_indexed_dataset(data_prefix, is_input=True, data_impl=data_impl, skip_warmup=skip_warmup)
|
||||
self.target_indexed_dataset = get_indexed_dataset(data_prefix, is_input=False, data_impl=data_impl, skip_warmup=skip_warmup)
|
||||
self.packed_indexed_dataset = get_packed_indexed_dataset(data_prefix, data_impl=data_impl, skip_warmup=skip_warmup)
|
||||
|
||||
# Checks
|
||||
assert np.min(documents) >= 0
|
||||
assert np.max(documents) < self.input_indexed_dataset.sizes.shape[0]
|
||||
assert np.max(documents) < self.target_indexed_dataset.sizes.shape[0]
|
||||
assert self.input_indexed_dataset.sizes.shape[0] == self.target_indexed_dataset.sizes.shape[0]
|
||||
assert len(self.packed_indexed_dataset) > 0
|
||||
|
||||
self.length = list(self.packed_indexed_dataset.values())[0].sizes.shape[0]
|
||||
|
||||
assert np.max(documents) < self.length
|
||||
for dataset in self.packed_indexed_dataset.values():
|
||||
assert dataset.sizes.shape[0] == self.length
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_indexed_dataset)
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_tokens = self.input_indexed_dataset.get(idx)
|
||||
target_tokens = self.target_indexed_dataset.get(idx)
|
||||
packed_data = dict()
|
||||
for key, dataset in self.packed_indexed_dataset.items():
|
||||
packed_data[key] = dataset.get(idx)
|
||||
assert len(packed_data[key]) > 0
|
||||
return packed_data
|
||||
|
||||
assert len(input_tokens) > 0
|
||||
assert len(target_tokens) > 0
|
||||
|
||||
return {
|
||||
'input_tokens': input_tokens,
|
||||
'target_tokens': target_tokens,
|
||||
}
|
||||
def get_packed_indexed_dataset(data_prefix: str, data_impl: str, skip_warmup: bool):
|
||||
index_dataset_name = f"{data_prefix}_packed_*_document*"
|
||||
names = glob.glob(index_dataset_name)
|
||||
template = f"{data_prefix}_packed_(.*)_document(.*)"
|
||||
all_field = set()
|
||||
for name in names:
|
||||
fields = re.match(template, name)
|
||||
all_field.add(fields.group(1))
|
||||
packed_dataset = dict()
|
||||
for field in all_field:
|
||||
packed_dataset[field] = get_indexed_dataset_(
|
||||
f"{data_prefix}_packed_{field}_document", data_impl, skip_warmup
|
||||
)
|
||||
return packed_dataset
|
||||
|
||||
def size(self, index):
|
||||
return {
|
||||
'input_tokens': self.input_indexed_dataset.size(index),
|
||||
'target_tokens': self.target_indexed_dataset.size(index),
|
||||
}
|
||||
|
||||
def get_indexed_dataset(data_prefix: str, is_input: bool, data_impl: str, skip_warmup: bool):
|
||||
if is_input:
|
||||
field = "inputs"
|
||||
else:
|
||||
field = "targets"
|
||||
|
||||
return get_indexed_dataset_(f"{data_prefix}_{field}_document", data_impl, skip_warmup)
|
||||
|
||||
def get_indexed_dataset_(path, data_impl, skip_warmup):
|
||||
"""Build indexed dataset."""
|
||||
|
@ -1,205 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Wikipedia dataset from DPR code for ORQA."""
|
||||
|
||||
from abc import ABC
|
||||
import csv
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ascendspeed import print_rank_0, get_args, get_tokenizer, mpu
|
||||
from ascendspeed.data.biencoder_dataset_utils import make_attention_mask
|
||||
|
||||
def get_open_retrieval_wiki_dataset():
|
||||
args = get_args()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase',
|
||||
'evidence',
|
||||
args.evidence_data_path,
|
||||
tokenizer,
|
||||
args.retriever_seq_length)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_open_retrieval_batch(data_iterator):
|
||||
# Items and their type.
|
||||
keys = ['row_id', 'context', 'context_mask', 'context_types',
|
||||
'context_pad_mask']
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
data = None if data_iterator is None else next(data_iterator)
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
row_id = data_b['row_id'].long()
|
||||
context = data_b['context'].long()
|
||||
|
||||
# TODO: make the context mask a binary one
|
||||
context_mask = (data_b['context_mask'] < 0.5)
|
||||
|
||||
context_types = data_b['context_types'].long()
|
||||
context_pad_mask = data_b['context_pad_mask'].long()
|
||||
|
||||
return row_id, context, context_mask, context_types, context_pad_mask
|
||||
|
||||
|
||||
def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length):
|
||||
"""Build token types and paddings, trim if needed, and pad if needed."""
|
||||
|
||||
title_ids = tokenizer.tokenize(row['title'])
|
||||
context_ids = tokenizer.tokenize(row['text'])
|
||||
|
||||
# Appending the title of the context at front
|
||||
extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids
|
||||
|
||||
context_ids, context_types, context_pad_mask = \
|
||||
build_tokens_types_paddings_from_ids(extended_context_ids,
|
||||
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
|
||||
|
||||
return context_ids, context_types, context_pad_mask
|
||||
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
|
||||
cls_id, sep_id, pad_id):
|
||||
"""Build token types and paddings, trim if needed, and pad if needed."""
|
||||
enc_ids = []
|
||||
tokentypes_enc = []
|
||||
|
||||
# [CLS].
|
||||
enc_ids.append(cls_id)
|
||||
tokentypes_enc.append(0)
|
||||
|
||||
# A.
|
||||
len_src = len(text_ids)
|
||||
enc_ids.extend(text_ids)
|
||||
tokentypes_enc.extend([0] * len_src)
|
||||
|
||||
# Cap the size.
|
||||
if len(enc_ids) > max_seq_length - 1:
|
||||
enc_ids = enc_ids[0: max_seq_length - 1]
|
||||
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
|
||||
|
||||
# [SEP].
|
||||
enc_ids.append(sep_id)
|
||||
tokentypes_enc.append(0)
|
||||
|
||||
num_tokens_enc = len(enc_ids)
|
||||
# Padding.
|
||||
padding_length = max_seq_length - len(enc_ids)
|
||||
if padding_length > 0:
|
||||
enc_ids.extend([pad_id] * padding_length)
|
||||
tokentypes_enc.extend([pad_id] * padding_length)
|
||||
|
||||
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
|
||||
pad_mask = np.array(pad_mask, dtype=np.int64)
|
||||
|
||||
return enc_ids, tokentypes_enc, pad_mask
|
||||
|
||||
|
||||
def build_sample(row_id, context_ids, context_types, context_pad_mask):
|
||||
"""Convert to numpy and return a sample consumed by the batch producer."""
|
||||
|
||||
context_ids = np.array(context_ids, dtype=np.int64)
|
||||
context_types = np.array(context_types, dtype=np.int64)
|
||||
context_mask = make_attention_mask(context_ids, context_ids)
|
||||
|
||||
sample = ({
|
||||
'row_id': row_id,
|
||||
'context': context_ids,
|
||||
'context_mask': context_mask,
|
||||
'context_types': context_types,
|
||||
'context_pad_mask': context_pad_mask
|
||||
})
|
||||
return sample
|
||||
|
||||
|
||||
class OpenRetrievalEvidenceDataset(ABC, Dataset):
|
||||
"""Open Retrieval Evidence dataset class."""
|
||||
|
||||
def __init__(self, task_name, dataset_name, datapath, tokenizer,
|
||||
max_seq_length):
|
||||
# Store inputs.
|
||||
self.task_name = task_name
|
||||
self.dataset_name = dataset_name
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_length = max_seq_length
|
||||
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
|
||||
self.dataset_name))
|
||||
# Process the files.
|
||||
print_rank_0(datapath)
|
||||
self.samples, self.id2text = self.process_samples_from_single_path(
|
||||
datapath)
|
||||
|
||||
args = get_args()
|
||||
if args.sample_rate < 1: # subsample
|
||||
k = int(len(self.samples) * args.sample_rate)
|
||||
self.samples = random.sample(self.samples, k)
|
||||
|
||||
print_rank_0(' >> total number of samples: {}'.format(
|
||||
len(self.samples)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
row = self.samples[idx]
|
||||
|
||||
context_ids, context_types, context_pad_mask = \
|
||||
build_tokens_types_paddings_from_text(row, self.tokenizer,
|
||||
self.max_seq_length)
|
||||
|
||||
sample = build_sample(row['doc_id'],
|
||||
context_ids,
|
||||
context_types,
|
||||
context_pad_mask)
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def process_samples_from_single_path(filename):
|
||||
print_rank_0(' > Processing {} ...'.format(filename))
|
||||
total = 0
|
||||
|
||||
rows = []
|
||||
id2text = {}
|
||||
|
||||
with open(filename) as tsvfile:
|
||||
reader = csv.reader(tsvfile, delimiter='\t')
|
||||
next(reader, None) # skip the headers
|
||||
for row in reader:
|
||||
# file format: doc_id, doc_text, title
|
||||
doc_id = int(row[0])
|
||||
text = row[1]
|
||||
title = row[2]
|
||||
|
||||
rows.append({'doc_id': doc_id,
|
||||
'text': text,
|
||||
'title': title})
|
||||
|
||||
assert doc_id not in id2text
|
||||
id2text[doc_id] = (text, title)
|
||||
|
||||
total += 1
|
||||
if total % 100000 == 0:
|
||||
print_rank_0(' > processed {} rows so far ...'.format(
|
||||
total))
|
||||
|
||||
print_rank_0(' >> processed {} samples.'.format(len(rows)))
|
||||
return rows, id2text
|
50
ascendspeed/data/prompter.py
Normal file
50
ascendspeed/data/prompter.py
Normal file
@ -0,0 +1,50 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os.path as osp
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlpacaTemplate:
|
||||
system_token = ""
|
||||
user_token = "### Instruction:"
|
||||
assistant_token = "### Response:"
|
||||
end_token = ""
|
||||
system = "Below is an instruction that describes a task, paired with an input that provides further context. " \
|
||||
"Write a response that appropriately completes the request. " \
|
||||
"Please note that you need to think through your response logically and step by step."
|
||||
|
||||
|
||||
class Prompter(object):
|
||||
|
||||
def __init__(self, template, verbose: bool = False):
|
||||
self._verbose = verbose
|
||||
self.template = template
|
||||
self.user_role = "user"
|
||||
self.assistant_role = "assistant"
|
||||
|
||||
def generate_training_prompt(self, messages) -> str:
|
||||
prompt = self.template.system_token + "\n" + self.template.system + self.template.end_token + "\n"
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == self.user_role:
|
||||
prompt += self.template.user_token + "\n" + message["content"] + self.template.end_token + "\n"
|
||||
else:
|
||||
prompt += self.template.assistant_token + "\n" + message["content"] \
|
||||
+ self.template.end_token + "\n"
|
||||
|
||||
return prompt
|
@ -1,270 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""T5 Style dataset."""
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ascendspeed import get_tokenizer
|
||||
from ascendspeed.data.dataset_utils import (
|
||||
create_masked_lm_predictions,
|
||||
get_samples_mapping
|
||||
)
|
||||
|
||||
class T5Dataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, name, indexed_dataset, data_prefix,
|
||||
num_epochs, max_num_samples, masked_lm_prob,
|
||||
max_seq_length, max_seq_length_dec,
|
||||
short_seq_prob, seed):
|
||||
|
||||
# Params to store.
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.masked_lm_prob = masked_lm_prob
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_seq_length_dec = max_seq_length_dec
|
||||
|
||||
# Dataset.
|
||||
self.indexed_dataset = indexed_dataset
|
||||
|
||||
# Build the samples mapping.
|
||||
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
self.max_seq_length - 2, # account for added tokens
|
||||
short_seq_prob,
|
||||
self.seed,
|
||||
self.name,
|
||||
False)
|
||||
|
||||
# Vocab stuff.
|
||||
tokenizer = get_tokenizer()
|
||||
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_dict = tokenizer.inv_vocab
|
||||
self.cls_id = tokenizer.cls
|
||||
self.sep_id = tokenizer.sep
|
||||
self.mask_id = tokenizer.mask
|
||||
self.pad_id = tokenizer.pad
|
||||
self.bos_id = tokenizer.bos_token_id
|
||||
self.eos_id = tokenizer.eos_token_id
|
||||
self.sentinel_tokens = tokenizer.additional_special_tokens_ids
|
||||
assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
|
||||
|
||||
def __len__(self):
|
||||
return self.samples_mapping.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
start_index, end_index, seq_length = self.samples_mapping[idx]
|
||||
sample = []
|
||||
for index in range(start_index, end_index):
|
||||
sample.append(self.indexed_dataset[index])
|
||||
# Note that this rng state should be numpy and not python since
|
||||
# python randint is inclusive whereas the numpy one is exclusive.
|
||||
np_rng = np.random.RandomState(seed=(self.seed + idx))
|
||||
return build_training_sample(sample, seq_length,
|
||||
self.max_seq_length, # needed for padding
|
||||
self.max_seq_length_dec,
|
||||
self.vocab_id_list,
|
||||
self.vocab_id_to_token_dict,
|
||||
self.cls_id, self.sep_id,
|
||||
self.mask_id, self.pad_id,
|
||||
self.masked_lm_prob, np_rng,
|
||||
self.bos_id, self.eos_id,
|
||||
self.sentinel_tokens)
|
||||
|
||||
|
||||
def build_training_sample(sample, target_seq_length,
|
||||
max_seq_length, max_seq_length_dec,
|
||||
vocab_id_list, vocab_id_to_token_dict,
|
||||
cls_id, sep_id, mask_id, pad_id,
|
||||
masked_lm_prob, np_rng, bos_id=None,
|
||||
eos_id=None, sentinel_tokens=None):
|
||||
"""Build training sample.
|
||||
|
||||
Arguments:
|
||||
sample: A list of sentences in which each sentence is a list token ids.
|
||||
target_seq_length: Desired sequence length.
|
||||
max_seq_length: Maximum length of the sequence. All values are padded to
|
||||
this length.
|
||||
vocab_id_list: List of vocabulary ids. Used to pick a random id.
|
||||
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
|
||||
cls_id: Start of example id.
|
||||
sep_id: Separator id.
|
||||
mask_id: Mask token id.
|
||||
pad_id: Padding token id.
|
||||
masked_lm_prob: Probability to mask tokens.
|
||||
np_rng: Random number genenrator. Note that this rng state should be
|
||||
numpy and not python since python randint is inclusive for
|
||||
the opper bound whereas the numpy one is exclusive.
|
||||
bos_id: start of decoder example id
|
||||
eos_id: end of generation id
|
||||
sentinel_tokens: unique value to be substituted for every replaced span
|
||||
"""
|
||||
|
||||
assert target_seq_length <= max_seq_length
|
||||
|
||||
# flatten sentences into one list
|
||||
tokens = [token for sentence in sample for token in sentence]
|
||||
|
||||
# Truncate to `target_sequence_length`.
|
||||
max_num_tokens = target_seq_length
|
||||
truncated = len(tokens) > max_num_tokens
|
||||
tokens = tokens[:max_num_tokens]
|
||||
|
||||
# Masking.
|
||||
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
||||
(tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions(
|
||||
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
|
||||
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,
|
||||
max_ngrams=10, geometric_dist=True, masking_style="t5")
|
||||
|
||||
# Padding.
|
||||
tokens_enc, tokens_dec_in, labels, enc_mask, \
|
||||
dec_mask, enc_dec_mask, loss_mask \
|
||||
= pad_and_convert_to_numpy(tokens, masked_positions,
|
||||
masked_labels, pad_id, max_seq_length,
|
||||
max_seq_length_dec, masked_spans,
|
||||
bos_id, eos_id, sentinel_tokens)
|
||||
|
||||
train_sample = {
|
||||
'text_enc': tokens_enc,
|
||||
'text_dec': tokens_dec_in,
|
||||
'labels': labels,
|
||||
'loss_mask': loss_mask,
|
||||
'truncated': int(truncated),
|
||||
'enc_mask': enc_mask,
|
||||
'dec_mask': dec_mask,
|
||||
'enc_dec_mask': enc_dec_mask,
|
||||
}
|
||||
return train_sample
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(tokens, masked_positions,
|
||||
masked_labels, pad_id,
|
||||
max_seq_length, max_seq_length_dec,
|
||||
masked_spans=None, bos_id=None,
|
||||
eos_id=None, sentinel_tokens=None):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
sentinel_tokens = collections.deque(sentinel_tokens)
|
||||
t5_input = []
|
||||
(t5_decoder_in, t5_decoder_out) = ([bos_id], [])
|
||||
(start_index, end_index) = (0, None)
|
||||
for span in masked_spans:
|
||||
flag = sentinel_tokens.popleft()
|
||||
|
||||
# Append the same tokens in decoder input and output
|
||||
t5_decoder_in.append(flag)
|
||||
t5_decoder_in.extend(span.label)
|
||||
t5_decoder_out.append(flag)
|
||||
t5_decoder_out.extend(span.label)
|
||||
|
||||
end_index = span.index[0]
|
||||
t5_input.extend(tokens[start_index: end_index])
|
||||
t5_input.append(flag)
|
||||
|
||||
# the next start index is the token after the last span token
|
||||
start_index = span.index[-1] + 1
|
||||
|
||||
# Add <eos> token to the t5_decoder_out
|
||||
t5_decoder_out.append(eos_id)
|
||||
|
||||
# Add the remaining tokens to the t5 input
|
||||
t5_input.extend(tokens[start_index:])
|
||||
|
||||
# assert (len(t5_input) - len(masked_spans)) + \
|
||||
# (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
|
||||
|
||||
# Some checks.
|
||||
|
||||
# Encoder-side padding mask.
|
||||
num_tokens = len(t5_input)
|
||||
padding_length = max_seq_length - num_tokens
|
||||
assert padding_length >= 0
|
||||
assert len(masked_positions) == len(masked_labels)
|
||||
|
||||
# Tokens..
|
||||
filler = [pad_id] * padding_length
|
||||
tokens_enc = np.array(t5_input + filler, dtype=np.int64)
|
||||
|
||||
# Decoder-side padding mask.
|
||||
num_tokens_dec = len(t5_decoder_in)
|
||||
padding_length_dec = max_seq_length_dec - num_tokens_dec
|
||||
assert padding_length_dec >= 0
|
||||
filler_dec = [pad_id] * padding_length_dec
|
||||
tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64)
|
||||
|
||||
# Create attention masks
|
||||
enc_mask = make_attention_mask(tokens_enc, tokens_enc)
|
||||
enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc)
|
||||
dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in)
|
||||
dec_mask = dec_mask * make_history_mask(tokens_dec_in)
|
||||
|
||||
# Labels mask.
|
||||
labels = t5_decoder_out + ([-1] * padding_length_dec)
|
||||
labels = np.array(labels, dtype=np.int64)
|
||||
|
||||
# Loss mask
|
||||
loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec)
|
||||
loss_mask = np.array(loss_mask, dtype=np.int64)
|
||||
|
||||
return tokens_enc, tokens_dec_in, labels, enc_mask, \
|
||||
dec_mask, enc_dec_mask, loss_mask
|
||||
|
||||
|
||||
def make_attention_mask(source_block, target_block):
|
||||
"""
|
||||
Returns a 2-dimensional (2-D) attention mask
|
||||
:param source_block: 1-D array
|
||||
:param target_block: 1-D array
|
||||
"""
|
||||
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
|
||||
mask = mask.astype(np.int64)
|
||||
# (source_length, target_length)
|
||||
return mask
|
||||
|
||||
|
||||
def make_attention_mask_3d(source_block, target_block):
|
||||
"""
|
||||
Returns a 3-dimensional (3-D) attention mask
|
||||
:param source_block: 1-D array
|
||||
:param target_block: 1-D array
|
||||
"""
|
||||
mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1)
|
||||
# (batch, source_length, target_length)
|
||||
# mask = mask.astype(np.int64)
|
||||
return mask
|
||||
|
||||
|
||||
def make_history_mask(block):
|
||||
length = block.shape[0]
|
||||
arange = np.arange(length)
|
||||
history_mask = (arange[None, ] <= arange[:, None])
|
||||
history_mask = history_mask.astype(np.int64)
|
||||
return history_mask
|
||||
|
||||
|
||||
def make_history_mask_3d(block):
|
||||
batch, length = block.shape
|
||||
arange = torch.arange(length, device=block.device)
|
||||
history_mask = (arange[None, ] <= arange[:, None])[None, ]
|
||||
history_mask = history_mask.expand(batch, length, length)
|
||||
return history_mask
|
@ -1,125 +0,0 @@
|
||||
# This file isn't really a formal automated test, it's just a place to
|
||||
# put some code used during development and manual testing of
|
||||
# indexed_dataset.
|
||||
|
||||
from ascendspeed.data import indexed_dataset
|
||||
from ascendspeed.tokenizer import build_tokenizer
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.append(os.path.join(script_dir, "../../../"))
|
||||
|
||||
|
||||
def test_indexed_dataset(args):
|
||||
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
tokenizer = build_tokenizer(args)
|
||||
print(len(ds.doc_idx))
|
||||
print(len(ds))
|
||||
print(ds.doc_idx[-1])
|
||||
if ds.supports_prefetch:
|
||||
# just prefetch the whole thing in test (so assume it is small)
|
||||
ds.prefetch(range(len(ds)))
|
||||
if args.count > len(ds.doc_idx) - 1:
|
||||
args.count = len(ds.doc_idx) - 1
|
||||
|
||||
for i in range(args.count):
|
||||
start = ds.doc_idx[i]
|
||||
end = ds.doc_idx[i + 1]
|
||||
ids = ds[start:end]
|
||||
print(f"Document {i}:")
|
||||
print("--------------")
|
||||
for s in ids:
|
||||
assert len(s) > 0
|
||||
l = s.data.tolist()
|
||||
text = tokenizer.detokenize(l)
|
||||
print(text)
|
||||
print("---")
|
||||
|
||||
|
||||
def test_indexed_dataset_get(args):
|
||||
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
tokenizer = build_tokenizer(args)
|
||||
size = ds.sizes[0]
|
||||
print(f"size: {size}")
|
||||
full = ds.get(0)
|
||||
print(full)
|
||||
# print(tokenizer.detokenize(full.data.tolist()))
|
||||
print("---")
|
||||
end = ds.get(0, offset=size - 10)
|
||||
print(end)
|
||||
# print(tokenizer.detokenize(end.data.tolist()))
|
||||
|
||||
start = ds.get(0, length=10)
|
||||
print(start)
|
||||
# print(tokenizer.detokenize(start.data.tolist()))
|
||||
|
||||
part = ds.get(0, offset=2, length=8)
|
||||
print(part)
|
||||
# print(tokenizer.detokenize(part.data.tolist()))
|
||||
|
||||
# def test_albert_dataset(args):
|
||||
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
|
||||
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
# # ds = AlbertDataset(idataset, tokenizer)
|
||||
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
|
||||
# args.epochs, args.max_num_samples,
|
||||
# args.masked_lm_prob, args.seq_length,
|
||||
# args.short_seq_prob, args.seed)
|
||||
# truncated = 0
|
||||
# total = 0
|
||||
# for i, s in enumerate(ds):
|
||||
# ids = s['text']
|
||||
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
|
||||
# print(tokens)
|
||||
# if i >= args.count-1:
|
||||
# exit()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data', type=str, help='prefix to data files')
|
||||
parser.add_argument('--dataset-impl', type=str, default='infer',
|
||||
choices=['lazy', 'cached', 'mmap', 'infer'])
|
||||
parser.add_argument('--count', type=int, default=10,
|
||||
help='Number of samples/documents to print')
|
||||
|
||||
group = parser.add_argument_group(title='tokenizer')
|
||||
group.add_argument('--tokenizer-type', type=str, required=True,
|
||||
choices=['BertWordPieceLowerCase',
|
||||
'GPT2BPETokenizer'],
|
||||
help='What type of tokenizer to use.')
|
||||
group.add_argument('--vocab-file', type=str, default=None,
|
||||
help='Path to the vocab file')
|
||||
group.add_argument('--merge-file', type=str, default=None,
|
||||
help='Path to the BPE merge file (if necessary).')
|
||||
|
||||
parser.add_argument('--epochs', type=int, default=5,
|
||||
help='Number of epochs to plan for')
|
||||
parser.add_argument('--max-num-samples', type=int, default=None,
|
||||
help='Maximum number of samples to plan for')
|
||||
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
|
||||
help='probability of masking tokens')
|
||||
parser.add_argument('--seq-length', type=int, default=512,
|
||||
help='maximum sequence length')
|
||||
parser.add_argument('--short-seq-prob', type=float, default=0.1,
|
||||
help='probability of creating a short sequence')
|
||||
parser.add_argument('--seed', type=int, default=1234,
|
||||
help='random seed')
|
||||
args = parser.parse_args()
|
||||
args.rank = 0
|
||||
args.make_vocab_size_divisible_by = 128
|
||||
args.tensor_model_parallel_size = 1
|
||||
|
||||
if args.dataset_impl == "infer":
|
||||
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
|
||||
|
||||
# test_albert_dataset(args)
|
||||
test_indexed_dataset_get(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,10 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPL=cached
|
||||
python ../preprocess_data.py \
|
||||
--input test_samples.json \
|
||||
--vocab vocab.txt \
|
||||
--dataset-impl ${IMPL} \
|
||||
--output-prefix test_samples_${IMPL} \
|
||||
--workers 1 \
|
||||
--log-interval 2
|
@ -1,58 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import torch
|
||||
from torchvision import datasets, transforms
|
||||
from ascendspeed.data.autoaugment import ImageNetPolicy
|
||||
|
||||
|
||||
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
|
||||
|
||||
# training dataset
|
||||
train_data_path = os.path.join(data_path[0], "train")
|
||||
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
process = [
|
||||
transforms.RandomResizedCrop(crop_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_jitter:
|
||||
process += [
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
|
||||
)
|
||||
]
|
||||
fp16_t = transforms.ConvertImageDtype(torch.half)
|
||||
process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
|
||||
transform_train = transforms.Compose(process)
|
||||
train_data = datasets.ImageFolder(
|
||||
root=train_data_path, transform=transform_train
|
||||
)
|
||||
|
||||
# validation dataset
|
||||
val_data_path = os.path.join(data_path[0], "val")
|
||||
transform_val = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(crop_size),
|
||||
transforms.CenterCrop(crop_size),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
fp16_t
|
||||
]
|
||||
)
|
||||
val_data = datasets.ImageFolder(
|
||||
root=val_data_path, transform=transform_val
|
||||
)
|
||||
|
||||
return train_data, val_data
|
70
ascendspeed/data_classes.py
Normal file
70
ascendspeed/data_classes.py
Normal file
@ -0,0 +1,70 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MergeWeightConfig:
|
||||
entire_model_dic: dict
|
||||
tp_models: list
|
||||
k: int
|
||||
pp_i: int
|
||||
tot_i: int
|
||||
dim: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class MergePpTpMergeConfig:
|
||||
pp_size: int
|
||||
tp_size: int
|
||||
input_model_dir: int
|
||||
orig_vocab_size: int
|
||||
num_heads: int
|
||||
num_layer: int
|
||||
hid_size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenAscendWeightsConfig:
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
output_model_dir: str
|
||||
make_vocab_size_divisible_by: int
|
||||
num_heads: int
|
||||
num_layer: int
|
||||
hid_size: int
|
||||
added_token_num: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenAscendWeightsAgaConfig:
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
model_type: str
|
||||
out_model_dir: str
|
||||
make_vocab_size_divisible_by: int
|
||||
added_token_num: int
|
||||
|
||||
|
||||
@dataclass()
|
||||
class SaveAscendspeedModelConfig:
|
||||
model_dic: dict
|
||||
rank_model: dict
|
||||
pp_size: int
|
||||
tp_rank: int
|
||||
pp_rank: int
|
||||
release_model_dir: str
|
132
ascendspeed/error_utils.py
Normal file
132
ascendspeed/error_utils.py
Normal file
@ -0,0 +1,132 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
class NotDivisibleError(Exception):
|
||||
def __init__(self, denominator, molecule, error_info):
|
||||
super().__init__()
|
||||
self._error_info = error_info
|
||||
self._molecule = molecule
|
||||
self._denominator = denominator
|
||||
|
||||
def __str__(self):
|
||||
if self._error_info is None:
|
||||
return f"{self._denominator} is not divisible by {self._molecule}"
|
||||
else:
|
||||
return self._error_info.format(self._denominator, self._molecule)
|
||||
|
||||
|
||||
def check_divisible(denominator, molecule, error_info=None):
|
||||
if denominator % molecule == 0:
|
||||
return
|
||||
raise NotDivisibleError(denominator, molecule, error_info)
|
||||
|
||||
|
||||
class NotEqualError(Exception):
|
||||
def __init__(self, tensor_a, tensor_b, error_info):
|
||||
super().__init__()
|
||||
self._error_info = error_info
|
||||
self._tensor_a = tensor_a
|
||||
self._tensor_b = tensor_b
|
||||
|
||||
def __str__(self):
|
||||
if self._error_info is None:
|
||||
return f"{self._tensor_a} is not equal to {self._tensor_b}"
|
||||
else:
|
||||
return self._error_info.format(self._tensor_a, self._tensor_b)
|
||||
|
||||
|
||||
def check_equal(tensor_a, tensor_b, error_info=None):
|
||||
if tensor_a == tensor_b:
|
||||
return
|
||||
raise NotEqualError(tensor_a, tensor_b, error_info)
|
||||
|
||||
|
||||
class NotExistError(Exception):
|
||||
def __init__(self, item, container, error_info):
|
||||
super().__init__()
|
||||
self._error_info = error_info
|
||||
self._item = item
|
||||
self._container = container
|
||||
|
||||
def __str__(self):
|
||||
if self._error_info is None:
|
||||
return f"{self._item} is not in {self._container}"
|
||||
else:
|
||||
return self._error_info.format(self._item, self._container)
|
||||
|
||||
|
||||
def check_exist(item, container, error_info=None):
|
||||
if item in container:
|
||||
return
|
||||
raise NotExistError(item, container, error_info)
|
||||
|
||||
|
||||
class NotExpectedTypeError(Exception):
|
||||
def __init__(self, variable, expected_type, error_message=None):
|
||||
super().__init__()
|
||||
self._variable = variable
|
||||
self._expected_type = expected_type
|
||||
self._error_message = error_message
|
||||
|
||||
def __str__(self):
|
||||
if self._error_message:
|
||||
return self._error_message
|
||||
else:
|
||||
return f"expected {self._expected_type.__name__} find {type(self._variable).__name__}"
|
||||
|
||||
|
||||
def check_type(variable, expected_type, error_message=None):
|
||||
if isinstance(variable, expected_type):
|
||||
return
|
||||
else:
|
||||
raise NotExpectedTypeError(variable, expected_type, error_message)
|
||||
|
||||
|
||||
class IsNotNoneError(Exception):
|
||||
def __init__(self, error_message):
|
||||
super().__init__()
|
||||
self._error_message = error_message
|
||||
|
||||
def __repr__(self):
|
||||
if self._error_message:
|
||||
return self._error_message
|
||||
else:
|
||||
return "expect None variable"
|
||||
|
||||
|
||||
class IsNoneError(Exception):
|
||||
def __init__(self, error_message):
|
||||
super().__init__()
|
||||
self._error_message = error_message
|
||||
|
||||
def __repr__(self):
|
||||
if self._error_message:
|
||||
return self._error_message
|
||||
else:
|
||||
return "expect not None variable"
|
||||
|
||||
|
||||
def ensure_var_is_not_none(variable, error_message=None):
|
||||
if variable is not None:
|
||||
return
|
||||
else:
|
||||
raise IsNoneError(error_message=error_message)
|
||||
|
||||
|
||||
def ensure_var_is_none(variable, error_message=None):
|
||||
if variable is None:
|
||||
return
|
||||
else:
|
||||
raise IsNotNoneError(error_message)
|
@ -120,10 +120,14 @@ def _build_tokenizer(args):
|
||||
return _GLOBAL_TOKENIZER
|
||||
|
||||
|
||||
def rebuild_tokenizer(args):
|
||||
def rebuild_tokenizer(args, tokenizer=None):
|
||||
global _GLOBAL_TOKENIZER
|
||||
_GLOBAL_TOKENIZER = None
|
||||
return _build_tokenizer(args)
|
||||
if tokenizer is None:
|
||||
return _build_tokenizer(args)
|
||||
else:
|
||||
_GLOBAL_TOKENIZER = tokenizer
|
||||
return _GLOBAL_TOKENIZER
|
||||
|
||||
|
||||
def _set_tensorboard_writer(args):
|
||||
|
@ -22,6 +22,8 @@ import numpy as np
|
||||
import torch
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.runtime.pipe.engine import PipelineEngine
|
||||
from deepspeed import initialize
|
||||
|
||||
from ascendspeed import get_adlr_autoresume
|
||||
from ascendspeed import get_args
|
||||
@ -38,9 +40,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults=None,
|
||||
"""Set global variables, initialize distributed, and
|
||||
set autoresume and random seeds.
|
||||
`allow_no_cuda` should not be set unless using ascendspeed for cpu only
|
||||
data processing. In general this arg should not be set unless you know
|
||||
data processing. In general this arg should not be set unless you know
|
||||
what you are doing.
|
||||
Returns a function to finalize distributed env initialization
|
||||
Returns a function to finalize distributed env initialization
|
||||
(optionally, only when args.lazy_mpu_init == True)
|
||||
"""
|
||||
if not args_defaults:
|
||||
@ -71,7 +73,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults=None,
|
||||
if args.lazy_mpu_init:
|
||||
args.use_cpu_initialization = True
|
||||
# delayed initialization of DDP-related stuff
|
||||
# We only set basic DDP globals
|
||||
# We only set basic DDP globals
|
||||
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
|
||||
# and return function for external DDP manager
|
||||
# to call when it has DDP initialized
|
||||
@ -255,3 +257,14 @@ def _is_rank_0():
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def adaptor_deepspeed_initialize(*largs, **kwargs):
|
||||
return_items = initialize(*largs, **kwargs)
|
||||
args = kwargs.get('args')
|
||||
if args is not None:
|
||||
if isinstance(return_items[0], PipelineEngine):
|
||||
return_items[0].is_pipe_partitioned = return_items[0].is_pipe_partitioned and not args.sequence_parallel
|
||||
return_items[0].is_grad_partitioned = return_items[0].is_grad_partitioned and not args.sequence_parallel
|
||||
|
||||
return tuple(return_items)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from deepspeed.accelerator.real_accelerator import get_accelerator
|
||||
|
||||
if get_accelerator().device_name() == 'cuda':
|
||||
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
else:
|
||||
@ -22,4 +23,3 @@ from .gpt_model import GPTModel, GPTModelPipe
|
||||
from .llama_model import LlamaModel, LlamaModelPipe
|
||||
from .language_model import get_language_model
|
||||
from .module import Float16Module
|
||||
from .enums import ModelType
|
||||
|
226
ascendspeed/model/baichuan_model.py
Normal file
226
ascendspeed/model/baichuan_model.py
Normal file
@ -0,0 +1,226 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, Huawei CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import ascendspeed.model.llama_model as llama_model
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed.core.enums import AttnMaskType
|
||||
from ascendspeed.model.llama_model import LlamaModel
|
||||
from ascendspeed.model.llama_model import LlamaParallelTransformerLayer
|
||||
from ascendspeed.model.llama_model import RMSNorm
|
||||
from ascendspeed.model.module import MegatronModule
|
||||
|
||||
|
||||
class BaichuanParallelTransformer(MegatronModule):
|
||||
"""Transformer class."""
|
||||
|
||||
def __init__(self, init_method, output_layer_init_method,
|
||||
self_attn_mask_type=AttnMaskType.causal,
|
||||
pre_process=True, post_process=True):
|
||||
|
||||
super(BaichuanParallelTransformer, self).__init__()
|
||||
args = get_args()
|
||||
|
||||
self.bf16 = args.bf16
|
||||
self.fp32_residual_connection = args.fp32_residual_connection
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
self.input_tensor = None
|
||||
self.ds_inference = args.ds_inference
|
||||
self.init_method = init_method
|
||||
self.output_layer_init_method = output_layer_init_method
|
||||
|
||||
# Store activation checkpoiting flag.
|
||||
self.checkpoint_activations = args.checkpoint_activations
|
||||
self.checkpoint_num_layers = args.checkpoint_num_layers
|
||||
self.checkpoint_policy = args.checkpoint_policy
|
||||
self.checkpoint_block_layer = args.checkpoint_block_layer
|
||||
|
||||
# Number of layers.
|
||||
self.num_layers = args.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
|
||||
|
||||
# Transformer layers.
|
||||
def build_layer(layer_number):
|
||||
return LlamaParallelTransformerLayer(
|
||||
self.init_method,
|
||||
self.output_layer_init_method,
|
||||
layer_number)
|
||||
|
||||
if args.virtual_pipeline_model_parallel_size is not None:
|
||||
# Number of layers in each model chunk is the number of layers in the stage,
|
||||
# divided by the number of model chunks in a stage.
|
||||
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
|
||||
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
|
||||
# layers to stages like (each list is a model chunk):
|
||||
# Stage 0: [0] [2] [4] [6]
|
||||
# Stage 1: [1] [3] [5] [7]
|
||||
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
|
||||
# layers to stages like (each list is a model chunk):
|
||||
# Stage 0: [0, 1] [4, 5]
|
||||
# Stage 1: [2, 3] [6, 7]
|
||||
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
|
||||
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
|
||||
(parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
|
||||
else:
|
||||
# Each stage gets a contiguous set of layers.
|
||||
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
|
||||
|
||||
self.layers = []
|
||||
# Build the layers
|
||||
for i in range(self.num_layers):
|
||||
layer_num = i + 1 + offset
|
||||
self.layers.append(build_layer(layer_num))
|
||||
|
||||
self.layers = torch.nn.ModuleList(self.layers)
|
||||
|
||||
if self.post_process:
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""Set input tensor to be used instead of forward()'s input.
|
||||
|
||||
When doing pipeline parallelism the input from the previous
|
||||
stage comes from communication, not from the input, so the
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func
|
||||
"""
|
||||
self.input_tensor = input_tensor
|
||||
|
||||
def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False):
|
||||
# Reza's note: DeepSpeed inference does not support transposes
|
||||
if not self.ds_inference:
|
||||
if self.pre_process:
|
||||
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
||||
# If the input flag for fp32 residual connection is set, convert for float.
|
||||
if self.fp32_residual_connection:
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
|
||||
# Otherwise, leave it as is.
|
||||
else:
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
else:
|
||||
# See set_input_tensor()
|
||||
hidden_states = self.input_tensor
|
||||
|
||||
if self.checkpoint_activations and self.checkpoint_policy == 'full':
|
||||
hidden_states = self._checkpointed_forward(hidden_states, attention_mask)
|
||||
elif self.checkpoint_activations and self.checkpoint_policy == 'block':
|
||||
hidden_states = self._checkpointed_forward_block(hidden_states, attention_mask)
|
||||
else:
|
||||
if get_key_value:
|
||||
presents = []
|
||||
for index in range(self.num_layers):
|
||||
layer = self._get_layer(index)
|
||||
past = None
|
||||
if layer_past is not None:
|
||||
past = layer_past[index]
|
||||
hidden_states = layer(hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_past=past,
|
||||
get_key_value=get_key_value)
|
||||
if get_key_value:
|
||||
hidden_states, present = hidden_states
|
||||
presents.append(present)
|
||||
|
||||
# Final layer norm.
|
||||
if self.post_process:
|
||||
if not self.ds_inference:
|
||||
# Reverting data format change [s b h] --> [b s h].
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
output = self.final_layernorm(hidden_states)
|
||||
else:
|
||||
output = hidden_states
|
||||
if get_key_value:
|
||||
output = [output, presents]
|
||||
|
||||
return output
|
||||
|
||||
def _get_layer(self, layer_number):
|
||||
return self.layers[layer_number]
|
||||
|
||||
def _checkpointed_forward(self, hidden_states, attention_mask):
|
||||
"""Forward method with activation checkpointing."""
|
||||
|
||||
def custom(start, end):
|
||||
def custom_forward(*inputs):
|
||||
x_ = inputs[0]
|
||||
attention_mask = inputs[1]
|
||||
for index in range(start, end):
|
||||
layer = self._get_layer(index)
|
||||
x_ = layer(x_, attention_mask=attention_mask)
|
||||
return x_
|
||||
|
||||
return custom_forward
|
||||
|
||||
# Make sure memory is freed.
|
||||
mpu.reset_checkpointed_activations_memory_buffer()
|
||||
idx = 0
|
||||
while idx < self.num_layers:
|
||||
hidden_states = mpu.checkpoint(
|
||||
custom(idx, idx + self.checkpoint_num_layers),
|
||||
hidden_states, attention_mask)
|
||||
idx += self.checkpoint_num_layers
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _checkpointed_forward_block(self, hidden_states, attention_mask):
|
||||
"""Forward method with activation checkpointing."""
|
||||
|
||||
def custom(start, end):
|
||||
def custom_forward(*inputs):
|
||||
x_ = inputs[0]
|
||||
attention_mask = inputs[1]
|
||||
for index in range(start, end):
|
||||
layer = self._get_layer(index)
|
||||
x_ = layer(x_, attention_mask=attention_mask)
|
||||
return x_
|
||||
|
||||
return custom_forward
|
||||
|
||||
# Make sure memory is freed.
|
||||
for idx in range(self.num_layers):
|
||||
if idx < self.checkpoint_block_layer:
|
||||
hidden_states = mpu.checkpoint(
|
||||
custom(idx, idx + 1),
|
||||
hidden_states, attention_mask)
|
||||
else:
|
||||
hidden_states = custom(idx, idx + 1)(hidden_states, attention_mask)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BaichuanModel(LlamaModel):
|
||||
"""Baichuan Language model."""
|
||||
|
||||
def __init__(self, config, pre_process, post_process, parallel_output=True, add_pooler=False):
|
||||
super(BaichuanModel, self).__init__(config, pre_process, post_process, parallel_output, add_pooler)
|
||||
# Transformer.
|
||||
self.language_model = BaichuanParallelTransformer(
|
||||
self.init_method,
|
||||
self.output_layer_init_method,
|
||||
self_attn_mask_type=self.self_attn_mask_type,
|
||||
pre_process=self.pre_process,
|
||||
post_process=self.post_process,
|
||||
)
|
||||
|
||||
|
||||
BaichuanModelPipe = llama_model.LlamaModelPipe
|
@ -15,6 +15,7 @@
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
@ -27,10 +28,11 @@ from ascendspeed.core import parallel_state
|
||||
|
||||
class MemoryBuffer:
|
||||
|
||||
def __init__(self, numel, dtype):
|
||||
def __init__(self, numel, numel_padded, dtype):
|
||||
self.numel = numel
|
||||
self.numel_padded = numel_padded
|
||||
self.dtype = dtype
|
||||
self.data = torch.zeros(self.numel,
|
||||
self.data = torch.zeros(self.numel_padded,
|
||||
dtype=self.dtype,
|
||||
device=get_accelerator().current_device_name(),
|
||||
requires_grad=False)
|
||||
@ -121,8 +123,11 @@ class DistributedDataParallel(DistributedDataParallelBase):
|
||||
# the case we use continuous buffers.
|
||||
# ===================================
|
||||
self._grad_buffers = None
|
||||
self._grad_buffer_param_index_map = None
|
||||
if self.use_contiguous_buffers:
|
||||
self._grad_buffers = {}
|
||||
self._grad_buffer_param_index_map = {}
|
||||
data_parallel_world_size = parallel_state.get_data_parallel_world_size()
|
||||
|
||||
# Simple function to define buffer type.
|
||||
def _get_buffer_type(param):
|
||||
@ -139,7 +144,18 @@ class DistributedDataParallel(DistributedDataParallelBase):
|
||||
|
||||
# Allocate the buffer.
|
||||
for dtype, num_elements in type_num_elements.items():
|
||||
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
|
||||
|
||||
# If using distributed optimizer, pad memory buffer to be
|
||||
# multiple of data_parallel_world_size. (This padding is done
|
||||
# due to a constraint with the reduce_scatter op, which requires
|
||||
# all tensors have equal size. See: optimizer.py.)
|
||||
num_elements_padded = data_parallel_world_size * \
|
||||
int(math.ceil(num_elements / data_parallel_world_size))
|
||||
|
||||
# Allocate grad buffer.
|
||||
self._grad_buffers[dtype] = MemoryBuffer(num_elements,
|
||||
num_elements_padded,
|
||||
dtype)
|
||||
|
||||
# Assume the back prop order is reverse the params order,
|
||||
# store the start index for the gradients.
|
||||
@ -149,6 +165,12 @@ class DistributedDataParallel(DistributedDataParallelBase):
|
||||
type_num_elements[dtype] -= param.data.nelement()
|
||||
param.main_grad = self._grad_buffers[dtype].get(
|
||||
param.data.shape, type_num_elements[dtype])
|
||||
if dtype not in self._grad_buffer_param_index_map:
|
||||
self._grad_buffer_param_index_map[dtype] = {}
|
||||
self._grad_buffer_param_index_map[dtype][param] = (
|
||||
type_num_elements[dtype],
|
||||
type_num_elements[dtype] + param.data.nelement(),
|
||||
)
|
||||
|
||||
# Backward hook.
|
||||
# Accumalation function for the gradients. We need
|
||||
@ -164,19 +186,18 @@ class DistributedDataParallel(DistributedDataParallelBase):
|
||||
grad_acc.register_hook(self._make_param_hook(param))
|
||||
self.grad_accs.append(grad_acc)
|
||||
|
||||
|
||||
def _make_param_hook(self, param):
|
||||
"""Create the all-reduce hook for backprop."""
|
||||
# Hook used for back-prop.
|
||||
def param_hook(*unused):
|
||||
# Add the gradient to the buffer.
|
||||
if param.grad.data is not None:
|
||||
if param.grad is not None:
|
||||
# The gradient function of linear layers is fused with GEMMs
|
||||
param.main_grad.add_(param.grad.data)
|
||||
# Now we can deallocate grad memory.
|
||||
param.grad = None
|
||||
return param_hook
|
||||
|
||||
|
||||
def zero_grad_buffer(self):
|
||||
"""Set the grad buffer data to zero. Needs to be called at the
|
||||
begining of each iteration."""
|
||||
@ -184,16 +205,23 @@ class DistributedDataParallel(DistributedDataParallelBase):
|
||||
for _, buffer_ in self._grad_buffers.items():
|
||||
buffer_.zero()
|
||||
|
||||
def broadcast_params(self):
|
||||
for param in self.module.parameters():
|
||||
torch.distributed.broadcast(param.data,
|
||||
src=parallel_state.get_data_parallel_src_rank(),
|
||||
group=parallel_state.get_data_parallel_group())
|
||||
|
||||
def allreduce_gradients(self):
|
||||
def allreduce_gradients(self, async_op=False):
|
||||
"""Reduce gradients across data parallel ranks."""
|
||||
# If we have buffers, simply reduce the data in the buffer.
|
||||
|
||||
handles = []
|
||||
if self._grad_buffers is not None:
|
||||
for _, buffer_ in self._grad_buffers.items():
|
||||
buffer_.data /= parallel_state.get_data_parallel_world_size()
|
||||
torch.distributed.all_reduce(
|
||||
buffer_.data, group=parallel_state.get_data_parallel_group())
|
||||
handle = torch.distributed.all_reduce(
|
||||
buffer_.data, group=parallel_state.get_data_parallel_group(),
|
||||
async_op=async_op)
|
||||
handles.append(handle)
|
||||
else:
|
||||
# Otherwise, bucketize and all-reduce
|
||||
buckets = {}
|
||||
@ -217,3 +245,4 @@ class DistributedDataParallel(DistributedDataParallelBase):
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||
coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
return handles
|
@ -1,36 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
encoder_or_decoder = 1
|
||||
encoder_and_decoder = 2
|
||||
|
||||
class LayerType(enum.Enum):
|
||||
encoder = 1
|
||||
decoder = 2
|
||||
|
||||
class AttnType(enum.Enum):
|
||||
self_attn = 1
|
||||
cross_attn = 2
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
causal = 2 # Overrides `attention_mask` to be a lower triangular matrix
|
||||
prefix = 3
|
||||
# Forces one to pass an `attention_mask` that's 1 if we need to mask.
|
||||
# Tensor that can be broadcast to [micro_batch_size, n_head, seq_length, seq_length]
|
||||
custom = 4
|
@ -16,7 +16,7 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
import torch.nn as nn
|
||||
from ascendspeed.model.enums import AttnMaskType
|
||||
from ascendspeed.core.enums import AttnMaskType
|
||||
|
||||
|
||||
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
@ -137,7 +137,7 @@ class NPUFusedScaleMaskSoftmax(nn.Module):
|
||||
return (
|
||||
self.scaled_masked_softmax_fusion # user want to fuse
|
||||
and self.input_in_float16 # input must be fp16
|
||||
and 32 < sk <= 2048 # sk must be 32 ~ 2048
|
||||
and 32 < sk <= 4096 # sk must be 32 ~ 4096
|
||||
and sq % 16 == 0 # sq must be divisor of 16
|
||||
and sk % 16 == 0 # sk must be divisor of 16
|
||||
)
|
||||
|
@ -15,24 +15,25 @@
|
||||
|
||||
"""GPT-2 model."""
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
|
||||
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed.core import tensor_parallel, parallel_state
|
||||
from .module import MegatronModule, fp32_to_float16
|
||||
from ascendspeed.model import LayerNorm
|
||||
from ascendspeed.model.fused_layer_norm import MixedFusedLayerNorm
|
||||
from ascendspeed.model.module import float16_to_fp32
|
||||
from ascendspeed.core.enums import AttnMaskType
|
||||
|
||||
from .enums import AttnMaskType
|
||||
from .language_model import parallel_lm_logits
|
||||
from .language_model import get_language_model
|
||||
from .utils import init_method_normal
|
||||
from .utils import scaled_init_method_normal
|
||||
|
||||
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
|
||||
from ascendspeed.model import LayerNorm
|
||||
from ascendspeed.model.module import float16_to_fp32
|
||||
from .module import MegatronModule, MegatronModuleForCausalLM, fp32_to_float16
|
||||
from .language_model import EmbeddingPipe
|
||||
from .transformer import ParallelTransformerLayerPipe
|
||||
from .manual_pipe import ManuallyAllocatedPipelineModule
|
||||
|
||||
|
||||
def post_language_model_processing(lm_output, labels, logit_weights,
|
||||
@ -64,17 +65,27 @@ def post_language_model_processing(lm_output, labels, logit_weights,
|
||||
return loss
|
||||
|
||||
|
||||
class GPTModel(MegatronModule):
|
||||
class LayerNormLayer(MegatronModule):
|
||||
def __init__(self, hidden_size, eps):
|
||||
super(LayerNormLayer, self).__init__()
|
||||
self.final_layernorm = torch.nn.LayerNorm(hidden_size, eps)
|
||||
|
||||
def forward(self, norm_input):
|
||||
return self.final_layernorm(norm_input)
|
||||
|
||||
|
||||
class GPTModel(MegatronModule, MegatronModuleForCausalLM):
|
||||
"""GPT-2 Language model."""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
num_tokentypes=0,
|
||||
parallel_output=True,
|
||||
pre_process=True,
|
||||
post_process=True,
|
||||
prefix_lm=False,
|
||||
return_moe_loss=True):
|
||||
super(GPTModel, self).__init__()
|
||||
super(GPTModel, self).__init__(config,)
|
||||
args = get_args()
|
||||
|
||||
self.parallel_output = parallel_output
|
||||
@ -133,7 +144,7 @@ class GPTModel(MegatronModule):
|
||||
self.parallel_output,
|
||||
forward_method_parallel_output,
|
||||
self.fp16_lm_cross_entropy)
|
||||
|
||||
|
||||
if self.return_moe_loss:
|
||||
return (lm_output, *moe_losses)
|
||||
else:
|
||||
@ -215,11 +226,13 @@ def get_cross_entropy(is_prefix: bool):
|
||||
return loss
|
||||
return CrossEntropy
|
||||
|
||||
class GPTModelPipe(PipelineModule,MegatronModule):
|
||||
|
||||
class GPTModelPipe(ManuallyAllocatedPipelineModule, MegatronModule, MegatronModuleForCausalLM):
|
||||
"""GPT-2 Language model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
num_tokentypes=0,
|
||||
parallel_output=True,
|
||||
attn_mask_type: AttnMaskType = AttnMaskType.causal
|
||||
@ -267,13 +280,15 @@ class GPTModelPipe(PipelineModule,MegatronModule):
|
||||
|
||||
for layer_idx in range(args.num_layers):
|
||||
self.specs.append(
|
||||
LayerSpec(ParallelTransformerLayerPipe,
|
||||
LayerSpec(
|
||||
ParallelTransformerLayerPipe,
|
||||
init_method=init_method,
|
||||
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
|
||||
args.num_layers),
|
||||
layer_number=layer_idx,
|
||||
# TODO: Change naming of class from GPT to something that encapsulate prefix lm.
|
||||
self_attn_mask_type=attn_mask_type))
|
||||
self_attn_mask_type=attn_mask_type)
|
||||
)
|
||||
|
||||
# Undo data format change
|
||||
def undo(x):
|
||||
@ -283,10 +298,18 @@ class GPTModelPipe(PipelineModule,MegatronModule):
|
||||
self.specs.append(undo)
|
||||
|
||||
# Final layernorm after transformer layers
|
||||
self.specs.append(
|
||||
LayerSpec(LayerNorm,
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon))
|
||||
if args.sequence_parallel:
|
||||
self.specs.append(
|
||||
LayerSpec(MixedFusedLayerNorm,
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel))
|
||||
|
||||
else:
|
||||
self.specs.append(
|
||||
LayerSpec(LayerNormLayer,
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon))
|
||||
|
||||
def _logits_helper(embedding, lm_output):
|
||||
"""A wrapper to massage inputs/outputs from pipeline. """
|
||||
|
158
ascendspeed/model/internlm_model.py
Normal file
158
ascendspeed/model/internlm_model.py
Normal file
@ -0,0 +1,158 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.core import parallel_state, utils
|
||||
from ascendspeed.model.module import MegatronModule
|
||||
from ascendspeed.core.enums import AttnMaskType, AttnType
|
||||
from ascendspeed.model.utils import attention_mask_func
|
||||
from ascendspeed.model.fused_softmax import NPUFusedScaleMaskSoftmax
|
||||
from ascendspeed.model.triangle_attention import TriangleAttention
|
||||
from ascendspeed.model import llama_model
|
||||
|
||||
|
||||
class InternLMParallelAttention(MegatronModule):
|
||||
"""Parallel self-attention layer abstract class.
|
||||
|
||||
Self-attention layer takes input with size [b, s, h]
|
||||
and returns output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, init_method,
|
||||
output_layer_init_method, layer_number,
|
||||
attention_type=AttnType.self_attn,
|
||||
attn_mask_type=AttnMaskType.causal):
|
||||
super(InternLMParallelAttention, self).__init__()
|
||||
args = get_args()
|
||||
self.fp16 = args.fp16
|
||||
self.bf16 = args.bf16
|
||||
self.sequence_parallel = args.sequence_parallel
|
||||
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
|
||||
if self.apply_query_key_layer_scaling:
|
||||
self.attention_softmax_in_fp32 = True
|
||||
self.layer_number = max(1, layer_number)
|
||||
self.attention_type = attention_type
|
||||
self.attn_mask_type = attn_mask_type
|
||||
self.init_method = init_method
|
||||
self.output_layer_init_method = output_layer_init_method
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
projection_size = args.kv_channels * args.num_attention_heads
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.hidden_size_per_partition = utils.divide(projection_size, world_size)
|
||||
self.hidden_size_per_attention_head = utils.divide(
|
||||
projection_size, args.num_attention_heads)
|
||||
self.num_attention_heads_per_partition = utils.divide(
|
||||
args.num_attention_heads, world_size)
|
||||
if attention_type == AttnType.self_attn:
|
||||
self.query_key_value = mpu.ColumnParallelLinear(
|
||||
args.hidden_size, 3 * projection_size, bias=True, gather_output=False,
|
||||
init_method=self.init_method, sequence_parallel_enabled=self.sequence_parallel)
|
||||
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
|
||||
self.scale_mask_softmax = NPUFusedScaleMaskSoftmax(
|
||||
self.fp16, self.bf16,
|
||||
self.attn_mask_type,
|
||||
args.masked_softmax_fusion,
|
||||
attention_mask_func,
|
||||
self.attention_softmax_in_fp32,
|
||||
(1 / self.norm_factor))
|
||||
|
||||
self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head)
|
||||
self.use_triangle_attn = args.triangle_attn
|
||||
if self.use_triangle_attn:
|
||||
self.triangle_attn = TriangleAttention(block_size=1024,
|
||||
masked_softmax_func=self.scale_mask_softmax)
|
||||
self.dense = mpu.RowParallelLinear(
|
||||
projection_size, args.hidden_size, bias=True, input_is_parallel=True,
|
||||
init_method=self.output_layer_init_method, skip_bias_add=False,
|
||||
sequence_parallel_enabled=self.sequence_parallel)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, layer_past=None,
|
||||
get_key_value=False):
|
||||
|
||||
if self.attention_type == AttnType.self_attn:
|
||||
mixed_x_layer, _ = self.query_key_value(hidden_states)
|
||||
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
(query_layer,
|
||||
key_layer,
|
||||
value_layer) = utils.split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
|
||||
query_layer = query_layer.permute(1, 2, 0, 3).contiguous()
|
||||
key_layer = key_layer.permute(1, 2, 0, 3).contiguous()
|
||||
value_layer = value_layer.permute(1, 2, 0, 3).contiguous()
|
||||
|
||||
cos, sin = self.rotary_emb(value_layer, seq_len=new_tensor_shape[0])
|
||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, offset=0)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer = torch.cat((past_key.type_as(key_layer),
|
||||
key_layer), dim=0)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer),
|
||||
value_layer), dim=0)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
if self.use_triangle_attn and layer_past is None:
|
||||
context_layer = self.triangle_attn(query_layer, key_layer, value_layer, attention_mask)
|
||||
output, _ = self.dense(context_layer)
|
||||
return output
|
||||
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(3, 2))
|
||||
if get_key_value:
|
||||
with torch.no_grad():
|
||||
if layer_past is not None:
|
||||
attention_mask = attention_mask[
|
||||
...,
|
||||
attention_scores.size(3) - 1,
|
||||
:attention_scores.size(3)].unsqueeze(2)
|
||||
else:
|
||||
attention_mask = attention_mask[
|
||||
...,
|
||||
:attention_scores.size(3),
|
||||
:attention_scores.size(3)]
|
||||
attention_probs = self.scale_mask_softmax(attention_scores,
|
||||
attention_mask)
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
bs, nh, sq, hd = context_layer.shape
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
context_layer = context_layer.view(sq, bs, nh * hd)
|
||||
|
||||
output, _ = self.dense(context_layer)
|
||||
|
||||
if get_key_value:
|
||||
output = [output, present]
|
||||
|
||||
return output
|
||||
|
||||
RotaryEmbedding = llama_model.RotaryEmbedding
|
||||
apply_rotary_pos_emb = llama_model.apply_rotary_pos_emb
|
||||
llama_model.LlamaParallelAttention = InternLMParallelAttention
|
||||
InternModel = llama_model.LlamaModel
|
||||
InternModelPipe = llama_model.LlamaModelPipe
|
@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.enums import PositionEmbeddingType
|
||||
from ascendspeed.model.enums import LayerType, AttnMaskType
|
||||
from ascendspeed.core.enums import LayerType, AttnMaskType
|
||||
from ascendspeed.model.module import MegatronModule
|
||||
from ascendspeed.model.transformer import ParallelTransformer
|
||||
from ascendspeed.model.utils import get_linear_layer
|
||||
@ -316,7 +316,6 @@ class EmbeddingPipe(Embedding):
|
||||
if hasattr(self._args, 'attn_mask'):
|
||||
return embeddings
|
||||
else:
|
||||
assert False
|
||||
return embeddings, attention_mask
|
||||
|
||||
|
||||
|
@ -29,8 +29,8 @@ from deepspeed.pipe import PipelineModule, LayerSpec
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.core import tensor_parallel, parallel_state, utils
|
||||
from ascendspeed.model.module import MegatronModule, float16_to_fp32, fp32_to_float16
|
||||
from ascendspeed.model.enums import AttnMaskType, AttnType
|
||||
from ascendspeed.model.module import MegatronModule, MegatronModuleForCausalLM, float16_to_fp32, fp32_to_float16
|
||||
from ascendspeed.core.enums import AttnMaskType, AttnType
|
||||
from ascendspeed.model.utils import init_method_normal, scaled_init_method_normal, attention_mask_func
|
||||
from ascendspeed.mpu.mappings import scatter_to_sequence_parallel_region
|
||||
from ascendspeed.model.fused_softmax import NPUFusedScaleMaskSoftmax
|
||||
@ -83,13 +83,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module): # for cpu
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
def __init__(self, hidden_size, eps=1e-6, sequence_parallel=False):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
setattr(self.weight, 'sequence_parallel', sequence_parallel)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
@ -119,7 +120,7 @@ class Llama2LMHead(MegatronModule):
|
||||
self.hidden_size = hidden_size
|
||||
self.init_method = init_method
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
self.sequence_parallel = args.sequence_parallel
|
||||
self.lm_head = mpu.ColumnParallelLinear(input_size=self.hidden_size,
|
||||
output_size=vocab_size,
|
||||
bias=False,
|
||||
@ -129,9 +130,9 @@ class Llama2LMHead(MegatronModule):
|
||||
sequence_parallel_enabled=args.sequence_parallel)
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = inputs.transpose(0, 1).contiguous()
|
||||
inputs = inputs.transpose(0, 1).contiguous() if self.sequence_parallel else inputs
|
||||
logits, _ = self.lm_head(inputs)
|
||||
logits = logits.transpose(0, 1).contiguous() # SBH-->BSH
|
||||
logits = logits.transpose(0, 1).contiguous() if self.sequence_parallel else logits
|
||||
return logits
|
||||
|
||||
|
||||
@ -211,14 +212,14 @@ class Llama2EmbeddingPipe(Llama2Embedding):
|
||||
if hasattr(self._args, 'attn_mask'):
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = inputs[1]
|
||||
attention_mask = inputs[-1]
|
||||
|
||||
embeddings = super().forward(input_ids)
|
||||
# If cmd args has attn_mask, we don't forward it as an activation.
|
||||
if hasattr(self._args, 'attn_mask'):
|
||||
return embeddings
|
||||
else:
|
||||
return embeddings, attention_mask
|
||||
if not hasattr(self._args, 'attn_mask'):
|
||||
setattr(self._args, 'attn_mask', attention_mask)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class Llama2ParallelMLP(MegatronModule):
|
||||
@ -334,11 +335,7 @@ class Llama2ParallelAttention(MegatronModule):
|
||||
init_method=self.init_method,
|
||||
sequence_parallel_enabled=self.sequence_parallel)
|
||||
|
||||
coeff = None
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
if self.apply_query_key_layer_scaling:
|
||||
coeff = self.layer_number
|
||||
self.norm_factor *= coeff
|
||||
|
||||
self.scale_mask_softmax = NPUFusedScaleMaskSoftmax(
|
||||
self.fp16, self.bf16,
|
||||
@ -346,7 +343,7 @@ class Llama2ParallelAttention(MegatronModule):
|
||||
args.masked_softmax_fusion,
|
||||
attention_mask_func,
|
||||
self.attention_softmax_in_fp32,
|
||||
coeff)
|
||||
(1 / self.norm_factor))
|
||||
|
||||
## Rotary Position Embedding
|
||||
self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head)
|
||||
@ -417,7 +414,6 @@ class Llama2ParallelAttention(MegatronModule):
|
||||
# Raw attention scores. [b, np, s, s]
|
||||
# ===================================
|
||||
|
||||
query_layer *= (1.0 / self.norm_factor)
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(3, 2))
|
||||
|
||||
# ==================================================
|
||||
@ -486,7 +482,8 @@ class Llama2ParallelTransformerLayer(MegatronModule):
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
# Self attention.
|
||||
self.attention = Llama2ParallelAttention(
|
||||
@ -498,7 +495,8 @@ class Llama2ParallelTransformerLayer(MegatronModule):
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
# MLP
|
||||
self.rank = args.rank
|
||||
@ -642,7 +640,8 @@ class Llama2ParallelTransformer(MegatronModule):
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
if deepspeed.checkpointing.is_configured():
|
||||
global get_cuda_rng_tracker, checkpoint
|
||||
@ -685,7 +684,10 @@ class Llama2ParallelTransformer(MegatronModule):
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func"""
|
||||
self.input_tensor = input_tensor
|
||||
if isinstance(input_tensor, (list, tuple)):
|
||||
self.input_tensor = input_tensor[0]
|
||||
else:
|
||||
self.input_tensor = input_tensor
|
||||
|
||||
def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False):
|
||||
|
||||
@ -755,7 +757,7 @@ def CrossEntropy(output, labels):
|
||||
return loss
|
||||
|
||||
|
||||
class Llama2ModelPipe(PipelineModule, MegatronModule):
|
||||
class Llama2ModelPipe(PipelineModule, MegatronModule, MegatronModuleForCausalLM):
|
||||
"""llama Language model."""
|
||||
|
||||
def __init__(self, parallel_output=True):
|
||||
@ -797,7 +799,9 @@ class Llama2ModelPipe(PipelineModule, MegatronModule):
|
||||
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
|
||||
|
||||
# Final layernorm after transformer layers
|
||||
self.specs.append(LayerSpec(RMSNorm, args.hidden_size, eps=args.layernorm_epsilon))
|
||||
self.specs.append(
|
||||
LayerSpec(RMSNorm, args.hidden_size, eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel))
|
||||
|
||||
self.specs.append(
|
||||
LayerSpec(Llama2LMHeadPipe, hidden_size=args.hidden_size, vocab_size=args.padded_vocab_size,
|
||||
@ -825,11 +829,11 @@ class Llama2ModelPipe(PipelineModule, MegatronModule):
|
||||
partition_method='type:transformer')
|
||||
|
||||
|
||||
class Llama2Model(MegatronModule):
|
||||
class Llama2Model(MegatronModule, MegatronModuleForCausalLM):
|
||||
"""llaMA2 Language model."""
|
||||
|
||||
def __init__(self, pre_process, post_process, parallel_output=True, add_pooler=False):
|
||||
super(Llama2Model, self).__init__(share_word_embeddings=False)
|
||||
def __init__(self, config, pre_process, post_process, parallel_output=True, add_pooler=False, **kwargs):
|
||||
super(Llama2Model, self).__init__(config=config, share_word_embeddings=False)
|
||||
args = get_args()
|
||||
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
|
||||
self.hidden_size = args.hidden_size
|
||||
@ -844,8 +848,8 @@ class Llama2Model(MegatronModule):
|
||||
|
||||
if self.pre_process:
|
||||
self.embedding = Llama2Embedding(hidden_size=args.hidden_size,
|
||||
init_method=self.init_method,
|
||||
vocab_size=self.padded_vocab_size)
|
||||
init_method=self.init_method,
|
||||
vocab_size=self.padded_vocab_size)
|
||||
|
||||
# Transformer.
|
||||
self.language_model = Llama2ParallelTransformer(
|
||||
|
@ -1,4 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, Huawei Technologies Co., Ltd. All rights reserved.
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -19,26 +20,31 @@ Following implementation from huggingface, https://github.com/huggingface/transf
|
||||
"""
|
||||
|
||||
import math
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
logging.warning("Import torch_npu Error.")
|
||||
|
||||
import torch.nn.functional as F
|
||||
import deepspeed
|
||||
from deepspeed.pipe import PipelineModule, LayerSpec
|
||||
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.core import tensor_parallel, parallel_state, utils
|
||||
from ascendspeed.model.module import MegatronModule, float16_to_fp32, fp32_to_float16
|
||||
from ascendspeed.model.enums import AttnMaskType, LayerType, AttnType
|
||||
from ascendspeed.model.utils import get_linear_layer, init_method_normal, scaled_init_method_normal, attention_mask_func, \
|
||||
from ascendspeed.model.module import MegatronModule, MegatronModuleForCausalLM, float16_to_fp32, fp32_to_float16
|
||||
from ascendspeed.core.enums import AttnMaskType, LayerType, AttnType
|
||||
from ascendspeed.model.utils import get_linear_layer, init_method_normal, scaled_init_method_normal, \
|
||||
attention_mask_func, \
|
||||
openai_gelu, erf_gelu
|
||||
|
||||
from ascendspeed.mpu.mappings import scatter_to_sequence_parallel_region
|
||||
from ascendspeed.model.fused_softmax import NPUFusedScaleMaskSoftmax
|
||||
from ascendspeed.model.language_model import Pooler
|
||||
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.pipe import PipelineModule, LayerSpec
|
||||
from ascendspeed.model.triangle_attention import TriangleAttention
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
@ -86,18 +92,27 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def apply_fused_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
||||
return torch_npu.npu_rotary_mul(q, cos, sin), torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module): # for cpu
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
def __init__(self, hidden_size, eps=1e-6, sequence_parallel=False):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
setattr(self.weight, 'sequence_parallel', sequence_parallel)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon).half()
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = self.weight * hidden_states
|
||||
|
||||
return hidden_states
|
||||
@ -123,7 +138,7 @@ class LlamaLMHead(MegatronModule):
|
||||
self.hidden_size = hidden_size
|
||||
self.init_method = init_method
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
self.sequence_parallel = args.sequence_parallel
|
||||
self.lm_head = mpu.ColumnParallelLinear(input_size=self.hidden_size,
|
||||
output_size=vocab_size,
|
||||
bias=False,
|
||||
@ -133,9 +148,9 @@ class LlamaLMHead(MegatronModule):
|
||||
sequence_parallel_enabled=args.sequence_parallel)
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = inputs.transpose(0, 1).contiguous()
|
||||
inputs = inputs.transpose(0, 1).contiguous() if self.sequence_parallel else inputs
|
||||
logits, _ = self.lm_head(inputs)
|
||||
logits = logits.transpose(0, 1).contiguous() # SBH-->BSH
|
||||
logits = logits.transpose(0, 1).contiguous() if self.sequence_parallel else logits
|
||||
return logits
|
||||
|
||||
|
||||
@ -215,14 +230,14 @@ class LlamaEmbeddingPipe(LlamaEmbedding):
|
||||
if hasattr(self._args, 'attn_mask'):
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = inputs[1]
|
||||
attention_mask = inputs[-1]
|
||||
|
||||
embeddings = super().forward(input_ids)
|
||||
# If cmd args has attn_mask, we don't forward it as an activation.
|
||||
if hasattr(self._args, 'attn_mask'):
|
||||
return embeddings
|
||||
else:
|
||||
return embeddings, attention_mask
|
||||
if not hasattr(self._args, 'attn_mask'):
|
||||
setattr(self._args, 'attn_mask', attention_mask)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class LlamaParallelMLP(MegatronModule):
|
||||
@ -320,12 +335,9 @@ class LlamaParallelAttention(MegatronModule):
|
||||
|
||||
# Per attention head and per partition values.
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.hidden_size_per_partition = utils.divide(projection_size,
|
||||
world_size)
|
||||
self.hidden_size_per_attention_head = utils.divide(
|
||||
projection_size, args.num_attention_heads)
|
||||
self.num_attention_heads_per_partition = utils.divide(
|
||||
args.num_attention_heads, world_size)
|
||||
self.hidden_size_per_partition = utils.divide(projection_size, world_size)
|
||||
self.hidden_size_per_attention_head = utils.divide(projection_size, args.num_attention_heads)
|
||||
self.num_attention_heads_per_partition = utils.divide(args.num_attention_heads, world_size)
|
||||
|
||||
# Strided linear layer.
|
||||
if attention_type == AttnType.self_attn:
|
||||
@ -337,11 +349,7 @@ class LlamaParallelAttention(MegatronModule):
|
||||
init_method=self.init_method,
|
||||
sequence_parallel_enabled=self.sequence_parallel)
|
||||
|
||||
coeff = None
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
if self.apply_query_key_layer_scaling:
|
||||
coeff = self.layer_number
|
||||
self.norm_factor *= coeff
|
||||
|
||||
self.scale_mask_softmax = NPUFusedScaleMaskSoftmax(
|
||||
self.fp16, self.bf16,
|
||||
@ -349,11 +357,18 @@ class LlamaParallelAttention(MegatronModule):
|
||||
args.masked_softmax_fusion,
|
||||
attention_mask_func,
|
||||
self.attention_softmax_in_fp32,
|
||||
coeff)
|
||||
(1 / self.norm_factor))
|
||||
|
||||
## Rotary Position Embedding
|
||||
self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head)
|
||||
self.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
if args.use_fused_rotary_pos_emb:
|
||||
self.apply_rotary_pos_emb = apply_fused_rotary_pos_emb
|
||||
|
||||
self.use_triangle_attn = args.triangle_attn
|
||||
if self.use_triangle_attn:
|
||||
self.triangle_attn = TriangleAttention(block_size=1024,
|
||||
masked_softmax_func=self.scale_mask_softmax)
|
||||
# Output.
|
||||
self.dense = mpu.RowParallelLinear(
|
||||
projection_size,
|
||||
@ -400,8 +415,7 @@ class LlamaParallelAttention(MegatronModule):
|
||||
value_layer = value_layer.permute(1, 2, 0, 3).contiguous()
|
||||
|
||||
cos, sin = self.rotary_emb(value_layer, seq_len=new_tensor_shape[0])
|
||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, offset=0)
|
||||
|
||||
query_layer, key_layer = self.apply_rotary_pos_emb(query_layer, key_layer, cos, sin, offset=0)
|
||||
|
||||
# ==================================
|
||||
# Adjust key and value for inference
|
||||
@ -416,11 +430,15 @@ class LlamaParallelAttention(MegatronModule):
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
# use triangle attention
|
||||
if self.use_triangle_attn and layer_past is None:
|
||||
context_layer = self.triangle_attn(query_layer, key_layer, value_layer, attention_mask)
|
||||
output, _ = self.dense(context_layer)
|
||||
return output
|
||||
|
||||
# ===================================
|
||||
# Raw attention scores. [b, np, s, s]
|
||||
# ===================================
|
||||
|
||||
query_layer *= (1.0 / self.norm_factor)
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(3, 2))
|
||||
|
||||
# ==================================================
|
||||
@ -447,6 +465,8 @@ class LlamaParallelAttention(MegatronModule):
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
attention_probs = self.scale_mask_softmax(attention_scores,
|
||||
attention_mask)
|
||||
if self.bf16:
|
||||
attention_probs = attention_probs.bfloat16()
|
||||
|
||||
# =========================
|
||||
# Context layer. [sq, b, hp]
|
||||
@ -455,7 +475,7 @@ class LlamaParallelAttention(MegatronModule):
|
||||
|
||||
bs, nh, sq, hd = context_layer.shape
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
context_layer = context_layer.view(sq, bs, nh*hd)
|
||||
context_layer = context_layer.view(sq, bs, nh * hd)
|
||||
|
||||
output, _ = self.dense(context_layer)
|
||||
|
||||
@ -489,7 +509,8 @@ class LlamaParallelTransformerLayer(MegatronModule):
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
# Self attention.
|
||||
self.attention = LlamaParallelAttention(
|
||||
@ -501,7 +522,8 @@ class LlamaParallelTransformerLayer(MegatronModule):
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
# MLP
|
||||
self.rank = args.rank
|
||||
@ -645,7 +667,8 @@ class LlamaParallelTransformer(MegatronModule):
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
if deepspeed.checkpointing.is_configured():
|
||||
global get_cuda_rng_tracker, checkpoint
|
||||
@ -688,7 +711,10 @@ class LlamaParallelTransformer(MegatronModule):
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func"""
|
||||
self.input_tensor = input_tensor
|
||||
if isinstance(input_tensor, (list, tuple)):
|
||||
self.input_tensor = input_tensor[0]
|
||||
else:
|
||||
self.input_tensor = input_tensor
|
||||
|
||||
def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False):
|
||||
|
||||
@ -758,10 +784,10 @@ def CrossEntropy(output, labels):
|
||||
return loss
|
||||
|
||||
|
||||
class LlamaModelPipe(PipelineModule, MegatronModule):
|
||||
class LlamaModelPipe(PipelineModule, MegatronModule, MegatronModuleForCausalLM):
|
||||
"""llama Language model."""
|
||||
|
||||
def __init__(self, parallel_output=True):
|
||||
def __init__(self, config, parallel_output=True):
|
||||
args = get_args()
|
||||
|
||||
self.init_method = init_method_normal(args.init_method_std)
|
||||
@ -800,7 +826,9 @@ class LlamaModelPipe(PipelineModule, MegatronModule):
|
||||
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
|
||||
|
||||
# Final layernorm after transformer layers
|
||||
self.specs.append(LayerSpec(RMSNorm, args.hidden_size, eps=args.layernorm_epsilon))
|
||||
self.specs.append(
|
||||
LayerSpec(RMSNorm, args.hidden_size, eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel))
|
||||
|
||||
self.specs.append(
|
||||
LayerSpec(LlamaLMHeadPipe, hidden_size=args.hidden_size, vocab_size=args.padded_vocab_size,
|
||||
@ -828,11 +856,11 @@ class LlamaModelPipe(PipelineModule, MegatronModule):
|
||||
partition_method='type:transformer')
|
||||
|
||||
|
||||
class LlamaModel(MegatronModule):
|
||||
class LlamaModel(MegatronModule, MegatronModuleForCausalLM):
|
||||
"""llama Language model."""
|
||||
|
||||
def __init__(self, pre_process, post_process, parallel_output=True, add_pooler=False):
|
||||
super(LlamaModel, self).__init__(share_word_embeddings=False)
|
||||
def __init__(self, config, pre_process, post_process, parallel_output=True, add_pooler=False):
|
||||
super(LlamaModel, self).__init__(config, share_word_embeddings=False)
|
||||
args = get_args()
|
||||
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
|
||||
self.hidden_size = args.hidden_size
|
||||
@ -873,17 +901,14 @@ class LlamaModel(MegatronModule):
|
||||
"""See ascendspeed.model.transformer.set_input_tensor()"""
|
||||
self.language_model.set_input_tensor(input_tensor)
|
||||
|
||||
def forward(self, input_ids, attention_mask, labels=None, layer_past=None, get_key_value=False):
|
||||
args = get_args()
|
||||
|
||||
def forward(self, input_ids, attention_mask, labels=None, layer_past=None, get_key_value=False, **kwargs):
|
||||
if self.pre_process:
|
||||
hidden_states = self.embedding(input_ids)
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
|
||||
# decoder
|
||||
hidden_states = self.language_model(hidden_states, attention_mask, layer_past=layer_past,
|
||||
get_key_value=get_key_value)
|
||||
get_key_value=get_key_value)
|
||||
|
||||
if self.post_process:
|
||||
if get_key_value:
|
||||
@ -908,3 +933,26 @@ class LlamaModel(MegatronModule):
|
||||
return loss
|
||||
|
||||
return hidden_states
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
"""
|
||||
不能在这里调用self.state_dict函数,否则deepspeed保存时把state_dict挂成state_dict_for_save_checkpoint会循环调用
|
||||
"""
|
||||
state_dict_ = {}
|
||||
language_model_state_dict = self.language_model.state_dict_for_save_checkpoint(
|
||||
prefix=prefix + 'language_model.', keep_vars=keep_vars)
|
||||
# MoE states need to be handled separately by DeepSpeed engine, thus
|
||||
# moving them to the top level dictionary
|
||||
if "moe_state_dict" in language_model_state_dict:
|
||||
for key in list(language_model_state_dict["moe_state_dict"].keys()):
|
||||
state_dict_[key] = language_model_state_dict["moe_state_dict"].pop(key)
|
||||
del language_model_state_dict["moe_state_dict"]
|
||||
state_dict_.update(language_model_state_dict)
|
||||
# Save word_embeddings.
|
||||
if self.pre_process:
|
||||
embedding_state_dict = self.embedding.state_dict(prefix=prefix + 'embedding.', keep_vars=keep_vars)
|
||||
state_dict_.update(embedding_state_dict)
|
||||
if self.post_process:
|
||||
lm_head_state_dict = self.lm_head.state_dict(prefix=prefix + 'lm_head.', keep_vars=keep_vars)
|
||||
state_dict_.update(lm_head_state_dict)
|
||||
return state_dict_
|
||||
|
334
ascendspeed/model/lora_modules.py
Normal file
334
ascendspeed/model/lora_modules.py
Normal file
@ -0,0 +1,334 @@
|
||||
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
import peft
|
||||
|
||||
assert peft.__version__ == "0.4.0"
|
||||
|
||||
from peft import LoraModel as PeftLoraModel
|
||||
from peft.tuners.lora import LoraLayer, Embedding, Conv2d, Linear
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||
from peft.utils import transpose, _get_submodules, ModulesToSaveWrapper
|
||||
|
||||
from ascendspeed.mpu import RowParallelLinear, ColumnParallelLinear
|
||||
|
||||
if is_bnb_available():
|
||||
import bitsandbytes as bnb
|
||||
from peft.tuners.lora import Linear8bitLt, Linear4bit
|
||||
|
||||
|
||||
class LoraParalleLayer(LoraLayer):
|
||||
def __init__(self, in_features: int, out_features: int, is_paralle_a: bool = False):
|
||||
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
|
||||
self.is_paralle_a = is_paralle_a
|
||||
|
||||
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, **kwargs):
|
||||
self.r[adapter_name] = r
|
||||
self.lora_alpha[adapter_name] = lora_alpha
|
||||
if lora_dropout > 0.0:
|
||||
lora_dropout_layer = nn.Dropout(p=lora_dropout)
|
||||
else:
|
||||
lora_dropout_layer = nn.Identity()
|
||||
|
||||
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
if self.is_paralle_a:
|
||||
lora_a = RowParallelLinear(input_size=self.in_features, output_size=r, bias=False,
|
||||
input_is_parallel=kwargs.get('input_is_parallel', True), skip_bias_add=True,
|
||||
dtype=torch.float32) # lora需要强制升格到32位精度,否则会溢出
|
||||
lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32)
|
||||
else:
|
||||
lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32)
|
||||
lora_b = ColumnParallelLinear(input_size=r, output_size=self.out_features, bias=False,
|
||||
gather_output=kwargs.get('gather_output', False), dtype=torch.float32)
|
||||
self.lora_A.update(nn.ModuleDict({adapter_name: lora_a}))
|
||||
self.lora_B.update(nn.ModuleDict({adapter_name: lora_b}))
|
||||
|
||||
self.scaling[adapter_name] = lora_alpha / r
|
||||
if init_lora_weights:
|
||||
self.reset_lora_parameters(adapter_name)
|
||||
self.to(self.weight.device)
|
||||
|
||||
|
||||
class LoraParallelLinear(ColumnParallelLinear, RowParallelLinear, LoraParalleLayer):
|
||||
"""
|
||||
当目标层parallel_linear为RowParallelLinear时:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
- -
|
||||
| a_1 |
|
||||
| . |
|
||||
lora_A = | . | lora_B = [ ... ]
|
||||
| . |
|
||||
| a_p |
|
||||
- -
|
||||
为了保持输入、输出的shape一致,我们需要将lora的矩阵A进行行切分,而此时的lora_B则应该是完整的线性层;
|
||||
同理,当目标层是ColumnParallelLinear时,我们对lora_B进行列切分,而lora_A依然是完整的线性层。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name: str,
|
||||
parallel_linear: Union[ColumnParallelLinear, RowParallelLinear],
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False,
|
||||
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
**kwargs,
|
||||
):
|
||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||
|
||||
self.parallel_linear_class = type(parallel_linear)
|
||||
parallel_linear_kwargs = {}
|
||||
if isinstance(parallel_linear, RowParallelLinear):
|
||||
parallel_linear_kwargs['input_is_parallel'] = parallel_linear.input_is_parallel
|
||||
else:
|
||||
parallel_linear_kwargs['gather_output'] = parallel_linear.gather_output
|
||||
type(parallel_linear).__init__(self, input_size=parallel_linear.input_size,
|
||||
output_size=parallel_linear.output_size, bias=parallel_linear.bias,
|
||||
skip_bias_add=parallel_linear.skip_bias_add,
|
||||
sequence_parallel_enabled=parallel_linear.sequence_parallel_enabled,
|
||||
**parallel_linear_kwargs)
|
||||
LoraParalleLayer.__init__(self, in_features=parallel_linear.input_size,
|
||||
out_features=parallel_linear.output_size,
|
||||
is_paralle_a=isinstance(parallel_linear, RowParallelLinear))
|
||||
|
||||
# weight会在_replace_module函数中进行拷贝
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, **parallel_linear_kwargs)
|
||||
self.active_adapter = adapter_name
|
||||
self.is_target_conv_1d_layer = False
|
||||
|
||||
def merge(self):
|
||||
if self.active_adapter not in self.lora_A.keys():
|
||||
return
|
||||
if self.merged:
|
||||
warnings.warn("Already merged. Nothing to do.")
|
||||
return
|
||||
if self.r[self.active_adapter] > 0:
|
||||
self.weight.data += self.get_delta_weight(self.active_adapter)
|
||||
self.merged = True
|
||||
|
||||
def unmerge(self):
|
||||
if self.active_adapter not in self.lora_A.keys():
|
||||
return
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
if self.r[self.active_adapter] > 0:
|
||||
self.weight.data -= self.get_delta_weight(self.active_adapter)
|
||||
self.merged = False
|
||||
|
||||
def get_delta_weight(self, adapter):
|
||||
return (
|
||||
transpose(
|
||||
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
|
||||
self.fan_in_fan_out,
|
||||
)
|
||||
* self.scaling[adapter]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
previous_dtype = x.dtype
|
||||
|
||||
if self.active_adapter not in self.lora_A.keys():
|
||||
result, bias = self.parallel_linear_class.forward(self, x)
|
||||
return result, bias
|
||||
if self.disable_adapters:
|
||||
if self.r[self.active_adapter] > 0 and self.merged:
|
||||
self.unmerge()
|
||||
result, bias = self.parallel_linear_class.forward(self, x)
|
||||
elif self.r[self.active_adapter] > 0 and not self.merged:
|
||||
result, bias = self.parallel_linear_class.forward(self, x)
|
||||
|
||||
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
|
||||
|
||||
lora_a = self.lora_A[self.active_adapter]
|
||||
lora_b = self.lora_B[self.active_adapter]
|
||||
lora_dropout = self.lora_dropout[self.active_adapter]
|
||||
scaling = self.scaling[self.active_adapter]
|
||||
|
||||
lora_result = lora_a(lora_dropout(x))
|
||||
if isinstance(lora_result, tuple):
|
||||
lora_result = lora_result[0]
|
||||
lora_result = lora_b(lora_result)
|
||||
if isinstance(lora_result, tuple):
|
||||
lora_result = lora_result[0]
|
||||
lora_result = lora_result * scaling
|
||||
|
||||
result = result.clone().detach() + lora_result
|
||||
else:
|
||||
result, bias = self.parallel_linear_class.forward(self, x)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
|
||||
return result, bias
|
||||
|
||||
|
||||
class AscendLoraModel(PeftLoraModel):
|
||||
def _create_new_module(self, lora_config, adapter_name, target):
|
||||
bias = hasattr(target, "bias") and target.bias is not None
|
||||
kwargs = {
|
||||
"r": lora_config.r,
|
||||
"lora_alpha": lora_config.lora_alpha,
|
||||
"lora_dropout": lora_config.lora_dropout,
|
||||
"fan_in_fan_out": lora_config.fan_in_fan_out,
|
||||
"init_lora_weights": lora_config.init_lora_weights,
|
||||
}
|
||||
|
||||
new_module = self._create_new_bit_linear_module(target, adapter_name, bias, kwargs)
|
||||
if new_module is None:
|
||||
if isinstance(target, torch.nn.Embedding):
|
||||
embedding_kwargs = kwargs.copy()
|
||||
embedding_kwargs.pop("fan_in_fan_out", None)
|
||||
in_features, out_features = target.num_embeddings, target.embedding_dim
|
||||
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
|
||||
elif isinstance(target, torch.nn.Conv2d):
|
||||
out_channels, in_channels = target.weight.size()[:2]
|
||||
kernel_size = target.weight.size()[2:]
|
||||
stride = target.stride
|
||||
padding = target.padding
|
||||
new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs)
|
||||
elif isinstance(target, (ColumnParallelLinear, RowParallelLinear)):
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
||||
new_module = LoraParallelLinear(adapter_name=adapter_name, parallel_linear=target, **kwargs)
|
||||
else:
|
||||
# 在_create_new_linear_module里还没有匹配上,会直接抛异常
|
||||
new_module = self._create_new_linear_module(target, adapter_name, lora_config, bias, kwargs)
|
||||
|
||||
return new_module
|
||||
|
||||
def _create_new_bit_linear_module(self, target, adapter_name, bias, kwargs):
|
||||
loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
|
||||
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
|
||||
|
||||
new_module = None
|
||||
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
||||
eightbit_kwargs = kwargs.copy()
|
||||
eightbit_kwargs.update(
|
||||
{
|
||||
"has_fp16_weights": target.state.has_fp16_weights,
|
||||
"memory_efficient_backward": target.state.memory_efficient_backward,
|
||||
"threshold": target.state.threshold,
|
||||
"index": target.index,
|
||||
}
|
||||
)
|
||||
new_module = Linear8bitLt(
|
||||
adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
|
||||
)
|
||||
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
|
||||
fourbit_kwargs = kwargs.copy()
|
||||
fourbit_kwargs.update(
|
||||
{
|
||||
"compute_dtype": target.compute_dtype,
|
||||
"compress_statistics": target.weight.compress_statistics,
|
||||
"quant_type": target.weight.quant_type,
|
||||
}
|
||||
)
|
||||
new_module = Linear4bit(adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs)
|
||||
return new_module
|
||||
|
||||
def _create_new_linear_module(self, target, adapter_name, lora_config, bias, kwargs):
|
||||
if isinstance(target, torch.nn.Linear):
|
||||
in_features, out_features = target.in_features, target.out_features
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
||||
elif isinstance(target, Conv1D):
|
||||
in_features, out_features = (
|
||||
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
|
||||
)
|
||||
kwargs["is_target_conv_1d_layer"] = True
|
||||
if not kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
||||
"Setting fan_in_fan_out to True."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. "
|
||||
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
|
||||
)
|
||||
new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
|
||||
return new_module
|
||||
|
||||
def _unload_and_optionally_merge(self, merge=True):
|
||||
if getattr(self.model, "is_loaded_in_8bit", False) or getattr(self.model, "is_loaded_in_4bit", False):
|
||||
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")
|
||||
|
||||
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
|
||||
for key in key_list:
|
||||
try:
|
||||
parent, target, target_name = _get_submodules(self.model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
if isinstance(target, LoraLayer):
|
||||
if isinstance(target, nn.Embedding):
|
||||
new_module = torch.nn.Embedding(target.in_features, target.out_features)
|
||||
elif isinstance(target, nn.Conv2d):
|
||||
new_module = torch.nn.Conv2d(
|
||||
target.in_channels,
|
||||
target.out_channels,
|
||||
kernel_size=target.kernel_size,
|
||||
stride=target.stride,
|
||||
padding=target.padding,
|
||||
dilation=target.dilation,
|
||||
)
|
||||
elif isinstance(target, (ColumnParallelLinear, RowParallelLinear)):
|
||||
parallel_linear_kwargs = {}
|
||||
if target.parallel_linear_class is RowParallelLinear:
|
||||
parallel_linear_kwargs['input_is_parallel'] = target.input_is_parallel
|
||||
parallel_linear_kwargs['sequence_parallel_enabled'] = target.sequence_parallel_enabled
|
||||
else:
|
||||
parallel_linear_kwargs['gather_output'] = target.gather_output
|
||||
parallel_linear_kwargs['sequence_parallel_enabled'] = target.sequence_parallel_enabled
|
||||
new_module = target.parallel_linear_class(input_size=target.input_size,
|
||||
output_size=target.output_size, bias=target.bias,
|
||||
skip_bias_add=target.skip_bias_add,
|
||||
**parallel_linear_kwargs)
|
||||
else:
|
||||
bias = target.bias is not None
|
||||
if getattr(target, "is_target_conv_1d_layer", False):
|
||||
new_module = Conv1D(target.out_features, target.in_features)
|
||||
else:
|
||||
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
|
||||
if merge:
|
||||
target.merge()
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
if isinstance(target, ModulesToSaveWrapper):
|
||||
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
peft.peft_model.PEFT_TYPE_TO_MODEL_MAPPING[peft.PeftType.LORA] = AscendLoraModel
|
200
ascendspeed/model/lora_utils.py
Normal file
200
ascendspeed/model/lora_utils.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_lora_model_classes():
|
||||
from peft import PeftModel, LoraModel
|
||||
from .lora_modules import AscendLoraModel
|
||||
return PeftModel, LoraModel, AscendLoraModel
|
||||
|
||||
|
||||
def is_enable_lora():
|
||||
from ascendspeed import get_args
|
||||
args = get_args()
|
||||
return bool(args.lora_target_modules)
|
||||
|
||||
|
||||
def is_enable_lora_modules_to_save():
|
||||
from ascendspeed import get_args
|
||||
args = get_args()
|
||||
return is_enable_lora() and bool(args.lora_modules_to_save)
|
||||
|
||||
|
||||
def is_module_name_in_lora_modules_to_save(module_name):
|
||||
from ascendspeed import get_args
|
||||
args = get_args()
|
||||
for modules_to_save_name in args.lora_modules_to_save:
|
||||
if module_name.endswith(f"{modules_to_save_name}.weight"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_lora_state_dict(state_dict):
|
||||
from ascendspeed import get_args
|
||||
args = get_args()
|
||||
original_module_key = 'original_module.weight'
|
||||
modules_to_save_key = f'modules_to_save.{args.lora_adapter_name}.weight'
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if "lora_" in key or key.endswith(original_module_key) or key.endswith(modules_to_save_key):
|
||||
state_dict_[key] = state_dict[key]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def is_lora_state_dict(state_dict):
|
||||
for key in state_dict.keys():
|
||||
if "lora_" in key:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_lora_modules_to_save_state_dict(state_dict):
|
||||
from ascendspeed import get_args
|
||||
args = get_args()
|
||||
modules_to_save_key = f'modules_to_save.{args.lora_adapter_name}'
|
||||
for key in state_dict.keys():
|
||||
if modules_to_save_key in key:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def handle_lora_modules_to_save_key(state_dict):
|
||||
if not is_enable_lora_modules_to_save():
|
||||
return state_dict
|
||||
if is_lora_modules_to_save_state_dict(state_dict):
|
||||
state_dict_ = {}
|
||||
for module_name in state_dict.keys():
|
||||
if not is_module_name_in_lora_modules_to_save(module_name):
|
||||
state_dict_[module_name] = state_dict[module_name]
|
||||
return state_dict_
|
||||
from ascendspeed import get_args
|
||||
args = get_args()
|
||||
original_module_key = 'original_module'
|
||||
modules_to_save_key = f'modules_to_save.{args.lora_adapter_name}'
|
||||
state_dict_ = {}
|
||||
for module_name in state_dict.keys():
|
||||
state_dict_[module_name] = state_dict[module_name]
|
||||
if not is_module_name_in_lora_modules_to_save(module_name):
|
||||
continue
|
||||
_module_name = module_name.split('.')
|
||||
if original_module_key not in module_name:
|
||||
original_module_name = '.'.join(_module_name[:-1] + [original_module_key] + [_module_name[-1]])
|
||||
state_dict_[original_module_name] = state_dict[module_name]
|
||||
if modules_to_save_key not in module_name:
|
||||
modules_to_save_name = '.'.join(_module_name[:-1] + [modules_to_save_key] + [_module_name[-1]])
|
||||
state_dict_[modules_to_save_name] = state_dict[module_name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def lora_custom_load_fn_for_deepspeed(src, dst):
|
||||
model = dst.get_base_model()
|
||||
state_dict = handle_lora_modules_to_save_key(state_dict=src)
|
||||
strict = is_lora_state_dict(state_dict=state_dict)
|
||||
# At this time, the model is a lora model, but the pre-training weights do not include lora, so strict is False
|
||||
result = model.load_state_dict(state_dict, strict=strict)
|
||||
if not strict and result:
|
||||
from ascendspeed import print_rank_0
|
||||
print_rank_0(f"lora_custom_load_fn_for_deepspeed result:{result}")
|
||||
|
||||
|
||||
def get_lora_load_fn_with_deepspeed(model, base_model_load_dir=None, tag=None):
|
||||
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
|
||||
from deepspeed.runtime.pipe.module import PipelineModule
|
||||
|
||||
if not base_model_load_dir:
|
||||
return lora_custom_load_fn_for_deepspeed
|
||||
|
||||
if tag is None:
|
||||
latest_tag = "latest_universal" if model.load_universal_checkpoint() else "latest"
|
||||
latest_path = os.path.join(base_model_load_dir, latest_tag)
|
||||
if os.path.isfile(latest_path):
|
||||
with open(latest_path, "r") as fd:
|
||||
tag = fd.read().strip()
|
||||
|
||||
ckpt_list = model._get_all_ckpt_names(base_model_load_dir, tag) # 需要在deepspeed外额外读取model的ckpt,故只能访问受保护成员
|
||||
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=model.checkpoint_engine)
|
||||
|
||||
is_pipe_parallel = isinstance(model.module, PipelineModule)
|
||||
|
||||
mp_rank = 0 if model.mpu is None else model.mpu.get_model_parallel_rank()
|
||||
load_path, checkpoint, _ = sd_loader.load(model.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel)
|
||||
|
||||
if checkpoint is None:
|
||||
raise ValueError(f"failed to load {base_model_load_dir}.")
|
||||
|
||||
module_state_dict = checkpoint['module']
|
||||
|
||||
def _lora_load_fn(src, dst):
|
||||
state_dict = {}
|
||||
state_dict.update(module_state_dict)
|
||||
state_dict.update(src)
|
||||
return lora_custom_load_fn_for_deepspeed(src=state_dict, dst=dst)
|
||||
|
||||
return _lora_load_fn
|
||||
|
||||
|
||||
def get_lora_state_dict_with_deepspeed(model):
|
||||
original_state_dict = model.module.state_dict
|
||||
|
||||
def _state_dict(destination=None, prefix='', keep_vars=False):
|
||||
state_dict = original_state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
return get_lora_state_dict(state_dict=state_dict)
|
||||
|
||||
return _state_dict
|
||||
|
||||
|
||||
def handle_model_with_lora(model):
|
||||
from ascendspeed import get_args, print_rank_0
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from . import lora_modules # 给lora打补丁
|
||||
|
||||
def _hook(_module, _x_in, _x_out):
|
||||
""" Extract the feature map of model"""
|
||||
_x_out.requires_grad_(True)
|
||||
|
||||
def _create_hooks(_model, layer):
|
||||
""" Make the hooks function"""
|
||||
for name, module in _model.named_modules():
|
||||
_name = name.split('.')[-1]
|
||||
if _name in layer:
|
||||
module.register_forward_hook(_hook)
|
||||
|
||||
args = get_args()
|
||||
|
||||
model_len = len(model)
|
||||
for i in range(model_len):
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=args.lora_target_modules,
|
||||
modules_to_save=args.lora_modules_to_save,
|
||||
lora_dropout=0.0,
|
||||
bias="none",
|
||||
)
|
||||
model[i] = get_peft_model(model[i], lora_config)
|
||||
|
||||
_create_hooks(model[i], args.lora_register_forward_hook)
|
||||
model[i].print_trainable_parameters()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def update_model_state_dict_with_megatron(model_state_dict, state_dict):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if key.startswith('model'):
|
||||
state_dict_[key] = model_state_dict[key]
|
||||
state_dict_[key].update(state_dict[key]) # 更新model的权重
|
||||
else:
|
||||
state_dict_[key] = state_dict[key]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def handle_lora_modules_to_save_key_with_megatron(state_dict):
|
||||
for key in state_dict:
|
||||
if key.startswith('model'):
|
||||
state_dict[key] = handle_lora_modules_to_save_key(state_dict[key])
|
||||
return state_dict
|
48
ascendspeed/model/manual_pipe.py
Normal file
48
ascendspeed/model/manual_pipe.py
Normal file
@ -0,0 +1,48 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
|
||||
import deepspeed.runtime.utils as ds_utils
|
||||
from deepspeed.utils import logger
|
||||
import torch.nn as nn
|
||||
from ascendspeed import get_args
|
||||
|
||||
|
||||
def bloom_manual_layer_allocation():
|
||||
args = get_args()
|
||||
parts = [0]
|
||||
for nlayer in args.manual_layers:
|
||||
parts.append(parts[-1] + nlayer)
|
||||
for i in range(1, len(parts)):
|
||||
parts[i] += 3
|
||||
|
||||
parts[-1] += 4
|
||||
return parts
|
||||
|
||||
|
||||
def use_bloom_manual_layer_allocation():
|
||||
args = get_args()
|
||||
return args.use_manual_layer_allocation
|
||||
|
||||
|
||||
class ManuallyAllocatedPipelineModule(PipelineModule):
|
||||
def _partition_layers(self, method='uniform'):
|
||||
super(ManuallyAllocatedPipelineModule, self)._partition_layers(method)
|
||||
stage_id = self._topo.get_coord(self.global_rank).pipe
|
||||
method = method.lower()
|
||||
if method.startswith('type:'):
|
||||
if use_bloom_manual_layer_allocation():
|
||||
parts = bloom_manual_layer_allocation()
|
||||
self._set_bounds(start=parts[stage_id], stop=parts[stage_id + 1])
|
@ -13,15 +13,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""AscendSpeed Module"""
|
||||
import os
|
||||
import abc
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
import deepspeed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from ascendspeed import get_args
|
||||
import ascendspeed
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed.core import parallel_state, tensor_parallel
|
||||
from ascendspeed.model.lora_utils import is_enable_lora, get_lora_model_classes
|
||||
|
||||
|
||||
_FLOAT_TYPES = (torch.FloatTensor, get_accelerator().FloatTensor)
|
||||
@ -33,23 +43,21 @@ def param_is_not_shared(param):
|
||||
return not hasattr(param, 'shared') or not param.shared
|
||||
|
||||
|
||||
|
||||
class MegatronModule(torch.nn.Module):
|
||||
"""Megatron specific extensions of torch Module with support
|
||||
for pipelining."""
|
||||
|
||||
def __init__(self, share_word_embeddings=True):
|
||||
def __init__(self, config=None, share_word_embeddings=True):
|
||||
super(MegatronModule, self).__init__()
|
||||
self.config = config
|
||||
self.share_word_embeddings = share_word_embeddings
|
||||
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
||||
keep_vars=False):
|
||||
"""Use this function to override the state dict for
|
||||
saving checkpoints."""
|
||||
return self.state_dict(destination, prefix, keep_vars)
|
||||
|
||||
|
||||
def word_embeddings_weight(self):
|
||||
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
|
||||
return self.language_model.embedding.word_embeddings.weight
|
||||
@ -61,7 +69,6 @@ class MegatronModule(torch.nn.Module):
|
||||
raise Exception('word_embeddings_weight() should be '
|
||||
'called for first and last stage only')
|
||||
|
||||
|
||||
def initialize_word_embeddings(self, init_method_normal):
|
||||
args = get_args()
|
||||
if not self.share_word_embeddings:
|
||||
@ -121,6 +128,7 @@ def conversion_helper(val, conversion):
|
||||
rtn = tuple(rtn)
|
||||
return rtn
|
||||
|
||||
|
||||
def fp32_to_float16(val, float16_convertor):
|
||||
def half_conversion(val):
|
||||
val_typecheck = val
|
||||
@ -135,6 +143,9 @@ def fp32_to_float16(val, float16_convertor):
|
||||
|
||||
def float16_to_fp32(val):
|
||||
def float_conversion(val):
|
||||
if val is None:
|
||||
return val
|
||||
|
||||
val_typecheck = val
|
||||
if isinstance(val_typecheck, (torch.nn.parameter.Parameter, torch.autograd.Variable)):
|
||||
val_typecheck = val.data
|
||||
@ -152,10 +163,12 @@ class Float16Module(MegatronModule):
|
||||
|
||||
if args.fp16:
|
||||
self.add_module('module', module.half())
|
||||
|
||||
def float16_convertor(val):
|
||||
return val.half()
|
||||
elif args.bf16:
|
||||
self.add_module('module', module.bfloat16())
|
||||
|
||||
def float16_convertor(val):
|
||||
return val.bfloat16()
|
||||
else:
|
||||
@ -163,7 +176,6 @@ class Float16Module(MegatronModule):
|
||||
|
||||
self.float16_convertor = float16_convertor
|
||||
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if parallel_state.is_pipeline_first_stage():
|
||||
inputs = fp32_to_float16(inputs, self.float16_convertor)
|
||||
@ -172,16 +184,415 @@ class Float16Module(MegatronModule):
|
||||
outputs = float16_to_fp32(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
return self.module.state_dict(destination, prefix, keep_vars)
|
||||
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
||||
keep_vars=False):
|
||||
return self.module.state_dict_for_save_checkpoint(destination, prefix,
|
||||
keep_vars)
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
self.module.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
class MegatronModuleForCausalLMABC(torch.nn.Module, abc.ABC):
|
||||
"""
|
||||
Megatron specific extensions of torch Module with support
|
||||
for text generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MegatronModuleForCausalLMABC, self).__init__()
|
||||
self.top_k = 50
|
||||
self.top_p = 1.0
|
||||
self.do_sample = False
|
||||
self.num_beams = 1
|
||||
self.temperature = 1.0
|
||||
self.max_length = 20
|
||||
self.max_new_tokens = 20
|
||||
self.eos_token_id = None
|
||||
self.pad_token_id = None
|
||||
self.num_return_sequences = 1
|
||||
self.length_penalty = 1.0
|
||||
self.tokenizer = None
|
||||
self.recompute = True
|
||||
self.detokenize = True
|
||||
self.include_input = False
|
||||
self.stream = True
|
||||
self.return_output_log_probs = False
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_provider, pretrained_model_name_or_path: Optional[Union[str, os.PathLike, None]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
This is an API for initializing model and loading weight.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
model_provider(`func`):
|
||||
Function used to generate model objects which is similar to the training define.
|
||||
pretrained_model_name_or_path(`str`, *optional*, defaults to None):
|
||||
File path of Model weight in megatron format (TP, PP may be used).
|
||||
If it is None, the random initialized weights will be used.
|
||||
"""
|
||||
|
||||
def generate(self, input_ids=None, **kwargs):
|
||||
"""
|
||||
This is an API for text generation which complies with most huggingface definition.
|
||||
|
||||
- *greedy decoding* if `do_sample=False`
|
||||
- *top-k decoding* if `top_k>0`
|
||||
- *top-p decoding* if `top_p>0.0`
|
||||
- *beam-search decoding* if `do_sample=False` and `num_beams>1`
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
input_ids(str | torch.Tensor):
|
||||
The text entered by the user, e.g. 'Hello!'
|
||||
Or
|
||||
The text, which encoded by tokenizer, entered by the user, e.g. [0, 13, 5, ...]
|
||||
do_sample (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use sampling ; use greedy decoding otherwise.
|
||||
top_k (`int`, *optional*, defaults to 0):
|
||||
The number of the highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
top_p (`float`, *optional*, defaults to 1.0):
|
||||
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to
|
||||
`top_p` or higher are kept for generation.
|
||||
temperature (`float`, *optional*, defaults to 1.0):
|
||||
The value used to modulate the next token probabilities.
|
||||
num_beams (`int`, *optional*, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
max_length (`int`, *optional*, defaults to 20):
|
||||
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
|
||||
`max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
|
||||
max_new_tokens (`int`, *optional*):
|
||||
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
tokenizer (`obj`, *optional*, defaults to None):
|
||||
If you don't want to use the tokenizer initialized by megatron, you can pass it in HF format here.
|
||||
length_penalty (`float`, *optional*, defaults to 1.0):
|
||||
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
||||
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
||||
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
||||
`length_penalty` < 0.0 encourages shorter sequences.Only activate in beam search mode.
|
||||
num_return_sequences(`int`, *optional*, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch. Only activate
|
||||
in beam search mode.
|
||||
recompute (`bool`, *optional*, defaults to True):
|
||||
Whether the model not to uses the last result in computing next token.
|
||||
detokenize (`bool`, *optional*, defaults to True):
|
||||
Whether to detokenize tokens into characters.
|
||||
include_input (`bool`, *optional*, defaults to False):
|
||||
Whether the output contains the context instruction.
|
||||
stream (`bool`, *optional*, defaults to True):
|
||||
Whether the output is streamed one by one.
|
||||
return_output_log_probs(`bool`, *optional*, defaults to False):
|
||||
Whether to return a probability distribution for each token.
|
||||
"""
|
||||
self.top_k = kwargs.pop("top_k", 50)
|
||||
self.top_p = kwargs.pop("top_p", 1.0)
|
||||
self.do_sample = kwargs.pop("do_sample", False)
|
||||
self.num_beams = kwargs.pop("num_beams", 1)
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
self.max_length = kwargs.pop("max_length", 20)
|
||||
self.max_new_tokens = kwargs.pop("max_new_tokens", 20)
|
||||
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
self.tokenizer = kwargs.pop("tokenizer", None)
|
||||
self.recompute = kwargs.pop("recompute", True)
|
||||
self.detokenize = kwargs.pop("detokenize", True)
|
||||
self.include_input = kwargs.pop("include_input", False)
|
||||
self.stream = kwargs.pop("stream", True)
|
||||
self.return_output_log_probs = kwargs.pop("return_output_log_probs", False)
|
||||
|
||||
|
||||
class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
"""
|
||||
Megatron specific extensions of torch Module with support
|
||||
for text generation.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
from ascendspeed import get_tokenizer
|
||||
from ascendspeed import text_generation_utils
|
||||
|
||||
args = get_args()
|
||||
args.max_tokens_to_oom = args.max_tokens_to_oom if hasattr(args, "max_tokens_to_oom") else 4096
|
||||
args.inference_batch_times_seqlen_threshold = args.inference_batch_times_seqlen_threshold \
|
||||
if hasattr(args, "inference_batch_times_seqlen_threshold") else 4
|
||||
|
||||
self.padded_vocab_size = args.padded_vocab_size
|
||||
self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size > 1
|
||||
|
||||
self.tokenizer_ori = get_tokenizer().tokenizer
|
||||
|
||||
# import module to avoid error of circular import
|
||||
self.utils = text_generation_utils
|
||||
|
||||
@staticmethod
|
||||
def _init_deepspeed_inference(model, args):
|
||||
ds_config = {
|
||||
"fp16": {
|
||||
"enabled": True,
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": False,
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 0,
|
||||
"reduce_bucket_size": args.hidden_size * args.hidden_size,
|
||||
},
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"wall_clock_breakdown": False,
|
||||
}
|
||||
if hasattr(args, "ds_config") and getattr(args, "ds_config"):
|
||||
ds_config = args.ds_config
|
||||
elif hasattr(args, "deepspeed_config") and getattr(args, "deepspeed_config"):
|
||||
with open(args.deepspeed_config, encoding='utf-8', errors='ignore') as f:
|
||||
ds_config = json.load(f, strict=False)
|
||||
|
||||
zero_optimization_info = ds_config.get("zero_optimization")
|
||||
if zero_optimization_info and zero_optimization_info.get("stage") > 0:
|
||||
logging.warning("Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3. "
|
||||
"Transferring to ZeRO-1")
|
||||
ds_config["zero_optimization"]["stage"] = 0
|
||||
|
||||
if args.ds_inference:
|
||||
logging.warning("ds_inference is not support now, use normal mode instead.")
|
||||
|
||||
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
raise ValueError("For now, in DeepSpeed pipeline mode, the pp should not greater than 1 now.\n"
|
||||
"Please set --pipeline-model-parallel-size 1.")
|
||||
|
||||
engine = deepspeed.initialize(
|
||||
model=model,
|
||||
config_params=ds_config,
|
||||
mpu=parallel_state if args.no_pipeline_parallel else None
|
||||
)[0]
|
||||
engine.module.eval()
|
||||
|
||||
return engine
|
||||
|
||||
@staticmethod
|
||||
def _tokenize_and_broadcast(input_ids, tokenizer):
|
||||
broadcast_rank = torch.zeros(dist.get_world_size(),
|
||||
dtype=torch.int64,
|
||||
device=torch.device(get_accelerator().device_name()))
|
||||
|
||||
if input_ids:
|
||||
if isinstance(input_ids, str):
|
||||
context_tokens = tokenizer.encode(input_ids)
|
||||
else:
|
||||
context_tokens = input_ids
|
||||
|
||||
context_length = len(context_tokens)
|
||||
counts = 1
|
||||
broadcast_rank[dist.get_rank()] = 1
|
||||
else:
|
||||
context_tokens = [tokenizer.encode("EMPTY TEXT")]
|
||||
context_length = 0
|
||||
counts = 0
|
||||
|
||||
input_info = [counts, context_length]
|
||||
input_info_tensor = get_accelerator().LongTensor(input_info)
|
||||
dist.all_reduce(input_info_tensor)
|
||||
dist.all_reduce(broadcast_rank)
|
||||
counts = input_info_tensor[0].item()
|
||||
if counts == 0:
|
||||
raise ValueError("Please pass prompt on at least one process.")
|
||||
context_length = input_info_tensor[1].item() // counts
|
||||
master_rank = torch.nonzero(broadcast_rank)[0, 0]
|
||||
return context_length, context_tokens, master_rank
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_tokens(context_tokens, context_length, master_rank):
|
||||
if dist.get_world_size() > 1:
|
||||
if dist.get_rank() == master_rank:
|
||||
context_tokens_tensor = get_accelerator().LongTensor(context_tokens)
|
||||
dist.broadcast(context_tokens_tensor, master_rank)
|
||||
else:
|
||||
context_tokens_tensor = torch.empty(context_length,
|
||||
dtype=torch.int64,
|
||||
device=torch.device(get_accelerator().device_name()))
|
||||
dist.broadcast(context_tokens_tensor, master_rank)
|
||||
else:
|
||||
context_tokens_tensor = get_accelerator().LongTensor(context_tokens)
|
||||
|
||||
return context_tokens_tensor
|
||||
|
||||
@staticmethod
|
||||
def _check_output(output, stream):
|
||||
if not stream:
|
||||
full_output = None
|
||||
for tmp in output:
|
||||
full_output = tmp
|
||||
return full_output
|
||||
else:
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _ids_check(ids, tokenizer):
|
||||
checked_ids = []
|
||||
for per_ids in ids:
|
||||
if torch.max(per_ids) >= len(tokenizer):
|
||||
logging.warning("The output ids exceeds the tokenizer length, "
|
||||
"the clamp operation is enforced, please check!!")
|
||||
checked_ids.append(torch.clamp(per_ids, min=0, max=len(tokenizer))-1)
|
||||
else:
|
||||
checked_ids.append(per_ids)
|
||||
return checked_ids
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_provider, pretrained_model_name_or_path: Optional[Union[str, os.PathLike, None]] = None,
|
||||
**kwargs
|
||||
) -> MegatronModuleForCausalLMABC:
|
||||
from ascendspeed.training import get_model
|
||||
from ascendspeed.checkpointing import load_checkpoint
|
||||
from ascendspeed.model import DistributedDataParallel as LocalDDP
|
||||
from ascendspeed.utils import unwrap_model
|
||||
|
||||
args = get_args()
|
||||
|
||||
for addition_key, addition_val in kwargs.items():
|
||||
setattr(args, addition_key, addition_val)
|
||||
|
||||
args.model = get_model(model_provider)
|
||||
|
||||
if pretrained_model_name_or_path:
|
||||
args.load = pretrained_model_name_or_path
|
||||
|
||||
if args.deepspeed:
|
||||
args.model[0] = cls._init_deepspeed_inference(args.model[0], args)
|
||||
|
||||
if args.load:
|
||||
load_checkpoint(args.model, None, None)
|
||||
|
||||
if not args.deepspeed:
|
||||
unwrap_classes = (torchDDP, LocalDDP, Float16Module)
|
||||
if is_enable_lora():
|
||||
unwrap_classes += get_lora_model_classes()
|
||||
else:
|
||||
unwrap_classes = (torchDDP, LocalDDP, Float16Module, deepspeed.DeepSpeedEngine)
|
||||
|
||||
return unwrap_model(args.model, unwrap_classes)[0]
|
||||
|
||||
def generate(self, input_ids=None, **kwargs):
|
||||
args = get_args()
|
||||
|
||||
if not args.deepspeed and parallel_state.get_data_parallel_world_size() > 1:
|
||||
raise ValueError("In this inference mode data parallel is forbidden.")
|
||||
|
||||
super().generate(input_ids=input_ids, **kwargs)
|
||||
|
||||
setattr(args, "text_generation_config", {
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"recompute": self.recompute,
|
||||
"return_output_log_probs": self.return_output_log_probs,
|
||||
})
|
||||
|
||||
args.out_seq_length = self.max_new_tokens
|
||||
args.seq_length = args.seq_length if hasattr(args, "seq_length") else self.max_length
|
||||
args.greedy = False if self.do_sample else True
|
||||
|
||||
# =======================================
|
||||
# Add additional parameters to args which
|
||||
# may be used in original logic of codes
|
||||
# =======================================
|
||||
for addition_key, addition_val in kwargs.items():
|
||||
setattr(args, addition_key, addition_val)
|
||||
|
||||
# =======================================
|
||||
# Initialize the tokenizer to choose
|
||||
# whether to use customizing tokenizer
|
||||
# =======================================
|
||||
tokenizer = self._init_tokenizer(args, self.eos_token_id, self.pad_token_id, self.tokenizer)
|
||||
|
||||
# =======================================
|
||||
# Tokenize the prompts and broadcasting,
|
||||
# so you don't need to pass the prompt on
|
||||
# each process.
|
||||
# =======================================
|
||||
context_length, context_tokens, master_rank = self._tokenize_and_broadcast(input_ids, tokenizer)
|
||||
|
||||
# =======================================
|
||||
# For parallel we need to send context tokens
|
||||
# to other process
|
||||
# =======================================
|
||||
context_tokens_tensor = self._broadcast_tokens(context_tokens, context_length, master_rank).unsqueeze(0)
|
||||
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
|
||||
|
||||
# =======================================
|
||||
# Get the streaming tokens generator
|
||||
# =======================================
|
||||
if self.num_beams > 1:
|
||||
raise NotImplementedError("Beam search will coming soon. ~~")
|
||||
else:
|
||||
token_stream = self.utils.get_token_stream(args.model[0], context_tokens)
|
||||
|
||||
# =======================================
|
||||
# Post processions in order to get final
|
||||
# output texts/tokens
|
||||
# =======================================
|
||||
output = self._post_processing(token_stream,
|
||||
tokenizer,
|
||||
context_length,
|
||||
self.include_input,
|
||||
self.detokenize)
|
||||
return self._check_output(output, self.stream)
|
||||
|
||||
def _init_tokenizer(self, args, eos_token_id, pad_token_id, tokenizer):
|
||||
if tokenizer is None:
|
||||
tokenizer = ascendspeed.global_vars.rebuild_tokenizer(args, tokenizer=self.tokenizer_ori)
|
||||
else:
|
||||
tokenizer = ascendspeed.global_vars.rebuild_tokenizer(args, tokenizer=tokenizer)
|
||||
|
||||
if pad_token_id is not None:
|
||||
tokenizer.pad_token_id = pad_token_id
|
||||
if eos_token_id is not None:
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
|
||||
if tokenizer.eos_token_id is not None:
|
||||
args.eos_id = tokenizer.eos_token_id
|
||||
args.eod_id = tokenizer.eos_token_id
|
||||
else:
|
||||
raise ValueError("Your tokenizer doesn't include eos_token.")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def _post_processing(self, token_stream, tokenizer, context_length, include_input, detokenize):
|
||||
for output, _, log_probs in token_stream:
|
||||
if not include_input:
|
||||
output = [val[context_length:] for val in output]
|
||||
log_probs = [val[context_length:, :] for val in log_probs] if log_probs is not None else None
|
||||
|
||||
if detokenize:
|
||||
try:
|
||||
output_checked = self._ids_check(output, tokenizer)
|
||||
output = tokenizer.batch_decode(output_checked, skip_special_tokens=True)
|
||||
except Exception as e:
|
||||
logging.error("Meet errors when trying to decode the tokens. "
|
||||
"Please handle it by yourself.")
|
||||
logging.error(e)
|
||||
|
||||
output = output[0] if len(output) == 1 else output
|
||||
|
||||
if log_probs is None:
|
||||
yield output
|
||||
else:
|
||||
yield output, log_probs[0] if len(log_probs) == 1 else log_probs
|
||||
|
@ -30,7 +30,8 @@ from ascendspeed import mpu
|
||||
from ascendspeed.core import utils, parallel_state
|
||||
from ascendspeed.enums import PositionEmbeddingType
|
||||
from ascendspeed.model import LayerNorm
|
||||
from ascendspeed.model.enums import AttnMaskType, LayerType, AttnType
|
||||
from ascendspeed.model.fused_layer_norm import MixedFusedLayerNorm
|
||||
from ascendspeed.core.enums import AttnMaskType, LayerType, AttnType
|
||||
from ascendspeed.model.fused_softmax import NPUFusedScaleMaskSoftmax
|
||||
from ascendspeed.model.fused_bias_gelu import bias_gelu_impl
|
||||
from ascendspeed.model.module import MegatronModule
|
||||
@ -448,9 +449,16 @@ class ParallelTransformerLayer(MegatronModule):
|
||||
self.fp32_residual_connection = args.fp32_residual_connection
|
||||
self.sequence_parallel = args.sequence_parallel
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
|
||||
if args.sequence_parallel:
|
||||
self.input_layernorm = MixedFusedLayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=self.sequence_parallel)
|
||||
else:
|
||||
self.input_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = ParallelAttention(
|
||||
@ -463,9 +471,15 @@ class ParallelTransformerLayer(MegatronModule):
|
||||
self.bias_dropout_fusion = args.bias_dropout_fusion
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
if args.sequence_parallel:
|
||||
self.post_attention_layernorm = MixedFusedLayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=self.sequence_parallel)
|
||||
else:
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
|
||||
if self.layer_type == LayerType.decoder:
|
||||
self.inter_attention = ParallelAttention(
|
||||
@ -474,9 +488,15 @@ class ParallelTransformerLayer(MegatronModule):
|
||||
layer_number,
|
||||
attention_type=AttnType.cross_attn)
|
||||
# Layernorm on the attention output.
|
||||
self.post_inter_attention_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
if self.sequence_parallel:
|
||||
self.post_inter_attention_layernorm = MixedFusedLayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=self.sequence_parallel)
|
||||
else:
|
||||
self.post_inter_attention_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
|
||||
self.num_experts = num_experts
|
||||
# MLP
|
||||
@ -757,9 +777,15 @@ class ParallelTransformer(MegatronModule):
|
||||
|
||||
if self.post_process:
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
if args.sequence_parallel:
|
||||
self.final_layernorm = MixedFusedLayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
else:
|
||||
self.final_layernorm = LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon)
|
||||
|
||||
if deepspeed.checkpointing.is_configured():
|
||||
global get_cuda_rng_tracker, checkpoint
|
||||
@ -807,7 +833,10 @@ class ParallelTransformer(MegatronModule):
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func"""
|
||||
self.input_tensor = input_tensor
|
||||
if isinstance(input_tensor, (list, tuple)):
|
||||
self.input_tensor = input_tensor[0]
|
||||
else:
|
||||
self.input_tensor = input_tensor
|
||||
|
||||
def forward(self, hidden_states, attention_mask, layer_past=None,
|
||||
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
|
||||
@ -848,20 +877,23 @@ class ParallelTransformer(MegatronModule):
|
||||
else:
|
||||
if get_key_value:
|
||||
presents = []
|
||||
|
||||
for index in range(self.num_layers):
|
||||
layer = self._get_layer(index)
|
||||
past = None
|
||||
if layer_past is not None:
|
||||
past = layer_past[index]
|
||||
hidden_states = layer(hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_output=encoder_output,
|
||||
enc_dec_attn_mask=enc_dec_attn_mask,
|
||||
layer_past=past,
|
||||
get_key_value=get_key_value)
|
||||
hidden_states, moe_loss = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_output=encoder_output,
|
||||
enc_dec_attn_mask=enc_dec_attn_mask,
|
||||
layer_past=past,
|
||||
get_key_value=get_key_value
|
||||
)
|
||||
if not self.ds_inference:
|
||||
hidden_states, moe_loss = hidden_states
|
||||
moe_losses.append(moe_loss)
|
||||
|
||||
if get_key_value:
|
||||
hidden_states, present = hidden_states
|
||||
presents.append(present)
|
||||
|
119
ascendspeed/model/triangle_attention.py
Normal file
119
ascendspeed/model/triangle_attention.py
Normal file
@ -0,0 +1,119 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
logging.warning("Import torch_npu Error.")
|
||||
|
||||
|
||||
class TriangleAttention(nn.Module):
|
||||
"""
|
||||
The triangle attention reduces the attention calculation of the mask
|
||||
part by dividing the q, k, and v matrices into blocks
|
||||
|
||||
Arguments:
|
||||
block_size: The size of the inverted triangle block, the default is 512,
|
||||
the smaller the block_size, the more calculations will be reduced,
|
||||
but the number of small operators will be increased
|
||||
masked_softmax_func: mask function to be applied.
|
||||
dropout_func: dropout function to be applied.
|
||||
"""
|
||||
|
||||
def __init__(self, block_size=512, masked_softmax_func=None, dropout_func=None):
|
||||
super(TriangleAttention, self).__init__()
|
||||
self.block_size = block_size
|
||||
self.mask_tmp_initialed = False
|
||||
self.mask_tmp_groups = []
|
||||
if masked_softmax_func is not None:
|
||||
self.scaled_masked_softmax = masked_softmax_func
|
||||
else:
|
||||
self.scaled_masked_softmax = torch_npu.npu_scaled_masked_softmax
|
||||
if dropout_func:
|
||||
self.dropout = True
|
||||
self.attn_dropout = dropout_func
|
||||
else:
|
||||
self.dropout = False
|
||||
|
||||
def compute_attn(self, q_layer, k_layer, v_layer, mask_tmp):
|
||||
# [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size]
|
||||
cur_sim = torch.matmul(q_layer, k_layer)
|
||||
|
||||
attention_probs = self.scaled_masked_softmax(cur_sim, mask_tmp)
|
||||
|
||||
# attention dropout
|
||||
if self.dropout:
|
||||
attention_probs = self.attn_dropout(attention_probs)
|
||||
|
||||
# [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd]
|
||||
context_layer_tmp = torch.matmul(attention_probs, v_layer)
|
||||
return context_layer_tmp
|
||||
|
||||
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
||||
# input shape: [b, hn, sq, hd]
|
||||
bsz, head_num, sequence_len, head_dim = key_layer.shape
|
||||
sparse_groups = sequence_len // self.block_size
|
||||
# Determine whether blocks size can be divided by sequence_length
|
||||
flag = sequence_len == self.block_size * sparse_groups
|
||||
if flag:
|
||||
q_tmp_layers = torch.chunk(query_layer, sparse_groups, 2)
|
||||
k_tmp_layers = torch.chunk(key_layer, sparse_groups, 2)
|
||||
v_tmp_layers = torch.chunk(value_layer, sparse_groups, 2)
|
||||
else:
|
||||
seq_tmp = self.block_size * sparse_groups
|
||||
q_last, k_last = query_layer[:, :, seq_tmp:, :], key_layer[:, :, seq_tmp:, :].transpose(2, 3).contiguous()
|
||||
v_last, mask_last = value_layer[:, :, seq_tmp:, :], attention_mask[:, :, seq_tmp:, seq_tmp:]
|
||||
q_tmp_layers = torch.chunk(query_layer[:, :, :seq_tmp, :], sparse_groups, 2)
|
||||
k_tmp_layers = torch.chunk(key_layer[:, :, :seq_tmp, :], sparse_groups, 2)
|
||||
v_tmp_layers = torch.chunk(value_layer[:, :, :seq_tmp, :], sparse_groups, 2)
|
||||
context_list_tmp, k_tmp, v_tmp = [], (), ()
|
||||
for i in range(sparse_groups):
|
||||
# compute slice shape of q k v for each loop
|
||||
q_begin, q_end = i * self.block_size, (i + 1) * self.block_size
|
||||
kv_begin, kv_end = 0, (i + 1) * self.block_size
|
||||
q_tmp = q_tmp_layers[i]
|
||||
# slice k and v
|
||||
if i == 0:
|
||||
k_tmp = k_tmp_layers[i].transpose(2, 3).contiguous()
|
||||
v_tmp = v_tmp_layers[i].contiguous()
|
||||
else:
|
||||
k_tmp = torch.cat((k_tmp, k_tmp_layers[i].transpose(2, 3)), -1).contiguous()
|
||||
v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous()
|
||||
|
||||
if not self.mask_tmp_initialed:
|
||||
mask_tmp = attention_mask[:, :, q_begin:q_end, kv_begin:kv_end]
|
||||
self.mask_tmp_groups.append(mask_tmp.contiguous())
|
||||
else:
|
||||
mask_tmp = self.mask_tmp_groups[i]
|
||||
|
||||
context_layer_tmp = self.compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp)
|
||||
context_list_tmp.append(context_layer_tmp)
|
||||
|
||||
if not flag:
|
||||
# circumstances that cannot be divisible
|
||||
context_layer_tmp = self.compute_attn(q_last, k_last, v_last, mask_last)
|
||||
context_list_tmp.append(context_layer_tmp)
|
||||
context_layer = torch.cat(context_list_tmp, 2)
|
||||
self.mask_tmp_initialed = True
|
||||
new_context_layer_shape = (sequence_len, bsz, head_num * head_dim)
|
||||
context_layer = torch.npu_confusion_transpose(context_layer, [2, 0, 1, 3], [*new_context_layer_shape], True)
|
||||
# =========================
|
||||
# Context layer. [sq, b, hp]
|
||||
# =========================
|
||||
return context_layer
|
||||
|
@ -1,211 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Vision Transformer(VIT) model."""
|
||||
|
||||
import math
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed.model.transformer import ParallelTransformer
|
||||
from ascendspeed.model.utils import (
|
||||
get_linear_layer,
|
||||
init_method_normal,
|
||||
scaled_init_method_normal,
|
||||
)
|
||||
from .module import MegatronModule
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
class VitMlpHead(MegatronModule):
|
||||
"""Pooler layer.
|
||||
|
||||
Pool hidden states of a specific token (for example start of the
|
||||
sequence) and add a linear transformation followed by a tanh.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
init_method: weight initialization method for the linear layer.
|
||||
bias is set to zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_classes):
|
||||
super(VitMlpHead, self).__init__()
|
||||
self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.dense_out = torch.nn.Linear(hidden_size, num_classes)
|
||||
torch.nn.init.constant_(self.dense_out.bias, -10)
|
||||
|
||||
def forward(self, hidden_states, sequence_index=0):
|
||||
# hidden_states: [b, s, h]
|
||||
# sequence_index: index of the token to pool.
|
||||
x = hidden_states[:, sequence_index, :]
|
||||
x = self.dense_in(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.dense_out(x)
|
||||
return x
|
||||
|
||||
|
||||
def twod_interpolate_position_embeddings_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
|
||||
args = get_args()
|
||||
num_patches_per_dim = args.img_dim // args.patch_dim
|
||||
num_patches = num_patches_per_dim ** 2
|
||||
seq_length = num_patches + 1
|
||||
hidden_size = args.hidden_size
|
||||
|
||||
key = prefix + "weight"
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
assert key in state_dict
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
|
||||
assert input_param.shape[1] == hidden_size
|
||||
if input_param.shape[0] != seq_length:
|
||||
# update input_param and load it to state_dict[key]
|
||||
|
||||
num_tok_input = input_param.shape[0] - 1
|
||||
num_tok_new = seq_length - 1
|
||||
input_param_tok, input_param_grid = (
|
||||
input_param[:1, :],
|
||||
input_param[1:, :],
|
||||
)
|
||||
|
||||
gs_input = int(math.sqrt(num_tok_input))
|
||||
gs_new = int(math.sqrt(num_tok_new))
|
||||
|
||||
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
|
||||
input_param_grid = input_param_grid.reshape(
|
||||
(1, -1, gs_input, gs_input)
|
||||
)
|
||||
input_param_grid = input_param_grid.float()
|
||||
scale_factor = gs_new / gs_input
|
||||
|
||||
input_param_grid = F.interpolate(
|
||||
input_param_grid, scale_factor=scale_factor, mode="bilinear"
|
||||
)
|
||||
|
||||
input_param_grid = input_param_grid.half()
|
||||
input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
|
||||
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
|
||||
|
||||
assert input_param_grid.shape[1] == hidden_size
|
||||
input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
|
||||
assert (
|
||||
input_param.shape[0] == seq_length
|
||||
and input_param.shape[1] == hidden_size
|
||||
)
|
||||
|
||||
state_dict[key] = input_param
|
||||
|
||||
|
||||
class VitModel(MegatronModule):
|
||||
"""Vision Transformer Model."""
|
||||
|
||||
def __init__(self, num_classes, finetune=False):
|
||||
super(VitModel, self).__init__()
|
||||
args = get_args()
|
||||
|
||||
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
|
||||
if args.init_method_xavier_uniform:
|
||||
self.init_method = torch.nn.init.xavier_uniform_
|
||||
self.scaled_init_method = torch.nn.init.xavier_uniform_
|
||||
else:
|
||||
self.init_method = init_method_normal(args.init_method_std)
|
||||
self.scaled_init_method = scaled_init_method_normal(
|
||||
args.init_method_std, args.num_layers
|
||||
)
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_classes = num_classes
|
||||
self.patch_dim = args.patch_dim
|
||||
self.img_dim = args.img_dim
|
||||
self.finetune = finetune
|
||||
|
||||
assert self.img_dim % self.patch_dim == 0
|
||||
self.num_patches_per_dim = self.img_dim // self.patch_dim
|
||||
self.num_patches = self.num_patches_per_dim ** 2
|
||||
self.seq_length = self.num_patches + 1
|
||||
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
|
||||
|
||||
# cls_token
|
||||
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
|
||||
torch.nn.init.zeros_(self.cls_token)
|
||||
|
||||
# Linear encoder
|
||||
self.linear_encoder = torch.nn.Linear(
|
||||
self.flatten_dim, self.hidden_size
|
||||
)
|
||||
|
||||
# embedding
|
||||
self.position_embeddings = torch.nn.Embedding(
|
||||
self.seq_length, self.hidden_size
|
||||
)
|
||||
init_method_normal(args.init_method_std)(
|
||||
self.position_embeddings.weight
|
||||
)
|
||||
self.position_ids = torch.arange(self.seq_length).expand(1, -1).to(get_accelerator().device_name())
|
||||
|
||||
self.position_embeddings._register_load_state_dict_pre_hook(
|
||||
twod_interpolate_position_embeddings_hook
|
||||
)
|
||||
|
||||
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
|
||||
|
||||
# Transformer
|
||||
self.transformer = ParallelTransformer(
|
||||
self.init_method, self.scaled_init_method
|
||||
)
|
||||
|
||||
# MLP head
|
||||
if not self.finetune:
|
||||
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
|
||||
else:
|
||||
self.class_head = get_linear_layer(
|
||||
self.hidden_size, num_classes, torch.nn.init.zeros_
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = einops.rearrange(
|
||||
x,
|
||||
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
|
||||
p1=self.patch_dim,
|
||||
p2=self.patch_dim,
|
||||
)
|
||||
|
||||
assert x.dtype == torch.half
|
||||
x = self.linear_encoder(x)
|
||||
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = x + self.position_embeddings(self.position_ids)
|
||||
x = self.embedding_dropout(x)
|
||||
x = self.transformer(x, None)
|
||||
|
||||
if not self.finetune:
|
||||
x = self.mlp_head(x)
|
||||
else:
|
||||
x = self.class_head(x[:, 0, :])
|
||||
|
||||
return x
|
@ -396,9 +396,9 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
keep_master_weight_for_test=False,
|
||||
skip_bias_add=False, moe=False,
|
||||
enable_expert_tensor_parallelism=False,
|
||||
sequence_parallel_enabled: bool = False
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
sequence_parallel_enabled: bool = False,
|
||||
dtype=None): # dtype如果没传,会在下面读取args里的配置
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
@ -420,10 +420,11 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
args = get_args()
|
||||
dtype = args.params_dtype if dtype is None else dtype
|
||||
if args.use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(self.output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=args.params_dtype))
|
||||
dtype=dtype))
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight, self.output_size, self.input_size,
|
||||
self.output_size_per_partition, 0, init_method,
|
||||
@ -431,19 +432,19 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size_per_partition, self.input_size,
|
||||
device=get_accelerator().current_device_name(), dtype=args.params_dtype))
|
||||
device=get_accelerator().current_device_name(), dtype=dtype))
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=0, stride=stride)
|
||||
|
||||
if bias:
|
||||
if args.use_cpu_initialization:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition, dtype=args.params_dtype))
|
||||
self.output_size_per_partition, dtype=dtype))
|
||||
else:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition,
|
||||
device=get_accelerator().current_device_name(),
|
||||
dtype=args.params_dtype))
|
||||
dtype=dtype))
|
||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
@ -525,8 +526,9 @@ class RowParallelLinear(torch.nn.Module):
|
||||
keep_master_weight_for_test=False,
|
||||
skip_bias_add=False, moe=False,
|
||||
enable_expert_tensor_parallelism=False,
|
||||
sequence_parallel_enabled: bool = False):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
sequence_parallel_enabled: bool = False,
|
||||
dtype=None): # dtype如果没传,会在下面读取args里的配置
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
@ -551,10 +553,11 @@ class RowParallelLinear(torch.nn.Module):
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
args = get_args()
|
||||
dtype = args.params_dtype if dtype is None else dtype
|
||||
if args.use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(self.output_size,
|
||||
self.input_size_per_partition,
|
||||
dtype=args.params_dtype))
|
||||
dtype=dtype))
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight, self.output_size, self.input_size,
|
||||
self.input_size_per_partition, 1, init_method,
|
||||
@ -562,17 +565,17 @@ class RowParallelLinear(torch.nn.Module):
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size, self.input_size_per_partition,
|
||||
device=get_accelerator().current_device_name(), dtype=args.params_dtype))
|
||||
device=get_accelerator().current_device_name(), dtype=dtype))
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=1, stride=stride)
|
||||
if bias:
|
||||
if args.use_cpu_initialization:
|
||||
self.bias = Parameter(torch.empty(self.output_size,
|
||||
dtype=args.params_dtype))
|
||||
dtype=dtype))
|
||||
else:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size, device=get_accelerator().current_device_name(),
|
||||
dtype=args.params_dtype))
|
||||
dtype=dtype))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
|
@ -1,83 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
import mpu
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
class IdentityLayer(torch.nn.Module):
|
||||
def __init__(self, size, scale=1.0):
|
||||
super(IdentityLayer, self).__init__()
|
||||
self.weight = torch.nn.Parameter(scale * torch.randn(size))
|
||||
|
||||
def forward(self):
|
||||
return self.weight
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set random seed for reproducability."""
|
||||
random.seed(seed)
|
||||
numpy.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
mpu.model_parallel_cuda_manual_seed(seed)
|
||||
|
||||
|
||||
def initialize_distributed(backend='nccl'):
|
||||
"""Initialize torch.distributed."""
|
||||
# Get local rank in case it is provided.
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--local_rank', type=int, default=None,
|
||||
help='local rank passed from distributed launcher')
|
||||
args = parser.parse_args()
|
||||
local_rank = args.local_rank
|
||||
|
||||
# Get rank and world size.
|
||||
rank = int(os.getenv('RANK', '0'))
|
||||
world_size = int(os.getenv("WORLD_SIZE", '1'))
|
||||
|
||||
print('> initializing torch.distributed with local rank: {}, '
|
||||
'rank: {}, world size: {}'.format(local_rank, rank, world_size))
|
||||
|
||||
# Set the device id.
|
||||
device = rank % get_accelerator().device_count()
|
||||
if local_rank is not None:
|
||||
device = local_rank
|
||||
get_accelerator().set_device(device)
|
||||
|
||||
# Call the init process.
|
||||
init_method = 'tcp://'
|
||||
master_ip = os.getenv('MASTER_ADDR', 'localhost')
|
||||
master_port = os.getenv('MASTER_PORT', '6000')
|
||||
init_method += master_ip + ':' + master_port
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
init_method=init_method)
|
||||
|
||||
|
||||
def print_separator(message):
|
||||
torch.distributed.barrier()
|
||||
filler_len = (78 - len(message)) // 2
|
||||
filler = '-' * filler_len
|
||||
string = '\n' + filler + ' {} '.format(message) + filler
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(string, flush=True)
|
||||
torch.distributed.barrier()
|
@ -1,109 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from commons import set_random_seed
|
||||
from commons import IdentityLayer
|
||||
from commons import print_separator
|
||||
from commons import initialize_distributed
|
||||
from mpu.cross_entropy import vocab_parallel_cross_entropy
|
||||
import mpu
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import random
|
||||
import sys
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
sys.path.append("../..")
|
||||
|
||||
|
||||
def torch_cross_entropy(batch_size, seq_length, vocab_size,
|
||||
logits_scale, seed):
|
||||
set_random_seed(seed)
|
||||
identity = IdentityLayer((batch_size, seq_length, vocab_size),
|
||||
scale=logits_scale).to(get_accelerator().device_name())
|
||||
logits = identity()
|
||||
target = get_accelerator().LongTensor(
|
||||
size=(batch_size, seq_length)).random_(0, vocab_size)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
|
||||
target.view(-1),
|
||||
reduction='none').view_as(target).mean()
|
||||
loss.backward()
|
||||
return loss, identity.weight.grad
|
||||
|
||||
|
||||
def mpu_cross_entropy(batch_size, seq_length, vocab_size,
|
||||
logits_scale, seed):
|
||||
set_random_seed(seed)
|
||||
identity = IdentityLayer((batch_size, seq_length, vocab_size),
|
||||
scale=logits_scale).to(get_accelerator().device_name())
|
||||
logits = identity()
|
||||
logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits)
|
||||
target = get_accelerator().LongTensor(
|
||||
size=(batch_size, seq_length)).random_(0, vocab_size)
|
||||
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
|
||||
loss.backward()
|
||||
return loss, identity.weight.grad
|
||||
|
||||
|
||||
def test_cross_entropy(tensor_model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing cross entropy with model parallel size {} ...'.
|
||||
format(tensor_model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
batch_size = 13
|
||||
seq_length = 17
|
||||
vocab_size_per_partition = 11
|
||||
logits_scale = 1000.0
|
||||
vocab_size = vocab_size_per_partition * tensor_model_parallel_size
|
||||
seed = 1234
|
||||
|
||||
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
|
||||
vocab_size, logits_scale,
|
||||
seed)
|
||||
loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length,
|
||||
vocab_size, logits_scale,
|
||||
seed)
|
||||
|
||||
error = loss_torch.sub_(loss_mpu).abs().max()
|
||||
print(' max error in loss on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = grad_torch.sub_(grad_mpu).abs().max()
|
||||
print(' max error in grad on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_tensor_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
print_separator('test cross entropy')
|
||||
test_cross_entropy(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
@ -1,89 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from commons import print_separator
|
||||
from commons import initialize_distributed
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from mpu import data as data_utils
|
||||
import mpu
|
||||
import torch
|
||||
import functools
|
||||
import operator
|
||||
import sys
|
||||
sys.path.append("../..")
|
||||
|
||||
|
||||
def test_broadcast_data(tensor_model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing broadcast_data with model parallel size {} ...'.
|
||||
format(tensor_model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
torch.manual_seed(1234 + mpu.get_data_parallel_rank())
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
key_size_t = {'key1': [7, 11],
|
||||
'key2': [8, 2, 1],
|
||||
'key3': [13],
|
||||
'key4': [5, 1, 2],
|
||||
'key5': [5, 12]}
|
||||
keys = list(key_size_t.keys())
|
||||
|
||||
data = {}
|
||||
data_t = {}
|
||||
for key in key_size_t:
|
||||
data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
|
||||
data_t[key] = data[key].clone()
|
||||
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
|
||||
data_t['keyX'] = data['keyX'].clone()
|
||||
if mpu.get_tensor_model_parallel_rank() != 0:
|
||||
data = None
|
||||
|
||||
data_utils._check_data_types(keys, data_t, torch.int64)
|
||||
key_size, key_numel, \
|
||||
total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
|
||||
for key in keys:
|
||||
assert key_size[key] == key_size_t[key]
|
||||
total_numel_t = 0
|
||||
for key in keys:
|
||||
target_size = functools.reduce(operator.mul, key_size_t[key], 1)
|
||||
assert key_numel[key] == target_size
|
||||
total_numel_t += target_size
|
||||
assert total_numel == total_numel_t
|
||||
|
||||
data_b = data_utils.broadcast_data(keys, data, torch.int64)
|
||||
for key in keys:
|
||||
tensor = data_t[key].to(get_accelerator().device_name())
|
||||
assert data_b[key].sub(tensor).abs().max() == 0
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_tensor_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
print_separator('test test broadcast data')
|
||||
test_broadcast_data(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
@ -1,95 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from commons import print_separator
|
||||
from commons import initialize_distributed
|
||||
import mpu
|
||||
import torch
|
||||
import sys
|
||||
sys.path.append("../..")
|
||||
|
||||
|
||||
def test_initialize_model_parallel(tensor_model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing initialize_model_parallel with size {} ...'.format(
|
||||
tensor_model_parallel_size))
|
||||
tensor_model_parallel_size_ = min(tensor_model_parallel_size,
|
||||
torch.distributed.get_world_size())
|
||||
assert not mpu.model_parallel_is_initialized()
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size_)
|
||||
assert mpu.model_parallel_is_initialized()
|
||||
|
||||
# Checks.
|
||||
def check(group, world_size, rank):
|
||||
assert world_size == torch.distributed.get_world_size(group=group)
|
||||
assert rank == torch.distributed.get_rank(group=group)
|
||||
|
||||
# Model parallel.
|
||||
world_size = tensor_model_parallel_size_
|
||||
rank = torch.distributed.get_rank() % tensor_model_parallel_size_
|
||||
assert world_size == mpu.get_tensor_model_parallel_world_size()
|
||||
assert rank == mpu.get_tensor_model_parallel_rank()
|
||||
check(mpu.get_tensor_model_parallel_group(), world_size, rank)
|
||||
|
||||
# Data parallel.
|
||||
world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
|
||||
rank = torch.distributed.get_rank() // tensor_model_parallel_size
|
||||
assert world_size == mpu.get_data_parallel_world_size()
|
||||
assert rank == mpu.get_data_parallel_rank()
|
||||
check(mpu.get_data_parallel_group(), world_size, rank)
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
|
||||
tensor_model_parallel_size_))
|
||||
tensor_model_parallel_size = min(tensor_model_parallel_size_,
|
||||
torch.distributed.get_world_size())
|
||||
assert not mpu.model_parallel_is_initialized()
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
assert mpu.model_parallel_is_initialized()
|
||||
|
||||
# Checks
|
||||
src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
|
||||
assert mpu.get_tensor_model_parallel_src_rank() == src_rank
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
print_separator('test initialize model parallel')
|
||||
test_initialize_model_parallel(tensor_model_parallel_size)
|
||||
print_separator('test model parallel source rank')
|
||||
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
@ -1,530 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mpu import layers
|
||||
from commons import set_random_seed
|
||||
from commons import print_separator
|
||||
from commons import initialize_distributed
|
||||
import mpu
|
||||
from torch.nn.parameter import Parameter
|
||||
import torch.nn.init as init
|
||||
import torch
|
||||
import random
|
||||
import sys
|
||||
sys.path.append("../..")
|
||||
|
||||
device_name = get_accelerator().device_name()
|
||||
def test_parallel_embedding(tensor_model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing parallel embedding with model parallel size {} ...'.
|
||||
format(tensor_model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
batch_size = 17
|
||||
seq_length = 23
|
||||
vocab_size = 48
|
||||
hidden_size = 16
|
||||
seed = 1236
|
||||
|
||||
set_random_seed(123)
|
||||
input_data = torch.LongTensor(
|
||||
size=(batch_size, seq_length)).random_(0, vocab_size).to(device_name)
|
||||
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).to(device_name)
|
||||
|
||||
set_random_seed(seed)
|
||||
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).to(device_name)
|
||||
|
||||
output = embedding_original(input_data)
|
||||
loss_original = torch.mul(output, loss_weight).sum()
|
||||
loss_original.backward()
|
||||
|
||||
set_random_seed(seed)
|
||||
embedding_parallel = layers.ParallelEmbedding(
|
||||
vocab_size, hidden_size, init_method=init.normal_).to(device_name)
|
||||
output = embedding_parallel(input_data)
|
||||
loss_parallel = torch.mul(output, loss_weight).sum()
|
||||
loss_parallel.backward()
|
||||
|
||||
set_random_seed(seed)
|
||||
embedding_vocab_parallel = layers.VocabParallelEmbedding(
|
||||
vocab_size, hidden_size, init_method=init.normal_).to(device_name)
|
||||
output = embedding_vocab_parallel(input_data)
|
||||
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
|
||||
loss_vocab_parallel.backward()
|
||||
|
||||
torch.distributed.barrier()
|
||||
error = loss_parallel.sub(loss_original).abs()
|
||||
print(' error in loss (parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
torch.distributed.barrier()
|
||||
error = loss_vocab_parallel.sub(loss_original).abs()
|
||||
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
weight_grad_orig = torch.split(embedding_original.weight.grad,
|
||||
hidden_size // tensor_model_parallel_size,
|
||||
1)[mpu.get_tensor_model_parallel_rank()]
|
||||
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
|
||||
print(' error in grad (parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
weight_grad_orig = torch.split(embedding_original.weight.grad,
|
||||
vocab_size // tensor_model_parallel_size,
|
||||
0)[mpu.get_tensor_model_parallel_rank()]
|
||||
error = embedding_vocab_parallel.weight.grad.sub(
|
||||
weight_grad_orig).abs().max()
|
||||
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_initialize_affine_weight(tensor_model_parallel_size):
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing initialize_affine_weight with model parallel '
|
||||
'size: {}'.format(tensor_model_parallel_size))
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
input_size_coeff = 13
|
||||
input_size = input_size_coeff * tensor_model_parallel_size
|
||||
output_size_coeff = 17
|
||||
output_size = output_size_coeff * tensor_model_parallel_size
|
||||
|
||||
# ---------------
|
||||
# Column parallel
|
||||
# ---------------
|
||||
weight = torch.empty(output_size_coeff, input_size)
|
||||
set_random_seed(seed)
|
||||
layers._initialize_affine_weight(weight, output_size, input_size,
|
||||
|
||||
output_size_coeff, 0,
|
||||
torch.nn.init.normal_)
|
||||
# Target.
|
||||
set_random_seed(seed)
|
||||
master_weight = torch.empty(output_size, input_size)
|
||||
torch.nn.init.normal_(master_weight)
|
||||
rank = mpu.get_tensor_model_parallel_rank()
|
||||
my_weight = torch.split(master_weight, output_size_coeff,
|
||||
dim=0)[rank].contiguous().clone()
|
||||
|
||||
# Compare.
|
||||
error = weight.sub(my_weight).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' column parallel max error (should be zero) on global rank '
|
||||
'{}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# ------------
|
||||
# Row parallel
|
||||
# ------------
|
||||
weight = torch.empty(output_size, input_size_coeff)
|
||||
set_random_seed(seed)
|
||||
mpu.layers._initialize_affine_weight(weight, output_size, input_size,
|
||||
input_size_coeff, 1,
|
||||
torch.nn.init.normal_)
|
||||
# Target.
|
||||
set_random_seed(seed)
|
||||
master_weight = torch.empty(output_size, input_size)
|
||||
torch.nn.init.normal_(master_weight)
|
||||
rank = mpu.get_tensor_model_parallel_rank()
|
||||
my_weight = torch.split(master_weight, input_size_coeff,
|
||||
dim=1)[rank].contiguous().clone()
|
||||
|
||||
# Compare.
|
||||
error = weight.sub(my_weight).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' row parallel max error (should be zero) on global rank '
|
||||
'{}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
class IdentityLayer2D(torch.nn.Module):
|
||||
def __init__(self, m, n):
|
||||
super(IdentityLayer2D, self).__init__()
|
||||
self.weight = Parameter(torch.Tensor(m, n))
|
||||
torch.nn.init.xavier_normal_(self.weight)
|
||||
|
||||
def forward(self):
|
||||
return self.weight
|
||||
|
||||
|
||||
def test_column_parallel_linear(tensor_model_parallel_size):
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing ColumnParallelLinear with model parallel '
|
||||
'size: {}'.format(tensor_model_parallel_size))
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
input_size_coeff = 13
|
||||
input_size = input_size_coeff * tensor_model_parallel_size
|
||||
output_size_coeff = 17
|
||||
output_size = output_size_coeff * tensor_model_parallel_size
|
||||
batch_size = 7
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer2D(batch_size, input_size).to(device_name)
|
||||
linear_layer = mpu.ColumnParallelLinear(
|
||||
input_size, output_size, keep_master_weight_for_test=True).to(device_name)
|
||||
loss_weight = torch.randn([batch_size, output_size]).to(device_name)
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = linear_layer(input_)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
# Values.
|
||||
dLdY = loss_weight
|
||||
X = identity_layer.weight
|
||||
A = linear_layer.master_weight.to(device_name)
|
||||
dLdA = torch.matmul(dLdY.t(), X)
|
||||
dLdb = torch.matmul(torch.ones(batch_size, 1).to(device_name).t(), dLdY).view(-1)
|
||||
dLdX = torch.matmul(dLdY, A)
|
||||
|
||||
rank = mpu.get_tensor_model_parallel_rank()
|
||||
my_dLdA = torch.split(dLdA, output_size_coeff,
|
||||
dim=0)[rank].contiguous().clone()
|
||||
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdA on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
my_dLdb = torch.split(dLdb, output_size_coeff,
|
||||
dim=0)[rank].contiguous().clone()
|
||||
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdb on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = dLdX.sub(identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdX on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
def test_row_parallel_linear(tensor_model_parallel_size):
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing RowParallelLinear with model parallel '
|
||||
'size: {}'.format(tensor_model_parallel_size))
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
input_size_coeff = 13
|
||||
input_size = input_size_coeff * tensor_model_parallel_size
|
||||
output_size_coeff = 17
|
||||
output_size = output_size_coeff * tensor_model_parallel_size
|
||||
batch_size = 7
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer2D(batch_size, input_size).to(device_name)
|
||||
linear_layer = mpu.RowParallelLinear(
|
||||
input_size, output_size, keep_master_weight_for_test=True).to(device_name)
|
||||
loss_weight = torch.randn([batch_size, output_size]).to(device_name)
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = linear_layer(input_)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
# Values.
|
||||
dLdY = loss_weight
|
||||
X = identity_layer.weight
|
||||
A = linear_layer.master_weight.to(device_name)
|
||||
dLdA = torch.matmul(dLdY.t(), X)
|
||||
dLdb = torch.matmul(torch.ones(batch_size, 1).to(device_name).t(), dLdY).view(-1)
|
||||
dLdX = torch.matmul(dLdY, A)
|
||||
|
||||
rank = mpu.get_tensor_model_parallel_rank()
|
||||
my_dLdA = torch.split(dLdA, input_size_coeff,
|
||||
dim=1)[rank].contiguous().clone()
|
||||
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdA on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = dLdb.sub(linear_layer.bias.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdb on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = dLdX.sub(identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdX on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
class IdentityLayer3D(torch.nn.Module):
|
||||
def __init__(self, m, n, k):
|
||||
super(IdentityLayer3D, self).__init__()
|
||||
self.weight = Parameter(torch.Tensor(m, n, k))
|
||||
torch.nn.init.xavier_normal_(self.weight)
|
||||
|
||||
def forward(self):
|
||||
return self.weight
|
||||
|
||||
|
||||
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, dropout_prob, batch_size,
|
||||
sequence_length):
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
|
||||
num_att_heads = num_att_heads_per_partition * \
|
||||
torch.distributed.get_world_size()
|
||||
hidden_size = hidden_size_per_att_head * num_att_heads
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer3D(batch_size, sequence_length,
|
||||
hidden_size).to(device_name)
|
||||
attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
|
||||
dropout_prob).to(device_name)
|
||||
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).to(device_name)
|
||||
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).to(device_name)
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = attention_layer(input_, attention_mask)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
rank = mpu.get_tensor_model_parallel_rank()
|
||||
mpu.destroy_model_parallel()
|
||||
return rank, hidden_size, tensor_model_parallel_size, loss, \
|
||||
attention_layer, identity_layer
|
||||
|
||||
|
||||
def test_parallel_self_attention(tensor_model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing ParallelSelfAttention with model parallel '
|
||||
'size: {}'.format(tensor_model_parallel_size))
|
||||
|
||||
num_att_heads_per_partition = 3
|
||||
hidden_size_per_att_head = 7
|
||||
dropout_prob = 0.0 # has to be zero
|
||||
batch_size = 5
|
||||
sequence_length = 13
|
||||
|
||||
rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
|
||||
attention_layer_1, identity_layer_1 = parallel_self_attention(
|
||||
1, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
|
||||
|
||||
rank, hidden_size, tensor_model_parallel_size, loss, \
|
||||
attention_layer, identity_layer = parallel_self_attention(
|
||||
tensor_model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
|
||||
assert hideen_size_1 == hidden_size
|
||||
|
||||
error = loss_1.sub(loss).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' loss error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-6
|
||||
|
||||
my_lin_grad_list = torch.split(
|
||||
attention_layer_1.query_key_value.weight.grad,
|
||||
hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
|
||||
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
|
||||
error = my_lin_grad.sub(
|
||||
attention_layer.query_key_value.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' weight gradient error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-6
|
||||
|
||||
error = identity_layer_1.weight.grad.sub(
|
||||
identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' input gradient error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-6
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, batch_size, sequence_length):
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
|
||||
num_att_heads = num_att_heads_per_partition * \
|
||||
torch.distributed.get_world_size()
|
||||
hidden_size = hidden_size_per_att_head * num_att_heads
|
||||
intermediate_size = 4 * hidden_size
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer3D(batch_size, sequence_length,
|
||||
hidden_size).to(device_name)
|
||||
transformer_layer = mpu.BertParallelTransformerLayer(
|
||||
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
|
||||
torch.nn.functional.relu, 1.0e-5).to(device_name)
|
||||
|
||||
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).to(device_name)
|
||||
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).to(device_name)
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = transformer_layer(input_, attention_mask)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
rank = mpu.get_tensor_model_parallel_rank()
|
||||
mpu.destroy_model_parallel()
|
||||
return rank, hidden_size, tensor_model_parallel_size, loss, \
|
||||
transformer_layer, identity_layer
|
||||
|
||||
|
||||
def test_parallel_transformer_layer(tensor_model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing ParallelTransformerLayer with model parallel '
|
||||
'size: {}'.format(tensor_model_parallel_size))
|
||||
|
||||
num_att_heads_per_partition = 3
|
||||
hidden_size_per_att_head = 7
|
||||
batch_size = 5
|
||||
sequence_length = 13
|
||||
|
||||
rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
|
||||
transformer_layer_1, identity_layer_1 = parallel_transformer(
|
||||
1, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, batch_size, sequence_length)
|
||||
|
||||
rank, hidden_size, tensor_model_parallel_size, loss, \
|
||||
transformer_layer, identity_layer = parallel_transformer(
|
||||
tensor_model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, batch_size, sequence_length)
|
||||
|
||||
error = loss_1.sub(loss).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' loss error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-5, 'error: {}'.format(error)
|
||||
|
||||
error = identity_layer_1.weight.grad.sub(
|
||||
identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' input gradient error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-5, 'error: {}'.format(error)
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
print_separator('test initialize affine weight')
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
test_initialize_affine_weight(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
||||
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
print_separator('test parallel embedding')
|
||||
test_parallel_embedding(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
||||
|
||||
print_separator('test column-parallel linear')
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
test_column_parallel_linear(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
||||
|
||||
print_separator('test row-parallel linear')
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
test_row_parallel_linear(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
||||
|
||||
print_separator('test parallel self-attention')
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
test_parallel_self_attention(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
||||
|
||||
print_separator('test parallel transformer')
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
test_parallel_transformer_layer(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
@ -1,204 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from commons import print_separator
|
||||
from commons import initialize_distributed
|
||||
import mpu
|
||||
import torch
|
||||
import sys
|
||||
sys.path.append("../..")
|
||||
|
||||
|
||||
def test_set_cuda_rng_state(tensor_model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing set_rng_state with size {} ...'.
|
||||
format(tensor_model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
size = 123
|
||||
seed = 1234
|
||||
get_accelerator().manual_seed(1234)
|
||||
tensor = get_accelerator().FloatTensor(size)
|
||||
|
||||
# Get the state
|
||||
rng_state = get_accelerator().get_rng_state()
|
||||
rng_state_copy = rng_state.clone()
|
||||
|
||||
# Do some stuff.
|
||||
for _ in range(5):
|
||||
torch.randn(size, out=tensor)
|
||||
result_1 = tensor.clone()
|
||||
|
||||
assert rng_state.sub(rng_state_copy).max() == 0
|
||||
assert get_accelerator().get_rng_state().sub(rng_state_copy).max() > 0
|
||||
|
||||
# State should be different.
|
||||
new_rng_state = get_accelerator().get_rng_state()
|
||||
max_diff = new_rng_state.sub(rng_state).max()
|
||||
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
|
||||
format(torch.distributed.get_rank(), max_diff))
|
||||
assert max_diff > 0
|
||||
|
||||
# Reset the rng state and do the same stuff.
|
||||
mpu.random._set_cuda_rng_state(rng_state)
|
||||
for _ in range(5):
|
||||
torch.randn(size, out=tensor)
|
||||
mpu.random._set_cuda_rng_state(rng_state)
|
||||
for _ in range(5):
|
||||
torch.randn(size, out=tensor)
|
||||
result_2 = tensor.clone()
|
||||
|
||||
# Results should be the same
|
||||
error = result_2.sub(result_1).abs().max()
|
||||
print(' max error in generated tensors (should be zero) on '
|
||||
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Input state should have remained intact.
|
||||
error = rng_state.sub(rng_state_copy).max()
|
||||
print(' max error in rng state (should be zero) on global rank {}: {}'.
|
||||
format(torch.distributed.get_rank(), error))
|
||||
assert error == 0
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_cuda_rng_tracker(tensor_model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing cuda rng tracker with size {} ...'.
|
||||
format(tensor_model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
seed_1 = 1234
|
||||
seed_2 = 4321
|
||||
size = [12, 21]
|
||||
tensor = get_accelerator().FloatTensor(size)
|
||||
|
||||
# Set to seed_1 and generate two tensors.
|
||||
get_accelerator().manual_seed(seed_1)
|
||||
torch.randn(size, out=tensor)
|
||||
target_11 = tensor.clone()
|
||||
torch.randn(size, out=tensor)
|
||||
target_12 = tensor.clone()
|
||||
|
||||
# Set to seed_2 and generate two tensors.
|
||||
get_accelerator().manual_seed(seed_2)
|
||||
torch.randn(size, out=tensor)
|
||||
target_21 = tensor.clone()
|
||||
torch.randn(size, out=tensor)
|
||||
target_22 = tensor.clone()
|
||||
|
||||
# Now if we interleave seed_1 and seed_2,
|
||||
# we should still get the same tensors
|
||||
get_accelerator().manual_seed(seed_1)
|
||||
mpu.get_cuda_rng_tracker().add('test', seed_2)
|
||||
|
||||
torch.randn(size, out=tensor)
|
||||
result_11 = tensor.clone()
|
||||
|
||||
with mpu.get_cuda_rng_tracker().fork('test'):
|
||||
torch.randn(size, out=tensor)
|
||||
result_21 = tensor.clone()
|
||||
|
||||
torch.randn(size, out=tensor)
|
||||
result_12 = tensor.clone()
|
||||
|
||||
with mpu.get_cuda_rng_tracker().fork('test'):
|
||||
torch.randn(size, out=tensor)
|
||||
result_22 = tensor.clone()
|
||||
|
||||
diff = result_11.sub(result_21).abs().max()
|
||||
diff = min(diff, result_12.sub(result_22).abs().max())
|
||||
print(' max diff in generated tensors (should be non-zero) on '
|
||||
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
|
||||
assert diff > 1.0e-6
|
||||
error = max(result_11.sub(target_11).abs().max(),
|
||||
result_12.sub(target_12).abs().max())
|
||||
error = max(error, result_21.sub(target_21).abs().max())
|
||||
error = max(error, result_22.sub(target_22).abs().max())
|
||||
print(' max error in generated tensors (should be zero) on '
|
||||
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset the tracker
|
||||
mpu.get_cuda_rng_tracker().reset()
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing model parallel cuda manual seed with size {} ...'.
|
||||
format(tensor_model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
|
||||
|
||||
mpu.model_parallel_cuda_manual_seed(12345)
|
||||
assert get_accelerator().initial_seed() == 12345
|
||||
with mpu.get_cuda_rng_tracker().fork():
|
||||
assert get_accelerator().initial_seed() == (12345 + 2718 +
|
||||
mpu.get_tensor_model_parallel_rank())
|
||||
|
||||
# Reset the tracker
|
||||
mpu.get_cuda_rng_tracker().reset()
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
print_separator('test set rng state')
|
||||
test_set_cuda_rng_state(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
||||
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
print_separator('test cuda rng tracker')
|
||||
test_cuda_rng_tracker(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
||||
|
||||
tensor_model_parallel_size = 1
|
||||
while tensor_model_parallel_size <= world_size:
|
||||
print_separator('test model parallel cuda manual seed')
|
||||
test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
|
||||
tensor_model_parallel_size *= 2
|
@ -19,9 +19,12 @@ from torch.optim import SGD
|
||||
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed.model import LayerNorm
|
||||
from ascendspeed.model.fused_layer_norm import MixedFusedLayerNorm
|
||||
from ascendspeed.model.lora_utils import is_enable_lora
|
||||
|
||||
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
|
||||
from .distrib_optimizer import DistributedOptimizer
|
||||
|
||||
def _get_params_for_weight_decay_optimization(modules):
|
||||
"""Divide params into with-weight-decay and without-weight-decay groups.
|
||||
@ -29,29 +32,79 @@ def _get_params_for_weight_decay_optimization(modules):
|
||||
"""
|
||||
args = get_args()
|
||||
|
||||
weight_decay_params = {'params': [], 'name' : 'weight_decay_params'}
|
||||
weight_decay_params = {'params': [], 'name': 'weight_decay_params'}
|
||||
no_weight_decay_params = {'params': [], 'weight_decay': 0.0, 'name': 'no_weight_decay_params'}
|
||||
|
||||
for module in modules:
|
||||
for module_ in module.modules():
|
||||
if isinstance(module_, LayerNorm):
|
||||
if isinstance(module_, LayerNorm) or isinstance(module_, MixedFusedLayerNorm):
|
||||
no_weight_decay_params['params'].extend(
|
||||
[p for p in list(module_._parameters.values())
|
||||
if p is not None])
|
||||
if p is not None and p.requires_grad])
|
||||
else:
|
||||
weight_decay_params['params'].extend(
|
||||
[p for n, p in list(module_._parameters.items())
|
||||
if p is not None and n != 'bias'])
|
||||
if p is not None and n != 'bias' and p.requires_grad])
|
||||
no_weight_decay_params['params'].extend(
|
||||
[p for n, p in list(module_._parameters.items())
|
||||
if p is not None and n == 'bias'])
|
||||
if p is not None and n == 'bias' and p.requires_grad])
|
||||
return weight_decay_params, no_weight_decay_params
|
||||
|
||||
|
||||
def _get_sp_params_for_weight_decay_optimization(modules):
|
||||
"""Divide params into with-weight-decay, sp-norm-without-decay and without-weight-decay groups.
|
||||
Layernorms and baises will have no weight decay but the rest will.
|
||||
"""
|
||||
params = 'params'
|
||||
name = 'name'
|
||||
args = get_args()
|
||||
|
||||
weight_decay_params = {params: [], name: 'weight_decay_params'}
|
||||
no_weight_decay_params = {params: [], 'weight_decay': 0.0, name: 'no_weight_decay_params'}
|
||||
no_weight_decay_layernorm_params = {
|
||||
params: [],
|
||||
'weight_decay': 0.0,
|
||||
name: 'no_weight_decay_layernorm_sp_params'
|
||||
}
|
||||
|
||||
def classify_params(local_module):
|
||||
nonlocal weight_decay_params
|
||||
nonlocal no_weight_decay_params
|
||||
nonlocal no_weight_decay_layernorm_params
|
||||
if isinstance(local_module, LayerNorm) or isinstance(local_module, MixedFusedLayerNorm):
|
||||
if getattr(list(local_module.named_parameters(recurse=False))[0][1], 'sequence_parallel', False):
|
||||
no_weight_decay_layernorm_params[params].extend(
|
||||
[p for _, p in local_module.named_parameters(recurse=False)
|
||||
if p is not None])
|
||||
else:
|
||||
no_weight_decay_params[params].extend(
|
||||
[p for _, p in local_module.named_parameters(recurse=False)
|
||||
if p is not None])
|
||||
else:
|
||||
for n, p in local_module.named_parameters(recurse=False):
|
||||
if p is not None and p.requires_grad:
|
||||
if getattr(p, 'sequence_parallel', False):
|
||||
no_weight_decay_layernorm_params[params].append(p)
|
||||
elif 'bias' not in n:
|
||||
weight_decay_params[params].append(p)
|
||||
elif 'bias' in n:
|
||||
no_weight_decay_params[params].append(p)
|
||||
|
||||
for module in modules:
|
||||
for module_ in module.modules():
|
||||
classify_params(module_)
|
||||
|
||||
return weight_decay_params, no_weight_decay_params, no_weight_decay_layernorm_params
|
||||
|
||||
|
||||
def get_megatron_optimizer(model):
|
||||
args = get_args()
|
||||
|
||||
# Base optimizer.
|
||||
param_groups = _get_params_for_weight_decay_optimization(model)
|
||||
if args.deepspeed and args.sequence_parallel:
|
||||
param_groups = _get_sp_params_for_weight_decay_optimization(model)
|
||||
else:
|
||||
param_groups = _get_params_for_weight_decay_optimization(model)
|
||||
if args.create_moe_param_group:
|
||||
from deepspeed.moe.utils import is_moe_param, split_params_into_different_moe_groups_for_optimizer
|
||||
param_groups = split_params_into_different_moe_groups_for_optimizer(param_groups)
|
||||
@ -95,7 +148,7 @@ def get_megatron_optimizer(model):
|
||||
|
||||
# Determine whether the params have main-grad field.
|
||||
params_have_main_grad = False
|
||||
if args.DDP_impl == 'local':
|
||||
if args.DDP_impl == 'local' and not is_enable_lora():
|
||||
params_have_main_grad = True
|
||||
|
||||
if args.fp16 or args.bf16:
|
||||
@ -122,12 +175,17 @@ def get_megatron_optimizer(model):
|
||||
hysteresis=args.hysteresis)
|
||||
|
||||
# Megatron optimizer.
|
||||
return Float16OptimizerWithFloat16Params(optimizer,
|
||||
args.clip_grad,
|
||||
args.log_num_zeros_in_grad,
|
||||
params_have_main_grad,
|
||||
args.bf16,
|
||||
grad_scaler)
|
||||
opt_ty = DistributedOptimizer if args.use_distributed_optimizer else Float16OptimizerWithFloat16Params
|
||||
return opt_ty(optimizer,
|
||||
args.clip_grad,
|
||||
args.log_num_zeros_in_grad,
|
||||
params_have_main_grad,
|
||||
args.use_contiguous_buffers_in_local_ddp,
|
||||
args.fp16,
|
||||
args.bf16,
|
||||
args.params_dtype,
|
||||
grad_scaler,
|
||||
model)
|
||||
|
||||
# FP32.
|
||||
return FP32Optimizer(optimizer, args.clip_grad,
|
||||
|
@ -4,6 +4,9 @@ import torch
|
||||
import torch_npu
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed import get_args
|
||||
|
||||
|
||||
def adamw(params: List[Tensor],
|
||||
@ -48,6 +51,21 @@ def adamw(params: List[Tensor],
|
||||
out=(param.data, exp_avg, exp_avg_sq)
|
||||
)
|
||||
|
||||
|
||||
def allreduce_layernorm_grads(grads):
|
||||
"""All-reduce layernorm grads (for sequence parallelism)."""
|
||||
|
||||
# All-reduce layernorm parameters across model parallel nodes
|
||||
# when sequence parallelism is used
|
||||
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
torch.distributed.all_reduce(
|
||||
coalesced, group=mpu.get_tensor_model_parallel_group())
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||
coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
r"""Implements AdamW algorithm.
|
||||
.. math::
|
||||
@ -133,6 +151,7 @@ class AdamW(Optimizer):
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
args = get_args()
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
@ -180,7 +199,9 @@ class AdamW(Optimizer):
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
is_sp_valid = 'layernorm_sp' in group['name'] and len(grads) > 0
|
||||
if args.deepspeed and is_sp_valid:
|
||||
allreduce_layernorm_grads(grads)
|
||||
# adamw_torch(params_with_grad,
|
||||
adamw(params_with_grad,
|
||||
grads,
|
||||
|
@ -33,7 +33,9 @@ from ascendspeed.core import parallel_state
|
||||
from ascendspeed.model.module import param_is_not_shared
|
||||
from ascendspeed.mpu.layers import param_is_not_tensor_parallel_duplicate
|
||||
|
||||
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
def clip_grad_norm_fp32(parameters, grads_for_norm,
|
||||
max_norm, norm_type=2,
|
||||
model_parallel_group=None):
|
||||
"""Clips gradient norm of an iterable of parameters whose gradients
|
||||
are in fp32.
|
||||
|
||||
@ -58,21 +60,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
if isinstance(grads_for_norm, torch.Tensor):
|
||||
grads_for_norm = [grads_for_norm]
|
||||
|
||||
# Grads.
|
||||
grads = []
|
||||
grads_for_norm = []
|
||||
for param in parameters:
|
||||
grad_not_none = param.grad is not None
|
||||
is_not_shared = param_is_not_shared(param)
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
grad = param.grad.detach()
|
||||
if grad_not_none:
|
||||
# Make sure the grads are in fp32
|
||||
grads.append(grad)
|
||||
if grad_not_none and is_not_shared and is_not_tp_duplicate:
|
||||
grads_for_norm.append(grad)
|
||||
|
||||
if param.grad is not None:
|
||||
assert param.grad.type() == 'torch.npu.FloatTensor'
|
||||
grads.append(param.grad.detach())
|
||||
|
||||
# Norm parameters.
|
||||
max_norm = float(max_norm)
|
||||
@ -103,12 +99,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
||||
clip_coeff = max_norm / (total_norm + 1.0e-6)
|
||||
if clip_coeff < 1.0:
|
||||
for p in parameters:
|
||||
p.grad.detach().mul_(clip_coeff)
|
||||
if p.grad is not None:
|
||||
p.grad.detach().mul_(clip_coeff)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def count_zeros_fp32(parameters):
|
||||
def count_zeros_fp32(parameters, model_parallel_group):
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
@ -130,7 +127,7 @@ def count_zeros_fp32(parameters):
|
||||
# Sum across all model-parallel GPUs.
|
||||
torch.distributed.all_reduce(total_num_zeros,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=parallel_state.get_model_parallel_group())
|
||||
group=model_parallel_group)
|
||||
total_num_zeros = total_num_zeros.item()
|
||||
|
||||
return total_num_zeros
|
||||
|
729
ascendspeed/optimizer/distrib_optimizer.py
Normal file
729
ascendspeed/optimizer/distrib_optimizer.py
Normal file
@ -0,0 +1,729 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
"""Megatron distributed optimizer."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
from ascendspeed import print_rank_0
|
||||
from ascendspeed import mpu
|
||||
from functools import reduce
|
||||
|
||||
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
|
||||
|
||||
|
||||
class Range:
|
||||
"""
|
||||
A range represents a start and end points for indexing a shard
|
||||
from a full tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.size = end - start
|
||||
|
||||
def normalize(self, start=0):
|
||||
return Range(start, start + self.size)
|
||||
|
||||
def __str__(self):
|
||||
return "%d,%d [%d]" % (self.start, self.end, self.size)
|
||||
|
||||
|
||||
class DistributedOptimizer(MixedPrecisionOptimizer):
|
||||
"""Distributed optimizer, for all data types (fp16, bf16, and fp32).
|
||||
|
||||
Arguments:
|
||||
optimizer: base optimizer such as Adam or SGD
|
||||
clip_grad: clip gradeints with this global L2 norm. Note
|
||||
that clipping is ignored if clip_grad == 0
|
||||
log_num_zeros_in_grad: return number of zeros in the gradients.
|
||||
params_have_main_grad: flag indicating if parameters have
|
||||
a `main_grad` field. If this is set, we are assuming
|
||||
that the model parameters are store in the `main_grad`
|
||||
field instead of the typical `grad` field. This happens
|
||||
for the DDP cases where there is a continuous buffer
|
||||
holding the gradients. For example for bfloat16, we want
|
||||
to do gradient accumulation and all-reduces in float32
|
||||
and as a result we store those gradients in the main_grad.
|
||||
Note that main grad is not necessarily in float32.
|
||||
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
|
||||
is using a contiguous buffer to hold the model grads.
|
||||
fp16: if true, the model is running in fp16.
|
||||
bf16: if true, the model is running in bfloat16.
|
||||
grad_scaler: used for scaling gradients. Note that this can be
|
||||
None. This case happens when `bf16 = True` and we don't
|
||||
use any loss scale. Note that for `bf16 = True`, we can have
|
||||
a constnat gradient scaler. Also for `bf16 = False`, we
|
||||
always require a grad scaler.
|
||||
models: list of models (i.e., the virtual pipelining models). This
|
||||
is used by the distributed optimizer for mapping parameters.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
|
||||
"""
|
||||
Build mapping from param reference to grad buffer shard ranges.
|
||||
|
||||
This method builds a mapping from parameter references to grad
|
||||
buffer shard ranges, specific to each data-parallel (DP) rank's
|
||||
set of 'owned' parameters. Each grad buffer (padded to be an even
|
||||
multiple of DP-world-size) is conceptually divided into DP-world-size
|
||||
contiguous regions, where each DP rank 'owns' a contiguous regions.
|
||||
Ownership in this sense means DP rank is responsible for reducing
|
||||
the relevant subset of grads, and updating the relevant subset of
|
||||
params.
|
||||
|
||||
This conceptual partitioning of the grad buffer does NOT respect
|
||||
parameter boundaries, and as such it is assumed that each created
|
||||
range references a shard (or subset) of the full parameter. It is
|
||||
easiest to think of each DP rank as operating (i.e., reducing,
|
||||
gathering) purely on views into the grad buffer, for all model-to-
|
||||
main & main-to-model operations.
|
||||
|
||||
This method creates three ranges:
|
||||
- The param's range within the entire grad buffer (i.e., world index).
|
||||
- The param's range within the DP rank's local view of the grad buffer.
|
||||
- The param's range within itself (i.e., its shard).
|
||||
"""
|
||||
|
||||
# Param range map.
|
||||
param_world_index_map = model._grad_buffer_param_index_map[dtype]
|
||||
param_range_map = {}
|
||||
for param, param_world_indexes in param_world_index_map.items():
|
||||
|
||||
# Param range.
|
||||
param_world_start, param_world_end = param_world_indexes
|
||||
param_local_start = max(
|
||||
0,
|
||||
param_world_start - gbuf_world_range.start)
|
||||
param_local_end = min(
|
||||
gbuf_world_range.size,
|
||||
param_world_end - gbuf_world_range.start)
|
||||
|
||||
# Add param, if within local gbuf range.
|
||||
if param_local_end > param_local_start:
|
||||
param_local_range = Range(param_local_start, param_local_end)
|
||||
param_world_range = param_local_range.normalize(
|
||||
param_local_start + gbuf_world_range.start)
|
||||
sub_param_start = max(0, gbuf_world_range.start - param_world_start)
|
||||
sub_param_range = param_local_range.normalize(sub_param_start)
|
||||
param_range_map[param] = {
|
||||
"gbuf_world": param_world_range,
|
||||
"gbuf_local": param_local_range,
|
||||
"param": sub_param_range,
|
||||
}
|
||||
|
||||
return param_range_map
|
||||
|
||||
@classmethod
|
||||
def build_model_gbuf_range(cls, model, dtype):
|
||||
"""
|
||||
Build mapping between params and their grad buffers.
|
||||
|
||||
This method does the initial setup for the method above. This setup
|
||||
includes determining the shard ranges into the DDP's grad buffer for
|
||||
each data-parallel (DP) rank. Each DP rank keeps range info for
|
||||
all other DP ranks, for the purpose of creating args for
|
||||
reduce-scatter and all-gather.
|
||||
"""
|
||||
|
||||
data_parallel_rank = mpu.get_data_parallel_rank()
|
||||
data_parallel_world_size = mpu.get_data_parallel_world_size()
|
||||
|
||||
# Grad buffer range.
|
||||
grad_buffer = model._grad_buffers[dtype]
|
||||
gbuf_size = grad_buffer.numel
|
||||
max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
|
||||
|
||||
# All world ranges. (i.e., across all data parallel ranks)
|
||||
gbuf_world_all_ranges = []
|
||||
for r in range(data_parallel_world_size):
|
||||
gbuf_world_start = r * max_gbuf_range_size
|
||||
gbuf_world_end = min(gbuf_size, gbuf_world_start + max_gbuf_range_size)
|
||||
gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
|
||||
gbuf_world_all_ranges.append(gbuf_world_range)
|
||||
|
||||
# Local DP's ranges.
|
||||
gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
|
||||
gbuf_local_range = gbuf_world_range.normalize()
|
||||
|
||||
# Get each param's ranges.
|
||||
param_range_map = cls.build_model_gbuf_param_range_map(model,
|
||||
dtype,
|
||||
gbuf_world_range)
|
||||
|
||||
# Group into dict.
|
||||
data = {
|
||||
"local": gbuf_local_range,
|
||||
"world": gbuf_world_range,
|
||||
"world_all": gbuf_world_all_ranges,
|
||||
"param_map": param_range_map,
|
||||
"max_range_size": max_gbuf_range_size,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def build_model_gbuf_range_map(cls, model):
|
||||
"""
|
||||
Create param-to-grad-buffer mappings, for grad buffer data types
|
||||
within a specific virtual model.
|
||||
"""
|
||||
return {
|
||||
dtype: cls.build_model_gbuf_range(model, dtype)
|
||||
for dtype in model._grad_buffers
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def build_model_param_gbuf_map(cls, model_gbuf_ranges):
|
||||
"""
|
||||
Create a reverse of the model_gbuf_ranges, for referencing in
|
||||
opposite direction.
|
||||
"""
|
||||
param_gbuf_map = {}
|
||||
for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
|
||||
for dtype, gbuf_range_map in model_gbuf_range_map.items():
|
||||
for param, param_range_map in gbuf_range_map["param_map"].items():
|
||||
param_gbuf_map[param] = (model_index, dtype)
|
||||
return param_gbuf_map
|
||||
|
||||
@classmethod
|
||||
def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
|
||||
"""
|
||||
Create optimizer groups.
|
||||
|
||||
Given the set of parameter shard ranges that are owned by the current
|
||||
data-parallel (DP) rank, gather the set of parameters that will be
|
||||
used (in the method below) to create the current DP's optimizer
|
||||
groups.
|
||||
"""
|
||||
|
||||
num_groups = len(param_groups)
|
||||
|
||||
# Param group map.
|
||||
param_group_map = {}
|
||||
for group_index, group in enumerate(param_groups):
|
||||
for param in group["params"]:
|
||||
assert param.requires_grad
|
||||
param_group_map[param] = group_index
|
||||
|
||||
# Optimizer group ranges.
|
||||
group_ranges = [{"params": []} for _ in param_groups]
|
||||
for model_gbuf_range_map in model_gbuf_ranges:
|
||||
for dtype, gbuf_range_map in model_gbuf_range_map.items():
|
||||
for param in gbuf_range_map["param_map"]:
|
||||
group_index = param_group_map[param]
|
||||
group_range = group_ranges[group_index]
|
||||
group_range["params"].append(param)
|
||||
|
||||
# Squeeze zero-size group ranges.
|
||||
for group_index, group_range in enumerate(group_ranges):
|
||||
group_range["orig_group"] = param_groups[group_index]
|
||||
group_ranges = [g for g in group_ranges if len(g["params"]) > 0]
|
||||
|
||||
return group_ranges
|
||||
|
||||
@classmethod
|
||||
def build_model_and_main_param_groups(cls,
|
||||
model_gbuf_ranges,
|
||||
param_gbuf_map,
|
||||
opt_group_ranges):
|
||||
"""
|
||||
Create main parameter groups needed for the optimizer step.
|
||||
|
||||
These groups encompass both: 1) groups used by this class, for
|
||||
reducing/gather, and 2) groups used by the inner optimizer for the
|
||||
parameter update. Given that the conceptual grad buffer partitioning
|
||||
(created in earlier method) doesn't respect parameter boundaries,
|
||||
the optimizer operates on shards of the model parameters, rather than
|
||||
the full parameters.
|
||||
"""
|
||||
|
||||
# Parameter groups:
|
||||
# model_float16_groups: original float16 parameters
|
||||
# model_fp32_groups: original fp32 parameters
|
||||
# shard_float16_groups: shards of original float16 parameters
|
||||
# shard_fp32_groups: shards of original fp32 parameters
|
||||
# shard_fp32_from_float16_groups: fp32 copy of float16 parameters
|
||||
model_float16_groups = []
|
||||
model_fp32_groups = []
|
||||
shard_float16_groups = []
|
||||
shard_fp32_groups = []
|
||||
shard_fp32_from_float16_groups = []
|
||||
|
||||
# Allocate (or slice) each group's param shard.
|
||||
for group_index, group_range in enumerate(opt_group_ranges):
|
||||
|
||||
# Params of this group.
|
||||
model_float16_params_this_group = []
|
||||
model_fp32_params_this_group = []
|
||||
shard_float16_params_this_group = []
|
||||
shard_fp32_params_this_group = []
|
||||
shard_fp32_from_float16_params_this_group = []
|
||||
model_float16_groups.append(model_float16_params_this_group)
|
||||
model_fp32_groups.append(model_fp32_params_this_group)
|
||||
shard_float16_groups.append(shard_float16_params_this_group)
|
||||
shard_fp32_groups.append(shard_fp32_params_this_group)
|
||||
shard_fp32_from_float16_groups.append(
|
||||
shard_fp32_from_float16_params_this_group)
|
||||
|
||||
for model_param in group_range["params"]:
|
||||
|
||||
assert model_param.requires_grad
|
||||
|
||||
model_index, dtype = param_gbuf_map[model_param]
|
||||
gbuf_range = model_gbuf_ranges[model_index][dtype]
|
||||
param_range = gbuf_range["param_map"][model_param]["param"]
|
||||
|
||||
# fp16, bf16 params.
|
||||
if model_param.type() in ['torch.npu.HalfTensor',
|
||||
'torch.npu.BFloat16Tensor']:
|
||||
|
||||
# Clone model -> main.
|
||||
shard_model_param = model_param.detach().view(-1) \
|
||||
[param_range.start:param_range.end]
|
||||
shard_main_param = shard_model_param.clone().float()
|
||||
mpu.copy_tensor_model_parallel_attributes(
|
||||
shard_model_param, model_param)
|
||||
mpu.copy_tensor_model_parallel_attributes(
|
||||
shard_main_param, model_param)
|
||||
if hasattr(model_param, 'shared'):
|
||||
shard_model_param.shared = model_param.shared
|
||||
shard_main_param.shared = model_param.shared
|
||||
|
||||
# Add to group.
|
||||
model_float16_params_this_group.append(model_param)
|
||||
shard_float16_params_this_group.append(shard_model_param)
|
||||
shard_fp32_from_float16_params_this_group.append(shard_main_param)
|
||||
|
||||
# fp32 params.
|
||||
elif model_param.type() == 'torch.npu.FloatTensor':
|
||||
shard_model_param = model_param.view(-1) \
|
||||
[param_range.start:param_range.end]
|
||||
model_fp32_params_this_group.append(model_param)
|
||||
shard_fp32_params_this_group.append(shard_model_param)
|
||||
mpu.copy_tensor_model_parallel_attributes(
|
||||
shard_model_param, model_param)
|
||||
if hasattr(model_param, 'shared'):
|
||||
shard_model_param.shared = model_param.shared
|
||||
|
||||
else:
|
||||
raise TypeError('Wrapped parameters must be one of '
|
||||
'torch.npu.FloatTensor, '
|
||||
'torch.npu.HalfTensor, or '
|
||||
'torch.npu.BFloat16Tensor. '
|
||||
'Received {}'.format(model_param.type()))
|
||||
|
||||
# Update optimizer's params.
|
||||
group_range["orig_group"]["params"] = [
|
||||
*shard_fp32_params_this_group,
|
||||
*shard_fp32_from_float16_params_this_group,
|
||||
]
|
||||
|
||||
return (
|
||||
model_float16_groups,
|
||||
model_fp32_groups,
|
||||
shard_float16_groups,
|
||||
shard_fp32_groups,
|
||||
shard_fp32_from_float16_groups,
|
||||
)
|
||||
|
||||
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
|
||||
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
|
||||
fp16, bf16, params_dtype, grad_scaler, models):
|
||||
"""
|
||||
See top of class definition for argument descriptions.
|
||||
|
||||
The steps in this method create the core mapping between DDP grad
|
||||
buffers, parameters, and parameter shard ranges, that is needed for
|
||||
converting between model param indexes and main parameter shard
|
||||
indexes. This method also updates the optimizer parameter groups
|
||||
with the newly created shards.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
optimizer, clip_grad, log_num_zeros_in_grad,
|
||||
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
|
||||
fp16, bf16, params_dtype, grad_scaler, models)
|
||||
|
||||
# Verify that contiguous buffers are being used.
|
||||
# - Note: this should already be checked in arguments.py.
|
||||
assert use_contiguous_buffers_in_local_ddp
|
||||
|
||||
# Model grad buffer ranges.
|
||||
self.model_gbuf_ranges = []
|
||||
for model_index, model in enumerate(self.models):
|
||||
self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
|
||||
self.model_param_gbuf_map = \
|
||||
self.build_model_param_gbuf_map(self.model_gbuf_ranges)
|
||||
|
||||
# Optimizer ranges.
|
||||
self.opt_group_ranges = self.build_optimizer_group_ranges(
|
||||
self.optimizer.param_groups,
|
||||
self.model_gbuf_ranges)
|
||||
|
||||
# Allocate main param shards.
|
||||
(
|
||||
self.model_float16_groups,
|
||||
self.model_fp32_groups,
|
||||
self.shard_float16_groups,
|
||||
self.shard_fp32_groups,
|
||||
self.shard_fp32_from_float16_groups,
|
||||
) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
|
||||
self.model_param_gbuf_map,
|
||||
self.opt_group_ranges)
|
||||
|
||||
# Initialize param buffers.
|
||||
# - These are views on the DDP model's grad buffers, that share
|
||||
# storage & have their own dtype. This is safe because the param
|
||||
# dtype size is always <= grad dtype size.
|
||||
self.param_buffers = []
|
||||
for model_index, model in enumerate(self.models):
|
||||
current_param_buffers = {}
|
||||
for dtype, grad_buffer in model._grad_buffers.items():
|
||||
param_buffer = torch.tensor(torch.flatten(grad_buffer.data), # grad_buffer.data.storage()._untyped(),
|
||||
dtype=params_dtype,
|
||||
device=grad_buffer.data.device)
|
||||
|
||||
# create NPU tensor with set_() instead of tensor.storage()._untyped()
|
||||
# param_buffer = torch.tensor(1, dtype=params_dtype, device=grad_buffer.data.device)
|
||||
# size = reduce(lambda x, y: x * y, grad_buffer.data.size())
|
||||
# param_buffer.set_(grad_buffer.data.storage(), 0, [size], [1])
|
||||
|
||||
param_buffer = param_buffer[:grad_buffer.numel_padded]
|
||||
current_param_buffers[dtype] = param_buffer
|
||||
self.param_buffers.append(current_param_buffers)
|
||||
|
||||
# Update optimizer groups.
|
||||
# - Also, leverage state_dict() and load_state_dict() to
|
||||
# recast preexisting per-param state tensors.
|
||||
self.optimizer.param_groups = \
|
||||
[g["orig_group"] for g in self.opt_group_ranges]
|
||||
self.optimizer.load_state_dict(self.optimizer.state_dict())
|
||||
|
||||
def get_model_param_range_map(self, param):
|
||||
"""
|
||||
Given a model param, get the index sub-range of the param that this
|
||||
data-parallel rank owns.
|
||||
"""
|
||||
model_index, dtype = self.model_param_gbuf_map[param]
|
||||
gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
|
||||
param_range_map = gbuf_range_map["param_map"][param]
|
||||
return param_range_map
|
||||
|
||||
def get_model_parallel_group(self):
|
||||
"""
|
||||
With the distributed optimizer, the model parallel group is the
|
||||
entire world.
|
||||
"""
|
||||
return None
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
The state dict must contain the fp32-from-float16 shards.
|
||||
"""
|
||||
state_dict = {}
|
||||
state_dict['optimizer'] = self.optimizer.state_dict()
|
||||
if self.grad_scaler:
|
||||
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
|
||||
state_dict['shard_fp32_from_float16_groups'] = \
|
||||
self.shard_fp32_from_float16_groups
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
Load the state dict.
|
||||
"""
|
||||
|
||||
# Optimizer.
|
||||
optimizer_key = 'optimizer'
|
||||
if optimizer_key not in state_dict:
|
||||
optimizer_key = 'optimizer_state_dict'
|
||||
print_rank_0('***WARNING*** loading optimizer from '
|
||||
'an old checkpoint ...')
|
||||
self.optimizer.load_state_dict(state_dict[optimizer_key])
|
||||
|
||||
# Grad scaler.
|
||||
if 'grad_scaler' not in state_dict:
|
||||
if self.fp16:
|
||||
print_rank_0('***WARNING*** found an old checkpoint, will not '
|
||||
'load grad scaler ...')
|
||||
else:
|
||||
if self.grad_scaler:
|
||||
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
|
||||
else:
|
||||
print_rank_0('***WARNING*** fould the grad scaler in the '
|
||||
'checkpoint but it is None in the class. '
|
||||
'Skipping loading grad scaler ...')
|
||||
|
||||
# Copy data for the main params.
|
||||
for current_group, saved_group in zip(
|
||||
self.shard_fp32_from_float16_groups,
|
||||
state_dict["shard_fp32_from_float16_groups"]):
|
||||
for current_param, saved_param in zip(current_group, saved_group):
|
||||
current_param.data.copy_(saved_param.data)
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""
|
||||
Zero grads.
|
||||
|
||||
We only need to zero the model related parameters, i.e.,
|
||||
model_float16_groups & model_fp32_groups. We additionally zero
|
||||
the remaining groups as a memory optimization to reduce
|
||||
fragmentation; in the case of set_to_none==True, the space
|
||||
used by this field can be safely deallocated at this point.
|
||||
"""
|
||||
for groups in (
|
||||
self.model_float16_groups,
|
||||
self.model_fp32_groups,
|
||||
self.shard_float16_groups, # grad empty/unused here?
|
||||
self.shard_fp32_groups, # throws grad-access warning
|
||||
self.shard_fp32_from_float16_groups):
|
||||
for group in groups:
|
||||
_zero_grad_group_helper(group, set_to_none)
|
||||
|
||||
@staticmethod
|
||||
def get_model_buffer_dp_views(model_buffers):
|
||||
"""
|
||||
Get shard views of each of the DDP's param/grad buffers.
|
||||
|
||||
In this nested list, the top level is grouped by the virtual model
|
||||
index and the buffer's data type. The sub-level is a list of
|
||||
shards of that buffer, where each shard in the list represents
|
||||
a contiguous view of the buffer, that is owned by a data-parallel
|
||||
rank. The shard boundary does not respect parameter boundaries, and
|
||||
so the elements of some parameters are split across data parallel
|
||||
ranks.
|
||||
|
||||
Additionally, return references to the entire buffers, for use
|
||||
in _reduce_scatter_base and _all_gather_base.
|
||||
"""
|
||||
|
||||
data_parallel_world_size = mpu.get_data_parallel_world_size()
|
||||
|
||||
# Buffer views.
|
||||
view_items = []
|
||||
for model_index, buffers in enumerate(model_buffers):
|
||||
for dtype, buf in buffers.items():
|
||||
assert buf.numel() % data_parallel_world_size == 0
|
||||
shard_size = int(buf.numel() / data_parallel_world_size)
|
||||
buf_views = [buf[(r * shard_size):((r + 1) * shard_size)]
|
||||
for r in range(data_parallel_world_size)]
|
||||
view_items.append((model_index, dtype, buf, buf_views))
|
||||
|
||||
return view_items
|
||||
|
||||
def get_model_grad_buffer_dp_views(self):
|
||||
return self.get_model_buffer_dp_views([
|
||||
{dtype: mem_buffer.data}
|
||||
for model in self.models
|
||||
for dtype, mem_buffer in model._grad_buffers.items()])
|
||||
|
||||
def get_model_param_buffer_dp_views(self):
|
||||
return self.get_model_buffer_dp_views(self.param_buffers)
|
||||
|
||||
def reduce_model_grads(self, args, timers):
|
||||
"""
|
||||
Reduce-scatter model grads.
|
||||
|
||||
The DDP's grad buffer is used for the reduce-scatter, and thus no
|
||||
tensors are dynamically allocated.
|
||||
|
||||
Note: this is a different order of reduction, versus the non-
|
||||
distributed optimizer, which reduces: 1) layernorm grads, 2) all
|
||||
grads, 3) embedding grads.
|
||||
"""
|
||||
|
||||
# All-reduce layer-norm grads (for sequence parallelism).
|
||||
timers('layernorm-grads-all-reduce', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
self.allreduce_layernorm_grads(args)
|
||||
timers('layernorm-grads-all-reduce').stop()
|
||||
|
||||
# All-reduce embedding grads.
|
||||
timers('embedding-grads-all-reduce', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
self.allreduce_embedding_grads(args)
|
||||
timers('embedding-grads-all-reduce').stop()
|
||||
|
||||
# Reduce-scatter setup.
|
||||
timers('grads-reduce-scatter', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
data_parallel_rank = mpu.get_data_parallel_rank()
|
||||
data_parallel_world_size = mpu.get_data_parallel_world_size()
|
||||
data_parallel_group = mpu.get_data_parallel_group()
|
||||
|
||||
# Scale grad buffers by '1 / data_parallel_world_size'.
|
||||
for model in self.models:
|
||||
for dtype, gbuf in model._grad_buffers.items():
|
||||
gbuf.data /= data_parallel_world_size
|
||||
|
||||
# Reduce-scatter all grads.
|
||||
gbuf_view_items = self.get_model_grad_buffer_dp_views()
|
||||
for index, (model_index, dtype, gbuf, gbuf_views) \
|
||||
in enumerate(gbuf_view_items):
|
||||
torch.distributed._reduce_scatter_base(
|
||||
gbuf_views[data_parallel_rank],
|
||||
gbuf,
|
||||
group=data_parallel_group,
|
||||
)
|
||||
|
||||
timers('grads-reduce-scatter').stop()
|
||||
|
||||
def gather_model_params(self, args, timers):
|
||||
"""
|
||||
All-gather updated model params.
|
||||
|
||||
The DDP's param buffer is used for the all-gather, and thus no
|
||||
tensors are dynamically allocated. After the all-gather, the params
|
||||
can be copied from the param buffer to the param.
|
||||
"""
|
||||
|
||||
timers('params-all-gather', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
|
||||
data_parallel_rank = mpu.get_data_parallel_rank()
|
||||
data_parallel_group = mpu.get_data_parallel_group()
|
||||
|
||||
# All-gather updated main params.
|
||||
# - All param buffer views are guaranteed to have the same num elements
|
||||
# across all data parallel ranks, due to grad buffer padding that is
|
||||
# done in distributed.py, and extended to the param buffers. Thus,
|
||||
# all sub-views will have consistent start/end indexes across data
|
||||
# parallel ranks.
|
||||
pbuf_view_items = self.get_model_param_buffer_dp_views()
|
||||
for index, (model_index, dtype, pbuf, pbuf_views) \
|
||||
in enumerate(pbuf_view_items):
|
||||
torch.distributed._all_gather_base(
|
||||
pbuf,
|
||||
pbuf_views[data_parallel_rank],
|
||||
group=data_parallel_group,
|
||||
)
|
||||
|
||||
# Copy from param buffer to each param.
|
||||
for model_id, model in enumerate(self.models):
|
||||
for dtype, param_map in model._grad_buffer_param_index_map.items():
|
||||
for param, buf_range in param_map.items():
|
||||
param_buf = self.param_buffers[model_id][dtype]
|
||||
param_buf_shard = param_buf[buf_range[0]:buf_range[1]]
|
||||
param.view(-1).detach().copy_(param_buf_shard)
|
||||
|
||||
timers('params-all-gather').stop()
|
||||
|
||||
def _collect_main_grad_data_for_unscaling(self):
|
||||
"""
|
||||
Note: this should be equivalent to the float-16 optimizer's method,
|
||||
but writtent differently, so the two should be combined.
|
||||
"""
|
||||
return [
|
||||
param.grad.data
|
||||
for group in self.optimizer.param_groups
|
||||
for param in group["params"]
|
||||
]
|
||||
|
||||
def _get_model_and_main_params_data_float16(self):
|
||||
"""
|
||||
Get aligned list of model and main params.
|
||||
"""
|
||||
model_data = []
|
||||
main_data = []
|
||||
for model_group, main_group in zip(self.shard_float16_groups,
|
||||
self.shard_fp32_from_float16_groups):
|
||||
for model_param, main_param in zip(model_group, main_group):
|
||||
model_data.append(model_param.data)
|
||||
main_data.append(main_param.data)
|
||||
return model_data, main_data
|
||||
|
||||
def _copy_model_grads_to_main_grads(self):
|
||||
"""
|
||||
Copy model grads to main grads.
|
||||
|
||||
Since this step follows a reduce-scatter through the DDP's grad
|
||||
buffer, this method is responsible for copying the updated grads
|
||||
from the grad buffer to the main shard's grad field.
|
||||
"""
|
||||
|
||||
# Utility method for copying group grads.
|
||||
def copy_group_grads(model_groups, shard_main_groups):
|
||||
for model_group, shard_main_group in zip(model_groups,
|
||||
shard_main_groups):
|
||||
for model_param, shard_main_param in zip(model_group,
|
||||
shard_main_group):
|
||||
param_range_map = self.get_model_param_range_map(model_param)
|
||||
param_range = param_range_map["param"]
|
||||
assert param_range.size == shard_main_param.nelement()
|
||||
|
||||
model_grad = model_param.main_grad
|
||||
shard_model_grad = model_grad.view(-1) \
|
||||
[param_range.start:param_range.end]
|
||||
shard_main_param.grad = shard_model_grad.float()
|
||||
|
||||
# Copy model groups to shard groups.
|
||||
copy_group_grads(self.model_float16_groups,
|
||||
self.shard_fp32_from_float16_groups)
|
||||
copy_group_grads(self.model_fp32_groups,
|
||||
self.shard_fp32_groups)
|
||||
|
||||
def _copy_main_params_to_model_params(self):
|
||||
"""
|
||||
Copy main params to model params.
|
||||
|
||||
Since this step is followed by an all-gather through the DDP's grad
|
||||
buffer, this method is responsible for copying the updated params
|
||||
from the main shards into the correct position in the grad buffer.
|
||||
"""
|
||||
|
||||
# Utility method for copying group params.
|
||||
def copy_group_params(shard_main_groups, model_groups):
|
||||
for shard_main_group, model_group in zip(shard_main_groups,
|
||||
model_groups):
|
||||
for shard_main_param, model_param in zip(shard_main_group,
|
||||
model_group):
|
||||
param_range_map = self.get_model_param_range_map(model_param)
|
||||
world_range = param_range_map["gbuf_world"]
|
||||
|
||||
assert world_range.size == shard_main_param.nelement()
|
||||
|
||||
model_id, dtype = self.model_param_gbuf_map[model_param]
|
||||
model_param_buffer = self.param_buffers[model_id][dtype]
|
||||
|
||||
shard_model_param = model_param_buffer.view(-1) \
|
||||
[world_range.start:world_range.end]
|
||||
|
||||
shard_model_param.data.copy_(shard_main_param)
|
||||
|
||||
# Copy shard groups to model groups.
|
||||
copy_group_params(self.shard_fp32_from_float16_groups,
|
||||
self.model_float16_groups)
|
||||
copy_group_params(self.shard_fp32_groups,
|
||||
self.model_fp32_groups)
|
||||
|
||||
def _copy_model_params_to_main_params(self):
|
||||
"""
|
||||
Copy model params to main params.
|
||||
|
||||
During finetuning, this method is used to reload the main params from
|
||||
the model params. This copy does not make use of the grad buffer as
|
||||
an intermediary.
|
||||
"""
|
||||
|
||||
# Utility method for copying group params.
|
||||
def copy_group_params(model_groups, shard_main_groups):
|
||||
for model_group, shard_main_group in zip(model_groups,
|
||||
shard_main_groups):
|
||||
for model_param, shard_main_param in zip(model_group,
|
||||
shard_main_group):
|
||||
|
||||
param_range_map = self.get_model_param_range_map(model_param)
|
||||
param_range = param_range_map["param"]
|
||||
assert param_range.size == shard_main_param.nelement()
|
||||
|
||||
shard_model_param = model_param.view(-1)[param_range.start: param_range.end]
|
||||
shard_main_param.data.copy_(shard_model_param)
|
||||
|
||||
# Copy model groups to shard groups.
|
||||
copy_group_params(self.model_float16_groups,
|
||||
self.shard_fp32_from_float16_groups)
|
||||
copy_group_params(self.model_fp32_groups,
|
||||
self.shard_fp32_groups)
|
@ -19,13 +19,18 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from ascendspeed import get_timers
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed import mpu, print_rank_0
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed import print_rank_0
|
||||
from ascendspeed.model.module import param_is_not_shared
|
||||
from ascendspeed.mpu.layers import param_is_not_tensor_parallel_duplicate
|
||||
from ascendspeed.utils import unwrap_model
|
||||
from ascendspeed.model import Float16Module, DistributedDataParallel as LocalDDP
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from ascendspeed.model.lora_utils import is_enable_lora, get_lora_model_classes
|
||||
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
|
||||
|
||||
|
||||
@ -64,13 +69,14 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
||||
that_.copy_(this_)
|
||||
|
||||
|
||||
|
||||
class MegatronOptimizer(ABC):
|
||||
|
||||
|
||||
def __init__(self, optimizer, clip_grad,
|
||||
log_num_zeros_in_grad,
|
||||
params_have_main_grad):
|
||||
params_have_main_grad,
|
||||
use_contiguous_buffers_in_local_ddp,
|
||||
models):
|
||||
|
||||
"""Input optimizer is the base optimizer for example Adam."""
|
||||
self.optimizer = optimizer
|
||||
assert self.optimizer, 'no optimizer is provided.'
|
||||
@ -78,7 +84,19 @@ class MegatronOptimizer(ABC):
|
||||
self.clip_grad = clip_grad
|
||||
self.log_num_zeros_in_grad = log_num_zeros_in_grad
|
||||
self.params_have_main_grad = params_have_main_grad
|
||||
self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
|
||||
|
||||
# 'models' are retained for access to the contiguous grad buffers.
|
||||
# (see distributed optimizer)
|
||||
self.models = models
|
||||
|
||||
if self.use_contiguous_buffers_in_local_ddp:
|
||||
assert self.params_have_main_grad, \
|
||||
"use of contiguous buffer requires that params have main grad"
|
||||
|
||||
self.unwrap_model_classes = (torchDDP, LocalDDP, Float16Module)
|
||||
if is_enable_lora():
|
||||
self.unwrap_model_classes += get_lora_model_classes()
|
||||
|
||||
def get_parameters(self):
|
||||
params = []
|
||||
@ -87,38 +105,53 @@ class MegatronOptimizer(ABC):
|
||||
params.append(param)
|
||||
return params
|
||||
|
||||
def get_main_grads_for_grad_norm(self):
|
||||
|
||||
# Filter parameters based on:
|
||||
# - grad should not be none
|
||||
# - parameter should not be shared
|
||||
# - should not be a replica due to tensor model parallelism
|
||||
params = self.get_parameters()
|
||||
grads_for_norm = []
|
||||
for param in params:
|
||||
grad = param.grad
|
||||
grad_not_none = grad is not None
|
||||
is_not_shared = param_is_not_shared(param)
|
||||
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
|
||||
if grad_not_none and is_not_shared and is_not_tp_duplicate:
|
||||
grads_for_norm.append(grad)
|
||||
|
||||
return grads_for_norm
|
||||
|
||||
def get_model_parallel_group(self):
|
||||
"""Default returned here, but the distributed optimizer overrides this."""
|
||||
return parallel_state.get_model_parallel_group()
|
||||
|
||||
def clip_grad_norm(self, clip_grad):
|
||||
params = self.get_parameters()
|
||||
return clip_grad_norm_fp32(params, clip_grad)
|
||||
|
||||
grads_for_norm = self.get_main_grads_for_grad_norm()
|
||||
return clip_grad_norm_fp32(
|
||||
params, grads_for_norm, clip_grad,
|
||||
model_parallel_group=self.get_model_parallel_group())
|
||||
|
||||
def count_zeros(self):
|
||||
params = self.get_parameters()
|
||||
return count_zeros_fp32(params)
|
||||
|
||||
return count_zeros_fp32(params,
|
||||
model_parallel_group=self.get_model_parallel_group())
|
||||
|
||||
@abstractmethod
|
||||
def zero_grad(self, set_to_none=True):
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_loss_scale(self):
|
||||
"""The output should be a cuda tensor of size 1."""
|
||||
pass
|
||||
|
||||
|
||||
def scale_loss(self, loss):
|
||||
"""Simple scaling."""
|
||||
return self.get_loss_scale() * loss
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def reload_model_params(self):
|
||||
"""Refreshes any internal state from the current model parameters.
|
||||
@ -128,17 +161,14 @@ class MegatronOptimizer(ABC):
|
||||
with main parameters, the main parameters need to also be updated."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
|
||||
# Promote state so it can be retrieved or set via
|
||||
# "optimizer_instance.state"
|
||||
def _get_state(self):
|
||||
@ -149,7 +179,6 @@ class MegatronOptimizer(ABC):
|
||||
|
||||
state = property(_get_state, _set_state)
|
||||
|
||||
|
||||
# Promote param_groups so it can be retrieved or set via
|
||||
# "optimizer_instance.param_groups"
|
||||
# (for example, to adjust the learning rate)
|
||||
@ -161,6 +190,274 @@ class MegatronOptimizer(ABC):
|
||||
|
||||
param_groups = property(_get_param_groups, _set_param_groups)
|
||||
|
||||
@abstractmethod
|
||||
def step(self, args, timers):
|
||||
pass
|
||||
|
||||
def gather_model_params(self, args, timers):
|
||||
"""
|
||||
For the case of a non-distributed-optimizer, there is nothing to
|
||||
do here.
|
||||
"""
|
||||
pass
|
||||
|
||||
def allreduce_word_embedding_grads(self, args):
|
||||
"""
|
||||
All-reduce word embedding grads.
|
||||
|
||||
Reduce grads across first and last stages to ensure that word_embeddings
|
||||
parameters stay in sync. This should only run for models that support
|
||||
pipelined model parallelism (BERT and GPT-2).
|
||||
"""
|
||||
|
||||
if parallel_state.is_rank_in_embedding_group(ignore_virtual=True) and \
|
||||
parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
|
||||
unwrapped_model = self.models[0]
|
||||
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
|
||||
unwrapped_model = self.models[-1]
|
||||
else: # We do not support the interleaved schedule for T5 yet.
|
||||
unwrapped_model = self.models[0]
|
||||
unwrapped_model = unwrap_model(
|
||||
unwrapped_model, self.unwrap_model_classes)
|
||||
|
||||
if unwrapped_model.share_word_embeddings:
|
||||
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
|
||||
if args.DDP_impl == 'local':
|
||||
grad = word_embeddings_weight.main_grad
|
||||
else:
|
||||
grad = word_embeddings_weight.grad
|
||||
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
|
||||
|
||||
def allreduce_position_embedding_grads(self, args):
|
||||
"""
|
||||
All-reduce position_embeddings grad across first (encoder) and
|
||||
split (decoder) stages to ensure that position embeddings parameters
|
||||
stay in sync. This should only run for T5 models with pipeline
|
||||
parallelism.
|
||||
"""
|
||||
if parallel_state.is_rank_in_position_embedding_group() and \
|
||||
parallel_state.get_pipeline_model_parallel_world_size() > 1 and \
|
||||
args.pipeline_model_parallel_split_rank is not None:
|
||||
unwrapped_model = self.models[0]
|
||||
unwrapped_model = unwrap_model(
|
||||
unwrapped_model, self.unwrap_model_classes)
|
||||
assert args.DDP_impl == 'local', \
|
||||
'T5 model is only supported with local DDP mode'
|
||||
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
|
||||
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
|
||||
|
||||
def allreduce_embedding_grads(self, args):
|
||||
"""All-reduce both word and position embeddings."""
|
||||
self.allreduce_word_embedding_grads(args)
|
||||
self.allreduce_position_embedding_grads(args)
|
||||
|
||||
def allreduce_layernorm_grads(self, args):
|
||||
"""All-reduce layernorm grads (for sequence parallelism)."""
|
||||
|
||||
# All-reduce layernorm parameters across model parallel nodes
|
||||
# when sequence parallelism is used
|
||||
if parallel_state.get_tensor_model_parallel_world_size() > 1 and \
|
||||
args.sequence_parallel:
|
||||
grads = []
|
||||
for model_module in self.models:
|
||||
unwrapped_model = unwrap_model(
|
||||
model_module, self.unwrap_model_classes)
|
||||
for param in unwrapped_model.parameters():
|
||||
if getattr(param, 'sequence_parallel', False):
|
||||
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
|
||||
grads.append(grad.data)
|
||||
|
||||
# print("rank [{}], len:{}\n".format(torch.cuda.current_device(), len(self.models[0])), end="", flush=True)
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
torch.distributed.all_reduce(
|
||||
coalesced, group=parallel_state.get_tensor_model_parallel_group())
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
||||
coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
def reduce_model_grads(self, args, timers):
|
||||
"""All-reduce all grads, and all-reduce embeddings."""
|
||||
|
||||
# All-reduce layer-norm grads (for sequence parallelism).
|
||||
timers('layernorm-grads-all-reduce', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
self.allreduce_layernorm_grads(args)
|
||||
timers('layernorm-grads-all-reduce').stop()
|
||||
|
||||
# All-reduce if needed.
|
||||
if args.DDP_impl == 'local':
|
||||
timers('grads-all-reduce', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
for model in self.models:
|
||||
model.allreduce_gradients()
|
||||
timers('grads-all-reduce').stop()
|
||||
|
||||
# All-reduce embedding grads.
|
||||
timers('embedding-grads-all-reduce', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
self.allreduce_embedding_grads(args)
|
||||
timers('embedding-grads-all-reduce').stop()
|
||||
|
||||
|
||||
class MixedPrecisionOptimizer(MegatronOptimizer):
|
||||
"""Base class for both the float-16 and the distributed optimizer.
|
||||
|
||||
Arguments:
|
||||
optimizer: base optimizer such as Adam or SGD
|
||||
clip_grad: clip gradeints with this global L2 norm. Note
|
||||
that clipping is ignored if clip_grad == 0
|
||||
log_num_zeros_in_grad: return number of zeros in the gradients.
|
||||
params_have_main_grad: flag indicating if parameters have
|
||||
a `main_grad` field. If this is set, we are assuming
|
||||
that the model parameters are store in the `main_grad`
|
||||
field instead of the typical `grad` field. This happens
|
||||
for the DDP cases where there is a continuous buffer
|
||||
holding the gradients. For example for bfloat16, we want
|
||||
to do gradient accumulation and all-reduces in float32
|
||||
and as a result we store those gradients in the main_grad.
|
||||
Note that main grad is not necessarily in float32.
|
||||
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
|
||||
is using a contiguous buffer to hold the model grads.
|
||||
fp16: if true, the model is running in fp16.
|
||||
bf16: if true, the model is running in bfloat16.
|
||||
params_dtype: used by distributed optimizer.
|
||||
grad_scaler: used for scaling gradients. Note that this can be
|
||||
None. This case happens when `bf16 = True` and we don't
|
||||
use any loss scale. Note that for `bf16 = True`, we can have
|
||||
a constnat gradient scaler. Also for `bf16 = False`, we
|
||||
always require a grad scaler.
|
||||
models: list of models (i.e., the virtual pipelining models). This
|
||||
is used by the distributed optimizer for mapping parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
|
||||
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
|
||||
fp16, bf16, params_dtype, grad_scaler,
|
||||
models):
|
||||
|
||||
super().__init__(
|
||||
optimizer, clip_grad, log_num_zeros_in_grad,
|
||||
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
|
||||
models)
|
||||
|
||||
self.fp16 = fp16
|
||||
self.bf16 = bf16
|
||||
self.params_dtype = params_dtype
|
||||
self.grad_scaler = grad_scaler
|
||||
|
||||
# None grad scaler is only supported for bf16.
|
||||
if self.grad_scaler is None:
|
||||
assert not self.fp16, 'fp16 expects a grad scaler.'
|
||||
|
||||
# Tensor used to determine if a nan/if has happend.
|
||||
# Any non-zero value indicates inf/nan.
|
||||
# Note that we keep this for the cases that grad scaler is none.
|
||||
# We still record nan/inf if we have a bfloat16 with a grad scaler.
|
||||
if self.grad_scaler:
|
||||
self.found_inf = torch.cuda.FloatTensor([0.0])
|
||||
|
||||
# Dummy tensor needed for apex multi-apply tensor.
|
||||
# For bfloat, we don't have multi-tensor apply and for now
|
||||
# we set it to none so the multi-tensor apply gets ignored.
|
||||
if bf16:
|
||||
self._dummy_overflow_buf = None
|
||||
else:
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
# In case grad scaler is not passed, define the unity scale.
|
||||
if self.grad_scaler is None:
|
||||
self._scale_one = torch.cuda.FloatTensor([1.0])
|
||||
|
||||
def get_loss_scale(self):
|
||||
if self.grad_scaler is None:
|
||||
return self._scale_one
|
||||
return self.grad_scaler.scale
|
||||
|
||||
def reload_model_params(self):
|
||||
self._copy_model_params_to_main_params()
|
||||
|
||||
def _unscale_main_grads_and_check_for_nan(self):
|
||||
|
||||
# Collect main grads.
|
||||
main_grads = self._collect_main_grad_data_for_unscaling()
|
||||
|
||||
# Reset found inf.
|
||||
self.found_inf.fill_(0.0)
|
||||
|
||||
# Unscale and set found inf/nan
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(
|
||||
main_grads, self.found_inf, self.grad_scaler.inv_scale)
|
||||
|
||||
# Update across all model parallel instances.
|
||||
torch.distributed.all_reduce(self.found_inf,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=self.get_model_parallel_group())
|
||||
torch.distributed.all_reduce(self.found_inf,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=parallel_state.get_data_parallel_group())
|
||||
|
||||
# Check for nan.
|
||||
found_inf_flag = (self.found_inf.item() > 0)
|
||||
|
||||
return found_inf_flag
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, args, timers):
|
||||
|
||||
# Copy gradients from model params to main params.
|
||||
timers('optimizer-copy-to-main-grad', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
self._copy_model_grads_to_main_grads()
|
||||
timers('optimizer-copy-to-main-grad').stop()
|
||||
|
||||
# Do unscale, check for inf, and update grad scaler only for
|
||||
# the case that grad scaler is provided.
|
||||
if self.grad_scaler:
|
||||
|
||||
# Unscale and check for inf/nan.
|
||||
timers('optimizer-unscale-and-check-inf', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
|
||||
timers('optimizer-unscale-and-check-inf').stop()
|
||||
|
||||
# We are done with scaling gradients
|
||||
# so we can update the loss scale.
|
||||
self.grad_scaler.update(found_inf_flag)
|
||||
|
||||
# If we found inf/nan, skip the update.
|
||||
if found_inf_flag:
|
||||
return False, None, None
|
||||
|
||||
# Clip the main gradients.
|
||||
timers('optimizer-clip-main-grad', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
grad_norm = None
|
||||
if self.clip_grad > 0.0:
|
||||
grad_norm = self.clip_grad_norm(self.clip_grad)
|
||||
timers('optimizer-clip-main-grad').stop()
|
||||
|
||||
# Count the zeros in the grads.
|
||||
timers('optimizer-count-zeros', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
num_zeros_in_grad = self.count_zeros() if \
|
||||
self.log_num_zeros_in_grad else None
|
||||
timers('optimizer-count-zeros').stop()
|
||||
|
||||
# Step the optimizer.
|
||||
timers('optimizer-inner-step', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
self.optimizer.step()
|
||||
timers('optimizer-inner-step').stop()
|
||||
|
||||
# Update params from main params.
|
||||
timers('optimizer-copy-main-to-model-params', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
self._copy_main_params_to_model_params()
|
||||
timers('optimizer-copy-main-to-model-params').stop()
|
||||
|
||||
# Successful update.
|
||||
return True, grad_norm, num_zeros_in_grad
|
||||
|
||||
|
||||
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
@ -189,11 +486,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
|
||||
params_have_main_grad, bf16, grad_scaler):
|
||||
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
|
||||
fp16, bf16, params_dtype, grad_scaler, models):
|
||||
|
||||
super(Float16OptimizerWithFloat16Params, self).__init__(
|
||||
optimizer, clip_grad, log_num_zeros_in_grad,
|
||||
params_have_main_grad)
|
||||
params_have_main_grad, use_contiguous_buffers_in_local_ddp, models)
|
||||
|
||||
self.bf16 = bf16
|
||||
self.grad_scaler = grad_scaler
|
||||
@ -240,18 +538,11 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
# For all the parameters in this group:
|
||||
for i, param in enumerate(param_group['params']):
|
||||
if param.requires_grad:
|
||||
if param.type() == "torch.cuda.HalfTensor":
|
||||
param_type = "torch.npu.HalfTensor"
|
||||
elif param.type() == "torch.cuda.BFloat16Tensor":
|
||||
param_type = "torch.npu.BFloat16Tensor"
|
||||
elif param.type() == "torch.cuda.FloatTensor":
|
||||
param_type = "torch.npu.FloatTensor"
|
||||
else:
|
||||
param_type = param.type()
|
||||
param_type = param.type().replace('cuda', get_accelerator().device_name())
|
||||
|
||||
# float16 params:
|
||||
if param_type in ['torch.{}.HalfTensor'.format(get_accelerator().device_name()),
|
||||
'torch.{}.BFloat16Tensor'.format(get_accelerator().device_name())]:
|
||||
'torch.{}.BFloat16Tensor'.format(get_accelerator().device_name())]:
|
||||
float16_params_this_group.append(param)
|
||||
# Create a copy
|
||||
main_param = param.detach().clone().float()
|
||||
@ -279,7 +570,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
'torch.{}.FloatTensor, '
|
||||
'torch.{}.HalfTensor, or '
|
||||
'torch.{}.BFloat16Tensor. '
|
||||
'Received {}'.format(device_name,device_name,device_name,param.type()))
|
||||
'Received {}'.format(device_name, device_name, device_name, param.type()))
|
||||
|
||||
self.float16_groups.append(float16_params_this_group)
|
||||
self.fp32_from_float16_groups.append(
|
||||
@ -290,7 +581,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
# recast preexisting per-param state tensors
|
||||
self.optimizer.load_state_dict(self.optimizer.state_dict())
|
||||
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""We only need to zero the model related parameters, i.e.,
|
||||
float16_groups & fp32_from_fp32_groups."""
|
||||
@ -299,13 +589,11 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
for group in self.fp32_from_fp32_groups:
|
||||
_zero_grad_group_helper(group, set_to_none)
|
||||
|
||||
|
||||
def get_loss_scale(self):
|
||||
if self.grad_scaler is None:
|
||||
return self._scale_one
|
||||
return self.grad_scaler.scale
|
||||
|
||||
|
||||
def _copy_model_grads_to_main_grads(self):
|
||||
# This only needs to be done for the float16 group.
|
||||
for model_group, main_group in zip(self.float16_groups,
|
||||
@ -323,7 +611,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
for model_param in model_group:
|
||||
model_param.grad = model_param.main_grad
|
||||
|
||||
|
||||
def _unscale_main_grads_and_check_for_nan(self):
|
||||
main_grads = []
|
||||
# fp32 params fromm float16 ones.
|
||||
@ -350,7 +637,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
found_inf_flag = (self.found_inf.item() > 0)
|
||||
return found_inf_flag
|
||||
|
||||
|
||||
def _get_model_and_main_params_data_float16(self):
|
||||
model_data = []
|
||||
main_data = []
|
||||
@ -361,27 +647,23 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
main_data.append(main_param.data)
|
||||
return model_data, main_data
|
||||
|
||||
|
||||
def _copy_main_params_to_model_params(self):
|
||||
# Only needed for the float16 params.
|
||||
model_data, main_data = self._get_model_and_main_params_data_float16()
|
||||
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
|
||||
overflow_buf=self._dummy_overflow_buf)
|
||||
|
||||
|
||||
def _copy_model_params_to_main_params(self):
|
||||
# Only needed for the float16 params.
|
||||
model_data, main_data = self._get_model_and_main_params_data_float16()
|
||||
_multi_tensor_copy_this_to_that(this=model_data, that=main_data,
|
||||
overflow_buf=self._dummy_overflow_buf)
|
||||
|
||||
|
||||
def reload_model_params(self):
|
||||
self._copy_model_params_to_main_params()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
def step(self, args, timers):
|
||||
|
||||
timers = get_timers()
|
||||
|
||||
@ -416,7 +698,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
|
||||
# count the zeros in the grads
|
||||
num_zeros_in_grad = self.count_zeros() if \
|
||||
self.log_num_zeros_in_grad else None
|
||||
self.log_num_zeros_in_grad else None
|
||||
|
||||
# Step the optimizer.
|
||||
self.optimizer.step()
|
||||
@ -429,7 +711,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
# Successful update.
|
||||
return True, grad_norm, num_zeros_in_grad
|
||||
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {}
|
||||
state_dict['optimizer'] = self.optimizer.state_dict()
|
||||
@ -438,7 +719,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
# Optimizer.
|
||||
optimizer_key = 'optimizer'
|
||||
@ -471,7 +751,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
|
||||
current_param.data.copy_(saved_param.data)
|
||||
|
||||
|
||||
|
||||
class FP32Optimizer(MegatronOptimizer):
|
||||
|
||||
def __init__(self, optimizer, clip_grad,
|
||||
@ -484,18 +763,15 @@ class FP32Optimizer(MegatronOptimizer):
|
||||
|
||||
self._scale = get_accelerator().FloatTensor([1.0])
|
||||
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""Copied from torch.optim.optimizer"""
|
||||
for group in self.optimizer.param_groups:
|
||||
_zero_grad_group_helper(group['params'], set_to_none)
|
||||
|
||||
|
||||
def get_loss_scale(self):
|
||||
"""FP32 optimizer does not do any scaling."""
|
||||
return self._scale
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
"""Clip gradients (if needed) and step the base optimizer.
|
||||
@ -514,7 +790,7 @@ class FP32Optimizer(MegatronOptimizer):
|
||||
|
||||
# count the zeros in the grads
|
||||
num_zeros_in_grad = self.count_zeros() if \
|
||||
self.log_num_zeros_in_grad else None
|
||||
self.log_num_zeros_in_grad else None
|
||||
|
||||
# Update parameters.
|
||||
self.optimizer.step()
|
||||
@ -522,14 +798,11 @@ class FP32Optimizer(MegatronOptimizer):
|
||||
# No overflow for FP32 optimizer.
|
||||
return True, grad_norm, num_zeros_in_grad
|
||||
|
||||
|
||||
def reload_model_params(self):
|
||||
pass
|
||||
|
||||
|
||||
def state_dict(self):
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.optimizer.load_state_dict(state_dict)
|
||||
|
@ -1,302 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import reduce
|
||||
import operator
|
||||
import torch
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.core import parallel_state
|
||||
|
||||
|
||||
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
|
||||
use_ring_exchange=False,
|
||||
recv_tensor_shape=None):
|
||||
"""Communicate tensors between stages. Used as helper method in other
|
||||
communication methods that are used in ascendspeed/schedules.py.
|
||||
|
||||
Takes the following arguments:
|
||||
tensor_send_next: tensor to send to next rank (no tensor sent if
|
||||
set to None).
|
||||
tensor_send_prev: tensor to send to prev rank (no tensor sent if
|
||||
set to None).
|
||||
recv_prev: boolean for whether tensor should be received from
|
||||
previous rank.
|
||||
recv_next: boolean for whether tensor should be received from
|
||||
next rank.
|
||||
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
|
||||
API should be used.
|
||||
recv_tensor_shape: shape of received tensor. This can be useful when using
|
||||
optimized pipeline parallelism.
|
||||
|
||||
Returns:
|
||||
(tensor_recv_prev, tensor_recv_next)
|
||||
"""
|
||||
args = get_args()
|
||||
|
||||
# Create placeholder tensors for receive in forward and backward directions
|
||||
# if needed.
|
||||
tensor_recv_prev = None
|
||||
tensor_recv_next = None
|
||||
|
||||
if args.optimized_pipeline and (recv_prev or recv_next) and (recv_tensor_shape is None):
|
||||
raise ValueError('recv_tensor_shape has to be provided for optimized pipeline.')
|
||||
|
||||
tensor_shape = recv_tensor_shape if args.optimized_pipeline and (recv_prev or recv_next) \
|
||||
else (args.seq_length, args.micro_batch_size, args.hidden_size)
|
||||
|
||||
if args.sequence_parallel:
|
||||
seq_length = args.seq_length // parallel_state.get_tensor_model_parallel_world_size()
|
||||
tensor_shape = (seq_length, args.micro_batch_size, args.hidden_size)
|
||||
|
||||
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
|
||||
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
|
||||
parallel_state.get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
tensor_chunk_shape = tensor_shape
|
||||
dtype = args.params_dtype
|
||||
if args.fp32_residual_connection:
|
||||
dtype = torch.float
|
||||
if recv_prev:
|
||||
tensor_recv_prev = torch.empty(tensor_chunk_shape,
|
||||
requires_grad=True,
|
||||
device=get_accelerator().current_device_name(),
|
||||
dtype=dtype)
|
||||
if recv_next:
|
||||
tensor_recv_next = torch.empty(tensor_chunk_shape,
|
||||
requires_grad=True,
|
||||
device=get_accelerator().current_device_name(),
|
||||
dtype=dtype)
|
||||
|
||||
# Split tensor into smaller chunks if using scatter-gather optimization.
|
||||
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
|
||||
if tensor_send_next is not None:
|
||||
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
|
||||
|
||||
if tensor_send_prev is not None:
|
||||
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
|
||||
|
||||
# Send tensors in both the forward and backward directions as appropriate.
|
||||
if use_ring_exchange:
|
||||
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
|
||||
tensor_recv_prev=tensor_recv_prev,
|
||||
tensor_send_next=tensor_send_next,
|
||||
tensor_recv_next=tensor_recv_next,
|
||||
group=parallel_state.get_pipeline_model_parallel_group())
|
||||
else:
|
||||
ops = []
|
||||
if tensor_send_prev is not None:
|
||||
send_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend, tensor_send_prev,
|
||||
parallel_state.get_pipeline_model_parallel_prev_rank())
|
||||
ops.append(send_prev_op)
|
||||
if tensor_recv_prev is not None:
|
||||
recv_prev_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv, tensor_recv_prev,
|
||||
parallel_state.get_pipeline_model_parallel_prev_rank())
|
||||
ops.append(recv_prev_op)
|
||||
|
||||
if args.num_layers_per_virtual_pipeline_stage is None:
|
||||
# pp
|
||||
if tensor_recv_next is not None:
|
||||
recv_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv, tensor_recv_next,
|
||||
parallel_state.get_pipeline_model_parallel_next_rank())
|
||||
ops.append(recv_next_op)
|
||||
if tensor_send_next is not None:
|
||||
send_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend, tensor_send_next,
|
||||
parallel_state.get_pipeline_model_parallel_next_rank())
|
||||
ops.append(send_next_op)
|
||||
else:
|
||||
# vp
|
||||
if tensor_send_next is not None:
|
||||
send_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.isend, tensor_send_next,
|
||||
parallel_state.get_pipeline_model_parallel_next_rank())
|
||||
ops.append(send_next_op)
|
||||
if tensor_recv_next is not None:
|
||||
recv_next_op = torch.distributed.P2POp(
|
||||
torch.distributed.irecv, tensor_recv_next,
|
||||
parallel_state.get_pipeline_model_parallel_next_rank())
|
||||
ops.append(recv_next_op)
|
||||
|
||||
if (args.num_layers_per_virtual_pipeline_stage is not None) \
|
||||
and (parallel_state.get_pipeline_model_parallel_rank() % 2 == 1):
|
||||
ops.reverse()
|
||||
|
||||
if len(ops) > 0:
|
||||
reqs = torch.distributed.batch_isend_irecv(ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
get_accelerator().synchronize()
|
||||
|
||||
# If using scatter-gather optimization, gather smaller chunks.
|
||||
if args.scatter_gather_tensors_in_pipeline and not args.sequence_parallel:
|
||||
if recv_prev:
|
||||
tensor_recv_prev = mpu.gather_split_1d_tensor(
|
||||
tensor_recv_prev).view(tensor_shape).requires_grad_()
|
||||
|
||||
if recv_next:
|
||||
tensor_recv_next = mpu.gather_split_1d_tensor(
|
||||
tensor_recv_next).view(tensor_shape).requires_grad_()
|
||||
|
||||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
||||
def recv_forward(timers=None, recv_tensor_shape=None):
|
||||
"""Receive tensor from previous rank in pipeline (forward receive)."""
|
||||
if parallel_state.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
if timers is not None:
|
||||
timers('forward-recv', log_level=2).start()
|
||||
input_tensor, _ = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=True,
|
||||
recv_next=False,
|
||||
recv_tensor_shape=recv_tensor_shape)
|
||||
if timers is not None:
|
||||
timers('forward-recv').stop()
|
||||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(timers=None, recv_tensor_shape=None):
|
||||
"""Receive tensor from next rank in pipeline (backward receive)."""
|
||||
if parallel_state.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
if timers is not None:
|
||||
timers('backward-recv', log_level=2).start()
|
||||
_, output_tensor_grad = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=True,
|
||||
recv_tensor_shape=recv_tensor_shape)
|
||||
if timers is not None:
|
||||
timers('backward-recv').stop()
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_tensor, timers=None):
|
||||
"""Send tensor to next rank in pipeline (forward send)."""
|
||||
if not parallel_state.is_pipeline_last_stage():
|
||||
if timers is not None:
|
||||
timers('forward-send', log_level=2).start()
|
||||
_communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=False)
|
||||
if timers is not None:
|
||||
timers('forward-send').stop()
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad, timers=None):
|
||||
"""Send tensor to previous rank in pipeline (backward send)."""
|
||||
if not parallel_state.is_pipeline_first_stage():
|
||||
if timers is not None:
|
||||
timers('backward-send', log_level=2).start()
|
||||
_communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=False,
|
||||
recv_next=False)
|
||||
if timers is not None:
|
||||
timers('backward-send').stop()
|
||||
|
||||
|
||||
def send_forward_recv_backward(output_tensor, timers=None, recv_tensor_shape=None):
|
||||
"""Batched send and recv with next rank in pipeline."""
|
||||
if parallel_state.is_pipeline_last_stage():
|
||||
output_tensor_grad = None
|
||||
else:
|
||||
if timers is not None:
|
||||
timers('forward-send-backward-recv', log_level=2).start()
|
||||
_, output_tensor_grad = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=True,
|
||||
recv_tensor_shape=recv_tensor_shape)
|
||||
if timers is not None:
|
||||
timers('forward-send-backward-recv').stop()
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_backward_recv_forward(input_tensor_grad, timers=None, recv_tensor_shape=None):
|
||||
"""Batched send and recv with previous rank in pipeline."""
|
||||
if parallel_state.is_pipeline_first_stage():
|
||||
input_tensor = None
|
||||
else:
|
||||
if timers is not None:
|
||||
timers('backward-send-forward-recv', log_level=2).start()
|
||||
input_tensor, _ = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=True,
|
||||
recv_next=False,
|
||||
recv_tensor_shape=recv_tensor_shape)
|
||||
if timers is not None:
|
||||
timers('backward-send-forward-recv').stop()
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
|
||||
"""Batched recv from previous rank and send to next rank in pipeline."""
|
||||
if timers is not None:
|
||||
timers('forward-send-forward-recv', log_level=2).start()
|
||||
input_tensor, _ = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=False)
|
||||
if timers is not None:
|
||||
timers('forward-send-forward-recv').stop()
|
||||
return input_tensor
|
||||
|
||||
|
||||
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
|
||||
"""Batched recv from next rank and send to previous rank in pipeline."""
|
||||
if timers is not None:
|
||||
timers('backward-send-backward-recv', log_level=2).start()
|
||||
_, output_tensor_grad = _communicate(
|
||||
tensor_send_next=None,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=False,
|
||||
recv_next=recv_next)
|
||||
if timers is not None:
|
||||
timers('backward-send-backward-recv').stop()
|
||||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad, recv_prev,
|
||||
recv_next, timers=None):
|
||||
"""Batched send and recv with previous and next ranks in pipeline."""
|
||||
if timers is not None:
|
||||
timers('forward-backward-send-forward-backward-recv', log_level=2).start()
|
||||
input_tensor, output_tensor_grad = _communicate(
|
||||
tensor_send_next=output_tensor,
|
||||
tensor_send_prev=input_tensor_grad,
|
||||
recv_prev=recv_prev,
|
||||
recv_next=recv_next)
|
||||
if timers is not None:
|
||||
timers('forward-backward-send-forward-backward-recv').stop()
|
||||
return input_tensor, output_tensor_grad
|
@ -1,620 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import get_num_microbatches
|
||||
from ascendspeed import get_timers
|
||||
from ascendspeed import p2p_communication
|
||||
from ascendspeed import print_rank_0
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed.utils import unwrap_model
|
||||
from ascendspeed.model import DistributedDataParallel as LocalDDP
|
||||
from ascendspeed.model import Float16Module
|
||||
|
||||
|
||||
def clear_npu_overflow_flag():
|
||||
float_status = torch.zeros(8).npu() # 8 bit for overflow
|
||||
result = torch_npu.npu_clear_float_status(float_status)
|
||||
|
||||
|
||||
def get_npu_overflow_flag():
|
||||
float_status = torch.zeros(8).npu() # 8 bit for overflow
|
||||
result = torch_npu.npu_get_float_status(float_status)
|
||||
if float_status.cpu()[0] != 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def set_npu_overflow_flag():
|
||||
torch.tensor([65504]).half().npu() + 100 # fp16 overflow flag
|
||||
|
||||
|
||||
def get_forward_backward_func():
|
||||
args = get_args()
|
||||
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
if args.virtual_pipeline_model_parallel_size is not None:
|
||||
forward_backward_func = forward_backward_pipelining_with_interleaving
|
||||
else:
|
||||
forward_backward_func = forward_backward_pipelining_without_interleaving
|
||||
else:
|
||||
forward_backward_func = forward_backward_no_pipelining
|
||||
return forward_backward_func
|
||||
|
||||
|
||||
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
|
||||
"""Forward step for passed-in model.
|
||||
|
||||
If first stage, input tensor is obtained from data_iterator, otherwise
|
||||
passed-in input_tensor is used.
|
||||
|
||||
Returns output tensor.
|
||||
"""
|
||||
timers = get_timers()
|
||||
|
||||
args = get_args()
|
||||
|
||||
timers('forward-compute', log_level=2).start()
|
||||
unwrapped_model = unwrap_model(
|
||||
model, (torchDDP, LocalDDP, Float16Module))
|
||||
if not args.deepspeed:
|
||||
unwrapped_model.set_input_tensor(input_tensor)
|
||||
else:
|
||||
unwrapped_model.module.set_input_tensor(input_tensor)
|
||||
|
||||
# Note: it's recommended to NOT add any new argument to forward_step_func()
|
||||
# because it is an abstract API used by many different models and tasks.
|
||||
# Changing this API requires changing it in all models/tasks. Instead,
|
||||
# it's recommended to use args to pass additional arguments.
|
||||
output_tensor, loss_func = forward_step_func(data_iterator, model)
|
||||
if parallel_state.is_pipeline_last_stage():
|
||||
output_tensor = loss_func(output_tensor)
|
||||
loss, loss_reduced = output_tensor
|
||||
if not args.no_pipeline_parallel:
|
||||
output_tensor = loss / get_num_microbatches()
|
||||
else:
|
||||
output_tensor = loss
|
||||
losses_reduced.append(loss_reduced)
|
||||
timers('forward-compute').stop()
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, model=None):
|
||||
"""Backward step through passed-in output tensor.
|
||||
|
||||
If last stage, output_tensor_grad is None, otherwise gradient of loss
|
||||
with respect to stage's output tensor.
|
||||
|
||||
Returns gradient of loss with respect to input tensor (None if first
|
||||
stage)."""
|
||||
args = get_args()
|
||||
|
||||
if args.deepspeed:
|
||||
assert model is not None
|
||||
|
||||
timers = get_timers()
|
||||
timers('backward-compute', log_level=2).start()
|
||||
|
||||
# Retain the grad on the input_tensor.
|
||||
if input_tensor is not None:
|
||||
input_tensor.retain_grad()
|
||||
|
||||
clear_npu_overflow_flag()
|
||||
if args.deepspeed:
|
||||
model.backward(output_tensor)
|
||||
else:
|
||||
# Backward pass.
|
||||
if output_tensor_grad is None:
|
||||
output_tensor = optimizer.scale_loss(output_tensor)
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
||||
|
||||
# Collect the grad of the input_tensor.
|
||||
input_tensor_grad = None
|
||||
if input_tensor is not None:
|
||||
input_tensor_grad = input_tensor.grad
|
||||
|
||||
timers('backward-compute').stop()
|
||||
|
||||
return input_tensor_grad
|
||||
|
||||
@contextmanager
|
||||
def dummy_handler():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
|
||||
optimizer, timers, forward_only):
|
||||
"""Run forward and backward passes with no pipeline parallelism
|
||||
(no inter-stage communication).
|
||||
|
||||
Returns dictionary with losses."""
|
||||
assert len(model) == 1
|
||||
model = model[0]
|
||||
|
||||
context_handler = dummy_handler
|
||||
if isinstance(model, torchDDP):
|
||||
context_handler = model.no_sync
|
||||
|
||||
losses_reduced = []
|
||||
input_tensor, output_tensor_grad = None, None
|
||||
overflow_flag_all = False
|
||||
with context_handler():
|
||||
for i in range(get_num_microbatches() - 1):
|
||||
output_tensor = forward_step(forward_step_func, data_iterator, model,
|
||||
input_tensor, losses_reduced)
|
||||
if not forward_only:
|
||||
backward_step(optimizer, input_tensor, output_tensor,
|
||||
output_tensor_grad)
|
||||
|
||||
overflow_flag = get_npu_overflow_flag()
|
||||
overflow_flag_all = overflow_flag or overflow_flag_all
|
||||
output_tensor = forward_step(forward_step_func, data_iterator, model,
|
||||
input_tensor, losses_reduced)
|
||||
if not forward_only:
|
||||
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
|
||||
|
||||
overflow_flag = get_npu_overflow_flag()
|
||||
overflow_flag_all = overflow_flag or overflow_flag_all
|
||||
|
||||
if overflow_flag_all:
|
||||
set_npu_overflow_flag()
|
||||
return losses_reduced
|
||||
|
||||
|
||||
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
|
||||
optimizer, timers, forward_only):
|
||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||
communication between pipeline stages as needed.
|
||||
|
||||
Returns dictionary with losses if the last stage, empty dict otherwise."""
|
||||
|
||||
input_tensors = [[] for _ in range(len(model))]
|
||||
output_tensors = [[] for _ in range(len(model))]
|
||||
losses_reduced = []
|
||||
if not forward_only:
|
||||
output_tensor_grads = [[] for _ in range(len(model))]
|
||||
|
||||
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
|
||||
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
|
||||
|
||||
# Compute number of warmup and remaining microbatches.
|
||||
num_model_chunks = len(model)
|
||||
num_microbatches = get_num_microbatches() * num_model_chunks
|
||||
all_warmup_microbatches = False
|
||||
if forward_only:
|
||||
num_warmup_microbatches = num_microbatches
|
||||
else:
|
||||
# Run all forward passes and then all backward passes if number of
|
||||
# microbatches is just the number of pipeline stages.
|
||||
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
|
||||
# all workers, followed by more microbatches after depending on
|
||||
# stage ID (more forward passes for earlier stages, later stages can
|
||||
# immediately start with 1F1B).
|
||||
if get_num_microbatches() == pipeline_parallel_size:
|
||||
num_warmup_microbatches = num_microbatches
|
||||
all_warmup_microbatches = True
|
||||
else:
|
||||
num_warmup_microbatches = \
|
||||
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
|
||||
num_warmup_microbatches += (
|
||||
num_model_chunks - 1) * pipeline_parallel_size
|
||||
num_warmup_microbatches = min(num_warmup_microbatches,
|
||||
num_microbatches)
|
||||
num_microbatches_remaining = \
|
||||
num_microbatches - num_warmup_microbatches
|
||||
|
||||
def get_model_chunk_id(microbatch_id, forward):
|
||||
"""Helper method to get the model chunk ID given the iteration number."""
|
||||
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
|
||||
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
|
||||
if not forward:
|
||||
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
|
||||
return model_chunk_id
|
||||
|
||||
def forward_step_helper(microbatch_id):
|
||||
"""Helper method to run forward step with model split into chunks
|
||||
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||
forward_step())."""
|
||||
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
|
||||
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
|
||||
|
||||
if parallel_state.is_pipeline_first_stage():
|
||||
if len(input_tensors[model_chunk_id]) == \
|
||||
len(output_tensors[model_chunk_id]):
|
||||
input_tensors[model_chunk_id].append(None)
|
||||
input_tensor = input_tensors[model_chunk_id][-1]
|
||||
output_tensor = forward_step(forward_step_func,
|
||||
data_iterator[model_chunk_id],
|
||||
model[model_chunk_id],
|
||||
input_tensor, losses_reduced)
|
||||
output_tensors[model_chunk_id].append(output_tensor)
|
||||
|
||||
return output_tensor
|
||||
|
||||
def backward_step_helper(microbatch_id):
|
||||
"""Helper method to run backward step with model split into chunks
|
||||
(run set_virtual_pipeline_model_parallel_rank() before calling
|
||||
backward_step())."""
|
||||
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
|
||||
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
|
||||
|
||||
if parallel_state.is_pipeline_last_stage():
|
||||
if len(output_tensor_grads[model_chunk_id]) == 0:
|
||||
output_tensor_grads[model_chunk_id].append(None)
|
||||
input_tensor = input_tensors[model_chunk_id].pop(0)
|
||||
output_tensor = output_tensors[model_chunk_id].pop(0)
|
||||
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
|
||||
input_tensor_grad = \
|
||||
backward_step(optimizer,
|
||||
input_tensor,
|
||||
output_tensor,
|
||||
output_tensor_grad)
|
||||
|
||||
return input_tensor_grad
|
||||
|
||||
# Run warmup forward passes.
|
||||
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
|
||||
input_tensors[0].append(
|
||||
p2p_communication.recv_forward(timers))
|
||||
for k in range(num_warmup_microbatches):
|
||||
output_tensor = forward_step_helper(k)
|
||||
|
||||
# Determine if tensor should be received from previous stage.
|
||||
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
|
||||
recv_prev = True
|
||||
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
|
||||
if next_forward_model_chunk_id == 0:
|
||||
recv_prev = False
|
||||
if k == (num_microbatches - 1):
|
||||
recv_prev = False
|
||||
|
||||
# Don't send tensor downstream if on last stage.
|
||||
if parallel_state.is_pipeline_last_stage():
|
||||
output_tensor = None
|
||||
|
||||
# Send and receive tensors as appropriate (send tensors computed
|
||||
# in this iteration; receive tensors for next iteration).
|
||||
if k == (num_warmup_microbatches - 1) and not forward_only and \
|
||||
not all_warmup_microbatches:
|
||||
input_tensor_grad = None
|
||||
recv_next = True
|
||||
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
|
||||
recv_next = False
|
||||
input_tensor, output_tensor_grad = \
|
||||
p2p_communication.send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
timers=timers)
|
||||
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
|
||||
else:
|
||||
input_tensor = \
|
||||
p2p_communication.send_forward_recv_forward(
|
||||
output_tensor, recv_prev, timers)
|
||||
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for k in range(num_microbatches_remaining):
|
||||
# Forward pass.
|
||||
forward_k = k + num_warmup_microbatches
|
||||
output_tensor = forward_step_helper(forward_k)
|
||||
|
||||
# Backward pass.
|
||||
backward_k = k
|
||||
input_tensor_grad = backward_step_helper(backward_k)
|
||||
|
||||
# Send output_tensor and input_tensor_grad, receive input_tensor
|
||||
# and output_tensor_grad.
|
||||
|
||||
# Determine if current stage has anything to send in either direction,
|
||||
# otherwise set tensor to None.
|
||||
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
|
||||
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
|
||||
if parallel_state.is_pipeline_last_stage():
|
||||
output_tensor = None
|
||||
|
||||
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
|
||||
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
|
||||
if parallel_state.is_pipeline_first_stage():
|
||||
input_tensor_grad = None
|
||||
|
||||
# Determine if peers are sending, and where in data structure to put
|
||||
# received tensors.
|
||||
recv_prev = True
|
||||
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
|
||||
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
|
||||
next_forward_model_chunk_id = get_model_chunk_id(
|
||||
forward_k - (pipeline_parallel_size - 1), forward=True)
|
||||
if next_forward_model_chunk_id == (num_model_chunks - 1):
|
||||
recv_prev = False
|
||||
next_forward_model_chunk_id += 1
|
||||
else:
|
||||
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
|
||||
forward=True)
|
||||
|
||||
recv_next = True
|
||||
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
|
||||
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
|
||||
next_backward_model_chunk_id = get_model_chunk_id(
|
||||
backward_k - (pipeline_parallel_size - 1), forward=False)
|
||||
if next_backward_model_chunk_id == 0:
|
||||
recv_next = False
|
||||
next_backward_model_chunk_id -= 1
|
||||
else:
|
||||
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
|
||||
forward=False)
|
||||
|
||||
# If last iteration, don't receive; we already received one extra
|
||||
# before the start of the for loop.
|
||||
if k == (num_microbatches_remaining - 1):
|
||||
recv_prev = False
|
||||
|
||||
# Communicate tensors.
|
||||
input_tensor, output_tensor_grad = \
|
||||
p2p_communication.send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
timers=timers)
|
||||
|
||||
# Put input_tensor and output_tensor_grad in data structures in the
|
||||
# right location.
|
||||
if recv_prev:
|
||||
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||
if recv_next:
|
||||
output_tensor_grads[next_backward_model_chunk_id].append(
|
||||
output_tensor_grad)
|
||||
|
||||
# Run cooldown backward passes (flush out pipeline).
|
||||
if not forward_only:
|
||||
if all_warmup_microbatches:
|
||||
output_tensor_grads[num_model_chunks-1].append(
|
||||
p2p_communication.recv_backward(timers))
|
||||
for k in range(num_microbatches_remaining, num_microbatches):
|
||||
input_tensor_grad = backward_step_helper(k)
|
||||
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
|
||||
recv_next = True
|
||||
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
|
||||
if next_backward_model_chunk_id == (num_model_chunks - 1):
|
||||
recv_next = False
|
||||
if k == (num_microbatches - 1):
|
||||
recv_next = False
|
||||
output_tensor_grads[next_backward_model_chunk_id].append(
|
||||
p2p_communication.send_backward_recv_backward(
|
||||
input_tensor_grad, recv_next, timers))
|
||||
|
||||
return losses_reduced
|
||||
|
||||
|
||||
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
|
||||
model, optimizer, timers,
|
||||
forward_only):
|
||||
"""Run non-interleaved 1F1B schedule, with communication between pipeline
|
||||
stages.
|
||||
|
||||
Returns dictionary with losses if the last stage, empty dict otherwise."""
|
||||
|
||||
timers = get_timers()
|
||||
|
||||
assert len(model) == 1
|
||||
model = model[0]
|
||||
|
||||
# Compute number of warmup microbatches.
|
||||
num_microbatches = get_num_microbatches()
|
||||
num_warmup_microbatches = \
|
||||
(parallel_state.get_pipeline_model_parallel_world_size() -
|
||||
parallel_state.get_pipeline_model_parallel_rank() - 1)
|
||||
num_warmup_microbatches = min(
|
||||
num_warmup_microbatches,
|
||||
num_microbatches)
|
||||
num_microbatches_remaining = \
|
||||
num_microbatches - num_warmup_microbatches
|
||||
|
||||
input_tensors = []
|
||||
output_tensors = []
|
||||
losses_reduced = []
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_tensor = p2p_communication.recv_forward(timers)
|
||||
output_tensor = forward_step(forward_step_func, data_iterator, model,
|
||||
input_tensor, losses_reduced)
|
||||
p2p_communication.send_forward(output_tensor, timers)
|
||||
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
input_tensor = p2p_communication.recv_forward(timers)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
|
||||
output_tensor = forward_step(forward_step_func, data_iterator, model,
|
||||
input_tensor, losses_reduced)
|
||||
if forward_only:
|
||||
p2p_communication.send_forward(output_tensor, timers)
|
||||
else:
|
||||
output_tensor_grad = \
|
||||
p2p_communication.send_forward_recv_backward(output_tensor,
|
||||
timers)
|
||||
|
||||
# Add input_tensor and output_tensor to end of list, then pop from the
|
||||
# start of the list for backward pass.
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
if forward_only:
|
||||
if not last_iteration:
|
||||
input_tensor = p2p_communication.recv_forward(timers)
|
||||
else:
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
|
||||
input_tensor_grad = \
|
||||
backward_step(optimizer, input_tensor, output_tensor,
|
||||
output_tensor_grad, model)
|
||||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
p2p_communication.send_backward(input_tensor_grad, timers)
|
||||
else:
|
||||
input_tensor = \
|
||||
p2p_communication.send_backward_recv_forward(
|
||||
input_tensor_grad, timers)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
output_tensor_grad = p2p_communication.recv_backward(timers)
|
||||
|
||||
input_tensor_grad = \
|
||||
backward_step(optimizer, input_tensor, output_tensor,
|
||||
output_tensor_grad, model)
|
||||
|
||||
p2p_communication.send_backward(input_tensor_grad, timers)
|
||||
|
||||
return losses_reduced
|
||||
|
||||
|
||||
def get_tensor_shapes():
|
||||
args = get_args()
|
||||
tensor_shapes = []
|
||||
mbs = args.manual_mbs
|
||||
for m in mbs:
|
||||
tensor_shapes.append((args.seq_length, m, args.hidden_size))
|
||||
|
||||
return tensor_shapes
|
||||
|
||||
|
||||
def optimized_forward_backward_pipelining(forward_step_func, data_iterator,
|
||||
model, optimizer, timers,
|
||||
forward_only):
|
||||
"""Run non-interleaved 1F1B schedule, with reduced pipeline bubble.
|
||||
Returns dictionary with losses if the last stage, empty dict otherwise.
|
||||
"""
|
||||
|
||||
timers = get_timers()
|
||||
|
||||
assert len(model) == 1
|
||||
model = model[0]
|
||||
|
||||
tensor_shapes = get_tensor_shapes()
|
||||
cnt_fwd, cnt_bwd = 0, 0
|
||||
|
||||
# Compute number of warmup microbatches.
|
||||
num_microbatches = get_num_microbatches()
|
||||
num_warmup_microbatches = \
|
||||
(parallel_state.get_pipeline_model_parallel_world_size() -
|
||||
parallel_state.get_pipeline_model_parallel_rank() - 1)
|
||||
num_warmup_microbatches = min(
|
||||
num_warmup_microbatches,
|
||||
num_microbatches)
|
||||
num_microbatches_remaining = \
|
||||
num_microbatches - num_warmup_microbatches
|
||||
|
||||
input_tensors = []
|
||||
output_tensors = []
|
||||
losses_reduced = []
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_tensor = p2p_communication.recv_forward(timers=timers,
|
||||
recv_tensor_shape=tensor_shapes[cnt_fwd])
|
||||
output_tensor = forward_step(forward_step_func, data_iterator, model,
|
||||
input_tensor, losses_reduced)
|
||||
p2p_communication.send_forward(output_tensor, timers=timers)
|
||||
cnt_fwd += 1
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
input_tensor = p2p_communication.recv_forward(timers=timers,
|
||||
recv_tensor_shape=tensor_shapes[cnt_fwd])
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
|
||||
output_tensor = forward_step(forward_step_func, data_iterator, model,
|
||||
input_tensor, losses_reduced)
|
||||
if forward_only:
|
||||
p2p_communication.send_forward(output_tensor, timers=timers)
|
||||
else:
|
||||
output_tensor_grad = \
|
||||
p2p_communication.send_forward_recv_backward(output_tensor, timers=timers,
|
||||
recv_tensor_shape=tensor_shapes[cnt_bwd])
|
||||
|
||||
cnt_fwd += 1
|
||||
|
||||
# Add input_tensor and output_tensor to end of list, then pop from the
|
||||
# start of the list for backward pass.
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
if forward_only:
|
||||
if not last_iteration:
|
||||
input_tensor = p2p_communication.recv_forward(
|
||||
recv_tensor_shape=tensor_shapes[cnt_fwd], timers=timers)
|
||||
else:
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
|
||||
input_tensor_grad = \
|
||||
backward_step(optimizer, input_tensor, output_tensor,
|
||||
output_tensor_grad, model)
|
||||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
p2p_communication.send_backward(input_tensor_grad, timers=timers)
|
||||
else:
|
||||
input_tensor = \
|
||||
p2p_communication.send_backward_recv_forward(
|
||||
input_tensor_grad, timers=timers, recv_tensor_shape=tensor_shapes[cnt_fwd])
|
||||
cnt_bwd += 1
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
output_tensor_grad = p2p_communication.recv_backward(timers=timers,
|
||||
recv_tensor_shape=tensor_shapes[cnt_bwd])
|
||||
|
||||
input_tensor_grad = \
|
||||
backward_step(optimizer, input_tensor, output_tensor,
|
||||
output_tensor_grad, model)
|
||||
|
||||
p2p_communication.send_backward(input_tensor_grad, timers=timers)
|
||||
|
||||
cnt_bwd += 1
|
||||
|
||||
return losses_reduced
|
@ -14,25 +14,27 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for generating text."""
|
||||
import os
|
||||
import time
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
# These are needed to unwrap the model, would be nice to put these in ascendspeed.utils if possible?
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import get_tokenizer
|
||||
from ascendspeed.core import parallel_state
|
||||
|
||||
from ascendspeed.utils import get_ltor_masks_and_position_ids, unwrap_model
|
||||
from ascendspeed.p2p_communication import recv_forward, send_forward
|
||||
from ascendspeed.core.pipeline_parallel.p2p_communication import recv_forward, send_forward
|
||||
|
||||
from ascendspeed.model import DistributedDataParallel as LocalDDP
|
||||
from ascendspeed.model import Float16Module
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from ascendspeed.model.lora_utils import is_enable_lora, get_lora_model_classes
|
||||
from ascendspeed.core.utils import get_attr_wrapped_model, get_model_config, get_model_type
|
||||
|
||||
|
||||
def get_batch(context_tokens):
|
||||
"""Generate batch from context tokens."""
|
||||
@ -44,7 +46,7 @@ def get_batch(context_tokens):
|
||||
# Get the attention mask and postition ids.
|
||||
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
tokenizer.eod,
|
||||
tokenizer.pad_token_id,
|
||||
args.reset_position_ids,
|
||||
args.reset_attention_mask,
|
||||
args.eod_mask_loss)
|
||||
@ -322,14 +324,13 @@ def generate_samples_interactive(model, print_frequency=24):
|
||||
context_count += 1
|
||||
|
||||
|
||||
|
||||
def generate_samples_unconditional(model, latencies=[], model_latencies=[], single_token_latency=[]):
|
||||
|
||||
args = get_args()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
num_samples = args.num_samples
|
||||
context_tokens = [[tokenizer.eod]
|
||||
context_tokens = [[tokenizer.pad_token_id]
|
||||
for _ in range(args.micro_batch_size)]
|
||||
ctr = 0
|
||||
while True:
|
||||
@ -343,10 +344,7 @@ def generate_samples_unconditional(model, latencies=[], model_latencies=[], sing
|
||||
start_time = time.time()
|
||||
if parallel_state.is_pipeline_last_stage() and \
|
||||
parallel_state.get_tensor_model_parallel_rank() == 0:
|
||||
#if ctr % args.log_interval == 0:
|
||||
# print('Avg s/batch:',
|
||||
# (time.time() - start_time) / min(args.log_interval, ctr + 1))
|
||||
# start_time = time.time()
|
||||
|
||||
length = len(token_stream)
|
||||
token_batch = token_stream[0].cpu().numpy().tolist()
|
||||
length_batch = token_stream[1].cpu().numpy().tolist()
|
||||
@ -397,9 +395,12 @@ def get_token_stream(model, context_tokens, model_latencies=[], single_token_lat
|
||||
args = get_args()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
context_tokens, context_lengths = pad_batch(context_tokens,
|
||||
tokenizer.eod, args)
|
||||
if hasattr(tokenizer, "eod"):
|
||||
pad_id = tokenizer.eod
|
||||
else:
|
||||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
|
||||
|
||||
context_tokens, context_lengths = pad_batch(context_tokens, pad_id, args)
|
||||
context_tokens_tensor = get_accelerator().LongTensor(context_tokens)
|
||||
context_length_tensor = get_accelerator().LongTensor(context_lengths)
|
||||
|
||||
@ -419,20 +420,20 @@ def get_token_stream(model, context_tokens, model_latencies=[], single_token_lat
|
||||
|
||||
count = 0
|
||||
|
||||
t0=time.time()
|
||||
for tokens, lengths in batch_token_iterator:
|
||||
t0 = time.time()
|
||||
for tokens, lengths, log_probs in batch_token_iterator:
|
||||
if count > 1:
|
||||
get_accelerator().synchronize()
|
||||
t_elapsed = time.time() - t0
|
||||
single_token_latency.append(t_elapsed)
|
||||
get_accelerator().synchronize()
|
||||
t_elapsed = time.time() - t0
|
||||
single_token_latency.append(t_elapsed)
|
||||
get_accelerator().synchronize()
|
||||
t0=time.time()
|
||||
count+=1
|
||||
t0 = time.time()
|
||||
count += 1
|
||||
context_length += 1
|
||||
if tokens is not None:
|
||||
yield tokens[:, :context_length], lengths
|
||||
yield tokens[:, :context_length], lengths, log_probs
|
||||
else:
|
||||
yield None, None
|
||||
yield None, None, None
|
||||
|
||||
|
||||
def switch(val1, val2, boolean):
|
||||
@ -452,32 +453,54 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
|
||||
args = get_args()
|
||||
orig_seq_length = args.seq_length
|
||||
args.seq_length = tokens.shape[1]
|
||||
|
||||
input_tensor = recv_forward()
|
||||
config = get_model_config(model)
|
||||
tensor_shapes = (args.seq_length, args.micro_batch_size, args.hidden_size)
|
||||
input_tensor = recv_forward(tensor_shapes, config)
|
||||
|
||||
# Forward pass through the model.
|
||||
unwrapped_model = unwrap_model(
|
||||
model, (torchDDP, LocalDDP, Float16Module))
|
||||
unwrap_classes = (torchDDP, LocalDDP, Float16Module)
|
||||
if is_enable_lora():
|
||||
unwrap_classes += get_lora_model_classes()
|
||||
unwrapped_model = unwrap_model(model, unwrap_classes)
|
||||
|
||||
if hasattr(unwrapped_model, 'set_input_tensor'):
|
||||
unwrapped_model.set_input_tensor(input_tensor)
|
||||
elif args.deepspeed or args.ds_inference:
|
||||
unwrapped_model.module.set_input_tensor(input_tensor)
|
||||
if args.deepspeed or args.ds_inference:
|
||||
unwrapped_model.module.set_input_tensor(input_tensor)
|
||||
else:
|
||||
unwrapped_model.set_input_tensor(input_tensor)
|
||||
|
||||
output_tensor = model(tokens, position_ids, attention_mask,
|
||||
tokentype_ids=tokentype_ids,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value,
|
||||
forward_method_parallel_output=forward_method_parallel_output)
|
||||
if args.deepspeed and args.ds_pipeline_enabled:
|
||||
output_tensor = model.eval_batch(
|
||||
iter([[(tokens, position_ids, attention_mask), (tokens, tokens)]]),
|
||||
compute_loss=False
|
||||
)
|
||||
else:
|
||||
output_tensor = model(
|
||||
input_ids=tokens,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
tokentype_ids=tokentype_ids,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value,
|
||||
forward_method_parallel_output=forward_method_parallel_output
|
||||
)
|
||||
|
||||
if isinstance(output_tensor, (list, tuple)):
|
||||
if output_tensor[0] is not None and not get_key_value:
|
||||
output_tensor = output_tensor[0]
|
||||
elif output_tensor[0] is not None and get_key_value:
|
||||
output_tensor = output_tensor[:2]
|
||||
else:
|
||||
raise ValueError("Please make sure that the output of the model is 'Tensor' or '[Tensor, ...]'")
|
||||
|
||||
if get_key_value:
|
||||
output_tensor, layer_past = output_tensor
|
||||
|
||||
send_forward(output_tensor)
|
||||
send_forward(output_tensor, config)
|
||||
|
||||
args.seq_length = orig_seq_length
|
||||
get_accelerator().synchronize()
|
||||
model_latencies.append(time.time()-t0)
|
||||
model_latencies.append(time.time() - t0)
|
||||
if get_key_value:
|
||||
return output_tensor, layer_past
|
||||
return output_tensor
|
||||
@ -495,11 +518,12 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
|
||||
context_length = context_lengths.min().item()
|
||||
|
||||
# added eos_id to support the function generate_samples_eval that passes
|
||||
# eos_id as an argument and needs termination when that id id found.
|
||||
if hasattr(args, 'eos_id'):
|
||||
# eos_id as an argument and needs termination when that id found.
|
||||
|
||||
if hasattr(args, 'eos_id') and args.eos_id is not None:
|
||||
eos_id = args.eos_id
|
||||
else:
|
||||
eos_id = tokenizer.eod
|
||||
eos_id = tokenizer.eos_token_id
|
||||
|
||||
counter = 0
|
||||
org_context_length = context_length
|
||||
@ -514,9 +538,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
|
||||
maxlen = org_context_length + args.out_seq_length
|
||||
|
||||
lengths = torch.ones([batch_size]).long().to(get_accelerator().device_name()) * maxlen
|
||||
output_log_probs = None
|
||||
|
||||
while context_length <= (maxlen):
|
||||
if args.recompute:
|
||||
if args.text_generation_config['recompute']:
|
||||
output = forward_step(model, tokens,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
@ -552,15 +577,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
|
||||
logits = output[:, -1].view(batch_size, -1).contiguous()
|
||||
|
||||
if parallel_state.is_pipeline_last_stage():
|
||||
vocab_size = torch.Tensor([logits.size(1)]).to(get_accelerator().device_name())
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
|
||||
if args.greedy:
|
||||
prev = torch.argmax(logits, dim=-1).view(-1)
|
||||
else:
|
||||
logits = logits.float()
|
||||
logits /= args.temperature
|
||||
logits = top_k_logits(logits, top_k=args.top_k,
|
||||
top_p=args.top_p)
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
|
||||
logits /= args.text_generation_config["temperature"]
|
||||
logits = top_k_logits(logits,
|
||||
top_k=args.text_generation_config["top_k"],
|
||||
top_p=args.text_generation_config["top_p"])
|
||||
logits = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(logits, num_samples=1).view(-1)
|
||||
|
||||
started = context_lengths <= context_length
|
||||
|
||||
@ -568,8 +597,23 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
|
||||
tokens[:, context_length].view(-1), prev, started)
|
||||
tokens[:, context_length] = new_tokens
|
||||
src = parallel_state.get_pipeline_model_parallel_last_rank()
|
||||
group = parallel_state.get_embedding_group()
|
||||
torch.distributed.broadcast(new_tokens, src, group)
|
||||
group = parallel_state.get_pipeline_model_parallel_group()
|
||||
|
||||
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
torch.distributed.broadcast(new_tokens, src, group)
|
||||
|
||||
if args.text_generation_config['return_output_log_probs']:
|
||||
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
torch.distributed.broadcast(vocab_size, src, group)
|
||||
torch.distributed.broadcast(log_probs, src, group)
|
||||
|
||||
if counter == 0:
|
||||
log_probs_seq = torch.zeros(
|
||||
(batch_size, maxlen + 1, int(vocab_size))
|
||||
).to(get_accelerator().device_name())
|
||||
|
||||
log_probs_seq[:, context_length, :] = log_probs
|
||||
output_log_probs = log_probs_seq[:, :context_length + 1, :]
|
||||
|
||||
done_token = (prev == eos_id).byte() & started.byte()
|
||||
just_finished = (done_token & ~is_done).bool()
|
||||
@ -580,18 +624,35 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
|
||||
src = parallel_state.get_pipeline_model_parallel_last_rank()
|
||||
group = parallel_state.get_pipeline_model_parallel_group()
|
||||
torch.distributed.broadcast(done, src, group)
|
||||
yield tokens, lengths
|
||||
|
||||
yield tokens, lengths, output_log_probs
|
||||
|
||||
else:
|
||||
if parallel_state.is_pipeline_first_stage():
|
||||
src = parallel_state.get_pipeline_model_parallel_last_rank()
|
||||
group = parallel_state.get_embedding_group()
|
||||
new_tokens = torch.empty_like(tokens[:, context_length])
|
||||
torch.distributed.broadcast(new_tokens, src, group)
|
||||
tokens[:, context_length] = new_tokens
|
||||
yield tokens, None
|
||||
else:
|
||||
yield None, None
|
||||
src = parallel_state.get_pipeline_model_parallel_last_rank()
|
||||
group = parallel_state.get_pipeline_model_parallel_group()
|
||||
|
||||
new_tokens = torch.empty_like(tokens[:, context_length])
|
||||
vocab_size = torch.empty_like(torch.Tensor([0])).to(get_accelerator().device_name())
|
||||
|
||||
torch.distributed.broadcast(new_tokens, src, group)
|
||||
|
||||
tokens[:, context_length] = new_tokens
|
||||
if args.text_generation_config['return_output_log_probs']:
|
||||
torch.distributed.broadcast(vocab_size, src, group)
|
||||
log_probs = torch.empty([batch_size, int(vocab_size)],
|
||||
dtype=torch.float32,
|
||||
device=get_accelerator().device_name())
|
||||
torch.distributed.broadcast(log_probs, src, group)
|
||||
|
||||
if counter == 0:
|
||||
log_probs_seq = torch.zeros(
|
||||
(batch_size, maxlen + 1, int(vocab_size))
|
||||
).to(get_accelerator().device_name())
|
||||
|
||||
log_probs_seq[:, context_length, :] = log_probs
|
||||
output_log_probs = log_probs_seq[:, :context_length + 1, :]
|
||||
|
||||
yield tokens, lengths, output_log_probs
|
||||
|
||||
done = get_accelerator().ByteTensor([0])
|
||||
src = parallel_state.get_pipeline_model_parallel_last_rank()
|
||||
|
@ -1,402 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import unicodedata
|
||||
import six
|
||||
|
||||
|
||||
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
"""Checks whether the casing config is consistent with the checkpoint name."""
|
||||
|
||||
# The casing has to be passed in by the user and there is no explicit check
|
||||
# as to whether it matches the checkpoint. The casing information probably
|
||||
# should have been stored in the bert_config.json file, but it's not, so
|
||||
# we have to heuristically detect it to validate.
|
||||
|
||||
if not init_checkpoint:
|
||||
return
|
||||
|
||||
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
||||
if m is None:
|
||||
return
|
||||
|
||||
model_name = m.group(1)
|
||||
|
||||
lower_models = [
|
||||
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
cased_models = [
|
||||
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
||||
"multi_cased_L-12_H-768_A-12"
|
||||
]
|
||||
|
||||
is_bad_config = False
|
||||
if model_name in lower_models and not do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "False"
|
||||
case_name = "lowercased"
|
||||
opposite_flag = "True"
|
||||
|
||||
if model_name in cased_models and do_lower_case:
|
||||
is_bad_config = True
|
||||
actual_flag = "True"
|
||||
case_name = "cased"
|
||||
opposite_flag = "False"
|
||||
|
||||
if is_bad_config:
|
||||
raise ValueError(
|
||||
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
||||
"However, `%s` seems to be a %s model, so you "
|
||||
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
||||
"how the model was pre-training. If this error is wrong, please "
|
||||
"just comment out this check." % (actual_flag, init_checkpoint,
|
||||
model_name, case_name, opposite_flag))
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text.decode("utf-8", "ignore")
|
||||
elif isinstance(text, unicode):
|
||||
return text
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
# These functions want `str` for both Python2 and Python3, but in one case
|
||||
# it's a Unicode string and in the other it's a byte string.
|
||||
if six.PY3:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
elif six.PY2:
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, unicode):
|
||||
return text.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
else:
|
||||
raise ValueError("Not running on Python2 or Python 3?")
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def convert_by_vocab(vocab, items):
|
||||
"""Converts a sequence of [tokens|ids] using the vocab."""
|
||||
output = []
|
||||
for item in items:
|
||||
output.append(vocab[item])
|
||||
return output
|
||||
|
||||
|
||||
def convert_tokens_to_ids(vocab, tokens):
|
||||
return convert_by_vocab(vocab, tokens)
|
||||
|
||||
|
||||
def convert_ids_to_tokens(inv_vocab, ids):
|
||||
return convert_by_vocab(inv_vocab, ids)
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class FullTokenizer(object):
|
||||
"""Runs end-to-end tokenziation."""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True):
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return convert_by_vocab(self.vocab, tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
return convert_by_vocab(self.inv_vocab, ids)
|
||||
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenziation."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat in ("Cc", "Cf"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
@ -13,13 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron tokenizers."""
|
||||
"""Megatron tokenizers. just using huggingface implementation."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from transformers import AutoTokenizer
|
||||
from .bert_tokenization import FullTokenizer as FullBertTokenizer
|
||||
from .gpt2_tokenization import GPT2Tokenizer
|
||||
from transformers import AutoTokenizer, GPT2Tokenizer
|
||||
from transformers import BertTokenizer as FullBertTokenizer
|
||||
|
||||
|
||||
def build_tokenizer(args):
|
||||
@ -167,6 +166,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
name = 'BERT Upper Case'
|
||||
super().__init__(name)
|
||||
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
|
||||
self.tokenizer.inv_vocab = {v: k for k, v in self.tokenizer.vocab.items()}
|
||||
self.cls_id = self.tokenizer.vocab['[CLS]']
|
||||
self.sep_id = self.tokenizer.vocab['[SEP]']
|
||||
self.pad_id = self.tokenizer.vocab['[PAD]']
|
||||
@ -205,7 +205,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.vocab_size()
|
||||
return self.tokenizer.vocab_size
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
@ -335,6 +335,7 @@ class _AutoTokenizer(AbstractTokenizer):
|
||||
|
||||
hf_tokenizer_kwargs["model_max_length"] = model_max_length
|
||||
hf_tokenizer_kwargs["use_fast"] = use_fast
|
||||
hf_tokenizer_kwargs["trust_remote_code"] = True
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs)
|
||||
self.encoder = self.tokenizer.get_vocab()
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
@ -380,6 +381,12 @@ class _AutoTokenizer(AbstractTokenizer):
|
||||
@property
|
||||
def pad(self):
|
||||
candidate = self.tokenizer.pad_token_id
|
||||
|
||||
# just use eos_token_id if pad_token_id is not available, it is reasonable
|
||||
# https://github.com/huggingface/transformers/blob/aea761499f4b1193f2706f471442da6f9df65d65/src/transformers/trainer_seq2seq.py#L327
|
||||
# maybe add a new token, and resize embedding layer is better
|
||||
if candidate is None:
|
||||
candidate = self.tokenizer.eos_token_id
|
||||
return self._check_token_candidate(candidate)
|
||||
|
||||
@property
|
||||
|
@ -13,9 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pretrain utilities."""
|
||||
|
||||
from datetime import datetime
|
||||
import os
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
@ -43,6 +41,7 @@ from ascendspeed.core import parallel_state
|
||||
from ascendspeed.checkpointing import load_checkpoint
|
||||
from ascendspeed.checkpointing import save_checkpoint
|
||||
from ascendspeed.model import Float16Module
|
||||
from ascendspeed.core.enums import ModelType
|
||||
from ascendspeed.optimizer import get_megatron_optimizer
|
||||
from ascendspeed.initialize import initialize_megatron
|
||||
from ascendspeed.initialize import write_args_to_tensorboard
|
||||
@ -52,17 +51,17 @@ from ascendspeed.utils import check_adlr_autoresume_termination
|
||||
from ascendspeed.utils import unwrap_model
|
||||
from ascendspeed.data.data_samplers import build_pretraining_data_loader
|
||||
from ascendspeed.utils import calc_params_l2_norm
|
||||
from ascendspeed.schedules import forward_backward_no_pipelining
|
||||
from ascendspeed.schedules import forward_backward_pipelining_without_interleaving
|
||||
from ascendspeed.schedules import forward_backward_pipelining_with_interleaving
|
||||
from ascendspeed.schedules import optimized_forward_backward_pipelining
|
||||
from ascendspeed.core.utils import get_model_config
|
||||
from ascendspeed.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator
|
||||
from ascendspeed.model.transformer import ParallelTransformerLayer
|
||||
|
||||
|
||||
from ascendspeed.model.transformer import ParallelTransformerLayer
|
||||
from ascendspeed.model.lora_utils import is_enable_lora, handle_model_with_lora
|
||||
from ascendspeed.core.pipeline_parallel.schedules import forward_backward_pipelining_with_foldx_fifo
|
||||
from ascendspeed.core.pipeline_parallel.schedules import forward_backward_pipelining_with_foldx_aiao
|
||||
from ascendspeed.core.pipeline_parallel.schedules import get_forward_backward_func, get_forward_func
|
||||
# The earliest we can measure the start time.
|
||||
_TRAIN_START_TIME = time.time()
|
||||
|
||||
|
||||
def print_datetime(string):
|
||||
"""Note that this call will sync across all ranks."""
|
||||
torch.distributed.barrier()
|
||||
@ -107,11 +106,11 @@ def _initialize_optimized_pipeline():
|
||||
'Check either miro batch sizes or global batch sizes.'
|
||||
|
||||
|
||||
|
||||
|
||||
def pretrain(train_valid_test_dataset_provider,
|
||||
model_provider,
|
||||
model_type,
|
||||
forward_step_func,
|
||||
process_non_loss_data_func=None,
|
||||
extra_args_provider=None,
|
||||
args_defaults={},
|
||||
data_post_process=None):
|
||||
@ -128,11 +127,16 @@ def pretrain(train_valid_test_dataset_provider,
|
||||
train/valid/test dataset and returns `train, valid, test` datasets.
|
||||
model_provider: a function that returns a vanilla version of the
|
||||
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
|
||||
model_type: an enum that specifies the type of model being trained.
|
||||
forward_step_func: a function that takes a `data iterator` and `model`,
|
||||
and returns a `loss` scalar with a dictionary with key:values being
|
||||
the info we would like to monitor during training, for example
|
||||
`lm-loss: value`. We also require that this function add
|
||||
`batch generator` to the timers class.
|
||||
process_non_loss_data_func: a function to post process outputs of the
|
||||
network. It can be used for dumping output tensors (e.g images) to
|
||||
tensorboard. It takes `collected data`(list of tensors),
|
||||
`current iteration index` and `tensorboard writer` as arguments.
|
||||
extra_args_provider: a function that takes a parser and adds arguments
|
||||
to it. It is used for programs to add their own arguments.
|
||||
args_defaults: a dictionary from argument-name to argument-value. It
|
||||
@ -166,7 +170,7 @@ def pretrain(train_valid_test_dataset_provider,
|
||||
args.deepspeed_configuration = json.load(
|
||||
open(args.deepspeed_config, 'r', encoding='utf-8'))
|
||||
if "curriculum_learning" in args.deepspeed_configuration and \
|
||||
"enabled" in args.deepspeed_configuration["curriculum_learning"]:
|
||||
"enabled" in args.deepspeed_configuration["curriculum_learning"]:
|
||||
args.curriculum_learning_legacy = args.deepspeed_configuration[ \
|
||||
"curriculum_learning"]["enabled"]
|
||||
if args.curriculum_learning_legacy and not args.no_pipeline_parallel:
|
||||
@ -181,13 +185,13 @@ def pretrain(train_valid_test_dataset_provider,
|
||||
# 2、模型并行:定义模型架构,并切割模型
|
||||
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
|
||||
model, optimizer, lr_scheduler = setup_model_and_optimizer(
|
||||
model_provider, teacher=False, data_post_process=data_post_process,
|
||||
model_provider, model_type, teacher=False, data_post_process=data_post_process,
|
||||
build_train_valid_test_datasets_provider=train_valid_test_dataset_provider)
|
||||
|
||||
timers('model-and-optimizer-setup').stop()
|
||||
print_datetime('after model, optimizer, and learning rate '
|
||||
'scheduler are built')
|
||||
|
||||
config = get_model_config(model[0])
|
||||
# Data stuff.
|
||||
# 3、构造train/val/test数据集
|
||||
timers('train/valid/test-data-iterators-setup', log_level=0).start(barrier=True)
|
||||
@ -197,18 +201,26 @@ def pretrain(train_valid_test_dataset_provider,
|
||||
for _ in range(len(model))
|
||||
]
|
||||
train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
|
||||
if args.foldx_mode is not None:
|
||||
train_data_iterator = [[] for _ in all_data_iterators]
|
||||
if all_data_iterators[0][0] is None:
|
||||
from types import SimpleNamespace
|
||||
train_data_iterator[0] = SimpleNamespace()
|
||||
else:
|
||||
train_data_iterator[0] = all_data_iterators[0][0]
|
||||
train_data_iterator[0].dummy_iterators = train_data_iterator[1:]
|
||||
valid_data_iterator = [[
|
||||
all_data_iterators[i][1][j] for i in range(len(all_data_iterators))]
|
||||
all_data_iterators[i][1][j] for i in range(len(all_data_iterators))]
|
||||
for j in range(len(all_data_iterators[0][1]))
|
||||
]
|
||||
]
|
||||
test_data_iterator = [[
|
||||
all_data_iterators[i][2][j] for i in range(len(all_data_iterators))]
|
||||
all_data_iterators[i][2][j] for i in range(len(all_data_iterators))]
|
||||
for j in range(len(all_data_iterators[0][2]))
|
||||
]
|
||||
else:
|
||||
train_data_iterator, valid_data_iterator, test_data_iterator \
|
||||
= build_train_valid_test_data_iterators(
|
||||
train_valid_test_dataset_provider)
|
||||
train_valid_test_dataset_provider)
|
||||
if args.data_efficiency_curriculum_learning:
|
||||
if args.deepspeed_dataloader is not None:
|
||||
# We use args to pass the deepspeed_dataloader because adding
|
||||
@ -229,7 +241,7 @@ def pretrain(train_valid_test_dataset_provider,
|
||||
# line to use kd, but users do need to provide teacher model configurations
|
||||
# like args.num_layers_teacher as described in setup_teacher_model()
|
||||
args.teacher_model = None
|
||||
if args.mos or args.kd: # Set up teacher model
|
||||
if args.mos or args.kd: # Set up teacher model
|
||||
args.teacher_model = setup_teacher_model(args, model_provider)
|
||||
|
||||
# Print setup timing.
|
||||
@ -241,16 +253,16 @@ def pretrain(train_valid_test_dataset_provider,
|
||||
if args.do_train and args.train_iters > 0:
|
||||
iteration = train(forward_step_func,
|
||||
model, optimizer, lr_scheduler,
|
||||
train_data_iterator, valid_data_iterator)
|
||||
train_data_iterator, valid_data_iterator, config)
|
||||
print_datetime('after training is done')
|
||||
|
||||
if args.do_valid:
|
||||
prefix = 'the end of training for val data'
|
||||
for iterator in valid_data_iterator:
|
||||
evaluate_and_print_results(prefix, forward_step_func,
|
||||
iterator, model,
|
||||
iteration, False)
|
||||
|
||||
iterator, model,
|
||||
iteration, False)
|
||||
|
||||
# Clean the model and do evaluation again
|
||||
if args.compression_training:
|
||||
model = [redundancy_clean(model[0], args.deepspeed_config, mpu)]
|
||||
@ -258,9 +270,8 @@ def pretrain(train_valid_test_dataset_provider,
|
||||
prefix = 'the end of training and after model cleaning for val data'
|
||||
for iterator in valid_data_iterator:
|
||||
evaluate_and_print_results(prefix, forward_step_func,
|
||||
iterator, model,
|
||||
iteration, False)
|
||||
|
||||
iterator, model,
|
||||
iteration, False)
|
||||
|
||||
if args.save and iteration != 0:
|
||||
save_checkpoint(iteration, model, optimizer, lr_scheduler)
|
||||
@ -270,11 +281,11 @@ def pretrain(train_valid_test_dataset_provider,
|
||||
prefix = 'the end of training for test data'
|
||||
for iterator in test_data_iterator:
|
||||
evaluate_and_print_results(prefix, forward_step_func,
|
||||
iterator, model,
|
||||
0, True)
|
||||
iterator, model,
|
||||
0, True)
|
||||
|
||||
|
||||
def update_train_iters(args):
|
||||
|
||||
# For iteration-based training, we don't need to do anything
|
||||
if args.train_iters:
|
||||
return
|
||||
@ -303,8 +314,7 @@ def update_train_iters(args):
|
||||
print_rank_0('setting training iterations to {}'.format(args.train_iters))
|
||||
|
||||
|
||||
def setup_teacher_model(args, model_provider):
|
||||
|
||||
def setup_teacher_model(args, model_provider):
|
||||
print_rank_0('***>>>>> Student model checkpoint iteration:{}'.format(args.iteration))
|
||||
iteration_stuent = args.iteration
|
||||
num_layers_student = args.num_layers
|
||||
@ -332,13 +342,16 @@ def setup_teacher_model(args, model_provider):
|
||||
|
||||
return teacher_model
|
||||
|
||||
def get_model(model_provider_func):
|
||||
|
||||
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
|
||||
"""Build the model."""
|
||||
args = get_args()
|
||||
|
||||
args.model_type = model_type
|
||||
# Build model.
|
||||
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and \
|
||||
args.virtual_pipeline_model_parallel_size is not None:
|
||||
args.virtual_pipeline_model_parallel_size is not None:
|
||||
assert model_type != ModelType.encoder_and_decoder, \
|
||||
"Interleaved schedule not supported for model with both encoder and decoder"
|
||||
model = []
|
||||
for i in range(args.virtual_pipeline_model_parallel_size):
|
||||
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
|
||||
@ -349,14 +362,37 @@ def get_model(model_provider_func):
|
||||
pre_process=pre_process,
|
||||
post_process=post_process
|
||||
)
|
||||
this_model.model_type = model_type
|
||||
|
||||
model.append(this_model)
|
||||
else:
|
||||
pre_process = parallel_state.is_pipeline_first_stage()
|
||||
post_process = parallel_state.is_pipeline_last_stage()
|
||||
model = model_provider_func(
|
||||
pre_process=pre_process,
|
||||
post_process=post_process
|
||||
)
|
||||
add_encoder = True
|
||||
add_decoder = True
|
||||
if model_type == ModelType.encoder_and_decoder:
|
||||
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
assert args.pipeline_model_parallel_split_rank is not None, \
|
||||
"Split rank needs to be specified for model with both encoder and decoder"
|
||||
rank = parallel_state.get_pipeline_model_parallel_rank()
|
||||
split_rank = args.pipeline_model_parallel_split_rank
|
||||
world_size = parallel_state.get_pipeline_model_parallel_world_size()
|
||||
pre_process = rank == 0 or rank == split_rank
|
||||
post_process = (rank == (split_rank - 1)) or (
|
||||
rank == (world_size - 1))
|
||||
add_encoder = parallel_state.is_pipeline_stage_before_split()
|
||||
add_decoder = parallel_state.is_pipeline_stage_after_split()
|
||||
model = model_provider_func(
|
||||
pre_process=pre_process,
|
||||
post_process=post_process,
|
||||
add_encoder=add_encoder,
|
||||
add_decoder=add_decoder)
|
||||
else:
|
||||
model = model_provider_func(
|
||||
pre_process=pre_process,
|
||||
post_process=post_process
|
||||
)
|
||||
model.model_type = model_type
|
||||
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
@ -375,38 +411,48 @@ def get_model(model_provider_func):
|
||||
'model parallel rank ({}, {}): {}'.format(
|
||||
parallel_state.get_tensor_model_parallel_rank(),
|
||||
parallel_state.get_pipeline_model_parallel_rank(),
|
||||
sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
|
||||
sum([sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement() for p in model_module.parameters()])
|
||||
for model_module in model])), flush=True)
|
||||
|
||||
if args.deepspeed:
|
||||
return model
|
||||
if not args.deepspeed:
|
||||
# GPU allocation.
|
||||
for model_module in model:
|
||||
device_name = get_accelerator().current_device_name()
|
||||
print_rank_0(f"model to {device_name}")
|
||||
model_module.to(device_name)
|
||||
|
||||
# GPU allocation.
|
||||
for model_module in model:
|
||||
device_name = get_accelerator().current_device_name()
|
||||
print_rank_0(f"model to {device_name}")
|
||||
model_module.to(device_name)
|
||||
model = wrap_model(model, wrap_with_ddp=wrap_with_ddp)
|
||||
|
||||
if is_enable_lora():
|
||||
model = handle_model_with_lora(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def wrap_model(model, wrap_with_ddp=True):
|
||||
args = get_args()
|
||||
# Fp16 conversion.
|
||||
if args.fp16 or args.bf16:
|
||||
model = [Float16Module(model_module, args) for model_module in model]
|
||||
if wrap_with_ddp:
|
||||
if args.DDP_impl == 'torch':
|
||||
i = get_accelerator().current_device()
|
||||
model = [torchDDP(model_module, device_ids=[i], output_device=i,
|
||||
process_group=parallel_state.get_data_parallel_group())
|
||||
for model_module in model]
|
||||
return model
|
||||
|
||||
if args.DDP_impl == 'torch':
|
||||
i = get_accelerator().current_device()
|
||||
model = [torchDDP(model_module, device_ids=[i], output_device=i,
|
||||
process_group=parallel_state.get_data_parallel_group())
|
||||
for model_module in model]
|
||||
return model
|
||||
elif args.DDP_impl == 'local':
|
||||
model = [LocalDDP(model_module,
|
||||
args.accumulate_allreduce_grads_in_fp32,
|
||||
args.use_contiguous_buffers_in_local_ddp)
|
||||
for model_module in model]
|
||||
return model
|
||||
else:
|
||||
raise NotImplementedError('Unknown DDP implementation specified: {}. '
|
||||
'Exiting.'.format(args.DDP_impl))
|
||||
|
||||
if args.DDP_impl == 'local':
|
||||
model = [LocalDDP(model_module,
|
||||
args.accumulate_allreduce_grads_in_fp32,
|
||||
args.use_contiguous_buffers_in_ddp)
|
||||
for model_module in model]
|
||||
return model
|
||||
|
||||
raise NotImplementedError('Unknown DDP implementation specified: {}. '
|
||||
'Exiting.'.format(args.DDP_impl))
|
||||
return model
|
||||
|
||||
|
||||
def get_learning_rate_scheduler(optimizer):
|
||||
@ -451,6 +497,7 @@ def get_learning_rate_scheduler(optimizer):
|
||||
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def load_model_weights_only(model_provider_func):
|
||||
"""Setup model and optimizer."""
|
||||
args = get_args()
|
||||
@ -486,23 +533,26 @@ def load_model_weights_only(model_provider_func):
|
||||
|
||||
return model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def setup_model_and_optimizer(model_provider_func,
|
||||
model_type,
|
||||
no_wd_decay_cond=None,
|
||||
scale_lr_cond=None,
|
||||
lr_mult=1.0,
|
||||
teacher=False,
|
||||
data_post_process=None,
|
||||
build_train_valid_test_datasets_provider=None):
|
||||
"""Setup model and optimizer."""
|
||||
args = get_args()
|
||||
|
||||
model = get_model(model_provider_func)
|
||||
|
||||
model = get_model(model_provider_func, model_type)
|
||||
# initialize the compression here
|
||||
student_global_steps = 0
|
||||
if args.kd or args.mos:
|
||||
model, _, _, _ = deepspeed.initialize(
|
||||
model=model[0],
|
||||
args=args,
|
||||
mpu=mpu if args.no_pipeline_parallel else None
|
||||
)
|
||||
model=model[0],
|
||||
args=args,
|
||||
mpu=parallel_state if args.no_pipeline_parallel else None
|
||||
)
|
||||
model = [model]
|
||||
if args.load is not None:
|
||||
args.iteration = load_checkpoint(model, None, None, strict=False)
|
||||
@ -511,16 +561,14 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
student_global_steps = model[0].global_steps
|
||||
print_rank_0('***>>>>> Student model, global step:{}'.format(student_global_steps))
|
||||
|
||||
|
||||
if args.compression_training:
|
||||
model, _, _, _ = deepspeed.initialize(
|
||||
model=model[0],
|
||||
args=args,
|
||||
mpu=mpu if args.no_pipeline_parallel else None
|
||||
mpu=parallel_state if args.no_pipeline_parallel else None
|
||||
)
|
||||
model = [model]
|
||||
model = [init_compression(model[0].module, args.deepspeed_config, mpu)]
|
||||
|
||||
|
||||
unwrapped_model = unwrap_model(model,
|
||||
(torchDDP, LocalDDP, Float16Module))
|
||||
@ -530,12 +578,11 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
lr_scheduler = None
|
||||
else:
|
||||
if teacher:
|
||||
optimizer = None
|
||||
optimizer = None
|
||||
else:
|
||||
optimizer = get_megatron_optimizer(unwrapped_model)
|
||||
optimizer = get_megatron_optimizer(model)
|
||||
lr_scheduler = get_learning_rate_scheduler(optimizer)
|
||||
|
||||
|
||||
if args.deepspeed:
|
||||
print_rank_0("DeepSpeed is enabled.")
|
||||
pp = parallel_state.get_pipeline_model_parallel_world_size()
|
||||
@ -556,11 +603,11 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
# baseline's logic to build eval/test dataset later in
|
||||
# build_train_valid_test_data_iterators.
|
||||
eval_iters = (args.train_iters // args.eval_interval + 1) * \
|
||||
args.eval_iters
|
||||
args.eval_iters
|
||||
test_iters = args.eval_iters
|
||||
train_val_test_num_samples = [train_samples,
|
||||
eval_iters * args.global_batch_size,
|
||||
test_iters * args.global_batch_size]
|
||||
eval_iters * args.global_batch_size,
|
||||
test_iters * args.global_batch_size]
|
||||
# Build the datasets.
|
||||
train_ds, _, _ = build_train_valid_test_datasets_provider(
|
||||
train_val_test_num_samples)
|
||||
@ -591,6 +638,8 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
assert model.grid.get_data_parallel_rank() == parallel_state.get_data_parallel_rank()
|
||||
model = [model]
|
||||
|
||||
handle_model_with_checkpoint(lr_scheduler, model, optimizer, student_global_steps)
|
||||
|
||||
# Compression has its own checkpoint loading path (e.g, loading both teacher and student models). So if compression is enabled, we skip the following checkpoint loading.
|
||||
no_post_init_checkpoint_loading = args.kd or args.mos
|
||||
if not no_post_init_checkpoint_loading:
|
||||
@ -601,6 +650,10 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
torch.distributed.barrier()
|
||||
timers('load-checkpoint', log_level=0).start(barrier=True)
|
||||
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
|
||||
if args.deepspeed:
|
||||
optimizer.refresh_fp32_params()
|
||||
else:
|
||||
optimizer.reload_model_params()
|
||||
torch.distributed.barrier()
|
||||
timers('load-checkpoint').stop(barrier=True)
|
||||
timers.log(['load-checkpoint'])
|
||||
@ -615,7 +668,7 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
|
||||
# get model without FP16 and/or TorchDDP wrappers
|
||||
if args.iteration == 0 and len(unwrapped_model) == 1 \
|
||||
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
|
||||
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
|
||||
print_rank_0("Initializing ICT from pretrained BERT model")
|
||||
unwrapped_model[0].init_state_dict_from_bert()
|
||||
if args.fp16:
|
||||
@ -628,8 +681,31 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
return model, optimizer, lr_scheduler
|
||||
|
||||
|
||||
def handle_model_with_checkpoint(lr_scheduler, model, optimizer, student_global_steps=0):
|
||||
args = get_args()
|
||||
# Compression has its own checkpoint loading path (e.g, loading both teacher and student models). So if compression is enabled, we skip the following checkpoint loading.
|
||||
no_post_init_checkpoint_loading = args.kd or args.mos
|
||||
if not no_post_init_checkpoint_loading:
|
||||
print_rank_0(f"\tsetup_model_and_optimizer : no_post_init_checkpoint_loading:{no_post_init_checkpoint_loading}")
|
||||
if args.load is not None:
|
||||
print_rank_0(f"\tsetup_model_and_optimizer : args.load:{args.load}")
|
||||
timers = get_timers()
|
||||
# Extra barrier is added to make sure all ranks report the
|
||||
# max time.
|
||||
torch.distributed.barrier()
|
||||
timers('load-checkpoint', log_level=0).start(barrier=True)
|
||||
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
|
||||
torch.distributed.barrier()
|
||||
timers('load-checkpoint').stop(barrier=True)
|
||||
timers.log(['load-checkpoint'])
|
||||
else:
|
||||
args.iteration = 0
|
||||
else:
|
||||
model[0].global_steps = student_global_steps
|
||||
|
||||
|
||||
def train_step(forward_step_func, data_iterator,
|
||||
model, optimizer, lr_scheduler):
|
||||
model, optimizer, lr_scheduler, config):
|
||||
"""Single training step."""
|
||||
args = get_args()
|
||||
timers = get_timers()
|
||||
@ -640,11 +716,11 @@ def train_step(forward_step_func, data_iterator,
|
||||
assert isinstance(model[0], deepspeed.PipelineEngine)
|
||||
loss = model[0].train_batch(data_iter=data_iterator)
|
||||
grad_norm = model[0].get_global_grad_norm()
|
||||
return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad
|
||||
return {'lm loss': loss}, skipped_iter, grad_norm, num_zeros_in_grad
|
||||
|
||||
# Set grad to zero.
|
||||
if not args.deepspeed:
|
||||
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
|
||||
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
|
||||
for partition in model:
|
||||
partition.zero_grad_buffer()
|
||||
else:
|
||||
@ -652,37 +728,50 @@ def train_step(forward_step_func, data_iterator,
|
||||
|
||||
timers('forward-backward', log_level=1).start(
|
||||
barrier=args.barrier_with_L1_time)
|
||||
forward_backward_func = get_forward_backward_func()
|
||||
|
||||
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
if args.virtual_pipeline_model_parallel_size is not None:
|
||||
forward_backward_func = forward_backward_pipelining_with_interleaving
|
||||
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
|
||||
'number of microbatches is not divisible by pipeline-parallel ' \
|
||||
'size when using interleaved schedule'
|
||||
elif args.optimized_pipeline:
|
||||
forward_backward_func = optimized_forward_backward_pipelining
|
||||
else:
|
||||
forward_backward_func = forward_backward_pipelining_without_interleaving
|
||||
else:
|
||||
forward_backward_func = forward_backward_no_pipelining
|
||||
if args.mos or args.kd:
|
||||
# args.teacher_forward is used as global variable to enable kd loss
|
||||
# calculation in forward pass. Users do not need to set it in the
|
||||
# command line to use kd.
|
||||
args.teacher_forward = True
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func, data_iterator, model,
|
||||
optimizer, timers, forward_only=False)
|
||||
if forward_backward_func == forward_backward_pipelining_with_foldx_fifo or\
|
||||
forward_backward_func == forward_backward_pipelining_with_foldx_aiao:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step_func,
|
||||
data_iterator=data_iterator,
|
||||
model=model,
|
||||
num_microbatches=get_num_microbatches(),
|
||||
seq_length=args.seq_length,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
decoder_seq_length=args.decoder_seq_length)
|
||||
else:
|
||||
losses_reduced = forward_backward_func(
|
||||
forward_step_func=forward_step_func,
|
||||
data_iterator=data_iterator,
|
||||
model=model,
|
||||
num_microbatches=get_num_microbatches(),
|
||||
seq_length=args.seq_length,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
decoder_seq_length=args.decoder_seq_length,
|
||||
forward_only=False)
|
||||
if args.mos or args.kd:
|
||||
args.teacher_forward = False
|
||||
|
||||
# reset timers if necessary
|
||||
if config.timers is None:
|
||||
config.timers = timers
|
||||
timers('forward-backward').stop()
|
||||
|
||||
# All-reduce if needed.
|
||||
if not args.deepspeed and args.DDP_impl == 'local':
|
||||
timers('backward-params-all-reduce', log_level=1).start(barrier=args.barrier_with_L1_time)
|
||||
for model_module in model:
|
||||
model_module.allreduce_gradients()
|
||||
if args.foldx_mode is not None:
|
||||
handles = model[0].allreduce_gradients(async_op=True)
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
else:
|
||||
for model_module in model:
|
||||
model_module.allreduce_gradients()
|
||||
timers('backward-params-all-reduce').stop()
|
||||
|
||||
# All-reduce word_embeddings' grad across first and last stages to ensure
|
||||
@ -691,23 +780,7 @@ def train_step(forward_step_func, data_iterator,
|
||||
# (BERT and GPT-2).
|
||||
timers('backward-embedding-all-reduce', log_level=1).start(barrier=args.barrier_with_L1_time)
|
||||
if not args.deepspeed:
|
||||
if (parallel_state.is_pipeline_first_stage(ignore_virtual=True) or
|
||||
parallel_state.is_pipeline_last_stage(ignore_virtual=True)) and \
|
||||
parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
|
||||
unwrapped_model = model[0]
|
||||
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
|
||||
unwrapped_model = model[-1]
|
||||
unwrapped_model = unwrap_model(
|
||||
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
|
||||
|
||||
if unwrapped_model.share_word_embeddings:
|
||||
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
|
||||
if args.DDP_impl == 'local':
|
||||
grad = word_embeddings_weight.main_grad
|
||||
else:
|
||||
grad = word_embeddings_weight.grad
|
||||
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
|
||||
optimizer.reduce_model_grads(args, timers)
|
||||
timers('backward-embedding-all-reduce').stop()
|
||||
|
||||
# Update parameters.
|
||||
@ -719,7 +792,9 @@ def train_step(forward_step_func, data_iterator,
|
||||
model[0].step(lr_kwargs={'increment': increment})
|
||||
update_successful = model[0].was_step_applied()
|
||||
else:
|
||||
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
|
||||
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
|
||||
if update_successful:
|
||||
optimizer.gather_model_params(args, timers)
|
||||
timers('optimizer').stop()
|
||||
|
||||
# Update learning rate.
|
||||
@ -727,7 +802,7 @@ def train_step(forward_step_func, data_iterator,
|
||||
skipped_iter = 0
|
||||
grad_norm = None
|
||||
num_zeros_in_grad = None
|
||||
|
||||
|
||||
loss_reduced = {}
|
||||
for key in losses_reduced[0]:
|
||||
losses_reduced_for_key = [x[key] for x in losses_reduced]
|
||||
@ -797,6 +872,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
|
||||
def add_to_logging(name):
|
||||
if name in timers._timers:
|
||||
timers_to_log.append(name)
|
||||
|
||||
add_to_logging('forward-compute')
|
||||
add_to_logging('forward-recv')
|
||||
add_to_logging('forward-send')
|
||||
@ -818,13 +894,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
|
||||
|
||||
# Calculate batch size.
|
||||
batch_size = args.micro_batch_size * args.data_parallel_size * \
|
||||
get_num_microbatches()
|
||||
get_num_microbatches()
|
||||
total_iterations = total_loss_dict[advanced_iters_key] + \
|
||||
total_loss_dict[skipped_iters_key]
|
||||
|
||||
# Tensorboard values.
|
||||
if writer and (iteration % args.tensorboard_log_interval == 0) and \
|
||||
is_last_rank():
|
||||
is_last_rank():
|
||||
writer.add_scalar('steps-vs-samples/y=steps,x=samples', iteration, args.consumed_train_samples)
|
||||
writer.add_scalar('steps-vs-samples/y=samples,x=steps', args.consumed_train_samples, iteration)
|
||||
writer.add_scalar('steps-vs-tokens/y=steps,x=tokens', iteration, args.consumed_train_tokens)
|
||||
@ -902,19 +978,19 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
|
||||
opt_stats_2 = [0.0] * 4
|
||||
for _, group in enumerate(optimizer.param_groups):
|
||||
for _, param in enumerate(group['params']):
|
||||
opt_stats[0] += (torch.norm(optimizer.state[param]['exp_avg_sq']).item())**2
|
||||
opt_stats[1] += (torch.norm(optimizer.state[param]['exp_avg_sq'].sqrt()).item())**2
|
||||
opt_stats[2] += (torch.norm(optimizer.state[param]['exp_avg']).item())**2
|
||||
opt_stats[3] += (torch.norm(param).item())**2
|
||||
opt_stats[4] += torch.norm(optimizer.state[param]['exp_avg_sq'],p=1).item()
|
||||
opt_stats[5] += torch.norm(optimizer.state[param]['exp_avg_sq'].sqrt(),p=1).item()
|
||||
opt_stats[6] += torch.norm(optimizer.state[param]['exp_avg'],p=1).item()
|
||||
opt_stats[7] += torch.norm(param,p=1).item()
|
||||
opt_stats_2[0] = max(opt_stats_2[0], abs(optimizer.state[param]['exp_avg_sq'].max().item()),
|
||||
opt_stats[0] += (torch.norm(optimizer.state[param]['exp_avg_sq']).item()) ** 2
|
||||
opt_stats[1] += (torch.norm(optimizer.state[param]['exp_avg_sq'].sqrt()).item()) ** 2
|
||||
opt_stats[2] += (torch.norm(optimizer.state[param]['exp_avg']).item()) ** 2
|
||||
opt_stats[3] += (torch.norm(param).item()) ** 2
|
||||
opt_stats[4] += torch.norm(optimizer.state[param]['exp_avg_sq'], p=1).item()
|
||||
opt_stats[5] += torch.norm(optimizer.state[param]['exp_avg_sq'].sqrt(), p=1).item()
|
||||
opt_stats[6] += torch.norm(optimizer.state[param]['exp_avg'], p=1).item()
|
||||
opt_stats[7] += torch.norm(param, p=1).item()
|
||||
opt_stats_2[0] = max(opt_stats_2[0], abs(optimizer.state[param]['exp_avg_sq'].max().item()),
|
||||
abs(optimizer.state[param]['exp_avg_sq'].min().item()))
|
||||
opt_stats_2[1] = max(opt_stats_2[1], optimizer.state[param]['exp_avg_sq']
|
||||
.sqrt().abs_().max().item())
|
||||
opt_stats_2[2] = max(opt_stats_2[2], abs(optimizer.state[param]['exp_avg'].max().item()),
|
||||
opt_stats_2[2] = max(opt_stats_2[2], abs(optimizer.state[param]['exp_avg'].max().item()),
|
||||
abs(optimizer.state[param]['exp_avg'].min().item()))
|
||||
opt_stats_2[3] = max(opt_stats_2[3], abs(param.max().item()), abs(param.min().item()))
|
||||
|
||||
@ -924,41 +1000,43 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
|
||||
torch.distributed.all_reduce(opt_stats, group=parallel_state.get_data_parallel_group())
|
||||
opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2)
|
||||
torch.distributed.all_reduce(opt_stats_2, op=torch.distributed.ReduceOp.MAX,
|
||||
group=parallel_state.get_data_parallel_group())
|
||||
group=parallel_state.get_data_parallel_group())
|
||||
|
||||
if args.tensor_model_parallel_size > 1:
|
||||
opt_stats = get_accelerator().FloatTensor(opt_stats)
|
||||
torch.distributed.all_reduce(opt_stats, group=parallel_state.get_tensor_model_parallel_group())
|
||||
opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2)
|
||||
torch.distributed.all_reduce(opt_stats_2, op=torch.distributed.ReduceOp.MAX,
|
||||
group=parallel_state.get_tensor_model_parallel_group())
|
||||
group=parallel_state.get_tensor_model_parallel_group())
|
||||
|
||||
if args.pipeline_model_parallel_size > 1:
|
||||
opt_stats = get_accelerator().FloatTensor(opt_stats)
|
||||
torch.distributed.all_reduce(opt_stats, group=parallel_state.get_pipeline_model_parallel_group())
|
||||
opt_stats_2 = get_accelerator().FloatTensor(opt_stats_2)
|
||||
torch.distributed.all_reduce(opt_stats_2, op=torch.distributed.ReduceOp.MAX,
|
||||
group=parallel_state.get_pipeline_model_parallel_group())
|
||||
group=parallel_state.get_pipeline_model_parallel_group())
|
||||
|
||||
# print('step {} rank {} after sync opt_stats {}, {}'.format(iteration, torch.distributed.get_rank(), opt_stats_2, opt_stats))
|
||||
if writer and is_last_rank():
|
||||
writer.add_scalar('optimizer/variance_l2 vs tokens', opt_stats[0]**0.5, args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/variance_sqrt_l2 vs tokens', opt_stats[1]**0.5, args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/momentum_l2 vs tokens', opt_stats[2]**0.5, args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/weight_l2 vs tokens', opt_stats[3]**0.5, args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/variance_l2 vs tokens', opt_stats[0] ** 0.5, args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/variance_sqrt_l2 vs tokens', opt_stats[1] ** 0.5,
|
||||
args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/momentum_l2 vs tokens', opt_stats[2] ** 0.5, args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/weight_l2 vs tokens', opt_stats[3] ** 0.5, args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/variance_l1 vs tokens', opt_stats[4], args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/variance_sqrt_l1 vs tokens', opt_stats[5], args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/momentum_l1 vs tokens', opt_stats[6], args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/weight_l1 vs tokens', opt_stats[7], args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/variance_abs_max vs tokens', opt_stats_2[0], args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/variance_sqrt_abs_max vs tokens', opt_stats_2[1], args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/variance_sqrt_abs_max vs tokens', opt_stats_2[1],
|
||||
args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/momentum_abs_max vs tokens', opt_stats_2[2], args.consumed_train_tokens)
|
||||
writer.add_scalar('optimizer/weight_abs_max vs tokens', opt_stats_2[3], args.consumed_train_tokens)
|
||||
|
||||
writer.add_scalar('optimizer/variance_l2', opt_stats[0]**0.5, iteration)
|
||||
writer.add_scalar('optimizer/variance_sqrt_l2', opt_stats[1]**0.5, iteration)
|
||||
writer.add_scalar('optimizer/momentum_l2', opt_stats[2]**0.5, iteration)
|
||||
writer.add_scalar('optimizer/weight_l2', opt_stats[3]**0.5, iteration)
|
||||
writer.add_scalar('optimizer/variance_l2', opt_stats[0] ** 0.5, iteration)
|
||||
writer.add_scalar('optimizer/variance_sqrt_l2', opt_stats[1] ** 0.5, iteration)
|
||||
writer.add_scalar('optimizer/momentum_l2', opt_stats[2] ** 0.5, iteration)
|
||||
writer.add_scalar('optimizer/weight_l2', opt_stats[3] ** 0.5, iteration)
|
||||
writer.add_scalar('optimizer/variance_l1', opt_stats[4], iteration)
|
||||
writer.add_scalar('optimizer/variance_sqrt_l1', opt_stats[5], iteration)
|
||||
writer.add_scalar('optimizer/momentum_l1', opt_stats[6], iteration)
|
||||
@ -978,7 +1056,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
|
||||
num_layers = args.num_layers
|
||||
vocab_size = args.padded_vocab_size
|
||||
|
||||
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(model, args, elapsed_time, total_iterations)
|
||||
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(model, args, elapsed_time,
|
||||
total_iterations)
|
||||
|
||||
# Compute throughput.
|
||||
samples_per_sec_per_replica = samples_per_sec / args.data_parallel_size
|
||||
@ -1041,7 +1120,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
|
||||
report_memory_flag = False
|
||||
timers.log(timers_to_log, normalizer=args.log_interval)
|
||||
|
||||
|
||||
return report_memory_flag
|
||||
|
||||
|
||||
@ -1059,7 +1137,7 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
|
||||
|
||||
|
||||
def train(forward_step_func, model, optimizer, lr_scheduler,
|
||||
train_data_iterator, valid_data_iterator):
|
||||
train_data_iterator, valid_data_iterator, config):
|
||||
"""Train the model function."""
|
||||
args = get_args()
|
||||
timers = get_timers()
|
||||
@ -1082,15 +1160,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
|
||||
# Iterations.
|
||||
iteration = args.iteration
|
||||
|
||||
timers('interval-time', log_level=0).start()
|
||||
# Translate args to core configuration
|
||||
if not args.deepspeed:
|
||||
config.grad_scale_func = optimizer.scale_loss
|
||||
config.timers = timers
|
||||
|
||||
timers('interval-time', log_level=0).start(barrier=True)
|
||||
print_datetime('before the start of training step')
|
||||
report_memory_flag = True
|
||||
if args.random_ltd:
|
||||
assert model[0].random_ltd_enabled()
|
||||
args.random_ltd_layer_num = model[0].random_ltd_scheduler.get_random_ltd_layer_num()
|
||||
|
||||
|
||||
while iteration < args.train_iters and (args.train_tokens is None or \
|
||||
args.consumed_train_tokens < args.train_tokens):
|
||||
args.consumed_train_tokens < args.train_tokens):
|
||||
update_num_microbatches(args.consumed_train_samples)
|
||||
if args.deepspeed:
|
||||
# inform deepspeed of any batch size changes
|
||||
@ -1101,18 +1184,18 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
|
||||
|
||||
if args.curriculum_learning_legacy and not args.no_pipeline_parallel:
|
||||
args.curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \
|
||||
args.iteration + 1)
|
||||
args.iteration + 1)
|
||||
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
|
||||
train_step(forward_step_func,
|
||||
train_data_iterator,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler)
|
||||
lr_scheduler, config)
|
||||
iteration += 1
|
||||
args.iteration = iteration
|
||||
new_samples = parallel_state.get_data_parallel_world_size() * \
|
||||
args.micro_batch_size * \
|
||||
get_num_microbatches()
|
||||
args.micro_batch_size * \
|
||||
get_num_microbatches()
|
||||
args.consumed_train_samples += new_samples
|
||||
# This actual_seq_length is used for actual consumed tokens calculation, flops calculation, and logging.
|
||||
args.actual_seq_length = args.seq_length
|
||||
@ -1121,18 +1204,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
|
||||
if args.random_ltd:
|
||||
args.random_ltd_reserved_length = model[0].random_ltd_scheduler.get_current_seq()
|
||||
if args.random_ltd_reserved_length < args.actual_seq_length:
|
||||
args.actual_seq_length = (args.actual_seq_length * (args.num_layers - args.random_ltd_layer_num) + args.random_ltd_reserved_length * args.random_ltd_layer_num) // args.num_layers
|
||||
args.actual_seq_length = (args.actual_seq_length * (
|
||||
args.num_layers - args.random_ltd_layer_num)
|
||||
+ args.random_ltd_reserved_length * args.random_ltd_layer_num) // args.num_layers
|
||||
if args.curriculum_learning_legacy or args.data_efficiency_curriculum_learning:
|
||||
if hasattr(args, 'data_efficiency_curriculum_learning_numel'):
|
||||
act_mbsz = args.data_efficiency_curriculum_learning_numel / args.curriculum_seqlen
|
||||
act_token = act_mbsz * args.actual_seq_length
|
||||
args.consumed_train_tokens += parallel_state.get_data_parallel_world_size() * \
|
||||
get_num_microbatches() * act_token
|
||||
get_num_microbatches() * act_token
|
||||
else:
|
||||
args.consumed_train_tokens += new_samples * args.actual_seq_length
|
||||
else:
|
||||
args.consumed_train_tokens += new_samples * args.actual_seq_length
|
||||
|
||||
|
||||
# Logging.
|
||||
if args.deepspeed:
|
||||
if hasattr(model[0].optimizer, 'cur_scale'):
|
||||
@ -1153,23 +1238,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
|
||||
|
||||
# Autoresume
|
||||
if args.adlr_autoresume and \
|
||||
(iteration % args.adlr_autoresume_interval == 0):
|
||||
(iteration % args.adlr_autoresume_interval == 0):
|
||||
check_adlr_autoresume_termination(iteration, model, optimizer,
|
||||
lr_scheduler)
|
||||
|
||||
# Evaluation
|
||||
if args.eval_interval and iteration % args.eval_interval == 0 and \
|
||||
args.do_valid:
|
||||
args.do_valid:
|
||||
prefix = 'iteration {}'.format(iteration)
|
||||
for iterator in valid_data_iterator:
|
||||
evaluate_and_print_results(prefix, forward_step_func,
|
||||
iterator, model,
|
||||
iteration, False)
|
||||
iterator, model,
|
||||
iteration, False)
|
||||
|
||||
# Checkpointing
|
||||
saved_checkpoint = False
|
||||
if args.save and args.save_interval and \
|
||||
iteration % args.save_interval == 0:
|
||||
iteration % args.save_interval == 0:
|
||||
save_checkpoint_and_time(iteration, model, optimizer,
|
||||
lr_scheduler)
|
||||
saved_checkpoint = True
|
||||
@ -1198,7 +1283,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
|
||||
print_datetime('exiting program at iteration {}'.format(iteration))
|
||||
sys.exit()
|
||||
|
||||
|
||||
return iteration
|
||||
|
||||
|
||||
@ -1222,6 +1306,11 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
|
||||
|
||||
total_loss_dict = {}
|
||||
|
||||
# make validation batch size independent from training batch size
|
||||
eval_batch_size = args.global_batch_size
|
||||
eval_num_microbatches = eval_batch_size // \
|
||||
(args.micro_batch_size * args.data_parallel_size)
|
||||
|
||||
with torch.no_grad():
|
||||
iteration = 0
|
||||
while iteration < args.eval_iters:
|
||||
@ -1229,27 +1318,23 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
|
||||
if verbose and iteration % args.log_interval == 0:
|
||||
print_rank_0('Evaluating iter {}/{}'.format(iteration,
|
||||
args.eval_iters))
|
||||
|
||||
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
|
||||
if args.virtual_pipeline_model_parallel_size is not None:
|
||||
forward_backward_func = forward_backward_pipelining_with_interleaving
|
||||
elif args.optimized_pipeline:
|
||||
forward_backward_func = optimized_forward_backward_pipelining
|
||||
else:
|
||||
forward_backward_func = forward_backward_pipelining_without_interleaving
|
||||
else:
|
||||
forward_backward_func = forward_backward_no_pipelining
|
||||
|
||||
forward_backward_func = get_forward_func()
|
||||
if args.deepspeed and args.ds_pipeline_enabled:
|
||||
# DeepSpeed uses eval_batch() and already aggregates losses.
|
||||
assert isinstance(model, list) and len(model) == 1
|
||||
loss = model[0].eval_batch(data_iterator)
|
||||
loss_dicts = [{'lm loss' : loss}] * get_num_microbatches()
|
||||
loss_dicts = [{'lm loss': loss}] * get_num_microbatches()
|
||||
else:
|
||||
loss_dicts = forward_backward_func(
|
||||
forward_step_func, data_iterator, model, optimizer=None,
|
||||
timers=None, forward_only=True)
|
||||
|
||||
forward_step_func=forward_step_func,
|
||||
data_iterator=data_iterator,
|
||||
model=model,
|
||||
num_microbatches=eval_num_microbatches,
|
||||
seq_length=args.seq_length,
|
||||
micro_batch_size=args.micro_batch_size,
|
||||
decoder_seq_length=args.decoder_seq_length,
|
||||
forward_only=True)
|
||||
|
||||
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
|
||||
# Reduce across processes.
|
||||
for loss_dict in loss_dicts:
|
||||
@ -1277,6 +1362,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
|
||||
|
||||
return total_loss_dict
|
||||
|
||||
|
||||
def evaluate_and_print_results(prefix, forward_step_func,
|
||||
data_iterator, model,
|
||||
iteration, verbose=False, test=False, **kwargs):
|
||||
@ -1324,6 +1410,7 @@ def cyclic_iter(iter):
|
||||
for x in iter:
|
||||
yield x
|
||||
|
||||
|
||||
def build_train_valid_test_data_iterators(
|
||||
build_train_valid_test_datasets_provider):
|
||||
"""XXX"""
|
||||
@ -1338,12 +1425,12 @@ def build_train_valid_test_data_iterators(
|
||||
assert args.train_samples is None, \
|
||||
'only backward compatiblity support for iteration-based training'
|
||||
args.consumed_train_samples = args.iteration * args.global_batch_size
|
||||
|
||||
|
||||
if args.iteration // args.eval_interval > 0 and args.consumed_valid_samples == 0:
|
||||
assert args.train_samples is None, \
|
||||
'only backward compatiblity support for iteration-based training'
|
||||
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
|
||||
args.eval_iters * args.global_batch_size
|
||||
args.eval_iters * args.global_batch_size
|
||||
|
||||
# Data loader only on rank 0 of each model parallel group.
|
||||
if parallel_state.get_tensor_model_parallel_rank() == 0:
|
||||
@ -1383,12 +1470,12 @@ def build_train_valid_test_data_iterators(
|
||||
train_dataloaders = build_pretraining_data_loader(train_ds[0], args.consumed_train_samples)
|
||||
|
||||
valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds))
|
||||
for d in valid_ds] \
|
||||
if valid_ds is not None else []
|
||||
for d in valid_ds] \
|
||||
if valid_ds is not None else []
|
||||
|
||||
# We collapse None and empty list as both should mean we don't run test
|
||||
test_dataloaders = [build_pretraining_data_loader(d, 0) for d in test_ds] \
|
||||
if test_ds is not None else []
|
||||
if test_ds is not None else []
|
||||
|
||||
# Flags to know if we need to do training/validation/testing.
|
||||
do_train = train_dataloaders is not None and args.train_iters > 0
|
||||
@ -1418,23 +1505,22 @@ def build_train_valid_test_data_iterators(
|
||||
|
||||
if train_dataloaders is not None:
|
||||
train_data_iterators = iter(train_dataloaders) if dl_type == 'single' \
|
||||
else iter(cyclic_iter(train_dataloaders))
|
||||
else iter(cyclic_iter(train_dataloaders))
|
||||
else:
|
||||
train_data_iterators = None
|
||||
|
||||
if valid_dataloaders is not None:
|
||||
valid_data_iterators = [iter(vdl) if dl_type in ['single'] \
|
||||
else iter(cyclic_iter(valid_dataloaders))
|
||||
for vdl in valid_dataloaders]
|
||||
else iter(cyclic_iter(valid_dataloaders))
|
||||
for vdl in valid_dataloaders]
|
||||
else:
|
||||
valid_data_iterators = [None] * num_valid_ds
|
||||
|
||||
if test_dataloaders is not None:
|
||||
test_data_iterators = [iter(tdl) if dl_type in ['single'] \
|
||||
else iter(cyclic_iter(test_dataloaders))
|
||||
for tdl in test_dataloaders]
|
||||
else iter(cyclic_iter(test_dataloaders))
|
||||
for tdl in test_dataloaders]
|
||||
else:
|
||||
test_data_iterators = [None] * num_test_ds
|
||||
|
||||
|
||||
return train_data_iterators, valid_data_iterators, test_data_iterators
|
||||
|
@ -291,3 +291,18 @@ def checkpoint_throughput_calculator(model, latency_second):
|
||||
checkpoint_GB = approx_parameters_in_billions * checkpoint_multiplier
|
||||
GB_per_second = checkpoint_GB / latency_second
|
||||
print_rank_0(f"Checkpoint Save GB: {round(checkpoint_GB, 3)}, GB/Sec: {round(GB_per_second,2)}, Latency(second): {round(latency_second, 3)}")
|
||||
|
||||
|
||||
def get_tune_attention_mask(attention_mask_1d, reset_attention_mask=True):
|
||||
micro_batch_size, seq_length = attention_mask_1d.size()
|
||||
if reset_attention_mask:
|
||||
att_mask_batch = micro_batch_size
|
||||
else:
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(torch.ones(
|
||||
(att_mask_batch, seq_length, seq_length), device=attention_mask_1d.device)).view(
|
||||
att_mask_batch, 1, seq_length, seq_length)
|
||||
attention_mask = attention_mask.masked_fill((attention_mask_1d < 0.5).view(-1, 1, 1, seq_length), value=0)
|
||||
attention_mask = (attention_mask < 0.5)
|
||||
return attention_mask
|
||||
|
||||
|
@ -1,6 +1,10 @@
|
||||
import os
|
||||
import stat
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
import xmlrunner
|
||||
|
||||
|
||||
#=============================
|
||||
# ST test, run with shell
|
||||
@ -9,23 +13,31 @@ def success_check(res):
|
||||
if res != 0:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def success_check_ut(res):
|
||||
if len(res.failures) + len(res.errors) != 0:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class ST_Test:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
BASE_DIR = Path(__file__).absolute().parent.parent
|
||||
TEST_DIR = os.path.join(BASE_DIR, 'tests')
|
||||
|
||||
gpt_shell_file = os.path.join(TEST_DIR, "st", "test_gpt", "test_gpt_ptd.sh")
|
||||
llama_shell_file = os.path.join(TEST_DIR, "st", "test_llama", "test_llama_ptd.sh")
|
||||
bloom_shell_file = os.path.join(TEST_DIR, "st", "test_bloom", "test_bloom_ptd.sh")
|
||||
llama_vp_shell_file = os.path.join(TEST_DIR, "st", "test_llama", "test_llama_vp.sh")
|
||||
llama_opp_file = os.path.join(TEST_DIR, "st", "test_llama", "test_llama_opp.sh")
|
||||
st_dir = "st"
|
||||
llama_dir = "test_llama"
|
||||
bloom_dir = "test_bloom"
|
||||
|
||||
bloom_shell_file = os.path.join(
|
||||
TEST_DIR, st_dir, bloom_dir, "test_bloom_ptd.sh")
|
||||
llama_shell_file = os.path.join(
|
||||
TEST_DIR, st_dir, llama_dir, "test_llama_ptd.sh")
|
||||
|
||||
self.shell_file_list = [
|
||||
gpt_shell_file,
|
||||
llama_shell_file,
|
||||
bloom_shell_file,
|
||||
llama_vp_shell_file,
|
||||
llama_opp_file
|
||||
llama_shell_file,
|
||||
bloom_shell_file
|
||||
]
|
||||
|
||||
def run_shell(self):
|
||||
@ -39,3 +51,10 @@ class ST_Test:
|
||||
if __name__ == "__main__":
|
||||
st_test = ST_Test()
|
||||
st_test.run_shell()
|
||||
test_loader = unittest.TestLoader()
|
||||
discover = test_loader.discover(start_dir="../tests/ut", pattern="test*.py")
|
||||
|
||||
flags = os.O_WRONLY | os.O_CREAT
|
||||
modes = stat.S_IWUSR | stat.S_IRUSR
|
||||
with os.fdopen(os.open('final.xml', flags, modes), 'wb') as output:
|
||||
success_check_ut(xmlrunner.XMLTestRunner(output=output).run(discover))
|
||||
|
35
examples/baichuan/generate_baichuan_13B_tp8_pp1.sh
Normal file
35
examples/baichuan/generate_baichuan_13B_tp8_pp1.sh
Normal file
@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT"
|
||||
|
||||
CHECKPOINT="your megatron checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference_llama.py \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 40 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 13824 \
|
||||
--load "${CHECKPOINT}" \
|
||||
--num-attention-heads 40 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 512 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42
|
107
examples/baichuan/pretrain_baichuan_zero_7B.sh
Normal file
107
examples/baichuan/pretrain_baichuan_zero_7B.sh
Normal file
@ -0,0 +1,107 @@
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6000
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE * $NNODES))
|
||||
|
||||
DATA_PATH=dataset/llama_text_document
|
||||
CHECKPOINT_PATH=ckpt
|
||||
TOKENIZER_PATH=tokenizer
|
||||
|
||||
DS_CONFIG=ds_config.json
|
||||
ZERO_STAGE=2
|
||||
MICRO_BATCH=1
|
||||
GLOBAL_BATCH=8
|
||||
|
||||
rm -rf kernel_meta*
|
||||
|
||||
cat <<EOT >$DS_CONFIG
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_batch_size": $GLOBAL_BATCH,
|
||||
"train_micro_batch_size_per_gpu":$MICRO_BATCH,
|
||||
"zero_allow_untested_optimizer": true,
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 8,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 2e-5,
|
||||
"eps": 1.0e-8,
|
||||
"betas": [
|
||||
0.9,
|
||||
0.95
|
||||
],
|
||||
"weight_decay": 0.0
|
||||
}
|
||||
},
|
||||
|
||||
"steps_per_print": 1,
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE,
|
||||
"contiguous_gradients": false,
|
||||
"allgather_bucket_size": 1e8,
|
||||
"reduce_bucket_size": 1e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true
|
||||
}
|
||||
}
|
||||
EOT
|
||||
|
||||
ds_args=""
|
||||
ds_args=" --deepspeed ${ds_args}"
|
||||
ds_args=" --no-pipeline-parallel ${ds_args}"
|
||||
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
|
||||
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"
|
||||
ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
|
||||
|
||||
SEQ_LEN=4096
|
||||
|
||||
deepspeed pretrain_baichuan.py \
|
||||
--DDP-impl local \
|
||||
--tensor-model-parallel-size 1 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 32 \
|
||||
--hidden-size 4096 \
|
||||
--ffn-hidden-size 11008 \
|
||||
--num-attention-heads 32 \
|
||||
--micro-batch-size $MICRO_BATCH \
|
||||
--global-batch-size $GLOBAL_BATCH \
|
||||
--seq-length $SEQ_LEN \
|
||||
--max-position-embeddings $SEQ_LEN \
|
||||
--train-iters 1024 \
|
||||
--data-path $DATA_PATH \
|
||||
--tokenizer-name-or-path $TOKENIZER_PATH \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 0.00001 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 1.0e-6 \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .1 \
|
||||
--log-interval 1 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 1 \
|
||||
--checkpoint-activations \
|
||||
--checkpoint_policy block \
|
||||
--checkpoint_block_layer 30 \
|
||||
--triangle-attn \
|
||||
$ds_args \
|
||||
--fp16 | tee logs/train.log
|
@ -116,9 +116,10 @@ TODO
|
||||
|
||||
The performance of the NPUs in **Ascend910 B1 64GB** and GPUs is **A100**:
|
||||
|
||||
| Model | total Iterations | throughput rate (samples/s/p) | throughput rate (tokens/s/p) | single-step time (s/step) | floating point operation (TFLOPs/s) |
|
||||
| :-------: | :--------------: | :---------------------------: | :--------------------------: | :-----------------------: | :---------------------------------: |
|
||||
| Bloom-7B1 | 995 | 1.16 | 2386.12 | 20.68 | 5.02 |
|
||||
| Device | Model | total Iterations | throughput rate (samples/s/p) | throughput rate (tokens/s/p) | single-step time (s/step) | floating point operation (TFLOPs/s) |
|
||||
| :------: | :-------: | :--------------: | :---------------------------: | :--------------------------: | :-----------------------: | :---------------------------------: |
|
||||
| GPUs | Bloom-7B1 | 995 | 1.23 | 2540.23 | 19.44 | 5.68 |
|
||||
| NPUs | Bloom-7B1 | 995 | 1.26 | 2588.03 | 19.07 | 5.91 |
|
||||
|
||||
Notes:
|
||||
|
||||
@ -176,6 +177,22 @@ We support AscendSpeed Inference for text generation with BLOOM 7B1.
|
||||
|
||||
### Model weights
|
||||
|
||||
We provide scripts that support converting pretrained weights into weights that AscendSpeed can load and used for inference. Download the BLOOM-176B checkpoint from [here](https://huggingface.co/bigscience/bloom/tree/main), make sure all chunks are downloaded completely, then use the following command to convert them into checkpoints that AscendSpeed can load. `--partition-layers` specifies the partitioning strategy under the pipeline parallel strategy, you can also modify it to a different strategy, but the sum of all elements of `--partition layers` should be equal to 70 and the number of elements in `--partition-layers` should be equal to `--pipeline-model-parallel-size`.
|
||||
|
||||
```shell
|
||||
#!/bin/bash
|
||||
|
||||
SCRIPT_PATH=./tasks/ckpt_convert/bloom/convert_weights_from_huggingface.py
|
||||
python $SCRIPT_PATH \
|
||||
--input-model-dir "your huggingface checkpoint path" \
|
||||
--output-model-dir "your ascendspeed checkpoint path" \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 12 \
|
||||
--type 176B \
|
||||
--deepspeed \
|
||||
--partition-layers 6,6,6,6,6,6,6,6,6,6,6,4
|
||||
```
|
||||
|
||||
Download the BLOOM model checkpoint from [here](TODO: XXXXX), make sure all chunks are downloaded completely, then use the following command to merge them into a single archive file and extract it:
|
||||
|
||||
```bash
|
||||
|
@ -1,103 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import print_rank_0
|
||||
from ascendspeed import get_tokenizer
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.text_generation_utils import generate_samples_interactive
|
||||
from ascendspeed.checkpointing import load_checkpoint
|
||||
from ascendspeed.initialize import initialize_megatron
|
||||
from ascendspeed.model.gpt_model import GPTModel
|
||||
from ascendspeed.training import get_model
|
||||
from ascendspeed.utils import get_ltor_masks_and_position_ids, unwrap_model
|
||||
from ascendspeed.p2p_communication import recv_forward, send_forward
|
||||
from ascendspeed.model import DistributedDataParallel as LocalDDP
|
||||
from ascendspeed.model import Float16Module
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
def model_provider(pre_process=True, post_process=True):
|
||||
"""Build the model."""
|
||||
print_rank_0('building bloom model ...')
|
||||
model = GPTModel(num_tokentypes=0, parallel_output=False,
|
||||
pre_process=pre_process, post_process=post_process,
|
||||
return_moe_loss=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def add_text_generate_args(parser):
|
||||
"""Text generation arguments."""
|
||||
group = parser.add_argument_group(title='text generation')
|
||||
|
||||
group.add_argument("--temperature", type=float, default=1.0,
|
||||
help='Sampling temperature.')
|
||||
group.add_argument("--greedy", action='store_true', default=False,
|
||||
help='Use greedy sampling.')
|
||||
group.add_argument("--top_p", type=float, default=0.0,
|
||||
help='Top p sampling.')
|
||||
group.add_argument("--top_k", type=int, default=0,
|
||||
help='Top k sampling.')
|
||||
group.add_argument("--out-seq-length", type=int, default=1024,
|
||||
help='Size of the output generated text.')
|
||||
group.add_argument("--sample-input-file", type=str, default=None,
|
||||
help='Get input from file instead of interactive mode, '
|
||||
'each line is an input.')
|
||||
group.add_argument("--sample-output-file", type=str, default=None,
|
||||
help='Output file got from --sample-input-file')
|
||||
group.add_argument("--num-samples", type=int, default=0,
|
||||
help='Number of samples to generate unconditionally, '
|
||||
'defaults to 0 and interactive conditional sampling')
|
||||
group.add_argument("--genfile", type=str,
|
||||
help='Output file when generating unconditionally')
|
||||
group.add_argument("--recompute", action='store_true',
|
||||
help='During generation recompute all attention '
|
||||
'instead of using previously computed keys/values.')
|
||||
|
||||
return parser
|
||||
|
||||
def main():
|
||||
"""Main program."""
|
||||
initialize_megatron(extra_args_provider=add_text_generate_args)
|
||||
|
||||
args = get_args()
|
||||
if args.num_layers_per_virtual_pipeline_stage is not None:
|
||||
print("Interleaved pipeline schedule is not yet supported for text generation.")
|
||||
exit()
|
||||
|
||||
# Set up model and load checkpoint.
|
||||
model = get_model(model_provider)
|
||||
if args.load is not None:
|
||||
_ = load_checkpoint(model, None, None)
|
||||
|
||||
assert len(model) == 1, "Above condition should have caught this"
|
||||
model = model[0]
|
||||
|
||||
# Generate samples.
|
||||
assert args.num_samples == 0, "No sample is required for interactive inference"
|
||||
assert args.micro_batch_size == 1, "Interactive inference requires that micro_batch_size be set to 1"
|
||||
generate_samples_interactive(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,41 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Environment
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
export HCCL_OP_BASE_FFTS_MODE_ENABLE=TRUE
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
# Distributed setting
|
||||
DISTRIBUTED_ARGS="--nproc_per_node 8 --nnodes 2 --node_rank 0 --master_addr localhost --master_port 6001"
|
||||
|
||||
CHECKPOINTS_PATH=<checkpoints-path>
|
||||
TOKENIZER_PATH=<tokenizer-vocab-file-path>
|
||||
|
||||
# Real script
|
||||
python -m torch.distributed.run $DISTRIBUTED_ARGS ./examples/bloom_task/generate_bloom.py \
|
||||
--load CHECKPOINTS_PATH \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path TOKENIZER_PATH \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--embed-layernorm \
|
||||
--position-embedding-type alibi \
|
||||
--num-layers 30 \
|
||||
--hidden-size 4096 \
|
||||
--attention-dropout 0 \
|
||||
--hidden-dropout 0 \
|
||||
--num-attention-heads 32 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 2048 \
|
||||
--init-method-std 0.0048 \
|
||||
--log-interval 1 \
|
||||
--layernorm-epsilon 1e-6 \
|
||||
--fp16 \
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--out-seq-length 1024 \
|
||||
--temperature 1.0 \
|
||||
--top_p 0.9 \
|
||||
--recompute
|
55
examples/bloom/generate_bloom_7b_deepspeed_pipeline.sh
Normal file
55
examples/bloom/generate_bloom_7b_deepspeed_pipeline.sh
Normal file
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
NNODES=1
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
CHECKPOINT="your megatron checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
ZERO_STAGE=0
|
||||
MICRO_BATCH_SIZE=1
|
||||
config_json="./ds_config.json"
|
||||
|
||||
cat <<EOT > $config_json
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 500,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1,
|
||||
"initial_scale_power": 12
|
||||
},
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
EOT
|
||||
|
||||
deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
|
||||
./tasks/inference/inference_bloom_pipeline \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 30 \
|
||||
--hidden-size 4096 \
|
||||
--num-attention-heads 32 \
|
||||
--max-position-embeddings 2048 \
|
||||
--position-embedding-type alibi \
|
||||
--embed-layernorm \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--load "${CHECKPOINT}" \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seed 42 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ${config_json} \
|
118
examples/bloom/pretrain_bloom_176b.sh
Normal file
118
examples/bloom/pretrain_bloom_176b.sh
Normal file
@ -0,0 +1,118 @@
|
||||
#!/bin/bash
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/home/anaconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1600
|
||||
export HCCL_OP_BASE_FFTS_MODE_ENABLE=TRUE
|
||||
export ASCEND_GLOBAL_LOG_LEVEL=3 # 1-INFO 3-ERROR
|
||||
export ASCEND_GLOBAL_EVENT_ENABLE=0
|
||||
export ASCEND_SLOG_PRINT_TO_STDOUT=0
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
# output data path
|
||||
CHECKPOINT_PATH='./ckpt'
|
||||
TENSORBOARD_PATH='./tensorboard/'
|
||||
LOGS_PATH='./logs/'
|
||||
mkdir -p $LOGS_PATH
|
||||
|
||||
# train parameter
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=12890
|
||||
GPUS_PER_NODE=8
|
||||
NNODES=12
|
||||
NODE_RANK=0
|
||||
PP_SIZE=12
|
||||
TP_SIZE=8
|
||||
|
||||
MICRO_BATCH_SIZE=1
|
||||
GLOBAL_BATCH_SIZE=2048
|
||||
|
||||
NLAYERS=70
|
||||
NHIDDEN=14336
|
||||
NHEADS=112
|
||||
SEQ_LEN=2048
|
||||
|
||||
SAVE_INTERVAL=5000
|
||||
|
||||
TRAIN_SAMPLES=220_000_000 # 450B tokens
|
||||
LR_DECAY_SAMPLES=200_000_000 # Decay for the first 410B tokens then continue at fixed --min-lr
|
||||
LR_WARMUP_SAMPLES=183_105 # 375M tokens
|
||||
|
||||
# dataset path
|
||||
TOKENIZER_NAME_OR_PATH=/home/bloom_data/vocab_file/
|
||||
DATA_PATH=/home/bloom_data/oscar_data_1g/my-gpt2_text_document
|
||||
|
||||
ZERO_STAGE=0 # important: bf16 must use z0! it implements its own zero stage 1 equivalent
|
||||
config_json="./ds_config.json"
|
||||
|
||||
cat <<EOT > $config_json
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
|
||||
"train_batch_size": $GLOBAL_BATCH_SIZE,
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
EOT
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
TRANSFORMERS_OFFLINE=1 \
|
||||
python -m torch.distributed.run $DISTRIBUTED_ARGS \
|
||||
pretrain_bloom.py \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--embed-layernorm \
|
||||
--tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \
|
||||
--data-path $DATA_PATH \
|
||||
--attention-dropout 0 \
|
||||
--hidden-dropout 0 \
|
||||
--pad-vocab-size-to 250880 \
|
||||
--tensor-model-parallel-size $TP_SIZE \
|
||||
--pipeline-model-parallel-size $PP_SIZE \
|
||||
--num-layers $NLAYERS \
|
||||
--hidden-size $NHIDDEN \
|
||||
--num-attention-heads $NHEADS \
|
||||
--seq-length $SEQ_LEN \
|
||||
--max-position-embeddings $SEQ_LEN \
|
||||
--micro-batch-size $MICRO_BATCH_SIZE \
|
||||
--rampup-batch-size 192 16 9_765_625 \
|
||||
--global-batch-size $GLOBAL_BATCH_SIZE \
|
||||
--train-samples $TRAIN_SAMPLES \
|
||||
--init-method-std 0.0048 \
|
||||
--bf16 \
|
||||
--seed 42 \
|
||||
--position-embedding-type alibi \
|
||||
--optimizer adam \
|
||||
--adam-beta1 0.9 \
|
||||
--adam-beta2 0.95 \
|
||||
--adam-eps 1e-8 \
|
||||
--lr 6e-5 \
|
||||
--min-lr 6e-6 \
|
||||
--lr-decay-style cosine \
|
||||
--lr-decay-samples $LR_DECAY_SAMPLES \
|
||||
--lr-warmup-samples $LR_WARMUP_SAMPLES \
|
||||
--clip-grad 1.0 \
|
||||
--weight-decay 1e-1 \
|
||||
--log-interval 1 \
|
||||
--save $CHECKPOINT_PATH \
|
||||
--save-interval $SAVE_INTERVAL \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 1 \
|
||||
--tensorboard-queue-size 5 \
|
||||
--load $CHECKPOINT_PATH \
|
||||
--log-timers-to-tensorboard \
|
||||
--log-batch-size-to-tensorboard \
|
||||
--log-validation-ppl-to-tensorboard \
|
||||
--data-impl mmap \
|
||||
--distributed-backend nccl \
|
||||
--deepspeed \
|
||||
--deepspeed_config ${config_json} \
|
||||
--zero-stage ${ZERO_STAGE} \
|
||||
--deepspeed-activation-checkpointing \
|
||||
--sequence-parallel \
|
||||
--checkpoint-activations
|
||||
|
@ -1,45 +0,0 @@
|
||||
# This is an example: basic gpt
|
||||
# without parameter specific and any parallel technologies
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
RANK=0
|
||||
WORLD_SIZE=1
|
||||
|
||||
DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence
|
||||
CHECKPOINT_PATH=./ckpt
|
||||
|
||||
export LOCAL_RANK=0
|
||||
|
||||
python pretrain_gpt.py \
|
||||
--DDP-impl local \
|
||||
--use-contiguous-buffers-in-ddp \
|
||||
--num-layers 1 \
|
||||
--hidden-size 4096 \
|
||||
--num-attention-heads 32 \
|
||||
--micro-batch-size 4 \
|
||||
--global-batch-size 8 \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 2048 \
|
||||
--train-iters 500000 \
|
||||
--lr-decay-iters 320000 \
|
||||
--save $CHECKPOINT_PATH \
|
||||
--load $CHECKPOINT_PATH \
|
||||
--data-path $DATA_PATH \
|
||||
--vocab-file ./dataset/gpt2-vocab.json \
|
||||
--merge-file ./dataset/gpt2-merges.txt \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 0.00015 \
|
||||
--min-lr 1.0e-5 \
|
||||
--lr-decay-style cosine \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .01 \
|
||||
--checkpoint-activations \
|
||||
--log-interval 10 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
--fp16 | tee logs/train.log
|
@ -20,7 +20,6 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
pretrain_gpt.py \
|
||||
--DDP-impl local \
|
||||
--use-contiguous-buffers-in-ddp \
|
||||
--tensor-model-parallel-size 2 \
|
||||
--pipeline-model-parallel-size 2 \
|
||||
--num-layers 8 \
|
||||
|
@ -1,52 +0,0 @@
|
||||
# This is an example: train gpt using TD,
|
||||
# the number of parameters is not aligned
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
GPUS_PER_NODE=8
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6000
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||
|
||||
DATA_PATH=./dataset/enwiki-gpt/gpt_text_sentence
|
||||
CHECKPOINT_PATH=./ckpt
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
pretrain_gpt.py \
|
||||
--DDP-impl local \
|
||||
--use-contiguous-buffers-in-ddp \
|
||||
--tensor-model-parallel-size 4 \
|
||||
--num-layers 8 \
|
||||
--hidden-size 4096 \
|
||||
--num-attention-heads 32 \
|
||||
--micro-batch-size 2 \
|
||||
--global-batch-size 16 \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 2048 \
|
||||
--train-iters 500000 \
|
||||
--lr-decay-iters 320000 \
|
||||
--save $CHECKPOINT_PATH \
|
||||
--load $CHECKPOINT_PATH \
|
||||
--data-path $DATA_PATH \
|
||||
--vocab-file ./dataset/gpt2-vocab.json \
|
||||
--merge-file ./dataset/gpt2-merges.txt \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 0.00015 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 1.0e-5 \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .01 \
|
||||
--checkpoint-activations \
|
||||
--log-interval 10 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
--fp16 | tee logs/train.log
|
92
examples/intern/pretrain_internlm_7b_zero.sh
Normal file
92
examples/intern/pretrain_internlm_7b_zero.sh
Normal file
@ -0,0 +1,92 @@
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
export INF_NAN_MODE_ENABLE=1
|
||||
GPUS_PER_NODE=8
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6000
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||
|
||||
DATA=./dataset/internlm_text_document
|
||||
CHECKPOINT=./ckpt/
|
||||
|
||||
DS_CONFIG=ds_config.json
|
||||
ZERO_STAGE=2
|
||||
GLOBAL_BATCH=64
|
||||
MICRO_BATCH=8
|
||||
|
||||
cat <<EOT > $DS_CONFIG
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 8,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "Adam"
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_batch_size": $GLOBAL_BATCH,
|
||||
"train_micro_batch_size_per_gpu":$MICRO_BATCH,
|
||||
"zero_allow_untested_optimizer": true
|
||||
}
|
||||
EOT
|
||||
|
||||
ds_args=""
|
||||
ds_args=" --deepspeed ${ds_args}"
|
||||
ds_args=" --no-pipeline-parallel ${ds_args}"
|
||||
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
|
||||
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"
|
||||
ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
|
||||
|
||||
|
||||
deepspeed pretrain_intern.py \
|
||||
--DDP-impl torch \
|
||||
--tensor-model-parallel-size 1 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 32 \
|
||||
--hidden-size 4096 \
|
||||
--ffn-hidden-size 11008 \
|
||||
--num-attention-heads 32 \
|
||||
--micro-batch-size $MICRO_BATCH \
|
||||
--global-batch-size $GLOBAL_BATCH \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 2048 \
|
||||
--train-iters 500000 \
|
||||
--lr-decay-iters 320000 \
|
||||
--data-path $DATA \
|
||||
--tokenizer-name-or-path ./dataset/internlm \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 0.00015 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 1.0e-5 \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .01 \
|
||||
--checkpoint-activations \
|
||||
--log-interval 1 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
--triangle-attn \
|
||||
$ds_args \
|
||||
--fp16 | tee logs/train.log
|
@ -1,198 +0,0 @@
|
||||
# LLAMA
|
||||
|
||||
This directory contains some of the scripts that were used to produce the results in the AscendSpeed. These scripts is to show the example how to run the following Model in terminal.
|
||||
|
||||
LLaMA model is from: [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971)
|
||||
|
||||
> Touvron, Hugo, et al. "Llama: Open and efficient foundation language models." arXiv preprint arXiv:2302.13971 (2023).
|
||||
# Contents
|
||||
|
||||
@DOC TODO:需要更新二级 content.
|
||||
|
||||
## Pre-Training
|
||||
|
||||
LLaMA's architecture is very similar to GPT3 with a few added improvements as will be discussed later in this article.
|
||||
|
||||
Here's a quick summary of training llama:
|
||||
|
||||
| | |
|
||||
| :----- | :------------- |
|
||||
| Hardware | 96 64GB Altas 910B NPUs |
|
||||
| Software | AscendSpeed |
|
||||
| Architecture | GPT3 w/ extras |
|
||||
| Dataset | xxxxxxxxxx |
|
||||
| Training time | xxxxxxxxxx |
|
||||
|
||||
### Datasets
|
||||
|
||||
TODO: change the context xxxx. Another important feature from Megatron-LM is the efficient data loader. During start up of the initial training each data set is split into samples of the requested sequence length (2048 for BLOOM) and index is created to number each sample. Based on the training parameters the number of epochs for a dataset is calculated and an ordering for that many epochs is created and then shuffled. For example, if a dataset has 10 samples and should be gone through twice, the system first lays out the samples indices in order [0, ..., 9, 0, ..., 9] and then shuffles that order to create the final global order for the dataset. Notice that this means that training will not simply go through the entire dataset and then repeat, it is possible to see the same sample twice before seeing another sample at all, but at the end of training the model will have seen each sample twice. This helps ensure a smooth training curve through the entire training process. These indices, including the offsets into the base dataset of each sample, are saved to a file to avoid recomputing them each time a training process is started. Several of these datasets can then be blended with varying weights into the final data seen by the training process.
|
||||
|
||||
- 46 Languages in 1.5TB of deduplicated massively cleaned up text, converted into 350B unique tokens
|
||||
- Vocabulary size of the model is 250,680 tokens
|
||||
- For full details please see The BigScience Corpus A 1.6TB Composite Multilingual Dataset
|
||||
|
||||
### Script
|
||||
|
||||
To launch the environment use `pretrain_llama_ptd_16B.sh`:
|
||||
|
||||
```Shell
|
||||
>>> sh pretrain_llama_ptd_16B.sh
|
||||
```
|
||||
|
||||
There is an hourly pulse checking script running that checks that the training is either running or scheduled.
|
||||
|
||||
The Training log will look like these:
|
||||
|
||||
```Shell
|
||||
XXXXX
|
||||
```
|
||||
|
||||
### performance
|
||||
|
||||
#### machine performance
|
||||
|
||||
The performance of the NPUs in XXXXX(configuration) and GPUs is:
|
||||
|
||||
TODO:通过表格呈现吞吐性能,还有并行配置
|
||||
|
||||
#### Accuracy of the loss
|
||||
|
||||
NPU vs GPU loss. XXXX(Explain more).
|
||||
|
||||
![NPU-LOSS](./images/7b_lm_loss.png)
|
||||
|
||||
NPU vs GPU loss relative error. XXXX(Explain more).
|
||||
|
||||
![NPU-Relative-Error](./images/relative_error.png)
|
||||
|
||||
## Fine-tune and Evaluation
|
||||
|
||||
TODO:提供微调的方式,先加载权重,再微调脚本,跟预训练格式一样;后面需要提供task的验证结果(待开发)。
|
||||
|
||||
## Inference
|
||||
|
||||
We support AscendSpeed Inference for text generation with BLOOM 7B1.
|
||||
|
||||
#### Model weights
|
||||
|
||||
Download the BLOOM model checkpoint from [here](TODO: XXXXX), make sure all chunks are downloaded completely, then use the following command to merge them into a single archive file and extract it:
|
||||
|
||||
```bash
|
||||
cat bloom-7b1.tar.part_* > gbloom-7b1.tar
|
||||
tar xvf bloom-7b1.tar
|
||||
```
|
||||
|
||||
Set `CHECKPOINT_PATH` in `/generate_bloom.sh` to the path of the extracted folder. Since the checkpoint file is large, it is recommended to use the SSD or RAM disk to reduce the checkpoint loading time. Since the checkpoint we distribute is in 8-way tensor parallel, a conversion scripts is also provided if you need to change the tensor parallel dimension.
|
||||
|
||||
```bash
|
||||
TODO: add convert_tp tools.
|
||||
|
||||
python tools/convert_tp.py \
|
||||
--input-folder <SRC_CKPT_PATH> \
|
||||
--output-folder <DST_CKPT_PATH> \
|
||||
--target-tp <TARGET_TP>
|
||||
```
|
||||
|
||||
### Script
|
||||
|
||||
We generate text samples using the `generate_bloom` script. Inference different from pre-training, such as we need to Load pre training checkpoint and the length of the output samples:
|
||||
|
||||
```shell
|
||||
bash ./generate_bloom_7b1.sh
|
||||
```
|
||||
|
||||
Alternatively you can also use DeepSpeed from source:
|
||||
|
||||
```Shell
|
||||
TODO: XXXX
|
||||
```
|
||||
|
||||
### Samples
|
||||
|
||||
<details>
|
||||
<summary><b>Left-To-Right Generation Examples</b></summary>
|
||||
|
||||
#### Example 1
|
||||
Input: Who is the greatest artist? The greatest artist is ?
|
||||
|
||||
Output: TODO: XXXX.
|
||||
|
||||
#### Example 2 (Chinese)
|
||||
Input: 问题:冬天,中国哪座城市最适合避寒?问题描述:能推荐一些国内适合冬天避寒的城市吗?回答用户:旅游爱好者 回答:?
|
||||
|
||||
Output: 问题: XXXX.
|
||||
</details>
|
||||
|
||||
All the provided scripts are tested on 8 910B 64GB GPUs for BLOOM 7B1 (fp16). These scripts might not work for other models or a different number of NPUs.
|
||||
|
||||
> Note: Sometimes NPUs memory is not freed when inference deployment crashes. You can free this memory by running kill all python in terminal.
|
||||
|
||||
## Evaluation and Tasks
|
||||
|
||||
### Dataset
|
||||
|
||||
First of all, You must download the evaluation dataset for the [BoolQ](https://storage.googleapis.com/boolq/dev.jsonl), PIQA ([1](https://yonatanbisk.com/piqa/data/valid.jsonl), [2](https://yonatanbisk.com/piqa/data/valid-labels.lst)), [HellaSwag](https://github.com/rowanz/hellaswag/tree/master/data/hellaswag_val.jsonl) tasks.
|
||||
|
||||
### LLama Evaluation
|
||||
|
||||
We include zero-shot example scripts for llama evaluation on [BoolQ](https://storage.googleapis.com/boolq/dev.jsonl), PIQA ([1](https://yonatanbisk.com/piqa/data/valid.jsonl), [2](https://yonatanbisk.com/piqa/data/valid-labels.lst)), and [HellaSwag](https://github.com/rowanz/hellaswag/tree/master/data/hellaswag_val.jsonl) accuracy.
|
||||
|
||||
For example, you can use the following command to run BoolQ zeroshot task on a Llama-7B parameter model.
|
||||
|
||||
```Shell
|
||||
WORLD_SIZE=8
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
|
||||
--nnodes 1 \
|
||||
--node_rank 0 \
|
||||
--master_addr localhost \
|
||||
--master_port 6000"
|
||||
|
||||
TASK="BoolQ"
|
||||
VALID_DATA=<boolq dev data path>.jsonl
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
|
||||
--task $TASK \
|
||||
--valid-data $VALID_DATA \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ./dataset/llama/ \
|
||||
--tokenizer-not-use-fast \
|
||||
--load $CHECKPOINT \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 32 \
|
||||
--hidden-size 4096 \
|
||||
--ffn-hidden-size 11008 \
|
||||
--num-attention-heads 32 \
|
||||
--micro-batch-size 8 \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 2048 \
|
||||
--log-interval 1 \
|
||||
--layernorm-epsilon 1e-6 \
|
||||
--fp16 \
|
||||
--no-load-optim \
|
||||
--no-load-rng
|
||||
```
|
||||
|
||||
### Zero-shot Task
|
||||
|
||||
|
||||
The following table shows the NPU and [LLama Paper](https://arxiv.org/abs/2302.13971) accuracy achieved by the Zero-shot task of the Llama model.
|
||||
|
||||
| Model Size | BoolQ | PIQA | HellaSwag |
|
||||
| :---: | :---: | :---: | :---: |
|
||||
| 7B | 74.7% \| 76.5% | 78.6% \| 79.8% | 73.9% \| 79.8% |
|
||||
| 13B | 79.5% \| 78.1% | 80.4% \| 80.1% | 77.3% \| 80.1% |
|
||||
| 33B | 83.1% \| 83.1% | 81.7% \| 82.3% | 83.0% \| 82.3% |
|
||||
| 65B | 85.5% \| 85.3% | 81.2% \| 82.8% | 82.3% \| 82.8% |
|
||||
|
||||
## Citation
|
||||
|
||||
You may also consider original work in your reference:
|
||||
|
||||
@article{touvron2023llama,
|
||||
title={Llama: Open and efficient foundation language models},
|
||||
author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
|
||||
journal={arXiv preprint arXiv:2302.13971},
|
||||
year={2023}
|
||||
}
|
52
examples/llama/generate_alpaca_13B_deepspeed.sh
Normal file
52
examples/llama/generate_alpaca_13B_deepspeed.sh
Normal file
@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
NNODES=1
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
CHECKPOINT="your origin deepspeed checkpoint path (TP=1, PP=1)"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
ZERO_STAGE=0
|
||||
MICRO_BATCH_SIZE=1
|
||||
config_json="./ds_config.json"
|
||||
|
||||
cat <<EOT > $config_json
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 500,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1,
|
||||
"initial_scale_power": 12
|
||||
},
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
EOT
|
||||
|
||||
deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
|
||||
./tasks/inference/inference_alpaca.py \
|
||||
--num-layers 40 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 13824 \
|
||||
--num-attention-heads 40 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--load "${CHECKPOINT}" \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seed 42 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ${config_json} \
|
||||
--no-pipeline-parallel \
|
58
examples/llama/generate_alpaca_13B_lora_deepspeed.sh
Normal file
58
examples/llama/generate_alpaca_13B_lora_deepspeed.sh
Normal file
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
NNODES=1
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
CHECKPOINT="your origin deepspeed checkpoint path (TP=1, PP=1)"
|
||||
LORA_CHECKPOINT="your lora checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
ZERO_STAGE=0
|
||||
MICRO_BATCH_SIZE=1
|
||||
config_json="./ds_config.json"
|
||||
|
||||
cat <<EOT > $config_json
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 500,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1,
|
||||
"initial_scale_power": 12
|
||||
},
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
EOT
|
||||
|
||||
deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
|
||||
./tasks/inference/inference_alpaca.py \
|
||||
--num-layers 40 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 13824 \
|
||||
--num-attention-heads 40 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--load "${CHECKPOINT}" \
|
||||
--lora-load "${LORA_CHECKPOINT}" \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seed 42 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--lora-target-modules query_key_value dense gate_proj up_proj down_proj \
|
||||
--deepspeed \
|
||||
--deepspeed_config ${config_json} \
|
||||
--no-pipeline-parallel \
|
53
examples/llama/generate_llama_7B_deepspeed_pipeline.sh
Normal file
53
examples/llama/generate_llama_7B_deepspeed_pipeline.sh
Normal file
@ -0,0 +1,53 @@
|
||||
#!/bin/bash
|
||||
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
NNODES=1
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
CHECKPOINT="your megatron checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
ZERO_STAGE=0
|
||||
MICRO_BATCH_SIZE=1
|
||||
config_json="./ds_config.json"
|
||||
|
||||
cat <<EOT > $config_json
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 500,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1,
|
||||
"initial_scale_power": 12
|
||||
},
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
EOT
|
||||
|
||||
deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
|
||||
./tasks/inference/inference_llama_pipeline.py \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--num-layers 30 \
|
||||
--hidden-size 4096 \
|
||||
--ffn-hidden-size 11008 \
|
||||
--num-attention-heads 32 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--load "${CHECKPOINT}" \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seed 42 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ${config_json} \
|
35
examples/llama/generate_llama_7B_tp2_pp2.sh
Normal file
35
examples/llama/generate_llama_7B_tp2_pp2.sh
Normal file
@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=4
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT"
|
||||
|
||||
CHECKPOINT="your megatron checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference_llama.py \
|
||||
--tensor-model-parallel-size 2 \
|
||||
--pipeline-model-parallel-size 2 \
|
||||
--num-layers 32 \
|
||||
--hidden-size 4096 \
|
||||
--ffn-hidden-size 11008 \
|
||||
--load "${CHECKPOINT}" \
|
||||
--num-attention-heads 32 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seed 42
|
97
examples/llama/pretrain_llama_13B_zero_8p.sh
Normal file
97
examples/llama/pretrain_llama_13B_zero_8p.sh
Normal file
@ -0,0 +1,97 @@
|
||||
# This is an example: train llama using TD,
|
||||
# the number of parameters is not aligned
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
source /path/to/cann/ascend-toolkit/set_env.sh
|
||||
|
||||
GPUS_PER_NODE=8
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||
|
||||
DATA=./dataset/llama_text_document
|
||||
CHECKPOINT=./ckpt
|
||||
|
||||
DS_CONFIG=deepspeed_config_13B.json
|
||||
ZERO_STAGE=3
|
||||
GLOBAL_BATCH=64
|
||||
MICRO_BATCH=2
|
||||
|
||||
export INF_NAN_MODE_ENABLE=1
|
||||
|
||||
cat <<EOT > $DS_CONFIG
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 8,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "Adam"
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 1e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 1e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": 4,
|
||||
"train_batch_size": $GLOBAL_BATCH,
|
||||
"train_micro_batch_size_per_gpu":$MICRO_BATCH,
|
||||
"zero_allow_untested_optimizer": true
|
||||
}
|
||||
EOT
|
||||
|
||||
ds_args=""
|
||||
ds_args=" --deepspeed ${ds_args}"
|
||||
ds_args=" --no-pipeline-parallel ${ds_args}"
|
||||
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
|
||||
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"
|
||||
ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
|
||||
|
||||
deepspeed pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--tensor-model-parallel-size 1 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 40 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 13824 \
|
||||
--num-attention-heads 40 \
|
||||
--micro-batch-size $MICRO_BATCH \
|
||||
--global-batch-size $GLOBAL_BATCH \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 2048 \
|
||||
--train-iters 500000 \
|
||||
--lr-decay-iters 320000 \
|
||||
--save $CHECKPOINT \
|
||||
--data-path $DATA \
|
||||
--tokenizer-name-or-path ./dataset/llama/ \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 0.00015 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 1.0e-5 \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .01 \
|
||||
--checkpoint-activations \
|
||||
--log-interval 1 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
$ds_args \
|
||||
--fp16 | tee logs/train_13B.log
|
@ -21,7 +21,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--use-contiguous-buffers-in-ddp \
|
||||
--use-distributed-optimizer \
|
||||
--tensor-model-parallel-size 4 \
|
||||
--pipeline-model-parallel-size 2 \
|
||||
--num-layers 30 \
|
||||
@ -37,7 +37,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
--save $CHECKPOINT_PATH \
|
||||
--load $CHECKPOINT_PATH \
|
||||
--data-path $DATA_PATH \
|
||||
--tokenizer-name-or-path ./dataset/llama/ \
|
||||
--tokenizer-name-or-path $TOKENIZER_PATH \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
59
examples/llama/pretrain_llama_65B_ptd_32p.sh
Normal file
59
examples/llama/pretrain_llama_65B_ptd_32p.sh
Normal file
@ -0,0 +1,59 @@
|
||||
# This is an example: train llama using PTD.
|
||||
|
||||
# The number of parameters is not aligned
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
export COMBINED_ENABLE=1
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=4
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
TOKENIZER_PATH=./dataset/llama_tokenizer
|
||||
DATA_PATH=./dataset/llama_text_document
|
||||
CHECKPOINT_PATH=./ckpt
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
# Main script
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 4 \
|
||||
--num-layers 80 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 22016 \
|
||||
--num-attention-heads 64 \
|
||||
--micro-batch-size 1 \
|
||||
--global-batch-size 256 \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 2048 \
|
||||
--train-iters 50000 \
|
||||
--lr-decay-iters 320000 \
|
||||
--save $CHECKPOINT_PATH \
|
||||
--load $CHECKPOINT_PATH \
|
||||
--data-path $DATA_PATH \
|
||||
--tokenizer-name-or-path $TOKENIZER_PATH \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 0.00015 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 1.0e-5 \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .01 \
|
||||
--log-interval 1 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
--checkpoint-activations \
|
||||
--initial-loss-scale 524288.0 \
|
||||
--sequence-parallel \
|
||||
--mlp-layer-fusion \
|
||||
--bf16 | tee logs/train.log
|
95
examples/llama/pretrain_llama_7B_zero_8p.sh
Normal file
95
examples/llama/pretrain_llama_7B_zero_8p.sh
Normal file
@ -0,0 +1,95 @@
|
||||
# This is an example: train llama using TD,
|
||||
# the number of parameters is not aligned
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
source /path/to/cann/ascend-toolkit/set_env.sh
|
||||
|
||||
GPUS_PER_NODE=8
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6000
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||
|
||||
DATA=./dataset/llama_text_document
|
||||
CHECKPOINT=./ckpt
|
||||
|
||||
DS_CONFIG=deepspeed_config_7B.json
|
||||
ZERO_STAGE=2
|
||||
GLOBAL_BATCH=64
|
||||
MICRO_BATCH=8
|
||||
|
||||
cat <<EOT > $DS_CONFIG
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 8,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "Adam"
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_batch_size": $GLOBAL_BATCH,
|
||||
"train_micro_batch_size_per_gpu":$MICRO_BATCH,
|
||||
"zero_allow_untested_optimizer": true
|
||||
}
|
||||
EOT
|
||||
|
||||
ds_args=""
|
||||
ds_args=" --deepspeed ${ds_args}"
|
||||
ds_args=" --no-pipeline-parallel ${ds_args}"
|
||||
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
|
||||
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"
|
||||
ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
|
||||
|
||||
deepspeed pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--tensor-model-parallel-size 1 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 32 \
|
||||
--hidden-size 4096 \
|
||||
--ffn-hidden-size 11008 \
|
||||
--num-attention-heads 32 \
|
||||
--micro-batch-size $MICRO_BATCH \
|
||||
--global-batch-size $GLOBAL_BATCH \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 2048 \
|
||||
--train-iters 500000 \
|
||||
--lr-decay-iters 320000 \
|
||||
--save $CHECKPOINT \
|
||||
--data-path $DATA \
|
||||
--tokenizer-name-or-path ./dataset/llama/ \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 0.000015 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 1.0e-6 \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .01 \
|
||||
--checkpoint-activations \
|
||||
--log-interval 1 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
$ds_args \
|
||||
--fp16 | tee logs/train_7B.log
|
119
examples/llama/tune_llama_deepspeed_13B.sh
Normal file
119
examples/llama/tune_llama_deepspeed_13B.sh
Normal file
@ -0,0 +1,119 @@
|
||||
# This is an example: train llama using TD,
|
||||
|
||||
# the number of parameters is not aligned
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/home/anaconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
export COMBINED_ENABLE=1
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
GPUS_PER_NODE=8
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6000
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
|
||||
|
||||
TP=1
|
||||
PP=1
|
||||
|
||||
DATA_PATH=<data-path>
|
||||
LOAD_CHECKPOINT_PATH=<origin-ckpt-path>
|
||||
SAVE_CHECKPOINT_PATH=<ckpt-path>
|
||||
TOKENIZER_PATH=<tokenizer-path>
|
||||
|
||||
DS_CONFIG=deepspeed_config_13B.json
|
||||
ZERO_STAGE=2
|
||||
|
||||
MICRO_BATCH=16
|
||||
GRADIENT_ACCUMULATION_STEP=1
|
||||
GLOBAL_BATCH=$(($MICRO_BATCH * $GRADIENT_ACCUMULATION_STEP * $WORLD_SIZE))
|
||||
EPOCH=2
|
||||
TRAIN_ITERS=$((52000 / $GLOBAL_BATCH * $EPOCH))
|
||||
echo $TRAIN_ITERS
|
||||
SAVE_INTERVAL=$(($TRAIN_ITERS / 4))
|
||||
echo $SAVE_INTERVAL
|
||||
|
||||
export HCCL_OP_BASE_FFTS_MODE_ENABLE=TRUE
|
||||
|
||||
cat <<EOT > $DS_CONFIG
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 8,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "Adam"
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 1e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 1e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": ${GRADIENT_ACCUMULATION_STEP},
|
||||
"train_batch_size": $GLOBAL_BATCH,
|
||||
"train_micro_batch_size_per_gpu":$MICRO_BATCH,
|
||||
"zero_allow_untested_optimizer": true
|
||||
}
|
||||
EOT
|
||||
|
||||
ds_args=""
|
||||
ds_args=" --deepspeed ${ds_args}"
|
||||
ds_args=" --no-pipeline-parallel ${ds_args}"
|
||||
ds_args=" --deepspeed_config=$DS_CONFIG ${ds_args}"
|
||||
ds_args=" --zero-stage=$ZERO_STAGE ${ds_args}"
|
||||
ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
|
||||
|
||||
deepspeed pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--no-contiguous-buffers-in-local-ddp \
|
||||
--tensor-model-parallel-size ${TP} \
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--num-layers 40 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 13824 \
|
||||
--num-attention-heads 40 \
|
||||
--micro-batch-size $MICRO_BATCH \
|
||||
--global-batch-size $GLOBAL_BATCH \
|
||||
--seq-length 256 \
|
||||
--max-position-embeddings 2048 \
|
||||
--train-iters ${TRAIN_ITERS} \
|
||||
--lr-decay-iters ${TRAIN_ITERS} \
|
||||
--save $SAVE_CHECKPOINT_PATH \
|
||||
--load $LOAD_CHECKPOINT_PATH \
|
||||
--data-path $DATA_PATH \
|
||||
--tokenizer-name-or-path $TOKENIZER_PATH \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 1e-6 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 0 \
|
||||
--weight-decay 0. \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-iters 200 \
|
||||
--checkpoint-activations \
|
||||
--log-interval 1 \
|
||||
--save-interval ${SAVE_INTERVAL} \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
--use-cpu-initialization \
|
||||
--lora-target-modules query_key_value dense gate_proj up_proj down_proj \
|
||||
--lora-r 64 \
|
||||
--lora-alpha 128 \
|
||||
--lora-modules-to-save word_embeddings lm_head.lm_head \
|
||||
--is-instruction-dataset \
|
||||
$ds_args \
|
||||
--fp16 | tee logs/train_13B_deepspeed.log
|
82
examples/llama/tune_llama_ptd_13B.sh
Normal file
82
examples/llama/tune_llama_ptd_13B.sh
Normal file
@ -0,0 +1,82 @@
|
||||
# This is an example: train llama using PTD.
|
||||
|
||||
# The number of parameters is not aligned
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
export COMBINED_ENABLE=1
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6002
|
||||
NNODES=1
|
||||
NODE_RANK=0 #1
|
||||
NPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
MICRO_BATCH=4
|
||||
GRADIENT_ACCUMULATION_STEP=4
|
||||
GLOBAL_BATCH=$(($MICRO_BATCH * $GRADIENT_ACCUMULATION_STEP * $WORLD_SIZE))
|
||||
EPOCH=5
|
||||
TRAIN_ITERS=$((52000 / $GLOBAL_BATCH * $EPOCH))
|
||||
echo $TRAIN_ITERS
|
||||
SAVE_INTERVAL=$(($TRAIN_ITERS / 4))
|
||||
echo $SAVE_INTERVAL
|
||||
|
||||
TP=4
|
||||
PP=2
|
||||
|
||||
DATA_PATH=<data-path>
|
||||
LOAD_CHECKPOINT_PATH=<origin-ckpt-path>
|
||||
SAVE_CHECKPOINT_PATH=<ckpt-path>
|
||||
TOKENIZER_PATH=<tokenizer-path>
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
# Main script
|
||||
python -m torch.distributed.launch ${DISTRIBUTED_ARGS} \
|
||||
pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--no-contiguous-buffers-in-local-ddp \
|
||||
--tensor-model-parallel-size ${TP} \
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--num-layers 40 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 13824 \
|
||||
--num-attention-heads 40 \
|
||||
--micro-batch-size ${MICRO_BATCH} \
|
||||
--global-batch-size ${GLOBAL_BATCH} \
|
||||
--seq-length 256 \
|
||||
--max-position-embeddings 2048 \
|
||||
--train-iters ${TRAIN_ITERS} \
|
||||
--save ${SAVE_CHECKPOINT_PATH} \
|
||||
--load ${LOAD_CHECKPOINT_PATH} \
|
||||
--checkpoint-activations \
|
||||
--data-path ${DATA_PATH} \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_PATH} \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 1e-6 \
|
||||
--lr-decay-style cosine \
|
||||
--lr-decay-iters ${TRAIN_ITERS} \
|
||||
--min-lr 0 \
|
||||
--weight-decay 0. \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-iters 200 \
|
||||
--log-interval 1 \
|
||||
--save-interval ${SAVE_INTERVAL} \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
--initial-loss-scale 4096.0 \
|
||||
--seed 1234 \
|
||||
--zero-stage 2 \
|
||||
--is-instruction-dataset \
|
||||
--lora-target-modules query_key_value dense gate_proj up_proj down_proj \
|
||||
--lora-r 64 \
|
||||
--lora-alpha 128 \
|
||||
--lora-modules-to-save word_embeddings lm_head.lm_head \
|
||||
--fp16 | tee logs/train_13B_megatron.log
|
||||
|
@ -22,7 +22,6 @@ DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--use-contiguous-buffers-in-ddp \
|
||||
--tensor-model-parallel-size 4 \
|
||||
--pipeline-model-parallel-size 2 \
|
||||
--num-layers 30 \
|
||||
|
106
pretrain_baichuan.py
Normal file
106
pretrain_baichuan.py
Normal file
@ -0,0 +1,106 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pretrain Baichuan."""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch_npu.contrib import transfer_to_npu
|
||||
|
||||
import deepspeed
|
||||
import deepspeed_npu
|
||||
from deepspeed.runtime.utils import see_memory_usage
|
||||
from deepspeed.accelerator.real_accelerator import get_accelerator
|
||||
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import print_rank_0
|
||||
from ascendspeed import get_timers
|
||||
from ascendspeed import get_tokenizer
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed.core.enums import ModelType
|
||||
from ascendspeed.data.gpt_dataset import build_train_valid_test_datasets
|
||||
from ascendspeed.model.baichuan_model import BaichuanModel, BaichuanModelPipe
|
||||
from ascendspeed.training import pretrain
|
||||
from ascendspeed.utils import get_ltor_masks_and_position_ids
|
||||
from ascendspeed.utils import average_losses_across_data_parallel_group
|
||||
from ascendspeed.arguments import core_transformer_config_from_args
|
||||
from pretrain_llama import (
|
||||
data_post_process,
|
||||
get_batch,
|
||||
get_batch_pipe,
|
||||
loss_func,
|
||||
forward_step,
|
||||
train_valid_test_datasets_provider
|
||||
)
|
||||
|
||||
|
||||
def model_provider(pre_process=True, post_process=True):
|
||||
"""Build the model."""
|
||||
|
||||
print_rank_0('Building Baichuan model ...')
|
||||
see_memory_usage(f"Before Building Model ...", force=True)
|
||||
|
||||
args = get_args()
|
||||
config = core_transformer_config_from_args(get_args())
|
||||
with deepspeed.zero.Init(data_parallel_group=parallel_state.get_data_parallel_group(),
|
||||
remote_device=None if args.remote_device == 'none' else args.remote_device,
|
||||
config_dict_or_path=args.deepspeed_config,
|
||||
enabled=args.zero_stage == 3,
|
||||
mpu=parallel_state):
|
||||
if args.deepspeed and not args.no_pipeline_parallel:
|
||||
model = BaichuanModelPipe(config, parallel_output=True)
|
||||
# This is a hack to give us a reference to get_batch_pipe from within training.py
|
||||
# We need to call model.set_batch_fn after deepspeed.initialize
|
||||
model._megatron_batch_fn = get_batch_pipe
|
||||
|
||||
# Predompute the attention mask and store it in args. This avoids having to
|
||||
# pipeline it as an activation during training. The mask is constant, and thus
|
||||
# we can reuse it.
|
||||
attention_mask = torch.tril(torch.ones(
|
||||
(1, args.seq_length, args.seq_length),
|
||||
device=get_accelerator().current_device_name())).view(
|
||||
1, 1, args.seq_length, args.seq_length)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = (attention_mask < 0.5)
|
||||
if args.fp16:
|
||||
attention_mask = attention_mask.half()
|
||||
elif args.bf16:
|
||||
attention_mask = attention_mask.bfloat16()
|
||||
|
||||
# Attention mask must be bool.
|
||||
args.attn_mask = attention_mask.to(torch.bool)
|
||||
else:
|
||||
model = BaichuanModel(
|
||||
config=config,
|
||||
parallel_output=True,
|
||||
add_pooler=False,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process
|
||||
)
|
||||
see_memory_usage(f"After Building Model", force=True)
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.npu.set_compile_mode(jit_compile=True)
|
||||
pretrain(train_valid_test_datasets_provider,
|
||||
model_provider,
|
||||
ModelType.encoder_or_decoder,
|
||||
forward_step,
|
||||
args_defaults={'tokenizer_type': 'PretrainedFromHF'},
|
||||
data_post_process=data_post_process)
|
@ -15,6 +15,7 @@
|
||||
|
||||
"""Pretrain BLOOM"""
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@ -32,11 +33,13 @@ from ascendspeed import get_tokenizer
|
||||
from ascendspeed import mpu
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group
|
||||
from ascendspeed.model import GPTModel, GPTModelPipe, ModelType
|
||||
from ascendspeed.model import GPTModel, GPTModelPipe
|
||||
from ascendspeed.core.enums import ModelType
|
||||
from ascendspeed.enums import AttnMaskType
|
||||
from ascendspeed.training import pretrain
|
||||
from ascendspeed.utils import get_ltor_masks_and_position_ids
|
||||
from ascendspeed.utils import average_losses_across_data_parallel_group
|
||||
from ascendspeed.arguments import core_transformer_config_from_args
|
||||
|
||||
|
||||
def model_provider(pre_process=True, post_process=True):
|
||||
@ -46,6 +49,7 @@ def model_provider(pre_process=True, post_process=True):
|
||||
see_memory_usage(f"Before Building Model", force=True)
|
||||
|
||||
args = get_args()
|
||||
config = core_transformer_config_from_args(get_args())
|
||||
with deepspeed.zero.Init(data_parallel_group=parallel_state.get_data_parallel_group(),
|
||||
remote_device=None if args.remote_device == 'none' else args.remote_device,
|
||||
config_dict_or_path=args.deepspeed_config,
|
||||
@ -54,6 +58,7 @@ def model_provider(pre_process=True, post_process=True):
|
||||
if args.deepspeed:
|
||||
args.pretrain_causal_attention = True
|
||||
model = GPTModelPipe(
|
||||
config,
|
||||
num_tokentypes=0,
|
||||
parallel_output=True
|
||||
)
|
||||
@ -79,6 +84,7 @@ def model_provider(pre_process=True, post_process=True):
|
||||
args.attn_mask = attention_mask.to(torch.bool)
|
||||
else:
|
||||
model = GPTModel(
|
||||
config=config,
|
||||
num_tokentypes=0,
|
||||
parallel_output=True,
|
||||
pre_process=pre_process,
|
||||
@ -98,10 +104,13 @@ def get_batch(data_iterator):
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
if data_iterator is not None:
|
||||
if hasattr(data_iterator, '__next__'):
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
data = None
|
||||
if isinstance(data_iterator, list):
|
||||
return data_iterator.pop(0)
|
||||
else:
|
||||
data = None
|
||||
data_b = mpu.broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
@ -117,6 +126,11 @@ def get_batch(data_iterator):
|
||||
args.reset_attention_mask,
|
||||
args.eod_mask_loss)
|
||||
|
||||
if args.foldx_mode is not None:
|
||||
if hasattr(data_iterator, 'dummy_iterators'):
|
||||
for iterator in data_iterator.dummy_iterators:
|
||||
iterator.append((tokens, labels, loss_mask, attention_mask, position_ids,))
|
||||
|
||||
return tokens, labels, loss_mask, attention_mask, position_ids
|
||||
|
||||
def data_post_process(data, data_sampler_state_dict):
|
||||
@ -234,10 +248,11 @@ def forward_step(data_iterator, model):
|
||||
timers = get_timers()
|
||||
|
||||
# Get the batch.
|
||||
timers('batch-generator', log_level=2).start()
|
||||
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
|
||||
data_iterator)
|
||||
timers('batch-generator').stop()
|
||||
if args.foldx_mode is None:
|
||||
timers('batch-generator').start()
|
||||
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
|
||||
if args.foldx_mode is None:
|
||||
timers('batch-generator').stop()
|
||||
|
||||
output_tensor = model(tokens, position_ids, attention_mask,
|
||||
labels=labels)
|
||||
@ -297,7 +312,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
|
||||
if __name__ == "__main__":
|
||||
torch_npu.npu.set_compile_mode(jit_compile=True)
|
||||
|
||||
pretrain(train_valid_test_datasets_provider, model_provider,
|
||||
pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_or_decoder,
|
||||
forward_step,
|
||||
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}
|
||||
)
|
||||
)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user