Skip to content

Instantly share code, notes, and snippets.

@phelrine
Created November 20, 2011 12:27
Show Gist options
  • Select an option

  • Save phelrine/1380219 to your computer and use it in GitHub Desktop.

Select an option

Save phelrine/1380219 to your computer and use it in GitHub Desktop.

Revisions

  1. phelrine created this gist Nov 20, 2011.
    37 changes: 37 additions & 0 deletions gmm.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,37 @@
    import numpy as np
    import numpy.random as nprand
    import matplotlib.pyplot as plt

    def dnorm(x, m, s):
    return np.exp(-((x - m) ** 2)/(2 * s)) / np.sqrt(2 * np.pi * s)

    def EM(data, init, iter):
    params = np.array(init)

    for _ in range(iter):
    w = np.array(map(lambda d: d/sum(d), [[dnorm(d, p[0], p[1]) * p[2] for p in params] for d in data])).T
    c = map(sum, w)
    for i, v in enumerate(params):
    params[i, 0] = sum([w[i, j] * d for j, d in enumerate(data)]) / c[i]
    params[i, 1] = sum([w[i, j] * (d - params[i, 0]) ** 2 for j, d in enumerate(data)]) / c[i]
    params[i, 2] = c[i] / len(data)

    return params

    def main():
    params = [[-7, 2, 0.3], [-2, 1, 0.5], [4, 3, 0.2]]
    xs = np.linspace(-15, 15, 1000)
    ys = [sum([dnorm(x, m, s) * c for m, s, c in params]) for x in xs]
    plt.plot(xs, ys, label = "base")

    data = reduce(lambda r, l: r + l, [nprand.normal(m, np.sqrt(s), c * 500).tolist() for m, s, c in params], [])
    for i in map(lambda x: 2 ** x, range(5)):
    est = EM(data, [[-6, 2, 0.3], [0, 2, 0.4], [6, 2, 0.3]], i)
    ys = [sum([dnorm(x, m, s) * c for m, s, c in est]) for x in xs]
    plt.plot(xs, ys, label = "iter-%d" % i)

    plt.legend(loc = "best")
    plt.show()

    if __name__ == '__main__':
    main()