mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
* add tqdm in requirements.txt
* fix FieldArray type check bugs
This commit is contained in:
parent
661780b975
commit
4b099bb0dd
@ -83,12 +83,12 @@ class FieldArray(object):
|
|||||||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.")
|
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.")
|
||||||
if len(val) == 0:
|
if len(val) == 0:
|
||||||
raise RuntimeError("Cannot append an empty list.")
|
raise RuntimeError("Cannot append an empty list.")
|
||||||
val_list_type = [type(_) for _ in val] # type check
|
val_list_type = set([type(_) for _ in val]) # type check
|
||||||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type:
|
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type:
|
||||||
# up-cast int to float
|
# up-cast int to float
|
||||||
val_type = float
|
val_type = float
|
||||||
elif len(val_list_type) == 1:
|
elif len(val_list_type) == 1:
|
||||||
val_type = val_list_type[0]
|
val_type = val_list_type.pop()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Cannot append a list of {}".format(val_list_type))
|
raise RuntimeError("Cannot append a list of {}".format(val_list_type))
|
||||||
else:
|
else:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
numpy>=1.14.2
|
numpy>=1.14.2
|
||||||
torch>=0.4.0
|
torch>=0.4.0
|
||||||
tensorboardX
|
tensorboardX
|
||||||
|
tqdm
|
@ -1,8 +1,8 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from fastNLP.core.dataset import DataSet
|
from fastNLP.core.dataset import DataSet
|
||||||
from fastNLP.core.instance import Instance
|
from fastNLP.core.instance import Instance
|
||||||
@ -26,6 +26,7 @@ def prepare_fake_dataset():
|
|||||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
|
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
|
||||||
return data_set
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
def prepare_fake_dataset2(*args, size=100):
|
def prepare_fake_dataset2(*args, size=100):
|
||||||
ys = np.random.randint(4, size=100)
|
ys = np.random.randint(4, size=100)
|
||||||
data = {'y': ys}
|
data = {'y': ys}
|
||||||
@ -33,6 +34,7 @@ def prepare_fake_dataset2(*args, size=100):
|
|||||||
data[arg] = np.random.randn(size, 5)
|
data[arg] = np.random.randn(size, 5)
|
||||||
return DataSet(data=data)
|
return DataSet(data=data)
|
||||||
|
|
||||||
|
|
||||||
class TrainerTestGround(unittest.TestCase):
|
class TrainerTestGround(unittest.TestCase):
|
||||||
def test_case(self):
|
def test_case(self):
|
||||||
data_set = prepare_fake_dataset()
|
data_set = prepare_fake_dataset()
|
||||||
@ -55,15 +57,20 @@ class TrainerTestGround(unittest.TestCase):
|
|||||||
check_code_level=2,
|
check_code_level=2,
|
||||||
use_tqdm=True)
|
use_tqdm=True)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
"""
|
||||||
|
# 应该正确运行
|
||||||
|
"""
|
||||||
|
|
||||||
def test_trainer_suggestion1(self):
|
def test_trainer_suggestion1(self):
|
||||||
# 检查报错提示能否正确提醒用户。
|
# 检查报错提示能否正确提醒用户。
|
||||||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。
|
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。
|
||||||
dataset = prepare_fake_dataset2('x')
|
dataset = prepare_fake_dataset2('x')
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc = nn.Linear(5, 4)
|
self.fc = nn.Linear(5, 4)
|
||||||
|
|
||||||
def forward(self, x1, x2, y):
|
def forward(self, x1, x2, y):
|
||||||
x1 = self.fc(x1)
|
x1 = self.fc(x1)
|
||||||
x2 = self.fc(x2)
|
x2 = self.fc(x2)
|
||||||
@ -72,6 +79,8 @@ class TrainerTestGround(unittest.TestCase):
|
|||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
model = Model()
|
model = Model()
|
||||||
|
|
||||||
|
with self.assertRaises(NameError):
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
train_data=dataset,
|
train_data=dataset,
|
||||||
model=model
|
model=model
|
||||||
@ -91,10 +100,12 @@ class TrainerTestGround(unittest.TestCase):
|
|||||||
# 这里传入forward需要的数据,看是否可以运行
|
# 这里传入forward需要的数据,看是否可以运行
|
||||||
dataset = prepare_fake_dataset2('x1', 'x2')
|
dataset = prepare_fake_dataset2('x1', 'x2')
|
||||||
dataset.set_input('x1', 'x2', 'y', flag=True)
|
dataset.set_input('x1', 'x2', 'y', flag=True)
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc = nn.Linear(5, 4)
|
self.fc = nn.Linear(5, 4)
|
||||||
|
|
||||||
def forward(self, x1, x2, y):
|
def forward(self, x1, x2, y):
|
||||||
x1 = self.fc(x1)
|
x1 = self.fc(x1)
|
||||||
x2 = self.fc(x2)
|
x2 = self.fc(x2)
|
||||||
@ -119,10 +130,12 @@ class TrainerTestGround(unittest.TestCase):
|
|||||||
# 这里传入forward需要的数据,但是forward没有返回loss这个key
|
# 这里传入forward需要的数据,但是forward没有返回loss这个key
|
||||||
dataset = prepare_fake_dataset2('x1', 'x2')
|
dataset = prepare_fake_dataset2('x1', 'x2')
|
||||||
dataset.set_input('x1', 'x2', 'y', flag=True)
|
dataset.set_input('x1', 'x2', 'y', flag=True)
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc = nn.Linear(5, 4)
|
self.fc = nn.Linear(5, 4)
|
||||||
|
|
||||||
def forward(self, x1, x2, y):
|
def forward(self, x1, x2, y):
|
||||||
x1 = self.fc(x1)
|
x1 = self.fc(x1)
|
||||||
x2 = self.fc(x2)
|
x2 = self.fc(x2)
|
||||||
@ -142,7 +155,6 @@ class TrainerTestGround(unittest.TestCase):
|
|||||||
# 应该正确运行
|
# 应该正确运行
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def test_case2(self):
|
def test_case2(self):
|
||||||
# check metrics Wrong
|
# check metrics Wrong
|
||||||
data_set = prepare_fake_dataset2('x1', 'x2')
|
data_set = prepare_fake_dataset2('x1', 'x2')
|
||||||
|
Loading…
Reference in New Issue
Block a user