提交llama13b脚本与readme文件

This commit is contained in:
l00619700 2023-10-18 14:23:53 +08:00
parent 45434eead6
commit 76609dec8c
5 changed files with 285 additions and 16 deletions

View File

@ -52,6 +52,8 @@ conda activate test
pip install torch-1.11.0-cp37-cp37m-linux_aarch64.whl
pip install torch_npu-1.11.0.post4_XXXXXX-cp37-cp37m-linux_aarch64.whl
pip install apex-0.1_ascend_XXXXXX-cp37-cp37m-linux_aarch64.whl
# install megatron-core
pip3 install -e git+https://github.com/NVIDIA/Megatron-LM.git@23.05#egg=megatron-core
# install deepspeed and deepspeed_npu
pip install deepspeed==0.9.2
git clone https://gitee.com/ascend/DeepSpeed.git -b v0.9.2 deepspeed_npu
@ -61,6 +63,19 @@ cd ..
# install other packages
pip install -r requirements.txt
```
*Note that if you want to train with the weight from huggingface, please run fix a deepspeed loading checkpointing bug by modified `if zero_sd_list is None` as `if zero_sd_list is None or len(zero_sd_list) == 0` in the `_load_zero_checkpoint` function of `<deepspeed-installed-path>/runtime/engine.py`*
```python
# original deepspeed/runtime/engine.py, about #Lines2746-2748
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None:
return False
# modified
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None or len(zero_sd_list) == 0:
return False
```
3. Download the Internlm-7B tokenizer model and file from [here](https://huggingface.co/internlm/internlm-7b/tree/main)
```shell

View File

