Skip to content

Instantly share code, notes, and snippets.

@sammcj
Created December 13, 2024 03:55
Show Gist options
  • Save sammcj/ec38182b10f6be3f7e96f7259a9b37e1 to your computer and use it in GitHub Desktop.
Save sammcj/ec38182b10f6be3f7e96f7259a9b37e1 to your computer and use it in GitHub Desktop.

Revisions

  1. sammcj created this gist Dec 13, 2024.
    242 changes: 242 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,242 @@
    import asyncio
    import aiohttp
    import os
    from pathlib import Path
    import logging
    from bs4 import BeautifulSoup
    from typing import List, Dict
    from dataclasses import dataclass
    from datetime import datetime
    import time
    from urllib.parse import urlparse, parse_qs

    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

    @dataclass
    class SASConfig:
    """Configuration for Azure Blob Storage SAS token"""
    container_url: str
    object_id: str
    tenant_id: str
    token_start: str
    token_expiry: str
    start_time: str
    end_time: str
    signature: str

    class AzureModelDownloader:
    def __init__(self, model_url: str, sas_config: SASConfig, output_dir: str = "downloads", max_parallel: int = 5):
    self.model_url = model_url
    self.output_dir = output_dir
    self.max_parallel = max_parallel
    self.files_list = []
    self.sas_config = sas_config
    self.semaphore = asyncio.Semaphore(max_parallel)

    @classmethod
    def from_example_url(cls, example_url: str, model_url: str, output_dir: str = "downloads", max_parallel: int = 5):
    """Create instance by parsing an example download URL"""
    parsed = urlparse(example_url)
    query = parse_qs(parsed.query)

    # Extract base container URL
    container_url = f"{parsed.scheme}://{parsed.netloc}{os.path.dirname(os.path.dirname(parsed.path))}"

    sas_config = SASConfig(
    container_url=container_url,
    object_id=query['skoid'][0],
    tenant_id=query['sktid'][0],
    token_start=query['skt'][0],
    token_expiry=query['ske'][0],
    start_time=query['st'][0],
    end_time=query['se'][0],
    signature=query['sig'][0]
    )

    return cls(model_url, sas_config, output_dir, max_parallel)

    def parse_file_tree(self, html_content: str) -> List[Dict]:
    """Parse the file tree HTML to extract file paths"""
    soup = BeautifulSoup(html_content, 'html.parser')
    files = []

    nav_links = soup.find_all('div', class_='nav-link')

    for link in nav_links:
    automation_id = link.get('data-automation-id', '')
    if not automation_id:
    continue

    is_directory = bool(link.find('i', class_='folder-icon'))
    if is_directory:
    continue

    file_span = link.find('span', attrs={'data-automation-localized': 'false'})
    if file_span:
    file_name = file_span.text
    files.append({
    'path': automation_id,
    'name': file_name,
    'full_path': os.path.join(*automation_id.split('/'))
    })

    return files

    def construct_download_url(self, file_path: str) -> str:
    """Construct download URL for a file using SAS config"""
    # Remove any leading slashes and ensure proper formatting
    clean_path = file_path.lstrip('/')

    # Construct the full URL with SAS parameters
    params = {
    'skoid': self.sas_config.object_id,
    'sktid': self.sas_config.tenant_id,
    'skt': self.sas_config.token_start,
    'ske': self.sas_config.token_expiry,
    'sks': 'b',
    'skv': '2021-10-04',
    'sv': '2021-10-04',
    'st': self.sas_config.start_time,
    'se': self.sas_config.end_time,
    'sr': 'c',
    'sp': 'rl',
    'sig': self.sas_config.signature
    }

    query_string = '&'.join([f"{k}={v}" for k, v in params.items()])
    return f"{self.sas_config.container_url}/{clean_path}?{query_string}"

    async def download_file(self, session: aiohttp.ClientSession, file_info: Dict):
    """Download a single file with retries"""
    file_path = os.path.join(self.output_dir, file_info['full_path'])
    os.makedirs(os.path.dirname(file_path), exist_ok=True)

    max_retries = 3
    retry_delay = 1

    download_url = self.construct_download_url(file_info['full_path'])

    async with self.semaphore:
    for attempt in range(max_retries):
    try:
    async with session.get(download_url) as response:
    if response.status != 200:
    raise aiohttp.ClientError(f"HTTP {response.status}")

    total_size = int(response.headers.get('content-length', 0))
    downloaded = 0

    with open(file_path, 'wb') as f:
    async for chunk in response.content.iter_chunked(1024*1024):
    f.write(chunk)
    downloaded += len(chunk)
    if total_size > 0:
    progress = (downloaded / total_size) * 100
    print(f"\r{file_info['name']}: {progress:.1f}%", end='', flush=True)

    print(f"\nCompleted: {file_info['name']}")
    return

    except Exception as e:
    if attempt == max_retries - 1:
    logging.error(f"Failed to download {file_info['name']} after {max_retries} attempts: {str(e)}")
    else:
    await asyncio.sleep(retry_delay * (2 ** attempt))
    logging.info(f"Retrying {file_info['name']} (attempt {attempt + 2}/{max_retries})")

    async def download_all_files(self):
    """Download all files in parallel"""
    async with aiohttp.ClientSession() as session:
    tasks = []
    for file_info in self.files_list:
    task = asyncio.create_task(self.download_file(session, file_info))
    tasks.append(task)

    await asyncio.gather(*tasks)

    def save_file_list(self):
    """Save the list of files to a text file"""
    list_path = os.path.join(self.output_dir, "files_to_download.txt")
    os.makedirs(self.output_dir, exist_ok=True)

    with open(list_path, 'w') as f:
    f.write("# Files to download from Azure AI Model Repository\n")
    f.write(f"# Model URL: {self.model_url}\n")
    f.write(f"# Container URL: {self.sas_config.container_url}\n")
    f.write(f"# Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
    f.write("File path | Target location | Download URL\n")
    f.write("-" * 40 + "|" + "-" * 40 + "|" + "-" * 60 + "\n")

    for file_info in self.files_list:
    download_url = self.construct_download_url(file_info['full_path'])
    f.write(f"{file_info['path']} | {file_info['full_path']} | {download_url}\n")

    return list_path

    async def run(self, html_content: str):
    """Main execution flow"""
    print("Parsing file tree...")
    self.files_list = self.parse_file_tree(html_content)
    print(f"Found {len(self.files_list)} files")

    list_file = self.save_file_list()
    print(f"File list saved to {list_file}")

    print("\nStarting downloads...")
    await self.download_all_files()
    print("\nAll downloads completed!")

    async def main():
    import argparse

    parser = argparse.ArgumentParser(description="Download Azure AI model files with directory structure")
    parser.add_argument("--html", required=True, help="Path to the HTML file containing the file tree")
    parser.add_argument("--url", required=True, help="URL of the model page")
    parser.add_argument("--output-dir", default="downloads", help="Output directory for downloaded files")
    parser.add_argument("--parallel", type=int, default=5, help="Maximum parallel downloads")

    # SAS token parameters
    parser.add_argument("--example-url", help="Example download URL to extract SAS parameters from")
    parser.add_argument("--container-url", default="https://amlwlrt4use01.blob.core.windows.net/azureml-0002c54c-3ae6-5726-aa2b-9823dd1236dc",
    help="Azure Blob Storage container URL")
    parser.add_argument("--object-id", default="ae2sdd35-a062-42a-961d-aasdad1sd294",
    help="Storage account object ID (skoid)")
    parser.add_argument("--tenant-id", default="3sss921-ss64-4f8c-a055-5bdasdasda3d",
    help="Tenant ID (sktid)")
    parser.add_argument("--token-start", default="2024-12-13T00:12:36Z",
    help="Token start time (skt)")
    parser.add_argument("--token-expiry", default="2024-12-13T16:22:36Z",
    help="Token expiry time (ske)")
    parser.add_argument("--start-time", default="2024-12-13T03:16:01Z",
    help="Start time (st)")
    parser.add_argument("--end-time", default="2024-12-13T11:26:01Z",
    help="End time (se)")
    parser.add_argument("--signature", default="MJlFgasdasdasddwefftXqxNTasdasdd=",
    help="SAS signature (sig)")

    args = parser.parse_args()

    with open(args.html, 'r', encoding='utf-8') as f:
    html_content = f.read()

    if args.example_url:
    downloader = AzureModelDownloader.from_example_url(
    args.example_url, args.url, args.output_dir, args.parallel
    )
    else:
    sas_config = SASConfig(
    container_url=args.container_url,
    object_id=args.object_id,
    tenant_id=args.tenant_id,
    token_start=args.token_start,
    token_expiry=args.token_expiry,
    start_time=args.start_time,
    end_time=args.end_time,
    signature=args.signature
    )
    downloader = AzureModelDownloader(args.url, sas_config, args.output_dir, args.parallel)

    await downloader.run(html_content)

    if __name__ == "__main__":
    asyncio.run(main())