diff --git a/modellink/core/datasets/gpt_dataset.py b/modellink/core/datasets/gpt_dataset.py index fe12a0a8a..4a37d53b2 100644 --- a/modellink/core/datasets/gpt_dataset.py +++ b/modellink/core/datasets/gpt_dataset.py @@ -12,6 +12,7 @@ from megatron.core.datasets.utils import log_single_rank from megatron.core.datasets.gpt_dataset import (_build_document_index, _build_shuffle_index ) +from modellink.error_utils import GPTDatasetSampleIndexError from .blended_megatron_dataset_builder import need_to_build_dataset logger = logging.getLogger(__name__) @@ -153,6 +154,11 @@ def _build_document_sample_shuffle_indices( num_epochs, num_tokens_per_epoch, ) + + if any(sample_index[:, 0] < 0): + _url = "https://gitee.com/ascend/ModelLink/wikis/megatron%20data%20helpers%E5%8F%AF%E8%83%BD%E5%BC%95%E5%85%A5%E7%9A%84%E9%97%AE%E9%A2%98" + raise GPTDatasetSampleIndexError(f"Bad sample index. Visit {_url} for more information") + numpy.save(path_to_sample_index, sample_index, allow_pickle=True) t_end = time.time() log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") @@ -204,6 +210,11 @@ def _build_document_sample_shuffle_indices( ) t_beg = time.time() sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r') + + if any(sample_index[:, 0] < 0): + _url = "https://gitee.com/ascend/ModelLink/wikis/megatron%20data%20helpers%E5%8F%AF%E8%83%BD%E5%BC%95%E5%85%A5%E7%9A%84%E9%97%AE%E9%A2%98" + raise GPTDatasetSampleIndexError(f"Bad sample index. Visit {_url} for more information") + t_end = time.time() log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") diff --git a/modellink/error_utils.py b/modellink/error_utils.py index bdf912e0c..88dc32191 100644 --- a/modellink/error_utils.py +++ b/modellink/error_utils.py @@ -159,3 +159,15 @@ class IsNotValidError(Exception): def ensure_valid(expression, error_message=None): if not expression: raise IsNotValidError(error_message) + + +class GPTDatasetSampleIndexError(Exception): + def __init__(self, error_message): + super().__init__() + self._error_message = error_message + + def __repr__(self): + if self._error_message: + return self._error_message + else: + return "Bad sample index."