mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
add __all__ and __doc__ for all files in module 'core', using 'undocumented' tags
This commit is contained in:
parent
9535ec60b6
commit
efe88263bb
@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa
|
||||
|
||||
对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。
|
||||
|
||||
|
||||
"""
|
||||
__all__ = [
|
||||
"DataSet",
|
||||
|
||||
"Instance",
|
||||
|
||||
"FieldArray",
|
||||
"Padder",
|
||||
"AutoPadder",
|
||||
"EngChar2DPadder",
|
||||
|
||||
"Vocabulary",
|
||||
|
||||
"DataSetIter",
|
||||
"BatchIter",
|
||||
"TorchLoaderIter",
|
||||
|
||||
"Const",
|
||||
|
||||
"Tester",
|
||||
"Trainer",
|
||||
|
||||
"cache_results",
|
||||
"seq_len_to_mask",
|
||||
"get_seq_len",
|
||||
"logger",
|
||||
|
||||
"Callback",
|
||||
"GradientClipCallback",
|
||||
"EarlyStopCallback",
|
||||
"FitlogCallback",
|
||||
"EvaluateCallback",
|
||||
"LRScheduler",
|
||||
"ControlC",
|
||||
"LRFinder",
|
||||
"TensorboardCallback",
|
||||
"WarmupCallback",
|
||||
'SaveModelCallback',
|
||||
"EchoCallback",
|
||||
"TesterCallback",
|
||||
"CallbackException",
|
||||
"EarlyStopError",
|
||||
|
||||
"LossFunc",
|
||||
"CrossEntropyLoss",
|
||||
"L1Loss",
|
||||
"BCELoss",
|
||||
"NLLLoss",
|
||||
"LossInForward",
|
||||
|
||||
"AccuracyMetric",
|
||||
"SpanFPreRecMetric",
|
||||
"ExtractiveQAMetric",
|
||||
|
||||
"Optimizer",
|
||||
"SGD",
|
||||
"Adam",
|
||||
"AdamW",
|
||||
|
||||
"SequentialSampler",
|
||||
"BucketSampler",
|
||||
"RandomSampler",
|
||||
"Sampler",
|
||||
]
|
||||
|
||||
from ._logger import logger
|
||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter
|
||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
|
||||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
|
||||
@ -28,4 +92,3 @@ from .tester import Tester
|
||||
from .trainer import Trainer
|
||||
from .utils import cache_results, seq_len_to_mask, get_seq_len
|
||||
from .vocabulary import Vocabulary
|
||||
from ._logger import logger
|
||||
|
@ -1,15 +1,15 @@
|
||||
import logging
|
||||
import logging.config
|
||||
import torch
|
||||
import _pickle as pickle
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
"""undocumented"""
|
||||
|
||||
__all__ = [
|
||||
'logger',
|
||||
]
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
ROOT_NAME = 'fastNLP'
|
||||
|
||||
try:
|
||||
@ -25,7 +25,7 @@ if tqdm is not None:
|
||||
class TqdmLoggingHandler(logging.Handler):
|
||||
def __init__(self, level=logging.INFO):
|
||||
super().__init__(level)
|
||||
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
msg = self.format(record)
|
||||
@ -59,14 +59,14 @@ def _add_file_handler(logger, path, level='INFO'):
|
||||
if os.path.abspath(path) == h.baseFilename:
|
||||
# file path already added
|
||||
return
|
||||
|
||||
|
||||
# File Handler
|
||||
if os.path.exists(path):
|
||||
assert os.path.isfile(path)
|
||||
warnings.warn('log already exists in {}'.format(path))
|
||||
dirname = os.path.abspath(os.path.dirname(path))
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
|
||||
file_handler = logging.FileHandler(path, mode='a')
|
||||
file_handler.setLevel(_get_level(level))
|
||||
file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s',
|
||||
@ -87,7 +87,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
|
||||
break
|
||||
if stream_handler is not None:
|
||||
logger.removeHandler(stream_handler)
|
||||
|
||||
|
||||
# Stream Handler
|
||||
if stdout == 'plain':
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
@ -95,7 +95,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
|
||||
stream_handler = TqdmLoggingHandler(level)
|
||||
else:
|
||||
stream_handler = None
|
||||
|
||||
|
||||
if stream_handler is not None:
|
||||
stream_formatter = logging.Formatter('%(message)s')
|
||||
stream_handler.setLevel(level)
|
||||
@ -103,38 +103,40 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
|
||||
class FastNLPLogger(logging.getLoggerClass()):
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
|
||||
def add_file(self, path='./log.txt', level='INFO'):
|
||||
"""add log output file and level"""
|
||||
_add_file_handler(self, path, level)
|
||||
|
||||
|
||||
def set_stdout(self, stdout='tqdm', level='INFO'):
|
||||
"""set stdout format and level"""
|
||||
_set_stdout_handler(self, stdout, level)
|
||||
|
||||
|
||||
logging.setLoggerClass(FastNLPLogger)
|
||||
|
||||
|
||||
# print(logging.getLoggerClass())
|
||||
# print(logging.getLogger())
|
||||
|
||||
def _init_logger(path=None, stdout='tqdm', level='INFO'):
|
||||
"""initialize logger"""
|
||||
level = _get_level(level)
|
||||
|
||||
|
||||
# logger = logging.getLogger()
|
||||
logger = logging.getLogger(ROOT_NAME)
|
||||
logger.propagate = False
|
||||
logger.setLevel(level)
|
||||
|
||||
|
||||
_set_stdout_handler(logger, stdout, level)
|
||||
|
||||
|
||||
# File Handler
|
||||
if path is not None:
|
||||
_add_file_handler(logger, path, level)
|
||||
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
"""undocumented"""
|
||||
|
||||
__all__ = []
|
||||
|
||||
import threading
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parallel.parallel_apply import get_a_var
|
||||
|
||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
|
||||
from torch.nn.parallel.replicate import replicate
|
||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
|
||||
|
||||
|
||||
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
|
||||
@ -27,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
|
||||
assert len(modules) == len(devices)
|
||||
else:
|
||||
devices = [None] * len(modules)
|
||||
|
||||
|
||||
lock = threading.Lock()
|
||||
results = {}
|
||||
grad_enabled = torch.is_grad_enabled()
|
||||
|
||||
|
||||
def _worker(i, module, input, kwargs, device=None):
|
||||
torch.set_grad_enabled(grad_enabled)
|
||||
if device is None:
|
||||
@ -47,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
|
||||
except Exception as e:
|
||||
with lock:
|
||||
results[i] = e
|
||||
|
||||
|
||||
if len(modules) > 1:
|
||||
threads = [threading.Thread(target=_worker,
|
||||
args=(i, module, input, kwargs, device))
|
||||
for i, (module, input, kwargs, device) in
|
||||
enumerate(zip(modules, inputs, kwargs_tup, devices))]
|
||||
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
else:
|
||||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
||||
|
||||
|
||||
outputs = []
|
||||
for i in range(len(inputs)):
|
||||
output = results[i]
|
||||
@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
|
||||
:param output_device: nn.DataParallel中的output_device
|
||||
:return:
|
||||
"""
|
||||
|
||||
def wrapper(network, *inputs, **kwargs):
|
||||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
|
||||
if len(device_ids) == 1:
|
||||
@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
|
||||
replicas = replicate(network, device_ids[:len(inputs)])
|
||||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
|
||||
return gather(outputs, output_device)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -99,4 +104,4 @@ def _model_contains_inner_module(model):
|
||||
if isinstance(model, nn.Module):
|
||||
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
@ -1,3 +1,13 @@
|
||||
"""
|
||||
.. todo::
|
||||
doc
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"Const"
|
||||
]
|
||||
|
||||
|
||||
class Const:
|
||||
"""
|
||||
fastNLP中field命名常量。
|
||||
@ -25,47 +35,47 @@ class Const:
|
||||
LOSS = 'loss'
|
||||
RAW_WORD = 'raw_words'
|
||||
RAW_CHAR = 'raw_chars'
|
||||
|
||||
|
||||
@staticmethod
|
||||
def INPUTS(i):
|
||||
"""得到第 i 个 ``INPUT`` 的命名"""
|
||||
i = int(i) + 1
|
||||
return Const.INPUT + str(i)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def CHAR_INPUTS(i):
|
||||
"""得到第 i 个 ``CHAR_INPUT`` 的命名"""
|
||||
i = int(i) + 1
|
||||
return Const.CHAR_INPUT + str(i)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def RAW_WORDS(i):
|
||||
i = int(i) + 1
|
||||
return Const.RAW_WORD + str(i)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def RAW_CHARS(i):
|
||||
i = int(i) + 1
|
||||
return Const.RAW_CHAR + str(i)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def INPUT_LENS(i):
|
||||
"""得到第 i 个 ``INPUT_LEN`` 的命名"""
|
||||
i = int(i) + 1
|
||||
return Const.INPUT_LEN + str(i)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def OUTPUTS(i):
|
||||
"""得到第 i 个 ``OUTPUT`` 的命名"""
|
||||
i = int(i) + 1
|
||||
return Const.OUTPUT + str(i)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def TARGETS(i):
|
||||
"""得到第 i 个 ``TARGET`` 的命名"""
|
||||
i = int(i) + 1
|
||||
return Const.TARGET + str(i)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def LOSSES(i):
|
||||
"""得到第 i 个 ``LOSS`` 的命名"""
|
||||
|
@ -1,29 +1,29 @@
|
||||
"""
|
||||
"""undocumented
|
||||
正在开发中的分布式训练代码
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.optim
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torch.optim
|
||||
from pkg_resources import parse_version
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import os
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
|
||||
from ._logger import logger
|
||||
from .batch import DataSetIter, BatchIter
|
||||
from .callback import DistCallbackManager, CallbackException, TesterCallback
|
||||
from .dataset import DataSet
|
||||
from .losses import _prepare_losser
|
||||
from .optimizer import Optimizer
|
||||
from .utils import _build_args
|
||||
from .utils import _move_dict_value_to_device
|
||||
from .utils import _get_func_signature
|
||||
from ._logger import logger
|
||||
import logging
|
||||
from pkg_resources import parse_version
|
||||
from .utils import _move_dict_value_to_device
|
||||
|
||||
__all__ = [
|
||||
'get_local_rank',
|
||||
|
@ -1,18 +1,25 @@
|
||||
"""
|
||||
.. todo::
|
||||
doc
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"Padder",
|
||||
"AutoPadder",
|
||||
"EngChar2DPadder",
|
||||
]
|
||||
|
||||
from numbers import Number
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from collections import Counter
|
||||
from .utils import _is_iterable
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ._logger import logger
|
||||
from .utils import _is_iterable
|
||||
|
||||
|
||||
class SetInputOrTargetException(Exception):
|
||||
|
@ -1,13 +1,15 @@
|
||||
"""
|
||||
..todo::
|
||||
检查这个类是否需要
|
||||
"""
|
||||
"""undocumented"""
|
||||
|
||||
__all__ = [
|
||||
"Predictor"
|
||||
]
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from . import DataSetIter
|
||||
from . import DataSet
|
||||
from . import DataSetIter
|
||||
from . import SequentialSampler
|
||||
from .utils import _build_args, _move_dict_value_to_device, _get_model_device
|
||||
|
||||
@ -21,7 +23,7 @@ class Predictor(object):
|
||||
|
||||
:param torch.nn.Module network: 用来完成预测任务的模型
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, network):
|
||||
if not isinstance(network, torch.nn.Module):
|
||||
raise ValueError(
|
||||
@ -29,7 +31,7 @@ class Predictor(object):
|
||||
self.network = network
|
||||
self.batch_size = 1
|
||||
self.batch_output = []
|
||||
|
||||
|
||||
def predict(self, data: DataSet, seq_len_field_name=None):
|
||||
"""用已经训练好的模型进行inference.
|
||||
|
||||
@ -41,27 +43,27 @@ class Predictor(object):
|
||||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
|
||||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
|
||||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))
|
||||
|
||||
|
||||
prev_training = self.network.training
|
||||
self.network.eval()
|
||||
network_device = _get_model_device(self.network)
|
||||
batch_output = defaultdict(list)
|
||||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
|
||||
|
||||
|
||||
if hasattr(self.network, "predict"):
|
||||
predict_func = self.network.predict
|
||||
else:
|
||||
predict_func = self.network.forward
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_x, _ in data_iterator:
|
||||
_move_dict_value_to_device(batch_x, _, device=network_device)
|
||||
refined_batch_x = _build_args(predict_func, **batch_x)
|
||||
prediction = predict_func(**refined_batch_x)
|
||||
|
||||
|
||||
if seq_len_field_name is not None:
|
||||
seq_lens = batch_x[seq_len_field_name].tolist()
|
||||
|
||||
|
||||
for key, value in prediction.items():
|
||||
value = value.cpu().numpy()
|
||||
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
|
||||
@ -74,6 +76,6 @@ class Predictor(object):
|
||||
batch_output[key].extend(tmp_batch)
|
||||
else:
|
||||
batch_output[key].append(value)
|
||||
|
||||
|
||||
self.network.train(prev_training)
|
||||
return batch_output
|
||||
|
@ -1,16 +1,22 @@
|
||||
"""
|
||||
.. todo::
|
||||
doc
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"Vocabulary",
|
||||
"VocabularyOption",
|
||||
]
|
||||
|
||||
from functools import wraps
|
||||
from collections import Counter
|
||||
from functools import partial
|
||||
from functools import wraps
|
||||
|
||||
from ._logger import logger
|
||||
from .dataset import DataSet
|
||||
from .utils import Option
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
from .utils import _is_iterable
|
||||
from ._logger import logger
|
||||
|
||||
|
||||
class VocabularyOption(Option):
|
||||
def __init__(self,
|
||||
@ -51,7 +57,7 @@ def _check_build_status(func):
|
||||
self.rebuild = True
|
||||
if self.max_size is not None and len(self.word_count) >= self.max_size:
|
||||
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. "
|
||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
|
||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
|
||||
self.max_size, func.__name__))
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
@ -199,7 +205,7 @@ class Vocabulary(object):
|
||||
self.build_reverse_vocab()
|
||||
self.rebuild = False
|
||||
return self
|
||||
|
||||
|
||||
def build_reverse_vocab(self):
|
||||
"""
|
||||
基于 `word to index` dict, 构建 `index to word` dict.
|
||||
@ -279,19 +285,19 @@ class Vocabulary(object):
|
||||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]):
|
||||
raise RuntimeError("Only support field with 2 dimensions.")
|
||||
return [[self.to_index(c) for c in w] for w in field]
|
||||
|
||||
|
||||
new_field_name = new_field_name or field_name
|
||||
|
||||
|
||||
if type(new_field_name) == type(field_name):
|
||||
if isinstance(new_field_name, list):
|
||||
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \
|
||||
"field_name."
|
||||
"field_name."
|
||||
elif isinstance(new_field_name, str):
|
||||
field_name = [field_name]
|
||||
new_field_name = [new_field_name]
|
||||
else:
|
||||
raise TypeError("field_name and new_field_name can only be str or List[str].")
|
||||
|
||||
|
||||
for idx, dataset in enumerate(datasets):
|
||||
if isinstance(dataset, DataSet):
|
||||
try:
|
||||
@ -377,7 +383,7 @@ class Vocabulary(object):
|
||||
:return: bool
|
||||
"""
|
||||
return word in self._no_create_word
|
||||
|
||||
|
||||
def to_index(self, w):
|
||||
"""
|
||||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``::
|
||||
|
@ -8,15 +8,17 @@ __all__ = [
|
||||
]
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from ..core.vocabulary import Vocabulary
|
||||
from ..core.dataset import DataSet
|
||||
from ..core.batch import DataSetIter
|
||||
from ..core.sampler import SequentialSampler
|
||||
from ..core.utils import _move_model_to_device, _get_model_device
|
||||
from .embedding import TokenEmbedding
|
||||
from ..core import logger
|
||||
from ..core.batch import DataSetIter
|
||||
from ..core.dataset import DataSet
|
||||
from ..core.sampler import SequentialSampler
|
||||
from ..core.utils import _move_model_to_device, _get_model_device
|
||||
from ..core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
class ContextualEmbedding(TokenEmbedding):
|
||||
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):
|
||||
|
Loading…
Reference in New Issue
Block a user