From 7d59bcab720faee9c32e719a584940713e7864c3 Mon Sep 17 00:00:00 2001 From: niphlod Date: Mon, 13 Jan 2014 23:35:42 +0100 Subject: [PATCH] first steps to a cleaner DAL and rname integration --- gluon/dal.py | 382 +++++++++++++++++++++------------------- gluon/tests/test_dal.py | 84 +++++---- 2 files changed, 242 insertions(+), 224 deletions(-) diff --git a/gluon/dal.py b/gluon/dal.py index 6fdbe98e..4181e596 100644 --- a/gluon/dal.py +++ b/gluon/dal.py @@ -1181,7 +1181,7 @@ class BaseAdapter(ConnectionPool): "'%(table)s', '%(field)s');" % dict(schema=schema, table=tablename, field=key,) ] elif self.dbengine in ('firebird',): - query = ['ALTER TABLE %s DROP %s;' % + query = ['ALTER TABLE %s DROP %s;' % (self.QUOTE_TEMPLATE % tablename, self.QUOTE_TEMPLATE % key)] else: query = ['ALTER TABLE %s DROP COLUMN %s;' % @@ -1929,7 +1929,7 @@ class BaseAdapter(ConnectionPool): def create_sequence_and_triggers(self, query, table, **args): self.execute(query) - + def log_execute(self, *a, **b): if not self.connection: raise ValueError(a[0]) @@ -3292,7 +3292,7 @@ class OracleAdapter(BaseAdapter): class MSSQLAdapter(BaseAdapter): drivers = ('pyodbc',) T_SEP = 'T' - + QUOTE_TEMPLATE = '"%s"' types = { @@ -5159,10 +5159,10 @@ class GoogleDatastoreAdapter(NoSQLAdapter): items.order(tableobj._key) else: items.order('__key__') - items = self.filter(items, tableobj, filter.name, + items = self.filter(items, tableobj, filter.name, filter.op, filter.value) \ if self.use_ndb else \ - items.filter('%s %s' % (filter.name,filter.op), + items.filter('%s %s' % (filter.name,filter.op), filter.value) if not isinstance(items,list): @@ -8611,8 +8611,7 @@ class Table(object): db, tablename, *fields, - **args - ): + **args): """ Initializes the table and performs checking on the provided fields. @@ -8623,20 +8622,27 @@ class Table(object): :raises SyntaxError: when a supplied field is of incorrect type. """ - self._actual = False # set to True by define_table() + self._actual = False # set to True by define_table() self._tablename = tablename - self._ot = None # args.get('rname') + if (not isinstance(tablename, str) or tablename[0] == '_' + or hasattr(DAL, tablename) or '.' in tablename + or REGEX_PYTHON_KEYWORDS.match(tablename) + ): + raise SyntaxError('Field: invalid table name: %s, ' + 'use rname for "funny" names' % tablename) + self._ot = None self._rname = args.get('rname') - self._sequence_name = args.get('sequence_name') or \ - db and db._adapter.sequence_name(self._rname or tablename) - self._trigger_name = args.get('trigger_name') or \ - db and db._adapter.trigger_name(tablename) + self._sequence_name = (args.get('sequence_name') or + db and db._adapter.sequence_name(self._rname + or tablename)) + self._trigger_name = (args.get('trigger_name') or + db and db._adapter.trigger_name(tablename)) self._common_filter = args.get('common_filter') self._format = args.get('format') self._singular = args.get( - 'singular',tablename.replace('_',' ').capitalize()) + 'singular', tablename.replace('_', ' ').capitalize()) self._plural = args.get( - 'plural',pluralize(self._singular.lower()).capitalize()) + 'plural', pluralize(self._singular.lower()).capitalize()) # horrible but for backard compatibility of appamdin: if 'primarykey' in args and args['primarykey'] is not None: self._primarykey = args.get('primarykey') @@ -8650,28 +8656,29 @@ class Table(object): self.add_method = MethodAdder(self) - fieldnames,newfields=set(),[] + fieldnames, newfields=set(), [] _primarykey = getattr(self, '_primarykey', None) if _primarykey is not None: if not isinstance(_primarykey, list): raise SyntaxError( - "primarykey must be a list of fields from table '%s'" \ + "primarykey must be a list of fields from table '%s'" % tablename) - if len(_primarykey)==1: - self._id = [f for f in fields if isinstance(f,Field) \ - and f.name==_primarykey[0]][0] - elif not [f for f in fields if (isinstance(f,Field) and - f.type=='id') or (isinstance(f, dict) and - f.get("type", None)=="id")]: + if len(_primarykey) == 1: + self._id = [f for f in fields if isinstance(f, Field) + and f.name ==_primarykey[0]][0] + elif not [f for f in fields if (isinstance(f, Field) and + f.type == 'id') or (isinstance(f, dict) and + f.get("type", None) == "id")]: field = Field('id', 'id') newfields.append(field) fieldnames.add('id') self._id = field virtual_fields = [] + def include_new(field): newfields.append(field) fieldnames.add(field.name) - if field.type=='id': + if field.type == 'id': self._id = field for field in fields: if isinstance(field, (FieldMethod, FieldVirtual)): @@ -8685,7 +8692,7 @@ class Table(object): elif isinstance(field, Table): table = field for field in table: - if not field.name in fieldnames and not field.type=='id': + if not field.name in fieldnames and not field.type == 'id': t2 = not table._actual and self._tablename include_new(field.clone(point_self_references_to=t2)) elif not isinstance(field, (Field, Table)): @@ -8698,16 +8705,16 @@ class Table(object): self.virtualfields = [] fields = list(fields) - if db and db._adapter.uploads_in_blob==True: - uploadfields = [f.name for f in fields if f.type=='blob'] + if db and db._adapter.uploads_in_blob is True: + uploadfields = [f.name for f in fields if f.type == 'blob'] for field in fields: fn = field.uploadfield if isinstance(field, Field) and field.type == 'upload'\ and fn is True: fn = field.uploadfield = '%s_blob' % field.name - if isinstance(fn,str) and not fn in uploadfields: - fields.append(Field(fn,'blob',default='', - writable=False,readable=False)) + if isinstance(fn, str) and not fn in uploadfields: + fields.append(Field(fn, 'blob', default='', + writable=False, readable=False)) fieldnames_set = set() reserved = dir(Table) + ['fields'] @@ -8715,8 +8722,8 @@ class Table(object): check_reserved = db.check_reserved_keyword else: def check_reserved(field_name): - if field_name in reserved: - raise SyntaxError("field name %s not allowed" % field_name) + if field_name in reserved: + raise SyntaxError("field name %s not allowed" % field_name) for field in fields: field_name = field.name check_reserved(field_name) @@ -8725,8 +8732,8 @@ class Table(object): else: fname_item = field_name if fname_item in fieldnames_set: - raise SyntaxError("duplicate field %s in table %s" \ - % (field_name, tablename)) + raise SyntaxError("duplicate field %s in table %s" % + (field_name, tablename)) else: fieldnames_set.add(fname_item) @@ -8743,7 +8750,8 @@ class Table(object): for k in _primarykey: if k not in self.fields: raise SyntaxError( - "primarykey must be a list of fields from table '%s " % tablename) + "primarykey must be a list of fields from table '%s " % + tablename) else: self[k].notnull = True for field in virtual_fields: @@ -8753,52 +8761,53 @@ class Table(object): def fields(self): return self._fields - def update(self,*args,**kwargs): + def update(self, *args, **kwargs): raise RuntimeError("Syntax Not Supported") def _enable_record_versioning(self, archive_db=None, - archive_name = '%(tablename)s_archive', - is_active = 'is_active', - current_record = 'current_record', - current_record_label = None): + archive_name='%(tablename)s_archive', + is_active='is_active', + current_record='current_record', + current_record_label=None): db = self._db archive_db = archive_db or db archive_name = archive_name % dict(tablename=self._tablename) if archive_name in archive_db.tables(): - return # do not try define the archive if already exists + return # do not try define the archive if already exists fieldnames = self.fields() same_db = archive_db is db field_type = self if same_db else 'bigint' clones = [] for field in self: nfk = same_db or not field.type.startswith('reference') - clones.append(field.clone( - unique=False, type=field.type if nfk else 'bigint')) + clones.append( + field.clone(unique=False, type=field.type if nfk else 'bigint') + ) archive_db.define_table( archive_name, - Field(current_record,field_type,label=current_record_label), - *clones,**dict(format=self._format)) + Field(current_record, field_type, label=current_record_label), + *clones, **dict(format=self._format)) self._before_update.append( - lambda qset,fs,db=archive_db,an=archive_name,cn=current_record: - archive_record(qset,fs,db[an],cn)) + lambda qset, fs, db=archive_db, an=archive_name, cn=current_record: + archive_record(qset, fs, db[an], cn)) if is_active and is_active in fieldnames: self._before_delete.append( lambda qset: qset.update(is_active=False)) newquery = lambda query, t=self, name=self._tablename: \ - reduce(AND,[db[tn].is_active == True + reduce(AND, [db[tn].is_active == True for tn in db._adapter.tables(query) - if tn==name or getattr(db[tn],'_ot',None)==name]) + if tn == name or getattr(db[tn],'_ot',None)==name]) query = self._common_filter if query: newquery = query & newquery self._common_filter = newquery - def _validate(self,**vars): + def _validate(self, **vars): errors = Row() - for key,value in vars.iteritems(): - value,error = self[key].validate(value) + for key, value in vars.iteritems(): + value, error = self[key].validate(value) if error: errors[key] = error return errors @@ -8809,28 +8818,29 @@ class Table(object): self._referenced_by = [] self._references = [] for field in self: - fieldname = field.name + #fieldname = field.name ##FIXME not used ? field_type = field.type - if isinstance(field_type,str) and field_type[:10] == 'reference ': + if isinstance(field_type, str) and field_type[:10] == 'reference ': ref = field_type[10:].strip() if not ref: - SyntaxError('Table: reference to nothing: %s' %ref) + SyntaxError('Table: reference to nothing: %s' % ref) if '.' in ref: - rtablename, throw_it,rfieldname = ref.partition('.') + rtablename, throw_it, rfieldname = ref.partition('.') else: rtablename, rfieldname = ref, None if not rtablename in db: - pr[rtablename] = pr.get(rtablename,[]) + [field] + pr[rtablename] = pr.get(rtablename, []) + [field] continue rtable = db[rtablename] if rfieldname: - if not hasattr(rtable,'_primarykey'): + if not hasattr(rtable, '_primarykey'): raise SyntaxError( 'keyed tables can only reference other keyed tables (for now)') if rfieldname not in rtable.fields: raise SyntaxError( - "invalid field '%s' for referenced table '%s' in table '%s'" \ - % (rfieldname, rtablename, self._tablename)) + "invalid field '%s' for referenced table '%s'" + " in table '%s'" % (rfieldname, rtablename, self._tablename) + ) rfield = rtable[rfieldname] else: rfield = rtable._id @@ -8844,7 +8854,6 @@ class Table(object): for referee in referees: self._referenced_by.append(referee) - def _filter_fields(self, record, id=False): return dict([(k, v) for (k, v) in record.iteritems() if k in self.fields and (self[k].type!='id' or id)]) @@ -8860,8 +8869,9 @@ class Table(object): query = (self[k] == v) else: raise SyntaxError( - 'Field %s is not part of the primary key of %s' % \ - (k,self._tablename)) + 'Field %s is not part of the primary key of %s' % + (k,self._tablename) + ) return query def __getitem__(self, key): @@ -8870,18 +8880,20 @@ class Table(object): elif isinstance(key, dict): """ for keyed table """ query = self._build_query(key) - return self._db(query).select(limitby=(0,1), orderby_on_limitby=False).first() + return self._db(query).select(limitby=(0, 1), orderby_on_limitby=False).first() elif str(key).isdigit() or 'google' in DRIVERS and isinstance(key, Key): - return self._db(self._id == key).select(limitby=(0,1), orderby_on_limitby=False).first() + return self._db(self._id == key).select(limitby=(0, 1), orderby_on_limitby=False).first() elif key: return ogetattr(self, str(key)) def __call__(self, key=DEFAULT, **kwargs): - for_update = kwargs.get('_for_update',False) - if '_for_update' in kwargs: del kwargs['_for_update'] + for_update = kwargs.get('_for_update', False) + if '_for_update' in kwargs: + del kwargs['_for_update'] - orderby = kwargs.get('_orderby',None) - if '_orderby' in kwargs: del kwargs['_orderby'] + orderby = kwargs.get('_orderby', None) + if '_orderby' in kwargs: + del kwargs['_orderby'] if not key is DEFAULT: if isinstance(key, Query): @@ -8915,7 +8927,7 @@ class Table(object): self._db(query).update(**self._filter_fields(value)) else: raise SyntaxError( - 'key must have all fields from primary key: %s'%\ + 'key must have all fields from primary key: %s'% (self._primarykey)) elif str(key).isdigit(): if key == 0: @@ -8946,7 +8958,7 @@ class Table(object): raise SyntaxError('No such record: %s' % key) def __contains__(self,key): - return hasattr(self,key) + return hasattr(self, key) has_key = __contains__ @@ -8960,9 +8972,8 @@ class Table(object): def iteritems(self): return self.__dict__.iteritems() - def __repr__(self): - return '' % (self._tablename,','.join(self.fields())) + return '
' % (self._tablename, ','.join(self.fields())) def __str__(self): if self._ot is not None: @@ -8970,7 +8981,7 @@ class Table(object): if 'Oracle' in str(type(self._db._adapter)): return '%s %s' % (ot, self._tablename) return '%s AS %s' % (ot, self._tablename) - + return self._tablename @property @@ -8987,14 +8998,14 @@ class Table(object): return self._db._adapter.sqlsafe_table(self._tablename, self._ot) - def _drop(self, mode = ''): + def _drop(self, mode=''): return self._db._adapter._drop(self, mode) - def drop(self, mode = ''): + def drop(self, mode=''): return self._db._adapter.drop(self,mode) def _listify(self,fields,update=False): - new_fields = {} # format: new_fields[name] = (field,value) + new_fields = {} # format: new_fields[name] = (field,value) # store all fields passed as input in new_fields for name in fields: @@ -9007,7 +9018,7 @@ class Table(object): value = fields[name] if field.filter_in: value = field.filter_in(value) - new_fields[name] = (field,value) + new_fields[name] = (field, value) # check all fields that should be in the table but are not passed to_compute = [] @@ -9015,18 +9026,18 @@ class Table(object): name = ofield.name if not name in new_fields: # if field is supposed to be computed, compute it! - if ofield.compute: # save those to compute for later - to_compute.append((name,ofield)) + if ofield.compute: # save those to compute for later + to_compute.append((name, ofield)) # if field is required, check its default value elif not update and not ofield.default is None: value = ofield.default fields[name] = value - new_fields[name] = (ofield,value) + new_fields[name] = (ofield, value) # if this is an update, user the update field instead elif update and not ofield.update is None: value = ofield.update fields[name] = value - new_fields[name] = (ofield,value) + new_fields[name] = (ofield, value) # if the field is still not there but it should, error elif not update and ofield.required: raise RuntimeError( @@ -9034,7 +9045,7 @@ class Table(object): # now deal with fields that are supposed to be computed if to_compute: row = Row(fields) - for name,ofield in to_compute: + for name, ofield in to_compute: # try compute it try: row[name] = new_value = ofield.compute(row) @@ -9047,13 +9058,13 @@ class Table(object): def _attempt_upload(self, fields): for field in self: - if field.type=='upload' and field.name in fields: + if field.type == 'upload' and field.name in fields: value = fields[field.name] - if value is not None and not isinstance(value,str): - if hasattr(value,'file') and hasattr(value,'filename'): - new_name = field.store(value.file,filename=value.filename) - elif hasattr(value,'read') and hasattr(value,'name'): - new_name = field.store(value,filename=value.name) + if value is not None and not isinstance(value, str): + if hasattr(value, 'file') and hasattr(value, 'filename'): + new_name = field.store(value.file, filename=value.filename) + elif hasattr(value, 'read') and hasattr(value, 'name'): + new_name = field.store(value, filename=value.name) else: raise RuntimeError("Unable to handle upload") fields[field.name] = new_name @@ -9081,7 +9092,7 @@ class Table(object): [f(fields,ret) for f in self._after_insert] return ret - def validate_and_insert(self,**fields): + def validate_and_insert(self, **fields): response = Row() response.errors = Row() new_fields = copy.copy(fields) @@ -9102,22 +9113,22 @@ class Table(object): response.errors = Row() new_fields = copy.copy(fields) - for key,value in fields.iteritems(): - value,error = self[key].validate(value) + for key, value in fields.iteritems(): + value, error = self[key].validate(value) if error: response.errors[key] = "%s" % error else: new_fields[key] = value if _key is DEFAULT: - record = self(**values) - elif isinstance(_key,dict): + record = self(**values) + elif isinstance(_key, dict): record = self(**_key) else: record = self(_key) if not response.errors and record: - row = self._db(self._id==_key) + row = self._db(self._id ==_key) response.id = row.update(**fields) else: response.id = None @@ -9126,7 +9137,7 @@ class Table(object): def update_or_insert(self, _key=DEFAULT, **values): if _key is DEFAULT: record = self(**values) - elif isinstance(_key,dict): + elif isinstance(_key, dict): record = self(**_key) else: record = self(_key) @@ -9147,10 +9158,10 @@ class Table(object): ret and [[f(item,ret[k]) for k,item in enumerate(items)] for f in self._after_insert] return ret - def _truncate(self, mode = None): + def _truncate(self, mode=None): return self._db._adapter._truncate(self, mode) - def truncate(self, mode = None): + def truncate(self, mode=None): return self._db._adapter.truncate(self, mode) def import_from_csv_file( @@ -9159,7 +9170,7 @@ class Table(object): id_map=None, null='', unique='uuid', - id_offset=None, # id_offset used only when id_map is None + id_offset=None, # id_offset used only when id_map is None *args, **kwargs ): """ @@ -9295,11 +9306,15 @@ class Table(object): id_map_self[long(line[cid])] = new_id def as_dict(self, flat=False, sanitize=True): - table_as_dict = dict(tablename=str(self), fields=[], - sequence_name=self._sequence_name, - trigger_name=self._trigger_name, - common_filter=self._common_filter, format=self._format, - singular=self._singular, plural=self._plural) + table_as_dict = dict( + tablename=str(self), + fields=[], + sequence_name=self._sequence_name, + trigger_name=self._trigger_name, + common_filter=self._common_filter, + format=self._format, + singular=self._singular, + plural=self._plural) for field in self: if (field.readable or field.writable) or (not sanitize): @@ -9326,15 +9341,16 @@ class Table(object): return serializers.yaml(d) def with_alias(self, alias): - return self._db._adapter.alias(self,alias) + return self._db._adapter.alias(self, alias) def on(self, query): - return Expression(self._db,self._db._adapter.ON,self,query) + return Expression(self._db, self._db._adapter.ON, self, query) -def archive_record(qset,fs,archive_table,current_record): + +def archive_record(qset, fs, archive_table, current_record): tablenames = qset.db._adapter.tables(qset.query) - if len(tablenames)!=1: raise RuntimeError("cannot update join") - table = qset.db[tablenames[0]] + if len(tablenames) != 1: + raise RuntimeError("cannot update join") for row in qset.select(): fields = archive_table._filter_fields(row) fields[current_record] = row.id @@ -9342,7 +9358,6 @@ def archive_record(qset,fs,archive_table,current_record): return False - class Expression(object): def __init__( @@ -9399,9 +9414,9 @@ class Expression(object): db = self.db return Expression(db, db._adapter.UPPER, self, None, self.type) - def replace(self,a,b): + def replace(self, a, b): db = self.db - return Expression(db, db._adapter.REPLACE, self, (a,b), self.type) + return Expression(db, db._adapter.REPLACE, self, (a, b), self.type) def year(self): db = self.db @@ -9423,7 +9438,7 @@ class Expression(object): db = self.db return Expression(db, db._adapter.EXTRACT, self, 'minute', 'integer') - def coalesce(self,*others): + def coalesce(self, *others): db = self.db return Expression(db, db._adapter.COALESCE, self, others, self.type) @@ -9452,32 +9467,32 @@ class Expression(object): length = self.len() else: length = '(%s - %s)' % (stop + 1, pos0) - return Expression(db,db._adapter.SUBSTRING, + return Expression(db, db._adapter.SUBSTRING, self, (pos0, length), self.type) def __getitem__(self, i): return self[i:i + 1] def __str__(self): - return self.db._adapter.expand(self,self.type) + return self.db._adapter.expand(self, self.type) def __or__(self, other): # for use in sortby db = self.db - return Expression(db,db._adapter.COMMA,self,other,self.type) + return Expression(db, db._adapter.COMMA, self, other, self.type) def __invert__(self): db = self.db if hasattr(self,'_op') and self.op == db._adapter.INVERT: return self.first - return Expression(db,db._adapter.INVERT,self,type=self.type) + return Expression(db, db._adapter.INVERT, self, type=self.type) def __add__(self, other): db = self.db - return Expression(db,db._adapter.ADD,self,other,self.type) + return Expression(db, db._adapter.ADD, self, other, self.type) def __sub__(self, other): db = self.db - if self.type in ('integer','bigint'): + if self.type in ('integer', 'bigint'): result_type = 'integer' elif self.type in ['date','time','datetime','double','float']: result_type = 'double' @@ -9724,20 +9739,23 @@ class FieldVirtual(object): def __str__(self): return '%s.%s' % (self.tablename, self.name) + class FieldMethod(object): def __init__(self, name, f=None, handler=None): # for backward compatibility (self.name, self.f) = (name, f) if f else ('unknown', name) self.handler = handler + def list_represent(x,r=None): return ', '.join(str(y) for y in x or []) + class Field(Expression): Virtual = FieldVirtual Method = FieldMethod - Lazy = FieldMethod # for backward compatibility + Lazy = FieldMethod # for backward compatibility """ an instance of this class represents a database field @@ -9798,13 +9816,13 @@ class Field(Expression): custom_retrieve=None, custom_retrieve_file_properties=None, custom_delete=None, - filter_in = None, - filter_out = None, - custom_qualifier = None, - map_none = None, - rname = None + filter_in=None, + filter_out=None, + custom_qualifier=None, + map_none=None, + rname=None ): - self._db = self.db = None # both for backward compatibility + self._db = self.db = None # both for backward compatibility self.op = None self.first = None self.second = None @@ -9815,16 +9833,18 @@ class Field(Expression): raise SyntaxError('Field: invalid unicode field name') self.name = fieldname = cleanup(fieldname) if not isinstance(fieldname, str) or hasattr(Table, fieldname) or \ - fieldname[0] == '_' or REGEX_PYTHON_KEYWORDS.match(fieldname): - raise SyntaxError('Field: invalid field name: %s' % fieldname) + fieldname[0] == '_' or '.' in fieldname or \ + REGEX_PYTHON_KEYWORDS.match(fieldname): + raise SyntaxError('Field: invalid field name: %s, ' + 'use rname for "funny" names' % fieldname) - if not isinstance(type, (Table,Field)): + if not isinstance(type, (Table, Field)): self.type = type else: self.type = 'reference %s' % type - self.length = length if not length is None else DEFAULTLENGTH.get(self.type,512) - self.default = default if default!=DEFAULT else (update or None) + self.length = length if not length is None else DEFAULTLENGTH.get(self.type, 512) + self.default = default if default != DEFAULT else (update or None) self.required = required # is this field required self.ondelete = ondelete.upper() # this is for reference fields only self.notnull = notnull @@ -9840,8 +9860,8 @@ class Field(Expression): self.update = update self.authorize = authorize self.autodelete = autodelete - self.represent = list_represent if \ - represent==None and type in ('list:integer','list:string') else represent + self.represent = (list_represent if represent is None and + type in ('list:integer', 'list:string') else represent) self.compute = compute self.isattachment = True self.custom_store = custom_store @@ -9851,15 +9871,16 @@ class Field(Expression): self.filter_in = filter_in self.filter_out = filter_out self.custom_qualifier = custom_qualifier - self.label = label if label!=None else fieldname.replace('_',' ').title() - self.requires = requires if requires!=None else [] + self.label = (label if label is not None else + fieldname.replace('_', ' ').title()) + self.requires = requires if requires is not None else [] self.map_none = map_none self._rname = rname - def set_attributes(self,*args,**attributes): - self.__dict__.update(*args,**attributes) + def set_attributes(self, *args, **attributes): + self.__dict__.update(*args, **attributes) - def clone(self,point_self_references_to=False,**args): + def clone(self, point_self_references_to=False, **args): field = copy.copy(self) if point_self_references_to and \ field.type == 'reference %s'+field._tablename: @@ -9869,14 +9890,13 @@ class Field(Expression): def store(self, file, filename=None, path=None): if self.custom_store: - return self.custom_store(file,filename,path) + return self.custom_store(file, filename, path) if isinstance(file, cgi.FieldStorage): filename = filename or file.filename file = file.file elif not filename: filename = file.name - filename = os.path.basename(filename.replace('/', os.sep)\ - .replace('\\', os.sep)) + filename = os.path.basename(filename.replace('/', os.sep).replace('\\', os.sep)) m = REGEX_STORE_PATTERN.search(filename) extension = m and m.group('e') or 'txt' uuid_key = web2py_uuid().replace('-', '')[-16:] @@ -9885,12 +9905,12 @@ class Field(Expression): (self._tablename, self.name, uuid_key, encoded_filename) newfilename = newfilename[:(self.length - 1 - len(extension))] + '.' + extension self_uploadfield = self.uploadfield - if isinstance(self_uploadfield,Field): + if isinstance(self_uploadfield, Field): blob_uploadfield_name = self_uploadfield.uploadfield - keys={self_uploadfield.name: newfilename, - blob_uploadfield_name: file.read()} + keys = {self_uploadfield.name: newfilename, + blob_uploadfield_name: file.read()} self_uploadfield.table.insert(**keys) - elif self_uploadfield == True: + elif self_uploadfield is True: if path: pass elif self.uploadfolder: @@ -9903,8 +9923,9 @@ class Field(Expression): if self.uploadseparate: if self.uploadfs: raise RuntimeError("not supported") - path = pjoin(path,"%s.%s" %(self._tablename, self.name), - uuid_key[:2]) + path = pjoin(path, "%s.%s" % ( + self._tablename, self.name), uuid_key[:2] + ) if not exists(path): os.makedirs(path) pathfilename = pjoin(path, newfilename) @@ -9916,7 +9937,8 @@ class Field(Expression): shutil.copyfileobj(file, dest_file) except IOError: raise IOError( - 'Unable to store file "%s" because invalid permissions, readonly file system, or filename too long' % pathfilename) + 'Unable to store file "%s" because invalid permissions, ' + 'readonly file system, or filename too long' % pathfilename) dest_file.close() return newfilename @@ -9935,11 +9957,11 @@ class Field(Expression): raise http.HTTP(404) if self.authorize and not self.authorize(row): raise http.HTTP(403) - file_properties = self.retrieve_file_properties(name,path) + file_properties = self.retrieve_file_properties(name, path) filename = file_properties['filename'] if isinstance(self_uploadfield, str): # ## if file is in DB stream = StringIO.StringIO(row[self_uploadfield] or '') - elif isinstance(self_uploadfield,Field): + elif isinstance(self_uploadfield, Field): blob_uploadfield_name = self_uploadfield.uploadfield query = self_uploadfield == name data = self_uploadfield.table(query)[blob_uploadfield_name] @@ -9951,10 +9973,10 @@ class Field(Expression): # ## if file is on regular filesystem # this is intentially a sting with filename and not a stream # this propagates and allows stream_file_or_304_or_206 to be called - fullname = pjoin(file_properties['path'],name) + fullname = pjoin(file_properties['path'], name) if nameonly: return (filename, fullname) - stream = open(fullname,'rb') + stream = open(fullname, 'rb') return (filename, stream) def retrieve_file_properties(self, name, path=None): @@ -9974,7 +9996,7 @@ class Field(Expression): filename = name # ## if file is in DB if isinstance(self_uploadfield, (str, Field)): - return dict(path=None,filename=filename) + return dict(path=None, filename=filename) # ## if file is on filesystem if not path: if self.uploadfolder: @@ -9985,9 +10007,8 @@ class Field(Expression): t = m.group('table') f = m.group('field') u = m.group('uuidkey') - path = pjoin(path,"%s.%s" % (t,f),u[:2]) - return dict(path=path,filename=filename) - + path = pjoin(path, "%s.%s" % (t, f), u[:2]) + return dict(path=path, filename=filename) def formatter(self, value): requires = self.requires @@ -10007,7 +10028,7 @@ class Field(Expression): def validate(self, value): if not self.requires or self.requires == DEFAULT: - return ((value if value!=self.map_none else None), None) + return ((value if value != self.map_none else None), None) requires = self.requires if not isinstance(requires, (list, tuple)): requires = [requires] @@ -10015,27 +10036,27 @@ class Field(Expression): (value, error) = validator(value) if error: return (value, error) - return ((value if value!=self.map_none else None), None) + return ((value if value != self.map_none else None), None) def count(self, distinct=None): return Expression(self.db, self.db._adapter.COUNT, self, distinct, 'integer') def as_dict(self, flat=False, sanitize=True): - attrs = ('name', 'authorize', 'represent', 'ondelete', - 'custom_store', 'autodelete', 'custom_retrieve', - 'filter_out', 'uploadseparate', 'widget', 'uploadfs', - 'update', 'custom_delete', 'uploadfield', 'uploadfolder', - 'custom_qualifier', 'unique', 'writable', 'compute', - 'map_none', 'default', 'type', 'required', 'readable', - 'requires', 'comment', 'label', 'length', 'notnull', - 'custom_retrieve_file_properties', 'filter_in') + attrs = ( + 'name', 'authorize', 'represent', 'ondelete', + 'custom_store', 'autodelete', 'custom_retrieve', + 'filter_out', 'uploadseparate', 'widget', 'uploadfs', + 'update', 'custom_delete', 'uploadfield', 'uploadfolder', + 'custom_qualifier', 'unique', 'writable', 'compute', + 'map_none', 'default', 'type', 'required', 'readable', + 'requires', 'comment', 'label', 'length', 'notnull', + 'custom_retrieve_file_properties', 'filter_in') serializable = (int, long, basestring, float, tuple, bool, type(None)) def flatten(obj): if isinstance(obj, dict): - return dict((flatten(k), flatten(v)) for k, v in - obj.items()) + return dict((flatten(k), flatten(v)) for k, v in obj.items()) elif isinstance(obj, (tuple, list, set)): return [flatten(v) for v in obj] elif isinstance(obj, serializable): @@ -10049,10 +10070,10 @@ class Field(Expression): d = dict() if not (sanitize and not (self.readable or self.writable)): for attr in attrs: - if flat: - d.update({attr: flatten(getattr(self, attr))}) - else: - d.update({attr: getattr(self, attr)}) + if flat: + d.update({attr: flatten(getattr(self, attr))}) + else: + d.update({attr: getattr(self, attr)}) d["fieldname"] = d.pop("name") return d @@ -10098,7 +10119,7 @@ class Field(Expression): @property def sqlsafe_name(self): return self._rname or self._db._adapter.sqlsafe_field(self.name) - + class Query(object): @@ -10120,7 +10141,7 @@ class Query(object): op, first=None, second=None, - ignore_common_filters = False, + ignore_common_filters=False, **optional_args ): self.db = self._db = db @@ -10177,6 +10198,7 @@ class Query(object): SERIALIZABLE_TYPES = (tuple, dict, set, list, int, long, float, basestring, type(None), bool) + def loop(d): newd = dict() for k, v in d.items(): @@ -10199,7 +10221,7 @@ class Query(object): newd[k] = v.__name__ elif isinstance(v, basestring): newd[k] = v - else: pass # not callable or string + else: pass # not callable or string elif isinstance(v, SERIALIZABLE_TYPES): if isinstance(v, dict): newd[k] = loop(v) @@ -10210,7 +10232,6 @@ class Query(object): return loop(self.__dict__) else: return self.__dict__ - def as_xml(self, sanitize=True): if have_serializers: xml = serializers.xml @@ -10227,6 +10248,7 @@ class Query(object): d = self.as_dict(flat=True, sanitize=sanitize) return json(d) + def xorify(orderby): if not orderby: return None @@ -10235,10 +10257,12 @@ def xorify(orderby): orderby2 = orderby2 | item return orderby2 + def use_common_filters(query): return (query and hasattr(query,'ignore_common_filters') and \ not query.ignore_common_filters) + class Set(object): """ @@ -10258,7 +10282,7 @@ class Set(object): def __init__(self, db, query, ignore_common_filters = None): self.db = db - self._db = db # for backward compatibility + self._db = db # for backward compatibility self.dquery = None # if query is a dict, parse it @@ -10850,7 +10874,6 @@ class Rows(object): "represent" attributes will be transformed). """ - if i is None: return (self.render(i, fields=fields) for i in range(len(self))) import sqlhtml @@ -10889,7 +10912,6 @@ class Rows(object): self.compact = compact return items - def as_dict(self, key='id', compact=True, @@ -10929,7 +10951,6 @@ class Rows(object): else: return dict([(key(r),r) for r in rows]) - def as_trees(self, parent_name='parent_id', children_name='children'): roots = [] drows = {} @@ -10964,6 +10985,7 @@ class Rows(object): represent = kwargs.get('represent', False) writer = csv.writer(ofile, delimiter=delimiter, quotechar=quotechar, quoting=quoting) + def unquote_colnames(colnames): unq_colnames = [] for col in colnames: diff --git a/gluon/tests/test_dal.py b/gluon/tests/test_dal.py index 0ca8b585..0aaae6f6 100644 --- a/gluon/tests/test_dal.py +++ b/gluon/tests/test_dal.py @@ -79,9 +79,6 @@ def tearDownModule(): class TestFields(unittest.TestCase): def testFieldName(self): - return - - # Any table name is supported as long as underlying db does. The following code is ignored. # Check that Fields cannot start with underscores self.assertRaises(SyntaxError, Field, '_abc', 'string') @@ -201,6 +198,25 @@ class TestFields(unittest.TestCase): db.tt.drop() +class TestTables(unittest.TestCase): + + def testTableNames(self): + + # Check that Tables cannot start with underscores + self.assertRaises(SyntaxError, Table, None, '_abc') + + # Check that Tables cannot contain punctuation other than underscores + self.assertRaises(SyntaxError, Table, None, 'a.bc') + + # Check that Tables cannot be a name of a method or property of DAL + for x in ['define_table', 'tables', 'as_dict']: + self.assertRaises(SyntaxError, Table, None, x) + + # Check that Table allows underscores in the body of a field name. + self.assert_(Table(None, 'a_bc'), + "Table isn't allowing underscores in tablename. It should.") + + class TestAll(unittest.TestCase): def setUp(self): @@ -1396,45 +1412,7 @@ class TestRNameFields(unittest.TestCase): self.assertEqual(len(db.person._referenced_by),0) db.person.drop() - class TestQuoting(unittest.TestCase): - # tests for complex table names - def testRun(self): - return - db = DAL(DEFAULT_URI, check_reserved=['all']) - - t0 = db.define_table('A.table.with.dots and spaces', - Field('f', 'string')) - t1 = db.define_table('A.table', - Field('f_other', t0), - Field('words', 'text')) - - blather = 'blah blah and so' - t0[0] = {'f': 'content'} - t1[0] = {'f_other': int(t0[1]['id']), - 'words': blather} - - - r = db(t1['f_other']==t0.id).select() - self.assertEqual(r[0][db['A.table']].words, blather) - - db.define_table('t0', Field('f0')) - db.define_table('t1', Field('f1'), Field('t0', db['t0'])) - db.t0[0]=dict(f0=3) - db.t1[0]=dict(f1=3, t0=1) - - rows=db(db.t0.id==db.t1.t0).select() - self.assertEqual(rows[0].t1.t0, rows[0].t0.id) - if DEFAULT_URI.startswith('mssql'): - #there's no drop cascade in mssql - t1.drop() - t0.drop() - else: - t0.drop('cascade') - t1.drop() - - db.t1.drop() - db.t0.drop() # tests for case sensitivity def testCase(self): @@ -1444,8 +1422,8 @@ class TestQuoting(unittest.TestCase): #multiple cascade gotcha for key in ['reference','reference FK']: db._adapter.types[key]=db._adapter.types[key].replace( - '%(on_delete_action)s','NO ACTION') - + '%(on_delete_action)s','NO ACTION') + # test table case t0 = db.define_table('B', Field('f', 'string')) @@ -1493,7 +1471,7 @@ class TestQuoting(unittest.TestCase): t0.drop() def testPKFK(self): - + # test primary keys db = DAL(DEFAULT_URI, check_reserved=['all'], ignore_field_case=False) @@ -1528,6 +1506,24 @@ class TestQuoting(unittest.TestCase): t3.drop() t4.drop() + +class TestTableAndFieldCase(unittest.TestCase): + """ + at the Python level we should not allow db.C and db.c because of .table conflicts on windows + but it should be possible to map two different names into distinct tables "c" and "C" at the Python level + By default Python models names should be mapped into lower case table names and assume case insensitivity. + """ + def testme(self): + return + + +class TestQuotesByDefault(unittest.TestCase): + """ + all default tables names should be quoted unless an explicit mapping has been given for a table. + """ + def testme(self): + return + if __name__ == '__main__': unittest.main() tearDownModule()