Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save chenerg/1d56abca5124fcb9f2c5e5f32c9460d3 to your computer and use it in GitHub Desktop.

Select an option

Save chenerg/1d56abca5124fcb9f2c5e5f32c9460d3 to your computer and use it in GitHub Desktop.

Revisions

  1. @peteflorence peteflorence revised this gist Jan 30, 2018. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions pytorch_bilinear_interpolation.md
    Original file line number Diff line number Diff line change
    @@ -86,6 +86,7 @@ def bilinear_interpolate_scipy(image, x, y):
    interp_func = scipy.interpolate.interp2d(x_indices, y_indices, image, kind='linear')
    return interp_func(x,y)

    # Make small sample data that's easy to interpret
    image = np.ones((5,5))
    image[3,3] = 4
    image[3,4] = 3
  2. @peteflorence peteflorence revised this gist Jan 30, 2018. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion pytorch_bilinear_interpolation.md
    Original file line number Diff line number Diff line change
    @@ -75,7 +75,7 @@ Bilinear interpolation is very simple but there are a few things that can be eas

    I did a quick comparison for correctness with SciPy's `interp2d`.

    - Side note: there are actually a ton of [interpolation options in SciPy](http://scipy.github.io/devdocs/interpolate.html#module-scipy.interpolate) but none I tested met my critera of (a) doing _bilinear_ interpolation for high-dimensional spaces and (b) efficiently use gridded data. The ones I tested that were built for many dimensions were requiring me to specify sample points for all of those dimensions (and doing trilinear, or other) interpolation. I could get `LinearNDInterpolator` to do bilinear interpolation for high dimensional vectors but this does not meet criteria (b). There's probably a better option but, at any rate, I gave up and went back to my numpy and PyTorch options :)
    - Side note: there are actually a ton of [interpolation options in SciPy](http://scipy.github.io/devdocs/interpolate.html#module-scipy.interpolate) but none I tested met my critera of (a) doing _bilinear_ interpolation for high-dimensional spaces and (b) efficiently use gridded data. The ones I tested that were built for many dimensions were requiring me to specify sample points for all of those dimensions (and doing trilinear, or other) interpolation. I could get `LinearNDInterpolator` to do bilinear interpolation for high dimensional vectors but this does not meet criteria (b). There's probably a better option but, at any rate, I gave up and went back to my numpy and PyTorch options.

    ```python
    # Also use scipy to check for correctness
  3. @peteflorence peteflorence created this gist Jan 30, 2018.
    243 changes: 243 additions & 0 deletions pytorch_bilinear_interpolation.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,243 @@
    Here's a simple implementation of bilinear interpolation on tensors using PyTorch.

    I wrote this up since I ended up learning a lot about options for interpolation in both the numpy and PyTorch ecosystems. More generally than just interpolation, too, it's also a nice case study in how PyTorch magically can put very numpy-like code on the GPU (and by the way, do autodiff for you too).

    For interpolation in PyTorch, this [open issue](https://github.com/pytorch/pytorch/issues/1552) calls for more interpolation features. There is now a [`nn.functional.grid_sample()`](http://pytorch.org/docs/0.3.0/nn.html#torch.nn.functional.grid_sample) feature but at least at first this didn't look like what I needed (but we'll come back to this later).

    In particular I wanted to take an image, `W x H x C`, and sample it many times at different random locations. Note also that this is different than [upsampling](http://pytorch.org/docs/master/nn.html#torch.nn.Upsample) which exhaustively samples and also doesn't give us flexibility with the precision of sampling.

    ## The implementations: numpy and PyTorch

    First let's look at a comparable implementation in numpy which is slightly modified from [here](https://stackoverflow.com/questions/12729228/simple-efficient-bilinear-interpolation-of-images-in-numpy-and-python).

    ```python
    import numpy as np

    def bilinear_interpolate_numpy(im, x, y):
    x0 = np.floor(x).astype(int)
    x1 = x0 + 1
    y0 = np.floor(y).astype(int)
    y1 = y0 + 1

    x0 = np.clip(x0, 0, im.shape[1]-1)
    x1 = np.clip(x1, 0, im.shape[1]-1)
    y0 = np.clip(y0, 0, im.shape[0]-1)
    y1 = np.clip(y1, 0, im.shape[0]-1)

    Ia = im[ y0, x0 ]
    Ib = im[ y1, x0 ]
    Ic = im[ y0, x1 ]
    Id = im[ y1, x1 ]

    wa = (x1-x) * (y1-y)
    wb = (x1-x) * (y-y0)
    wc = (x-x0) * (y1-y)
    wd = (x-x0) * (y-y0)

    return (Ia.T*wa).T + (Ib.T*wb).T + (Ic.T*wc).T + (Id.T*wd).T
    ```

    And now here I've converted this implementation to PyTorch:

    ```python
    import torch
    dtype = torch.cuda.FloatTensor
    dtype_long = torch.cuda.LongTensor

    def bilinear_interpolate_torch(im, x, y):
    x0 = torch.floor(x).type(dtype_long)
    x1 = x0 + 1

    y0 = torch.floor(y).type(dtype_long)
    y1 = y0 + 1

    x0 = torch.clamp(x0, 0, im.shape[1]-1)
    x1 = torch.clamp(x1, 0, im.shape[1]-1)
    y0 = torch.clamp(y0, 0, im.shape[0]-1)
    y1 = torch.clamp(y1, 0, im.shape[0]-1)

    Ia = im[ y0, x0 ][0]
    Ib = im[ y1, x0 ][0]
    Ic = im[ y0, x1 ][0]
    Id = im[ y1, x1 ][0]

    wa = (x1.type(dtype)-x) * (y1.type(dtype)-y)
    wb = (x1.type(dtype)-x) * (y-y0.type(dtype))
    wc = (x-x0.type(dtype)) * (y1.type(dtype)-y)
    wd = (x-x0.type(dtype)) * (y-y0.type(dtype))

    return torch.t((torch.t(Ia)*wa)) + torch.t(torch.t(Ib)*wb) + torch.t(torch.t(Ic)*wc) + torch.t(torch.t(Id)*wd)
    ```

    ## Testing for correctness

    Bilinear interpolation is very simple but there are a few things that can be easily messed up.

    I did a quick comparison for correctness with SciPy's `interp2d`.

    - Side note: there are actually a ton of [interpolation options in SciPy](http://scipy.github.io/devdocs/interpolate.html#module-scipy.interpolate) but none I tested met my critera of (a) doing _bilinear_ interpolation for high-dimensional spaces and (b) efficiently use gridded data. The ones I tested that were built for many dimensions were requiring me to specify sample points for all of those dimensions (and doing trilinear, or other) interpolation. I could get `LinearNDInterpolator` to do bilinear interpolation for high dimensional vectors but this does not meet criteria (b). There's probably a better option but, at any rate, I gave up and went back to my numpy and PyTorch options :)

    ```python
    # Also use scipy to check for correctness
    import scipy.interpolate
    def bilinear_interpolate_scipy(image, x, y):
    x_indices = np.arange(image.shape[0])
    y_indices = np.arange(image.shape[1])
    interp_func = scipy.interpolate.interp2d(x_indices, y_indices, image, kind='linear')
    return interp_func(x,y)

    image = np.ones((5,5))
    image[3,3] = 4
    image[3,4] = 3

    sample_x, sample_y = np.asarray([3.2]), np.asarray([3.4])

    print "numpy result:", bilinear_interpolate_numpy(image, sample_x, sample_y)
    print "scipy result:", bilinear_interpolate_scipy(image, sample_x, sample_y)

    image = torch.unsqueeze(torch.FloatTensor(image).type(dtype),2)
    sample_x = torch.FloatTensor([sample_x]).type(dtype)
    sample_y = torch.FloatTensor([sample_y]).type(dtype)

    print "torch result:", bilinear_interpolate_torch(image, sample_x, sample_y)
    ```

    The above gives:

    ```
    numpy result: [2.68]
    scipy result: [2.68]
    torch result:
    2.6800
    [torch.cuda.FloatTensor of size 1x1 (GPU 0)]
    ```

    ## High dimensional bilinear interpolation

    For the correctness test comparing with scipy, we couldn't do `W x H x C` interpolation for anything but `C=1`. Now though, we can do bilinear interpolation in either numpy or torch for arbitrary `C`:

    ```python
    # Do high dimensional bilinear interpolation in numpy and PyTorch
    W, H, C = 25, 25, 7
    image = np.random.randn(W, H, C)

    num_samples = 4
    samples_x, samples_y = np.random.rand(num_samples)*(W-1), np.random.rand(num_samples)*(H-1)

    print bilinear_interpolate_numpy(image, samples_x, samples_y)

    image = torch.from_numpy(image).type(dtype)
    samples_x = torch.FloatTensor([samples_x]).type(dtype)
    samples_y = torch.FloatTensor([samples_y]).type(dtype)

    print bilinear_interpolate_torch(image, samples_x, samples_y)
    ```

    You'll find that the above numpy and torch versions give the same result.

    ## Bechmarking: numpy (CPU) vs. pytorch (CPU) vs. pytorch (GPU)

    Now we do some simple benchmarking:

    ```python
    # Timing comparison for WxHxC (where C is large for a high dimensional descriptor)
    W, H, C = 640, 480, 32
    image = np.random.randn(W, H, C)

    num_samples = 10000
    samples_x, samples_y = np.random.rand(num_samples)*(W-1), np.random.rand(num_samples)*(H-1)

    import time

    start = time.time()
    bilinear_interpolate_numpy(image, samples_x, samples_y)
    print "numpy took ", time.time() - start

    dtype = torch.FloatTensor
    dtype_long = torch.LongTensor
    image = torch.FloatTensor(image).type(dtype)
    samples_x = torch.FloatTensor([samples_x]).type(dtype)
    samples_y = torch.FloatTensor([samples_y]).type(dtype)

    start = time.time()
    bilinear_interpolate_torch(image, samples_x, samples_y)
    print "torch on CPU took", time.time() - start

    dtype = torch.cuda.FloatTensor
    dtype_long = torch.cuda.LongTensor
    image = image.type(dtype)
    samples_x = samples_x.type(dtype)
    samples_y = samples_y.type(dtype)

    start = time.time()
    bilinear_interpolate_torch(image, samples_x, samples_y)
    print "torch on GPU took", time.time() - start
    ```
    On my machine (CPU: 10-core i7-6950X, GPU: GTX 1080) I get the following times (in seconds):

    ```
    numpy took 0.00756597518921
    torch on CPU took 0.12672996521
    torch on GPU took 0.000642061233521
    ```

    Interestingly we have torch on the GPU beating numpy (CPU-only) by about 10x. I'm not sure why torch on the CPU is that slow for this test case. Note that the ratios between these change quite drastically for different `W, H, C, num_samples`.

    ## Using the available `nn.functional.grid_sample()`

    I ended up figuring out how to use [`nn.functional.grid_sample()`](http://pytorch.org/docs/0.3.0/nn.html#torch.nn.functional.grid_sample) although it was a little odd of a fit for my needs. (Data needs to be in `N x C x W x H` tensor input, and samples need to be as normalized between [-1,1], and AFAIK the `WxH` ordering of the samples do not have any meaning other -- they are completely separate samples.)

    It was good practice in using `permute, multiple unsqueezes, cat`.

    ```python
    import torch.nn.functional
    dtype = torch.cuda.FloatTensor
    dtype_long = torch.cuda.LongTensor

    def bilinear_interpolate_torch_gridsample(image, samples_x, samples_y):
    # input image is: W x H x C
    image = image.permute(2,0,1) # change to: C x W x H
    image = image.unsqueeze(0) # change to: 1 x C x W x H
    samples_x = samples_x.unsqueeze(2)
    samples_x = samples_x.unsqueeze(3)
    samples_y = samples_y.unsqueeze(2)
    samples_y = samples_y.unsqueeze(3)
    samples = torch.cat([samples_x, samples_y],3)
    samples[:,:,:,0] = (samples[:,:,:,0]/(W-1)) # normalize to between 0 and 1
    samples[:,:,:,1] = (samples[:,:,:,1]/(H-1)) # normalize to between 0 and 1
    samples = samples*2-1 # normalize to between -1 and 1
    return torch.nn.functional.grid_sample(image, samples)

    # Correctness test
    W, H, C = 5, 5, 1
    test_image = torch.ones(W,H,C).type(dtype)
    test_image[3,3,:] = 4
    test_image[3,4,:] = 3

    test_samples_x = torch.FloatTensor([[3.2]]).type(dtype)
    test_samples_y = torch.FloatTensor([[3.4]]).type(dtype)

    print bilinear_interpolate_torch_gridsample(test_image, test_samples_x, test_samples_y)

    # Benchmark
    start = time.time()
    bilinear_interpolate_torch_gridsample(image, samples_x, samples_y)
    print "torch gridsample took ", time.time() - start
    ```
    My wrapping of grid_sample produces the same bilinear interpolation results and at speeds comparable to our `bilinear_interpolate_torch()` function:

    ```
    Variable containing:
    (0 ,0 ,.,.) =
    2.6800
    [torch.cuda.FloatTensor of size 1x1x1x1 (GPU 0)]
    torch gridsample took 0.000624895095825
    ```
    Another note about the `nn.functional.grid_sample()` interface is that it forces the sampled interpolations into a `Variable()` wrapper even though this seems best left to the programmer to decide.

    ## Conclusions

    - It's surprisingly easy to convert powerful vectorized numpy code into more-powerful vectorized PyTorch code
    - PyTorch is very fast on the GPU
    - Some of the higher-level feature (like `nn.function.grid_sample`) are nice but so too is writing your own tensor manipulations (and can be comparably fast)