Created
November 19, 2018 22:34
-
-
Save AzureDVBB/4315d2350a457a3c2b98e1e1a6353f4a to your computer and use it in GitHub Desktop.
Revisions
-
AzureDVBB created this gist
Nov 19, 2018 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,271 @@ # -*- coding: utf-8 -*- # look into 'streamz' package, neat pipelining with dask integration import cv2 # opencv-python for frame reading import skimage # scikit-image for loaded image analysis import dask # parallelized python EZ mode import numpy as np # yep import matplotlib.pyplot as plt # pretty charts no? import matplotlib from skimage.feature import match_descriptors, ORB from skimage.measure import ransac from skimage.transform import FundamentalMatrixTransform import os #standard libs import time import random import itertools # due to bugs in scikit-video with opening and reading files # resorted to using OpenCV for reading frames class VideoFile_p: def __init__(self, file): self.file = file # look at opencv documentation: Flags for video I/O # the cv2 properties did not function properly, # passing the integer value of the flag did self.capture = cv2.VideoCapture(self.file) self.number_of_frames = int(self.capture.get(7)) self.current_index = 0 def __len__(self): return self.number_of_frames def __iter__(self): self.current_index = 0 self.capture = cv2.VideoCapture(self.file) self.number_of_frames = int(self.capture.get(7)) return self def __next__(self): self.current_index += 1 ret, frame = self.capture.read() # ret is false at EOF if ret is False: self.current_index = None self.capture = None raise StopIteration elif ret is True: # cv2 opens in bgr mode and needs to be converted to RGB return {'index': self.current_index, 'raw_frame': cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)} def save_frames(self, list_of_frames, save_folder): # checks and makes the directory path if not existing if not os.path.exists(save_folder): os.makedirs(save_folder) reader = iter(self) for img in reader: # premature termination on last frame write if img['index'] > max(list_of_frames): break elif img['index'] in list_of_frames: # padding out number for up to 6 digits filename = os.path.join(save_folder, str(img['index']).zfill(6)) + '.jpg' frame = cv2.cvtColor(img['raw_frame'], cv2.COLOR_RGB2BGR) cv2.imwrite(filename, frame) from skimage.feature import match_descriptors, ORB from skimage.measure import ransac from skimage.transform import FundamentalMatrixTransform import numpy as np from numba import jit class Analysis_p: @staticmethod def compute_keypoints_descriptors_blur(frame, n_keypoints = 500, opencv=True, sift=False): # using opencv to reduce dependencies needed # estimate blur by taking the image's laplacian vairance (blurry=low) blur = cv2.Laplacian(frame['raw_frame'], cv2.CV_64F).var() if not opencv: # skimage has poor detection speed, 10x slower as of writing this # keeping it here if in the future its better if sift: orb = cv2.SIFT(nfeatures=n_keypoints) else: orb = skimage.feature.ORB(n_keypoints = n_keypoints, downscale=2) # skimage ORB needs grayscale image orb.detect_and_extract(skimage.color.rgb2gray(frame['raw_frame'])) keypoints = orb.keypoints descriptors = orb.descriptors return {'index': frame['index'], 'blur': blur, 'keypoints': keypoints, 'descriptors': descriptors} else: # boilerplate from opencv python reference orb = cv2.ORB_create(nfeatures = n_keypoints) # Initiate ORB detector keypoints_o = orb.detect(frame['raw_frame'], None) keypoints_o, descriptors = orb.compute(frame['raw_frame'], keypoints_o) # make keypoints compatible with scikit-image # array of [[x, y],] coords ndarray keypoints = np.ndarray(shape=(n_keypoints, 2), dtype=np.int64) try: for i, k in enumerate(keypoints_o, start=0): keypoints[i] = k.pt except: keypoints = None # if something goes catastrophically wrong #cannot pickle openCV keypoint objects unfortunately, need to convert to coords (x,y aray) return {'index': frame['index'], 'blur': blur, 'keypoints': keypoints, 'descriptors': descriptors} @staticmethod def match_frames(frame1, frame2, minsamples=8, maxtrials=100, opencv=False): if opencv is False: # skimage has nicer matching then opencv # modified boilerplate example code from doc of skimage # ORB matches = match_descriptors(frame1['descriptors'], frame2['descriptors'], cross_check = True) try: # filtering out outliers, note first return is 'model', we dont care _, inliers = ransac((frame1['keypoints'][matches[:, 0]], frame2['keypoints'][matches[:, 1]]), FundamentalMatrixTransform, min_samples = minsamples, residual_threshold = 1, max_trials = maxtrials) # only the number of inliers matter to us inliers_sum = inliers.sum() #inliers_sum = len(matches) except: # just show raw matches if RANSAC errors out inliers_sum = len(matches) finally: return inliers_sum else: pass # I doubt anyone wants to use opencv here class FrameSelection_p: def __init__(self): pass def variance_picker(matches_to_base_frame, min_variance=0.1): new = None old = None for i, _ in enumerate(matches_to_base_frame, start=0): if old is None: old = matches_to_base_frame[i] new = old else: old = new #new = sum(matches_to_base_frame[:i])/(i+1) new = matches_to_base_frame[i] variance = abs(new - old) / old if variance <= min_variance: return i return None # too much variance in dataset def compute_best_frames(frame_stream, last_frame_index, client, batch_size=10, min_variance=0.05): from itertools import repeat, islice last_frame_index = last_frame_index-1 # removes infinite loop bug frame_generator = itertools.islice(vid_stream, last_frame_index) base_descriptor = None batch_num = 1 descriptor_collection = [] found_at_collection_index = None matches_to_base_frame = [] good_frame_indexes = [1] # include first frame last_batch = False while True: if base_descriptor is None: base_descriptor = client.submit( Analysis_p.compute_keypoints_descriptors_blur, next(frame_generator)) # check if the next batch is the last one if good_frame_indexes[-1] + batch_num*batch_size >= last_frame_index: if last_batch is True: break # end the loop if it has been else: last_batch = True # put the appropriate amount onto the collection futures = client.map(Analysis_p.compute_keypoints_descriptors_blur, islice(frame_generator, last_frame_index - good_frame_indexes[-1] - batch_size * (batch_num - 1))) descriptor_collection += futures else: futures = client.map(Analysis_p.compute_keypoints_descriptors_blur, islice(frame_generator, batch_size * batch_num - len(descriptor_collection))) descriptor_collection += futures # match all elements in the collection against base match_num_futures = client.map(Analysis_p.match_frames, repeat(base_descriptor), islice(descriptor_collection, len(matches_to_base_frame), batch_size * batch_num)) # TODO: the above method passes the entire slice into method # need to fix that so it only sends base future and collection future matches_to_base_frame += client.gather(match_num_futures) # selection pass found_at_collection_index = FrameSelection_p.variance_picker( matches_to_base_frame, min_variance=min_variance) if found_at_collection_index is not None: # save the frame's index as good frame_index = descriptor_collection[found_at_collection_index].result()['index'] base_descriptor = descriptor_collection[found_at_collection_index] good_frame_indexes.append(frame_index) # make the good frame the base base_descriptor = descriptor_collection[found_at_collection_index] # delete frame dictionaries at new base and before # and reset variables del descriptor_collection[:found_at_collection_index+1] found_at_collection_index = None matches_to_base_frame.clear() batch_num = 1 else: # if not found then repeat batch_num += 1 # repeat untill input frames are exhausted return good_frame_indexes # finished if __name__ == '__main__': from dask.distributed import Client file = 'G:/_SFMDatasets/VideoCodeTest/ground.mp4' client = Client('tcp://127.0.0.1:8786') #change address for cluster's one vid_stream = VideoFile_p(file) #slc = itertools.islice(vid_stream, 2000) good = FrameSelection_p.compute_best_frames(vid_stream, vid_stream.number_of_frames, client, min_variance=0.08, batch_size=20) vid_stream.save_frames(good, 'J:/selected_video')