!1487 数据预处理加速

Merge pull request !1487 from glhyy/master
This commit is contained in:
glhyy 2024-08-16 08:13:07 +00:00 committed by i-robot
parent f6f9ba24d3
commit 03f23d1ce3
4 changed files with 66 additions and 5 deletions

View File

@ -329,6 +329,11 @@ python ./preprocess_data.py \
--json-keys text input output \
```
【--n-subs】
数据预处理并行加速参数。当需要预处理的数据集比较大时,可以通过并行处理进行加速,方法为设置参数`--n-subs`,通过该参数设置并行处理数量。在数据预处理过程会将原始数据集切分为`n_sub`个子集对子集进行并行处理然后合并从而实现加速。建议预处理数据集超过GB级别时加上该参数。
ModelLink预训练数据集处理脚本命名风格及启动方法为
```shell
# Legacy

View File

@ -427,6 +427,8 @@ def _add_distributed_args(parser):
group.add_argument('--local-rank', type=int, default=None,
help='Local rank passed from distributed launcher for torch2.x.')
group.add_argument('--distributed-timeout-minutes', type=int, default=45,
help='Timeout minutes for torch.distributed.')
return parser

View File

@ -80,6 +80,7 @@ class BaseDatasetHandler(object):
output_idx_files[key] = f"{self.args.output_prefix}_{key}_{level}.idx"
# vocab_size=None : use int32 dtype for -100 will be used in labels
builders[key] = indexed_dataset.IndexedDatasetBuilder(output_bin_files[key])
self.output_idx_files = output_idx_files
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0

View File

@ -17,8 +17,10 @@
import argparse
import json
import multiprocessing
import os
import sys
import copy
import logging
from typing import List
@ -33,6 +35,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
from modellink.tokenizer import build_tokenizer
from modellink.tasks.preprocess.data_handler import build_dataset, get_dataset_handler
from megatron.core.datasets.indexed_dataset import IndexedDatasetBuilder
logging.basicConfig(level=logging.INFO)
@ -163,6 +166,8 @@ def add_output_args(parser):
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
help='Number of worker processes to launch')
group.add_argument('--n-subs', type=int, default=1,
help='Number of subsets to cut for multiprocessing')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
@ -195,6 +200,22 @@ def validate_args(args):
raise AssertionError('If specify prompt_type , handler name must be "LlamaFactoryInstructionHandler""AlpacaStyleInstructionHandler""SharegptStyleInstructionHandler".')
def cut_range_to_subs(n, gap):
n_ = n // gap
mod = n % gap
if mod != 0:
return [(k * gap, (k + 1) * gap) for k in range(0, n_)] + [(gap * n_, n)]
else:
return [(k * gap, (k + 1) * gap) for k in range(0, n_)]
def handle_subset(params):
"""params: [args, dataset, tokenizer, splitter]"""
handler = get_dataset_handler(params[0], params[1], params[2], params[3])
handler.serialize_to_disk()
return handler.output_idx_files
def main():
args = get_args()
validate_args(args)
@ -205,11 +226,43 @@ def main():
logger.info("building dataset: %s", args.input)
raw_data = build_dataset(args)
handler = get_dataset_handler(args, raw_data, tokenizer, splitter)
# serialize to bin&idx
handler.serialize_to_disk()
if args.n_subs == 1:
handler = get_dataset_handler(args, raw_data, tokenizer, splitter)
# serialize to bin&idx
handler.serialize_to_disk()
else:
target_prefix = args.output_prefix
target_prefixname = os.path.basename(target_prefix)
num_samples = len(raw_data)
start_ends = cut_range_to_subs(num_samples, num_samples // args.n_subs)
subsets = [raw_data.select(range(x[0], x[1])) for x in start_ends]
# multiprocessing
params_list = []
for k, subset in enumerate(subsets):
args_ = copy.deepcopy(args)
args_.output_prefix = target_prefix.replace(target_prefixname, f'{str(k).zfill(3)}_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}')
params = [args_, subset, tokenizer, splitter]
params_list.append(params)
pool = multiprocessing.Pool()
sub_idx_files = pool.map(handle_subset, params_list)
pool.close()
pool.join()
for key in sub_idx_files[0].keys():
idx_files = [x[key] for x in sub_idx_files]
idx_files.sort()
target_idx = idx_files[0].replace(f'000_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}', target_prefixname)
target_bin = target_idx.replace('.idx', '.bin')
idx = IndexedDatasetBuilder(target_bin)
for idx_file in idx_files:
idx.add_index(idx_file.replace('.idx', ''))
idx.finalize(target_idx)
for idx_file in idx_files:
os.remove(idx_file)
os.remove(idx_file.replace('.idx', '.bin'))
if __name__ == '__main__':
main()