diff --git a/examples/gemma/generate_gemma_7b_lora_ptd.sh b/examples/gemma/generate_gemma_7b_lora_ptd.sh index 54bfbb3d3..2fbb8bafb 100644 --- a/examples/gemma/generate_gemma_7b_lora_ptd.sh +++ b/examples/gemma/generate_gemma_7b_lora_ptd.sh @@ -59,5 +59,5 @@ torchrun $DISTRIBUTED_ARGS inference.py \ --lora-r 16 \ --lora-alpha 32 \ --lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_gemma_7b.log diff --git a/examples/llama/generate_llama_13b_lora_ptd.sh b/examples/llama/generate_llama_13b_lora_ptd.sh index 93d1b2fd6..7063cb0fe 100644 --- a/examples/llama/generate_llama_13b_lora_ptd.sh +++ b/examples/llama/generate_llama_13b_lora_ptd.sh @@ -53,5 +53,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --lora-alpha 32 \ --lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \ --make-vocab-size-divisible-by 1 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_llama_13b_lora.log diff --git a/examples/llama/generate_llama_33b_lora_ptd.sh b/examples/llama/generate_llama_33b_lora_ptd.sh index e2afa3718..9009a8a80 100644 --- a/examples/llama/generate_llama_33b_lora_ptd.sh +++ b/examples/llama/generate_llama_33b_lora_ptd.sh @@ -60,5 +60,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --lora-r 16 \ --lora-alpha 32 \ --make-vocab-size-divisible-by 1 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_llama_33b_lora.log diff --git a/examples/llama/generate_llama_65b_lora_ptd.sh b/examples/llama/generate_llama_65b_lora_ptd.sh index b61efd8b1..432dae3f3 100644 --- a/examples/llama/generate_llama_65b_lora_ptd.sh +++ b/examples/llama/generate_llama_65b_lora_ptd.sh @@ -53,5 +53,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --lora-alpha 32 \ --lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \ --make-vocab-size-divisible-by 1 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_llama_65b_lora.log diff --git a/examples/llama/generate_llama_65b_ptd.sh b/examples/llama/generate_llama_65b_ptd.sh index 0a9f48ed3..0d4f71791 100644 --- a/examples/llama/generate_llama_65b_ptd.sh +++ b/examples/llama/generate_llama_65b_ptd.sh @@ -48,5 +48,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --no-gradient-accumulation-fusion \ --exit-on-missing-checkpoint \ --make-vocab-size-divisible-by 1 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_llama_65b.log diff --git a/examples/llama/generate_llama_7b_lora_ptd.sh b/examples/llama/generate_llama_7b_lora_ptd.sh index 3fde06844..e205ca339 100644 --- a/examples/llama/generate_llama_7b_lora_ptd.sh +++ b/examples/llama/generate_llama_7b_lora_ptd.sh @@ -53,5 +53,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --lora-alpha 32 \ --lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \ --make-vocab-size-divisible-by 1 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_llama_7b_lora.log \ No newline at end of file diff --git a/examples/llama2/generate_llama2_13b_lora_ptd.sh b/examples/llama2/generate_llama2_13b_lora_ptd.sh index 2ae5c480f..914126c77 100644 --- a/examples/llama2/generate_llama2_13b_lora_ptd.sh +++ b/examples/llama2/generate_llama2_13b_lora_ptd.sh @@ -59,5 +59,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --lora-r 16 \ --lora-alpha 32 \ --make-vocab-size-divisible-by 1 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_llama2_13b_lora.log diff --git a/examples/llama2/generate_llama2_34b_lora_ptd.sh b/examples/llama2/generate_llama2_34b_lora_ptd.sh index c48f1cf51..380b1b18c 100644 --- a/examples/llama2/generate_llama2_34b_lora_ptd.sh +++ b/examples/llama2/generate_llama2_34b_lora_ptd.sh @@ -62,5 +62,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --make-vocab-size-divisible-by 1 \ --group-query-attention \ --num-query-groups 8 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_llama2_34b_lora.log diff --git a/examples/llama2/generate_llama2_70b_lora_ptd.sh b/examples/llama2/generate_llama2_70b_lora_ptd.sh index a1de86d11..a2146cb8e 100644 --- a/examples/llama2/generate_llama2_70b_lora_ptd.sh +++ b/examples/llama2/generate_llama2_70b_lora_ptd.sh @@ -55,5 +55,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --lora-alpha 32 \ --lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \ --make-vocab-size-divisible-by 1 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_llama2_70b_lora.log diff --git a/examples/llama2/generate_llama2_7b_lora_ptd.sh b/examples/llama2/generate_llama2_7b_lora_ptd.sh index e2e723729..3932d2d2c 100644 --- a/examples/llama2/generate_llama2_7b_lora_ptd.sh +++ b/examples/llama2/generate_llama2_7b_lora_ptd.sh @@ -55,5 +55,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \ --lora-alpha 32 \ --lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \ --make-vocab-size-divisible-by 1 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ | tee logs/generate_llama2_7b_lora.log diff --git a/examples/mcore/mistral/generate_mistral_7b_ptd.sh b/examples/mcore/mistral/generate_mistral_7b_ptd.sh index 3049297e2..455ef8876 100644 --- a/examples/mcore/mistral/generate_mistral_7b_ptd.sh +++ b/examples/mcore/mistral/generate_mistral_7b_ptd.sh @@ -65,5 +65,5 @@ torchrun $DISTRIBUTED_ARGS inference.py \ $GPT_ARGS \ $MOE_ARGS \ --distributed-backend nccl \ - --inference-prompt-type mixtral \ + --prompt-type mixtral \ | tee logs/generate_mixtral.log diff --git a/examples/mistral/generate_mistral_7b_ptd.sh b/examples/mistral/generate_mistral_7b_ptd.sh index 79321c36b..5a7be861f 100644 --- a/examples/mistral/generate_mistral_7b_ptd.sh +++ b/examples/mistral/generate_mistral_7b_ptd.sh @@ -61,5 +61,5 @@ torchrun $DISTRIBUTED_ARGS inference.py \ $GPT_ARGS \ $MOE_ARGS \ --distributed-backend nccl \ - --inference-prompt-type mixtral \ + --prompt-type mixtral \ | tee logs/generate_mixtral.log diff --git a/examples/mixtral/generate_mixtral_8x7b_ptd.sh b/examples/mixtral/generate_mixtral_8x7b_ptd.sh index c0607bc60..2f93b16f2 100644 --- a/examples/mixtral/generate_mixtral_8x7b_ptd.sh +++ b/examples/mixtral/generate_mixtral_8x7b_ptd.sh @@ -66,5 +66,5 @@ torchrun $DISTRIBUTED_ARGS inference.py \ $GPT_ARGS \ $MOE_ARGS \ --distributed-backend nccl \ - --inference-prompt-type llama \ + --prompt-type mixtral \ | tee logs/generate_mixtral.log diff --git a/examples/qwen15/generate_qwen15_32b_lora_chat_ptd.sh b/examples/qwen15/generate_qwen15_32b_lora_chat_ptd.sh index b4b9e273b..003a439bb 100644 --- a/examples/qwen15/generate_qwen15_32b_lora_chat_ptd.sh +++ b/examples/qwen15/generate_qwen15_32b_lora_chat_ptd.sh @@ -52,7 +52,7 @@ torchrun $DISTRIBUTED_ARGS inference.py \ --lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \ --lora-r 16 \ --lora-alpha 32 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ --normalization RMSNorm \ --group-query-attention \ --hidden-dropout 0 \ diff --git a/examples/qwen15/generate_qwen15_72b_lora_chat_ptd.sh b/examples/qwen15/generate_qwen15_72b_lora_chat_ptd.sh index 64ccb10b8..cbb55d0f5 100644 --- a/examples/qwen15/generate_qwen15_72b_lora_chat_ptd.sh +++ b/examples/qwen15/generate_qwen15_72b_lora_chat_ptd.sh @@ -52,7 +52,7 @@ torchrun $DISTRIBUTED_ARGS inference.py \ --lora-target-modules query_key_value dense dense_h_to_4h dense_4h_to_h \ --lora-r 16 \ --lora-alpha 32 \ - --inference-prompt-type 'alpaca' \ + --prompt-type 'alpaca' \ --normalization RMSNorm \ --position-embedding-type rope \ --norm-epsilon 1e-6 \ diff --git a/inference.py b/inference.py index 17399db9e..5c5567ba7 100644 --- a/inference.py +++ b/inference.py @@ -100,20 +100,7 @@ def main(): pretrained_model_name_or_path=args.load ) - system_template = "" - dialog_template = "{instruction}" - - if args.inference_prompt_type == 'alpaca': - system_template = "Below is an instruction that describes a task, paired with an input that provides further " \ - "context. Write a response that appropriately completes the request. " \ - "Please note that you need to think through your response logically and step by step.\n\n" - dialog_template = "### Instruction:\n{instruction}\n\n### Response:" - - elif args.inference_prompt_type == 'mixtral': - system_template = "" - dialog_template = "[INST] {instruction} [/INST] " - - task_factory(args, model, system_template=system_template, dialog_template=dialog_template) + task_factory(args, model) if __name__ == "__main__": diff --git a/modellink/arguments.py b/modellink/arguments.py index 1aab44293..c3bde2919 100644 --- a/modellink/arguments.py +++ b/modellink/arguments.py @@ -405,6 +405,10 @@ def _add_training_args(parser): help='enable deterministic computing for npu') group.add_argument('--jit-compile', action='store_true', default=False, help='Setting jit compile mode to True') + group.add_argument('--prompt-type', type=str, default=None, + choices=['default', 'empty', 'chatglm2', 'chatglm3', 'chatglm3_system', 'chatml', 'chatml_de', 'qwen', 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma'], + help='Which template to use for constructing prompts in training/inference.' + 'e.g., "qwen"') return parser @@ -519,6 +523,14 @@ def _valid_lora(args): raise AssertionError('lora_fusion for CCLoRA is forbidden without sequence_parallel.') +def _validate_inference_args(args): + if args.prompt_type is not None and hasattr(args, "hf_chat_template") and args.hf_chat_template: + raise AssertionError('Prompt-type is forbidden when use huggingface chat template.') + + if hasattr(args, "history_turns") and args.history_turns < 0: + raise AssertionError('History turns of chat must greater than 0.') + + def _validate_moe_expert_capacity_factor(args): if args.moe_expert_capacity_factor is not None: if args.moe_token_dispatcher_type != "alltoall": @@ -615,6 +627,7 @@ def validate_args_decorator(megatron_validate_args): _validate_position_embedding(args) _validate_high_availability(args) _valid_lora(args) + _validate_inference_args(args) _validate_moe_expert_capacity_factor(args) _validate_mla(args) _validate_yarn(args) diff --git a/modellink/tasks/inference/text_generation/infer_base.py b/modellink/tasks/inference/text_generation/infer_base.py index 9ea6d2810..a9ed2c1e2 100644 --- a/modellink/tasks/inference/text_generation/infer_base.py +++ b/modellink/tasks/inference/text_generation/infer_base.py @@ -18,6 +18,7 @@ import time import shutil import logging import subprocess +from copy import deepcopy import torch from torch import distributed as dist @@ -36,14 +37,13 @@ def add_text_generate_args(parser): group.add_argument("--temperature", type=float, default=0.7, help='Sampling temperature.') group.add_argument("--max-length", type=int, default=256, help='Total length of text.') group.add_argument("--max-new-tokens", type=int, default=128, help='Size of the output generated text.') - group.add_argument('--inference-prompt-type', type=str, default='llama', - help="choose the prompt type for inference") group.add_argument('--hf-chat-template', action='store_true', default=False, - help="Using Huggingface chat template") + help="Using Huggingface chat template") group.add_argument('--add-eos-token', nargs='+', type=str, default=[], help="Use additional eos tokens") group.add_argument('--use-kv-cache', action="store_true", default=False, help="Use kv cache to accelerate inference") + group.add_argument('--history-turns', type=int, default=3, help='Chat turns of histories.') return parser @@ -56,7 +56,7 @@ def print_flush(prev_str, curr_str): sys.stdout.write(difference) -def task_factory(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"): +def task_factory(args, model): task_map = { "greedy": task_greedy_search, "do_sample": task_do_sample, @@ -85,40 +85,32 @@ def task_factory(args, model, tokenizer=None, system_template="", dialog_templat task_map.get(task)( args, model, - tokenizer, - system_template=system_template, - dialog_template=dialog_template ) -def task_greedy_search(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"): +def task_greedy_search(args, model): """Greedy Search""" - prompt = "how are you?" - template = system_template + dialog_template - instruction = template.format(instruction=prompt) + instruction = "how are you?" t = time.time() output = model.generate( instruction, max_new_tokens=args.max_new_tokens, - tokenizer=tokenizer, stream=False ) if dist.get_rank() == 0: logging.info("\n=============== Greedy Search ================") - logging.info("\nYou:\n%s\n\nModelLink:\n%s", prompt, output) + logging.info("\nYou:\n%s\n\nModelLink:\n%s", instruction, output) logging.info("==============================================") logging.info("\nElapsed: %ss", round(time.time() - t, 2)) dist.barrier() -def task_do_sample(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"): +def task_do_sample(args, model): """Do Sample""" - prompt = "how are you?" - template = system_template + dialog_template - instruction = template.format(instruction=prompt) + instruction = "how are you?" t = time.time() output = model.generate( @@ -127,24 +119,21 @@ def task_do_sample(args, model, tokenizer=None, system_template="", dialog_templ top_k=args.top_k, top_p=args.top_p, max_new_tokens=args.max_new_tokens, - tokenizer=tokenizer, stream=False ) if dist.get_rank() == 0: logging.info("\n================ Do Sample =================") - logging.info("\nYou:\n%s\n\nModelLink:\n%s", prompt, output) + logging.info("\nYou:\n%s\n\nModelLink:\n%s", instruction, output) logging.info("============================================") logging.info("\nElapsed: %ss", round(time.time() - t, 2)) dist.barrier() -def task_beam_search(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"): +def task_beam_search(args, model): """Beam Search""" - prompt = "how are you?" - template = system_template + dialog_template - instruction = template.format(instruction=prompt) + instruction = "how are you?" t = time.time() output = model.generate( @@ -153,24 +142,21 @@ def task_beam_search(args, model, tokenizer=None, system_template="", dialog_tem top_k=args.top_k, top_p=args.top_p, max_new_tokens=args.max_new_tokens, - tokenizer=tokenizer, stream=False ) if dist.get_rank() == 0: logging.info("\n=============== Beam Search =================") - logging.info("\nYou:\n%s\n\nModelLink:\n%s", prompt, output) + logging.info("\nYou:\n%s\n\nModelLink:\n%s", instruction, output) logging.info("=============================================") logging.info("\nElapsed: %ss", round(time.time() - t, 2)) dist.barrier() -def task_beam_search_with_sampling(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"): +def task_beam_search_with_sampling(args, model): """Beam Search with sampling""" - prompt = "how are you?" - template = system_template + dialog_template - instruction = template.format(instruction=prompt) + instruction = "how are you?" t = time.time() output = model.generate( @@ -180,24 +166,21 @@ def task_beam_search_with_sampling(args, model, tokenizer=None, system_template= top_k=args.top_k, top_p=args.top_p, max_new_tokens=args.max_new_tokens, - tokenizer=tokenizer, stream=False ) if dist.get_rank() == 0: logging.info("\n======== Beam Search with sampling ==========") - logging.info("\nYou:\n%s\n\nModelLink:\n%s", prompt, output) + logging.info("\nYou:\n%s\n\nModelLink:\n%s", instruction, output) logging.info("=============================================") logging.info("\nElapsed: %ss", round(time.time() - t, 2)) dist.barrier() -def task_return_output_log_probs(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"): +def task_return_output_log_probs(args, model): """Returns the probability distribution of tokens""" - prompt = "how are you?" - template = system_template + dialog_template - instruction = template.format(instruction=prompt) + instruction = "how are you?" t = time.time() tokens, log_probs = model.generate( @@ -207,7 +190,6 @@ def task_return_output_log_probs(args, model, tokenizer=None, system_template="" top_p=args.top_p, temperature=args.temperature, max_new_tokens=args.max_new_tokens, - tokenizer=tokenizer, stream=False, detokenize=False, return_output_log_probs=True @@ -221,7 +203,6 @@ def task_return_output_log_probs(args, model, tokenizer=None, system_template="" top_p=args.top_p, temperature=args.temperature, max_new_tokens=args.max_new_tokens, - tokenizer=tokenizer, stream=False, detokenize=False, return_output_log_probs=True, @@ -238,32 +219,77 @@ def task_return_output_log_probs(args, model, tokenizer=None, system_template="" dist.barrier() -def task_chat(args, model, tokenizer=None, system_template="", dialog_template="{instruction}"): - """Interactive dialog mode with multiple rounds of conversation""" +def chat_get_instruction(args, histories_no_template, histories_template, prompt): + instruction = None def get_context(content): - res = system_template + res = "" for q, r in content: if r is None: - res += dialog_template.format(instruction=q) + res += q else: - res += dialog_template.format(instruction=q) + r + res += q + r return res - histories = [] - columns, rows = shutil.get_terminal_size() - output, prompt, instruction = "", "", "" - input_template, response_template = "\n\nYou >> ", "\nModelLink:\n" + if args.hf_chat_template or args.prompt_type is not None: + # Handle conversation history, there can be a better solution + if len(histories_template) > 2 * args.history_turns: + histories_template.pop(0) + histories_template.pop(0) + + histories_template.append({"role": "user", "content": prompt}) + + # use llamafactory template, We need to build the intermediate format ourselves + instruction = deepcopy(histories_template) + else: + # not use llamafactory template,keep old process + histories_no_template.append((prompt, None)) + instruction = get_context(histories_no_template) + histories_no_template.pop() + + return instruction + + +def chat_print_and_update_histories(args, responses, histories_no_template, histories_template, prompt): + response_template = "\nModelLink:\n" + output = "" + + if dist.get_rank() == 0: + sys.stdout.write(response_template) + + prev = "" + for output in responses: + if dist.get_rank() == 0: + curr = output.replace("�", "") + print_flush(prev, curr) + prev = curr + + # old propress + if args.hf_chat_template or args.prompt_type is not None: + histories_template.append({"role": "assistant", "content": output}) + else: + histories_no_template.append((prompt, output)) + if len(histories_no_template) > 3: + histories_no_template.pop() + + return output + + +def task_chat(args, model): + """Interactive dialog mode with multiple rounds of conversation""" + + histories_no_template = [] + histories_template = [] + instruction = None + prompt = "" + input_template = "\n\nYou >> " command_clear = ["clear"] - messages = [] - if args.hf_chat_template: - from megatron.training import get_tokenizer - tokenizer = get_tokenizer().tokenizer + while True: terminate_runs = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device()) if dist.get_rank() == 0: - if not histories: + if not histories_no_template and not histories_template: logging.info("===========================================================") logging.info("1. If you want to quit, please entry one of [q, quit, exit]") logging.info("2. To create new title, please entry one of [clear, new]") @@ -277,25 +303,15 @@ def task_chat(args, model, tokenizer=None, system_template="", dialog_template=" if prompt.strip() in ["clear", "new"]: subprocess.call(command_clear) - histories = [] - messages = [] + histories_no_template = [] + histories_template = [] continue if not prompt.strip(): continue + + instruction = chat_get_instruction(args, histories_no_template, histories_template, prompt) - histories.append((prompt, None)) - instruction = get_context(histories) - histories.pop() - messages.append( - {"role": "user", "content": prompt} - ) - if args.hf_chat_template: - instruction = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) dist.all_reduce(terminate_runs) dist.barrier() @@ -307,23 +323,10 @@ def task_chat(args, model, tokenizer=None, system_template="", dialog_template=" do_sample=True, top_k=args.top_k, top_p=args.top_p, - tokenizer=tokenizer, + tokenizer=None, temperature=args.temperature, max_new_tokens=args.max_new_tokens, stream=True ) - if dist.get_rank() == 0: - sys.stdout.write(response_template) - - prev = "" - for output in responses: - if dist.get_rank() == 0: - curr = output.replace("�", "") - print_flush(prev, curr) - prev = curr - - histories.append((prompt, output)) - messages.append( - {"role": "assistant", "content": output} - ) + chat_print_and_update_histories(args, responses, histories_no_template, histories_template, prompt) \ No newline at end of file diff --git a/modellink/tasks/inference/text_generation/module.py b/modellink/tasks/inference/text_generation/module.py index bce100e92..8fe138c0d 100644 --- a/modellink/tasks/inference/text_generation/module.py +++ b/modellink/tasks/inference/text_generation/module.py @@ -18,7 +18,6 @@ import abc import logging from typing import Optional, Union - import torch from torch import distributed as dist from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP @@ -27,6 +26,8 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.training import get_args, global_vars from megatron.core import parallel_state +from modellink.tasks.preprocess.templates import Template, get_model_template + class MegatronModuleForCausalLMABC(torch.nn.Module, abc.ABC): """ @@ -189,6 +190,9 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC): self.greedy_search_or_sampling = greedy_search_or_sampling self.beam_search_in_sampling = beam_search self.broadcast_float_list = broadcast_float_list + self.template = None + if hasattr(args, "prompt_type") and args.prompt_type is not None: + self.template = get_model_template(args.prompt_type.strip()) @staticmethod def _ids_check(ids, tokenizer): @@ -274,6 +278,14 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC): args.master_rank = master_rank args.micro_batch_size = len(context_tokens) + stop_token = [args.eos_id] + stop_ids + + if hasattr(args, "prompt_type") and args.prompt_type is not None: + stop_ids = stop_ids + [self.tokenizer.convert_tokens_to_ids(token) for token in self.template.stop_words] + \ + [self.tokenizer.eos_token_id] + + stop_token = [args.eos_id] + stop_ids + # ======================================= # Get the streaming tokens generator # ======================================= @@ -282,7 +294,7 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC): args.model[0], context_tokens, beam_size=self.num_beams, - stop_token=[args.eos_id] + stop_ids, + stop_token=stop_token, num_return_gen=self.num_return_sequences, length_penalty=self.length_penalty ) @@ -361,29 +373,82 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC): if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + def _encode_no_template(self, input_ids): + context_tokens = [[]] + if isinstance(input_ids, str): + context_tokens = [self.tokenizer.encode(input_ids)] + elif torch.is_tensor(input_ids): + if len(input_ids.shape) == 1: + context_tokens = input_ids.unsqueeze(0).numpy().tolist() + elif len(input_ids.shape) == 2: + context_tokens = input_ids.numpy().tolist() + elif isinstance(input_ids, (tuple, list)): + if len(input_ids) and isinstance(input_ids[0], (tuple, list)): + context_tokens = input_ids + elif len(input_ids) and isinstance(input_ids[0], int): + context_tokens = [input_ids] + elif len(input_ids) and isinstance(input_ids[0], str): + context_tokens = [self.tokenizer.encode(val) for val in input_ids] + else: + raise TypeError("Please check input_ids in correct type.") + + return context_tokens + + + def _encode_by_template(self, input_ids): + context_tokens = [] + + if input_ids is None: + return [[]] + response_prompt = [{"role": "assistant", "content": ""}] + if len(input_ids) and isinstance(input_ids, str): + paired_messages = [{"role": "user", "content": "{}".format(input_ids)}] + response_prompt + tokens, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=paired_messages, tools="") + context_tokens.append(tokens) + elif len(input_ids) and isinstance(input_ids[0], (dict)): + paired_messages = input_ids + response_prompt + tokens, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=paired_messages, tools="") + context_tokens.append(tokens) + elif len(input_ids) and isinstance(input_ids[0], (str)): + for query in input_ids: + paired_messages = [{"role": "user", "content": "{}".format(query)}] + response_prompt + tokens, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=paired_messages, tools="") + context_tokens.append(tokens) + elif len(input_ids) and isinstance(input_ids[0], (tuple, list)): + for val in input_ids: + if len(val) and isinstance(val, (tuple, list)): + paired_messages = val + response_prompt + tokens, _ = self.template.encode_oneturn(tokenizer=self.tokenizer, messages=paired_messages, tools="") + context_tokens.append(tokens) + else: + raise TypeError("Please check input_ids in correct type.") + + + return context_tokens if len(context_tokens) > 0 else [context_tokens] + + def _tokenize(self, input_ids): context_tokens = [[]] broadcast_rank = torch.zeros(dist.get_world_size(), dtype=torch.int64, device=torch.device(torch.cuda.current_device())) - if input_ids is not None and len(input_ids) > 0: - if isinstance(input_ids, str): - context_tokens = [self.tokenizer.encode(input_ids)] - elif torch.is_tensor(input_ids): - if len(input_ids.shape) == 1: - context_tokens = input_ids.unsqueeze(0).numpy().tolist() - elif len(input_ids.shape) == 2: - context_tokens = input_ids.numpy().tolist() - elif isinstance(input_ids, (tuple, list)): - if len(input_ids) and isinstance(input_ids[0], (tuple, list)): - context_tokens = input_ids - elif len(input_ids) and isinstance(input_ids[0], int): - context_tokens = [input_ids] - elif len(input_ids) and isinstance(input_ids[0], str): - context_tokens = [self.tokenizer.encode(val) for val in input_ids] + args = get_args() + if args.hf_chat_template: + if not hasattr(self.tokenizer, "apply_chat_template"): + raise AssertionError('The tokenizer has no Huggingface chat template, Please use chat model.') + + context_tokens = [self.tokenizer.apply_chat_template( + input_ids, + tokenize=True, + add_generation_prompt=True + )] + elif self.template is None: + context_tokens = self._encode_no_template(input_ids) else: - raise TypeError("Please check input_ids in correct type.") + context_tokens = self._encode_by_template(input_ids) + broadcast_rank[dist.get_rank()] = 1 diff --git a/modellink/tasks/preprocess/data_handler.py b/modellink/tasks/preprocess/data_handler.py index 531b05b3f..0ada8864e 100644 --- a/modellink/tasks/preprocess/data_handler.py +++ b/modellink/tasks/preprocess/data_handler.py @@ -26,8 +26,7 @@ from datasets import load_dataset from megatron.core.datasets import indexed_dataset -from modellink.tasks.preprocess.templates import Prompter, AlpacaTemplate -from modellink.tasks.preprocess.templates import get_template_and_fix_tokenizer +from modellink.tasks.preprocess.templates import Prompter, AlpacaTemplate, get_model_template from .utils import get_dataset_list, get_handler_dataset_attr, load_single_dataset, merge_dataset, align_dataset @@ -201,7 +200,7 @@ class LlamaFactoryInstructionHandler(BaseDatasetHandler): self.args.output_prefix = self.args.output_prefix + "_packed" self.ignored_label = -100 self.is_multi_turn = True - self.llama_factory_template = get_template_and_fix_tokenizer(tokenizer.tokenizer, args.lla_fact_ins_template.strip()) + self.llama_factory_template = get_model_template(args.prompt_type.strip()) def _format_msg(self, sample): return sample diff --git a/modellink/tasks/preprocess/templates.py b/modellink/tasks/preprocess/templates.py index 1c0017674..6da986834 100644 --- a/modellink/tasks/preprocess/templates.py +++ b/modellink/tasks/preprocess/templates.py @@ -88,6 +88,7 @@ class Template: format_observation: "Formatter" format_tools: "Formatter" format_separator: "Formatter" + format_prefix: "Formatter" default_system: str stop_words: List[str] efficient_eos: bool @@ -142,16 +143,18 @@ class Template: ) -> Sequence[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. - Turn 0: system + query resp + Turn 0: prefix + system + query resp Turn t: sep + query resp """ system = system or self.default_system encoded_messages = [] for i, message in enumerate(messages): elements = [] - if i == 0 and (system or tools or self.force_system): - tool_text = self.format_tools.apply(content=tools)[0] if tools else "" - elements += self.format_system.apply(content=(system + tool_text)) + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) elif i > 0 and i % 2 == 0: elements += self.format_separator.apply() @@ -182,7 +185,7 @@ class Template: if len(elem) != 0: token_ids += tokenizer.encode(elem, add_special_tokens=False) elif isinstance(elem, dict): - token_ids += [tokenizerr.convert_tokens_to_ids(elem.get("token"))] + token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] elif isinstance(elem, set): if "bos_token" in elem and tokenizer.bos_token_id is not None: token_ids += [tokenizer.bos_token_id] @@ -241,9 +244,11 @@ class Llama2Template(Template): for i, message in enumerate(messages): elements = [] system_text = "" - if i == 0 and (system or tools or self.force_system): - tool_text = self.format_tools.apply(content=tools)[0] if tools else "" - system_text = self.format_system.apply(content=(system + tool_text))[0] + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + system_text = self.format_system.apply(content=(system + tool_text))[0] elif i > 0 and i % 2 == 0: elements += self.format_separator.apply() @@ -270,16 +275,21 @@ def get_templates() -> Dict[str, Template]: return templates -def get_template_and_fix_tokenizer( - tokenizer: "PreTrainedTokenizer", - name: Optional[str] = None, -): +def get_model_template(name): if name is None: template = templates["empty"] # placeholder else: template = get_templates().get(name, None) if template is None: raise ValueError("Template {} does not exist.".format(name)) + return template + + +def fix_model_tokenizer( + tokenizer: "PreTrainedTokenizer", + name: Optional[str] = None, +): + template = get_model_template(name) stop_words = template.stop_words if template.replace_eos: @@ -321,6 +331,7 @@ def _register_template( format_observation: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None, format_separator: Optional["Formatter"] = None, + format_prefix: Optional["Formatter"] = None, default_system: str = "", stop_words: List[str] = [], efficient_eos: bool = False, @@ -360,6 +371,7 @@ def _register_template( default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() + default_prefix_formatter = EmptyFormatter() templates[name] = template_class( format_user=format_user or default_user_formatter, format_assistant=format_assistant or default_assistant_formatter, @@ -368,6 +380,7 @@ def _register_template( format_observation=format_observation or format_user or default_user_formatter, format_tools=format_tools or default_tool_formatter, format_separator=format_separator or default_separator_formatter, + format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, stop_words=stop_words, efficient_eos=efficient_eos, @@ -418,6 +431,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: jinja_template = "" + prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) + if prefix: + jinja_template += "{{ " + prefix + " }}" + if template.default_system: jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" @@ -455,7 +472,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") _register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), format_separator=EmptyFormatter(slots=["\n\n"]), efficient_eos=True, force_system=True, @@ -466,17 +483,18 @@ _register_template( name="chatglm3", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), - format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), - format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), format_observation=StringFormatter( slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] ), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), stop_words=["<|user|>", "<|observation|>"], efficient_eos=True, - force_system=True, ) + _register_template( name="chatglm3_system", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), @@ -544,4 +562,62 @@ _register_template( default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], replace_eos=True, +) + + +_register_template( + name="llama3", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_observation=StringFormatter( + slots=[ + ( + "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|eot_id|>"], + replace_eos=True, +) + + +_register_template( + name="mistral", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="mixtral", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="gemma", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + efficient_eos=True, +) + + +_register_template( + name="llama2", + format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), ) \ No newline at end of file diff --git a/modellink/tokenizer/tokenizer.py b/modellink/tokenizer/tokenizer.py index a72eb69cd..d54b15245 100644 --- a/modellink/tokenizer/tokenizer.py +++ b/modellink/tokenizer/tokenizer.py @@ -19,6 +19,7 @@ from transformers import AutoTokenizer from megatron.training.tokenizer import build_tokenizer as megatron_build_tokenizer from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer +from modellink.tasks.preprocess.templates import fix_model_tokenizer def build_tokenizer(args): @@ -55,6 +56,9 @@ def build_tokenizer(args): else: tokenizer = TokenizerAdaptor(megatron_build_tokenizer(args)) + if hasattr(args, "prompt_type") and args.prompt_type is not None: + fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip()) + return tokenizer diff --git a/preprocess_data.py b/preprocess_data.py index cf4699a63..0cd9b4059 100644 --- a/preprocess_data.py +++ b/preprocess_data.py @@ -97,10 +97,10 @@ def add_data_args(parser): group.add_argument('--keep-newlines', action='store_true', help='Keep newlines between sentences when splitting.') # LlamaFactory - group.add_argument('--lla-fact-ins-template', type=str, default=None, - choices=['chatglm2', 'chatglm3', 'chatglm3_system', 'chatml', 'chatml_de', 'default', 'empty', 'qwen'], + group.add_argument('--prompt-type', type=str, default=None, + choices=['default', 'empty', 'chatglm2', 'chatglm3', 'chatglm3_system', 'chatml', 'chatml_de', 'qwen', 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma'], help='Which template to use for constructing prompts in training.' - 'ex: "qwen"') + 'e.g., "qwen"') group.add_argument("--interleave-probs", default=None, help='Probabilities to sample data from datasets. Use commas to separate multiple datasets. ' 'probabilities should sum to 1. ex: "0.1, 0.2, 0.3, 0.4"') diff --git a/tests/pipeline/llama3-8B/llama3-8B.sh b/tests/pipeline/llama3-8B/llama3-8B.sh index d98c70f34..6a18ffc49 100644 --- a/tests/pipeline/llama3-8B/llama3-8B.sh +++ b/tests/pipeline/llama3-8B/llama3-8B.sh @@ -3,4 +3,5 @@ python tests/pipeline/llama3-8B/test_convert_ckpt_from_huggingface.py pytest -s tests/pipeline/llama3-8B/test_generation.py -pytest -s tests/pipeline/llama3-8B/test_evaluation.py \ No newline at end of file +pytest -s tests/pipeline/llama3-8B/test_evaluation.py +pytest -s tests/pipeline/llama3-8B/test_chat.py \ No newline at end of file diff --git a/tests/pipeline/llama3-8B/param_config.json b/tests/pipeline/llama3-8B/param_config.json index 8fc8b553c..c659c7462 100644 --- a/tests/pipeline/llama3-8B/param_config.json +++ b/tests/pipeline/llama3-8B/param_config.json @@ -41,6 +41,14 @@ "--attention-softmax-in-fp32" ], + "INFERENCE_HF_CHAT_PARAM": [ + "--hf-chat-template" + ], + + "INFERENCE_PROMPT_CHAT_PARAM": [ + "--prompt-type", "llama3" + ], + "EVALUATION_PARAM": [ "--tokenizer-not-use-fast", "--task-data-path", "/home/dataset/eval_dataset/mmlu/test", diff --git a/tests/pipeline/llama3-8B/test_chat.py b/tests/pipeline/llama3-8B/test_chat.py new file mode 100644 index 000000000..5c7d0edcd --- /dev/null +++ b/tests/pipeline/llama3-8B/test_chat.py @@ -0,0 +1,123 @@ +import sys +import os +import nltk +import torch +from torch import distributed as dist +import torch_npu +from tests.common import DistributedTest +from utils import ParamConfig, assert_judge +import modellink +from megatron.legacy.model import GPTModel +from modellink.tasks.inference.text_generation.infer_base import add_text_generate_args, chat_get_instruction, chat_print_and_update_histories + + +class TestGeneration(DistributedTest): + world_size = 8 + + def init(self, config=ParamConfig, chat_type=None): + """ + initialize the environment and arguments + """ + sys.argv = [sys.argv[0]] + config.distributed_param + config.network_size + \ + config.inference_param + config.auxiliary_param + config.tokenizer_param + + if chat_type == "hf_chat": + sys.argv = sys.argv + config.inference_hf_chat_param + elif chat_type == "prompt_chat": + sys.argv = sys.argv + config.inference_prompt_chat_param + + from megatron.training.initialize import initialize_megatron + os.environ.update({"CUDA_DEVICE_MAX_CONNECTIONS": "1"}) + initialize_megatron(extra_args_provider=add_text_generate_args, + args_defaults={'no_load_rng': True, + 'no_load_optim': True}) + from megatron.training import get_args + self.args = get_args() + + + def edit_distance_similarity(self, text1, text2): + """ + edit distance: to compare the similarity between two texts. + """ + distance = nltk.edit_distance(text1, text2) + try: + similarity = 1 - (distance / max(len(text1), len(text2))) + except ZeroDivisionError as e: + raise e + return similarity + + + def run_chat(self, model, turn0outputExpect): + histories_no_template = [] + histories_template = [] + instruction = None + + test_questions = ["你能推荐几本深度学习的书吗?", "上面推荐的书建议学习顺序呢?", "9.11和9.9谁大?"] + + turns = 0 + while turns < 3: + + prompt = test_questions[turns] + + instruction = chat_get_instruction(self.args, histories_no_template, histories_template, prompt) + + responses = model.generate( + instruction, + do_sample=True, + top_k=self.args.top_k, + top_p=self.args.top_p, + tokenizer=None, + temperature=self.args.temperature, + max_new_tokens=self.args.max_new_tokens, + stream=True + ) + output = chat_print_and_update_histories(self.args, responses, histories_no_template, histories_template, prompt) + if torch.distributed.get_rank() == 0: + print("-------------------------------") + print(output) + + if(turns == 0): + similarity1 = self.edit_distance_similarity(output[:30], turn0outputExpect[0][:30]) + similarity2 = self.edit_distance_similarity(output[:30], turn0outputExpect[1][:30]) + print("similarity1:", similarity1) + print("similarity1:", similarity2) + assert_judge(max(similarity1, similarity2) > 0.75) + + turns = turns + 1 + + + def test_hf_chat(self): + """Interactive dialog mode with multiple rounds of conversation""" + self.init(config=ParamConfig, chat_type="hf_chat") + from inference import model_provider + model = GPTModel.from_pretrained( + model_provider=model_provider, + pretrained_model_name_or_path=self.args.load + ) + turn1outputExpect = [] + turn1outputExpect1 = "Here are some highly recommended books on deep learning that can help you dive deeper into the subject:" + turn1outputExpect2 = '''Here are some highly recommended books for deep learning:\n\n**Foundational Books**\n\n1. **"Deep Learning" by Ian Goodfellow, Yoshua Bengio, and Aaron Courville**: This is the bible of deep learning.''' + + turn1outputExpect.append(turn1outputExpect1) + turn1outputExpect.append(turn1outputExpect2) + + self.run_chat(model, turn1outputExpect) + + + def test_prompt_type_chat(self): + """Interactive dialog mode with multiple rounds of conversation""" + self.init(config=ParamConfig, chat_type="prompt_chat") + from inference import model_provider + model = GPTModel.from_pretrained( + model_provider=model_provider, + pretrained_model_name_or_path=self.args.load + ) + turn1outputExpect = [] + turn1outputExpect1 = "Here are some highly recommended books on deep learning that can help you dive deeper into the subject:" + turn1outputExpect2 = '''Here are some highly recommended books for deep learning:\n\n**Foundational Books**\n\n1. **"Deep Learning" by Ian Goodfellow, Yoshua Bengio, and Aaron Courville**: This is the bible of deep learning.''' + + turn1outputExpect.append(turn1outputExpect1) + turn1outputExpect.append(turn1outputExpect2) + + self.run_chat(model, turn1outputExpect) + diff --git a/tests/pipeline/llama3-8B/utils.py b/tests/pipeline/llama3-8B/utils.py index a795ab338..0230420b7 100644 --- a/tests/pipeline/llama3-8B/utils.py +++ b/tests/pipeline/llama3-8B/utils.py @@ -29,6 +29,9 @@ class ParamConfig: evaluation_param = config_file["EVALUATION_PARAM"] auxiliary_param = config_file["AUXILIARY_PARAM"] + inference_hf_chat_param = config_file["INFERENCE_HF_CHAT_PARAM"] + inference_prompt_chat_param = config_file["INFERENCE_PROMPT_CHAT_PARAM"] + def assert_judge(expression): if not expression: diff --git a/tests/pipeline/qwen-7B/param_config.json b/tests/pipeline/qwen-7B/param_config.json index 20a89bf8d..12f7ac0e7 100644 --- a/tests/pipeline/qwen-7B/param_config.json +++ b/tests/pipeline/qwen-7B/param_config.json @@ -52,8 +52,12 @@ "--data-path", "/home/dataset/tune-dataset-qwen-7B/alpaca", "--split", "90,5,5", "--train-iters", "5" - ], + ], + "DISTRIBUTED_PARAM_TP8_PP1": [ + "--tensor-model-parallel-size", "8", + "--pipeline-model-parallel-size", "1" + ], "PROCESS_INSTRUCTION_DATA": [ "--input", "train-00000-of-00001-a09b74b3ef9c3b56, alpaca_zh, sharegpt1, sharegpt2", @@ -64,7 +68,7 @@ "--workers", "4", "--log-interval", "1000", "--append-eod", - "--lla-fact-ins-template", "qwen", + "--prompt-type", "qwen", "--dataset-dir", "/home/dataset/tune-dataset-qwen-7B/lfhandler_tune_dataset/dataset/", "--overwrite-cache" ], @@ -79,7 +83,7 @@ "--workers", "4", "--log-interval", "1000", "--append-eod", - "--lla-fact-ins-template", "qwen", + "--prompt-type", "qwen", "--dataset-dir", "/home/dataset/tune-dataset-qwen-7B/lfhandler_tune_dataset/dataset/", "--overwrite-cache", "--interleave-probs", "0.1, 0.2, 0.3, 0.4", @@ -96,11 +100,52 @@ "--workers", "4", "--log-interval", "1000", "--append-eod", - "--lla-fact-ins-template", "qwen", + "--prompt-type", "qwen", "--dataset-dir", "/home/dataset/tune-dataset-qwen-7B/lfhandler_tune_dataset/dataset/", "--overwrite-cache", "--interleave-probs", "0.1, 0.2, 0.3, 0.4", "--mix-strategy", "interleave_over", "--max-samples", "10" + ], + + + "INFERENCE_PARAM": [ + "--max-new-tokens", "256", + "--tokenizer-not-use-fast", + "--exit-on-missing-checkpoint", + "--attention-softmax-in-fp32", + "--prompt-type", "qwen", + "--seed", "42", + "--load", "/home/dataset/Qwen-7B-v0.1-tp8-pp1/" + ], + + + "BEAM_SEARCH_AUXILIARY_PARAM": [ + "--task", "beam_search", + "--top-p", "0.95", + "--top-k", "50" + ], + + "GREEDY_SEARCH_AUXILIARY_PARAM": [ + "--task", "greedy" + ], + + "DO_SAMPLE_AUXILIARY_PARAM": [ + "--task", "do_sample", + "--top-p", "0.95", + "--top-k", "50" + ], + + "BEAM_SEARCH_WITH_SAMPLING_AUXILIARY_PARAM": [ + "--task", "beam_search_with_sampling", + "--top-p", "0.95", + "--top-k", "50" + ], + + "RETURN_OUTPUT_LOG_PROBS_AUXILIARY_PARAM": [ + "--task", "return_output_log_probs", + "--temperature 0.6", + "--top-p", "0.95", + "--top-k", "50" ] } \ No newline at end of file diff --git a/tests/pipeline/qwen-7B/qwen-7B.sh b/tests/pipeline/qwen-7B/qwen-7B.sh index 453443891..4ed9016b8 100644 --- a/tests/pipeline/qwen-7B/qwen-7B.sh +++ b/tests/pipeline/qwen-7B/qwen-7B.sh @@ -1,4 +1,6 @@ # Provide uniform access for piepline. pytest -s ./tests/pipeline/qwen-7B/test_instruction.py -pytest -s ./tests/pipeline/qwen-7B/test_process_instruction_data.py \ No newline at end of file +pytest -s ./tests/pipeline/qwen-7B/test_process_instruction_data.py +pytest -s ./tests/pipeline/qwen-7B/test_generation.py +pytest -s ./tests/pipeline/qwen-7B/test_generation2.py \ No newline at end of file diff --git a/tests/pipeline/qwen-7B/test_generation.py b/tests/pipeline/qwen-7B/test_generation.py new file mode 100644 index 000000000..db40cd61f --- /dev/null +++ b/tests/pipeline/qwen-7B/test_generation.py @@ -0,0 +1,141 @@ +import sys +import os +import torch +import nltk +from tests.common import DistributedTest +from utils import ParamConfig, assert_judge +import modellink +from megatron.legacy.model import GPTModel +from megatron.training import get_args, get_tokenizer +from megatron.training.initialize import initialize_megatron +from modellink.tasks.inference.text_generation.infer_base import add_text_generate_args + + +class TestGeneration(DistributedTest): + world_size = 8 + + def init(self, config=ParamConfig, task=None): + """ + initialize the environment and arguments + """ + sys.argv = [sys.argv[0]] + config.distributed_param_tp8_pp1 + config.network_size + \ + config.inference_param + config.beam_search_auxliary_param + config.auxiliary_param + config.tokenizer_param + + if task == "beam_search_with_sampling": + sys.argv = sys.argv + config.beam_search_with_sampling_auxliary_param + elif task == "return_output_log_probs": + sys.argv = sys.argv + config.return_output_log_probs_auxliary_param + + os.environ.update({"CUDA_DEVICE_MAX_CONNECTIONS": "1"}) + initialize_megatron(extra_args_provider=add_text_generate_args, + args_defaults={'no_load_rng': True, + 'no_load_optim': True}) + self.args = get_args() + + + def edit_distance_similarity(self, text1, text2): + """ + edit distance: to compare the similarity between two texts. + """ + distance = nltk.edit_distance(text1, text2) + try: + similarity = 1 - (distance / max(len(text1), len(text2))) + except ZeroDivisionError as e: + raise e + return similarity + + + def test_beam_search_with_sampling(self): + """Beam Search with sampling""" + self.init(config=ParamConfig, task="beam_search_with_sampling") + from inference import model_provider + model = GPTModel.from_pretrained( + model_provider=model_provider, + pretrained_model_name_or_path=self.args.load + ) + + instruction = "Give me three tips for staying healthy." + + output = model.generate( + instruction, + num_beams=2, + do_sample=True, + top_k=self.args.top_k, + top_p=self.args.top_p, + max_new_tokens=self.args.max_new_tokens, + tokenizer=None, + stream=False + ) + + expect_output1 = '''1. Get enough sleep. A good night's sleep is important for your physical and mental health.\n2. Eat a balanced diet. Eating a variety of healthy foods can help you get the nutrients your body needs.\n3. Exercise regularly. Exercise can help you maintain a healthy weight, reduce stress, and improve your overall health.''' + + expect_output2 = '''Sure, here are three tips for staying healthy:\n1. Eat a balanced diet that includes fruits, vegetables, whole grains, and lean proteins.\n2. Get regular exercise, such as going for a walk or doing yoga.\n3. Get enough sleep each night, ideally 7-8 hours.''' + + if torch.distributed.get_rank() == 0: + print(output) + tokenizer = get_tokenizer() + + similarity1 = self.edit_distance_similarity(output[:30], expect_output1[:30]) + similarity2 = self.edit_distance_similarity(output[:30], expect_output2[:30]) + print("similarity1:", similarity1) + print("similarity1:", similarity2) + assert_judge(max(similarity1, similarity2) > 0.75) + + + def test_return_output_log_probs(self): + """Returns the probability distribution of tokens""" + self.init(config=ParamConfig, task="return_output_log_probs") + from inference import model_provider + model = GPTModel.from_pretrained( + model_provider=model_provider, + pretrained_model_name_or_path=self.args.load + ) + + instruction = "What is the whether like today?" + + output1, log_probs = model.generate( + instruction, + do_sample=True, + top_k=self.args.top_k, + top_p=self.args.top_p, + temperature=self.args.temperature, + max_new_tokens=self.args.max_new_tokens, + tokenizer=None, + stream=False, + detokenize=False, + return_output_log_probs=True + ) + + if torch.distributed.get_rank() == 0: + tokenizer = get_tokenizer() + print("--------------output1-------------") + print(output1) + print(tokenizer.decode(output1)) + + expected_output1 = [2132, 686, 6761, 389, 1380, 498, 525, 304, 279, 1879, + 13, 576, 9104, 646, 387, 2155, 304, 2155, 7482, 624, + 872, 198, 3838, 374, 279, 9104, 1075, 304, 7148, 5267, + 77091, 198, 785, 9104, 304, 7148, 3351, 374, 39698, 323] + + expected_output1_ext = [2132, 686, 6761, 389, 1380, 498, 525, 7407, 13, 16503, + 498, 3291, 752, 697, 3728, 5267, 872, 198, 29596, 11902, + 198, 77091, 198, 641, 9656, 11902, 11, 432, 594, 39698, + 3351, 13, 576, 9315, 374, 220, 23, 15, 12348, 68723] + expected_output1_ext2 = [2132, 374, 83253, 16916, 3351, 382, 77091, 198, 3838, 374, + 279, 9104, 1075, 3351, 5267, 2610, 525, 264, 10950, 17847, + 13, 279, 198, 3838, 374, 279, 9104, 1075, 3351, 5267, + 2610, 525, 264, 10950, 17847, 13, 279, 198, 3838, 374] + print("--------------log_probs----------------") + print(log_probs.shape) + assert_judge(log_probs.shape[0] == 256) + assert_judge(log_probs.shape[1] == 151936) + + similarity = torch.nn.CosineSimilarity(dim=1) + cos_sim = similarity(torch.tensor(expected_output1[:40]).unsqueeze(0).float().npu(), + output1[:40].unsqueeze(0).float()) + cos_sim = max(cos_sim, similarity(torch.tensor(expected_output1_ext[:40]).unsqueeze(0).float().npu(), + output1[:40].unsqueeze(0).float())) + cos_sim = max(cos_sim, similarity(torch.tensor(expected_output1_ext2[:40]).unsqueeze(0).float().npu(), + output1[:40].unsqueeze(0).float())) + print("similarity1: ", cos_sim) + assert_judge(cos_sim > 0.75) \ No newline at end of file diff --git a/tests/pipeline/qwen-7B/test_generation2.py b/tests/pipeline/qwen-7B/test_generation2.py new file mode 100644 index 000000000..7d86c110b --- /dev/null +++ b/tests/pipeline/qwen-7B/test_generation2.py @@ -0,0 +1,165 @@ +import sys +import os +import torch +import nltk +from tests.common import DistributedTest +from utils import ParamConfig, assert_judge +import modellink +from megatron.legacy.model import GPTModel +from megatron.training import get_args, get_tokenizer +from megatron.training.initialize import initialize_megatron +from modellink.tasks.inference.text_generation.infer_base import add_text_generate_args + + +class TestGeneration(DistributedTest): + world_size = 8 + + def init(self, config=ParamConfig, task=None): + """ + initialize the environment and arguments + """ + sys.argv = [sys.argv[0]] + config.distributed_param_tp8_pp1 + config.network_size + \ + config.inference_param + config.beam_search_auxliary_param + config.auxiliary_param + config.tokenizer_param + + if task == "beam_search": + sys.argv = sys.argv + config.beam_search_auxliary_param + elif task == "greedy": + sys.argv = sys.argv + config.greedy_search_auxliary_param + elif task == "do_sample": + sys.argv = sys.argv + config.do_sample_auxliary_param + + os.environ.update({"CUDA_DEVICE_MAX_CONNECTIONS": "1"}) + initialize_megatron(extra_args_provider=add_text_generate_args, + args_defaults={'no_load_rng': True, + 'no_load_optim': True}) + self.args = get_args() + + + def test_beam_search(self): + """ + load weight to get model and construct the prompts to generate output, + and compare with expected for `beam search`. + """ + self.init(config=ParamConfig, task="beam_search") + from inference import model_provider + model = GPTModel.from_pretrained( + model_provider=model_provider, + pretrained_model_name_or_path=self.args.load + ) + + max_new_tokens = self.args.max_new_tokens + instruction = "如何提高身体素质" + output = model.generate( + instruction, + num_beams=2, + top_k=self.args.top_k, + top_p=self.args.top_p, + max_new_tokens=max_new_tokens, + tokenizer=None, + stream=False, + detokenize=False + ) + + if torch.distributed.get_rank() == 0: + print("----------------------output-------------------------") + print(output) + expected_output1 = [100627, 101099, 100838, 104339, 101194, 3837, 87752, 99639, 6684, 31338, + 96422, 28311, 16, 13, 4891, 251, 248, 68878, 101079, 5122, + 106854, 104102, 71817, 16, 20, 15, 83031, 9370, 15946, 49567, + 102660, 18830, 100316, 101079, 3837, 29524, 99234, 99314, 5373, 107530] + + expected_output2 = [30534, 100627, 101099, 100838, 3837, 73670, 103975, 87752, 101082, 28311, + 16, 13, 4891, 223, 98, 99446, 104579, 5122, 101907, 109635, + 103170, 107151, 5373, 100912, 52510, 116570, 5373, 105349, 5373, 105373, + 33108, 117094, 49567, 102100, 101252, 3837, 101153, 44636, 108461, 5373] + + similarity = torch.nn.CosineSimilarity(dim=1) + cos_sim = similarity(torch.tensor(expected_output1).unsqueeze(0).float().npu(), + output[:40].unsqueeze(0).float()) + cos_sim = max(cos_sim, similarity(torch.tensor(expected_output2).unsqueeze(0).float().npu(), + output[:40].unsqueeze(0).float())) + print("similarity: ", cos_sim) + assert_judge(cos_sim > 0.85) + + + def edit_distance_similarity(self, text1, text2): + """ + edit distance: to compare the similarity between two texts. + """ + distance = nltk.edit_distance(text1, text2) + try: + similarity = 1 - (distance / max(len(text1), len(text2))) + except ZeroDivisionError as e: + raise e + return similarity + + + def test_greedy_search(self): + """ + load weight to get model and construct the prompts to generate output, + and compare with expected for `greedy search`. + """ + self.init(config=ParamConfig, task="greedy") + from inference import model_provider + model = GPTModel.from_pretrained( + model_provider=model_provider, + pretrained_model_name_or_path=self.args.load + ) + + instruction = ["What are the characteristics of Suzhou?", "Introducing the Forbidden City in Beijing."] + output = model.generate(instruction) + expect_output1 = [ + "Suzhou is a city in China. It is known for its beautiful gardens, canals, and classical Chinese architecture. It is also known for its silk production and traditional arts and crafts. The city has a rich cultural heritage and is home to many historic temples and museums. Additionally, Suzhou is known for its cuisine, which features local specialties such as sweet and sour fish and rice cakes." + ] + expect_output2 = [ + 'The Forbidden City is a palace complex in Beijing, China. It was the home of the emperors of China for almost 500 years, from the Ming Dynasty to the end of the Qing Dynasty. The complex covers an area of 72 hectares and has over 9,000 rooms. It is a UNESCO World Heritage Site and one of the most popular tourist attractions in China..' + ] + + expect_output1_seq = "".join(expect_output1) + expect_output2_seq = ''.join(expect_output2) + + if torch.distributed.get_rank() == 0: + print("----------------------output1-------------------------") + print(output[0]) + print("----------------------output2-------------------------") + print(output[1]) + + similarity1 = self.edit_distance_similarity(output[0][:30], expect_output1_seq[:30]) + similarity2 = self.edit_distance_similarity(output[1][:30], expect_output2_seq[:30]) + print("similarity1:", similarity1) + print("similarity2:", similarity2) + assert_judge(similarity1 > 0.85) + assert_judge(similarity2 > 0.85) + + + def test_do_sample(self): + """Do Sample""" + self.init(config=ParamConfig, task="do_sample") + from inference import model_provider + model = GPTModel.from_pretrained( + model_provider=model_provider, + pretrained_model_name_or_path=self.args.load + ) + + instruction = "what is Disneyland?" + + output = model.generate( + [instruction, instruction], + do_sample=True, + top_k=self.args.top_k, + top_p=self.args.top_p, + max_new_tokens=self.args.max_new_tokens, + tokenizer=None, + stream=False + ) + + expect_output1 = "Disneyland Park is an entertainment park located in Anaheim, California, United States. It is owned by the Disney Parks, Experiences and Consumer Products division of the American multinational conglomerate corporation the Walt Disney Company. It is also the first of seven theme parks built at Walt Disney's original vision, where visitors can enjoy various attractions, entertainment, and dining." + expect_output1_seq = "".join(expect_output1) + + if torch.distributed.get_rank() == 0: + print(output) + tokenizer = get_tokenizer() + + similarity1 = self.edit_distance_similarity(output[0][:30], expect_output1_seq[:30]) + print("similarity1:", similarity1) + assert_judge(similarity1 > 0.85) \ No newline at end of file diff --git a/tests/pipeline/qwen-7B/utils.py b/tests/pipeline/qwen-7B/utils.py index ecbe964ea..0585a86db 100644 --- a/tests/pipeline/qwen-7B/utils.py +++ b/tests/pipeline/qwen-7B/utils.py @@ -28,13 +28,24 @@ class ParamConfig: network_size = config_file["NETWORK_SIZE"] tokenizer_param = config_file["TOKENIZER_PARAM"] distributed_param = config_file["DISTRIBUTED_PARAM"] + distributed_param_tp8_pp1 = config_file["DISTRIBUTED_PARAM_TP8_PP1"] auxiliary_param = config_file["AUXILIARY_PARAM"] instruction_param = config_file["INSTRUCTION_PARAM"] output_param = config_file["OUTPUT_PARAM"] + + # prepreocess instruction data instruction_data_param = config_file["PROCESS_INSTRUCTION_DATA"] instruction_data_mix_param1 = config_file["PROCESS_INSTRUCTION_DATA_MIX1"] instruction_data_mix_param2 = config_file["PROCESS_INSTRUCTION_DATA_MIX2"] + # inference + inference_param = config_file["INFERENCE_PARAM"] + beam_search_auxliary_param = config_file["BEAM_SEARCH_AUXILIARY_PARAM"] + greedy_search_auxliary_param = config_file["GREEDY_SEARCH_AUXILIARY_PARAM"] + do_sample_auxliary_param = config_file["DO_SAMPLE_AUXILIARY_PARAM"] + beam_search_with_sampling_auxliary_param = config_file["BEAM_SEARCH_WITH_SAMPLING_AUXILIARY_PARAM"] + return_output_log_probs_auxliary_param = config_file["RETURN_OUTPUT_LOG_PROBS_AUXILIARY_PARAM"] + def assert_judge(expression): if not expression: diff --git a/tests/st/test_gemma_inference_ptd.sh b/tests/st/test_gemma_inference_ptd.sh index 35b051424..aa057e627 100644 --- a/tests/st/test_gemma_inference_ptd.sh +++ b/tests/st/test_gemma_inference_ptd.sh @@ -36,6 +36,7 @@ python3.8 -m torch.distributed.launch $DISTRIBUTED_ARGS ${basepath}/inference.py --geglu \ --input-embeds-norm \ --fp16 \ + --prompt-type gemma \ --micro-batch-size 1 \ --seq-length 8192 \ --max-new-tokens 64 \ diff --git a/tests/st/test_inference.json b/tests/st/test_inference.json index d7addf8e9..56efbe044 100644 --- a/tests/st/test_inference.json +++ b/tests/st/test_inference.json @@ -34,7 +34,7 @@ "lora-r":16, "lora-alpha":32, "lora-target-modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], - "inference-prompt-type": "'alpaca'", + "prompt-type": "'alpaca'", "bf16":null, "task":"greedy" } diff --git a/tests/st/test_llama_inference_ptd.sh b/tests/st/test_llama_inference_ptd.sh index 0f05780e3..3948e39ac 100644 --- a/tests/st/test_llama_inference_ptd.sh +++ b/tests/st/test_llama_inference_ptd.sh @@ -42,6 +42,7 @@ python3.8 -m torch.distributed.launch $DISTRIBUTED_ARGS ${basepath}/inference.py --normalization RMSNorm \ --untie-embeddings-and-output-weights \ --disable-bias-linear \ + --prompt-type llama2 \ --attention-softmax-in-fp32 \ --no-masked-softmax-fusion \ --no-gradient-accumulation-fusion \ diff --git a/tests/st/test_llama_instruction_ptd.sh b/tests/st/test_llama_instruction_ptd.sh index 688607ad5..4a4ec4918 100644 --- a/tests/st/test_llama_instruction_ptd.sh +++ b/tests/st/test_llama_instruction_ptd.sh @@ -28,7 +28,7 @@ export PYTHONPATH=${basepath}:$PYTHONPATH python3.8 -m torch.distributed.launch $DISTRIBUTED_ARGS \ ${basepath}/pretrain_gpt.py \ --tensor-model-parallel-size 2 \ - --pipeline-model-parallel-size 2 \ + --pipeline-model-parallel-size 4 \ --sequence-parallel \ --num-layers 4 \ --hidden-size 4096 \ diff --git a/tests/ut/process_data/test_process_instruction_data_lf.py b/tests/ut/process_data/test_process_instruction_data_lf.py index c9dec7844..3bb7425c8 100644 --- a/tests/ut/process_data/test_process_instruction_data_lf.py +++ b/tests/ut/process_data/test_process_instruction_data_lf.py @@ -26,7 +26,7 @@ class TestProcessInstructionDataLf: "--tokenizer-name-or-path", "/data/qwen-7b/", "--workers", "4", "--log-interval", "1000", - "--lla-fact-ins-template", "qwen" + "--prompt-type", "qwen" ] self.args = get_args() self.tokenizer = build_tokenizer(self.args) @@ -44,7 +44,7 @@ class TestProcessInstructionDataLf: "--tokenizer-name-or-path", "/data/qwen-7b/", "--workers", "4", "--log-interval", "1000", - "--lla-fact-ins-template", "qwen", + "--prompt-type", "qwen", "--map-keys", '{"history":"history"}' ] self.args = get_args() @@ -61,7 +61,7 @@ class TestProcessInstructionDataLf: "--tokenizer-name-or-path", "/data/qwen-7b/", "--workers", "4", "--log-interval", "1000", - "--lla-fact-ins-template", "qwen", + "--prompt-type", "qwen", "--map-keys", '{"system":"system_prompt"}' ] @@ -79,7 +79,7 @@ class TestProcessInstructionDataLf: "--tokenizer-name-or-path", "/data/qwen-7b/", "--workers", "4", "--log-interval", "1000", - "--lla-fact-ins-template", "qwen", + "--prompt-type", "qwen", "--map-keys", '{"messages":"messages", "tags":{"role_tag": "role","content_tag": "content","user_tag": "user","assistant_tag": "assistant","system_tag": "system"} }' ]