Skip to content

Instantly share code, notes, and snippets.

@cli99
Last active October 17, 2024 03:41
Show Gist options
  • Save cli99/be1b39c1ada619c24242b264cebf048a to your computer and use it in GitHub Desktop.
Save cli99/be1b39c1ada619c24242b264cebf048a to your computer and use it in GitHub Desktop.
welford's variance
# https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/
import torch
def two_pass_variance(data):
n = len(data)
mean = sum(data) / n
var = sum([(x - mean) ** 2 for x in data]) / (n - 1)
return var
def one_pass_variance(data):
# numerically unstable
n = len(data)
sum = 0.0
sumq = 0.0
for x in data:
sum += x
sumq += x**2
mean = sum / n
return (sumq - n * mean**2) / (n - 1)
def welford_variance(data):
m = 0.0
s = 0.0
for idx, x in enumerate(data):
old_m = m
m = m + (x - m) / (idx + 1)
s = s + (x - m) * (x - old_m)
return s / (len(data) - 1)
x = torch.randn(10)
y1 = two_pass_variance(x)
y2 = one_pass_variance(x)
y3 = welford_variance(x)
print(y1, y2, y3)
assert torch.allclose(y1, y2), "two_pass_variance and one_pass_variance are not equal"
assert torch.allclose(y1, y3), "two_pass_variance and welford_variance are not equal"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment