Last active
November 2, 2019 17:21
-
-
Save dnanhkhoa/cbd7ba6b860834011701c470e7771b0a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| import random | |
| import string | |
| from collections import defaultdict | |
| random.seed(42) | |
| def generate_random_code(valid_chars, code_len): | |
| return "".join(random.choices(valid_chars, k=code_len)) | |
| def generate_random_codes(num_codes, valid_chars, code_len_range): | |
| codes = set() # Using set to make sure there are no duplicates | |
| while len(codes) < num_codes: | |
| codes.add(generate_random_code(valid_chars, random.randint(*code_len_range))) | |
| return codes | |
| def load_external_codes(filename): | |
| codes = set() # Using set to make sure there are no duplicates | |
| with open(filename, "r", encoding="UTF-8") as fin: | |
| for line in filter(None, map(str.strip, fin)): | |
| codes.add(line) | |
| return codes | |
| class EditDistance: | |
| def __init__(self, max_distance, weights): | |
| self._max_distance = max_distance | |
| self._weights = weights | |
| def __find_diff_segment(self, s1, s2): | |
| pos = 0 | |
| s1_len = len(s1) | |
| s2_len = len(s2) | |
| min_len = min(s1_len, s2_len) | |
| while min_len > 0 and s1[s1_len - 1] == s2[s2_len - 1]: | |
| s1_len -= 1 | |
| s2_len -= 1 | |
| min_len -= 1 | |
| while pos < min_len and s1[pos] == s2[pos]: | |
| pos += 1 | |
| return pos, s1_len - pos, s2_len - pos | |
| def score(self, s1, s2): | |
| if abs(len(s1) - len(s2)) > self._max_distance: | |
| return -1 # Invalid case | |
| pos, s1_diff_len, s2_diff_len = self.__find_diff_segment(s1, s2) | |
| class Indexer: | |
| def __init__(self, threshold, distance_algorithm): | |
| self._original_codes = set() | |
| self._deleted_codes = defaultdict(list) | |
| self._threshold = threshold | |
| self._distance_algorithm = distance_algorithm | |
| self._max_distance = distance_algorithm._max_distance | |
| self._max_code_len = 0 | |
| def __generate_deleted_codes(self, original_code, current_code, current_distance): | |
| if len(current_code) > 1 and current_distance > 0: | |
| for i in range(len(current_code)): | |
| deleted_code = current_code[:i] + current_code[i + 1 :] | |
| if deleted_code not in self._deleted_codes: | |
| self._deleted_codes[deleted_code].append(original_code) | |
| self.__generate_deleted_codes( | |
| original_code, deleted_code, current_distance - 1 | |
| ) | |
| def index(self, codes): | |
| # This function is just called once | |
| for code in codes: | |
| if code in self._original_codes: | |
| continue | |
| self.__generate_deleted_codes(code, code, self._max_distance) | |
| self._original_codes.add(code) | |
| self._max_code_len = max(self._max_code_len, len(code)) | |
| def search(self, code): | |
| pass | |
| if __name__ == "__main__": | |
| # Configs | |
| valid_chars = string.ascii_uppercase + string.digits | |
| num_valid_chars = len(valid_chars) | |
| max_distance = 3 # [2, 3] are common in OCR correction | |
| threshold = 0.5 # This value should be > the lower bound estimated in DB | |
| weights = {} # Define a sparse cost matrix for confusable characters | |
| # Generate synthetic codes | |
| num_codes = 50_000 | |
| code_len_range = (3, 20) | |
| codes = generate_random_codes(num_codes, valid_chars, code_len_range) | |
| # Or load codes from file | |
| # codes = load_external_codes("codes.txt") | |
| edit_distance = EditDistance(max_distance, weights) | |
| indexer = Indexer(threshold, edit_distance) | |
| indexer.index(codes) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment