!1455 prompt-type推理适配

Merge pull request !1455 from wucong/addgen
This commit is contained in:
wucong 2024-08-13 06:53:26 +00:00 committed by i-robot
parent 6f42710f0c
commit 652574b6c4
37 changed files with 809 additions and 161 deletions

View File

@ -59,5 +59,5 @@ torchrun $DISTRIBUTED_ARGS inference.py \
--lora-r 16 \
--lora-alpha 32 \
--lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_gemma_7b.log

View File

@ -53,5 +53,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
--lora-alpha 32 \
--lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \
--make-vocab-size-divisible-by 1 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_llama_13b_lora.log

View File

@ -60,5 +60,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
--lora-r 16 \
--lora-alpha 32 \
--make-vocab-size-divisible-by 1 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_llama_33b_lora.log

View File

@ -53,5 +53,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
--lora-alpha 32 \
--lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \
--make-vocab-size-divisible-by 1 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_llama_65b_lora.log

View File

@ -48,5 +48,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
--no-gradient-accumulation-fusion \
--exit-on-missing-checkpoint \
--make-vocab-size-divisible-by 1 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_llama_65b.log

View File

@ -53,5 +53,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
--lora-alpha 32 \
--lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \
--make-vocab-size-divisible-by 1 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_llama_7b_lora.log

View File

@ -59,5 +59,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
--lora-r 16 \
--lora-alpha 32 \
--make-vocab-size-divisible-by 1 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_llama2_13b_lora.log

View File

@ -62,5 +62,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
--make-vocab-size-divisible-by 1 \
--group-query-attention \
--num-query-groups 8 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_llama2_34b_lora.log

View File

@ -55,5 +55,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
--lora-alpha 32 \
--lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \
--make-vocab-size-divisible-by 1 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_llama2_70b_lora.log

View File

@ -55,5 +55,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
--lora-alpha 32 \
--lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \
--make-vocab-size-divisible-by 1 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
| tee logs/generate_llama2_7b_lora.log

View File

@ -65,5 +65,5 @@ torchrun $DISTRIBUTED_ARGS inference.py \
$GPT_ARGS \
$MOE_ARGS \
--distributed-backend nccl \
--inference-prompt-type mixtral \
--prompt-type mixtral \
| tee logs/generate_mixtral.log

View File

@ -61,5 +61,5 @@ torchrun $DISTRIBUTED_ARGS inference.py \
$GPT_ARGS \
$MOE_ARGS \
--distributed-backend nccl \
--inference-prompt-type mixtral \
--prompt-type mixtral \
| tee logs/generate_mixtral.log

View File

@ -66,5 +66,5 @@ torchrun $DISTRIBUTED_ARGS inference.py \
$GPT_ARGS \
$MOE_ARGS \
--distributed-backend nccl \
--inference-prompt-type llama \
--prompt-type mixtral \
| tee logs/generate_mixtral.log

View File

@ -52,7 +52,7 @@ torchrun $DISTRIBUTED_ARGS inference.py \
--lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \
--lora-r 16 \
--lora-alpha 32 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
--normalization RMSNorm \
--group-query-attention \
--hidden-dropout 0 \

View File

@ -52,7 +52,7 @@ torchrun $DISTRIBUTED_ARGS inference.py \
--lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \
--lora-r 16 \
--lora-alpha 32 \
--inference-prompt-type 'alpaca' \
--prompt-type 'alpaca' \
--normalization RMSNorm \
--position-embedding-type rope \
--norm-epsilon 1e-6 \

View File

@ -100,20 +100,7 @@ def main():
pretrained_model_name_or_path=args.load
)
system_template = ""
dialog_template = "{instruction}"
if args.inference_prompt_type == 'alpaca':
system_template = "Below is an instruction that describes a task, paired with an input that provides further " \
"context. Write a response that appropriately completes the request. " \
"Please note that you need to think through your response logically and step by step.\n\n"
dialog_template = "### Instruction:\n{instruction}\n\n### Response:"
elif args.inference_prompt_type == 'mixtral':
system_template = "<s>"
dialog_template = "[INST] {instruction} [/INST] "
task_factory(args, model, system_template=system_template, dialog_template=dialog_template)
task_factory(args, model)
if __name__ == "__main__":

View File

