Skip to content

Instantly share code, notes, and snippets.

@walidsa3d
Forked from praveen-udacity/postgres_to_redshift.py
Created September 16, 2017 23:41
Show Gist options
  • Save walidsa3d/d1c6733b36eb033ba492ba2f3e63ebb4 to your computer and use it in GitHub Desktop.
Save walidsa3d/d1c6733b36eb033ba492ba2f3e63ebb4 to your computer and use it in GitHub Desktop.
Postgres 2 Redshift Operator - Interview Code Review
import logging
import unicodecsv
from datetime import datetime as dt
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException
from airflow.hooks.postgres_hook import PostgresHook
from airflow.plugins_manager import AirflowPlugin
from psycopg2.extras import DictCursor
from helpers.s3hook import S3Hook
class PostgresToRedshift(BaseOperator):
"""Loads data from postgres to Redshift.
:param conn_id: the connection id to the source postgres DB.
:param dest_conn_id: destination connection id string.
:param config: python dictionary file for performing operation.
"""
template_fields = ('config',)
ui_color = "#808080"
@apply_defaults
def __init__(
self,
conn_id,
dest_conn_id,
config,
autocommit=False,
parameters=None,
*args, **kwargs):
super(PostgresToRedshift, self).__init__(*args, **kwargs)
self.config = config
self.conn_id = conn_id
self.dest_conn_id = dest_conn_id
self.autocommit = autocommit
self.parameters = parameters
self.config_properties = ["dest_schema", "dest_table", "source_sql",
"post_processor", "column_map", "truncate_dest"]
def return_dict(self, sql):
result = list()
with PostgresHook(postgres_conn_id=self.conn_id).get_conn() as conn:
logging.info("Connected to the postgres database.")
with conn.cursor(cursor_factory=DictCursor) as cur:
cur.execute(sql)
result.extend(dict(i.items()) for i in cur.fetchall())
return result
def dict_to_tsv(self, data):
table_schema, table_name = self.config.get('dest_schema'), self.config.get('dest_table')
file_name = "_".join([table_schema, table_name, dt.utcnow().strftime("%Y-%m-%d-%H-%M-%S")]) + ".tsv"
with open('/tmp/' + file_name, 'w') as output_file:
csv_writer = unicodecsv.DictWriter(
output_file, self.columns, delimiter="|", quoting=unicodecsv.QUOTE_NONE)
csv_writer.writerows(data)
logging.info("Write successful to temporary tsv file: ", file_name)
return file_name
def copy_to_s3(self, filename):
hook = S3Hook('airflow_s3')
bucket_name = "s3://newdacity-aws-data-pipeline"
key = hook.get_key(bucket_name)
key.key = "plugin_import/{0}".format(filename)
key.set_contents_from_filename('/tmp/'+filename)
logging.info("Dumped {0} to S3.".format(filename))
def create_redshift_schema_and_table(self):
schema_name = self.config.get('dest_schema')
table_name = self.config.get('dest_table')
column_map = self.config.get('column_map')
query = 'CREATE SCHEMA IF NOT EXISTS "{schema}"'.format(schema=schema_name)
# create the schema
with PostgresHook(postgres_conn_id=self.dest_conn_id).get_conn() as conn:
cur = conn.cursor()
cur.execute(query)
with PostgresHook(postgres_conn_id=self.dest_conn_id).get_conn() as conn:
cur = conn.cursor()
cur.execute(query)
def copy_to_redshift(self):
start_time = time.time()
self.create_redshift_schema_and_table()
self.comment_on_table()
template = self.env.get_template('s3_csv_to_redshift_operator.sql')
schema_name, table_name = self.config.get('dest_schema'), self.config.get('dest_table')
column_map = self.config.get('column_map')
primary_key = [col.get('name') for col in column_map if "is_primary_key" in col]
query = template.render(
schema=schema_name,
table=table_name,
filename=self.filename,
truncate=True if self.config.get('TRUNCATE_DEST') else False,
primary_key=[] if self.config.get('TRUNCATE_DEST') else primary_key
)
hook = S3Hook(self.S3_CONN_ID)
credentials = 'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
credentials = credentials.format(
access_key=hook.connection.aws_access_key_id,
secret_key=hook.connection.aws_secret_access_key)
parameters = (credentials, )
ps_hook = PostgresHook(postgres_conn_id=self.dest_conn_id)
ps_hook.run(query, autocommit=True, parameters=parameters)
logging.info("Executed query:\n %s", query)
self.time_to_copy_to_redshift = time.time() - start_time
def execute(self, context):
# check the config.
self.check_config()
logging.info("Config has no errors for the task.")
sql = self.config.get('source_sql')
logging.info("Executing: " + str(sql))
# step one: load the data from postgres DB
result = self.return_dict(sql)
logging.info("Executed the query. Total data points fetched {0}.".format(len(result)))
file_name = self.dict_to_tsv(result)
logging.info("Wrote to file: {0}".format(file_name))
self.copy_to_s3(file_name)
def check_config(self):
"""
Checks the config file for any errors.
:return:
"""
config = self.config
if isinstance(config, dict):
# check for properties
unknown_keys = [key for key in self.config if key not in self.config_properties]
if len(unknown_keys) > 0:
raise Exception(str(unknown_keys) + " is not present in the config spec.\n" +
" Only {0} are valid keys.".format(self.config_properties))
# now we know all the properties are present.
if not isinstance(self.config.get('column_map'), list):
raise AirflowException("Column Map should be a list of dictionaries.")
column_map = self.config.get('column_map')
# check in column map if there is primary key or no
if not any(("is_primary_key" in col for col in column_map)):
raise AirflowException("No primary key defined in column map.")
self.columns = [col['name'] for col in column_map]
logging.info("Columns, found from config file: " + str(self.columns))
class PostgresToRedshiftPlugin(AirflowPlugin):
name = "postgres_utils"
operators = [PostgresToRedshift, ]
{# Jinja2 template for loading a csv into redshift
After rendering template you will still need to replace %s with aws credentials
This is done so credentials dont get printed in logs #}
BEGIN TRANSACTION;
DROP TABLE IF EXISTS {{table}}_staging;
CREATE TEMPORARY TABLE {{table}}_staging (LIKE {{schema}}.{{table}});
COPY {{table}}_staging
FROM 's3://data-ilum/pipeline_dumps/{{filename}}'
CREDENTIALS %s
FORMAT CSV
DELIMITER ','
TRUNCATECOLUMNS;
{% if truncate %}
TRUNCATE TABLE {{schema}}.{{table}};
{% else %}
UNLOAD ('SELECT * FROM {{schema}}.{{table}} JOIN {{table}}_staging ON
{% for key in primary_key %}
{% if loop.index > 1 %} AND{% endif %} {{table}}.{{key}} = {{table}}_staging.{{key}}
{% endfor %};')
TO 's3://data-ilum/redshift_dumps/{{file_prefix}}'
WITH CREDENTIALS %s
MANIFEST
PARALLEL off;
DELETE FROM {{schema}}.{{table}}
USING {{table}}_staging
WHERE
{% for key in primary_key %}
{% if loop.index > 1 %} AND{% endif %} {{table}}.{{key}} = {{table}}_staging.{{key}}
{% endfor %};
{% endif %}
INSERT INTO {{schema}}.{{table}}
SELECT * FROM {{table}}_staging;
END TRANSACTION;
DROP TABLE {{table}}_staging;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment