add __all__ and __doc__ for all files in module 'core', using 'undocumented' tags

This commit is contained in:
ChenXin 2019-08-26 10:21:10 +08:00
parent 9535ec60b6
commit efe88263bb
9 changed files with 180 additions and 83 deletions

View File

@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa
对于常用的功能你只需要在 :doc:`fastNLP` 中查看即可如果想了解各个子模块的具体作用您可以在下面找到每个子模块的具体文档
"""
__all__ = [
"DataSet",
"Instance",
"FieldArray",
"Padder",
"AutoPadder",
"EngChar2DPadder",
"Vocabulary",
"DataSetIter",
"BatchIter",
"TorchLoaderIter",
"Const",
"Tester",
"Trainer",
"cache_results",
"seq_len_to_mask",
"get_seq_len",
"logger",
"Callback",
"GradientClipCallback",
"EarlyStopCallback",
"FitlogCallback",
"EvaluateCallback",
"LRScheduler",
"ControlC",
"LRFinder",
"TensorboardCallback",
"WarmupCallback",
'SaveModelCallback',
"EchoCallback",
"TesterCallback",
"CallbackException",
"EarlyStopError",
"LossFunc",
"CrossEntropyLoss",
"L1Loss",
"BCELoss",
"NLLLoss",
"LossInForward",
"AccuracyMetric",
"SpanFPreRecMetric",
"ExtractiveQAMetric",
"Optimizer",
"SGD",
"Adam",
"AdamW",
"SequentialSampler",
"BucketSampler",
"RandomSampler",
"Sampler",
]
from ._logger import logger
from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
@ -28,4 +92,3 @@ from .tester import Tester
from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary
from ._logger import logger

View File

@ -1,15 +1,15 @@
import logging
import logging.config
import torch
import _pickle as pickle
import os
import sys
import warnings
"""undocumented"""
__all__ = [
'logger',
]
import logging
import logging.config
import os
import sys
import warnings
ROOT_NAME = 'fastNLP'
try:
@ -25,7 +25,7 @@ if tqdm is not None:
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
@ -59,14 +59,14 @@ def _add_file_handler(logger, path, level='INFO'):
if os.path.abspath(path) == h.baseFilename:
# file path already added
return
# File Handler
if os.path.exists(path):
assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)
file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(_get_level(level))
file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s',
@ -87,7 +87,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
break
if stream_handler is not None:
logger.removeHandler(stream_handler)
# Stream Handler
if stdout == 'plain':
stream_handler = logging.StreamHandler(sys.stdout)
@ -95,7 +95,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
stream_handler = TqdmLoggingHandler(level)
else:
stream_handler = None
if stream_handler is not None:
stream_formatter = logging.Formatter('%(message)s')
stream_handler.setLevel(level)
@ -103,38 +103,40 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
logger.addHandler(stream_handler)
class FastNLPLogger(logging.getLoggerClass()):
def __init__(self, name):
super().__init__(name)
def add_file(self, path='./log.txt', level='INFO'):
"""add log output file and level"""
_add_file_handler(self, path, level)
def set_stdout(self, stdout='tqdm', level='INFO'):
"""set stdout format and level"""
_set_stdout_handler(self, stdout, level)
logging.setLoggerClass(FastNLPLogger)
# print(logging.getLoggerClass())
# print(logging.getLogger())
def _init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger"""
level = _get_level(level)
# logger = logging.getLogger()
logger = logging.getLogger(ROOT_NAME)
logger.propagate = False
logger.setLevel(level)
_set_stdout_handler(logger, stdout, level)
# File Handler
if path is not None:
_add_file_handler(logger, path, level)
return logger

View File

@ -1,11 +1,14 @@
"""undocumented"""
__all__ = []
import threading
import torch
from torch import nn
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
@ -27,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
@ -47,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
:param output_device: nn.DataParallel中的output_device
:return:
"""
def wrapper(network, *inputs, **kwargs):
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
if len(device_ids) == 1:
@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
replicas = replicate(network, device_ids[:len(inputs)])
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
return gather(outputs, output_device)
return wrapper
@ -99,4 +104,4 @@ def _model_contains_inner_module(model):
if isinstance(model, nn.Module):
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return True
return False
return False

View File

@ -1,3 +1,13 @@
"""
.. todo::
doc
"""
__all__ = [
"Const"
]
class Const:
"""
fastNLP中field命名常量
@ -25,47 +35,47 @@ class Const:
LOSS = 'loss'
RAW_WORD = 'raw_words'
RAW_CHAR = 'raw_chars'
@staticmethod
def INPUTS(i):
"""得到第 i 个 ``INPUT`` 的命名"""
i = int(i) + 1
return Const.INPUT + str(i)
@staticmethod
def CHAR_INPUTS(i):
"""得到第 i 个 ``CHAR_INPUT`` 的命名"""
i = int(i) + 1
return Const.CHAR_INPUT + str(i)
@staticmethod
def RAW_WORDS(i):
i = int(i) + 1
return Const.RAW_WORD + str(i)
@staticmethod
def RAW_CHARS(i):
i = int(i) + 1
return Const.RAW_CHAR + str(i)
@staticmethod
def INPUT_LENS(i):
"""得到第 i 个 ``INPUT_LEN`` 的命名"""
i = int(i) + 1
return Const.INPUT_LEN + str(i)
@staticmethod
def OUTPUTS(i):
"""得到第 i 个 ``OUTPUT`` 的命名"""
i = int(i) + 1
return Const.OUTPUT + str(i)
@staticmethod
def TARGETS(i):
"""得到第 i 个 ``TARGET`` 的命名"""
i = int(i) + 1
return Const.TARGET + str(i)
@staticmethod
def LOSSES(i):
"""得到第 i 个 ``LOSS`` 的命名"""

