Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Last active October 16, 2021 18:49
Show Gist options
  • Select an option

  • Save kastnerkyle/97120046d0aa8f49c3fce03b844329d7 to your computer and use it in GitHub Desktop.

Select an option

Save kastnerkyle/97120046d0aa8f49c3fce03b844329d7 to your computer and use it in GitHub Desktop.
Not-so-minimal anymore (check the early commits) beam search example
# Author: Kyle Kastner
# License: BSD 3-Clause
# See core implementations here http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html
import numpy as np
import heapq
class Beam(object):
"""
From http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html
For comparison of prefixes, the tuple (prefix_probability, complete_sentence) is used.
This is so that if two prefixes have equal probabilities then a complete sentence is preferred
over an incomplete one since (0.5, False, whatever_prefix) < (0.5, True, some_other_prefix)
"""
def __init__(self, beam_width, init_beam=None):
if init_beam is None:
self.heap = list()
else:
self.heap = init_beam
heapq.heapify(self.heap)
self.beam_width = beam_width
def add(self, prob, complete, prefix):
heapq.heappush(self.heap, (prob, complete, prefix))
if len(self.heap) > self.beam_width:
# remove lowest probability from heap
heapq.heappop(self.heap)
def __iter__(self):
return iter(self.heap)
def single_beamsearch(probabilities_function, beam_width=10, clip_len=-1,
start_token="<START>", end_token="<END>", eps=1E-9):
"""
From http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html
"probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix.
"beam_width" is the number of prefixes to keep (so that instead of keeping the top 10 prefixes you can keep the top 100 for example).
By making the beam search bigger you can get closer to the actual most probable sentence but it would also take longer to process.
"clip_len" is a maximum length to tolerate, beyond which the most probable prefix is returned as an incomplete sentence.
Without a maximum length, a faulty probabilities function which does not return a highly probable end token
will lead to an infinite loop or excessively long garbage sentences.
"""
prev_beam = Beam(beam_width)
#prev_beam.add(1.0, False, [start_token])
prev_beam.add(.0, False, [start_token])
while True:
curr_beam = Beam(beam_width)
# Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them.
for (prefix_prob, complete, prefix) in prev_beam:
if complete == True:
curr_beam.add(prefix_prob, True, prefix)
else:
# Get probability of each possible next word
for (next_prob, next_word) in probabilities_function(prefix):
if next_word == end_token:
# If next word is the end token then mark prefix as complete and leave out the end token
#curr_beam.add(prefix_prob * next_prob, True, prefix)
if next_prob > eps:
curr_beam.add(prefix_prob + np.log(next_prob), True, prefix)
else:
curr_beam.add(prefix_prob + np.log(eps), True, prefix)
else:
# If next word is the end token then mark prefix as incomplete
#curr_beam.add(prefix_prob * next_prob, False, prefix + [next_word])
if next_prob > eps:
curr_beam.add(prefix_prob + np.log(next_prob), False, prefix + [next_word])
else:
curr_beam.add(prefix_prob + np.log(eps), False, prefix + [next_word])
(best_prob, best_complete, best_prefix) = max(curr_beam)
if best_complete == True or len(best_prefix) - 1 == clip_len:
# If most probable prefix is a complete sentence or has a length that
# exceeds the clip length (ignoring the start token) then return it
return (best_prefix[1:], best_prob)
# return best sentence without the start token and together with its probability
prev_beam = curr_beam
def beamsearch(probabilities_function, beam_width=10, clip_len=-1,
start_token="<START>", end_token="<END>", eps=1E-9):
"""
From http://geekyisawesome.blogspot.ca/2017/04/getting-top-n-most-probable-sentences.html
"probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix.
"beam_width" is the number of prefixes to keep (so that instead of keeping the top 10 prefixes you can keep the top 100 for example).
By making the beam search bigger you can get closer to the actual most probable sentence but it would also take longer to process.
"clip_len" is a maximum length to tolerate, beyond which the most probable prefix is returned as an incomplete sentence.
Without a maximum length, a faulty probabilities function which does not return a highly probable end token
will lead to an infinite loop or excessively long garbage sentences.
"""
prev_beam = Beam(beam_width)
#prev_beam.add(1.0, False, [start_token])
prev_beam.add(.0, False, [start_token])
while True:
curr_beam = Beam(beam_width)
# Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them.
for (prefix_prob, complete, prefix) in prev_beam:
if complete == True:
curr_beam.add(prefix_prob, True, prefix)
else:
# Get probability of each possible next word for the incomplete prefix
for (next_prob, next_word) in probabilities_function(prefix):
if next_word == end_token:
# If next word is the end token then mark prefix as complete and leave out the end token
#curr_beam.add(prefix_prob * next_prob, True, prefix)
if next_prob > eps:
curr_beam.add(prefix_prob + np.log(next_prob), True, prefix)
else:
curr_beam.add(prefix_prob + np.log(eps), True, prefix)
else:
# If next word is the end token then mark prefix as incomplete
#curr_beam.add(prefix_prob * next_prob, False, prefix + [next_word])
if next_prob > eps:
curr_beam.add(prefix_prob + np.log(next_prob), False, prefix + [next_word])
else:
curr_beam.add(prefix_prob + np.log(eps), False, prefix + [next_word])
# Get all prefixes in beam sorted by probability
sorted_beam = sorted(curr_beam)
any_removals = False
while True:
# Get highest probability prefix
(best_prob, best_complete, best_prefix) = sorted_beam[-1]
if best_complete == True or len(best_prefix) - 1 == clip_len:
# If most probable prefix is a complete sentence or has a length that
# exceeds the clip length (ignoring the start token) then return it
# yield best without start token, along with probability
yield (best_prefix[1:], best_prob)
sorted_beam.pop()
any_removals = True
# If there are no more sentences in the beam then stop checking
if len(sorted_beam) == 0:
break
else:
break
if any_removals == True:
if len(sorted_beam) == 0:
break
else:
prev_beam = Beam(beam_width, sorted_beam)
else:
prev_beam = curr_beam
def run_experiment(sentence):
random_state = np.random.RandomState(2147)
pairs = [("cat", "meow"), ("dog", "bark"), ("cow", "moo"), ("horse", "neigh")]
sentences = [sentence.format(n, v) for n, v in pairs]
word_list = []
for se in sentences:
word_list.extend(se.split(" "))
word_list = sorted(list(set(word_list)))
word2idx = {k:v for v, k in enumerate(word_list)}
idx2word = {v:k for k, v in word2idx.items()}
# transition probability from current word to next word
single_markov_probs = np.zeros((len(word_list), len(word_list)))
for se in sentences:
# sentence as list of words
sw = se.split(" ")
# current word, next word
for cw, nw in zip(sw[:-1], sw[1:]):
ci = word2idx[cw]
ni = word2idx[nw]
# add to count
single_markov_probs[ci, ni] += 1
# make it a true probability matrix = each row sums to 1
for n in range(len(single_markov_probs)):
sn = sum(single_markov_probs[n])
if sn > 0.:
single_markov_probs[n] /= sn
else:
# skip for things which are never started from
continue
# 'sample' a sentence
sampled = ["<START>"]
while sampled[-1] != "<END>":
ci = word2idx[sampled[-1]]
next_probs = single_markov_probs[ci]
# Logic to avoid sampling from truly 0. probabilities
# needed due to floating point math and numpy sampling
next_probs_lu = np.where(next_probs > 0.)[0]
next_probs = next_probs[next_probs_lu]
# deterministic
ni_lu = next_probs.argmax()
# random
#ni_lu = random_state.multinomial(1, next_probs).argmax()
ni = next_probs_lu[ni_lu]
nw = idx2word[ni]
sampled.append(nw)
print("Greedy with history 1: {}".format(sampled))
def pf(prefix):
ci = word2idx[prefix[-1]]
probs = single_markov_probs[ci]
words = [idx2word[n] for n in range(len(probs))]
return [(probs[i], words[i]) for i in range(len(probs))]
# Now beam search on history 1, returns (seq, prob_of_seq)
"""
r = single_beamsearch(pf, beam_width=3)
beamsearched = r[0]
print("Beamsearched with history 1: {}".format(["<START>"] + beamsearched + ["<END>"]))
"""
bw = 3
bs = beamsearch(pf, beam_width=bw)
beamsearched = next(bs)[0]
print("Beamsearched with history 1: {}".format(["<START>"] + beamsearched + ["<END>"]))
# now do the same experiment for history of 2
token_list = [(w1, w2) for w1 in word_list for w2 in word_list]
sentences_2w = ["<START> " + se + " <END>" for se in sentences]
token2idx = {k: v for v, k in enumerate(token_list)}
idx2token = {v: k for k, v in token2idx.items()}
# transition probability from current token to next word
second_markov_probs = np.zeros((len(token_list), len(word_list)))
for se_2w in sentences_2w:
# sentence as list of words
sw = se_2w.split(" ")
# current word, next word
for pw, cw, nw in zip(sw[:-2], sw[1:-1], sw[2:]):
ci = token2idx[(pw, cw)]
ni = word2idx[nw]
# add to count
second_markov_probs[ci, ni] += 1
# make it a true probability matrix = each row sums to 1
for n in range(len(second_markov_probs)):
sn = sum(second_markov_probs[n])
if sn > 0.:
second_markov_probs[n] /= sn
else:
# skip for things which are never started from
continue
sampled_2w = ["<START>", "<START>"]
while sampled_2w[-1] != "<END>":
ci = token2idx[(sampled_2w[-2], sampled_2w[-1])]
next_probs = second_markov_probs[ci]
# Logic to avoid sampling from truly 0. probabilities
# needed due to floating point math and numpy sampling
next_probs_lu = np.where(next_probs > 0.)[0]
next_probs = next_probs[next_probs_lu]
# deterministic
ni_lu = next_probs.argmax()
# random
#ni_lu = random_state.multinomial(1, next_probs).argmax()
ni = next_probs_lu[ni_lu]
nw = idx2word[ni]
sampled_2w.append(nw)
# Skip the extra <START> we appended above
print("Greedy with history 2: {}".format(sampled_2w[1:]))
# toy examples
# <START> The {cat, dog, cow, horse} may {meow, bark, moo, neigh} nightly <END>
print("")
sentence1 = "<START> The {} may {} nightly <END>"
print("Experiment 1: {}".format(sentence1))
run_experiment(sentence1)
print("")
sentence2 = "<START> The {} may sometimes {} nightly <END>"
print("Experiment 2: {}".format(sentence2))
run_experiment(sentence2)
print("")
sentence3 = "<START> The {} may sometimes very occaisionally rarely but periodically {} nightly <END>"
print("Experiment 3: {}".format(sentence3))
run_experiment(sentence3)
print("")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment