mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
delete predictor.py
This commit is contained in:
parent
f3a2aa6da5
commit
8445bdbc79
@ -1,79 +0,0 @@
|
||||
"""
|
||||
..todo::
|
||||
检查这个类是否需要
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from . import DataSetIter
|
||||
from . import DataSet
|
||||
from . import SequentialSampler
|
||||
from .utils import _build_args, _move_dict_value_to_device, _get_model_device
|
||||
|
||||
|
||||
class Predictor(object):
|
||||
"""
|
||||
一个根据训练模型预测输出的预测器(Predictor)
|
||||
|
||||
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。
|
||||
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。
|
||||
|
||||
:param torch.nn.Module network: 用来完成预测任务的模型
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
if not isinstance(network, torch.nn.Module):
|
||||
raise ValueError(
|
||||
"Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network)))
|
||||
self.network = network
|
||||
self.batch_size = 1
|
||||
self.batch_output = []
|
||||
|
||||
def predict(self, data: DataSet, seq_len_field_name=None):
|
||||
"""用已经训练好的模型进行inference.
|
||||
|
||||
:param fastNLP.DataSet data: 待预测的数据集
|
||||
:param str seq_len_field_name: 表示序列长度信息的field名字
|
||||
:return: dict dict里面的内容为模型预测的结果
|
||||
"""
|
||||
if not isinstance(data, DataSet):
|
||||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
|
||||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
|
||||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))
|
||||
|
||||
prev_training = self.network.training
|
||||
self.network.eval()
|
||||
network_device = _get_model_device(self.network)
|
||||
batch_output = defaultdict(list)
|
||||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
|
||||
|
||||
if hasattr(self.network, "predict"):
|
||||
predict_func = self.network.predict
|
||||
else:
|
||||
predict_func = self.network.forward
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_x, _ in data_iterator:
|
||||
_move_dict_value_to_device(batch_x, _, device=network_device)
|
||||
refined_batch_x = _build_args(predict_func, **batch_x)
|
||||
prediction = predict_func(**refined_batch_x)
|
||||
|
||||
if seq_len_field_name is not None:
|
||||
seq_lens = batch_x[seq_len_field_name].tolist()
|
||||
|
||||
for key, value in prediction.items():
|
||||
value = value.cpu().numpy()
|
||||
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
|
||||
batch_output[key].extend(value.tolist())
|
||||
else:
|
||||
if seq_len_field_name is not None:
|
||||
tmp_batch = []
|
||||
for idx, seq_len in enumerate(seq_lens):
|
||||
tmp_batch.append(value[idx, :seq_len])
|
||||
batch_output[key].extend(tmp_batch)
|
||||
else:
|
||||
batch_output[key].append(value)
|
||||
|
||||
self.network.train(prev_training)
|
||||
return batch_output
|
@ -1,48 +0,0 @@
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.predictor import Predictor
|
||||
|
||||
|
||||
def prepare_fake_dataset():
|
||||
mean = np.array([-3, -3])
|
||||
cov = np.array([[1, 0], [0, 1]])
|
||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
|
||||
|
||||
mean = np.array([3, 3])
|
||||
cov = np.array([[1, 0], [0, 1]])
|
||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
|
||||
|
||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
|
||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
|
||||
return data_set
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(LinearModel, self).__init__()
|
||||
self.linear = torch.nn.Linear(2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return {"predict": self.linear(x)}
|
||||
|
||||
|
||||
class TestPredictor(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
model = LinearModel()
|
||||
predictor = Predictor(model)
|
||||
data = prepare_fake_dataset()
|
||||
data.set_input("x")
|
||||
ans = predictor.predict(data)
|
||||
self.assertTrue(isinstance(ans, defaultdict))
|
||||
self.assertTrue("predict" in ans)
|
||||
self.assertTrue(isinstance(ans["predict"], list))
|
||||
|
||||
def test_sequence(self):
|
||||
# test sequence input/output
|
||||
pass
|
Loading…
Reference in New Issue
Block a user