Skip to content

Instantly share code, notes, and snippets.

@soravux
Created November 23, 2017 00:39
Show Gist options
  • Save soravux/18a13892e700b57a964a484e97c314ed to your computer and use it in GitHub Desktop.
Save soravux/18a13892e700b57a964a484e97c314ed to your computer and use it in GitHub Desktop.

Revisions

  1. soravux created this gist Nov 23, 2017.
    73 changes: 73 additions & 0 deletions match_wb.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,73 @@
    import numpy as np
    from scipy.misc import imread, imsave

    from matplotlib import pyplot as plt


    # Get a linear version of the images
    target_im = (imread("lighting_gt.png").astype('float32') / 255.)**2.2 # Poor man un-sRGB
    source_im = (imread("warped_0150.png").astype('float32') / 255.)**2.2

    # This will not work robustly
    # # Get a pixel that *seems* white
    # target_luminance = target_im.dot([0.299, 0.587, 0.114])
    # source_luminance = source_im.dot([0.299, 0.587, 0.114])

    # target_white_gray = np.percentile(target_luminance[target_luminance < 1.], 80)
    # src_white_gray = np.percentile(source_luminance[source_luminance < 1.], 80)

    # target_idx_h, target_idx_w = np.where(target_luminance == target_white_gray)
    # src_idx_h, src_idx_w = np.where(source_luminance == src_white_gray)


    mouse_x, mouse_y = -1, -1
    def onclick(event):
    global mouse_x, mouse_y
    mouse_x = event.xdata
    mouse_y = event.ydata

    # Ask for point in target image
    fig = plt.figure()
    plt.imshow(target_im**(1./2.2))
    cid = fig.canvas.mpl_connect('button_press_event', onclick)
    plt.show()

    target_idx_h = int(mouse_y)
    target_idx_w = int(mouse_x)


    # Ask for point in source image
    fig = plt.figure()
    plt.imshow(source_im**(1./2.2))
    cid = fig.canvas.mpl_connect('button_press_event', onclick)
    plt.show()

    src_idx_h = int(mouse_y)
    src_idx_w = int(mouse_x)


    # Get area of 5x5 pixels around the selected positions
    target_white = np.mean(target_im[target_idx_h-2:target_idx_h+3, target_idx_w-2:target_idx_w+3, :], axis=(0, 1))
    src_white = np.mean(source_im[src_idx_h-2:src_idx_h+3, src_idx_w-2:src_idx_w+3, :], axis=(0, 1))

    # Normalize the colors (we don't want to change brightness, just color correction)
    target_white /= target_white.sum()
    src_white /= src_white.sum()

    print("target white:", target_white)
    print("source white:", src_white)

    # Apply white balance correction
    fix_matrix = np.diag(target_white/src_white)
    source_fixed = source_im.dot(fix_matrix)
    print("Correction matrix:\n", fix_matrix)


    plt.subplot(221); plt.imshow(target_im**(1./2.2)); plt.scatter(target_idx_w, target_idx_h, s=1); plt.axis('off')
    plt.subplot(222); plt.imshow(source_im**(1./2.2)); plt.scatter(src_idx_w, src_idx_h, s=1); plt.axis('off')
    plt.subplot(223); plt.imshow(np.clip(source_fixed, 0, 1)**(1./2.2)); plt.axis('off')
    plt.show()

    # Save back as sRGB
    source_fixed_sRGB = np.clip((255.*source_fixed**(1./2.2)), 0, 255).astype('uint8')
    imsave("warped_0150_fixed.png", source_fixed_sRGB)