mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 19:57:34 +08:00
1. fix bugs of DataSet.from_dataset 2. fix bugs of some tests 3. add to OneflowDriver.load_model
This commit is contained in:
parent
96e8c6bda2
commit
749415970e
@ -1048,7 +1048,7 @@ class DataSet:
|
||||
:param dataset 为实例化好的 huggingface Dataset 对象
|
||||
"""
|
||||
from datasets import Dataset
|
||||
if not isinstance(dataset, DataSet):
|
||||
if not isinstance(dataset, Dataset):
|
||||
raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!")
|
||||
|
||||
data_dict = dataset.to_dict()
|
||||
|
@ -192,7 +192,8 @@ class OneflowDriver(Driver):
|
||||
f"`only_state_dict=False`")
|
||||
if not isinstance(res, dict):
|
||||
res = res.state_dict()
|
||||
model.load_state_dict(res)
|
||||
_strict = kwargs.get("strict")
|
||||
model.load_state_dict(res, _strict)
|
||||
|
||||
@rank_zero_call
|
||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
||||
|
@ -1,5 +1,5 @@
|
||||
|
||||
from fastNLP.core.collators.packer_unpacker import *
|
||||
from fastNLP.core.collators.packer_unpacker import MappingPackerUnpacker, NestedMappingPackerUnpacker, SequencePackerUnpacker
|
||||
|
||||
|
||||
def test_unpack_batch_mapping():
|
||||
|
@ -97,6 +97,7 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=
|
||||
# if dist.is_initialized():
|
||||
# dist.destroy_process_group()
|
||||
|
||||
@pytest.mark.deepspeed
|
||||
@magic_argv_env_context
|
||||
def test_multi_optimizers():
|
||||
torch_model = TorchNormalModel_Classification_1(10, 10)
|
||||
|
Loading…
Reference in New Issue
Block a user