Skip to content

Instantly share code, notes, and snippets.

@nthon
Forked from udibr/beamsearch.py
Created November 2, 2020 13:01
Show Gist options
  • Save nthon/1ee5f115e5580e2f2ed57fdd094ab3c3 to your computer and use it in GitHub Desktop.
Save nthon/1ee5f115e5580e2f2ed57fdd094ab3c3 to your computer and use it in GitHub Desktop.

Revisions

  1. @udibr udibr revised this gist Mar 23, 2016. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion beamsearch.py
    Original file line number Diff line number Diff line change
    @@ -26,7 +26,7 @@ def beamsearch(predict=keras_rnn_predict,
    # for every possible live sample calc prob for every possible label
    probs = predict(live_samples, empty=empty)

    # total score for every hyp is sum of -log of word prb
    # total score for every sample is sum of -log of word prb
    cand_scores = np.array(live_scores)[:,None] - np.log(probs)
    if not use_unk and oov is not None:
    cand_scores[:,oov] = 1e20
  2. @udibr udibr revised this gist Mar 23, 2016. 1 changed file with 31 additions and 1 deletion.
    32 changes: 31 additions & 1 deletion beamsearch.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,34 @@
    if not use_unk and oov is not None:
    # variation to https://github.com/ryankiros/skip-thoughts/blob/master/decoding/search.py

    def keras_rnn_predict(samples, empty=empty, rnn_model=model, maxlen=maxlen):
    """for every sample, calculate probability for every possible label
    you need to supply your RNN model and maxlen - the length of sequences it can handle
    """
    data = sequence.pad_sequences(samples, maxlen=maxlen, value=empty)
    return rnn_model.predict(data, verbose=0)

    def beamsearch(predict=keras_rnn_predict,
    k=1, maxsample=400, use_unk=False, oov=oov, empty=empty, eos=eos):
    """return k samples (beams) and their NLL scores, each sample is a sequence of labels,
    all samples starts with an `empty` label and end with `eos` or truncated to length of `maxsample`.
    You need to supply `predict` which returns the label probability of each sample.
    `use_unk` allow usage of `oov` (out-of-vocabulary) label in samples
    """

    dead_k = 0 # samples that reached eos
    dead_samples = []
    dead_scores = []
    live_k = 1 # samples that did not yet reached eos
    live_samples = [[empty]]
    live_scores = [0]

    while live_k and dead_k < k:
    # for every possible live sample calc prob for every possible label
    probs = predict(live_samples, empty=empty)

    # total score for every hyp is sum of -log of word prb
    cand_scores = np.array(live_scores)[:,None] - np.log(probs)
    if not use_unk and oov is not None:
    cand_scores[:,oov] = 1e20
    cand_flat = cand_scores.flatten()

  3. @udibr udibr created this gist Mar 23, 2016.
    25 changes: 25 additions & 0 deletions beamsearch.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,25 @@
    if not use_unk and oov is not None:
    cand_scores[:,oov] = 1e20
    cand_flat = cand_scores.flatten()

    # find the best (lowest) scores we have from all possible samples and new words
    ranks_flat = cand_flat.argsort()[:(k-dead_k)]
    live_scores = cand_flat[ranks_flat]

    # append the new words to their appropriate live sample
    voc_size = probs.shape[1]
    live_samples = [live_samples[r//voc_size]+[r%voc_size] for r in ranks_flat]

    # live samples that should be dead are...
    zombie = [s[-1] == eos or len(s) >= maxsample for s in live_samples]

    # add zombies to the dead
    dead_samples += [s for s,z in zip(live_samples,zombie) if z] # remove first label == empty
    dead_scores += [s for s,z in zip(live_scores,zombie) if z]
    dead_k = len(dead_samples)
    # remove zombies from the living
    live_samples = [s for s,z in zip(live_samples,zombie) if not z]
    live_scores = [s for s,z in zip(live_scores,zombie) if not z]
    live_k = len(live_samples)

    return dead_samples + live_samples, dead_scores + live_scores