Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save hitrust/7ec5baa17b9a5b7c149a9e54a595a227 to your computer and use it in GitHub Desktop.
Save hitrust/7ec5baa17b9a5b7c149a9e54a595a227 to your computer and use it in GitHub Desktop.

Revisions

  1. @a-canela a-canela revised this gist Mar 23, 2023. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion s3_client_sync_download_folder.py
    Original file line number Diff line number Diff line change
    @@ -184,7 +184,7 @@ def _downloader(self) -> None:
    futures.append(executor.submit(self._download_file, *args))
    args = self.downloader_queue.get()
    for future in futures:
    future.result() # Important for raising exceptions at this level
    future.result()

    def _download_file(
    self,
  2. @a-canela a-canela revised this gist Mar 23, 2023. 1 changed file with 20 additions and 17 deletions.
    37 changes: 20 additions & 17 deletions s3_client_sync_download_folder.py
    Original file line number Diff line number Diff line change
    @@ -109,7 +109,10 @@ def download_folder(
    pattern: str = None,
    bucket: str = '',
    ) -> None:
    """Download files (either a folder recursively or a single file) from S3.
    """Download a folder (recursively) from S3.
    TODO: Benchmark this (Process + ThreadPoolExecutor) vs only ThreadPoolExecutor.
    Test cases: small/large list (> 10000 files), need/no need to re-download.
    Args:
    dst_dir (str): Destination local folder.
    @@ -152,6 +155,7 @@ def _enqueue_downloads(
    pattern (str, optional): Including filter pattern. Defaults to None.
    bucket (str, optional): File/folder S3 bucket. Defaults to ''.
    """
    root_dir = f'{dst_dir}/{prefix}'
    for page in self.paginator.paginate(Bucket=bucket, Prefix=prefix):
    files = page.get('Contents', ())
    file_data = [
    @@ -165,36 +169,30 @@ def _enqueue_downloads(
    ]
    if pattern:
    file_data = list(filter(lambda file: fnmatch(file[3], pattern), file_data))
    created_sub_dirs = set()
    sub_dirs = set([os.path.dirname(file[3]) for file in file_data])
    for sub_dir in sub_dirs:
    os.makedirs(f'{root_dir}/{sub_dir}', exist_ok=True)
    for timestamp, size, key, sub_path in file_data:
    dst_path = f'{dst_dir}/{prefix}/{sub_path}'
    if os.path.exists(dst_path):
    dst_file_stat = os.stat(dst_path)
    if timestamp == dst_file_stat.st_mtime and size == dst_file_stat.st_size:
    continue
    sub_dir = os.path.dirname(sub_path)
    if sub_dir not in created_sub_dirs:
    os.makedirs(f'{dst_dir}/{prefix}/{sub_dir}', exist_ok=True)
    created_sub_dirs.add(sub_dir)
    self.downloader_queue.put((bucket, key, dst_path, timestamp))
    self.downloader_queue.put((bucket, key, f'{root_dir}/{sub_path}', timestamp, size))

    def _downloader(self) -> None:
    """Downloader background process"""
    executor = ThreadPoolExecutor(max_workers=self.max_download_workers)
    try:
    futures = []
    with ThreadPoolExecutor(max_workers=self.max_download_workers) as executor:
    args = self.downloader_queue.get()
    while args:
    executor.submit(self._download_file, *args)
    futures.append(executor.submit(self._download_file, *args))
    args = self.downloader_queue.get()
    finally:
    executor.shutdown(wait=True)
    for future in futures:
    future.result() # Important for raising exceptions at this level

    def _download_file(
    self,
    bucket: str,
    key: str,
    dst_path: str,
    timestamp: float,
    size: int,
    ) -> None:
    """Download a file and assign the provided timestamp to it.
    @@ -203,7 +201,12 @@ def _download_file(
    key (str): File S3 key.
    dst_path (str): File destination local path.
    timestamp (float): File S3 timestamp.
    size (int): File S3 size.
    """
    if os.path.exists(dst_path):
    dst_file_stat = os.stat(dst_path)
    if timestamp == dst_file_stat.st_mtime and size == dst_file_stat.st_size:
    return
    self.client.download_file(bucket, key, dst_path)
    if os.path.exists(dst_path):
    os.utime(dst_path, (timestamp, timestamp))
  3. @a-canela a-canela revised this gist Mar 9, 2023. 1 changed file with 87 additions and 29 deletions.
    116 changes: 87 additions & 29 deletions s3_client_sync_download_folder.py
    Original file line number Diff line number Diff line change
    @@ -2,10 +2,69 @@
    import boto3
    from fnmatch import fnmatch
    from botocore.config import Config
    from multiprocessing import Process, Queue
    from traceback import format_exc
    from multiprocessing import Process, Queue, Pipe
    from concurrent.futures import ThreadPoolExecutor


    class DownloadException(Exception):
    pass


    class S3Exception(Exception):
    pass


    class S3DownloadException(S3Exception, DownloadException):
    pass


    class ProcessWithExceptionPiping(Process):
    """Multiprocessing Process with exception piping"""

    def __init__(
    self,
    *args,
    raise_at_child: bool = False,
    raise_at_parent: bool = True,
    **kwargs,
    ):
    """Initialize process with exception piping.
    Args:
    raise_at_child (bool, optional): Raise exception at child process. Defaults to False.
    raise_at_parent (bool, optional): Raise exception at parent process. Defaults to True.
    """
    self.raise_at_child = raise_at_child
    self.raise_at_parent = raise_at_parent
    self.exception_traceback = None
    self.parent_connection, self.child_connection = Pipe()
    super().__init__(*args, **kwargs)

    def run(self):
    try:
    super().run()
    self.child_connection.send(None)
    except BaseException as exception:
    traceback = format_exc()
    self.child_connection.send((exception, traceback))
    if self.raise_at_child:
    raise exception

    def join(self, *args, **kwargs):
    super().join(*args, **kwargs)
    if self.raise_at_parent:
    exception_traceback = self.get_exception_traceback()
    if exception_traceback:
    exception, traceback = exception_traceback
    raise type(exception)(f'{traceback}\n{exception}')

    def get_exception_traceback(self):
    if self.parent_connection.poll():
    self.exception_traceback = self.parent_connection.recv()
    return self.exception_traceback


    class S3Client:
    """AWS S3 boto3-based client for downloading/syncing folders"""

    @@ -28,7 +87,11 @@ def __init__(
    Defaults to 100
    max_download_workers (int, optional) Maximum number of workers for downloading files.
    Defaults to 20
    Attributes:
    success_paths (list): List of tuples of successfully uploaded paths.
    """
    self.success_paths = []
    self.bucket = bucket.replace('s3://', '')
    self.max_attempts = max_attempts
    config = Config(retries={
    @@ -39,7 +102,7 @@ def __init__(
    self.max_download_workers = max_download_workers
    self.paginator = self.client.get_paginator('list_objects_v2')

    def download_files(
    def download_folder(
    self,
    dst_dir: str,
    prefix: str,
    @@ -51,11 +114,14 @@ def download_files(
    Args:
    dst_dir (str): Destination local folder.
    prefix (str): File/folder S3 key.
    pattern (str, optional): Including filter pattern. Defaults to None.
    bucket (str, optional): File/folder S3 bucket. Defaults to ''.
    pattern (str, optional): Including filter pattern. For more information, check
    fnmatch library. Examples: '*/my_folder/image_???_*.png', '*.json'.
    Defaults to None.
    bucket (str, optional): File/folder S3 bucket.
    Defaults to self.bucket.
    """
    self.downloader_queue = Queue()
    self.downloader = Process(target=self._downloader)
    self.downloader = ProcessWithExceptionPiping(target=self._downloader)
    self.downloader.start()
    try:
    dst_dir = dst_dir.rstrip('/')
    @@ -88,38 +154,27 @@ def _enqueue_downloads(
    """
    for page in self.paginator.paginate(Bucket=bucket, Prefix=prefix):
    files = page.get('Contents', ())
    if len(files) == 1 and files[0].get('Key') == prefix:
    file_data = [
    (
    file.get('LastModified').timestamp(),
    file.get('Size'),
    file.get('Key'),
    os.path.basename(file.get('Key')),
    )
    for file in files
    ]
    else:
    file_data = [
    (
    file.get('LastModified').timestamp(),
    file.get('Size'),
    file.get('Key'),
    file.get('Key').replace(f'{prefix}/', ''),
    )
    for file in files
    ]
    file_data = [
    (
    file.get('LastModified').timestamp(),
    file.get('Size'),
    file.get('Key'),
    file.get('Key').replace(f'{prefix}/', ''),
    )
    for file in files
    ]
    if pattern:
    file_data = filter(lambda file_entry: fnmatch(file_entry[3], pattern), file_data)
    file_data = list(filter(lambda file: fnmatch(file[3], pattern), file_data))
    created_sub_dirs = set()
    for timestamp, size, key, sub_path in file_data:
    dst_path = f'{dst_dir}/{sub_path}'
    dst_path = f'{dst_dir}/{prefix}/{sub_path}'
    if os.path.exists(dst_path):
    dst_file_stat = os.stat(dst_path)
    if timestamp == dst_file_stat.st_mtime and size == dst_file_stat.st_size:
    continue
    sub_dir = os.path.dirname(sub_path)
    if sub_dir not in created_sub_dirs:
    os.makedirs(f'{dst_dir}/{sub_dir}', exist_ok=True)
    os.makedirs(f'{dst_dir}/{prefix}/{sub_dir}', exist_ok=True)
    created_sub_dirs.add(sub_dir)
    self.downloader_queue.put((bucket, key, dst_path, timestamp))

    @@ -150,4 +205,7 @@ def _download_file(
    timestamp (float): File S3 timestamp.
    """
    self.client.download_file(bucket, key, dst_path)
    os.utime(dst_path, (timestamp, timestamp))
    if os.path.exists(dst_path):
    os.utime(dst_path, (timestamp, timestamp))
    else:
    raise S3DownloadException(f'Error downloading: s3://{bucket}/{key} => {dst_path}')
  4. @a-canela a-canela revised this gist Feb 28, 2023. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion s3_client_sync_download_folder.py
    Original file line number Diff line number Diff line change
    @@ -109,7 +109,7 @@ def _enqueue_downloads(
    for file in files
    ]
    if pattern:
    file_data = [data for data in file_data if fnmatch(data[3], pattern)]
    file_data = filter(lambda file_entry: fnmatch(file_entry[3], pattern), file_data)
    created_sub_dirs = set()
    for timestamp, size, key, sub_path in file_data:
    dst_path = f'{dst_dir}/{sub_path}'
  5. @a-canela a-canela created this gist Feb 28, 2023.
    153 changes: 153 additions & 0 deletions s3_client_sync_download_folder.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,153 @@
    import os
    import boto3
    from fnmatch import fnmatch
    from botocore.config import Config
    from multiprocessing import Process, Queue
    from concurrent.futures import ThreadPoolExecutor


    class S3Client:
    """AWS S3 boto3-based client for downloading/syncing folders"""

    def __init__(
    self,
    bucket: str = '',
    max_attempts: int = 10,
    max_pool_connections: int = 100,
    max_download_workers: int = 20,
    ) -> None:
    """AWS S3 boto3-based client constructor.
    Note: Implementation only valid for UNIX.
    Args:
    bucket (str): Bucket path.
    max_attempts (int, optional): Maximum retry attemps.
    Defaults to 10
    max_pool_connections (int, optional) Maximum number of concurrent requests to aws s3.
    Defaults to 100
    max_download_workers (int, optional) Maximum number of workers for downloading files.
    Defaults to 20
    """
    self.bucket = bucket.replace('s3://', '')
    self.max_attempts = max_attempts
    config = Config(retries={
    'mode': 'standard',
    'max_attempts': max_attempts,
    }, max_pool_connections=max_pool_connections)
    self.client = boto3.client('s3', config=config)
    self.max_download_workers = max_download_workers
    self.paginator = self.client.get_paginator('list_objects_v2')

    def download_files(
    self,
    dst_dir: str,
    prefix: str,
    pattern: str = None,
    bucket: str = '',
    ) -> None:
    """Download files (either a folder recursively or a single file) from S3.
    Args:
    dst_dir (str): Destination local folder.
    prefix (str): File/folder S3 key.
    pattern (str, optional): Including filter pattern. Defaults to None.
    bucket (str, optional): File/folder S3 bucket. Defaults to ''.
    """
    self.downloader_queue = Queue()
    self.downloader = Process(target=self._downloader)
    self.downloader.start()
    try:
    dst_dir = dst_dir.rstrip('/')
    prefix = prefix.rstrip('/')
    bucket = bucket or self.bucket
    self._enqueue_downloads(
    dst_dir=dst_dir,
    prefix=prefix,
    pattern=pattern,
    bucket=bucket,
    )
    finally:
    self.downloader_queue.put(None)
    self.downloader.join()

    def _enqueue_downloads(
    self,
    dst_dir: str,
    prefix: str,
    pattern: str = None,
    bucket: str = '',
    ) -> None:
    """Enqueue files to be downloaded by the downloader background process.
    Args:
    dst_dir (str): Destination local folder.
    prefix (str): File/folder S3 key.
    pattern (str, optional): Including filter pattern. Defaults to None.
    bucket (str, optional): File/folder S3 bucket. Defaults to ''.
    """
    for page in self.paginator.paginate(Bucket=bucket, Prefix=prefix):
    files = page.get('Contents', ())
    if len(files) == 1 and files[0].get('Key') == prefix:
    file_data = [
    (
    file.get('LastModified').timestamp(),
    file.get('Size'),
    file.get('Key'),
    os.path.basename(file.get('Key')),
    )
    for file in files
    ]
    else:
    file_data = [
    (
    file.get('LastModified').timestamp(),
    file.get('Size'),
    file.get('Key'),
    file.get('Key').replace(f'{prefix}/', ''),
    )
    for file in files
    ]
    if pattern:
    file_data = [data for data in file_data if fnmatch(data[3], pattern)]
    created_sub_dirs = set()
    for timestamp, size, key, sub_path in file_data:
    dst_path = f'{dst_dir}/{sub_path}'
    if os.path.exists(dst_path):
    dst_file_stat = os.stat(dst_path)
    if timestamp == dst_file_stat.st_mtime and size == dst_file_stat.st_size:
    continue
    sub_dir = os.path.dirname(sub_path)
    if sub_dir not in created_sub_dirs:
    os.makedirs(f'{dst_dir}/{sub_dir}', exist_ok=True)
    created_sub_dirs.add(sub_dir)
    self.downloader_queue.put((bucket, key, dst_path, timestamp))

    def _downloader(self) -> None:
    """Downloader background process"""
    executor = ThreadPoolExecutor(max_workers=self.max_download_workers)
    try:
    args = self.downloader_queue.get()
    while args:
    executor.submit(self._download_file, *args)
    args = self.downloader_queue.get()
    finally:
    executor.shutdown(wait=True)

    def _download_file(
    self,
    bucket: str,
    key: str,
    dst_path: str,
    timestamp: float,
    ) -> None:
    """Download a file and assign the provided timestamp to it.
    Args:
    bucket (str): File S3 bucket.
    key (str): File S3 key.
    dst_path (str): File destination local path.
    timestamp (float): File S3 timestamp.
    """
    self.client.download_file(bucket, key, dst_path)
    os.utime(dst_path, (timestamp, timestamp))