Here's a simple implementation of bilinear interpolation on tensors using PyTorch. I wrote this up since I ended up learning a lot about options for interpolation in both the numpy and PyTorch ecosystems. More generally than just interpolation, too, it's also a nice case study in how PyTorch magically can put very numpy-like code on the GPU (and by the way, do autodiff for you too). For interpolation in PyTorch, this [open issue](https://github.com/pytorch/pytorch/issues/1552) calls for more interpolation features. There is now a [`nn.functional.grid_sample()`](http://pytorch.org/docs/0.3.0/nn.html#torch.nn.functional.grid_sample) feature but at least at first this didn't look like what I needed (but we'll come back to this later). In particular I wanted to take an image, `W x H x C`, and sample it many times at different random locations. Note also that this is different than [upsampling](http://pytorch.org/docs/master/nn.html#torch.nn.Upsample) which exhaustively samples and also doesn't give us flexibility with the precision of sampling. ## The implementations: numpy and PyTorch First let's look at a comparable implementation in numpy which is slightly modified from [here](https://stackoverflow.com/questions/12729228/simple-efficient-bilinear-interpolation-of-images-in-numpy-and-python). ```python import numpy as np def bilinear_interpolate_numpy(im, x, y): x0 = np.floor(x).astype(int) x1 = x0 + 1 y0 = np.floor(y).astype(int) y1 = y0 + 1 x0 = np.clip(x0, 0, im.shape[1]-1) x1 = np.clip(x1, 0, im.shape[1]-1) y0 = np.clip(y0, 0, im.shape[0]-1) y1 = np.clip(y1, 0, im.shape[0]-1) Ia = im[ y0, x0 ] Ib = im[ y1, x0 ] Ic = im[ y0, x1 ] Id = im[ y1, x1 ] wa = (x1-x) * (y1-y) wb = (x1-x) * (y-y0) wc = (x-x0) * (y1-y) wd = (x-x0) * (y-y0) return (Ia.T*wa).T + (Ib.T*wb).T + (Ic.T*wc).T + (Id.T*wd).T ``` And now here I've converted this implementation to PyTorch: ```python import torch dtype = torch.cuda.FloatTensor dtype_long = torch.cuda.LongTensor def bilinear_interpolate_torch(im, x, y): x0 = torch.floor(x).type(dtype_long) x1 = x0 + 1 y0 = torch.floor(y).type(dtype_long) y1 = y0 + 1 x0 = torch.clamp(x0, 0, im.shape[1]-1) x1 = torch.clamp(x1, 0, im.shape[1]-1) y0 = torch.clamp(y0, 0, im.shape[0]-1) y1 = torch.clamp(y1, 0, im.shape[0]-1) Ia = im[ y0, x0 ][0] Ib = im[ y1, x0 ][0] Ic = im[ y0, x1 ][0] Id = im[ y1, x1 ][0] wa = (x1.type(dtype)-x) * (y1.type(dtype)-y) wb = (x1.type(dtype)-x) * (y-y0.type(dtype)) wc = (x-x0.type(dtype)) * (y1.type(dtype)-y) wd = (x-x0.type(dtype)) * (y-y0.type(dtype)) return torch.t((torch.t(Ia)*wa)) + torch.t(torch.t(Ib)*wb) + torch.t(torch.t(Ic)*wc) + torch.t(torch.t(Id)*wd) ``` ## Testing for correctness Bilinear interpolation is very simple but there are a few things that can be easily messed up. I did a quick comparison for correctness with SciPy's `interp2d`. - Side note: there are actually a ton of [interpolation options in SciPy](http://scipy.github.io/devdocs/interpolate.html#module-scipy.interpolate) but none I tested met my critera of (a) doing _bilinear_ interpolation for high-dimensional spaces and (b) efficiently use gridded data. The ones I tested that were built for many dimensions were requiring me to specify sample points for all of those dimensions (and doing trilinear, or other) interpolation. I could get `LinearNDInterpolator` to do bilinear interpolation for high dimensional vectors but this does not meet criteria (b). There's probably a better option but, at any rate, I gave up and went back to my numpy and PyTorch options. ```python # Also use scipy to check for correctness import scipy.interpolate def bilinear_interpolate_scipy(image, x, y): x_indices = np.arange(image.shape[0]) y_indices = np.arange(image.shape[1]) interp_func = scipy.interpolate.interp2d(x_indices, y_indices, image, kind='linear') return interp_func(x,y) # Make small sample data that's easy to interpret image = np.ones((5,5)) image[3,3] = 4 image[3,4] = 3 sample_x, sample_y = np.asarray([3.2]), np.asarray([3.4]) print "numpy result:", bilinear_interpolate_numpy(image, sample_x, sample_y) print "scipy result:", bilinear_interpolate_scipy(image, sample_x, sample_y) image = torch.unsqueeze(torch.FloatTensor(image).type(dtype),2) sample_x = torch.FloatTensor([sample_x]).type(dtype) sample_y = torch.FloatTensor([sample_y]).type(dtype) print "torch result:", bilinear_interpolate_torch(image, sample_x, sample_y) ``` The above gives: ``` numpy result: [2.68] scipy result: [2.68] torch result: 2.6800 [torch.cuda.FloatTensor of size 1x1 (GPU 0)] ``` ## High dimensional bilinear interpolation For the correctness test comparing with scipy, we couldn't do `W x H x C` interpolation for anything but `C=1`. Now though, we can do bilinear interpolation in either numpy or torch for arbitrary `C`: ```python # Do high dimensional bilinear interpolation in numpy and PyTorch W, H, C = 25, 25, 7 image = np.random.randn(W, H, C) num_samples = 4 samples_x, samples_y = np.random.rand(num_samples)*(W-1), np.random.rand(num_samples)*(H-1) print bilinear_interpolate_numpy(image, samples_x, samples_y) image = torch.from_numpy(image).type(dtype) samples_x = torch.FloatTensor([samples_x]).type(dtype) samples_y = torch.FloatTensor([samples_y]).type(dtype) print bilinear_interpolate_torch(image, samples_x, samples_y) ``` You'll find that the above numpy and torch versions give the same result. ## Bechmarking: numpy (CPU) vs. pytorch (CPU) vs. pytorch (GPU) Now we do some simple benchmarking: ```python # Timing comparison for WxHxC (where C is large for a high dimensional descriptor) W, H, C = 640, 480, 32 image = np.random.randn(W, H, C) num_samples = 10000 samples_x, samples_y = np.random.rand(num_samples)*(W-1), np.random.rand(num_samples)*(H-1) import time start = time.time() bilinear_interpolate_numpy(image, samples_x, samples_y) print "numpy took ", time.time() - start dtype = torch.FloatTensor dtype_long = torch.LongTensor image = torch.FloatTensor(image).type(dtype) samples_x = torch.FloatTensor([samples_x]).type(dtype) samples_y = torch.FloatTensor([samples_y]).type(dtype) start = time.time() bilinear_interpolate_torch(image, samples_x, samples_y) print "torch on CPU took", time.time() - start dtype = torch.cuda.FloatTensor dtype_long = torch.cuda.LongTensor image = image.type(dtype) samples_x = samples_x.type(dtype) samples_y = samples_y.type(dtype) start = time.time() bilinear_interpolate_torch(image, samples_x, samples_y) print "torch on GPU took", time.time() - start ``` On my machine (CPU: 10-core i7-6950X, GPU: GTX 1080) I get the following times (in seconds): ``` numpy took 0.00756597518921 torch on CPU took 0.12672996521 torch on GPU took 0.000642061233521 ``` Interestingly we have torch on the GPU beating numpy (CPU-only) by about 10x. I'm not sure why torch on the CPU is that slow for this test case. Note that the ratios between these change quite drastically for different `W, H, C, num_samples`. ## Using the available `nn.functional.grid_sample()` I ended up figuring out how to use [`nn.functional.grid_sample()`](http://pytorch.org/docs/0.3.0/nn.html#torch.nn.functional.grid_sample) although it was a little odd of a fit for my needs. (Data needs to be in `N x C x W x H` tensor input, and samples need to be as normalized between [-1,1], and AFAIK the `WxH` ordering of the samples do not have any meaning other -- they are completely separate samples.) It was good practice in using `permute, multiple unsqueezes, cat`. ```python import torch.nn.functional dtype = torch.cuda.FloatTensor dtype_long = torch.cuda.LongTensor def bilinear_interpolate_torch_gridsample(image, samples_x, samples_y): # input image is: W x H x C image = image.permute(2,0,1) # change to: C x W x H image = image.unsqueeze(0) # change to: 1 x C x W x H samples_x = samples_x.unsqueeze(2) samples_x = samples_x.unsqueeze(3) samples_y = samples_y.unsqueeze(2) samples_y = samples_y.unsqueeze(3) samples = torch.cat([samples_x, samples_y],3) samples[:,:,:,0] = (samples[:,:,:,0]/(W-1)) # normalize to between 0 and 1 samples[:,:,:,1] = (samples[:,:,:,1]/(H-1)) # normalize to between 0 and 1 samples = samples*2-1 # normalize to between -1 and 1 return torch.nn.functional.grid_sample(image, samples) # Correctness test W, H, C = 5, 5, 1 test_image = torch.ones(W,H,C).type(dtype) test_image[3,3,:] = 4 test_image[3,4,:] = 3 test_samples_x = torch.FloatTensor([[3.2]]).type(dtype) test_samples_y = torch.FloatTensor([[3.4]]).type(dtype) print bilinear_interpolate_torch_gridsample(test_image, test_samples_x, test_samples_y) # Benchmark start = time.time() bilinear_interpolate_torch_gridsample(image, samples_x, samples_y) print "torch gridsample took ", time.time() - start ``` My wrapping of grid_sample produces the same bilinear interpolation results and at speeds comparable to our `bilinear_interpolate_torch()` function: ``` Variable containing: (0 ,0 ,.,.) = 2.6800 [torch.cuda.FloatTensor of size 1x1x1x1 (GPU 0)] torch gridsample took 0.000624895095825 ``` Another note about the `nn.functional.grid_sample()` interface is that it forces the sampled interpolations into a `Variable()` wrapper even though this seems best left to the programmer to decide. ## Conclusions - It's surprisingly easy to convert powerful vectorized numpy code into more-powerful vectorized PyTorch code - PyTorch is very fast on the GPU - Some of the higher-level feature (like `nn.function.grid_sample`) are nice but so too is writing your own tensor manipulations (and can be comparably fast)