mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 05:17:40 +08:00
!1246 更新baichuan2-13B性能至1668
Merge pull request !1246 from fengliangjun/master
This commit is contained in:
parent
4109f95dfd
commit
791677c135
10
README.md
10
README.md
@ -306,7 +306,7 @@ ModelLink旨在为华为 [昇腾芯片](https://open.codehub.huawei.com/OpenBaiz
|
||||
<td>13B</td>
|
||||
<td> 1x8</td>
|
||||
<td> BF16 </td>
|
||||
<td> 1550 </td>
|
||||
<td> 1668 </td>
|
||||
<td> 2062 </td>
|
||||
<td> <a href="examples/baichuan2/pretrain_baichuan2_ptd_13B.sh">训练</a> </td>
|
||||
</tr>
|
||||
@ -959,10 +959,12 @@ ModelLink支持张量并行、流水线并行、序列并行、重计算、分
|
||||
| 序列并行 | --sequence-parallel |
|
||||
| 重计算 | --recompute-granularity |
|
||||
| 分布式优化器 | --use-distributed-optimizer |
|
||||
| DDP allreduce 掩盖 | --overlap-grad-reduce |
|
||||
| DDP allreduce 掩盖 | --overlap-grad-reduce |
|
||||
| DDP allgather 掩盖 | --overlap-param-gather |
|
||||
| Flash attention | --use-flash-attn |
|
||||
| Fused rmsnorm | --use-fused-rmsnorm |
|
||||
| Fused swiglu | --use-fused-swiglu |
|
||||
| Fused swiglu | --use-fused-swiglu |
|
||||
| mc2 | --use-mc2 |
|
||||
| Fused rotary <br/>position embedding | --use-fused-rotary-pos-emb |
|
||||
|
||||
|
||||
@ -980,7 +982,9 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
|
||||
--use-fused-rmsnorm \
|
||||
--use-fused-swiglu \
|
||||
--overlap-grad-reduce \
|
||||
--overlap-param-gather \
|
||||
--use-fused-rotary-pos-emb \
|
||||
--use-mc2 \
|
||||
... \
|
||||
...
|
||||
```
|
||||
|
15
README_en.md
15
README_en.md
@ -307,7 +307,7 @@ For the supported models listed above, we provide training scripts and readme in
|
||||
<td>13B</td>
|
||||
<td> 1x8</td>
|
||||
<td> BF16 </td>
|
||||
<td> 1550 </td>
|
||||
<td> 1668 </td>
|
||||
<td> 2062 </td>
|
||||
<td> <a href="examples/baichuan2/pretrain_baichuan2_ptd_13B.sh">train</a> </td>
|
||||
</tr>
|
||||
@ -966,16 +966,21 @@ ModelLink supports various acceleration algorithms such as tensor parallelism, p
|
||||
| Sequence Parallel | --sequence-parallel |
|
||||
| Recomputation | --recompute-granularity |
|
||||
| Distributed Optimizer | --use-distributed-optimizer |
|
||||
| overlap DDP allreduce | --overlap-grad-reduce |
|
||||
| overlap DDP allreduce | --overlap-grad-reduce |
|
||||
| overlap DDP allgather | --overlap-param-gather |
|
||||
| Flash attention | --use-flash-attn |
|
||||
| Fused rmsnorm | --use-fused-rmsnorm |
|
||||
| Fused swiglu | --use-fused-swiglu |
|
||||
| Fused swiglu | --use-fused-swiglu |
|
||||
| mc2 | --use-mc2 |
|
||||
| Fused rotary <br/>position embedding | --use-fused-rotary-pos-emb |
|
||||
|
||||
|
||||
|
||||
```bash
|
||||
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
|
||||
--tensor-model-parallel-size ${TP} \
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--num-layer-list 1,2,2,2,1 \
|
||||
--sequence-parallel \
|
||||
--recompute-granularity full \
|
||||
--recompute-method block \
|
||||
@ -984,8 +989,10 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
|
||||
--use-flash-attn \
|
||||
--use-fused-rmsnorm \
|
||||
--use-fused-swiglu \
|
||||
--use-fused-rotary-pos-emb \
|
||||
--overlap-grad-reduce \
|
||||
--overlap-param-gather \
|
||||
--use-fused-rotary-pos-emb \
|
||||
--use-mc2 \
|
||||
... \
|
||||
...
|
||||
```
|
||||
|
@ -416,7 +416,7 @@ Baichuan2-13B 在 **昇腾芯片** 和 **参考芯片** 上的性能对比:
|
||||
|
||||
| 设备 | 模型 | 迭代数 | 样本吞吐 (samples/s) | token吞吐 (tokens/p/s) | 单步迭代时间 (s/step) |
|
||||
|:----:|:-------------------------:|:----:|:----------------:|:--------------------:|:---------------:|
|
||||
| NPUs | Baichuan2-13B | 1000 | - | 1550 | - |
|
||||
| NPUs | Baichuan2-13B | 1000 | - | 1668 | - |
|
||||
| 参考 | Baichuan2-13B | - | - | 2062 | - |
|
||||
|
||||
|
||||
|
@ -415,7 +415,7 @@ The performance of the Baichuan2-13B in **Ascend NPU** and **Reference**:
|
||||
|
||||
| Device | Model | total Iterations | throughput rate (samples/s/p) | throughput rate (tokens/s/p) | single-step time (s/step) |
|
||||
|:----:|:-------------------------:|:----:|:-----------------------------:|:----------------------------:|:-------------------------:|
|
||||
| NPUs | Baichuan2-13B |1000 | - | 1550 | - |
|
||||
| NPUs | Baichuan2-13B |1000 | - | 1668 | - |
|
||||
| Reference | Baichuan2-13B |-| - | 2062 | - |
|
||||
|
||||
|
||||
|
@ -57,6 +57,8 @@ GPT_ARGS="
|
||||
--normalization RMSNorm \
|
||||
--use-fused-rmsnorm \
|
||||
--use-flash-attn \
|
||||
--use-fused-swiglu \
|
||||
--use-mc2 \
|
||||
--swiglu \
|
||||
--no-masked-softmax-fusion \
|
||||
--attention-softmax-in-fp32 \
|
||||
@ -72,7 +74,7 @@ GPT_ARGS="
|
||||
--adam-eps 1.0e-8 \
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--fp16
|
||||
--bf16
|
||||
"
|
||||
|
||||
DATA_ARGS="
|
||||
|
@ -312,7 +312,7 @@ python ./tools/preprocess_data.py \
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
# 修改数据集,词表,权重等路径
|
||||
CKPT_SAVE_DIR="./ckpt/internlm-65b/"
|
||||
TOKENIZER_PATH="./model_from_hf/internlm-65b/" #词表路径
|
||||
TOKENIZER_PATH="./model_from_hf/internlm-65b/tokenizer.model" #词表路径
|
||||
DATA_PATH="./dataset/internlm-65b/alpaca_text_document" #数据集路径
|
||||
```
|
||||
|
||||
|
@ -312,7 +312,7 @@ python ./tools/preprocess_data.py \
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
# modify script orign dataset path according to your own dataset path
|
||||
CKPT_SAVE_DIR="./ckpt/internlm-65b/"
|
||||
TOKENIZER_PATH="./model_from_hf/internlm-65b/" #tokenizer path
|
||||
TOKENIZER_PATH="./model_from_hf/internlm-65b/tokenizer.model" #tokenizer path
|
||||
DATA_PATH="./dataset/internlm-65b/alpaca_text_document" #processed dataset
|
||||
```
|
||||
|
||||
|
@ -54,11 +54,12 @@ GPT_ARGS="
|
||||
--no-masked-softmax-fusion \
|
||||
--attention-softmax-in-fp32 \
|
||||
--min-lr 1.0e-7 \
|
||||
--weight-decay 1e-2 \
|
||||
--weight-decay 0.1 \
|
||||
--clip-grad 1.0 \
|
||||
--adam-beta1 0.9 \
|
||||
--initial-loss-scale 4096.0 \
|
||||
--adam-beta2 0.999 \
|
||||
--adam-beta2 0.95 \
|
||||
--adam-eps 1e-5 \
|
||||
--no-gradient-accumulation-fusion \
|
||||
--load ${CKPT_LOAD_DIR} \
|
||||
--no-load-optim \
|
||||
|
@ -169,6 +169,8 @@ def _add_network_size_args(parser):
|
||||
help="Use fused swiglu.")
|
||||
group.add_argument("--use-fused-rotary-pos-emb", action='store_true',
|
||||
help="Use fused rotary-pos-emb.")
|
||||
group.add_argument("--use-mc2", action='store_true',
|
||||
help="Use mc2 for compute-comm overlap in tp.")
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -0,0 +1,14 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
14
modellink/core/tensor_parallel/ascend_turbo/__init__.py
Normal file
14
modellink/core/tensor_parallel/ascend_turbo/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,56 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class AscendConfig:
|
||||
def __init__(self):
|
||||
self.ColumnParallelLinear = None
|
||||
self.RowParallelLinear = None
|
||||
self.group_func = None
|
||||
self.world_size_func = None
|
||||
|
||||
self.sequence_parallel_enabled = True
|
||||
self.all_gather_recomputation = True
|
||||
|
||||
def set_sequence_parallel(self, sequence_parallel):
|
||||
self.sequence_parallel = sequence_parallel
|
||||
|
||||
def set_all_gather_recomputation(self, all_gather_recomputation):
|
||||
self.all_gather_recomputation = all_gather_recomputation
|
||||
|
||||
def set_group(self, group_func):
|
||||
self.group_func = group_func
|
||||
|
||||
def get_group(self):
|
||||
return self.group_func()
|
||||
|
||||
def set_world_size(self, world_size_func):
|
||||
self.world_size_func = world_size_func
|
||||
|
||||
def get_world_size(self):
|
||||
return self.world_size_func()
|
||||
|
||||
def set_column_parallel_linear(self, column_parallel_linear):
|
||||
self.ColumnParallelLinear = column_parallel_linear
|
||||
|
||||
def set_row_parallel_linear(self, row_parallel_linear):
|
||||
self.RowParallelLinear = row_parallel_linear
|
||||
|
||||
def parallel_linear_plugin(self, column_parallel_forward, row_parallel_forward):
|
||||
self.ColumnParallelLinear.forward = column_parallel_forward
|
||||
self.RowParallelLinear.forward = row_parallel_forward
|
||||
|
||||
|
||||
ascend_turbo_cfg = AscendConfig()
|
72
modellink/core/tensor_parallel/ascend_turbo/initialize.py
Normal file
72
modellink/core/tensor_parallel/ascend_turbo/initialize.py
Normal file
@ -0,0 +1,72 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from megatron.core.parallel_state import (
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_world_size
|
||||
)
|
||||
|
||||
from .ascend_turbo_cfg import ascend_turbo_cfg
|
||||
from .mc2_linears_seq_parallel import ColumnSeqParallelLinear, RowSeqParallelLinear
|
||||
|
||||
|
||||
def column_parallel_forward(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
output = ColumnSeqParallelLinear.apply(
|
||||
input_,
|
||||
self.weight,
|
||||
bias,
|
||||
ascend_turbo_cfg.get_group()
|
||||
)
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def row_parallel_forward(self, input_):
|
||||
output = RowSeqParallelLinear.apply(
|
||||
input_,
|
||||
self.weight,
|
||||
None,
|
||||
ascend_turbo_cfg.get_group()
|
||||
)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
output_bias = None
|
||||
else:
|
||||
output_bias = self.bias
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
def initialize_cfg_from_framework():
|
||||
ascend_turbo_cfg.set_group(get_tensor_model_parallel_group)
|
||||
ascend_turbo_cfg.set_world_size(get_tensor_model_parallel_world_size)
|
||||
|
||||
ascend_turbo_cfg.set_column_parallel_linear(ColumnParallelLinear)
|
||||
ascend_turbo_cfg.set_row_parallel_linear(RowParallelLinear)
|
||||
ascend_turbo_cfg.parallel_linear_plugin(column_parallel_forward, row_parallel_forward)
|
||||
|
||||
|
||||
def initialize_cfg_from_args(args):
|
||||
if not args.sequence_parallel or args.tensor_model_parallel_size == 1:
|
||||
return
|
||||
|
||||
ascend_turbo_cfg.set_sequence_parallel(args.sequence_parallel)
|
||||
ascend_turbo_cfg.set_all_gather_recomputation(True)
|
||||
initialize_cfg_from_framework()
|
@ -0,0 +1,180 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from .ascend_turbo_cfg import ascend_turbo_cfg
|
||||
|
||||
|
||||
class ColumnSeqParallelLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, group):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
hcomm_info = None
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
|
||||
global_rank
|
||||
)
|
||||
|
||||
else:
|
||||
hcomm_info = group.get_hccl_comm_name(rank)
|
||||
|
||||
x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
|
||||
|
||||
world_size = ascend_turbo_cfg.get_world_size()
|
||||
output, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
|
||||
x,
|
||||
weight.t(),
|
||||
hcomm_info,
|
||||
world_size,
|
||||
bias=bias,
|
||||
gather_index=0,
|
||||
gather_output=(not ascend_turbo_cfg.all_gather_recomputation),
|
||||
)
|
||||
|
||||
output = output.view(
|
||||
int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
|
||||
)
|
||||
|
||||
ctx.all_gather_output = all_gather_grad_output
|
||||
ctx.world_size = world_size
|
||||
ctx.group = group
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight = ctx.saved_tensors
|
||||
|
||||
grad_output_ = grad_output.reshape(
|
||||
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
|
||||
)
|
||||
|
||||
if ascend_turbo_cfg.all_gather_recomputation:
|
||||
dim_size = list(input_.size())
|
||||
dim_size[0] = dim_size[0] * ctx.world_size
|
||||
all_gather_output = torch.empty(
|
||||
dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False,
|
||||
)
|
||||
all_gather_work = torch.distributed._all_gather_base(
|
||||
all_gather_output, input_.contiguous(), group=ctx.group, async_op=True
|
||||
)
|
||||
else:
|
||||
all_gather_output = ctx.all_gather_output
|
||||
|
||||
grad_input = grad_output_.matmul(weight)
|
||||
grad_input = grad_input.reshape(
|
||||
grad_output.shape[0], grad_output.shape[1], weight.shape[1]
|
||||
)
|
||||
|
||||
sub_grad_input = torch.empty(
|
||||
list(input_.size()), dtype=input_.dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
reduce_scatter_work = torch.distributed._reduce_scatter_base(
|
||||
sub_grad_input, grad_input, group=ctx.group, async_op=True
|
||||
)
|
||||
|
||||
if ascend_turbo_cfg.all_gather_recomputation:
|
||||
all_gather_work.wait()
|
||||
all_gather_output = all_gather_output.reshape(
|
||||
all_gather_output.shape[0] * all_gather_output.shape[1],
|
||||
all_gather_output.shape[2],
|
||||
)
|
||||
|
||||
grad_weight = grad_output_.t().matmul(all_gather_output)
|
||||
|
||||
is_grad_bias_needed = ctx.needs_input_grad[2]
|
||||
if is_grad_bias_needed and ctx.use_bias:
|
||||
grad_bias = (
|
||||
grad_output_.sum(dim=0)
|
||||
if grad_output_.is_contiguous()
|
||||
else grad_output_.t().sum(dim=1)
|
||||
)
|
||||
else:
|
||||
grad_bias = None
|
||||
|
||||
reduce_scatter_work.wait()
|
||||
return sub_grad_input, grad_weight, grad_bias, None
|
||||
|
||||
|
||||
class RowSeqParallelLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, group):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
|
||||
rank = torch.distributed.get_rank(group)
|
||||
world_size = ascend_turbo_cfg.get_world_size()
|
||||
hcomm_info = None
|
||||
if torch.__version__ > "2.0":
|
||||
global_rank = torch.distributed.get_global_rank(group, rank)
|
||||
hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
|
||||
global_rank
|
||||
)
|
||||
else:
|
||||
hcomm_info = group.get_hccl_comm_name(rank)
|
||||
|
||||
x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
|
||||
|
||||
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||
x, weight.t(), hcomm_info, world_size, reduce_op="sum", bias=bias
|
||||
)
|
||||
|
||||
ctx.hcomm_info = hcomm_info
|
||||
ctx.world_size = world_size
|
||||
|
||||
output = output.view(
|
||||
int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight = ctx.saved_tensors
|
||||
hcomm_info = ctx.hcomm_info
|
||||
world_size = ctx.world_size
|
||||
|
||||
grad_output_ = grad_output.reshape(
|
||||
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
|
||||
)
|
||||
|
||||
grad_input, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
|
||||
grad_output_, weight, hcomm_info, world_size, bias=None, gather_index=0
|
||||
)
|
||||
|
||||
grad_input = grad_input.view_as(input_)
|
||||
|
||||
x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
|
||||
grad_weight = all_gather_grad_output.t().matmul(x)
|
||||
|
||||
is_grad_bias_needed = ctx.needs_input_grad[2]
|
||||
if is_grad_bias_needed and ctx.use_bias:
|
||||
grad_bias = (
|
||||
grad_output.sum(dim=0)
|
||||
if grad_output.is_contiguous()
|
||||
else grad_output.t().sum(dim=1)
|
||||
)
|
||||
else:
|
||||
grad_bias = None
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None
|
@ -1,6 +1,21 @@
|
||||
import time
|
||||
import torch
|
||||
|
||||
import megatron
|
||||
from megatron import get_args
|
||||
from megatron.core import mpu
|
||||
from megatron.arguments import validate_args
|
||||
from megatron.checkpointing import load_args_from_checkpoint
|
||||
from megatron.global_vars import set_global_variables
|
||||
from megatron.initialize import (
|
||||
_initialize_distributed, _set_random_seed,
|
||||
_init_autoresume, _initialize_tp_communicators
|
||||
)
|
||||
|
||||
from modellink.arguments import parse_args_decorator
|
||||
from modellink.core.tensor_parallel.ascend_turbo.initialize import initialize_cfg_from_args
|
||||
from modellink.error_utils import ensure_valid
|
||||
|
||||
|
||||
def _compile_dependencies():
|
||||
device_count = torch.cuda.device_count()
|
||||
@ -12,4 +27,81 @@ def _compile_dependencies():
|
||||
from megatron.core.datasets.utils import compile_helpers
|
||||
compile_helpers()
|
||||
print('>>> done with dataset index builder. Compilation time: {:.3f} '
|
||||
'seconds'.format(time.time() - start_time), flush=True)
|
||||
'seconds'.format(time.time() - start_time), flush=True)
|
||||
|
||||
|
||||
def initialize_megatron(
|
||||
extra_args_provider=None,
|
||||
args_defaults={},
|
||||
ignore_unknown_args=False,
|
||||
allow_no_cuda=False,
|
||||
skip_mpu_initialization=False,
|
||||
):
|
||||
"""Set global variables, initialize distributed, and
|
||||
set autoresume and random seeds.
|
||||
`allow_no_cuda` should not be set unless using megatron for cpu only
|
||||
data processing. In general this arg should not be set unless you know
|
||||
what you are doing.
|
||||
Returns a function to finalize distributed env initialization
|
||||
(optionally, only when args.lazy_mpu_init == True)
|
||||
"""
|
||||
if not allow_no_cuda:
|
||||
# Make sure cuda is available.
|
||||
ensure_valid(torch.cuda.is_available(), "Megatron requires CUDA.")
|
||||
|
||||
# Parse arguments
|
||||
parse_args = parse_args_decorator(megatron.arguments.parse_args)
|
||||
args = parse_args(extra_args_provider, ignore_unknown_args)
|
||||
|
||||
if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
|
||||
ensure_valid(args.load is not None,
|
||||
"--use-checkpoints-args requires --load argument")
|
||||
load_args_from_checkpoint(args)
|
||||
|
||||
validate_args(args, args_defaults)
|
||||
|
||||
# set global args, build tokenizer, and set adlr-autoresume,
|
||||
# tensorboard-writer, and timers.
|
||||
set_global_variables(args)
|
||||
|
||||
# torch.distributed initialization
|
||||
def finish_mpu_init():
|
||||
args = get_args()
|
||||
# Pytorch distributed.
|
||||
_initialize_distributed()
|
||||
|
||||
# Random seeds for reproducibility.
|
||||
if args.rank == 0:
|
||||
print("> setting random seeds to {} ...".format(args.seed))
|
||||
_set_random_seed(args.seed, args.data_parallel_random_init)
|
||||
if args.use_mc2:
|
||||
initialize_cfg_from_args(args)
|
||||
|
||||
if skip_mpu_initialization:
|
||||
return None
|
||||
|
||||
args = get_args()
|
||||
if args.lazy_mpu_init:
|
||||
args.use_cpu_initialization = True
|
||||
# delayed initialization of DDP-related stuff
|
||||
# We only set basic DDP globals
|
||||
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
|
||||
# and return function for external DDP manager
|
||||
# to call when it has DDP initialized
|
||||
mpu.set_tensor_model_parallel_rank(args.rank)
|
||||
return finish_mpu_init
|
||||
else:
|
||||
# Megatron's MPU is the master. Complete initialization right away.
|
||||
finish_mpu_init()
|
||||
|
||||
# Autoresume.
|
||||
_init_autoresume()
|
||||
|
||||
# Compile dependencies.
|
||||
_compile_dependencies()
|
||||
|
||||
if args.tp_comm_overlap:
|
||||
_initialize_tp_communicators()
|
||||
|
||||
# No continuation function
|
||||
return None
|
||||
|
@ -42,6 +42,7 @@ from ..data import build_pretraining_data_loader
|
||||
from ..tokenizer import build_tokenizer
|
||||
from ..arguments import parse_args_decorator
|
||||
from ..checkpointing import _load_base_checkpoint_wrapper, load_checkpoint_wrapper
|
||||
from ..initialize import initialize_megatron
|
||||
|
||||
|
||||
def exec_patch():
|
||||
@ -99,7 +100,7 @@ def patch_tensor_parallel():
|
||||
megatron.core.tensor_parallel.random._set_cuda_rng_state = _set_cuda_rng_state # default_generators need replace after set_device
|
||||
megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward = _VocabParallelCrossEntropyForward # change masked_target for better performance
|
||||
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward = vocab_embedding_wrapper(
|
||||
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward)
|
||||
VocabParallelEmbeddingForward)
|
||||
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__ = norm_wrapper(
|
||||
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__)
|
||||
|
||||
@ -169,6 +170,7 @@ def patch_initialize():
|
||||
megatron.initialize._compile_dependencies = _compile_dependencies # remove cuda kernel compile
|
||||
megatron.initialize.set_jit_fusion_options = set_jit_fusion_options # remove cuda jit nvfuser
|
||||
megatron.initialize.parse_args = parse_args_decorator(megatron.initialize.parse_args)
|
||||
megatron.initialize.initialize_megatron = initialize_megatron
|
||||
|
||||
|
||||
def patch_training():
|
||||
|
Loading…
Reference in New Issue
Block a user