!1466 支持reuse-fp32-param

Merge pull request !1466 from WangYu/reuse-fp32
This commit is contained in:
WangYu 2024-08-02 03:40:07 +00:00 committed by i-robot
parent e79fc1081f
commit 34f970f19b
4 changed files with 315 additions and 1 deletions

View File

@ -869,6 +869,13 @@ ModelLink预训练支持张量并行、流水线并行等多种加速算法和
<td>Yes</td> <td>Yes</td>
<td>Yes</td> <td>Yes</td>
</tr> </tr>
<tr>
<td>显存 </td>
<td>参数副本复用</td>
<td>--reuse-fp32-param</td>
<td>Yes</td>
<td> Yes</td>
</tr>
<tr> <tr>
<td rowspan="3">通信</td> <td rowspan="3">通信</td>
<td>梯度reduce通算掩盖</td> <td>梯度reduce通算掩盖</td>

View File

@ -290,7 +290,9 @@ def _add_network_size_args(parser):
def _add_algorithm_args(parser): def _add_algorithm_args(parser):
group = parser.add_argument_group(title='algorithm') group = parser.add_argument_group(title='algorithm')
group.add_argument('--rotary-base', type=float, help='rotary-base.') group.add_argument('--rotary-base', type=float, help='rotary-base.')
group.add_argument('--reuse-fp32-param', action='store_true',
help='The distributed training optimizer frees up '
'param copies of FP32 to save memory.')
return parser return parser
@ -462,6 +464,13 @@ def core_transformer_config_from_args_wrapper(fn):
return wrapper return wrapper
def _validate_optimizer(args):
if args.reuse_fp32_param and not args.bf16:
raise AssertionError('--reuse-fp32-param only support for `bf16`')
if args.reuse_fp32_param and args.enable_high_availability:
raise AssertionError('reuse-fp32-param and enable-high-availability do not support enabling together.')
def validate_args_decorator(megatron_validate_args): def validate_args_decorator(megatron_validate_args):
@wraps(megatron_validate_args) @wraps(megatron_validate_args)
def wrapper(args, defaults=None): def wrapper(args, defaults=None):
@ -480,6 +489,7 @@ def validate_args_decorator(megatron_validate_args):
_validate_high_availability(args) _validate_high_availability(args)
_validate_moe_expert_capacity_factor(args) _validate_moe_expert_capacity_factor(args)
_validate_optimizer(args)
from modellink.utils import print_args from modellink.utils import print_args
print_args('ModelLink Arguments', args) print_args('ModelLink Arguments', args)
return args return args

View File

