Skip to content

Instantly share code, notes, and snippets.

@xtao
Forked from tarekziade/distribution.py
Created June 4, 2016 05:39
Show Gist options
  • Save xtao/b2e918c7d0dc092bbab24637ef7ae6bc to your computer and use it in GitHub Desktop.
Save xtao/b2e918c7d0dc092bbab24637ef7ae6bc to your computer and use it in GitHub Desktop.
Consistent Distribution of users across servers
""" Consistent load-balancing.
We have a few servers and we want a load-balancer to
distribute incoming requests across them in a deterministic
and consistent way - without keeping any counter to make the
decision.
Removing a backend server should not impact users on other
servers.
Adding a backend will generate the redistribution of
users across other servers.
The goal is to come up with the best algorithm
for 1M users across 5 servers. Speed is a bonus.
Two known techniques:
- RendezVous : https://en.wikipedia.org/wiki/Rendezvous_hashing#Comparison_With_Consistent_Hashing
- Consistent Hashing: https://en.wikipedia.org/wiki/Consistent_hashing
Consistent Hashing implementation inspired by:
http://techspot.zzzeek.org/2012/07/07/the-absolutely-simplest-consistent-hashing-example
XXX Add credits for the murmur hash
Also, good read on hashes:
http://programmers.stackexchange.com/questions/49550/which-hashing-algorithm-is-best-for-uniqueness-and-speed/145633#145633
"""
import hashlib
import bisect
from collections import defaultdict
import binascii
def fnv32a(key):
hval = 0x811c9dc5
fnv_32_prime = 0x01000193
uint32_max = 2 ** 32
for s in key:
hval = hval ^ ord(s)
hval = (hval * fnv_32_prime) % uint32_max
return hval
def sha256(key):
return long(hashlib.sha256(key).hexdigest(), 16)
def md5(key):
return long(hashlib.md5(key).hexdigest(), 16)
def bytes_to_long(bytes):
assert len(bytes) == 8
return sum((b << (k * 8) for k, b in enumerate(bytes)))
def murmur64(data, seed=19820125):
m = 0xc6a4a7935bd1e995
r = 47
MASK = 2 ** 64 - 1
data = bytearray(data)
h = seed ^ ((m * len(data)) & MASK)
off = len(data) / 8 * 8
for ll in range(0, off, 8):
k = bytes_to_long(data[ll:ll + 8])
k = (k * m) & MASK
k = k ^ ((k >> r) & MASK)
k = (k * m) & MASK
h = (h ^ k)
h = (h * m) & MASK
l = len(data) & 7
if l >= 7:
h = (h ^ (data[off+6] << 48))
if l >= 6:
h = (h ^ (data[off+5] << 40))
if l >= 5:
h = (h ^ (data[off+4] << 32))
if l >= 4:
h = (h ^ (data[off+3] << 24))
if l >= 3:
h = (h ^ (data[off+2] << 16))
if l >= 2:
h = (h ^ (data[off+1] << 8))
if l >= 1:
h = (h ^ data[off])
h = (h * m) & MASK
h = h ^ ((h >> r) & MASK)
h = (h * m) & MASK
h = h ^ ((h >> r) & MASK)
return h
class RendezVous(object):
def __init__(self, ips=None, hash=murmur64):
if ips is None:
ips = []
self.ips = ips
self._hash = hash
def __str__(self):
return '<RendezVous with %s hash>' % self._hash
def add(self, ip):
self.ips.append(ip)
def remove(self, ip):
self.ips.remove(ip)
def select(self, key):
high_score = -1
winner = None
for ip in self.ips:
score = self._hash("%s-%s" % (str(ip), str(key)))
if score > high_score:
high_score, winner = score, ip
elif score == high_score:
high_score, winner = score, max(str(ip), str(winner))
return winner
def _repl(name, index):
return '%s:%d' % (name, index)
class ConsistentHashing(object):
def __init__(self, ips=[], replicas=100, hash=md5):
self._ips = {}
self._hashed_ips = []
self.replicas = replicas
self._hash = hash
for ip in ips:
self.add(ip)
def __str__(self):
return '<ConsistentHashing with %s hash>' % self._hash
def add(self, ip):
for i in range(self.replicas):
sip = _repl(ip, i)
hashed = self._hash(sip)
self._ips[hashed] = sip
bisect.insort(self._hashed_ips, hashed)
def remove(self, ip):
for i in range(self.replicas):
sip = _repl(ip, i)
hashed = self._hash(sip)
del self._ips[hashed]
index = bisect.bisect_left(self._hashed_ips, hashed)
del self._hashed_ips[index]
def select(self, username):
hashed = self._hash(username)
start = bisect.bisect(self._hashed_ips, hashed,
hi=len(self._hashed_ips)-1)
return self._ips[self._hashed_ips[start]].split(':')[0]
NUM_USERS = 1000000
def run_test(servers, users):
selection = defaultdict(list)
for user in users:
user_db = servers.select(user)
selection[user_db].append(user)
print '===='
print('Distribution')
smallest = NUM_USERS + 1
biggest = 0
for db in selection:
size = len(selection[db])
if size < smallest:
smallest = size
if size > biggest:
biggest = size
print('%d users in %s' % (size, db))
print('span: %d' % (biggest - smallest))
# removing server 2 and 4
servers.remove('postgres2')
print '===='
selection = defaultdict(list)
for user in users:
user_db = servers.select(user)
selection[user_db].append(user)
smallest = NUM_USERS + 1
biggest = 0
for i, db in enumerate(selection):
size = len(selection[db])
if size < smallest:
smallest = size
if size > biggest:
biggest = size
print('%d users in %s' % (size, db))
print('span: %d' % (biggest - smallest))
if __name__ == '__main__':
users = ['%06d' % i for i in range(NUM_USERS)]
servers = ['postgres5', 'postgres2', 'postgres3', 'postgres4',
'postgres1']
for klass in (ConsistentHashing, RendezVous):
for hash in (md5, murmur64, sha256, binascii.crc32, fnv32a):
cluster = klass(list(servers), hash=hash)
print(cluster)
run_test(cluster, users)
print
print
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment