mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
7514be6f30
- restructure: move reproduction outside - add evaluate in tester
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def batch_generator(x, batch_size):
|
|
# x: [num_words, in_channel, height, width]
|
|
# partitions x into batches
|
|
num_step = x.size()[0] // batch_size
|
|
for t in range(num_step):
|
|
yield x[t * batch_size:(t + 1) * batch_size]
|
|
|
|
|
|
def text2vec(words, char_dict, max_word_len):
|
|
""" Return list of list of int """
|
|
word_vec = []
|
|
for word in words:
|
|
vec = [char_dict[ch] for ch in word]
|
|
if len(vec) < max_word_len:
|
|
vec += [char_dict["PAD"] for _ in range(max_word_len - len(vec))]
|
|
vec = [char_dict["BOW"]] + vec + [char_dict["EOW"]]
|
|
word_vec.append(vec)
|
|
return word_vec
|
|
|
|
|
|
def seq2vec(input_words, char_embedding, char_embedding_dim, char_table):
|
|
""" convert the input strings into character embeddings """
|
|
# input_words == list of string
|
|
# char_embedding == torch.nn.Embedding
|
|
# char_embedding_dim == int
|
|
# char_table == list of unique chars
|
|
# Returns: tensor of shape [len(input_words), char_embedding_dim, max_word_len+2]
|
|
max_word_len = max([len(word) for word in input_words])
|
|
print("max_word_len={}".format(max_word_len))
|
|
tensor_list = []
|
|
|
|
start_column = torch.ones(char_embedding_dim, 1)
|
|
end_column = torch.ones(char_embedding_dim, 1)
|
|
|
|
for word in input_words:
|
|
# convert string to word attention
|
|
word_encoding = char_embedding_lookup(word, char_embedding, char_table)
|
|
# add start and end columns
|
|
word_encoding = torch.cat([start_column, word_encoding, end_column], 1)
|
|
# zero-pad right columns
|
|
word_encoding = F.pad(word_encoding, (0, max_word_len - word_encoding.size()[1] + 2)).data
|
|
# create dimension
|
|
word_encoding = word_encoding.unsqueeze(0)
|
|
|
|
tensor_list.append(word_encoding)
|
|
|
|
return torch.cat(tensor_list, 0)
|
|
|
|
|
|
def read_data(file_name):
|
|
# Return: list of strings
|
|
with open(file_name, 'r') as f:
|
|
corpus = f.read().lower()
|
|
import re
|
|
corpus = re.sub(r"<unk>", "unk", corpus)
|
|
return corpus.split()
|
|
|
|
|
|
def get_char_dict(vocabulary):
|
|
# vocabulary == dict of (word, int)
|
|
# Return: dict of (char, int), starting from 1
|
|
char_dict = dict()
|
|
count = 1
|
|
for word in vocabulary:
|
|
for ch in word:
|
|
if ch not in char_dict:
|
|
char_dict[ch] = count
|
|
count += 1
|
|
return char_dict
|
|
|
|
|
|
def create_word_char_dict(*file_name):
|
|
text = []
|
|
for file in file_name:
|
|
text += read_data(file)
|
|
word_dict = {word: ix for ix, word in enumerate(set(text))}
|
|
char_dict = get_char_dict(word_dict)
|
|
return word_dict, char_dict
|