@ -1,3 +1,220 @@
# LLaMA-7B/13B
- [LLaMA-7B/13B](#llama-7b13b)
- [Training](#training)
- [Script](#script)
- [Performance](#performance)
- [Machine performance](#machine-performance)
- [Accuracy of the loss](#accuracy-of-the-loss)
- [LLaMA-65B](#llama-65b)
- [Pre-Training](#pre-training)
- [Datasets](#datasets)
- [Script](#script-1)
- [Performance](#performance-1)
- [Machine performance](#machine-performance-1)
- [Accuracy of the loss](#accuracy-of-the-loss-1)
- [Citation](#citation)
## Training
Here's a hardware summary of pre-training LLaMA-7B/13B:
| Hardware | Value |
| :------: | :---------------------------------------------: |
| NPU | 8 x Ascend NPUs |
Here's a software summary of pre-training LLaMA-7B/13B:
| Software | Version |link |
| :-----------------------: | :-------------------------------------: | :---:|
| Python | 3.7.16 |-|
| driver | 23.0.RC3.B050 |[link](https://support.huawei.com/enterprise/zh/ascend-computing/ascend-hdk-pid-252764743/software/261159045?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C252764743)|
| firmware | 7.0.t8.0.b214 |[link](https://support.huawei.com/enterprise/zh/ascend-computing/ascend-hdk-pid-252764743/software/261159045?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C252764743)|
| CANN |Ascend-cann-toolkit-7.0.T8-linux |[link](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373/software/261204647?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C251168373)|
| binary arithmetic package | Ascend-cann-kernels-910b_7.0.T8_linux |[link](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373/software/261204647?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C251168373)|
| torch | 1.11.0 |[link](https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v1.11.0/20230915.2/pytorch_v1.11.0_py37.tar.gz)|
| torch_npu | 1.11.0.post1-20230922 |[link](https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/ascendspeed/torch_npu-1.11.0.post1_20230922-cp37-cp37m-linux_x86_64.whl)|
### Script
1. Clone the repository to your local server:
```shell
git clone https://gitee.com/ascend/AscendSpeed.git
cd AscendSpeed
mkdir logs
mkdir ckpt
```
2. Build environment
```bash
# python3.7
conda create -n test python=3.7
conda activate test
# install torch and torch_npu
pip install torch-1.11.0-cp37-cp37m-linux_aarch64.whl
pip install torch_npu-1.11.0.post4_XXXXXX-cp37-cp37m-linux_aarch64.whl
pip install apex-0.1_ascend_XXXXXX-cp37-cp37m-linux_aarch64.whl
# install megatron-core
pip3 install -e git+https://github.com/NVIDIA/Megatron-LM.git@23.05#egg=megatron-core
# install deepspeed and deepspeed_npu
pip install deepspeed==0.9.2
git clone https://gitee.com/ascend/DeepSpeed.git -b v0.9.2 deepspeed_npu
cd deepspeed_npu
pip3 install -e ./
cd ..
# install other packages
pip install -r requirements.txt
```
*Note that if you want to train with the weight from huggingface, please run fix a deepspeed loading checkpointing bug by modified `if zero_sd_list is None` as `if zero_sd_list is None or len(zero_sd_list) == 0` in the `_load_zero_checkpoint` function of `<deepspeed-installed-path>/runtime/engine.py`*
```python
# original deepspeed/runtime/engine.py, about #Lines2746-2748
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None:
return False
# modified
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None or len(zero_sd_list) == 0:
return False
```
3. Download the LLaMA-7B/13B tokenizer model and file from [here](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main)
```shell
#!/bin/bash
mkdir -p dataset/llama
cd ./dataset/llama
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/config.json
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/generation_config.json
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/special_tokens_map.json
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/tokenizer.model
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/tokenizer_config.json
cd ..
```
4. Prepare dataset. Download the Internlm-7B datasets from [here](https://huggingface.co/datasets/tatsu-lab/alpaca/resolve/main/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet)
```shell
cd dataset/
wget https://huggingface.co/datasets/tatsu-lab/alpaca/resolve/main/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet
cd ..
```
```shell
#!/bin/bash
source /usr/local/Ascend/ascend-toolkit/set_env.sh
python ./tools/preprocess_data.py \
--input ./dataset/train-00000-of-00001-a09b74b3ef9c3b56.parquet \
--tokenizer-name-or-path ./dataset/llama \
--output-prefix ./dataset/llama \
--workers 4 \
--log-interval 1000 \
--tokenizer-type PretrainedFromHF \
--handler-name AlpacaPretrainHandler \
--tokenizer-not-use-fast \
--append-eod
```
5. Weights convert
Download the LLaMA-7B checkpoint from [here](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main)
```shell
mkdir model_from_hf
cd ./model_from_hf
# you must install git-lfs
git clone https://huggingface.co/decapoda-research/llama-7b-hf
cd ..
```
Download the LLaMA-13B checkpoint from [here](https://huggingface.co/decapoda-research/llama-13b-hf/tree/main)
```shell
mkdir model_from_hf
cd ./model_from_hf
# you must install git-lfs
git clone https://huggingface.co/decapoda-research/llama-13b-hf
cd ..
```
In order to adapt to the LLaMA-7B/13B model, the following script is used to convert the model pre-training weights.
LLaMA-7B
```shell
mkdir model_weights
SCRIPT_PATH=./tools/ckpt_convert/llama/convert_weights_from_huggingface.py
python $SCRIPT_PATH \
--input-model-dir ./model_from_hf/llama-7b/ \
--output-model-dir ./model_weights/llama-7b \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--type 7B \
--deepspeed
```
LLaMA-13B
```shell
mkdir model_weights
SCRIPT_PATH=./tools/ckpt_convert/llama/convert_weights_from_huggingface.py
python $SCRIPT_PATH \
--input-model-dir ./model_from_hf/llama-13b/ \
--output-model-dir ./model_weights/llama-13b \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 8 \
--type 13B
```
6. Config LLaMA-7B/13B pre-training script.
```shell
# modify the script according to your own ascend-toolkit path
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# modify script orign dataset path according to your own dataset path
TOKENIZER_PATH=./dataset/llama #tokenizer path
DATA=./dataset/llama_text_document #processed dataset
CHECKPOINT=./model_weights/
```
*Note that if you do not load weights for pre-training, remove the `--load` parameter from the training script*
7. Launch LLaMA-7B/13B pre-training script.
LLaMA-7B
```shell
bash examples/intern/pretrain_llama_7B_zero_8p.sh
```
LLaMA-13B
```shell
bash examples/intern/pretrain_llama_13B_ptd_8p.sh
```
### Performance
#### Machine performance
The performance of LLaMA-7B/13B in **Ascend NPU** and **Reference**:
| Device | Model | total Iterations | throughput rate (samples/s/p) | throughput rate (tokens/s/p) | single-step time (s/step) | floating point operation (TFLOPs/s) |
| ------ |--------------| ---------------- |-------------------------------|------------------------------|---------------------------|-------------------------------------|
| NPUs | LLaMA-7B | 2048 | 1.398 | 2862 | 5.725 | 162.2 |
| Reference | LLaMA-7B | 2048 | 1.395 | 2859 | 5.73 | 161.8 |
| NPUs | LLaMA-13B | 2048 | 0.879 | 1800 | 18.20 | 146.1 |
| Reference | LLaMA-13B | 2048 | 0.847 | 1734 | 18.89 | 141.0 |
#### Accuracy of the loss
LLama-7b with huggingface weights NPU vs GPU loss.
![NPU-Loss-with-weight-and-Relative-Error](../../sources/images/llama/llama7b-loss-with-weight.png)
LLama-13b with huggingface weights NPU vs GPU loss.
![NPU-Loss-with-weight-and-Relative-Error](../../sources/images/llama/llama13b-loss-with-weight.png)
# LLaMA-65B
This directory contains some of the scripts that were used to produce the results in the AscendSpeed. These scripts is to show the example how to run llama-65B in terminal.
@ -6,22 +223,6 @@ LLaMA model is from: [LLaMA: OPen and Efficient Foundation Language Models](http
>Touvron, Hugo, et al. "LLaMA: OPen and Efficient Foundation Language Models." arXiv preprint arXiv:2302.13971 (2023).
# Contents
- [Contents](#contents)
- [Pre-Training](#pre-training)
- [Datasets](#datasets)
- [Script](#script)
- [Performance](#performance)
- [Machine performance](#machine-performance)
- [Accuracy of the loss](#accuracy-of-the-loss)
- [Citation](#citation)
## Pre-Training
LLaMA's model performace is better than GPT3 with less parameters. The 65B LLaMA model is comparable to Google's Chinchilla-70B and Palm-540B.

View File

@ -0,0 +1,53 @@
# This is an example: train llama using PTD.
export LD_LIBRARY_PATH=/usr/local/lib:/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
export HCCL_CONNECT_TIMEOUT=1200
export COMBINED_ENABLE=1
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6001
NNODES=1
NODE_RANK=0
NPUS_PER_NODE=8
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
DATA_PATH=./dataset/llama_text_document
CHECKPOINT=./model_weights/llama-13b
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
# Main script
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_llama.py \
--DDP-impl local \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 8 \
--num-layers 40 \
--hidden-size 5120 \
--ffn-hidden-size 13824 \
--num-attention-heads 40 \
--micro-batch-size 1 \
--global-batch-size 128 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--train-iters 1000 \
--lr-decay-iters 640 \
--load $CHECKPOINT \
--data-path $DATA_PATH \
--tokenizer-name-or-path ./dataset/llama/ \
--tokenizer-not-use-fast \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 1.0e-6 \
--lr-decay-style cosine \
--min-lr 1.0e-7 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction .01 \
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--initial-loss-scale 4096.0 \
--checkpoint-activations \
--fp16 | tee logs/train_13B.log

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB