!1359 检测GPTDataset sample_index异常时主动抛出错误

Merge pull request !1359 from glhyy/master
This commit is contained in:
glhyy 2024-06-20 09:15:08 +00:00 committed by i-robot
parent baf8f2237f
commit d2e3bfaf84
2 changed files with 23 additions and 0 deletions

View File

@ -12,6 +12,7 @@ from megatron.core.datasets.utils import log_single_rank
from megatron.core.datasets.gpt_dataset import (_build_document_index,
_build_shuffle_index
)
from modellink.error_utils import GPTDatasetSampleIndexError
from .blended_megatron_dataset_builder import need_to_build_dataset
logger = logging.getLogger(__name__)
@ -153,6 +154,11 @@ def _build_document_sample_shuffle_indices(
num_epochs,
num_tokens_per_epoch,
)
if any(sample_index[:, 0] < 0):
_url = "https://gitee.com/ascend/ModelLink/wikis/megatron%20data%20helpers%E5%8F%AF%E8%83%BD%E5%BC%95%E5%85%A5%E7%9A%84%E9%97%AE%E9%A2%98"
raise GPTDatasetSampleIndexError(f"Bad sample index. Visit {_url} for more information")
numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
@ -204,6 +210,11 @@ def _build_document_sample_shuffle_indices(
)
t_beg = time.time()
sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r')
if any(sample_index[:, 0] < 0):
_url = "https://gitee.com/ascend/ModelLink/wikis/megatron%20data%20helpers%E5%8F%AF%E8%83%BD%E5%BC%95%E5%85%A5%E7%9A%84%E9%97%AE%E9%A2%98"
raise GPTDatasetSampleIndexError(f"Bad sample index. Visit {_url} for more information")
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")

View File

@ -159,3 +159,15 @@ class IsNotValidError(Exception):
def ensure_valid(expression, error_message=None):
if not expression:
raise IsNotValidError(error_message)
class GPTDatasetSampleIndexError(Exception):
def __init__(self, error_message):
super().__init__()
self._error_message = error_message
def __repr__(self):
if self._error_message:
return self._error_message
else:
return "Bad sample index."