mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
add base methods for model.base_model
This commit is contained in:
parent
4f71d44999
commit
7b46f422c7
4
.idea/deployment.xml
Normal file
4
.idea/deployment.xml
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" persistUploadOnCheckin="false" />
|
||||
</project>
|
11
.idea/fastNLP.iml
Normal file
11
.idea/fastNLP.iml
Normal file
@ -0,0 +1,11 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
|
||||
</component>
|
||||
</module>
|
4
.idea/misc.xml
Normal file
4
.idea/misc.xml
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.5 (PCA_emb)" project-jdk-type="Python SDK" />
|
||||
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/fastNLP.iml" filepath="$PROJECT_DIR$/.idea/fastNLP.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
@ -1,4 +1,3 @@
|
||||
Some useful reference:
|
||||
SpaCy "Doc"
|
||||
https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/doc.pyx#L80
|
||||
|
@ -8,10 +8,10 @@ class Action(object):
|
||||
self.logger = None
|
||||
|
||||
def load_config(self, args):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def load_dataset(self, args):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, args):
|
||||
self.logger.log(args)
|
||||
@ -22,7 +22,7 @@ class Action(object):
|
||||
|
||||
def batchify(self, X, Y=None):
|
||||
# a generator
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def make_log(self, *args):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
@ -29,7 +29,7 @@ class Tester(Action):
|
||||
for step in range(iterations):
|
||||
batch_x, batch_y = test_batch_generator.__next__()
|
||||
|
||||
# forward pass from test input to predicted output
|
||||
# forward pass from tests input to predicted output
|
||||
prediction = network.data_forward(batch_x)
|
||||
|
||||
# get the loss
|
||||
|
@ -11,4 +11,4 @@ class Trainer(Action):
|
||||
self.arg = arg
|
||||
|
||||
def train(self, args):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
@ -10,5 +10,4 @@ class ConfigLoader(BaseLoader):
|
||||
|
||||
@staticmethod
|
||||
def parse(string):
|
||||
# To do
|
||||
return string
|
||||
raise NotImplementedError
|
||||
|
20
model/base_model.py
Normal file
20
model/base_model.py
Normal file
@ -0,0 +1,20 @@
|
||||
class BaseModel(object):
|
||||
"""base model for all models"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def prepare_input(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def mode(self, test=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def data_forward(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def grad_backward(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def loss(self, pred, truth):
|
||||
raise NotImplementedError
|
@ -1,17 +1,12 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
import
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.datasets as dsets
|
||||
import torchvision.transforms as transforms
|
||||
import dataset as dst
|
||||
from model import CNN_text
|
||||
.dataset as dst
|
||||
from .model import CNN_text
|
||||
from torch.autograd import Variable
|
||||
|
||||
from sklearn import cross_validation
|
||||
from sklearn import datasets
|
||||
|
||||
|
||||
|
||||
# Hyper Parameters
|
||||
batch_size = 50
|
||||
learning_rate = 0.0001
|
||||
@ -51,8 +46,7 @@ if cuda:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
|
||||
|
||||
|
||||
#train and test
|
||||
# train and tests
|
||||
best_acc = None
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
|
@ -1,13 +1,13 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from model import charLM
|
||||
from utilities import *
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from utilities import *
|
||||
|
||||
|
||||
def to_var(x):
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda()
|
||||
@ -76,18 +76,18 @@ if __name__ == "__main__":
|
||||
|
||||
|
||||
if os.path.exists("cache/data_sets.pt") is False:
|
||||
|
||||
test_text = read_data("./test.txt")
|
||||
|
||||
test_text = read_data("./tests.txt")
|
||||
test_set = np.array(text2vec(test_text, char_dict, max_word_len))
|
||||
|
||||
# Labels are next-word index in word_dict with the same length as inputs
|
||||
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]])
|
||||
|
||||
category = {"test": test_set, "tlabel":test_label}
|
||||
category = {"tests": test_set, "tlabel": test_label}
|
||||
torch.save(category, "cache/data_sets.pt")
|
||||
else:
|
||||
data_sets = torch.load("cache/data_sets.pt")
|
||||
test_set = data_sets["test"]
|
||||
test_set = data_sets["tests"]
|
||||
test_label = data_sets["tlabel"]
|
||||
train_set = data_sets["tdata"]
|
||||
train_label = data_sets["trlabel"]
|
||||
|
@ -13,8 +13,7 @@ from .utilities import *
|
||||
|
||||
|
||||
def preprocess():
|
||||
|
||||
word_dict, char_dict = create_word_char_dict("valid.txt", "train.txt", "test.txt")
|
||||
word_dict, char_dict = create_word_char_dict("valid.txt", "train.txt", "tests.txt")
|
||||
num_words = len(word_dict)
|
||||
num_char = len(char_dict)
|
||||
char_dict["BOW"] = num_char+1
|
||||
@ -195,7 +194,7 @@ if __name__=="__main__":
|
||||
if os.path.exists("cache/data_sets.pt") is False:
|
||||
train_text = read_data("./train.txt")
|
||||
valid_text = read_data("./valid.txt")
|
||||
test_text = read_data("./test.txt")
|
||||
test_text = read_data("./tests.txt")
|
||||
|
||||
train_set = np.array(text2vec(train_text, char_dict, max_word_len))
|
||||
valid_set = np.array(text2vec(valid_text, char_dict, max_word_len))
|
||||
@ -206,14 +205,14 @@ if __name__=="__main__":
|
||||
valid_label = np.array([word_dict[w] for w in valid_text[1:]] + [word_dict[valid_text[-1]]])
|
||||
test_label = np.array([word_dict[w] for w in test_text[1:]] + [word_dict[test_text[-1]]])
|
||||
|
||||
category = {"tdata":train_set, "vdata":valid_set, "test": test_set,
|
||||
category = {"tdata": train_set, "vdata": valid_set, "tests": test_set,
|
||||
"trlabel":train_label, "vlabel":valid_label, "tlabel":test_label}
|
||||
torch.save(category, "cache/data_sets.pt")
|
||||
else:
|
||||
data_sets = torch.load("cache/data_sets.pt")
|
||||
train_set = data_sets["tdata"]
|
||||
valid_set = data_sets["vdata"]
|
||||
test_set = data_sets["test"]
|
||||
test_set = data_sets["tests"]
|
||||
train_label = data_sets["trlabel"]
|
||||
valid_label = data_sets["vlabel"]
|
||||
test_label = data_sets["tlabel"]
|
||||
|
@ -5,10 +5,10 @@ class BaseSaver(object):
|
||||
self.save_path = save_path
|
||||
|
||||
def save_bytes(self):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def save_str(self):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def compress(self):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
@ -8,4 +8,4 @@ class Logger(BaseSaver):
|
||||
super(Logger, self).__init__(save_path)
|
||||
|
||||
def log(self, string):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
Loading…
Reference in New Issue
Block a user