Merge pull request #96 from rpedroso/redis-session

lock in redis session
This commit is contained in:
mdipierro
2013-05-19 17:12:53 -07:00

View File

@@ -42,8 +42,10 @@ class RedisClient(object):
meta_storage = {}
MAX_RETRIES = 5
RETRIES = 0
_release_script = None
def __init__(self, server='localhost:6379', db=None, debug=False, session_expiry=False):
def __init__(self, server='localhost:6379', db=None, debug=False,
session_expiry=False, with_lock=False):
"""session_expiry can be an integer, in seconds, to set the default expiration
of sessions. The corresponding record will be deleted from the redis instance,
and there's virtually no need to run sessions2trash.py
@@ -58,8 +60,12 @@ class RedisClient(object):
else:
self.app = ''
self.r_server = redis.Redis(host=host, port=port, db=self.db)
if with_lock:
RedisClient._release_script = \
self.r_server.register_script(_LUA_RELEASE_LOCK)
self.tablename = None
self.session_expiry = session_expiry
self.with_lock = with_lock
def get(self, what, default):
return self.tablename
@@ -71,7 +77,8 @@ class RedisClient(object):
def define_table(self, tablename, *fields, **args):
if not self.tablename:
self.tablename = MockTable(
self, self.r_server, tablename, self.session_expiry)
self, self.r_server, tablename, self.session_expiry,
self.with_lock)
return self.tablename
def __getitem__(self, key):
@@ -88,7 +95,7 @@ class RedisClient(object):
class MockTable(object):
def __init__(self, db, r_server, tablename, session_expiry):
def __init__(self, db, r_server, tablename, session_expiry, with_lock=False):
self.db = db
self.r_server = r_server
self.tablename = tablename
@@ -101,15 +108,14 @@ class MockTable(object):
self.id_idx = "%s:id_idx" % self.keyprefix
#remember the session_expiry setting
self.session_expiry = session_expiry
def getserial(self):
#return an auto-increment id
return "%s" % self.r_server.incr(self.serial, 1)
self.with_lock = with_lock
def __getattr__(self, key):
if key == 'id':
#return a fake query. We need to query it just by id for normal operations
self.query = MockQuery(field='id', db=self.r_server, prefix=self.keyprefix, session_expiry=self.session_expiry)
self.query = MockQuery(field='id', db=self.r_server,
prefix=self.keyprefix, session_expiry=self.session_expiry,
with_lock=self.with_lock)
return self.query
elif key == '_db':
#needed because of the calls in sessions2trash.py and globals.py
@@ -120,14 +126,21 @@ class MockTable(object):
#'locked', 'client_ip','created_datetime','modified_datetime'
#'unique_key', 'session_data'
#retrieve a new key
newid = self.getserial()
key = "%s:%s" % (self.keyprefix, newid)
#add it to the index
self.r_server.sadd(self.id_idx, key)
#set a hash key with the Storage
self.r_server.hmset(key, kwargs)
if self.session_expiry:
self.r_server.expire(key, self.session_expiry)
newid = str(self.r_server.incr(self.serial))
key = self.keyprefix + ':' + newid
if self.with_lock:
key_lock = key + ':lock'
acquire_lock(self.r_server, key_lock, newid)
with self.r_server.pipeline() as pipe:
#add it to the index
pipe.sadd(self.id_idx, key)
#set a hash key with the Storage
pipe.hmset(key, kwargs)
if self.session_expiry:
pipe.expire(key, self.session_expiry)
pipe.execute()
if self.with_lock:
release_lock(self.r_server, key_lock, newid)
return newid
@@ -135,13 +148,15 @@ class MockQuery(object):
"""a fake Query object that supports querying by id
and listing all keys. No other operation is supported
"""
def __init__(self, field=None, db=None, prefix=None, session_expiry=False):
def __init__(self, field=None, db=None, prefix=None, session_expiry=False,
with_lock=False):
self.field = field
self.value = None
self.db = db
self.keyprefix = prefix
self.op = None
self.session_expiry = session_expiry
self.with_lock = with_lock
def __eq__(self, value, op='eq'):
self.value = value
@@ -154,12 +169,11 @@ class MockQuery(object):
def select(self):
if self.op == 'eq' and self.field == 'id' and self.value:
#means that someone wants to retrieve the key self.value
rtn = self.db.hgetall("%s:%s" % (self.keyprefix, self.value))
if rtn == dict():
#return an empty resultset for non existing key
return []
else:
return [Storage(rtn)]
key = self.keyprefix + ':' + self.value
if self.with_lock:
acquire_lock(self.db, key + ':lock', self.value)
rtn = self.db.hgetall(key)
return [Storage(rtn)] if rtn else []
elif self.op == 'ge' and self.field == 'id' and self.value == 0:
#means that someone wants the complete list
rtn = []
@@ -168,13 +182,11 @@ class MockQuery(object):
allkeys = self.db.smembers(id_idx)
for sess in allkeys:
val = self.db.hgetall(sess)
if val == dict():
if not val:
if self.session_expiry:
#clean up the idx, because the key expired
self.db.srem(id_idx, sess)
continue
else:
continue
continue
val = Storage(val)
#add a delete_record method (necessary for sessions2trash.py)
val.delete_record = RecordDeleter(
@@ -188,9 +200,13 @@ class MockQuery(object):
#means that the session has been found and needs an update
if self.op == 'eq' and self.field == 'id' and self.value:
key = "%s:%s" % (self.keyprefix, self.value)
rtn = self.db.hmset(key, kwargs)
if self.session_expiry:
self.db.expire(key, self.session_expiry)
with self.db.pipeline() as pipe:
pipe.hmset(key, kwargs)
if self.session_expiry:
pipe.expire(key, self.session_expiry)
rtn = pipe.execute()[0]
if self.with_lock:
release_lock(self.db, key + ':lock', self.value)
return rtn
@@ -206,3 +222,25 @@ class RecordDeleter(object):
self.db.srem(id_idx, self.key)
#remove the key itself
self.db.delete(self.key)
def acquire_lock(conn, lockname, identifier, ltime=10):
while True:
if conn.set(lockname, identifier, ex=ltime, nx=True):
return identifier
time.sleep(.01)
_LUA_RELEASE_LOCK = """
if redis.call("get", KEYS[1]) == ARGV[1]
then
return redis.call("del", KEYS[1])
else
return 0
end
"""
def release_lock(conn, lockname, identifier):
return RedisClient._release_script(keys=[lockname], args=[identifier],
client=conn)