ModelLink2/inference.py
wucong 652574b6c4 !1455 prompt-type推理适配
Merge pull request !1455 from wucong/addgen
2024-08-13 06:53:26 +00:00

107 lines
4.0 KiB
Python

# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import modellink
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec, \
get_gpt_layer_local_spec
from megatron.core.transformer.spec_utils import import_module
from megatron.training import get_args, print_rank_0
from megatron.legacy.model import GPTModel
from megatron.training.initialize import initialize_megatron
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from modellink.tasks.inference.text_generation.infer_base import task_factory, add_text_generate_args
from modellink.tasks.inference.text_generation.module import GPTModelInfer, MegatronModuleForCausalLM
def model_provider(pre_process=True, post_process=True) -> Union[GPTModelInfer, GPTModel]:
"""Builds the model.
If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
Returns:
Union[GPTModelInfer, GPTModel]: The returned model
"""
args = get_args()
use_te = args.transformer_impl == "transformer_engine"
print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
if args.use_mcore_models:
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm)
else:
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm)
model = GPTModelInfer(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=False,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
)
else:
if not args.context_parallel_size == 1:
raise ValueError("Context parallelism is only supported with Megatron Core!")
model = GPTModel(
config,
parallel_output=False,
pre_process=pre_process,
post_process=post_process
)
return model
def main():
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'no_load_rng': True,
'no_load_optim': True})
args = get_args()
model = MegatronModuleForCausalLM.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=args.load
)
task_factory(args, model)
if __name__ == "__main__":
main()