!1460 支持chatglm3-6B使用mcore r0.6.0结构

Merge pull request !1460 from WangYu/chatglm3
This commit is contained in:
WangYu 2024-08-01 10:52:21 +00:00 committed by i-robot
parent bc0b4076b1
commit 34d2bbb412
5 changed files with 380 additions and 1 deletions

View File

@ -202,7 +202,7 @@ ModelLink已发布版本维护策略
<td>【Pass】</td>
</tr>
<tr>
<td rowspan="1"><a href="https://huggingface.co/THUDM">ChatGLM3</a></td>
<td rowspan="4"><a href="https://huggingface.co/THUDM">ChatGLM3</a></td>
<td><a href="https://huggingface.co/THUDM/chatglm3-6b/tree/main">6B</a></td>
<td> 8K </td>
<th>Legacy</th>
@ -212,6 +212,38 @@ ModelLink已发布版本维护策略
<td> 4267 </td>
<td><center>【昇腾】</td>
<td>【Test】</td>
<tr>
<td>6B</td>
<td>8K</td>
<th>Mcore</th>
<td >1x8</td>
<td> FP16 </td>
<td> 4611 </td>
<td> 4543 </td>
<td>【昇腾】</td>
<td>【Test】</td>
</tr>
<tr>
<td>6B</td>
<td>32K</td>
<th>Mcore</th>
<td >1x8</td>
<td> FP16 </td>
<td> 2650 </td>
<td> 2887 </td>
<td>【昇腾】</td>
<td>【Test】</td>
</tr>
<tr>
<td>6B</td>
<td>64K</td>
<th>Mcore</th>
<td >2x8</td>
<td> FP16 </td>
<td> 1724 </td>
<td> 2097 </td>
<td>【昇腾】</td>
<td>【Test】</td>
</tr>
<tr>
<td rowspan="1"><a href="https://huggingface.co/codellama">CodeLlama</a></td>

View File

@ -0,0 +1,111 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
NPUS_PER_NODE=8
MASTER_ADDR=<local_rank>
MASTER_PORT=6001
NNODES=1
NODE_RANK=0
WORLD_SIZE=$((NPUS_PER_NODE*$NNODES))
CKPT_SAVE_DIR="your model save ckpt path"
DATA_PATH="your data path"
TOKENIZER_PATH="your tokenizer path"
CKPT_LOAD_DIR="your model ckpt path"
TP=1
PP=1
CP=8
MBS=1
GBS=32
SEQ_LEN=32768
CP_ALGO=ulysses_cp_algo
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
"
GPT_ARGS="
--use-mcore-models \
--transformer-impl local \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--sequence-parallel \
--num-layers 28 \
--hidden-size 4096 \
--ffn-hidden-size 13696 \
--num-attention-heads 32 \
--seq-length ${SEQ_LEN} \
--micro-batch-size ${MBS} \
--global-batch-size ${GBS} \
--context-parallel-algo ${CP_ALGO} \
--context-parallel-size ${CP} \
--max-position-embeddings ${SEQ_LEN} \
--padded-vocab-size 65024 \
--make-vocab-size-divisible-by 1 \
--group-query-attention \
--num-query-groups 2 \
--disable-bias-linear \
--add-qkv-bias \
--position-embedding-type rope \
--no-rope-fusion \
--use-distributed-optimizer \
--use-partial-rope \
--use-flash-attn \
--use-fused-rmsnorm \
--use-fused-swiglu \
--normalization RMSNorm \
--swiglu \
--no-create-attention-mask-in-dataloader \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_PATH} \
--lr 1e-6 \
--train-iters 2000 \
--lr-decay-style cosine \
--untie-embeddings-and-output-weights \
--attention-dropout 0.0 \
--init-method-std 0.01 \
--hidden-dropout 0.0 \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--min-lr 1e-8 \
--weight-decay 1e-1 \
--lr-warmup-fraction 0.01 \
--clip-grad 1.0 \
--adam-beta1 0.9 \
--initial-loss-scale 4096 \
--adam-beta2 0.95 \
--no-gradient-accumulation-fusion \
--no-load-optim \
--no-load-rng \
--fp16 \
--kv-head-repeat-before-uly-alltoall \
--use-cp-send-recv-overlap \
--overlap-grad-reduce \
--overlap-param-gather \
"
DATA_ARGS="
--data-path $DATA_PATH \
--split 949,50,1 \
"
OUTPUT_ARGS="
--log-interval 1 \
--save-interval 1000 \
--eval-interval 1000 \
--eval-iters 10 \
"
python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--save $CKPT_SAVE_DIR \
| tee logs/train_mcore_chatglm3_6B_32K.log

View File

