diff --git a/README.md b/README.md index bb3813faa..d456a64c5 100644 --- a/README.md +++ b/README.md @@ -390,6 +390,17 @@ ModelLink旨在为华为 [昇腾芯片](https://open.codehub.huawei.com/OpenBaiz eval 【昇腾贡献】 + + GPT3 + 175B + pretrain + -- + -- + -- + -- + -- + 【社区贡献】 + @@ -679,6 +690,14 @@ ModelLink旨在为华为 [昇腾芯片](https://open.codehub.huawei.com/OpenBaiz 2938 2607 + + GPT3 + 175B + 16x8 + FP16 + 153 + -- + diff --git a/README_en.md b/README_en.md index be0a5cce6..e6baf5026 100644 --- a/README_en.md +++ b/README_en.md @@ -389,6 +389,17 @@ Current ModelLink supports pre-training and fine-tuning for the following models eval 【Ascend】 + + GPT3 + 175B + pretrain + -- + -- + -- + -- + -- + 【Community】 + @@ -678,6 +689,14 @@ For the supported models listed above, we provide training scripts and readme in 2938 2607 + + GPT3 + 175B + 16x8 + FP16 + 153 + -- + diff --git a/examples/bloom/pretrain_bloom_ptd_7B.sh b/examples/bloom/pretrain_bloom_ptd_7B.sh index c4e3850f7..70516cec2 100644 --- a/examples/bloom/pretrain_bloom_ptd_7B.sh +++ b/examples/bloom/pretrain_bloom_ptd_7B.sh @@ -46,9 +46,10 @@ GPT_ARGS=" --attention-softmax-in-fp32 \ --apply-query-key-layer-scaling \ --lr 1.2e-4 \ - --train-iters 200 \ + --train-iters 2000 \ --init-method-std 0.0048 \ --hidden-dropout 0.0 \ + --attention-dropout 0.0 \ --position-embedding-type alibi \ --normalization LayerNorm \ --min-lr 6e-6 \ diff --git a/examples/gpt3/pretrain_gpt3_15B_ptd.sh b/examples/gpt3/pretrain_gpt3_15B_ptd.sh new file mode 100644 index 000000000..6d436adf6 --- /dev/null +++ b/examples/gpt3/pretrain_gpt3_15B_ptd.sh @@ -0,0 +1,87 @@ +#!/bin/bash + + +# Runs the "175B" parameter model in deminishing layers for single machine +export COMBINED_ENABLE=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export AZUREML_EXPERIMENT_ID=0 + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6001 +NUM_NODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +DATA_PATH="your dataset path" +VOCAB_FILE="vocab file for gpt" +MERGE_FILE="merge file for gpt" + +TP=8 +PP=1 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NUM_NODES \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --node_rank $NODE_RANK +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --sequence-parallel \ + --num-layers 8 \ + --hidden-size 12288 \ + --num-attention-heads 96 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --transformer-impl local \ + --micro-batch-size 1 \ + --global-batch-size 64 \ + --train-iters 2000 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --initial-loss-scale 4096 \ + --init-method-std 0.006 \ + --clip-grad 1.0 \ + --fp16 \ + --lr 6.0e-5 \ + --lr-decay-style cosine \ + --min-lr 6.0e-6 \ + --lr-warmup-fraction .001 \ + --lr-decay-iters 430000 \ + --no-load-optim \ + --no-load-rng \ + --no-gradient-accumulation-fusion \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --use-flash-attn \ + --no-bias-gelu-fusion \ + --use-mc2 +" + +DATA_ARGS=" + --data-path $DATA_PATH + --vocab-file $VOCAB_FILE + --merge-file $MERGE_FILE + --split 949,50,1 +" + +OUTPUT_ARGS=" + --log-interval 1 + --eval-interval 5000 + --eval-iters 1 +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --distributed-backend nccl 2>&1 + | tee ./logs/pretrain_gpt3_175B_8layers.log diff --git a/examples/gpt3/pretrain_gpt3_175B_ptd.sh b/examples/gpt3/pretrain_gpt3_175B_ptd.sh new file mode 100644 index 000000000..6b5576c11 --- /dev/null +++ b/examples/gpt3/pretrain_gpt3_175B_ptd.sh @@ -0,0 +1,88 @@ +#!/bin/bash + +# Runs the "175B" parameter model in full layers. +export COMBINED_ENABLE=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export AZUREML_EXPERIMENT_ID=0 + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6001 +NUM_NODES=16 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +DATA_PATH="your dataset path" +VOCAB_FILE="vocab file for gpt training" +MERGE_FILE="merge file for gpt training" + +TP=8 +PP=16 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NUM_NODES \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --node_rank $NODE_RANK +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --num-layers-per-virtual-pipeline-stage 2 \ + --sequence-parallel \ + --num-layers 96 \ + --hidden-size 12288 \ + --num-attention-heads 96 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --transformer-impl local \ + --micro-batch-size 2 \ + --global-batch-size 1024 \ + --train-iters 2000 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --initial-loss-scale 4096 \ + --init-method-std 0.006 \ + --clip-grad 1.0 \ + --fp16 \ + --lr 6.0e-5 \ + --lr-decay-style cosine \ + --min-lr 6.0e-6 \ + --lr-warmup-fraction .001 \ + --lr-decay-iters 430000 \ + --no-load-optim \ + --no-load-rng \ + --no-gradient-accumulation-fusion \ + --no-masked-softmax-fusion \ + --no-bias-gelu-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --use-flash-attn \ + --overlap-grad-reduce \ + --use-mc2 +" + +DATA_ARGS=" + --data-path $DATA_PATH + --vocab-file $VOCAB_FILE + --merge-file $MERGE_FILE + --split 949,50,1 +" + +OUTPUT_ARGS=" + --log-interval 1 + --eval-interval 5000 + --eval-iters 1 +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --distributed-backend nccl 2>&1 + | tee ./logs/pretrain_gpt3_175B.log diff --git a/examples/gpt3/readme.md b/examples/gpt3/readme.md new file mode 100644 index 000000000..bfdfd3552 --- /dev/null +++ b/examples/gpt3/readme.md @@ -0,0 +1,136 @@ +# GPT3 $\color{black}{\bf\tiny{【社区贡献模型】}}$ + +

+ 简体中文 | + English +

+ +# 目录 + +- [GPT3](#GPT3) +- [目录](#目录) +- [GPT3-175B](#GPT3-175B) + - [训练-175B](#训练) + - [脚本](#脚本) + - [性能](#性能) + - [吞吐](#吞吐) + +# GPT3-175B + +## 训练 + +GPT3-175B 训练的硬件配置: + +| 硬件 | 配置 | +| :--: | :-------------: | +| NPU | 128 x Ascend NPUs | + +### 脚本 + +1. 克隆仓库到本地服务器: + + ```shell + git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. + cd ModelLink + mkdir logs + mkdir vocab_file + mkdir dataset + ``` + +2. 搭建环境 + + ```bash + # python3.8 + conda create -n test python=3.8 + conda activate test + + # 安装 torch 和 torch_npu + pip install torch-2.1.0-cp38-cp38m-manylinux2014_aarch64.whl + pip install torch_npu-2.1.0*-cp38-cp38m-linux_aarch64.whl + pip install apex-0.1_ascend*-cp38-cp38m-linux_aarch64.whl + + # 修改 ascend-toolkit 路径 + source /usr/local/Ascend/ascend-toolkit/set_env.sh + + # 安装 AscendSpeed + git clone https://gitee.com/ascend/AscendSpeed.git + cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 + pip install -r requirements.txt + pip3 install -e . + cd .. + + # 安装其他依赖 + pip install -r requirements.txt + ``` + +3. 准备数据、词表来拉起模型 + 3.1 准备数据 + + 可以从 [这里](https://huggingface.co/datasets/wikipedia/tree/main/data/20220301.en) 下载原始数据 + ```shell + # 下载 enwiki 数据 + # 总共有 41 个文件,我们可以选择部分来制作数据 + cd ./dataset + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00000-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00001-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00002-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00003-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00004-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00005-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00006-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00007-of-00041.parquet + cd .. + + # 下载 vocab file 和 merge table + cd vocab_file + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt + cd .. + + # 处理成训练数据 + python ./tools/preprocess_data.py \ + --input ./dataset/ \ + --output-prefix ./dataset/gpt_text_sentence \ + --tokenizer-type GPT2BPETokenizer \ + --vocab-file ./vocab_file/gpt2-vocab.json \ + --merge-file ./vocab_file/gpt2-merges.txt \ + --append-eod \ + --workers 4 \ + --log-interval 1000 + ``` + + 3.2 用 ptd 模式进行预训练 + 配置 GPT3-175B PTD 预训练脚本: examples/gpt3/pretrain_gpt3_175B.sh + + ```shell + # 请根据真实情况配置 ascend-toolkit 路径 + source /usr/local/Ascend/ascend-toolkit/set_env.sh + + # 请根据真实存放路径配置以下参数 + VOCAB_FILE="./vocab_file/gpt2-vocab.json" # 词表 + MERGE_FILE="./vocab_file/gpt2-merges.txt" # BPE 合并表 + DATA_PATH="./dataset/gpt_text_sentence" # 数据路径 + ``` + + 拉起 GPT3-175B PTD 预训练脚本: examples/gpt3/pretrain_gpt3_175B.sh + + ```shell + bash examples/gpt3/pretrain_gpt3_175B.sh + ``` + +### 性能 + +#### 吞吐 + +GPT3-175B 在 **昇腾芯片**上的性能数据: + +| 设备 | 模型 | tokens吞吐 (tokens/s/p) | +| :--: | :--------: |:---------------------:| +| NPUs | GPT3-175B | 153.1 | + diff --git a/examples/gpt3/readme_en.md b/examples/gpt3/readme_en.md new file mode 100644 index 000000000..39dda4ae4 --- /dev/null +++ b/examples/gpt3/readme_en.md @@ -0,0 +1,136 @@ +# GPT3 $\color{black}{\rm\tiny{【model}}$ $\color{black}{\rm\tiny{contributed}}$ $\color{black}{\rm\tiny{by}}$ $\color{black}{\rm\tiny{Community】}}$ + +

+ English | + English +

+ +# Contents + +- [GPT3](#GPT3) +- [Contents](#contents) +- [GPT3-175B](#GPT3-175B) + - [Training-175B](#training) + - [Script](#script) + - [Perforfance](#performance) + - [Machine performance](#machine-performance) + +# GPT3-175B + +## Training + +Here is a hardware summary of pre-trianing GPT3-175B: + +| Hardware | Value | +| :--: | :-------------: | +| NPU | 128 x Ascend NPUs | + +### Script + +1. Clone repository to your local server: + + ```shell + git clone https://gitee.com/ascend/ModelLink.git + git clone https://github.com/NVIDIA/Megatron-LM.git + cd Megatron-LM + git checkout -f bcce6f + cp -r megatron ../ModelLink/ + cd .. + cd ModelLink + mkdir logs + mkdir vocab_file + mkdir dataset + ``` + +2. Build environment + + ```bash + # python3.8 + conda create -n test python=3.8 + conda activate test + + # install torch and torch_npu + pip install torch-2.1.0-cp38-cp38m-manylinux2014_aarch64.whl + pip install torch_npu-2.1.0*-cp38-cp38m-linux_aarch64.whl + pip install apex-0.1_ascend*-cp38-cp38m-linux_aarch64.whl + + # modify ascend-toolkit path + source /usr/local/Ascend/ascend-toolkit/set_env.sh + + # install AscendSpeed + git clone https://gitee.com/ascend/AscendSpeed.git + cd AscendSpeed + git checkout 224ae35e8fc96778f957029d1371ddb623452a50 + pip install -r requirements.txt + pip3 install -e . + cd .. + + # install other packages + pip install -r requirements.txt + ``` + +3. Prepare dataset and vocab file for pretrain + 3.1 Prepare dataset + + Download the GPT raw dataset from [here](https://huggingface.co/datasets/wikipedia/tree/main/data/20220301.en) + ```shell + # download enwiki raw data + # There are 41 files in total, we can just select part to make our datasets. + cd ./dataset + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00000-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00001-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00002-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00003-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00004-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00005-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00006-of-00041.parquet + wget https://huggingface.co/datasets/wikipedia/blob/main/data/20220301.en/train-00007-of-00041.parquet + cd .. + + # download vocab file and merge table + cd vocab_file + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json + wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt + cd .. + + # process formal dataset + python ./tools/preprocess_data.py \ + --input ./dataset/ \ + --output-prefix ./dataset/gpt_text_sentence \ + --tokenizer-type GPT2BPETokenizer \ + --vocab-file ./vocab_file/gpt2-vocab.json \ + --merge-file ./vocab_file/gpt2-merges.txt \ + --append-eod \ + --workers 4 \ + --log-interval 1000 + ``` + + 3.2 pre-training in ptd mode + Config GPT3-175B PTD pre-training script: examples/gpt3/pretrain_gpt3_175B.sh + + ```shell + # modify ascend-toolkit path according to your own config + source /usr/local/Ascend/ascend-toolkit/set_env.sh + + # modify config according to your own actual situation + VOCAB_FILE="./vocab_file/gpt2-vocab.json" # vocab file for training + MERGE_FILE="./vocab_file/gpt2-merges.txt" # BPE merge file for training + DATA_PATH="./dataset/gpt_text_sentence" # dataset path + ``` + + Launch GPT3-175B PTD pre-training script: examples/gpt3/pretrain_gpt3_175B.sh + + ```shell + bash examples/gpt3/pretrain_gpt3_175B.sh + ``` + +### Performance + +#### Machine performance + +The performance of GPT3-175B in **Ascend NPU**: + +| device | model | tokens capacity (tokens/s/p) | +| :--: | :--------: |:---------------------:| +| NPUs | GPT3-175B | 153.1 | + diff --git a/modellink/model/transformer.py b/modellink/model/transformer.py index 56f372d21..bd1ef11fe 100644 --- a/modellink/model/transformer.py +++ b/modellink/model/transformer.py @@ -420,7 +420,7 @@ class FlashSelfAttention(torch.nn.Module): """Implements the multihead softmax attention. Arguments --------- - q, k, v: The tensor containing the query, key, and value. (B, S, H, D) + q, k, v: The tensor containing the query, key, and value. (S, B, H, D) """ args = get_args()