Created
February 26, 2020 18:04
-
-
Save andrew310/a59c5d2137a87b9870a2ffaad3a3b43e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from random import randint | |
| from time import sleep | |
| from airflow.contrib.hooks.aws_hook import AwsHook | |
| from airflow.exceptions import AirflowException | |
| from airflow.models import BaseOperator | |
| from airflow.utils import apply_defaults | |
| class AWSGlueOperator(BaseOperator): | |
| """ | |
| Runs an AWS Glue Job with given parameters. AWS Glue is a serverless Spark | |
| ETL service for running Spark Jobs on the AWS cloud. | |
| Language support: Python and Scala | |
| :param job_name: unique job name per AWS Account | |
| :type str | |
| :param job_arguments: etl script arguments and AWS Glue arguments | |
| :type dict | |
| :param aws_conn_id: which of your Airflow Connections to AWS to use | |
| :type str | |
| :param region_name: aws region name (example: us-east-1) | |
| :type region_name: str | |
| :param max_retries: how many times to ping Glue to check on job | |
| :type max_retries: int | |
| """ | |
| MAX_RETRIES = 4200 | |
| template_fields = ("job_arguments",) | |
| template_ext = () | |
| ui_color = "#ededed" | |
| @apply_defaults | |
| def __init__( | |
| self, | |
| job_name="aws_default_glue_job", | |
| job_arguments={}, | |
| aws_conn_id=None, | |
| region_name=None, | |
| max_retries=MAX_RETRIES, | |
| **kwargs | |
| ): | |
| super(AWSGlueOperator, self).__init__(**kwargs) | |
| self.job_name = job_name | |
| self.job_arguments = job_arguments | |
| self.aws_conn_id = aws_conn_id | |
| self.job_run_id = None | |
| self.client = None | |
| self.hook = self.get_hook() | |
| self.region_name = region_name | |
| self.max_retries = max_retries | |
| def execute(self, context=None): | |
| self.log.info("Running AWS Batch Job - Job name: %s", self.job_name) | |
| self.client = self.hook.get_client_type("glue", region_name=self.region_name) | |
| try: | |
| response = self.client.start_job_run( | |
| JobName=self.job_name, Arguments=self.job_arguments | |
| ) | |
| self.log.info("AWS Glue Job started: %s", response) | |
| self.job_run_id = response["JobRunId"] | |
| self._poll_for_task_ended() | |
| self._check_success_task() | |
| self.log.info("AWS Glue Job has been successfully executed: %s", response) | |
| except Exception as e: | |
| self.log.info("AWS Glue Job has failed") | |
| raise AirflowException(e) | |
| def get_hook(self): | |
| return AwsHook(aws_conn_id=self.aws_conn_id) | |
| def _poll_for_task_ended(self): | |
| """ | |
| Poll for job status | |
| * docs.aws.amazon.com/general/latest/gr/api-retries.html | |
| """ | |
| # Allow a batch job some time to spin up. A random interval | |
| # decreases the chances of exceeding an AWS API throttle | |
| # limit when there are many concurrent tasks. | |
| pause = randint(60, 200) | |
| tries = 0 | |
| while tries < self.max_retries: | |
| tries += 1 | |
| self.log.info( | |
| "AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds", | |
| self.job_run_id, | |
| tries, | |
| self.max_retries, | |
| pause, | |
| ) | |
| sleep(pause) | |
| job_run = self._get_job_run() | |
| status = job_run.get("JobRunState") | |
| self.log.info("AWS Glue job (%s) status: %s", self.job_run_id, status) | |
| if status in ["SUCCEEDED", "FAILED"]: | |
| break | |
| pause = 1 + pow(tries * 0.3, 2) | |
| def _check_success_task(self): | |
| """ | |
| Check the final status of the batch job; the job status options are: | |
| 'STARTING'|'RUNNING'|'STOPPING'|'STOPPED'|'SUCCEEDED'|'FAILED'|'TIMEOUT' | |
| """ | |
| job_run = self._get_job_run() | |
| status = job_run.get("JobRunState") | |
| if status == "FAILED": | |
| reason = job_run["ErrorMessage"] | |
| raise AirflowException( | |
| "Job ({}) failed with status {}".format(self.job_run_id, reason) | |
| ) | |
| elif status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]: | |
| raise AirflowException( | |
| "Job ({}) is still pending {}".format(self.job_run_id, status) | |
| ) | |
| def _get_job_run(self): | |
| response = self.client.get_job_run(JobName=self.job_name, RunId=self.job_run_id) | |
| job_run = response.get("JobRun") | |
| return job_run | |
| def on_kill(self): | |
| response = self.client.batch_stop_job_run( | |
| JobName=self.job_name, JobRunIds=[self.job_run_id] | |
| ) | |
| self.log.info(response) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment