add kmax pooling module

This commit is contained in:
Ke Zhen 2018-07-02 15:12:18 +08:00
parent 2569c85c8e
commit 7b7826544e
3 changed files with 20 additions and 0 deletions

View File

@ -0,0 +1,20 @@
# python: 3.6
# encoding: utf-8
import torch
import torch.nn as nn
# import torch.nn.functional as F
class KMaxPool(nn.Module):
"""K max-pooling module."""
def __init__(self, k):
super(KMaxPool, self).__init__()
self.k = k
def forward(self, x):
# [N,C,L] -> [N,C*k]
x, index = torch.topk(x, self.k, dim=-1, sorted=False)
x = torch.reshape(x, (x.size(0), -1))
return x