mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
add kmax pooling module
This commit is contained in:
parent
2569c85c8e
commit
7b7826544e
20
fastNLP/modules/convolution/kmax_pool.py
Normal file
20
fastNLP/modules/convolution/kmax_pool.py
Normal file
@ -0,0 +1,20 @@
|
||||
# python: 3.6
|
||||
# encoding: utf-8
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# import torch.nn.functional as F
|
||||
|
||||
|
||||
class KMaxPool(nn.Module):
|
||||
"""K max-pooling module."""
|
||||
|
||||
def __init__(self, k):
|
||||
super(KMaxPool, self).__init__()
|
||||
self.k = k
|
||||
|
||||
def forward(self, x):
|
||||
# [N,C,L] -> [N,C*k]
|
||||
x, index = torch.topk(x, self.k, dim=-1, sorted=False)
|
||||
x = torch.reshape(x, (x.size(0), -1))
|
||||
return x
|
Loading…
Reference in New Issue
Block a user