mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
Merge branch 'trainer' of github.com:FengZiYjun/fastNLP into trainer
# Conflicts: # test/core/test_trainer.py
This commit is contained in:
commit
abe5ec7261
@ -11,7 +11,7 @@ class FieldArray(object):
|
||||
"""
|
||||
|
||||
:param str name: the name of the FieldArray
|
||||
:param list content: a list of int, float, or a list of list.
|
||||
:param list content: a list of int, float, str or np.ndarray, or a list of list of one.
|
||||
:param int padding_val: the integer for padding. Default: 0.
|
||||
:param bool is_target: If True, this FieldArray is used to compute loss.
|
||||
:param bool is_input: If True, this FieldArray is used to the model input.
|
||||
@ -27,35 +27,46 @@ class FieldArray(object):
|
||||
self.padding_val = padding_val
|
||||
self.is_target = is_target
|
||||
self.is_input = is_input
|
||||
|
||||
self.BASIC_TYPES = (int, float, str, np.ndarray)
|
||||
self.is_2d_list = False
|
||||
self.pytype = self._type_detection(content)
|
||||
self.dtype = self._map_to_np_type(self.pytype)
|
||||
|
||||
@staticmethod
|
||||
def _type_detection(content):
|
||||
def _type_detection(self, content):
|
||||
"""
|
||||
|
||||
:param content: a list of int, float, str or np.ndarray, or a list of list of one.
|
||||
:return type: one of int, float, str, np.ndarray
|
||||
|
||||
"""
|
||||
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list):
|
||||
# 2-D list
|
||||
# TODO: refactor
|
||||
type_set = set([type(item) for item in content[0]])
|
||||
else:
|
||||
# 1-D list
|
||||
# content is a 2-D list
|
||||
type_set = set([self._type_detection(x) for x in content])
|
||||
if len(type_set) > 1:
|
||||
raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set))
|
||||
self.is_2d_list = True
|
||||
return type_set.pop()
|
||||
|
||||
elif isinstance(content, list):
|
||||
# content is a 1-D list
|
||||
if len(content) == 0:
|
||||
raise RuntimeError("Cannot create FieldArray with an empty list.")
|
||||
type_set = set([type(item) for item in content])
|
||||
|
||||
if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)):
|
||||
return type_set.pop()
|
||||
elif len(type_set) == 2 and float in type_set and int in type_set:
|
||||
# up-cast int to float
|
||||
for idx, _ in enumerate(content):
|
||||
content[idx] = float(content[idx])
|
||||
return float
|
||||
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES:
|
||||
return type_set.pop()
|
||||
elif len(type_set) == 2 and float in type_set and int in type_set:
|
||||
# up-cast int to float
|
||||
return float
|
||||
else:
|
||||
raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set))
|
||||
else:
|
||||
raise ValueError("Unsupported type conversion detected in FieldArray: {}".format(*type_set))
|
||||
raise RuntimeError("Cannot create FieldArray with type {}".format(type(content)))
|
||||
|
||||
@staticmethod
|
||||
def _map_to_np_type(basic_type):
|
||||
type_mapping = {int: np.int64, float: np.float64, str: np.str}
|
||||
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray}
|
||||
return type_mapping[basic_type]
|
||||
|
||||
def __repr__(self):
|
||||
@ -64,29 +75,35 @@ class FieldArray(object):
|
||||
def append(self, val):
|
||||
"""Add a new item to the tail of FieldArray.
|
||||
|
||||
:param val: int, float, str, or a list of them.
|
||||
:param val: int, float, str, or a list of one.
|
||||
"""
|
||||
val_type = type(val)
|
||||
if val_type is int and self.pytype is float:
|
||||
# up-cast the appended value
|
||||
val = float(val)
|
||||
elif val_type is float and self.pytype is int:
|
||||
# up-cast all other values in the content
|
||||
for idx, _ in enumerate(self.content):
|
||||
self.content[idx] = float(self.content[idx])
|
||||
if val_type == list: # shape check
|
||||
if self.is_2d_list is False:
|
||||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.")
|
||||
if len(val) == 0:
|
||||
raise RuntimeError("Cannot append an empty list.")
|
||||
val_list_type = set([type(_) for _ in val]) # type check
|
||||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type:
|
||||
# up-cast int to float
|
||||
val_type = float
|
||||
elif len(val_list_type) == 1:
|
||||
val_type = val_list_type.pop()
|
||||
else:
|
||||
raise RuntimeError("Cannot append a list of {}".format(val_list_type))
|
||||
else:
|
||||
if self.is_2d_list is True:
|
||||
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.")
|
||||
if val_type == float and self.pytype == int:
|
||||
# up-cast
|
||||
self.pytype = float
|
||||
self.dtype = self._map_to_np_type(self.pytype)
|
||||
elif val_type is list:
|
||||
if len(val) == 0:
|
||||
raise ValueError("Cannot append an empty list.")
|
||||
else:
|
||||
if type(val[0]) != self.pytype:
|
||||
raise ValueError(
|
||||
"Cannot append a list of {}-type value into a {}-tpye FieldArray.".
|
||||
format(type(val[0]), self.pytype))
|
||||
elif val_type != self.pytype:
|
||||
raise ValueError("Cannot append a {}-type value into a {}-tpye FieldArray.".format(val_type, self.pytype))
|
||||
|
||||
elif val_type == int and self.pytype == float:
|
||||
pass
|
||||
elif val_type == self.pytype:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype))
|
||||
self.content.append(val)
|
||||
|
||||
def __getitem__(self, indices):
|
||||
@ -102,7 +119,6 @@ class FieldArray(object):
|
||||
:param indices: an int, or a list of int.
|
||||
:return:
|
||||
"""
|
||||
# TODO: 返回行为不一致,有隐患
|
||||
if isinstance(indices, int):
|
||||
return self.content[indices]
|
||||
assert self.is_input is True or self.is_target is True
|
||||
|
@ -126,6 +126,7 @@ class LossBase(object):
|
||||
for keys, val in target_dict.items():
|
||||
param_val_dict.update({keys: val})
|
||||
|
||||
# TODO: use the origin key to raise error
|
||||
if not self._checked:
|
||||
for keys in args:
|
||||
if param_map[keys] not in param_val_dict.keys():
|
||||
|
@ -1,3 +1,4 @@
|
||||
numpy>=1.14.2
|
||||
torch>=0.4.0
|
||||
tensorboardX
|
||||
tqdm
|
@ -24,19 +24,31 @@ class TestFieldArray(unittest.TestCase):
|
||||
def test_type_conversion(self):
|
||||
fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True)
|
||||
self.assertEqual(fa.pytype, float)
|
||||
self.assertEqual(fa.dtype, np.double)
|
||||
self.assertEqual(fa.dtype, np.float64)
|
||||
|
||||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
|
||||
fa.append(1.3333)
|
||||
self.assertEqual(fa.pytype, float)
|
||||
self.assertEqual(fa.dtype, np.double)
|
||||
self.assertEqual(fa.dtype, np.float64)
|
||||
|
||||
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False)
|
||||
fa.append(10)
|
||||
self.assertEqual(fa.pytype, float)
|
||||
self.assertEqual(fa.dtype, np.double)
|
||||
self.assertEqual(fa.dtype, np.float64)
|
||||
|
||||
fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False)
|
||||
fa.append("e")
|
||||
self.assertEqual(fa.dtype, np.str)
|
||||
self.assertEqual(fa.pytype, str)
|
||||
|
||||
def test_support_np_array(self):
|
||||
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False)
|
||||
self.assertEqual(fa.dtype, np.ndarray)
|
||||
|
||||
fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5]))
|
||||
self.assertEqual(fa.pytype, np.ndarray)
|
||||
|
||||
def test_nested_list(self):
|
||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False)
|
||||
self.assertEqual(fa.pytype, float)
|
||||
self.assertEqual(fa.dtype, np.float64)
|
||||
|
@ -1,8 +1,8 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.instance import Instance
|
||||
@ -27,6 +27,7 @@ def prepare_fake_dataset():
|
||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
|
||||
return data_set
|
||||
|
||||
|
||||
def prepare_fake_dataset2(*args, size=100):
|
||||
ys = np.random.randint(4, size=100)
|
||||
data = {'y': ys}
|
||||
@ -34,6 +35,7 @@ def prepare_fake_dataset2(*args, size=100):
|
||||
data[arg] = np.random.randn(size, 5)
|
||||
return DataSet(data=data)
|
||||
|
||||
|
||||
class TrainerTestGround(unittest.TestCase):
|
||||
def test_case(self):
|
||||
data_set = prepare_fake_dataset()
|
||||
@ -56,15 +58,20 @@ class TrainerTestGround(unittest.TestCase):
|
||||
check_code_level=2,
|
||||
use_tqdm=True)
|
||||
trainer.train()
|
||||
"""
|
||||
# 应该正确运行
|
||||
"""
|
||||
|
||||
def test_trainer_suggestion1(self):
|
||||
# 检查报错提示能否正确提醒用户。
|
||||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。
|
||||
dataset = prepare_fake_dataset2('x')
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(5, 4)
|
||||
|
||||
def forward(self, x1, x2, y):
|
||||
x1 = self.fc(x1)
|
||||
x2 = self.fc(x2)
|
||||
@ -73,10 +80,12 @@ class TrainerTestGround(unittest.TestCase):
|
||||
return {'loss': loss}
|
||||
|
||||
model = Model()
|
||||
trainer = Trainer(
|
||||
train_data=dataset,
|
||||
model=model
|
||||
)
|
||||
|
||||
with self.assertRaises(NameError):
|
||||
trainer = Trainer(
|
||||
train_data=dataset,
|
||||
model=model
|
||||
)
|
||||
"""
|
||||
# 应该获取到的报错提示
|
||||
NameError:
|
||||
@ -92,10 +101,12 @@ class TrainerTestGround(unittest.TestCase):
|
||||
# 这里传入forward需要的数据,看是否可以运行
|
||||
dataset = prepare_fake_dataset2('x1', 'x2')
|
||||
dataset.set_input('x1', 'x2', 'y', flag=True)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(5, 4)
|
||||
|
||||
def forward(self, x1, x2, y):
|
||||
x1 = self.fc(x1)
|
||||
x2 = self.fc(x2)
|
||||
@ -120,10 +131,12 @@ class TrainerTestGround(unittest.TestCase):
|
||||
# 这里传入forward需要的数据,但是forward没有返回loss这个key
|
||||
dataset = prepare_fake_dataset2('x1', 'x2')
|
||||
dataset.set_input('x1', 'x2', 'y', flag=True)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(5, 4)
|
||||
|
||||
def forward(self, x1, x2, y):
|
||||
x1 = self.fc(x1)
|
||||
x2 = self.fc(x2)
|
||||
@ -221,7 +234,6 @@ class TrainerTestGround(unittest.TestCase):
|
||||
print_every=2
|
||||
)
|
||||
|
||||
|
||||
def test_case2(self):
|
||||
# check metrics Wrong
|
||||
data_set = prepare_fake_dataset2('x1', 'x2')
|
||||
|
Loading…
Reference in New Issue
Block a user