format some code.

This commit is contained in:
chenzomi 2023-07-21 01:42:26 +08:00 committed by dilililiwhy
parent 4455b80650
commit 937791fa6d
8 changed files with 101 additions and 47 deletions

View File

@ -1,9 +1,11 @@
## AscendSpeed
Acceleration library for large language models in Ascend platform, which is developed based on Megatron and Deepspeed. Currently, GPT and LLaMA with PTD parallel strategy are supported.
### Quick Start
1. Clone the repository to your local server:
```bash
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
@ -12,6 +14,7 @@ mkdir ckpt
```
2. Download GPT dataset or LLama dataset or Bloom dataset:
```bash
mv dataset AscendSpeed
cd dataset
@ -35,12 +38,14 @@ mkdir -p dataset/bloom_vocab/vocab_file
```
3. Build conda envs:
```bash
conda create -n py37 python=3.7
conda activate py37
```
4. Install pytorch and dependency:
```bash
pip install apex-0.1_ascend_XXXX-cp37-cp37m-linux_XXXX.whl
pip install torch-1.11.0+cpu-cp37-cp37m-linux_XXXXX.whl
@ -50,6 +55,7 @@ pip install -r requirements
```
5. Install deepspeed and adapter:
```bash
pip install deepspeed==0.9.2
git clone https://gitee.com/ascend/DeepSpeed.git (Deepspeed adapter v0.9.2, not master!)
@ -59,26 +65,31 @@ python setup.py develop
```
6. Start your task
```bash
sh examples/gpt_task/pretrain_gpt_ptd_8p.sh (Demon gpt with 200M parameters)
sh examples/llama_task/pretrain_llama_ptd_8p.sh (Demon llama with 250M parameters)
sh examples/gpt/pretrain_gpt_ptd_8p.sh (Demon gpt with 200M parameters)
sh examples/llama/pretrain_llama_ptd_16B.sh (Demon llama with 16B parameters)
```
# Using HuggingFace Tokenizer
Llama Tokenizer
------
1. Set `--tokenizer-name-or-path` in the training script :
```bash
# examples/llama_task/pretrain_llama_ptd_8p.sh
# examples/llama/pretrain_llama_ptd_8p.sh
--tokenizer-name-or-path ./dataset/llama/ \
--tokenizer-not-use-fast \
```
2. Remove `--vocab-file` and`--merge-file` arguments.
3. Make sure the `tokenizer_type` of `args_defaults` in `pretrain_llama.py` is `PretrainedFromHF`.
Below is ascendspeed original README:
Below is AscendSpeed original README:
------
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""initialization."""
import random
@ -68,7 +67,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults=None,
_set_random_seed(args.seed)
args = get_args()
if args.lazy_mpu_init:
if args.lazy_mpu_init:
args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
@ -84,7 +83,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults=None,
# Initialize memory buffers.
_initialize_mem_buffs()
# Autoresume.
# Auto resume.
_init_autoresume()
# Compile dependencies.
@ -103,6 +102,7 @@ def _compile_dependencies():
print('>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'.format(time.time() - start_time), flush=True)
def setup_deepspeed_random_and_activation_checkpointing(args):
'''Optional DeepSpeed Activation Checkpointing features.
Gives access to partition activations, contiguous memory optimizations
@ -137,19 +137,26 @@ def setup_deepspeed_random_and_activation_checkpointing(args):
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
args = get_args()
device_count = get_accelerator().device_count()
if torch.distributed.is_initialized():
# 当前进程所在的node上可使用的GPU的数量
device_count = get_accelerator().device_count()
# 如果已创建好分布式环境
if torch.distributed.is_initialized():
# 在0号进程上打印出“创建完毕”的日志
if args.rank == 0:
print('torch distributed is already initialized, '
'skipping initialization ...', flush=True)
'skipping initialization ...',
flush=True)
# 取得当前进程的全局序号
args.rank = torch.distributed.get_rank()
# 取得全局进程的个数
args.world_size = torch.distributed.get_world_size()
else:
if args.rank == 0:
print('> initializing torch distributed ...', flush=True)
# Manually set the device ids.
# 1. 初始化进程分配GPU并设置进程大组group
if device_count > 0:
device = args.rank % device_count
if args.local_rank is not None:
@ -161,9 +168,10 @@ def _initialize_distributed():
get_accelerator().set_device(device) # only do so when device_count > 0
# Call the init process
# 设置进程大组
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
master_ip = os.getenv('MASTER_ADDR', 'localhost') # 获取rank=0进程的ip
master_port = os.getenv('MASTER_PORT', '6000') # 获取rank=0进程的端口
init_method += master_ip + ':' + master_port
if args.deepspeed or args.ds_inference:
@ -173,8 +181,10 @@ def _initialize_distributed():
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
# 2、制定DP/TP/PP分组策略设置进程子组subgroup
if device_count > 0:
if mpu.model_parallel_is_initialized():
print('model parallel is already initialized')
@ -183,12 +193,13 @@ def _initialize_distributed():
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size)
# 设置DeepSpeed ZeRO-R对activation进行优化
if args.deepspeed and args.deepspeed_activation_checkpointing:
setup_deepspeed_random_and_activation_checkpointing(args)
def _init_autoresume():
"""Set autoresume start time."""
"""Set auto resume start time."""
autoresume = get_adlr_autoresume()
if autoresume:
torch.distributed.barrier()

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer based language model."""
import torch
@ -29,6 +28,7 @@ from ascendspeed.model.utils import init_method_normal, scaled_init_method_norma
from ascendspeed.mpu.mappings import gather_from_sequence_parallel_region
from ascendspeed.mpu.initialize import get_global_memory_buffer
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
"""LM logits using word embedding weights."""

