Merge pull request #8 from keezen/master

rename and add kmax pooling module
This commit is contained in:
Coet 2018-07-02 19:27:30 +08:00 committed by GitHub
commit a99895223c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 73 additions and 48 deletions

View File

@ -1,22 +0,0 @@
# python: 3.6
# encoding: utf-8
import torch.nn as nn
# import torch.nn.functional as F
class AvgPool1d(nn.Module):
"""1-d average pooling module."""
def __init__(self, kernel_size, stride=None, padding=0,
ceil_mode=False, count_include_pad=True):
super(AvgPool1d, self).__init__()
self.pool = nn.AvgPool1d(
kernel_size=kernel_size,
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
def forward(self, x):
return self.pool(x)

View File

@ -1,23 +0,0 @@
# python: 3.6
# encoding: utf-8
import torch.nn as nn
# import torch.nn.functional as F
class MaxPool1d(nn.Module):
"""1-d max-pooling module."""
def __init__(self, kernel_size, stride=None, padding=0,
dilation=1, return_indices=False, ceil_mode=False):
super(MaxPool1d, self).__init__()
self.maxpool = nn.MaxPool1d(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode)
def forward(self, x):
return self.maxpool(x)

View File

@ -0,0 +1,24 @@
# python: 3.6
# encoding: utf-8
import torch.nn as nn
import torch.nn.functional as F
class AvgPool(nn.Module):
"""1-d average pooling module."""
def __init__(self, stride=None, padding=0):
super(AvgPool, self).__init__()
self.stride = stride
self.padding = padding
def forward(self, x):
# [N,C,L] -> [N,C]
kernel_size = x.size(2)
x = F.max_pool1d(
input=x,
kernel_size=kernel_size,
stride=self.stride,
padding=self.padding)
return x.squeeze(dim=-1)

View File

@ -5,7 +5,7 @@ import torch.nn as nn
# import torch.nn.functional as F
class Conv1d(nn.Module):
class Conv(nn.Module):
"""
Basic 1-d convolution module.
"""
@ -13,7 +13,7 @@ class Conv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1,
groups=1, bias=True):
super(Conv1d, self).__init__()
super(Conv, self).__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
@ -25,4 +25,4 @@ class Conv1d(nn.Module):
bias=bias)
def forward(self, x):
return self.conv(x)
return self.conv(x) # [N,C,L]

View File

@ -0,0 +1,20 @@
# python: 3.6
# encoding: utf-8
import torch
import torch.nn as nn
# import torch.nn.functional as F
class KMaxPool(nn.Module):
"""K max-pooling module."""
def __init__(self, k):
super(KMaxPool, self).__init__()
self.k = k
def forward(self, x):
# [N,C,L] -> [N,C*k]
x, index = torch.topk(x, self.k, dim=-1, sorted=False)
x = torch.reshape(x, (x.size(0), -1))
return x

View File

@ -0,0 +1,26 @@
# python: 3.6
# encoding: utf-8
import torch.nn as nn
import torch.nn.functional as F
class MaxPool(nn.Module):
"""1-d max-pooling module."""
def __init__(self, stride=None, padding=0, dilation=1):
super(MaxPool, self).__init__()
self.stride = stride
self.padding = padding
self.dilation = dilation
def forward(self, x):
# [N,C,L] -> [N,C]
kernel_size = x.size(2)
x = F.max_pool1d(
input=x,
kernel_size=kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation)
return x.squeeze(dim=-1)