mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
1.修复trainer中潜在多步更新bug; 2. LSTM的数据并行修改;3. embed_loader中bug修复, 且允许手动初始化;
This commit is contained in:
parent
8b8d184026
commit
9c1b4914d8
@ -548,7 +548,7 @@ class LRScheduler(Callback):
|
||||
else:
|
||||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.")
|
||||
|
||||
def on_epoch_begin(self):
|
||||
def on_epoch_end(self):
|
||||
self.scheduler.step(self.epoch)
|
||||
|
||||
|
||||
|
@ -801,17 +801,19 @@ class DataSet(object):
|
||||
else:
|
||||
return DataSet()
|
||||
|
||||
def split(self, ratio):
|
||||
def split(self, ratio, shuffle=True):
|
||||
"""
|
||||
将DataSet按照ratio的比例拆分,返回两个DataSet
|
||||
|
||||
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `(1-ratio)` 这么多数据,第二个DataSet拥有`ratio`这么多数据
|
||||
:param bool shuffle: 在split前是否shuffle一下
|
||||
:return: [DataSet, DataSet]
|
||||
"""
|
||||
assert isinstance(ratio, float)
|
||||
assert 0 < ratio < 1
|
||||
all_indices = [_ for _ in range(len(self))]
|
||||
np.random.shuffle(all_indices)
|
||||
if shuffle:
|
||||
np.random.shuffle(all_indices)
|
||||
split = int(ratio * len(self))
|
||||
dev_indices = all_indices[:split]
|
||||
train_indices = all_indices[split:]
|
||||
|
@ -36,6 +36,23 @@ class Optimizer(object):
|
||||
"""
|
||||
return [param for param in params if param.requires_grad]
|
||||
|
||||
class NullOptimizer(Optimizer):
|
||||
"""
|
||||
当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(None)
|
||||
|
||||
def construct_from_pytorch(self, model_params):
|
||||
pass
|
||||
|
||||
def __getattr__(self, item):
|
||||
def pass_func(*args, **kwargs):
|
||||
pass
|
||||
|
||||
return pass_func
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
"""
|
||||
|
@ -615,7 +615,7 @@ class Trainer(object):
|
||||
if self.step % self.print_every == 0:
|
||||
avg_loss = float(avg_loss) / self.print_every
|
||||
if self.use_tqdm:
|
||||
print_output = "loss:{0:<6.5f}".format(avg_loss)
|
||||
print_output = "loss:{:<6.5f}".format(avg_loss)
|
||||
pbar.update(self.print_every)
|
||||
else:
|
||||
end = time.time()
|
||||
@ -679,7 +679,7 @@ class Trainer(object):
|
||||
"""Perform weight update on a model.
|
||||
|
||||
"""
|
||||
if self.optimizer is not None and (self.step + 1) % self.update_every == 0:
|
||||
if self.step % self.update_every == 0:
|
||||
self.optimizer.step()
|
||||
|
||||
def _data_forward(self, network, x):
|
||||
@ -697,7 +697,7 @@ class Trainer(object):
|
||||
|
||||
For PyTorch, just do "loss.backward()"
|
||||
"""
|
||||
if self.step % self.update_every == 0:
|
||||
if (self.step-1) % self.update_every == 0:
|
||||
self.model.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
|
@ -38,7 +38,8 @@ class EmbedLoader(BaseLoader):
|
||||
super(EmbedLoader, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, error='ignore'):
|
||||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True,
|
||||
error='ignore', init_method=None):
|
||||
"""
|
||||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是
|
||||
word2vec(第一行只有两个元素)还是glove格式的数据。
|
||||
@ -52,6 +53,7 @@ class EmbedLoader(BaseLoader):
|
||||
:param bool normalize: 是否将每个vector归一化到norm为1
|
||||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。
|
||||
这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。
|
||||
:param callable init_method: 传入numpy.ndarray, 返回numpy.ndarray, 用以初始化embedding
|
||||
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
|
||||
"""
|
||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported."
|
||||
@ -69,6 +71,8 @@ class EmbedLoader(BaseLoader):
|
||||
dim = len(parts) - 1
|
||||
f.seek(0)
|
||||
matrix = np.random.randn(len(vocab), dim).astype(dtype)
|
||||
if init_method:
|
||||
matrix = init_method(matrix)
|
||||
for idx, line in enumerate(f, start_idx):
|
||||
try:
|
||||
parts = line.strip().split()
|
||||
@ -91,14 +95,15 @@ class EmbedLoader(BaseLoader):
|
||||
raise e
|
||||
total_hits = sum(hit_flags)
|
||||
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
|
||||
found_vectors = matrix[hit_flags]
|
||||
if len(found_vectors) != 0:
|
||||
mean = np.mean(found_vectors, axis=0, keepdims=True)
|
||||
std = np.std(found_vectors, axis=0, keepdims=True)
|
||||
unfound_vec_num = len(vocab) - total_hits
|
||||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean
|
||||
matrix[hit_flags == False] = r_vecs
|
||||
|
||||
if init_method is None:
|
||||
found_vectors = matrix[hit_flags]
|
||||
if len(found_vectors) != 0:
|
||||
mean = np.mean(found_vectors, axis=0, keepdims=True)
|
||||
std = np.std(found_vectors, axis=0, keepdims=True)
|
||||
unfound_vec_num = len(vocab) - total_hits
|
||||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean
|
||||
matrix[hit_flags == False] = r_vecs
|
||||
|
||||
if normalize:
|
||||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True)
|
||||
|
||||
@ -157,13 +162,17 @@ class EmbedLoader(BaseLoader):
|
||||
if dim == -1:
|
||||
raise RuntimeError("{} is an empty file.".format(embed_filepath))
|
||||
matrix = np.random.randn(len(vocab), dim).astype(dtype)
|
||||
for key, vec in vec_dict.items():
|
||||
index = vocab.to_index(key)
|
||||
matrix[index] = vec
|
||||
|
||||
if (unknown is not None and not found_unknown) or (padding is not None and not found_pad):
|
||||
start_idx = 0
|
||||
if padding is not None:
|
||||
start_idx += 1
|
||||
if unknown is not None:
|
||||
start_idx += 1
|
||||
|
||||
|
||||
mean = np.mean(matrix[start_idx:], axis=0, keepdims=True)
|
||||
std = np.std(matrix[start_idx:], axis=0, keepdims=True)
|
||||
if (unknown is not None and not found_unknown):
|
||||
@ -171,10 +180,6 @@ class EmbedLoader(BaseLoader):
|
||||
if (padding is not None and not found_pad):
|
||||
matrix[0] = np.random.randn(1, dim).astype(dtype) * std + mean
|
||||
|
||||
for key, vec in vec_dict.items():
|
||||
index = vocab.to_index(key)
|
||||
matrix[index] = vec
|
||||
|
||||
if normalize:
|
||||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True)
|
||||
|
||||
|
@ -73,21 +73,12 @@ class LSTM(nn.Module):
|
||||
x = x[:, sort_idx]
|
||||
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first)
|
||||
output, hx = self.lstm(x, hx) # -> [N,L,C]
|
||||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first)
|
||||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len)
|
||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
|
||||
if self.batch_first:
|
||||
output = output[unsort_idx]
|
||||
else:
|
||||
output = output[:, unsort_idx]
|
||||
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591
|
||||
if self.batch_first:
|
||||
if output.size(1) < max_len:
|
||||
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1))
|
||||
output = torch.cat([output, dummy_tensor], 1)
|
||||
else:
|
||||
if output.size(0) < max_len:
|
||||
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1))
|
||||
output = torch.cat([output, dummy_tensor], 0)
|
||||
else:
|
||||
output, hx = self.lstm(x, hx)
|
||||
return output, hx
|
||||
|
@ -82,6 +82,8 @@ def get_embeddings(init_embed):
|
||||
if isinstance(init_embed, tuple):
|
||||
res = nn.Embedding(
|
||||
num_embeddings=init_embed[0], embedding_dim=init_embed[1])
|
||||
nn.init.uniform_(res.weight.data, a=-np.sqrt(3/res.weight.data.size(1)),
|
||||
b=np.sqrt(3/res.weight.data.size(1)))
|
||||
elif isinstance(init_embed, nn.Module):
|
||||
res = init_embed
|
||||
elif isinstance(init_embed, torch.Tensor):
|
||||
|
2
setup.py
2
setup.py
@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f:
|
||||
|
||||
setup(
|
||||
name='FastNLP',
|
||||
version='0.4.0',
|
||||
version='dev0.5.0',
|
||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
|
||||
long_description=readme,
|
||||
long_description_content_type='text/markdown',
|
||||
|
Loading…
Reference in New Issue
Block a user