""" Consistent load-balancing. We have a few servers and we want a load-balancer to distribute incoming requests across them in a deterministic and consistent way - without keeping any counter to make the decision. Removing a backend server should not impact users on other servers. Adding a backend will generate the redistribution of users across other servers. The goal is to come up with the best algorithm for 1M users across 5 servers. Speed is a bonus. Two known techniques: - RendezVous : https://en.wikipedia.org/wiki/Rendezvous_hashing#Comparison_With_Consistent_Hashing - Consistent Hashing: https://en.wikipedia.org/wiki/Consistent_hashing Consistent Hashing implementation inspired by: http://techspot.zzzeek.org/2012/07/07/the-absolutely-simplest-consistent-hashing-example Also, good read on hashes: http://programmers.stackexchange.com/questions/49550/which-hashing-algorithm-is-best-for-uniqueness-and-speed/145633#145633 """ import hashlib import bisect from collections import defaultdict import binascii import time from functools import wraps class CollisionError(Exception): pass _collisions = {} def catch_collision(func): @wraps(func) def _catch(key): res = func(key) if res in _collisions and key != _collisions[res]: raise CollisionError('%s and %s with %s' % (key, _collisions[res], func)) _collisions[res] = key return res return _catch @catch_collision def fnv32a(key): hval = 0x811c9dc5 fnv_32_prime = 0x01000193 uint32_max = 2 ** 32 for s in key: hval = hval ^ ord(s) hval = (hval * fnv_32_prime) % uint32_max return hval @catch_collision def sha512(key): return long(hashlib.sha512(key).hexdigest(), 16) @catch_collision def sha256(key): return long(hashlib.sha256(key).hexdigest(), 16) @catch_collision def md5(key): return long(hashlib.md5(key).hexdigest(), 16) class RendezVous(object): def __init__(self, ips=None, hash=md5): if ips is None: ips = [] self.ips = ips self._hash = hash def __str__(self): return '' % self._hash def add(self, ip): self.ips.append(ip) def remove(self, ip): self.ips.remove(ip) def select(self, key): high_score = -1 winner = None for ip in self.ips: score = self._hash("%s-%s" % (str(ip), str(key))) if score > high_score: high_score, winner = score, ip elif score == high_score: high_score, winner = score, max(str(ip), str(winner)) return winner def _repl(name, index): return '%s:%d' % (name, index) class ConsistentHashing(object): def __init__(self, ips=[], replicas=200, hash=md5): self._ips = {} self._hashed_ips = [] self.replicas = replicas self._hash = hash for ip in ips: self.add(ip) def __str__(self): return '' % self._hash def add(self, ip): for i in range(self.replicas): sip = _repl(ip, i) hashed = self._hash(sip) self._ips[hashed] = sip bisect.insort(self._hashed_ips, hashed) def remove(self, ip): for i in range(self.replicas): sip = _repl(ip, i) hashed = self._hash(sip) del self._ips[hashed] index = bisect.bisect_left(self._hashed_ips, hashed) del self._hashed_ips[index] def select(self, username): hashed = self._hash(username) start = bisect.bisect(self._hashed_ips, hashed, hi=len(self._hashed_ips)-1) return self._ips[self._hashed_ips[start]].split(':')[0] NUM_USERS = 1000000 def run_test(servers, users): selection = defaultdict(list) for user in users: user_db = servers.select(user) selection[user_db].append(user) print '====' print('Distribution') smallest = NUM_USERS + 1 biggest = 0 for db in selection: size = len(selection[db]) if size < smallest: smallest = size if size > biggest: biggest = size print('%d users in %s' % (size, db)) print('span: %d' % (biggest - smallest)) # removing server 2 and 4 servers.remove('postgres2') print '====' selection = defaultdict(list) for user in users: user_db = servers.select(user) selection[user_db].append(user) smallest = NUM_USERS + 1 biggest = 0 for i, db in enumerate(selection): size = len(selection[db]) if size < smallest: smallest = size if size > biggest: biggest = size print('%d users in %s' % (size, db)) print('span: %d' % (biggest - smallest)) if __name__ == '__main__': users = ['%06d' % i for i in range(NUM_USERS)] servers = ['postgres5', 'postgres2', 'postgres3', 'postgres4', 'postgres1'] for klass in (ConsistentHashing, RendezVous): for hash in (md5, sha256, binascii.crc32, fnv32a, sha512): try: cluster = klass(list(servers), hash=hash) except CollisionError: print('Collision error with hash %s' % hash) continue print(cluster) start = time.time() try: run_test(cluster, users) except CollisionError: print('Collision error..') print('Took %d seconds' % (time.time() - start)) print print