1. fix bugs of DataSet.from_dataset 2. fix bugs of some tests 3. add to OneflowDriver.load_model

This commit is contained in:
x54-729 2022-09-13 14:41:22 +08:00
parent 96e8c6bda2
commit 749415970e
4 changed files with 5 additions and 3 deletions

View File

@ -1048,7 +1048,7 @@ class DataSet:
:param dataset 为实例化好的 huggingface Dataset 对象
"""
from datasets import Dataset
if not isinstance(dataset, DataSet):
if not isinstance(dataset, Dataset):
raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!")
data_dict = dataset.to_dict()

View File

@ -192,7 +192,8 @@ class OneflowDriver(Driver):
f"`only_state_dict=False`")
if not isinstance(res, dict):
res = res.state_dict()
model.load_state_dict(res)
_strict = kwargs.get("strict")
model.load_state_dict(res, _strict)
@rank_zero_call
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):

View File

@ -1,5 +1,5 @@
from fastNLP.core.collators.packer_unpacker import *
from fastNLP.core.collators.packer_unpacker import MappingPackerUnpacker, NestedMappingPackerUnpacker, SequencePackerUnpacker
def test_unpack_batch_mapping():

View File

@ -97,6 +97,7 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=
# if dist.is_initialized():
# dist.destroy_process_group()
@pytest.mark.deepspeed
@magic_argv_env_context
def test_multi_optimizers():
torch_model = TorchNormalModel_Classification_1(10, 10)