mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
* FieldArray添加对list of np.array的支持
* 添加测试:FieldArray的初始化
This commit is contained in:
parent
e4f997d52a
commit
b93ca9bb30
@ -112,13 +112,17 @@ class FieldArray(object):
|
|||||||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
|
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
|
||||||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])
|
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])
|
||||||
|
|
||||||
注意:np.array必须仅在最外层,即np.array([np.array, np.array]) 和 list of np.array不考虑
|
|
||||||
类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。
|
类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
content = content
|
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list
|
||||||
|
# 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list]
|
||||||
|
if len(content) == 1 and isinstance(content[0], np.ndarray):
|
||||||
|
# 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field])
|
||||||
|
# 将[np.array] 转化为 list of list
|
||||||
|
content[0] = content[0].tolist()
|
||||||
elif isinstance(content, np.ndarray):
|
elif isinstance(content, np.ndarray):
|
||||||
content = content.tolist() # convert np.ndarray into 2-D list
|
content = content.tolist() # convert np.ndarray into 2-D list
|
||||||
else:
|
else:
|
||||||
|
@ -144,6 +144,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--train", type=str, help="training conll file", default="/home/zyfeng/data/sample.conllx")
|
parser.add_argument("--train", type=str, help="training conll file", default="/home/zyfeng/data/sample.conllx")
|
||||||
parser.add_argument("--dev", type=str, help="dev conll file", default="/home/zyfeng/data/sample.conllx")
|
parser.add_argument("--dev", type=str, help="dev conll file", default="/home/zyfeng/data/sample.conllx")
|
||||||
parser.add_argument("--test", type=str, help="test conll file", default=None)
|
parser.add_argument("--test", type=str, help="test conll file", default=None)
|
||||||
|
parser.add_argument("--save", type=str, help="path to save", default=None)
|
||||||
|
|
||||||
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training")
|
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training")
|
||||||
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model")
|
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model")
|
||||||
|
@ -5,8 +5,59 @@ import numpy as np
|
|||||||
from fastNLP.core.fieldarray import FieldArray
|
from fastNLP.core.fieldarray import FieldArray
|
||||||
|
|
||||||
|
|
||||||
|
class TestFieldArrayInit(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
|
||||||
|
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
|
||||||
|
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
|
||||||
|
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
|
||||||
|
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray;
|
||||||
|
然后后面的样本使用FieldArray.append进行添加。
|
||||||
|
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
|
||||||
|
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))])
|
||||||
|
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
|
||||||
|
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_init_v1(self):
|
||||||
|
# 二维list
|
||||||
|
fa = FieldArray("x", [[1, 2], [3, 4]] * 5, is_input=True)
|
||||||
|
|
||||||
|
def test_init_v2(self):
|
||||||
|
# 二维array
|
||||||
|
fa = FieldArray("x", np.array([[1, 2], [3, 4]] * 5), is_input=True)
|
||||||
|
|
||||||
|
def test_init_v3(self):
|
||||||
|
# 三维list
|
||||||
|
fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True)
|
||||||
|
|
||||||
|
def test_init_v4(self):
|
||||||
|
# 一维list
|
||||||
|
val = [1, 2, 3, 4]
|
||||||
|
fa = FieldArray("x", [val], is_input=True)
|
||||||
|
fa.append(val)
|
||||||
|
|
||||||
|
def test_init_v5(self):
|
||||||
|
# 一维array
|
||||||
|
val = np.array([1, 2, 3, 4])
|
||||||
|
fa = FieldArray("x", [val], is_input=True)
|
||||||
|
fa.append(val)
|
||||||
|
|
||||||
|
def test_init_v6(self):
|
||||||
|
# 二维array
|
||||||
|
val = [[1, 2], [3, 4]]
|
||||||
|
fa = FieldArray("x", [val], is_input=True)
|
||||||
|
fa.append(val)
|
||||||
|
|
||||||
|
def test_init_v7(self):
|
||||||
|
# 二维list
|
||||||
|
val = np.array([[1, 2], [3, 4]])
|
||||||
|
fa = FieldArray("x", [val], is_input=True)
|
||||||
|
fa.append(val)
|
||||||
|
|
||||||
|
|
||||||
class TestFieldArray(unittest.TestCase):
|
class TestFieldArray(unittest.TestCase):
|
||||||
def test(self):
|
def test_main(self):
|
||||||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
|
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
|
||||||
self.assertEqual(len(fa), 5)
|
self.assertEqual(len(fa), 5)
|
||||||
fa.append(6)
|
fa.append(6)
|
||||||
|
@ -408,12 +408,12 @@ class TestTutorial(unittest.TestCase):
|
|||||||
model=model,
|
model=model,
|
||||||
loss=CrossEntropyLoss(pred='pred', target='label'),
|
loss=CrossEntropyLoss(pred='pred', target='label'),
|
||||||
metrics=AccuracyMetric(),
|
metrics=AccuracyMetric(),
|
||||||
n_epochs=5,
|
n_epochs=3,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
print_every=-1,
|
print_every=-1,
|
||||||
validate_every=-1,
|
validate_every=-1,
|
||||||
dev_data=dev_data,
|
dev_data=dev_data,
|
||||||
use_cuda=True,
|
use_cuda=False,
|
||||||
optimizer=Adam(lr=1e-3, weight_decay=0),
|
optimizer=Adam(lr=1e-3, weight_decay=0),
|
||||||
check_code_level=-1,
|
check_code_level=-1,
|
||||||
metric_key='acc',
|
metric_key='acc',
|
||||||
|
Loading…
Reference in New Issue
Block a user