View File

@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module"""
"""AscendSpeed Module"""
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from deepspeed.accelerator import get_accelerator
from ascendspeed import get_args
from ascendspeed import mpu
@ -28,7 +28,6 @@ _HALF_TYPES = (torch.HalfTensor, get_accelerator().HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared

View File

@ -1,4 +1,21 @@
# coding=utf-8
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Extracted from: https://github.com/EleutherAI/gpt-neox
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AscendSpeed Module"""
import torch

View File

@ -84,16 +84,23 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
tensor_model_parallel_size_))
print('> initializing pipeline model parallel with size {}'.format(
pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
# 确保torch已经做了分布式初始化
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
world_size = torch.distributed.get_world_size() # 得到全局进程的总数
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
# 后者表示一个完整模型所占的gpu数我们要保证前者能被后者整除
ensure_divisibility(world_size,
tensor_model_parallel_size * pipeline_model_parallel_size)
# 根据TP_size和PP_size求出DP_size
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
# 设置 tp、pp、dp的组数
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
@ -104,9 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
# 获取当前进程的全局rank
rank = torch.distributed.get_rank()
# Build the data-parallel groups.
# Build the data-parallel groups. 设置DP组
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
@ -122,7 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
if rank in ranks:
_DATA_PARALLEL_GROUP = group
# Build the model-parallel groups.
# Build the model-parallel groups. 设置MP组
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \
'model parallel group is already initialized'
@ -133,7 +141,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
# Build the tensor model-parallel groups. 设置TP组
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
@ -146,6 +154,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
# 设置PP组与embedding组
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
@ -156,7 +165,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size,
num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks) # 设置PP组
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
@ -166,7 +175,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
embedding_ranks = [ranks[0], ranks[-1]]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
group = torch.distributed.new_group(embedding_ranks) # 设置embedding组
if rank in embedding_ranks:
_EMBEDDING_GROUP = group

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain utilities."""
from datetime import datetime
@ -25,6 +24,11 @@ import json
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.compression.compress import init_compression, redundancy_clean
from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd
from ascendspeed import get_args
from ascendspeed import get_timers
from ascendspeed import get_tensorboard_writer
@ -54,11 +58,6 @@ from ascendspeed.schedules import optimized_forward_backward_pipelining
from ascendspeed.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator
from ascendspeed.model.transformer import ParallelTransformerLayer
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.compression.compress import init_compression, redundancy_clean
from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
@ -121,7 +120,7 @@ def pretrain(train_valid_test_dataset_provider,
1) initialize ascendspeed.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) train the modle using the forward_step_func.
4) train the model using the forward_step_func.
Arguments:
train_valid_test_dataset_provider: a function that takes the size of
@ -139,7 +138,8 @@ def pretrain(train_valid_test_dataset_provider,
to set already parse arguments.
"""
# Initalize and get arguments, timers, and Tensorboard writer.
# Initialize and get arguments, timers, and TensorBoard writer.
# 1.初始化分布式环境
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
@ -177,6 +177,7 @@ def pretrain(train_valid_test_dataset_provider,
args.compression_training = True
# Model, optimizer, and learning rate.
# 2、模型并行定义模型架构并切割模型
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(
model_provider, teacher=False, data_post_process=data_post_process,
@ -187,6 +188,7 @@ def pretrain(train_valid_test_dataset_provider,
'scheduler are built')
# Data stuff.
# 3、构造train/val/test数据集
timers('train/valid/test-data-iterators-setup').start()
if args.virtual_pipeline_model_parallel_size is not None:
all_data_iterators = [
@ -234,6 +236,7 @@ def pretrain(train_valid_test_dataset_provider,
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
print_rank_0('training ...')
# 4、正式训练
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
@ -354,7 +357,6 @@ def get_model(model_provider_func):
post_process=post_process
)
if not isinstance(model, list):
model = [model]
@ -383,7 +385,6 @@ def get_model(model_provider_func):
device_name = get_accelerator().current_device_name()
print_rank_0(f"model to {device_name}")
model_module.to(device_name)
# Fp16 conversion.
if args.fp16 or args.bf16:
@ -480,13 +481,14 @@ def load_model_weights_only(model_provider_func):
print_datetime('before load checkpoint')
if args.load is not None:
iteration = load_checkpoint(model, optimizer, lr_scheduler, strict=True, load_only_weights=True)
print_datetime('after load checkpoint weights')
return model, optimizer, lr_scheduler
def setup_model_and_optimizer(model_provider_func, teacher=False,
data_post_process=None, build_train_valid_test_datasets_provider=None):
def setup_model_and_optimizer(model_provider_func,
teacher=False,
data_post_process=None,
build_train_valid_test_datasets_provider=None):
"""Setup model and optimizer."""
args = get_args()

View File

@ -1,4 +1,5 @@
# coding=utf-8
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -12,15 +13,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain Llama."""
"""Pretrain Llama"""
import math
from functools import partial
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
import deepspeed
import deepspeed_npu
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.accelerator.real_accelerator import get_accelerator
from ascendspeed import get_args
from ascendspeed import print_rank_0
@ -33,16 +38,12 @@ from ascendspeed.training import pretrain
from ascendspeed.utils import get_ltor_masks_and_position_ids
from ascendspeed.utils import average_losses_across_data_parallel_group
import deepspeed
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.accelerator.real_accelerator import get_accelerator
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building llama model ...')
see_memory_usage(f"Before Building Model", force=True)
print_rank_0('Building llama model ...')
see_memory_usage(f"Before Building Model ...", force=True)
args = get_args()
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
@ -60,7 +61,8 @@ def model_provider(pre_process=True, post_process=True):
# pipeline it as an activation during training. The mask is constant, and thus
# we can reuse it.
attention_mask = torch.tril(torch.ones(
(1, args.seq_length, args.seq_length), device=get_accelerator().current_device_name())).view(
(1, args.seq_length, args.seq_length),
device=get_accelerator().current_device_name())).view(
1, 1, args.seq_length, args.seq_length)
# Convert attention mask to binary:
@ -72,7 +74,6 @@ def model_provider(pre_process=True, post_process=True):
# Attention mask must be bool.
args.attn_mask = attention_mask.to(torch.bool)
else:
model = LlamaModel(
parallel_output=True,
@ -170,6 +171,7 @@ def get_batch_pipe(data):
def loss_func(loss_mask, output_tensor):
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
@ -182,6 +184,7 @@ def loss_func(loss_mask, output_tensor):
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
@ -212,6 +215,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
pretrain(train_valid_test_datasets_provider,
model_provider,
forward_step,
args_defaults={'tokenizer_type': 'PretrainedFromHF'},
data_post_process=data_post_process)