@ -1,10 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. # Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
import types
import itertools import itertools
from functools import wraps from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
import torch import torch
import torch_npu
import torch.distributed
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from megatron.training import get_args from megatron.training import get_args
@ -12,6 +15,288 @@ from megatron.core.distributed import ParamAndGradBuffer
from megatron.core.optimizer.grad_scaler import MegatronGradScaler from megatron.core.optimizer.grad_scaler import MegatronGradScaler
from megatron.core.optimizer import OptimizerConfig from megatron.core.optimizer import OptimizerConfig
from megatron.core.optimizer.optimizer import MixedPrecisionOptimizer from megatron.core.optimizer.optimizer import MixedPrecisionOptimizer
from mindspeed.optimizer.distrib_optimizer import _copy_model_params_to_main_params
TRANSPOSE_BF16_BLOCK_SIZE = 4096 * 4096
def reuse_fp32_param_distrib_optimizer_init_wrapper(init_func):
@wraps(init_func)
def reuse_fp32_param_distrib_optimizer_init(self, *args, **kwargs):
init_func(*args, **kwargs)
global_args = get_args()
self.reuse_fp32_param = global_args.reuse_fp32_param if hasattr(global_args, "reuse_fp32_param") else False
# A flag that disables the value subtraction when the `fp16_tensor_convert_to_fp32_tensor` function is invoked for the first time.
self.first_sub_flag = True
if self.reuse_fp32_param:
from mindspeed.op_builder import AlgorithmOpBuilder
reuse_data_ptr = AlgorithmOpBuilder().load().reuse_data_ptr
data_parallel_world_size = torch.distributed.get_world_size(self.data_parallel_group)
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
self.model_param_bucket_and_res_map = {}
self.model_param_bucket_and_shard_main_param_int32_view_map = {}
self.shard_main_param_res_buffers = []
self.bucket_num_groups = []
if data_parallel_world_size == 1:
self.shard_fp32_param_fp16_view_group = []
for buffer in self.buffers:
buffer_numel = buffer.param_data.numel()
shard_res_and_buffer_model_param = torch.zeros(buffer_numel * 2, dtype=torch.bfloat16, device=buffer.param_data.device)
shard_main_param_int32_view_buffer = torch.empty(buffer_numel, dtype=torch.int32, device=buffer.param_data.device)
reuse_data_ptr(shard_main_param_int32_view_buffer, shard_res_and_buffer_model_param, 0)
self.shard_main_param_res_buffers.append(shard_res_and_buffer_model_param)
self.model_param_bucket_and_shard_main_param_int32_view_map[shard_res_and_buffer_model_param] = shard_main_param_int32_view_buffer
for model_fp16_params_this_group, shard_fp32_from_float16_group in zip(
self.model_float16_groups, self.shard_fp32_from_float16_groups):
for i, (model_param, shard_fp32_main_param) in enumerate(
zip(model_fp16_params_this_group, shard_fp32_from_float16_group)):
gbuf_index, _, bucket_id = self.model_param_gbuf_map[model_param]
data_start_index, data_end_index, bucket_id = self.buffers[gbuf_index].param_index_map[model_param]
reuse_data_ptr(shard_fp32_from_float16_group[i], self.shard_main_param_res_buffers[gbuf_index], data_start_index)
old_param_data = model_param.data
model_param.data = self.shard_main_param_res_buffers[gbuf_index][data_start_index + data_end_index: 2 * data_end_index].view(old_param_data.shape)
model_param.data.detach().copy_(old_param_data)
self.shard_fp32_param_fp16_view_group.append(self.shard_main_param_res_buffers[gbuf_index][2 * data_start_index: 2 * data_end_index])
for i, buffer in enumerate(self.buffers):
buffer_numel = buffer.param_data.numel()
reuse_data_ptr(buffer.param_data, self.shard_main_param_res_buffers[i], buffer_numel)
else:
for buffer in self.buffers:
self.bucket_num_group = []
bucket_res_numel = 0
res_numel = buffer.numel // data_parallel_world_size
shard_main_param_res_buffer = torch.zeros(res_numel, dtype=torch.bfloat16, device=buffer.param_data.device)
self.shard_main_param_res_buffers.append(shard_main_param_res_buffer)
for bucket in buffer.buckets:
self.bucket_num_group.append(bucket.param_data.numel())
param_data_dp_numel = bucket.param_data.numel() // data_parallel_world_size
shard_main_param_int32_view_bucket = torch.empty(param_data_dp_numel, dtype=torch.int32, device=bucket.param_data.device)
reuse_data_ptr(
shard_main_param_int32_view_bucket,
buffer.param_data,
(bucket_res_numel * data_parallel_world_size) // 2 + max(0, data_parallel_rank - 1) * param_data_dp_numel // 2)
self.model_param_bucket_and_res_map[bucket.param_data] = self.shard_main_param_res_buffers[-1][bucket_res_numel: bucket_res_numel + param_data_dp_numel]
self.model_param_bucket_and_shard_main_param_int32_view_map[bucket.param_data] = shard_main_param_int32_view_bucket
bucket_res_numel += param_data_dp_numel
self.bucket_num_groups.append(self.bucket_num_group)
for model_fp16_params_this_group, shard_fp32_from_float16_group in zip(
self.model_float16_groups, self.shard_fp32_from_float16_groups):
for i, (model_param, shard_fp32_main_param) in enumerate(
zip(model_fp16_params_this_group, shard_fp32_from_float16_group)):
world_range = self._get_model_param_range_map(model_param)["gbuf_world_in_bucket"]
gbuf_index, _, bucket_id = self.model_param_gbuf_map[model_param]
model_param_buffer = self.buffers[gbuf_index].param_data
bucket_offset_in_buffer = sum(self.bucket_num_groups[gbuf_index][:bucket_id]) // 2
model_param_bucket = self.buffers[gbuf_index].buckets[bucket_id].param_data
model_param_bucket_numel_per_dp = model_param_bucket.numel() // data_parallel_world_size
shard_fp32_param_bucket_offset = world_range.start if data_parallel_rank == 0 else \
world_range.start - model_param_bucket_numel_per_dp * (1 + data_parallel_rank) // 2
shard_main_param_buffer_start = bucket_offset_in_buffer + shard_fp32_param_bucket_offset
reuse_data_ptr(shard_fp32_from_float16_group[i], model_param_buffer, shard_main_param_buffer_start)
torch_npu.npu.empty_cache()
self._copy_model_params_to_main_params = _copy_model_params_to_main_params
self.load_parameter_state_from_dp_zero_func = self.load_parameter_state_from_dp_zero
self.load_parameter_state_from_dp_zero = types.MethodType(load_parameter_state_from_dp_zero, self)
self.get_parameter_state_dp_zero_func = self.get_parameter_state_dp_zero
self.get_parameter_state_dp_zero = types.MethodType(get_parameter_state_dp_zero, self)
self.fp16_tensor_convert_to_fp32_tensor = types.MethodType(fp16_tensor_convert_to_fp32_tensor, self)
self.fp32_tensor_convert_to_fp16_tensor = types.MethodType(fp32_tensor_convert_to_fp16_tensor, self)
return reuse_fp32_param_distrib_optimizer_init
def load_parameter_state_from_dp_zero(self, state_dict):
self.load_parameter_state_from_dp_zero_func(state_dict)
self.first_sub_flag = False
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
data_parallel_global_ranks = torch.distributed.get_process_group_ranks(
self.data_parallel_group_gloo
)
if data_parallel_world_size == 1 or \
not hasattr(self, "shard_main_param_res_buffers"):
return
for i, shard_main_param_res_buffer in enumerate(self.shard_main_param_res_buffers):
shard_res_numel = shard_main_param_res_buffer.numel()
if data_parallel_rank == 0:
send_tensors = [
state_dict["shard_main_param_res"][i][
dpr * shard_res_numel: (dpr + 1) * shard_res_numel] for dpr in range(data_parallel_world_size)
]
else:
send_tensors = None
shard_res_numel = shard_main_param_res_buffer.numel()
recv_tensor = torch.empty((shard_res_numel,), dtype=torch.float16, device="cpu")
torch.distributed.scatter(
recv_tensor,
send_tensors,
data_parallel_global_ranks[0],
data_parallel_group_gloo,
)
recv_tensor_bf16_view = torch.tensor(recv_tensor.data.untyped_storage(), dtype=torch.bfloat16, device=recv_tensor.device)
shard_main_param_res_buffer.copy_(recv_tensor_bf16_view)
def get_parameter_state_dp_zero(self):
state = self.get_parameter_state_dp_zero_func()
data_parallel_world_size = torch.distributed.get_world_size(self.data_parallel_group)
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
data_parallel_global_ranks = torch.distributed.get_process_group_ranks(
self.data_parallel_group_gloo
)
if data_parallel_world_size == 1 or not hasattr(self, "shard_main_param_res_buffers"):
return state
# gather buffer res
buffer_res_full_shard = []
for shard_main_param_res_buffer in self.shard_main_param_res_buffers:
if data_parallel_rank == 0:
recv_tensors = [torch.empty((shard_main_param_res_buffer.numel(),), dtype=torch.float16, device="cpu") for _ in range(data_parallel_world_size)]
else:
recv_tensors = None
send_tensor = torch.empty((shard_main_param_res_buffer.numel(),), dtype=torch.float16, device="cpu")
send_tensor_bf16_view = torch.tensor(send_tensor.data.untyped_storage(), dtype=torch.bfloat16, device=send_tensor.device)
send_tensor_bf16_view.copy_(shard_main_param_res_buffer.detach().cpu())
torch.distributed.gather(
send_tensor,
recv_tensors,
data_parallel_global_ranks[0],
data_parallel_group_gloo,
)
if data_parallel_rank == 0:
buffer_res_full_shard.append(torch.cat(recv_tensors))
state['shard_main_param_res'] = buffer_res_full_shard
return state
def fp16_tensor_convert_to_fp32_tensor(self):
"""
res(0000) + bf16(pppp) -> fp32(0p0p0p0p)
Transform the bf16 data and residuals data in the continuous memory block
into the fp32 tensor through view transposition.
"""
data_parallel_world_size = torch.distributed.get_world_size(self.data_parallel_group)
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
if data_parallel_world_size == 1:
for shard_fp32_param_fp16_view in self.shard_fp32_param_fp16_view_group:
shard_fp32_param_fp16_view.copy_(
shard_fp32_param_fp16_view.view(2, -1).transpose(1, 0).reshape(-1).contiguous())
for shard_res_and_buffer_model_param in self.shard_main_param_res_buffers:
shard_main_param_int32_view_buffer = self.model_param_bucket_and_shard_main_param_int32_view_map[
shard_res_and_buffer_model_param]
if not self.first_sub_flag:
shard_main_param_int32_view_buffer.sub_(32768)
else:
for buffer in self.buffers:
for bucket in buffer.buckets:
bucket_param_data = bucket.param_data
param_data_dp_numel = bucket_param_data.numel() // data_parallel_world_size
bucket_res = self.model_param_bucket_and_res_map[bucket_param_data]
if data_parallel_rank == 0:
bucket_param_data[param_data_dp_numel:param_data_dp_numel * 2].copy_(
bucket_param_data[:param_data_dp_numel])
bucket_res_position = max(0, data_parallel_rank - 1) * param_data_dp_numel
shard_fp32_main_param_view = bucket_param_data[
bucket_res_position: bucket_res_position + param_data_dp_numel * 2]
shard_main_param_int32_view_bucket = self.model_param_bucket_and_shard_main_param_int32_view_map[
bucket_param_data]
loops = param_data_dp_numel // TRANSPOSE_BF16_BLOCK_SIZE
remain = param_data_dp_numel % TRANSPOSE_BF16_BLOCK_SIZE
workspace = torch.zeros(
TRANSPOSE_BF16_BLOCK_SIZE * 2, dtype=torch.bfloat16, device=bucket_res.device)
residual_space = bucket_res
bf16_space_dp_rank = max(1, data_parallel_rank)
bf16_space = bucket_param_data[
param_data_dp_numel * bf16_space_dp_rank:param_data_dp_numel * (bf16_space_dp_rank + 1)]
for loop in range(loops):
copy_start = loop * TRANSPOSE_BF16_BLOCK_SIZE
copy_end = (loop + 1) * TRANSPOSE_BF16_BLOCK_SIZE
workspace_convert_view = workspace[:TRANSPOSE_BF16_BLOCK_SIZE * 2]
workspace[:TRANSPOSE_BF16_BLOCK_SIZE].copy_(residual_space[copy_start: copy_end])
workspace[TRANSPOSE_BF16_BLOCK_SIZE:TRANSPOSE_BF16_BLOCK_SIZE * 2].copy_(
bf16_space[copy_start: copy_end])
shard_fp32_main_param_view[copy_start * 2: copy_end * 2].copy_(
workspace_convert_view.view(2, -1).transpose(1, 0).reshape(-1).contiguous())
if remain > 0:
workspace_convert_view = workspace[:remain * 2]
workspace[:remain].copy_(residual_space[-remain:])
workspace[remain:remain * 2].copy_(bf16_space[-remain:])
shard_fp32_main_param_view[-remain * 2:].copy_(
workspace_convert_view.view(2, -1).transpose(1, 0).reshape(-1).contiguous())
if not self.first_sub_flag:
shard_main_param_int32_view_bucket[:param_data_dp_numel].sub_(32768)
def fp32_tensor_convert_to_fp16_tensor(self):
"""
fp32(0p0p0p0p) -> fp32(0'p0'p0'p0'p) -> res(0000) + bf16(pppp)
Transform the fp32 tensor in the continuous memory block
into the bf16 data and residual through view transposition.
"""
data_parallel_world_size = torch.distributed.get_world_size(self.data_parallel_group)
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
if data_parallel_world_size == 1:
for shard_res_and_buffer_model_param in self.shard_main_param_res_buffers:
shard_main_param_int32_view_buffer = self.model_param_bucket_and_shard_main_param_int32_view_map[
shard_res_and_buffer_model_param]
shard_main_param_int32_view_buffer.add_(32768)
self.first_sub_flag = False
for shard_fp32_param_fp16_view in self.shard_fp32_param_fp16_view_group:
shard_fp32_param_fp16_view.copy_(
shard_fp32_param_fp16_view.view(-1, 2).transpose(1, 0).reshape(-1).contiguous())
else:
for buffer in self.buffers:
for bucket in buffer.buckets:
bucket_param_data = bucket.param_data
param_data_dp_numel = bucket_param_data.numel() // data_parallel_world_size
bucket_res = self.model_param_bucket_and_res_map[bucket_param_data]
shard_main_param_int32_view_bucket = self.model_param_bucket_and_shard_main_param_int32_view_map[
bucket_param_data]
shard_main_param_int32_view_bucket[:param_data_dp_numel].add_(32768)
self.first_sub_flag = False
bucket_res_position = max(0, data_parallel_rank - 1) * param_data_dp_numel
shard_fp32_main_param_view = bucket_param_data[
bucket_res_position: bucket_res_position + param_data_dp_numel * 2]
loops = param_data_dp_numel // TRANSPOSE_BF16_BLOCK_SIZE
remain = param_data_dp_numel % TRANSPOSE_BF16_BLOCK_SIZE
workspace = torch.zeros(
TRANSPOSE_BF16_BLOCK_SIZE * 2, dtype=torch.bfloat16, device=bucket_res.device)
bf16_space_dp_rank = max(0, data_parallel_rank - 1)
residual_space = bucket_res
bf16_space = bucket_param_data[
param_data_dp_numel * bf16_space_dp_rank:param_data_dp_numel * (bf16_space_dp_rank + 1)]
for loop in range(loops):
workspace_convert_view = workspace[:TRANSPOSE_BF16_BLOCK_SIZE * 2]
workspace_convert_view.copy_(
shard_fp32_main_param_view[
loop * TRANSPOSE_BF16_BLOCK_SIZE * 2: (loop + 1) * TRANSPOSE_BF16_BLOCK_SIZE * 2])
temp = workspace_convert_view.view(-1, 2).transpose(1, 0).reshape(-1).contiguous()
residual_space[loop * TRANSPOSE_BF16_BLOCK_SIZE: (loop + 1) * TRANSPOSE_BF16_BLOCK_SIZE].copy_(
temp[:TRANSPOSE_BF16_BLOCK_SIZE])
bf16_space[loop * TRANSPOSE_BF16_BLOCK_SIZE: (loop + 1) * TRANSPOSE_BF16_BLOCK_SIZE].copy_(
temp[TRANSPOSE_BF16_BLOCK_SIZE: TRANSPOSE_BF16_BLOCK_SIZE * 2])
if remain > 0:
workspace_convert_view = workspace[:remain * 2]
workspace_convert_view.copy_(shard_fp32_main_param_view[-remain * 2:])
temp = workspace_convert_view.view(-1, 2).transpose(1, 0).reshape(-1).contiguous()
residual_space[-remain:].copy_(temp[:remain])
bf16_space[-remain:].copy_(temp[remain: remain * 2])
if data_parallel_rank != 0:
shard_fp32_main_param_view[param_data_dp_numel:param_data_dp_numel * 2].copy_(
shard_fp32_main_param_view[:param_data_dp_numel])
def distributed_optimizer_init( def distributed_optimizer_init(

View File

@ -89,6 +89,7 @@ def patch_megatron_noncore():
patch_training() patch_training()
patch_log_handler() patch_log_handler()
patch_high_availability_feature() patch_high_availability_feature()
patch_optimizer()
def patch_fusions(): def patch_fusions():
@ -312,3 +313,14 @@ def patch_high_availability_feature():
PatchManager.register_patch('megatron.core.optimizer.optimizer.clip_grad_norm_fp32', clip_grad_norm_fp32_wrapper) PatchManager.register_patch('megatron.core.optimizer.optimizer.clip_grad_norm_fp32', clip_grad_norm_fp32_wrapper)
PatchManager.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__', distributed_optimizer_init_wrapper) PatchManager.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__', distributed_optimizer_init_wrapper)
PatchManager.register_patch('megatron.training.training.setup_model_and_optimizer', setup_model_and_optimizer_wrapper) PatchManager.register_patch('megatron.training.training.setup_model_and_optimizer', setup_model_and_optimizer_wrapper)
def patch_optimizer():
if get_modellink_args().reuse_fp32_param:
from mindspeed.optimizer.optimizer import mixed_precision_optimizer_step, reuse_fp32_param_init_wrapper, \
optimizer_config_init_wrapper
from ..core.optimizer.distrib_optimizer import reuse_fp32_param_distrib_optimizer_init_wrapper
PatchManager.register_patch('megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step', mixed_precision_optimizer_step)
PatchManager.register_patch('megatron.core.optimizer.optimizer.Float16OptimizerWithFloat16Params.__init__', reuse_fp32_param_init_wrapper)
PatchManager.register_patch('megatron.core.optimizer.optimizer_config.OptimizerConfig.__init__', optimizer_config_init_wrapper)
PatchManager.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__', reuse_fp32_param_distrib_optimizer_init_wrapper)