mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-04 21:07:51 +08:00
!1359 检测GPTDataset sample_index异常时主动抛出错误
Merge pull request !1359 from glhyy/master
This commit is contained in:
parent
baf8f2237f
commit
d2e3bfaf84
@ -12,6 +12,7 @@ from megatron.core.datasets.utils import log_single_rank
|
||||
from megatron.core.datasets.gpt_dataset import (_build_document_index,
|
||||
_build_shuffle_index
|
||||
)
|
||||
from modellink.error_utils import GPTDatasetSampleIndexError
|
||||
from .blended_megatron_dataset_builder import need_to_build_dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -153,6 +154,11 @@ def _build_document_sample_shuffle_indices(
|
||||
num_epochs,
|
||||
num_tokens_per_epoch,
|
||||
)
|
||||
|
||||
if any(sample_index[:, 0] < 0):
|
||||
_url = "https://gitee.com/ascend/ModelLink/wikis/megatron%20data%20helpers%E5%8F%AF%E8%83%BD%E5%BC%95%E5%85%A5%E7%9A%84%E9%97%AE%E9%A2%98"
|
||||
raise GPTDatasetSampleIndexError(f"Bad sample index. Visit {_url} for more information")
|
||||
|
||||
numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
|
||||
t_end = time.time()
|
||||
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
||||
@ -204,6 +210,11 @@ def _build_document_sample_shuffle_indices(
|
||||
)
|
||||
t_beg = time.time()
|
||||
sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r')
|
||||
|
||||
if any(sample_index[:, 0] < 0):
|
||||
_url = "https://gitee.com/ascend/ModelLink/wikis/megatron%20data%20helpers%E5%8F%AF%E8%83%BD%E5%BC%95%E5%85%A5%E7%9A%84%E9%97%AE%E9%A2%98"
|
||||
raise GPTDatasetSampleIndexError(f"Bad sample index. Visit {_url} for more information")
|
||||
|
||||
t_end = time.time()
|
||||
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
||||
|
||||
|
@ -159,3 +159,15 @@ class IsNotValidError(Exception):
|
||||
def ensure_valid(expression, error_message=None):
|
||||
if not expression:
|
||||
raise IsNotValidError(error_message)
|
||||
|
||||
|
||||
class GPTDatasetSampleIndexError(Exception):
|
||||
def __init__(self, error_message):
|
||||
super().__init__()
|
||||
self._error_message = error_message
|
||||
|
||||
def __repr__(self):
|
||||
if self._error_message:
|
||||
return self._error_message
|
||||
else:
|
||||
return "Bad sample index."
|
||||
|
Loading…
Reference in New Issue
Block a user