Created
June 20, 2018 17:48
-
-
Save cryzed/f25823ea594a2cdd8a41eb81e370e662 to your computer and use it in GitHub Desktop.
Revisions
-
cryzed created this gist
Jun 20, 2018 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,368 @@ import argparse import collections import itertools import os import sys import time try: import cPickle as pickle except ImportError: import pickle import PIL.Image import numpy as np import vlfeat import scipy.cluster.vq import scipy.spatial.distance import matplotlib.pyplot as plt PAGES_PATH = os.path.join('data', 'pages') GT_PATH = os.path.join('data', 'GT') IFS_MATCH_IMAGES_PATH = os.path.join('ifs_match_images') MATCH_IMAGES_PATH = os.path.join('match_images') CODE_BOOK_PATH = os.path.join('data', 'codebook.bin') SPATIAL_PYRAMID_TYPES = ['L', 'R', 'G', 'GL', 'GR', 'LR', 'GLR'] CELL_MARGIN_TYPES = ['none', 'horizontal', 'vertical', 'both'] argument_parser = argparse.ArgumentParser() argument_parser.add_argument('--step-size', '-s', type=int, default=15) argument_parser.add_argument('--cell-size', '-c', type=int, default=3) argument_parser.add_argument('--centroids', '-C', type=int, default=40) argument_parser.add_argument('--k-means-iterations', '-k', type=int, default=20) argument_parser.add_argument('--distance-metric', choices=['cityblock', 'cosine', 'euclidean'], default='cosine') argument_parser.add_argument('--spatial-pyramid-type', '-S', choices=SPATIAL_PYRAMID_TYPES, default='LR') argument_parser.add_argument('--pages', '-p', type=int, default=1) argument_parser.add_argument('--accumulator-percentile', '-a', type=float, default=95.0) argument_parser.add_argument('--use-ifs', '-I', action='store_true') argument_parser.add_argument('--use-accumulator', '-A', action='store_true') argument_parser.add_argument('--save-images', '-sa', action='store_true') argument_parser.add_argument('--verbose', action='store_true') argument_parser.add_argument('--cell-margin', choices=CELL_MARGIN_TYPES, default='horizontal') SpatialPyramid = collections.namedtuple('SpatialPyramid', ['global_', 'left', 'right']) def makedirs(name, mode=0777, exist_ok=False): if not exist_ok: return os.makedirs(name, mode) # Taken from Python 3 try: os.makedirs(name, mode) except OSError: # Cannot rely on checking for EEXIST, since the operating system # could give priority to other errors like EACCES or EROFS if not exist_ok or not os.path.isdir(name): raise def load_gtp_file(path): entries = collections.defaultdict(list) with open(path) as file: for line in (line for line in (line.strip() for line in file) if line): x1, y1, x2, y2, word = line.split() entries[word].append((int(x1), int(y1), int(x2), int(y2))) return entries def load_codebook(path): input_file = open(path, 'r') code_book = np.fromfile(input_file, dtype='float32') code_book = np.reshape(code_book, (4096, 128)) return code_book def make_spatial_pyramid(data, length, type_='GLR'): count = len(data) left_index = int(np.floor(count / 2)) right_index = int(np.ceil(count / 2)) if type_ == 'L': data = [], data[:left_index], [] elif type_ == 'R': data = [], [], data[right_index:] elif type_ == 'G': data = data, [], [] elif type_ == 'GL': data = data, data[:left_index], [] elif type_ == 'GR': data = data, [], data[right_index:] elif type_ == 'LR': data = [], data[:left_index], data[right_index:] elif type_ == 'GLR': data = data, data[:left_index], data[right_index:] else: raise ValueError('unknown spatial pyramid type: %r' % type_) spatial_pyramid = SpatialPyramid(*(np.bincount(datum, minlength=length) for datum in data)) return np.concatenate(spatial_pyramid) def load_corpus(page_names): defaultdict_factory = lambda: collections.defaultdict(defaultdict_factory) corpus = collections.defaultdict(defaultdict_factory) offset = 0 images = [] corpus_gtp = collections.defaultdict(list) for page_name in page_names: corpus['pages'][page_name]['offset'] = offset # Load page image image_path = os.path.join(PAGES_PATH, '%s.png' % page_name) corpus['pages'][page_name]['image_path'] = image_path image = PIL.Image.open(image_path) corpus['pages'][page_name]['image'] = image images.append(image) # Load page GTP gtp_path = os.path.join(GT_PATH, '%s.gtp' % page_name) corpus['pages'][page_name]['gtp_path'] = gtp_path gtp = load_gtp_file(gtp_path) corpus['pages'][page_name]['gtp'] = gtp # Update global corpus GTP with current offset for word, coordinates in gtp.items(): for x1, y1, x2, y2 in coordinates: corpus_gtp[word].append((x1 + offset, y1, x2 + offset, y2)) offset += image.width # Create Corpus image by concatenating page images horizontally width = sum(image.width for image in images) max_height = max(image.height for image in images) corpus_image = PIL.Image.new(images[0].mode, (width, max_height)) x_offset = 0 for image in images: corpus_image.paste(image, (x_offset, 0)) x_offset += image.width corpus['gtp'] = corpus_gtp corpus['image'] = corpus_image corpus['data'] = np.array(corpus_image, dtype='float32') return corpus def pre_main(arguments): load_codebook(os.path.join('data', 'codebook.bin')) page_names = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(PAGES_PATH))[:arguments.pages]] corpus = load_corpus(page_names) results1 = collections.OrderedDict() results2 = collections.OrderedDict() for accumulator_percentile in range(0, 105, 5): print accumulator_percentile arguments.use_ifs = True arguments.use_accumulator = True arguments.accumulator_percentile = accumulator_percentile start = time.time() mean_average_precision = main(arguments, corpus) duration = int(time.time() - start) results1[accumulator_percentile] = mean_average_precision results2[accumulator_percentile] = duration plt.plot(range(len(results1)), results1.values(), 'o') plt.xlabel('Accumulator Percentile') plt.ylabel('Mean Average Precision') plt.xticks(range(len(results1)), results1.keys()) plt.grid(True) plt.ylim(0, 1) plt.tight_layout() plt.show() plt.plot(range(len(results2)), results2.values(), 'o') plt.xlabel('Accumulator Percentile') plt.ylabel('Runtime') plt.xticks(range(len(results2)), results2.keys()) plt.grid(True) plt.tight_layout() plt.show() def main(arguments, corpus): # Calculate SIFT data for corpus frames, descriptors = vlfeat.vl_dsift( corpus['image'] / corpus['data'].max(), step=arguments.step_size, size=arguments.cell_size, fast=True, float_descriptors=True) # Find all frames and descriptors contained inside word boundaries (minus a cell margin of cell_size * 2) cell_margin = 2 * arguments.cell_size words_frames = [] words_descriptors = [] previous_frame_index = 0 word_data_indices = collections.OrderedDict() word_coordinates = collections.OrderedDict() for word, coordinates in corpus['gtp'].items(): # Filter word frames within word bounding box for variation, (x1, y1, x2, y2) in enumerate(coordinates): if arguments.cell_margin == 'none': mask = ( (frames[:, 0] >= x1) & (frames[:, 1] >= y1) & (frames[:, 0] <= x2) & (frames[:, 1] <= y2)) elif arguments.cell_margin == 'horizontal': mask = ( (frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1) & (frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2)) elif arguments.cell_margin == 'vertical': mask = ( (frames[:, 0] >= x1) & (frames[:, 1] >= y1 + cell_margin) & (frames[:, 0] <= x2) & (frames[:, 1] <= y2 - cell_margin)) elif arguments.cell_margin == 'both': mask = ( (frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1 + cell_margin) & (frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2 - cell_margin)) else: raise RuntimeError('dude what the fuck are you doing') # Get matching frames/desc for the word word_frames = frames[mask] words_frames.append(word_frames) words_descriptors.append(descriptors[mask]) # Count how many frames are contained inside the bounding box frame_count = word_frames.shape[0] # Note at which index and how many (following) frames/descs are part of a word key = word, variation word_data_indices[key] = previous_frame_index, frame_count word_coordinates[key] = x1, y1, x2, y2 previous_frame_index += frame_count words_frames = np.concatenate(words_frames) words_descriptors = np.concatenate(words_descriptors) if arguments.centroids == 4096: code_book = load_codebook(CODE_BOOK_PATH) labels, _ = scipy.cluster.vq.vq(words_descriptors, code_book) else: # Calculate labels _, labels = scipy.cluster.vq.kmeans2( words_descriptors, arguments.centroids, iter=arguments.k_means_iterations, minit='points') # Word -> labels mapping # noinspection PyArgumentList word_labels = collections.OrderedDict( (key, labels[start:start + length]) for key, (start, length) in word_data_indices.items()) # Create (word, variation) -> spatial pyramid mapping # noinspection PyArgumentList spatial_pyramids = collections.OrderedDict( (key, make_spatial_pyramid(labels, arguments.centroids, arguments.spatial_pyramid_type)) for key, labels in word_labels.items()) # Create IFS database ifs_height = len(spatial_pyramids.values()[0]) ifs = [set() for count in range(ifs_height)] for word_index, spatial_pyramid in enumerate(spatial_pyramids.values()): for index, count in enumerate(spatial_pyramid): if count: ifs[index].add(word_index) # Create word index -> variation set mapping word_variation_indices = collections.defaultdict(set) for word_index, (word, variation) in enumerate(spatial_pyramids.keys()): word_variation_indices[word].add(word_index) # Find query in IFS spatial_pyramids_values = spatial_pyramids.values() word_coordinates_values = word_coordinates.values() average_precisions = [] average_recalls = [] for word_index, ((word, variation), query) in enumerate(spatial_pyramids.items()): # Skip words with no findable duplicates in the IFS database appearances = len(word_variation_indices[word]) - 1 if not appearances: if arguments.verbose: print >> sys.stderr, 'No duplicate appearances for (%s, %d)!' % (word, variation) continue if arguments.use_ifs: ifs_candidate_indices = list(itertools.chain(*(ifs[index] for index, count in enumerate(query) if count))) candidate_indices = set(ifs_candidate_indices) if not candidate_indices: if arguments.verbose: print >> sys.stderr, 'No candidates for (%s, %d) after IFS!' % (word, variation) average_precisions.append(0) average_recalls.append(0) continue if arguments.use_accumulator: # noinspection PyArgumentList accumulator = collections.Counter(ifs_candidate_indices) # No candidates left after having applied the IFS if not accumulator: if arguments.verbose: print >> sys.stderr, 'No candidates for (%s, %d) after IFS + Accumulator!' % (word, variation) average_precisions.append(0) average_recalls.append(0) continue most_common = accumulator.most_common() rankings = sorted(set(accumulator.values())) percentile_ranking = rankings[max(0, int(len(rankings) * arguments.accumulator_percentile / 100.0) - 1)] candidate_indices = set( index for index, count in list(itertools.takewhile(lambda item: item[1] >= percentile_ranking, most_common))) else: candidate_indices = set(range(len(spatial_pyramids))) candidate_indices -= {word_index} if not candidate_indices: if arguments.verbose: print >> sys.stderr, 'No candidates for (%s, %d)' % (word, variation) average_precisions.append(0) average_recalls.append(0) continue candidate_pyramids = np.array([spatial_pyramids_values[index] for index in candidate_indices]) query = query.reshape((1, query.shape[0])) distances = scipy.spatial.distance.cdist(query, candidate_pyramids, metric=arguments.distance_metric)[0] # Translate index in distance array to index of candidate distances_indices = range(distances.shape[0]) distance_index_to_candidate_index = { distance_index: candidate_index for distance_index, candidate_index in zip(distances_indices, candidate_indices)} distances_sorted_indices = np.argsort(distances) sorted_candidate_indices = [ distance_index_to_candidate_index[distance_index] for distance_index in distances_sorted_indices] hits = [1 if index in word_variation_indices[word] else 0 for index in sorted_candidate_indices] true_positives = sum(hits) # Calculate accumulated hits at index hits_at_k = [] current_hits = 0 for hit in hits: if hit: current_hits += 1 hits_at_k.append(current_hits) average_precision = sum( (current_hits / float(index)) * hit for index, (hit, current_hits) in enumerate(zip(hits, hits_at_k), start=1)) / float(appearances) average_precisions.append(average_precision) average_recalls.append(true_positives / float(appearances)) if arguments.save_images: match_images_path = os.path.join(MATCH_IMAGES_PATH, '%s_%d' % (word, variation)) makedirs(match_images_path, exist_ok=True) coordinates = word_coordinates_values[word_index] corpus['image'].crop(coordinates).save(os.path.join(match_images_path, '0_original.png')) for rank, candidate_word_index in enumerate(sorted_candidate_indices, start=1): coordinates = word_coordinates_values[candidate_word_index] path = os.path.join(match_images_path, 'candidate_%d.png' % rank) corpus['image'].crop(coordinates).save(path) # print 'Word %s (Variation: %d): %.2f%%' % (word, variation, average_precision * 100) print 'Mean Recall: %f' % (np.mean(average_recalls) * 100) mean_average_precision = np.mean(average_precisions) print 'Mean Average Precision: %f' % (mean_average_precision * 100) return mean_average_precision if __name__ == '__main__': arguments = argument_parser.parse_args() pre_main(arguments)