Skip to content

Instantly share code, notes, and snippets.

@cli99
Last active October 17, 2024 03:41
Show Gist options
  • Save cli99/be1b39c1ada619c24242b264cebf048a to your computer and use it in GitHub Desktop.
Save cli99/be1b39c1ada619c24242b264cebf048a to your computer and use it in GitHub Desktop.

Revisions

  1. cli99 created this gist Oct 17, 2024.
    41 changes: 41 additions & 0 deletions one_pass_variance.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,41 @@
    # https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/

    import torch


    def two_pass_variance(data):
    n = len(data)
    mean = sum(data) / n
    var = sum([(x - mean) ** 2 for x in data]) / (n - 1)
    return var


    def one_pass_variance(data):
    # numerically unstable
    n = len(data)
    sum = 0.0
    sumq = 0.0
    for x in data:
    sum += x
    sumq += x**2
    mean = sum / n
    return (sumq - n * mean**2) / (n - 1)


    def welford_variance(data):
    m = 0.0
    s = 0.0
    for idx, x in enumerate(data):
    old_m = m
    m = m + (x - m) / (idx + 1)
    s = s + (x - m) * (x - old_m)
    return s / (len(data) - 1)


    x = torch.randn(10)
    y1 = two_pass_variance(x)
    y2 = one_pass_variance(x)
    y3 = welford_variance(x)
    print(y1, y2, y3)
    assert torch.allclose(y1, y2), "two_pass_variance and one_pass_variance are not equal"
    assert torch.allclose(y1, y3), "two_pass_variance and welford_variance are not equal"