mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 05:17:40 +08:00
!223 (PART 3)在线推理框架ST用例与样例补充
Merge pull request !223 from shishaoyu/master
This commit is contained in:
commit
8e06ed062f
176
README.md
176
README.md
@ -324,9 +324,9 @@ For llama and other LLMs without bias in FFN, the linear transformation in FFN c
|
||||
<tbody>
|
||||
<tr>
|
||||
<td> Baichaun </td>
|
||||
<td> 7B </td>
|
||||
<td> -- </td>
|
||||
<td> 13B </td>
|
||||
<td> -- </td>
|
||||
<td> <a href="https://gitee.com/ascend/AscendSpeed/tree/master/examples/baichuan/generate_baichuan_13B_tp8_pp1.sh">inference</a> </td>
|
||||
<td> -- </td>
|
||||
<td> -- </td>
|
||||
</tr>
|
||||
@ -357,28 +357,28 @@ For llama and other LLMs without bias in FFN, the linear transformation in FFN c
|
||||
<td rowspan="4">LLaMA</td>
|
||||
<td>7B</td>
|
||||
<td> <a href="https://gitee.com/ascend/AscendSpeed/tree/master/examples/alpaca/finetune_chinese_llama_alpaca_7_13_33b_tp4_pp2.sh">lora</a> </td>
|
||||
<td> -- </td>
|
||||
<td> <a href="https://gitee.com/ascend/AscendSpeed/tree/master/examples/llama/generate_llama_7B_tp2_pp2.sh">inference </a> </td>
|
||||
<td> -- </td>
|
||||
<td> <a href="https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json">alpaca_data.json</a> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>13B</td>
|
||||
<td> <a href="https://gitee.com/ascend/AscendSpeed/tree/master/examples/alpaca/finetune_chinese_llama_alpaca_7_13_33b_tp4_pp2.sh">lora</a> </td>
|
||||
<td> -- </td>
|
||||
<td> <a href="https://gitee.com/ascend/AscendSpeed/tree/master/examples/llama/generate_llama_13B_tp8_pp1.sh">inference </a> </td>
|
||||
<td> -- </td>
|
||||
<td> <a href="https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json">alpaca_data.json</a> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>33B</td>
|
||||
<td> <a href="https://gitee.com/ascend/AscendSpeed/tree/master/examples/alpaca/finetune_chinese_llama_alpaca_7_13_33b_tp4_pp2.sh">lora</a> </td>
|
||||
<td> -- </td>
|
||||
<td> <a href="https://gitee.com/ascend/AscendSpeed/tree/master/examples/llama/generate_llama_33B_tp8_pp1.sh">inference </a> </td>
|
||||
<td> -- </td>
|
||||
<td> <a href="https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json">alpaca_data.json</a> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td > 65B </td>
|
||||
<td > -- </td>
|
||||
<td> -- </td>
|
||||
<td> <a href="https://gitee.com/ascend/AscendSpeed/tree/master/examples/llama/generate_llama_65B_tp8_pp1.sh">inference </a> </td>
|
||||
<td> -- </td>
|
||||
<td> -- </td>
|
||||
</tr>
|
||||
@ -386,7 +386,7 @@ For llama and other LLMs without bias in FFN, the linear transformation in FFN c
|
||||
<td>LLaMA2</td>
|
||||
<td>7B</td>
|
||||
<td> -- </td>
|
||||
<td> -- </td>
|
||||
<td> <a href="https://gitee.com/ascend/AscendSpeed/tree/master/examples/llama/generate_llama_7B_tp2_pp2.sh">inference </a> </td>
|
||||
<td> -- </td>
|
||||
<td> -- </td>
|
||||
</tr>
|
||||
@ -395,6 +395,7 @@ For llama and other LLMs without bias in FFN, the linear transformation in FFN c
|
||||
|
||||
|
||||
|
||||
|
||||
### <span id="jump11"> Dataset Processing </span>
|
||||
#### Quick Start
|
||||
|
||||
@ -507,7 +508,15 @@ In addition, `BelleMultiTurnInstructionHandler` is used to handle [belle dataset
|
||||
### <span id="jump12"> Finetune </span>
|
||||
#### Lora
|
||||
|
||||
Now, we support Lora to fine-tune your models. You just need to add this argument in your script to open Lora:
|
||||
Now, we support Lora to fine-tune your models.
|
||||
|
||||
First, you need to install version 0.4.0 of the peft library, like this:
|
||||
```shell
|
||||
pip install peft==0.4.0
|
||||
```
|
||||
You can also choose to install from [the source package in the GitHub repository](https://github.com/huggingface/peft/archive/refs/tags/v0.4.0.tar.gz), so you can modify the setup.py file to avoid some dependency issues.
|
||||
|
||||
Next, you just need to add this argument in your script to open Lora:
|
||||
|
||||
```shell
|
||||
# Llama example
|
||||
@ -550,8 +559,8 @@ AscendSpeed:
|
||||
### <span id="jump13"> Inference </span>
|
||||
Currently, we support the following four strategies for inference:
|
||||
- PTD only
|
||||
- Deepspeed ZeRO only
|
||||
- Deepspeed ZeRO in PIPELINE with TP
|
||||
- DeepSpeed ZeRO only
|
||||
- DeepSpeed ZeRO in PIPELINE with TP
|
||||
- Model fine-tuned with lora
|
||||
|
||||
#### Quick Start
|
||||
@ -559,40 +568,47 @@ Here are some example scripts in different mode mentioned above for you to launc
|
||||
|
||||
***Please Note that:***
|
||||
1. If you want to use the weight from huggingface, please run the weight conversion script first.
|
||||
Take Llama-7B, for example:
|
||||
```bash
|
||||
python tools/ckpt_convert/llama/convert_weights_from_huggingface.py --input-model-dir llama-7b-hf \
|
||||
--output-model-dir llama-7b-tp2-pp2 \
|
||||
--tensor-model-parallel-size 2 \
|
||||
--pipeline-model-parallel-size 2 \
|
||||
--type 7B
|
||||
```
|
||||
Here are some open source model weights available for download:
|
||||
- [Llama-7B](https://huggingface.co/yahma/llama-7b-hf/tree/main)
|
||||
- [Llama-13B](https://huggingface.co/yahma/llama-13b-hf/tree/main)
|
||||
|
||||
Take Llama-7B, for example:
|
||||
|
||||
- PTD only
|
||||
|
||||
```bash
|
||||
python tools/ckpt_convert/llama/convert_weights_from_huggingface.py --input-model-dir llama-7b-hf \
|
||||
--output-model-dir llama-7b-tp2-pp2 \
|
||||
--tensor-model-parallel-size 2 \
|
||||
--pipeline-model-parallel-size 2 \
|
||||
--type 7B
|
||||
```
|
||||
|
||||
- DeepSpeed ZeRO only
|
||||
```bash
|
||||
python tools/ckpt_convert/llama/convert_weights_from_huggingface.py --input-model-dir llama-7b-hf \
|
||||
--output-model-dir llama-7b-deepspeed \
|
||||
--type 7B \
|
||||
--deepspeed
|
||||
```
|
||||
|
||||
2. You need to modify some variables in the shell script such as **model weight path** and **vocab path**.
|
||||
|
||||
- **PTD only:** In this mode, the model is split by pipeline parallel and tensor parallel mode in megatron ways.
|
||||
```bash
|
||||
sh examples/llama/generate_llama_7B_tp2_pp2.sh
|
||||
```
|
||||
```bash
|
||||
sh examples/llama/generate_llama_7B_tp2_pp2.sh
|
||||
```
|
||||
- **Deepspeed ZeRO only:** In this mode, the model uses DeepSpeed ZeRO 1, 2 or 3 definition with tp=1, pp=1.
|
||||
```bash
|
||||
sh examples/llama/generate_alpaca_13B_deepspeed.sh
|
||||
```
|
||||
```bash
|
||||
sh examples/alpaca/generate_alpaca_13B_deepspeed.sh
|
||||
```
|
||||
- **Deepspeed ZeRO in Pipe with TP:** In this mode, the model uses pipe model definition in DeepSpeed ZeRO 1, 2 or 3 with tp>1, pp=1.
|
||||
```bash
|
||||
sh examples/llama/generate_llama_7B_deepspeed_pipeline.sh
|
||||
```
|
||||
```bash
|
||||
sh examples/llama/generate_llama_7B_deepspeed_pipeline.sh
|
||||
```
|
||||
- **If you want to use lora model**, for details, refer to:
|
||||
```bash
|
||||
sh examples/llama/generate_alpaca_13B_lora_deepspeed.sh
|
||||
```
|
||||
```bash
|
||||
sh examples/alpaca/generate_alpaca_13B_lora_deepspeed.sh
|
||||
```
|
||||
|
||||
***Some examples with [Chinese-LLaMA-Alpaca-13B weights](https://github.com/ymcui/Chinese-LLaMA-Alpaca) is see [here](#case1)***
|
||||
|
||||
|
||||
|
||||
#### Usage Guide
|
||||
Follow these steps to write your own inference code:
|
||||
|
||||
@ -638,53 +654,53 @@ pretrained_model_name_or_path(`str`, *optional*, defaults to None):
|
||||
```
|
||||
##### <span id="case1"> Generate text in HuggingFace-like ways </span>
|
||||
|
||||
###### Greedy
|
||||
```python
|
||||
responses = model.generate(
|
||||
"Write quick sort code in python",
|
||||
max_new_tokens=512
|
||||
)
|
||||
```
|
||||
<img src="sources/images/greedy.png">
|
||||
- Greedy Search
|
||||
```python
|
||||
responses = model.generate(
|
||||
"Write quick sort code in python",
|
||||
max_new_tokens=512
|
||||
)
|
||||
```
|
||||
<img src="sources/images/greedy.png">
|
||||
|
||||
###### Do sample with Top-k and Top-p
|
||||
```python
|
||||
responses = model.generate(
|
||||
"Write quick sort code in python",
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
max_new_tokens=512
|
||||
)
|
||||
```
|
||||
<img src="sources/images/sampling.png">
|
||||
- Do sample with top-k and top-p
|
||||
```python
|
||||
responses = model.generate(
|
||||
"Write quick sort code in python",
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
max_new_tokens=512
|
||||
)
|
||||
```
|
||||
<img src="sources/images/sampling.png">
|
||||
|
||||
###### Beam search with Top-k and Top-p
|
||||
```python
|
||||
responses = model.generate(
|
||||
"Write quick sort code in python",
|
||||
num_beams=4,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
max_new_tokens=512
|
||||
)
|
||||
```
|
||||
<img src="sources/images/beam_search.png">
|
||||
- Beam search with top-k and top-p
|
||||
```python
|
||||
responses = model.generate(
|
||||
"Write quick sort code in python",
|
||||
num_beams=4,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
max_new_tokens=512
|
||||
)
|
||||
```
|
||||
<img src="sources/images/beam_search.png">
|
||||
|
||||
###### Beam search with Top-k and Top-p sampling
|
||||
```python
|
||||
responses = model.generate(
|
||||
"Write quick sort code in python",
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
num_beams=4,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
max_new_tokens=512
|
||||
)
|
||||
```
|
||||
<img src="sources/images/beam_search_sampling.png">
|
||||
- Beam search with top-k and top-p sampling
|
||||
```python
|
||||
responses = model.generate(
|
||||
"Write quick sort code in python",
|
||||
do_sample=True,
|
||||
temperature=0.6,
|
||||
num_beams=4,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
max_new_tokens=512
|
||||
)
|
||||
```
|
||||
<img src="sources/images/beam_search_sampling.png">
|
||||
|
||||
## <span id="jump14"> Evaluation with Benchmarks </span>
|
||||
### Quick Show
|
||||
|
@ -642,7 +642,7 @@ class LlamaParallelAttention(MegatronModule):
|
||||
value_layer = value_layer.permute(1, 2, 0, 3).contiguous()
|
||||
|
||||
# preallocting result tensor: [b * np, sq, sk]
|
||||
matmul_result = pse[:output_size[0] * output_size[1], :, :output_size[3]]
|
||||
matmul_result = pse[:, :output_size[3]]
|
||||
# Raw attention scores. [b * np, sq, sk]
|
||||
q_trans = query_layer.transpose(0, 1).contiguous()
|
||||
k_trans = key_layer.transpose(0, 1).transpose(1, 2).contiguous()
|
||||
@ -777,8 +777,7 @@ class LlamaParallelTransformerLayer(MegatronModule):
|
||||
tp_index = parallel_state.get_tensor_model_parallel_rank()
|
||||
alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index]
|
||||
|
||||
alibi = alibi.repeat(batch_size, 1, 1)
|
||||
return alibi
|
||||
return alibi[0]
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None,
|
||||
layer_past=None, get_key_value=False):
|
||||
|
@ -235,7 +235,7 @@ class MegatronModuleForCausalLMABC(torch.nn.Module, abc.ABC):
|
||||
self.pad_token_id = None
|
||||
self.num_return_sequences = 1
|
||||
self.length_penalty = 1.0
|
||||
self.tokenizer = None
|
||||
self.tokenizer_new = None
|
||||
self.recompute = True
|
||||
self.detokenize = True
|
||||
self.include_input = False
|
||||
@ -334,7 +334,7 @@ class MegatronModuleForCausalLMABC(torch.nn.Module, abc.ABC):
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
self.tokenizer = kwargs.pop("tokenizer", None)
|
||||
self.tokenizer_new = kwargs.pop("tokenizer", None)
|
||||
self.recompute = kwargs.pop("recompute", True)
|
||||
self.detokenize = kwargs.pop("detokenize", True)
|
||||
self.include_input = kwargs.pop("include_input", False)
|
||||
@ -362,7 +362,7 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
self.padded_vocab_size = args.padded_vocab_size
|
||||
self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size > 1
|
||||
|
||||
self.tokenizer_ori = get_tokenizer().tokenizer
|
||||
self.tokenizer = get_tokenizer().tokenizer
|
||||
|
||||
# import module to avoid error of circular import
|
||||
self.greedy_search_or_sampling = greedy_search_or_sampling
|
||||
@ -570,9 +570,12 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
return self._check_output(output, self.stream)
|
||||
|
||||
def _init_tokenizer(self, args):
|
||||
if self.tokenizer is None:
|
||||
if self.tokenizer_new is None:
|
||||
self.tokenizer = ascendspeed.global_vars.rebuild_tokenizer(
|
||||
args, tokenizer=self.tokenizer_ori)
|
||||
args, tokenizer=self.tokenizer)
|
||||
else:
|
||||
self.tokenizer = ascendspeed.global_vars.rebuild_tokenizer(
|
||||
args, tokenizer=self.tokenizer_new)
|
||||
|
||||
if self.pad_token_id is not None:
|
||||
self.tokenizer.pad_token_id = self.pad_token_id
|
||||
|
@ -156,9 +156,6 @@ def _with_pipelining_forward_step(model, inputs, inference_params, micro_batch_s
|
||||
(batch_size, sequence_length, args.padded_vocab_size),
|
||||
dtype=torch.float32, device=torch.cuda.current_device())
|
||||
|
||||
# Preallocate recv buffer.
|
||||
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
|
||||
|
||||
for micro_batch_index in range(num_micro_batches):
|
||||
# Slice among the batch dimenion.
|
||||
start = micro_batch_index * micro_batch_size
|
||||
@ -167,10 +164,6 @@ def _with_pipelining_forward_step(model, inputs, inference_params, micro_batch_s
|
||||
tokens2use = tokens[start:end, ...]
|
||||
position_ids2use = position_ids[start:end, ...]
|
||||
|
||||
# Run a simple forward pass.
|
||||
if this_micro_batch_size != micro_batch_size:
|
||||
recv_buffer = None
|
||||
|
||||
output = _forward_step_helper(model,
|
||||
tokens2use,
|
||||
position_ids=position_ids2use,
|
||||
|
@ -36,11 +36,14 @@ class ST_Test:
|
||||
TEST_DIR, st_dir, llama_dir, "test_llama_ptd.sh")
|
||||
lora_shell_file = os.path.join(
|
||||
TEST_DIR, st_dir, llama_dir, "test_lora_llama_ptd.sh")
|
||||
llama_inference_shell_file = os.path.join(
|
||||
TEST_DIR, st_dir, llama_dir, "test_llama_inference_ptd.sh")
|
||||
|
||||
self.shell_file_list = [
|
||||
llama_inference_shell_file,
|
||||
llama_shell_file,
|
||||
bloom_shell_file,
|
||||
lora_shell_file
|
||||
lora_shell_file,
|
||||
]
|
||||
|
||||
def run_shell(self):
|
||||
|
@ -45,8 +45,8 @@ deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ${config_json} \
|
||||
|
@ -48,8 +48,8 @@ deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
|
@ -22,7 +22,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 40 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 13824 \
|
||||
--ffn-hidden-size 13696 \
|
||||
--load "${CHECKPOINT}" \
|
||||
--num-attention-heads 40 \
|
||||
--max-position-embeddings 2048 \
|
||||
@ -31,6 +31,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 512 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42
|
||||
--position-embedding-type alibi \
|
||||
|
@ -34,8 +34,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 2048 \
|
||||
--max-new-tokens 64 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42 \
|
||||
--load "${CHECKPOINT}" \
|
||||
--embed-layernorm \
|
||||
|
@ -49,8 +49,8 @@ deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ${config_json} \
|
||||
|
36
examples/llama/generate_llama_13B_tp8_pp1.sh
Normal file
36
examples/llama/generate_llama_13B_tp8_pp1.sh
Normal file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT"
|
||||
|
||||
CHECKPOINT="your megatron checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference_llama.py \
|
||||
--no-contiguous-buffers-in-local-ddp \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 40 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 13824 \
|
||||
--load "${CHECKPOINT}" \
|
||||
--num-attention-heads 40 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42
|
36
examples/llama/generate_llama_33B_tp8_pp1.sh
Normal file
36
examples/llama/generate_llama_33B_tp8_pp1.sh
Normal file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT"
|
||||
|
||||
CHECKPOINT="your megatron checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference_llama.py \
|
||||
--no-contiguous-buffers-in-local-ddp \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 60 \
|
||||
--hidden-size 6656 \
|
||||
--ffn-hidden-size 17920 \
|
||||
--load "${CHECKPOINT}" \
|
||||
--num-attention-heads 52 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42
|
36
examples/llama/generate_llama_65B_tp8_pp1.sh
Normal file
36
examples/llama/generate_llama_65B_tp8_pp1.sh
Normal file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT"
|
||||
|
||||
CHECKPOINT="your megatron checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference_llama.py \
|
||||
--no-contiguous-buffers-in-local-ddp \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 80 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 22016 \
|
||||
--load "${CHECKPOINT}" \
|
||||
--num-attention-heads 64 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42
|
@ -47,8 +47,8 @@ deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ${config_json} \
|
||||
|
42
examples/llama/generate_llama_7B_lora_tp1_pp1.sh
Normal file
42
examples/llama/generate_llama_7B_lora_tp1_pp1.sh
Normal file
@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=1
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT"
|
||||
|
||||
CHECKPOINT="your megatron checkpoint path"
|
||||
LORA_CHECKPOINT="your lora checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
./tasks/inference/inference_llama.py \
|
||||
--no-contiguous-buffers-in-local-ddp \
|
||||
--tensor-model-parallel-size 1 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 32 \
|
||||
--hidden-size 4096 \
|
||||
--ffn-hidden-size 11008 \
|
||||
--load "${CHECKPOINT}" \
|
||||
--lora-load "${LORA_CHECKPOINT}" \
|
||||
--num-attention-heads 32 \
|
||||
--seq-length 1024 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--lora-target-modules query_key_value dense gate_proj up_proj down_proj \
|
@ -31,6 +31,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42
|
||||
|
36
tests/st/test_llama/test_llama_inference_ptd.sh
Normal file
36
tests/st/test_llama/test_llama_inference_ptd.sh
Normal file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6661
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT"
|
||||
|
||||
VOCAB_FILE=/home/dataset/llama
|
||||
basepath=$(cd `dirname $0`; cd ../../..; pwd)
|
||||
export PYTHONPATH=${basepath}:$PYTHONPATH
|
||||
|
||||
python3 -m torch.distributed.launch $DISTRIBUTED_ARGS ${basepath}/tasks/inference/inference_llama.py \
|
||||
--task 1 2 3 4 5 \
|
||||
--tensor-model-parallel-size 2 \
|
||||
--pipeline-model-parallel-size 4 \
|
||||
--num-layers 32 \
|
||||
--hidden-size 4096 \
|
||||
--ffn-hidden-size 11008 \
|
||||
--num-attention-heads 32 \
|
||||
--max-position-embeddings 2048 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 256 \
|
||||
--max-new-tokens 64 \
|
||||
--seed 42
|
Loading…
Reference in New Issue
Block a user