# frozen_string_literal: true class RateLimiter::Redis < RateLimiter def get_global(name, ip) r = Cache.redis.get(global_key(name, ip)) return r.to_i unless r.nil? nil end def get_route(name, domain, path, ip) r = Cache.redis.get(route_key(name, domain, path, ip)) return r.to_i unless r.nil? nil end def get_global_ttl(name, ip) r = Cache.redis.pttl(global_key(name, ip)) return nil if r.nil? || r < 0 r end def get_route_ttl(name, domain, path, ip) r = Cache.redis.pttl(route_key(name, domain, path, ip)) return nil if r.nil? || r < 0 r end def consume_global(name, window, limit, ip) consume(global_key(name, ip), window, limit) end def consume_route(name, domain, path, window, limit, ip) consume(route_key(name, domain, path, ip), window, limit) end def fix_infinite_expiry(key, _val, expiry) Cache.redis.pexpire(key, expiry) if expiry.to_i == -1 end private def consume(key, window, limit) exists = Cache.redis.exists(key) == 1 Cache.redis.set(key, "0", px: window) unless exists val = Cache.redis.get(key).to_i exp = Cache.redis.pttl(key) fix_infinite_expiry(key, val, exp) # don't increase if the request will be rejected val > limit ? limit + 1 : Cache.redis.incr(key) end end