mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
Merge pull request #4 from choosewhatulike/paper-implement
Implementation of Hierarchical Attention Networks for Document Classification
This commit is contained in:
commit
f5d858e7d2
36
HAN-document_classification/README.md
Normal file
36
HAN-document_classification/README.md
Normal file
@ -0,0 +1,36 @@
|
||||
## Introduction
|
||||
This is the implementation of [Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) paper in PyTorch.
|
||||
* Dataset is 600k documents extracted from [Yelp 2018](https://www.yelp.com/dataset) customer reviews
|
||||
* Use [NLTK](http://www.nltk.org/) and [Stanford CoreNLP](https://stanfordnlp.github.io/CoreNLP/) to tokenize documents and sentences
|
||||
* Both CPU & GPU support
|
||||
* The best accuracy is 71%, reaching the same performance in the paper
|
||||
|
||||
## Requirement
|
||||
* python 3.6
|
||||
* pytorch = 0.3.0
|
||||
* numpy
|
||||
* gensim
|
||||
* nltk
|
||||
* coreNLP
|
||||
|
||||
## Parameters
|
||||
According to the paper and experiment, I set model parameters:
|
||||
|word embedding dimension|GRU hidden size|GRU layer|word/sentence context vector dimension|
|
||||
|---|---|---|---|
|
||||
|200|50|1|100|
|
||||
|
||||
And the training parameters:
|
||||
|Epoch|learning rate|momentum|batch size|
|
||||
|---|---|---|---|
|
||||
|3|0.01|0.9|64|
|
||||
|
||||
## Run
|
||||
1. Prepare dataset. Download the [data set](https://www.yelp.com/dataset), and unzip the custom reviews as a file. Use preprocess.py to transform file into data set foe model input.
|
||||
2. Train the model. Word enbedding of train data in 'yelp.word2vec'. The model will trained and autosaved in 'model.dict'
|
||||
```
|
||||
python train
|
||||
```
|
||||
3. Test the model.
|
||||
```
|
||||
python evaluate
|
||||
```
|
BIN
HAN-document_classification/data/test_samples.pkl
Normal file
BIN
HAN-document_classification/data/test_samples.pkl
Normal file
Binary file not shown.
BIN
HAN-document_classification/data/train_samples.pkl
Normal file
BIN
HAN-document_classification/data/train_samples.pkl
Normal file
Binary file not shown.
BIN
HAN-document_classification/data/yelp.word2vec
Normal file
BIN
HAN-document_classification/data/yelp.word2vec
Normal file
Binary file not shown.
44
HAN-document_classification/evaluate.py
Normal file
44
HAN-document_classification/evaluate.py
Normal file
@ -0,0 +1,44 @@
|
||||
from model import *
|
||||
from train import *
|
||||
|
||||
def evaluate(net, dataset, bactch_size=64, use_cuda=False):
|
||||
dataloader = DataLoader(dataset, batch_size=bactch_size, collate_fn=collate, num_workers=0)
|
||||
count = 0
|
||||
if use_cuda:
|
||||
net.cuda()
|
||||
for i, batch_samples in enumerate(dataloader):
|
||||
x, y = batch_samples
|
||||
doc_list = []
|
||||
for sample in x:
|
||||
doc = []
|
||||
for sent_vec in sample:
|
||||
if use_cuda:
|
||||
sent_vec = sent_vec.cuda()
|
||||
doc.append(Variable(sent_vec, volatile=True))
|
||||
doc_list.append(pack_sequence(doc))
|
||||
if use_cuda:
|
||||
y = y.cuda()
|
||||
predicts = net(doc_list)
|
||||
p, idx = torch.max(predicts, dim=1)
|
||||
idx = idx.data
|
||||
count += torch.sum(torch.eq(idx, y))
|
||||
return count
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
Evaluate the performance of model
|
||||
'''
|
||||
from gensim.models import Word2Vec
|
||||
import gensim
|
||||
from gensim import models
|
||||
embed_model = Word2Vec.load('yelp.word2vec')
|
||||
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
|
||||
del embed_model
|
||||
|
||||
net = HAN(input_size=200, output_size=5,
|
||||
word_hidden_size=50, word_num_layers=1, word_context_size=100,
|
||||
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
|
||||
net.load_state_dict(torch.load('model.dict'))
|
||||
test_dataset = YelpDocSet('reviews', 199, 4, embedding)
|
||||
correct = evaluate(net, test_dataset, True)
|
||||
print('accuracy {}'.format(correct/len(test_dataset)))
|
110
HAN-document_classification/model.py
Normal file
110
HAN-document_classification/model.py
Normal file
@ -0,0 +1,110 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
|
||||
def pack_sequence(tensor_seq, padding_value=0.0):
|
||||
if len(tensor_seq) <= 0:
|
||||
return
|
||||
length = [v.size(0) for v in tensor_seq]
|
||||
max_len = max(length)
|
||||
size = [len(tensor_seq), max_len]
|
||||
size.extend(list(tensor_seq[0].size()[1:]))
|
||||
ans = torch.Tensor(*size).fill_(padding_value)
|
||||
if tensor_seq[0].data.is_cuda:
|
||||
ans = ans.cuda()
|
||||
ans = Variable(ans)
|
||||
for i, v in enumerate(tensor_seq):
|
||||
ans[i, :length[i], :] = v
|
||||
return ans
|
||||
|
||||
class HAN(nn.Module):
|
||||
def __init__(self, input_size, output_size,
|
||||
word_hidden_size, word_num_layers, word_context_size,
|
||||
sent_hidden_size, sent_num_layers, sent_context_size):
|
||||
super(HAN, self).__init__()
|
||||
|
||||
self.word_layer = AttentionNet(input_size,
|
||||
word_hidden_size,
|
||||
word_num_layers,
|
||||
word_context_size)
|
||||
self.sent_layer = AttentionNet(2* word_hidden_size,
|
||||
sent_hidden_size,
|
||||
sent_num_layers,
|
||||
sent_context_size)
|
||||
self.output_layer = nn.Linear(2* sent_hidden_size, output_size)
|
||||
self.softmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, batch_doc):
|
||||
# input is a sequence of matrix
|
||||
doc_vec_list = []
|
||||
for doc in batch_doc:
|
||||
sent_mat = self.word_layer(doc) # doc's dim (num_sent, seq_len, word_dim)
|
||||
doc_vec_list.append(sent_mat) # sent_mat's dim (num_sent, vec_dim)
|
||||
doc_vec = self.sent_layer(pack_sequence(doc_vec_list))
|
||||
output = self.softmax(self.output_layer(doc_vec))
|
||||
return output
|
||||
|
||||
class AttentionNet(nn.Module):
|
||||
def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size):
|
||||
super(AttentionNet, self).__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.gru_hidden_size = gru_hidden_size
|
||||
self.gru_num_layers = gru_num_layers
|
||||
self.context_vec_size = context_vec_size
|
||||
|
||||
# Encoder
|
||||
self.gru = nn.GRU(input_size=input_size,
|
||||
hidden_size=gru_hidden_size,
|
||||
num_layers=gru_num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=True)
|
||||
# Attention
|
||||
self.fc = nn.Linear(2* gru_hidden_size, context_vec_size)
|
||||
self.tanh = nn.Tanh()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
# context vector
|
||||
self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1))
|
||||
self.context_vec.data.uniform_(-0.1, 0.1)
|
||||
|
||||
def forward(self, inputs):
|
||||
# GRU part
|
||||
h_t, hidden = self.gru(inputs) # inputs's dim (batch_size, seq_len, word_dim)
|
||||
u = self.tanh(self.fc(h_t))
|
||||
# Attention part
|
||||
alpha = self.softmax(torch.matmul(u, self.context_vec)) # u's dim (batch_size, seq_len, context_vec_size)
|
||||
output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) # alpha's dim (batch_size, seq_len, 1)
|
||||
return torch.squeeze(output, dim=2) # output's dim (batch_size, 2*hidden_size, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
Test the model correctness
|
||||
'''
|
||||
import numpy as np
|
||||
use_cuda = True
|
||||
net = HAN(input_size=200, output_size=5,
|
||||
word_hidden_size=50, word_num_layers=1, word_context_size=100,
|
||||
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
|
||||
criterion = nn.NLLLoss()
|
||||
test_time = 10
|
||||
batch_size = 64
|
||||
if use_cuda:
|
||||
net.cuda()
|
||||
print('test training')
|
||||
for step in range(test_time):
|
||||
x_data = [torch.randn(np.random.randint(1,10), 200, 200) for i in range(batch_size)]
|
||||
y_data = torch.LongTensor([np.random.randint(0, 5) for i in range(batch_size)])
|
||||
if use_cuda:
|
||||
x_data = [x_i.cuda() for x_i in x_data]
|
||||
y_data = y_data.cuda()
|
||||
x = [Variable(x_i) for x_i in x_data]
|
||||
y = Variable(y_data)
|
||||
predict = net(x)
|
||||
loss = criterion(predict, y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(loss.data[0])
|
51
HAN-document_classification/preprocess.py
Normal file
51
HAN-document_classification/preprocess.py
Normal file
@ -0,0 +1,51 @@
|
||||
''''
|
||||
Tokenize yelp dataset's documents using stanford core nlp
|
||||
'''
|
||||
|
||||
import pickle
|
||||
import json
|
||||
import nltk
|
||||
from nltk.tokenize import stanford
|
||||
import os
|
||||
|
||||
input_filename = 'review.json'
|
||||
|
||||
# config for stanford core nlp
|
||||
os.environ['JAVAHOME'] = 'D:\\java\\bin\\java.exe'
|
||||
path_to_jar = 'E:\\College\\fudanNLP\\stanford-corenlp-full-2018-02-27\\stanford-corenlp-3.9.1.jar'
|
||||
tokenizer = stanford.CoreNLPTokenizer()
|
||||
|
||||
in_dirname = 'review'
|
||||
out_dirname = 'reviews'
|
||||
|
||||
|
||||
f = open(input_filename, encoding='utf-8')
|
||||
samples = []
|
||||
j = 0
|
||||
for i, line in enumerate(f.readlines()):
|
||||
review = json.loads(line)
|
||||
samples.append((review['stars'], review['text']))
|
||||
if (i+1) % 5000 == 0:
|
||||
print(i)
|
||||
pickle.dump(samples, open(in_dirname + '/samples%d.pkl'%j, 'wb'))
|
||||
j += 1
|
||||
samples = []
|
||||
pickle.dump(samples, open(in_dirname + '/samples%d.pkl'%j, 'wb'))
|
||||
# samples = pickle.load(open(out_dirname + '/samples0.pkl', 'rb'))
|
||||
# print(samples[0])
|
||||
|
||||
|
||||
for fn in os.listdir(in_dirname):
|
||||
print(fn)
|
||||
precessed = []
|
||||
for stars, text in pickle.load(open(os.path.join(in_dirname, fn), 'rb')):
|
||||
tokens = []
|
||||
sents = nltk.tokenize.sent_tokenize(text)
|
||||
for s in sents:
|
||||
tokens.append(tokenizer.tokenize(s))
|
||||
precessed.append((stars, tokens))
|
||||
# print(tokens)
|
||||
if len(precessed) % 100 == 0:
|
||||
print(len(precessed))
|
||||
pickle.dump(precessed, open(os.path.join(out_dirname, fn), 'wb'))
|
||||
|
167
HAN-document_classification/train.py
Normal file
167
HAN-document_classification/train.py
Normal file
@ -0,0 +1,167 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import nltk
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from model import *
|
||||
|
||||
class SentIter:
|
||||
def __init__(self, dirname, count):
|
||||
self.dirname = dirname
|
||||
self.count = int(count)
|
||||
|
||||
def __iter__(self):
|
||||
for f in os.listdir(self.dirname)[:self.count]:
|
||||
with open(os.path.join(self.dirname, f), 'rb') as f:
|
||||
for y, x in pickle.load(f):
|
||||
for sent in x:
|
||||
yield sent
|
||||
|
||||
def train_word_vec():
|
||||
# load data
|
||||
dirname = 'reviews'
|
||||
sents = SentIter(dirname, 238)
|
||||
# define model and train
|
||||
model = models.Word2Vec(size=200, sg=0, workers=4, min_count=5)
|
||||
model.build_vocab(sents)
|
||||
model.train(sents, total_examples=model.corpus_count, epochs=10)
|
||||
model.save('yelp.word2vec')
|
||||
print(model.wv.similarity('woman', 'man'))
|
||||
print(model.wv.similarity('nice', 'awful'))
|
||||
|
||||
class Embedding_layer:
|
||||
def __init__(self, wv, vector_size):
|
||||
self.wv = wv
|
||||
self.vector_size = vector_size
|
||||
|
||||
def get_vec(self, w):
|
||||
try:
|
||||
v = self.wv[w]
|
||||
except KeyError as e:
|
||||
v = np.random.randn(self.vector_size)
|
||||
return v
|
||||
|
||||
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
class YelpDocSet(Dataset):
|
||||
def __init__(self, dirname, start_file, num_files, embedding):
|
||||
self.dirname = dirname
|
||||
self.num_files = num_files
|
||||
self._files = os.listdir(dirname)[start_file:start_file + num_files]
|
||||
self.embedding = embedding
|
||||
self._cache = [(-1, None) for i in range(5)]
|
||||
|
||||
def get_doc(self, n):
|
||||
file_id = n // 5000
|
||||
idx = file_id % 5
|
||||
if self._cache[idx][0] != file_id:
|
||||
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
|
||||
self._cache[idx] = (file_id, pickle.load(f))
|
||||
y, x = self._cache[idx][1][n % 5000]
|
||||
sents = []
|
||||
for s_list in x:
|
||||
sents.append(' '.join(s_list))
|
||||
x = '\n'.join(sents)
|
||||
return x, y-1
|
||||
|
||||
def __len__(self):
|
||||
return len(self._files)*5000
|
||||
|
||||
def __getitem__(self, n):
|
||||
file_id = n // 5000
|
||||
idx = file_id % 5
|
||||
if self._cache[idx][0] != file_id:
|
||||
print('load {} to {}'.format(file_id, idx))
|
||||
with open(os.path.join(self.dirname, self._files[file_id]), 'rb') as f:
|
||||
self._cache[idx] = (file_id, pickle.load(f))
|
||||
y, x = self._cache[idx][1][n % 5000]
|
||||
doc = []
|
||||
for sent in x:
|
||||
if len(sent) == 0:
|
||||
continue
|
||||
sent_vec = []
|
||||
for word in sent:
|
||||
vec = self.embedding.get_vec(word)
|
||||
sent_vec.append(vec.tolist())
|
||||
sent_vec = torch.Tensor(sent_vec)
|
||||
doc.append(sent_vec)
|
||||
if len(doc) == 0:
|
||||
doc = [torch.zeros(1,200)]
|
||||
return doc, y-1
|
||||
|
||||
def collate(iterable):
|
||||
y_list = []
|
||||
x_list = []
|
||||
for x, y in iterable:
|
||||
y_list.append(y)
|
||||
x_list.append(x)
|
||||
return x_list, torch.LongTensor(y_list)
|
||||
|
||||
def train(net, dataset, num_epoch, batch_size, print_size=10, use_cuda=False):
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
|
||||
criterion = nn.NLLLoss()
|
||||
|
||||
dataloader = DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate,
|
||||
num_workers=0)
|
||||
running_loss = 0.0
|
||||
|
||||
if use_cuda:
|
||||
net.cuda()
|
||||
print('start training')
|
||||
for epoch in range(num_epoch):
|
||||
for i, batch_samples in enumerate(dataloader):
|
||||
x, y = batch_samples
|
||||
doc_list = []
|
||||
for sample in x:
|
||||
doc = []
|
||||
for sent_vec in sample:
|
||||
if use_cuda:
|
||||
sent_vec = sent_vec.cuda()
|
||||
doc.append(Variable(sent_vec))
|
||||
doc_list.append(pack_sequence(doc))
|
||||
if use_cuda:
|
||||
y = y.cuda()
|
||||
y = Variable(y)
|
||||
predict = net(doc_list)
|
||||
loss = criterion(predict, y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss += loss.data[0]
|
||||
if i % print_size == print_size-1:
|
||||
print('{}, {}'.format(i+1, running_loss/print_size))
|
||||
running_loss = 0.0
|
||||
torch.save(net.state_dict(), 'model.dict')
|
||||
torch.save(net.state_dict(), 'model.dict')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
Train process
|
||||
'''
|
||||
from gensim.models import Word2Vec
|
||||
import gensim
|
||||
from gensim import models
|
||||
|
||||
train_word_vec()
|
||||
|
||||
embed_model = Word2Vec.load('yelp.word2vec')
|
||||
embedding = Embedding_layer(embed_model.wv, embed_model.wv.vector_size)
|
||||
del embed_model
|
||||
start_file = 0
|
||||
dataset = YelpDocSet('reviews', start_file, 120-start_file, embedding)
|
||||
print('training data size {}'.format(len(dataset)))
|
||||
net = HAN(input_size=200, output_size=5,
|
||||
word_hidden_size=50, word_num_layers=1, word_context_size=100,
|
||||
sent_hidden_size=50, sent_num_layers=1, sent_context_size=100)
|
||||
try:
|
||||
net.load_state_dict(torch.load('model.dict'))
|
||||
print("last time trained model has loaded")
|
||||
except Exception:
|
||||
print("cannot load model, train the inital model")
|
||||
|
||||
train(net, dataset, num_epoch=5, batch_size=64, use_cuda=True)
|
Loading…
Reference in New Issue
Block a user