add an optional lock mechanism to redis sessions
This commit is contained in:
@@ -42,8 +42,10 @@ class RedisClient(object):
|
||||
meta_storage = {}
|
||||
MAX_RETRIES = 5
|
||||
RETRIES = 0
|
||||
_release_script = None
|
||||
|
||||
def __init__(self, server='localhost:6379', db=None, debug=False, session_expiry=False):
|
||||
def __init__(self, server='localhost:6379', db=None, debug=False,
|
||||
session_expiry=False, with_lock=False):
|
||||
"""session_expiry can be an integer, in seconds, to set the default expiration
|
||||
of sessions. The corresponding record will be deleted from the redis instance,
|
||||
and there's virtually no need to run sessions2trash.py
|
||||
@@ -58,8 +60,12 @@ class RedisClient(object):
|
||||
else:
|
||||
self.app = ''
|
||||
self.r_server = redis.Redis(host=host, port=port, db=self.db)
|
||||
if with_lock:
|
||||
RedisClient._release_script = \
|
||||
self.r_server.register_script(_LUA_RELEASE_LOCK)
|
||||
self.tablename = None
|
||||
self.session_expiry = session_expiry
|
||||
self.with_lock = with_lock
|
||||
|
||||
def get(self, what, default):
|
||||
return self.tablename
|
||||
@@ -71,7 +77,8 @@ class RedisClient(object):
|
||||
def define_table(self, tablename, *fields, **args):
|
||||
if not self.tablename:
|
||||
self.tablename = MockTable(
|
||||
self, self.r_server, tablename, self.session_expiry)
|
||||
self, self.r_server, tablename, self.session_expiry,
|
||||
self.with_lock)
|
||||
return self.tablename
|
||||
|
||||
def __getitem__(self, key):
|
||||
@@ -88,7 +95,7 @@ class RedisClient(object):
|
||||
|
||||
class MockTable(object):
|
||||
|
||||
def __init__(self, db, r_server, tablename, session_expiry):
|
||||
def __init__(self, db, r_server, tablename, session_expiry, with_lock=False):
|
||||
self.db = db
|
||||
self.r_server = r_server
|
||||
self.tablename = tablename
|
||||
@@ -101,15 +108,14 @@ class MockTable(object):
|
||||
self.id_idx = "%s:id_idx" % self.keyprefix
|
||||
#remember the session_expiry setting
|
||||
self.session_expiry = session_expiry
|
||||
|
||||
def getserial(self):
|
||||
#return an auto-increment id
|
||||
return "%s" % self.r_server.incr(self.serial, 1)
|
||||
self.with_lock = with_lock
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key == 'id':
|
||||
#return a fake query. We need to query it just by id for normal operations
|
||||
self.query = MockQuery(field='id', db=self.r_server, prefix=self.keyprefix, session_expiry=self.session_expiry)
|
||||
self.query = MockQuery(field='id', db=self.r_server,
|
||||
prefix=self.keyprefix, session_expiry=self.session_expiry,
|
||||
with_lock=self.with_lock)
|
||||
return self.query
|
||||
elif key == '_db':
|
||||
#needed because of the calls in sessions2trash.py and globals.py
|
||||
@@ -120,14 +126,21 @@ class MockTable(object):
|
||||
#'locked', 'client_ip','created_datetime','modified_datetime'
|
||||
#'unique_key', 'session_data'
|
||||
#retrieve a new key
|
||||
newid = self.getserial()
|
||||
key = "%s:%s" % (self.keyprefix, newid)
|
||||
#add it to the index
|
||||
self.r_server.sadd(self.id_idx, key)
|
||||
#set a hash key with the Storage
|
||||
self.r_server.hmset(key, kwargs)
|
||||
if self.session_expiry:
|
||||
self.r_server.expire(key, self.session_expiry)
|
||||
newid = str(self.r_server.incr(self.serial))
|
||||
key = self.keyprefix + ':' + newid
|
||||
if self.with_lock:
|
||||
key_lock = key + ':lock'
|
||||
acquire_lock(self.r_server, key_lock, newid)
|
||||
with self.r_server.pipeline() as pipe:
|
||||
#add it to the index
|
||||
pipe.sadd(self.id_idx, key)
|
||||
#set a hash key with the Storage
|
||||
pipe.hmset(key, kwargs)
|
||||
if self.session_expiry:
|
||||
pipe.expire(key, self.session_expiry)
|
||||
pipe.execute()
|
||||
if self.with_lock:
|
||||
release_lock(self.r_server, key_lock, newid)
|
||||
return newid
|
||||
|
||||
|
||||
@@ -135,13 +148,15 @@ class MockQuery(object):
|
||||
"""a fake Query object that supports querying by id
|
||||
and listing all keys. No other operation is supported
|
||||
"""
|
||||
def __init__(self, field=None, db=None, prefix=None, session_expiry=False):
|
||||
def __init__(self, field=None, db=None, prefix=None, session_expiry=False,
|
||||
with_lock=False):
|
||||
self.field = field
|
||||
self.value = None
|
||||
self.db = db
|
||||
self.keyprefix = prefix
|
||||
self.op = None
|
||||
self.session_expiry = session_expiry
|
||||
self.with_lock = with_lock
|
||||
|
||||
def __eq__(self, value, op='eq'):
|
||||
self.value = value
|
||||
@@ -154,7 +169,10 @@ class MockQuery(object):
|
||||
def select(self):
|
||||
if self.op == 'eq' and self.field == 'id' and self.value:
|
||||
#means that someone wants to retrieve the key self.value
|
||||
rtn = self.db.hgetall("%s:%s" % (self.keyprefix, self.value))
|
||||
key = self.keyprefix + ':' + self.value
|
||||
if self.with_lock:
|
||||
acquire_lock(self.db, key + ':lock', self.value)
|
||||
rtn = self.db.hgetall(key)
|
||||
return [Storage(rtn)] if rtn else []
|
||||
elif self.op == 'ge' and self.field == 'id' and self.value == 0:
|
||||
#means that someone wants the complete list
|
||||
@@ -182,9 +200,13 @@ class MockQuery(object):
|
||||
#means that the session has been found and needs an update
|
||||
if self.op == 'eq' and self.field == 'id' and self.value:
|
||||
key = "%s:%s" % (self.keyprefix, self.value)
|
||||
rtn = self.db.hmset(key, kwargs)
|
||||
if self.session_expiry:
|
||||
self.db.expire(key, self.session_expiry)
|
||||
with self.db.pipeline() as pipe:
|
||||
pipe.hmset(key, kwargs)
|
||||
if self.session_expiry:
|
||||
pipe.expire(key, self.session_expiry)
|
||||
rtn = pipe.execute()[0]
|
||||
if self.with_lock:
|
||||
release_lock(self.db, key + ':lock', self.value)
|
||||
return rtn
|
||||
|
||||
|
||||
@@ -200,3 +222,24 @@ class RecordDeleter(object):
|
||||
self.db.srem(id_idx, self.key)
|
||||
#remove the key itself
|
||||
self.db.delete(self.key)
|
||||
|
||||
|
||||
def acquire_lock(conn, lockname, identifier, ltime=10):
|
||||
while True:
|
||||
if conn.set(lockname, identifier, ex=ltime, nx=True):
|
||||
return identifier
|
||||
time.sleep(.01)
|
||||
|
||||
|
||||
_LUA_RELEASE_LOCK = """
|
||||
if redis.call("get", KEYS[1]) == ARGV[1]
|
||||
then
|
||||
return redis.call("del", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
"""
|
||||
|
||||
|
||||
def release_lock(conn, lockname, identifier):
|
||||
return RedisClient._release_script(keys=[lockname], args=[identifier])
|
||||
|
||||
Reference in New Issue
Block a user