diff --git a/README.md b/README.md
index 9cd5e062c..c9512a488 100644
--- a/README.md
+++ b/README.md
@@ -306,7 +306,7 @@ ModelLink旨在为华为 [昇腾芯片](https://open.codehub.huawei.com/OpenBaiz
13B |
1x8 |
BF16 |
- 1550 |
+ 1668 |
2062 |
训练 |
@@ -959,10 +959,12 @@ ModelLink支持张量并行、流水线并行、序列并行、重计算、分
| 序列并行 | --sequence-parallel |
| 重计算 | --recompute-granularity |
| 分布式优化器 | --use-distributed-optimizer |
-| DDP allreduce 掩盖 | --overlap-grad-reduce |
+| DDP allreduce 掩盖 | --overlap-grad-reduce |
+| DDP allgather 掩盖 | --overlap-param-gather |
| Flash attention | --use-flash-attn |
| Fused rmsnorm | --use-fused-rmsnorm |
-| Fused swiglu | --use-fused-swiglu |
+| Fused swiglu | --use-fused-swiglu |
+| mc2 | --use-mc2 |
| Fused rotary
position embedding | --use-fused-rotary-pos-emb |
@@ -980,7 +982,9 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
--use-fused-rmsnorm \
--use-fused-swiglu \
--overlap-grad-reduce \
+ --overlap-param-gather \
--use-fused-rotary-pos-emb \
+ --use-mc2 \
... \
...
```
diff --git a/README_en.md b/README_en.md
index e5f03e124..7ac059c5e 100644
--- a/README_en.md
+++ b/README_en.md
@@ -307,7 +307,7 @@ For the supported models listed above, we provide training scripts and readme in
13B |
1x8 |
BF16 |
- 1550 |
+ 1668 |
2062 |
train |
@@ -966,16 +966,21 @@ ModelLink supports various acceleration algorithms such as tensor parallelism, p
| Sequence Parallel | --sequence-parallel |
| Recomputation | --recompute-granularity |
| Distributed Optimizer | --use-distributed-optimizer |
-| overlap DDP allreduce | --overlap-grad-reduce |
+| overlap DDP allreduce | --overlap-grad-reduce |
+| overlap DDP allgather | --overlap-param-gather |
| Flash attention | --use-flash-attn |
| Fused rmsnorm | --use-fused-rmsnorm |
-| Fused swiglu | --use-fused-swiglu |
+| Fused swiglu | --use-fused-swiglu |
+| mc2 | --use-mc2 |
| Fused rotary
position embedding | --use-fused-rotary-pos-emb |
+
+
```bash
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
+ --num-layer-list 1,2,2,2,1 \
--sequence-parallel \
--recompute-granularity full \
--recompute-method block \
@@ -984,8 +989,10 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
--use-flash-attn \
--use-fused-rmsnorm \
--use-fused-swiglu \
- --use-fused-rotary-pos-emb \
--overlap-grad-reduce \
+ --overlap-param-gather \
+ --use-fused-rotary-pos-emb \
+ --use-mc2 \
... \
...
```
diff --git a/examples/baichuan2/README.md b/examples/baichuan2/README.md
index 92e62d735..c8e492897 100644
--- a/examples/baichuan2/README.md
+++ b/examples/baichuan2/README.md
@@ -416,7 +416,7 @@ Baichuan2-13B 在 **昇腾芯片** 和 **参考芯片** 上的性能对比:
| 设备 | 模型 | 迭代数 | 样本吞吐 (samples/s) | token吞吐 (tokens/p/s) | 单步迭代时间 (s/step) |
|:----:|:-------------------------:|:----:|:----------------:|:--------------------:|:---------------:|
-| NPUs | Baichuan2-13B | 1000 | - | 1550 | - |
+| NPUs | Baichuan2-13B | 1000 | - | 1668 | - |
| 参考 | Baichuan2-13B | - | - | 2062 | - |
diff --git a/examples/baichuan2/README_en.md b/examples/baichuan2/README_en.md
index 18e977462..7126fc72b 100644
--- a/examples/baichuan2/README_en.md
+++ b/examples/baichuan2/README_en.md
@@ -415,7 +415,7 @@ The performance of the Baichuan2-13B in **Ascend NPU** and **Reference**:
| Device | Model | total Iterations | throughput rate (samples/s/p) | throughput rate (tokens/s/p) | single-step time (s/step) |
|:----:|:-------------------------:|:----:|:-----------------------------:|:----------------------------:|:-------------------------:|
-| NPUs | Baichuan2-13B |1000 | - | 1550 | - |
+| NPUs | Baichuan2-13B |1000 | - | 1668 | - |
| Reference | Baichuan2-13B |-| - | 2062 | - |
diff --git a/examples/baichuan2/pretrain_baichuan2_ptd_13B.sh b/examples/baichuan2/pretrain_baichuan2_ptd_13B.sh
index bdf8a1d7e..34d722f5d 100644
--- a/examples/baichuan2/pretrain_baichuan2_ptd_13B.sh
+++ b/examples/baichuan2/pretrain_baichuan2_ptd_13B.sh
@@ -57,6 +57,8 @@ GPT_ARGS="
--normalization RMSNorm \
--use-fused-rmsnorm \
--use-flash-attn \
+ --use-fused-swiglu \
+ --use-mc2 \
--swiglu \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
@@ -72,7 +74,7 @@ GPT_ARGS="
--adam-eps 1.0e-8 \
--no-load-optim \
--no-load-rng \
- --fp16
+ --bf16
"
DATA_ARGS="
diff --git a/examples/intern/README.md b/examples/intern/README.md
index 9b99ac6b6..187665634 100644
--- a/examples/intern/README.md
+++ b/examples/intern/README.md
@@ -312,7 +312,7 @@ python ./tools/preprocess_data.py \
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 修改数据集,词表,权重等路径
CKPT_SAVE_DIR="./ckpt/internlm-65b/"
-TOKENIZER_PATH="./model_from_hf/internlm-65b/" #词表路径
+TOKENIZER_PATH="./model_from_hf/internlm-65b/tokenizer.model" #词表路径
DATA_PATH="./dataset/internlm-65b/alpaca_text_document" #数据集路径
```
diff --git a/examples/intern/README_en.md b/examples/intern/README_en.md
index 56d77e91c..7ee10001c 100644
--- a/examples/intern/README_en.md
+++ b/examples/intern/README_en.md
@@ -312,7 +312,7 @@ python ./tools/preprocess_data.py \
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# modify script orign dataset path according to your own dataset path
CKPT_SAVE_DIR="./ckpt/internlm-65b/"
-TOKENIZER_PATH="./model_from_hf/internlm-65b/" #tokenizer path
+TOKENIZER_PATH="./model_from_hf/internlm-65b/tokenizer.model" #tokenizer path
DATA_PATH="./dataset/internlm-65b/alpaca_text_document" #processed dataset
```
diff --git a/examples/llama2/pretrain_llama2_70b_ptd.sh b/examples/llama2/pretrain_llama2_70b_ptd.sh
index cec37aa0a..800f9b046 100644
--- a/examples/llama2/pretrain_llama2_70b_ptd.sh
+++ b/examples/llama2/pretrain_llama2_70b_ptd.sh
@@ -54,11 +54,12 @@ GPT_ARGS="
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--min-lr 1.0e-7 \
- --weight-decay 1e-2 \
+ --weight-decay 0.1 \
--clip-grad 1.0 \
--adam-beta1 0.9 \
--initial-loss-scale 4096.0 \
- --adam-beta2 0.999 \
+ --adam-beta2 0.95 \
+ --adam-eps 1e-5 \
--no-gradient-accumulation-fusion \
--load ${CKPT_LOAD_DIR} \
--no-load-optim \
diff --git a/modellink/arguments.py b/modellink/arguments.py
index 68b4b6b1f..190755a16 100644
--- a/modellink/arguments.py
+++ b/modellink/arguments.py
@@ -169,6 +169,8 @@ def _add_network_size_args(parser):
help="Use fused swiglu.")
group.add_argument("--use-fused-rotary-pos-emb", action='store_true',
help="Use fused rotary-pos-emb.")
+ group.add_argument("--use-mc2", action='store_true',
+ help="Use mc2 for compute-comm overlap in tp.")
return parser
diff --git a/modellink/core/tensor_parallel/__init__.py b/modellink/core/tensor_parallel/__init__.py
index e69de29bb..9a1307aec 100644
--- a/modellink/core/tensor_parallel/__init__.py
+++ b/modellink/core/tensor_parallel/__init__.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
\ No newline at end of file
diff --git a/modellink/core/tensor_parallel/ascend_turbo/__init__.py b/modellink/core/tensor_parallel/ascend_turbo/__init__.py
new file mode 100644
index 000000000..aaf493892
--- /dev/null
+++ b/modellink/core/tensor_parallel/ascend_turbo/__init__.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/modellink/core/tensor_parallel/ascend_turbo/ascend_turbo_cfg.py b/modellink/core/tensor_parallel/ascend_turbo/ascend_turbo_cfg.py
new file mode 100644
index 000000000..a49a1046c
--- /dev/null
+++ b/modellink/core/tensor_parallel/ascend_turbo/ascend_turbo_cfg.py
@@ -0,0 +1,56 @@
+# coding=utf-8
+# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class AscendConfig:
+ def __init__(self):
+ self.ColumnParallelLinear = None
+ self.RowParallelLinear = None
+ self.group_func = None
+ self.world_size_func = None
+
+ self.sequence_parallel_enabled = True
+ self.all_gather_recomputation = True
+
+ def set_sequence_parallel(self, sequence_parallel):
+ self.sequence_parallel = sequence_parallel
+
+ def set_all_gather_recomputation(self, all_gather_recomputation):
+ self.all_gather_recomputation = all_gather_recomputation
+
+ def set_group(self, group_func):
+ self.group_func = group_func
+
+ def get_group(self):
+ return self.group_func()
+
+ def set_world_size(self, world_size_func):
+ self.world_size_func = world_size_func
+
+ def get_world_size(self):
+ return self.world_size_func()
+
+ def set_column_parallel_linear(self, column_parallel_linear):
+ self.ColumnParallelLinear = column_parallel_linear
+
+ def set_row_parallel_linear(self, row_parallel_linear):
+ self.RowParallelLinear = row_parallel_linear
+
+ def parallel_linear_plugin(self, column_parallel_forward, row_parallel_forward):
+ self.ColumnParallelLinear.forward = column_parallel_forward
+ self.RowParallelLinear.forward = row_parallel_forward
+
+
+ascend_turbo_cfg = AscendConfig()
diff --git a/modellink/core/tensor_parallel/ascend_turbo/initialize.py b/modellink/core/tensor_parallel/ascend_turbo/initialize.py
new file mode 100644
index 000000000..f66648c63
--- /dev/null
+++ b/modellink/core/tensor_parallel/ascend_turbo/initialize.py
@@ -0,0 +1,72 @@
+# coding=utf-8
+# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
+from megatron.core.parallel_state import (
+ get_tensor_model_parallel_group,
+ get_tensor_model_parallel_world_size
+)
+
+from .ascend_turbo_cfg import ascend_turbo_cfg
+from .mc2_linears_seq_parallel import ColumnSeqParallelLinear, RowSeqParallelLinear
+
+
+def column_parallel_forward(self, input_):
+ bias = self.bias if not self.skip_bias_add else None
+
+ output = ColumnSeqParallelLinear.apply(
+ input_,
+ self.weight,
+ bias,
+ ascend_turbo_cfg.get_group()
+ )
+
+ output_bias = self.bias if self.skip_bias_add else None
+ return output, output_bias
+
+
+def row_parallel_forward(self, input_):
+ output = RowSeqParallelLinear.apply(
+ input_,
+ self.weight,
+ None,
+ ascend_turbo_cfg.get_group()
+ )
+
+ if not self.skip_bias_add:
+ output = output + self.bias if self.bias is not None else output
+ output_bias = None
+ else:
+ output_bias = self.bias
+
+ return output, output_bias
+
+
+def initialize_cfg_from_framework():
+ ascend_turbo_cfg.set_group(get_tensor_model_parallel_group)
+ ascend_turbo_cfg.set_world_size(get_tensor_model_parallel_world_size)
+
+ ascend_turbo_cfg.set_column_parallel_linear(ColumnParallelLinear)
+ ascend_turbo_cfg.set_row_parallel_linear(RowParallelLinear)
+ ascend_turbo_cfg.parallel_linear_plugin(column_parallel_forward, row_parallel_forward)
+
+
+def initialize_cfg_from_args(args):
+ if not args.sequence_parallel or args.tensor_model_parallel_size == 1:
+ return
+
+ ascend_turbo_cfg.set_sequence_parallel(args.sequence_parallel)
+ ascend_turbo_cfg.set_all_gather_recomputation(True)
+ initialize_cfg_from_framework()
diff --git a/modellink/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py b/modellink/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py
new file mode 100644
index 000000000..5d5292797
--- /dev/null
+++ b/modellink/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py
@@ -0,0 +1,180 @@
+# coding=utf-8
+# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch_npu
+from .ascend_turbo_cfg import ascend_turbo_cfg
+
+
+class ColumnSeqParallelLinear(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias, group):
+ ctx.save_for_backward(input_, weight)
+ ctx.use_bias = bias is not None
+
+ rank = torch.distributed.get_rank(group)
+ hcomm_info = None
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
+ global_rank
+ )
+
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+
+ x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
+
+ world_size = ascend_turbo_cfg.get_world_size()
+ output, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
+ x,
+ weight.t(),
+ hcomm_info,
+ world_size,
+ bias=bias,
+ gather_index=0,
+ gather_output=(not ascend_turbo_cfg.all_gather_recomputation),
+ )
+
+ output = output.view(
+ int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
+ )
+
+ ctx.all_gather_output = all_gather_grad_output
+ ctx.world_size = world_size
+ ctx.group = group
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+
+ grad_output_ = grad_output.reshape(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+
+ if ascend_turbo_cfg.all_gather_recomputation:
+ dim_size = list(input_.size())
+ dim_size[0] = dim_size[0] * ctx.world_size
+ all_gather_output = torch.empty(
+ dim_size,
+ dtype=input_.dtype,
+ device=torch.cuda.current_device(),
+ requires_grad=False,
+ )
+ all_gather_work = torch.distributed._all_gather_base(
+ all_gather_output, input_.contiguous(), group=ctx.group, async_op=True
+ )
+ else:
+ all_gather_output = ctx.all_gather_output
+
+ grad_input = grad_output_.matmul(weight)
+ grad_input = grad_input.reshape(
+ grad_output.shape[0], grad_output.shape[1], weight.shape[1]
+ )
+
+ sub_grad_input = torch.empty(
+ list(input_.size()), dtype=input_.dtype, device=torch.cuda.current_device()
+ )
+ reduce_scatter_work = torch.distributed._reduce_scatter_base(
+ sub_grad_input, grad_input, group=ctx.group, async_op=True
+ )
+
+ if ascend_turbo_cfg.all_gather_recomputation:
+ all_gather_work.wait()
+ all_gather_output = all_gather_output.reshape(
+ all_gather_output.shape[0] * all_gather_output.shape[1],
+ all_gather_output.shape[2],
+ )
+
+ grad_weight = grad_output_.t().matmul(all_gather_output)
+
+ is_grad_bias_needed = ctx.needs_input_grad[2]
+ if is_grad_bias_needed and ctx.use_bias:
+ grad_bias = (
+ grad_output_.sum(dim=0)
+ if grad_output_.is_contiguous()
+ else grad_output_.t().sum(dim=1)
+ )
+ else:
+ grad_bias = None
+
+ reduce_scatter_work.wait()
+ return sub_grad_input, grad_weight, grad_bias, None
+
+
+class RowSeqParallelLinear(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input_, weight, bias, group):
+ ctx.save_for_backward(input_, weight)
+ ctx.use_bias = bias is not None
+
+ rank = torch.distributed.get_rank(group)
+ world_size = ascend_turbo_cfg.get_world_size()
+ hcomm_info = None
+ if torch.__version__ > "2.0":
+ global_rank = torch.distributed.get_global_rank(group, rank)
+ hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
+ global_rank
+ )
+ else:
+ hcomm_info = group.get_hccl_comm_name(rank)
+
+ x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
+
+ output = torch_npu.npu_mm_reduce_scatter_base(
+ x, weight.t(), hcomm_info, world_size, reduce_op="sum", bias=bias
+ )
+
+ ctx.hcomm_info = hcomm_info
+ ctx.world_size = world_size
+
+ output = output.view(
+ int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
+ )
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+ hcomm_info = ctx.hcomm_info
+ world_size = ctx.world_size
+
+ grad_output_ = grad_output.reshape(
+ grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
+ )
+
+ grad_input, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
+ grad_output_, weight, hcomm_info, world_size, bias=None, gather_index=0
+ )
+
+ grad_input = grad_input.view_as(input_)
+
+ x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
+ grad_weight = all_gather_grad_output.t().matmul(x)
+
+ is_grad_bias_needed = ctx.needs_input_grad[2]
+ if is_grad_bias_needed and ctx.use_bias:
+ grad_bias = (
+ grad_output.sum(dim=0)
+ if grad_output.is_contiguous()
+ else grad_output.t().sum(dim=1)
+ )
+ else:
+ grad_bias = None
+
+ return grad_input, grad_weight, grad_bias, None
diff --git a/modellink/initialize.py b/modellink/initialize.py
index 938a741ef..56ce63025 100644
--- a/modellink/initialize.py
+++ b/modellink/initialize.py
@@ -1,6 +1,21 @@
import time
import torch
+import megatron
+from megatron import get_args
+from megatron.core import mpu
+from megatron.arguments import validate_args
+from megatron.checkpointing import load_args_from_checkpoint
+from megatron.global_vars import set_global_variables
+from megatron.initialize import (
+ _initialize_distributed, _set_random_seed,
+ _init_autoresume, _initialize_tp_communicators
+)
+
+from modellink.arguments import parse_args_decorator
+from modellink.core.tensor_parallel.ascend_turbo.initialize import initialize_cfg_from_args
+from modellink.error_utils import ensure_valid
+
def _compile_dependencies():
device_count = torch.cuda.device_count()
@@ -12,4 +27,81 @@ def _compile_dependencies():
from megatron.core.datasets.utils import compile_helpers
compile_helpers()
print('>>> done with dataset index builder. Compilation time: {:.3f} '
- 'seconds'.format(time.time() - start_time), flush=True)
\ No newline at end of file
+ 'seconds'.format(time.time() - start_time), flush=True)
+
+
+def initialize_megatron(
+ extra_args_provider=None,
+ args_defaults={},
+ ignore_unknown_args=False,
+ allow_no_cuda=False,
+ skip_mpu_initialization=False,
+):
+ """Set global variables, initialize distributed, and
+ set autoresume and random seeds.
+ `allow_no_cuda` should not be set unless using megatron for cpu only
+ data processing. In general this arg should not be set unless you know
+ what you are doing.
+ Returns a function to finalize distributed env initialization
+ (optionally, only when args.lazy_mpu_init == True)
+ """
+ if not allow_no_cuda:
+ # Make sure cuda is available.
+ ensure_valid(torch.cuda.is_available(), "Megatron requires CUDA.")
+
+ # Parse arguments
+ parse_args = parse_args_decorator(megatron.arguments.parse_args)
+ args = parse_args(extra_args_provider, ignore_unknown_args)
+
+ if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
+ ensure_valid(args.load is not None,
+ "--use-checkpoints-args requires --load argument")
+ load_args_from_checkpoint(args)
+
+ validate_args(args, args_defaults)
+
+ # set global args, build tokenizer, and set adlr-autoresume,
+ # tensorboard-writer, and timers.
+ set_global_variables(args)
+
+ # torch.distributed initialization
+ def finish_mpu_init():
+ args = get_args()
+ # Pytorch distributed.
+ _initialize_distributed()
+
+ # Random seeds for reproducibility.
+ if args.rank == 0:
+ print("> setting random seeds to {} ...".format(args.seed))
+ _set_random_seed(args.seed, args.data_parallel_random_init)
+ if args.use_mc2:
+ initialize_cfg_from_args(args)
+
+ if skip_mpu_initialization:
+ return None
+
+ args = get_args()
+ if args.lazy_mpu_init:
+ args.use_cpu_initialization = True
+ # delayed initialization of DDP-related stuff
+ # We only set basic DDP globals
+ mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
+ # and return function for external DDP manager
+ # to call when it has DDP initialized
+ mpu.set_tensor_model_parallel_rank(args.rank)
+ return finish_mpu_init
+ else:
+ # Megatron's MPU is the master. Complete initialization right away.
+ finish_mpu_init()
+
+ # Autoresume.
+ _init_autoresume()
+
+ # Compile dependencies.
+ _compile_dependencies()
+
+ if args.tp_comm_overlap:
+ _initialize_tp_communicators()
+
+ # No continuation function
+ return None
diff --git a/modellink/patchs/megatron_patch.py b/modellink/patchs/megatron_patch.py
index e4da9b1dc..b89041380 100644
--- a/modellink/patchs/megatron_patch.py
+++ b/modellink/patchs/megatron_patch.py
@@ -42,6 +42,7 @@ from ..data import build_pretraining_data_loader
from ..tokenizer import build_tokenizer
from ..arguments import parse_args_decorator
from ..checkpointing import _load_base_checkpoint_wrapper, load_checkpoint_wrapper
+from ..initialize import initialize_megatron
def exec_patch():
@@ -99,7 +100,7 @@ def patch_tensor_parallel():
megatron.core.tensor_parallel.random._set_cuda_rng_state = _set_cuda_rng_state # default_generators need replace after set_device
megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward = _VocabParallelCrossEntropyForward # change masked_target for better performance
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward = vocab_embedding_wrapper(
- megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward)
+ VocabParallelEmbeddingForward)
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__ = norm_wrapper(
megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__)
@@ -169,6 +170,7 @@ def patch_initialize():
megatron.initialize._compile_dependencies = _compile_dependencies # remove cuda kernel compile
megatron.initialize.set_jit_fusion_options = set_jit_fusion_options # remove cuda jit nvfuser
megatron.initialize.parse_args = parse_args_decorator(megatron.initialize.parse_args)
+ megatron.initialize.initialize_megatron = initialize_megatron
def patch_training():