mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
add __all__ and __doc__ for all files in module 'core', using 'undocumented' tags
This commit is contained in:
parent
9535ec60b6
commit
efe88263bb
@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa
|
||||
|
||||
对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。
|
||||
|
||||
|
||||
"""
|
||||
__all__ = [
|
||||
"DataSet",
|
||||
|
||||
"Instance",
|
||||
|
||||
"FieldArray",
|
||||
"Padder",
|
||||
"AutoPadder",
|
||||
"EngChar2DPadder",
|
||||
|
||||
"Vocabulary",
|
||||
|
||||
"DataSetIter",
|
||||
"BatchIter",
|
||||
"TorchLoaderIter",
|
||||
|
||||
"Const",
|
||||
|
||||
"Tester",
|
||||
"Trainer",
|
||||
|
||||
"cache_results",
|
||||
"seq_len_to_mask",
|
||||
"get_seq_len",
|
||||
"logger",
|
||||
|
||||
"Callback",
|
||||
"GradientClipCallback",
|
||||
"EarlyStopCallback",
|
||||
"FitlogCallback",
|
||||
"EvaluateCallback",
|
||||
"LRScheduler",
|
||||
"ControlC",
|
||||
"LRFinder",
|
||||
"TensorboardCallback",
|
||||
"WarmupCallback",
|
||||
'SaveModelCallback',
|
||||
"EchoCallback",
|
||||
"TesterCallback",
|
||||
"CallbackException",
|
||||
"EarlyStopError",
|
||||
|
||||
"LossFunc",
|
||||
"CrossEntropyLoss",
|
||||
"L1Loss",
|
||||
"BCELoss",
|
||||
"NLLLoss",
|
||||
"LossInForward",
|
||||
|
||||
"AccuracyMetric",
|
||||
"SpanFPreRecMetric",
|
||||
"ExtractiveQAMetric",
|
||||
|
||||
"Optimizer",
|
||||
"SGD",
|
||||
"Adam",
|
||||
"AdamW",
|
||||
|
||||
"SequentialSampler",
|
||||
"BucketSampler",
|
||||
"RandomSampler",
|
||||
"Sampler",
|
||||
]
|
||||
|
||||
from ._logger import logger
|
||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter
|
||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
|
||||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
|
||||
@ -28,4 +92,3 @@ from .tester import Tester
|
||||
from .trainer import Trainer
|
||||
from .utils import cache_results, seq_len_to_mask, get_seq_len
|
||||
from .vocabulary import Vocabulary
|
||||
from ._logger import logger
|
||||
|
@ -1,15 +1,15 @@
|
||||
import logging
|
||||
import logging.config
|
||||
import torch
|
||||
import _pickle as pickle
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
"""undocumented"""
|
||||
|
||||
__all__ = [
|
||||
'logger',
|
||||
]
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
ROOT_NAME = 'fastNLP'
|
||||
|
||||
try:
|
||||
@ -103,7 +103,6 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
|
||||
class FastNLPLogger(logging.getLoggerClass()):
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
@ -116,7 +115,10 @@ class FastNLPLogger(logging.getLoggerClass()):
|
||||
"""set stdout format and level"""
|
||||
_set_stdout_handler(self, stdout, level)
|
||||
|
||||
|
||||
logging.setLoggerClass(FastNLPLogger)
|
||||
|
||||
|
||||
# print(logging.getLoggerClass())
|
||||
# print(logging.getLogger())
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
"""undocumented"""
|
||||
|
||||
__all__ = []
|
||||
|
||||
import threading
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parallel.parallel_apply import get_a_var
|
||||
|
||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
|
||||
from torch.nn.parallel.replicate import replicate
|
||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
|
||||
|
||||
|
||||
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
|
||||
@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
|
||||
:param output_device: nn.DataParallel中的output_device
|
||||
:return:
|
||||
"""
|
||||
|
||||
def wrapper(network, *inputs, **kwargs):
|
||||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
|
||||
if len(device_ids) == 1:
|
||||
@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
|
||||
replicas = replicate(network, device_ids[:len(inputs)])
|
||||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
|
||||
return gather(outputs, output_device)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
@ -1,3 +1,13 @@
|
||||
"""
|
||||
.. todo::
|
||||
doc
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"Const"
|
||||
]
|
||||
|
||||
|
||||
class Const:
|
||||
"""
|
||||
fastNLP中field命名常量。
|
||||
|
@ -1,29 +1,29 @@
|
||||
"""
|
||||
"""undocumented
|
||||
正在开发中的分布式训练代码
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.optim
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import torch.optim
|
||||
from pkg_resources import parse_version
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import os
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
|
||||
from ._logger import logger
|
||||
from .batch import DataSetIter, BatchIter
|
||||
from .callback import DistCallbackManager, CallbackException, TesterCallback
|
||||
from .dataset import DataSet
|
||||
from .losses import _prepare_losser
|
||||
from .optimizer import Optimizer
|
||||
from .utils import _build_args
|
||||
from .utils import _move_dict_value_to_device
|
||||
from .utils import _get_func_signature
|
||||
from ._logger import logger
|
||||
import logging
|
||||
from pkg_resources import parse_version
|
||||
from .utils import _move_dict_value_to_device
|
||||
|
||||
__all__ = [
|
||||
'get_local_rank',
|
||||
|
@ -1,18 +1,25 @@
|
||||
"""
|
||||
.. todo::
|
||||
doc
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"Padder",
|
||||
"AutoPadder",
|
||||
"EngChar2DPadder",
|
||||
]
|
||||
|
||||
from numbers import Number
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from collections import Counter
|
||||
from .utils import _is_iterable
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ._logger import logger
|
||||
from .utils import _is_iterable
|
||||
|
||||
|
||||
class SetInputOrTargetException(Exception):
|
||||
|
@ -1,13 +1,15 @@
|
||||
"""
|
||||
..todo::
|
||||
检查这个类是否需要
|
||||
"""
|
||||
"""undocumented"""
|
||||
|
||||
__all__ = [
|
||||
"Predictor"
|
||||
]
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from . import DataSetIter
|
||||
from . import DataSet
|
||||
from . import DataSetIter
|
||||
from . import SequentialSampler
|
||||
from .utils import _build_args, _move_dict_value_to_device, _get_model_device
|
||||
|
||||
|
@ -1,16 +1,22 @@
|
||||
"""
|
||||
.. todo::
|
||||
doc
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"Vocabulary",
|
||||
"VocabularyOption",
|
||||
]
|
||||
|
||||
from functools import wraps
|
||||
from collections import Counter
|
||||
from functools import partial
|
||||
from functools import wraps
|
||||
|
||||
from ._logger import logger
|
||||
from .dataset import DataSet
|
||||
from .utils import Option
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
from .utils import _is_iterable
|
||||
from ._logger import logger
|
||||
|
||||
|
||||
class VocabularyOption(Option):
|
||||
def __init__(self,
|
||||
|
@ -8,15 +8,17 @@ __all__ = [
|
||||
]
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from ..core.vocabulary import Vocabulary
|
||||
from ..core.dataset import DataSet
|
||||
from ..core.batch import DataSetIter
|
||||
from ..core.sampler import SequentialSampler
|
||||
from ..core.utils import _move_model_to_device, _get_model_device
|
||||
from .embedding import TokenEmbedding
|
||||
from ..core import logger
|
||||
from ..core.batch import DataSetIter
|
||||
from ..core.dataset import DataSet
|
||||
from ..core.sampler import SequentialSampler
|
||||
from ..core.utils import _move_model_to_device, _get_model_device
|
||||
from ..core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
class ContextualEmbedding(TokenEmbedding):
|
||||
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):
|
||||
|
Loading…
Reference in New Issue
Block a user