mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-02 03:48:24 +08:00
117 lines
4.1 KiB
Python
117 lines
4.1 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2020, HUAWEI CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
|
|
from ascendspeed.core.parallel_state import get_tensor_model_parallel_group
|
|
from ascendspeed.core.parallel_state import get_tensor_model_parallel_rank
|
|
from ascendspeed.core.parallel_state import get_tensor_model_parallel_src_rank
|
|
from deepspeed.accelerator import get_accelerator
|
|
|
|
_MAX_DATA_DIM = 5
|
|
|
|
|
|
def _check_data_types(keys, data, target_dtype):
|
|
"""Check that all the keys have the same target data type."""
|
|
for key in keys:
|
|
assert data[key].dtype == target_dtype, '{} has data type {} which '\
|
|
'is different than {}'.format(key, data[key].dtype, target_dtype)
|
|
|
|
|
|
def _build_key_size_numel_dictionaries(keys, data):
|
|
"""Build the size on rank 0 and broadcast."""
|
|
max_dim = _MAX_DATA_DIM
|
|
sizes = [0 for _ in range(max_dim) for _ in keys]
|
|
|
|
# Pack the sizes on rank zero.
|
|
if get_tensor_model_parallel_rank() == 0:
|
|
offset = 0
|
|
for key in keys:
|
|
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
|
|
size = data[key].size()
|
|
for i, s in enumerate(size):
|
|
sizes[i + offset] = s
|
|
offset += max_dim
|
|
|
|
# Move to GPU and broadcast.
|
|
sizes_cuda = get_accelerator().LongTensor(sizes)
|
|
torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),
|
|
group=get_tensor_model_parallel_group())
|
|
|
|
# Move back to cpu and unpack.
|
|
sizes_cpu = sizes_cuda.cpu()
|
|
key_size = {}
|
|
key_numel = {}
|
|
total_numel = 0
|
|
offset = 0
|
|
for key in keys:
|
|
i = 0
|
|
size = []
|
|
numel = 1
|
|
while sizes_cpu[offset + i] > 0:
|
|
this_size = sizes_cpu[offset + i]
|
|
size.append(this_size)
|
|
numel *= this_size
|
|
i += 1
|
|
key_size[key] = size
|
|
key_numel[key] = numel
|
|
total_numel += numel
|
|
offset += max_dim
|
|
|
|
return key_size, key_numel, total_numel
|
|
|
|
|
|
def broadcast_data(keys, data, datatype):
|
|
"""Broadcast data from rank zero of each model parallel group to the
|
|
members of the same model parallel group.
|
|
|
|
Arguments:
|
|
keys: list of keys in the data disctionary to be broadcasted
|
|
data: data dictionary of string keys and cpu tensor values.
|
|
datatype: torch data type of all tensors in data associated
|
|
with keys.
|
|
"""
|
|
# Build (key, size) and (key, number of elements) dictionaries along
|
|
# with the total number of elements on all ranks.
|
|
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
|
|
data)
|
|
|
|
# Pack on rank zero.
|
|
if get_tensor_model_parallel_rank() == 0:
|
|
# Check that all keys have the same data type.
|
|
_check_data_types(keys, data, datatype)
|
|
# Flatten the data associated with the keys
|
|
flatten_data = torch.cat(
|
|
[data[key].contiguous().view(-1) for key in keys], dim=0).to(get_accelerator().device_name())
|
|
else:
|
|
flatten_data = torch.empty(total_numel,
|
|
device=get_accelerator().current_device_name(),
|
|
dtype=datatype)
|
|
|
|
# Broadcast
|
|
torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
|
|
group=get_tensor_model_parallel_group())
|
|
|
|
# Unpack
|
|
output = {}
|
|
offset = 0
|
|
for key in keys:
|
|
size = key_size[key]
|
|
numel = key_numel[key]
|
|
output[key] = flatten_data.narrow(0, offset, numel).view(size)
|
|
offset += numel
|
|
|
|
return output
|