mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 21:37:43 +08:00
!1359 检测GPTDataset sample_index异常时主动抛出错误
Merge pull request !1359 from glhyy/master
This commit is contained in:
parent
baf8f2237f
commit
d2e3bfaf84
@ -12,6 +12,7 @@ from megatron.core.datasets.utils import log_single_rank
|
|||||||
from megatron.core.datasets.gpt_dataset import (_build_document_index,
|
from megatron.core.datasets.gpt_dataset import (_build_document_index,
|
||||||
_build_shuffle_index
|
_build_shuffle_index
|
||||||
)
|
)
|
||||||
|
from modellink.error_utils import GPTDatasetSampleIndexError
|
||||||
from .blended_megatron_dataset_builder import need_to_build_dataset
|
from .blended_megatron_dataset_builder import need_to_build_dataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -153,6 +154,11 @@ def _build_document_sample_shuffle_indices(
|
|||||||
num_epochs,
|
num_epochs,
|
||||||
num_tokens_per_epoch,
|
num_tokens_per_epoch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if any(sample_index[:, 0] < 0):
|
||||||
|
_url = "https://gitee.com/ascend/ModelLink/wikis/megatron%20data%20helpers%E5%8F%AF%E8%83%BD%E5%BC%95%E5%85%A5%E7%9A%84%E9%97%AE%E9%A2%98"
|
||||||
|
raise GPTDatasetSampleIndexError(f"Bad sample index. Visit {_url} for more information")
|
||||||
|
|
||||||
numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
|
numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
|
||||||
t_end = time.time()
|
t_end = time.time()
|
||||||
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
||||||
@ -204,6 +210,11 @@ def _build_document_sample_shuffle_indices(
|
|||||||
)
|
)
|
||||||
t_beg = time.time()
|
t_beg = time.time()
|
||||||
sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r')
|
sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r')
|
||||||
|
|
||||||
|
if any(sample_index[:, 0] < 0):
|
||||||
|
_url = "https://gitee.com/ascend/ModelLink/wikis/megatron%20data%20helpers%E5%8F%AF%E8%83%BD%E5%BC%95%E5%85%A5%E7%9A%84%E9%97%AE%E9%A2%98"
|
||||||
|
raise GPTDatasetSampleIndexError(f"Bad sample index. Visit {_url} for more information")
|
||||||
|
|
||||||
t_end = time.time()
|
t_end = time.time()
|
||||||
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
|
||||||
|
|
||||||
|
@ -159,3 +159,15 @@ class IsNotValidError(Exception):
|
|||||||
def ensure_valid(expression, error_message=None):
|
def ensure_valid(expression, error_message=None):
|
||||||
if not expression:
|
if not expression:
|
||||||
raise IsNotValidError(error_message)
|
raise IsNotValidError(error_message)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTDatasetSampleIndexError(Exception):
|
||||||
|
def __init__(self, error_message):
|
||||||
|
super().__init__()
|
||||||
|
self._error_message = error_message
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self._error_message:
|
||||||
|
return self._error_message
|
||||||
|
else:
|
||||||
|
return "Bad sample index."
|
||||||
|
Loading…
Reference in New Issue
Block a user