Skip to content

Instantly share code, notes, and snippets.

@AzureDVBB
Created November 19, 2018 22:34
Show Gist options
  • Save AzureDVBB/4315d2350a457a3c2b98e1e1a6353f4a to your computer and use it in GitHub Desktop.
Save AzureDVBB/4315d2350a457a3c2b98e1e1a6353f4a to your computer and use it in GitHub Desktop.

Revisions

  1. AzureDVBB created this gist Nov 19, 2018.
    271 changes: 271 additions & 0 deletions computations_parallel.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,271 @@
    # -*- coding: utf-8 -*-

    # look into 'streamz' package, neat pipelining with dask integration
    import cv2 # opencv-python for frame reading
    import skimage # scikit-image for loaded image analysis
    import dask # parallelized python EZ mode
    import numpy as np # yep
    import matplotlib.pyplot as plt # pretty charts no?
    import matplotlib
    from skimage.feature import match_descriptors, ORB
    from skimage.measure import ransac
    from skimage.transform import FundamentalMatrixTransform
    import os

    #standard libs
    import time
    import random
    import itertools

    # due to bugs in scikit-video with opening and reading files
    # resorted to using OpenCV for reading frames

    class VideoFile_p:

    def __init__(self, file):
    self.file = file
    # look at opencv documentation: Flags for video I/O
    # the cv2 properties did not function properly,
    # passing the integer value of the flag did
    self.capture = cv2.VideoCapture(self.file)
    self.number_of_frames = int(self.capture.get(7))
    self.current_index = 0

    def __len__(self):
    return self.number_of_frames

    def __iter__(self):
    self.current_index = 0
    self.capture = cv2.VideoCapture(self.file)
    self.number_of_frames = int(self.capture.get(7))
    return self

    def __next__(self):
    self.current_index += 1
    ret, frame = self.capture.read() # ret is false at EOF
    if ret is False:
    self.current_index = None
    self.capture = None
    raise StopIteration
    elif ret is True:
    # cv2 opens in bgr mode and needs to be converted to RGB
    return {'index': self.current_index,
    'raw_frame': cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)}

    def save_frames(self, list_of_frames, save_folder):
    # checks and makes the directory path if not existing
    if not os.path.exists(save_folder):
    os.makedirs(save_folder)
    reader = iter(self)
    for img in reader:
    # premature termination on last frame write
    if img['index'] > max(list_of_frames):
    break

    elif img['index'] in list_of_frames:
    # padding out number for up to 6 digits
    filename = os.path.join(save_folder, str(img['index']).zfill(6)) + '.jpg'
    frame = cv2.cvtColor(img['raw_frame'], cv2.COLOR_RGB2BGR)
    cv2.imwrite(filename, frame)


    from skimage.feature import match_descriptors, ORB
    from skimage.measure import ransac
    from skimage.transform import FundamentalMatrixTransform
    import numpy as np
    from numba import jit

    class Analysis_p:

    @staticmethod
    def compute_keypoints_descriptors_blur(frame, n_keypoints = 500,
    opencv=True, sift=False):
    # using opencv to reduce dependencies needed
    # estimate blur by taking the image's laplacian vairance (blurry=low)
    blur = cv2.Laplacian(frame['raw_frame'], cv2.CV_64F).var()

    if not opencv:
    # skimage has poor detection speed, 10x slower as of writing this
    # keeping it here if in the future its better

    if sift:
    orb = cv2.SIFT(nfeatures=n_keypoints)
    else:
    orb = skimage.feature.ORB(n_keypoints = n_keypoints, downscale=2)
    # skimage ORB needs grayscale image
    orb.detect_and_extract(skimage.color.rgb2gray(frame['raw_frame']))
    keypoints = orb.keypoints
    descriptors = orb.descriptors

    return {'index': frame['index'], 'blur': blur,
    'keypoints': keypoints, 'descriptors': descriptors}

    else:
    # boilerplate from opencv python reference
    orb = cv2.ORB_create(nfeatures = n_keypoints) # Initiate ORB detector
    keypoints_o = orb.detect(frame['raw_frame'], None)
    keypoints_o, descriptors = orb.compute(frame['raw_frame'], keypoints_o)
    # make keypoints compatible with scikit-image
    # array of [[x, y],] coords ndarray
    keypoints = np.ndarray(shape=(n_keypoints, 2), dtype=np.int64)

    try:
    for i, k in enumerate(keypoints_o, start=0):
    keypoints[i] = k.pt
    except:
    keypoints = None # if something goes catastrophically wrong

    #cannot pickle openCV keypoint objects unfortunately, need to convert to coords (x,y aray)
    return {'index': frame['index'], 'blur': blur,
    'keypoints': keypoints, 'descriptors': descriptors}


    @staticmethod
    def match_frames(frame1, frame2, minsamples=8, maxtrials=100, opencv=False):
    if opencv is False:
    # skimage has nicer matching then opencv
    # modified boilerplate example code from doc of skimage
    # ORB
    matches = match_descriptors(frame1['descriptors'],
    frame2['descriptors'],
    cross_check = True)
    try:
    # filtering out outliers, note first return is 'model', we dont care
    _, inliers = ransac((frame1['keypoints'][matches[:, 0]],
    frame2['keypoints'][matches[:, 1]]),
    FundamentalMatrixTransform,
    min_samples = minsamples,
    residual_threshold = 1, max_trials = maxtrials)

    # only the number of inliers matter to us
    inliers_sum = inliers.sum()
    #inliers_sum = len(matches)

    except:
    # just show raw matches if RANSAC errors out
    inliers_sum = len(matches)

    finally:
    return inliers_sum

    else:
    pass # I doubt anyone wants to use opencv here


    class FrameSelection_p:

    def __init__(self):
    pass

    def variance_picker(matches_to_base_frame, min_variance=0.1):
    new = None
    old = None

    for i, _ in enumerate(matches_to_base_frame, start=0):
    if old is None:
    old = matches_to_base_frame[i]
    new = old
    else:
    old = new
    #new = sum(matches_to_base_frame[:i])/(i+1)
    new = matches_to_base_frame[i]
    variance = abs(new - old) / old

    if variance <= min_variance:
    return i

    return None # too much variance in dataset

    def compute_best_frames(frame_stream, last_frame_index, client,
    batch_size=10, min_variance=0.05):

    from itertools import repeat, islice

    last_frame_index = last_frame_index-1 # removes infinite loop bug
    frame_generator = itertools.islice(vid_stream, last_frame_index)
    base_descriptor = None
    batch_num = 1
    descriptor_collection = []
    found_at_collection_index = None
    matches_to_base_frame = []
    good_frame_indexes = [1] # include first frame
    last_batch = False


    while True:

    if base_descriptor is None:
    base_descriptor = client.submit(
    Analysis_p.compute_keypoints_descriptors_blur,
    next(frame_generator))

    # check if the next batch is the last one
    if good_frame_indexes[-1] + batch_num*batch_size >= last_frame_index:
    if last_batch is True:
    break # end the loop if it has been
    else:
    last_batch = True
    # put the appropriate amount onto the collection
    futures = client.map(Analysis_p.compute_keypoints_descriptors_blur,
    islice(frame_generator,
    last_frame_index -
    good_frame_indexes[-1] -
    batch_size * (batch_num - 1)))

    descriptor_collection += futures

    else:
    futures = client.map(Analysis_p.compute_keypoints_descriptors_blur,
    islice(frame_generator,
    batch_size * batch_num -
    len(descriptor_collection)))

    descriptor_collection += futures

    # match all elements in the collection against base
    match_num_futures = client.map(Analysis_p.match_frames,
    repeat(base_descriptor),
    islice(descriptor_collection,
    len(matches_to_base_frame),
    batch_size * batch_num))
    # TODO: the above method passes the entire slice into method
    # need to fix that so it only sends base future and collection future
    matches_to_base_frame += client.gather(match_num_futures)

    # selection pass
    found_at_collection_index = FrameSelection_p.variance_picker(
    matches_to_base_frame,
    min_variance=min_variance)

    if found_at_collection_index is not None:
    # save the frame's index as good
    frame_index = descriptor_collection[found_at_collection_index].result()['index']
    base_descriptor = descriptor_collection[found_at_collection_index]
    good_frame_indexes.append(frame_index)
    # make the good frame the base
    base_descriptor = descriptor_collection[found_at_collection_index]

    # delete frame dictionaries at new base and before
    # and reset variables
    del descriptor_collection[:found_at_collection_index+1]
    found_at_collection_index = None
    matches_to_base_frame.clear()
    batch_num = 1

    else:
    # if not found then repeat
    batch_num += 1

    # repeat untill input frames are exhausted

    return good_frame_indexes # finished

    if __name__ == '__main__':
    from dask.distributed import Client
    file = 'G:/_SFMDatasets/VideoCodeTest/ground.mp4'
    client = Client('tcp://127.0.0.1:8786') #change address for cluster's one
    vid_stream = VideoFile_p(file)
    #slc = itertools.islice(vid_stream, 2000)
    good = FrameSelection_p.compute_best_frames(vid_stream, vid_stream.number_of_frames, client,
    min_variance=0.08, batch_size=20)
    vid_stream.save_frames(good, 'J:/selected_video')