mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
Add ENAS (Efficient Neural Architecture Search)
This commit is contained in:
parent
13faa2b410
commit
efeac2c427
223
fastNLP/models/enas_controller.py
Normal file
223
fastNLP/models/enas_controller.py
Normal file
@ -0,0 +1,223 @@
|
||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch
|
||||
"""A module with NAS controller-related code."""
|
||||
import collections
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import fastNLP
|
||||
import fastNLP.models.enas_utils as utils
|
||||
from fastNLP.models.enas_utils import Node
|
||||
|
||||
|
||||
def _construct_dags(prev_nodes, activations, func_names, num_blocks):
|
||||
"""Constructs a set of DAGs based on the actions, i.e., previous nodes and
|
||||
activation functions, sampled from the controller/policy pi.
|
||||
|
||||
Args:
|
||||
prev_nodes: Previous node actions from the policy.
|
||||
activations: Activations sampled from the policy.
|
||||
func_names: Mapping from activation function names to functions.
|
||||
num_blocks: Number of blocks in the target RNN cell.
|
||||
|
||||
Returns:
|
||||
A list of DAGs defined by the inputs.
|
||||
|
||||
RNN cell DAGs are represented in the following way:
|
||||
|
||||
1. Each element (node) in a DAG is a list of `Node`s.
|
||||
|
||||
2. The `Node`s in the list dag[i] correspond to the subsequent nodes
|
||||
that take the output from node i as their own input.
|
||||
|
||||
3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}.
|
||||
dag[-1] always feeds dag[0].
|
||||
dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its
|
||||
weights.
|
||||
|
||||
4. dag[N - 1] is the node that produces the hidden state passed to
|
||||
the next timestep. dag[N - 1] is also always a leaf node, and therefore
|
||||
is always averaged with the other leaf nodes and fed to the output
|
||||
decoder.
|
||||
"""
|
||||
dags = []
|
||||
for nodes, func_ids in zip(prev_nodes, activations):
|
||||
dag = collections.defaultdict(list)
|
||||
|
||||
# add first node
|
||||
dag[-1] = [Node(0, func_names[func_ids[0]])]
|
||||
dag[-2] = [Node(0, func_names[func_ids[0]])]
|
||||
|
||||
# add following nodes
|
||||
for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])):
|
||||
dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id]))
|
||||
|
||||
leaf_nodes = set(range(num_blocks)) - dag.keys()
|
||||
|
||||
# merge with avg
|
||||
for idx in leaf_nodes:
|
||||
dag[idx] = [Node(num_blocks, 'avg')]
|
||||
|
||||
# This is actually y^{(t)}. h^{(t)} is node N - 1 in
|
||||
# the graph, where N Is the number of nodes. I.e., h^{(t)} takes
|
||||
# only one other node as its input.
|
||||
# last h[t] node
|
||||
last_node = Node(num_blocks + 1, 'h[t]')
|
||||
dag[num_blocks] = [last_node]
|
||||
dags.append(dag)
|
||||
|
||||
return dags
|
||||
|
||||
|
||||
class Controller(torch.nn.Module):
|
||||
"""Based on
|
||||
https://github.com/pytorch/examples/blob/master/word_language_model/model.py
|
||||
|
||||
RL controllers do not necessarily have much to do with
|
||||
language models.
|
||||
|
||||
Base the controller RNN on the GRU from:
|
||||
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
|
||||
"""
|
||||
def __init__(self, num_blocks=4, controller_hid=100, cuda=False):
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
# `num_tokens` here is just the activation function
|
||||
# for every even step,
|
||||
self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid']
|
||||
self.num_tokens = [len(self.shared_rnn_activations)]
|
||||
self.controller_hid = controller_hid
|
||||
self.use_cuda = cuda
|
||||
self.num_blocks = num_blocks
|
||||
for idx in range(num_blocks):
|
||||
self.num_tokens += [idx + 1, len(self.shared_rnn_activations)]
|
||||
self.func_names = self.shared_rnn_activations
|
||||
|
||||
num_total_tokens = sum(self.num_tokens)
|
||||
|
||||
self.encoder = torch.nn.Embedding(num_total_tokens,
|
||||
controller_hid)
|
||||
self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid)
|
||||
|
||||
# Perhaps these weights in the decoder should be
|
||||
# shared? At least for the activation functions, which all have the
|
||||
# same size.
|
||||
self.decoders = []
|
||||
for idx, size in enumerate(self.num_tokens):
|
||||
decoder = torch.nn.Linear(controller_hid, size)
|
||||
self.decoders.append(decoder)
|
||||
|
||||
self._decoders = torch.nn.ModuleList(self.decoders)
|
||||
|
||||
self.reset_parameters()
|
||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden)
|
||||
|
||||
def _get_default_hidden(key):
|
||||
return utils.get_variable(
|
||||
torch.zeros(key, self.controller_hid),
|
||||
self.use_cuda,
|
||||
requires_grad=False)
|
||||
|
||||
self.static_inputs = utils.keydefaultdict(_get_default_hidden)
|
||||
|
||||
def reset_parameters(self):
|
||||
init_range = 0.1
|
||||
for param in self.parameters():
|
||||
param.data.uniform_(-init_range, init_range)
|
||||
for decoder in self.decoders:
|
||||
decoder.bias.data.fill_(0)
|
||||
|
||||
def forward(self, # pylint:disable=arguments-differ
|
||||
inputs,
|
||||
hidden,
|
||||
block_idx,
|
||||
is_embed):
|
||||
if not is_embed:
|
||||
embed = self.encoder(inputs)
|
||||
else:
|
||||
embed = inputs
|
||||
|
||||
hx, cx = self.lstm(embed, hidden)
|
||||
logits = self.decoders[block_idx](hx)
|
||||
|
||||
logits /= 5.0
|
||||
|
||||
# # exploration
|
||||
# if self.args.mode == 'train':
|
||||
# logits = (2.5 * F.tanh(logits))
|
||||
|
||||
return logits, (hx, cx)
|
||||
|
||||
def sample(self, batch_size=1, with_details=False, save_dir=None):
|
||||
"""Samples a set of `args.num_blocks` many computational nodes from the
|
||||
controller, where each node is made up of an activation function, and
|
||||
each node except the last also includes a previous node.
|
||||
"""
|
||||
if batch_size < 1:
|
||||
raise Exception(f'Wrong batch_size: {batch_size} < 1')
|
||||
|
||||
# [B, L, H]
|
||||
inputs = self.static_inputs[batch_size]
|
||||
hidden = self.static_init_hidden[batch_size]
|
||||
|
||||
activations = []
|
||||
entropies = []
|
||||
log_probs = []
|
||||
prev_nodes = []
|
||||
# The RNN controller alternately outputs an activation,
|
||||
# followed by a previous node, for each block except the last one,
|
||||
# which only gets an activation function. The last node is the output
|
||||
# node, and its previous node is the average of all leaf nodes.
|
||||
for block_idx in range(2*(self.num_blocks - 1) + 1):
|
||||
logits, hidden = self.forward(inputs,
|
||||
hidden,
|
||||
block_idx,
|
||||
is_embed=(block_idx == 0))
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
log_prob = F.log_softmax(logits, dim=-1)
|
||||
# .mean() for entropy?
|
||||
entropy = -(log_prob * probs).sum(1, keepdim=False)
|
||||
|
||||
action = probs.multinomial(num_samples=1).data
|
||||
selected_log_prob = log_prob.gather(
|
||||
1, utils.get_variable(action, requires_grad=False))
|
||||
|
||||
# why the [:, 0] here? Should it be .squeeze(), or
|
||||
# .view()? Same below with `action`.
|
||||
entropies.append(entropy)
|
||||
log_probs.append(selected_log_prob[:, 0])
|
||||
|
||||
# 0: function, 1: previous node
|
||||
mode = block_idx % 2
|
||||
inputs = utils.get_variable(
|
||||
action[:, 0] + sum(self.num_tokens[:mode]),
|
||||
requires_grad=False)
|
||||
|
||||
if mode == 0:
|
||||
activations.append(action[:, 0])
|
||||
elif mode == 1:
|
||||
prev_nodes.append(action[:, 0])
|
||||
|
||||
prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
|
||||
activations = torch.stack(activations).transpose(0, 1)
|
||||
|
||||
dags = _construct_dags(prev_nodes,
|
||||
activations,
|
||||
self.func_names,
|
||||
self.num_blocks)
|
||||
|
||||
if save_dir is not None:
|
||||
for idx, dag in enumerate(dags):
|
||||
utils.draw_network(dag,
|
||||
os.path.join(save_dir, f'graph{idx}.png'))
|
||||
|
||||
if with_details:
|
||||
return dags, torch.cat(log_probs), torch.cat(entropies)
|
||||
|
||||
return dags
|
||||
|
||||
def init_hidden(self, batch_size):
|
||||
zeros = torch.zeros(batch_size, self.controller_hid)
|
||||
return (utils.get_variable(zeros, self.use_cuda, requires_grad=False),
|
||||
utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False))
|
388
fastNLP/models/enas_model.py
Normal file
388
fastNLP/models/enas_model.py
Normal file
@ -0,0 +1,388 @@
|
||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch
|
||||
|
||||
"""Module containing the shared RNN model."""
|
||||
import numpy as np
|
||||
import collections
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
import fastNLP.models.enas_utils as utils
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
import fastNLP.modules.encoder as encoder
|
||||
|
||||
def _get_dropped_weights(w_raw, dropout_p, is_training):
|
||||
"""Drops out weights to implement DropConnect.
|
||||
|
||||
Args:
|
||||
w_raw: Full, pre-dropout, weights to be dropped out.
|
||||
dropout_p: Proportion of weights to drop out.
|
||||
is_training: True iff _shared_ model is training.
|
||||
|
||||
Returns:
|
||||
The dropped weights.
|
||||
|
||||
Why does torch.nn.functional.dropout() return:
|
||||
1. `torch.autograd.Variable()` on the training loop
|
||||
2. `torch.nn.Parameter()` on the controller or eval loop, when
|
||||
training = False...
|
||||
|
||||
Even though the call to `_setweights` in the Smerity repo's
|
||||
`weight_drop.py` does not have this behaviour, and `F.dropout` always
|
||||
returns `torch.autograd.Variable` there, even when `training=False`?
|
||||
|
||||
The above TODO is the reason for the hacky check for `torch.nn.Parameter`.
|
||||
"""
|
||||
dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training)
|
||||
|
||||
if isinstance(dropped_w, torch.nn.Parameter):
|
||||
dropped_w = dropped_w.clone()
|
||||
|
||||
return dropped_w
|
||||
|
||||
class EmbeddingDropout(torch.nn.Embedding):
|
||||
"""Class for dropping out embeddings by zero'ing out parameters in the
|
||||
embedding matrix.
|
||||
|
||||
This is equivalent to dropping out particular words, e.g., in the sentence
|
||||
'the quick brown fox jumps over the lazy dog', dropping out 'the' would
|
||||
lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the
|
||||
embedding vector space).
|
||||
|
||||
See 'A Theoretically Grounded Application of Dropout in Recurrent Neural
|
||||
Networks', (Gal and Ghahramani, 2016).
|
||||
"""
|
||||
def __init__(self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
max_norm=None,
|
||||
norm_type=2,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
dropout=0.1,
|
||||
scale=None):
|
||||
"""Embedding constructor.
|
||||
|
||||
Args:
|
||||
dropout: Dropout probability.
|
||||
scale: Used to scale parameters of embedding weight matrix that are
|
||||
not dropped out. Note that this is _in addition_ to the
|
||||
`1/(1 - dropout)` scaling.
|
||||
|
||||
See `torch.nn.Embedding` for remaining arguments.
|
||||
"""
|
||||
torch.nn.Embedding.__init__(self,
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
self.dropout = dropout
|
||||
assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 '
|
||||
'and < 1.0')
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, inputs): # pylint:disable=arguments-differ
|
||||
"""Embeds `inputs` with the dropped out embedding weight matrix."""
|
||||
if self.training:
|
||||
dropout = self.dropout
|
||||
else:
|
||||
dropout = 0
|
||||
|
||||
if dropout:
|
||||
mask = self.weight.data.new(self.weight.size(0), 1)
|
||||
mask.bernoulli_(1 - dropout)
|
||||
mask = mask.expand_as(self.weight)
|
||||
mask = mask / (1 - dropout)
|
||||
masked_weight = self.weight * Variable(mask)
|
||||
else:
|
||||
masked_weight = self.weight
|
||||
if self.scale and self.scale != 1:
|
||||
masked_weight = masked_weight * self.scale
|
||||
|
||||
return F.embedding(inputs,
|
||||
masked_weight,
|
||||
max_norm=self.max_norm,
|
||||
norm_type=self.norm_type,
|
||||
scale_grad_by_freq=self.scale_grad_by_freq,
|
||||
sparse=self.sparse)
|
||||
|
||||
|
||||
class LockedDropout(nn.Module):
|
||||
# code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, dropout=0.5):
|
||||
if not self.training or not dropout:
|
||||
return x
|
||||
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
|
||||
mask = Variable(m, requires_grad=False) / (1 - dropout)
|
||||
mask = mask.expand_as(x)
|
||||
return mask * x
|
||||
|
||||
|
||||
class ENASModel(BaseModel):
|
||||
"""Shared RNN model."""
|
||||
def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000):
|
||||
super(ENASModel, self).__init__()
|
||||
|
||||
self.use_cuda = cuda
|
||||
|
||||
self.shared_hid = shared_hid
|
||||
self.num_blocks = num_blocks
|
||||
self.decoder = nn.Linear(self.shared_hid, num_classes)
|
||||
self.encoder = EmbeddingDropout(embed_num,
|
||||
shared_embed,
|
||||
dropout=0.1)
|
||||
self.lockdrop = LockedDropout()
|
||||
self.dag = None
|
||||
|
||||
# Tie weights
|
||||
# self.decoder.weight = self.encoder.weight
|
||||
|
||||
# Since W^{x, c} and W^{h, c} are always summed, there
|
||||
# is no point duplicating their bias offset parameter. Likewise for
|
||||
# W^{x, h} and W^{h, h}.
|
||||
self.w_xc = nn.Linear(shared_embed, self.shared_hid)
|
||||
self.w_xh = nn.Linear(shared_embed, self.shared_hid)
|
||||
|
||||
# The raw weights are stored here because the hidden-to-hidden weights
|
||||
# are weight dropped on the forward pass.
|
||||
self.w_hc_raw = torch.nn.Parameter(
|
||||
torch.Tensor(self.shared_hid, self.shared_hid))
|
||||
self.w_hh_raw = torch.nn.Parameter(
|
||||
torch.Tensor(self.shared_hid, self.shared_hid))
|
||||
self.w_hc = None
|
||||
self.w_hh = None
|
||||
|
||||
self.w_h = collections.defaultdict(dict)
|
||||
self.w_c = collections.defaultdict(dict)
|
||||
|
||||
for idx in range(self.num_blocks):
|
||||
for jdx in range(idx + 1, self.num_blocks):
|
||||
self.w_h[idx][jdx] = nn.Linear(self.shared_hid,
|
||||
self.shared_hid,
|
||||
bias=False)
|
||||
self.w_c[idx][jdx] = nn.Linear(self.shared_hid,
|
||||
self.shared_hid,
|
||||
bias=False)
|
||||
|
||||
self._w_h = nn.ModuleList([self.w_h[idx][jdx]
|
||||
for idx in self.w_h
|
||||
for jdx in self.w_h[idx]])
|
||||
self._w_c = nn.ModuleList([self.w_c[idx][jdx]
|
||||
for idx in self.w_c
|
||||
for jdx in self.w_c[idx]])
|
||||
|
||||
self.batch_norm = None
|
||||
# if args.mode == 'train':
|
||||
# self.batch_norm = nn.BatchNorm1d(self.shared_hid)
|
||||
# else:
|
||||
# self.batch_norm = None
|
||||
|
||||
self.reset_parameters()
|
||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden)
|
||||
|
||||
def setDAG(self, dag):
|
||||
if self.dag is None:
|
||||
self.dag = dag
|
||||
|
||||
def forward(self, word_seq, hidden=None):
|
||||
inputs = torch.transpose(word_seq, 0, 1)
|
||||
|
||||
time_steps = inputs.size(0)
|
||||
batch_size = inputs.size(1)
|
||||
|
||||
|
||||
self.w_hh = _get_dropped_weights(self.w_hh_raw,
|
||||
0.5,
|
||||
self.training)
|
||||
self.w_hc = _get_dropped_weights(self.w_hc_raw,
|
||||
0.5,
|
||||
self.training)
|
||||
|
||||
# hidden = self.static_init_hidden[batch_size] if hidden is None else hidden
|
||||
hidden = self.static_init_hidden[batch_size]
|
||||
|
||||
embed = self.encoder(inputs)
|
||||
|
||||
embed = self.lockdrop(embed, 0.65 if self.training else 0)
|
||||
|
||||
# The norm of hidden states are clipped here because
|
||||
# otherwise ENAS is especially prone to exploding activations on the
|
||||
# forward pass. This could probably be fixed in a more elegant way, but
|
||||
# it might be exposing a weakness in the ENAS algorithm as currently
|
||||
# proposed.
|
||||
#
|
||||
# For more details, see
|
||||
# https://github.com/carpedm20/ENAS-pytorch/issues/6
|
||||
clipped_num = 0
|
||||
max_clipped_norm = 0
|
||||
h1tohT = []
|
||||
logits = []
|
||||
for step in range(time_steps):
|
||||
x_t = embed[step]
|
||||
logit, hidden = self.cell(x_t, hidden, self.dag)
|
||||
|
||||
hidden_norms = hidden.norm(dim=-1)
|
||||
max_norm = 25.0
|
||||
if hidden_norms.data.max() > max_norm:
|
||||
# Just directly use the torch slice operations
|
||||
# in PyTorch v0.4.
|
||||
#
|
||||
# This workaround for PyTorch v0.3.1 does everything in numpy,
|
||||
# because the PyTorch slicing and slice assignment is too
|
||||
# flaky.
|
||||
hidden_norms = hidden_norms.data.cpu().numpy()
|
||||
|
||||
clipped_num += 1
|
||||
if hidden_norms.max() > max_clipped_norm:
|
||||
max_clipped_norm = hidden_norms.max()
|
||||
|
||||
clip_select = hidden_norms > max_norm
|
||||
clip_norms = hidden_norms[clip_select]
|
||||
|
||||
mask = np.ones(hidden.size())
|
||||
normalizer = max_norm/clip_norms
|
||||
normalizer = normalizer[:, np.newaxis]
|
||||
|
||||
mask[clip_select] = normalizer
|
||||
|
||||
if self.use_cuda:
|
||||
hidden *= torch.autograd.Variable(
|
||||
torch.FloatTensor(mask).cuda(), requires_grad=False)
|
||||
else:
|
||||
hidden *= torch.autograd.Variable(
|
||||
torch.FloatTensor(mask), requires_grad=False)
|
||||
logits.append(logit)
|
||||
h1tohT.append(hidden)
|
||||
|
||||
h1tohT = torch.stack(h1tohT)
|
||||
output = torch.stack(logits)
|
||||
raw_output = output
|
||||
|
||||
output = self.lockdrop(output, 0.4 if self.training else 0)
|
||||
|
||||
#Pooling
|
||||
output = torch.mean(output, 0)
|
||||
|
||||
decoded = self.decoder(output)
|
||||
|
||||
extra_out = {'dropped': decoded,
|
||||
'hiddens': h1tohT,
|
||||
'raw': raw_output}
|
||||
return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out}
|
||||
|
||||
def cell(self, x, h_prev, dag):
|
||||
"""Computes a single pass through the discovered RNN cell."""
|
||||
c = {}
|
||||
h = {}
|
||||
f = {}
|
||||
|
||||
f[0] = self.get_f(dag[-1][0].name)
|
||||
c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None))
|
||||
h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) +
|
||||
(1 - c[0])*h_prev)
|
||||
|
||||
leaf_node_ids = []
|
||||
q = collections.deque()
|
||||
q.append(0)
|
||||
|
||||
# Computes connections from the parent nodes `node_id`
|
||||
# to their child nodes `next_id` recursively, skipping leaf nodes. A
|
||||
# leaf node is a node whose id == `self.num_blocks`.
|
||||
#
|
||||
# Connections between parent i and child j should be computed as
|
||||
# h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i,
|
||||
# where c_j = \sigmoid{(W^c_{ij}*h_i)}
|
||||
#
|
||||
# See Training details from Section 3.1 of the paper.
|
||||
#
|
||||
# The following algorithm does a breadth-first (since `q.popleft()` is
|
||||
# used) search over the nodes and computes all the hidden states.
|
||||
while True:
|
||||
if len(q) == 0:
|
||||
break
|
||||
|
||||
node_id = q.popleft()
|
||||
nodes = dag[node_id]
|
||||
|
||||
for next_node in nodes:
|
||||
next_id = next_node.id
|
||||
if next_id == self.num_blocks:
|
||||
leaf_node_ids.append(node_id)
|
||||
assert len(nodes) == 1, ('parent of leaf node should have '
|
||||
'only one child')
|
||||
continue
|
||||
|
||||
w_h = self.w_h[node_id][next_id]
|
||||
w_c = self.w_c[node_id][next_id]
|
||||
|
||||
f[next_id] = self.get_f(next_node.name)
|
||||
c[next_id] = torch.sigmoid(w_c(h[node_id]))
|
||||
h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) +
|
||||
(1 - c[next_id])*h[node_id])
|
||||
|
||||
q.append(next_id)
|
||||
|
||||
# Instead of averaging loose ends, perhaps there should
|
||||
# be a set of separate unshared weights for each "loose" connection
|
||||
# between each node in a cell and the output.
|
||||
#
|
||||
# As it stands, all weights W^h_{ij} are doing double duty by
|
||||
# connecting both from i to j, as well as from i to the output.
|
||||
|
||||
# average all the loose ends
|
||||
leaf_nodes = [h[node_id] for node_id in leaf_node_ids]
|
||||
output = torch.mean(torch.stack(leaf_nodes, 2), -1)
|
||||
|
||||
# stabilizing the Updates of omega
|
||||
if self.batch_norm is not None:
|
||||
output = self.batch_norm(output)
|
||||
|
||||
return output, h[self.num_blocks - 1]
|
||||
|
||||
def init_hidden(self, batch_size):
|
||||
zeros = torch.zeros(batch_size, self.shared_hid)
|
||||
return utils.get_variable(zeros, self.use_cuda, requires_grad=False)
|
||||
|
||||
def get_f(self, name):
|
||||
name = name.lower()
|
||||
if name == 'relu':
|
||||
f = torch.relu
|
||||
elif name == 'tanh':
|
||||
f = torch.tanh
|
||||
elif name == 'identity':
|
||||
f = lambda x: x
|
||||
elif name == 'sigmoid':
|
||||
f = torch.sigmoid
|
||||
return f
|
||||
|
||||
|
||||
@property
|
||||
def num_parameters(self):
|
||||
def size(p):
|
||||
return np.prod(p.size())
|
||||
return sum([size(param) for param in self.parameters()])
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
init_range = 0.025
|
||||
# init_range = 0.025 if self.args.mode == 'train' else 0.04
|
||||
for param in self.parameters():
|
||||
param.data.uniform_(-init_range, init_range)
|
||||
self.decoder.bias.data.fill_(0)
|
||||
|
||||
def predict(self, word_seq):
|
||||
"""
|
||||
|
||||
:param word_seq: torch.LongTensor, [batch_size, seq_len]
|
||||
:return predict: dict of torch.LongTensor, [batch_size, seq_len]
|
||||
"""
|
||||
output = self(word_seq)
|
||||
_, predict = output['pred'].max(dim=1)
|
||||
return {'pred': predict}
|
385
fastNLP/models/enas_trainer.py
Normal file
385
fastNLP/models/enas_trainer.py
Normal file
@ -0,0 +1,385 @@
|
||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
from tqdm.autonotebook import tqdm
|
||||
except:
|
||||
from fastNLP.core.utils import pseudo_tqdm as tqdm
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.callback import CallbackManager, CallbackException
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.utils import CheckError
|
||||
from fastNLP.core.utils import _move_dict_value_to_device
|
||||
import fastNLP
|
||||
import fastNLP.models.enas_utils as utils
|
||||
from fastNLP.core.utils import _build_args
|
||||
|
||||
from torch.optim import Adam
|
||||
|
||||
|
||||
def _get_no_grad_ctx_mgr():
|
||||
"""Returns a the `torch.no_grad` context manager for PyTorch version >=
|
||||
0.4, or a no-op context manager otherwise.
|
||||
"""
|
||||
return torch.no_grad()
|
||||
|
||||
|
||||
class ENASTrainer(fastNLP.Trainer):
|
||||
"""A class to wrap training code."""
|
||||
def __init__(self, train_data, model, controller, **kwargs):
|
||||
"""Constructor for training algorithm.
|
||||
:param DataSet train_data: the training data
|
||||
:param torch.nn.modules.module model: a PyTorch model
|
||||
:param torch.nn.modules.module controller: a PyTorch model
|
||||
"""
|
||||
self.final_epochs = kwargs['final_epochs']
|
||||
kwargs.pop('final_epochs')
|
||||
super(ENASTrainer, self).__init__(train_data, model, **kwargs)
|
||||
self.controller_step = 0
|
||||
self.shared_step = 0
|
||||
self.max_length = 35
|
||||
|
||||
self.shared = model
|
||||
self.controller = controller
|
||||
|
||||
self.shared_optim = Adam(
|
||||
self.shared.parameters(),
|
||||
lr=20.0,
|
||||
weight_decay=1e-7)
|
||||
|
||||
self.controller_optim = Adam(
|
||||
self.controller.parameters(),
|
||||
lr=3.5e-4)
|
||||
|
||||
def train(self, load_best_model=True):
|
||||
"""
|
||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
|
||||
最好的模型参数。
|
||||
:return results: 返回一个字典类型的数据, 内含以下内容::
|
||||
|
||||
seconds: float, 表示训练时长
|
||||
以下三个内容只有在提供了dev_data的情况下会有。
|
||||
best_eval: Dict of Dict, 表示evaluation的结果
|
||||
best_epoch: int,在第几个epoch取得的最佳值
|
||||
best_step: int, 在第几个step(batch)更新取得的最佳值
|
||||
|
||||
"""
|
||||
results = {}
|
||||
if self.n_epochs <= 0:
|
||||
print(f"training epoch is {self.n_epochs}, nothing was done.")
|
||||
results['seconds'] = 0.
|
||||
return results
|
||||
try:
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
self.model = self.model.cuda()
|
||||
self._model_device = self.model.parameters().__next__().device
|
||||
self._mode(self.model, is_test=False)
|
||||
|
||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
|
||||
start_time = time.time()
|
||||
print("training epochs started " + self.start_time, flush=True)
|
||||
|
||||
try:
|
||||
self.callback_manager.on_train_begin()
|
||||
self._train()
|
||||
self.callback_manager.on_train_end(self.model)
|
||||
except (CallbackException, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_exception(e, self.model)
|
||||
|
||||
if self.dev_data is not None:
|
||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
|
||||
self.tester._format_eval_results(self.best_dev_perf),)
|
||||
results['best_eval'] = self.best_dev_perf
|
||||
results['best_epoch'] = self.best_dev_epoch
|
||||
results['best_step'] = self.best_dev_step
|
||||
if load_best_model:
|
||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
|
||||
load_succeed = self._load_model(self.model, model_name)
|
||||
if load_succeed:
|
||||
print("Reloaded the best model.")
|
||||
else:
|
||||
print("Fail to reload best model.")
|
||||
finally:
|
||||
pass
|
||||
results['seconds'] = round(time.time() - start_time, 2)
|
||||
|
||||
return results
|
||||
|
||||
def _train(self):
|
||||
if not self.use_tqdm:
|
||||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
|
||||
else:
|
||||
inner_tqdm = tqdm
|
||||
self.step = 0
|
||||
start = time.time()
|
||||
total_steps = (len(self.train_data) // self.batch_size + int(
|
||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
|
||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
|
||||
avg_loss = 0
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
|
||||
prefetch=self.prefetch)
|
||||
for epoch in range(1, self.n_epochs+1):
|
||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
|
||||
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs)
|
||||
if epoch == self.n_epochs + 1 - self.final_epochs:
|
||||
print('Entering the final stage. (Only train the selected structure)')
|
||||
# early stopping
|
||||
self.callback_manager.on_epoch_begin(epoch, self.n_epochs)
|
||||
|
||||
# 1. Training the shared parameters omega of the child models
|
||||
self.train_shared(pbar)
|
||||
|
||||
# 2. Training the controller parameters theta
|
||||
if not last_stage:
|
||||
self.train_controller()
|
||||
|
||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
|
||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
|
||||
and self.dev_data is not None:
|
||||
if not last_stage:
|
||||
self.derive()
|
||||
eval_res = self._do_validation(epoch=epoch, step=self.step)
|
||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
|
||||
total_steps) + \
|
||||
self.tester._format_eval_results(eval_res)
|
||||
pbar.write(eval_str)
|
||||
|
||||
# lr decay; early stopping
|
||||
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer)
|
||||
# =============== epochs end =================== #
|
||||
pbar.close()
|
||||
# ============ tqdm end ============== #
|
||||
|
||||
|
||||
def get_loss(self, inputs, targets, hidden, dags):
|
||||
"""Computes the loss for the same batch for M models.
|
||||
|
||||
This amounts to an estimate of the loss, which is turned into an
|
||||
estimate for the gradients of the shared model.
|
||||
"""
|
||||
if not isinstance(dags, list):
|
||||
dags = [dags]
|
||||
|
||||
loss = 0
|
||||
for dag in dags:
|
||||
self.shared.setDAG(dag)
|
||||
inputs = _build_args(self.shared.forward, **inputs)
|
||||
inputs['hidden'] = hidden
|
||||
result = self.shared(**inputs)
|
||||
output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out']
|
||||
|
||||
self.callback_manager.on_loss_begin(targets, result)
|
||||
sample_loss = self._compute_loss(result, targets)
|
||||
loss += sample_loss
|
||||
|
||||
assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
|
||||
return loss, hidden, extra_out
|
||||
|
||||
def train_shared(self, pbar=None, max_step=None, dag=None):
|
||||
"""Train the language model for 400 steps of minibatches of 64
|
||||
examples.
|
||||
|
||||
Args:
|
||||
max_step: Used to run extra training steps as a warm-up.
|
||||
dag: If not None, is used instead of calling sample().
|
||||
|
||||
BPTT is truncated at 35 timesteps.
|
||||
|
||||
For each weight update, gradients are estimated by sampling M models
|
||||
from the fixed controller policy, and averaging their gradients
|
||||
computed on a batch of training data.
|
||||
"""
|
||||
model = self.shared
|
||||
model.train()
|
||||
self.controller.eval()
|
||||
|
||||
hidden = self.shared.init_hidden(self.batch_size)
|
||||
|
||||
abs_max_grad = 0
|
||||
abs_max_hidden_norm = 0
|
||||
step = 0
|
||||
raw_total_loss = 0
|
||||
total_loss = 0
|
||||
train_idx = 0
|
||||
avg_loss = 0
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
|
||||
prefetch=self.prefetch)
|
||||
|
||||
for batch_x, batch_y in data_iterator:
|
||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
|
||||
indices = data_iterator.get_batch_indices()
|
||||
# negative sampling; replace unknown; re-weight batch_y
|
||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
|
||||
# prediction = self._data_forward(self.model, batch_x)
|
||||
|
||||
dags = self.controller.sample(1)
|
||||
inputs, targets = batch_x, batch_y
|
||||
# self.callback_manager.on_loss_begin(batch_y, prediction)
|
||||
loss, hidden, extra_out = self.get_loss(inputs,
|
||||
targets,
|
||||
hidden,
|
||||
dags)
|
||||
hidden.detach_()
|
||||
|
||||
avg_loss += loss.item()
|
||||
|
||||
# Is loss NaN or inf? requires_grad = False
|
||||
self.callback_manager.on_backward_begin(loss, self.model)
|
||||
self._grad_backward(loss)
|
||||
self.callback_manager.on_backward_end(self.model)
|
||||
|
||||
self._update()
|
||||
self.callback_manager.on_step_end(self.optimizer)
|
||||
|
||||
if (self.step+1) % self.print_every == 0:
|
||||
if self.use_tqdm:
|
||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
|
||||
pbar.update(self.print_every)
|
||||
else:
|
||||
end = time.time()
|
||||
diff = timedelta(seconds=round(end - start))
|
||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
|
||||
epoch, self.step, avg_loss, diff)
|
||||
pbar.set_postfix_str(print_output)
|
||||
avg_loss = 0
|
||||
self.step += 1
|
||||
step += 1
|
||||
self.shared_step += 1
|
||||
self.callback_manager.on_batch_end()
|
||||
# ================= mini-batch end ==================== #
|
||||
|
||||
|
||||
def get_reward(self, dag, entropies, hidden, valid_idx=0):
|
||||
"""Computes the perplexity of a single sampled model on a minibatch of
|
||||
validation data.
|
||||
"""
|
||||
if not isinstance(entropies, np.ndarray):
|
||||
entropies = entropies.data.cpu().numpy()
|
||||
|
||||
data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
|
||||
prefetch=self.prefetch)
|
||||
|
||||
for inputs, targets in data_iterator:
|
||||
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)
|
||||
valid_loss = utils.to_item(valid_loss.data)
|
||||
|
||||
valid_ppl = math.exp(valid_loss)
|
||||
|
||||
R = 80 / valid_ppl
|
||||
|
||||
rewards = R + 1e-4 * entropies
|
||||
|
||||
return rewards, hidden
|
||||
|
||||
def train_controller(self):
|
||||
"""Fixes the shared parameters and updates the controller parameters.
|
||||
|
||||
The controller is updated with a score function gradient estimator
|
||||
(i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
|
||||
is computed on a minibatch of validation data.
|
||||
|
||||
A moving average baseline is used.
|
||||
|
||||
The controller is trained for 2000 steps per epoch (i.e.,
|
||||
first (Train Shared) phase -> second (Train Controller) phase).
|
||||
"""
|
||||
model = self.controller
|
||||
model.train()
|
||||
# Why can't we call shared.eval() here? Leads to loss
|
||||
# being uniformly zero for the controller.
|
||||
# self.shared.eval()
|
||||
|
||||
avg_reward_base = None
|
||||
baseline = None
|
||||
adv_history = []
|
||||
entropy_history = []
|
||||
reward_history = []
|
||||
|
||||
hidden = self.shared.init_hidden(self.batch_size)
|
||||
total_loss = 0
|
||||
valid_idx = 0
|
||||
for step in range(20):
|
||||
# sample models
|
||||
dags, log_probs, entropies = self.controller.sample(
|
||||
with_details=True)
|
||||
|
||||
# calculate reward
|
||||
np_entropies = entropies.data.cpu().numpy()
|
||||
# No gradients should be backpropagated to the
|
||||
# shared model during controller training, obviously.
|
||||
with _get_no_grad_ctx_mgr():
|
||||
rewards, hidden = self.get_reward(dags,
|
||||
np_entropies,
|
||||
hidden,
|
||||
valid_idx)
|
||||
|
||||
|
||||
reward_history.extend(rewards)
|
||||
entropy_history.extend(np_entropies)
|
||||
|
||||
# moving average baseline
|
||||
if baseline is None:
|
||||
baseline = rewards
|
||||
else:
|
||||
decay = 0.95
|
||||
baseline = decay * baseline + (1 - decay) * rewards
|
||||
|
||||
adv = rewards - baseline
|
||||
adv_history.extend(adv)
|
||||
|
||||
# policy loss
|
||||
loss = -log_probs*utils.get_variable(adv,
|
||||
self.use_cuda,
|
||||
requires_grad=False)
|
||||
|
||||
loss = loss.sum() # or loss.mean()
|
||||
|
||||
# update
|
||||
self.controller_optim.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
self.controller_optim.step()
|
||||
|
||||
total_loss += utils.to_item(loss.data)
|
||||
|
||||
if ((step % 50) == 0) and (step > 0):
|
||||
reward_history, adv_history, entropy_history = [], [], []
|
||||
total_loss = 0
|
||||
|
||||
self.controller_step += 1
|
||||
# prev_valid_idx = valid_idx
|
||||
# valid_idx = ((valid_idx + self.max_length) %
|
||||
# (self.valid_data.size(0) - 1))
|
||||
# # Whenever we wrap around to the beginning of the
|
||||
# # validation data, we reset the hidden states.
|
||||
# if prev_valid_idx > valid_idx:
|
||||
# hidden = self.shared.init_hidden(self.batch_size)
|
||||
|
||||
def derive(self, sample_num=10, valid_idx=0):
|
||||
"""We are always deriving based on the very first batch
|
||||
of validation data? This seems wrong...
|
||||
"""
|
||||
hidden = self.shared.init_hidden(self.batch_size)
|
||||
|
||||
dags, _, entropies = self.controller.sample(sample_num,
|
||||
with_details=True)
|
||||
|
||||
max_R = 0
|
||||
best_dag = None
|
||||
for dag in dags:
|
||||
R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
|
||||
if R.max() > max_R:
|
||||
max_R = R.max()
|
||||
best_dag = dag
|
||||
|
||||
self.model.setDAG(best_dag)
|
56
fastNLP/models/enas_utils.py
Normal file
56
fastNLP/models/enas_utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import defaultdict
|
||||
import collections
|
||||
from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
def detach(h):
|
||||
if type(h) == Variable:
|
||||
return Variable(h.data)
|
||||
else:
|
||||
return tuple(detach(v) for v in h)
|
||||
|
||||
def get_variable(inputs, cuda=False, **kwargs):
|
||||
if type(inputs) in [list, np.ndarray]:
|
||||
inputs = torch.Tensor(inputs)
|
||||
if cuda:
|
||||
out = Variable(inputs.cuda(), **kwargs)
|
||||
else:
|
||||
out = Variable(inputs, **kwargs)
|
||||
return out
|
||||
|
||||
def update_lr(optimizer, lr):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
Node = collections.namedtuple('Node', ['id', 'name'])
|
||||
|
||||
|
||||
class keydefaultdict(defaultdict):
|
||||
def __missing__(self, key):
|
||||
if self.default_factory is None:
|
||||
raise KeyError(key)
|
||||
else:
|
||||
ret = self[key] = self.default_factory(key)
|
||||
return ret
|
||||
|
||||
|
||||
def to_item(x):
|
||||
"""Converts x, possibly scalar and possibly tensor, to a Python scalar."""
|
||||
if isinstance(x, (float, int)):
|
||||
return x
|
||||
|
||||
if float(torch.__version__[0:3]) < 0.4:
|
||||
assert (x.dim() == 1) and (len(x) == 1)
|
||||
return x[0]
|
||||
|
||||
return x.item()
|
112
test/models/test_enas.py
Normal file
112
test/models/test_enas.py
Normal file
@ -0,0 +1,112 @@
|
||||
import unittest
|
||||
|
||||
from fastNLP import DataSet
|
||||
from fastNLP import Instance
|
||||
from fastNLP import Vocabulary
|
||||
from fastNLP.core.losses import CrossEntropyLoss
|
||||
from fastNLP.core.metrics import AccuracyMetric
|
||||
|
||||
|
||||
class TestENAS(unittest.TestCase):
|
||||
def testENAS(self):
|
||||
# 从csv读取数据到DataSet
|
||||
sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv"
|
||||
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'),
|
||||
sep='\t')
|
||||
print(len(dataset))
|
||||
print(dataset[0])
|
||||
print(dataset[-3])
|
||||
|
||||
dataset.append(Instance(raw_sentence='fake data', label='0'))
|
||||
# 将所有数字转为小写
|
||||
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
|
||||
# label转int
|
||||
dataset.apply(lambda x: int(x['label']), new_field_name='label')
|
||||
|
||||
# 使用空格分割句子
|
||||
def split_sent(ins):
|
||||
return ins['raw_sentence'].split()
|
||||
|
||||
dataset.apply(split_sent, new_field_name='words')
|
||||
|
||||
# 增加长度信息
|
||||
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')
|
||||
print(len(dataset))
|
||||
print(dataset[0])
|
||||
|
||||
# DataSet.drop(func)筛除数据
|
||||
dataset.drop(lambda x: x['seq_len'] <= 3)
|
||||
print(len(dataset))
|
||||
|
||||
# 设置DataSet中,哪些field要转为tensor
|
||||
# set target,loss或evaluate中的golden,计算loss,模型评估时使用
|
||||
dataset.set_target("label")
|
||||
# set input,模型forward时使用
|
||||
dataset.set_input("words", "seq_len")
|
||||
|
||||
# 分出测试集、训练集
|
||||
test_data, train_data = dataset.split(0.5)
|
||||
print(len(test_data))
|
||||
print(len(train_data))
|
||||
|
||||
# 构建词表, Vocabulary.add(word)
|
||||
vocab = Vocabulary(min_freq=2)
|
||||
train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
|
||||
vocab.build_vocab()
|
||||
|
||||
# index句子, Vocabulary.to_index(word)
|
||||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
|
||||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
|
||||
print(test_data[0])
|
||||
|
||||
# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.sampler import RandomSampler
|
||||
|
||||
batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())
|
||||
for batch_x, batch_y in batch_iterator:
|
||||
print("batch_x has: ", batch_x)
|
||||
print("batch_y has: ", batch_y)
|
||||
break
|
||||
|
||||
from fastNLP.models.enas_model import ENASModel
|
||||
from fastNLP.models.enas_controller import Controller
|
||||
model = ENASModel(embed_num=len(vocab), num_classes=5)
|
||||
controller = Controller()
|
||||
|
||||
from fastNLP.models.enas_trainer import ENASTrainer
|
||||
from copy import deepcopy
|
||||
|
||||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致
|
||||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致
|
||||
train_data.rename_field('label', 'label_seq')
|
||||
test_data.rename_field('words', 'word_seq')
|
||||
test_data.rename_field('label', 'label_seq')
|
||||
|
||||
loss = CrossEntropyLoss(pred="output", target="label_seq")
|
||||
metric = AccuracyMetric(pred="predict", target="label_seq")
|
||||
|
||||
trainer = ENASTrainer(model=model, controller=controller, train_data=train_data, dev_data=test_data,
|
||||
loss=CrossEntropyLoss(pred="output", target="label_seq"),
|
||||
metrics=AccuracyMetric(pred="predict", target="label_seq"),
|
||||
check_code_level=-1,
|
||||
save_path=None,
|
||||
batch_size=32,
|
||||
print_every=1,
|
||||
n_epochs=3,
|
||||
final_epochs=1)
|
||||
trainer.train()
|
||||
print('Train finished!')
|
||||
|
||||
# 调用Tester在test_data上评价效果
|
||||
from fastNLP import Tester
|
||||
|
||||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"),
|
||||
batch_size=4)
|
||||
|
||||
acc = tester.test()
|
||||
print(acc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user