Skip to content

Instantly share code, notes, and snippets.

@y-lan
Forked from stucchio/beta_bandit.py
Created March 24, 2016 06:42
Show Gist options
  • Select an option

  • Save y-lan/d3aad8630bfbf4b79485 to your computer and use it in GitHub Desktop.

Select an option

Save y-lan/d3aad8630bfbf4b79485 to your computer and use it in GitHub Desktop.

Revisions

  1. Chris Stucchio revised this gist Apr 14, 2013. 1 changed file with 52 additions and 0 deletions.
    52 changes: 52 additions & 0 deletions beta_bandit_test.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,52 @@
    from beta_bandit import *

    from numpy import *
    from scipy.stats import beta
    import random

    theta = (0.25, 0.35)

    def is_conversion(title):
    if random.random() < theta[title]:
    return True
    else:
    return False

    conversions = [0,0]
    trials = [0,0]

    N = 100000
    trials = zeros(shape=(N,2))
    successes = zeros(shape=(N,2))

    bb = BetaBandit()
    for i in range(N):
    choice = bb.get_recommendation()
    trials[choice] = trials[choice]+1
    conv = is_conversion(choice)
    bb.add_result(choice, conv)

    trials[i] = bb.trials
    successes[i] = bb.successes

    from pylab import *
    subplot(211)
    n = arange(N)+1
    loglog(n, trials[:,0], label="title 0")
    loglog(n, trials[:,1], label="title 1")
    legend()
    xlabel("Number of trials")
    ylabel("Number of trials/title")

    subplot(212)
    semilogx(n, (successes[:,0]+successes[:,1])/n, label="CTR")
    semilogx(n, zeros(shape=(N,))+0.35, label="Best CTR")
    semilogx(n, zeros(shape=(N,))+0.30, label="Random chance CTR")
    semilogx(n, zeros(shape=(N,))+0.25, label="Worst CTR")
    axis([0,N,0.15,0.45])
    xlabel("Number of trials")
    ylabel("CTR")


    legend()
    show()
  2. Chris Stucchio renamed this gist Apr 14, 2013. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  3. Chris Stucchio revised this gist Apr 14, 2013. 1 changed file with 6 additions and 3 deletions.
    9 changes: 6 additions & 3 deletions gistfile1.py
    Original file line number Diff line number Diff line change
    @@ -17,7 +17,10 @@ def add_result(self, trial_id, success):
    def get_recommendation(self):
    sampled_theta = []
    for i in range(self.num_options):
    dist = beta(self.prior[0]+self.successes[i], self.prior[1]+self.trials[i]-self.successes[i]) #Construct beta distribution for posterior
    sampled_theta += [ dist.rvs() ] #Draw sample from beta distribution

    #Construct beta distribution for posterior
    dist = beta(self.prior[0]+self.successes[i],
    self.prior[1]+self.trials[i]-self.successes[i])
    #Draw sample from beta distribution
    sampled_theta += [ dist.rvs() ]
    # Return the index of the sample with the largest value
    return sampled_theta.index( max(sampled_theta) )
  4. Chris Stucchio created this gist Apr 14, 2013.
    23 changes: 23 additions & 0 deletions gistfile1.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,23 @@
    from numpy import *
    from scipy.stats import beta


    class BetaBandit(object):
    def __init__(self, num_options=2, prior=(1.0,1.0)):
    self.trials = zeros(shape=(num_options,), dtype=int)
    self.successes = zeros(shape=(num_options,), dtype=int)
    self.num_options = num_options
    self.prior = prior

    def add_result(self, trial_id, success):
    self.trials[trial_id] = self.trials[trial_id] + 1
    if (success):
    self.successes[trial_id] = self.successes[trial_id] + 1

    def get_recommendation(self):
    sampled_theta = []
    for i in range(self.num_options):
    dist = beta(self.prior[0]+self.successes[i], self.prior[1]+self.trials[i]-self.successes[i]) #Construct beta distribution for posterior
    sampled_theta += [ dist.rvs() ] #Draw sample from beta distribution

    return sampled_theta.index( max(sampled_theta) )