Skip to content

Instantly share code, notes, and snippets.

@cryzed
Created June 20, 2018 17:48
Show Gist options
  • Save cryzed/f25823ea594a2cdd8a41eb81e370e662 to your computer and use it in GitHub Desktop.
Save cryzed/f25823ea594a2cdd8a41eb81e370e662 to your computer and use it in GitHub Desktop.

Revisions

  1. cryzed created this gist Jun 20, 2018.
    368 changes: 368 additions & 0 deletions IPFS.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,368 @@
    import argparse
    import collections
    import itertools
    import os
    import sys

    import time

    try:
    import cPickle as pickle
    except ImportError:
    import pickle

    import PIL.Image
    import numpy as np
    import vlfeat
    import scipy.cluster.vq
    import scipy.spatial.distance
    import matplotlib.pyplot as plt

    PAGES_PATH = os.path.join('data', 'pages')
    GT_PATH = os.path.join('data', 'GT')
    IFS_MATCH_IMAGES_PATH = os.path.join('ifs_match_images')
    MATCH_IMAGES_PATH = os.path.join('match_images')
    CODE_BOOK_PATH = os.path.join('data', 'codebook.bin')
    SPATIAL_PYRAMID_TYPES = ['L', 'R', 'G', 'GL', 'GR', 'LR', 'GLR']
    CELL_MARGIN_TYPES = ['none', 'horizontal', 'vertical', 'both']

    argument_parser = argparse.ArgumentParser()
    argument_parser.add_argument('--step-size', '-s', type=int, default=15)
    argument_parser.add_argument('--cell-size', '-c', type=int, default=3)
    argument_parser.add_argument('--centroids', '-C', type=int, default=40)
    argument_parser.add_argument('--k-means-iterations', '-k', type=int, default=20)
    argument_parser.add_argument('--distance-metric', choices=['cityblock', 'cosine', 'euclidean'], default='cosine')
    argument_parser.add_argument('--spatial-pyramid-type', '-S', choices=SPATIAL_PYRAMID_TYPES, default='LR')
    argument_parser.add_argument('--pages', '-p', type=int, default=1)
    argument_parser.add_argument('--accumulator-percentile', '-a', type=float, default=95.0)
    argument_parser.add_argument('--use-ifs', '-I', action='store_true')
    argument_parser.add_argument('--use-accumulator', '-A', action='store_true')
    argument_parser.add_argument('--save-images', '-sa', action='store_true')
    argument_parser.add_argument('--verbose', action='store_true')
    argument_parser.add_argument('--cell-margin', choices=CELL_MARGIN_TYPES, default='horizontal')

    SpatialPyramid = collections.namedtuple('SpatialPyramid', ['global_', 'left', 'right'])


    def makedirs(name, mode=0777, exist_ok=False):
    if not exist_ok:
    return os.makedirs(name, mode)

    # Taken from Python 3
    try:
    os.makedirs(name, mode)
    except OSError:
    # Cannot rely on checking for EEXIST, since the operating system
    # could give priority to other errors like EACCES or EROFS
    if not exist_ok or not os.path.isdir(name):
    raise


    def load_gtp_file(path):
    entries = collections.defaultdict(list)
    with open(path) as file:
    for line in (line for line in (line.strip() for line in file) if line):
    x1, y1, x2, y2, word = line.split()
    entries[word].append((int(x1), int(y1), int(x2), int(y2)))

    return entries


    def load_codebook(path):
    input_file = open(path, 'r')
    code_book = np.fromfile(input_file, dtype='float32')
    code_book = np.reshape(code_book, (4096, 128))
    return code_book


    def make_spatial_pyramid(data, length, type_='GLR'):
    count = len(data)
    left_index = int(np.floor(count / 2))
    right_index = int(np.ceil(count / 2))

    if type_ == 'L':
    data = [], data[:left_index], []
    elif type_ == 'R':
    data = [], [], data[right_index:]
    elif type_ == 'G':
    data = data, [], []
    elif type_ == 'GL':
    data = data, data[:left_index], []
    elif type_ == 'GR':
    data = data, [], data[right_index:]
    elif type_ == 'LR':
    data = [], data[:left_index], data[right_index:]
    elif type_ == 'GLR':
    data = data, data[:left_index], data[right_index:]
    else:
    raise ValueError('unknown spatial pyramid type: %r' % type_)

    spatial_pyramid = SpatialPyramid(*(np.bincount(datum, minlength=length) for datum in data))
    return np.concatenate(spatial_pyramid)


    def load_corpus(page_names):
    defaultdict_factory = lambda: collections.defaultdict(defaultdict_factory)
    corpus = collections.defaultdict(defaultdict_factory)

    offset = 0
    images = []
    corpus_gtp = collections.defaultdict(list)
    for page_name in page_names:
    corpus['pages'][page_name]['offset'] = offset

    # Load page image
    image_path = os.path.join(PAGES_PATH, '%s.png' % page_name)
    corpus['pages'][page_name]['image_path'] = image_path
    image = PIL.Image.open(image_path)
    corpus['pages'][page_name]['image'] = image
    images.append(image)

    # Load page GTP
    gtp_path = os.path.join(GT_PATH, '%s.gtp' % page_name)
    corpus['pages'][page_name]['gtp_path'] = gtp_path
    gtp = load_gtp_file(gtp_path)
    corpus['pages'][page_name]['gtp'] = gtp

    # Update global corpus GTP with current offset
    for word, coordinates in gtp.items():
    for x1, y1, x2, y2 in coordinates:
    corpus_gtp[word].append((x1 + offset, y1, x2 + offset, y2))

    offset += image.width

    # Create Corpus image by concatenating page images horizontally
    width = sum(image.width for image in images)
    max_height = max(image.height for image in images)
    corpus_image = PIL.Image.new(images[0].mode, (width, max_height))
    x_offset = 0
    for image in images:
    corpus_image.paste(image, (x_offset, 0))
    x_offset += image.width

    corpus['gtp'] = corpus_gtp
    corpus['image'] = corpus_image
    corpus['data'] = np.array(corpus_image, dtype='float32')
    return corpus


    def pre_main(arguments):
    load_codebook(os.path.join('data', 'codebook.bin'))
    page_names = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(PAGES_PATH))[:arguments.pages]]
    corpus = load_corpus(page_names)

    results1 = collections.OrderedDict()
    results2 = collections.OrderedDict()
    for accumulator_percentile in range(0, 105, 5):
    print accumulator_percentile
    arguments.use_ifs = True
    arguments.use_accumulator = True
    arguments.accumulator_percentile = accumulator_percentile
    start = time.time()
    mean_average_precision = main(arguments, corpus)
    duration = int(time.time() - start)
    results1[accumulator_percentile] = mean_average_precision
    results2[accumulator_percentile] = duration

    plt.plot(range(len(results1)), results1.values(), 'o')
    plt.xlabel('Accumulator Percentile')
    plt.ylabel('Mean Average Precision')
    plt.xticks(range(len(results1)), results1.keys())
    plt.grid(True)
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.show()

    plt.plot(range(len(results2)), results2.values(), 'o')
    plt.xlabel('Accumulator Percentile')
    plt.ylabel('Runtime')
    plt.xticks(range(len(results2)), results2.keys())
    plt.grid(True)
    plt.tight_layout()
    plt.show()


    def main(arguments, corpus):
    # Calculate SIFT data for corpus
    frames, descriptors = vlfeat.vl_dsift(
    corpus['image'] / corpus['data'].max(), step=arguments.step_size, size=arguments.cell_size,
    fast=True, float_descriptors=True)

    # Find all frames and descriptors contained inside word boundaries (minus a cell margin of cell_size * 2)
    cell_margin = 2 * arguments.cell_size
    words_frames = []
    words_descriptors = []
    previous_frame_index = 0
    word_data_indices = collections.OrderedDict()
    word_coordinates = collections.OrderedDict()
    for word, coordinates in corpus['gtp'].items():
    # Filter word frames within word bounding box
    for variation, (x1, y1, x2, y2) in enumerate(coordinates):
    if arguments.cell_margin == 'none':
    mask = (
    (frames[:, 0] >= x1) & (frames[:, 1] >= y1) &
    (frames[:, 0] <= x2) & (frames[:, 1] <= y2))
    elif arguments.cell_margin == 'horizontal':
    mask = (
    (frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1) &
    (frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2))
    elif arguments.cell_margin == 'vertical':
    mask = (
    (frames[:, 0] >= x1) & (frames[:, 1] >= y1 + cell_margin) &
    (frames[:, 0] <= x2) & (frames[:, 1] <= y2 - cell_margin))
    elif arguments.cell_margin == 'both':
    mask = (
    (frames[:, 0] >= x1 + cell_margin) & (frames[:, 1] >= y1 + cell_margin) &
    (frames[:, 0] <= x2 - cell_margin) & (frames[:, 1] <= y2 - cell_margin))
    else:
    raise RuntimeError('dude what the fuck are you doing')

    # Get matching frames/desc for the word
    word_frames = frames[mask]
    words_frames.append(word_frames)
    words_descriptors.append(descriptors[mask])

    # Count how many frames are contained inside the bounding box
    frame_count = word_frames.shape[0]

    # Note at which index and how many (following) frames/descs are part of a word
    key = word, variation
    word_data_indices[key] = previous_frame_index, frame_count
    word_coordinates[key] = x1, y1, x2, y2
    previous_frame_index += frame_count
    words_frames = np.concatenate(words_frames)
    words_descriptors = np.concatenate(words_descriptors)

    if arguments.centroids == 4096:
    code_book = load_codebook(CODE_BOOK_PATH)
    labels, _ = scipy.cluster.vq.vq(words_descriptors, code_book)
    else:
    # Calculate labels
    _, labels = scipy.cluster.vq.kmeans2(
    words_descriptors, arguments.centroids, iter=arguments.k_means_iterations, minit='points')

    # Word -> labels mapping
    # noinspection PyArgumentList
    word_labels = collections.OrderedDict(
    (key, labels[start:start + length]) for key, (start, length) in word_data_indices.items())

    # Create (word, variation) -> spatial pyramid mapping
    # noinspection PyArgumentList
    spatial_pyramids = collections.OrderedDict(
    (key, make_spatial_pyramid(labels, arguments.centroids, arguments.spatial_pyramid_type))
    for key, labels in word_labels.items())

    # Create IFS database
    ifs_height = len(spatial_pyramids.values()[0])
    ifs = [set() for count in range(ifs_height)]
    for word_index, spatial_pyramid in enumerate(spatial_pyramids.values()):
    for index, count in enumerate(spatial_pyramid):
    if count:
    ifs[index].add(word_index)

    # Create word index -> variation set mapping
    word_variation_indices = collections.defaultdict(set)
    for word_index, (word, variation) in enumerate(spatial_pyramids.keys()):
    word_variation_indices[word].add(word_index)

    # Find query in IFS
    spatial_pyramids_values = spatial_pyramids.values()
    word_coordinates_values = word_coordinates.values()
    average_precisions = []
    average_recalls = []
    for word_index, ((word, variation), query) in enumerate(spatial_pyramids.items()):
    # Skip words with no findable duplicates in the IFS database
    appearances = len(word_variation_indices[word]) - 1
    if not appearances:
    if arguments.verbose:
    print >> sys.stderr, 'No duplicate appearances for (%s, %d)!' % (word, variation)
    continue

    if arguments.use_ifs:
    ifs_candidate_indices = list(itertools.chain(*(ifs[index] for index, count in enumerate(query) if count)))
    candidate_indices = set(ifs_candidate_indices)
    if not candidate_indices:
    if arguments.verbose:
    print >> sys.stderr, 'No candidates for (%s, %d) after IFS!' % (word, variation)
    average_precisions.append(0)
    average_recalls.append(0)
    continue

    if arguments.use_accumulator:
    # noinspection PyArgumentList
    accumulator = collections.Counter(ifs_candidate_indices)
    # No candidates left after having applied the IFS
    if not accumulator:
    if arguments.verbose:
    print >> sys.stderr, 'No candidates for (%s, %d) after IFS + Accumulator!' % (word, variation)
    average_precisions.append(0)
    average_recalls.append(0)
    continue

    most_common = accumulator.most_common()
    rankings = sorted(set(accumulator.values()))
    percentile_ranking = rankings[max(0, int(len(rankings) * arguments.accumulator_percentile / 100.0) - 1)]
    candidate_indices = set(
    index for index, count in
    list(itertools.takewhile(lambda item: item[1] >= percentile_ranking, most_common)))
    else:
    candidate_indices = set(range(len(spatial_pyramids)))

    candidate_indices -= {word_index}
    if not candidate_indices:
    if arguments.verbose:
    print >> sys.stderr, 'No candidates for (%s, %d)' % (word, variation)
    average_precisions.append(0)
    average_recalls.append(0)
    continue

    candidate_pyramids = np.array([spatial_pyramids_values[index] for index in candidate_indices])
    query = query.reshape((1, query.shape[0]))
    distances = scipy.spatial.distance.cdist(query, candidate_pyramids, metric=arguments.distance_metric)[0]

    # Translate index in distance array to index of candidate
    distances_indices = range(distances.shape[0])
    distance_index_to_candidate_index = {
    distance_index: candidate_index for distance_index, candidate_index in
    zip(distances_indices, candidate_indices)}
    distances_sorted_indices = np.argsort(distances)
    sorted_candidate_indices = [
    distance_index_to_candidate_index[distance_index] for distance_index in distances_sorted_indices]

    hits = [1 if index in word_variation_indices[word] else 0 for index in sorted_candidate_indices]
    true_positives = sum(hits)
    # Calculate accumulated hits at index
    hits_at_k = []
    current_hits = 0
    for hit in hits:
    if hit:
    current_hits += 1
    hits_at_k.append(current_hits)
    average_precision = sum(
    (current_hits / float(index)) * hit for index, (hit, current_hits) in
    enumerate(zip(hits, hits_at_k), start=1)) / float(appearances)
    average_precisions.append(average_precision)
    average_recalls.append(true_positives / float(appearances))

    if arguments.save_images:
    match_images_path = os.path.join(MATCH_IMAGES_PATH, '%s_%d' % (word, variation))
    makedirs(match_images_path, exist_ok=True)
    coordinates = word_coordinates_values[word_index]
    corpus['image'].crop(coordinates).save(os.path.join(match_images_path, '0_original.png'))

    for rank, candidate_word_index in enumerate(sorted_candidate_indices, start=1):
    coordinates = word_coordinates_values[candidate_word_index]
    path = os.path.join(match_images_path, 'candidate_%d.png' % rank)
    corpus['image'].crop(coordinates).save(path)

    # print 'Word %s (Variation: %d): %.2f%%' % (word, variation, average_precision * 100)

    print 'Mean Recall: %f' % (np.mean(average_recalls) * 100)
    mean_average_precision = np.mean(average_precisions)
    print 'Mean Average Precision: %f' % (mean_average_precision * 100)
    return mean_average_precision


    if __name__ == '__main__':
    arguments = argument_parser.parse_args()
    pre_main(arguments)