-
-
Save josiahcarlson/80584b49da41549a7d5c to your computer and use it in GitHub Desktop.
| ''' | |
| rate_limit2.py | |
| Copyright 2014, Josiah Carlson - [email protected] | |
| Released under the MIT license | |
| This module intends to show how to perform standard and sliding-window rate | |
| limits as a companion to the two articles posted on Binpress entitled | |
| "Introduction to rate limiting with Redis", parts 1 and 2: | |
| http://www.binpress.com/tutorial/introduction-to-rate-limiting-with-redis/155 | |
| http://www.binpress.com/tutorial/introduction-to-rate-limiting-with-redis/166 | |
| ... which will (or have already been) reposted on my personal blog at least 2 | |
| weeks after their original binpress.com posting: | |
| http://www.dr-josiah.com | |
| ''' | |
| import time | |
| from flask import g, request | |
| def get_identifiers(): | |
| ret = ['ip:' + request.remote_addr] | |
| if g.user.is_authenticated(): | |
| ret.append('user:' + g.user.get_id()) | |
| return ret | |
| def over_limit(conn, duration=3600, limit=240): | |
| bucket = ':%i:%i'%(duration, time.time() // duration) | |
| for id in get_identifiers(): | |
| key = id + bucket | |
| count = conn.incr(key) | |
| conn.expire(key, duration) | |
| if count > limit: | |
| return True | |
| return False | |
| def over_limit_multi(conn, limits=[(1, 10), (60, 120), (3600, 240)]): | |
| for duration, limit in limits: | |
| if over_limit(conn, duration, limit): | |
| return True | |
| return False | |
| def over_limit(conn, duration=3600, limit=240): | |
| # Replaces the earlier over_limit() function and reduces round trips with | |
| # pipelining. | |
| pipe = conn.pipeline(transaction=True) | |
| bucket = ':%i:%i'%(duration, time.time() // duration) | |
| for id in get_identifiers(): | |
| key = id + bucket | |
| pipe.incr(key) | |
| pipe.expire(key, duration) | |
| if pipe.execute()[0] > limit: | |
| return True | |
| return False | |
| def over_limit_multi_lua(conn, limits=[(1, 10), (60, 120), (3600, 240)]): | |
| if not hasattr(conn, 'over_limit_lua'): | |
| conn.over_limit_lua = conn.register_script(over_limit_multi_lua_) | |
| return conn.over_limit_lua( | |
| keys=get_identifiers(), args=[json.dumps(limits), time.time()]) | |
| over_limit_multi_lua_ = ''' | |
| local limits = cjson.decode(ARGV[1]) | |
| local now = tonumber(ARGV[2]) | |
| for i, limit in ipairs(limits) do | |
| local duration = limit[1] | |
| local bucket = ':' .. duration .. ':' .. math.floor(now / duration) | |
| for j, id in ipairs(KEYS) do | |
| local key = id .. bucket | |
| local count = redis.call('INCR', key) | |
| redis.call('EXPIRE', key, duration) | |
| if tonumber(count) > limit[2] then | |
| return 1 | |
| end | |
| end | |
| end | |
| return 0 | |
| ''' | |
| def over_limit_sliding_window(conn, weight=1, limits=[(1, 10), (60, 120), (3600, 240, 60)], redis_time=False): | |
| if not hasattr(conn, 'over_limit_sliding_window_lua'): | |
| conn.over_limit_sliding_window_lua = conn.register_script(over_limit_sliding_window_lua_) | |
| now = conn.time()[0] if redis_time else time.time() | |
| return conn.over_limit_sliding_window_lua( | |
| keys=get_identifiers(), args=[json.dumps(limits), now, weight]) | |
| over_limit_sliding_window_lua_ = ''' | |
| local limits = cjson.decode(ARGV[1]) | |
| local now = tonumber(ARGV[2]) | |
| local weight = tonumber(ARGV[3] or '1') | |
| local longest_duration = limits[1][1] or 0 | |
| local saved_keys = {} | |
| -- handle cleanup and limit checks | |
| for i, limit in ipairs(limits) do | |
| local duration = limit[1] | |
| longest_duration = math.max(longest_duration, duration) | |
| local precision = limit[3] or duration | |
| precision = math.min(precision, duration) | |
| local blocks = math.ceil(duration / precision) | |
| local saved = {} | |
| table.insert(saved_keys, saved) | |
| saved.block_id = math.floor(now / precision) | |
| saved.trim_before = saved.block_id - blocks + 1 | |
| saved.count_key = duration .. ':' .. precision .. ':' | |
| saved.ts_key = saved.count_key .. 'o' | |
| for j, key in ipairs(KEYS) do | |
| local old_ts = redis.call('HGET', key, saved.ts_key) | |
| old_ts = old_ts and tonumber(old_ts) or saved.trim_before | |
| if old_ts > now then | |
| -- don't write in the past | |
| return 1 | |
| end | |
| -- discover what needs to be cleaned up | |
| local decr = 0 | |
| local dele = {} | |
| local trim = math.min(saved.trim_before, old_ts + blocks) | |
| for old_block = old_ts, trim - 1 do | |
| local bkey = saved.count_key .. old_block | |
| local bcount = redis.call('HGET', key, bkey) | |
| if bcount then | |
| decr = decr + tonumber(bcount) | |
| table.insert(dele, bkey) | |
| end | |
| end | |
| -- handle cleanup | |
| local cur | |
| if #dele > 0 then | |
| redis.call('HDEL', key, unpack(dele)) | |
| cur = redis.call('HINCRBY', key, saved.count_key, -decr) | |
| else | |
| cur = redis.call('HGET', key, saved.count_key) | |
| end | |
| -- check our limits | |
| if tonumber(cur or '0') + weight > limit[2] then | |
| return 1 | |
| end | |
| end | |
| end | |
| -- there is enough resources, update the counts | |
| for i, limit in ipairs(limits) do | |
| local saved = saved_keys[i] | |
| for j, key in ipairs(KEYS) do | |
| -- update the current timestamp, count, and bucket count | |
| redis.call('HSET', key, saved.ts_key, saved.trim_before) | |
| redis.call('HINCRBY', key, saved.count_key, weight) | |
| redis.call('HINCRBY', key, saved.count_key .. saved.block_id, weight) | |
| end | |
| end | |
| -- We calculated the longest-duration limit so we can EXPIRE | |
| -- the whole HASH for quick and easy idle-time cleanup :) | |
| if longest_duration > 0 then | |
| for _, key in ipairs(KEYS) do | |
| redis.call('EXPIRE', key, longest_duration) | |
| end | |
| end | |
| return 0 | |
| ''' |
How would you return an actual timestamp instead of 1 to be used in a Retry-After header?
The answer for you @ciokan is you need to modify the Lua script to calculate the delay. Right now it just returns whether you need to wait. https://gist.github.com/josiahcarlson/80584b49da41549a7d5c#file-rate_limit2-py-L157 is the line you are looking for.
Hi I have three questions.
Question 1
In over_limit_sliding_window_lua_, should
if old_ts > now then
at here be
if old_ts > saved.block_id then
because old_ts is the oldest block id, not a timestamp?
Question 2
Should
local trim = math.min(saved.trim_before, old_ts + blocks)
at here be
saved.trim_before = math.min(saved.trim_before, old_ts + blocks)
because later when saving the oldest block id the code uses saved.trim_before
redis.call('HSET', key, saved.ts_key, saved.trim_before)
?
Question 3
Is the purpose of the code
local trim = math.min(saved.trim_before, old_ts + blocks)
at here to limit the number of blocks to trim to be at most blocks?
How would you return an actual timestamp instead of
1to be used in aRetry-Afterheader?
Replace line 157 (return 1) with the below code. We are trying to loop through the present duration blocks and find out the earliest block with a request made and then calculate the time until that request block would become stall and thus allows for new request.
-- return 1
local last_attempt
for last_block = saved.trim_before, saved.block_id, precision do
local bcount = redis.call('HGET', key, saved.count_key .. last_block)
if (bcount) then
last_attempt = last_block
break
end
end
local next_attempt
if last_attempt then
next_attempt = (last_attempt + blocks) * precision
else
next_attempt = 0
end
return next_attempt
Note: The next_attempt received is UNIX timestamp in seconds and not milliseconds
@josiahcarlson Please review this code for any improvement or bug
Would be nice if
over_limit_sliding_window_luareturned which limit was in effect, useful for different actions on different limits ("require captcha for this limit, reject on that limit"). For this you can just returniinstead of1.