-
-
Save nthon/1ee5f115e5580e2f2ed57fdd094ab3c3 to your computer and use it in GitHub Desktop.
Revisions
-
udibr revised this gist
Mar 23, 2016 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -26,7 +26,7 @@ def beamsearch(predict=keras_rnn_predict, # for every possible live sample calc prob for every possible label probs = predict(live_samples, empty=empty) # total score for every sample is sum of -log of word prb cand_scores = np.array(live_scores)[:,None] - np.log(probs) if not use_unk and oov is not None: cand_scores[:,oov] = 1e20 -
udibr revised this gist
Mar 23, 2016 . 1 changed file with 31 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,4 +1,34 @@ # variation to https://github.com/ryankiros/skip-thoughts/blob/master/decoding/search.py def keras_rnn_predict(samples, empty=empty, rnn_model=model, maxlen=maxlen): """for every sample, calculate probability for every possible label you need to supply your RNN model and maxlen - the length of sequences it can handle """ data = sequence.pad_sequences(samples, maxlen=maxlen, value=empty) return rnn_model.predict(data, verbose=0) def beamsearch(predict=keras_rnn_predict, k=1, maxsample=400, use_unk=False, oov=oov, empty=empty, eos=eos): """return k samples (beams) and their NLL scores, each sample is a sequence of labels, all samples starts with an `empty` label and end with `eos` or truncated to length of `maxsample`. You need to supply `predict` which returns the label probability of each sample. `use_unk` allow usage of `oov` (out-of-vocabulary) label in samples """ dead_k = 0 # samples that reached eos dead_samples = [] dead_scores = [] live_k = 1 # samples that did not yet reached eos live_samples = [[empty]] live_scores = [0] while live_k and dead_k < k: # for every possible live sample calc prob for every possible label probs = predict(live_samples, empty=empty) # total score for every hyp is sum of -log of word prb cand_scores = np.array(live_scores)[:,None] - np.log(probs) if not use_unk and oov is not None: cand_scores[:,oov] = 1e20 cand_flat = cand_scores.flatten() -
udibr created this gist
Mar 23, 2016 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,25 @@ if not use_unk and oov is not None: cand_scores[:,oov] = 1e20 cand_flat = cand_scores.flatten() # find the best (lowest) scores we have from all possible samples and new words ranks_flat = cand_flat.argsort()[:(k-dead_k)] live_scores = cand_flat[ranks_flat] # append the new words to their appropriate live sample voc_size = probs.shape[1] live_samples = [live_samples[r//voc_size]+[r%voc_size] for r in ranks_flat] # live samples that should be dead are... zombie = [s[-1] == eos or len(s) >= maxsample for s in live_samples] # add zombies to the dead dead_samples += [s for s,z in zip(live_samples,zombie) if z] # remove first label == empty dead_scores += [s for s,z in zip(live_scores,zombie) if z] dead_k = len(dead_samples) # remove zombies from the living live_samples = [s for s,z in zip(live_samples,zombie) if not z] live_scores = [s for s,z in zip(live_scores,zombie) if not z] live_k = len(live_samples) return dead_samples + live_samples, dead_scores + live_scores