mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 05:17:40 +08:00
提交llama13b脚本与readme文件
This commit is contained in:
parent
45434eead6
commit
76609dec8c
@ -52,6 +52,8 @@ conda activate test
|
||||
pip install torch-1.11.0-cp37-cp37m-linux_aarch64.whl
|
||||
pip install torch_npu-1.11.0.post4_XXXXXX-cp37-cp37m-linux_aarch64.whl
|
||||
pip install apex-0.1_ascend_XXXXXX-cp37-cp37m-linux_aarch64.whl
|
||||
# install megatron-core
|
||||
pip3 install -e git+https://github.com/NVIDIA/Megatron-LM.git@23.05#egg=megatron-core
|
||||
# install deepspeed and deepspeed_npu
|
||||
pip install deepspeed==0.9.2
|
||||
git clone https://gitee.com/ascend/DeepSpeed.git -b v0.9.2 deepspeed_npu
|
||||
@ -61,6 +63,19 @@ cd ..
|
||||
# install other packages
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
*Note that if you want to train with the weight from huggingface, please run fix a deepspeed loading checkpointing bug by modified `if zero_sd_list is None` as `if zero_sd_list is None or len(zero_sd_list) == 0` in the `_load_zero_checkpoint` function of `<deepspeed-installed-path>/runtime/engine.py`*
|
||||
|
||||
```python
|
||||
# original deepspeed/runtime/engine.py, about #Lines2746-2748
|
||||
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
|
||||
if zero_sd_list is None:
|
||||
return False
|
||||
|
||||
# modified
|
||||
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
|
||||
if zero_sd_list is None or len(zero_sd_list) == 0:
|
||||
return False
|
||||
```
|
||||
3. Download the Internlm-7B tokenizer model and file from [here](https://huggingface.co/internlm/internlm-7b/tree/main)
|
||||
|
||||
```shell
|
||||
|
@ -1,3 +1,220 @@
|
||||
# LLaMA-7B/13B
|
||||
|
||||
- [LLaMA-7B/13B](#llama-7b13b)
|
||||
- [Training](#training)
|
||||
- [Script](#script)
|
||||
- [Performance](#performance)
|
||||
- [Machine performance](#machine-performance)
|
||||
- [Accuracy of the loss](#accuracy-of-the-loss)
|
||||
- [LLaMA-65B](#llama-65b)
|
||||
- [Pre-Training](#pre-training)
|
||||
- [Datasets](#datasets)
|
||||
- [Script](#script-1)
|
||||
- [Performance](#performance-1)
|
||||
- [Machine performance](#machine-performance-1)
|
||||
- [Accuracy of the loss](#accuracy-of-the-loss-1)
|
||||
- [Citation](#citation)
|
||||
|
||||
## Training
|
||||
|
||||
Here's a hardware summary of pre-training LLaMA-7B/13B:
|
||||
|
||||
| Hardware | Value |
|
||||
| :------: | :---------------------------------------------: |
|
||||
| NPU | 8 x Ascend NPUs |
|
||||
|
||||
Here's a software summary of pre-training LLaMA-7B/13B:
|
||||
|
||||
|
||||
| Software | Version |link |
|
||||
| :-----------------------: | :-------------------------------------: | :---:|
|
||||
| Python | 3.7.16 |-|
|
||||
| driver | 23.0.RC3.B050 |[link](https://support.huawei.com/enterprise/zh/ascend-computing/ascend-hdk-pid-252764743/software/261159045?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C252764743)|
|
||||
| firmware | 7.0.t8.0.b214 |[link](https://support.huawei.com/enterprise/zh/ascend-computing/ascend-hdk-pid-252764743/software/261159045?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C252764743)|
|
||||
| CANN |Ascend-cann-toolkit-7.0.T8-linux |[link](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373/software/261204647?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C251168373)|
|
||||
| binary arithmetic package | Ascend-cann-kernels-910b_7.0.T8_linux |[link](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373/software/261204647?idAbsPath=fixnode01%7C23710424%7C251366513%7C22892968%7C251168373)|
|
||||
| torch | 1.11.0 |[link](https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v1.11.0/20230915.2/pytorch_v1.11.0_py37.tar.gz)|
|
||||
| torch_npu | 1.11.0.post1-20230922 |[link](https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/ascendspeed/torch_npu-1.11.0.post1_20230922-cp37-cp37m-linux_x86_64.whl)|
|
||||
|
||||
|
||||
### Script
|
||||
|
||||
1. Clone the repository to your local server:
|
||||
```shell
|
||||
git clone https://gitee.com/ascend/AscendSpeed.git
|
||||
cd AscendSpeed
|
||||
mkdir logs
|
||||
mkdir ckpt
|
||||
```
|
||||
|
||||
2. Build environment
|
||||
|
||||
```bash
|
||||
# python3.7
|
||||
conda create -n test python=3.7
|
||||
conda activate test
|
||||
# install torch and torch_npu
|
||||
pip install torch-1.11.0-cp37-cp37m-linux_aarch64.whl
|
||||
pip install torch_npu-1.11.0.post4_XXXXXX-cp37-cp37m-linux_aarch64.whl
|
||||
pip install apex-0.1_ascend_XXXXXX-cp37-cp37m-linux_aarch64.whl
|
||||
# install megatron-core
|
||||
pip3 install -e git+https://github.com/NVIDIA/Megatron-LM.git@23.05#egg=megatron-core
|
||||
# install deepspeed and deepspeed_npu
|
||||
pip install deepspeed==0.9.2
|
||||
git clone https://gitee.com/ascend/DeepSpeed.git -b v0.9.2 deepspeed_npu
|
||||
cd deepspeed_npu
|
||||
pip3 install -e ./
|
||||
cd ..
|
||||
# install other packages
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
*Note that if you want to train with the weight from huggingface, please run fix a deepspeed loading checkpointing bug by modified `if zero_sd_list is None` as `if zero_sd_list is None or len(zero_sd_list) == 0` in the `_load_zero_checkpoint` function of `<deepspeed-installed-path>/runtime/engine.py`*
|
||||
|
||||
```python
|
||||
# original deepspeed/runtime/engine.py, about #Lines2746-2748
|
||||
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
|
||||
if zero_sd_list is None:
|
||||
return False
|
||||
|
||||
# modified
|
||||
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
|
||||
if zero_sd_list is None or len(zero_sd_list) == 0:
|
||||
return False
|
||||
```
|
||||
3. Download the LLaMA-7B/13B tokenizer model and file from [here](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main)
|
||||
|
||||
|
||||
```shell
|
||||
#!/bin/bash
|
||||
mkdir -p dataset/llama
|
||||
cd ./dataset/llama
|
||||
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/config.json
|
||||
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/generation_config.json
|
||||
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/special_tokens_map.json
|
||||
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/tokenizer.model
|
||||
wget https://huggingface.co/decapoda-research/llama-7b-hf/tree/main/tokenizer_config.json
|
||||
cd ..
|
||||
```
|
||||
|
||||
|
||||
4. Prepare dataset. Download the Internlm-7B datasets from [here](https://huggingface.co/datasets/tatsu-lab/alpaca/resolve/main/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet)
|
||||
|
||||
```shell
|
||||
cd dataset/
|
||||
wget https://huggingface.co/datasets/tatsu-lab/alpaca/resolve/main/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet
|
||||
cd ..
|
||||
```
|
||||
|
||||
```shell
|
||||
#!/bin/bash
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
python ./tools/preprocess_data.py \
|
||||
--input ./dataset/train-00000-of-00001-a09b74b3ef9c3b56.parquet \
|
||||
--tokenizer-name-or-path ./dataset/llama \
|
||||
--output-prefix ./dataset/llama \
|
||||
--workers 4 \
|
||||
--log-interval 1000 \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--handler-name AlpacaPretrainHandler \
|
||||
--tokenizer-not-use-fast \
|
||||
--append-eod
|
||||
```
|
||||
|
||||
5. Weights convert
|
||||
|
||||
Download the LLaMA-7B checkpoint from [here](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main)
|
||||
```shell
|
||||
mkdir model_from_hf
|
||||
cd ./model_from_hf
|
||||
# you must install git-lfs
|
||||
git clone https://huggingface.co/decapoda-research/llama-7b-hf
|
||||
cd ..
|
||||
```
|
||||
|
||||
Download the LLaMA-13B checkpoint from [here](https://huggingface.co/decapoda-research/llama-13b-hf/tree/main)
|
||||
```shell
|
||||
mkdir model_from_hf
|
||||
cd ./model_from_hf
|
||||
# you must install git-lfs
|
||||
git clone https://huggingface.co/decapoda-research/llama-13b-hf
|
||||
cd ..
|
||||
```
|
||||
|
||||
In order to adapt to the LLaMA-7B/13B model, the following script is used to convert the model pre-training weights.
|
||||
|
||||
LLaMA-7B
|
||||
```shell
|
||||
mkdir model_weights
|
||||
SCRIPT_PATH=./tools/ckpt_convert/llama/convert_weights_from_huggingface.py
|
||||
python $SCRIPT_PATH \
|
||||
--input-model-dir ./model_from_hf/llama-7b/ \
|
||||
--output-model-dir ./model_weights/llama-7b \
|
||||
--tensor-model-parallel-size 1 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--type 7B \
|
||||
--deepspeed
|
||||
```
|
||||
|
||||
LLaMA-13B
|
||||
```shell
|
||||
mkdir model_weights
|
||||
SCRIPT_PATH=./tools/ckpt_convert/llama/convert_weights_from_huggingface.py
|
||||
python $SCRIPT_PATH \
|
||||
--input-model-dir ./model_from_hf/llama-13b/ \
|
||||
--output-model-dir ./model_weights/llama-13b \
|
||||
--tensor-model-parallel-size 1 \
|
||||
--pipeline-model-parallel-size 8 \
|
||||
--type 13B
|
||||
```
|
||||
|
||||
6. Config LLaMA-7B/13B pre-training script.
|
||||
|
||||
```shell
|
||||
# modify the script according to your own ascend-toolkit path
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
# modify script orign dataset path according to your own dataset path
|
||||
TOKENIZER_PATH=./dataset/llama #tokenizer path
|
||||
DATA=./dataset/llama_text_document #processed dataset
|
||||
CHECKPOINT=./model_weights/
|
||||
```
|
||||
*Note that if you do not load weights for pre-training, remove the `--load` parameter from the training script*
|
||||
|
||||
7. Launch LLaMA-7B/13B pre-training script.
|
||||
|
||||
LLaMA-7B
|
||||
```shell
|
||||
bash examples/intern/pretrain_llama_7B_zero_8p.sh
|
||||
```
|
||||
|
||||
LLaMA-13B
|
||||
```shell
|
||||
bash examples/intern/pretrain_llama_13B_ptd_8p.sh
|
||||
```
|
||||
|
||||
### Performance
|
||||
|
||||
#### Machine performance
|
||||
|
||||
The performance of LLaMA-7B/13B in **Ascend NPU** and **Reference**:
|
||||
|
||||
| Device | Model | total Iterations | throughput rate (samples/s/p) | throughput rate (tokens/s/p) | single-step time (s/step) | floating point operation (TFLOPs/s) |
|
||||
| ------ |--------------| ---------------- |-------------------------------|------------------------------|---------------------------|-------------------------------------|
|
||||
| NPUs | LLaMA-7B | 2048 | 1.398 | 2862 | 5.725 | 162.2 |
|
||||
| Reference | LLaMA-7B | 2048 | 1.395 | 2859 | 5.73 | 161.8 |
|
||||
| NPUs | LLaMA-13B | 2048 | 0.879 | 1800 | 18.20 | 146.1 |
|
||||
| Reference | LLaMA-13B | 2048 | 0.847 | 1734 | 18.89 | 141.0 |
|
||||
|
||||
|
||||
|
||||
#### Accuracy of the loss
|
||||
|
||||
LLama-7b with huggingface weights NPU vs GPU loss.
|
||||
![NPU-Loss-with-weight-and-Relative-Error](../../sources/images/llama/llama7b-loss-with-weight.png)
|
||||
LLama-13b with huggingface weights NPU vs GPU loss.
|
||||
![NPU-Loss-with-weight-and-Relative-Error](../../sources/images/llama/llama13b-loss-with-weight.png)
|
||||
|
||||
|
||||
|
||||
# LLaMA-65B
|
||||
|
||||
This directory contains some of the scripts that were used to produce the results in the AscendSpeed. These scripts is to show the example how to run llama-65B in terminal.
|
||||
@ -6,22 +223,6 @@ LLaMA model is from: [LLaMA: OPen and Efficient Foundation Language Models](http
|
||||
|
||||
>Touvron, Hugo, et al. "LLaMA: OPen and Efficient Foundation Language Models." arXiv preprint arXiv:2302.13971 (2023).
|
||||
|
||||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
|
||||
- [Pre-Training](#pre-training)
|
||||
- [Datasets](#datasets)
|
||||
|
||||
- [Script](#script)
|
||||
|
||||
- [Performance](#performance)
|
||||
- [Machine performance](#machine-performance)
|
||||
- [Accuracy of the loss](#accuracy-of-the-loss)
|
||||
|
||||
- [Citation](#citation)
|
||||
|
||||
|
||||
## Pre-Training
|
||||
|
||||
LLaMA's model performace is better than GPT3 with less parameters. The 65B LLaMA model is comparable to Google's Chinchilla-70B and Palm-540B.
|
||||
|
53
examples/llama/pretrain_llama_13B_ptd_8p.sh
Normal file
53
examples/llama/pretrain_llama_13B_ptd_8p.sh
Normal file
@ -0,0 +1,53 @@
|
||||
# This is an example: train llama using PTD.
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
export COMBINED_ENABLE=1
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
DATA_PATH=./dataset/llama_text_document
|
||||
CHECKPOINT=./model_weights/llama-13b
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
# Main script
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--tensor-model-parallel-size 1 \
|
||||
--pipeline-model-parallel-size 8 \
|
||||
--num-layers 40 \
|
||||
--hidden-size 5120 \
|
||||
--ffn-hidden-size 13824 \
|
||||
--num-attention-heads 40 \
|
||||
--micro-batch-size 1 \
|
||||
--global-batch-size 128 \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 2048 \
|
||||
--train-iters 1000 \
|
||||
--lr-decay-iters 640 \
|
||||
--load $CHECKPOINT \
|
||||
--data-path $DATA_PATH \
|
||||
--tokenizer-name-or-path ./dataset/llama/ \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 1.0e-6 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 1.0e-7 \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .01 \
|
||||
--log-interval 1 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
--initial-loss-scale 4096.0 \
|
||||
--checkpoint-activations \
|
||||
--fp16 | tee logs/train_13B.log
|
BIN
sources/images/llama/llama13b-loss-with-weight.png
Normal file
BIN
sources/images/llama/llama13b-loss-with-weight.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
BIN
sources/images/llama/llama7b-loss-with-weight.png
Normal file
BIN
sources/images/llama/llama7b-loss-with-weight.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 42 KiB |
Loading…
Reference in New Issue
Block a user