add __all__ and __doc__ for all files in module 'core', using 'undocumented' tags

This commit is contained in:
ChenXin 2019-08-26 10:21:10 +08:00
parent 9535ec60b6
commit efe88263bb
9 changed files with 180 additions and 83 deletions

View File

@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa
对于常用的功能你只需要在 :doc:`fastNLP` 中查看即可如果想了解各个子模块的具体作用您可以在下面找到每个子模块的具体文档 对于常用的功能你只需要在 :doc:`fastNLP` 中查看即可如果想了解各个子模块的具体作用您可以在下面找到每个子模块的具体文档
""" """
__all__ = [
"DataSet",
"Instance",
"FieldArray",
"Padder",
"AutoPadder",
"EngChar2DPadder",
"Vocabulary",
"DataSetIter",
"BatchIter",
"TorchLoaderIter",
"Const",
"Tester",
"Trainer",
"cache_results",
"seq_len_to_mask",
"get_seq_len",
"logger",
"Callback",
"GradientClipCallback",
"EarlyStopCallback",
"FitlogCallback",
"EvaluateCallback",
"LRScheduler",
"ControlC",
"LRFinder",
"TensorboardCallback",
"WarmupCallback",
'SaveModelCallback',
"EchoCallback",
"TesterCallback",
"CallbackException",
"EarlyStopError",
"LossFunc",
"CrossEntropyLoss",
"L1Loss",
"BCELoss",
"NLLLoss",
"LossInForward",
"AccuracyMetric",
"SpanFPreRecMetric",
"ExtractiveQAMetric",
"Optimizer",
"SGD",
"Adam",
"AdamW",
"SequentialSampler",
"BucketSampler",
"RandomSampler",
"Sampler",
]
from ._logger import logger
from .batch import DataSetIter, BatchIter, TorchLoaderIter from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
@ -28,4 +92,3 @@ from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from ._logger import logger

View File

@ -1,15 +1,15 @@
import logging """undocumented"""
import logging.config
import torch
import _pickle as pickle
import os
import sys
import warnings
__all__ = [ __all__ = [
'logger', 'logger',
] ]
import logging
import logging.config
import os
import sys
import warnings
ROOT_NAME = 'fastNLP' ROOT_NAME = 'fastNLP'
try: try:
@ -103,7 +103,6 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
class FastNLPLogger(logging.getLoggerClass()): class FastNLPLogger(logging.getLoggerClass()):
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
@ -116,7 +115,10 @@ class FastNLPLogger(logging.getLoggerClass()):
"""set stdout format and level""" """set stdout format and level"""
_set_stdout_handler(self, stdout, level) _set_stdout_handler(self, stdout, level)
logging.setLoggerClass(FastNLPLogger) logging.setLoggerClass(FastNLPLogger)
# print(logging.getLoggerClass()) # print(logging.getLoggerClass())
# print(logging.getLogger()) # print(logging.getLogger())

View File

@ -1,11 +1,14 @@
"""undocumented"""
__all__ = []
import threading import threading
import torch import torch
from torch import nn from torch import nn
from torch.nn.parallel.parallel_apply import get_a_var from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from torch.nn.parallel.replicate import replicate from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
:param output_device: nn.DataParallel中的output_device :param output_device: nn.DataParallel中的output_device
:return: :return:
""" """
def wrapper(network, *inputs, **kwargs): def wrapper(network, *inputs, **kwargs):
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
if len(device_ids) == 1: if len(device_ids) == 1:
@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
replicas = replicate(network, device_ids[:len(inputs)]) replicas = replicate(network, device_ids[:len(inputs)])
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
return gather(outputs, output_device) return gather(outputs, output_device)
return wrapper return wrapper

View File

@ -1,3 +1,13 @@
"""
.. todo::
doc
"""
__all__ = [
"Const"
]
class Const: class Const:
""" """
fastNLP中field命名常量 fastNLP中field命名常量

View File

@ -1,29 +1,29 @@
""" """undocumented
正在开发中的分布式训练代码 正在开发中的分布式训练代码
""" """
import logging
import os
import time
from datetime import datetime
import torch import torch
import torch.cuda import torch.cuda
import torch.optim
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler import torch.optim
from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import os from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
import time
from datetime import datetime, timedelta
from functools import partial
from ._logger import logger
from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, TesterCallback from .callback import DistCallbackManager, CallbackException, TesterCallback
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .optimizer import Optimizer from .optimizer import Optimizer
from .utils import _build_args from .utils import _build_args
from .utils import _move_dict_value_to_device
from .utils import _get_func_signature from .utils import _get_func_signature
from ._logger import logger from .utils import _move_dict_value_to_device
import logging
from pkg_resources import parse_version
__all__ = [ __all__ = [
'get_local_rank', 'get_local_rank',

View File

@ -1,18 +1,25 @@
"""
.. todo::
doc
"""
__all__ = [ __all__ = [
"Padder", "Padder",
"AutoPadder", "AutoPadder",
"EngChar2DPadder", "EngChar2DPadder",
] ]
from numbers import Number
import torch
import numpy as np
from typing import Any
from abc import abstractmethod from abc import abstractmethod
from copy import deepcopy
from collections import Counter from collections import Counter
from .utils import _is_iterable from copy import deepcopy
from numbers import Number
from typing import Any
import numpy as np
import torch
from ._logger import logger from ._logger import logger
from .utils import _is_iterable
class SetInputOrTargetException(Exception): class SetInputOrTargetException(Exception):

View File

@ -1,13 +1,15 @@
""" """undocumented"""
..todo::
检查这个类是否需要 __all__ = [
""" "Predictor"
]
from collections import defaultdict from collections import defaultdict
import torch import torch
from . import DataSetIter
from . import DataSet from . import DataSet
from . import DataSetIter
from . import SequentialSampler from . import SequentialSampler
from .utils import _build_args, _move_dict_value_to_device, _get_model_device from .utils import _build_args, _move_dict_value_to_device, _get_model_device

View File

@ -1,16 +1,22 @@
"""
.. todo::
doc
"""
__all__ = [ __all__ = [
"Vocabulary", "Vocabulary",
"VocabularyOption", "VocabularyOption",
] ]
from functools import wraps
from collections import Counter from collections import Counter
from functools import partial
from functools import wraps
from ._logger import logger
from .dataset import DataSet from .dataset import DataSet
from .utils import Option from .utils import Option
from functools import partial
import numpy as np
from .utils import _is_iterable from .utils import _is_iterable
from ._logger import logger
class VocabularyOption(Option): class VocabularyOption(Option):
def __init__(self, def __init__(self,

View File

@ -8,15 +8,17 @@ __all__ = [
] ]
from abc import abstractmethod from abc import abstractmethod
import torch import torch
from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet
from ..core.batch import DataSetIter
from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device
from .embedding import TokenEmbedding from .embedding import TokenEmbedding
from ..core import logger from ..core import logger
from ..core.batch import DataSetIter
from ..core.dataset import DataSet
from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device
from ..core.vocabulary import Vocabulary
class ContextualEmbedding(TokenEmbedding): class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):