!1246 更新baichuan2-13B性能至1668

Merge pull request !1246 from fengliangjun/master
This commit is contained in:
fengliangjun 2024-04-26 01:47:52 +00:00 committed by i-robot
parent 4109f95dfd
commit 791677c135
16 changed files with 462 additions and 16 deletions

View File

@ -306,7 +306,7 @@ ModelLink旨在为华为 [昇腾芯片](https://open.codehub.huawei.com/OpenBaiz
<td>13B</td>
<td> 1x8</td>
<td> BF16 </td>
<td> 1550 </td>
<td> 1668 </td>
<td> 2062 </td>
<td> <a href="examples/baichuan2/pretrain_baichuan2_ptd_13B.sh">训练</a> </td>
</tr>
@ -959,10 +959,12 @@ ModelLink支持张量并行、流水线并行、序列并行、重计算、分
| 序列并行 | --sequence-parallel |
| 重计算 | --recompute-granularity |
| 分布式优化器 | --use-distributed-optimizer |
| DDP allreduce 掩盖 | --overlap-grad-reduce |
| DDP allreduce 掩盖 | --overlap-grad-reduce |
| DDP allgather 掩盖 | --overlap-param-gather |
| Flash attention | --use-flash-attn |
| Fused rmsnorm | --use-fused-rmsnorm |
| Fused swiglu | --use-fused-swiglu |
| Fused swiglu | --use-fused-swiglu |
| mc2 | --use-mc2 |
| Fused rotary <br/>position embedding | --use-fused-rotary-pos-emb |
@ -980,7 +982,9 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
--use-fused-rmsnorm \
--use-fused-swiglu \
--overlap-grad-reduce \
--overlap-param-gather \
--use-fused-rotary-pos-emb \
--use-mc2 \
... \
...
```

View File

@ -307,7 +307,7 @@ For the supported models listed above, we provide training scripts and readme in
<td>13B</td>
<td> 1x8</td>
<td> BF16 </td>
<td> 1550 </td>
<td> 1668 </td>
<td> 2062 </td>
<td> <a href="examples/baichuan2/pretrain_baichuan2_ptd_13B.sh">train</a> </td>
</tr>
@ -966,16 +966,21 @@ ModelLink supports various acceleration algorithms such as tensor parallelism, p
| Sequence Parallel | --sequence-parallel |
| Recomputation | --recompute-granularity |
| Distributed Optimizer | --use-distributed-optimizer |
| overlap DDP allreduce | --overlap-grad-reduce |
| overlap DDP allreduce | --overlap-grad-reduce |
| overlap DDP allgather | --overlap-param-gather |
| Flash attention | --use-flash-attn |
| Fused rmsnorm | --use-fused-rmsnorm |
| Fused swiglu | --use-fused-swiglu |
| Fused swiglu | --use-fused-swiglu |
| mc2 | --use-mc2 |
| Fused rotary <br/>position embedding | --use-fused-rotary-pos-emb |
```bash
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--num-layer-list 1,2,2,2,1 \
--sequence-parallel \
--recompute-granularity full \
--recompute-method block \
@ -984,8 +989,10 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
--use-flash-attn \
--use-fused-rmsnorm \
--use-fused-swiglu \
--use-fused-rotary-pos-emb \
--overlap-grad-reduce \
--overlap-param-gather \
--use-fused-rotary-pos-emb \
--use-mc2 \
... \
...
```

View File

@ -416,7 +416,7 @@ Baichuan2-13B 在 **昇腾芯片** 和 **参考芯片** 上的性能对比:
| 设备 | 模型 | 迭代数 | 样本吞吐 (samples/s) | token吞吐 (tokens/p/s) | 单步迭代时间 (s/step) |
|:----:|:-------------------------:|:----:|:----------------:|:--------------------:|:---------------:|
| NPUs | Baichuan2-13B | 1000 | - | 1550 | - |
| NPUs | Baichuan2-13B | 1000 | - | 1668 | - |
| 参考 | Baichuan2-13B | - | - | 2062 | - |

View File

@ -415,7 +415,7 @@ The performance of the Baichuan2-13B in **Ascend NPU** and **Reference**:
| Device | Model | total Iterations | throughput rate (samples/s/p) | throughput rate (tokens/s/p) | single-step time (s/step) |
|:----:|:-------------------------:|:----:|:-----------------------------:|:----------------------------:|:-------------------------:|
| NPUs | Baichuan2-13B |1000 | - | 1550 | - |
| NPUs | Baichuan2-13B |1000 | - | 1668 | - |
| Reference | Baichuan2-13B |-| - | 2062 | - |

View File

@ -57,6 +57,8 @@ GPT_ARGS="
--normalization RMSNorm \
--use-fused-rmsnorm \
--use-flash-attn \
--use-fused-swiglu \
--use-mc2 \
--swiglu \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
@ -72,7 +74,7 @@ GPT_ARGS="
--adam-eps 1.0e-8 \
--no-load-optim \
--no-load-rng \
--fp16
--bf16
"
DATA_ARGS="

View File

@ -312,7 +312,7 @@ python ./tools/preprocess_data.py \
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 修改数据集,词表,权重等路径
CKPT_SAVE_DIR="./ckpt/internlm-65b/"
TOKENIZER_PATH="./model_from_hf/internlm-65b/" #词表路径
TOKENIZER_PATH="./model_from_hf/internlm-65b/tokenizer.model" #词表路径
DATA_PATH="./dataset/internlm-65b/alpaca_text_document" #数据集路径
```

View File

@ -312,7 +312,7 @@ python ./tools/preprocess_data.py \
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# modify script orign dataset path according to your own dataset path
CKPT_SAVE_DIR="./ckpt/internlm-65b/"
TOKENIZER_PATH="./model_from_hf/internlm-65b/" #tokenizer path
TOKENIZER_PATH="./model_from_hf/internlm-65b/tokenizer.model" #tokenizer path
DATA_PATH="./dataset/internlm-65b/alpaca_text_document" #processed dataset
```

View File

@ -54,11 +54,12 @@ GPT_ARGS="
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--min-lr 1.0e-7 \
--weight-decay 1e-2 \
--weight-decay 0.1 \
--clip-grad 1.0 \
--adam-beta1 0.9 \
--initial-loss-scale 4096.0 \
--adam-beta2 0.999 \
--adam-beta2 0.95 \
--adam-eps 1e-5 \
--no-gradient-accumulation-fusion \
--load ${CKPT_LOAD_DIR} \
--no-load-optim \

View File

@ -169,6 +169,8 @@ def _add_network_size_args(parser):
help="Use fused swiglu.")
group.add_argument("--use-fused-rotary-pos-emb", action='store_true',
help="Use fused rotary-pos-emb.")
group.add_argument("--use-mc2", action='store_true',
help="Use mc2 for compute-comm overlap in tp.")
return parser

View File

@ -0,0 +1,14 @@
# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,14 @@
# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,56 @@
# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class AscendConfig:
def __init__(self):
self.ColumnParallelLinear = None
self.RowParallelLinear = None
self.group_func = None
self.world_size_func = None
self.sequence_parallel_enabled = True
self.all_gather_recomputation = True
def set_sequence_parallel(self, sequence_parallel):
self.sequence_parallel = sequence_parallel
def set_all_gather_recomputation(self, all_gather_recomputation):
self.all_gather_recomputation = all_gather_recomputation
def set_group(self, group_func):
self.group_func = group_func
def get_group(self):
return self.group_func()
def set_world_size(self, world_size_func):
self.world_size_func = world_size_func
def get_world_size(self):
return self.world_size_func()
def set_column_parallel_linear(self, column_parallel_linear):
self.ColumnParallelLinear = column_parallel_linear
def set_row_parallel_linear(self, row_parallel_linear):
self.RowParallelLinear = row_parallel_linear
def parallel_linear_plugin(self, column_parallel_forward, row_parallel_forward):
self.ColumnParallelLinear.forward = column_parallel_forward
self.RowParallelLinear.forward = row_parallel_forward
ascend_turbo_cfg = AscendConfig()

View File

@ -0,0 +1,72 @@
# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size
)
from .ascend_turbo_cfg import ascend_turbo_cfg
from .mc2_linears_seq_parallel import ColumnSeqParallelLinear, RowSeqParallelLinear
def column_parallel_forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
output = ColumnSeqParallelLinear.apply(
input_,
self.weight,
bias,
ascend_turbo_cfg.get_group()
)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def row_parallel_forward(self, input_):
output = RowSeqParallelLinear.apply(
input_,
self.weight,
None,
ascend_turbo_cfg.get_group()
)
if not self.skip_bias_add:
output = output + self.bias if self.bias is not None else output
output_bias = None
else:
output_bias = self.bias
return output, output_bias
def initialize_cfg_from_framework():
ascend_turbo_cfg.set_group(get_tensor_model_parallel_group)
ascend_turbo_cfg.set_world_size(get_tensor_model_parallel_world_size)
ascend_turbo_cfg.set_column_parallel_linear(ColumnParallelLinear)
ascend_turbo_cfg.set_row_parallel_linear(RowParallelLinear)
ascend_turbo_cfg.parallel_linear_plugin(column_parallel_forward, row_parallel_forward)
def initialize_cfg_from_args(args):
if not args.sequence_parallel or args.tensor_model_parallel_size == 1:
return
ascend_turbo_cfg.set_sequence_parallel(args.sequence_parallel)
ascend_turbo_cfg.set_all_gather_recomputation(True)
initialize_cfg_from_framework()

View File

@ -0,0 +1,180 @@
# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch_npu
from .ascend_turbo_cfg import ascend_turbo_cfg
class ColumnSeqParallelLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias, group):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
rank = torch.distributed.get_rank(group)
hcomm_info = None
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
global_rank
)
else:
hcomm_info = group.get_hccl_comm_name(rank)
x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
world_size = ascend_turbo_cfg.get_world_size()
output, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
x,
weight.t(),
hcomm_info,
world_size,
bias=bias,
gather_index=0,
gather_output=(not ascend_turbo_cfg.all_gather_recomputation),
)
output = output.view(
int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
)
ctx.all_gather_output = all_gather_grad_output
ctx.world_size = world_size
ctx.group = group
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
grad_output_ = grad_output.reshape(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
if ascend_turbo_cfg.all_gather_recomputation:
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * ctx.world_size
all_gather_output = torch.empty(
dim_size,
dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
all_gather_work = torch.distributed._all_gather_base(
all_gather_output, input_.contiguous(), group=ctx.group, async_op=True
)
else:
all_gather_output = ctx.all_gather_output
grad_input = grad_output_.matmul(weight)
grad_input = grad_input.reshape(
grad_output.shape[0], grad_output.shape[1], weight.shape[1]
)
sub_grad_input = torch.empty(
list(input_.size()), dtype=input_.dtype, device=torch.cuda.current_device()
)
reduce_scatter_work = torch.distributed._reduce_scatter_base(
sub_grad_input, grad_input, group=ctx.group, async_op=True
)
if ascend_turbo_cfg.all_gather_recomputation:
all_gather_work.wait()
all_gather_output = all_gather_output.reshape(
all_gather_output.shape[0] * all_gather_output.shape[1],
all_gather_output.shape[2],
)
grad_weight = grad_output_.t().matmul(all_gather_output)
is_grad_bias_needed = ctx.needs_input_grad[2]
if is_grad_bias_needed and ctx.use_bias:
grad_bias = (
grad_output_.sum(dim=0)
if grad_output_.is_contiguous()
else grad_output_.t().sum(dim=1)
)
else:
grad_bias = None
reduce_scatter_work.wait()
return sub_grad_input, grad_weight, grad_bias, None
class RowSeqParallelLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias, group):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
rank = torch.distributed.get_rank(group)
world_size = ascend_turbo_cfg.get_world_size()
hcomm_info = None
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
global_rank
)
else:
hcomm_info = group.get_hccl_comm_name(rank)
x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
output = torch_npu.npu_mm_reduce_scatter_base(
x, weight.t(), hcomm_info, world_size, reduce_op="sum", bias=bias
)
ctx.hcomm_info = hcomm_info
ctx.world_size = world_size
output = output.view(
int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
hcomm_info = ctx.hcomm_info
world_size = ctx.world_size
grad_output_ = grad_output.reshape(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
grad_input, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
grad_output_, weight, hcomm_info, world_size, bias=None, gather_index=0
)
grad_input = grad_input.view_as(input_)
x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
grad_weight = all_gather_grad_output.t().matmul(x)
is_grad_bias_needed = ctx.needs_input_grad[2]
if is_grad_bias_needed and ctx.use_bias:
grad_bias = (
grad_output.sum(dim=0)
if grad_output.is_contiguous()
else grad_output.t().sum(dim=1)
)
else:
grad_bias = None
return grad_input, grad_weight, grad_bias, None

View File

@ -1,6 +1,21 @@
import time
import torch
import megatron
from megatron import get_args
from megatron.core import mpu
from megatron.arguments import validate_args
from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables
from megatron.initialize import (
_initialize_distributed, _set_random_seed,
_init_autoresume, _initialize_tp_communicators
)
from modellink.arguments import parse_args_decorator
from modellink.core.tensor_parallel.ascend_turbo.initialize import initialize_cfg_from_args
from modellink.error_utils import ensure_valid
def _compile_dependencies():
device_count = torch.cuda.device_count()
@ -12,4 +27,81 @@ def _compile_dependencies():
from megatron.core.datasets.utils import compile_helpers
compile_helpers()
print('>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'.format(time.time() - start_time), flush=True)
'seconds'.format(time.time() - start_time), flush=True)
def initialize_megatron(
extra_args_provider=None,
args_defaults={},
ignore_unknown_args=False,
allow_no_cuda=False,
skip_mpu_initialization=False,
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if not allow_no_cuda:
# Make sure cuda is available.
ensure_valid(torch.cuda.is_available(), "Megatron requires CUDA.")
# Parse arguments
parse_args = parse_args_decorator(megatron.arguments.parse_args)
args = parse_args(extra_args_provider, ignore_unknown_args)
if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
ensure_valid(args.load is not None,
"--use-checkpoints-args requires --load argument")
load_args_from_checkpoint(args)
validate_args(args, args_defaults)
# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(args)
# torch.distributed initialization
def finish_mpu_init():
args = get_args()
# Pytorch distributed.
_initialize_distributed()
# Random seeds for reproducibility.
if args.rank == 0:
print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init)
if args.use_mc2:
initialize_cfg_from_args(args)
if skip_mpu_initialization:
return None
args = get_args()
if args.lazy_mpu_init:
args.use_cpu_initialization = True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager
# to call when it has DDP initialized
mpu.set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init()
# Autoresume.
_init_autoresume()
# Compile dependencies.
_compile_dependencies()
if args.tp_comm_overlap:
_initialize_tp_communicators()
# No continuation function
return None

View File

@ -42,6 +42,7 @@ from ..data import build_pretraining_data_loader
from ..tokenizer import build_tokenizer
from ..arguments import parse_args_decorator
from ..checkpointing import _load_base_checkpoint_wrapper, load_checkpoint_wrapper
from ..initialize import initialize_megatron
def exec_patch():
@ -99,7 +100,7 @@ def patch_tensor_parallel():
megatron.core.tensor_parallel.random._set_cuda_rng_state = _set_cuda_rng_state # default_generators need replace after set_device
megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward = _VocabParallelCrossEntropyForward # change masked_target for better performance
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward = vocab_embedding_wrapper(
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward)
VocabParallelEmbeddingForward)
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__ = norm_wrapper(
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__)
@ -169,6 +170,7 @@ def patch_initialize():
megatron.initialize._compile_dependencies = _compile_dependencies # remove cuda kernel compile
megatron.initialize.set_jit_fusion_options = set_jit_fusion_options # remove cuda jit nvfuser
megatron.initialize.parse_args = parse_args_decorator(megatron.initialize.parse_args)
megatron.initialize.initialize_megatron = initialize_megatron
def patch_training():