Skip to content

Instantly share code, notes, and snippets.

@alexlib
Created April 26, 2025 06:38
Show Gist options
  • Save alexlib/15a2334103f42dd0edbf11c7071f3bc4 to your computer and use it in GitHub Desktop.
Save alexlib/15a2334103f42dd0edbf11c7071f3bc4 to your computer and use it in GitHub Desktop.

Revisions

  1. alexlib created this gist Apr 26, 2025.
    14 changes: 14 additions & 0 deletions Readme.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,14 @@
    ```
    test_inpaint_equivalence.py
    ..
    ----------------------------------------------------------------------
    Ran 2 tests in 1.908s
    OK
    (.conda) user@user-NUC8i7BEH:~/Documents/repos/OpenOpticalFlow_PIV_v1$ /home/user/Documents/repos/OpenOpticalFlow_PIV_v1/.conda/bin/python /home/user/Documents/repos/OpenOpticalFlow_PIV_v1/test_inpaint_equivalence.py
    ..
    ----------------------------------------------------------------------
    Ran 2 tests in 1.504s
    OK
    ```
    19 changes: 19 additions & 0 deletions inpaint3d.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,19 @@
    import numpy as np
    from scipy.interpolate import griddata

    def inpaint_nans_3d(array):
    # Get coordinates of non-nan values
    valid_mask = ~np.isnan(array)
    coords = np.array(np.nonzero(valid_mask)).T
    values = array[valid_mask]

    # Get coordinates of nan values
    nan_coords = np.array(np.nonzero(~valid_mask)).T

    # Interpolate
    filled_values = griddata(coords, values, nan_coords, method='linear')

    # Create output array and fill with interpolated values
    result = array.copy()
    result[~valid_mask] = filled_values
    return result
    287 changes: 287 additions & 0 deletions inpaint_nans3.m
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,287 @@
    function B=inpaint_nans3(A,method)
    % INPAINT_NANS3: in-paints over nans in a 3-D array
    % usage: B=INPAINT_NANS3(A) % default method (0)
    % usage: B=INPAINT_NANS3(A,method) % specify method used
    %
    % Solves approximation to a boundary value problem to
    % interpolate and extrapolate holes in a 3-D array.
    %
    % Note that if the array is large, and there are many NaNs
    % to be filled in, this may take a long time, or run into
    % memory problems.
    %
    % arguments (input):
    % A - n1 x n2 x n3 array with some NaNs to be filled in
    %
    % method - (OPTIONAL) scalar numeric flag - specifies
    % which approach (or physical metaphor to use
    % for the interpolation.) All methods are capable
    % of extrapolation, some are better than others.
    % There are also speed differences, as well as
    % accuracy differences for smooth surfaces.
    %
    % method 0 uses a simple plate metaphor.
    % method 1 uses a spring metaphor.
    %
    % method == 0 --> (DEFAULT) Solves the Laplacian
    % equation over the set of nan elements in the
    % array.
    % Extrapolation behavior is roughly linear.
    %
    % method == 1 --+ Uses a spring metaphor. Assumes
    % springs (with a nominal length of zero)
    % connect each node with every neighbor
    % (horizontally, vertically and diagonally)
    % Since each node tries to be like its neighbors,
    % extrapolation is roughly a constant function where
    % this is consistent with the neighboring nodes.
    %
    % There are only two different methods in this code,
    % chosen as the most useful ones (IMHO) from my
    % original inpaint_nans code.
    %
    %
    % arguments (output):
    % B - n1xn2xn3 array with NaNs replaced
    %
    %
    % Example:
    % % A linear function of 3 independent variables,
    % % used to test whether inpainting will interpolate
    % % the missing elements correctly.
    % [x,y,z] = ndgrid(-10:10,-10:10,-10:10);
    % W = x + y + z;
    %
    % % Pick a set of distinct random elements to NaN out.
    % ind = unique(ceil(rand(3000,1)*numel(W)));
    % Wnan = W;
    % Wnan(ind) = NaN;
    %
    % % Do inpainting
    % Winp = inpaint_nans3(Wnan,0);
    %
    % % Show that the inpainted values are essentially
    % % within eps of the originals.
    % std(Winp(ind) - W(ind))
    % ans =
    % 4.3806e-15
    %
    %
    % See also: griddatan, inpaint_nans
    %
    % Author: John D'Errico
    % e-mail address: [email protected]
    % Release: 1
    % Release date: 8/21/08

    % Need to know which elements are NaN, and
    % what size is the array. Unroll A for the
    % inpainting, although inpainting will be done
    % fully in 3-d.
    NA = size(A);
    A = A(:);
    nt = prod(NA);
    k = isnan(A(:));

    % list the nodes which are known, and which will
    % be interpolated
    nan_list=find(k);
    known_list=find(~k);

    % how many nans overall
    nan_count=length(nan_list);

    % convert NaN indices to (r,c) form
    % nan_list==find(k) are the unrolled (linear) indices
    % (row,column) form
    [n1,n2,n3]=ind2sub(NA,nan_list);

    % both forms of index for all the nan elements in one array:
    % column 1 == unrolled index
    % column 2 == index 1
    % column 3 == index 2
    % column 4 == index 3
    nan_list=[nan_list,n1,n2,n3];

    % supply default method
    if (nargin<2) || isempty(method)
    method = 0;
    elseif ~ismember(method,[0 1])
    error 'If supplied, method must be one of: {0,1}.'
    end

    % alternative methods
    switch method
    case 0
    % The same as method == 1, except only work on those
    % elements which are NaN, or at least touch a NaN.

    % horizontal and vertical neighbors only
    talks_to = [-1 0 0;1 0 0;0 -1 0;0 1 0;0 0 -1;0 0 1];
    neighbors_list=identify_neighbors(NA,nan_list,talks_to);

    % list of all nodes we have identified
    all_list=[nan_list;neighbors_list];

    % generate sparse array with second partials on row
    % variable for each element in either list, but only
    % for those nodes which have a row index > 1 or < n
    L = find((all_list(:,2) > 1) & (all_list(:,2) < NA(1)));
    nL=length(L);
    if nL>0
    fda=sparse(repmat(all_list(L,1),1,3), ...
    repmat(all_list(L,1),1,3)+repmat([-1 0 1],nL,1), ...
    repmat([1 -2 1],nL,1),nt,nt);
    else
    fda=spalloc(nt,nt,size(all_list,1)*7);
    end

    % 2nd partials on column index
    L = find((all_list(:,3) > 1) & (all_list(:,3) < NA(2)));
    nL=length(L);
    if nL>0
    fda=fda+sparse(repmat(all_list(L,1),1,3), ...
    repmat(all_list(L,1),1,3)+repmat([-NA(1) 0 NA(1)],nL,1), ...
    repmat([1 -2 1],nL,1),nt,nt);
    end

    % 2nd partials on third index
    L = find((all_list(:,4) > 1) & (all_list(:,4) < NA(3)));
    nL=length(L);
    if nL>0
    ntimesm = NA(1)*NA(2);
    fda=fda+sparse(repmat(all_list(L,1),1,3), ...
    repmat(all_list(L,1),1,3)+repmat([-ntimesm 0 ntimesm],nL,1), ...
    repmat([1 -2 1],nL,1),nt,nt);
    end

    % eliminate knowns
    rhs=-fda(:,known_list)*A(known_list);
    k=find(any(fda(:,nan_list(:,1)),2));

    % and solve...
    B=A;
    B(nan_list(:,1))=fda(k,nan_list(:,1))\rhs(k);

    case 1
    % Spring analogy
    % interpolating operator.

    % list of all springs between a node and a horizontal
    % or vertical neighbor
    hv_list=[-1 -1 0 0;1 1 0 0;-NA(1) 0 -1 0;NA(1) 0 1 0; ...
    -NA(1)*NA(2) 0 0 -1;NA(1)*NA(2) 0 0 1];
    hv_springs=[];
    for i=1:size(hv_list,1)
    hvs=nan_list+repmat(hv_list(i,:),nan_count,1);
    k=(hvs(:,2)>=1) & (hvs(:,2)<=NA(1)) & ...
    (hvs(:,3)>=1) & (hvs(:,3)<=NA(2)) & ...
    (hvs(:,4)>=1) & (hvs(:,4)<=NA(3));
    hv_springs=[hv_springs;[nan_list(k,1),hvs(k,1)]];
    end

    % delete replicate springs
    hv_springs=unique(sort(hv_springs,2),'rows');

    % build sparse matrix of connections
    nhv=size(hv_springs,1);
    springs=sparse(repmat((1:nhv)',1,2),hv_springs, ...
    repmat([1 -1],nhv,1),nhv,prod(NA));

    % eliminate knowns
    rhs=-springs(:,known_list)*A(known_list);

    % and solve...
    B=A;
    B(nan_list(:,1))=springs(:,nan_list(:,1))\rhs;

    end

    % all done, make sure that B is the same shape as
    % A was when we came in.
    B=reshape(B,NA);


    % ====================================================
    % end of main function
    % ====================================================
    % ====================================================
    % begin subfunctions
    % ====================================================
    function neighbors_list=identify_neighbors(NA,nan_list,talks_to)
    % identify_neighbors: identifies all the neighbors of
    % those nodes in nan_list, not including the nans
    % themselves
    %
    % arguments (input):
    % NA - 1x3 vector = size(A), where A is the
    % array to be interpolated
    % nan_list - array - list of every nan element in A
    % nan_list(i,1) == linear index of i'th nan element
    % nan_list(i,2) == row index of i'th nan element
    % nan_list(i,3) == column index of i'th nan element
    % nan_list(i,4) == third index of i'th nan element
    % talks_to - px2 array - defines which nodes communicate
    % with each other, i.e., which nodes are neighbors.
    %
    % talks_to(i,1) - defines the offset in the row
    % dimension of a neighbor
    % talks_to(i,2) - defines the offset in the column
    % dimension of a neighbor
    %
    % For example, talks_to = [-1 0;0 -1;1 0;0 1]
    % means that each node talks only to its immediate
    % neighbors horizontally and vertically.
    %
    % arguments(output):
    % neighbors_list - array - list of all neighbors of
    % all the nodes in nan_list

    if ~isempty(nan_list)
    % use the definition of a neighbor in talks_to
    nan_count=size(nan_list,1);
    talk_count=size(talks_to,1);

    nn=zeros(nan_count*talk_count,3);
    j=[1,nan_count];
    for i=1:talk_count
    nn(j(1):j(2),:)=nan_list(:,2:4) + ...
    repmat(talks_to(i,:),nan_count,1);
    j=j+nan_count;
    end

    % drop those nodes which fall outside the bounds of the
    % original array
    L = (nn(:,1)<1) | (nn(:,1)>NA(1)) | ...
    (nn(:,2)<1) | (nn(:,2)>NA(2)) | ...
    (nn(:,3)<1) | (nn(:,3)>NA(3));
    nn(L,:)=[];

    % form the same format 4 column array as nan_list
    neighbors_list=[sub2ind(NA,nn(:,1),nn(:,2),nn(:,3)),nn];

    % delete replicates in the neighbors list
    neighbors_list=unique(neighbors_list,'rows');

    % and delete those which are also in the list of NaNs.
    neighbors_list=setdiff(neighbors_list,nan_list,'rows');

    else
    neighbors_list=[];
    end
















    16 changes: 16 additions & 0 deletions test_inpaint_comparison.m
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,16 @@
    % Load the test data generated by Python
    load('test_data.mat');

    % Verify the data shape
    size_data = size(data_with_nans);
    fprintf('Input data size: %dx%dx%d\n', size_data(1), size_data(2), size_data(3));

    % Run inpaint_nans3
    inpainted_data = inpaint_nans3(data_with_nans, 0);

    % Verify the output shape
    size_result = size(inpainted_data);
    fprintf('Output data size: %dx%dx%d\n', size_result(1), size_result(2), size_result(3));

    % Save the results for Python to read
    save('matlab_result.mat', 'inpainted_data');
    76 changes: 76 additions & 0 deletions test_inpaint_comparison.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,76 @@
    import numpy as np
    import scipy.io as sio
    import matplotlib.pyplot as plt
    from inpaint3d import inpaint_nans_3d

    def generate_test_data(shape=(20, 20, 20)):
    """Generate test data with known NaN positions"""
    # Create a 3D test array with a known pattern
    x, y, z = np.meshgrid(np.linspace(-2, 2, shape[0]),
    np.linspace(-2, 2, shape[1]),
    np.linspace(-2, 2, shape[2]))
    data = np.sin(x) * np.cos(y) * np.exp(-0.1 * (x**2 + y**2 + z**2))

    # Add NaN values randomly
    nan_mask = np.random.random(data.shape) < 0.3
    data_with_nans = data.copy()
    data_with_nans[nan_mask] = np.nan

    return data, data_with_nans, nan_mask

    def compare_results(original, matlab_result, python_result):
    """Compare the results between MATLAB and Python implementations"""
    # Calculate error metrics
    matlab_error = np.nanmean(np.abs(original - matlab_result))
    python_error = np.nanmean(np.abs(original - python_result))
    difference = np.nanmean(np.abs(matlab_result - python_result))

    print(f"Mean Absolute Error (MATLAB): {matlab_error:.6f}")
    print(f"Mean Absolute Error (Python): {python_error:.6f}")
    print(f"Mean Difference between implementations: {difference:.6f}")

    # Visualize middle slices
    mid_slice = original.shape[2] // 2

    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    axes[0,0].imshow(original[:,:,mid_slice])
    axes[0,0].set_title('Original')

    axes[0,1].imshow(matlab_result[:,:,mid_slice])
    axes[0,1].set_title('MATLAB Result')

    axes[1,0].imshow(python_result[:,:,mid_slice])
    axes[1,0].set_title('Python Result')

    diff = np.abs(matlab_result - python_result)
    axes[1,1].imshow(diff[:,:,mid_slice])
    axes[1,1].set_title('Absolute Difference')

    plt.tight_layout()
    plt.show()

    def main():
    # Generate test data
    original, data_with_nans, nan_mask = generate_test_data()

    # Save test data for MATLAB
    sio.savemat('test_data.mat',
    {'data_with_nans': data_with_nans})

    # Run Python implementation
    python_result = inpaint_nans_3d(data_with_nans)

    # Load MATLAB results (after running MATLAB script)
    try:
    matlab_data = sio.loadmat('matlab_result.mat')
    matlab_result = matlab_data['inpainted_data']

    # Compare results
    compare_results(original, matlab_result, python_result)

    except FileNotFoundError:
    print("MATLAB results not found. Please run the MATLAB script first.")

    if __name__ == "__main__":
    main()
    75 changes: 75 additions & 0 deletions test_inpaint_equivalence.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,75 @@
    import unittest
    import numpy as np
    import scipy.io as sio
    from inpaint3d import inpaint_nans_3d

    def generate_test_data(shape=(20, 20, 20)):
    """Generate test data with known NaN positions"""
    # Create a 3D test array with a known pattern
    x, y, z = np.meshgrid(np.linspace(-2, 2, shape[0]),
    np.linspace(-2, 2, shape[1]),
    np.linspace(-2, 2, shape[2]))
    data = np.sin(x) * np.cos(y) * np.exp(-0.1 * (x**2 + y**2 + z**2))

    # Add NaN values randomly
    np.random.seed(42) # For reproducibility
    nan_mask = np.random.random(data.shape) < 0.3
    data_with_nans = data.copy()
    data_with_nans[nan_mask] = np.nan

    return data, data_with_nans, nan_mask

    class TestInpaintEquivalence(unittest.TestCase):
    def setUp(self):
    # Use the same shape as in test_inpaint_comparison.py
    self.shape = (20, 20, 20)
    self.original, self.data_with_nans, _ = generate_test_data(self.shape)

    # Save for MATLAB
    sio.savemat('test_data.mat', {'data_with_nans': self.data_with_nans})

    def test_results_close_to_original(self):
    """Test if both implementations give results close to original data"""
    # Run Python implementation
    python_result = inpaint_nans_3d(self.data_with_nans)

    try:
    # Load MATLAB results
    matlab_data = sio.loadmat('matlab_result.mat')
    matlab_result = matlab_data['inpainted_data']

    # Verify shapes match
    self.assertEqual(matlab_result.shape, self.original.shape,
    "MATLAB result shape doesn't match original data shape")

    # Test if results are close to original (within 5% error)
    python_error = np.nanmean(np.abs(self.original - python_result))
    matlab_error = np.nanmean(np.abs(self.original - matlab_result))

    self.assertLess(python_error, 0.05)
    self.assertLess(matlab_error, 0.05)

    except FileNotFoundError:
    self.skipTest("MATLAB results file not found. Run MATLAB script first.")

    def test_implementations_equivalent(self):
    """Test if both implementations give similar results"""
    python_result = inpaint_nans_3d(self.data_with_nans)

    try:
    matlab_data = sio.loadmat('matlab_result.mat')
    matlab_result = matlab_data['inpainted_data']

    # Verify shapes match
    self.assertEqual(matlab_result.shape, python_result.shape,
    "MATLAB and Python results have different shapes")

    # Test if results are within 1% of each other
    difference = np.nanmean(np.abs(matlab_result - python_result))
    self.assertLess(difference, 0.01)

    except FileNotFoundError:
    self.skipTest("MATLAB results file not found. Run MATLAB script first.")

    if __name__ == '__main__':
    unittest.main()