Skip to content

Instantly share code, notes, and snippets.

@andrew310
Created February 26, 2020 18:04
Show Gist options
  • Select an option

  • Save andrew310/a59c5d2137a87b9870a2ffaad3a3b43e to your computer and use it in GitHub Desktop.

Select an option

Save andrew310/a59c5d2137a87b9870a2ffaad3a3b43e to your computer and use it in GitHub Desktop.
from random import randint
from time import sleep
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils import apply_defaults
class AWSGlueOperator(BaseOperator):
"""
Runs an AWS Glue Job with given parameters. AWS Glue is a serverless Spark
ETL service for running Spark Jobs on the AWS cloud.
Language support: Python and Scala
:param job_name: unique job name per AWS Account
:type str
:param job_arguments: etl script arguments and AWS Glue arguments
:type dict
:param aws_conn_id: which of your Airflow Connections to AWS to use
:type str
:param region_name: aws region name (example: us-east-1)
:type region_name: str
:param max_retries: how many times to ping Glue to check on job
:type max_retries: int
"""
MAX_RETRIES = 4200
template_fields = ("job_arguments",)
template_ext = ()
ui_color = "#ededed"
@apply_defaults
def __init__(
self,
job_name="aws_default_glue_job",
job_arguments={},
aws_conn_id=None,
region_name=None,
max_retries=MAX_RETRIES,
**kwargs
):
super(AWSGlueOperator, self).__init__(**kwargs)
self.job_name = job_name
self.job_arguments = job_arguments
self.aws_conn_id = aws_conn_id
self.job_run_id = None
self.client = None
self.hook = self.get_hook()
self.region_name = region_name
self.max_retries = max_retries
def execute(self, context=None):
self.log.info("Running AWS Batch Job - Job name: %s", self.job_name)
self.client = self.hook.get_client_type("glue", region_name=self.region_name)
try:
response = self.client.start_job_run(
JobName=self.job_name, Arguments=self.job_arguments
)
self.log.info("AWS Glue Job started: %s", response)
self.job_run_id = response["JobRunId"]
self._poll_for_task_ended()
self._check_success_task()
self.log.info("AWS Glue Job has been successfully executed: %s", response)
except Exception as e:
self.log.info("AWS Glue Job has failed")
raise AirflowException(e)
def get_hook(self):
return AwsHook(aws_conn_id=self.aws_conn_id)
def _poll_for_task_ended(self):
"""
Poll for job status
* docs.aws.amazon.com/general/latest/gr/api-retries.html
"""
# Allow a batch job some time to spin up. A random interval
# decreases the chances of exceeding an AWS API throttle
# limit when there are many concurrent tasks.
pause = randint(60, 200)
tries = 0
while tries < self.max_retries:
tries += 1
self.log.info(
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
self.job_run_id,
tries,
self.max_retries,
pause,
)
sleep(pause)
job_run = self._get_job_run()
status = job_run.get("JobRunState")
self.log.info("AWS Glue job (%s) status: %s", self.job_run_id, status)
if status in ["SUCCEEDED", "FAILED"]:
break
pause = 1 + pow(tries * 0.3, 2)
def _check_success_task(self):
"""
Check the final status of the batch job; the job status options are:
'STARTING'|'RUNNING'|'STOPPING'|'STOPPED'|'SUCCEEDED'|'FAILED'|'TIMEOUT'
"""
job_run = self._get_job_run()
status = job_run.get("JobRunState")
if status == "FAILED":
reason = job_run["ErrorMessage"]
raise AirflowException(
"Job ({}) failed with status {}".format(self.job_run_id, reason)
)
elif status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]:
raise AirflowException(
"Job ({}) is still pending {}".format(self.job_run_id, status)
)
def _get_job_run(self):
response = self.client.get_job_run(JobName=self.job_name, RunId=self.job_run_id)
job_run = response.get("JobRun")
return job_run
def on_kill(self):
response = self.client.batch_stop_job_run(
JobName=self.job_name, JobRunIds=[self.job_run_id]
)
self.log.info(response)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment