fastNLP/reproduction/Char-aware_NLM/utilities.py
FengZiYjun 7514be6f30 - add validation loss into trainer.train
- restructure: move reproduction outside
- add evaluate in tester
2018-07-11 21:51:35 +08:00

83 lines
2.6 KiB
Python

import torch
import torch.nn.functional as F
def batch_generator(x, batch_size):
# x: [num_words, in_channel, height, width]
# partitions x into batches
num_step = x.size()[0] // batch_size
for t in range(num_step):
yield x[t * batch_size:(t + 1) * batch_size]
def text2vec(words, char_dict, max_word_len):
""" Return list of list of int """
word_vec = []
for word in words:
vec = [char_dict[ch] for ch in word]
if len(vec) < max_word_len:
vec += [char_dict["PAD"] for _ in range(max_word_len - len(vec))]
vec = [char_dict["BOW"]] + vec + [char_dict["EOW"]]
word_vec.append(vec)
return word_vec
def seq2vec(input_words, char_embedding, char_embedding_dim, char_table):
""" convert the input strings into character embeddings """
# input_words == list of string
# char_embedding == torch.nn.Embedding
# char_embedding_dim == int
# char_table == list of unique chars
# Returns: tensor of shape [len(input_words), char_embedding_dim, max_word_len+2]
max_word_len = max([len(word) for word in input_words])
print("max_word_len={}".format(max_word_len))
tensor_list = []
start_column = torch.ones(char_embedding_dim, 1)
end_column = torch.ones(char_embedding_dim, 1)
for word in input_words:
# convert string to word attention
word_encoding = char_embedding_lookup(word, char_embedding, char_table)
# add start and end columns
word_encoding = torch.cat([start_column, word_encoding, end_column], 1)
# zero-pad right columns
word_encoding = F.pad(word_encoding, (0, max_word_len - word_encoding.size()[1] + 2)).data
# create dimension
word_encoding = word_encoding.unsqueeze(0)
tensor_list.append(word_encoding)
return torch.cat(tensor_list, 0)
def read_data(file_name):
# Return: list of strings
with open(file_name, 'r') as f:
corpus = f.read().lower()
import re
corpus = re.sub(r"<unk>", "unk", corpus)
return corpus.split()
def get_char_dict(vocabulary):
# vocabulary == dict of (word, int)
# Return: dict of (char, int), starting from 1
char_dict = dict()
count = 1
for word in vocabulary:
for ch in word:
if ch not in char_dict:
char_dict[ch] = count
count += 1
return char_dict
def create_word_char_dict(*file_name):
text = []
for file in file_name:
text += read_data(file)
word_dict = {word: ix for ix, word in enumerate(set(text))}
char_dict = get_char_dict(word_dict)
return word_dict, char_dict