Created
May 30, 2019 10:42
-
-
Save DerThorsten/7117b9b7a41da4e0a13d6500f9a1b657 to your computer and use it in GitHub Desktop.
pytorch learnable gabor filter bank
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import math | |
| import numbers | |
| import torch | |
| import numpy | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class GaborFilters(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| n_sigmas = 3, | |
| n_lambdas = 4, | |
| n_gammas = 1, | |
| n_thetas = 7, | |
| kernel_radius=15, | |
| rotation_invariant=True | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| kernel_size = kernel_radius*2 + 1 | |
| self.kernel_size = kernel_size | |
| self.kernel_radius = kernel_radius | |
| self.n_thetas = n_thetas | |
| self.rotation_invariant = rotation_invariant | |
| def make_param(in_channels, values, requires_grad=True, dtype=None): | |
| if dtype is None: | |
| dtype = 'float32' | |
| values = numpy.require(values, dtype=dtype) | |
| n = in_channels * len(values) | |
| data=torch.from_numpy(values).view(1,-1) | |
| data = data.repeat(in_channels, 1) | |
| return torch.nn.Parameter(data=data, requires_grad=requires_grad) | |
| # build all learnable parameters | |
| self.sigmas = make_param(in_channels, 2**numpy.arange(n_sigmas)*2) | |
| self.lambdas = make_param(in_channels, 2**numpy.arange(n_lambdas)*4.0) | |
| self.gammas = make_param(in_channels, numpy.ones(n_gammas)*0.5) | |
| self.psis = make_param(in_channels, numpy.array([0, math.pi/2.0])) | |
| print(len(self.sigmas)) | |
| thetas = numpy.linspace(0.0, 2.0*math.pi, num=n_thetas, endpoint=False) | |
| thetas = torch.from_numpy(thetas).float() | |
| self.register_buffer('thetas', thetas) | |
| indices = torch.arange(kernel_size, dtype=torch.float32) - (kernel_size - 1)/2 | |
| self.register_buffer('indices', indices) | |
| # number of channels after the conv | |
| self._n_channels_post_conv = self.in_channels * self.sigmas.shape[1] * \ | |
| self.lambdas.shape[1] * self.gammas.shape[1] * \ | |
| self.psis.shape[1] * self.thetas.shape[0] | |
| def make_gabor_filters(self): | |
| sigmas=self.sigmas | |
| lambdas=self.lambdas | |
| gammas=self.gammas | |
| psis=self.psis | |
| thetas=self.thetas | |
| y=self.indices | |
| x=self.indices | |
| in_channels = sigmas.shape[0] | |
| assert in_channels == lambdas.shape[0] | |
| assert in_channels == gammas.shape[0] | |
| kernel_size = y.shape[0], x.shape[0] | |
| sigmas = sigmas.view (in_channels, sigmas.shape[1],1, 1, 1, 1, 1, 1) | |
| lambdas = lambdas.view(in_channels, 1, lambdas.shape[1],1, 1, 1, 1, 1) | |
| gammas = gammas.view (in_channels, 1, 1, gammas.shape[1], 1, 1, 1, 1) | |
| psis = psis.view (in_channels, 1, 1, 1, psis.shape[1], 1, 1, 1) | |
| thetas = thetas.view(1,1, 1, 1, 1, thetas.shape[0], 1, 1) | |
| y = y.view(1,1, 1, 1, 1, 1, y.shape[0], 1) | |
| x = x.view(1,1, 1, 1, 1, 1, 1, x.shape[0]) | |
| sigma_x = sigmas | |
| sigma_y = sigmas / gammas | |
| sin_t = torch.sin(thetas) | |
| cos_t = torch.cos(thetas) | |
| y_theta = -x * sin_t + y * cos_t | |
| x_theta = x * cos_t + y * sin_t | |
| gb = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \ | |
| * torch.cos(2.0 * math.pi * x_theta / lambdas + psis) | |
| gb = gb.view(-1,kernel_size[0], kernel_size[1]) | |
| return gb | |
| def forward(self, x): | |
| batch_size = x.size(0) | |
| sy = x.size(2) | |
| sx = x.size(3) | |
| gb = self.make_gabor_filters() | |
| assert gb.shape[0] == self._n_channels_post_conv | |
| assert gb.shape[1] == self.kernel_size | |
| assert gb.shape[2] == self.kernel_size | |
| gb = gb.view(self._n_channels_post_conv,1,self.kernel_size,self.kernel_size) | |
| res = nn.functional.conv2d(input=x, weight=gb, | |
| padding=self.kernel_radius, groups=self.in_channels) | |
| if self.rotation_invariant: | |
| res = res.view(batch_size, self.in_channels, -1, self.n_thetas,sy, sx) | |
| res,_ = res.max(dim=3) | |
| res = res.view(batch_size, -1,sy, sx) | |
| return res | |
| if __name__ == "__main__": | |
| import pylab | |
| import skimage.data | |
| astronaut = skimage.data.astronaut() | |
| #astronaut[...,0] = astronaut[...,0].T | |
| astronaut = numpy.moveaxis(astronaut,-1,0)[None,...] | |
| astronaut = torch.from_numpy(astronaut).float() | |
| gb = GaborFilters(in_channels=3) | |
| res = gb(astronaut) | |
| print(res.shape) | |
| for c in range(res.size(1)): | |
| img = res[0,c,...] | |
| img = img.detach().numpy() | |
| fig = pylab.imshow(img) | |
| pylab.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Great to know that u have implemented this layer. I was trying to construct a very similar layer using tensor-flow . Have u seen any TF implementation of this to start with. IN case you have used this in your network, does this increase the training and inference time of the network greatly.