Last active
October 16, 2021 18:49
-
-
Save kastnerkyle/97120046d0aa8f49c3fce03b844329d7 to your computer and use it in GitHub Desktop.
Revisions
-
kastnerkyle revised this gist
Jun 8, 2017 . 1 changed file with 31 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -28,6 +28,7 @@ import argparse import cPickle as pickle import time from itertools import izip class Beam(object): @@ -82,6 +83,7 @@ def __iter__(self): def beamsearch(probabilities_function, beam_width=10, clip_len=-1, start_token="<START>", end_token="<EOS>", use_log=True, renormalize=True, length_score=True, diversity_score=True, stochastic=False, temperature=1.0, random_state=None, eps=1E-9): """ @@ -161,8 +163,31 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, else: min_prob = 1. if diversity_score: # get prefixes pre = [r[-1][len(start_token):] for r in prev_beam] base = set(pre[0]) diversity_scores = [] # score for first entry if use_log: diversity_scores.append(0.) else: diversity_scores.append(1.) if len(pre) > 1: for pre_i in pre[1:]: s = set(pre_i) union = base | s # number of new things + (- number of repetitions) sc = (len(union) - len(base)) - (len(pre_i) - len(s)) # update it base = union if use_log: diversity_scores.append(sc) else: diversity_scores.append(np.exp(sc)) # Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. for ni, (prefix_score, complete, prefix_prob, prefix) in enumerate(prev_beam): if complete == True: curr_beam.add(prefix_score, True, prefix_prob, prefix) else: @@ -180,12 +205,16 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, score = prefix_prob + np.log(n) + np.log(len(prefix)) - min_prob else: score = prefix_prob + np.log(n) - min_prob if diversity_score: score = score + diversity_scores[ni] prob = prefix_prob + np.log(n) else: if length_score: score = (prefix_prob * n * len(prefix)) / min_prob else: score = (prefix_prob * n) / min_prob if diversity_score: score = score * diversity_scores[ni] prob = prefix_prob * n if end_token_is_seq: @@ -362,6 +391,7 @@ def pf(prefix): end_token=end_token, clip_len=n_letters, stochastic=stochastic, diversity_score=True, random_state=random_state) # it is a generator but do this so that function prototypes are consistent all_r = [] -
kastnerkyle revised this gist
Jun 2, 2017 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -56,14 +56,14 @@ def add(self, score, complete, prob, prefix): while len(self.heap) > self.beam_width: if self.stochastic: # same whether logspace or no? probs = np.array([h[0] for h in self.heap]) probs = probs / self.temperature e_x = np.exp(probs - np.max(probs)) s_x = e_x / e_x.sum() is_x = 1. - s_x is_x = is_x / is_x.sum() to_remove = self.random_state.multinomial(1, is_x).argmax() completed = [n for n, h in enumerate(self.heap) if h[1] == True] # Don't remove completed sentences by randomness if to_remove not in completed: # there must be a faster way... -
kastnerkyle revised this gist
May 20, 2017 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -57,7 +57,7 @@ def add(self, score, complete, prob, prefix): if self.stochastic: # same whether logspace or no? probs = np.array([h[2] for h in self.heap]) probs = probs / self.temperature e_x = np.exp(probs - np.max(probs)) s_x = e_x / e_x.sum() is_x = 1. - s_x -
kastnerkyle revised this gist
May 20, 2017 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -390,7 +390,7 @@ def pf(prefix): default_beamwidth = 10 default_decoder = 0 default_randomseed = 1999 default_maxlength = 500 default_cache = 1 default_print = 1 default_verbose = 1 -
kastnerkyle revised this gist
May 20, 2017 . 1 changed file with 14 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -26,7 +26,7 @@ import os import sys import argparse import cPickle as pickle import time @@ -476,12 +476,17 @@ def pf(prefix): raise ValueError("Decoder must be 0, 1, 2, or 3! Was set to {}".format(decoder)) # only things that affect the language model are training data, temperature, order cached_name = "model_{}_t{}_o{}.pkl".format("".join(fpath.split(".")[:-1]), str(temperature).replace(".", "pt"), order) if cache == 1 and os.path.exists(cached_name): print("Found cached model at {}, loading...".format(cached_name)) start_time = time.time() with open(cached_name, "rb") as f: lm = pickle.load(f) """ # codec troubles :( with open(cached_name, "r") as f: lm = json.load(f, encoding="latin1") """ stop_time = time.time() print("Time to load: {} s".format(stop_time - start_time)) else: @@ -492,8 +497,13 @@ def pf(prefix): print("Time to train: {} s".format(stop_time - start_time)) if cache == 1: print("Caching model now...") with open(cached_name, "wb") as f: pickle.dump(lm, f) """ # codec troubles :( with open(cached_name, "w") as f: json.dump(lm, f, encoding="latin1") """ print("Caching complete!") # All this logic to handle/match different start keys -
kastnerkyle revised this gist
May 20, 2017 . 1 changed file with 36 additions and 20 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -26,7 +26,7 @@ import os import sys import argparse import json import time @@ -222,7 +222,7 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, else: stop = -1 yield (best_prefix[skip:stop], best_score, best_prob) sorted_beam.pop() completed_beams += 1 any_removals = True @@ -306,8 +306,8 @@ def generate_letter(lm, history, order, stochastic, random_state): def step_text(lm, order, stochastic, random_seed, history=None, end=None, beam_width=1, n_letters=1000, verbose=False): # beam_width argument is ignored, as is end, and verbose if history is None or history == "<START>": history = "~" * order else: @@ -324,7 +324,7 @@ def step_text(lm, order, stochastic, random_seed, history=None, end=None, def beam_text(lm, order, stochastic, random_seed, history=None, end=None, beam_width=10, n_letters=1000, verbose=False): def pf(prefix): history = prefix[-order:] # lm wants key as a single string @@ -364,14 +364,21 @@ def pf(prefix): stochastic=stochastic, random_state=random_state) # it is a generator but do this so that function prototypes are consistent all_r = [] for r in b: all_r.append((r[0], r[1], r[2])) # reorder so final scoring is matched (rather than completion order) all_r = sorted(all_r, key=lambda x: x[1]) returns = [] for r in all_r: s_r = "".join(r[0]) if verbose: s_r = s_r + "\nScore: {}".format(r[1]) + "\nProbability: {}".format(r[2]) returns.append(s_r) # return list of all beams, ordered by score return returns if __name__ == "__main__": @@ -383,10 +390,10 @@ def pf(prefix): default_beamwidth = 10 default_decoder = 0 default_randomseed = 1999 default_maxlength = 400 default_cache = 1 default_print = 1 default_verbose = 1 # TODO: Faster cache parser = argparse.ArgumentParser(description="A Markov chain character level language model with beamsearch decoding", @@ -403,6 +410,7 @@ def pf(prefix): parser.add_argument("-m", "--maxlength", help="Max generation length.\nDefault {}".format(default_maxlength), default=default_maxlength) parser.add_argument("-c", "--cache", help="Whether to cache models for faster use.\nDefault {}".format(default_cache), default=default_cache) parser.add_argument("-a", "--allbeams", help="Print all beams for beamsearch, 0 for top only, 1 for all.\nDefault {}".format(default_print), default=default_print) parser.add_argument("-v", "--verbose", help="Print the score and probability for beams.\nDefault {}".format(default_verbose), default=default_verbose) args = parser.parse_args() @@ -422,6 +430,7 @@ def pf(prefix): random_seed = int(args.randomseed) maxlength = int(args.maxlength) allbeams = int(args.allbeams) verbose = int(args.verbose) order = int(args.order) if order < 1: @@ -467,12 +476,12 @@ def pf(prefix): raise ValueError("Decoder must be 0, 1, 2, or 3! Was set to {}".format(decoder)) # only things that affect the language model are training data, temperature, order cached_name = "model_{}_t{}_o{}.json".format("".join(fpath.split(".")[:-1]), str(temperature).replace(".", "pt"), order) if cache == 1 and os.path.exists(cached_name): print("Found cached model at {}, loading...".format(cached_name)) start_time = time.time() with open(cached_name, "r") as f: lm = json.load(f) stop_time = time.time() print("Time to load: {} s".format(stop_time - start_time)) else: @@ -483,8 +492,8 @@ def pf(prefix): print("Time to train: {} s".format(stop_time - start_time)) if cache == 1: print("Caching model now...") with open(cached_name, "w") as f: json.dump(lm, f) print("Caching complete!") # All this logic to handle/match different start keys @@ -530,10 +539,17 @@ def pf(prefix): else: raise ValueError("Unknown setting for allbeams {}".format(allbeams)) if verbose == 0: verbose = False elif verbose == 1: verbose = True else: raise ValueError("Unknown setting for verbose {}".format(verbose)) start_time = time.time() all_o = decode_fun(lm, order, stochastic, random_seed, history=start_token, end=end_token, beam_width=beam_width, n_letters=maxlength, verbose=verbose) stop_time = time.time() print(type_tag) @@ -546,11 +562,11 @@ def pf(prefix): for n, oi in enumerate(all_o): if len(all_o) > 1: if n == 0: print("BEAM {} (worst score)".format(n + 1)) elif n != (len(all_o) - 1): print("BEAM {}".format(n + 1)) else: print("BEAM {} (best score)".format(n + 1)) print("----------") if user_start_token: print("".join(start_token) + oi) -
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 1 addition and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -237,8 +237,7 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, break else: prev_beam = Beam(beam_width - completed_beams, sorted_beam, use_log, stochastic, temperature, random_state) else: prev_beam = curr_beam -
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 2 additions and 6 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -307,7 +307,7 @@ def generate_letter(lm, history, order, stochastic, random_state): def step_text(lm, order, stochastic, random_seed, history=None, end=None, beam_width=1, n_letters=1000): # beam_width argument is ignored, as is end, and return_count if history is None or history == "<START>": history = "~" * order @@ -325,7 +325,7 @@ def step_text(lm, order, stochastic, random_seed, history=None, end=None, def beam_text(lm, order, stochastic, random_seed, history=None, end=None, beam_width=10, n_letters=1000): def pf(prefix): history = prefix[-order:] # lm wants key as a single string @@ -366,14 +366,10 @@ def pf(prefix): random_state=random_state) # it is a generator but do this so that function prototypes are consistent # top beam search output all_r = [] for r in b: s_r = "".join(r[0]) all_r.append(s_r) # return list of all beams, ordered by probability return all_r -
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 7 additions and 9 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -38,8 +38,7 @@ class Beam(object): over an incomplete one since (0.5, False, whatever_prefix) < (0.5, True, some_other_prefix) """ def __init__(self, beam_width, init_beam=None, use_log=True, stochastic=False, temperature=1.0, random_state=None): if init_beam is None: self.heap = list() else: @@ -51,11 +50,10 @@ def __init__(self, beam_width, init_beam=None, use_log=True, # use_log currently unused... self.use_log = use_log self.beam_width = beam_width def add(self, score, complete, prob, prefix): heapq.heappush(self.heap, (score, complete, prob, prefix)) while len(self.heap) > self.beam_width: if self.stochastic: # same whether logspace or no? probs = np.array([h[2] for h in self.heap]) @@ -123,8 +121,8 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, raise ValueError("Must pass np.random.RandomState() object if stochastic=True") completed_beams = 0 prev_beam = Beam(beam_width - completed_beams, None, use_log, stochastic, temperature, random_state) try: basestring except NameError: @@ -151,8 +149,8 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, while True: curr_beam = Beam(beam_width - completed_beams, None, use_log, stochastic, temperature, random_state) if renormalize: sorted_prev_beam = sorted(prev_beam) # renormalize by the previous minimum value in the beam @@ -238,7 +236,7 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, if len(sorted_beam) == 0: break else: prev_beam = Beam(beam_width - completed_beams, sorted_beam, use_log, stochastic, temperature, random_state, completed_beams) else: -
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 21 additions and 12 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -38,7 +38,8 @@ class Beam(object): over an incomplete one since (0.5, False, whatever_prefix) < (0.5, True, some_other_prefix) """ def __init__(self, beam_width, init_beam=None, use_log=True, stochastic=False, temperature=1.0, random_state=None, completed_beams=0): if init_beam is None: self.heap = list() else: @@ -50,10 +51,11 @@ def __init__(self, beam_width, init_beam=None, use_log=True, # use_log currently unused... self.use_log = use_log self.beam_width = beam_width self.completed_beams = completed_beams def add(self, score, complete, prob, prefix): heapq.heappush(self.heap, (score, complete, prob, prefix)) while len(self.heap) > (self.beam_width - self.completed_beams): if self.stochastic: # same whether logspace or no? probs = np.array([h[2] for h in self.heap]) @@ -120,8 +122,9 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, if random_state is None: raise ValueError("Must pass np.random.RandomState() object if stochastic=True") completed_beams = 0 prev_beam = Beam(beam_width, None, use_log, stochastic, temperature, random_state, completed_beams) try: basestring except NameError: @@ -146,9 +149,10 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, else: prev_beam.add(1.0, False, 1.0, start_token) while True: curr_beam = Beam(beam_width, None, use_log, stochastic, temperature, random_state, completed_beams) if renormalize: sorted_prev_beam = sorted(prev_beam) # renormalize by the previous minimum value in the beam @@ -222,6 +226,7 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, yield (best_prefix[skip:stop], best_prob) sorted_beam.pop() completed_beams += 1 any_removals = True # If there are no more sentences in the beam then stop checking if len(sorted_beam) == 0: @@ -234,7 +239,8 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, break else: prev_beam = Beam(beam_width, sorted_beam, use_log, stochastic, temperature, random_state, completed_beams) else: prev_beam = curr_beam @@ -403,7 +409,7 @@ def pf(prefix): parser.add_argument("-e", "--endtoken", help="Random seed to initialize randomness. Can be a string such as 'goodbye\\n'.\nDefault {}".format(default_end), default=default_end) parser.add_argument("-m", "--maxlength", help="Max generation length.\nDefault {}".format(default_maxlength), default=default_maxlength) parser.add_argument("-c", "--cache", help="Whether to cache models for faster use.\nDefault {}".format(default_cache), default=default_cache) parser.add_argument("-a", "--allbeams", help="Print all beams for beamsearch, 0 for top only, 1 for all.\nDefault {}".format(default_print), default=default_print) args = parser.parse_args() @@ -524,23 +530,26 @@ def pf(prefix): if user_end_token: end_token = list(end_token) if allbeams == 0: return_count = 1 elif allbeams == 1: pass else: raise ValueError("Unknown setting for allbeams {}".format(allbeams)) start_time = time.time() all_o = decode_fun(lm, order, stochastic, random_seed, history=start_token, end=end_token, beam_width=beam_width, n_letters=maxlength) stop_time = time.time() print(type_tag) print("Time to decode: {} s".format(stop_time - start_time)) print("----------") if allbeams == 0: all_o = [all_o[0]] for n, oi in enumerate(all_o): if len(all_o) > 1: if n == 0: -
kastnerkyle revised this gist
May 19, 2017 . No changes.There are no files selected for viewing
-
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 6 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -543,7 +543,12 @@ def pf(prefix): print("----------") for n, oi in enumerate(all_o): if len(all_o) > 1: if n == 0: print("BEAM {} (best score)".format(n + 1)) elif n != (len(all_o) - 1): print("BEAM {}".format(n + 1)) else: print("BEAM {} (worst score)".format(n + 1)) print("----------") if user_start_token: print("".join(start_token) + oi) -
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 25 additions and 16 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -53,7 +53,7 @@ def __init__(self, beam_width, init_beam=None, use_log=True, def add(self, score, complete, prob, prefix): heapq.heappush(self.heap, (score, complete, prob, prefix)) while len(self.heap) > self.beam_width: if self.stochastic: # same whether logspace or no? probs = np.array([h[2] for h in self.heap]) @@ -303,8 +303,8 @@ def generate_letter(lm, history, order, stochastic, random_state): def step_text(lm, order, stochastic, random_seed, history=None, end=None, beam_width=1, n_letters=1000, return_count=None): # beam_width argument is ignored, as is end, and return_count if history is None or history == "<START>": history = "~" * order else: @@ -321,7 +321,7 @@ def step_text(lm, order, stochastic, random_seed, history=None, end=None, def beam_text(lm, order, stochastic, random_seed, history=None, end=None, beam_width=10, n_letters=1000, return_count=None): def pf(prefix): history = prefix[-order:] # lm wants key as a single string @@ -362,13 +362,16 @@ def pf(prefix): random_state=random_state) # it is a generator but do this so that function prototypes are consistent # top beam search output if return_count == None: return_count = -1 all_r = [] for r in b: s_r = "".join(r[0]) all_r.append(s_r) if return_count > 0 and len(all_r) >= return_count: break # return list of all beams, ordered by probability return all_r @@ -384,7 +387,7 @@ def pf(prefix): default_maxlength = 1000 default_cycles = 3 default_cache = 1 default_print = 1 # TODO: Faster cache parser = argparse.ArgumentParser(description="A Markov chain character level language model with beamsearch decoding", @@ -400,7 +403,7 @@ def pf(prefix): parser.add_argument("-e", "--endtoken", help="Random seed to initialize randomness. Can be a string such as 'goodbye\\n'.\nDefault {}".format(default_end), default=default_end) parser.add_argument("-m", "--maxlength", help="Max generation length.\nDefault {}".format(default_maxlength), default=default_maxlength) parser.add_argument("-c", "--cache", help="Whether to cache models for faster use.\nDefault {}".format(default_cache), default=default_cache) parser.add_argument("-a", "--allbeams", help="Print all beams for beamsearch, 0 for top only, 1 for top beam_width, 2 for all.\nDefault {}".format(default_print), default=default_print) args = parser.parse_args() @@ -523,17 +526,23 @@ def pf(prefix): end_token = list(end_token) start_time = time.time() if allbeams == 0: return_count = 1 elif allbeams == 1: return_count = beam_width elif allbeams == 2: return_count = -1 else: raise ValueError("Unknown setting for allbeams {}".format(allbeams)) all_o = decode_fun(lm, order, stochastic, random_seed, history=start_token, end=end_token, beam_width=beam_width, n_letters=maxlength, return_count=return_count) stop_time = time.time() print(type_tag) print("Time to decode: {} s".format(stop_time - start_time)) print("----------") for n, oi in enumerate(all_o): if len(all_o) > 1: print("BEAM {}".format(n + 1)) print("----------") if user_start_token: -
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 47 additions and 15 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -64,12 +64,13 @@ def add(self, score, complete, prob, prefix): is_x = is_x / is_x.sum() to_remove = self.random_state.multinomial(1, is_x).argmax() completed = [n for n, h in enumerate(self.heap) if h[2] == True] # Don't remove completed sentences by randomness if to_remove not in completed: # there must be a faster way... self.heap.pop(to_remove) heapq.heapify(self.heap) else: heapq.heappop(self.heap) else: # remove lowest score from heap heapq.heappop(self.heap) @@ -174,13 +175,13 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, # score is renormalized prob if use_log: if length_score: score = prefix_prob + np.log(n) + np.log(len(prefix)) - min_prob else: score = prefix_prob + np.log(n) - min_prob prob = prefix_prob + np.log(n) else: if length_score: score = (prefix_prob * n * len(prefix)) / min_prob else: score = (prefix_prob * n) / min_prob prob = prefix_prob * n @@ -205,7 +206,6 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, while True: # Get highest probability prefix - heapq is sorted in ascending order (best_score, best_complete, best_prob, best_prefix) = sorted_beam[-1] if best_complete == True or len(best_prefix) - 1 == clip_len: # If most probable prefix is a complete sentence or has a length that # exceeds the clip length (ignoring the start token) then return it @@ -308,23 +308,38 @@ def step_text(lm, order, stochastic, random_seed, history=None, end=None, if history is None or history == "<START>": history = "~" * order else: history = "".join(history).decode("string_escape") out = [] random_state = np.random.RandomState(random_seed) for i in range(n_letters): c = generate_letter(lm, history, order, stochastic, random_state) history = history[-order:] + c out.append(c) # return list to match beam_text return ["".join(out)] def beam_text(lm, order, stochastic, random_seed, history=None, end=None, beam_width=10, n_letters=1000): def pf(prefix): history = prefix[-order:] # lm wants key as a single string k = "".join(history).decode("string_escape") # sometimes the distribution "dead-ends"... try: dist = lm[k] except KeyError: alt_keys = [i for i in lm.keys() if "".join(prefix[-order:-1]) in i and "".join(prefix[-order-1:-1]) != i] # if no alternates, start from a random place if len(alt_keys) == 0: # choose a key at semi-random ak = lm.keys() dist = lm[ak[random_seed % len(ak)]] else: dist = lm[alt_keys[0]] return dist if history is None or history == "<START>": @@ -347,8 +362,14 @@ def pf(prefix): random_state=random_state) # it is a generator but do this so that function prototypes are consistent # top beam search output all_r = [] #for _ in range(beam_width): # r = next(b) for r in b: s_r = "".join(r[0]) all_r.append(s_r) # return list of all beams return all_r if __name__ == "__main__": @@ -363,6 +384,7 @@ def pf(prefix): default_maxlength = 1000 default_cycles = 3 default_cache = 1 default_print = 0 # TODO: Faster cache parser = argparse.ArgumentParser(description="A Markov chain character level language model with beamsearch decoding", @@ -378,6 +400,7 @@ def pf(prefix): parser.add_argument("-e", "--endtoken", help="Random seed to initialize randomness. Can be a string such as 'goodbye\\n'.\nDefault {}".format(default_end), default=default_end) parser.add_argument("-m", "--maxlength", help="Max generation length.\nDefault {}".format(default_maxlength), default=default_maxlength) parser.add_argument("-c", "--cache", help="Whether to cache models for faster use.\nDefault {}".format(default_cache), default=default_cache) parser.add_argument("-a", "--allbeams", help="Print all beams for beamsearch, 0 for top only, 1 for all.\nDefault {}".format(default_print), default=default_print) args = parser.parse_args() @@ -396,6 +419,7 @@ def pf(prefix): temperature = float(args.temperature) random_seed = int(args.randomseed) maxlength = int(args.maxlength) allbeams = int(args.allbeams) order = int(args.order) if order < 1: @@ -499,13 +523,21 @@ def pf(prefix): end_token = list(end_token) start_time = time.time() all_o = decode_fun(lm, order, stochastic, random_seed, history=start_token, end=end_token, beam_width=beam_width, n_letters=maxlength) stop_time = time.time() print(type_tag) print("Time to decode: {} s".format(stop_time - start_time)) print("----------") if allbeams == 0: o = [all_o[0]] else: o = all_o for n, oi in enumerate(o): if len(o) > 1: print("BEAM {}".format(n + 1)) print("----------") if user_start_token: print("".join(start_token) + oi) else: print(oi) print("----------") -
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -324,7 +324,7 @@ def beam_text(lm, order, stochastic, random_seed, history=None, def pf(prefix): history = prefix[-order:] # lm wants key as a single string dist = lm["".join(history).replace("\\n", "\n")] return dist if history is None or history == "<START>": -
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 3 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -307,6 +307,8 @@ def step_text(lm, order, stochastic, random_seed, history=None, end=None, # beam_width argument is ignored, as is end if history is None or history == "<START>": history = "~" * order else: history = "".join(history) out = [] random_state = np.random.RandomState(random_seed) @@ -362,6 +364,7 @@ def pf(prefix): default_cycles = 3 default_cache = 1 # TODO: Faster cache parser = argparse.ArgumentParser(description="A Markov chain character level language model with beamsearch decoding", epilog="Simple usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10\nFull usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10 -d 0 -s 'HOLOFERNES' -e 'crew?\\n' -r 2177", formatter_class=argparse.RawTextHelpFormatter) @@ -375,7 +378,6 @@ def pf(prefix): parser.add_argument("-e", "--endtoken", help="Random seed to initialize randomness. Can be a string such as 'goodbye\\n'.\nDefault {}".format(default_end), default=default_end) parser.add_argument("-m", "--maxlength", help="Max generation length.\nDefault {}".format(default_maxlength), default=default_maxlength) parser.add_argument("-c", "--cache", help="Whether to cache models for faster use.\nDefault {}".format(default_cache), default=default_cache) args = parser.parse_args() -
kastnerkyle revised this gist
May 19, 2017 . 1 changed file with 52 additions and 17 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -363,16 +363,16 @@ def pf(prefix): default_cache = 1 parser = argparse.ArgumentParser(description="A Markov chain character level language model with beamsearch decoding", epilog="Simple usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10\nFull usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10 -d 0 -s 'HOLOFERNES' -e 'crew?\\n' -r 2177", formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("filepath", help="Path to file to use for language modeling. For an example file, try downloading\nhttp://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", default=None) parser.add_argument("-o", "--order", help="Markov chain order, higher will make better text but takes longer to process.\nDefault {}".format(default_order), default=default_order) parser.add_argument("-t", "--temperature", help="Temperature for Markov chain softmax, higher is more random, lower more static.\nDefault {}".format(default_temperature), default=default_temperature) parser.add_argument("-d","--decoder", help="Decoder for Markov chain, 0 is stochastic beamsearch, 1 is argmax beamsearch, 2 is sampled next-step, 3 is argmax next-step.\nDefault {}".format(default_decoder), default=default_decoder) parser.add_argument("-b", "--beamwidth", help="Beamwidth to use for beamsearch.\nDefault {}".format(default_beamwidth), default=default_beamwidth) parser.add_argument("-r", "--randomseed", help="Random seed to initialize randomness.\nDefault {}".format(default_randomseed), default=default_randomseed) parser.add_argument("-s", "--starttoken", help="Start sequence token. Can be a string such as 'hello\\n', extra padding will be inferred from the data.\nDefault {}".format(default_start), default=default_start) parser.add_argument("-e", "--endtoken", help="Random seed to initialize randomness. Can be a string such as 'goodbye\\n'.\nDefault {}".format(default_end), default=default_end) parser.add_argument("-m", "--maxlength", help="Max generation length.\nDefault {}".format(default_maxlength), default=default_maxlength) parser.add_argument("-c", "--cache", help="Whether to cache models for faster use.\nDefault {}".format(default_cache), default=default_cache) #parser.add_argument("-c", "--cycles", help="Number of cycles, using the last piece of the previous beam to start a new one. Can be useful for long beamsearches. Default {}".format(default_cycles)) @@ -389,20 +389,6 @@ def pf(prefix): decoder_settings = [0, 1, 2, 3] decoder = int(args.decoder) # TODO: gumbel-max in stochastic beam decoder...? beam_width = int(args.beamwidth) temperature = float(args.temperature) @@ -417,6 +403,18 @@ def pf(prefix): if cache not in [0, 1]: raise ValueError("Cache must be either 0 (no save) or 1 (save)! Was set to {}".format(cache)) start_token = args.starttoken.decode("string_escape") if start_token != default_start: user_start_token = True else: user_start_token = False end_token = args.endtoken.decode("string_escape") if end_token != default_end: user_end_token = True else: user_end_token = False if decoder == 0: # stochastic beam stochastic = True @@ -461,6 +459,43 @@ def pf(prefix): pickle.dump(lm, f, pickle.HIGHEST_PROTOCOL) print("Caching complete!") # All this logic to handle/match different start keys rs = np.random.RandomState(random_seed) if user_start_token: if len(start_token) > order: start_token = start_token[-order:] print("WARNING: specified start token larger than order, truncating to\n{}".format(start_token)) if len(start_token) <= order: matching_keys = [k for k in lm.keys() if k.endswith(start_token)] all_keys = [k for k in lm.keys()] while True: if len(matching_keys) == 0: rs.shuffle(all_keys) print("No matching key for `{}` in language model!".format(start_token)) print("Please enter another one (suggestions in backticks)\n`{}`\n`{}`\n`{}`)".format(all_keys[0], all_keys[1], all_keys[2])) line = raw_input('Prompt ("Ctrl-C" to quit): ') line = line.strip() if len(line) == 0: continue else: start_token = line matching_keys = [k for k in lm.keys() if k.endswith(start_token)] else: break if len(start_token) < order: # choose key at random matching_keys = [k for k in lm.keys() if k.endswith(start_token)] rs.shuffle(matching_keys) start_token = matching_keys[0] print("WARNING: start key shorter than order, set to\n`{}`".format(start_token)) start_token = list(start_token) if user_end_token: end_token = list(end_token) start_time = time.time() o = decode_fun(lm, order, stochastic, random_seed, history=start_token, end=end_token, beam_width=beam_width, n_letters=maxlength) stop_time = time.time() -
kastnerkyle revised this gist
May 18, 2017 . 1 changed file with 6 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -5,9 +5,13 @@ # markov_lm.py # https://gist.github.com/yoavg/d76121dfde2618422139 # These datasets can be a lot of fun... # # https://github.com/frnsys/texts # python minimal_beamsearch.py 2600_phrases_for_effective_performance_reviews.txt -o 5 -d 0 # # Download kjv.txt from http://www.ccel.org/ccel/bible/kjv.txt # python minimal_beamsearch.py kjv.txt -o 5 -d 2 -r 2145 # Snippet: # Queen ording found Raguel: I kill. # THROUGH JESUS OF OUR BRETHREN, AND PEACE, -
kastnerkyle revised this gist
May 18, 2017 . 1 changed file with 54 additions and 29 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -59,9 +59,13 @@ def add(self, score, complete, prob, prefix): is_x = 1. - s_x is_x = is_x / is_x.sum() to_remove = self.random_state.multinomial(1, is_x).argmax() completed = [n for n, h in enumerate(self.heap) if h[2] == True] # Don't remove completed sentences, ever if to_remove not in completed: # there must be a faster way... self.heap.pop(to_remove) heapq.heapify(self.heap) else: # remove lowest score from heap heapq.heappop(self.heap) @@ -117,6 +121,7 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, basestring except NameError: basestring = str if isinstance(start_token, collections.Sequence) and not isinstance(start_token, basestring): start_token_is_seq = True else: @@ -184,8 +189,8 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, right_cmp = end_token if left_cmp == right_cmp: # If next word is the end token then mark prefix as complete curr_beam.add(score, True, prob, prefix + [next_word]) else: curr_beam.add(score, False, prob, prefix + [next_word]) @@ -205,10 +210,12 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, skip = len(start_token) else: skip = 1 if end_token_is_seq: stop = None else: stop = -1 yield (best_prefix[skip:stop], best_prob) sorted_beam.pop() any_removals = True @@ -291,9 +298,9 @@ def generate_letter(lm, history, order, stochastic, random_state): return c def step_text(lm, order, stochastic, random_seed, history=None, end=None, beam_width=1, n_letters=1000): # beam_width argument is ignored, as is end if history is None or history == "<START>": history = "~" * order @@ -307,7 +314,7 @@ def step_text(lm, order, stochastic, random_seed, history=None, stop_token=None, def beam_text(lm, order, stochastic, random_seed, history=None, end=None, beam_width=10, n_letters=1000): def pf(prefix): history = prefix[-order:] # lm wants key as a single string @@ -316,6 +323,10 @@ def pf(prefix): if history is None or history == "<START>": start_token = ["~"] * order else: start_token = history if len(start_token) != order: raise ValueError("Start length must match order setting of {}! {} is length {}".format(order, history, len(history))) if end is None: end_token = "<EOS>" @@ -329,7 +340,7 @@ def pf(prefix): stochastic=stochastic, random_state=random_state) # it is a generator but do this so that function prototypes are consistent # top beam search output r = next(b) return "".join(r[0]) @@ -347,18 +358,19 @@ def pf(prefix): default_cycles = 3 default_cache = 1 parser = argparse.ArgumentParser(description="A Markov chain character level language model with beamsearch decoding", epilog="Simple usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10\nFull usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10 -d 0 -s 'H,O,L,O,F,E,R,N,E,S' -e 'c,r,e,w,?,\\n' -r 2177", formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("filepath", help="Path to file to use for language modeling. For an example file, try downloading\nhttp://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", default=None) parser.add_argument("-o", "--order", help="Markov chain order, higher will make better text but takes longer to process.\nDefault {}".format(default_order), default=default_order) parser.add_argument("-t", "--temperature", help="Temperature for Markov chain softmax, higher is more random, lower more static.\nDefault {}".format(default_temperature), default=default_temperature) parser.add_argument("-d","--decoder", help="Decoder for Markov chain, 0 is stochastic beamsearch, 1 is argmax beamsearch, 2 is sampled next-step, 3 is argmax next-step.\nDefault {}".format(default_decoder), default=default_decoder) parser.add_argument("-b", "--beamwidth", help="Beamwidth to use for beamsearch.\nDefault {}".format(default_beamwidth), default=default_beamwidth) parser.add_argument("-r", "--randomseed", help="Random seed to initialize randomness.\nDefault {}".format(default_randomseed), default=default_randomseed) parser.add_argument("-s", "--starttoken", help="Start sequence token. Can be a comma separated list such as '\\n,\\n', same length as --order argument.\nDefault {}".format(default_start), default=default_start) parser.add_argument("-e", "--endtoken", help="Random seed to initialize randomness. Can be a comma separated list such as '\\n,\\n'.\nDefault {}".format(default_end), default=default_end) parser.add_argument("-m", "--maxlength", help="Max generation length.\nDefault {}".format(default_maxlength), default=default_maxlength) parser.add_argument("-c", "--cache", help="Whether to cache models for faster use.\nDefault {}".format(default_cache), default=default_cache) #parser.add_argument("-c", "--cycles", help="Number of cycles, using the last piece of the previous beam to start a new one. Can be useful for long beamsearches. Default {}".format(default_cycles)) args = parser.parse_args() @@ -373,11 +385,21 @@ def pf(prefix): decoder_settings = [0, 1, 2, 3] decoder = int(args.decoder) start_token = args.starttoken.decode("string_escape") if start_token != default_start: user_start_token = True else: user_start_token = False if "," in start_token: start_token = [str(si).decode("string_escape") for si in start_token.split(",")] end_token = args.endtoken.decode("string_escape") if "," in end_token: end_token = [str(ei).decode("string_escape") for ei in end_token.split(",")] # TODO: gumbel-max in stochastic beam decoder...? beam_width = int(args.beamwidth) temperature = float(args.temperature) random_seed = int(args.randomseed) @@ -436,10 +458,13 @@ def pf(prefix): print("Caching complete!") start_time = time.time() o = decode_fun(lm, order, stochastic, random_seed, history=start_token, end=end_token, beam_width=beam_width, n_letters=maxlength) stop_time = time.time() print(type_tag) print("Time to decode: {} s".format(stop_time - start_time)) print("----------") if user_start_token: print("".join(start_token) + o) else: print(o) print("----------") -
kastnerkyle revised this gist
May 18, 2017 . 1 changed file with 170 additions and 79 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -5,13 +5,26 @@ # markov_lm.py # https://gist.github.com/yoavg/d76121dfde2618422139 # Fun alternate settings # Download kjv.txt from http://www.ccel.org/ccel/bible/kjv.txt # python markov_lm.py kjv.txt 5 1. # Snippet: # Queen ording found Raguel: I kill. # THROUGH JESUS OF OUR BRETHREN, AND PEACE, # # NUN. import numpy as np import heapq from collections import defaultdict, Counter import collections import os import sys import argparse import cPickle as pickle import time class Beam(object): """ @@ -38,8 +51,10 @@ def add(self, score, complete, prob, prefix): heapq.heappush(self.heap, (score, complete, prob, prefix)) if len(self.heap) > self.beam_width: if self.stochastic: # same whether logspace or no? probs = np.array([h[2] for h in self.heap]) probs = probs / temperature e_x = np.exp(probs - np.max(probs)) s_x = e_x / e_x.sum() is_x = 1. - s_x is_x = is_x / is_x.sum() @@ -56,7 +71,7 @@ def __iter__(self): def beamsearch(probabilities_function, beam_width=10, clip_len=-1, start_token="<START>", end_token="<EOS>", use_log=True, renormalize=True, length_score=True, stochastic=False, temperature=1.0, random_state=None, eps=1E-9): @@ -79,8 +94,18 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, "end_token" is a single string (token), or a sequence of tokens that signifies end of the sequence "use_log, renormalize, length_score" are all related to calculation of beams to keep and should improve results when True "stochastic" uses a different sampling algorithm for reducing/aggregating beams it should result in more diverse and interesting outputs "temperature" is the softmax temperature for the underlying stochastic beamsearch - the default of 1.0 is usually fine "random_state" is a np.random.RandomState() object, passed when using the stochastic beamsearch in order to control randomness "eps" minimum probability for log-space calculations, to avoid numerical issues """ if stochastic: if random_state is None: @@ -202,15 +227,6 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, else: prev_beam = curr_beam # Reduce memory on python 2 if sys.version_info < (3, 0): range = xrange @@ -255,100 +271,175 @@ def normalize(counter): return outlm def generate_letter(lm, history, order, stochastic, random_state): history = history[-order:] dist = lm[history] if stochastic: x = random_state.rand() for v, c in dist: x = x - v if x <= 0: return c # randomize choice if it all failed li = list(range(len(dist))) random_state.shuffle(li) _, c = dist[li[0]] else: probs = np.array([d[0] for d in dist]) ii = np.argmax(probs) _, c = dist[ii] return c def step_text(lm, order, stochastic, random_seed, history=None, stop_token=None, beam_width=1, n_letters=1000): # beam_width argument is ignored, as is stop_token if history is None or history == "<START>": history = "~" * order out = [] random_state = np.random.RandomState(random_seed) for i in range(n_letters): c = generate_letter(lm, history, order, stochastic, random_state) history = history[-order:] + c out.append(c) return "".join(out) def beam_text(lm, order, stochastic, random_seed, history=None, end=None,beam_width=10, n_letters=1000): def pf(prefix): history = prefix[-order:] # lm wants key as a single string dist = lm["".join(history)] return dist if history is None or history == "<START>": start_token = ["~"] * order if end is None: end_token = "<EOS>" else: end_token = end random_state = np.random.RandomState(random_seed) b = beamsearch(pf, beam_width, start_token=start_token, end_token=end_token, clip_len=n_letters, stochastic=stochastic, random_state=random_state) # it is a generator but do this so that function prototypes are consistent # top beam search r = next(b) return "".join(r[0]) if __name__ == "__main__": default_order = 6 default_temperature = 1.0 default_beamwidth = 10 default_start = "<START>" default_end = "<EOS>" default_beamwidth = 10 default_decoder = 0 default_randomseed = 1999 default_maxlength = 1000 default_cycles = 3 default_cache = 1 parser = argparse.ArgumentParser(description="A Markov chain language model with beamsearch decoding", epilog="Example usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10 -d 0") parser.add_argument("filepath", help="Path to file to use for language modeling. For an example file, try downloading http://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", default=None) parser.add_argument("-o", "--order", help="Markov chain order, higher will make better text but takes longer to process. Default {}".format(default_order), default=default_order) parser.add_argument("-t", "--temperature", help="Temperature for Markov chain softmax, higher is more random, lower more static. Default {}".format(default_temperature), default=default_temperature) parser.add_argument("-d","--decoder", help="Decoder for Markov chain, 0 is stochastic beamsearch, 1 is argmax beamsearch, 2 is sampled next-step, 3 is argmax next-step. Default {}".format(default_decoder), default=default_decoder) parser.add_argument("-b", "--beamwidth", help="Beamwidth to use for beamsearch. Default {}".format(default_beamwidth), default=default_beamwidth) parser.add_argument("-r", "--randomseed", help="Random seed to initialize randomness. Default {}".format(default_randomseed), default=default_randomseed) parser.add_argument("-s", "--starttoken", help="Start sequence token. Can be a comma separated (no spaces) list, in which case it should be the same length as --order argument. Default {}".format(default_start), default=default_start) parser.add_argument("-e", "--endtoken", help="Random seed to initialize randomness. Can be a comma separated (no space) list. Default {}".format(default_end), default=default_end) parser.add_argument("-m", "--maxlength", help="Max generation length. Default {}".format(default_maxlength), default=default_maxlength) parser.add_argument("-c", "--cache", help="Whether to cache models for faster use. Default {}".format(default_cache), default=default_cache) #parser.add_argument("-c", "--cycles", help="Number of cycles, using the last piece of the previous beam to start a new one. Can be useful for long beamsearches. Default {}".format(default_cycles)) args = parser.parse_args() if args.filepath is None: raise ValueError("No text filepath provided!") else: fpath = args.filepath if not os.path.exists(fpath): raise ValueError("Unable to find file at %s" % fpath) decoder_settings = [0, 1, 2, 3] decoder = int(args.decoder) start_token = args.starttoken end_token = args.endtoken # TODO: list support # TODO: gumbel-max in stochastic beam decoder... beam_width = int(args.beamwidth) temperature = float(args.temperature) random_seed = int(args.randomseed) maxlength = int(args.maxlength) order = int(args.order) if order < 1: raise ValueError("Order must be greater than 1! Was set to {}".format(order)) cache = int(args.cache) if cache not in [0, 1]: raise ValueError("Cache must be either 0 (no save) or 1 (save)! Was set to {}".format(cache)) if decoder == 0: # stochastic beam stochastic = True decode_fun = beam_text type_tag = "Stochastic beam search, beam width {}, Markov order {}, temperature {}, seed {}".format(beam_width, order, temperature, random_seed) elif decoder == 1: # argmax beam stochastic = False decode_fun = beam_text type_tag = "Argmax beam search, beam width {}, Markov order {}".format(beam_width, order) elif decoder == 2: # stochastic next-step stochastic = True decode_fun = step_text type_tag = "Stochastic next step, Markov order {}, temperature {}, seed {}".format(order, temperature, random_seed) elif decoder == 3: # argmax next-step stochastic = False decode_fun = step_text type_tag = "Argmax next step, Markov order {}".format(order) else: raise ValueError("Decoder must be 0, 1, 2, or 3! Was set to {}".format(decoder)) # only things that affect the language model are training data, temperature, order cached_name = "model_{}_t{}_o{}.pkl".format("".join(fpath.split(".")[:-1]), str(temperature).replace(".", "pt"), order) if cache == 1 and os.path.exists(cached_name): print("Found cached model at {}, loading...".format(cached_name)) start_time = time.time() with open(cached_name, "rb") as f: lm = pickle.load(f) stop_time = time.time() print("Time to load: {} s".format(stop_time - start_time)) else: start_time = time.time() lm = train_char_lm(fpath, order=order, temperature=temperature) stop_time = time.time() print("Time to train: {} s".format(stop_time - start_time)) if cache == 1: print("Caching model now...") with open(cached_name, "wb") as f: pickle.dump(lm, f, pickle.HIGHEST_PROTOCOL) print("Caching complete!") start_time = time.time() o = decode_fun(lm, order, stochastic, random_seed, history=start_token, end=end_token, beam_width=beam_width, n_letters=maxlength) stop_time = time.time() print(type_tag) print("Time to decode: {} s".format(stop_time - start_time)) print("----------") print(o) print("----------") -
kastnerkyle revised this gist
May 17, 2017 . 1 changed file with 113 additions and 41 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -20,26 +20,46 @@ class Beam(object): This is so that if two prefixes have equal probabilities then a complete sentence is preferred over an incomplete one since (0.5, False, whatever_prefix) < (0.5, True, some_other_prefix) """ def __init__(self, beam_width, init_beam=None, use_log=True, stochastic=False, temperature=1.0, random_state=None): if init_beam is None: self.heap = list() else: self.heap = init_beam heapq.heapify(self.heap) self.stochastic = stochastic self.random_state = random_state self.temperature = temperature # use_log currently unused... self.use_log = use_log self.beam_width = beam_width def add(self, score, complete, prob, prefix): heapq.heappush(self.heap, (score, complete, prob, prefix)) if len(self.heap) > self.beam_width: if self.stochastic: log_probs = np.array([h[2] for h in self.heap]) e_x = np.exp(log_probs - np.max(log_probs)) s_x = e_x / e_x.sum() is_x = 1. - s_x is_x = is_x / is_x.sum() to_remove = self.random_state.multinomial(1, is_x).argmax() # there must be a faster way... self.heap.pop(to_remove) heapq.heapify(self.heap) else: # remove lowest score from heap heapq.heappop(self.heap) def __iter__(self): return iter(self.heap) def beamsearch(probabilities_function, beam_width=10, clip_len=-1, start_token="<START>", end_token="<END>", use_log=True, renormalize=True, length_score=True, stochastic=False, temperature=1.0, random_state=None, eps=1E-9): """ From http://geekyisawesome.blogspot.ca/2017/04/getting-top-n-most-probable-sentences.html @@ -56,28 +76,44 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, "start_token" can be a single string (token), or a sequence of tokens "end_token" is a single string (token), or a sequence of tokens that signifies end of the sequence "use_log, renormalize, length_score" are all related to calculation of beams to keep "stochastic" """ if stochastic: if random_state is None: raise ValueError("Must pass np.random.RandomState() object if stochastic=True") prev_beam = Beam(beam_width, None, use_log, stochastic, temperature, random_state) try: basestring except NameError: basestring = str if isinstance(start_token, collections.Sequence) and not isinstance(start_token, basestring): start_token_is_seq = True else: # make it a list with 1 entry start_token = [start_token] start_token_is_seq = False if isinstance(end_token, collections.Sequence) and not isinstance(end_token, basestring): end_token_is_seq = True else: # make it a list with 1 entry end_token = [end_token] end_token_is_seq = False if use_log: prev_beam.add(.0, False, .0, start_token) else: prev_beam.add(1.0, False, 1.0, start_token) while True: curr_beam = Beam(beam_width, None, use_log, stochastic, temperature, random_state) if renormalize: sorted_prev_beam = sorted(prev_beam) # renormalize by the previous minimum value in the beam @@ -89,9 +125,9 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, min_prob = 1. # Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. for (prefix_score, complete, prefix_prob, prefix) in prev_beam: if complete == True: curr_beam.add(prefix_score, True, prefix_prob, prefix) else: # Get probability of each possible next word for the incomplete prefix for (next_prob, next_word) in probabilities_function(prefix): @@ -101,35 +137,54 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, else: n = eps # score is renormalized prob if use_log: if length_score: score = prefix_prob + np.log(n) - min_prob + np.log(len(prefix)) else: score = prefix_prob + np.log(n) - min_prob prob = prefix_prob + np.log(n) else: if length_score: score = (prefix_prob * n) / min_prob * len(prefix) else: score = (prefix_prob * n) / min_prob prob = prefix_prob * n if end_token_is_seq: left_cmp = prefix[-len(end_token) + 1:] + [next_word] right_cmp = end_token else: left_cmp = next_word right_cmp = end_token if left_cmp == right_cmp: # If next word is the end token then mark prefix as complete and leave out the end token curr_beam.add(score, True, prob, prefix) else: curr_beam.add(score, False, prob, prefix + [next_word]) # Get all prefixes in beam sorted by probability sorted_beam = sorted(curr_beam) any_removals = False while True: # Get highest probability prefix - heapq is sorted in ascending order (best_score, best_complete, best_prob, best_prefix) = sorted_beam[-1] if best_complete == True or len(best_prefix) - 1 == clip_len: # If most probable prefix is a complete sentence or has a length that # exceeds the clip length (ignoring the start token) then return it # yield best without start token, along with probability if start_token_is_seq: skip = len(start_token) else: skip = 1 if end_token_is_seq: stop = -len(end_token) + 1 else: stop = None yield (best_prefix[skip:stop], best_prob) sorted_beam.pop() any_removals = True # If there are no more sentences in the beam then stop checking @@ -142,7 +197,8 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, if len(sorted_beam) == 0: break else: prev_beam = Beam(beam_width, sorted_beam, use_log, stochastic, temperature, random_state) else: prev_beam = curr_beam @@ -225,21 +281,25 @@ def generate_text(lm, order, n_letters=1000): return "".join(out) def beam_text(lm, order, stochastic, beam_width=10, n_letters=1000): def pf(prefix): history = prefix[-order:] # lm wants key as a single string dist = lm["".join(history)] return dist random_state = np.random.RandomState(1999) b = beamsearch(pf, beam_width, start_token=["~"] * order, clip_len=n_letters, stochastic=stochastic, random_state=random_state) return b if __name__ == "__main__": default_order = 6 default_temperature = 1.0 default_stochastic = True default_fpath = "shakespeare_input.txt" if len(sys.argv) > 1: @@ -262,21 +322,33 @@ def pf(prefix): else: temperature = default_temperature if len(sys.argv) > 4: stochastic = int(sys.argv[4]) if stochastic == 0: stochastic = False elif stochastic == 1: stochastic = True else: print("Unknown setting for stochastic (argument 4), {}".format(stochastic)) else: stochastic = default_stochastic lm = train_char_lm(fpath, order=order, temperature=temperature) b = beam_text(lm, order, stochastic) # only print top beam r = next(b) print("".join(r[0])) print("-------") """ # for all... for r in b: print("".join(r[0])) print("-------") """ print(generate_text(lm, order)) print("-------") -
kastnerkyle revised this gist
May 17, 2017 . 1 changed file with 17 additions and 6 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -39,7 +39,7 @@ def __iter__(self): def beamsearch(probabilities_function, beam_width=10, clip_len=-1, start_token="<START>", end_token="<END>", use_log=True, renormalize=True, eps=1E-9): """ From http://geekyisawesome.blogspot.ca/2017/04/getting-top-n-most-probable-sentences.html @@ -78,6 +78,15 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, while True: curr_beam = Beam(beam_width) if renormalize: sorted_prev_beam = sorted(prev_beam) # renormalize by the previous minimum value in the beam min_prob = sorted_prev_beam[0][0] else: if use_log: min_prob = 0. else: min_prob = 1. # Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. for (prefix_prob, complete, prefix) in prev_beam: @@ -95,21 +104,23 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, if next_word == end_token: # If next word is the end token then mark prefix as complete and leave out the end token if use_log: curr_beam.add(prefix_prob + np.log(n) - min_prob, True, prefix) else: curr_beam.add((prefix_prob * n) / min_prob, True, prefix) else: # If next word is the end token then mark prefix as incomplete if use_log: curr_beam.add(prefix_prob + np.log(n) - min_prob, False, prefix + [next_word]) else: curr_beam.add((prefix_prob * n) / min_prob, False, prefix + [next_word]) # Get all prefixes in beam sorted by probability sorted_beam = sorted(curr_beam) any_removals = False while True: # Get highest probability prefix - heapq is sorted in ascending order (best_prob, best_complete, best_prefix) = sorted_beam[-1] if best_complete == True or len(best_prefix) - 1 == clip_len: # If most probable prefix is a complete sentence or has a length that # exceeds the clip length (ignoring the start token) then return it -
kastnerkyle revised this gist
May 17, 2017 . 1 changed file with 181 additions and 186 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,9 +1,17 @@ # Author: Kyle Kastner # License: BSD 3-Clause # See core implementations here http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html # Also includes a reduction of the post by Yoav Goldberg to a script # markov_lm.py # https://gist.github.com/yoavg/d76121dfde2618422139 import numpy as np import heapq from collections import defaultdict, Counter import collections import os import sys class Beam(object): """ @@ -29,66 +37,45 @@ def __iter__(self): return iter(self.heap) def beamsearch(probabilities_function, beam_width=10, clip_len=-1, start_token="<START>", end_token="<END>", use_log=True, eps=1E-9): """ From http://geekyisawesome.blogspot.ca/2017/04/getting-top-n-most-probable-sentences.html returns a generator, which will yield beamsearched sequences in order of their probability "probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix. "beam_width" is the number of prefixes to keep (so that instead of keeping the top 10 prefixes you can keep the top 100 for example). By making the beam search bigger you can get closer to the actual most probable sentence but it would also take longer to process. "clip_len" is a maximum length to tolerate, beyond which the most probable prefix is returned as an incomplete sentence. Without a maximum length, a faulty probabilities function which does not return a highly probable end token will lead to an infinite loop or excessively long garbage sentences. "start_token" can be a single string (token), or a sequence of tokens "end_token" is a string (token) that signifies end of the sequence """ prev_beam = Beam(beam_width) try: basestring except NameError: basestring = str if isinstance(start_token, collections.Sequence) and not isinstance(start_token, basestring): if use_log: prev_beam.add(.0, False, start_token) else: prev_beam.add(1.0, False, start_token) list_token = True else: if use_log: prev_beam.add(.0, False, [start_token]) else: prev_beam.add(1.0, False, [start_token]) list_token = False while True: curr_beam = Beam(beam_width) @@ -99,20 +86,24 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, else: # Get probability of each possible next word for the incomplete prefix for (next_prob, next_word) in probabilities_function(prefix): # use eps tolerance to avoid log(0.) issues if next_prob > eps: n = next_prob else: n = eps if next_word == end_token: # If next word is the end token then mark prefix as complete and leave out the end token if use_log: curr_beam.add(prefix_prob + np.log(n), True, prefix) else: curr_beam.add(prefix_prob * n, True, prefix) else: # If next word is the end token then mark prefix as incomplete if use_log: curr_beam.add(prefix_prob + np.log(n), False, prefix + [next_word]) else: curr_beam.add(prefix_prob * n, False, prefix + [next_word]) # Get all prefixes in beam sorted by probability sorted_beam = sorted(curr_beam) any_removals = False @@ -123,7 +114,11 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, # If most probable prefix is a complete sentence or has a length that # exceeds the clip length (ignoring the start token) then return it # yield best without start token, along with probability if list_token: skip = len(start_token) else: skip = 1 yield (best_prefix[skip:], best_prob) sorted_beam.pop() any_removals = True # If there are no more sentences in the beam then stop checking @@ -140,137 +135,137 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, else: prev_beam = curr_beam # Fun alternate settings # Download kjv.txt from http://www.ccel.org/ccel/bible/kjv.txt # python markov_lm.py kjv.txt 5 1. # Snippet: # Queen ording found Raguel: I kill. # THROUGH JESUS OF OUR BRETHREN, AND PEACE, # # NUN. # Reduce memory on python 2 if sys.version_info < (3, 0): range = xrange def train_char_lm(fname, order=4, temperature=1.0): data = file(fname).read() lm = defaultdict(Counter) pad = "~" * order data = pad + data for i in range(len(data) - order): history, char = data[i:i + order], data[i + order] lm[history][char] += 1 def normalize(counter): # Use a proper softmax with temperature t = temperature ck = counter.keys() cv = counter.values() # Keep it in log space s = float(sum([pi for pi in cv])) # 0 to 1 to help numerical issues p = [pi / s for pi in cv] # log_space p = [pi / float(t) for pi in p] mx = max(p) # log sum exp s_p = mx + np.log(sum([np.exp(pi - mx) for pi in p])) # Calculate softmax in a hopefully more stable way # s(xi) = exp ^ (xi / t) / sum exp ^ (xi / t) # log s(xi) = log (exp ^ (xi / t) / sum exp ^ (xi / t)) # log s(xi) = log exp ^ (xi / t) - log sum exp ^ (xi / t) # with pi = xi / t # with s_p = log sum exp ^ (xi / t) # log s(xi) = pi - s_p # s(xi) = np.exp(pi - s_p) p = [np.exp(pi - s_p) for pi in p] return [(pi, ci) for ci, pi in zip(ck, p)] outlm = {hist: normalize(chars) for hist, chars in lm.iteritems()} return outlm def generate_letter(lm, history, order, random_state): history = history[-order:] dist = lm[history] x = random_state.rand() for v, c in dist: x = x - v if x <= 0: return c # randomize choice if it all failed li = list(range(len(dist))) random_state.shuffle(li) _, c = dist[li[0]] return c def generate_text(lm, order, n_letters=1000): history = "~" * order out = [] random_state = np.random.RandomState(2145) for i in range(n_letters): c = generate_letter(lm, history, order, random_state) history = history[-order:] + c out.append(c) return "".join(out) def beam_text(lm, order, beam_width=10, n_letters=1000): def pf(prefix): history = prefix[-order:] # lm wants key as a single string dist = lm["".join(history)] return dist b = beamsearch(pf, beam_width, start_token=["~"] * order, clip_len=n_letters) return b if __name__ == "__main__": default_order = 6 default_temperature = 1.0 default_fpath = "shakespeare_input.txt" if len(sys.argv) > 1: fpath = sys.argv[1] if not os.path.exists(fpath): raise ValueError("Unable to find file at %s" % fpath) else: fpath = default_fpath if not os.path.exists(fpath): raise ValueError("Default shakespeare file not found!" "Get the shakespeare file from http://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt" "Place at %s" % fpath) if len(sys.argv) > 2: order = int(sys.argv[2]) else: order = default_order if len(sys.argv) > 3: temperature = float(sys.argv[3]) else: temperature = default_temperature lm = train_char_lm(fpath, order=order, temperature=temperature) b = beam_text(lm, order) # only print top beam r = next(b) print("".join(r[0])) print("-------") print(generate_text(lm, order)) print("-------") # for all... """ for r in b: print("".join(r[0])) print("-------") """ -
kastnerkyle revised this gist
May 17, 2017 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -270,7 +270,7 @@ def pf(prefix): print("Experiment 2: {}".format(sentence2)) run_experiment(sentence2) print("") sentence3 = "<START> The {} sometimes very rarely but periodically will {} nightly <END>" print("Experiment 3: {}".format(sentence3)) run_experiment(sentence3) print("") -
kastnerkyle revised this gist
May 17, 2017 . 1 changed file with 26 additions and 9 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -30,7 +30,7 @@ def __iter__(self): def single_beamsearch(probabilities_function, beam_width=10, clip_len=-1, start_token="<START>", end_token="<END>", eps=1E-9): """ From http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html "probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix. @@ -41,8 +41,8 @@ def single_beamsearch(probabilities_function, beam_width=10, clip_len=-1, will lead to an infinite loop or excessively long garbage sentences. """ prev_beam = Beam(beam_width) #prev_beam.add(1.0, False, [start_token]) prev_beam.add(.0, False, [start_token]) while True: curr_beam = Beam(beam_width) # Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. @@ -54,10 +54,18 @@ def single_beamsearch(probabilities_function, beam_width=10, clip_len=-1, for (next_prob, next_word) in probabilities_function(prefix): if next_word == end_token: # If next word is the end token then mark prefix as complete and leave out the end token #curr_beam.add(prefix_prob * next_prob, True, prefix) if next_prob > eps: curr_beam.add(prefix_prob + np.log(next_prob), True, prefix) else: curr_beam.add(prefix_prob + np.log(eps), True, prefix) else: # If next word is the end token then mark prefix as incomplete #curr_beam.add(prefix_prob * next_prob, False, prefix + [next_word]) if next_prob > eps: curr_beam.add(prefix_prob + np.log(next_prob), False, prefix + [next_word]) else: curr_beam.add(prefix_prob + np.log(eps), False, prefix + [next_word]) (best_prob, best_complete, best_prefix) = max(curr_beam) if best_complete == True or len(best_prefix) - 1 == clip_len: # If most probable prefix is a complete sentence or has a length that @@ -68,7 +76,7 @@ def single_beamsearch(probabilities_function, beam_width=10, clip_len=-1, def beamsearch(probabilities_function, beam_width=10, clip_len=-1, start_token="<START>", end_token="<END>", eps=1E-9): """ From http://geekyisawesome.blogspot.ca/2017/04/getting-top-n-most-probable-sentences.html "probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix. @@ -79,7 +87,8 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, will lead to an infinite loop or excessively long garbage sentences. """ prev_beam = Beam(beam_width) #prev_beam.add(1.0, False, [start_token]) prev_beam.add(.0, False, [start_token]) while True: curr_beam = Beam(beam_width) @@ -92,10 +101,18 @@ def beamsearch(probabilities_function, beam_width=10, clip_len=-1, for (next_prob, next_word) in probabilities_function(prefix): if next_word == end_token: # If next word is the end token then mark prefix as complete and leave out the end token #curr_beam.add(prefix_prob * next_prob, True, prefix) if next_prob > eps: curr_beam.add(prefix_prob + np.log(next_prob), True, prefix) else: curr_beam.add(prefix_prob + np.log(eps), True, prefix) else: # If next word is the end token then mark prefix as incomplete #curr_beam.add(prefix_prob * next_prob, False, prefix + [next_word]) if next_prob > eps: curr_beam.add(prefix_prob + np.log(next_prob), False, prefix + [next_word]) else: curr_beam.add(prefix_prob + np.log(eps), False, prefix + [next_word]) # Get all prefixes in beam sorted by probability sorted_beam = sorted(curr_beam) any_removals = False -
kastnerkyle revised this gist
May 17, 2017 . 1 changed file with 256 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,259 @@ # Author: Kyle Kastner # License: BSD 3-Clause # See core implementations here http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html import numpy as np import heapq class Beam(object): """ From http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html For comparison of prefixes, the tuple (prefix_probability, complete_sentence) is used. This is so that if two prefixes have equal probabilities then a complete sentence is preferred over an incomplete one since (0.5, False, whatever_prefix) < (0.5, True, some_other_prefix) """ def __init__(self, beam_width, init_beam=None): if init_beam is None: self.heap = list() else: self.heap = init_beam heapq.heapify(self.heap) self.beam_width = beam_width def add(self, prob, complete, prefix): heapq.heappush(self.heap, (prob, complete, prefix)) if len(self.heap) > self.beam_width: # remove lowest probability from heap heapq.heappop(self.heap) def __iter__(self): return iter(self.heap) def single_beamsearch(probabilities_function, beam_width=10, clip_len=-1, start_token="<START>", end_token="<END>"): """ From http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html "probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix. "beam_width" is the number of prefixes to keep (so that instead of keeping the top 10 prefixes you can keep the top 100 for example). By making the beam search bigger you can get closer to the actual most probable sentence but it would also take longer to process. "clip_len" is a maximum length to tolerate, beyond which the most probable prefix is returned as an incomplete sentence. Without a maximum length, a faulty probabilities function which does not return a highly probable end token will lead to an infinite loop or excessively long garbage sentences. """ prev_beam = Beam(beam_width) # Currently in normal probability space, not logprob prev_beam.add(1.0, False, [start_token]) while True: curr_beam = Beam(beam_width) # Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. for (prefix_prob, complete, prefix) in prev_beam: if complete == True: curr_beam.add(prefix_prob, True, prefix) else: # Get probability of each possible next word for (next_prob, next_word) in probabilities_function(prefix): if next_word == end_token: # If next word is the end token then mark prefix as complete and leave out the end token curr_beam.add(prefix_prob * next_prob, True, prefix) else: # If next word is the end token then mark prefix as incomplete curr_beam.add(prefix_prob * next_prob, False, prefix + [next_word]) (best_prob, best_complete, best_prefix) = max(curr_beam) if best_complete == True or len(best_prefix) - 1 == clip_len: # If most probable prefix is a complete sentence or has a length that # exceeds the clip length (ignoring the start token) then return it return (best_prefix[1:], best_prob) # return best sentence without the start token and together with its probability prev_beam = curr_beam def beamsearch(probabilities_function, beam_width=10, clip_len=-1, start_token="<START>", end_token="<END>"): """ From http://geekyisawesome.blogspot.ca/2017/04/getting-top-n-most-probable-sentences.html "probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix. "beam_width" is the number of prefixes to keep (so that instead of keeping the top 10 prefixes you can keep the top 100 for example). By making the beam search bigger you can get closer to the actual most probable sentence but it would also take longer to process. "clip_len" is a maximum length to tolerate, beyond which the most probable prefix is returned as an incomplete sentence. Without a maximum length, a faulty probabilities function which does not return a highly probable end token will lead to an infinite loop or excessively long garbage sentences. """ prev_beam = Beam(beam_width) prev_beam.add(1.0, False, [start_token]) while True: curr_beam = Beam(beam_width) # Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. for (prefix_prob, complete, prefix) in prev_beam: if complete == True: curr_beam.add(prefix_prob, True, prefix) else: # Get probability of each possible next word for the incomplete prefix for (next_prob, next_word) in probabilities_function(prefix): if next_word == end_token: # If next word is the end token then mark prefix as complete and leave out the end token curr_beam.add(prefix_prob * next_prob, True, prefix) else: # If next word is the end token then mark prefix as incomplete curr_beam.add(prefix_prob * next_prob, False, prefix + [next_word]) # Get all prefixes in beam sorted by probability sorted_beam = sorted(curr_beam) any_removals = False while True: # Get highest probability prefix (best_prob, best_complete, best_prefix) = sorted_beam[-1] if best_complete == True or len(best_prefix) - 1 == clip_len: # If most probable prefix is a complete sentence or has a length that # exceeds the clip length (ignoring the start token) then return it # yield best without start token, along with probability yield (best_prefix[1:], best_prob) sorted_beam.pop() any_removals = True # If there are no more sentences in the beam then stop checking if len(sorted_beam) == 0: break else: break if any_removals == True: if len(sorted_beam) == 0: break else: prev_beam = Beam(beam_width, sorted_beam) else: prev_beam = curr_beam def run_experiment(sentence): random_state = np.random.RandomState(2147) pairs = [("cat", "meow"), ("dog", "bark"), ("cow", "moo"), ("horse", "neigh")] sentences = [sentence.format(n, v) for n, v in pairs] word_list = [] for se in sentences: word_list.extend(se.split(" ")) word_list = sorted(list(set(word_list))) word2idx = {k:v for v, k in enumerate(word_list)} idx2word = {v:k for k, v in word2idx.items()} # transition probability from current word to next word single_markov_probs = np.zeros((len(word_list), len(word_list))) for se in sentences: # sentence as list of words sw = se.split(" ") # current word, next word for cw, nw in zip(sw[:-1], sw[1:]): ci = word2idx[cw] ni = word2idx[nw] # add to count single_markov_probs[ci, ni] += 1 # make it a true probability matrix = each row sums to 1 for n in range(len(single_markov_probs)): sn = sum(single_markov_probs[n]) if sn > 0.: single_markov_probs[n] /= sn else: # skip for things which are never started from continue # 'sample' a sentence sampled = ["<START>"] while sampled[-1] != "<END>": ci = word2idx[sampled[-1]] next_probs = single_markov_probs[ci] # Logic to avoid sampling from truly 0. probabilities # needed due to floating point math and numpy sampling next_probs_lu = np.where(next_probs > 0.)[0] next_probs = next_probs[next_probs_lu] # deterministic ni_lu = next_probs.argmax() # random #ni_lu = random_state.multinomial(1, next_probs).argmax() ni = next_probs_lu[ni_lu] nw = idx2word[ni] sampled.append(nw) print("Greedy with history 1: {}".format(sampled)) def pf(prefix): ci = word2idx[prefix[-1]] probs = single_markov_probs[ci] words = [idx2word[n] for n in range(len(probs))] return [(probs[i], words[i]) for i in range(len(probs))] # Now beam search on history 1, returns (seq, prob_of_seq) """ r = single_beamsearch(pf, beam_width=3) beamsearched = r[0] print("Beamsearched with history 1: {}".format(["<START>"] + beamsearched + ["<END>"])) """ bw = 3 bs = beamsearch(pf, beam_width=bw) beamsearched = next(bs)[0] print("Beamsearched with history 1: {}".format(["<START>"] + beamsearched + ["<END>"])) # now do the same experiment for history of 2 token_list = [(w1, w2) for w1 in word_list for w2 in word_list] sentences_2w = ["<START> " + se + " <END>" for se in sentences] token2idx = {k: v for v, k in enumerate(token_list)} idx2token = {v: k for k, v in token2idx.items()} # transition probability from current token to next word second_markov_probs = np.zeros((len(token_list), len(word_list))) for se_2w in sentences_2w: # sentence as list of words sw = se_2w.split(" ") # current word, next word for pw, cw, nw in zip(sw[:-2], sw[1:-1], sw[2:]): ci = token2idx[(pw, cw)] ni = word2idx[nw] # add to count second_markov_probs[ci, ni] += 1 # make it a true probability matrix = each row sums to 1 for n in range(len(second_markov_probs)): sn = sum(second_markov_probs[n]) if sn > 0.: second_markov_probs[n] /= sn else: # skip for things which are never started from continue sampled_2w = ["<START>", "<START>"] while sampled_2w[-1] != "<END>": ci = token2idx[(sampled_2w[-2], sampled_2w[-1])] next_probs = second_markov_probs[ci] # Logic to avoid sampling from truly 0. probabilities # needed due to floating point math and numpy sampling next_probs_lu = np.where(next_probs > 0.)[0] next_probs = next_probs[next_probs_lu] # deterministic ni_lu = next_probs.argmax() # random #ni_lu = random_state.multinomial(1, next_probs).argmax() ni = next_probs_lu[ni_lu] nw = idx2word[ni] sampled_2w.append(nw) # Skip the extra <START> we appended above print("Greedy with history 2: {}".format(sampled_2w[1:])) # toy examples # <START> The {cat, dog, cow, horse} may {meow, bark, moo, neigh} nightly <END> print("") sentence1 = "<START> The {} may {} nightly <END>" print("Experiment 1: {}".format(sentence1)) run_experiment(sentence1) print("") sentence2 = "<START> The {} may sometimes {} nightly <END>" print("Experiment 2: {}".format(sentence2)) run_experiment(sentence2) print("") sentence3 = "<START> The {} may sometimes very occaisionally rarely but periodically {} nightly <END>" print("Experiment 3: {}".format(sentence3)) run_experiment(sentence3) print("") -
kastnerkyle created this gist
May 17, 2017 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,3 @@ # Author: Kyle Kastner # License: BSD 3-Clause # See core implementations here http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html