mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
add __all__ and __doc__ for all files in module 'core', using 'undocumented' tags
This commit is contained in:
parent
9535ec60b6
commit
efe88263bb
@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa
|
|||||||
|
|
||||||
对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。
|
对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
__all__ = [
|
||||||
|
"DataSet",
|
||||||
|
|
||||||
|
"Instance",
|
||||||
|
|
||||||
|
"FieldArray",
|
||||||
|
"Padder",
|
||||||
|
"AutoPadder",
|
||||||
|
"EngChar2DPadder",
|
||||||
|
|
||||||
|
"Vocabulary",
|
||||||
|
|
||||||
|
"DataSetIter",
|
||||||
|
"BatchIter",
|
||||||
|
"TorchLoaderIter",
|
||||||
|
|
||||||
|
"Const",
|
||||||
|
|
||||||
|
"Tester",
|
||||||
|
"Trainer",
|
||||||
|
|
||||||
|
"cache_results",
|
||||||
|
"seq_len_to_mask",
|
||||||
|
"get_seq_len",
|
||||||
|
"logger",
|
||||||
|
|
||||||
|
"Callback",
|
||||||
|
"GradientClipCallback",
|
||||||
|
"EarlyStopCallback",
|
||||||
|
"FitlogCallback",
|
||||||
|
"EvaluateCallback",
|
||||||
|
"LRScheduler",
|
||||||
|
"ControlC",
|
||||||
|
"LRFinder",
|
||||||
|
"TensorboardCallback",
|
||||||
|
"WarmupCallback",
|
||||||
|
'SaveModelCallback',
|
||||||
|
"EchoCallback",
|
||||||
|
"TesterCallback",
|
||||||
|
"CallbackException",
|
||||||
|
"EarlyStopError",
|
||||||
|
|
||||||
|
"LossFunc",
|
||||||
|
"CrossEntropyLoss",
|
||||||
|
"L1Loss",
|
||||||
|
"BCELoss",
|
||||||
|
"NLLLoss",
|
||||||
|
"LossInForward",
|
||||||
|
|
||||||
|
"AccuracyMetric",
|
||||||
|
"SpanFPreRecMetric",
|
||||||
|
"ExtractiveQAMetric",
|
||||||
|
|
||||||
|
"Optimizer",
|
||||||
|
"SGD",
|
||||||
|
"Adam",
|
||||||
|
"AdamW",
|
||||||
|
|
||||||
|
"SequentialSampler",
|
||||||
|
"BucketSampler",
|
||||||
|
"RandomSampler",
|
||||||
|
"Sampler",
|
||||||
|
]
|
||||||
|
|
||||||
|
from ._logger import logger
|
||||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter
|
from .batch import DataSetIter, BatchIter, TorchLoaderIter
|
||||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
|
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
|
||||||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
|
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
|
||||||
@ -28,4 +92,3 @@ from .tester import Tester
|
|||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
from .utils import cache_results, seq_len_to_mask, get_seq_len
|
from .utils import cache_results, seq_len_to_mask, get_seq_len
|
||||||
from .vocabulary import Vocabulary
|
from .vocabulary import Vocabulary
|
||||||
from ._logger import logger
|
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
import logging
|
"""undocumented"""
|
||||||
import logging.config
|
|
||||||
import torch
|
|
||||||
import _pickle as pickle
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'logger',
|
'logger',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
ROOT_NAME = 'fastNLP'
|
ROOT_NAME = 'fastNLP'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -103,7 +103,6 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
|
|||||||
logger.addHandler(stream_handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FastNLPLogger(logging.getLoggerClass()):
|
class FastNLPLogger(logging.getLoggerClass()):
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
@ -116,7 +115,10 @@ class FastNLPLogger(logging.getLoggerClass()):
|
|||||||
"""set stdout format and level"""
|
"""set stdout format and level"""
|
||||||
_set_stdout_handler(self, stdout, level)
|
_set_stdout_handler(self, stdout, level)
|
||||||
|
|
||||||
|
|
||||||
logging.setLoggerClass(FastNLPLogger)
|
logging.setLoggerClass(FastNLPLogger)
|
||||||
|
|
||||||
|
|
||||||
# print(logging.getLoggerClass())
|
# print(logging.getLoggerClass())
|
||||||
# print(logging.getLogger())
|
# print(logging.getLogger())
|
||||||
|
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
|
"""undocumented"""
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.parallel.parallel_apply import get_a_var
|
from torch.nn.parallel.parallel_apply import get_a_var
|
||||||
|
|
||||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
|
|
||||||
from torch.nn.parallel.replicate import replicate
|
from torch.nn.parallel.replicate import replicate
|
||||||
|
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
|
||||||
|
|
||||||
|
|
||||||
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
|
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
|
||||||
@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
|
|||||||
:param output_device: nn.DataParallel中的output_device
|
:param output_device: nn.DataParallel中的output_device
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(network, *inputs, **kwargs):
|
def wrapper(network, *inputs, **kwargs):
|
||||||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
|
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
|
||||||
if len(device_ids) == 1:
|
if len(device_ids) == 1:
|
||||||
@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
|
|||||||
replicas = replicate(network, device_ids[:len(inputs)])
|
replicas = replicate(network, device_ids[:len(inputs)])
|
||||||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
|
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
|
||||||
return gather(outputs, output_device)
|
return gather(outputs, output_device)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,13 @@
|
|||||||
|
"""
|
||||||
|
.. todo::
|
||||||
|
doc
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Const"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Const:
|
class Const:
|
||||||
"""
|
"""
|
||||||
fastNLP中field命名常量。
|
fastNLP中field命名常量。
|
||||||
|
@ -1,29 +1,29 @@
|
|||||||
"""
|
"""undocumented
|
||||||
正在开发中的分布式训练代码
|
正在开发中的分布式训练代码
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import torch.optim
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
import torch.optim
|
||||||
|
from pkg_resources import parse_version
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
import os
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import time
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
from ._logger import logger
|
||||||
from .batch import DataSetIter, BatchIter
|
from .batch import DataSetIter, BatchIter
|
||||||
from .callback import DistCallbackManager, CallbackException, TesterCallback
|
from .callback import DistCallbackManager, CallbackException, TesterCallback
|
||||||
from .dataset import DataSet
|
from .dataset import DataSet
|
||||||
from .losses import _prepare_losser
|
from .losses import _prepare_losser
|
||||||
from .optimizer import Optimizer
|
from .optimizer import Optimizer
|
||||||
from .utils import _build_args
|
from .utils import _build_args
|
||||||
from .utils import _move_dict_value_to_device
|
|
||||||
from .utils import _get_func_signature
|
from .utils import _get_func_signature
|
||||||
from ._logger import logger
|
from .utils import _move_dict_value_to_device
|
||||||
import logging
|
|
||||||
from pkg_resources import parse_version
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'get_local_rank',
|
'get_local_rank',
|
||||||
|
@ -1,18 +1,25 @@
|
|||||||
|
"""
|
||||||
|
.. todo::
|
||||||
|
doc
|
||||||
|
"""
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Padder",
|
"Padder",
|
||||||
"AutoPadder",
|
"AutoPadder",
|
||||||
"EngChar2DPadder",
|
"EngChar2DPadder",
|
||||||
]
|
]
|
||||||
|
|
||||||
from numbers import Number
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from typing import Any
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from copy import deepcopy
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from .utils import _is_iterable
|
from copy import deepcopy
|
||||||
|
from numbers import Number
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from ._logger import logger
|
from ._logger import logger
|
||||||
|
from .utils import _is_iterable
|
||||||
|
|
||||||
|
|
||||||
class SetInputOrTargetException(Exception):
|
class SetInputOrTargetException(Exception):
|
||||||
|
@ -1,13 +1,15 @@
|
|||||||
"""
|
"""undocumented"""
|
||||||
..todo::
|
|
||||||
检查这个类是否需要
|
__all__ = [
|
||||||
"""
|
"Predictor"
|
||||||
|
]
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import DataSetIter
|
|
||||||
from . import DataSet
|
from . import DataSet
|
||||||
|
from . import DataSetIter
|
||||||
from . import SequentialSampler
|
from . import SequentialSampler
|
||||||
from .utils import _build_args, _move_dict_value_to_device, _get_model_device
|
from .utils import _build_args, _move_dict_value_to_device, _get_model_device
|
||||||
|
|
||||||
|
@ -1,16 +1,22 @@
|
|||||||
|
"""
|
||||||
|
.. todo::
|
||||||
|
doc
|
||||||
|
"""
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Vocabulary",
|
"Vocabulary",
|
||||||
"VocabularyOption",
|
"VocabularyOption",
|
||||||
]
|
]
|
||||||
|
|
||||||
from functools import wraps
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from functools import partial
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from ._logger import logger
|
||||||
from .dataset import DataSet
|
from .dataset import DataSet
|
||||||
from .utils import Option
|
from .utils import Option
|
||||||
from functools import partial
|
|
||||||
import numpy as np
|
|
||||||
from .utils import _is_iterable
|
from .utils import _is_iterable
|
||||||
from ._logger import logger
|
|
||||||
|
|
||||||
class VocabularyOption(Option):
|
class VocabularyOption(Option):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -8,15 +8,17 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..core.vocabulary import Vocabulary
|
|
||||||
from ..core.dataset import DataSet
|
|
||||||
from ..core.batch import DataSetIter
|
|
||||||
from ..core.sampler import SequentialSampler
|
|
||||||
from ..core.utils import _move_model_to_device, _get_model_device
|
|
||||||
from .embedding import TokenEmbedding
|
from .embedding import TokenEmbedding
|
||||||
from ..core import logger
|
from ..core import logger
|
||||||
|
from ..core.batch import DataSetIter
|
||||||
|
from ..core.dataset import DataSet
|
||||||
|
from ..core.sampler import SequentialSampler
|
||||||
|
from ..core.utils import _move_model_to_device, _get_model_device
|
||||||
|
from ..core.vocabulary import Vocabulary
|
||||||
|
|
||||||
|
|
||||||
class ContextualEmbedding(TokenEmbedding):
|
class ContextualEmbedding(TokenEmbedding):
|
||||||
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):
|
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):
|
||||||
|
Loading…
Reference in New Issue
Block a user