import os import boto3 from fnmatch import fnmatch from botocore.config import Config from traceback import format_exc from multiprocessing import Process, Queue, Pipe from concurrent.futures import ThreadPoolExecutor class DownloadException(Exception): pass class S3Exception(Exception): pass class S3DownloadException(S3Exception, DownloadException): pass class ProcessWithExceptionPiping(Process): """Multiprocessing Process with exception piping""" def __init__( self, *args, raise_at_child: bool = False, raise_at_parent: bool = True, **kwargs, ): """Initialize process with exception piping. Args: raise_at_child (bool, optional): Raise exception at child process. Defaults to False. raise_at_parent (bool, optional): Raise exception at parent process. Defaults to True. """ self.raise_at_child = raise_at_child self.raise_at_parent = raise_at_parent self.exception_traceback = None self.parent_connection, self.child_connection = Pipe() super().__init__(*args, **kwargs) def run(self): try: super().run() self.child_connection.send(None) except BaseException as exception: traceback = format_exc() self.child_connection.send((exception, traceback)) if self.raise_at_child: raise exception def join(self, *args, **kwargs): super().join(*args, **kwargs) if self.raise_at_parent: exception_traceback = self.get_exception_traceback() if exception_traceback: exception, traceback = exception_traceback raise type(exception)(f'{traceback}\n{exception}') def get_exception_traceback(self): if self.parent_connection.poll(): self.exception_traceback = self.parent_connection.recv() return self.exception_traceback class S3Client: """AWS S3 boto3-based client for downloading/syncing folders""" def __init__( self, bucket: str = '', max_attempts: int = 10, max_pool_connections: int = 100, max_download_workers: int = 20, ) -> None: """AWS S3 boto3-based client constructor. Note: Implementation only valid for UNIX. Args: bucket (str): Bucket path. max_attempts (int, optional): Maximum retry attemps. Defaults to 10 max_pool_connections (int, optional) Maximum number of concurrent requests to aws s3. Defaults to 100 max_download_workers (int, optional) Maximum number of workers for downloading files. Defaults to 20 Attributes: success_paths (list): List of tuples of successfully uploaded paths. """ self.success_paths = [] self.bucket = bucket.replace('s3://', '') self.max_attempts = max_attempts config = Config(retries={ 'mode': 'standard', 'max_attempts': max_attempts, }, max_pool_connections=max_pool_connections) self.client = boto3.client('s3', config=config) self.max_download_workers = max_download_workers self.paginator = self.client.get_paginator('list_objects_v2') def download_folder( self, dst_dir: str, prefix: str, pattern: str = None, bucket: str = '', ) -> None: """Download a folder (recursively) from S3. TODO: Benchmark this (Process + ThreadPoolExecutor) vs only ThreadPoolExecutor. Test cases: small/large list (> 10000 files), need/no need to re-download. Args: dst_dir (str): Destination local folder. prefix (str): File/folder S3 key. pattern (str, optional): Including filter pattern. For more information, check fnmatch library. Examples: '*/my_folder/image_???_*.png', '*.json'. Defaults to None. bucket (str, optional): File/folder S3 bucket. Defaults to self.bucket. """ self.downloader_queue = Queue() self.downloader = ProcessWithExceptionPiping(target=self._downloader) self.downloader.start() try: dst_dir = dst_dir.rstrip('/') prefix = prefix.rstrip('/') bucket = bucket or self.bucket self._enqueue_downloads( dst_dir=dst_dir, prefix=prefix, pattern=pattern, bucket=bucket, ) finally: self.downloader_queue.put(None) self.downloader.join() def _enqueue_downloads( self, dst_dir: str, prefix: str, pattern: str = None, bucket: str = '', ) -> None: """Enqueue files to be downloaded by the downloader background process. Args: dst_dir (str): Destination local folder. prefix (str): File/folder S3 key. pattern (str, optional): Including filter pattern. Defaults to None. bucket (str, optional): File/folder S3 bucket. Defaults to ''. """ root_dir = f'{dst_dir}/{prefix}' for page in self.paginator.paginate(Bucket=bucket, Prefix=prefix): files = page.get('Contents', ()) file_data = [ ( file.get('LastModified').timestamp(), file.get('Size'), file.get('Key'), file.get('Key').replace(f'{prefix}/', ''), ) for file in files ] if pattern: file_data = list(filter(lambda file: fnmatch(file[3], pattern), file_data)) sub_dirs = set([os.path.dirname(file[3]) for file in file_data]) for sub_dir in sub_dirs: os.makedirs(f'{root_dir}/{sub_dir}', exist_ok=True) for timestamp, size, key, sub_path in file_data: self.downloader_queue.put((bucket, key, f'{root_dir}/{sub_path}', timestamp, size)) def _downloader(self) -> None: """Downloader background process""" futures = [] with ThreadPoolExecutor(max_workers=self.max_download_workers) as executor: args = self.downloader_queue.get() while args: futures.append(executor.submit(self._download_file, *args)) args = self.downloader_queue.get() for future in futures: future.result() def _download_file( self, bucket: str, key: str, dst_path: str, timestamp: float, size: int, ) -> None: """Download a file and assign the provided timestamp to it. Args: bucket (str): File S3 bucket. key (str): File S3 key. dst_path (str): File destination local path. timestamp (float): File S3 timestamp. size (int): File S3 size. """ if os.path.exists(dst_path): dst_file_stat = os.stat(dst_path) if timestamp == dst_file_stat.st_mtime and size == dst_file_stat.st_size: return self.client.download_file(bucket, key, dst_path) if os.path.exists(dst_path): os.utime(dst_path, (timestamp, timestamp)) else: raise S3DownloadException(f'Error downloading: s3://{bucket}/{key} => {dst_path}')