mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-01 19:39:02 +08:00
171 lines
5.3 KiB
Python
171 lines
5.3 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||
|
|
||
|
import argparse
|
||
|
import importlib
|
||
|
import os
|
||
|
import sys
|
||
|
import torch.multiprocessing as mp
|
||
|
|
||
|
# A loader is a python file with at least two functions
|
||
|
# - add_arguments - takes in a parser and adds any arguments needed
|
||
|
# - load_checkpoint - takes in the queue and parsed arguments
|
||
|
|
||
|
# A saver is similar but has save_checkpoint instead of
|
||
|
# load_checkpoint
|
||
|
|
||
|
# The loader and saver process are each given a queue, the loader
|
||
|
# should load the checkpoint and send the weights in messages in the
|
||
|
# following order, the saver should receive them in this order and
|
||
|
# save the checkpoints. A message consists of a python dictionary with
|
||
|
# a "name" for error checking and an entry for each tensor as
|
||
|
# indicated below. Note that the weight sent over the queue are the
|
||
|
# full model weights, nothing split.
|
||
|
|
||
|
# If the loader ever sends "exit" to the queue, that means something
|
||
|
# went wrong and it is exiting.
|
||
|
|
||
|
# - Metadata Namespace with the following attributes:
|
||
|
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
|
||
|
# num_layers - Number of transformer layers
|
||
|
# hidden_size
|
||
|
# seq_length
|
||
|
# num_attention_heads
|
||
|
# max_position_embeddings
|
||
|
# tokenizer_type
|
||
|
# iteration
|
||
|
# params_dtype
|
||
|
# bert_binary_head - Used only if model_type is BERT
|
||
|
# previous_tensor_parallel_size - Optional
|
||
|
# previous_pipeline_parallel_size - Optional
|
||
|
# true_vocab_size
|
||
|
# make_vocab_size_divisble_by
|
||
|
# consumed_train_samples
|
||
|
# consumed_valid_samples
|
||
|
# messages
|
||
|
# {
|
||
|
# "name": "embeddings"
|
||
|
# "position embeddings"
|
||
|
# "word embeddings"
|
||
|
# }
|
||
|
# (for each transformer layer):
|
||
|
# {
|
||
|
# "name": "transformer layer N"
|
||
|
# "input layernorm weight"
|
||
|
# "input layernorm bias"
|
||
|
# "qkv weight"
|
||
|
# "qkv bias"
|
||
|
# "dense weight"
|
||
|
# "dense bias"
|
||
|
# "post layernorm weight"
|
||
|
# "post layernorm bias"
|
||
|
# "mlp l0 weight"
|
||
|
# "mlp l0 bias"
|
||
|
# "mlp l1 weight"
|
||
|
# "mlp l1 bias"
|
||
|
# }
|
||
|
# {
|
||
|
# "name": "final layer norm"
|
||
|
# "weight"
|
||
|
# "bias"
|
||
|
# }
|
||
|
# if present (i.e. for BERT):
|
||
|
# {
|
||
|
# "name": "pooler"
|
||
|
# "weight"
|
||
|
# "bias"
|
||
|
# }
|
||
|
# {
|
||
|
# "name": "lm head"
|
||
|
# "dense weight"
|
||
|
# "dense bias"
|
||
|
# "layernorm weight"
|
||
|
# "layernorm bias"
|
||
|
# }
|
||
|
# {
|
||
|
# "name": "binary head"
|
||
|
# "weight"
|
||
|
# "bias"
|
||
|
# }
|
||
|
# - "done"
|
||
|
|
||
|
|
||
|
def load_plugin(plugin_type, name):
|
||
|
module_name = f"{plugin_type}_{name}"
|
||
|
try:
|
||
|
plugin = importlib.import_module(module_name)
|
||
|
except ModuleNotFoundError:
|
||
|
module_name = name
|
||
|
try:
|
||
|
plugin = importlib.import_module(module_name)
|
||
|
except ModuleNotFoundError:
|
||
|
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
|
||
|
|
||
|
if not hasattr(plugin, 'add_arguments'):
|
||
|
sys.exit(f"{module_name} module is not a plugin. Exiting.")
|
||
|
|
||
|
print(f"Loaded {module_name} as the {plugin_type}.")
|
||
|
return plugin
|
||
|
|
||
|
|
||
|
def main():
|
||
|
|
||
|
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
|
||
|
allow_abbrev=False, conflict_handler='resolve')
|
||
|
|
||
|
parser.add_argument('--model-type', type=str, required=True,
|
||
|
choices=['GPT', 'BERT'],
|
||
|
help='Type of the model')
|
||
|
parser.add_argument('--loader', type=str, default='megatron',
|
||
|
help='Module name to load checkpoint, should be on python path')
|
||
|
parser.add_argument('--saver', type=str, default='megatron',
|
||
|
help='Module name to save checkpoint, shdoul be on python path')
|
||
|
parser.add_argument('--load-dir', type=str, required=True,
|
||
|
help='Directory to load model checkpoint from')
|
||
|
parser.add_argument('--save-dir', type=str, required=True,
|
||
|
help='Directory to save model checkpoint to')
|
||
|
parser.add_argument('--max-queue-size', type=int, default=50,
|
||
|
help='Maximum number of tensors in the queue')
|
||
|
parser.add_argument('--no-checking', action='store_false',
|
||
|
help='Do not perform checking on the name and ordering of weights',
|
||
|
dest='checking')
|
||
|
|
||
|
known_args, _ = parser.parse_known_args()
|
||
|
loader = load_plugin('loader', known_args.loader)
|
||
|
saver = load_plugin('saver', known_args.saver)
|
||
|
|
||
|
loader.add_arguments(parser)
|
||
|
saver.add_arguments(parser)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
queue = mp.Queue(maxsize=args.max_queue_size)
|
||
|
|
||
|
print("Starting saver...")
|
||
|
saver_proc = mp.Process(target=saver.save_model_checkpoint, args=(queue, args))
|
||
|
saver_proc.start()
|
||
|
|
||
|
print("Starting loader...")
|
||
|
loader.load_checkpoint(queue, args)
|
||
|
|
||
|
print("Waiting for saver to complete...")
|
||
|
saver_proc.join()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|
||
|
|