''' This is a module that defines some helper classes and functions for expiring groups of related keys at the same time. Written July 1-2, 2013 by Josiah Carlson Released into the public domain ''' import time import redis class KeyDiscoveryPipeline(object): ''' This class is used as a wrapper around a Redis pipeline in Python to discover keys that are being used. This will work for commands where the key to be modified is the first argument to the command. It won't work properly with commands that modify multiple keys at the same time like some forms of DEL, Used like:: >>> conn = redis.Redis() >>> pipe = KeyDiscoveryPipeline(conn.pipeline()) >>> pipe.sadd('foo', 'bar') >>> pipe.execute() # will fail, use one of the subclasses ''' def __init__(self, pipeline): self.pipeline = pipeline self.keys = set() def __getattr__(self, attribute): ''' This is a bit of Python magic to discover the keys that are being modified. ''' def call(*args, **kwargs): if args: self.keys.add(args[0]) return getattr(self.pipeline, attribute)(*args, **kwargs) return call def execute(self): raise NotImplementedError TTL = 86400 # one day def get_user(keys): # we will assume that all keys are of the form: # :[:] for skey in keys: return skey.partition(':')[0] return [] class ExpirePipeline(KeyDiscoveryPipeline): ''' Will automatically call EXPIRE on all keys used. ''' def execute(self): user = get_user(self.keys) if not user: return user # add all the keys to the expire SET self.keys.add(user + ':expire') self.pipeline.sadd(user + ':expire', *list(self.keys)) # fetch all known keys from the expire SET self.pipeline.smembers(user + ':expire') # get the results result = self.pipeline.execute() # keep the results to return separate ret = result[:-len(self.keys)-1] # update the expiration time for all known keys for key in result[-1]: self.pipeline.expire(key, TTL) self.pipeline.execute() # clear all known keys and return the result self.keys = set() return ret class SetExpirePipeline(KeyDiscoveryPipeline): ''' Supposed to be used by a Redis-level C change, but won't work with standard Redis. ''' def execute(self): user = get_user(self.keys) if not user: return user # add all of the keys to the expiration SET self.pipeline.sadd(user + ':expire', *list(self.keys)) # this won't work, EXPIRE doesn't take # a 3rd argument - only for show self.pipeline.expire(user + ':expire', TTL, 'keys') try: return self.pipeline.execute()[:-2] finally: self.keys = set() class LuaExpirePipeline(KeyDiscoveryPipeline): ''' This is supposed to be used with the expire_user() function to expire user data. ''' def execute(self): # This first part is the same as SetExpirePipeline user = get_user(self.keys) if not user: return user self.pipeline.sadd(user + ':expire', *list(self.keys)) # Instead of calling EXPIRE, we'll just add it to the # expiration ZSET self.pipeline.zadd(':expire', **{user: time.time()}) try: return self.pipeline.execute()[:-2] finally: self.keys = set() def script_load(script): ''' This function is borrowed from Redis in Action and is MIT licensed. It is provided for convenience. ''' sha = [None] def call(conn, keys=[], args=[], force_eval=False): if not force_eval: if not sha[0]: sha[0] = conn.execute_command( "SCRIPT", "LOAD", script, parse="LOAD") try: return conn.execute_command( "EVALSHA", sha[0], len(keys), *(keys+args)) except redis.exceptions.ResponseError as msg: if not msg.args[0].startswith("NOSCRIPT"): raise return conn.execute_command( "EVAL", script, len(keys), *(keys+args)) return call def expire_user(conn, cutoff=None): ''' Expire a single user that was updated as part of calls to LuaExpirePipeline. ''' # warning: this is not Redis Cluster compatible return expire_user_lua(conn, [], [cutoff or time.time() - TTL]) expire_user_lua = script_load(''' -- fetch the first user with a score before our cutoff local key = redis.call('zrangebyscore', ':expire', 0, ARGV[1], 'LIMIT', 0, 1) if #key == 0 then return 0 end -- fetch the known keys to delete local keys = redis.call('smembers', key[1] .. ':expire') keys[#keys+1] = key[1] .. ':expire' -- delete the keys and remove the entry from the zset redis.call('del', unpack(keys)) redis.call('zrem', ':expire', key[1]) return 1 ''') class LuaExpirePipeline2(KeyDiscoveryPipeline): ''' This is supposed to be used with the expire_user2() function to expire user data, and is modified to somewhat reduce execution time. ''' def execute(self): # This first part is the same as LuaExpirePipeline user = get_user(self.keys) if not user: return user # Instead of adding this to the ZSET, we'll update the user # metadata entry - make sure it's in the expire SET! self.hset(user + ':info', 'updated', time.time()) self.pipeline.sadd(user + ':expire', *list(self.keys)) try: return self.pipeline.execute()[:-2] finally: self.keys = set() def expire_user2(conn, cutoff=None): ''' Expire a single user that was updated as part of calls to LuaExpirePipeline2. ''' # warning: this is also not Redis Cluster compatible return expire_user_lua2(conn, [], [cutoff or time.time() - TTL]) expire_user_lua2 = script_load(''' -- same as before local key = redis.call('zrangebyscore', ':expire', 0, ARGV[1], 'LIMIT', 0, 1) if #key == 0 then return 0 end -- verify that the user data should expire local last = redis.call('hget', key[1] .. ':info', 'updated') if tonumber(last) > tonumber(ARGV[1]) then -- shouldn't expire, so update the ZSET redis.call('zadd', ':expire', last, key[1]) return 1 end local keys = redis.call('smembers', key[1] .. ':expire') keys[#keys+1] = key[1] .. ':expire' redis.call('del', unpack(keys)) redis.call('zrem', ':expire', key[1]) return 1 ''') def crappy_test(): conn = redis.Redis(db=15) conn.flushdb() c1 = ExpirePipeline(conn.pipeline(True)) c1.sadd('12:foo', 'bar') c1.hset('12:goo', 'goo', 'baz') c1.execute() for k in conn.keys('*'): print k, conn.ttl(k) print conn.flushdb() c2 = LuaExpirePipeline(conn.pipeline(True)) c2.sadd('12:foo', 'bar') c2.hset('12:goo', 'goo', 'baz') c2.execute() for k in conn.keys('*'): print k, conn.ttl(k) print ':', conn.smembers('12:expire') print conn.zrange(':expire', 0, -1, withscores=True) print expire_user(conn, time.time() + 1) for k in conn.keys('*'): print k, conn.ttl(k) print ':', conn.smembers('12:expire') print conn.zrange(':expire', 0, -1, withscores=True) print conn.flushdb() # this should be done during user login conn.zadd(':expire', '12', time.time()) c3 = LuaExpirePipeline2(conn.pipeline(True)) c3.sadd('12:foo', 'bar') c3.hset('12:goo', 'goo', 'baz') c3.execute() for k in conn.keys('*'): print k, conn.ttl(k) print ':', conn.smembers('12:expire') print conn.zrange(':expire', 0, -1, withscores=True) print expire_user(conn, time.time() + 1) for k in conn.keys('*'): print k, conn.ttl(k) print ':', conn.smembers('12:expire') print conn.zrange(':expire', 0, -1, withscores=True) if __name__ == '__main__': crappy_test()