# -*- coding: utf-8 -*- import re from .._load import psycopg2_adapt from .._globals import IDENTITY, LOGGER from ..helpers.methods import varquote_aux from .base import BaseAdapter class PostgreSQLAdapter(BaseAdapter): drivers = ('psycopg2','pg8000') QUOTE_TEMPLATE = '"%s"' support_distributed_transaction = True types = { 'boolean': 'CHAR(1)', 'string': 'VARCHAR(%(length)s)', 'text': 'TEXT', 'json': 'TEXT', 'password': 'VARCHAR(%(length)s)', 'blob': 'BYTEA', 'upload': 'VARCHAR(%(length)s)', 'integer': 'INTEGER', 'bigint': 'BIGINT', 'float': 'FLOAT', 'double': 'FLOAT8', 'decimal': 'NUMERIC(%(precision)s,%(scale)s)', 'date': 'DATE', 'time': 'TIME', 'datetime': 'TIMESTAMP', 'id': 'SERIAL PRIMARY KEY', 'reference': 'INTEGER REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s', 'list:integer': 'TEXT', 'list:string': 'TEXT', 'list:reference': 'TEXT', 'geometry': 'GEOMETRY', 'geography': 'GEOGRAPHY', 'big-id': 'BIGSERIAL PRIMARY KEY', 'big-reference': 'BIGINT REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s', 'reference FK': ', CONSTRAINT "FK_%(constraint_name)s" FOREIGN KEY (%(field_name)s) REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s', 'reference TFK': ' CONSTRAINT "FK_%(foreign_table)s_PK" FOREIGN KEY (%(field_name)s) REFERENCES %(foreign_table)s (%(foreign_key)s) ON DELETE %(on_delete_action)s', } def varquote(self,name): return varquote_aux(name,'"%s"') def adapt(self,obj): if self.driver_name == 'psycopg2': return psycopg2_adapt(obj).getquoted() elif self.driver_name == 'pg8000': return "'%s'" % str(obj).replace("%","%%").replace("'","''") else: return "'%s'" % str(obj).replace("'","''") def sequence_name(self,table): return self.QUOTE_TEMPLATE % (table + '_id_seq') def RANDOM(self): return 'RANDOM()' def ADD(self, first, second): t = first.type if t in ('text','string','password', 'json', 'upload','blob'): return '(%s || %s)' % (self.expand(first), self.expand(second, t)) else: return '(%s + %s)' % (self.expand(first), self.expand(second, t)) def distributed_transaction_begin(self,key): return def prepare(self,key): self.execute("PREPARE TRANSACTION '%s';" % key) def commit_prepared(self,key): self.execute("COMMIT PREPARED '%s';" % key) def rollback_prepared(self,key): self.execute("ROLLBACK PREPARED '%s';" % key) def create_sequence_and_triggers(self, query, table, **args): # following lines should only be executed if table._sequence_name does not exist # self.execute('CREATE SEQUENCE %s;' % table._sequence_name) # self.execute("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT NEXTVAL('%s');" \ # % (table._tablename, table._fieldname, table._sequence_name)) self.execute(query) REGEX_URI = re.compile('^(?P[^:@]+)(\:(?P[^@]*))?@(?P[^\:@]+)(\:(?P[0-9]+))?/(?P[^\?]+)(\?sslmode=(?P.+))?$') def __init__(self,db,uri,pool_size=0,folder=None,db_codec ='UTF-8', credential_decoder=IDENTITY, driver_args={}, adapter_args={}, do_connect=True, srid=4326, after_connection=None): self.db = db self.dbengine = "postgres" self.uri = uri if do_connect: self.find_driver(adapter_args,uri) self.pool_size = pool_size self.folder = folder self.db_codec = db_codec self._after_connection = after_connection self.srid = srid self.find_or_make_work_folder() self._last_insert = None # for INSERT ... RETURNING ID ruri = uri.split('://',1)[1] m = self.REGEX_URI.match(ruri) if not m: raise SyntaxError("Invalid URI string in DAL") user = credential_decoder(m.group('user')) if not user: raise SyntaxError('User required') password = credential_decoder(m.group('password')) if not password: password = '' host = m.group('host') if not host: raise SyntaxError('Host name required') db = m.group('db') if not db: raise SyntaxError('Database name required') port = m.group('port') or '5432' sslmode = m.group('sslmode') if sslmode: msg = ("dbname='%s' user='%s' host='%s' " "port=%s password='%s' sslmode='%s'") \ % (db, user, host, port, password, sslmode) else: msg = ("dbname='%s' user='%s' host='%s' " "port=%s password='%s'") \ % (db, user, host, port, password) # choose diver according uri if self.driver: self.__version__ = "%s %s" % (self.driver.__name__, self.driver.__version__) else: self.__version__ = None def connector(msg=msg,driver_args=driver_args): return self.driver.connect(msg,**driver_args) self.connector = connector if do_connect: self.reconnect() def after_connection(self): self.connection.set_client_encoding('UTF8') self.execute("SET standard_conforming_strings=on;") self.try_json() def _insert(self, table, fields): table_rname = table.sqlsafe if fields: keys = ','.join(f.sqlsafe_name for f, v in fields) values = ','.join(self.expand(v, f.type) for f, v in fields) if table._id: self._last_insert = (table._id, 1) return 'INSERT INTO %s(%s) VALUES (%s) RETURNING %s;' % ( table_rname, keys, values, table._id.name) else: self._last_insert = None return 'INSERT INTO %s(%s) VALUES (%s);' % (table_rname, keys, values) else: self._last_insert return self._insert_empty(table) def lastrowid(self, table=None): if self._last_insert: return int(self.cursor.fetchone()[0]) else: self.execute("select lastval()") return int(self.cursor.fetchone()[0]) def try_json(self): # check JSON data type support # (to be added to after_connection) # until pg8000 supports json, leave this commented #if self.driver_name == "pg8000": # supports_json = self.connection.server_version >= "9.2.0" if (self.driver_name == "psycopg2" and self.driver.__version__ >= "2.0.12"): supports_json = self.connection.server_version >= 90200 elif self.driver_name == "zxJDBC": supports_json = self.connection.dbversion >= "9.2.0" else: supports_json = None if supports_json: self.types["json"] = "JSON" if (self.driver_name == "psycopg2" and self.driver.__version__ >= '2.5.0'): self.driver_auto_json = ['loads'] else: LOGGER.debug("Your database version does not support the JSON" " data type (using TEXT instead)") def LIKE(self,first,second): args = (self.expand(first), self.expand(second,'string')) if not first.type in ('string', 'text', 'json'): return '(%s LIKE %s)' % ( self.CAST(args[0], 'CHAR(%s)' % first.length), args[1]) else: return '(%s LIKE %s)' % args def ILIKE(self,first,second): args = (self.expand(first), self.expand(second,'string')) if not first.type in ('string', 'text', 'json'): return '(%s LIKE %s)' % ( self.CAST(args[0], 'CHAR(%s)' % first.length), args[1]) else: return '(%s ILIKE %s)' % args def REGEXP(self,first,second): return '(%s ~ %s)' % (self.expand(first), self.expand(second,'string')) # GIS functions def ST_ASGEOJSON(self, first, second): """ http://postgis.org/docs/ST_AsGeoJSON.html """ return 'ST_AsGeoJSON(%s,%s,%s,%s)' %(second['version'], self.expand(first), second['precision'], second['options']) def ST_ASTEXT(self, first): """ http://postgis.org/docs/ST_AsText.html """ return 'ST_AsText(%s)' %(self.expand(first)) def ST_X(self, first): """ http://postgis.org/docs/ST_X.html """ return 'ST_X(%s)' %(self.expand(first)) def ST_Y(self, first): """ http://postgis.org/docs/ST_Y.html """ return 'ST_Y(%s)' %(self.expand(first)) def ST_CONTAINS(self, first, second): """ http://postgis.org/docs/ST_Contains.html """ return 'ST_Contains(%s,%s)' %(self.expand(first), self.expand(second, first.type)) def ST_DISTANCE(self, first, second): """ http://postgis.org/docs/ST_Distance.html """ return 'ST_Distance(%s,%s)' %(self.expand(first), self.expand(second, first.type)) def ST_EQUALS(self, first, second): """ http://postgis.org/docs/ST_Equals.html """ return 'ST_Equals(%s,%s)' %(self.expand(first), self.expand(second, first.type)) def ST_INTERSECTS(self, first, second): """ http://postgis.org/docs/ST_Intersects.html """ return 'ST_Intersects(%s,%s)' %(self.expand(first), self.expand(second, first.type)) def ST_OVERLAPS(self, first, second): """ http://postgis.org/docs/ST_Overlaps.html """ return 'ST_Overlaps(%s,%s)' %(self.expand(first), self.expand(second, first.type)) def ST_SIMPLIFY(self, first, second): """ http://postgis.org/docs/ST_Simplify.html """ return 'ST_Simplify(%s,%s)' %(self.expand(first), self.expand(second, 'double')) def ST_TOUCHES(self, first, second): """ http://postgis.org/docs/ST_Touches.html """ return 'ST_Touches(%s,%s)' %(self.expand(first), self.expand(second, first.type)) def ST_WITHIN(self, first, second): """ http://postgis.org/docs/ST_Within.html """ return 'ST_Within(%s,%s)' %(self.expand(first), self.expand(second, first.type)) def ST_DWITHIN(self, first, (second, third)): """ http://postgis.org/docs/ST_DWithin.html """ return 'ST_DWithin(%s,%s,%s)' %(self.expand(first), self.expand(second, first.type), self.expand(third, 'double')) def represent(self, obj, fieldtype): field_is_type = fieldtype.startswith if field_is_type('geo'): srid = 4326 # postGIS default srid for geometry geotype, parms = fieldtype[:-1].split('(') parms = parms.split(',') if len(parms) >= 2: schema, srid = parms[:2] if field_is_type('geometry'): value = "ST_GeomFromText('%s',%s)" %(obj, srid) elif field_is_type('geography'): value = "ST_GeogFromText('SRID=%s;%s')" %(srid, obj) # else: # raise SyntaxError('Invalid field type %s' %fieldtype) return value return BaseAdapter.represent(self, obj, fieldtype) def _drop(self, table, mode='restrict'): if mode not in ['restrict', 'cascade', '']: raise ValueError('Invalid mode: %s' % mode) return ['DROP TABLE ' + table.sqlsafe + ' ' + str(mode) + ';'] class NewPostgreSQLAdapter(PostgreSQLAdapter): drivers = ('psycopg2','pg8000') types = { 'boolean': 'CHAR(1)', 'string': 'VARCHAR(%(length)s)', 'text': 'TEXT', 'json': 'TEXT', 'password': 'VARCHAR(%(length)s)', 'blob': 'BYTEA', 'upload': 'VARCHAR(%(length)s)', 'integer': 'INTEGER', 'bigint': 'BIGINT', 'float': 'FLOAT', 'double': 'FLOAT8', 'decimal': 'NUMERIC(%(precision)s,%(scale)s)', 'date': 'DATE', 'time': 'TIME', 'datetime': 'TIMESTAMP', 'id': 'SERIAL PRIMARY KEY', 'reference': 'INTEGER REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s', 'list:integer': 'BIGINT[]', 'list:string': 'TEXT[]', 'list:reference': 'BIGINT[]', 'geometry': 'GEOMETRY', 'geography': 'GEOGRAPHY', 'big-id': 'BIGSERIAL PRIMARY KEY', 'big-reference': 'BIGINT REFERENCES %(foreign_key)s ON DELETE %(on_delete_action)s', } def parse_list_integers(self, value, field_type): return value def parse_list_references(self, value, field_type): return [self.parse_reference(r, field_type[5:]) for r in value] def parse_list_strings(self, value, field_type): return value def represent(self, obj, fieldtype): field_is_type = fieldtype.startswith if field_is_type('list:'): if not obj: obj = [] elif not isinstance(obj, (list, tuple)): obj = [obj] if field_is_type('list:string'): obj = map(str,obj) else: obj = map(int,obj) return 'ARRAY[%s]' % ','.join(repr(item) for item in obj) return BaseAdapter.represent(self, obj, fieldtype) class JDBCPostgreSQLAdapter(PostgreSQLAdapter): drivers = ('zxJDBC',) REGEX_URI = re.compile('^(?P[^:@]+)(\:(?P[^@]*))?@(?P[^\:/]+)(\:(?P[0-9]+))?/(?P.+)$') def __init__(self,db,uri,pool_size=0,folder=None,db_codec ='UTF-8', credential_decoder=IDENTITY, driver_args={}, adapter_args={}, do_connect=True, after_connection=None ): self.db = db self.dbengine = "postgres" self.uri = uri if do_connect: self.find_driver(adapter_args,uri) self.pool_size = pool_size self.folder = folder self.db_codec = db_codec self._after_connection = after_connection self.find_or_make_work_folder() ruri = uri.split('://',1)[1] m = self.REGEX_URI.match(ruri) if not m: raise SyntaxError("Invalid URI string in DAL") user = credential_decoder(m.group('user')) if not user: raise SyntaxError('User required') password = credential_decoder(m.group('password')) if not password: password = '' host = m.group('host') if not host: raise SyntaxError('Host name required') db = m.group('db') if not db: raise SyntaxError('Database name required') port = m.group('port') or '5432' msg = ('jdbc:postgresql://%s:%s/%s' % (host, port, db), user, password) def connector(msg=msg,driver_args=driver_args): return self.driver.connect(*msg,**driver_args) self.connector = connector if do_connect: self.reconnect() def after_connection(self): self.connection.set_client_encoding('UTF8') self.execute('BEGIN;') self.execute("SET CLIENT_ENCODING TO 'UNICODE';") self.try_json()