Skip to content

Instantly share code, notes, and snippets.

@harishkashyap
Forked from Arachnid/automata.py
Created December 11, 2015 18:24
Show Gist options
  • Save harishkashyap/3da8ba71613260835f04 to your computer and use it in GitHub Desktop.
Save harishkashyap/3da8ba71613260835f04 to your computer and use it in GitHub Desktop.

Revisions

  1. @Arachnid Arachnid revised this gist Jul 28, 2010. 2 changed files with 98 additions and 23 deletions.
    46 changes: 23 additions & 23 deletions automata.py
    Original file line number Diff line number Diff line change
    @@ -91,34 +91,32 @@ def next_state(self, src, input):
    return state_transitions.get(input, self.defaults.get(src, None))

    def next_valid_string(self, input):
    s = self.start_state
    state = self.start_state
    stack = []

    # Evaluate the DFA as far as possible
    for i, x in enumerate(input):
    stack.append((i, s, x))
    s = self.next_state(s, x)
    if not s: break
    stack.append((input[:i], state, x))
    state = self.next_state(state, x)
    if not state: break
    else:
    stack.append((i + 1, s, None))
    stack.append((input[:i+1], state, None))

    if self.is_final(s):
    if self.is_final(state):
    # Input word is already valid
    return input

    # Walk up the stack until we find a state that we can reach a final state from
    for i, s, x in reversed(stack):
    tail = ''
    # Attempt to extend the string to reach a final state
    while s:
    x = self.find_next_edge(s, x)
    if not x:
    # No edge with label greater than the one provided
    break
    tail += x
    s = self.next_state(s, x)
    if self.is_final(s):
    return input[:i] + tail
    x = None
    # Perform a 'wall following' search for the lexicographically smallest
    # accepting state.
    while stack:
    path, state, x = stack.pop()
    x = self.find_next_edge(state, x)
    if x:
    path += x
    state = self.next_state(state, x)
    if self.is_final(state):
    return path
    stack.append((path, state, None))
    return None

    def find_next_edge(self, s, x):
    @@ -168,10 +166,12 @@ def find_all_matches(word, k, lookup_func):
    Every matching word within levenshtein distance k from the database.
    """
    lev = levenshtein_automata(word, k).to_dfa()
    match = lookup_func('')
    match = lev.next_valid_string(u'\0')
    while match:
    next = lev.next_valid_string(match)
    next = lookup_func(match)
    if not next:
    return
    if match == next:
    yield match
    next = next + u'\0'
    match = lookup_func(next)
    match = lev.next_valid_string(next)
    75 changes: 75 additions & 0 deletions automata_test.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,75 @@
    import automata
    import bisect
    import random


    class Matcher(object):
    def __init__(self, l):
    self.l = l
    self.probes = 0
    def __call__(self, w):
    self.probes += 1
    pos = bisect.bisect_left(self.l, w)
    if pos < len(self.l):
    return self.l[pos]
    else:
    return None


    words = [x.strip().lower().decode('utf-8') for x in open('/usr/share/dict/web2')]
    words.sort()
    words10 = [x for x in words if random.random() <= 0.1]
    words100 = [x for x in words if random.random() <= 0.01]


    m = Matcher(words)
    assert len(list(automata.find_all_matches('food', 1, m))) == 18
    print m.probes

    m = Matcher(words)
    assert len(list(automata.find_all_matches('food', 2, m))) == 283
    print m.probes


    def levenshtein(s1, s2):
    if len(s1) < len(s2):
    return levenshtein(s2, s1)
    if not s1:
    return len(s2)

    previous_row = xrange(len(s2) + 1)
    for i, c1 in enumerate(s1):
    current_row = [i + 1]
    for j, c2 in enumerate(s2):
    insertions = previous_row[j + 1] + 1 # j+1 instead of j since previous_row and current_row are one character longer
    deletions = current_row[j] + 1 # than s2
    substitutions = previous_row[j] + (c1 != c2)
    current_row.append(min(insertions, deletions, substitutions))
    previous_row = current_row

    return previous_row[-1]

    class BKNode(object):
    def __init__(self, term):
    self.term = term
    self.children = {}

    def insert(self, other):
    distance = levenshtein(self.term, other)
    if distance in self.children:
    self.children[distance].insert(other)
    else:
    self.children[distance] = BKNode(other)

    def search(self, term, k, results=None):
    if results is None:
    results = []
    distance = levenshtein(self.term, term)
    counter = 1
    if distance <= k:
    results.append(self.term)
    for i in range(max(0, distance - k), distance + k + 1):
    child = self.children.get(i)
    if child:
    counter += child.search(term, k, results)
    return counter
  2. @Arachnid Arachnid created this gist Jul 27, 2010.
    177 changes: 177 additions & 0 deletions automata.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,177 @@
    import bisect

    class NFA(object):
    EPSILON = object()
    ANY = object()

    def __init__(self, start_state):
    self.transitions = {}
    self.final_states = set()
    self._start_state = start_state

    @property
    def start_state(self):
    return frozenset(self._expand(set([self._start_state])))

    def add_transition(self, src, input, dest):
    self.transitions.setdefault(src, {}).setdefault(input, set()).add(dest)

    def add_final_state(self, state):
    self.final_states.add(state)

    def is_final(self, states):
    return self.final_states.intersection(states)

    def _expand(self, states):
    frontier = set(states)
    while frontier:
    state = frontier.pop()
    new_states = self.transitions.get(state, {}).get(NFA.EPSILON, set()).difference(states)
    frontier.update(new_states)
    states.update(new_states)
    return states

    def next_state(self, states, input):
    dest_states = set()
    for state in states:
    state_transitions = self.transitions.get(state, {})
    dest_states.update(state_transitions.get(input, []))
    dest_states.update(state_transitions.get(NFA.ANY, []))
    return frozenset(self._expand(dest_states))

    def get_inputs(self, states):
    inputs = set()
    for state in states:
    inputs.update(self.transitions.get(state, {}).keys())
    return inputs

    def to_dfa(self):
    dfa = DFA(self.start_state)
    frontier = [self.start_state]
    seen = set()
    while frontier:
    current = frontier.pop()
    inputs = self.get_inputs(current)
    for input in inputs:
    if input == NFA.EPSILON: continue
    new_state = self.next_state(current, input)
    if new_state not in seen:
    frontier.append(new_state)
    seen.add(new_state)
    if self.is_final(new_state):
    dfa.add_final_state(new_state)
    if input == NFA.ANY:
    dfa.set_default_transition(current, new_state)
    else:
    dfa.add_transition(current, input, new_state)
    return dfa


    class DFA(object):
    def __init__(self, start_state):
    self.start_state = start_state
    self.transitions = {}
    self.defaults = {}
    self.final_states = set()

    def add_transition(self, src, input, dest):
    self.transitions.setdefault(src, {})[input] = dest

    def set_default_transition(self, src, dest):
    self.defaults[src] = dest

    def add_final_state(self, state):
    self.final_states.add(state)

    def is_final(self, state):
    return state in self.final_states

    def next_state(self, src, input):
    state_transitions = self.transitions.get(src, {})
    return state_transitions.get(input, self.defaults.get(src, None))

    def next_valid_string(self, input):
    s = self.start_state
    stack = []

    # Evaluate the DFA as far as possible
    for i, x in enumerate(input):
    stack.append((i, s, x))
    s = self.next_state(s, x)
    if not s: break
    else:
    stack.append((i + 1, s, None))

    if self.is_final(s):
    return input

    # Walk up the stack until we find a state that we can reach a final state from
    for i, s, x in reversed(stack):
    tail = ''
    # Attempt to extend the string to reach a final state
    while s:
    x = self.find_next_edge(s, x)
    if not x:
    # No edge with label greater than the one provided
    break
    tail += x
    s = self.next_state(s, x)
    if self.is_final(s):
    return input[:i] + tail
    x = None
    return None

    def find_next_edge(self, s, x):
    if x is None:
    x = u'\0'
    else:
    x = unichr(ord(x) + 1)
    state_transitions = self.transitions.get(s, {})
    if x in state_transitions or s in self.defaults:
    return x
    labels = sorted(state_transitions.keys())
    pos = bisect.bisect_left(labels, x)
    if pos < len(labels):
    return labels[pos]
    return None


    def levenshtein_automata(term, k):
    nfa = NFA((0, 0))
    for i, c in enumerate(term):
    for e in range(k + 1):
    # Correct character
    nfa.add_transition((i, e), c, (i + 1, e))
    if e < k:
    # Deletion
    nfa.add_transition((i, e), NFA.ANY, (i, e + 1))
    # Insertion
    nfa.add_transition((i, e), NFA.EPSILON, (i + 1, e + 1))
    # Substitution
    nfa.add_transition((i, e), NFA.ANY, (i + 1, e + 1))
    for e in range(k + 1):
    if e < k:
    nfa.add_transition((len(term), e), NFA.ANY, (len(term), e + 1))
    nfa.add_final_state((len(term), e))
    return nfa


    def find_all_matches(word, k, lookup_func):
    """Uses lookup_func to find all words within levenshtein distance k of word.
    Args:
    word: The word to look up
    k: Maximum edit distance
    lookup_func: A single argument function that returns the first word in the
    database that is greater than or equal to the input argument.
    Yields:
    Every matching word within levenshtein distance k from the database.
    """
    lev = levenshtein_automata(word, k).to_dfa()
    match = lookup_func('')
    while match:
    next = lev.next_valid_string(match)
    if match == next:
    yield match
    next = next + u'\0'
    match = lookup_func(next)