mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
22 lines
518 B
Python
22 lines
518 B
Python
from collections import namedtuple
|
|
|
|
import numpy as np
|
|
|
|
from action.trainer import Trainer
|
|
from model.base_model import ToyModel
|
|
|
|
|
|
def test_trainer():
|
|
Config = namedtuple("config", ["epochs", "validate", "save_when_better"])
|
|
train_config = Config(epochs=5, validate=True, save_when_better=True)
|
|
trainer = Trainer(train_config)
|
|
|
|
net = ToyModel()
|
|
data = np.random.rand(20, 6)
|
|
dev_data = np.random.rand(20, 6)
|
|
trainer.train(net, data, dev_data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_trainer()
|