mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
* add tqdm in requirements.txt
* fix FieldArray type check bugs
This commit is contained in:
parent
661780b975
commit
4b099bb0dd
@ -83,12 +83,12 @@ class FieldArray(object):
|
||||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.")
|
||||
if len(val) == 0:
|
||||
raise RuntimeError("Cannot append an empty list.")
|
||||
val_list_type = [type(_) for _ in val] # type check
|
||||
val_list_type = set([type(_) for _ in val]) # type check
|
||||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type:
|
||||
# up-cast int to float
|
||||
val_type = float
|
||||
elif len(val_list_type) == 1:
|
||||
val_type = val_list_type[0]
|
||||
val_type = val_list_type.pop()
|
||||
else:
|
||||
raise RuntimeError("Cannot append a list of {}".format(val_list_type))
|
||||
else:
|
||||
|
@ -1,3 +1,4 @@
|
||||
numpy>=1.14.2
|
||||
torch>=0.4.0
|
||||
tensorboardX
|
||||
tqdm
|
@ -1,8 +1,8 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.instance import Instance
|
||||
@ -26,6 +26,7 @@ def prepare_fake_dataset():
|
||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
|
||||
return data_set
|
||||
|
||||
|
||||
def prepare_fake_dataset2(*args, size=100):
|
||||
ys = np.random.randint(4, size=100)
|
||||
data = {'y': ys}
|
||||
@ -33,6 +34,7 @@ def prepare_fake_dataset2(*args, size=100):
|
||||
data[arg] = np.random.randn(size, 5)
|
||||
return DataSet(data=data)
|
||||
|
||||
|
||||
class TrainerTestGround(unittest.TestCase):
|
||||
def test_case(self):
|
||||
data_set = prepare_fake_dataset()
|
||||
@ -55,15 +57,20 @@ class TrainerTestGround(unittest.TestCase):
|
||||
check_code_level=2,
|
||||
use_tqdm=True)
|
||||
trainer.train()
|
||||
"""
|
||||
# 应该正确运行
|
||||
"""
|
||||
|
||||
def test_trainer_suggestion1(self):
|
||||
# 检查报错提示能否正确提醒用户。
|
||||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。
|
||||
dataset = prepare_fake_dataset2('x')
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(5, 4)
|
||||
|
||||
def forward(self, x1, x2, y):
|
||||
x1 = self.fc(x1)
|
||||
x2 = self.fc(x2)
|
||||
@ -72,10 +79,12 @@ class TrainerTestGround(unittest.TestCase):
|
||||
return {'loss': loss}
|
||||
|
||||
model = Model()
|
||||
trainer = Trainer(
|
||||
train_data=dataset,
|
||||
model=model
|
||||
)
|
||||
|
||||
with self.assertRaises(NameError):
|
||||
trainer = Trainer(
|
||||
train_data=dataset,
|
||||
model=model
|
||||
)
|
||||
"""
|
||||
# 应该获取到的报错提示
|
||||
NameError:
|
||||
@ -91,10 +100,12 @@ class TrainerTestGround(unittest.TestCase):
|
||||
# 这里传入forward需要的数据,看是否可以运行
|
||||
dataset = prepare_fake_dataset2('x1', 'x2')
|
||||
dataset.set_input('x1', 'x2', 'y', flag=True)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(5, 4)
|
||||
|
||||
def forward(self, x1, x2, y):
|
||||
x1 = self.fc(x1)
|
||||
x2 = self.fc(x2)
|
||||
@ -119,10 +130,12 @@ class TrainerTestGround(unittest.TestCase):
|
||||
# 这里传入forward需要的数据,但是forward没有返回loss这个key
|
||||
dataset = prepare_fake_dataset2('x1', 'x2')
|
||||
dataset.set_input('x1', 'x2', 'y', flag=True)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(5, 4)
|
||||
|
||||
def forward(self, x1, x2, y):
|
||||
x1 = self.fc(x1)
|
||||
x2 = self.fc(x2)
|
||||
@ -142,7 +155,6 @@ class TrainerTestGround(unittest.TestCase):
|
||||
# 应该正确运行
|
||||
"""
|
||||
|
||||
|
||||
def test_case2(self):
|
||||
# check metrics Wrong
|
||||
data_set = prepare_fake_dataset2('x1', 'x2')
|
||||
|
Loading…
Reference in New Issue
Block a user