mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
cancel restriction for base model
This commit is contained in:
parent
cca276b8c0
commit
4c9c791304
@ -3,31 +3,12 @@ import torch
|
|||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
"""Base PyTorch model for all models.
|
"""Base PyTorch model for all models.
|
||||||
Three network modules presented:
|
To do: add some useful common features
|
||||||
- encoder module
|
|
||||||
- aggregation module
|
|
||||||
- decoder module
|
|
||||||
Subclasses must implement these three modules with "components".
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(BaseModel, self).__init__()
|
super(BaseModel, self).__init__()
|
||||||
|
|
||||||
def forward(self, *inputs):
|
|
||||||
x = self.encode(*inputs)
|
|
||||||
x = self.aggregate(x)
|
|
||||||
x = self.decode(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def aggregate(self, x):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class Vocabulary(object):
|
class Vocabulary(object):
|
||||||
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab`
|
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab`
|
||||||
|
Loading…
Reference in New Issue
Block a user