mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-02 03:48:24 +08:00
!1659 llama2-34b/70b mcore适配
Merge pull request !1659 from yuhui/llama2
This commit is contained in:
parent
dc03368cf1
commit
a2552c6a49
@ -420,10 +420,10 @@ ModelLink 通过模型并行与数据并行来训练大语言模型,为了演
|
||||
<tr>
|
||||
<td><a href="https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/tree/main">34B</a></td>
|
||||
<td>4K</td>
|
||||
<th>Legacy</th>
|
||||
<th>Mcore</th>
|
||||
<td>2x8</td>
|
||||
<td>BF16 </td>
|
||||
<td> 749 </td>
|
||||
<td> 810 </td>
|
||||
<td> -- </td>
|
||||
<td> 796 </td>
|
||||
<td>【GTS】</td>
|
||||
@ -432,10 +432,10 @@ ModelLink 通过模型并行与数据并行来训练大语言模型,为了演
|
||||
<tr>
|
||||
<td><a href="https://huggingface.co/meta-llama/Llama-2-70b-hf">70B</a></td>
|
||||
<td>4K</td>
|
||||
<th>Legacy</th>
|
||||
<th>Mcore</th>
|
||||
<td>4x8</td>
|
||||
<td>BF16 </td>
|
||||
<td> 420 </td>
|
||||
<td> 439 </td>
|
||||
<td> -- </td>
|
||||
<td> 430 </td>
|
||||
<td>【GTS】</td>
|
||||
|
@ -1045,7 +1045,7 @@ ModelLink已支持模型评估分数如下:
|
||||
| LLaMA-7B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 74.6% | [75.4](https://hub.opencompass.org.cn/dataset-detail/BoolQ) | LLaMA-13B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 79.6% | [78.7](https://hub.opencompass.org.cn/dataset-detail/BoolQ) |
|
||||
| LLaMA-33B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 83.2% | [83.1](https://paperswithcode.com/sota/question-answering-on-boolq) | LLaMA-65B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 85.7% | [86.6](https://paperswithcode.com/sota/question-answering-on-boolq) |
|
||||
| LLaMA2-7B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 45.7% | 45.3% | LLaMA2-13B | [BoolQ](https://paperswithcode.com/dataset/boolq) | 82.2% | [81.7](https://paperswithcode.com/sota/question-answering-on-boolq) |
|
||||
| LLaMA2-34B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 85.9% | -- | LLaMA2-70B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 65.1% | -- |
|
||||
| LLaMA2-34B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 82.0% | -- | LLaMA2-70B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 86.4% | -- |
|
||||
| LLaMA3-8B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 65.2% | 66.6% | LLaMA3-70B | [BoolQ](https://github.com/google-research-datasets/boolean-questions) | 78.4% | 79.5% |
|
||||
| LLaMA3.1-8B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 65.26% | 66.7% | LLaMA3.1-70B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 81.8% | 79.3% |
|
||||
| Mistral-7B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 56.3% | 56.3% | Mixtral-8x7B | [MMLU](https://paperswithcode.com/dataset/mmlu) | 69.9% | [70.6%](https://paperswithcode.com/sota/multi-task-language-understanding-on-mmlu) |
|
||||
|
54
examples/mcore/llama2/chat_llama2_70b_ptd.sh
Normal file
54
examples/mcore/llama2/chat_llama2_70b_ptd.sh
Normal file
@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
|
||||
# please fill these path configurations
|
||||
CHECKPOINT="your model directory path"
|
||||
TOKENIZER_PATH="your tokenizer directory path"
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS inference.py \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--use-mcore-models \
|
||||
--task chat \
|
||||
--hf-chat-template \
|
||||
--add-eos-token '<|eot_id|>' \
|
||||
--top-p 0.9 \
|
||||
--temperature 1 \
|
||||
--num-layers 32 \
|
||||
--num-layers 80 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 28672 \
|
||||
--position-embedding-type rope \
|
||||
--seq-length 4096 \
|
||||
--max-new-tokens 256 \
|
||||
--micro-batch-size 1 \
|
||||
--num-attention-heads 64 \
|
||||
--max-position-embeddings 4096 \
|
||||
--swiglu \
|
||||
--load ${CHECKPOINT} \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_PATH} \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--group-query-attention \
|
||||
--num-query-groups 8 \
|
||||
--normalization RMSNorm \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--disable-bias-linear \
|
||||
--attention-softmax-in-fp32 \
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--no-masked-softmax-fusion \
|
||||
--no-gradient-accumulation-fusion \
|
||||
--exit-on-missing-checkpoint \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
| tee logs/chat_llama2_70b_mcore.log
|
55
examples/mcore/llama2/evaluate_llama2_34B_ptd.sh
Normal file
55
examples/mcore/llama2/evaluate_llama2_34B_ptd.sh
Normal file
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
CHECKPOINT="Your ckpt file path"
|
||||
TOKENIZER_PATH="Your tokenizer path"
|
||||
DATA_PATH="./boolq/data/test/"
|
||||
TASK="boolq"
|
||||
|
||||
# Different task needs different max_new_tokens value, please follow the instruction in readme.
|
||||
torchrun $DISTRIBUTED_ARGS evaluation.py \
|
||||
--use-mcore-models \
|
||||
--task-data-path ${DATA_PATH} \
|
||||
--task ${TASK} \
|
||||
--seq-length 4096 \
|
||||
--max-new-tokens 1 \
|
||||
--max-position-embeddings 4096 \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 48 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 22016 \
|
||||
--num-attention-heads 64 \
|
||||
--padded-vocab-size 32000 \
|
||||
--rotary-base 1000000 \
|
||||
--disable-bias-linear \
|
||||
--swiglu \
|
||||
--position-embedding-type rope \
|
||||
--load ${CHECKPOINT} \
|
||||
--normalization RMSNorm \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_PATH} \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--use-fused-rmsnorm \
|
||||
--exit-on-missing-checkpoint \
|
||||
--no-load-rng \
|
||||
--no-load-optim \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--no-masked-softmax-fusion \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
--group-query-attention \
|
||||
--num-query-groups 8 \
|
||||
--seed 42 | tee logs/evaluation_llama2_34b_mcore_${TASK}.log
|
52
examples/mcore/llama2/evaluate_llama2_70B_ptd.sh
Normal file
52
examples/mcore/llama2/evaluate_llama2_70B_ptd.sh
Normal file
@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
CHECKPOINT="Your ckpt file path"
|
||||
TOKENIZER_PATH="Your tokenizer path"
|
||||
DATA_PATH="./boolq/data/test/"
|
||||
TASK="boolq"
|
||||
|
||||
# Different task needs different max_new_tokens value, please follow the instruction in readme.
|
||||
torchrun $DISTRIBUTED_ARGS evaluation.py \
|
||||
--use-mcore-models \
|
||||
--task-data-path ${DATA_PATH} \
|
||||
--task ${TASK} \
|
||||
--seq-length 4096 \
|
||||
--max-new-tokens 1 \
|
||||
--max-position-embeddings 4096 \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 80 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 28672 \
|
||||
--num-attention-heads 64 \
|
||||
--disable-bias-linear \
|
||||
--swiglu \
|
||||
--position-embedding-type rope \
|
||||
--load ${CHECKPOINT} \
|
||||
--normalization RMSNorm \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_PATH} \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--group-query-attention \
|
||||
--num-query-groups 8 \
|
||||
--micro-batch-size 1 \
|
||||
--exit-on-missing-checkpoint \
|
||||
--no-load-rng \
|
||||
--no-load-optim \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--no-masked-softmax-fusion \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
--seed 42 | tee logs/eval_llama2_70b_mcore_${TASK}.log
|
55
examples/mcore/llama2/generate_llama2_34B_ptd.sh
Normal file
55
examples/mcore/llama2/generate_llama2_34B_ptd.sh
Normal file
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
|
||||
# please fill these path configurations
|
||||
CHECKPOINT="your model directory path"
|
||||
TOKENIZER_PATH="your tokenizer directory path"
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS inference.py \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--use-mcore-models \
|
||||
--num-layers 48 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 22016 \
|
||||
--position-embedding-type rope \
|
||||
--seq-length 4096 \
|
||||
--max-new-tokens 256 \
|
||||
--micro-batch-size 2 \
|
||||
--global-batch-size 16 \
|
||||
--num-attention-heads 64 \
|
||||
--padded-vocab-size 32000 \
|
||||
--rotary-base 1000000 \
|
||||
--max-position-embeddings 4096 \
|
||||
--swiglu \
|
||||
--load ${CHECKPOINT} \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_PATH} \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--normalization RMSNorm \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--disable-bias-linear \
|
||||
--attention-softmax-in-fp32 \
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--no-masked-softmax-fusion \
|
||||
--no-gradient-accumulation-fusion \
|
||||
--exit-on-missing-checkpoint \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
--vocab-size 32000 \
|
||||
--group-query-attention \
|
||||
--num-query-groups 8 \
|
||||
| tee logs/generate_llama2_34b_mcore.log
|
||||
|
||||
|
48
examples/mcore/llama2/generate_llama2_70b_ptd.sh
Normal file
48
examples/mcore/llama2/generate_llama2_70b_ptd.sh
Normal file
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
|
||||
# please fill these path configurations
|
||||
CHECKPOINT="your model directory path"
|
||||
TOKENIZER_PATH="your tokenizer directory path"
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS inference.py \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--use-mcore-models \
|
||||
--num-layers 80 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 28672 \
|
||||
--position-embedding-type rope \
|
||||
--seq-length 4096 \
|
||||
--max-new-tokens 256 \
|
||||
--micro-batch-size 1 \
|
||||
--num-attention-heads 64 \
|
||||
--max-position-embeddings 4096 \
|
||||
--swiglu \
|
||||
--load ${CHECKPOINT} \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--tokenizer-name-or-path ${TOKENIZER_PATH} \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--group-query-attention \
|
||||
--num-query-groups 8 \
|
||||
--normalization RMSNorm \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--disable-bias-linear \
|
||||
--attention-softmax-in-fp32 \
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--no-masked-softmax-fusion \
|
||||
--no-gradient-accumulation-fusion \
|
||||
--exit-on-missing-checkpoint \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
| tee logs/generate_llama2_70b_mcore.log
|
95
examples/mcore/llama2/pretrain_llama2_34B_ptd_16p.sh
Normal file
95
examples/mcore/llama2/pretrain_llama2_34B_ptd_16p.sh
Normal file
@ -0,0 +1,95 @@
|
||||
#!/bin/bash
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
|
||||
NPUS_PER_NODE=8
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=2
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
CKPT_SAVE_DIR="your model save ckpt path"
|
||||
DATA_PATH="your data path"
|
||||
TOKENIZER_MODEL="your tokenizer path"
|
||||
CKPT_LOAD_DIR="your model ckpt path"
|
||||
TP=8
|
||||
PP=2
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $NPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT
|
||||
"
|
||||
|
||||
GPT_ARGS="
|
||||
--tensor-model-parallel-size ${TP} \
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--sequence-parallel \
|
||||
--use-mcore-models \
|
||||
--use-fused-swiglu \
|
||||
--use-fused-rotary-pos-emb \
|
||||
--use-mc2 \
|
||||
--num-layers 48 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 22016 \
|
||||
--num-attention-heads 64 \
|
||||
--tokenizer-type Llama2Tokenizer \
|
||||
--tokenizer-model ${TOKENIZER_MODEL} \
|
||||
--seq-length 4096 \
|
||||
--max-position-embeddings 4096 \
|
||||
--micro-batch-size 2 \
|
||||
--global-batch-size 1024 \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
--lr 1.0e-7 \
|
||||
--train-iters 2000 \
|
||||
--rotary-base 1000000 \
|
||||
--lr-decay-style cosine \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--disable-bias-linear \
|
||||
--attention-dropout 0.0 \
|
||||
--init-method-std 0.01 \
|
||||
--hidden-dropout 0.0 \
|
||||
--position-embedding-type rope \
|
||||
--normalization RMSNorm \
|
||||
--use-fused-rmsnorm \
|
||||
--swiglu \
|
||||
--use-flash-attn \
|
||||
--group-query-attention \
|
||||
--num-query-groups 8 \
|
||||
--no-masked-softmax-fusion \
|
||||
--attention-softmax-in-fp32 \
|
||||
--min-lr 1.0e-8 \
|
||||
--weight-decay 1e-2 \
|
||||
--lr-warmup-fraction 0.01 \
|
||||
--clip-grad 1.0 \
|
||||
--adam-beta1 0.9 \
|
||||
--initial-loss-scale 524288.0 \
|
||||
--adam-beta2 0.999 \
|
||||
--no-gradient-accumulation-fusion \
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--bf16 \
|
||||
"
|
||||
|
||||
DATA_ARGS="
|
||||
--data-path $DATA_PATH \
|
||||
--split 949,50,1
|
||||
"
|
||||
|
||||
OUTPUT_ARGS="
|
||||
--log-interval 1 \
|
||||
--save-interval 2000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
"
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
|
||||
$GPT_ARGS \
|
||||
$DATA_ARGS \
|
||||
$OUTPUT_ARGS \
|
||||
--distributed-backend nccl \
|
||||
--load ${CKPT_LOAD_DIR} \
|
||||
--save ${CKPT_SAVE_DIR} \
|
||||
| tee logs/train_llama2_34b_mcore.log
|
96
examples/mcore/llama2/pretrain_llama2_70b_ptd.sh
Normal file
96
examples/mcore/llama2/pretrain_llama2_70b_ptd.sh
Normal file
@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
|
||||
NPUS_PER_NODE=8
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=4
|
||||
NODE_RANK=0
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
CKPT_SAVE_DIR="your model save ckpt path"
|
||||
DATA_PATH="your data path"
|
||||
TOKENIZER_MODEL="your tokenizer path"
|
||||
CKPT_LOAD_DIR="your model ckpt path"
|
||||
TP=8
|
||||
PP=4
|
||||
|
||||
DISTRIBUTED_ARGS="
|
||||
--nproc_per_node $NPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT
|
||||
"
|
||||
|
||||
GPT_ARGS="
|
||||
--tensor-model-parallel-size ${TP} \
|
||||
--pipeline-model-parallel-size ${PP} \
|
||||
--sequence-parallel \
|
||||
--use-mcore-models \
|
||||
--use-mc2 \
|
||||
--use-fused-swiglu \
|
||||
--use-fused-rotary-pos-emb \
|
||||
--num-layers 80 \
|
||||
--hidden-size 8192 \
|
||||
--ffn-hidden-size 28672 \
|
||||
--num-attention-heads 64 \
|
||||
--tokenizer-type Llama2Tokenizer \
|
||||
--tokenizer-model ${TOKENIZER_MODEL} \
|
||||
--seq-length 4096 \
|
||||
--max-position-embeddings 4096 \
|
||||
--micro-batch-size 1 \
|
||||
--global-batch-size 1024 \
|
||||
--make-vocab-size-divisible-by 1 \
|
||||
--lr 1.0e-6 \
|
||||
--train-iters 2000 \
|
||||
--lr-decay-style cosine \
|
||||
--untie-embeddings-and-output-weights \
|
||||
--attention-dropout 0.0 \
|
||||
--init-method-std 0.01 \
|
||||
--hidden-dropout 0.0 \
|
||||
--position-embedding-type rope \
|
||||
--normalization RMSNorm \
|
||||
--use-fused-rmsnorm \
|
||||
--swiglu \
|
||||
--use-flash-attn \
|
||||
--no-masked-softmax-fusion \
|
||||
--attention-softmax-in-fp32 \
|
||||
--min-lr 1.0e-7 \
|
||||
--weight-decay 0.1 \
|
||||
--clip-grad 1.0 \
|
||||
--adam-beta1 0.9 \
|
||||
--initial-loss-scale 4096.0 \
|
||||
--adam-beta2 0.95 \
|
||||
--adam-eps 1e-5 \
|
||||
--no-gradient-accumulation-fusion \
|
||||
--no-load-optim \
|
||||
--no-load-rng \
|
||||
--disable-bias-linear \
|
||||
--group-query-attention \
|
||||
--num-query-groups 8 \
|
||||
--lr-warmup-fraction 0.01 \
|
||||
--bf16 \
|
||||
--num-layers-per-virtual-pipeline-stage 5 \
|
||||
"
|
||||
|
||||
DATA_ARGS="
|
||||
--data-path $DATA_PATH \
|
||||
--split 949,50,1
|
||||
"
|
||||
|
||||
OUTPUT_ARGS="
|
||||
--log-interval 1 \
|
||||
--save-interval 2000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
"
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
|
||||
$GPT_ARGS \
|
||||
$DATA_ARGS \
|
||||
$OUTPUT_ARGS \
|
||||
--distributed-backend nccl \
|
||||
--load ${CKPT_LOAD_DIR} \
|
||||
--save ${CKPT_SAVE_DIR} \
|
||||
| tee logs/train_llama2_70b_mcore.log
|
Loading…
Reference in New Issue
Block a user