-
-
Save y-lan/d3aad8630bfbf4b79485 to your computer and use it in GitHub Desktop.
Revisions
-
Chris Stucchio revised this gist
Apr 14, 2013 . 1 changed file with 52 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,52 @@ from beta_bandit import * from numpy import * from scipy.stats import beta import random theta = (0.25, 0.35) def is_conversion(title): if random.random() < theta[title]: return True else: return False conversions = [0,0] trials = [0,0] N = 100000 trials = zeros(shape=(N,2)) successes = zeros(shape=(N,2)) bb = BetaBandit() for i in range(N): choice = bb.get_recommendation() trials[choice] = trials[choice]+1 conv = is_conversion(choice) bb.add_result(choice, conv) trials[i] = bb.trials successes[i] = bb.successes from pylab import * subplot(211) n = arange(N)+1 loglog(n, trials[:,0], label="title 0") loglog(n, trials[:,1], label="title 1") legend() xlabel("Number of trials") ylabel("Number of trials/title") subplot(212) semilogx(n, (successes[:,0]+successes[:,1])/n, label="CTR") semilogx(n, zeros(shape=(N,))+0.35, label="Best CTR") semilogx(n, zeros(shape=(N,))+0.30, label="Random chance CTR") semilogx(n, zeros(shape=(N,))+0.25, label="Worst CTR") axis([0,N,0.15,0.45]) xlabel("Number of trials") ylabel("CTR") legend() show() -
Chris Stucchio renamed this gist
Apr 14, 2013 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
Chris Stucchio revised this gist
Apr 14, 2013 . 1 changed file with 6 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -17,7 +17,10 @@ def add_result(self, trial_id, success): def get_recommendation(self): sampled_theta = [] for i in range(self.num_options): #Construct beta distribution for posterior dist = beta(self.prior[0]+self.successes[i], self.prior[1]+self.trials[i]-self.successes[i]) #Draw sample from beta distribution sampled_theta += [ dist.rvs() ] # Return the index of the sample with the largest value return sampled_theta.index( max(sampled_theta) ) -
Chris Stucchio created this gist
Apr 14, 2013 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,23 @@ from numpy import * from scipy.stats import beta class BetaBandit(object): def __init__(self, num_options=2, prior=(1.0,1.0)): self.trials = zeros(shape=(num_options,), dtype=int) self.successes = zeros(shape=(num_options,), dtype=int) self.num_options = num_options self.prior = prior def add_result(self, trial_id, success): self.trials[trial_id] = self.trials[trial_id] + 1 if (success): self.successes[trial_id] = self.successes[trial_id] + 1 def get_recommendation(self): sampled_theta = [] for i in range(self.num_options): dist = beta(self.prior[0]+self.successes[i], self.prior[1]+self.trials[i]-self.successes[i]) #Construct beta distribution for posterior sampled_theta += [ dist.rvs() ] #Draw sample from beta distribution return sampled_theta.index( max(sampled_theta) )