""" Patched version to support PostgreSQL (original version: https://github.com/pydata/pandas/blob/v0.13.1/pandas/io/sql.py) Adapted functions are: - added _write_postgresql - updated table_exist - updated get_sqltype - updated get_schema - fix python3 compatibility Collection of query wrappers / abstractions to both facilitate data retrieval and to reduce dependency on DB-specific API. """ from __future__ import print_function from datetime import datetime, date from pandas.compat import range, lzip, map, zip import pandas.compat as compat import numpy as np import traceback from pandas.core.datetools import format as date_format from pandas.core.api import DataFrame, isnull #------------------------------------------------------------------------------ # Helper execution function def execute(sql, con, retry=True, cur=None, params=None): """ Execute the given SQL query using the provided connection object. Parameters ---------- sql: string Query to be executed con: database connection instance Database connection. Must implement PEP249 (Database API v2.0). retry: bool Not currently implemented cur: database cursor, optional Must implement PEP249 (Datbase API v2.0). If cursor is not provided, one will be obtained from the database connection. params: list or tuple, optional List of parameters to pass to execute method. Returns ------- Cursor object """ try: if cur is None: cur = con.cursor() if params is None: cur.execute(sql) else: cur.execute(sql, params) return cur except Exception: try: con.rollback() except Exception: # pragma: no cover pass print('Error on sql %s' % sql) raise def _safe_fetch(cur): try: result = cur.fetchall() if not isinstance(result, list): result = list(result) return result except Exception as e: # pragma: no cover excName = e.__class__.__name__ if excName == 'OperationalError': return [] def tquery(sql, con=None, cur=None, retry=True): """ Returns list of tuples corresponding to each row in given sql query. If only one column selected, then plain list is returned. Parameters ---------- sql: string SQL query to be executed con: SQLConnection or DB API 2.0-compliant connection cur: DB API 2.0 cursor Provide a specific connection or a specific cursor if you are executing a lot of sequential statements and want to commit outside. """ cur = execute(sql, con, cur=cur) result = _safe_fetch(cur) if con is not None: try: cur.close() con.commit() except Exception as e: excName = e.__class__.__name__ if excName == 'OperationalError': # pragma: no cover print('Failed to commit, may need to restart interpreter') else: raise traceback.print_exc() if retry: return tquery(sql, con=con, retry=False) if result and len(result[0]) == 1: # python 3 compat result = list(lzip(*result)[0]) elif result is None: # pragma: no cover result = [] return result def uquery(sql, con=None, cur=None, retry=True, params=None): """ Does the same thing as tquery, but instead of returning results, it returns the number of rows affected. Good for update queries. """ cur = execute(sql, con, cur=cur, retry=retry, params=params) result = cur.rowcount try: con.commit() except Exception as e: excName = e.__class__.__name__ if excName != 'OperationalError': raise traceback.print_exc() if retry: print('Looks like your connection failed, reconnecting...') return uquery(sql, con, retry=False) return result def read_frame(sql, con, index_col=None, coerce_float=True, params=None): """ Returns a DataFrame corresponding to the result set of the query string. Optionally provide an index_col parameter to use one of the columns as the index. Otherwise will be 0 to len(results) - 1. Parameters ---------- sql: string SQL query to be executed con: DB connection object, optional index_col: string, optional column name to use for the returned DataFrame object. coerce_float : boolean, default True Attempt to convert values to non-string, non-numeric objects (like decimal.Decimal) to floating point, useful for SQL result sets params: list or tuple, optional List of parameters to pass to execute method. """ cur = execute(sql, con, params=params) rows = _safe_fetch(cur) columns = [col_desc[0] for col_desc in cur.description] cur.close() con.commit() result = DataFrame.from_records(rows, columns=columns, coerce_float=coerce_float) if index_col is not None: result = result.set_index(index_col) return result frame_query = read_frame read_sql = read_frame def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): """ Write records stored in a DataFrame to a SQL database. Parameters ---------- frame: DataFrame name: name of SQL table con: an open SQL database connection object flavor: {'sqlite', 'mysql', 'oracle'}, default 'sqlite' if_exists: {'fail', 'replace', 'append'}, default 'fail' fail: If table exists, do nothing. replace: If table exists, drop it, recreate it, and insert data. append: If table exists, insert data. Create if does not exist. """ if 'append' in kwargs: import warnings warnings.warn("append is deprecated, use if_exists instead", FutureWarning) if kwargs['append']: if_exists = 'append' else: if_exists = 'fail' if if_exists not in ('fail', 'replace', 'append'): raise ValueError("'%s' is not valid for if_exists" % if_exists) exists = table_exists(name, con, flavor) if if_exists == 'fail' and exists: raise ValueError("Table '%s' already exists." % name) # creation/replacement dependent on the table existing and if_exist criteria create = None if exists: if if_exists == 'fail': raise ValueError("Table '%s' already exists." % name) elif if_exists == 'replace': cur = con.cursor() cur.execute("DROP TABLE %s;" % name) cur.close() create = get_schema(frame, name, flavor) else: create = get_schema(frame, name, flavor) if create is not None: cur = con.cursor() cur.execute(create) cur.close() cur = con.cursor() # Replace spaces in DataFrame column names with _. safe_names = [s.replace(' ', '_').strip() for s in frame.columns] flavor_picker = {'sqlite' : _write_sqlite, 'mysql' : _write_mysql, 'postgresql' : _write_postgresql} func = flavor_picker.get(flavor, None) if func is None: raise NotImplementedError func(frame, name, safe_names, cur) cur.close() con.commit() def _write_sqlite(frame, table, names, cur): bracketed_names = ['[' + column + ']' for column in names] col_names = ','.join(bracketed_names) wildcards = ','.join(['?'] * len(names)) insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % ( table, col_names, wildcards) # pandas types are badly handled if there is only 1 column ( Issue #3628 ) if not len(frame.columns) == 1: data = [tuple(x) for x in frame.values] else: data = [tuple(x) for x in frame.values.tolist()] cur.executemany(insert_query, data) def _write_mysql(frame, table, names, cur): bracketed_names = ['`' + column + '`' for column in names] col_names = ','.join(bracketed_names) wildcards = ','.join([r'%s'] * len(names)) insert_query = "INSERT INTO %s (%s) VALUES (%s)" % ( table, col_names, wildcards) data = [tuple(x) for x in frame.values] cur.executemany(insert_query, data) def _write_postgresql(frame, table, names, cur): bracketed_names = ['"' + column + '"' for column in names] col_names = ','.join(bracketed_names) wildcards = ','.join([r'%s'] * len(names)) insert_query = 'INSERT INTO public.%s (%s) VALUES (%s)' % ( table, col_names, wildcards) data = [tuple(x) for x in frame.values] print(insert_query) print(data) cur.executemany(insert_query, data) def table_exists(name, con, flavor): flavor_map = { 'sqlite': ("SELECT name FROM sqlite_master " "WHERE type='table' AND name='%s';") % name, 'mysql' : "SHOW TABLES LIKE '%s'" % name, 'postgresql' : "SELECT * FROM pg_catalog.pg_tables where tablename = '%s'" % name} query = flavor_map.get(flavor, None) # if query is None: # raise NotImplementedError return len(tquery(query, con)) > 0 def get_sqltype(pytype, flavor): sqltype = {'mysql': 'VARCHAR (63)', 'sqlite': 'TEXT', 'postgresql': 'VARCHAR (63)'} if issubclass(pytype, np.floating): sqltype['mysql'] = 'FLOAT' sqltype['sqlite'] = 'REAL' sqltype['postgresql'] = 'double precision' if issubclass(pytype, np.integer): #TODO: Refine integer size. sqltype['mysql'] = 'BIGINT' sqltype['sqlite'] = 'INTEGER' sqltype['postgresql'] = 'integer' if issubclass(pytype, np.datetime64) or pytype is datetime: # Caution: np.datetime64 is also a subclass of np.number. sqltype['mysql'] = 'DATETIME' sqltype['sqlite'] = 'TIMESTAMP' sqltype['postgresql'] = 'timestamp' if pytype is datetime.date: sqltype['mysql'] = 'DATE' sqltype['sqlite'] = 'TIMESTAMP' sqltype['postgresql'] = 'date' if issubclass(pytype, np.bool_): sqltype['sqlite'] = 'INTEGER' sqltype['postgresql'] = 'boolean' return sqltype[flavor] def get_schema(frame, name, flavor, keys=None): "Return a CREATE TABLE statement to suit the contents of a DataFrame." lookup_type = lambda dtype: get_sqltype(dtype.type, flavor) # Replace spaces in DataFrame column names with _. safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index] column_types = lzip(safe_columns, map(lookup_type, frame.dtypes)) if flavor == 'sqlite': columns = ',\n '.join('[%s] %s' % x for x in column_types) elif flavor == 'postgresql': columns = ',\n '.join('"%s" %s' % x for x in column_types) else: columns = ',\n '.join('`%s` %s' % x for x in column_types) keystr = '' if keys is not None: if isinstance(keys, compat.string_types): keys = (keys,) keystr = ', PRIMARY KEY (%s)' % ','.join(keys) template = """CREATE TABLE %(name)s ( %(columns)s %(keystr)s );""" create_statement = template % {'name': name, 'columns': columns, 'keystr': keystr} return create_statement def sequence2dict(seq): """Helper function for cx_Oracle. For each element in the sequence, creates a dictionary item equal to the element and keyed by the position of the item in the list. >>> sequence2dict(("Matt", 1)) {'1': 'Matt', '2': 1} Source: http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/ """ d = {} for k, v in zip(range(1, 1 + len(seq)), seq): d[str(k)] = v return d