From 34f970f19bb27b7c2bc06a1cd3959e5d82e97c81 Mon Sep 17 00:00:00 2001 From: WangYu <14693785+Wang----Yu@user.noreply.gitee.com> Date: Fri, 2 Aug 2024 03:40:07 +0000 Subject: [PATCH] =?UTF-8?q?!1466=20=E6=94=AF=E6=8C=81reuse-fp32-param=20Me?= =?UTF-8?q?rge=20pull=20request=20!1466=20from=20WangYu/reuse-fp32?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 7 + modellink/arguments.py | 12 +- modellink/core/optimizer/distrib_optimizer.py | 285 ++++++++++++++++++ modellink/patchs/megatron_patch.py | 12 + 4 files changed, 315 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5da12a0c3..336d7218c 100644 --- a/README.md +++ b/README.md @@ -869,6 +869,13 @@ ModelLink预训练支持张量并行、流水线并行等多种加速算法和 Yes Yes + + 显存 + 参数副本复用 + --reuse-fp32-param + Yes + Yes + 通信 梯度reduce通算掩盖 diff --git a/modellink/arguments.py b/modellink/arguments.py index 1d98700b2..0cb790a49 100644 --- a/modellink/arguments.py +++ b/modellink/arguments.py @@ -290,7 +290,9 @@ def _add_network_size_args(parser): def _add_algorithm_args(parser): group = parser.add_argument_group(title='algorithm') group.add_argument('--rotary-base', type=float, help='rotary-base.') - + group.add_argument('--reuse-fp32-param', action='store_true', + help='The distributed training optimizer frees up ' + 'param copies of FP32 to save memory.') return parser @@ -462,6 +464,13 @@ def core_transformer_config_from_args_wrapper(fn): return wrapper +def _validate_optimizer(args): + if args.reuse_fp32_param and not args.bf16: + raise AssertionError('--reuse-fp32-param only support for `bf16`') + if args.reuse_fp32_param and args.enable_high_availability: + raise AssertionError('reuse-fp32-param and enable-high-availability do not support enabling together.') + + def validate_args_decorator(megatron_validate_args): @wraps(megatron_validate_args) def wrapper(args, defaults=None): @@ -480,6 +489,7 @@ def validate_args_decorator(megatron_validate_args): _validate_high_availability(args) _validate_moe_expert_capacity_factor(args) + _validate_optimizer(args) from modellink.utils import print_args print_args('ModelLink Arguments', args) return args diff --git a/modellink/core/optimizer/distrib_optimizer.py b/modellink/core/optimizer/distrib_optimizer.py index 6bd6444e0..4feccea92 100644 --- a/modellink/core/optimizer/distrib_optimizer.py +++ b/modellink/core/optimizer/distrib_optimizer.py @@ -1,10 +1,13 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +import types import itertools from functools import wraps from typing import Callable, Dict, List, Optional, Tuple import torch +import torch_npu +import torch.distributed from apex.optimizers import FusedAdam as Adam from megatron.training import get_args @@ -12,6 +15,288 @@ from megatron.core.distributed import ParamAndGradBuffer from megatron.core.optimizer.grad_scaler import MegatronGradScaler from megatron.core.optimizer import OptimizerConfig from megatron.core.optimizer.optimizer import MixedPrecisionOptimizer +from mindspeed.optimizer.distrib_optimizer import _copy_model_params_to_main_params + +TRANSPOSE_BF16_BLOCK_SIZE = 4096 * 4096 + + +def reuse_fp32_param_distrib_optimizer_init_wrapper(init_func): + @wraps(init_func) + def reuse_fp32_param_distrib_optimizer_init(self, *args, **kwargs): + init_func(*args, **kwargs) + global_args = get_args() + self.reuse_fp32_param = global_args.reuse_fp32_param if hasattr(global_args, "reuse_fp32_param") else False + # A flag that disables the value subtraction when the `fp16_tensor_convert_to_fp32_tensor` function is invoked for the first time. + self.first_sub_flag = True + if self.reuse_fp32_param: + from mindspeed.op_builder import AlgorithmOpBuilder + reuse_data_ptr = AlgorithmOpBuilder().load().reuse_data_ptr + data_parallel_world_size = torch.distributed.get_world_size(self.data_parallel_group) + data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo) + self.model_param_bucket_and_res_map = {} + self.model_param_bucket_and_shard_main_param_int32_view_map = {} + self.shard_main_param_res_buffers = [] + self.bucket_num_groups = [] + if data_parallel_world_size == 1: + self.shard_fp32_param_fp16_view_group = [] + for buffer in self.buffers: + buffer_numel = buffer.param_data.numel() + shard_res_and_buffer_model_param = torch.zeros(buffer_numel * 2, dtype=torch.bfloat16, device=buffer.param_data.device) + shard_main_param_int32_view_buffer = torch.empty(buffer_numel, dtype=torch.int32, device=buffer.param_data.device) + reuse_data_ptr(shard_main_param_int32_view_buffer, shard_res_and_buffer_model_param, 0) + self.shard_main_param_res_buffers.append(shard_res_and_buffer_model_param) + self.model_param_bucket_and_shard_main_param_int32_view_map[shard_res_and_buffer_model_param] = shard_main_param_int32_view_buffer + for model_fp16_params_this_group, shard_fp32_from_float16_group in zip( + self.model_float16_groups, self.shard_fp32_from_float16_groups): + for i, (model_param, shard_fp32_main_param) in enumerate( + zip(model_fp16_params_this_group, shard_fp32_from_float16_group)): + gbuf_index, _, bucket_id = self.model_param_gbuf_map[model_param] + data_start_index, data_end_index, bucket_id = self.buffers[gbuf_index].param_index_map[model_param] + reuse_data_ptr(shard_fp32_from_float16_group[i], self.shard_main_param_res_buffers[gbuf_index], data_start_index) + old_param_data = model_param.data + model_param.data = self.shard_main_param_res_buffers[gbuf_index][data_start_index + data_end_index: 2 * data_end_index].view(old_param_data.shape) + model_param.data.detach().copy_(old_param_data) + self.shard_fp32_param_fp16_view_group.append(self.shard_main_param_res_buffers[gbuf_index][2 * data_start_index: 2 * data_end_index]) + for i, buffer in enumerate(self.buffers): + buffer_numel = buffer.param_data.numel() + reuse_data_ptr(buffer.param_data, self.shard_main_param_res_buffers[i], buffer_numel) + else: + for buffer in self.buffers: + self.bucket_num_group = [] + bucket_res_numel = 0 + res_numel = buffer.numel // data_parallel_world_size + shard_main_param_res_buffer = torch.zeros(res_numel, dtype=torch.bfloat16, device=buffer.param_data.device) + self.shard_main_param_res_buffers.append(shard_main_param_res_buffer) + for bucket in buffer.buckets: + self.bucket_num_group.append(bucket.param_data.numel()) + param_data_dp_numel = bucket.param_data.numel() // data_parallel_world_size + shard_main_param_int32_view_bucket = torch.empty(param_data_dp_numel, dtype=torch.int32, device=bucket.param_data.device) + reuse_data_ptr( + shard_main_param_int32_view_bucket, + buffer.param_data, + (bucket_res_numel * data_parallel_world_size) // 2 + max(0, data_parallel_rank - 1) * param_data_dp_numel // 2) + self.model_param_bucket_and_res_map[bucket.param_data] = self.shard_main_param_res_buffers[-1][bucket_res_numel: bucket_res_numel + param_data_dp_numel] + self.model_param_bucket_and_shard_main_param_int32_view_map[bucket.param_data] = shard_main_param_int32_view_bucket + bucket_res_numel += param_data_dp_numel + self.bucket_num_groups.append(self.bucket_num_group) + for model_fp16_params_this_group, shard_fp32_from_float16_group in zip( + self.model_float16_groups, self.shard_fp32_from_float16_groups): + for i, (model_param, shard_fp32_main_param) in enumerate( + zip(model_fp16_params_this_group, shard_fp32_from_float16_group)): + world_range = self._get_model_param_range_map(model_param)["gbuf_world_in_bucket"] + gbuf_index, _, bucket_id = self.model_param_gbuf_map[model_param] + model_param_buffer = self.buffers[gbuf_index].param_data + bucket_offset_in_buffer = sum(self.bucket_num_groups[gbuf_index][:bucket_id]) // 2 + model_param_bucket = self.buffers[gbuf_index].buckets[bucket_id].param_data + model_param_bucket_numel_per_dp = model_param_bucket.numel() // data_parallel_world_size + shard_fp32_param_bucket_offset = world_range.start if data_parallel_rank == 0 else \ + world_range.start - model_param_bucket_numel_per_dp * (1 + data_parallel_rank) // 2 + shard_main_param_buffer_start = bucket_offset_in_buffer + shard_fp32_param_bucket_offset + reuse_data_ptr(shard_fp32_from_float16_group[i], model_param_buffer, shard_main_param_buffer_start) + torch_npu.npu.empty_cache() + self._copy_model_params_to_main_params = _copy_model_params_to_main_params + self.load_parameter_state_from_dp_zero_func = self.load_parameter_state_from_dp_zero + self.load_parameter_state_from_dp_zero = types.MethodType(load_parameter_state_from_dp_zero, self) + self.get_parameter_state_dp_zero_func = self.get_parameter_state_dp_zero + self.get_parameter_state_dp_zero = types.MethodType(get_parameter_state_dp_zero, self) + self.fp16_tensor_convert_to_fp32_tensor = types.MethodType(fp16_tensor_convert_to_fp32_tensor, self) + self.fp32_tensor_convert_to_fp16_tensor = types.MethodType(fp32_tensor_convert_to_fp16_tensor, self) + return reuse_fp32_param_distrib_optimizer_init + + +def load_parameter_state_from_dp_zero(self, state_dict): + self.load_parameter_state_from_dp_zero_func(state_dict) + self.first_sub_flag = False + data_parallel_world_size = self.data_parallel_group_gloo.size() + data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo) + data_parallel_group_gloo = self.data_parallel_group_gloo + data_parallel_global_ranks = torch.distributed.get_process_group_ranks( + self.data_parallel_group_gloo + ) + if data_parallel_world_size == 1 or \ + not hasattr(self, "shard_main_param_res_buffers"): + return + for i, shard_main_param_res_buffer in enumerate(self.shard_main_param_res_buffers): + shard_res_numel = shard_main_param_res_buffer.numel() + if data_parallel_rank == 0: + send_tensors = [ + state_dict["shard_main_param_res"][i][ + dpr * shard_res_numel: (dpr + 1) * shard_res_numel] for dpr in range(data_parallel_world_size) + ] + else: + send_tensors = None + shard_res_numel = shard_main_param_res_buffer.numel() + recv_tensor = torch.empty((shard_res_numel,), dtype=torch.float16, device="cpu") + torch.distributed.scatter( + recv_tensor, + send_tensors, + data_parallel_global_ranks[0], + data_parallel_group_gloo, + ) + recv_tensor_bf16_view = torch.tensor(recv_tensor.data.untyped_storage(), dtype=torch.bfloat16, device=recv_tensor.device) + shard_main_param_res_buffer.copy_(recv_tensor_bf16_view) + + +def get_parameter_state_dp_zero(self): + state = self.get_parameter_state_dp_zero_func() + data_parallel_world_size = torch.distributed.get_world_size(self.data_parallel_group) + data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo) + data_parallel_group_gloo = self.data_parallel_group_gloo + data_parallel_global_ranks = torch.distributed.get_process_group_ranks( + self.data_parallel_group_gloo + ) + if data_parallel_world_size == 1 or not hasattr(self, "shard_main_param_res_buffers"): + return state + # gather buffer res + buffer_res_full_shard = [] + for shard_main_param_res_buffer in self.shard_main_param_res_buffers: + if data_parallel_rank == 0: + recv_tensors = [torch.empty((shard_main_param_res_buffer.numel(),), dtype=torch.float16, device="cpu") for _ in range(data_parallel_world_size)] + else: + recv_tensors = None + send_tensor = torch.empty((shard_main_param_res_buffer.numel(),), dtype=torch.float16, device="cpu") + send_tensor_bf16_view = torch.tensor(send_tensor.data.untyped_storage(), dtype=torch.bfloat16, device=send_tensor.device) + send_tensor_bf16_view.copy_(shard_main_param_res_buffer.detach().cpu()) + torch.distributed.gather( + send_tensor, + recv_tensors, + data_parallel_global_ranks[0], + data_parallel_group_gloo, + ) + if data_parallel_rank == 0: + buffer_res_full_shard.append(torch.cat(recv_tensors)) + state['shard_main_param_res'] = buffer_res_full_shard + return state + + +def fp16_tensor_convert_to_fp32_tensor(self): + """ + res(0000) + bf16(pppp) -> fp32(0p0p0p0p) + + Transform the bf16 data and residuals data in the continuous memory block + into the fp32 tensor through view transposition. + """ + data_parallel_world_size = torch.distributed.get_world_size(self.data_parallel_group) + data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo) + if data_parallel_world_size == 1: + for shard_fp32_param_fp16_view in self.shard_fp32_param_fp16_view_group: + shard_fp32_param_fp16_view.copy_( + shard_fp32_param_fp16_view.view(2, -1).transpose(1, 0).reshape(-1).contiguous()) + + for shard_res_and_buffer_model_param in self.shard_main_param_res_buffers: + shard_main_param_int32_view_buffer = self.model_param_bucket_and_shard_main_param_int32_view_map[ + shard_res_and_buffer_model_param] + if not self.first_sub_flag: + shard_main_param_int32_view_buffer.sub_(32768) + else: + for buffer in self.buffers: + for bucket in buffer.buckets: + bucket_param_data = bucket.param_data + param_data_dp_numel = bucket_param_data.numel() // data_parallel_world_size + bucket_res = self.model_param_bucket_and_res_map[bucket_param_data] + if data_parallel_rank == 0: + bucket_param_data[param_data_dp_numel:param_data_dp_numel * 2].copy_( + bucket_param_data[:param_data_dp_numel]) + bucket_res_position = max(0, data_parallel_rank - 1) * param_data_dp_numel + shard_fp32_main_param_view = bucket_param_data[ + bucket_res_position: bucket_res_position + param_data_dp_numel * 2] + shard_main_param_int32_view_bucket = self.model_param_bucket_and_shard_main_param_int32_view_map[ + bucket_param_data] + + loops = param_data_dp_numel // TRANSPOSE_BF16_BLOCK_SIZE + remain = param_data_dp_numel % TRANSPOSE_BF16_BLOCK_SIZE + workspace = torch.zeros( + TRANSPOSE_BF16_BLOCK_SIZE * 2, dtype=torch.bfloat16, device=bucket_res.device) + residual_space = bucket_res + bf16_space_dp_rank = max(1, data_parallel_rank) + bf16_space = bucket_param_data[ + param_data_dp_numel * bf16_space_dp_rank:param_data_dp_numel * (bf16_space_dp_rank + 1)] + + for loop in range(loops): + copy_start = loop * TRANSPOSE_BF16_BLOCK_SIZE + copy_end = (loop + 1) * TRANSPOSE_BF16_BLOCK_SIZE + workspace_convert_view = workspace[:TRANSPOSE_BF16_BLOCK_SIZE * 2] + workspace[:TRANSPOSE_BF16_BLOCK_SIZE].copy_(residual_space[copy_start: copy_end]) + workspace[TRANSPOSE_BF16_BLOCK_SIZE:TRANSPOSE_BF16_BLOCK_SIZE * 2].copy_( + bf16_space[copy_start: copy_end]) + shard_fp32_main_param_view[copy_start * 2: copy_end * 2].copy_( + workspace_convert_view.view(2, -1).transpose(1, 0).reshape(-1).contiguous()) + + if remain > 0: + workspace_convert_view = workspace[:remain * 2] + workspace[:remain].copy_(residual_space[-remain:]) + workspace[remain:remain * 2].copy_(bf16_space[-remain:]) + shard_fp32_main_param_view[-remain * 2:].copy_( + workspace_convert_view.view(2, -1).transpose(1, 0).reshape(-1).contiguous()) + + if not self.first_sub_flag: + shard_main_param_int32_view_bucket[:param_data_dp_numel].sub_(32768) + + +def fp32_tensor_convert_to_fp16_tensor(self): + """ + fp32(0p0p0p0p) -> fp32(0'p0'p0'p0'p) -> res(0000) + bf16(pppp) + + Transform the fp32 tensor in the continuous memory block + into the bf16 data and residual through view transposition. + """ + data_parallel_world_size = torch.distributed.get_world_size(self.data_parallel_group) + data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo) + if data_parallel_world_size == 1: + for shard_res_and_buffer_model_param in self.shard_main_param_res_buffers: + shard_main_param_int32_view_buffer = self.model_param_bucket_and_shard_main_param_int32_view_map[ + shard_res_and_buffer_model_param] + shard_main_param_int32_view_buffer.add_(32768) + self.first_sub_flag = False + + for shard_fp32_param_fp16_view in self.shard_fp32_param_fp16_view_group: + shard_fp32_param_fp16_view.copy_( + shard_fp32_param_fp16_view.view(-1, 2).transpose(1, 0).reshape(-1).contiguous()) + else: + for buffer in self.buffers: + for bucket in buffer.buckets: + bucket_param_data = bucket.param_data + param_data_dp_numel = bucket_param_data.numel() // data_parallel_world_size + bucket_res = self.model_param_bucket_and_res_map[bucket_param_data] + shard_main_param_int32_view_bucket = self.model_param_bucket_and_shard_main_param_int32_view_map[ + bucket_param_data] + shard_main_param_int32_view_bucket[:param_data_dp_numel].add_(32768) + self.first_sub_flag = False + + bucket_res_position = max(0, data_parallel_rank - 1) * param_data_dp_numel + shard_fp32_main_param_view = bucket_param_data[ + bucket_res_position: bucket_res_position + param_data_dp_numel * 2] + + loops = param_data_dp_numel // TRANSPOSE_BF16_BLOCK_SIZE + remain = param_data_dp_numel % TRANSPOSE_BF16_BLOCK_SIZE + workspace = torch.zeros( + TRANSPOSE_BF16_BLOCK_SIZE * 2, dtype=torch.bfloat16, device=bucket_res.device) + bf16_space_dp_rank = max(0, data_parallel_rank - 1) + residual_space = bucket_res + bf16_space = bucket_param_data[ + param_data_dp_numel * bf16_space_dp_rank:param_data_dp_numel * (bf16_space_dp_rank + 1)] + + for loop in range(loops): + workspace_convert_view = workspace[:TRANSPOSE_BF16_BLOCK_SIZE * 2] + workspace_convert_view.copy_( + shard_fp32_main_param_view[ + loop * TRANSPOSE_BF16_BLOCK_SIZE * 2: (loop + 1) * TRANSPOSE_BF16_BLOCK_SIZE * 2]) + temp = workspace_convert_view.view(-1, 2).transpose(1, 0).reshape(-1).contiguous() + residual_space[loop * TRANSPOSE_BF16_BLOCK_SIZE: (loop + 1) * TRANSPOSE_BF16_BLOCK_SIZE].copy_( + temp[:TRANSPOSE_BF16_BLOCK_SIZE]) + bf16_space[loop * TRANSPOSE_BF16_BLOCK_SIZE: (loop + 1) * TRANSPOSE_BF16_BLOCK_SIZE].copy_( + temp[TRANSPOSE_BF16_BLOCK_SIZE: TRANSPOSE_BF16_BLOCK_SIZE * 2]) + + if remain > 0: + workspace_convert_view = workspace[:remain * 2] + workspace_convert_view.copy_(shard_fp32_main_param_view[-remain * 2:]) + temp = workspace_convert_view.view(-1, 2).transpose(1, 0).reshape(-1).contiguous() + residual_space[-remain:].copy_(temp[:remain]) + bf16_space[-remain:].copy_(temp[remain: remain * 2]) + + if data_parallel_rank != 0: + shard_fp32_main_param_view[param_data_dp_numel:param_data_dp_numel * 2].copy_( + shard_fp32_main_param_view[:param_data_dp_numel]) def distributed_optimizer_init( diff --git a/modellink/patchs/megatron_patch.py b/modellink/patchs/megatron_patch.py index e6648d2f1..9637431dc 100644 --- a/modellink/patchs/megatron_patch.py +++ b/modellink/patchs/megatron_patch.py @@ -89,6 +89,7 @@ def patch_megatron_noncore(): patch_training() patch_log_handler() patch_high_availability_feature() + patch_optimizer() def patch_fusions(): @@ -312,3 +313,14 @@ def patch_high_availability_feature(): PatchManager.register_patch('megatron.core.optimizer.optimizer.clip_grad_norm_fp32', clip_grad_norm_fp32_wrapper) PatchManager.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__', distributed_optimizer_init_wrapper) PatchManager.register_patch('megatron.training.training.setup_model_and_optimizer', setup_model_and_optimizer_wrapper) + + +def patch_optimizer(): + if get_modellink_args().reuse_fp32_param: + from mindspeed.optimizer.optimizer import mixed_precision_optimizer_step, reuse_fp32_param_init_wrapper, \ + optimizer_config_init_wrapper + from ..core.optimizer.distrib_optimizer import reuse_fp32_param_distrib_optimizer_init_wrapper + PatchManager.register_patch('megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step', mixed_precision_optimizer_step) + PatchManager.register_patch('megatron.core.optimizer.optimizer.Float16OptimizerWithFloat16Params.__init__', reuse_fp32_param_init_wrapper) + PatchManager.register_patch('megatron.core.optimizer.optimizer_config.OptimizerConfig.__init__', optimizer_config_init_wrapper) + PatchManager.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__', reuse_fp32_param_distrib_optimizer_init_wrapper) \ No newline at end of file