@ -0,0 +1,130 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
NPUS_PER_NODE=8
IPs=('IP1' 'IP2')
LOCAL_HOST=`hostname -I|awk -F " " '{print$1}'`
echo $LOCAL_HOST
NPUS_PER_NODE=8
MASTER_ADDR=${IPs[0]}
MASTER_PORT=6010
NNODES=${#IPs[@]}
NODE_RANK=""
for i in "${!IPs[@]}";
do
if [ "$LOCAL_HOST" == "${IPs[$i]}" ];
then
echo "Node Rank : ${i}"
NODE_RANK=$i
break
fi
done
WORLD_SIZE=$((NPUS_PER_NODE*$NNODES))
CKPT_SAVE_DIR="your model save ckpt path"
DATA_PATH="your data path"
TOKENIZER_PATH="your tokenizer path"
CKPT_LOAD_DIR="your model ckpt path"
TP=1
PP=1
CP=16
MBS=1
GBS=64
SEQ_LEN=65536
CP_ALGO=hybrid_cp_algo
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
GPT_ARGS="
--use-mcore-models \
--transformer-impl local \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--sequence-parallel \
--num-layers 28 \
--hidden-size 4096 \
--ffn-hidden-size 13696 \
--num-attention-heads 32 \
--ulysses-degree-in-cp 8 \
--seq-length ${SEQ_LEN} \
--micro-batch-size ${MBS} \
--global-batch-size ${GBS} \
--context-parallel-algo ${CP_ALGO} \
--context-parallel-size ${CP} \
--max-position-embeddings ${SEQ_LEN} \
--padded-vocab-size 65024 \
--make-vocab-size-divisible-by 1 \
--group-query-attention \
--num-query-groups 2 \
--disable-bias-linear \
--add-qkv-bias \
--position-embedding-type rope \
--no-rope-fusion \
--use-distributed-optimizer \
--use-partial-rope \
--use-flash-attn \
--use-fused-rmsnorm \
--use-fused-swiglu \
--normalization RMSNorm \
--swiglu \
--no-create-attention-mask-in-dataloader \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_PATH} \
--lr 1e-6 \
--train-iters 2000 \
--lr-decay-style cosine \
--untie-embeddings-and-output-weights \
--attention-dropout 0.0 \
--init-method-std 0.01 \
--hidden-dropout 0.0 \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--min-lr 1e-8 \
--weight-decay 1e-1 \
--lr-warmup-fraction 0.01 \
--clip-grad 1.0 \
--adam-beta1 0.9 \
--initial-loss-scale 4096 \
--adam-beta2 0.95 \
--no-gradient-accumulation-fusion \
--no-load-optim \
--no-load-rng \
--fp16 \
--num-workers 1 \
--kv-head-repeat-before-uly-alltoall \
--no-shared-storage \
--use-cp-send-recv-overlap \
--overlap-grad-reduce \
--overlap-param-gather \
"
DATA_ARGS="
--data-path $DATA_PATH \
--split 949,50,1
"
OUTPUT_ARGS="
--log-interval 1 \
--save-interval 1000 \
--eval-interval 1000 \
--eval-iters 10 \
"
python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--save $CKPT_SAVE_DIR \
| tee logs/train_mcore_chatglm3_6B_64K.log

View File

@ -0,0 +1,103 @@
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
NPUS_PER_NODE=8
MASTER_ADDR=localhost
MASTER_PORT=6021
NNODES=1
NODE_RANK=0
WORLD_SIZE=$((NPUS_PER_NODE*$NNODES))
CKPT_SAVE_DIR="your model save ckpt path"
DATA_PATH="your data path"
TOKENIZER_PATH="your tokenizer path"
CKPT_LOAD_DIR="your model ckpt path"
TP=1
PP=2
MBS=1
GBS=128
SEQ_LEN=8192
DISTRIBUTED_ARGS="
--nproc_per_node $NPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
GPT_ARGS="
--use-mcore-models \
--transformer-impl local \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--sequence-parallel \
--num-layers 28 \
--hidden-size 4096 \
--ffn-hidden-size 13696 \
--num-attention-heads 32 \
--seq-length ${SEQ_LEN} \
--micro-batch-size ${MBS} \
--global-batch-size ${GBS} \
--max-position-embeddings ${SEQ_LEN} \
--padded-vocab-size 65024 \
--make-vocab-size-divisible-by 1 \
--group-query-attention \
--num-query-groups 2 \
--disable-bias-linear \
--add-qkv-bias \
--position-embedding-type rope \
--no-rope-fusion \
--use-partial-rope \
--normalization RMSNorm \
--use-fused-rmsnorm \
--swiglu \
--use-fused-swiglu \
--use-flash-attn \
--use-distributed-optimizer \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path ${TOKENIZER_PATH} \
--lr 1e-6 \
--train-iters 2000 \
--lr-decay-style cosine \
--untie-embeddings-and-output-weights \
--attention-dropout 0.0 \
--init-method-std 0.01 \
--hidden-dropout 0.0 \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--min-lr 1e-8 \
--weight-decay 1e-1 \
--lr-warmup-fraction 0.01 \
--clip-grad 1.0 \
--adam-beta1 0.9 \
--initial-loss-scale 4096 \
--adam-beta2 0.95 \
--no-gradient-accumulation-fusion \
--no-load-optim \
--no-load-rng \
--fp16 \
--overlap-grad-reduce \
--overlap-param-gather \
"
DATA_ARGS="
--data-path $DATA_PATH \
--split 949,50,1
"
OUTPUT_ARGS="
--log-interval 1 \
--save-interval 1000 \
--eval-interval 1000 \
--eval-iters 10 \
"
python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--save $CKPT_SAVE_DIR \
| tee logs/train_mcore_chatglm3_6B_8K.log

View File

@ -112,6 +112,9 @@ def _validate_cp_args(args):
'for ulysses alltoall as well as memory usage.')
if args.context_parallel_size <= 1:
if args.kv_head_repeat_before_uly_alltoall:
args.kv_head_repeat_before_uly_alltoall = False
print_rank0_by_args(args, f"When context_parallel is not activated, kv_head_repeat_before_uly_alltoall would be set to False for reducing memory usage.")
return
# In context parallel we use FA