add __all__ and __doc__ for all files in module 'core', using 'undocumented' tags

This commit is contained in:
ChenXin 2019-08-26 10:21:10 +08:00
parent 9535ec60b6
commit efe88263bb
9 changed files with 180 additions and 83 deletions

View File

@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa
对于常用的功能你只需要在 :doc:`fastNLP` 中查看即可如果想了解各个子模块的具体作用您可以在下面找到每个子模块的具体文档
"""
__all__ = [
"DataSet",
"Instance",
"FieldArray",
"Padder",
"AutoPadder",
"EngChar2DPadder",
"Vocabulary",
"DataSetIter",
"BatchIter",
"TorchLoaderIter",
"Const",
"Tester",
"Trainer",
"cache_results",
"seq_len_to_mask",
"get_seq_len",
"logger",
"Callback",
"GradientClipCallback",
"EarlyStopCallback",
"FitlogCallback",
"EvaluateCallback",
"LRScheduler",
"ControlC",
"LRFinder",
"TensorboardCallback",
"WarmupCallback",
'SaveModelCallback',
"EchoCallback",
"TesterCallback",
"CallbackException",
"EarlyStopError",
"LossFunc",
"CrossEntropyLoss",
"L1Loss",
"BCELoss",
"NLLLoss",
"LossInForward",
"AccuracyMetric",
"SpanFPreRecMetric",
"ExtractiveQAMetric",
"Optimizer",
"SGD",
"Adam",
"AdamW",
"SequentialSampler",
"BucketSampler",
"RandomSampler",
"Sampler",
]
from ._logger import logger
from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
@ -28,4 +92,3 @@ from .tester import Tester
from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary
from ._logger import logger

View File

@ -1,15 +1,15 @@
import logging
import logging.config
import torch
import _pickle as pickle
import os
import sys
import warnings
"""undocumented"""
__all__ = [
'logger',
]
import logging
import logging.config
import os
import sys
import warnings
ROOT_NAME = 'fastNLP'
try:
@ -103,7 +103,6 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
logger.addHandler(stream_handler)
class FastNLPLogger(logging.getLoggerClass()):
def __init__(self, name):
super().__init__(name)
@ -116,7 +115,10 @@ class FastNLPLogger(logging.getLoggerClass()):
"""set stdout format and level"""
_set_stdout_handler(self, stdout, level)
logging.setLoggerClass(FastNLPLogger)
# print(logging.getLoggerClass())
# print(logging.getLogger())

View File

@ -1,11 +1,14 @@
"""undocumented"""
__all__ = []
import threading
import torch
from torch import nn
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
:param output_device: nn.DataParallel中的output_device
:return:
"""
def wrapper(network, *inputs, **kwargs):
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
if len(device_ids) == 1:
@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
replicas = replicate(network, device_ids[:len(inputs)])
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
return gather(outputs, output_device)
return wrapper

View File

@ -1,3 +1,13 @@
"""
.. todo::
doc
"""
__all__ = [
"Const"
]
class Const:
"""
fastNLP中field命名常量

View File

@ -1,29 +1,29 @@
"""
"""undocumented
正在开发中的分布式训练代码
"""
import logging
import os
import time
from datetime import datetime
import torch
import torch.cuda
import torch.optim
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
import torch.optim
from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import time
from datetime import datetime, timedelta
from functools import partial
from ._logger import logger
from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, TesterCallback
from .dataset import DataSet
from .losses import _prepare_losser
from .optimizer import Optimizer
from .utils import _build_args
from .utils import _move_dict_value_to_device
from .utils import _get_func_signature
from ._logger import logger
import logging
from pkg_resources import parse_version
from .utils import _move_dict_value_to_device
__all__ = [
'get_local_rank',

View File

@ -1,18 +1,25 @@
"""
.. todo::
doc
"""
__all__ = [
"Padder",
"AutoPadder",
"EngChar2DPadder",
]
from numbers import Number
import torch
import numpy as np
from typing import Any
from abc import abstractmethod
from copy import deepcopy
from collections import Counter
from .utils import _is_iterable
from copy import deepcopy
from numbers import Number
from typing import Any
import numpy as np
import torch
from ._logger import logger
from .utils import _is_iterable
class SetInputOrTargetException(Exception):

View File

@ -1,13 +1,15 @@
"""
..todo::
检查这个类是否需要
"""
"""undocumented"""
__all__ = [
"Predictor"
]
from collections import defaultdict
import torch
from . import DataSetIter
from . import DataSet
from . import DataSetIter
from . import SequentialSampler
from .utils import _build_args, _move_dict_value_to_device, _get_model_device

View File

@ -1,16 +1,22 @@
"""
.. todo::
doc
"""
__all__ = [
"Vocabulary",
"VocabularyOption",
]
from functools import wraps
from collections import Counter
from functools import partial
from functools import wraps
from ._logger import logger
from .dataset import DataSet
from .utils import Option
from functools import partial
import numpy as np
from .utils import _is_iterable
from ._logger import logger
class VocabularyOption(Option):
def __init__(self,

View File

@ -8,15 +8,17 @@ __all__ = [
]
from abc import abstractmethod
import torch
from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet
from ..core.batch import DataSetIter
from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device
from .embedding import TokenEmbedding
from ..core import logger
from ..core.batch import DataSetIter
from ..core.dataset import DataSet
from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device
from ..core.vocabulary import Vocabulary
class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):