Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save phuysmans/8049d8165905e63e22fa2d539b64bfd9 to your computer and use it in GitHub Desktop.
Save phuysmans/8049d8165905e63e22fa2d539b64bfd9 to your computer and use it in GitHub Desktop.

Revisions

  1. @thmavri thmavri revised this gist Dec 13, 2016. No changes.
  2. @thmavri thmavri revised this gist Dec 13, 2016. No changes.
  3. @thmavri thmavri created this gist Dec 13, 2016.
    50 changes: 50 additions & 0 deletions Learning2Search_Named_Entity_Classification_sample.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,50 @@
    #determine the labels
    import pyvw #vw python interface

    DEST = 1
    PROP = 2
    FAC = 3

    ...
    #create the class for the Sequence Labeler
    class SequenceLabeler(pyvw.SearchTask):
    def __init__(self, vw, sch, num_actions):
    # you must must must initialize the parent class
    # this will automatically store self.sch <- sch, self.vw <- vw
    pyvw.SearchTask.__init__(self, vw, sch, num_actions)

    # set whatever options you want
    sch.set_options( sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES )

    def _run(self, sentence): # it's called _run to remind you that you shouldn't call it directly!
    output = []
    for n in range(len(sentence)):
    pos,word = sentence[n]
    # use "with...as..." to guarantee that the example is finished properly
    with self.vw.example({'w': [word]}) as ex:
    pred = self.sch.predict(examples=ex, my_tag=n+1, oracle=pos, condition=[(n,'p'), (n-1, 'q')])
    output.append(pred)
    return output
    ...
    #build the training set
    ...
    ...
    ...
    my_dataset #training set
    test_set

    #train commands
    vw = pyvw.vw("--search 3 --search_task hook --ring_size 1024") # 3 is the number of labels
    sequenceLabeler = vw.init_search_task(SequenceLabeler)

    #actual training
    for i in xrange(2):
    sequenceLabeler.learn(my_dataset)

    #predict
    test_example = [ (0,w) for w in "hotel amsterdam wifi".split() ]
    print test_example
    #[(0, 'hotel'), (0, 'amsterdam'), (0, 'wifi')]
    out = sequenceLabeler.predict(test_example)
    print out
    [2, 1, 3]