cancel restriction for base model

This commit is contained in:
FengZiYjun 2018-07-07 16:59:59 +08:00
parent cca276b8c0
commit 4c9c791304

View File

@ -3,31 +3,12 @@ import torch
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
"""Base PyTorch model for all models. """Base PyTorch model for all models.
Three network modules presented: To do: add some useful common features
- encoder module
- aggregation module
- decoder module
Subclasses must implement these three modules with "components".
""" """
def __init__(self): def __init__(self):
super(BaseModel, self).__init__() super(BaseModel, self).__init__()
def forward(self, *inputs):
x = self.encode(*inputs)
x = self.aggregate(x)
x = self.decode(x)
return x
def encode(self, x):
raise NotImplementedError
def aggregate(self, x):
raise NotImplementedError
def decode(self, x):
raise NotImplementedError
class Vocabulary(object): class Vocabulary(object):
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab` """A look-up table that allows you to access `Lexeme` objects. The `Vocab`