fix a serial bugs on importing

This commit is contained in:
ChenXin 2019-08-14 20:18:54 +08:00
parent afa73bf5c8
commit a8a21b169a
12 changed files with 30 additions and 20 deletions

View File

@ -1,3 +1,6 @@
"""
正在开发中的分布式训练代码
"""
import torch import torch
import torch.cuda import torch.cuda
import torch.optim import torch.optim
@ -41,7 +44,8 @@ def get_local_rank():
class DistTrainer(): class DistTrainer():
"""Distributed Trainer that support distributed and mixed precision training """
Distributed Trainer that support distributed and mixed precision training
""" """
def __init__(self, train_data, model, optimizer=None, loss=None, def __init__(self, train_data, model, optimizer=None, loss=None,
callbacks_all=None, callbacks_master=None, callbacks_all=None, callbacks_master=None,

View File

@ -1,4 +1,3 @@
from abc import abstractmethod from abc import abstractmethod
import torch import torch
@ -9,6 +8,10 @@ from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device from ..core.utils import _move_model_to_device, _get_model_device
from .embedding import TokenEmbedding from .embedding import TokenEmbedding
__all__ = [
"ContextualEmbedding"
]
class ContextualEmbedding(TokenEmbedding): class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0):

View File

@ -1,7 +1,9 @@
""" """
用于读入和处理和保存 config 文件 用于读入和处理和保存 config 文件
.. todo::
.. todo::
这个模块中的类可能被抛弃 这个模块中的类可能被抛弃
""" """
__all__ = [ __all__ = [
"ConfigLoader", "ConfigLoader",

View File

@ -1,12 +1,12 @@
from typing import Dict, Union from typing import Dict, Union
from .loader import Loader from .loader import Loader
from ... import DataSet from ...core.dataset import DataSet
from ..file_reader import _read_conll from ..file_reader import _read_conll
from ... import Instance from ...core.instance import Instance
from .. import DataBundle from .. import DataBundle
from ..utils import check_loader_paths from ..utils import check_loader_paths
from ... import Const from ...core.const import Const
class ConllLoader(Loader): class ConllLoader(Loader):

View File

@ -1,6 +1,6 @@
from .loader import Loader from .loader import Loader
from ...core import DataSet, Instance from ...core.dataset import DataSet
from ...core.instance import Instance
class CWSLoader(Loader): class CWSLoader(Loader):

View File

@ -1,4 +1,4 @@
from ... import DataSet from ...core.dataset import DataSet
from .. import DataBundle from .. import DataBundle
from ..utils import check_loader_paths from ..utils import check_loader_paths
from typing import Union, Dict from typing import Union, Dict

View File

@ -1,12 +1,12 @@
import warnings import warnings
from .loader import Loader from .loader import Loader
from .json import JsonLoader from .json import JsonLoader
from ...core import Const from ...core.const import Const
from .. import DataBundle from .. import DataBundle
import os import os
from typing import Union, Dict from typing import Union, Dict
from ...core import DataSet from ...core.dataset import DataSet
from ...core import Instance from ...core.instance import Instance
class MNLILoader(Loader): class MNLILoader(Loader):

View File

@ -4,13 +4,14 @@ from ..base_loader import DataBundle
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
from ...core.const import Const from ...core.const import Const
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ...core import DataSet, Instance from ...core.dataset import DataSet
from ...core.instance import Instance
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from .pipe import Pipe from .pipe import Pipe
import re import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
from ...core import cache_results from ...core.utils import cache_results
class _CLSPipe(Pipe): class _CLSPipe(Pipe):
""" """

View File

@ -1,7 +1,7 @@
from .pipe import Pipe from .pipe import Pipe
from .. import DataBundle from .. import DataBundle
from .utils import iob2, iob2bioes from .utils import iob2, iob2bioes
from ... import Const from ...core.const import Const
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader
from .utils import _indexize, _add_words_field from .utils import _indexize, _add_words_field

View File

@ -2,8 +2,8 @@ import math
from .pipe import Pipe from .pipe import Pipe
from .utils import get_tokenizer from .utils import get_tokenizer
from ...core import Const from ...core.const import Const
from ...core import Vocabulary from ...core.vocabulary import Vocabulary
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader

View File

@ -1,6 +1,6 @@
from typing import List from typing import List
from ...core import Vocabulary from ...core.vocabulary import Vocabulary
from ...core import Const from ...core.const import Const
def iob2(tags:List[str])->List[str]: def iob2(tags:List[str])->List[str]:
""" """

View File

@ -51,7 +51,7 @@ class ChineseNERLoader(DataSetLoader):
:param paths: :param paths:
:param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>] :param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>]
:param bool, trigrams: 是否包含trigram feature[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>] :param bool, trigrams: 是否包含trigram feature[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>]
:return: DataBundle :return: ~fastNLP.io.DataBundle
包含以下的fields 包含以下的fields
raw_chars: List[str] raw_chars: List[str]
chars: List[int] chars: List[int]