Skip to content

Instantly share code, notes, and snippets.

@dnanhkhoa
Last active November 2, 2019 17:21
Show Gist options
  • Save dnanhkhoa/cbd7ba6b860834011701c470e7771b0a to your computer and use it in GitHub Desktop.
Save dnanhkhoa/cbd7ba6b860834011701c470e7771b0a to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import random
import string
from collections import defaultdict
random.seed(42)
def generate_random_code(valid_chars, code_len):
return "".join(random.choices(valid_chars, k=code_len))
def generate_random_codes(num_codes, valid_chars, code_len_range):
codes = set() # Using set to make sure there are no duplicates
while len(codes) < num_codes:
codes.add(generate_random_code(valid_chars, random.randint(*code_len_range)))
return codes
def load_external_codes(filename):
codes = set() # Using set to make sure there are no duplicates
with open(filename, "r", encoding="UTF-8") as fin:
for line in filter(None, map(str.strip, fin)):
codes.add(line)
return codes
class EditDistance:
def __init__(self, max_distance, weights):
self._max_distance = max_distance
self._weights = weights
def __find_diff_segment(self, s1, s2):
pos = 0
s1_len = len(s1)
s2_len = len(s2)
min_len = min(s1_len, s2_len)
while min_len > 0 and s1[s1_len - 1] == s2[s2_len - 1]:
s1_len -= 1
s2_len -= 1
min_len -= 1
while pos < min_len and s1[pos] == s2[pos]:
pos += 1
return pos, s1_len - pos, s2_len - pos
def score(self, s1, s2):
if abs(len(s1) - len(s2)) > self._max_distance:
return -1 # Invalid case
pos, s1_diff_len, s2_diff_len = self.__find_diff_segment(s1, s2)
class Indexer:
def __init__(self, threshold, distance_algorithm):
self._original_codes = set()
self._deleted_codes = defaultdict(list)
self._threshold = threshold
self._distance_algorithm = distance_algorithm
self._max_distance = distance_algorithm._max_distance
self._max_code_len = 0
def __generate_deleted_codes(self, original_code, current_code, current_distance):
if len(current_code) > 1 and current_distance > 0:
for i in range(len(current_code)):
deleted_code = current_code[:i] + current_code[i + 1 :]
if deleted_code not in self._deleted_codes:
self._deleted_codes[deleted_code].append(original_code)
self.__generate_deleted_codes(
original_code, deleted_code, current_distance - 1
)
def index(self, codes):
# This function is just called once
for code in codes:
if code in self._original_codes:
continue
self.__generate_deleted_codes(code, code, self._max_distance)
self._original_codes.add(code)
self._max_code_len = max(self._max_code_len, len(code))
def search(self, code):
pass
if __name__ == "__main__":
# Configs
valid_chars = string.ascii_uppercase + string.digits
num_valid_chars = len(valid_chars)
max_distance = 3 # [2, 3] are common in OCR correction
threshold = 0.5 # This value should be > the lower bound estimated in DB
weights = {} # Define a sparse cost matrix for confusable characters
# Generate synthetic codes
num_codes = 50_000
code_len_range = (3, 20)
codes = generate_random_codes(num_codes, valid_chars, code_len_range)
# Or load codes from file
# codes = load_external_codes("codes.txt")
edit_distance = EditDistance(max_distance, weights)
indexer = Indexer(threshold, edit_distance)
indexer.index(codes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment