-
-
Save nimanzik/1a637b64ddfa953ca6b63374695f7de0 to your computer and use it in GitHub Desktop.
PyTorch MedianPool (MedianFilter)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.modules.utils import _pair, _quadruple | |
| class MedianPool2d(nn.Module): | |
| """ Median pool (usable as median filter when stride=1) module. | |
| Args: | |
| kernel_size: size of pooling kernel, int or 2-tuple | |
| stride: pool stride, int or 2-tuple | |
| padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad | |
| same: override padding and enforce same padding, boolean | |
| """ | |
| def __init__(self, kernel_size=3, stride=1, padding=0, same=False): | |
| super(MedianPool2d, self).__init__() | |
| self.k = _pair(kernel_size) | |
| self.stride = _pair(stride) | |
| self.padding = _quadruple(padding) # convert to l, r, t, b | |
| self.same = same | |
| def _padding(self, x): | |
| if self.same: | |
| ih, iw = x.size()[2:] | |
| if ih % self.stride[0] == 0: | |
| ph = max(self.k[0] - self.stride[0], 0) | |
| else: | |
| ph = max(self.k[0] - (ih % self.stride[0]), 0) | |
| if iw % self.stride[1] == 0: | |
| pw = max(self.k[1] - self.stride[1], 0) | |
| else: | |
| pw = max(self.k[1] - (iw % self.stride[1]), 0) | |
| pl = pw // 2 | |
| pr = pw - pl | |
| pt = ph // 2 | |
| pb = ph - pt | |
| padding = (pl, pr, pt, pb) | |
| else: | |
| padding = self.padding | |
| return padding | |
| def forward(self, x): | |
| # using existing pytorch functions and tensor ops so that we get autograd, | |
| # would likely be more efficient to implement from scratch at C/Cuda level | |
| x = F.pad(x, self._padding(x), mode='reflect') | |
| x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) | |
| x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] | |
| return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment