Skip to content

Instantly share code, notes, and snippets.

@fmm
Created December 24, 2015 20:51
Show Gist options
  • Select an option

  • Save fmm/d0a16cfef3c692b4d8fc to your computer and use it in GitHub Desktop.

Select an option

Save fmm/d0a16cfef3c692b4d8fc to your computer and use it in GitHub Desktop.

Revisions

  1. Filipe Martins created this gist Dec 24, 2015.
    75 changes: 75 additions & 0 deletions em.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,75 @@
    import math
    import random

    # DATA
    # 1st: Coin B, {HTTTHHTHTH}, 5H,5T
    # 2nd: Coin A, {HHHHTHHHHH}, 9H,1T
    # 3rd: Coin A, {HTHHHHHTHH}, 8H,2T
    # 4th: Coin B, {HTHTTTHHTT}, 4H,6T
    # 5th: Coin A, {THHHTHHHTH}, 7H,3T

    sample = [
    'HTTTHHTHTH', # 1st: Coin B, 5H,5T
    'HHHHTHHHHH', # 2nd: Coin A, 9H,1T
    'HTHHHHHTHH', # 3rd: Coin A, 8H,2T
    'HTHTTTHHTT', # 4th: Coin B, 4H,6T
    'THHHTHHHTH', # 5th: Coin A, 7H,3T
    ];

    def generate(n, pa, pb):
    data = []
    for i in xrange(n):
    prob = pa if random.random() <= 0.5 else pb
    s = ""
    for j in xrange(10):
    if random.random() <= prob:
    s += "H"
    else:
    s += "T"
    data.append(s)
    return data

    def extract(tosses):
    output = []
    for t in tosses:
    cnt = [0,0]
    for ch in t:
    if ch == 'H':
    cnt[0] += 1
    else:
    cnt[1] += 1
    output.append(cnt)
    return output

    def e_step(tosses, pa, pb):
    output = []
    for t in tosses:
    ta = (pa)**t[0] * (1-pa)**t[1]
    tb = (pb)**t[0] * (1-pb)**t[1]
    z = [ta/(ta+tb), tb/(ta+tb)]
    #print t, z
    output.append(z)
    return output

    def m_step(tosses, zs):
    ha, ta = 0, 0
    hb, tb = 0, 0
    for (t,z) in zip(tosses,zs):
    ha += t[0]*z[0]
    ta += t[1]*z[0]
    hb += t[0]*z[1]
    tb += t[1]*z[1]
    #print ha, ta, hb, tb
    return [ha/(ha+ta), hb/(hb+tb)]

    def solve(data, pa, pb):
    tosses = extract(data)
    for i in xrange(1000):
    zs = e_step(tosses, pa, pb)
    npa, npb = m_step(tosses, zs)
    #print i, npa, npb
    pa, pb = npa, npb
    print pa, pb


    solve(sample, 0.6, 0.5)