@ -405,6 +405,10 @@ def _add_training_args(parser):
help='enable deterministic computing for npu')
group.add_argument('--jit-compile', action='store_true', default=False,
help='Setting jit compile mode to True')
group.add_argument('--prompt-type', type=str, default=None,
choices=['default', 'empty', 'chatglm2', 'chatglm3', 'chatglm3_system', 'chatml', 'chatml_de', 'qwen', 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma'],
help='Which template to use for constructing prompts in training/inference.'
'e.g., "qwen"')
return parser
@ -519,6 +523,14 @@ def _valid_lora(args):
raise AssertionError('lora_fusion for CCLoRA is forbidden without sequence_parallel.')
def _validate_inference_args(args):
if args.prompt_type is not None and hasattr(args, "hf_chat_template") and args.hf_chat_template:
raise AssertionError('Prompt-type is forbidden when use huggingface chat template.')
if hasattr(args, "history_turns") and args.history_turns < 0:
raise AssertionError('History turns of chat must greater than 0.')
def _validate_moe_expert_capacity_factor(args):
if args.moe_expert_capacity_factor is not None:
if args.moe_token_dispatcher_type != "alltoall":
@ -615,6 +627,7 @@ def validate_args_decorator(megatron_validate_args):
_validate_position_embedding(args)
_validate_high_availability(args)
_valid_lora(args)
_validate_inference_args(args)
_validate_moe_expert_capacity_factor(args)
_validate_mla(args)
_validate_yarn(args)

View File

@ -18,6 +18,7 @@ import time
import shutil
import logging
import subprocess
from copy import deepcopy
import torch
from torch import distributed as dist
@ -36,14 +37,13 @@ def add_text_generate_args(parser):
group.add_argument("--temperature", type=float, default=0.7, help='Sampling temperature.')
group.add_argument("--max-length", type=int, default=256, help='Total length of text.')
group.add_argument("--max-new-tokens", type=int, default=128, help='Size of the output generated text.')
group.add_argument('--inference-prompt-type', type=str, default='llama',
help="choose the prompt type for inference")
group.add_argument('--hf-chat-template', action='store_true', default=False,
help="Using Huggingface chat template")
help="Using Huggingface chat template")
group.add_argument('--add-eos-token', nargs='+', type=str, default=[],
help="Use additional eos tokens")
group.add_argument('--use-kv-cache', action="store_true", default=False,
help="Use kv cache to accelerate inference")
group.add_argument('--history-turns', type=int, default=3, help='Chat turns of histories.')
return parser
@ -56,7 +56,7 @@ def print_flush(prev_str, curr_str):
sys.stdout.write(difference)
def task_factory(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"):
def task_factory(args, model):
task_map = {
"greedy": task_greedy_search,
"do_sample": task_do_sample,
@ -85,40 +85,32 @@ def task_factory(args, model, tokenizer=None, system_template="", dialog_templat
task_map.get(task)(
args,
model,
tokenizer,
system_template=system_template,
dialog_template=dialog_template
)
def task_greedy_search(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"):
def task_greedy_search(args, model):
"""Greedy Search"""
prompt = "how are you?"
template = system_template + dialog_template
instruction = template.format(instruction=prompt)
instruction = "how are you?"
t = time.time()
output = model.generate(
instruction,
max_new_tokens=args.max_new_tokens,
tokenizer=tokenizer,
stream=False
)
if dist.get_rank() == 0:
logging.info("\n=============== Greedy Search ================")
logging.info("\nYou:\n%s\n\nModelLink:\n%s", prompt, output)
logging.info("\nYou:\n%s\n\nModelLink:\n%s", instruction, output)
logging.info("==============================================")
logging.info("\nElapsed: %ss", round(time.time() - t, 2))
dist.barrier()
def task_do_sample(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"):
def task_do_sample(args, model):
"""Do Sample"""
prompt = "how are you?"
template = system_template + dialog_template
instruction = template.format(instruction=prompt)
instruction = "how are you?"
t = time.time()
output = model.generate(
@ -127,24 +119,21 @@ def task_do_sample(args, model, tokenizer=None, system_template="", dialog_templ
top_k=args.top_k,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
tokenizer=tokenizer,
stream=False
)
if dist.get_rank() == 0:
logging.info("\n================ Do Sample =================")
logging.info("\nYou:\n%s\n\nModelLink:\n%s", prompt, output)
logging.info("\nYou:\n%s\n\nModelLink:\n%s", instruction, output)
logging.info("============================================")
logging.info("\nElapsed: %ss", round(time.time() - t, 2))
dist.barrier()
def task_beam_search(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"):
def task_beam_search(args, model):
"""Beam Search"""
prompt = "how are you?"
template = system_template + dialog_template
instruction = template.format(instruction=prompt)
instruction = "how are you?"
t = time.time()
output = model.generate(
@ -153,24 +142,21 @@ def task_beam_search(args, model, tokenizer=None, system_template="", dialog_tem
top_k=args.top_k,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
tokenizer=tokenizer,
stream=False
)
if dist.get_rank() == 0:
logging.info("\n=============== Beam Search =================")
logging.info("\nYou:\n%s\n\nModelLink:\n%s", prompt, output)
logging.info("\nYou:\n%s\n\nModelLink:\n%s", instruction, output)
logging.info("=============================================")
logging.info("\nElapsed: %ss", round(time.time() - t, 2))
dist.barrier()
def task_beam_search_with_sampling(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"):
def task_beam_search_with_sampling(args, model):
"""Beam Search with sampling"""
prompt = "how are you?"
template = system_template + dialog_template
instruction = template.format(instruction=prompt)
instruction = "how are you?"
t = time.time()
output = model.generate(
@ -180,24 +166,21 @@ def task_beam_search_with_sampling(args, model, tokenizer=None, system_template=
top_k=args.top_k,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
tokenizer=tokenizer,
stream=False
)
if dist.get_rank() == 0:
logging.info("\n======== Beam Search with sampling ==========")
logging.info("\nYou:\n%s\n\nModelLink:\n%s", prompt, output)
logging.info("\nYou:\n%s\n\nModelLink:\n%s", instruction, output)
logging.info("=============================================")
logging.info("\nElapsed: %ss", round(time.time() - t, 2))
dist.barrier()
def task_return_output_log_probs(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"):
def task_return_output_log_probs(args, model):
"""Returns the probability distribution of tokens"""
prompt = "how are you?"
template = system_template + dialog_template
instruction = template.format(instruction=prompt)
instruction = "how are you?"
t = time.time()
tokens, log_probs = model.generate(
@ -207,7 +190,6 @@ def task_return_output_log_probs(args, model, tokenizer=None, system_template=""
top_p=args.top_p,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
tokenizer=tokenizer,
stream=False,
detokenize=False,
return_output_log_probs=True
@ -221,7 +203,6 @@ def task_return_output_log_probs(args, model, tokenizer=None, system_template=""
top_p=args.top_p,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
tokenizer=tokenizer,
stream=False,
detokenize=False,
return_output_log_probs=True,
@ -238,32 +219,77 @@ def task_return_output_log_probs(args, model, tokenizer=None, system_template=""
dist.barrier()
def task_chat(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"):
"""Interactive dialog mode with multiple rounds of conversation"""
def chat_get_instruction(args, histories_no_template, histories_template, prompt):
instruction = None
def get_context(content):
res = system_template
res = ""
for q, r in content:
if r is None:
res += dialog_template.format(instruction=q)
res += q
else:
res += dialog_template.format(instruction=q) + r
res += q + r
return res
histories = []
columns, rows = shutil.get_terminal_size()
output, prompt, instruction = "", "", ""
input_template, response_template = "\n\nYou >> ", "\nModelLink:\n"
if args.hf_chat_template or args.prompt_type is not None:
# Handle conversation history, there can be a better solution
if len(histories_template) > 2 * args.history_turns:
histories_template.pop(0)
histories_template.pop(0)
histories_template.append({"role": "user", "content": prompt})
# use llamafactory template, We need to build the intermediate format ourselves
instruction = deepcopy(histories_template)
else:
# not use llamafactory template,keep old process
histories_no_template.append((prompt, None))
instruction = get_context(histories_no_template)
histories_no_template.pop()
return instruction
def chat_print_and_update_histories(args, responses, histories_no_template, histories_template, prompt):
response_template = "\nModelLink:\n"
output = ""
if dist.get_rank() == 0:
sys.stdout.write(response_template)
prev = ""
for output in responses:
if dist.get_rank() == 0:
curr = output.replace("<EFBFBD>", "")
print_flush(prev, curr)
prev = curr
# old propress
if args.hf_chat_template or args.prompt_type is not None:
histories_template.append({"role": "assistant", "content": output})
else:
histories_no_template.append((prompt, output))
if len(histories_no_template) > 3:
histories_no_template.pop()
return output
def task_chat(args, model):
"""Interactive dialog mode with multiple rounds of conversation"""
histories_no_template = []
histories_template = []
instruction = None
prompt = ""
input_template = "\n\nYou >> "
command_clear = ["clear"]
messages = []
if args.hf_chat_template:
from megatron.training import get_tokenizer
tokenizer = get_tokenizer().tokenizer
while True:
terminate_runs = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
if dist.get_rank() == 0:
if not histories:
if not histories_no_template and not histories_template:
logging.info("===========================================================")
logging.info("1. If you want to quit, please entry one of [q, quit, exit]")
logging.info("2. To create new title, please entry one of [clear, new]")
@ -277,25 +303,15 @@ def task_chat(args, model, tokenizer=None, system_template="", dialog_template="
if prompt.strip() in ["clear", "new"]:
subprocess.call(command_clear)
histories = []
messages = []
histories_no_template = []
histories_template = []
continue
if not prompt.strip():
continue
instruction = chat_get_instruction(args, histories_no_template, histories_template, prompt)
histories.append((prompt, None))
instruction = get_context(histories)
histories.pop()
messages.append(
{"role": "user", "content": prompt}
)
if args.hf_chat_template:
instruction = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
dist.all_reduce(terminate_runs)
dist.barrier()
@ -307,23 +323,10 @@ def task_chat(args, model, tokenizer=None, system_template="", dialog_template="
do_sample=True,
top_k=args.top_k,
top_p=args.top_p,
tokenizer=tokenizer,
tokenizer=None,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
stream=True
)
if dist.get_rank() == 0:
sys.stdout.write(response_template)
prev = ""
for output in responses:
if dist.get_rank() == 0:
curr = output.replace("<EFBFBD>", "")
print_flush(prev, curr)
prev = curr
histories.append((prompt, output))
messages.append(
{"role": "assistant", "content": output}
)
chat_print_and_update_histories(args, responses, histories_no_template, histories_template, prompt)

View File

@ -18,7 +18,6 @@ import abc
import logging
from typing import Optional, Union
import torch
from torch import distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
@ -27,6 +26,8 @@ from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.training import get_args, global_vars
from megatron.core import parallel_state
from modellink.tasks.preprocess.templates import Template, get_model_template
class MegatronModuleForCausalLMABC(torch.nn.Module, abc.ABC):
"""
@ -189,6 +190,9 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
self.greedy_search_or_sampling = greedy_search_or_sampling
self.beam_search_in_sampling = beam_search
self.broadcast_float_list = broadcast_float_list
self.template = None
if hasattr(args, "prompt_type") and args.prompt_type is not None:
self.template = get_model_template(args.prompt_type.strip())
@staticmethod
def _ids_check(ids, tokenizer):
@ -274,6 +278,14 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
args.master_rank = master_rank
args.micro_batch_size = len(context_tokens)
stop_token = [args.eos_id] + stop_ids
if hasattr(args, "prompt_type") and args.prompt_type is not None:
stop_ids = stop_ids + [self.tokenizer.convert_tokens_to_ids(token) for token in self.template.stop_words] + \
[self.tokenizer.eos_token_id]
stop_token = [args.eos_id] + stop_ids
# =======================================
# Get the streaming tokens generator
# =======================================
@ -282,7 +294,7 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
args.model[0],
context_tokens,
beam_size=self.num_beams,
stop_token=[args.eos_id] + stop_ids,
stop_token=stop_token,
num_return_gen=self.num_return_sequences,
length_penalty=self.length_penalty
)
@ -361,29 +373,82 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
def _encode_no_template(self, input_ids):
context_tokens = [[]]
if isinstance(input_ids, str):
context_tokens = [self.tokenizer.encode(input_ids)]
elif torch.is_tensor(input_ids):
if len(input_ids.shape) == 1:
context_tokens = input_ids.unsqueeze(0).numpy().tolist()
elif len(input_ids.shape) == 2:
context_tokens = input_ids.numpy().tolist()
elif isinstance(input_ids, (tuple, list)):
if len(input_ids) and isinstance(input_ids[0], (tuple, list)):
context_tokens = input_ids
elif len(input_ids) and isinstance(input_ids[0], int):
context_tokens = [input_ids]
elif len(input_ids) and isinstance(input_ids[0], str):
context_tokens = [self.tokenizer.encode(val) for val in input_ids]
else:
raise TypeError("Please check input_ids in correct type.")
return context_tokens
def _encode_by_template(self, input_ids):
context_tokens = []
if input_ids is None:
return [[]]
response_prompt = [{"role": "assistant", "content": ""}]
if len(input_ids) and isinstance(input_ids, str):
paired_messages = [{"role": "user", "content": "{}".format(input_ids)}] + response_prompt
tokens, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=paired_messages, tools="")
context_tokens.append(tokens)
elif len(input_ids) and isinstance(input_ids[0], (dict)):
paired_messages = input_ids + response_prompt
tokens, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=paired_messages, tools="")
context_tokens.append(tokens)
elif len(input_ids) and isinstance(input_ids[0], (str)):
for query in input_ids:
paired_messages = [{"role": "user", "content": "{}".format(query)}] + response_prompt
tokens, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=paired_messages, tools="")
context_tokens.append(tokens)
elif len(input_ids) and isinstance(input_ids[0], (tuple, list)):
for val in input_ids:
if len(val) and isinstance(val, (tuple, list)):
paired_messages = val + response_prompt
tokens, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=paired_messages, tools="")
context_tokens.append(tokens)
else:
raise TypeError("Please check input_ids in correct type.")
return context_tokens if len(context_tokens) > 0 else [context_tokens]
def _tokenize(self, input_ids):
context_tokens = [[]]
broadcast_rank = torch.zeros(dist.get_world_size(),
dtype=torch.int64,
device=torch.device(torch.cuda.current_device()))
if input_ids is not None and len(input_ids) > 0:
if isinstance(input_ids, str):
context_tokens = [self.tokenizer.encode(input_ids)]
elif torch.is_tensor(input_ids):
if len(input_ids.shape) == 1:
context_tokens = input_ids.unsqueeze(0).numpy().tolist()
elif len(input_ids.shape) == 2:
context_tokens = input_ids.numpy().tolist()
elif isinstance(input_ids, (tuple, list)):
if len(input_ids) and isinstance(input_ids[0], (tuple, list)):
context_tokens = input_ids
elif len(input_ids) and isinstance(input_ids[0], int):
context_tokens = [input_ids]
elif len(input_ids) and isinstance(input_ids[0], str):
context_tokens = [self.tokenizer.encode(val) for val in input_ids]
args = get_args()
if args.hf_chat_template:
if not hasattr(self.tokenizer, "apply_chat_template"):
raise AssertionError('The tokenizer has no Huggingface chat template, Please use chat model.')
context_tokens = [self.tokenizer.apply_chat_template(
input_ids,
tokenize=True,
add_generation_prompt=True
)]
elif self.template is None:
context_tokens = self._encode_no_template(input_ids)
else:
raise TypeError("Please check input_ids in correct type.")
context_tokens = self._encode_by_template(input_ids)
broadcast_rank[dist.get_rank()] = 1

View File

@ -26,8 +26,7 @@ from datasets import load_dataset
from megatron.core.datasets import indexed_dataset
from modellink.tasks.preprocess.templates import Prompter, AlpacaTemplate
from modellink.tasks.preprocess.templates import get_template_and_fix_tokenizer
from modellink.tasks.preprocess.templates import Prompter, AlpacaTemplate, get_model_template
from .utils import get_dataset_list, get_handler_dataset_attr, load_single_dataset, merge_dataset, align_dataset
@ -201,7 +200,7 @@ class LlamaFactoryInstructionHandler(BaseDatasetHandler):
self.args.output_prefix = self.args.output_prefix + "_packed"
self.ignored_label = -100
self.is_multi_turn = True
self.llama_factory_template = get_template_and_fix_tokenizer(tokenizer.tokenizer, args.lla_fact_ins_template.strip())
self.llama_factory_template = get_model_template(args.prompt_type.strip())
def _format_msg(self, sample):
return sample

View File

@ -88,6 +88,7 @@ class Template:
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
format_prefix: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
@ -142,16 +143,18 @@ class Template:
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
if i == 0:
elements += self.format_prefix.apply()
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
@ -182,7 +185,7 @@ class Template:
if len(elem) != 0:
token_ids += tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids += [tokenizerr.convert_tokens_to_ids(elem.get("token"))]
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id]
@ -241,9 +244,11 @@ class Llama2Template(Template):
for i, message in enumerate(messages):
elements = []
system_text = ""
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
if i == 0:
elements += self.format_prefix.apply()
if system or tools:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
@ -270,16 +275,21 @@ def get_templates() -> Dict[str, Template]:
return templates
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
):
def get_model_template(name):
if name is None:
template = templates["empty"] # placeholder
else:
template = get_templates().get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
return template
def fix_model_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
):
template = get_model_template(name)
stop_words = template.stop_words
if template.replace_eos:
@ -321,6 +331,7 @@ def _register_template(
format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "",
stop_words: List[str] = [],
efficient_eos: bool = False,
@ -360,6 +371,7 @@ def _register_template(
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
templates[name] = template_class(
format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter,
@ -368,6 +380,7 @@ def _register_template(
format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system,
stop_words=stop_words,
efficient_eos=efficient_eos,
@ -418,6 +431,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
if prefix:
jinja_template += "{{ " + prefix + " }}"
if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
@ -455,7 +472,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
_register_template(
name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
force_system=True,
@ -466,17 +483,18 @@ _register_template(
name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
force_system=True,
)
_register_template(
name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@ -544,4 +562,62 @@ _register_template(
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
)
_register_template(
name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="mixtral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
)
_register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
)
_register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
)

View File

@ -19,6 +19,7 @@ from transformers import AutoTokenizer
from megatron.training.tokenizer import build_tokenizer as megatron_build_tokenizer
from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding
from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from modellink.tasks.preprocess.templates import fix_model_tokenizer
def build_tokenizer(args):
@ -55,6 +56,9 @@ def build_tokenizer(args):
else:
tokenizer = TokenizerAdaptor(megatron_build_tokenizer(args))
if hasattr(args, "prompt_type") and args.prompt_type is not None:
fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip())
return tokenizer

View File

@ -97,10 +97,10 @@ def add_data_args(parser):
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
# LlamaFactory
group.add_argument('--lla-fact-ins-template', type=str, default=None,
choices=['chatglm2', 'chatglm3', 'chatglm3_system', 'chatml', 'chatml_de', 'default', 'empty', 'qwen'],
group.add_argument('--prompt-type', type=str, default=None,
choices=['default', 'empty', 'chatglm2', 'chatglm3', 'chatglm3_system', 'chatml', 'chatml_de', 'qwen', 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma'],
help='Which template to use for constructing prompts in training.'
'ex: "qwen"')
'e.g., "qwen"')
group.add_argument("--interleave-probs", default=None,
help='Probabilities to sample data from datasets. Use commas to separate multiple datasets. '
'probabilities should sum to 1. ex: "0.1, 0.2, 0.3, 0.4"')

View File

@ -3,4 +3,5 @@
python tests/pipeline/llama3-8B/test_convert_ckpt_from_huggingface.py
pytest -s tests/pipeline/llama3-8B/test_generation.py
pytest -s tests/pipeline/llama3-8B/test_evaluation.py
pytest -s tests/pipeline/llama3-8B/test_evaluation.py
pytest -s tests/pipeline/llama3-8B/test_chat.py

View File

@ -41,6 +41,14 @@
"--attention-softmax-in-fp32"
],
"INFERENCE_HF_CHAT_PARAM": [
"--hf-chat-template"
],
"INFERENCE_PROMPT_CHAT_PARAM": [
"--prompt-type", "llama3"
],
"EVALUATION_PARAM": [
"--tokenizer-not-use-fast",
"--task-data-path", "/home/dataset/eval_dataset/mmlu/test",

View File

@ -0,0 +1,123 @@
import sys
import os
import nltk
import torch
from torch import distributed as dist
import torch_npu
from tests.common import DistributedTest
from utils import ParamConfig, assert_judge
import modellink
from megatron.legacy.model import GPTModel
from modellink.tasks.inference.text_generation.infer_base import add_text_generate_args, chat_get_instruction, chat_print_and_update_histories
class TestGeneration(DistributedTest):
world_size = 8
def init(self, config=ParamConfig, chat_type=None):
"""
initialize the environment and arguments
"""
sys.argv = [sys.argv[0]] + config.distributed_param + config.network_size + \
config.inference_param + config.auxiliary_param + config.tokenizer_param
if chat_type == "hf_chat":
sys.argv = sys.argv + config.inference_hf_chat_param
elif chat_type == "prompt_chat":
sys.argv = sys.argv + config.inference_prompt_chat_param
from megatron.training.initialize import initialize_megatron
os.environ.update({"CUDA_DEVICE_MAX_CONNECTIONS": "1"})
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'no_load_rng': True,
'no_load_optim': True})
from megatron.training import get_args
self.args = get_args()
def edit_distance_similarity(self, text1, text2):
"""
edit distance: to compare the similarity between two texts.
"""
distance = nltk.edit_distance(text1, text2)
try:
similarity = 1 - (distance / max(len(text1), len(text2)))
except ZeroDivisionError as e:
raise e
return similarity
def run_chat(self, model, turn0outputExpect):
histories_no_template = []
histories_template = []
instruction = None
test_questions = ["你能推荐几本深度学习的书吗?", "上面推荐的书建议学习顺序呢?", "9.11和9.9谁大?"]
turns = 0
while turns < 3:
prompt = test_questions[turns]
instruction = chat_get_instruction(self.args, histories_no_template, histories_template, prompt)
responses = model.generate(
instruction,
do_sample=True,
top_k=self.args.top_k,
top_p=self.args.top_p,
tokenizer=None,
temperature=self.args.temperature,
max_new_tokens=self.args.max_new_tokens,
stream=True
)
output = chat_print_and_update_histories(self.args, responses, histories_no_template, histories_template, prompt)
if torch.distributed.get_rank() == 0:
print("-------------------------------")
print(output)
if(turns == 0):
similarity1 = self.edit_distance_similarity(output[:30], turn0outputExpect[0][:30])
similarity2 = self.edit_distance_similarity(output[:30], turn0outputExpect[1][:30])
print("similarity1:", similarity1)
print("similarity1:", similarity2)
assert_judge(max(similarity1, similarity2) > 0.75)
turns = turns + 1
def test_hf_chat(self):
"""Interactive dialog mode with multiple rounds of conversation"""
self.init(config=ParamConfig, chat_type="hf_chat")
from inference import model_provider
model = GPTModel.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=self.args.load
)
turn1outputExpect = []
turn1outputExpect1 = "Here are some highly recommended books on deep learning that can help you dive deeper into the subject:"
turn1outputExpect2 = '''Here are some highly recommended books for deep learning:\n\n**Foundational Books**\n\n1. **"Deep Learning" by Ian Goodfellow, Yoshua Bengio, and Aaron Courville**: This is the bible of deep learning.'''
turn1outputExpect.append(turn1outputExpect1)
turn1outputExpect.append(turn1outputExpect2)
self.run_chat(model, turn1outputExpect)
def test_prompt_type_chat(self):
"""Interactive dialog mode with multiple rounds of conversation"""
self.init(config=ParamConfig, chat_type="prompt_chat")
from inference import model_provider
model = GPTModel.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=self.args.load
)
turn1outputExpect = []
turn1outputExpect1 = "Here are some highly recommended books on deep learning that can help you dive deeper into the subject:"
turn1outputExpect2 = '''Here are some highly recommended books for deep learning:\n\n**Foundational Books**\n\n1. **"Deep Learning" by Ian Goodfellow, Yoshua Bengio, and Aaron Courville**: This is the bible of deep learning.'''
turn1outputExpect.append(turn1outputExpect1)
turn1outputExpect.append(turn1outputExpect2)
self.run_chat(model, turn1outputExpect)

View File

@ -29,6 +29,9 @@ class ParamConfig:
evaluation_param = config_file["EVALUATION_PARAM"]
auxiliary_param = config_file["AUXILIARY_PARAM"]
inference_hf_chat_param = config_file["INFERENCE_HF_CHAT_PARAM"]
inference_prompt_chat_param = config_file["INFERENCE_PROMPT_CHAT_PARAM"]
def assert_judge(expression):
if not expression:

View File

@ -52,8 +52,12 @@
"--data-path", "/home/dataset/tune-dataset-qwen-7B/alpaca",
"--split", "90,5,5",
"--train-iters", "5"
],
],
"DISTRIBUTED_PARAM_TP8_PP1": [
"--tensor-model-parallel-size", "8",
"--pipeline-model-parallel-size", "1"
],
"PROCESS_INSTRUCTION_DATA": [
"--input", "train-00000-of-00001-a09b74b3ef9c3b56, alpaca_zh, sharegpt1, sharegpt2",
@ -64,7 +68,7 @@
"--workers", "4",
"--log-interval", "1000",
"--append-eod",
"--lla-fact-ins-template", "qwen",
"--prompt-type", "qwen",
"--dataset-dir", "/home/dataset/tune-dataset-qwen-7B/lfhandler_tune_dataset/dataset/",
"--overwrite-cache"
],
@ -79,7 +83,7 @@
"--workers", "4",
"--log-interval", "1000",
"--append-eod",
"--lla-fact-ins-template", "qwen",
"--prompt-type", "qwen",
"--dataset-dir", "/home/dataset/tune-dataset-qwen-7B/lfhandler_tune_dataset/dataset/",
"--overwrite-cache",
"--interleave-probs", "0.1, 0.2, 0.3, 0.4",
@ -96,11 +100,52 @@
"--workers", "4",
"--log-interval", "1000",
"--append-eod",
"--lla-fact-ins-template", "qwen",
"--prompt-type", "qwen",
"--dataset-dir", "/home/dataset/tune-dataset-qwen-7B/lfhandler_tune_dataset/dataset/",
"--overwrite-cache",
"--interleave-probs", "0.1, 0.2, 0.3, 0.4",
"--mix-strategy", "interleave_over",
"--max-samples", "10"
],
"INFERENCE_PARAM": [
"--max-new-tokens", "256",
"--tokenizer-not-use-fast",
"--exit-on-missing-checkpoint",
"--attention-softmax-in-fp32",
"--prompt-type", "qwen",
"--seed", "42",
"--load", "/home/dataset/Qwen-7B-v0.1-tp8-pp1/"
],
"BEAM_SEARCH_AUXILIARY_PARAM": [
"--task", "beam_search",
"--top-p", "0.95",
"--top-k", "50"
],
"GREEDY_SEARCH_AUXILIARY_PARAM": [
"--task", "greedy"
],
"DO_SAMPLE_AUXILIARY_PARAM": [
"--task", "do_sample",
"--top-p", "0.95",
"--top-k", "50"
],
"BEAM_SEARCH_WITH_SAMPLING_AUXILIARY_PARAM": [
"--task", "beam_search_with_sampling",
"--top-p", "0.95",
"--top-k", "50"
],
"RETURN_OUTPUT_LOG_PROBS_AUXILIARY_PARAM": [
"--task", "return_output_log_probs",
"--temperature 0.6",
"--top-p", "0.95",
"--top-k", "50"
]
}

View File

@ -1,4 +1,6 @@
# Provide uniform access for piepline.
pytest -s ./tests/pipeline/qwen-7B/test_instruction.py
pytest -s ./tests/pipeline/qwen-7B/test_process_instruction_data.py
pytest -s ./tests/pipeline/qwen-7B/test_process_instruction_data.py
pytest -s ./tests/pipeline/qwen-7B/test_generation.py
pytest -s ./tests/pipeline/qwen-7B/test_generation2.py

View File

@ -0,0 +1,141 @@
import sys
import os
import torch
import nltk
from tests.common import DistributedTest
from utils import ParamConfig, assert_judge
import modellink
from megatron.legacy.model import GPTModel
from megatron.training import get_args, get_tokenizer
from megatron.training.initialize import initialize_megatron
from modellink.tasks.inference.text_generation.infer_base import add_text_generate_args
class TestGeneration(DistributedTest):
world_size = 8
def init(self, config=ParamConfig, task=None):
"""
initialize the environment and arguments
"""
sys.argv = [sys.argv[0]] + config.distributed_param_tp8_pp1 + config.network_size + \
config.inference_param + config.beam_search_auxliary_param + config.auxiliary_param + config.tokenizer_param
if task == "beam_search_with_sampling":
sys.argv = sys.argv + config.beam_search_with_sampling_auxliary_param
elif task == "return_output_log_probs":
sys.argv = sys.argv + config.return_output_log_probs_auxliary_param
os.environ.update({"CUDA_DEVICE_MAX_CONNECTIONS": "1"})
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'no_load_rng': True,
'no_load_optim': True})
self.args = get_args()
def edit_distance_similarity(self, text1, text2):
"""
edit distance: to compare the similarity between two texts.
"""
distance = nltk.edit_distance(text1, text2)
try:
similarity = 1 - (distance / max(len(text1), len(text2)))
except ZeroDivisionError as e:
raise e
return similarity
def test_beam_search_with_sampling(self):
"""Beam Search with sampling"""
self.init(config=ParamConfig, task="beam_search_with_sampling")
from inference import model_provider
model = GPTModel.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=self.args.load
)
instruction = "Give me three tips for staying healthy."
output = model.generate(
instruction,
num_beams=2,
do_sample=True,
top_k=self.args.top_k,
top_p=self.args.top_p,
max_new_tokens=self.args.max_new_tokens,
tokenizer=None,
stream=False
)
expect_output1 = '''1. Get enough sleep. A good night's sleep is important for your physical and mental health.\n2. Eat a balanced diet. Eating a variety of healthy foods can help you get the nutrients your body needs.\n3. Exercise regularly. Exercise can help you maintain a healthy weight, reduce stress, and improve your overall health.'''
expect_output2 = '''Sure, here are three tips for staying healthy:\n1. Eat a balanced diet that includes fruits, vegetables, whole grains, and lean proteins.\n2. Get regular exercise, such as going for a walk or doing yoga.\n3. Get enough sleep each night, ideally 7-8 hours.'''
if torch.distributed.get_rank() == 0:
print(output)
tokenizer = get_tokenizer()
similarity1 = self.edit_distance_similarity(output[:30], expect_output1[:30])
similarity2 = self.edit_distance_similarity(output[:30], expect_output2[:30])
print("similarity1:", similarity1)
print("similarity1:", similarity2)
assert_judge(max(similarity1, similarity2) > 0.75)
def test_return_output_log_probs(self):
"""Returns the probability distribution of tokens"""
self.init(config=ParamConfig, task="return_output_log_probs")
from inference import model_provider
model = GPTModel.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=self.args.load
)
instruction = "What is the whether like today?"
output1, log_probs = model.generate(
instruction,
do_sample=True,
top_k=self.args.top_k,
top_p=self.args.top_p,
temperature=self.args.temperature,
max_new_tokens=self.args.max_new_tokens,
tokenizer=None,
stream=False,
detokenize=False,
return_output_log_probs=True
)
if torch.distributed.get_rank() == 0:
tokenizer = get_tokenizer()
print("--------------output1-------------")
print(output1)
print(tokenizer.decode(output1))
expected_output1 = [2132, 686, 6761, 389, 1380, 498, 525, 304, 279, 1879,
13, 576, 9104, 646, 387, 2155, 304, 2155, 7482, 624,
872, 198, 3838, 374, 279, 9104, 1075, 304, 7148, 5267,
77091, 198, 785, 9104, 304, 7148, 3351, 374, 39698, 323]
expected_output1_ext = [2132, 686, 6761, 389, 1380, 498, 525, 7407, 13, 16503,
498, 3291, 752, 697, 3728, 5267, 872, 198, 29596, 11902,
198, 77091, 198, 641, 9656, 11902, 11, 432, 594, 39698,
3351, 13, 576, 9315, 374, 220, 23, 15, 12348, 68723]
expected_output1_ext2 = [2132, 374, 83253, 16916, 3351, 382, 77091, 198, 3838, 374,
279, 9104, 1075, 3351, 5267, 2610, 525, 264, 10950, 17847,
13, 279, 198, 3838, 374, 279, 9104, 1075, 3351, 5267,
2610, 525, 264, 10950, 17847, 13, 279, 198, 3838, 374]
print("--------------log_probs----------------")
print(log_probs.shape)
assert_judge(log_probs.shape[0] == 256)
assert_judge(log_probs.shape[1] == 151936)
similarity = torch.nn.CosineSimilarity(dim=1)
cos_sim = similarity(torch.tensor(expected_output1[:40]).unsqueeze(0).float().npu(),
output1[:40].unsqueeze(0).float())
cos_sim = max(cos_sim, similarity(torch.tensor(expected_output1_ext[:40]).unsqueeze(0).float().npu(),
output1[:40].unsqueeze(0).float()))
cos_sim = max(cos_sim, similarity(torch.tensor(expected_output1_ext2[:40]).unsqueeze(0).float().npu(),
output1[:40].unsqueeze(0).float()))
print("similarity1: ", cos_sim)
assert_judge(cos_sim > 0.75)

View File

@ -0,0 +1,165 @@
import sys
import os
import torch
import nltk
from tests.common import DistributedTest
from utils import ParamConfig, assert_judge
import modellink
from megatron.legacy.model import GPTModel
from megatron.training import get_args, get_tokenizer
from megatron.training.initialize import initialize_megatron
from modellink.tasks.inference.text_generation.infer_base import add_text_generate_args
class TestGeneration(DistributedTest):
world_size = 8
def init(self, config=ParamConfig, task=None):
"""
initialize the environment and arguments
"""
sys.argv = [sys.argv[0]] + config.distributed_param_tp8_pp1 + config.network_size + \
config.inference_param + config.beam_search_auxliary_param + config.auxiliary_param + config.tokenizer_param
if task == "beam_search":
sys.argv = sys.argv + config.beam_search_auxliary_param
elif task == "greedy":
sys.argv = sys.argv + config.greedy_search_auxliary_param
elif task == "do_sample":
sys.argv = sys.argv + config.do_sample_auxliary_param
os.environ.update({"CUDA_DEVICE_MAX_CONNECTIONS": "1"})
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'no_load_rng': True,
'no_load_optim': True})
self.args = get_args()
def test_beam_search(self):
"""
load weight to get model and construct the prompts to generate output,
and compare with expected for `beam search`.
"""
self.init(config=ParamConfig, task="beam_search")
from inference import model_provider
model = GPTModel.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=self.args.load
)
max_new_tokens = self.args.max_new_tokens
instruction = "如何提高身体素质"
output = model.generate(
instruction,
num_beams=2,
top_k=self.args.top_k,
top_p=self.args.top_p,
max_new_tokens=max_new_tokens,
tokenizer=None,
stream=False,
detokenize=False
)
if torch.distributed.get_rank() == 0:
print("----------------------output-------------------------")
print(output)
expected_output1 = [100627, 101099, 100838, 104339, 101194, 3837, 87752, 99639, 6684, 31338,
96422, 28311, 16, 13, 4891, 251, 248, 68878, 101079, 5122,
106854, 104102, 71817, 16, 20, 15, 83031, 9370, 15946, 49567,
102660, 18830, 100316, 101079, 3837, 29524, 99234, 99314, 5373, 107530]
expected_output2 = [30534, 100627, 101099, 100838, 3837, 73670, 103975, 87752, 101082, 28311,
16, 13, 4891, 223, 98, 99446, 104579, 5122, 101907, 109635,
103170, 107151, 5373, 100912, 52510, 116570, 5373, 105349, 5373, 105373,
33108, 117094, 49567, 102100, 101252, 3837, 101153, 44636, 108461, 5373]
similarity = torch.nn.CosineSimilarity(dim=1)
cos_sim = similarity(torch.tensor(expected_output1).unsqueeze(0).float().npu(),
output[:40].unsqueeze(0).float())
cos_sim = max(cos_sim, similarity(torch.tensor(expected_output2).unsqueeze(0).float().npu(),
output[:40].unsqueeze(0).float()))
print("similarity: ", cos_sim)
assert_judge(cos_sim > 0.85)
def edit_distance_similarity(self, text1, text2):
"""
edit distance: to compare the similarity between two texts.
"""
distance = nltk.edit_distance(text1, text2)
try:
similarity = 1 - (distance / max(len(text1), len(text2)))
except ZeroDivisionError as e:
raise e
return similarity
def test_greedy_search(self):
"""
load weight to get model and construct the prompts to generate output,
and compare with expected for `greedy search`.
"""
self.init(config=ParamConfig, task="greedy")
from inference import model_provider
model = GPTModel.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=self.args.load
)
instruction = ["What are the characteristics of Suzhou?", "Introducing the Forbidden City in Beijing."]
output = model.generate(instruction)
expect_output1 = [
"Suzhou is a city in China. It is known for its beautiful gardens, canals, and classical Chinese architecture. It is also known for its silk production and traditional arts and crafts. The city has a rich cultural heritage and is home to many historic temples and museums. Additionally, Suzhou is known for its cuisine, which features local specialties such as sweet and sour fish and rice cakes."
]
expect_output2 = [
'The Forbidden City is a palace complex in Beijing, China. It was the home of the emperors of China for almost 500 years, from the Ming Dynasty to the end of the Qing Dynasty. The complex covers an area of 72 hectares and has over 9,000 rooms. It is a UNESCO World Heritage Site and one of the most popular tourist attractions in China..'
]
expect_output1_seq = "".join(expect_output1)
expect_output2_seq = ''.join(expect_output2)
if torch.distributed.get_rank() == 0:
print("----------------------output1-------------------------")
print(output[0])
print("----------------------output2-------------------------")
print(output[1])
similarity1 = self.edit_distance_similarity(output[0][:30], expect_output1_seq[:30])
similarity2 = self.edit_distance_similarity(output[1][:30], expect_output2_seq[:30])
print("similarity1:", similarity1)
print("similarity2:", similarity2)
assert_judge(similarity1 > 0.85)
assert_judge(similarity2 > 0.85)
def test_do_sample(self):
"""Do Sample"""
self.init(config=ParamConfig, task="do_sample")
from inference import model_provider
model = GPTModel.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=self.args.load
)
instruction = "what is Disneyland?"
output = model.generate(
[instruction, instruction],
do_sample=True,
top_k=self.args.top_k,
top_p=self.args.top_p,
max_new_tokens=self.args.max_new_tokens,
tokenizer=None,
stream=False
)
expect_output1 = "Disneyland Park is an entertainment park located in Anaheim, California, United States. It is owned by the Disney Parks, Experiences and Consumer Products division of the American multinational conglomerate corporation the Walt Disney Company. It is also the first of seven theme parks built at Walt Disney's original vision, where visitors can enjoy various attractions, entertainment, and dining."
expect_output1_seq = "".join(expect_output1)
if torch.distributed.get_rank() == 0:
print(output)
tokenizer = get_tokenizer()
similarity1 = self.edit_distance_similarity(output[0][:30], expect_output1_seq[:30])
print("similarity1:", similarity1)
assert_judge(similarity1 > 0.85)

View File

@ -28,13 +28,24 @@ class ParamConfig:
network_size = config_file["NETWORK_SIZE"]
tokenizer_param = config_file["TOKENIZER_PARAM"]
distributed_param = config_file["DISTRIBUTED_PARAM"]
distributed_param_tp8_pp1 = config_file["DISTRIBUTED_PARAM_TP8_PP1"]
auxiliary_param = config_file["AUXILIARY_PARAM"]
instruction_param = config_file["INSTRUCTION_PARAM"]
output_param = config_file["OUTPUT_PARAM"]
# prepreocess instruction data
instruction_data_param = config_file["PROCESS_INSTRUCTION_DATA"]
instruction_data_mix_param1 = config_file["PROCESS_INSTRUCTION_DATA_MIX1"]
instruction_data_mix_param2 = config_file["PROCESS_INSTRUCTION_DATA_MIX2"]
# inference
inference_param = config_file["INFERENCE_PARAM"]
beam_search_auxliary_param = config_file["BEAM_SEARCH_AUXILIARY_PARAM"]
greedy_search_auxliary_param = config_file["GREEDY_SEARCH_AUXILIARY_PARAM"]
do_sample_auxliary_param = config_file["DO_SAMPLE_AUXILIARY_PARAM"]
beam_search_with_sampling_auxliary_param = config_file["BEAM_SEARCH_WITH_SAMPLING_AUXILIARY_PARAM"]
return_output_log_probs_auxliary_param = config_file["RETURN_OUTPUT_LOG_PROBS_AUXILIARY_PARAM"]
def assert_judge(expression):
if not expression:

View File

@ -36,6 +36,7 @@ python3.8 -m torch.distributed.launch $DISTRIBUTED_ARGS ${basepath}/inference.py
--geglu \
--input-embeds-norm \
--fp16 \
--prompt-type gemma \
--micro-batch-size 1 \
--seq-length 8192 \
--max-new-tokens 64 \

View File

@ -34,7 +34,7 @@
"lora-r":16,
"lora-alpha":32,
"lora-target-modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
"inference-prompt-type": "'alpaca'",
"prompt-type": "'alpaca'",
"bf16":null,
"task":"greedy"
}

View File

@ -42,6 +42,7 @@ python3.8 -m torch.distributed.launch $DISTRIBUTED_ARGS ${basepath}/inference.py
--normalization RMSNorm \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--prompt-type llama2 \
--attention-softmax-in-fp32 \
--no-masked-softmax-fusion \
--no-gradient-accumulation-fusion \

View File

@ -28,7 +28,7 @@ export PYTHONPATH=${basepath}:$PYTHONPATH
python3.8 -m torch.distributed.launch $DISTRIBUTED_ARGS \
${basepath}/pretrain_gpt.py \
--tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 2 \
--pipeline-model-parallel-size 4 \
--sequence-parallel \
--num-layers 4 \
--hidden-size 4096 \

View File

@ -26,7 +26,7 @@ class TestProcessInstructionDataLf:
"--tokenizer-name-or-path", "/data/qwen-7b/",
"--workers", "4",
"--log-interval", "1000",
"--lla-fact-ins-template", "qwen"
"--prompt-type", "qwen"
]
self.args = get_args()
self.tokenizer = build_tokenizer(self.args)
@ -44,7 +44,7 @@ class TestProcessInstructionDataLf:
"--tokenizer-name-or-path", "/data/qwen-7b/",
"--workers", "4",
"--log-interval", "1000",
"--lla-fact-ins-template", "qwen",
"--prompt-type", "qwen",
"--map-keys", '{"history":"history"}'
]
self.args = get_args()
@ -61,7 +61,7 @@ class TestProcessInstructionDataLf:
"--tokenizer-name-or-path", "/data/qwen-7b/",
"--workers", "4",
"--log-interval", "1000",
"--lla-fact-ins-template", "qwen",
"--prompt-type", "qwen",
"--map-keys", '{"system":"system_prompt"}'
]
@ -79,7 +79,7 @@ class TestProcessInstructionDataLf:
"--tokenizer-name-or-path", "/data/qwen-7b/",
"--workers", "4",
"--log-interval", "1000",
"--lla-fact-ins-template", "qwen",
"--prompt-type", "qwen",
"--map-keys", '{"messages":"messages", "tags":{"role_tag": "role","content_tag": "content","user_tag": "user","assistant_tag": "assistant","system_tag": "system"} }'
]