Adapt to lora and get through the whole process.

This commit is contained in:
ningbenzhe1 2023-11-24 09:16:02 +08:00
parent 8757d302a5
commit 8a282348ca
8 changed files with 378 additions and 50 deletions

View File

@ -19,6 +19,7 @@ import os
import random
import sys
import numpy as np
from deepspeed import PipelineEngine
from deepspeed.accelerator import get_accelerator
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
@ -26,12 +27,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from ascendspeed.enums import PositionEmbeddingType
from ascendspeed.utils import WRITE_FILE_DEFAULT_FLAGS, WRITE_FILE_DEFAULT_MODES
from ascendspeed import (get_args,
is_rank_0,
print_rank_0,
update_num_microbatches,
utils)
is_rank_0,
print_rank_0,
update_num_microbatches,
utils)
from ascendspeed.core import parallel_state, tensor_parallel
from ascendspeed.model import DistributedDataParallel as LocalDDP, Float16Module
from ascendspeed.model.lora_utils import is_enable_lora, get_lora_state_dict, lora_custom_load_fn_for_deepspeed, \
@ -69,8 +69,8 @@ def check_checkpoint_args(checkpoint_args):
checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(args, arg_name)
error_info = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format(
arg_name, checkpoint_value, args_value)
'input argument value ({}).'.format(
arg_name, checkpoint_value, args_value)
check_equal(checkpoint_value, args_value, error_info)
if not args.mos and not args.kd:
@ -144,7 +144,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
iteration, args.save))
if not torch.distributed.is_initialized() or parallel_state.get_data_parallel_rank() == 0 \
or args.deepspeed:
or args.deepspeed:
# Arguments, iteration, and model.
state_dict = {}
@ -266,8 +266,8 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):
intermediate_shape = \
(num_attention_heads_per_partition,
hidden_size_per_attention_head, num_splits) +\
input_shape[1:]
hidden_size_per_attention_head, num_splits) + \
input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(1, 2).contiguous()
@ -304,7 +304,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
sys.exit()
param.data.copy_(fixed_param)
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
" checkpoint version {}".format(checkpoint_version))
def read_tracker(load_dir):
@ -404,11 +404,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0(f" will not load any checkpoints and will start from random")
return 0
custom_load_fn, load_dir = get_custom_load_fn(model=model[0], load_dir=load_dir, lora_load_dir=lora_load_dir)
load_zero_optim = sum(['zero' in file for file in os.listdir(load_dir)]) > 0
if args.no_pipeline_parallel:
load_zero_optim = sum(['zero' in file for file in os.listdir(load_dir)]) > 0
else:
load_zero_optim = sum(['global' in file for file in os.listdir(load_dir)]) > 0
release = not load_zero_optim
loaded_dir, state_dict = model[0].load_checkpoint(
load_dir,
load_module_strict=strict,
# It is only loaded not strictly when lora is turned on and the original model is loaded.
load_module_strict=not (release and is_enable_lora()),
load_module_only=not load_zero_optim,
load_optimizer_states=load_zero_optim,
load_lr_scheduler_states=load_zero_optim,
@ -452,10 +456,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
'consumed_valid_samples', 0)
else:
print_rank_0('could not find arguments in the checkpoint ...')
@ -512,6 +516,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
def get_custom_load_fn(model, load_dir, lora_load_dir=None):
custom_load_fn = None
if isinstance(model, PipelineEngine):
return custom_load_fn, load_dir
if is_enable_lora():
if lora_load_dir:
custom_load_fn = get_lora_load_fn_with_deepspeed(model=model, base_model_load_dir=load_dir)
@ -566,7 +574,7 @@ def load_state_dict_from_checkpoint_with_megatron(checkpoint_name, model_checkpo
def load_biencoder_checkpoint(model, only_query_model=False,
only_context_model=False, custom_load_path=None):
only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
from saved checkpoints

View File

@ -104,7 +104,7 @@ class LoraParallelLinear(ColumnParallelLinear, RowParallelLinear, LoraParalleLay
else:
parallel_linear_kwargs['gather_output'] = parallel_linear.gather_output
type(parallel_linear).__init__(self, input_size=parallel_linear.input_size,
output_size=parallel_linear.output_size, bias=parallel_linear.bias,
output_size=parallel_linear.output_size, bias=parallel_linear.bias is not None,
skip_bias_add=parallel_linear.skip_bias_add,
**parallel_linear_kwargs)
LoraParalleLayer.__init__(self, in_features=parallel_linear.input_size,

View File

@ -33,7 +33,6 @@ from ascendspeed.core import parallel_state, tensor_parallel
from ascendspeed.model.lora_utils import is_enable_lora, get_lora_model_classes
from ascendspeed.error_utils import ensure_valid
_FLOAT_TYPES = (torch.FloatTensor, get_accelerator().FloatTensor)
_HALF_TYPES = (torch.HalfTensor, get_accelerator().HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor)
@ -154,6 +153,7 @@ def conversion_helper(val, conversion):
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (torch.nn.parameter.Parameter, torch.autograd.Variable)):
@ -167,6 +167,7 @@ def fp32_to_float16(val, float16_convertor):
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
if val is None:
return val
@ -429,7 +430,7 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
checked_ids = []
for per_ids in ids:
if per_ids == torch.Size([]) and torch.max(per_ids) >= len(tokenizer):
warning_info = "The output ids exceeds the tokenizer length, "\
warning_info = "The output ids exceeds the tokenizer length, " \
"the clamp operation is enforced, please check!!"
logging.warning(warning_info)
checked_ids.append(torch.clamp(per_ids, min=0, max=len(tokenizer)) - 1)
@ -459,18 +460,20 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
args.load = pretrained_model_name_or_path
if args.deepspeed:
if is_enable_lora():
unwrap_classes = get_lora_model_classes()
# The deepspeed pipeline needs to verify the model base class. Therefore, the peft package needs to be unpacked.
args.model = unwrap_model(args.model, unwrap_classes)
args.model[0] = cls._init_deepspeed_inference(args.model[0], args)
if args.load:
load_checkpoint(args.model, None, None)
if not args.deepspeed:
unwrap_classes = (torchDDP, LocalDDP, Float16Module)
if is_enable_lora():
unwrap_classes += get_lora_model_classes()
else:
unwrap_classes = (torchDDP, LocalDDP, Float16Module, deepspeed.DeepSpeedEngine)
unwrap_classes = (torchDDP, LocalDDP, Float16Module)
if args.deepspeed:
unwrap_classes += (deepspeed.DeepSpeedEngine,)
# The returned model provides the MegatronModuleForCausalLM class identifier. In actual inference, args.model is still used.
return unwrap_model(args.model, unwrap_classes)[0]
def generate(self, input_ids=None, **kwargs):
@ -610,7 +613,7 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
output_checked = self._ids_check(output, self.tokenizer)
output = self.tokenizer.batch_decode(output_checked, skip_special_tokens=True)
except Exception as e:
error_info = "Meet errors when trying to decode the tokens. "\
error_info = "Meet errors when trying to decode the tokens. " \
"Please handle it by yourself."
logging.error(error_info)
logging.error(e)

View File

@ -29,8 +29,8 @@ def beam_search(model, tokens, **kwargs):
# ==========================
# Pad tokens
# ==========================
final_sequence_length = args.max_length_ori
prompt_length, context_lengths, tokens = _pad_tokens(args, tokens, beam_size, num_return_gen)
final_sequence_length = args.max_length_ori
# ==========================
# Forward step

View File

@ -52,7 +52,7 @@ from ascendspeed.data.data_samplers import build_pretraining_data_loader
from ascendspeed.utils import calc_params_l2_norm
from ascendspeed.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator
from ascendspeed.model.transformer import ParallelTransformerLayer
from ascendspeed.model.lora_utils import is_enable_lora, handle_model_with_lora
from ascendspeed.model.lora_utils import is_enable_lora, handle_model_with_lora, get_lora_model_classes
from ascendspeed.core.pipeline_parallel.schedules import forward_backward_pipelining_with_foldx_fifo
from ascendspeed.core.pipeline_parallel.schedules import forward_backward_pipelining_with_foldx_aiao
from ascendspeed.core.pipeline_parallel.schedules import get_forward_backward_func, get_forward_func
@ -582,8 +582,10 @@ def setup_model_and_optimizer(model_provider_func,
model = [model]
model = [init_compression(model[0].module, args.deepspeed_config, tensor_parallel)]
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
unwrap_model_classes = (torchDDP, LocalDDP, Float16Module)
if is_enable_lora():
unwrap_model_classes += get_lora_model_classes()
unwrapped_model = unwrap_model(model, unwrap_model_classes)
if args.inference:
optimizer = None
@ -624,7 +626,7 @@ def setup_model_and_optimizer(model_provider_func,
train_ds, _, _ = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
model, optimizer, args.deepspeed_dataloader, lr_scheduler = deepspeed.initialize(
model=model[0],
model=unwrapped_model[0],
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
@ -634,7 +636,7 @@ def setup_model_and_optimizer(model_provider_func,
model.set_data_post_process_func(data_post_process)
else:
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model[0],
model=unwrapped_model[0],
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,

View File

@ -2,7 +2,6 @@
<p align="left">
<b>简体中文</b> |
<b><a href="https://gitee.com/ascend/AscendSpeed/blob/master/examples/bloom/README_en.md">English</a> </b>
</p>
</p>
@ -15,7 +14,9 @@
- [吞吐](#吞吐)
- [精度](#精度)
- [推理](#推理)
- [脚本](#脚本)
- [deepspeed_pipeline](#deepspeed_pipeline)
- [megatron](#megatron)
- [评估](#评估)
- [Bloom-176B](#Bloom-176B)
- [训练](#训练)
- [脚本](#脚本)
@ -23,8 +24,10 @@
- [吞吐](#吞吐)
- [精度](#精度)
- [推理](#推理)
- [脚本](#脚本)
- [deepspeed_pipeline](#deepspeed_pipeline)
- [megatron](#megatron)
- [评估](#评估)
- [举例](#举例)
# Bloom-7B
## 训练
@ -179,6 +182,21 @@ NPU vs 参考 loss 相对误差
AscendSpeed 支持 BLOOM 7B 的文本生成推理.
### deepspeed_pipeline
```shell
# 修改 model weight 路径和 tokenizer 路径
CHECKPOINT=/home/model/bloom_7B
VOCAB_FILE=/home/bloom_data/vocab_file/
```
```shell
bash ./examples/bloom/generate_bloom_7b_deepspeed_pipeline.sh
```
### megatron
使用 [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh) 将bloom-7B的权重转换为推理格式
```bash
@ -190,13 +208,12 @@ python $SCRIPT_PATH \
--pipeline-model-parallel-size 1 \
--type 7B
```
### 脚本
配置 Bloom-7B 推理脚本: examples/bloom/generate_bloom_7B_tp8_pp1.sh
```shell
# 修改 model weight 路径和 tokenizer 路径
CHECKPOINT=/home/bloom_data/enwiki_100k/enwiki-100k_text_document
CHECKPOINT=/home/model/bloom_7B
VOCAB_FILE=/home/bloom_data/vocab_file/
```
@ -204,6 +221,59 @@ VOCAB_FILE=/home/bloom_data/vocab_file/
bash ./examples/bloom/generate_bloom_7B_tp8_pp1.sh
```
## 评估
配置 Bloom-7B 评估脚本: tasks/evaluation/eval_bloom.sh
```shell
# 修改 model weight 路径和 tokenizer 路径和数据集任务路径
CHECKPOINT=/home/model/bloom_7B
VOCAB_FILE=/home/bloom_data/vocab_file/
DATA_PATH="/dataset/boolq/test"
TASK="boolq"
```
除此之外你还需要根据模型大小设置参数:
```shell
--num-layers 30
--hidden-size 4096
--num-attention-heads 32
```
```shell
bash ./tasks/evaluation/eval_bloom.sh
```
<table>
<thead>
<tr>
<th>任务</th>
<th>验证集</th>
<th>模型</th>
<th>昇腾值</th>
<th>社区值</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://huggingface.co/datasets/boolq">Boolq</a></td>
<td>Test</td>
<th>bloom 7b</th>
<td>0.614</td>
<td>--</td>
</tr>
</tbody>
<tbody>
<tr>
<td><a href="https://huggingface.co/datasets/cais/mmlu">mmlu</a></td>
<td>Test</td>
<th>bloom 7b</th>
<td>0.251</td>
<td><a href="https://www.hayo.com/article/648ace24409528db3186ef1c">0.254</a></td>
</tr>
</tbody>
</table>
# Bloom-176B
## 训练
@ -367,6 +437,21 @@ NPU vs 参考 loss
## 推理
AscendSpeed 支持 BLOOM 176B的在线文本生成推理
We support AscendSpeed Inference for text generation with BLOOM 176B (deepspeed or megatron).
### deepspeed_pipeline
```shell
# # 修改 model weight 路径和 tokenizer 路径
CHECKPOINT=/home/model/bloom_176B
VOCAB_FILE=/home/bloom_data/vocab_file/
```
```shell
bash ./examples/bloom/generate_bloom_176b_deepspeed_pipeline.sh
```
### megatron
使用 [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh) 脚本将权重转化为推理格式。
推理需要两节点运行需要我们手工将权重同步到两节点下0号节点需要 1-37 层权重1号节点需要 38-74 层权重,执行脚本如下:
@ -391,7 +476,7 @@ MASTER_ADDR=localhost
NODE_RANK=0
# 修改数据集路径和词表路径
CHECKPOINT=/home/bloom_data/enwiki_100k/enwiki-100k_text_document
CHECKPOINT=/home/model/bloom_176B
VOCAB_FILE=/home/bloom_data/vocab_file/
```
@ -399,6 +484,50 @@ VOCAB_FILE=/home/bloom_data/vocab_file/
bash ./examples/bloom/generate_bloom_176b_2nodes.sh
```
## 评估
配置 Bloom-176B 评估脚本: tasks/evaluation/eval_bloom.sh
```shell
# 修改 model weight 路径和 tokenizer 路径和数据集任务路径
CHECKPOINT=/home/model/bloom_176B
VOCAB_FILE=/home/bloom_data/vocab_file/
DATA_PATH="/dataset/boolq/test"
TASK="boolq"
```
除此之外你还需要根据模型大小设置参数:
```shell
--num-layers 70
--hidden-size 14336
--num-attention-heads 112
```
```shell
bash ./tasks/evaluation/eval_bloom.sh
```
<table>
<thead>
<tr>
<th>任务</th>
<th>验证集</th>
<th>模型</th>
<th>昇腾值</th>
<th>社区值</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://huggingface.co/datasets/boolq">Boolq</a></td>
<td>Test</td>
<th>bloom 176b</th>
<td>0.657</td>
<td>--</td>
</tr>
</tbody>
</table>
## 举例
1. bloom 7b
@ -409,9 +538,11 @@ bash ./examples/bloom/generate_bloom_176b_2nodes.sh
## 引用
```
@article{scao2022bloom,
title={Bloom: A 176b-parameter open-access multilingual language model},
author={Scao, Teven Le and Fan, Angela and Akiki, Christopher and Pavlick, Ellie and Ili{\'c}, Suzana and Hesslow, Daniel and Castagn{\'e}, Roman and Luccioni, Alexandra Sasha and Yvon, Fran{\c{c}}ois and Gall{\'e}, Matthias and others},
journal={arXiv preprint arXiv:2211.05100},
year={2022}
}
}
```

View File

@ -2,9 +2,10 @@
<p align="left">
<b><a href="https://gitee.com/ascend/AscendSpeed/blob/master/examples/bloom/README.md">简体中文</a></b> |
<b>English</b>
</p>
</p>
# Contents
- [Bloom-7B](#contents)
@ -66,6 +67,7 @@ conda activate bloom7b
# install torch and torch_npu and apex
pip install torch-2.0.1-cp38-cp38-manylinux2014_aarch64.whl
pip install torch_npu-2.0.1rc1.postxxxxxxxx-cp38-cp38-linux_aarch64.whl
pip install apex-0.1_ascend*-cp38-cp38m-linux_aarch64.whl
# install megatron-core
pip3 install --no-use-pep517 -e git+https://github.com/NVIDIA/Megatron-LM.git@23.05#egg=megatron-core
@ -83,7 +85,7 @@ pip install -r requirements.txt
3. Prepare pretrained weights
Download the Bloom-7B weights from [here](https://huggingface.co/bigscience/bloom-7b1/tree/main).
Download the Bloom-7B tokensizer from [here](https://huggingface.co/bigscience/bloom-7b1/tree/main).
```shell
mkdir tokenizer
@ -91,7 +93,6 @@ cd tokenizer
wget https://huggingface.co/bigscience/bloom/resolve/main/special_tokens_map.json
wget https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json
wget https://huggingface.co/bigscience/bloom/resolve/main/tokenizer_config.json
...
cd ..
```
@ -180,9 +181,23 @@ NPU vs GPU loss relative error.
## Inference
We support AscendSpeed Inference for text generation with BLOOM 7B.
We support AscendSpeed Inference for text generation with BLOOM 7B (deepspeed or megatron).
Use [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh), converting deep speed checkpoints to megatron.Convert the checkpoint of deepspeed to megtron.
### deepspeed_pipeline
```shell
# modify the model weight path and tokenizer path
CHECKPOINT=/home/model/bloom_7B
VOCAB_FILE=/home/bloom_data/vocab_file/
```
```shell
bash ./examples/bloom/generate_bloom_7b_deepspeed_pipeline.sh
```
### megatron
Use [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh), converting deepspeed checkpoints to megatron.
```bash
SCRIPT_PATH=./tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel_v2.py
@ -193,7 +208,6 @@ python $SCRIPT_PATH \
--pipeline-model-parallel-size 1 \
--type 7B
```
### Script
We generate text samples using the `generate_bloom` script. Inference different from pre-training, such as we need to Load pre training checkpoint and the length of the output samples:
@ -201,7 +215,7 @@ Config Bloom-7B inference script: examples/bloom/generate_bloom_7B_tp8_pp1.sh
```shell
# modify the model weight path and tokenizer path
CHECKPOINT=/home/bloom_data/enwiki_100k/enwiki-100k_text_document
CHECKPOINT=/home/model/bloom_7B
VOCAB_FILE=/home/bloom_data/vocab_file/
```
@ -209,6 +223,59 @@ VOCAB_FILE=/home/bloom_data/vocab_file/
bash ./examples/bloom/generate_bloom_7B_tp8_pp1.sh
```
## Evaluation
Config Bloom-7B evaluation script: tasks/evaluation/eval_bloom.sh
```shell
# modify the model weight path and tokenizer path
CHECKPOINT=/home/model/bloom_7B
VOCAB_FILE=/home/bloom_data/vocab_file/
DATA_PATH="/dataset/boolq/test"
TASK="boolq"
```
In addition, you need to set the corresponding parameters according to the model size, bloom_7B parameters are:
```shell
--num-layers 30
--hidden-size 4096
--num-attention-heads 32
```
```shell
bash ./tasks/evaluation/eval_bloom.sh
```
<table>
<thead>
<tr>
<th>任务</th>
<th>验证集</th>
<th>模型</th>
<th>昇腾值</th>
<th>社区值</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://huggingface.co/datasets/boolq">Boolq</a></td>
<td>Test</td>
<th>bloom 7b</th>
<td>0.614</td>
<td>--</td>
</tr>
</tbody>
<tbody>
<tr>
<td><a href="https://huggingface.co/datasets/cais/mmlu">mmlu</a></td>
<td>Test</td>
<th>bloom 7b</th>
<td>0.251</td>
<td><a href="https://www.hayo.com/article/648ace24409528db3186ef1c">0.254</a></td>
</tr>
</tbody>
</table>
# Bloom-176B
## Training
@ -371,7 +438,21 @@ and GPU on a single-node system. The average relative error is 0.1%, less than 2
## Inference
We support AscendSpeed Inference for text generation with BLOOM 176B.
We support AscendSpeed Inference for text generation with BLOOM 176B (deepspeed or megatron).
### deepspeed_pipeline
```shell
# modify the model weight path and tokenizer path
CHECKPOINT=/home/model/bloom_176B
VOCAB_FILE=/home/bloom_data/vocab_file/
```
```shell
bash ./examples/bloom/generate_bloom_176b_deepspeed_pipeline.sh
```
### megatron.
Use [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh), converting deep speed checkpoints to megatron.Convert the checkpoint of deepspeed to megtron.
@ -399,7 +480,7 @@ MASTER_ADDR=localhost
NODE_RANK=0
# modify the model weight path and tokenizer path
CHECKPOINT=/home/bloom_data/enwiki_100k/enwiki-100k_text_document
CHECKPOINT=/home/model/bloom_176B
VOCAB_FILE=/home/bloom_data/vocab_file/
```
@ -407,6 +488,48 @@ VOCAB_FILE=/home/bloom_data/vocab_file/
bash ./examples/bloom/generate_bloom_176b_2nodes.sh
```
## Evaluation
Config Bloom-7B evaluation script: tasks/evaluation/eval_bloom.sh
```shell
# modify the model weight path and tokenizer path
CHECKPOINT=/home/model/bloom_176B
VOCAB_FILE=/home/bloom_data/vocab_file/
DATA_PATH="/dataset/boolq/test"
TASK="boolq"
```
In addition, you need to set the corresponding parameters according to the model size, bloom_7B parameters are:
```shell
--num-layers 70
--hidden-size 14336
--num-attention-heads 112
```
```shell
bash ./tasks/evaluation/eval_bloom.sh
```
<table>
<thead>
<tr>
<th>任务</th>
<th>验证集</th>
<th>模型</th>
<th>昇腾值</th>
<th>社区值</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://huggingface.co/datasets/boolq">Boolq</a></td>
<td>Test</td>
<th>bloom 176b</th>
<td>0.657</td>
<td>--</td>
</tr>
</tbody>
</table>
## Example
1. bloom 7b
![bloom_7b_generate.png](..%2F..%2Fsources%2Fimages%2Fbloom_7b_generate.png)
@ -421,9 +544,14 @@ All the provided scripts are tested on 910 64GB NPUs for BLOOM 7B and BLOOM 176B
You may also consider original work in your reference:
```
@article{scao2022bloom,
title={Bloom: A 176b-parameter open-access multilingual language model},
author={Scao, Teven Le and Fan, Angela and Akiki, Christopher and Pavlick, Ellie and Ili{\'c}, Suzana and Hesslow, Daniel and Castagn{\'e}, Roman and Luccioni, Alexandra Sasha and Yvon, Fran{\c{c}}ois and Gall{\'e}, Matthias and others},
journal={arXiv preprint arXiv:2211.05100},
year={2022}
}
}
```
\
\
<font size=1>If the download of the file fails using 'wget' , you can download it manually while ensuring website security.</font>

View File

@ -0,0 +1,56 @@
#!/bin/bash
export TOKENIZERS_PARALLELISM=false
NNODES=1
NPUS_PER_NODE=8
CHECKPOINT="your megatron checkpoint path"
VOCAB_FILE="your vocab path"
ZERO_STAGE=0
MICRO_BATCH_SIZE=1
config_json="./ds_config.json"
cat <<EOT > $config_json
{
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 12
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
EOT
deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
./tasks/inference/inference_bloom_pipeline.py \
--no-contiguous-buffers-in-local-ddp \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 \
--num-layers 70 \
--hidden-size 14336 \
--num-attention-heads 112 \
--max-position-embeddings 2048 \
--position-embedding-type alibi \
--embed-layernorm \
--tokenizer-type PretrainedFromHF \
--load "${CHECKPOINT}" \
--tokenizer-name-or-path "$VOCAB_FILE" \
--tokenizer-not-use-fast \
--fp16 \
--micro-batch-size 1 \
--seq-length 1024 \
--max-new-tokens 256 \
--seed 42 \
--deepspeed \
--deepspeed_config ${config_json} \