Skip to content

Instantly share code, notes, and snippets.

@bellbind
Created December 21, 2011 08:04
Show Gist options
  • Save bellbind/1505153 to your computer and use it in GitHub Desktop.
Save bellbind/1505153 to your computer and use it in GitHub Desktop.

Revisions

  1. bellbind created this gist Dec 21, 2011.
    63 changes: 63 additions & 0 deletions dft.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,63 @@
    """DFT and FFT"""
    import math

    def iexp(n):
    return complex(math.cos(n), math.sin(n))

    def is_pow2(n):
    return False if n == 0 else (n == 1 or is_pow2(n >> 1))

    def dft(xs):
    "naive dft"
    n = len(xs)
    return [sum((xs[k] * iexp(-2 * math.pi * i * k / n) for k in range(n)))
    for i in range(n)]

    def dftinv(xs):
    "naive dft"
    n = len(xs)
    return [sum((xs[k] * iexp(2 * math.pi * i * k / n) for k in range(n))) / n
    for i in range(n)]

    def fft_(xs, n, start=0, stride=1):
    "cooley-turkey fft"
    if n == 1: return [xs[start]]
    hn, sd = n // 2, stride * 2
    rs = fft_(xs, hn, start, sd) + fft_(xs, hn, start + stride, sd)
    for i in range(hn):
    e = iexp(-2 * math.pi * i / n)
    rs[i], rs[i + hn] = rs[i] + e * rs[i + hn], rs[i] - e * rs[i + hn]
    pass
    return rs

    def fft(xs):
    assert is_pow2(len(xs))
    return fft_(xs, len(xs))

    def fftinv_(xs, n, start=0, stride=1):
    "cooley-turkey fft"
    if n == 1: return [xs[start]]
    hn, sd = n // 2, stride * 2
    rs = fftinv_(xs, hn, start, sd) + fftinv_(xs, hn, start + stride, sd)
    for i in range(hn):
    e = iexp(2 * math.pi * i / n)
    rs[i], rs[i + hn] = rs[i] + e * rs[i + hn], rs[i] - e * rs[i + hn]
    pass
    return rs

    def fftinv(xs):
    assert is_pow2(len(xs))
    n = len(xs)
    return [v / n for v in fftinv_(xs, n)]

    if __name__ == "__main__":
    wave = [0, 1, 2, 3, 3, 2, 1, 0]
    dfreq = dft(wave)
    ffreq = fft(wave)
    dwave = dftinv(dfreq)
    fwave= fftinv(ffreq)
    print(dfreq)
    print(ffreq)
    print([v.real for v in dwave])
    print([v.real for v in fwave])
    pass