View File

@ -1,29 +1,29 @@
"""
"""undocumented
正在开发中的分布式训练代码
"""
import logging
import os
import time
from datetime import datetime
import torch
import torch.cuda
import torch.optim
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
import torch.optim
from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import time
from datetime import datetime, timedelta
from functools import partial
from ._logger import logger
from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, TesterCallback
from .dataset import DataSet
from .losses import _prepare_losser
from .optimizer import Optimizer
from .utils import _build_args
from .utils import _move_dict_value_to_device
from .utils import _get_func_signature
from ._logger import logger
import logging
from pkg_resources import parse_version
from .utils import _move_dict_value_to_device
__all__ = [
'get_local_rank',

View File

@ -1,18 +1,25 @@
"""
.. todo::
doc
"""
__all__ = [
"Padder",
"AutoPadder",
"EngChar2DPadder",
]
from numbers import Number
import torch
import numpy as np
from typing import Any
from abc import abstractmethod
from copy import deepcopy
from collections import Counter
from .utils import _is_iterable
from copy import deepcopy
from numbers import Number
from typing import Any
import numpy as np
import torch
from ._logger import logger
from .utils import _is_iterable
class SetInputOrTargetException(Exception):

View File

@ -1,13 +1,15 @@
"""
..todo::
检查这个类是否需要
"""
"""undocumented"""
__all__ = [
"Predictor"
]
from collections import defaultdict
import torch
from . import DataSetIter
from . import DataSet
from . import DataSetIter
from . import SequentialSampler
from .utils import _build_args, _move_dict_value_to_device, _get_model_device
@ -21,7 +23,7 @@ class Predictor(object):
:param torch.nn.Module network: 用来完成预测任务的模型
"""
def __init__(self, network):
if not isinstance(network, torch.nn.Module):
raise ValueError(
@ -29,7 +31,7 @@ class Predictor(object):
self.network = network
self.batch_size = 1
self.batch_output = []
def predict(self, data: DataSet, seq_len_field_name=None):
"""用已经训练好的模型进行inference.
@ -41,27 +43,27 @@ class Predictor(object):
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))
prev_training = self.network.training
self.network.eval()
network_device = _get_model_device(self.network)
batch_output = defaultdict(list)
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
if hasattr(self.network, "predict"):
predict_func = self.network.predict
else:
predict_func = self.network.forward
with torch.no_grad():
for batch_x, _ in data_iterator:
_move_dict_value_to_device(batch_x, _, device=network_device)
refined_batch_x = _build_args(predict_func, **batch_x)
prediction = predict_func(**refined_batch_x)
if seq_len_field_name is not None:
seq_lens = batch_x[seq_len_field_name].tolist()
for key, value in prediction.items():
value = value.cpu().numpy()
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
@ -74,6 +76,6 @@ class Predictor(object):
batch_output[key].extend(tmp_batch)
else:
batch_output[key].append(value)
self.network.train(prev_training)
return batch_output

View File

@ -1,16 +1,22 @@
"""
.. todo::
doc
"""
__all__ = [
"Vocabulary",
"VocabularyOption",
]
from functools import wraps
from collections import Counter
from functools import partial
from functools import wraps
from ._logger import logger
from .dataset import DataSet
from .utils import Option
from functools import partial
import numpy as np
from .utils import _is_iterable
from ._logger import logger
class VocabularyOption(Option):
def __init__(self,
@ -51,7 +57,7 @@ def _check_build_status(func):
self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size:
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__))
return func(self, *args, **kwargs)
@ -199,7 +205,7 @@ class Vocabulary(object):
self.build_reverse_vocab()
self.rebuild = False
return self
def build_reverse_vocab(self):
"""
基于 `word to index` dict, 构建 `index to word` dict.
@ -279,19 +285,19 @@ class Vocabulary(object):
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]):
raise RuntimeError("Only support field with 2 dimensions.")
return [[self.to_index(c) for c in w] for w in field]
new_field_name = new_field_name or field_name
if type(new_field_name) == type(field_name):
if isinstance(new_field_name, list):
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \
"field_name."
"field_name."
elif isinstance(new_field_name, str):
field_name = [field_name]
new_field_name = [new_field_name]
else:
raise TypeError("field_name and new_field_name can only be str or List[str].")
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
try:
@ -377,7 +383,7 @@ class Vocabulary(object):
:return: bool
"""
return word in self._no_create_word
def to_index(self, w):
"""
将词转为数字. 若词不再词典中被记录, 将视为 unknown, ``unknown=None`` , 将抛出``ValueError``::

View File

@ -8,15 +8,17 @@ __all__ = [
]
from abc import abstractmethod
import torch
from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet
from ..core.batch import DataSetIter
from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device
from .embedding import TokenEmbedding
from ..core import logger
from ..core.batch import DataSetIter
from ..core.dataset import DataSet
from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device
from ..core.vocabulary import Vocabulary
class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):