diff --git a/README.md b/README.md index 9cd5e062c..c9512a488 100644 --- a/README.md +++ b/README.md @@ -306,7 +306,7 @@ ModelLink旨在为华为 [昇腾芯片](https://open.codehub.huawei.com/OpenBaiz 13B 1x8 BF16 - 1550 + 1668 2062 训练 @@ -959,10 +959,12 @@ ModelLink支持张量并行、流水线并行、序列并行、重计算、分 | 序列并行 | --sequence-parallel | | 重计算 | --recompute-granularity | | 分布式优化器 | --use-distributed-optimizer | -| DDP allreduce 掩盖 | --overlap-grad-reduce | +| DDP allreduce 掩盖 | --overlap-grad-reduce | +| DDP allgather 掩盖 | --overlap-param-gather | | Flash attention | --use-flash-attn | | Fused rmsnorm | --use-fused-rmsnorm | -| Fused swiglu | --use-fused-swiglu | +| Fused swiglu | --use-fused-swiglu | +| mc2 | --use-mc2 | | Fused rotary
position embedding | --use-fused-rotary-pos-emb | @@ -980,7 +982,9 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ --use-fused-rmsnorm \ --use-fused-swiglu \ --overlap-grad-reduce \ + --overlap-param-gather \ --use-fused-rotary-pos-emb \ + --use-mc2 \ ... \ ... ``` diff --git a/README_en.md b/README_en.md index e5f03e124..7ac059c5e 100644 --- a/README_en.md +++ b/README_en.md @@ -307,7 +307,7 @@ For the supported models listed above, we provide training scripts and readme in 13B 1x8 BF16 - 1550 + 1668 2062 train @@ -966,16 +966,21 @@ ModelLink supports various acceleration algorithms such as tensor parallelism, p | Sequence Parallel | --sequence-parallel | | Recomputation | --recompute-granularity | | Distributed Optimizer | --use-distributed-optimizer | -| overlap DDP allreduce | --overlap-grad-reduce | +| overlap DDP allreduce | --overlap-grad-reduce | +| overlap DDP allgather | --overlap-param-gather | | Flash attention | --use-flash-attn | | Fused rmsnorm | --use-fused-rmsnorm | -| Fused swiglu | --use-fused-swiglu | +| Fused swiglu | --use-fused-swiglu | +| mc2 | --use-mc2 | | Fused rotary
position embedding | --use-fused-rotary-pos-emb | + + ```bash torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ --tensor-model-parallel-size ${TP} \ --pipeline-model-parallel-size ${PP} \ + --num-layer-list 1,2,2,2,1 \ --sequence-parallel \ --recompute-granularity full \ --recompute-method block \ @@ -984,8 +989,10 @@ torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ --use-flash-attn \ --use-fused-rmsnorm \ --use-fused-swiglu \ - --use-fused-rotary-pos-emb \ --overlap-grad-reduce \ + --overlap-param-gather \ + --use-fused-rotary-pos-emb \ + --use-mc2 \ ... \ ... ``` diff --git a/examples/baichuan2/README.md b/examples/baichuan2/README.md index 92e62d735..c8e492897 100644 --- a/examples/baichuan2/README.md +++ b/examples/baichuan2/README.md @@ -416,7 +416,7 @@ Baichuan2-13B 在 **昇腾芯片** 和 **参考芯片** 上的性能对比: | 设备 | 模型 | 迭代数 | 样本吞吐 (samples/s) | token吞吐 (tokens/p/s) | 单步迭代时间 (s/step) | |:----:|:-------------------------:|:----:|:----------------:|:--------------------:|:---------------:| -| NPUs | Baichuan2-13B | 1000 | - | 1550 | - | +| NPUs | Baichuan2-13B | 1000 | - | 1668 | - | | 参考 | Baichuan2-13B | - | - | 2062 | - | diff --git a/examples/baichuan2/README_en.md b/examples/baichuan2/README_en.md index 18e977462..7126fc72b 100644 --- a/examples/baichuan2/README_en.md +++ b/examples/baichuan2/README_en.md @@ -415,7 +415,7 @@ The performance of the Baichuan2-13B in **Ascend NPU** and **Reference**: | Device | Model | total Iterations | throughput rate (samples/s/p) | throughput rate (tokens/s/p) | single-step time (s/step) | |:----:|:-------------------------:|:----:|:-----------------------------:|:----------------------------:|:-------------------------:| -| NPUs | Baichuan2-13B |1000 | - | 1550 | - | +| NPUs | Baichuan2-13B |1000 | - | 1668 | - | | Reference | Baichuan2-13B |-| - | 2062 | - | diff --git a/examples/baichuan2/pretrain_baichuan2_ptd_13B.sh b/examples/baichuan2/pretrain_baichuan2_ptd_13B.sh index bdf8a1d7e..34d722f5d 100644 --- a/examples/baichuan2/pretrain_baichuan2_ptd_13B.sh +++ b/examples/baichuan2/pretrain_baichuan2_ptd_13B.sh @@ -57,6 +57,8 @@ GPT_ARGS=" --normalization RMSNorm \ --use-fused-rmsnorm \ --use-flash-attn \ + --use-fused-swiglu \ + --use-mc2 \ --swiglu \ --no-masked-softmax-fusion \ --attention-softmax-in-fp32 \ @@ -72,7 +74,7 @@ GPT_ARGS=" --adam-eps 1.0e-8 \ --no-load-optim \ --no-load-rng \ - --fp16 + --bf16 " DATA_ARGS=" diff --git a/examples/intern/README.md b/examples/intern/README.md index 9b99ac6b6..187665634 100644 --- a/examples/intern/README.md +++ b/examples/intern/README.md @@ -312,7 +312,7 @@ python ./tools/preprocess_data.py \ source /usr/local/Ascend/ascend-toolkit/set_env.sh # 修改数据集,词表,权重等路径 CKPT_SAVE_DIR="./ckpt/internlm-65b/" -TOKENIZER_PATH="./model_from_hf/internlm-65b/" #词表路径 +TOKENIZER_PATH="./model_from_hf/internlm-65b/tokenizer.model" #词表路径 DATA_PATH="./dataset/internlm-65b/alpaca_text_document" #数据集路径 ``` diff --git a/examples/intern/README_en.md b/examples/intern/README_en.md index 56d77e91c..7ee10001c 100644 --- a/examples/intern/README_en.md +++ b/examples/intern/README_en.md @@ -312,7 +312,7 @@ python ./tools/preprocess_data.py \ source /usr/local/Ascend/ascend-toolkit/set_env.sh # modify script orign dataset path according to your own dataset path CKPT_SAVE_DIR="./ckpt/internlm-65b/" -TOKENIZER_PATH="./model_from_hf/internlm-65b/" #tokenizer path +TOKENIZER_PATH="./model_from_hf/internlm-65b/tokenizer.model" #tokenizer path DATA_PATH="./dataset/internlm-65b/alpaca_text_document" #processed dataset ``` diff --git a/examples/llama2/pretrain_llama2_70b_ptd.sh b/examples/llama2/pretrain_llama2_70b_ptd.sh index cec37aa0a..800f9b046 100644 --- a/examples/llama2/pretrain_llama2_70b_ptd.sh +++ b/examples/llama2/pretrain_llama2_70b_ptd.sh @@ -54,11 +54,12 @@ GPT_ARGS=" --no-masked-softmax-fusion \ --attention-softmax-in-fp32 \ --min-lr 1.0e-7 \ - --weight-decay 1e-2 \ + --weight-decay 0.1 \ --clip-grad 1.0 \ --adam-beta1 0.9 \ --initial-loss-scale 4096.0 \ - --adam-beta2 0.999 \ + --adam-beta2 0.95 \ + --adam-eps 1e-5 \ --no-gradient-accumulation-fusion \ --load ${CKPT_LOAD_DIR} \ --no-load-optim \ diff --git a/modellink/arguments.py b/modellink/arguments.py index 68b4b6b1f..190755a16 100644 --- a/modellink/arguments.py +++ b/modellink/arguments.py @@ -169,6 +169,8 @@ def _add_network_size_args(parser): help="Use fused swiglu.") group.add_argument("--use-fused-rotary-pos-emb", action='store_true', help="Use fused rotary-pos-emb.") + group.add_argument("--use-mc2", action='store_true', + help="Use mc2 for compute-comm overlap in tp.") return parser diff --git a/modellink/core/tensor_parallel/__init__.py b/modellink/core/tensor_parallel/__init__.py index e69de29bb..9a1307aec 100644 --- a/modellink/core/tensor_parallel/__init__.py +++ b/modellink/core/tensor_parallel/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/modellink/core/tensor_parallel/ascend_turbo/__init__.py b/modellink/core/tensor_parallel/ascend_turbo/__init__.py new file mode 100644 index 000000000..aaf493892 --- /dev/null +++ b/modellink/core/tensor_parallel/ascend_turbo/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/modellink/core/tensor_parallel/ascend_turbo/ascend_turbo_cfg.py b/modellink/core/tensor_parallel/ascend_turbo/ascend_turbo_cfg.py new file mode 100644 index 000000000..a49a1046c --- /dev/null +++ b/modellink/core/tensor_parallel/ascend_turbo/ascend_turbo_cfg.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class AscendConfig: + def __init__(self): + self.ColumnParallelLinear = None + self.RowParallelLinear = None + self.group_func = None + self.world_size_func = None + + self.sequence_parallel_enabled = True + self.all_gather_recomputation = True + + def set_sequence_parallel(self, sequence_parallel): + self.sequence_parallel = sequence_parallel + + def set_all_gather_recomputation(self, all_gather_recomputation): + self.all_gather_recomputation = all_gather_recomputation + + def set_group(self, group_func): + self.group_func = group_func + + def get_group(self): + return self.group_func() + + def set_world_size(self, world_size_func): + self.world_size_func = world_size_func + + def get_world_size(self): + return self.world_size_func() + + def set_column_parallel_linear(self, column_parallel_linear): + self.ColumnParallelLinear = column_parallel_linear + + def set_row_parallel_linear(self, row_parallel_linear): + self.RowParallelLinear = row_parallel_linear + + def parallel_linear_plugin(self, column_parallel_forward, row_parallel_forward): + self.ColumnParallelLinear.forward = column_parallel_forward + self.RowParallelLinear.forward = row_parallel_forward + + +ascend_turbo_cfg = AscendConfig() diff --git a/modellink/core/tensor_parallel/ascend_turbo/initialize.py b/modellink/core/tensor_parallel/ascend_turbo/initialize.py new file mode 100644 index 000000000..f66648c63 --- /dev/null +++ b/modellink/core/tensor_parallel/ascend_turbo/initialize.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_world_size +) + +from .ascend_turbo_cfg import ascend_turbo_cfg +from .mc2_linears_seq_parallel import ColumnSeqParallelLinear, RowSeqParallelLinear + + +def column_parallel_forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + output = ColumnSeqParallelLinear.apply( + input_, + self.weight, + bias, + ascend_turbo_cfg.get_group() + ) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +def row_parallel_forward(self, input_): + output = RowSeqParallelLinear.apply( + input_, + self.weight, + None, + ascend_turbo_cfg.get_group() + ) + + if not self.skip_bias_add: + output = output + self.bias if self.bias is not None else output + output_bias = None + else: + output_bias = self.bias + + return output, output_bias + + +def initialize_cfg_from_framework(): + ascend_turbo_cfg.set_group(get_tensor_model_parallel_group) + ascend_turbo_cfg.set_world_size(get_tensor_model_parallel_world_size) + + ascend_turbo_cfg.set_column_parallel_linear(ColumnParallelLinear) + ascend_turbo_cfg.set_row_parallel_linear(RowParallelLinear) + ascend_turbo_cfg.parallel_linear_plugin(column_parallel_forward, row_parallel_forward) + + +def initialize_cfg_from_args(args): + if not args.sequence_parallel or args.tensor_model_parallel_size == 1: + return + + ascend_turbo_cfg.set_sequence_parallel(args.sequence_parallel) + ascend_turbo_cfg.set_all_gather_recomputation(True) + initialize_cfg_from_framework() diff --git a/modellink/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py b/modellink/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py new file mode 100644 index 000000000..5d5292797 --- /dev/null +++ b/modellink/core/tensor_parallel/ascend_turbo/mc2_linears_seq_parallel.py @@ -0,0 +1,180 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch_npu +from .ascend_turbo_cfg import ascend_turbo_cfg + + +class ColumnSeqParallelLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, weight, bias, group): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + + rank = torch.distributed.get_rank(group) + hcomm_info = None + if torch.__version__ > "2.0": + global_rank = torch.distributed.get_global_rank(group, rank) + hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name( + global_rank + ) + + else: + hcomm_info = group.get_hccl_comm_name(rank) + + x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2]) + + world_size = ascend_turbo_cfg.get_world_size() + output, all_gather_grad_output = torch_npu.npu_all_gather_base_mm( + x, + weight.t(), + hcomm_info, + world_size, + bias=bias, + gather_index=0, + gather_output=(not ascend_turbo_cfg.all_gather_recomputation), + ) + + output = output.view( + int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1] + ) + + ctx.all_gather_output = all_gather_grad_output + ctx.world_size = world_size + ctx.group = group + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + + grad_output_ = grad_output.reshape( + grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] + ) + + if ascend_turbo_cfg.all_gather_recomputation: + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * ctx.world_size + all_gather_output = torch.empty( + dim_size, + dtype=input_.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + all_gather_work = torch.distributed._all_gather_base( + all_gather_output, input_.contiguous(), group=ctx.group, async_op=True + ) + else: + all_gather_output = ctx.all_gather_output + + grad_input = grad_output_.matmul(weight) + grad_input = grad_input.reshape( + grad_output.shape[0], grad_output.shape[1], weight.shape[1] + ) + + sub_grad_input = torch.empty( + list(input_.size()), dtype=input_.dtype, device=torch.cuda.current_device() + ) + reduce_scatter_work = torch.distributed._reduce_scatter_base( + sub_grad_input, grad_input, group=ctx.group, async_op=True + ) + + if ascend_turbo_cfg.all_gather_recomputation: + all_gather_work.wait() + all_gather_output = all_gather_output.reshape( + all_gather_output.shape[0] * all_gather_output.shape[1], + all_gather_output.shape[2], + ) + + grad_weight = grad_output_.t().matmul(all_gather_output) + + is_grad_bias_needed = ctx.needs_input_grad[2] + if is_grad_bias_needed and ctx.use_bias: + grad_bias = ( + grad_output_.sum(dim=0) + if grad_output_.is_contiguous() + else grad_output_.t().sum(dim=1) + ) + else: + grad_bias = None + + reduce_scatter_work.wait() + return sub_grad_input, grad_weight, grad_bias, None + + +class RowSeqParallelLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, weight, bias, group): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + + rank = torch.distributed.get_rank(group) + world_size = ascend_turbo_cfg.get_world_size() + hcomm_info = None + if torch.__version__ > "2.0": + global_rank = torch.distributed.get_global_rank(group, rank) + hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name( + global_rank + ) + else: + hcomm_info = group.get_hccl_comm_name(rank) + + x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2]) + + output = torch_npu.npu_mm_reduce_scatter_base( + x, weight.t(), hcomm_info, world_size, reduce_op="sum", bias=bias + ) + + ctx.hcomm_info = hcomm_info + ctx.world_size = world_size + + output = output.view( + int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1] + ) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + hcomm_info = ctx.hcomm_info + world_size = ctx.world_size + + grad_output_ = grad_output.reshape( + grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] + ) + + grad_input, all_gather_grad_output = torch_npu.npu_all_gather_base_mm( + grad_output_, weight, hcomm_info, world_size, bias=None, gather_index=0 + ) + + grad_input = grad_input.view_as(input_) + + x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2]) + grad_weight = all_gather_grad_output.t().matmul(x) + + is_grad_bias_needed = ctx.needs_input_grad[2] + if is_grad_bias_needed and ctx.use_bias: + grad_bias = ( + grad_output.sum(dim=0) + if grad_output.is_contiguous() + else grad_output.t().sum(dim=1) + ) + else: + grad_bias = None + + return grad_input, grad_weight, grad_bias, None diff --git a/modellink/initialize.py b/modellink/initialize.py index 938a741ef..56ce63025 100644 --- a/modellink/initialize.py +++ b/modellink/initialize.py @@ -1,6 +1,21 @@ import time import torch +import megatron +from megatron import get_args +from megatron.core import mpu +from megatron.arguments import validate_args +from megatron.checkpointing import load_args_from_checkpoint +from megatron.global_vars import set_global_variables +from megatron.initialize import ( + _initialize_distributed, _set_random_seed, + _init_autoresume, _initialize_tp_communicators +) + +from modellink.arguments import parse_args_decorator +from modellink.core.tensor_parallel.ascend_turbo.initialize import initialize_cfg_from_args +from modellink.error_utils import ensure_valid + def _compile_dependencies(): device_count = torch.cuda.device_count() @@ -12,4 +27,81 @@ def _compile_dependencies(): from megatron.core.datasets.utils import compile_helpers compile_helpers() print('>>> done with dataset index builder. Compilation time: {:.3f} ' - 'seconds'.format(time.time() - start_time), flush=True) \ No newline at end of file + 'seconds'.format(time.time() - start_time), flush=True) + + +def initialize_megatron( + extra_args_provider=None, + args_defaults={}, + ignore_unknown_args=False, + allow_no_cuda=False, + skip_mpu_initialization=False, +): + """Set global variables, initialize distributed, and + set autoresume and random seeds. + `allow_no_cuda` should not be set unless using megatron for cpu only + data processing. In general this arg should not be set unless you know + what you are doing. + Returns a function to finalize distributed env initialization + (optionally, only when args.lazy_mpu_init == True) + """ + if not allow_no_cuda: + # Make sure cuda is available. + ensure_valid(torch.cuda.is_available(), "Megatron requires CUDA.") + + # Parse arguments + parse_args = parse_args_decorator(megatron.arguments.parse_args) + args = parse_args(extra_args_provider, ignore_unknown_args) + + if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): + ensure_valid(args.load is not None, + "--use-checkpoints-args requires --load argument") + load_args_from_checkpoint(args) + + validate_args(args, args_defaults) + + # set global args, build tokenizer, and set adlr-autoresume, + # tensorboard-writer, and timers. + set_global_variables(args) + + # torch.distributed initialization + def finish_mpu_init(): + args = get_args() + # Pytorch distributed. + _initialize_distributed() + + # Random seeds for reproducibility. + if args.rank == 0: + print("> setting random seeds to {} ...".format(args.seed)) + _set_random_seed(args.seed, args.data_parallel_random_init) + if args.use_mc2: + initialize_cfg_from_args(args) + + if skip_mpu_initialization: + return None + + args = get_args() + if args.lazy_mpu_init: + args.use_cpu_initialization = True + # delayed initialization of DDP-related stuff + # We only set basic DDP globals + mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) + # and return function for external DDP manager + # to call when it has DDP initialized + mpu.set_tensor_model_parallel_rank(args.rank) + return finish_mpu_init + else: + # Megatron's MPU is the master. Complete initialization right away. + finish_mpu_init() + + # Autoresume. + _init_autoresume() + + # Compile dependencies. + _compile_dependencies() + + if args.tp_comm_overlap: + _initialize_tp_communicators() + + # No continuation function + return None diff --git a/modellink/patchs/megatron_patch.py b/modellink/patchs/megatron_patch.py index e4da9b1dc..b89041380 100644 --- a/modellink/patchs/megatron_patch.py +++ b/modellink/patchs/megatron_patch.py @@ -42,6 +42,7 @@ from ..data import build_pretraining_data_loader from ..tokenizer import build_tokenizer from ..arguments import parse_args_decorator from ..checkpointing import _load_base_checkpoint_wrapper, load_checkpoint_wrapper +from ..initialize import initialize_megatron def exec_patch(): @@ -99,7 +100,7 @@ def patch_tensor_parallel(): megatron.core.tensor_parallel.random._set_cuda_rng_state = _set_cuda_rng_state # default_generators need replace after set_device megatron.core.tensor_parallel.cross_entropy._VocabParallelCrossEntropy.forward = _VocabParallelCrossEntropyForward # change masked_target for better performance megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward = vocab_embedding_wrapper( - megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward) + VocabParallelEmbeddingForward) megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__ = norm_wrapper( megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__) @@ -169,6 +170,7 @@ def patch_initialize(): megatron.initialize._compile_dependencies = _compile_dependencies # remove cuda kernel compile megatron.initialize.set_jit_fusion_options = set_jit_fusion_options # remove cuda jit nvfuser megatron.initialize.parse_args = parse_args_decorator(megatron.initialize.parse_args) + megatron.initialize.initialize_megatron = initialize_megatron def patch_training():