db.define_table(...,common_filter) and db(...,ignore_common_filter=True), thanks Yair Eshel

This commit is contained in:
Massimo Di Pierro
2011-12-04 17:45:28 -06:00
parent cf413e568f
commit d92dec8451
5 changed files with 68 additions and 38 deletions
+1 -1
View File
@@ -1 +1 @@
Version 1.99.3 (2011-12-04 16:24:53) dev
Version 1.99.3 (2011-12-04 17:45:21) dev
+5 -5
View File
@@ -150,7 +150,7 @@ def csv():
return None
response.headers['Content-disposition'] = 'attachment; filename=%s_%s.csv'\
% tuple(request.vars.query.split('.')[:2])
return str(db(query).select())
return str(db(query).select(ignore_common_filter=True))
def import_csv(table, file):
@@ -225,10 +225,10 @@ def select():
response.flash = T('%s rows deleted', nrows)
nrows = db(query).count()
if orderby:
rows = db(query).select(limitby=(start, stop),
rows = db(query).select(limitby=(start, stop),ignore_common_filter=True,
orderby=eval_in_global_env(orderby))
else:
rows = db(query).select(limitby=(start, stop))
rows = db(query).select(limitby=(start, stop), ignore_common_filter=True)
except Exception, e:
(rows, nrows) = ([], 0)
response.flash = DIV(T('Invalid Query'),PRE(str(e)))
@@ -255,9 +255,9 @@ def update():
if keyed:
key = [f for f in request.vars if f in db[table]._primarykey]
if key:
record = db(db[table][key[0]] == request.vars[key[0]]).select().first()
record = db(db[table][key[0]] == request.vars[key[0]]).select(ignore_common_filter=True).first()
else:
record = db(db[table].id == request.args(2)).select().first()
record = db(db[table].id == request.args(2)).select(ignore_common_filter=True).first()
if not record:
qry = query_by_table_type(table, db)
@@ -150,7 +150,7 @@ def csv():
return None
response.headers['Content-disposition'] = 'attachment; filename=%s_%s.csv'\
% tuple(request.vars.query.split('.')[:2])
return str(db(query).select())
return str(db(query).select(ignore_common_filter=True))
def import_csv(table, file):
@@ -225,10 +225,10 @@ def select():
response.flash = T('%s rows deleted', nrows)
nrows = db(query).count()
if orderby:
rows = db(query).select(limitby=(start, stop),
rows = db(query).select(limitby=(start, stop),ignore_common_filter=True,
orderby=eval_in_global_env(orderby))
else:
rows = db(query).select(limitby=(start, stop))
rows = db(query).select(limitby=(start, stop), ignore_common_filter=True)
except Exception, e:
(rows, nrows) = ([], 0)
response.flash = DIV(T('Invalid Query'),PRE(str(e)))
@@ -255,9 +255,9 @@ def update():
if keyed:
key = [f for f in request.vars if f in db[table]._primarykey]
if key:
record = db(db[table][key[0]] == request.vars[key[0]]).select().first()
record = db(db[table][key[0]] == request.vars[key[0]]).select(ignore_common_filter=True).first()
else:
record = db(db[table].id == request.args(2)).select().first()
record = db(db[table].id == request.args(2)).select(ignore_common_filter=True).first()
if not record:
qry = query_by_table_type(table, db)
+5 -5
View File
@@ -150,7 +150,7 @@ def csv():
return None
response.headers['Content-disposition'] = 'attachment; filename=%s_%s.csv'\
% tuple(request.vars.query.split('.')[:2])
return str(db(query).select())
return str(db(query).select(ignore_common_filter=True))
def import_csv(table, file):
@@ -225,10 +225,10 @@ def select():
response.flash = T('%s rows deleted', nrows)
nrows = db(query).count()
if orderby:
rows = db(query).select(limitby=(start, stop),
rows = db(query).select(limitby=(start, stop),ignore_common_filter=True,
orderby=eval_in_global_env(orderby))
else:
rows = db(query).select(limitby=(start, stop))
rows = db(query).select(limitby=(start, stop), ignore_common_filter=True)
except Exception, e:
(rows, nrows) = ([], 0)
response.flash = DIV(T('Invalid Query'),PRE(str(e)))
@@ -255,9 +255,9 @@ def update():
if keyed:
key = [f for f in request.vars if f in db[table]._primarykey]
if key:
record = db(db[table][key[0]] == request.vars[key[0]]).select().first()
record = db(db[table][key[0]] == request.vars[key[0]]).select(ignore_common_filter=True).first()
else:
record = db(db[table].id == request.args(2)).select().first()
record = db(db[table].id == request.args(2)).select(ignore_common_filter=True).first()
if not record:
qry = query_by_table_type(table, db)
+52 -22
View File
@@ -1077,13 +1077,15 @@ class BaseAdapter(ConnectionPool):
finally:
logfile.close()
def _update(self, tablename, query,fields):
query = self.filter_tenant(query, [tablename])
def _update(self, tablename, query, fields):
if query:
if not query.ignore_common_filters:
query = self.common_filter(query, [tablename])
sql_w = ' WHERE ' + self.expand(query)
else:
sql_w = ''
sql_v = ','.join(['%s=%s' % (field.name, self.expand(value, field.type)) for (field, value) in fields])
sql_v = ','.join(['%s=%s' % (field.name, self.expand(value, field.type)) \
for (field, value) in fields])
return 'UPDATE %s SET %s%s;' % (tablename, sql_v, sql_w)
def update(self, tablename, query, fields):
@@ -1095,8 +1097,9 @@ class BaseAdapter(ConnectionPool):
return None
def _delete(self, tablename, query):
query = self.filter_tenant(query, [tablename])
if query:
if not query.ignore_common_filters:
query = self.common_filter(query, [tablename])
sql_w = ' WHERE ' + self.expand(query)
else:
sql_w = ''
@@ -1136,7 +1139,8 @@ class BaseAdapter(ConnectionPool):
def _select(self, query, fields, attributes):
for key in set(attributes.keys())-set(('orderby', 'groupby', 'limitby',
'required', 'cache', 'left',
'distinct', 'having', 'join', 'for_update')):
'distinct', 'having', 'join',
'for_update')):
raise SyntaxError, 'invalid select attribute: %s' % key
# ## if no fields specified take them all from the requested tables
new_fields = []
@@ -1146,8 +1150,11 @@ class BaseAdapter(ConnectionPool):
else:
new_fields.append(item)
fields = new_fields
tablenames = self.tables(query)
query = self.filter_tenant(query,tablenames)
tablenames = self.tables(query)
if not query.ignore_common_filters:
query = self.common_filter(query,tablenames)
if not fields:
for table in tablenames:
for field in self.db[table]:
@@ -1281,8 +1288,9 @@ class BaseAdapter(ConnectionPool):
def _count(self, query, distinct=None):
tablenames = self.tables(query)
query = self.filter_tenant(query, tablenames)
if query:
if not query.ignore_common_filters:
query = self.common_filter(query, tablenames)
sql_w = ' WHERE ' + self.expand(query)
else:
sql_w = ''
@@ -1570,14 +1578,21 @@ class BaseAdapter(ConnectionPool):
pass
return rowsobj
def filter_tenant(self, query, tablenames):
fieldname = self.db._request_tenant
def common_filter(self, query, tablenames):
tenant_fieldname = self.db._request_tenant
for tablename in tablenames:
table = self.db[tablename]
if fieldname in table:
default = table[fieldname].default
# deal with user provided filters
if table._common_filter != None:
query = query & table._common_filter(query)
# deal with multi_tenant filters
if tenant_fieldname in table:
default = table[tenant_fieldname].default
if not default is None:
newquery = table[fieldname] == default
newquery = table[tenant_fieldname] == default
if query is None:
query = newquery
else:
@@ -3396,7 +3411,10 @@ class GoogleDatastoreAdapter(NoSQLAdapter):
query = fields[0].table._id>0
else:
raise SyntaxError, "Unable to determine a tablename"
query = self.filter_tenant(query,[tablename])
if not query.ignore_common_filters:
query = self.common_filter(query,[tablename])
tableobj = self.db[tablename]._tableobj
items = tableobj.all()
filters = self.expand(query)
@@ -4571,6 +4589,7 @@ def index():
'plural',
'trigger_name',
'sequence_name',
'common_filter',
'polymodel',
'table_class']:
raise SyntaxError, 'invalid table "%s" attribute: %s' \
@@ -4603,10 +4622,14 @@ def index():
if self._common_fields:
fields = [f for f in fields] + [f for f in self._common_fields]
common_filter = args.get('common_filter', None)
t = self[tablename] = table_class(self, tablename, *fields,
**dict(primarykey=primarykey,
trigger_name=trigger_name,
sequence_name=sequence_name))
sequence_name=sequence_name,
common_filter=common_filter))
# db magic
if self._uri in (None,'None'):
return t
@@ -4654,12 +4677,12 @@ def index():
def smart_query(self,fields,text):
return Set(self, smart_query(fields,text))
def __call__(self, query=None):
def __call__(self, query=None, ignore_common_filters=False):
if isinstance(query,Table):
query = query._id>0
elif isinstance(query,Field):
query = query!=None
return Set(self, query)
return Set(self, query, ignore_common_filters=ignore_common_filters)
def commit(self):
self._adapter.commit()
@@ -4848,7 +4871,7 @@ class Table(dict):
db and db._adapter.sequence_name(tablename)
self._trigger_name = args.get('trigger_name',None) or \
db and db._adapter.trigger_name(tablename)
self._common_filter = args.get('common_filter', None)
primarykey = args.get('primarykey', None)
fieldnames,newfields=set(),[]
if primarykey:
@@ -5754,11 +5777,13 @@ class Query(object):
op,
first=None,
second=None,
ignore_common_filters = False,
):
self.db = self._db = db
self.op = op
self.first = first
self.second = second
self.ignore_common_filters = ignore_common_filters
def __str__(self):
return self.db._adapter.expand(self)
@@ -5804,12 +5829,15 @@ class Set(object):
subset = set(db.users.id<5)
"""
def __init__(self, db, query):
def __init__(self, db, query, ignore_common_filters = False):
self.db = db
self._db = db # for backward compatibility
self.query = query
if query:
### Some care here we shuld copy query but we are not for speed!
query.ignore_common_filters = ignore_common_filters
def __call__(self, query):
def __call__(self, query, ignore_common_filters=False):
if isinstance(query,Table):
query = query._id>0
elif isinstance(query,str):
@@ -5817,9 +5845,11 @@ class Set(object):
elif isinstance(query,Field):
query = query!=None
if self.query:
return Set(self.db, self.query & query)
return Set(self.db, self.query & query,
ignore_common_filters = ignore_common_filters)
else:
return Set(self.db, query)
return Set(self.db, query,
ignore_common_filters = ignore_common_filters)
def _count(self,distinct=None):
return self.db._adapter._count(self.query,distinct)