mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-02 11:58:26 +08:00
652574b6c4
Merge pull request !1455 from wucong/addgen
107 lines
4.0 KiB
Python
107 lines
4.0 KiB
Python
# coding=utf-8
|
|
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
|
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Union
|
|
|
|
import modellink
|
|
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec, \
|
|
get_gpt_layer_local_spec
|
|
from megatron.core.transformer.spec_utils import import_module
|
|
from megatron.training import get_args, print_rank_0
|
|
from megatron.legacy.model import GPTModel
|
|
from megatron.training.initialize import initialize_megatron
|
|
from megatron.training.arguments import core_transformer_config_from_args
|
|
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
|
|
|
|
from modellink.tasks.inference.text_generation.infer_base import task_factory, add_text_generate_args
|
|
from modellink.tasks.inference.text_generation.module import GPTModelInfer, MegatronModuleForCausalLM
|
|
|
|
|
|
def model_provider(pre_process=True, post_process=True) -> Union[GPTModelInfer, GPTModel]:
|
|
"""Builds the model.
|
|
|
|
If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model.
|
|
|
|
Args:
|
|
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
|
|
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
|
|
|
|
|
|
Returns:
|
|
Union[GPTModelInfer, GPTModel]: The returned model
|
|
"""
|
|
args = get_args()
|
|
use_te = args.transformer_impl == "transformer_engine"
|
|
|
|
print_rank_0('building GPT model ...')
|
|
# Experimental loading arguments from yaml
|
|
if args.yaml_cfg is not None:
|
|
config = core_transformer_config_from_yaml(args, "language_model")
|
|
else:
|
|
config = core_transformer_config_from_args(args)
|
|
|
|
if args.use_mcore_models:
|
|
if args.spec is not None:
|
|
transformer_layer_spec = import_module(args.spec)
|
|
else:
|
|
if use_te:
|
|
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm)
|
|
else:
|
|
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm)
|
|
|
|
model = GPTModelInfer(
|
|
config=config,
|
|
transformer_layer_spec=transformer_layer_spec,
|
|
vocab_size=args.padded_vocab_size,
|
|
max_sequence_length=args.max_position_embeddings,
|
|
pre_process=pre_process,
|
|
post_process=post_process,
|
|
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
|
|
parallel_output=False,
|
|
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
|
|
position_embedding_type=args.position_embedding_type,
|
|
rotary_percent=args.rotary_percent,
|
|
)
|
|
else:
|
|
if not args.context_parallel_size == 1:
|
|
raise ValueError("Context parallelism is only supported with Megatron Core!")
|
|
|
|
model = GPTModel(
|
|
config,
|
|
parallel_output=False,
|
|
pre_process=pre_process,
|
|
post_process=post_process
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def main():
|
|
initialize_megatron(extra_args_provider=add_text_generate_args,
|
|
args_defaults={'no_load_rng': True,
|
|
'no_load_optim': True})
|
|
|
|
args = get_args()
|
|
|
|
model = MegatronModuleForCausalLM.from_pretrained(
|
|
model_provider=model_provider,
|
|
pretrained_model_name_or_path=args.load
|
|
)
|
|
|
|
task_factory(args, model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |