From d92dec84515aafbcbd952b14298cd633ddd6122c Mon Sep 17 00:00:00 2001 From: Massimo Di Pierro Date: Sun, 4 Dec 2011 17:45:28 -0600 Subject: [PATCH] db.define_table(...,common_filter) and db(...,ignore_common_filter=True), thanks Yair Eshel --- VERSION | 2 +- applications/admin/controllers/appadmin.py | 10 +-- applications/examples/controllers/appadmin.py | 10 +-- applications/welcome/controllers/appadmin.py | 10 +-- gluon/dal.py | 74 +++++++++++++------ 5 files changed, 68 insertions(+), 38 deletions(-) diff --git a/VERSION b/VERSION index 70d77739..57745f85 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -Version 1.99.3 (2011-12-04 16:24:53) dev +Version 1.99.3 (2011-12-04 17:45:21) dev diff --git a/applications/admin/controllers/appadmin.py b/applications/admin/controllers/appadmin.py index d20eb874..a177568b 100644 --- a/applications/admin/controllers/appadmin.py +++ b/applications/admin/controllers/appadmin.py @@ -150,7 +150,7 @@ def csv(): return None response.headers['Content-disposition'] = 'attachment; filename=%s_%s.csv'\ % tuple(request.vars.query.split('.')[:2]) - return str(db(query).select()) + return str(db(query).select(ignore_common_filter=True)) def import_csv(table, file): @@ -225,10 +225,10 @@ def select(): response.flash = T('%s rows deleted', nrows) nrows = db(query).count() if orderby: - rows = db(query).select(limitby=(start, stop), + rows = db(query).select(limitby=(start, stop),ignore_common_filter=True, orderby=eval_in_global_env(orderby)) else: - rows = db(query).select(limitby=(start, stop)) + rows = db(query).select(limitby=(start, stop), ignore_common_filter=True) except Exception, e: (rows, nrows) = ([], 0) response.flash = DIV(T('Invalid Query'),PRE(str(e))) @@ -255,9 +255,9 @@ def update(): if keyed: key = [f for f in request.vars if f in db[table]._primarykey] if key: - record = db(db[table][key[0]] == request.vars[key[0]]).select().first() + record = db(db[table][key[0]] == request.vars[key[0]]).select(ignore_common_filter=True).first() else: - record = db(db[table].id == request.args(2)).select().first() + record = db(db[table].id == request.args(2)).select(ignore_common_filter=True).first() if not record: qry = query_by_table_type(table, db) diff --git a/applications/examples/controllers/appadmin.py b/applications/examples/controllers/appadmin.py index d20eb874..a177568b 100644 --- a/applications/examples/controllers/appadmin.py +++ b/applications/examples/controllers/appadmin.py @@ -150,7 +150,7 @@ def csv(): return None response.headers['Content-disposition'] = 'attachment; filename=%s_%s.csv'\ % tuple(request.vars.query.split('.')[:2]) - return str(db(query).select()) + return str(db(query).select(ignore_common_filter=True)) def import_csv(table, file): @@ -225,10 +225,10 @@ def select(): response.flash = T('%s rows deleted', nrows) nrows = db(query).count() if orderby: - rows = db(query).select(limitby=(start, stop), + rows = db(query).select(limitby=(start, stop),ignore_common_filter=True, orderby=eval_in_global_env(orderby)) else: - rows = db(query).select(limitby=(start, stop)) + rows = db(query).select(limitby=(start, stop), ignore_common_filter=True) except Exception, e: (rows, nrows) = ([], 0) response.flash = DIV(T('Invalid Query'),PRE(str(e))) @@ -255,9 +255,9 @@ def update(): if keyed: key = [f for f in request.vars if f in db[table]._primarykey] if key: - record = db(db[table][key[0]] == request.vars[key[0]]).select().first() + record = db(db[table][key[0]] == request.vars[key[0]]).select(ignore_common_filter=True).first() else: - record = db(db[table].id == request.args(2)).select().first() + record = db(db[table].id == request.args(2)).select(ignore_common_filter=True).first() if not record: qry = query_by_table_type(table, db) diff --git a/applications/welcome/controllers/appadmin.py b/applications/welcome/controllers/appadmin.py index d20eb874..a177568b 100644 --- a/applications/welcome/controllers/appadmin.py +++ b/applications/welcome/controllers/appadmin.py @@ -150,7 +150,7 @@ def csv(): return None response.headers['Content-disposition'] = 'attachment; filename=%s_%s.csv'\ % tuple(request.vars.query.split('.')[:2]) - return str(db(query).select()) + return str(db(query).select(ignore_common_filter=True)) def import_csv(table, file): @@ -225,10 +225,10 @@ def select(): response.flash = T('%s rows deleted', nrows) nrows = db(query).count() if orderby: - rows = db(query).select(limitby=(start, stop), + rows = db(query).select(limitby=(start, stop),ignore_common_filter=True, orderby=eval_in_global_env(orderby)) else: - rows = db(query).select(limitby=(start, stop)) + rows = db(query).select(limitby=(start, stop), ignore_common_filter=True) except Exception, e: (rows, nrows) = ([], 0) response.flash = DIV(T('Invalid Query'),PRE(str(e))) @@ -255,9 +255,9 @@ def update(): if keyed: key = [f for f in request.vars if f in db[table]._primarykey] if key: - record = db(db[table][key[0]] == request.vars[key[0]]).select().first() + record = db(db[table][key[0]] == request.vars[key[0]]).select(ignore_common_filter=True).first() else: - record = db(db[table].id == request.args(2)).select().first() + record = db(db[table].id == request.args(2)).select(ignore_common_filter=True).first() if not record: qry = query_by_table_type(table, db) diff --git a/gluon/dal.py b/gluon/dal.py index 026319dc..280fee45 100644 --- a/gluon/dal.py +++ b/gluon/dal.py @@ -1077,13 +1077,15 @@ class BaseAdapter(ConnectionPool): finally: logfile.close() - def _update(self, tablename, query,fields): - query = self.filter_tenant(query, [tablename]) + def _update(self, tablename, query, fields): if query: + if not query.ignore_common_filters: + query = self.common_filter(query, [tablename]) sql_w = ' WHERE ' + self.expand(query) else: sql_w = '' - sql_v = ','.join(['%s=%s' % (field.name, self.expand(value, field.type)) for (field, value) in fields]) + sql_v = ','.join(['%s=%s' % (field.name, self.expand(value, field.type)) \ + for (field, value) in fields]) return 'UPDATE %s SET %s%s;' % (tablename, sql_v, sql_w) def update(self, tablename, query, fields): @@ -1095,8 +1097,9 @@ class BaseAdapter(ConnectionPool): return None def _delete(self, tablename, query): - query = self.filter_tenant(query, [tablename]) if query: + if not query.ignore_common_filters: + query = self.common_filter(query, [tablename]) sql_w = ' WHERE ' + self.expand(query) else: sql_w = '' @@ -1136,7 +1139,8 @@ class BaseAdapter(ConnectionPool): def _select(self, query, fields, attributes): for key in set(attributes.keys())-set(('orderby', 'groupby', 'limitby', 'required', 'cache', 'left', - 'distinct', 'having', 'join', 'for_update')): + 'distinct', 'having', 'join', + 'for_update')): raise SyntaxError, 'invalid select attribute: %s' % key # ## if no fields specified take them all from the requested tables new_fields = [] @@ -1146,8 +1150,11 @@ class BaseAdapter(ConnectionPool): else: new_fields.append(item) fields = new_fields - tablenames = self.tables(query) - query = self.filter_tenant(query,tablenames) + tablenames = self.tables(query) + + if not query.ignore_common_filters: + query = self.common_filter(query,tablenames) + if not fields: for table in tablenames: for field in self.db[table]: @@ -1281,8 +1288,9 @@ class BaseAdapter(ConnectionPool): def _count(self, query, distinct=None): tablenames = self.tables(query) - query = self.filter_tenant(query, tablenames) if query: + if not query.ignore_common_filters: + query = self.common_filter(query, tablenames) sql_w = ' WHERE ' + self.expand(query) else: sql_w = '' @@ -1570,14 +1578,21 @@ class BaseAdapter(ConnectionPool): pass return rowsobj - def filter_tenant(self, query, tablenames): - fieldname = self.db._request_tenant + def common_filter(self, query, tablenames): + tenant_fieldname = self.db._request_tenant + for tablename in tablenames: table = self.db[tablename] - if fieldname in table: - default = table[fieldname].default + + # deal with user provided filters + if table._common_filter != None: + query = query & table._common_filter(query) + + # deal with multi_tenant filters + if tenant_fieldname in table: + default = table[tenant_fieldname].default if not default is None: - newquery = table[fieldname] == default + newquery = table[tenant_fieldname] == default if query is None: query = newquery else: @@ -3396,7 +3411,10 @@ class GoogleDatastoreAdapter(NoSQLAdapter): query = fields[0].table._id>0 else: raise SyntaxError, "Unable to determine a tablename" - query = self.filter_tenant(query,[tablename]) + + if not query.ignore_common_filters: + query = self.common_filter(query,[tablename]) + tableobj = self.db[tablename]._tableobj items = tableobj.all() filters = self.expand(query) @@ -4571,6 +4589,7 @@ def index(): 'plural', 'trigger_name', 'sequence_name', + 'common_filter', 'polymodel', 'table_class']: raise SyntaxError, 'invalid table "%s" attribute: %s' \ @@ -4603,10 +4622,14 @@ def index(): if self._common_fields: fields = [f for f in fields] + [f for f in self._common_fields] + common_filter = args.get('common_filter', None) + t = self[tablename] = table_class(self, tablename, *fields, **dict(primarykey=primarykey, trigger_name=trigger_name, - sequence_name=sequence_name)) + sequence_name=sequence_name, + common_filter=common_filter)) + # db magic if self._uri in (None,'None'): return t @@ -4654,12 +4677,12 @@ def index(): def smart_query(self,fields,text): return Set(self, smart_query(fields,text)) - def __call__(self, query=None): + def __call__(self, query=None, ignore_common_filters=False): if isinstance(query,Table): query = query._id>0 elif isinstance(query,Field): query = query!=None - return Set(self, query) + return Set(self, query, ignore_common_filters=ignore_common_filters) def commit(self): self._adapter.commit() @@ -4848,7 +4871,7 @@ class Table(dict): db and db._adapter.sequence_name(tablename) self._trigger_name = args.get('trigger_name',None) or \ db and db._adapter.trigger_name(tablename) - + self._common_filter = args.get('common_filter', None) primarykey = args.get('primarykey', None) fieldnames,newfields=set(),[] if primarykey: @@ -5754,11 +5777,13 @@ class Query(object): op, first=None, second=None, + ignore_common_filters = False, ): self.db = self._db = db self.op = op self.first = first self.second = second + self.ignore_common_filters = ignore_common_filters def __str__(self): return self.db._adapter.expand(self) @@ -5804,12 +5829,15 @@ class Set(object): subset = set(db.users.id<5) """ - def __init__(self, db, query): + def __init__(self, db, query, ignore_common_filters = False): self.db = db self._db = db # for backward compatibility self.query = query + if query: + ### Some care here we shuld copy query but we are not for speed! + query.ignore_common_filters = ignore_common_filters - def __call__(self, query): + def __call__(self, query, ignore_common_filters=False): if isinstance(query,Table): query = query._id>0 elif isinstance(query,str): @@ -5817,9 +5845,11 @@ class Set(object): elif isinstance(query,Field): query = query!=None if self.query: - return Set(self.db, self.query & query) + return Set(self.db, self.query & query, + ignore_common_filters = ignore_common_filters) else: - return Set(self.db, query) + return Set(self.db, query, + ignore_common_filters = ignore_common_filters) def _count(self,distinct=None): return self.db._adapter._count(self.query,distinct)