diff --git a/examples/README.md b/examples/README.md index 3c11435ad..618519dcf 100644 --- a/examples/README.md +++ b/examples/README.md @@ -329,6 +329,11 @@ python ./preprocess_data.py \ --json-keys text input output \ ``` +【--n-subs】 + +数据预处理并行加速参数。当需要预处理的数据集比较大时,可以通过并行处理进行加速,方法为设置参数`--n-subs`,通过该参数设置并行处理数量。在数据预处理过程会将原始数据集切分为`n_sub`个子集,对子集进行并行处理,然后合并,从而实现加速。建议预处理数据集超过GB级别时加上该参数。 + + ModelLink预训练数据集处理脚本命名风格及启动方法为: ```shell # Legacy diff --git a/modellink/arguments.py b/modellink/arguments.py index 4a18a0644..7d55361d1 100644 --- a/modellink/arguments.py +++ b/modellink/arguments.py @@ -427,6 +427,8 @@ def _add_distributed_args(parser): group.add_argument('--local-rank', type=int, default=None, help='Local rank passed from distributed launcher for torch2.x.') + group.add_argument('--distributed-timeout-minutes', type=int, default=45, + help='Timeout minutes for torch.distributed.') return parser diff --git a/modellink/tasks/preprocess/data_handler.py b/modellink/tasks/preprocess/data_handler.py index 0ada8864e..70f705e34 100644 --- a/modellink/tasks/preprocess/data_handler.py +++ b/modellink/tasks/preprocess/data_handler.py @@ -80,6 +80,7 @@ class BaseDatasetHandler(object): output_idx_files[key] = f"{self.args.output_prefix}_{key}_{level}.idx" # vocab_size=None : use int32 dtype for -100 will be used in labels builders[key] = indexed_dataset.IndexedDatasetBuilder(output_bin_files[key]) + self.output_idx_files = output_idx_files startup_end = time.time() proc_start = time.time() total_bytes_processed = 0 diff --git a/preprocess_data.py b/preprocess_data.py index 0835b6416..4260d835b 100644 --- a/preprocess_data.py +++ b/preprocess_data.py @@ -17,8 +17,10 @@ import argparse import json +import multiprocessing import os import sys +import copy import logging from typing import List @@ -33,6 +35,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), from modellink.tokenizer import build_tokenizer from modellink.tasks.preprocess.data_handler import build_dataset, get_dataset_handler +from megatron.core.datasets.indexed_dataset import IndexedDatasetBuilder logging.basicConfig(level=logging.INFO) @@ -163,6 +166,8 @@ def add_output_args(parser): group = parser.add_argument_group(title='runtime') group.add_argument('--workers', type=int, default=1, help='Number of worker processes to launch') + group.add_argument('--n-subs', type=int, default=1, + help='Number of subsets to cut for multiprocessing') group.add_argument('--log-interval', type=int, default=100, help='Interval between progress updates') @@ -195,6 +200,22 @@ def validate_args(args): raise AssertionError('If specify prompt_type , handler name must be "LlamaFactoryInstructionHandler"、"AlpacaStyleInstructionHandler"、"SharegptStyleInstructionHandler".') +def cut_range_to_subs(n, gap): + n_ = n // gap + mod = n % gap + if mod != 0: + return [(k * gap, (k + 1) * gap) for k in range(0, n_)] + [(gap * n_, n)] + else: + return [(k * gap, (k + 1) * gap) for k in range(0, n_)] + + +def handle_subset(params): + """params: [args, dataset, tokenizer, splitter]""" + handler = get_dataset_handler(params[0], params[1], params[2], params[3]) + handler.serialize_to_disk() + return handler.output_idx_files + + def main(): args = get_args() validate_args(args) @@ -205,11 +226,43 @@ def main(): logger.info("building dataset: %s", args.input) raw_data = build_dataset(args) - handler = get_dataset_handler(args, raw_data, tokenizer, splitter) - - # serialize to bin&idx - handler.serialize_to_disk() - + if args.n_subs == 1: + handler = get_dataset_handler(args, raw_data, tokenizer, splitter) + # serialize to bin&idx + handler.serialize_to_disk() + else: + target_prefix = args.output_prefix + target_prefixname = os.path.basename(target_prefix) + + num_samples = len(raw_data) + start_ends = cut_range_to_subs(num_samples, num_samples // args.n_subs) + subsets = [raw_data.select(range(x[0], x[1])) for x in start_ends] + + # multiprocessing + params_list = [] + for k, subset in enumerate(subsets): + args_ = copy.deepcopy(args) + args_.output_prefix = target_prefix.replace(target_prefixname, f'{str(k).zfill(3)}_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}') + params = [args_, subset, tokenizer, splitter] + params_list.append(params) + pool = multiprocessing.Pool() + sub_idx_files = pool.map(handle_subset, params_list) + pool.close() + pool.join() + + for key in sub_idx_files[0].keys(): + idx_files = [x[key] for x in sub_idx_files] + idx_files.sort() + target_idx = idx_files[0].replace(f'000_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}', target_prefixname) + target_bin = target_idx.replace('.idx', '.bin') + idx = IndexedDatasetBuilder(target_bin) + for idx_file in idx_files: + idx.add_index(idx_file.replace('.idx', '')) + idx.finalize(target_idx) + + for idx_file in idx_files: + os.remove(idx_file) + os.remove(idx_file.replace('.idx', '.bin')) if __name__ == '__main__': main()