# https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/ import torch def two_pass_variance(data): n = len(data) mean = sum(data) / n var = sum([(x - mean) ** 2 for x in data]) / (n - 1) return var def one_pass_variance(data): # numerically unstable n = len(data) sum = 0.0 sumq = 0.0 for x in data: sum += x sumq += x**2 mean = sum / n return (sumq - n * mean**2) / (n - 1) def welford_variance(data): m = 0.0 s = 0.0 for idx, x in enumerate(data): old_m = m m = m + (x - m) / (idx + 1) s = s + (x - m) * (x - old_m) return s / (len(data) - 1) x = torch.randn(10) y1 = two_pass_variance(x) y2 = one_pass_variance(x) y3 = welford_variance(x) print(y1, y2, y3) assert torch.allclose(y1, y2), "two_pass_variance and one_pass_variance are not equal" assert torch.allclose(y1, y3), "two_pass_variance and welford_variance are not equal"