Last active
October 16, 2021 18:49
-
-
Save kastnerkyle/97120046d0aa8f49c3fce03b844329d7 to your computer and use it in GitHub Desktop.
Not-so-minimal anymore (check the early commits) beam search example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Author: Kyle Kastner | |
| # License: BSD 3-Clause | |
| # See core implementations here http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html | |
| import numpy as np | |
| import heapq | |
| class Beam(object): | |
| """ | |
| From http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html | |
| For comparison of prefixes, the tuple (prefix_probability, complete_sentence) is used. | |
| This is so that if two prefixes have equal probabilities then a complete sentence is preferred | |
| over an incomplete one since (0.5, False, whatever_prefix) < (0.5, True, some_other_prefix) | |
| """ | |
| def __init__(self, beam_width, init_beam=None): | |
| if init_beam is None: | |
| self.heap = list() | |
| else: | |
| self.heap = init_beam | |
| heapq.heapify(self.heap) | |
| self.beam_width = beam_width | |
| def add(self, prob, complete, prefix): | |
| heapq.heappush(self.heap, (prob, complete, prefix)) | |
| if len(self.heap) > self.beam_width: | |
| # remove lowest probability from heap | |
| heapq.heappop(self.heap) | |
| def __iter__(self): | |
| return iter(self.heap) | |
| def single_beamsearch(probabilities_function, beam_width=10, clip_len=-1, | |
| start_token="<START>", end_token="<END>", eps=1E-9): | |
| """ | |
| From http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html | |
| "probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix. | |
| "beam_width" is the number of prefixes to keep (so that instead of keeping the top 10 prefixes you can keep the top 100 for example). | |
| By making the beam search bigger you can get closer to the actual most probable sentence but it would also take longer to process. | |
| "clip_len" is a maximum length to tolerate, beyond which the most probable prefix is returned as an incomplete sentence. | |
| Without a maximum length, a faulty probabilities function which does not return a highly probable end token | |
| will lead to an infinite loop or excessively long garbage sentences. | |
| """ | |
| prev_beam = Beam(beam_width) | |
| #prev_beam.add(1.0, False, [start_token]) | |
| prev_beam.add(.0, False, [start_token]) | |
| while True: | |
| curr_beam = Beam(beam_width) | |
| # Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. | |
| for (prefix_prob, complete, prefix) in prev_beam: | |
| if complete == True: | |
| curr_beam.add(prefix_prob, True, prefix) | |
| else: | |
| # Get probability of each possible next word | |
| for (next_prob, next_word) in probabilities_function(prefix): | |
| if next_word == end_token: | |
| # If next word is the end token then mark prefix as complete and leave out the end token | |
| #curr_beam.add(prefix_prob * next_prob, True, prefix) | |
| if next_prob > eps: | |
| curr_beam.add(prefix_prob + np.log(next_prob), True, prefix) | |
| else: | |
| curr_beam.add(prefix_prob + np.log(eps), True, prefix) | |
| else: | |
| # If next word is the end token then mark prefix as incomplete | |
| #curr_beam.add(prefix_prob * next_prob, False, prefix + [next_word]) | |
| if next_prob > eps: | |
| curr_beam.add(prefix_prob + np.log(next_prob), False, prefix + [next_word]) | |
| else: | |
| curr_beam.add(prefix_prob + np.log(eps), False, prefix + [next_word]) | |
| (best_prob, best_complete, best_prefix) = max(curr_beam) | |
| if best_complete == True or len(best_prefix) - 1 == clip_len: | |
| # If most probable prefix is a complete sentence or has a length that | |
| # exceeds the clip length (ignoring the start token) then return it | |
| return (best_prefix[1:], best_prob) | |
| # return best sentence without the start token and together with its probability | |
| prev_beam = curr_beam | |
| def beamsearch(probabilities_function, beam_width=10, clip_len=-1, | |
| start_token="<START>", end_token="<END>", eps=1E-9): | |
| """ | |
| From http://geekyisawesome.blogspot.ca/2017/04/getting-top-n-most-probable-sentences.html | |
| "probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix. | |
| "beam_width" is the number of prefixes to keep (so that instead of keeping the top 10 prefixes you can keep the top 100 for example). | |
| By making the beam search bigger you can get closer to the actual most probable sentence but it would also take longer to process. | |
| "clip_len" is a maximum length to tolerate, beyond which the most probable prefix is returned as an incomplete sentence. | |
| Without a maximum length, a faulty probabilities function which does not return a highly probable end token | |
| will lead to an infinite loop or excessively long garbage sentences. | |
| """ | |
| prev_beam = Beam(beam_width) | |
| #prev_beam.add(1.0, False, [start_token]) | |
| prev_beam.add(.0, False, [start_token]) | |
| while True: | |
| curr_beam = Beam(beam_width) | |
| # Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. | |
| for (prefix_prob, complete, prefix) in prev_beam: | |
| if complete == True: | |
| curr_beam.add(prefix_prob, True, prefix) | |
| else: | |
| # Get probability of each possible next word for the incomplete prefix | |
| for (next_prob, next_word) in probabilities_function(prefix): | |
| if next_word == end_token: | |
| # If next word is the end token then mark prefix as complete and leave out the end token | |
| #curr_beam.add(prefix_prob * next_prob, True, prefix) | |
| if next_prob > eps: | |
| curr_beam.add(prefix_prob + np.log(next_prob), True, prefix) | |
| else: | |
| curr_beam.add(prefix_prob + np.log(eps), True, prefix) | |
| else: | |
| # If next word is the end token then mark prefix as incomplete | |
| #curr_beam.add(prefix_prob * next_prob, False, prefix + [next_word]) | |
| if next_prob > eps: | |
| curr_beam.add(prefix_prob + np.log(next_prob), False, prefix + [next_word]) | |
| else: | |
| curr_beam.add(prefix_prob + np.log(eps), False, prefix + [next_word]) | |
| # Get all prefixes in beam sorted by probability | |
| sorted_beam = sorted(curr_beam) | |
| any_removals = False | |
| while True: | |
| # Get highest probability prefix | |
| (best_prob, best_complete, best_prefix) = sorted_beam[-1] | |
| if best_complete == True or len(best_prefix) - 1 == clip_len: | |
| # If most probable prefix is a complete sentence or has a length that | |
| # exceeds the clip length (ignoring the start token) then return it | |
| # yield best without start token, along with probability | |
| yield (best_prefix[1:], best_prob) | |
| sorted_beam.pop() | |
| any_removals = True | |
| # If there are no more sentences in the beam then stop checking | |
| if len(sorted_beam) == 0: | |
| break | |
| else: | |
| break | |
| if any_removals == True: | |
| if len(sorted_beam) == 0: | |
| break | |
| else: | |
| prev_beam = Beam(beam_width, sorted_beam) | |
| else: | |
| prev_beam = curr_beam | |
| def run_experiment(sentence): | |
| random_state = np.random.RandomState(2147) | |
| pairs = [("cat", "meow"), ("dog", "bark"), ("cow", "moo"), ("horse", "neigh")] | |
| sentences = [sentence.format(n, v) for n, v in pairs] | |
| word_list = [] | |
| for se in sentences: | |
| word_list.extend(se.split(" ")) | |
| word_list = sorted(list(set(word_list))) | |
| word2idx = {k:v for v, k in enumerate(word_list)} | |
| idx2word = {v:k for k, v in word2idx.items()} | |
| # transition probability from current word to next word | |
| single_markov_probs = np.zeros((len(word_list), len(word_list))) | |
| for se in sentences: | |
| # sentence as list of words | |
| sw = se.split(" ") | |
| # current word, next word | |
| for cw, nw in zip(sw[:-1], sw[1:]): | |
| ci = word2idx[cw] | |
| ni = word2idx[nw] | |
| # add to count | |
| single_markov_probs[ci, ni] += 1 | |
| # make it a true probability matrix = each row sums to 1 | |
| for n in range(len(single_markov_probs)): | |
| sn = sum(single_markov_probs[n]) | |
| if sn > 0.: | |
| single_markov_probs[n] /= sn | |
| else: | |
| # skip for things which are never started from | |
| continue | |
| # 'sample' a sentence | |
| sampled = ["<START>"] | |
| while sampled[-1] != "<END>": | |
| ci = word2idx[sampled[-1]] | |
| next_probs = single_markov_probs[ci] | |
| # Logic to avoid sampling from truly 0. probabilities | |
| # needed due to floating point math and numpy sampling | |
| next_probs_lu = np.where(next_probs > 0.)[0] | |
| next_probs = next_probs[next_probs_lu] | |
| # deterministic | |
| ni_lu = next_probs.argmax() | |
| # random | |
| #ni_lu = random_state.multinomial(1, next_probs).argmax() | |
| ni = next_probs_lu[ni_lu] | |
| nw = idx2word[ni] | |
| sampled.append(nw) | |
| print("Greedy with history 1: {}".format(sampled)) | |
| def pf(prefix): | |
| ci = word2idx[prefix[-1]] | |
| probs = single_markov_probs[ci] | |
| words = [idx2word[n] for n in range(len(probs))] | |
| return [(probs[i], words[i]) for i in range(len(probs))] | |
| # Now beam search on history 1, returns (seq, prob_of_seq) | |
| """ | |
| r = single_beamsearch(pf, beam_width=3) | |
| beamsearched = r[0] | |
| print("Beamsearched with history 1: {}".format(["<START>"] + beamsearched + ["<END>"])) | |
| """ | |
| bw = 3 | |
| bs = beamsearch(pf, beam_width=bw) | |
| beamsearched = next(bs)[0] | |
| print("Beamsearched with history 1: {}".format(["<START>"] + beamsearched + ["<END>"])) | |
| # now do the same experiment for history of 2 | |
| token_list = [(w1, w2) for w1 in word_list for w2 in word_list] | |
| sentences_2w = ["<START> " + se + " <END>" for se in sentences] | |
| token2idx = {k: v for v, k in enumerate(token_list)} | |
| idx2token = {v: k for k, v in token2idx.items()} | |
| # transition probability from current token to next word | |
| second_markov_probs = np.zeros((len(token_list), len(word_list))) | |
| for se_2w in sentences_2w: | |
| # sentence as list of words | |
| sw = se_2w.split(" ") | |
| # current word, next word | |
| for pw, cw, nw in zip(sw[:-2], sw[1:-1], sw[2:]): | |
| ci = token2idx[(pw, cw)] | |
| ni = word2idx[nw] | |
| # add to count | |
| second_markov_probs[ci, ni] += 1 | |
| # make it a true probability matrix = each row sums to 1 | |
| for n in range(len(second_markov_probs)): | |
| sn = sum(second_markov_probs[n]) | |
| if sn > 0.: | |
| second_markov_probs[n] /= sn | |
| else: | |
| # skip for things which are never started from | |
| continue | |
| sampled_2w = ["<START>", "<START>"] | |
| while sampled_2w[-1] != "<END>": | |
| ci = token2idx[(sampled_2w[-2], sampled_2w[-1])] | |
| next_probs = second_markov_probs[ci] | |
| # Logic to avoid sampling from truly 0. probabilities | |
| # needed due to floating point math and numpy sampling | |
| next_probs_lu = np.where(next_probs > 0.)[0] | |
| next_probs = next_probs[next_probs_lu] | |
| # deterministic | |
| ni_lu = next_probs.argmax() | |
| # random | |
| #ni_lu = random_state.multinomial(1, next_probs).argmax() | |
| ni = next_probs_lu[ni_lu] | |
| nw = idx2word[ni] | |
| sampled_2w.append(nw) | |
| # Skip the extra <START> we appended above | |
| print("Greedy with history 2: {}".format(sampled_2w[1:])) | |
| # toy examples | |
| # <START> The {cat, dog, cow, horse} may {meow, bark, moo, neigh} nightly <END> | |
| print("") | |
| sentence1 = "<START> The {} may {} nightly <END>" | |
| print("Experiment 1: {}".format(sentence1)) | |
| run_experiment(sentence1) | |
| print("") | |
| sentence2 = "<START> The {} may sometimes {} nightly <END>" | |
| print("Experiment 2: {}".format(sentence2)) | |
| run_experiment(sentence2) | |
| print("") | |
| sentence3 = "<START> The {} may sometimes very occaisionally rarely but periodically {} nightly <END>" | |
| print("Experiment 3: {}".format(sentence3)) | |
| run_experiment(sentence3) | |
| print("") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment