mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 05:17:40 +08:00
Adapt to lora and get through the whole process.
This commit is contained in:
parent
8757d302a5
commit
8a282348ca
@ -19,6 +19,7 @@ import os
|
||||
import random
|
||||
import sys
|
||||
import numpy as np
|
||||
from deepspeed import PipelineEngine
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
@ -26,12 +27,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
|
||||
from ascendspeed.enums import PositionEmbeddingType
|
||||
from ascendspeed.utils import WRITE_FILE_DEFAULT_FLAGS, WRITE_FILE_DEFAULT_MODES
|
||||
|
||||
|
||||
from ascendspeed import (get_args,
|
||||
is_rank_0,
|
||||
print_rank_0,
|
||||
update_num_microbatches,
|
||||
utils)
|
||||
is_rank_0,
|
||||
print_rank_0,
|
||||
update_num_microbatches,
|
||||
utils)
|
||||
from ascendspeed.core import parallel_state, tensor_parallel
|
||||
from ascendspeed.model import DistributedDataParallel as LocalDDP, Float16Module
|
||||
from ascendspeed.model.lora_utils import is_enable_lora, get_lora_state_dict, lora_custom_load_fn_for_deepspeed, \
|
||||
@ -69,8 +69,8 @@ def check_checkpoint_args(checkpoint_args):
|
||||
checkpoint_value = getattr(checkpoint_args, arg_name)
|
||||
args_value = getattr(args, arg_name)
|
||||
error_info = '{} value from checkpoint ({}) is not equal to the ' \
|
||||
'input argument value ({}).'.format(
|
||||
arg_name, checkpoint_value, args_value)
|
||||
'input argument value ({}).'.format(
|
||||
arg_name, checkpoint_value, args_value)
|
||||
check_equal(checkpoint_value, args_value, error_info)
|
||||
|
||||
if not args.mos and not args.kd:
|
||||
@ -144,7 +144,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
|
||||
iteration, args.save))
|
||||
|
||||
if not torch.distributed.is_initialized() or parallel_state.get_data_parallel_rank() == 0 \
|
||||
or args.deepspeed:
|
||||
or args.deepspeed:
|
||||
|
||||
# Arguments, iteration, and model.
|
||||
state_dict = {}
|
||||
@ -266,8 +266,8 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):
|
||||
|
||||
intermediate_shape = \
|
||||
(num_attention_heads_per_partition,
|
||||
hidden_size_per_attention_head, num_splits) +\
|
||||
input_shape[1:]
|
||||
hidden_size_per_attention_head, num_splits) + \
|
||||
input_shape[1:]
|
||||
|
||||
t = t.view(*intermediate_shape)
|
||||
t = t.transpose(1, 2).contiguous()
|
||||
@ -304,7 +304,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
|
||||
sys.exit()
|
||||
param.data.copy_(fixed_param)
|
||||
print_rank_0(" succesfully fixed query-key-values ordering for"
|
||||
" checkpoint version {}".format(checkpoint_version))
|
||||
" checkpoint version {}".format(checkpoint_version))
|
||||
|
||||
|
||||
def read_tracker(load_dir):
|
||||
@ -404,11 +404,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
|
||||
print_rank_0(f" will not load any checkpoints and will start from random")
|
||||
return 0
|
||||
custom_load_fn, load_dir = get_custom_load_fn(model=model[0], load_dir=load_dir, lora_load_dir=lora_load_dir)
|
||||
load_zero_optim = sum(['zero' in file for file in os.listdir(load_dir)]) > 0
|
||||
if args.no_pipeline_parallel:
|
||||
load_zero_optim = sum(['zero' in file for file in os.listdir(load_dir)]) > 0
|
||||
else:
|
||||
load_zero_optim = sum(['global' in file for file in os.listdir(load_dir)]) > 0
|
||||
release = not load_zero_optim
|
||||
loaded_dir, state_dict = model[0].load_checkpoint(
|
||||
load_dir,
|
||||
load_module_strict=strict,
|
||||
# It is only loaded not strictly when lora is turned on and the original model is loaded.
|
||||
load_module_strict=not (release and is_enable_lora()),
|
||||
load_module_only=not load_zero_optim,
|
||||
load_optimizer_states=load_zero_optim,
|
||||
load_lr_scheduler_states=load_zero_optim,
|
||||
@ -452,10 +456,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
|
||||
checkpoint_args = state_dict['args']
|
||||
check_checkpoint_args(checkpoint_args)
|
||||
args.consumed_train_samples = getattr(checkpoint_args,
|
||||
'consumed_train_samples', 0)
|
||||
'consumed_train_samples', 0)
|
||||
update_num_microbatches(consumed_samples=args.consumed_train_samples)
|
||||
args.consumed_valid_samples = getattr(checkpoint_args,
|
||||
'consumed_valid_samples', 0)
|
||||
'consumed_valid_samples', 0)
|
||||
else:
|
||||
print_rank_0('could not find arguments in the checkpoint ...')
|
||||
|
||||
@ -512,6 +516,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
|
||||
|
||||
def get_custom_load_fn(model, load_dir, lora_load_dir=None):
|
||||
custom_load_fn = None
|
||||
|
||||
if isinstance(model, PipelineEngine):
|
||||
return custom_load_fn, load_dir
|
||||
|
||||
if is_enable_lora():
|
||||
if lora_load_dir:
|
||||
custom_load_fn = get_lora_load_fn_with_deepspeed(model=model, base_model_load_dir=load_dir)
|
||||
@ -566,7 +574,7 @@ def load_state_dict_from_checkpoint_with_megatron(checkpoint_name, model_checkpo
|
||||
|
||||
|
||||
def load_biencoder_checkpoint(model, only_query_model=False,
|
||||
only_context_model=False, custom_load_path=None):
|
||||
only_context_model=False, custom_load_path=None):
|
||||
"""
|
||||
selectively load retrieval models for indexing/retrieving
|
||||
from saved checkpoints
|
||||
|
@ -104,7 +104,7 @@ class LoraParallelLinear(ColumnParallelLinear, RowParallelLinear, LoraParalleLay
|
||||
else:
|
||||
parallel_linear_kwargs['gather_output'] = parallel_linear.gather_output
|
||||
type(parallel_linear).__init__(self, input_size=parallel_linear.input_size,
|
||||
output_size=parallel_linear.output_size, bias=parallel_linear.bias,
|
||||
output_size=parallel_linear.output_size, bias=parallel_linear.bias is not None,
|
||||
skip_bias_add=parallel_linear.skip_bias_add,
|
||||
**parallel_linear_kwargs)
|
||||
LoraParalleLayer.__init__(self, in_features=parallel_linear.input_size,
|
||||
|
@ -33,7 +33,6 @@ from ascendspeed.core import parallel_state, tensor_parallel
|
||||
from ascendspeed.model.lora_utils import is_enable_lora, get_lora_model_classes
|
||||
from ascendspeed.error_utils import ensure_valid
|
||||
|
||||
|
||||
_FLOAT_TYPES = (torch.FloatTensor, get_accelerator().FloatTensor)
|
||||
_HALF_TYPES = (torch.HalfTensor, get_accelerator().HalfTensor)
|
||||
_BF16_TYPES = (torch.BFloat16Tensor)
|
||||
@ -154,6 +153,7 @@ def conversion_helper(val, conversion):
|
||||
|
||||
def fp32_to_float16(val, float16_convertor):
|
||||
"""Convert fp32 `val` to fp16/bf16"""
|
||||
|
||||
def half_conversion(val):
|
||||
val_typecheck = val
|
||||
if isinstance(val_typecheck, (torch.nn.parameter.Parameter, torch.autograd.Variable)):
|
||||
@ -167,6 +167,7 @@ def fp32_to_float16(val, float16_convertor):
|
||||
|
||||
def float16_to_fp32(val):
|
||||
"""Convert fp16/bf16 `val` to fp32"""
|
||||
|
||||
def float_conversion(val):
|
||||
if val is None:
|
||||
return val
|
||||
@ -429,7 +430,7 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
checked_ids = []
|
||||
for per_ids in ids:
|
||||
if per_ids == torch.Size([]) and torch.max(per_ids) >= len(tokenizer):
|
||||
warning_info = "The output ids exceeds the tokenizer length, "\
|
||||
warning_info = "The output ids exceeds the tokenizer length, " \
|
||||
"the clamp operation is enforced, please check!!"
|
||||
logging.warning(warning_info)
|
||||
checked_ids.append(torch.clamp(per_ids, min=0, max=len(tokenizer)) - 1)
|
||||
@ -459,18 +460,20 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
args.load = pretrained_model_name_or_path
|
||||
|
||||
if args.deepspeed:
|
||||
if is_enable_lora():
|
||||
unwrap_classes = get_lora_model_classes()
|
||||
# The deepspeed pipeline needs to verify the model base class. Therefore, the peft package needs to be unpacked.
|
||||
args.model = unwrap_model(args.model, unwrap_classes)
|
||||
args.model[0] = cls._init_deepspeed_inference(args.model[0], args)
|
||||
|
||||
if args.load:
|
||||
load_checkpoint(args.model, None, None)
|
||||
|
||||
if not args.deepspeed:
|
||||
unwrap_classes = (torchDDP, LocalDDP, Float16Module)
|
||||
if is_enable_lora():
|
||||
unwrap_classes += get_lora_model_classes()
|
||||
else:
|
||||
unwrap_classes = (torchDDP, LocalDDP, Float16Module, deepspeed.DeepSpeedEngine)
|
||||
unwrap_classes = (torchDDP, LocalDDP, Float16Module)
|
||||
|
||||
if args.deepspeed:
|
||||
unwrap_classes += (deepspeed.DeepSpeedEngine,)
|
||||
# The returned model provides the MegatronModuleForCausalLM class identifier. In actual inference, args.model is still used.
|
||||
return unwrap_model(args.model, unwrap_classes)[0]
|
||||
|
||||
def generate(self, input_ids=None, **kwargs):
|
||||
@ -610,7 +613,7 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
output_checked = self._ids_check(output, self.tokenizer)
|
||||
output = self.tokenizer.batch_decode(output_checked, skip_special_tokens=True)
|
||||
except Exception as e:
|
||||
error_info = "Meet errors when trying to decode the tokens. "\
|
||||
error_info = "Meet errors when trying to decode the tokens. " \
|
||||
"Please handle it by yourself."
|
||||
logging.error(error_info)
|
||||
logging.error(e)
|
||||
|
@ -29,8 +29,8 @@ def beam_search(model, tokens, **kwargs):
|
||||
# ==========================
|
||||
# Pad tokens
|
||||
# ==========================
|
||||
final_sequence_length = args.max_length_ori
|
||||
prompt_length, context_lengths, tokens = _pad_tokens(args, tokens, beam_size, num_return_gen)
|
||||
final_sequence_length = args.max_length_ori
|
||||
|
||||
# ==========================
|
||||
# Forward step
|
||||
|
@ -52,7 +52,7 @@ from ascendspeed.data.data_samplers import build_pretraining_data_loader
|
||||
from ascendspeed.utils import calc_params_l2_norm
|
||||
from ascendspeed.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator
|
||||
from ascendspeed.model.transformer import ParallelTransformerLayer
|
||||
from ascendspeed.model.lora_utils import is_enable_lora, handle_model_with_lora
|
||||
from ascendspeed.model.lora_utils import is_enable_lora, handle_model_with_lora, get_lora_model_classes
|
||||
from ascendspeed.core.pipeline_parallel.schedules import forward_backward_pipelining_with_foldx_fifo
|
||||
from ascendspeed.core.pipeline_parallel.schedules import forward_backward_pipelining_with_foldx_aiao
|
||||
from ascendspeed.core.pipeline_parallel.schedules import get_forward_backward_func, get_forward_func
|
||||
@ -582,8 +582,10 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
model = [model]
|
||||
model = [init_compression(model[0].module, args.deepspeed_config, tensor_parallel)]
|
||||
|
||||
unwrapped_model = unwrap_model(model,
|
||||
(torchDDP, LocalDDP, Float16Module))
|
||||
unwrap_model_classes = (torchDDP, LocalDDP, Float16Module)
|
||||
if is_enable_lora():
|
||||
unwrap_model_classes += get_lora_model_classes()
|
||||
unwrapped_model = unwrap_model(model, unwrap_model_classes)
|
||||
|
||||
if args.inference:
|
||||
optimizer = None
|
||||
@ -624,7 +626,7 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
train_ds, _, _ = build_train_valid_test_datasets_provider(
|
||||
train_val_test_num_samples)
|
||||
model, optimizer, args.deepspeed_dataloader, lr_scheduler = deepspeed.initialize(
|
||||
model=model[0],
|
||||
model=unwrapped_model[0],
|
||||
optimizer=optimizer,
|
||||
args=args,
|
||||
lr_scheduler=lr_scheduler,
|
||||
@ -634,7 +636,7 @@ def setup_model_and_optimizer(model_provider_func,
|
||||
model.set_data_post_process_func(data_post_process)
|
||||
else:
|
||||
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
||||
model=model[0],
|
||||
model=unwrapped_model[0],
|
||||
optimizer=optimizer,
|
||||
args=args,
|
||||
lr_scheduler=lr_scheduler,
|
||||
|
@ -2,7 +2,6 @@
|
||||
<p align="left">
|
||||
<b>简体中文</b> |
|
||||
<b><a href="https://gitee.com/ascend/AscendSpeed/blob/master/examples/bloom/README_en.md">English</a> </b>
|
||||
</p>
|
||||
</p>
|
||||
|
||||
|
||||
@ -15,7 +14,9 @@
|
||||
- [吞吐](#吞吐)
|
||||
- [精度](#精度)
|
||||
- [推理](#推理)
|
||||
- [脚本](#脚本)
|
||||
- [deepspeed_pipeline](#deepspeed_pipeline)
|
||||
- [megatron](#megatron)
|
||||
- [评估](#评估)
|
||||
- [Bloom-176B](#Bloom-176B)
|
||||
- [训练](#训练)
|
||||
- [脚本](#脚本)
|
||||
@ -23,8 +24,10 @@
|
||||
- [吞吐](#吞吐)
|
||||
- [精度](#精度)
|
||||
- [推理](#推理)
|
||||
- [脚本](#脚本)
|
||||
|
||||
- [deepspeed_pipeline](#deepspeed_pipeline)
|
||||
- [megatron](#megatron)
|
||||
- [评估](#评估)
|
||||
- [举例](#举例)
|
||||
# Bloom-7B
|
||||
|
||||
## 训练
|
||||
@ -179,6 +182,21 @@ NPU vs 参考 loss 相对误差
|
||||
|
||||
AscendSpeed 支持 BLOOM 7B 的文本生成推理.
|
||||
|
||||
### deepspeed_pipeline
|
||||
|
||||
```shell
|
||||
# 修改 model weight 路径和 tokenizer 路径
|
||||
CHECKPOINT=/home/model/bloom_7B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
```
|
||||
|
||||
```shell
|
||||
bash ./examples/bloom/generate_bloom_7b_deepspeed_pipeline.sh
|
||||
```
|
||||
|
||||
|
||||
### megatron
|
||||
|
||||
使用 [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh) 将bloom-7B的权重转换为推理格式
|
||||
|
||||
```bash
|
||||
@ -190,13 +208,12 @@ python $SCRIPT_PATH \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--type 7B
|
||||
```
|
||||
### 脚本
|
||||
|
||||
配置 Bloom-7B 推理脚本: examples/bloom/generate_bloom_7B_tp8_pp1.sh
|
||||
|
||||
```shell
|
||||
# 修改 model weight 路径和 tokenizer 路径
|
||||
CHECKPOINT=/home/bloom_data/enwiki_100k/enwiki-100k_text_document
|
||||
CHECKPOINT=/home/model/bloom_7B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
```
|
||||
|
||||
@ -204,6 +221,59 @@ VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
bash ./examples/bloom/generate_bloom_7B_tp8_pp1.sh
|
||||
```
|
||||
|
||||
## 评估
|
||||
配置 Bloom-7B 评估脚本: tasks/evaluation/eval_bloom.sh
|
||||
|
||||
```shell
|
||||
# 修改 model weight 路径和 tokenizer 路径和数据集任务路径
|
||||
CHECKPOINT=/home/model/bloom_7B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
DATA_PATH="/dataset/boolq/test"
|
||||
TASK="boolq"
|
||||
```
|
||||
|
||||
除此之外你还需要根据模型大小设置参数:
|
||||
```shell
|
||||
--num-layers 30
|
||||
--hidden-size 4096
|
||||
--num-attention-heads 32
|
||||
```
|
||||
|
||||
```shell
|
||||
bash ./tasks/evaluation/eval_bloom.sh
|
||||
```
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>任务</th>
|
||||
<th>验证集</th>
|
||||
<th>模型</th>
|
||||
<th>昇腾值</th>
|
||||
<th>社区值</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><a href="https://huggingface.co/datasets/boolq">Boolq</a></td>
|
||||
<td>Test</td>
|
||||
<th>bloom 7b</th>
|
||||
<td>0.614</td>
|
||||
<td>--</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><a href="https://huggingface.co/datasets/cais/mmlu">mmlu</a></td>
|
||||
<td>Test</td>
|
||||
<th>bloom 7b</th>
|
||||
<td>0.251</td>
|
||||
<td><a href="https://www.hayo.com/article/648ace24409528db3186ef1c">0.254</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
# Bloom-176B
|
||||
|
||||
## 训练
|
||||
@ -367,6 +437,21 @@ NPU vs 参考 loss
|
||||
## 推理
|
||||
|
||||
AscendSpeed 支持 BLOOM 176B的在线文本生成推理
|
||||
We support AscendSpeed Inference for text generation with BLOOM 176B (deepspeed or megatron).
|
||||
|
||||
### deepspeed_pipeline
|
||||
|
||||
```shell
|
||||
# # 修改 model weight 路径和 tokenizer 路径
|
||||
CHECKPOINT=/home/model/bloom_176B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
```
|
||||
|
||||
```shell
|
||||
bash ./examples/bloom/generate_bloom_176b_deepspeed_pipeline.sh
|
||||
```
|
||||
|
||||
### megatron
|
||||
|
||||
使用 [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh) 脚本将权重转化为推理格式。
|
||||
推理需要两节点运行,需要我们手工将权重同步到两节点下,0号节点需要 1-37 层权重,1号节点需要 38-74 层权重,执行脚本如下:
|
||||
@ -391,7 +476,7 @@ MASTER_ADDR=localhost
|
||||
NODE_RANK=0
|
||||
|
||||
# 修改数据集路径和词表路径
|
||||
CHECKPOINT=/home/bloom_data/enwiki_100k/enwiki-100k_text_document
|
||||
CHECKPOINT=/home/model/bloom_176B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
```
|
||||
|
||||
@ -399,6 +484,50 @@ VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
bash ./examples/bloom/generate_bloom_176b_2nodes.sh
|
||||
```
|
||||
|
||||
|
||||
## 评估
|
||||
配置 Bloom-176B 评估脚本: tasks/evaluation/eval_bloom.sh
|
||||
|
||||
```shell
|
||||
# 修改 model weight 路径和 tokenizer 路径和数据集任务路径
|
||||
CHECKPOINT=/home/model/bloom_176B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
DATA_PATH="/dataset/boolq/test"
|
||||
TASK="boolq"
|
||||
```
|
||||
|
||||
除此之外你还需要根据模型大小设置参数:
|
||||
|
||||
```shell
|
||||
--num-layers 70
|
||||
--hidden-size 14336
|
||||
--num-attention-heads 112
|
||||
```
|
||||
|
||||
```shell
|
||||
bash ./tasks/evaluation/eval_bloom.sh
|
||||
```
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>任务</th>
|
||||
<th>验证集</th>
|
||||
<th>模型</th>
|
||||
<th>昇腾值</th>
|
||||
<th>社区值</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><a href="https://huggingface.co/datasets/boolq">Boolq</a></td>
|
||||
<td>Test</td>
|
||||
<th>bloom 176b</th>
|
||||
<td>0.657</td>
|
||||
<td>--</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## 举例
|
||||
1. bloom 7b
|
||||
|
||||
@ -409,9 +538,11 @@ bash ./examples/bloom/generate_bloom_176b_2nodes.sh
|
||||
|
||||
## 引用
|
||||
|
||||
```
|
||||
@article{scao2022bloom,
|
||||
title={Bloom: A 176b-parameter open-access multilingual language model},
|
||||
author={Scao, Teven Le and Fan, Angela and Akiki, Christopher and Pavlick, Ellie and Ili{\'c}, Suzana and Hesslow, Daniel and Castagn{\'e}, Roman and Luccioni, Alexandra Sasha and Yvon, Fran{\c{c}}ois and Gall{\'e}, Matthias and others},
|
||||
journal={arXiv preprint arXiv:2211.05100},
|
||||
year={2022}
|
||||
}
|
||||
}
|
||||
```
|
@ -2,9 +2,10 @@
|
||||
<p align="left">
|
||||
<b><a href="https://gitee.com/ascend/AscendSpeed/blob/master/examples/bloom/README.md">简体中文</a></b> |
|
||||
<b>English</b>
|
||||
</p>
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
# Contents
|
||||
|
||||
- [Bloom-7B](#contents)
|
||||
@ -66,6 +67,7 @@ conda activate bloom7b
|
||||
# install torch and torch_npu and apex
|
||||
pip install torch-2.0.1-cp38-cp38-manylinux2014_aarch64.whl
|
||||
pip install torch_npu-2.0.1rc1.postxxxxxxxx-cp38-cp38-linux_aarch64.whl
|
||||
pip install apex-0.1_ascend*-cp38-cp38m-linux_aarch64.whl
|
||||
|
||||
# install megatron-core
|
||||
pip3 install --no-use-pep517 -e git+https://github.com/NVIDIA/Megatron-LM.git@23.05#egg=megatron-core
|
||||
@ -83,7 +85,7 @@ pip install -r requirements.txt
|
||||
|
||||
3. Prepare pretrained weights
|
||||
|
||||
Download the Bloom-7B weights from [here](https://huggingface.co/bigscience/bloom-7b1/tree/main).
|
||||
Download the Bloom-7B tokensizer from [here](https://huggingface.co/bigscience/bloom-7b1/tree/main).
|
||||
|
||||
```shell
|
||||
mkdir tokenizer
|
||||
@ -91,7 +93,6 @@ cd tokenizer
|
||||
wget https://huggingface.co/bigscience/bloom/resolve/main/special_tokens_map.json
|
||||
wget https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json
|
||||
wget https://huggingface.co/bigscience/bloom/resolve/main/tokenizer_config.json
|
||||
...
|
||||
cd ..
|
||||
```
|
||||
|
||||
@ -180,9 +181,23 @@ NPU vs GPU loss relative error.
|
||||
|
||||
## Inference
|
||||
|
||||
We support AscendSpeed Inference for text generation with BLOOM 7B.
|
||||
We support AscendSpeed Inference for text generation with BLOOM 7B (deepspeed or megatron).
|
||||
|
||||
Use [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh), converting deep speed checkpoints to megatron.Convert the checkpoint of deepspeed to megtron.
|
||||
### deepspeed_pipeline
|
||||
|
||||
```shell
|
||||
# modify the model weight path and tokenizer path
|
||||
CHECKPOINT=/home/model/bloom_7B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
```
|
||||
|
||||
```shell
|
||||
bash ./examples/bloom/generate_bloom_7b_deepspeed_pipeline.sh
|
||||
```
|
||||
|
||||
### megatron
|
||||
|
||||
Use [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh), converting deepspeed checkpoints to megatron.
|
||||
|
||||
```bash
|
||||
SCRIPT_PATH=./tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel_v2.py
|
||||
@ -193,7 +208,6 @@ python $SCRIPT_PATH \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--type 7B
|
||||
```
|
||||
### Script
|
||||
|
||||
We generate text samples using the `generate_bloom` script. Inference different from pre-training, such as we need to Load pre training checkpoint and the length of the output samples:
|
||||
|
||||
@ -201,7 +215,7 @@ Config Bloom-7B inference script: examples/bloom/generate_bloom_7B_tp8_pp1.sh
|
||||
|
||||
```shell
|
||||
# modify the model weight path and tokenizer path
|
||||
CHECKPOINT=/home/bloom_data/enwiki_100k/enwiki-100k_text_document
|
||||
CHECKPOINT=/home/model/bloom_7B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
```
|
||||
|
||||
@ -209,6 +223,59 @@ VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
bash ./examples/bloom/generate_bloom_7B_tp8_pp1.sh
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
Config Bloom-7B evaluation script: tasks/evaluation/eval_bloom.sh
|
||||
|
||||
```shell
|
||||
# modify the model weight path and tokenizer path
|
||||
CHECKPOINT=/home/model/bloom_7B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
DATA_PATH="/dataset/boolq/test"
|
||||
TASK="boolq"
|
||||
```
|
||||
|
||||
In addition, you need to set the corresponding parameters according to the model size, bloom_7B parameters are:
|
||||
```shell
|
||||
--num-layers 30
|
||||
--hidden-size 4096
|
||||
--num-attention-heads 32
|
||||
```
|
||||
|
||||
```shell
|
||||
bash ./tasks/evaluation/eval_bloom.sh
|
||||
```
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>任务</th>
|
||||
<th>验证集</th>
|
||||
<th>模型</th>
|
||||
<th>昇腾值</th>
|
||||
<th>社区值</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><a href="https://huggingface.co/datasets/boolq">Boolq</a></td>
|
||||
<td>Test</td>
|
||||
<th>bloom 7b</th>
|
||||
<td>0.614</td>
|
||||
<td>--</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><a href="https://huggingface.co/datasets/cais/mmlu">mmlu</a></td>
|
||||
<td>Test</td>
|
||||
<th>bloom 7b</th>
|
||||
<td>0.251</td>
|
||||
<td><a href="https://www.hayo.com/article/648ace24409528db3186ef1c">0.254</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
# Bloom-176B
|
||||
|
||||
## Training
|
||||
@ -371,7 +438,21 @@ and GPU on a single-node system. The average relative error is 0.1%, less than 2
|
||||
|
||||
## Inference
|
||||
|
||||
We support AscendSpeed Inference for text generation with BLOOM 176B.
|
||||
We support AscendSpeed Inference for text generation with BLOOM 176B (deepspeed or megatron).
|
||||
|
||||
### deepspeed_pipeline
|
||||
|
||||
```shell
|
||||
# modify the model weight path and tokenizer path
|
||||
CHECKPOINT=/home/model/bloom_176B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
```
|
||||
|
||||
```shell
|
||||
bash ./examples/bloom/generate_bloom_176b_deepspeed_pipeline.sh
|
||||
```
|
||||
|
||||
### megatron.
|
||||
|
||||
Use [convert_weights_from_gptmodelpipe_to_gptmodel.sh](../../tools/ckpt_convert/bloom/convert_weights_from_gptmodelpipe_to_gptmodel.sh), converting deep speed checkpoints to megatron.Convert the checkpoint of deepspeed to megtron.
|
||||
|
||||
@ -399,7 +480,7 @@ MASTER_ADDR=localhost
|
||||
NODE_RANK=0
|
||||
|
||||
# modify the model weight path and tokenizer path
|
||||
CHECKPOINT=/home/bloom_data/enwiki_100k/enwiki-100k_text_document
|
||||
CHECKPOINT=/home/model/bloom_176B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
```
|
||||
|
||||
@ -407,6 +488,48 @@ VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
bash ./examples/bloom/generate_bloom_176b_2nodes.sh
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
Config Bloom-7B evaluation script: tasks/evaluation/eval_bloom.sh
|
||||
|
||||
```shell
|
||||
# modify the model weight path and tokenizer path
|
||||
CHECKPOINT=/home/model/bloom_176B
|
||||
VOCAB_FILE=/home/bloom_data/vocab_file/
|
||||
DATA_PATH="/dataset/boolq/test"
|
||||
TASK="boolq"
|
||||
```
|
||||
|
||||
In addition, you need to set the corresponding parameters according to the model size, bloom_7B parameters are:
|
||||
```shell
|
||||
--num-layers 70
|
||||
--hidden-size 14336
|
||||
--num-attention-heads 112
|
||||
```
|
||||
|
||||
```shell
|
||||
bash ./tasks/evaluation/eval_bloom.sh
|
||||
```
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>任务</th>
|
||||
<th>验证集</th>
|
||||
<th>模型</th>
|
||||
<th>昇腾值</th>
|
||||
<th>社区值</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><a href="https://huggingface.co/datasets/boolq">Boolq</a></td>
|
||||
<td>Test</td>
|
||||
<th>bloom 176b</th>
|
||||
<td>0.657</td>
|
||||
<td>--</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## Example
|
||||
1. bloom 7b
|
||||
![bloom_7b_generate.png](..%2F..%2Fsources%2Fimages%2Fbloom_7b_generate.png)
|
||||
@ -421,9 +544,14 @@ All the provided scripts are tested on 910 64GB NPUs for BLOOM 7B and BLOOM 176B
|
||||
|
||||
You may also consider original work in your reference:
|
||||
|
||||
```
|
||||
@article{scao2022bloom,
|
||||
title={Bloom: A 176b-parameter open-access multilingual language model},
|
||||
author={Scao, Teven Le and Fan, Angela and Akiki, Christopher and Pavlick, Ellie and Ili{\'c}, Suzana and Hesslow, Daniel and Castagn{\'e}, Roman and Luccioni, Alexandra Sasha and Yvon, Fran{\c{c}}ois and Gall{\'e}, Matthias and others},
|
||||
journal={arXiv preprint arXiv:2211.05100},
|
||||
year={2022}
|
||||
}
|
||||
}
|
||||
```
|
||||
\
|
||||
\
|
||||
<font size=1>If the download of the file fails using 'wget' , you can download it manually while ensuring website security.</font>
|
56
examples/bloom/generate_bloom_176b_deepspeed_pipeline.sh
Normal file
56
examples/bloom/generate_bloom_176b_deepspeed_pipeline.sh
Normal file
@ -0,0 +1,56 @@
|
||||
#!/bin/bash
|
||||
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
NNODES=1
|
||||
NPUS_PER_NODE=8
|
||||
|
||||
CHECKPOINT="your megatron checkpoint path"
|
||||
VOCAB_FILE="your vocab path"
|
||||
|
||||
ZERO_STAGE=0
|
||||
MICRO_BATCH_SIZE=1
|
||||
config_json="./ds_config.json"
|
||||
|
||||
cat <<EOT > $config_json
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": $ZERO_STAGE
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 500,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1,
|
||||
"initial_scale_power": 12
|
||||
},
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
EOT
|
||||
|
||||
deepspeed --num_nodes $NNODES --num_gpus $NPUS_PER_NODE \
|
||||
./tasks/inference/inference_bloom_pipeline.py \
|
||||
--no-contiguous-buffers-in-local-ddp \
|
||||
--tensor-model-parallel-size 8 \
|
||||
--pipeline-model-parallel-size 1 \
|
||||
--num-layers 70 \
|
||||
--hidden-size 14336 \
|
||||
--num-attention-heads 112 \
|
||||
--max-position-embeddings 2048 \
|
||||
--position-embedding-type alibi \
|
||||
--embed-layernorm \
|
||||
--tokenizer-type PretrainedFromHF \
|
||||
--load "${CHECKPOINT}" \
|
||||
--tokenizer-name-or-path "$VOCAB_FILE" \
|
||||
--tokenizer-not-use-fast \
|
||||
--fp16 \
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ${config_json} \
|
Loading…
Reference in New Issue
Block a user