mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-02 03:48:24 +08:00
1187 lines
62 KiB
Python
1187 lines
62 KiB
Python
# coding=utf-8
|
||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""ascendspeed arguments."""
|
||
|
||
import dataclasses
|
||
import argparse
|
||
import os
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import deepspeed
|
||
from ascendspeed.core.transformer import TransformerConfig
|
||
from ascendspeed.enums import PositionEmbeddingType
|
||
from ascendspeed.data.parse import ParseDataPaths, ParseDataPathsPath
|
||
|
||
|
||
def parse_args(extra_args_provider=None, defaults={},
|
||
ignore_unknown_args=False):
|
||
"""Parse all arguments."""
|
||
parser = argparse.ArgumentParser(description='AscendSpeed Arguments',
|
||
allow_abbrev=False)
|
||
|
||
# Standard arguments.
|
||
parser = _add_network_size_args(parser)
|
||
parser = _add_regularization_args(parser)
|
||
parser = _add_training_args(parser)
|
||
parser = _add_initialization_args(parser)
|
||
parser = _add_learning_rate_args(parser)
|
||
parser = _add_checkpointing_args(parser)
|
||
parser = _add_mixed_precision_args(parser)
|
||
parser = _add_distributed_args(parser)
|
||
parser = _add_validation_args(parser)
|
||
parser = _add_data_args(parser)
|
||
parser = _add_autoresume_args(parser)
|
||
parser = _add_biencoder_args(parser)
|
||
parser = _add_vit_args(parser)
|
||
parser = _add_logging_args(parser)
|
||
parser = _add_zero_args(parser)
|
||
parser = _add_memoryopt_args(parser)
|
||
parser = _add_activation_checkpoint_args(parser)
|
||
parser = _add_distillation_args(parser)
|
||
parser = _add_optimized_pipeline_args(parser)
|
||
parser = _add_manual_layer_allocation(parser)
|
||
parser = _add_lora_args(parser)
|
||
|
||
# Custom arguments.
|
||
if extra_args_provider is not None:
|
||
parser = extra_args_provider(parser)
|
||
|
||
parser = deepspeed.add_config_arguments(parser)
|
||
|
||
# Parse.
|
||
if ignore_unknown_args:
|
||
args, _ = parser.parse_known_args()
|
||
else:
|
||
args = parser.parse_args()
|
||
|
||
# helper argument to set deepspeed pipeline parallel or not
|
||
args.ds_pipeline_enabled = not args.no_pipeline_parallel
|
||
|
||
# Distributed args.
|
||
args.rank = int(os.getenv('RANK', '0'))
|
||
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
|
||
# Tensor model parallel size.
|
||
args.tensor_model_parallel_size = min(
|
||
args.tensor_model_parallel_size, args.world_size)
|
||
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
|
||
' ({}) is not divisible by tensor model parallel size ({})'.format(
|
||
args.world_size, args.tensor_model_parallel_size)
|
||
# Pipeline model parallel size.
|
||
args.pipeline_model_parallel_size = min(
|
||
args.pipeline_model_parallel_size,
|
||
(args.world_size // args.tensor_model_parallel_size))
|
||
# Checks.
|
||
if args.no_pipeline_parallel:
|
||
assert args.pipeline_model_parallel_size == 1, \
|
||
"pipeline_model_parallel_size must be 1 if pipeline parallel is disabled"
|
||
model_parallel_size = args.pipeline_model_parallel_size * \
|
||
args.tensor_model_parallel_size
|
||
assert args.world_size % model_parallel_size == 0, 'world size is not'\
|
||
' divisible by tensor parallel size ({}) times pipeline parallel ' \
|
||
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
|
||
args.pipeline_model_parallel_size)
|
||
args.data_parallel_size = args.world_size // model_parallel_size
|
||
if args.rank == 0:
|
||
print('using world size: {}, data-parallel-size: {}, '
|
||
'tensor-model-parallel size: {}, '
|
||
'pipeline-model-parallel size: {} '.format(
|
||
args.world_size, args.data_parallel_size,
|
||
args.tensor_model_parallel_size,
|
||
args.pipeline_model_parallel_size), flush=True)
|
||
|
||
if args.data_path:
|
||
assert args.train_weighted_split_paths is None, message
|
||
setattr(args, "valid_weighted_split_names", None)
|
||
setattr(args, "valid_weighted_split_weights", None)
|
||
setattr(args, "valid_weighted_split_splits", None)
|
||
|
||
setattr(args, "test_weighted_split_names", None)
|
||
setattr(args, "test_weighted_split_weights", None)
|
||
setattr(args, "test_weighted_split_splits", None)
|
||
|
||
# Deprecated arguments
|
||
assert args.batch_size is None, '--batch-size argument is no longer ' \
|
||
'valid, use --micro-batch-size instead'
|
||
del args.batch_size
|
||
assert args.warmup is None, '--warmup argument is no longer valid, use ' \
|
||
'--lr-warmup-fraction instead'
|
||
del args.warmup
|
||
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
|
||
'longer valid, use --tensor-model-parallel-size instead'
|
||
del args.model_parallel_size
|
||
|
||
# Set input defaults.
|
||
for key in defaults:
|
||
# For default to be valid, it should not be provided in the
|
||
# arguments that are passed to the program. We check this by
|
||
# ensuring the arg is set to None.
|
||
if getattr(args, key) is not None:
|
||
if args.rank == 0:
|
||
print('WARNING: overriding default arguments for {key}:{v} \
|
||
with {key}:{v2}'.format(key=key, v=defaults[key],
|
||
v2=getattr(args, key)),
|
||
flush=True)
|
||
else:
|
||
setattr(args, key, defaults[key])
|
||
|
||
# Batch size.
|
||
assert args.micro_batch_size is not None
|
||
assert args.micro_batch_size > 0
|
||
if args.global_batch_size is None:
|
||
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
|
||
if args.rank == 0:
|
||
print('setting global batch size to {}'.format(
|
||
args.global_batch_size), flush=True)
|
||
assert args.global_batch_size > 0
|
||
if args.num_layers_per_virtual_pipeline_stage is not None:
|
||
assert args.pipeline_model_parallel_size > 2, \
|
||
'pipeline-model-parallel size should be greater than 2 with ' \
|
||
'interleaved schedule'
|
||
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
|
||
'number of layers is not divisible by number of layers per virtual ' \
|
||
'pipeline stage'
|
||
args.virtual_pipeline_model_parallel_size = \
|
||
(args.num_layers // args.pipeline_model_parallel_size) // \
|
||
args.num_layers_per_virtual_pipeline_stage
|
||
else:
|
||
args.virtual_pipeline_model_parallel_size = None
|
||
|
||
# Parameters dtype.
|
||
args.params_dtype = torch.float
|
||
if args.fp16:
|
||
assert not args.bf16
|
||
args.params_dtype = torch.half
|
||
if args.bf16:
|
||
assert not args.fp16
|
||
args.params_dtype = torch.bfloat16
|
||
# bfloat16 requires gradient accumulation and all-reduce to
|
||
# be done in fp32.
|
||
if not args.accumulate_allreduce_grads_in_fp32:
|
||
args.accumulate_allreduce_grads_in_fp32 = True
|
||
if args.rank == 0:
|
||
print('accumulate and all-reduce gradients in fp32 for '
|
||
'bfloat16 data type.', flush=True)
|
||
|
||
if args.rank == 0:
|
||
print('using {} for parameters ...'.format(args.params_dtype),
|
||
flush=True)
|
||
|
||
# If we do accumulation and all-reduces in fp32, we need to have local DDP
|
||
# and we should make sure use-contiguous-buffers-in-local-ddp is not off.
|
||
if args.accumulate_allreduce_grads_in_fp32:
|
||
assert args.DDP_impl == 'local'
|
||
assert args.use_contiguous_buffers_in_local_ddp
|
||
|
||
# If we use the distributed optimizer, we need to have local DDP
|
||
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
|
||
if args.use_distributed_optimizer:
|
||
assert args.DDP_impl == 'local'
|
||
assert args.use_contiguous_buffers_in_local_ddp
|
||
|
||
# For torch DDP, we do not use contiguous buffer
|
||
if args.DDP_impl == 'torch':
|
||
args.use_contiguous_buffers_in_local_ddp = False
|
||
|
||
if args.dataloader_type is None:
|
||
args.dataloader_type = 'single'
|
||
|
||
# Consumed tokens.
|
||
args.consumed_train_samples = 0
|
||
args.consumed_valid_samples = 0
|
||
args.consumed_train_tokens = 0
|
||
args.custom_token_counting = False
|
||
|
||
# Iteration-based training.
|
||
if args.train_iters:
|
||
# If we use iteration-based training, make sure the
|
||
# sample-based options are off.
|
||
assert args.train_samples is None, \
|
||
'expected iteration-based training'
|
||
assert args.lr_decay_samples is None, \
|
||
'expected iteration-based learning rate decay'
|
||
assert args.lr_warmup_samples == 0, \
|
||
'expected iteration-based learning rate warmup'
|
||
assert args.rampup_batch_size is None, \
|
||
'expected no batch-size rampup for iteration-based training'
|
||
if args.lr_warmup_fraction is not None:
|
||
assert args.lr_warmup_iters == 0, \
|
||
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
|
||
|
||
# Sample-based training.
|
||
if args.train_samples:
|
||
# If we use sample-based training, make sure the
|
||
# iteration-based options are off.
|
||
assert args.train_iters is None, \
|
||
'expected sample-based training'
|
||
assert args.lr_decay_iters is None, \
|
||
'expected sample-based learning rate decay'
|
||
assert args.lr_warmup_iters == 0, \
|
||
'expected sample-based learnig rate warmup'
|
||
if args.lr_warmup_fraction is not None:
|
||
assert args.lr_warmup_samples == 0, \
|
||
'can only specify one of lr-warmup-fraction ' \
|
||
'and lr-warmup-samples'
|
||
|
||
|
||
# Check required arguments.
|
||
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
|
||
'max_position_embeddings']
|
||
for req_arg in required_args:
|
||
_check_arg_is_not_none(args, req_arg)
|
||
|
||
# Checks.
|
||
if args.ffn_hidden_size is None:
|
||
args.ffn_hidden_size = 4 * args.hidden_size
|
||
|
||
if args.kv_channels is None:
|
||
assert args.hidden_size % args.num_attention_heads == 0
|
||
args.kv_channels = args.hidden_size // args.num_attention_heads
|
||
|
||
if args.tensor_model_parallel_size == 1:
|
||
args.sequence_parallel = False
|
||
|
||
if args.seq_length is not None:
|
||
assert args.encoder_seq_length is None
|
||
args.encoder_seq_length = args.seq_length
|
||
else:
|
||
assert args.encoder_seq_length is not None
|
||
args.seq_length = args.encoder_seq_length
|
||
|
||
if (args.position_embedding_type == PositionEmbeddingType.absolute or
|
||
args.position_embedding_type == PositionEmbeddingType.alibi):
|
||
assert args.max_position_embeddings is not None
|
||
if not args.seq_length:
|
||
assert args.max_position_embeddings >= args.seq_length
|
||
if args.decoder_seq_length is not None:
|
||
assert args.max_position_embeddings >= args.decoder_seq_length
|
||
else:
|
||
assert args.max_position_embeddings is None
|
||
|
||
if args.seq_length is not None:
|
||
assert args.max_position_embeddings >= args.seq_length
|
||
if args.decoder_seq_length is not None:
|
||
assert args.max_position_embeddings >= args.decoder_seq_length
|
||
if args.lr is not None:
|
||
assert args.min_lr <= args.lr
|
||
if args.save is not None:
|
||
assert args.save_interval is not None
|
||
# Mixed precision checks.
|
||
if args.fp16_lm_cross_entropy:
|
||
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
|
||
if args.fp32_residual_connection:
|
||
assert args.fp16 or args.bf16, \
|
||
'residual connection in fp32 only supported when using fp16 or bf16.'
|
||
# Activation checkpointing.
|
||
if args.distribute_checkpointed_activations:
|
||
assert args.checkpoint_activations, \
|
||
'for distribute-checkpointed-activations to work you '\
|
||
'need to enable checkpoint-activations'
|
||
torch_major = int(torch.__version__.split('.')[0])
|
||
torch_minor = int(torch.__version__.split('.')[1])
|
||
# Persistent fused layer norm.
|
||
if torch_major < 1 or (torch_major == 1 and torch_minor < 11):
|
||
args.no_persist_layer_norm = True
|
||
if args.rank == 0:
|
||
print('Persistent fused layer norm kernel is supported from '
|
||
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
|
||
'Defaulting to no_persist_layer_norm=True')
|
||
else:
|
||
args.no_persist_layer_norm = False
|
||
args.curriculum_learning_legacy = False
|
||
args.compression_training = False
|
||
args.apply_layernorm_1p = False
|
||
args.overlap_p2p_comm = False
|
||
args.swiglu = False
|
||
args.fp8_e4m3 = False
|
||
args.fp8_hybrid = False
|
||
args.group_query_attention = False
|
||
# AML
|
||
if args.aml_data_download_path is not None:
|
||
data_paths = []
|
||
for path in args.data_path:
|
||
data_paths.append(f"{args.aml_data_download_path}/{path}")
|
||
args.data_path = data_paths
|
||
|
||
# manually layer distribute
|
||
_get_manual_layer_allocation(args)
|
||
|
||
_print_args(args)
|
||
return args
|
||
|
||
|
||
def _print_args(args):
|
||
"""Print arguments."""
|
||
if args.rank == 0:
|
||
print('------------------------ arguments ------------------------',
|
||
flush=True)
|
||
str_list = []
|
||
for arg in vars(args):
|
||
dots = '.' * (48 - len(arg))
|
||
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
|
||
for arg in sorted(str_list, key=lambda x: x.lower()):
|
||
print(arg, flush=True)
|
||
print('-------------------- end of arguments ---------------------',
|
||
flush=True)
|
||
|
||
|
||
def _check_arg_is_not_none(args, arg):
|
||
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
|
||
|
||
|
||
def core_transformer_config_from_args(args):
|
||
# Translate args to core transformer configuration
|
||
kw_args = {}
|
||
for f in dataclasses.fields(TransformerConfig):
|
||
if hasattr(args, f.name):
|
||
kw_args[f.name] = getattr(args, f.name)
|
||
kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
|
||
kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
|
||
kw_args['deallocate_pipeline_outputs'] = False
|
||
kw_args['pipeline_dtype'] = args.params_dtype
|
||
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
|
||
if args.swiglu:
|
||
kw_args['activation_func'] = F.silu
|
||
kw_args['gated_linear_unit'] = True
|
||
kw_args['bias_gelu_fusion'] = False
|
||
if args.init_method_xavier_uniform:
|
||
kw_args['init_method'] = torch.nn.init.xavier_uniform_
|
||
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
|
||
kw_args['fp8'] = args.fp8_e4m3 or args.fp8_hybrid
|
||
kw_args['fp8_e4m3'] = args.fp8_e4m3
|
||
kw_args['fp8_margin'] = args.fp8_hybrid
|
||
if args.group_query_attention:
|
||
kw_args['num_query_groups'] = args.num_query_groups
|
||
else:
|
||
kw_args['num_query_groups'] = None
|
||
|
||
return TransformerConfig(**kw_args)
|
||
|
||
|
||
def _add_network_size_args(parser):
|
||
group = parser.add_argument_group(title='network size')
|
||
|
||
group.add_argument('--num-layers', type=int, default=None,
|
||
help='Number of transformer layers.')
|
||
group.add_argument('--num-experts', type=int, nargs='+', default=[1,],
|
||
help='number of experts list, MoE related.')
|
||
group.add_argument('--mlp-type', type=str, default='standard',
|
||
help='Only applicable when num-experts > 1, accepts [standard, residual]')
|
||
group.add_argument('--topk', type=int, default=1,
|
||
help='Sets the k in TopK gating for MoE layers')
|
||
group.add_argument('--expert-interval', type=int, default=2,
|
||
help='Use experts in every "expert-interval" layers')
|
||
group.add_argument('--hidden-size', type=int, default=None,
|
||
help='Tansformer hidden size.')
|
||
group.add_argument('--ffn-hidden-size', type=int, default=None,
|
||
help='Transformer Feed-Forward Network hidden size. '
|
||
'This is set to 4*hidden-size if not provided')
|
||
group.add_argument('--num-attention-heads', type=int, default=None,
|
||
help='Number of transformer attention heads.')
|
||
group.add_argument('--kv-channels', type=int, default=None,
|
||
help='Projection weights dimension in multi-head '
|
||
'attention. This is set to '
|
||
' args.hidden_size // args.num_attention_heads '
|
||
'if not provided.')
|
||
group.add_argument('--embed-layernorm', action='store_true',
|
||
help='Use layernorm for embedding.')
|
||
group.add_argument('--max-position-embeddings', type=int, default=None,
|
||
help='Maximum number of position embeddings to use. '
|
||
'This is the size of position embedding.')
|
||
group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x],
|
||
choices=list(PositionEmbeddingType), default=PositionEmbeddingType.absolute,
|
||
help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.')
|
||
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
|
||
help='Pad the vocab size to be divisible by this value.'
|
||
'This is added for computational efficieny reasons.')
|
||
group.add_argument('--pad-vocab-size-to', type=int, default=None,
|
||
help='Pad the vocab size to this value.'
|
||
'This value must be greater than the initial size of the tokenizer,'
|
||
'needs to be divisible by TP size and `make-vocab-size-divisible-by`.')
|
||
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
|
||
help='Layer norm epsilon.')
|
||
group.add_argument('--apply-residual-connection-post-layernorm',
|
||
action='store_true',
|
||
help='If set, use original BERT residula connection '
|
||
'ordering.')
|
||
group.add_argument('--openai-gelu', action='store_true',
|
||
help='Use OpenAIs GeLU implementation. This option'
|
||
'should not be used unless for backward compatibility'
|
||
'reasons.')
|
||
group.add_argument('--onnx-safe', type=bool, required=False,
|
||
help='Use workarounds for known problems with '
|
||
'Torch ONNX exporter')
|
||
group.add_argument('--bert-no-binary-head', action='store_false',
|
||
help='Disable BERT binary head.',
|
||
dest='bert_binary_head')
|
||
group.add_argument('--mlp-layer-fusion', action='store_true',
|
||
help='Fuse gate and upprojection in MLP for llama families, '
|
||
'e.g. llama or internlm')
|
||
return parser
|
||
|
||
|
||
def _add_logging_args(parser):
|
||
group = parser.add_argument_group(title='logging')
|
||
|
||
group.add_argument('--log-params-norm', action='store_true',
|
||
help='If set, calculate and log parameters norm.')
|
||
group.add_argument('--log-num-zeros-in-grad', action='store_true',
|
||
help='If set, calculate and log the number of zeros in gradient.')
|
||
group.add_argument('--timing-log-level', type=int,
|
||
default=0, choices=range(0,3),
|
||
help='Granularity level to measure and report timing. '
|
||
' 0: report only iteration time and make sure timing '
|
||
' does not introduce extra overhead.'
|
||
' 1: report timing for operations that are executed '
|
||
' very limited times (basically once) during '
|
||
' each iteration (such as gradient all-reduce) '
|
||
' 2: report timing for operations that migh be '
|
||
' executed numerous times during each iteration. '
|
||
'Note that setting the level to 1 or 2 might '
|
||
'cause increase in iteration time.')
|
||
group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
|
||
help='If not set, use barrier with level 1 time '
|
||
'measurements. Note that this is up to the user '
|
||
'to make sure calling barrier with their timers '
|
||
'will not result in hangs. This can happen if for '
|
||
'example the user adds a level 1 timer that is not '
|
||
'called by all ranks.',
|
||
dest='barrier_with_L1_time')
|
||
group.add_argument('--timing-log-option', type=str, default='flatten',
|
||
choices=['flatten', 'max', 'minmax', 'all'],
|
||
help='Options for logging timing:'
|
||
' flatten: report elapsed time in one line'
|
||
' max: report the max timing across all ranks'
|
||
' minmax: report min and max timings across all ranks'
|
||
' all: report timings of all ranks.')
|
||
group.add_argument('--tensorboard-log-interval', type=int, default=1,
|
||
help='Report to tensorboard interval.')
|
||
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
|
||
help='Size of the tensorboard queue for pending events '
|
||
'and summaries before one of the ‘add’ calls forces a '
|
||
'flush to disk.')
|
||
group.add_argument('--log-timers-to-tensorboard', action='store_true',
|
||
help='If set, write timers to tensorboard.')
|
||
group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
|
||
help='If set, write batch-size to tensorboard.')
|
||
group.add_argument('--no-log-learnig-rate-to-tensorboard',
|
||
action='store_false',
|
||
help='Disable learning rate logging to tensorboard.',
|
||
dest='log_learning_rate_to_tensorboard')
|
||
group.add_argument('--no-log-loss-scale-to-tensorboard',
|
||
action='store_false',
|
||
help='Disable loss-scale logging to tensorboard.',
|
||
dest='log_loss_scale_to_tensorboard')
|
||
group.add_argument('--log-validation-ppl-to-tensorboard',
|
||
action='store_true',
|
||
help='If set, write validation perplexity to '
|
||
'tensorboard.')
|
||
group.add_argument('--log-optimizer-states-to-tensorboard',
|
||
action='store_true',
|
||
help='If set, write various optimizer states to '
|
||
'tensorboard. This feature may consume extra GPU memory.')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_regularization_args(parser):
|
||
group = parser.add_argument_group(title='regularization')
|
||
|
||
group.add_argument('--attention-dropout', type=float, default=0.1,
|
||
help='Post attention dropout probability.')
|
||
group.add_argument('--hidden-dropout', type=float, default=0.1,
|
||
help='Dropout probability for hidden state transformer.')
|
||
group.add_argument('--weight-decay', type=float, default=0.01,
|
||
help='Weight decay coefficient for L2 regularization.')
|
||
group.add_argument('--clip-grad', type=float, default=1.0,
|
||
help='Gradient clipping based on global L2 norm.')
|
||
group.add_argument('--adam-beta1', type=float, default=0.9,
|
||
help='First coefficient for computing running averages '
|
||
'of gradient and its square')
|
||
group.add_argument('--adam-beta2', type=float, default=0.999,
|
||
help='Second coefficient for computing running averages '
|
||
'of gradient and its square')
|
||
group.add_argument('--adam-eps', type=float, default=1e-08,
|
||
help='Term added to the denominator to improve'
|
||
'numerical stability')
|
||
group.add_argument('--sgd-momentum', type=float, default=0.9,
|
||
help='Momentum factor for sgd')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_training_args(parser):
|
||
group = parser.add_argument_group(title='training')
|
||
|
||
group.add_argument('--micro-batch-size', type=int, default=None,
|
||
help='Batch size per model instance (local batch size). '
|
||
'Global batch size is local batch size times data '
|
||
'parallel size times number of micro batches.')
|
||
group.add_argument('--batch-size', type=int, default=None,
|
||
help='Old batch size parameter, do not use. '
|
||
'Use --micro-batch-size instead')
|
||
group.add_argument('--global-batch-size', type=int, default=None,
|
||
help='Training batch size. If set, it should be a '
|
||
'multiple of micro-batch-size times data-parallel-size. '
|
||
'If this value is None, then '
|
||
'use micro-batch-size * data-parallel-size as the '
|
||
'global batch size. This choice will result in 1 for '
|
||
'number of micro-batches.')
|
||
group.add_argument('--rampup-batch-size', nargs='*', default=None,
|
||
help='Batch size ramp up with the following values:'
|
||
' --rampup-batch-size <start batch size> '
|
||
' <batch size incerement> '
|
||
' <ramp-up samples> '
|
||
'For example:'
|
||
' --rampup-batch-size 16 8 300000 \ '
|
||
' --global-batch-size 1024'
|
||
'will start with global batch size 16 and over '
|
||
' (1024 - 16) / 8 = 126 intervals will increase'
|
||
'the batch size linearly to 1024. In each interval'
|
||
'we will use approximately 300000 / 126 = 2380 samples.')
|
||
group.add_argument('--checkpoint-activations', action='store_true',
|
||
help='Checkpoint activation to allow for training '
|
||
'with larger models, sequences, and batch sizes.')
|
||
group.add_argument('--distribute-checkpointed-activations',
|
||
action='store_true',
|
||
help='If set, distribute checkpointed activations '
|
||
'across model parallel group.')
|
||
group.add_argument('--checkpoint-num-layers', type=int, default=1,
|
||
help='chunk size (number of layers) for checkpointing.')
|
||
group.add_argument('--train-iters', type=int, default=None,
|
||
help='Total number of iterations to train over all '
|
||
'training runs. Note that either train-iters or '
|
||
'train-samples should be provided.')
|
||
group.add_argument('--train-samples', type=int, default=None,
|
||
help='Total number of samples to train over all '
|
||
'training runs. Note that either train-iters or '
|
||
'train-samples should be provided.')
|
||
group.add_argument('--train-tokens', type=int, default=None,
|
||
help='Total number of tokens to train over all '
|
||
'training runs.')
|
||
group.add_argument('--random-ltd',
|
||
action='store_true',
|
||
help='enable random layer token drop')
|
||
group.add_argument('--log-interval', type=int, default=100,
|
||
help='Report loss and timing interval.')
|
||
group.add_argument('--exit-interval', type=int, default=None,
|
||
help='Exit the program after the iteration is divisible '
|
||
'by this value.')
|
||
group.add_argument('--exit-duration-in-mins', type=int, default=None,
|
||
help='Exit the program after this many minutes.')
|
||
group.add_argument('--tensorboard-dir', type=str, default=None,
|
||
help='Write TensorBoard logs to this directory.')
|
||
group.add_argument('--no-masked-softmax-fusion',
|
||
action='store_false',
|
||
help='Disable fusion of query_key_value scaling, '
|
||
'masking, and softmax.',
|
||
dest='masked_softmax_fusion')
|
||
group.add_argument('--no-bias-gelu-fusion', action='store_false',
|
||
help='Disable bias and gelu fusion.',
|
||
dest='bias_gelu_fusion')
|
||
group.add_argument('--no-bias-dropout-fusion', action='store_false',
|
||
help='Disable bias and dropout fusion.',
|
||
dest='bias_dropout_fusion')
|
||
group.add_argument('--disable-moe-token-dropping', action='store_false',
|
||
help='Disable MoE expert token dropping.',
|
||
dest='moe_token_dropping')
|
||
group.add_argument('--moe-train-capacity-factor', type=float, default=1.0,
|
||
help='The capacity of the MoE expert at training time')
|
||
group.add_argument('--moe-eval-capacity-factor', type=float, default=1.0,
|
||
help='The capacity of the MoE expert at eval time.')
|
||
group.add_argument('--moe-min-capacity', type=int, default=4,
|
||
help='The minimum capacity per MoE expert regardless of the capacity_factor.')
|
||
group.add_argument('--moe-loss-coeff', type=float, default=0.1,
|
||
help='Scaling coefficient for adding MoE loss to model loss')
|
||
group.add_argument('--create-moe-param-group', action='store_true',
|
||
help='Create separate groups for MoE params.'
|
||
'This is necessary for techniques like ZeRO.')
|
||
group.add_argument('--optimizer', type=str, default='adam',
|
||
choices=['adam', 'sgd', 'fused_adam'],
|
||
help='Optimizer function')
|
||
group.add_argument('--dataloader-type', type=str, default=None,
|
||
choices=['single', 'cyclic'],
|
||
help='Single pass vs multiple pass data loader')
|
||
group.add_argument('--ds-inference', action='store_true',
|
||
help='DeepSpeed inference engine being used')
|
||
group.add_argument('--cpu-optimizer', action='store_true',
|
||
help='Run optimizer on CPU')
|
||
group.add_argument('--cpu_torch_adam', action='store_true',
|
||
help='Use Torch Adam as optimizer on CPU.')
|
||
group.add_argument('--no-pipeline-parallel', action='store_true',
|
||
help='Disable pipeline parallelism')
|
||
group.add_argument('--use-tutel', action='store_true',
|
||
help='Use Tutel optimization for MoE')
|
||
group.add_argument('--inference', action='store_true',
|
||
help='Very basic inference mode: not allocating optim/lr - requires ZERO_STAGE=0')
|
||
group.add_argument('--use-fused-rotary-pos-emb', action='store_true',
|
||
help='use fused rotary pos emb')
|
||
return parser
|
||
|
||
|
||
def _add_initialization_args(parser):
|
||
group = parser.add_argument_group(title='initialization')
|
||
|
||
group.add_argument('--seed', type=int, default=1234,
|
||
help='Random seed used for python, numpy, '
|
||
'pytorch, and cuda.')
|
||
group.add_argument('--init-method-std', type=float, default=0.02,
|
||
help='Standard deviation of the zero mean normal '
|
||
'distribution used for weight initialization.')
|
||
group.add_argument('--init-method-xavier-uniform', action='store_true',
|
||
help='Enable Xavier uniform parameter initialization')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_learning_rate_args(parser):
|
||
group = parser.add_argument_group(title='learning rate')
|
||
|
||
group.add_argument('--lr', type=float, default=None,
|
||
help='Initial learning rate. Depending on decay style '
|
||
'and initial warmup, the learing rate at each '
|
||
'iteration would be different.')
|
||
group.add_argument('--lr-decay-style', type=str, default='linear',
|
||
choices=['constant', 'linear', 'cosine'],
|
||
help='Learning rate decay function.')
|
||
group.add_argument('--lr-decay-iters', type=int, default=None,
|
||
help='number of iterations to decay learning rate over,'
|
||
' If None defaults to `--train-iters`')
|
||
group.add_argument('--lr-decay-samples', type=int, default=None,
|
||
help='number of samples to decay learning rate over,'
|
||
' If None defaults to `--train-samples`')
|
||
group.add_argument('--lr-decay-tokens', type=int, default=None,
|
||
help='number of tokens to decay learning rate over,'
|
||
' If not None will override iter/sample-based decay')
|
||
group.add_argument('--lr-warmup-fraction', type=float, default=None,
|
||
help='fraction of lr-warmup-(iters/samples) to use '
|
||
'for warmup (as a float)')
|
||
group.add_argument('--lr-warmup-iters', type=int, default=0,
|
||
help='number of iterations to linearly warmup '
|
||
'learning rate over.')
|
||
group.add_argument('--lr-warmup-samples', type=int, default=0,
|
||
help='number of samples to linearly warmup '
|
||
'learning rate over.')
|
||
group.add_argument('--lr-warmup-tokens', type=int, default=None,
|
||
help='number of tokens to linearly warmup '
|
||
'learning rate over.')
|
||
group.add_argument('--warmup', type=int, default=None,
|
||
help='Old lr warmup argument, do not use. Use one of the'
|
||
'--lr-warmup-* arguments above')
|
||
group.add_argument('--min-lr', type=float, default=0.0,
|
||
help='Minumum value for learning rate. The scheduler'
|
||
'clip values below this threshold.')
|
||
group.add_argument('--override-lr-scheduler', action='store_true',
|
||
help='Reset the values of the scheduler (learning rate,'
|
||
'warmup iterations, minimum learning rate, maximum '
|
||
'number of iterations, and decay style from input '
|
||
'arguments and ignore values from checkpoints. Note'
|
||
'that all the above values will be reset.')
|
||
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
|
||
help='Use checkpoint to set the values of the scheduler '
|
||
'(learning rate, warmup iterations, minimum learning '
|
||
'rate, maximum number of iterations, and decay style '
|
||
'from checkpoint and ignore input arguments.')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_checkpointing_args(parser):
|
||
group = parser.add_argument_group(title='checkpointing')
|
||
|
||
group.add_argument('--save', type=str, default=None,
|
||
help='Output directory to save checkpoints to.')
|
||
group.add_argument('--save-interval', type=int, default=None,
|
||
help='Number of iterations between checkpoint saves.')
|
||
group.add_argument('--no-save-optim', action='store_true', default=None,
|
||
help='Do not save current optimizer.')
|
||
group.add_argument('--no-save-rng', action='store_true', default=None,
|
||
help='Do not save current rng state.')
|
||
group.add_argument('--load', type=str, default=None,
|
||
help='Directory containing a model checkpoint.')
|
||
group.add_argument('--no-load-optim', action='store_true', default=None,
|
||
help='Do not load optimizer when loading checkpoint.')
|
||
group.add_argument('--no-load-rng', action='store_true', default=None,
|
||
help='Do not load rng state when loading checkpoint.')
|
||
group.add_argument('--no-load-lr-state', action='store_true',
|
||
help='Do not load lr state when loading checkpoint.')
|
||
group.add_argument('--finetune', action='store_true',
|
||
help='Load model for finetuning. Do not load optimizer '
|
||
'or rng state from checkpoint and set iteration to 0. '
|
||
'Assumed when loading a release checkpoint.')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_mixed_precision_args(parser):
|
||
group = parser.add_argument_group(title='mixed precision')
|
||
|
||
group.add_argument('--fp16', action='store_true',
|
||
help='Run model in fp16 mode.')
|
||
group.add_argument('--bf16', action='store_true',
|
||
help='Run model in bfloat16 mode.')
|
||
group.add_argument('--loss-scale', type=float, default=None,
|
||
help='Static loss scaling, positive power of 2 '
|
||
'values can improve fp16 convergence. If None, dynamic'
|
||
'loss scaling is used.')
|
||
group.add_argument('--initial-loss-scale', type=float, default=2**32,
|
||
help='Initial loss-scale for dynamic loss scaling.')
|
||
group.add_argument('--min-loss-scale', type=float, default=1.0,
|
||
help='Minimum loss scale for dynamic loss scale.')
|
||
group.add_argument('--loss-scale-window', type=float, default=1000,
|
||
help='Window over which to raise/lower dynamic scale.')
|
||
group.add_argument('--hysteresis', type=int, default=2,
|
||
help='hysteresis for dynamic loss scaling')
|
||
group.add_argument('--fp32-residual-connection', action='store_true',
|
||
help='Move residual connections to fp32.')
|
||
group.add_argument('--no-query-key-layer-scaling', action='store_false',
|
||
help='Do not scale Q * K^T by 1 / layer-number.',
|
||
dest='apply_query_key_layer_scaling')
|
||
group.add_argument('--attention-softmax-in-fp32', action='store_true',
|
||
help='Run attention masking and softmax in fp32. '
|
||
'This flag is ignored unless '
|
||
'--no-query-key-layer-scaling is specified.')
|
||
group.add_argument('--accumulate-allreduce-grads-in-fp32',
|
||
action='store_true',
|
||
help='Gradient accumulation and all-reduce in fp32.')
|
||
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
|
||
help='Move the cross entropy unreduced loss calculation'
|
||
'for lm head to fp16.')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_distributed_args(parser):
|
||
group = parser.add_argument_group(title='distributed')
|
||
|
||
group.add_argument('--foldx-mode', default=None,
|
||
choices=['aiao', 'fifo'],
|
||
help='Choose fold-x pipeline parallelism.')
|
||
group.add_argument('--tensor-model-parallel-size', type=int, default=1,
|
||
help='Degree of tensor model parallelism.')
|
||
group.add_argument('--enable-expert-tensor-parallelism', action='store_true',
|
||
default=False,
|
||
help="use tensor parallelism for expert layers in MoE")
|
||
group.add_argument('--sequence-parallel', action='store_true',
|
||
default=False,
|
||
help="use sequence parallelism")
|
||
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
|
||
help='Degree of pipeline model parallelism.')
|
||
group.add_argument('--pipeline-model-parallel-split-rank',
|
||
type=int, default=None,
|
||
help='Rank where encoder and decoder should be split.')
|
||
group.add_argument('--moe-expert-parallel-size', type=int, default=1,
|
||
help='Degree of the MoE expert parallelism.')
|
||
group.add_argument('--model-parallel-size', type=int, default=None,
|
||
help='Old model parallel argument, do not use. Use '
|
||
'--tensor-model-parallel-size instead.')
|
||
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
|
||
help='Number of layers per virtual pipeline stage')
|
||
group.add_argument('--distributed-backend', default='nccl',
|
||
choices=['nccl', 'gloo', 'ccl'],
|
||
help='Which backend to use for distributed training.')
|
||
group.add_argument('--DDP-impl', default='local',
|
||
choices=['local', 'torch'],
|
||
help='which DistributedDataParallel implementation '
|
||
'to use.')
|
||
group.add_argument('--no-contiguous-buffers-in-local-ddp',
|
||
action='store_false', help='If set, dont use '
|
||
'contiguous buffer in local DDP.',
|
||
dest='use_contiguous_buffers_in_local_ddp')
|
||
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
|
||
help='Use scatter/gather to optimize communication of tensors in pipeline',
|
||
dest='scatter_gather_tensors_in_pipeline')
|
||
group.add_argument('--local_rank', type=int, default=None,
|
||
help='local rank passed from distributed launcher.')
|
||
group.add_argument('--lazy-mpu-init', type=bool, required=False,
|
||
help='If set to True, initialize_megatron() '
|
||
'skips DDP initialization and returns function to '
|
||
'complete it instead.Also turns on '
|
||
'--use-cpu-initialization flag. This is for '
|
||
'external DDP manager.' )
|
||
group.add_argument('--use-cpu-initialization', action='store_true',
|
||
default=None, help='If set, affine parallel weights '
|
||
'initialization uses CPU' )
|
||
group.add_argument('--triangle-attn', action='store_true',
|
||
help="use triangle attention instead self attention")
|
||
group.add_argument('--use-distributed-optimizer', action='store_true',
|
||
help='Use distributed optimizer.')
|
||
return parser
|
||
|
||
|
||
def _add_validation_args(parser):
|
||
group = parser.add_argument_group(title='validation')
|
||
|
||
group.add_argument('--eval-iters', type=int, default=100,
|
||
help='Number of iterations to run for evaluation'
|
||
'validation/test for.')
|
||
group.add_argument('--eval-interval', type=int, default=1000,
|
||
help='Interval between running evaluation on '
|
||
'validation set.')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_data_args(parser):
|
||
group = parser.add_argument_group(title='data and dataloader')
|
||
|
||
group.add_argument('--aml-data-download-path', type=str, default=None,
|
||
help='Path to mounted input dataset')
|
||
group.add_argument('--data-path', nargs='*', default=None,
|
||
help='Path to the training dataset. Accepted format:'
|
||
'1) a single data path, 2) multiple datasets in the'
|
||
'form: dataset1-weight dataset1-path dataset2-weight '
|
||
'dataset2-path ...')
|
||
group.add_argument('--is-instruction-dataset', action='store_true',
|
||
help='use instruction dataset or not')
|
||
group.add_argument('--split', type=str, default='969, 30, 1',
|
||
help='Comma-separated list of proportions for training,'
|
||
' validation, and test split. For example the split '
|
||
'`90,5,5` will use 90%% of data for training, 5%% for '
|
||
'validation and 5%% for test.')
|
||
group.add_argument('--vocab-file', type=str, default=None,
|
||
help='Path to the vocab file.')
|
||
group.add_argument('--merge-file', type=str, default=None,
|
||
help='Path to the BPE merge file.')
|
||
group.add_argument('--vocab-extra-ids', type=int, default=0,
|
||
help='Number of additional vocabulary tokens. '
|
||
'They are used for span masking in the T5 model')
|
||
group.add_argument('--seq-length', type=int, default=None,
|
||
help='Maximum sequence length to process.')
|
||
group.add_argument('--encoder-seq-length', type=int, default=None,
|
||
help='Maximum encoder sequence length to process.'
|
||
'This should be exclusive of --seq-length')
|
||
group.add_argument('--decoder-seq-length', type=int, default=None,
|
||
help="Maximum decoder sequence length to process.")
|
||
group.add_argument('--retriever-seq-length', type=int, default=256,
|
||
help='Maximum sequence length for the biencoder model '
|
||
' for retriever')
|
||
group.add_argument('--sample-rate', type=float, default=1.0,
|
||
help='sample rate for training data. Supposed to be 0 '
|
||
' < sample_rate < 1')
|
||
group.add_argument('--mask-prob', type=float, default=0.15,
|
||
help='Probability of replacing a token with mask.')
|
||
group.add_argument('--short-seq-prob', type=float, default=0.1,
|
||
help='Probability of producing a short sequence.')
|
||
group.add_argument('--mmap-warmup', action='store_true',
|
||
help='Warm up mmap files.')
|
||
group.add_argument('--num-workers', type=int, default=2,
|
||
help="Dataloader number of workers.")
|
||
group.add_argument('--tokenizer-type', type=str,
|
||
default=None,
|
||
choices=['BertWordPieceLowerCase',
|
||
'BertWordPieceCase',
|
||
'GPT2BPETokenizer',
|
||
'PretrainedFromHF'],
|
||
help='What type of tokenizer to use.')
|
||
group.add_argument("--tokenizer-name-or-path", type=str, default=None,
|
||
help="Name or path of the huggingface tokenizer.")
|
||
group.add_argument("--tokenizer-not-use-fast", action='store_false',
|
||
help="HuggingFace tokenizer not use the fast version.")
|
||
group.add_argument('--data-impl', type=str, default='infer',
|
||
choices=['lazy', 'cached', 'mmap', 'infer'],
|
||
help='Implementation of indexed datasets.')
|
||
group.add_argument('--reset-position-ids', action='store_true',
|
||
help='Reset posistion ids after end-of-document token.')
|
||
group.add_argument('--reset-attention-mask', action='store_true',
|
||
help='Reset self attention maske after '
|
||
'end-of-document token.')
|
||
group.add_argument('--eod-mask-loss', action='store_true',
|
||
help='Mask loss for the end of document tokens.')
|
||
group.add_argument('--loss-on-targets-only', action='store_true',
|
||
help='Mask loss on input sequence.')
|
||
group.add_argument('--train-data-exact-num-epochs', type=int, default=None,
|
||
help='When building the train dataset, force it to be '
|
||
'an exact number of epochs of the raw data')
|
||
group.add_argument('--return-data-index', action='store_true',
|
||
help='Return the index of data sample.')
|
||
group.add_argument('--data-efficiency-curriculum-learning', action='store_true',
|
||
help='Use DeepSpeed data efficiency library curriculum learning feature.')
|
||
group.add_argument('--train-idx-path', type=str, default=None,
|
||
help='Force to use certain index file.')
|
||
group.add_argument('--train-doc-idx-path', type=str, default=None,
|
||
help='Force to use certain index file.')
|
||
group.add_argument('--train-sample-idx-path', type=str, default=None,
|
||
help='Force to use certain index file.')
|
||
group.add_argument('--train-shuffle-idx-path', type=str, default=None,
|
||
help='Force to use certain index file.')
|
||
|
||
group.add_argument('--train-weighted-split-paths', nargs='*', default=None,
|
||
help='Weights, splits and paths to groups of datasets'
|
||
'Accepted format: ONE dataset groups could be'
|
||
'submitted in the following form between double quotes'
|
||
'"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"'
|
||
'e.g.: "NAME_ABC: 0.6 0:0.6 A, 0.3 0:1 B, 0.1 0:1 C" '
|
||
'WEIGHT is used to up and down sample each dataset A,B,C in the group'
|
||
'START:END indicates the split portion of the dataset',
|
||
action=ParseDataPaths)
|
||
group.add_argument('--valid-weighted-split-paths', nargs='*', default=None,
|
||
help='Weights, splits and paths to groups of datasets'
|
||
'Accepted format: one or many dataset groups could be'
|
||
'submitted in the following form each between double quotes'
|
||
'"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"'
|
||
'e.g.: "NAME_ABC: 0.6 0.6:0.8 A, 0.3 0:1 B, 0.1 0:1 C" '
|
||
'"NAME_CDE: 0.6 0.6:0.8 C, 0.3 0:1 D, 0.1 0:1 E" '
|
||
'validation will be run on each of those groups independently',
|
||
action=ParseDataPaths)
|
||
group.add_argument('--test-weighted-split-paths', nargs='*', default=None,
|
||
help='Weights, splits and paths to groups of datasets'
|
||
'Accepted format: one or many dataset groups could be'
|
||
'submitted in the following form each between double quotes'
|
||
'"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"'
|
||
'e.g.: "NAME_ABC: 0.6 0.6:0.8 A, 0.3 0:1 B, 0.1 0:1 C" '
|
||
'"NAME_CDE: 0.6 0.6:0.8 C, 0.3 0:1 D, 0.1 0:1 E" '
|
||
'test will be run on each of those groups independently',
|
||
action=ParseDataPaths)
|
||
|
||
group.add_argument('--train-weighted-split-paths-path', type=str, action=ParseDataPathsPath ,default=None)
|
||
group.add_argument('--valid-weighted-split-paths-path', type=str, action=ParseDataPathsPath, default=None)
|
||
group.add_argument('--test-weighted-split-paths-path', type=str, action=ParseDataPathsPath, default=None)
|
||
|
||
return parser
|
||
|
||
|
||
def _add_autoresume_args(parser):
|
||
group = parser.add_argument_group(title='autoresume')
|
||
|
||
group.add_argument('--adlr-autoresume', action='store_true',
|
||
help='Enable autoresume on adlr cluster.')
|
||
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
|
||
help='Intervals over which check for autoresume'
|
||
'termination signal')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_biencoder_args(parser):
|
||
group = parser.add_argument_group(title='biencoder')
|
||
|
||
# network size
|
||
group.add_argument('--ict-head-size', type=int, default=None,
|
||
help='Size of block embeddings to be used in ICT and '
|
||
'REALM (paper default: 128)')
|
||
group.add_argument('--biencoder-projection-dim', type=int, default=0,
|
||
help='Size of projection head used in biencoder (paper'
|
||
' default: 128)')
|
||
group.add_argument('--biencoder-shared-query-context-model', action='store_true',
|
||
help='Whether to share the parameters of the query '
|
||
'and context models or not')
|
||
|
||
# checkpointing
|
||
group.add_argument('--ict-load', type=str, default=None,
|
||
help='Directory containing an ICTBertModel checkpoint')
|
||
group.add_argument('--bert-load', type=str, default=None,
|
||
help='Directory containing an BertModel checkpoint '
|
||
'(needed to start ICT and REALM)')
|
||
|
||
# data
|
||
group.add_argument('--titles-data-path', type=str, default=None,
|
||
help='Path to titles dataset used for ICT')
|
||
group.add_argument('--query-in-block-prob', type=float, default=0.1,
|
||
help='Probability of keeping query in block for '
|
||
'ICT dataset')
|
||
group.add_argument('--use-one-sent-docs', action='store_true',
|
||
help='Whether to use one sentence documents in ICT')
|
||
group.add_argument('--evidence-data-path', type=str, default=None,
|
||
help='Path to Wikipedia Evidence frm DPR paper')
|
||
|
||
# training
|
||
group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
|
||
default=[], help="Which top-k accuracies to report "
|
||
"(e.g. '1 5 20')")
|
||
group.add_argument('--retriever-score-scaling', action='store_true',
|
||
help='Whether to scale retriever scores by inverse '
|
||
'square root of hidden size')
|
||
|
||
# faiss index
|
||
group.add_argument('--block-data-path', type=str, default=None,
|
||
help='Where to save/load BlockData to/from')
|
||
group.add_argument('--embedding-path', type=str, default=None,
|
||
help='Where to save/load Open-Retrieval Embedding'
|
||
' data to/from')
|
||
|
||
# indexer
|
||
group.add_argument('--indexer-batch-size', type=int, default=128,
|
||
help='How large of batches to use when doing indexing '
|
||
'jobs')
|
||
group.add_argument('--indexer-log-interval', type=int, default=1000,
|
||
help='After how many batches should the indexer '
|
||
'report progress')
|
||
return parser
|
||
|
||
|
||
def _add_vit_args(parser):
|
||
group = parser.add_argument_group(title="vit")
|
||
|
||
group.add_argument('--num-classes', type=int, default=1000,
|
||
help='num of classes in vision classificaiton task')
|
||
group.add_argument('--img-dim', type=int, default=224,
|
||
help='Image size for vision classification task')
|
||
group.add_argument('--num-channels', type=int, default=3,
|
||
help='Number of channels in input image data')
|
||
group.add_argument('--patch-dim', type=int, default=16,
|
||
help='patch dimension used in vit')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_zero_args(parser):
|
||
"""Text generate arguments."""
|
||
|
||
group = parser.add_argument_group('ZeRO configurations', 'configurations')
|
||
group.add_argument("--zero-stage", type=int, default=1.0)
|
||
group.add_argument('--zero-reduce-scatter', action='store_true',
|
||
help='Use reduce scatter if specified')
|
||
group.add_argument('--zero-contigious-gradients', action='store_true',
|
||
help='Use contigious memory optimizaiton if specified')
|
||
group.add_argument("--zero-reduce-bucket-size", type=int, default=0.0)
|
||
group.add_argument("--zero-allgather-bucket-size", type=int, default=0.0)
|
||
group.add_argument('--remote-device', type=str, default='none', choices=['none', 'cpu', 'nvme'],
|
||
help='Remote device for ZeRO-3 initialized parameters.')
|
||
group.add_argument('--use-pin-memory', action='store_true',
|
||
help='Use pinned CPU memory for ZeRO-3 initialized model parameters.')
|
||
return parser
|
||
|
||
|
||
def _add_memoryopt_args(parser):
|
||
"""Memory optimization arguments."""
|
||
|
||
group = parser.add_argument_group('Memory optimizations', 'configurations')
|
||
group.add_argument("--scattered-embeddings", action='store_true',
|
||
help='Save memory by scattering embedding activations. '
|
||
'Introduces dropout differences across MP configurations.')
|
||
group.add_argument("--split-transformers", action='store_true',
|
||
help='Save memory by splitting transformer layers into two parts, '
|
||
'allowing for more frequent activation checkpoint savings.')
|
||
group.add_argument("--memory-centric-tiled-linear", action="store_true",
|
||
help='Save memory by tiling with deepspeed.zero.TiledLinear.')
|
||
group.add_argument("--tile-factor", type=int, default=1,
|
||
help='Make all linear layers the same size of [hidden/tile_factor, hidden/tile_factor]. '
|
||
'Must be enabled with --memory-centric-tiled-linear. '
|
||
'Example A: if tile_factor=1, the qkv layer [hidden, 3* hidden] '
|
||
'would be converted into [1,3] tiles of size [hidden,hidden]. '
|
||
'Example B: if tile_factor=2, the intermediate layer [4*hidden, hidden] '
|
||
'will be converted into [8, 2] tiles of size [hidden/2, hidden/2]. '
|
||
'Default is 1.')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_activation_checkpoint_args(parser):
|
||
group = parser.add_argument_group('Activation Checkpointing',
|
||
'Checkpointing Configurations')
|
||
group.add_argument('--deepspeed-activation-checkpointing', action='store_true',
|
||
help='uses activation checkpointing from deepspeed')
|
||
group.add_argument('--partition-activations', action='store_true',
|
||
help='partition Activations across GPUs before checkpointing.')
|
||
group.add_argument('--contigious-checkpointing', action='store_true',
|
||
help='Contigious memory checkpointing for activatoins.')
|
||
group.add_argument('--checkpoint-in-cpu', action='store_true',
|
||
help='Move the activation checkpoints to CPU.')
|
||
group.add_argument('--synchronize-each-layer', action='store_true',
|
||
help='does a synchronize at the beginning and end of each checkpointed layer.')
|
||
group.add_argument('--profile-backward', action='store_true',
|
||
help='Enables backward pass profiling for checkpointed layers.')
|
||
group.add_argument('--checkpoint_policy', type=str, default='full', choices=['full', 'block'],
|
||
help="activation checkpoint policy")
|
||
group.add_argument('--checkpoint_block_layer', type=int, default=25,
|
||
help="activation checkpoint block layer number")
|
||
return parser
|
||
|
||
|
||
def _add_distillation_args(parser):
|
||
group = parser.add_argument_group('Knowledge distillation',
|
||
'Distillation Configurations')
|
||
|
||
group.add_argument('--num-layers-teacher', type=int, default=None,
|
||
help='Number of the teacher transformer layers.')
|
||
group.add_argument('--num-experts-teacher', type=int, nargs='+', default=[1,],
|
||
help='number of teacher experts list, MoE related.')
|
||
group.add_argument('--hidden-size-teacher', type=int, default=None,
|
||
help='Tansformer teacher hidden size.')
|
||
group.add_argument('--num-attention-heads-teacher', type=int, default=None,
|
||
help='Number of teacher transformer attention heads.')
|
||
|
||
group.add_argument('--mos', action='store_true',
|
||
help='Enable Mixture-of-Students via knolwedge distillation.')
|
||
group.add_argument('--kd', action='store_true',
|
||
help='Enable knolwedge distillation.')
|
||
group.add_argument('--kd-alpha-ce', default=1, type=float)
|
||
group.add_argument('--kd-beta-ce', default=1, type=float)
|
||
group.add_argument('--kd-temp', default=1.0, type=float)
|
||
group.add_argument('--reset-iteration', action='store_true',
|
||
help='Reset the iteration count.')
|
||
|
||
group.add_argument('--load-teacher', type=str, default=None,
|
||
help='Directory containing a teacher model checkpoint.')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_optimized_pipeline_args(parser):
|
||
group = parser.add_argument_group(title='optimized_pipeline')
|
||
|
||
group.add_argument('--optimized-pipeline', action='store_true',
|
||
help='Enable optimized pipeline for bubble time reduction.')
|
||
group.add_argument('--manual-mbs', type=str, default='',
|
||
help='Dynamic micro batches for optimized pipeline. '
|
||
'The format shoud be a sequence of numbers seperated by '
|
||
'comma; e.g., 4,4,4,4. Two examples are provided by '
|
||
'--manual-mbs example-config-1, and '
|
||
'--manual-mbs example-config-2')
|
||
|
||
return parser
|
||
|
||
|
||
def _add_manual_layer_allocation(parser):
|
||
group = parser.add_argument_group(title='manual_layer_allocation')
|
||
group.add_argument('--use-manual-layer-allocation', action='store_true',
|
||
help='Enable manually allocated layers for pipeline model parallel.')
|
||
group.add_argument('--manual-layers', type=str, help='a list of number of layers, '
|
||
'seperated by comma; e.g., 4,4,4,4')
|
||
|
||
return parser
|
||
|
||
|
||
def _get_manual_layer_allocation(args=None):
|
||
if args is not None and args.use_manual_layer_allocation:
|
||
assert args.manual_layers is not None
|
||
manual_layers = list(map(int, args.manual_layers.split(',')))
|
||
assert len(manual_layers) == args.pipeline_model_parallel_size
|
||
args.manual_layers = manual_layers
|
||
|
||
|
||
def _add_lora_args(parser):
|
||
group = parser.add_argument_group(title='lora')
|
||
|
||
group.add_argument('--lora-target-modules', nargs='+', type=str, default=[],
|
||
help='Lora target modules.')
|
||
group.add_argument('--lora-load', type=str, default=None,
|
||
help='Directory containing a lora model checkpoint.')
|
||
group.add_argument('--lora-r', type=int, default=16,
|
||
help='Lora r.')
|
||
group.add_argument('--lora-alpha', type=int, default=32,
|
||
help='Lora alpha.')
|
||
group.add_argument('--lora-modules-to-save', nargs='+', type=str, default=None,
|
||
help='Lora modules to save.')
|
||
group.add_argument('--lora-register-forward-hook', nargs='+', type=str,
|
||
default=['word_embeddings', 'input_layernorm'],
|
||
help='Lora register forward hook.')
|
||
group.add_argument('--lora-adapter-name', type=str, default='default',
|
||
help='Lora adapter name.')
|
||
|
||
return parser
|