mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 18:59:01 +08:00
* fix processor.py
* add code comments * merge *_saver.py & *_loader.py in io/ * (ancient codes) rename Loss into LossFromTorch
This commit is contained in:
parent
306eee9690
commit
27e9453d19
@ -1,5 +1,3 @@
|
||||
import torch
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
@ -7,6 +5,8 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from requests.utils import urlparse
|
||||
from requests import get as urlopen
|
||||
@ -132,7 +132,3 @@ if tqdm is None:
|
||||
|
||||
sys.stderr.write('\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context-4e86fd93.pkl', model_dir='.')
|
||||
print(type(pipeline))
|
||||
|
@ -1,14 +1,15 @@
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.sampler import SequentialSampler
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
class Processor:
|
||||
class Processor(object):
|
||||
def __init__(self, field_name, new_added_field_name):
|
||||
self.field_name = field_name
|
||||
if new_added_field_name is None:
|
||||
@ -17,7 +18,7 @@ class Processor:
|
||||
self.new_added_field_name = new_added_field_name
|
||||
|
||||
def process(self, *args, **kwargs):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.process(*args, **kwargs)
|
||||
@ -132,13 +133,14 @@ class Num2TagProcessor(Processor):
|
||||
|
||||
|
||||
class IndexerProcessor(Processor):
|
||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False):
|
||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):
|
||||
|
||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
|
||||
|
||||
super(IndexerProcessor, self).__init__(field_name, new_added_field_name)
|
||||
self.vocab = vocab
|
||||
self.delete_old_field = delete_old_field
|
||||
self.is_input = is_input
|
||||
|
||||
def set_vocab(self, vocab):
|
||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
|
||||
@ -146,13 +148,14 @@ class IndexerProcessor(Processor):
|
||||
self.vocab = vocab
|
||||
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
|
||||
for ins in dataset:
|
||||
tokens = ins[self.field_name]
|
||||
index = [self.vocab.to_index(token) for token in tokens]
|
||||
ins[self.new_added_field_name] = index
|
||||
|
||||
dataset._set_need_tensor(**{self.new_added_field_name: True})
|
||||
if self.is_input:
|
||||
dataset.set_input(self.new_added_field_name)
|
||||
|
||||
if self.delete_old_field:
|
||||
dataset.delete_field(self.field_name)
|
||||
@ -161,6 +164,9 @@ class IndexerProcessor(Processor):
|
||||
|
||||
|
||||
class VocabProcessor(Processor):
|
||||
"""Build vocabulary with a field in the data set.
|
||||
|
||||
"""
|
||||
def __init__(self, field_name):
|
||||
super(VocabProcessor, self).__init__(field_name, None)
|
||||
self.vocab = Vocabulary()
|
||||
@ -178,17 +184,20 @@ class VocabProcessor(Processor):
|
||||
|
||||
|
||||
class SeqLenProcessor(Processor):
|
||||
def __init__(self, field_name, new_added_field_name='seq_lens'):
|
||||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
|
||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
|
||||
self.is_input = is_input
|
||||
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
for ins in dataset:
|
||||
length = len(ins[self.field_name])
|
||||
ins[self.new_added_field_name] = length
|
||||
dataset._set_need_tensor(**{self.new_added_field_name: True})
|
||||
if self.is_input:
|
||||
dataset.set_input(self.new_added_field_name)
|
||||
return dataset
|
||||
|
||||
|
||||
class ModelProcessor(Processor):
|
||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
|
||||
"""
|
||||
@ -238,6 +247,7 @@ class ModelProcessor(Processor):
|
||||
device = torch.device(device)
|
||||
self.model.to(device)
|
||||
|
||||
|
||||
class Index2WordProcessor(Processor):
|
||||
def __init__(self, vocab, field_name, new_added_field_name):
|
||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
|
||||
@ -251,6 +261,7 @@ class Index2WordProcessor(Processor):
|
||||
|
||||
|
||||
class SetTensorProcessor(Processor):
|
||||
# TODO: remove it. It is strange.
|
||||
def __init__(self, field_dict, default=False):
|
||||
super(SetTensorProcessor, self).__init__(None, None)
|
||||
self.field_dict = field_dict
|
||||
@ -264,6 +275,7 @@ class SetTensorProcessor(Processor):
|
||||
|
||||
|
||||
class SetIsTargetProcessor(Processor):
|
||||
# TODO; remove it.
|
||||
def __init__(self, field_dict, default=False):
|
||||
super(SetIsTargetProcessor, self).__init__(None, None)
|
||||
self.field_dict = field_dict
|
||||
|
@ -2,7 +2,7 @@ from .batch import Batch
|
||||
from .dataset import DataSet
|
||||
from .fieldarray import FieldArray
|
||||
from .instance import Instance
|
||||
from .losses import Loss
|
||||
from .losses import LossFromTorch
|
||||
from .optimizer import Optimizer
|
||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
|
||||
from .tester import Tester
|
||||
|
@ -9,32 +9,20 @@ from fastNLP.core.utils import get_func_signature
|
||||
_READERS = {}
|
||||
|
||||
|
||||
def construct_dataset(sentences):
|
||||
"""Construct a data set from a list of sentences.
|
||||
|
||||
:param sentences: list of list of str
|
||||
:return dataset: a DataSet object
|
||||
"""
|
||||
dataset = DataSet()
|
||||
for sentence in sentences:
|
||||
instance = Instance()
|
||||
instance['raw_sentence'] = sentence
|
||||
dataset.append(instance)
|
||||
return dataset
|
||||
|
||||
|
||||
class DataSet(object):
|
||||
"""DataSet is the collection of examples.
|
||||
DataSet provides instance-level interface. You can append and access an instance of the DataSet.
|
||||
However, it stores data in a different way: Field-first, Instance-second.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data=None):
|
||||
"""
|
||||
|
||||
:param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field.
|
||||
All values must be of the same length.
|
||||
If it is a list, it must be a list of Instance objects.
|
||||
:param data: a dict or a list.
|
||||
If `data` is a dict, the key is the name of a FieldArray and the value is the FieldArray. All values
|
||||
must be of the same length.
|
||||
If `data` is a list, it must be a list of Instance objects.
|
||||
"""
|
||||
self.field_arrays = {}
|
||||
if data is not None:
|
||||
@ -60,6 +48,7 @@ class DataSet(object):
|
||||
def iter_func():
|
||||
for idx in range(len(self)):
|
||||
yield self[idx]
|
||||
|
||||
return iter_func()
|
||||
|
||||
def _inner_iter(self):
|
||||
@ -69,7 +58,8 @@ class DataSet(object):
|
||||
self.idx = idx
|
||||
|
||||
def __getitem__(self, item):
|
||||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[self.idx])
|
||||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[
|
||||
self.idx])
|
||||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
|
||||
return self.dataset.field_arrays[item][self.idx]
|
||||
|
||||
@ -79,6 +69,7 @@ class DataSet(object):
|
||||
def inner_iter_func():
|
||||
for idx in range(len(self)):
|
||||
yield Iter_ptr(self, idx)
|
||||
|
||||
return inner_iter_func()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
@ -217,9 +208,17 @@ class DataSet(object):
|
||||
raise KeyError("{} is not a valid field name.".format(name))
|
||||
|
||||
def get_input_name(self):
|
||||
"""Get all field names with `is_input` as True.
|
||||
|
||||
:return list field_names: a list of str
|
||||
"""
|
||||
return [name for name, field in self.field_arrays.items() if field.is_input]
|
||||
|
||||
def get_target_name(self):
|
||||
"""Get all field names with `is_target` as True.
|
||||
|
||||
:return list field_names: a list of str
|
||||
"""
|
||||
return [name for name, field in self.field_arrays.items() if field.is_target]
|
||||
|
||||
@classmethod
|
||||
@ -243,7 +242,7 @@ class DataSet(object):
|
||||
:return results: if new_field_name is not passed, returned values of the function over all instances.
|
||||
"""
|
||||
results = [func(ins) for ins in self._inner_iter()]
|
||||
if len(list(filter(lambda x: x is not None, results)))==0: # all None
|
||||
if len(list(filter(lambda x: x is not None, results))) == 0: # all None
|
||||
raise ValueError("{} always return None.".format(get_func_signature(func=func)))
|
||||
|
||||
extra_param = {}
|
||||
@ -269,6 +268,12 @@ class DataSet(object):
|
||||
return results
|
||||
|
||||
def drop(self, func):
|
||||
"""Drop instances if a condition holds.
|
||||
|
||||
:param func: a function that takes an Instance object as input, and returns bool.
|
||||
The instance will be dropped if the function returns True.
|
||||
|
||||
"""
|
||||
results = [ins for ins in self._inner_iter() if not func(ins)]
|
||||
for name, old_field in self.field_arrays.items():
|
||||
self.field_arrays[name].content = [ins[name] for ins in results]
|
||||
@ -338,10 +343,33 @@ class DataSet(object):
|
||||
return cls(_dict)
|
||||
|
||||
def save(self, path):
|
||||
"""Save the DataSet object as pickle.
|
||||
|
||||
:param str path: the path to the pickle
|
||||
"""
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@staticmethod
|
||||
def load(path):
|
||||
"""Load a DataSet object from pickle.
|
||||
|
||||
:param str path: the path to the pickle
|
||||
:return DataSet data_set:
|
||||
"""
|
||||
with open(path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def construct_dataset(sentences):
|
||||
"""Construct a data set from a list of sentences.
|
||||
|
||||
:param sentences: list of list of str
|
||||
:return dataset: a DataSet object
|
||||
"""
|
||||
dataset = DataSet()
|
||||
for sentence in sentences:
|
||||
instance = Instance()
|
||||
instance['raw_sentence'] = sentence
|
||||
dataset.append(instance)
|
||||
return dataset
|
||||
|
@ -7,14 +7,13 @@ import torch.nn.functional as F
|
||||
from fastNLP.core.utils import CheckError
|
||||
from fastNLP.core.utils import CheckRes
|
||||
from fastNLP.core.utils import _build_args
|
||||
from fastNLP.core.utils import _check_function_or_method
|
||||
from fastNLP.core.utils import _check_arg_dict_list
|
||||
from fastNLP.core.utils import _check_function_or_method
|
||||
from fastNLP.core.utils import get_func_signature
|
||||
|
||||
|
||||
class LossBase(object):
|
||||
def __init__(self):
|
||||
# key: name in target function; value: name in output function
|
||||
self.param_map = {}
|
||||
self._checked = False
|
||||
|
||||
@ -159,8 +158,18 @@ class LossBase(object):
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class LossFunc(LossBase):
|
||||
"""A wrapper of user-provided loss function.
|
||||
|
||||
"""
|
||||
def __init__(self, func, key_map=None, **kwargs):
|
||||
"""
|
||||
|
||||
:param func: a callable object, such as a function.
|
||||
:param dict key_map:
|
||||
:param kwargs:
|
||||
"""
|
||||
super(LossFunc, self).__init__()
|
||||
_check_function_or_method(func)
|
||||
if key_map is not None:
|
||||
@ -254,19 +263,19 @@ def _prepare_losser(losser):
|
||||
|
||||
|
||||
def squash(predict, truth, **kwargs):
|
||||
'''To reshape tensors in order to fit Loss functions in pytorch
|
||||
"""To reshape tensors in order to fit loss functions in pytorch
|
||||
|
||||
:param predict : Tensor, model output
|
||||
:param truth : Tensor, truth from dataset
|
||||
:param **kwargs : extra arguments
|
||||
|
||||
:return predict , truth: predict & truth after processing
|
||||
'''
|
||||
"""
|
||||
return predict.view(-1, predict.size()[-1]), truth.view(-1, )
|
||||
|
||||
|
||||
def unpad(predict, truth, **kwargs):
|
||||
'''To process padded sequence output to get true loss
|
||||
"""To process padded sequence output to get true loss
|
||||
Using pack_padded_sequence() method
|
||||
This method contains squash()
|
||||
|
||||
@ -277,7 +286,7 @@ def unpad(predict, truth, **kwargs):
|
||||
the i-th element is true lengths of i-th sequence
|
||||
|
||||
:return predict , truth: predict & truth after processing
|
||||
'''
|
||||
"""
|
||||
if kwargs.get("lens") is None:
|
||||
return predict, truth
|
||||
lens = torch.LongTensor(kwargs["lens"])
|
||||
@ -288,7 +297,7 @@ def unpad(predict, truth, **kwargs):
|
||||
|
||||
|
||||
def unpad_mask(predict, truth, **kwargs):
|
||||
'''To process padded sequence output to get true loss
|
||||
"""To process padded sequence output to get true loss
|
||||
Using mask() method
|
||||
This method contains squash()
|
||||
|
||||
@ -299,7 +308,7 @@ def unpad_mask(predict, truth, **kwargs):
|
||||
the i-th element is true lengths of i-th sequence
|
||||
|
||||
:return predict , truth: predict & truth after processing
|
||||
'''
|
||||
"""
|
||||
if kwargs.get("lens") is None:
|
||||
return predict, truth
|
||||
mas = make_mask(kwargs["lens"], truth.size()[1])
|
||||
@ -307,7 +316,7 @@ def unpad_mask(predict, truth, **kwargs):
|
||||
|
||||
|
||||
def mask(predict, truth, **kwargs):
|
||||
'''To select specific elements from Tensor
|
||||
"""To select specific elements from Tensor
|
||||
This method contains squash()
|
||||
|
||||
:param predict : Tensor, [batch_size , max_len , tag_size]
|
||||
@ -317,7 +326,7 @@ def mask(predict, truth, **kwargs):
|
||||
the mask Tensor , the position that is 1 will be selected
|
||||
|
||||
:return predict , truth: predict & truth after processing
|
||||
'''
|
||||
"""
|
||||
if kwargs.get("mask") is None:
|
||||
return predict, truth
|
||||
mask = kwargs["mask"]
|
||||
@ -332,14 +341,14 @@ def mask(predict, truth, **kwargs):
|
||||
|
||||
|
||||
def make_mask(lens, tar_len):
|
||||
'''to generate a mask that select [:lens[i]] for i-th element
|
||||
"""to generate a mask that select [:lens[i]] for i-th element
|
||||
embezzle from fastNLP.models.sequence_modeling.seq_mask
|
||||
|
||||
:param lens : list or LongTensor, [batch_size]
|
||||
:param tar_len : int
|
||||
|
||||
:return mask : ByteTensor
|
||||
'''
|
||||
"""
|
||||
lens = torch.LongTensor(lens)
|
||||
mask = [torch.ge(lens, i + 1) for i in range(tar_len)]
|
||||
mask = torch.stack(mask, 1)
|
||||
@ -376,9 +385,11 @@ loss_function_name = {
|
||||
}
|
||||
|
||||
|
||||
class Loss(object):
|
||||
"""a Loss object is a callable object represents loss functions
|
||||
class LossFromTorch(object):
|
||||
"""a LossFromTorch object is a callable object represents loss functions
|
||||
|
||||
This class only helps you with loss functions from PyTorch.
|
||||
It has nothing to do with Trainer.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_name, pre_pro=[squash], **kwargs):
|
||||
@ -408,11 +419,11 @@ class Loss(object):
|
||||
self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro]
|
||||
|
||||
def add_pre_pro(self, func):
|
||||
'''add a pre_pro function
|
||||
"""add a pre_pro function
|
||||
|
||||
:param func: a function or str, methods to reform parameters before calculating loss
|
||||
the strings will be auto translated to pre-defined functions
|
||||
'''
|
||||
"""
|
||||
if not callable(func):
|
||||
func = method_dict.get(func)
|
||||
if func is None:
|
||||
@ -421,12 +432,12 @@ class Loss(object):
|
||||
|
||||
@staticmethod
|
||||
def _get_loss(loss_name, **kwargs):
|
||||
'''Get loss function from torch
|
||||
"""Get loss function from torch
|
||||
|
||||
:param loss_name: str, the name of loss function
|
||||
:param **kwargs: kwargs for torch loss function
|
||||
:return: A callable loss function object
|
||||
'''
|
||||
"""
|
||||
loss_name = loss_name.strip().lower()
|
||||
loss_name = "".join(loss_name.split("_"))
|
||||
|
||||
@ -435,19 +446,19 @@ class Loss(object):
|
||||
return loss_function_name[loss_name](**kwargs)
|
||||
|
||||
def get(self):
|
||||
'''This method exists just for make some existing codes run error-freely
|
||||
'''
|
||||
"""This method exists just for make some existing codes run error-freely
|
||||
"""
|
||||
return self
|
||||
|
||||
def __call__(self, predict, truth, **kwargs):
|
||||
'''call a loss function
|
||||
"""Call a loss function
|
||||
predict and truth will be processed by pre_pro methods in order of addition
|
||||
|
||||
:param predict : Tensor, model output
|
||||
:param truth : Tensor, truth from dataset
|
||||
:param **kwargs : extra arguments, pass to pre_pro functions
|
||||
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens
|
||||
'''
|
||||
"""
|
||||
for f in self.pre_pro:
|
||||
if f is None:
|
||||
continue
|
||||
|
@ -308,6 +308,13 @@ def _prepare_metrics(metrics):
|
||||
return _metrics
|
||||
|
||||
|
||||
"""
|
||||
Attention: Codes below are not used in current FastNLP.
|
||||
However, it is useful.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _conver_numpy(x):
|
||||
"""convert input data to numpy array
|
||||
|
||||
|
@ -11,6 +11,12 @@ class Optimizer(object):
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, model_params=None, lr=0.01, momentum=0):
|
||||
"""
|
||||
|
||||
:param model_params: a generator. E.g. model.parameters() for PyTorch models.
|
||||
:param float lr: learning rate. Default: 0.01
|
||||
:param float momentum: momentum. Default: 0
|
||||
"""
|
||||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)
|
||||
|
||||
def construct_from_pytorch(self, model_params):
|
||||
@ -23,6 +29,12 @@ class SGD(Optimizer):
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, model_params=None, lr=0.01, weight_decay=0):
|
||||
"""
|
||||
|
||||
:param model_params: a generator. E.g. model.parameters() for PyTorch models.
|
||||
:param float lr: learning rate
|
||||
:param float weight_decay:
|
||||
"""
|
||||
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)
|
||||
|
||||
def construct_from_pytorch(self, model_params):
|
||||
|
@ -140,7 +140,6 @@ class Trainer(object):
|
||||
def train(self):
|
||||
"""Start Training.
|
||||
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
@ -216,14 +215,6 @@ class Trainer(object):
|
||||
pbar.close()
|
||||
|
||||
def _print_train(self):
|
||||
"""
|
||||
|
||||
:param data_iterator:
|
||||
:param model:
|
||||
:param epoch:
|
||||
:param start:
|
||||
:return:
|
||||
"""
|
||||
epoch = 1
|
||||
start = time.time()
|
||||
while epoch <= self.n_epochs:
|
||||
|
@ -29,19 +29,3 @@ class BaseLoader(object):
|
||||
with open(cache_path, 'wb') as f:
|
||||
pickle.dump(obj, f)
|
||||
return obj
|
||||
|
||||
|
||||
class ToyLoader0(BaseLoader):
|
||||
"""
|
||||
For CharLM
|
||||
"""
|
||||
|
||||
def __init__(self, data_path):
|
||||
super(ToyLoader0, self).__init__(data_path)
|
||||
|
||||
def load(self):
|
||||
with open(self.data_path, 'r') as f:
|
||||
corpus = f.read().lower()
|
||||
import re
|
||||
corpus = re.sub(r"<unk>", "unk", corpus)
|
||||
return corpus.split()
|
||||
|
@ -1,6 +1,152 @@
|
||||
import configparser
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.io.base_loader import BaseLoader
|
||||
|
||||
|
||||
class ConfigLoader(BaseLoader):
|
||||
"""loader for configuration files"""
|
||||
|
||||
def __init__(self, data_path=None):
|
||||
super(ConfigLoader, self).__init__()
|
||||
if data_path is not None:
|
||||
self.config = self.parse(super(ConfigLoader, self).load(data_path))
|
||||
|
||||
@staticmethod
|
||||
def parse(string):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def load_config(file_path, sections):
|
||||
"""
|
||||
:param file_path: the path of config file
|
||||
:param sections: the dict of {section_name(string): Section instance}
|
||||
Example:
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
|
||||
:return: return nothing, but the value of attributes are saved in sessions
|
||||
"""
|
||||
assert isinstance(sections, dict)
|
||||
cfg = configparser.ConfigParser()
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError("config file {} not found. ".format(file_path))
|
||||
cfg.read(file_path)
|
||||
for s in sections:
|
||||
attr_list = [i for i in sections[s].__dict__.keys() if
|
||||
not callable(getattr(sections[s], i)) and not i.startswith("__")]
|
||||
if s not in cfg:
|
||||
print('section %s not found in config file' % (s))
|
||||
continue
|
||||
gen_sec = cfg[s]
|
||||
for attr in gen_sec.keys():
|
||||
try:
|
||||
val = json.loads(gen_sec[attr])
|
||||
# print(s, attr, val, type(val))
|
||||
if attr in attr_list:
|
||||
assert type(val) == type(getattr(sections[s], attr)), \
|
||||
'type not match, except %s but got %s' % \
|
||||
(type(getattr(sections[s], attr)), type(val))
|
||||
"""
|
||||
if attr in attr_list then check its type and
|
||||
update its value.
|
||||
else add a new attr in sections[s]
|
||||
"""
|
||||
setattr(sections[s], attr, val)
|
||||
except Exception as e:
|
||||
print("cannot load attribute %s in section %s"
|
||||
% (attr, s))
|
||||
pass
|
||||
|
||||
|
||||
class ConfigSection(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""
|
||||
:param key: str, the name of the attribute
|
||||
:return attr: the value of this attribute
|
||||
if key not in self.__dict__.keys():
|
||||
return self[key]
|
||||
else:
|
||||
raise AttributeError
|
||||
"""
|
||||
if key in self.__dict__.keys():
|
||||
return getattr(self, key)
|
||||
raise AttributeError("do NOT have attribute %s" % key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""
|
||||
:param key: str, the name of the attribute
|
||||
:param value: the value of this attribute
|
||||
if key not in self.__dict__.keys():
|
||||
self[key] will be added
|
||||
else:
|
||||
self[key] will be updated
|
||||
"""
|
||||
if key in self.__dict__.keys():
|
||||
if not isinstance(value, type(getattr(self, key))):
|
||||
raise AttributeError("attr %s except %s but got %s" %
|
||||
(key, str(type(getattr(self, key))), str(type(value))))
|
||||
setattr(self, key, value)
|
||||
|
||||
def __contains__(self, item):
|
||||
"""
|
||||
:param item: The key of item.
|
||||
:return: True if the key in self.__dict__.keys() else False.
|
||||
"""
|
||||
return item in self.__dict__.keys()
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Overwrite the == operator
|
||||
|
||||
:param other: Another ConfigSection() object which to be compared.
|
||||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False.
|
||||
"""
|
||||
for k in self.__dict__.keys():
|
||||
if k not in other.__dict__.keys():
|
||||
return False
|
||||
if getattr(self, k) != getattr(self, k):
|
||||
return False
|
||||
|
||||
for k in other.__dict__.keys():
|
||||
if k not in self.__dict__.keys():
|
||||
return False
|
||||
if getattr(self, k) != getattr(self, k):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
"""Overwrite the != operator
|
||||
|
||||
:param other:
|
||||
:return:
|
||||
"""
|
||||
return not self.__eq__(other)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.__dict__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = ConfigLoader('there is no data')
|
||||
|
||||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
|
||||
"""
|
||||
General and My can be found in config file, so the attr and
|
||||
value will be updated
|
||||
A cannot be found in config file, so nothing will be done
|
||||
"""
|
||||
|
||||
config.load_config("../../test/data_for_tests/config", section)
|
||||
for s in section:
|
||||
print(s)
|
||||
for attr in section[s].__dict__.keys():
|
||||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))
|
||||
|
||||
|
||||
class ConfigSaver(object):
|
||||
@ -125,7 +271,7 @@ class ConfigSaver(object):
|
||||
# logger = create_logger(__name__, "./config_loader.log")
|
||||
# logger.warning("section [%s] in config file [%s] has been changed" % (
|
||||
# section_name, self.file_path
|
||||
#))
|
||||
# ))
|
||||
change_file = True
|
||||
break
|
||||
if not change_file:
|
@ -1,149 +0,0 @@
|
||||
import configparser
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastNLP.io.base_loader import BaseLoader
|
||||
|
||||
|
||||
class ConfigLoader(BaseLoader):
|
||||
"""loader for configuration files"""
|
||||
|
||||
def __init__(self, data_path=None):
|
||||
super(ConfigLoader, self).__init__()
|
||||
if data_path is not None:
|
||||
self.config = self.parse(super(ConfigLoader, self).load(data_path))
|
||||
|
||||
@staticmethod
|
||||
def parse(string):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def load_config(file_path, sections):
|
||||
"""
|
||||
:param file_path: the path of config file
|
||||
:param sections: the dict of {section_name(string): Section instance}
|
||||
Example:
|
||||
test_args = ConfigSection()
|
||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
|
||||
:return: return nothing, but the value of attributes are saved in sessions
|
||||
"""
|
||||
assert isinstance(sections, dict)
|
||||
cfg = configparser.ConfigParser()
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError("config file {} not found. ".format(file_path))
|
||||
cfg.read(file_path)
|
||||
for s in sections:
|
||||
attr_list = [i for i in sections[s].__dict__.keys() if
|
||||
not callable(getattr(sections[s], i)) and not i.startswith("__")]
|
||||
if s not in cfg:
|
||||
print('section %s not found in config file' % (s))
|
||||
continue
|
||||
gen_sec = cfg[s]
|
||||
for attr in gen_sec.keys():
|
||||
try:
|
||||
val = json.loads(gen_sec[attr])
|
||||
# print(s, attr, val, type(val))
|
||||
if attr in attr_list:
|
||||
assert type(val) == type(getattr(sections[s], attr)), \
|
||||
'type not match, except %s but got %s' % \
|
||||
(type(getattr(sections[s], attr)), type(val))
|
||||
"""
|
||||
if attr in attr_list then check its type and
|
||||
update its value.
|
||||
else add a new attr in sections[s]
|
||||
"""
|
||||
setattr(sections[s], attr, val)
|
||||
except Exception as e:
|
||||
print("cannot load attribute %s in section %s"
|
||||
% (attr, s))
|
||||
pass
|
||||
|
||||
|
||||
class ConfigSection(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""
|
||||
:param key: str, the name of the attribute
|
||||
:return attr: the value of this attribute
|
||||
if key not in self.__dict__.keys():
|
||||
return self[key]
|
||||
else:
|
||||
raise AttributeError
|
||||
"""
|
||||
if key in self.__dict__.keys():
|
||||
return getattr(self, key)
|
||||
raise AttributeError("do NOT have attribute %s" % key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""
|
||||
:param key: str, the name of the attribute
|
||||
:param value: the value of this attribute
|
||||
if key not in self.__dict__.keys():
|
||||
self[key] will be added
|
||||
else:
|
||||
self[key] will be updated
|
||||
"""
|
||||
if key in self.__dict__.keys():
|
||||
if not isinstance(value, type(getattr(self, key))):
|
||||
raise AttributeError("attr %s except %s but got %s" %
|
||||
(key, str(type(getattr(self, key))), str(type(value))))
|
||||
setattr(self, key, value)
|
||||
|
||||
def __contains__(self, item):
|
||||
"""
|
||||
:param item: The key of item.
|
||||
:return: True if the key in self.__dict__.keys() else False.
|
||||
"""
|
||||
return item in self.__dict__.keys()
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Overwrite the == operator
|
||||
|
||||
:param other: Another ConfigSection() object which to be compared.
|
||||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False.
|
||||
"""
|
||||
for k in self.__dict__.keys():
|
||||
if k not in other.__dict__.keys():
|
||||
return False
|
||||
if getattr(self, k) != getattr(self, k):
|
||||
return False
|
||||
|
||||
for k in other.__dict__.keys():
|
||||
if k not in self.__dict__.keys():
|
||||
return False
|
||||
if getattr(self, k) != getattr(self, k):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
"""Overwrite the != operator
|
||||
|
||||
:param other:
|
||||
:return:
|
||||
"""
|
||||
return not self.__eq__(other)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.__dict__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = ConfigLoader('there is no data')
|
||||
|
||||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
|
||||
"""
|
||||
General and My can be found in config file, so the attr and
|
||||
value will be updated
|
||||
A cannot be found in config file, so nothing will be done
|
||||
"""
|
||||
|
||||
config.load_config("../../test/data_for_tests/config", section)
|
||||
for s in section:
|
||||
print(s)
|
||||
for attr in section[s].__dict__.keys():
|
||||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))
|
@ -1,4 +1,3 @@
|
||||
#TODO: need fix for current DataSet
|
||||
import os
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
@ -20,8 +19,7 @@ def convert_seq_dataset(data):
|
||||
"""
|
||||
dataset = DataSet()
|
||||
for word_seq in data:
|
||||
x = TextField(word_seq, is_target=False)
|
||||
dataset.append(Instance(word_seq=x))
|
||||
dataset.append(Instance(word_seq=word_seq))
|
||||
return dataset
|
||||
|
||||
|
||||
@ -40,11 +38,7 @@ def convert_seq2tag_dataset(data):
|
||||
"""
|
||||
dataset = DataSet()
|
||||
for sample in data:
|
||||
word_seq, label = sample[0], sample[1]
|
||||
ins = Instance()
|
||||
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \
|
||||
.add_field("label", LabelField(label, is_target=True))
|
||||
dataset.append(ins)
|
||||
dataset.append(Instance(word_seq=sample[0], label=sample[1]))
|
||||
return dataset
|
||||
|
||||
|
||||
@ -63,11 +57,7 @@ def convert_seq2seq_dataset(data):
|
||||
"""
|
||||
dataset = DataSet()
|
||||
for sample in data:
|
||||
word_seq, label_seq = sample[0], sample[1]
|
||||
ins = Instance()
|
||||
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \
|
||||
.add_field("label_seq", TextField(label_seq, is_target=True))
|
||||
dataset.append(ins)
|
||||
dataset.append(Instance(word_seq=sample[0], label_seq=sample[1]))
|
||||
return dataset
|
||||
|
||||
|
||||
@ -273,85 +263,6 @@ class ClassDataSetLoader(DataSetLoader):
|
||||
return convert_seq2tag_dataset(data)
|
||||
|
||||
|
||||
@DataSet.set_reader('read_conll')
|
||||
class ConllLoader(DataSetLoader):
|
||||
"""loader for conll format files"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
:param str data_path: the path to the conll data set
|
||||
"""
|
||||
super(ConllLoader, self).__init__()
|
||||
|
||||
def load(self, data_path):
|
||||
"""
|
||||
:return: list lines: all lines in a conll file
|
||||
"""
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
data = self.parse(lines)
|
||||
return self.convert(data)
|
||||
|
||||
@staticmethod
|
||||
def parse(lines):
|
||||
"""
|
||||
:param list lines:a list containing all lines in a conll file.
|
||||
:return: a 3D list
|
||||
"""
|
||||
sentences = list()
|
||||
tokens = list()
|
||||
for line in lines:
|
||||
if line[0] == "#":
|
||||
# skip the comments
|
||||
continue
|
||||
if line == "\n":
|
||||
sentences.append(tokens)
|
||||
tokens = []
|
||||
continue
|
||||
tokens.append(line.split())
|
||||
return sentences
|
||||
|
||||
def convert(self, data):
|
||||
pass
|
||||
|
||||
|
||||
@DataSet.set_reader('read_lm')
|
||||
class LMDataSetLoader(DataSetLoader):
|
||||
"""Language Model Dataset Loader
|
||||
|
||||
This loader produces data for language model training in a supervised way.
|
||||
That means it has X and Y.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(LMDataSetLoader, self).__init__()
|
||||
|
||||
def load(self, data_path):
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError("file {} not found.".format(data_path))
|
||||
with open(data_path, "r", encoding="utf=8") as f:
|
||||
text = " ".join(f.readlines())
|
||||
tokens = text.strip().split()
|
||||
data = self.sentence_cut(tokens)
|
||||
return self.convert(data)
|
||||
|
||||
def sentence_cut(self, tokens, sentence_length=15):
|
||||
start_idx = 0
|
||||
data_set = []
|
||||
for idx in range(len(tokens) // sentence_length):
|
||||
x = tokens[start_idx * idx: start_idx * idx + sentence_length]
|
||||
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1]
|
||||
if start_idx * idx + sentence_length + 1 >= len(tokens):
|
||||
# ad hoc
|
||||
y.extend(["<unk>"])
|
||||
data_set.append([x, y])
|
||||
return data_set
|
||||
|
||||
def convert(self, data):
|
||||
pass
|
||||
|
||||
|
||||
@DataSet.set_reader('read_people_daily')
|
||||
class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
"""
|
||||
@ -403,10 +314,19 @@ class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
pos_tag_examples.append([sent_words, sent_pos_tag])
|
||||
ner_examples.append([sent_words, sent_ner])
|
||||
# List[List[List[str], List[str]]]
|
||||
return pos_tag_examples, ner_examples
|
||||
# ner_examples not used
|
||||
return self.convert(pos_tag_examples)
|
||||
|
||||
def convert(self, data):
|
||||
pass
|
||||
data_set = DataSet()
|
||||
for item in data:
|
||||
sent_words, sent_pos_tag = item[0], item[1]
|
||||
data_set.append(Instance(words=sent_words, tags=sent_pos_tag))
|
||||
data_set.apply(lambda ins: len(ins), new_field_name="seq_len")
|
||||
data_set.set_target("tags")
|
||||
data_set.set_input("sent_words")
|
||||
data_set.set_input("seq_len")
|
||||
return data_set
|
||||
|
||||
|
||||
class SNLIDataSetLoader(DataSetLoader):
|
||||
@ -462,17 +382,13 @@ class SNLIDataSetLoader(DataSetLoader):
|
||||
for example in data:
|
||||
p, h, l = example
|
||||
# list, list, str
|
||||
x1 = TextField(p, is_target=False)
|
||||
x2 = TextField(h, is_target=False)
|
||||
x1_len = TextField([1] * len(p), is_target=False)
|
||||
x2_len = TextField([1] * len(h), is_target=False)
|
||||
y = LabelField(l, is_target=True)
|
||||
instance = Instance()
|
||||
instance.add_field("premise", x1)
|
||||
instance.add_field("hypothesis", x2)
|
||||
instance.add_field("premise_len", x1_len)
|
||||
instance.add_field("hypothesis_len", x2_len)
|
||||
instance.add_field("truth", y)
|
||||
instance.add_field("premise", p)
|
||||
instance.add_field("hypothesis", h)
|
||||
instance.add_field("truth", l)
|
||||
data_set.append(instance)
|
||||
|
||||
data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len")
|
||||
data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len")
|
||||
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len")
|
||||
data_set.set_target("truth")
|
||||
return data_set
|
||||
|
@ -1,5 +1,32 @@
|
||||
import torch
|
||||
|
||||
from fastNLP.io.base_loader import BaseLoader
|
||||
|
||||
|
||||
class ModelLoader(BaseLoader):
|
||||
"""
|
||||
Loader for models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ModelLoader, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def load_pytorch(empty_model, model_path):
|
||||
"""
|
||||
Load model parameters from .pkl files into the empty PyTorch model.
|
||||
:param empty_model: a PyTorch model with initialized parameters.
|
||||
:param model_path: str, the path to the saved model.
|
||||
"""
|
||||
empty_model.load_state_dict(torch.load(model_path))
|
||||
|
||||
@staticmethod
|
||||
def load_pytorch_model(model_path):
|
||||
"""Load the entire model.
|
||||
|
||||
"""
|
||||
return torch.load(model_path)
|
||||
|
||||
|
||||
class ModelSaver(object):
|
||||
"""Save a model
|
||||
@ -8,6 +35,7 @@ class ModelSaver(object):
|
||||
saver.save_pytorch(model)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, save_path):
|
||||
"""
|
||||
|
@ -1,28 +0,0 @@
|
||||
import torch
|
||||
|
||||
from fastNLP.io.base_loader import BaseLoader
|
||||
|
||||
|
||||
class ModelLoader(BaseLoader):
|
||||
"""
|
||||
Loader for models.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ModelLoader, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def load_pytorch(empty_model, model_path):
|
||||
"""
|
||||
Load model parameters from .pkl files into the empty PyTorch model.
|
||||
:param empty_model: a PyTorch model with initialized parameters.
|
||||
:param model_path: str, the path to the saved model.
|
||||
"""
|
||||
empty_model.load_state_dict(torch.load(model_path))
|
||||
|
||||
@staticmethod
|
||||
def load_pytorch_model(model_path):
|
||||
"""Load the entire model.
|
||||
|
||||
"""
|
||||
return torch.load(model_path)
|
@ -5,7 +5,7 @@ sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])
|
||||
|
||||
from fastNLP.api.processor import *
|
||||
from fastNLP.models.biaffine_parser import BiaffineParser
|
||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.io.config_io import ConfigSection, ConfigLoader
|
||||
|
||||
import _pickle as pickle
|
||||
import torch
|
||||
|
@ -13,11 +13,10 @@ from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.field import TextField, SeqLabelField
|
||||
from fastNLP.core.tester import Tester
|
||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.model_loader import ModelLoader
|
||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.model_io import ModelLoader, ModelSaver
|
||||
from fastNLP.io.embed_loader import EmbedLoader
|
||||
from fastNLP.models.biaffine_parser import BiaffineParser
|
||||
from fastNLP.io.model_saver import ModelSaver
|
||||
|
||||
BOS = '<BOS>'
|
||||
EOS = '<EOS>'
|
||||
|
@ -2,8 +2,8 @@ import torch.nn.functional as F
|
||||
|
||||
from fastNLP.core.trainer import ClassificationTrainer
|
||||
from fastNLP.core.utils import ClassPreprocess as Preprocess
|
||||
from fastNLP.io.config_loader import ConfigLoader
|
||||
from fastNLP.io.config_loader import ConfigSection
|
||||
from fastNLP.io.config_io import ConfigLoader
|
||||
from fastNLP.io.config_io import ConfigSection
|
||||
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules.aggregator.self_attention import SelfAttention
|
||||
|
@ -3,12 +3,11 @@ import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection
|
||||
from fastNLP.core.trainer import SeqLabelTrainer
|
||||
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader
|
||||
from fastNLP.core.utils import load_pickle
|
||||
from fastNLP.io.model_saver import ModelSaver
|
||||
from fastNLP.io.model_loader import ModelLoader
|
||||
from fastNLP.io.model_io import ModelLoader, ModelSaver
|
||||
from fastNLP.core.tester import SeqLabelTester
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
from fastNLP.core.predictor import SeqLabelInfer
|
||||
|
4
setup.py
4
setup.py
@ -12,12 +12,12 @@ with open('requirements.txt', encoding='utf-8') as f:
|
||||
reqs = f.read()
|
||||
|
||||
setup(
|
||||
name='fastNLP',
|
||||
name='FastNLP',
|
||||
version='0.1.1',
|
||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
|
||||
long_description=readme,
|
||||
license=license,
|
||||
author='fudanNLP',
|
||||
author='FudanNLP',
|
||||
python_requires='>=3.5',
|
||||
packages=find_packages(),
|
||||
install_requires=reqs.strip().split('\n'),
|
||||
|
12
test/api/test_processor.py
Normal file
12
test/api/test_processor.py
Normal file
@ -0,0 +1,12 @@
|
||||
import unittest
|
||||
|
||||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
|
||||
from fastNLP.core.dataset import DataSet
|
||||
|
||||
|
||||
class TestProcessor(unittest.TestCase):
|
||||
def test_FullSpaceToHalfSpaceProcessor(self):
|
||||
ds = DataSet({"word": ["00, u1, u), (u2, u2"]})
|
||||
proc = FullSpaceToHalfSpaceProcessor("word")
|
||||
ds = proc(ds)
|
||||
self.assertTrue(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"])
|
@ -45,7 +45,7 @@ class TestLoss(unittest.TestCase):
|
||||
# 验证squash()的正确性
|
||||
|
||||
log = math.log
|
||||
loss_func = loss.Loss("nll")
|
||||
loss_func = loss.LossFromTorch("nll")
|
||||
|
||||
y = tc.Tensor(
|
||||
[
|
||||
@ -129,7 +129,7 @@ class TestLoss(unittest.TestCase):
|
||||
lens = [4, 2, 1]
|
||||
y = tc.log(y)
|
||||
|
||||
loss_func = loss.Loss("nll", pre_pro=["unpad"])
|
||||
loss_func = loss.LossFromTorch("nll", pre_pro=["unpad"])
|
||||
los = loss_func(y, gy, lens=lens)
|
||||
|
||||
r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
|
||||
@ -169,7 +169,7 @@ class TestLoss(unittest.TestCase):
|
||||
|
||||
lens = [2, 4, 2]
|
||||
|
||||
loss_func = loss.Loss("nll", pre_pro=["mask"])
|
||||
loss_func = loss.LossFromTorch("nll", pre_pro=["mask"])
|
||||
los = loss_func(y, gy, mask=mask)
|
||||
|
||||
los2 = loss_func(y, gy, mask=loss.make_mask(lens, gy.size()[-1]))
|
||||
@ -205,7 +205,7 @@ class TestLoss(unittest.TestCase):
|
||||
|
||||
y = tc.log(y)
|
||||
|
||||
loss_func = loss.Loss("nll", pre_pro=["unpad_mask"])
|
||||
loss_func = loss.LossFromTorch("nll", pre_pro=["unpad_mask"])
|
||||
los = loss_func(y, gy, lens=lens)
|
||||
|
||||
r = -log(.1) - log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
|
||||
@ -235,7 +235,7 @@ class TestLoss(unittest.TestCase):
|
||||
lens = [4, 2, 1]
|
||||
y = tc.log(y)
|
||||
|
||||
loss_func = loss.Loss("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0]))
|
||||
loss_func = loss.LossFromTorch("nll", pre_pro=[], weight=tc.Tensor([1, 1, 0]))
|
||||
loss_func.add_pre_pro("unpad_mask")
|
||||
los = loss_func(y, gy, lens=lens)
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
|
||||
from fastNLP.io.config_saver import ConfigSaver
|
||||
from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver
|
||||
|
||||
|
||||
class TestConfigSaver(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user