Created
December 24, 2015 20:51
-
-
Save fmm/d0a16cfef3c692b4d8fc to your computer and use it in GitHub Desktop.
Revisions
-
Filipe Martins created this gist
Dec 24, 2015 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,75 @@ import math import random # DATA # 1st: Coin B, {HTTTHHTHTH}, 5H,5T # 2nd: Coin A, {HHHHTHHHHH}, 9H,1T # 3rd: Coin A, {HTHHHHHTHH}, 8H,2T # 4th: Coin B, {HTHTTTHHTT}, 4H,6T # 5th: Coin A, {THHHTHHHTH}, 7H,3T sample = [ 'HTTTHHTHTH', # 1st: Coin B, 5H,5T 'HHHHTHHHHH', # 2nd: Coin A, 9H,1T 'HTHHHHHTHH', # 3rd: Coin A, 8H,2T 'HTHTTTHHTT', # 4th: Coin B, 4H,6T 'THHHTHHHTH', # 5th: Coin A, 7H,3T ]; def generate(n, pa, pb): data = [] for i in xrange(n): prob = pa if random.random() <= 0.5 else pb s = "" for j in xrange(10): if random.random() <= prob: s += "H" else: s += "T" data.append(s) return data def extract(tosses): output = [] for t in tosses: cnt = [0,0] for ch in t: if ch == 'H': cnt[0] += 1 else: cnt[1] += 1 output.append(cnt) return output def e_step(tosses, pa, pb): output = [] for t in tosses: ta = (pa)**t[0] * (1-pa)**t[1] tb = (pb)**t[0] * (1-pb)**t[1] z = [ta/(ta+tb), tb/(ta+tb)] #print t, z output.append(z) return output def m_step(tosses, zs): ha, ta = 0, 0 hb, tb = 0, 0 for (t,z) in zip(tosses,zs): ha += t[0]*z[0] ta += t[1]*z[0] hb += t[0]*z[1] tb += t[1]*z[1] #print ha, ta, hb, tb return [ha/(ha+ta), hb/(hb+tb)] def solve(data, pa, pb): tosses = extract(data) for i in xrange(1000): zs = e_step(tosses, pa, pb) npa, npb = m_step(tosses, zs) #print i, npa, npb pa, pb = npa, npb print pa, pb solve(sample, 0.6, 0.5)