add an optional lock mechanism to redis sessions

This commit is contained in:
Ricardo Pedroso
2013-05-18 20:26:34 +01:00
parent cc3f045efa
commit 693c5cd5e4

View File

@@ -42,8 +42,10 @@ class RedisClient(object):
meta_storage = {}
MAX_RETRIES = 5
RETRIES = 0
_release_script = None
def __init__(self, server='localhost:6379', db=None, debug=False, session_expiry=False):
def __init__(self, server='localhost:6379', db=None, debug=False,
session_expiry=False, with_lock=False):
"""session_expiry can be an integer, in seconds, to set the default expiration
of sessions. The corresponding record will be deleted from the redis instance,
and there's virtually no need to run sessions2trash.py
@@ -58,8 +60,12 @@ class RedisClient(object):
else:
self.app = ''
self.r_server = redis.Redis(host=host, port=port, db=self.db)
if with_lock:
RedisClient._release_script = \
self.r_server.register_script(_LUA_RELEASE_LOCK)
self.tablename = None
self.session_expiry = session_expiry
self.with_lock = with_lock
def get(self, what, default):
return self.tablename
@@ -71,7 +77,8 @@ class RedisClient(object):
def define_table(self, tablename, *fields, **args):
if not self.tablename:
self.tablename = MockTable(
self, self.r_server, tablename, self.session_expiry)
self, self.r_server, tablename, self.session_expiry,
self.with_lock)
return self.tablename
def __getitem__(self, key):
@@ -88,7 +95,7 @@ class RedisClient(object):
class MockTable(object):
def __init__(self, db, r_server, tablename, session_expiry):
def __init__(self, db, r_server, tablename, session_expiry, with_lock=False):
self.db = db
self.r_server = r_server
self.tablename = tablename
@@ -101,15 +108,14 @@ class MockTable(object):
self.id_idx = "%s:id_idx" % self.keyprefix
#remember the session_expiry setting
self.session_expiry = session_expiry
def getserial(self):
#return an auto-increment id
return "%s" % self.r_server.incr(self.serial, 1)
self.with_lock = with_lock
def __getattr__(self, key):
if key == 'id':
#return a fake query. We need to query it just by id for normal operations
self.query = MockQuery(field='id', db=self.r_server, prefix=self.keyprefix, session_expiry=self.session_expiry)
self.query = MockQuery(field='id', db=self.r_server,
prefix=self.keyprefix, session_expiry=self.session_expiry,
with_lock=self.with_lock)
return self.query
elif key == '_db':
#needed because of the calls in sessions2trash.py and globals.py
@@ -120,14 +126,21 @@ class MockTable(object):
#'locked', 'client_ip','created_datetime','modified_datetime'
#'unique_key', 'session_data'
#retrieve a new key
newid = self.getserial()
key = "%s:%s" % (self.keyprefix, newid)
#add it to the index
self.r_server.sadd(self.id_idx, key)
#set a hash key with the Storage
self.r_server.hmset(key, kwargs)
if self.session_expiry:
self.r_server.expire(key, self.session_expiry)
newid = str(self.r_server.incr(self.serial))
key = self.keyprefix + ':' + newid
if self.with_lock:
key_lock = key + ':lock'
acquire_lock(self.r_server, key_lock, newid)
with self.r_server.pipeline() as pipe:
#add it to the index
pipe.sadd(self.id_idx, key)
#set a hash key with the Storage
pipe.hmset(key, kwargs)
if self.session_expiry:
pipe.expire(key, self.session_expiry)
pipe.execute()
if self.with_lock:
release_lock(self.r_server, key_lock, newid)
return newid
@@ -135,13 +148,15 @@ class MockQuery(object):
"""a fake Query object that supports querying by id
and listing all keys. No other operation is supported
"""
def __init__(self, field=None, db=None, prefix=None, session_expiry=False):
def __init__(self, field=None, db=None, prefix=None, session_expiry=False,
with_lock=False):
self.field = field
self.value = None
self.db = db
self.keyprefix = prefix
self.op = None
self.session_expiry = session_expiry
self.with_lock = with_lock
def __eq__(self, value, op='eq'):
self.value = value
@@ -154,7 +169,10 @@ class MockQuery(object):
def select(self):
if self.op == 'eq' and self.field == 'id' and self.value:
#means that someone wants to retrieve the key self.value
rtn = self.db.hgetall("%s:%s" % (self.keyprefix, self.value))
key = self.keyprefix + ':' + self.value
if self.with_lock:
acquire_lock(self.db, key + ':lock', self.value)
rtn = self.db.hgetall(key)
return [Storage(rtn)] if rtn else []
elif self.op == 'ge' and self.field == 'id' and self.value == 0:
#means that someone wants the complete list
@@ -182,9 +200,13 @@ class MockQuery(object):
#means that the session has been found and needs an update
if self.op == 'eq' and self.field == 'id' and self.value:
key = "%s:%s" % (self.keyprefix, self.value)
rtn = self.db.hmset(key, kwargs)
if self.session_expiry:
self.db.expire(key, self.session_expiry)
with self.db.pipeline() as pipe:
pipe.hmset(key, kwargs)
if self.session_expiry:
pipe.expire(key, self.session_expiry)
rtn = pipe.execute()[0]
if self.with_lock:
release_lock(self.db, key + ':lock', self.value)
return rtn
@@ -200,3 +222,24 @@ class RecordDeleter(object):
self.db.srem(id_idx, self.key)
#remove the key itself
self.db.delete(self.key)
def acquire_lock(conn, lockname, identifier, ltime=10):
while True:
if conn.set(lockname, identifier, ex=ltime, nx=True):
return identifier
time.sleep(.01)
_LUA_RELEASE_LOCK = """
if redis.call("get", KEYS[1]) == ARGV[1]
then
return redis.call("del", KEYS[1])
else
return 0
end
"""
def release_lock(conn, lockname, identifier):
return RedisClient._release_script(keys=[lockname], args=[identifier])