Created
December 13, 2024 03:55
-
-
Save sammcj/ec38182b10f6be3f7e96f7259a9b37e1 to your computer and use it in GitHub Desktop.
Revisions
-
sammcj created this gist
Dec 13, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,242 @@ import asyncio import aiohttp import os from pathlib import Path import logging from bs4 import BeautifulSoup from typing import List, Dict from dataclasses import dataclass from datetime import datetime import time from urllib.parse import urlparse, parse_qs logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @dataclass class SASConfig: """Configuration for Azure Blob Storage SAS token""" container_url: str object_id: str tenant_id: str token_start: str token_expiry: str start_time: str end_time: str signature: str class AzureModelDownloader: def __init__(self, model_url: str, sas_config: SASConfig, output_dir: str = "downloads", max_parallel: int = 5): self.model_url = model_url self.output_dir = output_dir self.max_parallel = max_parallel self.files_list = [] self.sas_config = sas_config self.semaphore = asyncio.Semaphore(max_parallel) @classmethod def from_example_url(cls, example_url: str, model_url: str, output_dir: str = "downloads", max_parallel: int = 5): """Create instance by parsing an example download URL""" parsed = urlparse(example_url) query = parse_qs(parsed.query) # Extract base container URL container_url = f"{parsed.scheme}://{parsed.netloc}{os.path.dirname(os.path.dirname(parsed.path))}" sas_config = SASConfig( container_url=container_url, object_id=query['skoid'][0], tenant_id=query['sktid'][0], token_start=query['skt'][0], token_expiry=query['ske'][0], start_time=query['st'][0], end_time=query['se'][0], signature=query['sig'][0] ) return cls(model_url, sas_config, output_dir, max_parallel) def parse_file_tree(self, html_content: str) -> List[Dict]: """Parse the file tree HTML to extract file paths""" soup = BeautifulSoup(html_content, 'html.parser') files = [] nav_links = soup.find_all('div', class_='nav-link') for link in nav_links: automation_id = link.get('data-automation-id', '') if not automation_id: continue is_directory = bool(link.find('i', class_='folder-icon')) if is_directory: continue file_span = link.find('span', attrs={'data-automation-localized': 'false'}) if file_span: file_name = file_span.text files.append({ 'path': automation_id, 'name': file_name, 'full_path': os.path.join(*automation_id.split('/')) }) return files def construct_download_url(self, file_path: str) -> str: """Construct download URL for a file using SAS config""" # Remove any leading slashes and ensure proper formatting clean_path = file_path.lstrip('/') # Construct the full URL with SAS parameters params = { 'skoid': self.sas_config.object_id, 'sktid': self.sas_config.tenant_id, 'skt': self.sas_config.token_start, 'ske': self.sas_config.token_expiry, 'sks': 'b', 'skv': '2021-10-04', 'sv': '2021-10-04', 'st': self.sas_config.start_time, 'se': self.sas_config.end_time, 'sr': 'c', 'sp': 'rl', 'sig': self.sas_config.signature } query_string = '&'.join([f"{k}={v}" for k, v in params.items()]) return f"{self.sas_config.container_url}/{clean_path}?{query_string}" async def download_file(self, session: aiohttp.ClientSession, file_info: Dict): """Download a single file with retries""" file_path = os.path.join(self.output_dir, file_info['full_path']) os.makedirs(os.path.dirname(file_path), exist_ok=True) max_retries = 3 retry_delay = 1 download_url = self.construct_download_url(file_info['full_path']) async with self.semaphore: for attempt in range(max_retries): try: async with session.get(download_url) as response: if response.status != 200: raise aiohttp.ClientError(f"HTTP {response.status}") total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open(file_path, 'wb') as f: async for chunk in response.content.iter_chunked(1024*1024): f.write(chunk) downloaded += len(chunk) if total_size > 0: progress = (downloaded / total_size) * 100 print(f"\r{file_info['name']}: {progress:.1f}%", end='', flush=True) print(f"\nCompleted: {file_info['name']}") return except Exception as e: if attempt == max_retries - 1: logging.error(f"Failed to download {file_info['name']} after {max_retries} attempts: {str(e)}") else: await asyncio.sleep(retry_delay * (2 ** attempt)) logging.info(f"Retrying {file_info['name']} (attempt {attempt + 2}/{max_retries})") async def download_all_files(self): """Download all files in parallel""" async with aiohttp.ClientSession() as session: tasks = [] for file_info in self.files_list: task = asyncio.create_task(self.download_file(session, file_info)) tasks.append(task) await asyncio.gather(*tasks) def save_file_list(self): """Save the list of files to a text file""" list_path = os.path.join(self.output_dir, "files_to_download.txt") os.makedirs(self.output_dir, exist_ok=True) with open(list_path, 'w') as f: f.write("# Files to download from Azure AI Model Repository\n") f.write(f"# Model URL: {self.model_url}\n") f.write(f"# Container URL: {self.sas_config.container_url}\n") f.write(f"# Generated on: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n") f.write("File path | Target location | Download URL\n") f.write("-" * 40 + "|" + "-" * 40 + "|" + "-" * 60 + "\n") for file_info in self.files_list: download_url = self.construct_download_url(file_info['full_path']) f.write(f"{file_info['path']} | {file_info['full_path']} | {download_url}\n") return list_path async def run(self, html_content: str): """Main execution flow""" print("Parsing file tree...") self.files_list = self.parse_file_tree(html_content) print(f"Found {len(self.files_list)} files") list_file = self.save_file_list() print(f"File list saved to {list_file}") print("\nStarting downloads...") await self.download_all_files() print("\nAll downloads completed!") async def main(): import argparse parser = argparse.ArgumentParser(description="Download Azure AI model files with directory structure") parser.add_argument("--html", required=True, help="Path to the HTML file containing the file tree") parser.add_argument("--url", required=True, help="URL of the model page") parser.add_argument("--output-dir", default="downloads", help="Output directory for downloaded files") parser.add_argument("--parallel", type=int, default=5, help="Maximum parallel downloads") # SAS token parameters parser.add_argument("--example-url", help="Example download URL to extract SAS parameters from") parser.add_argument("--container-url", default="https://amlwlrt4use01.blob.core.windows.net/azureml-0002c54c-3ae6-5726-aa2b-9823dd1236dc", help="Azure Blob Storage container URL") parser.add_argument("--object-id", default="ae2sdd35-a062-42a-961d-aasdad1sd294", help="Storage account object ID (skoid)") parser.add_argument("--tenant-id", default="3sss921-ss64-4f8c-a055-5bdasdasda3d", help="Tenant ID (sktid)") parser.add_argument("--token-start", default="2024-12-13T00:12:36Z", help="Token start time (skt)") parser.add_argument("--token-expiry", default="2024-12-13T16:22:36Z", help="Token expiry time (ske)") parser.add_argument("--start-time", default="2024-12-13T03:16:01Z", help="Start time (st)") parser.add_argument("--end-time", default="2024-12-13T11:26:01Z", help="End time (se)") parser.add_argument("--signature", default="MJlFgasdasdasddwefftXqxNTasdasdd=", help="SAS signature (sig)") args = parser.parse_args() with open(args.html, 'r', encoding='utf-8') as f: html_content = f.read() if args.example_url: downloader = AzureModelDownloader.from_example_url( args.example_url, args.url, args.output_dir, args.parallel ) else: sas_config = SASConfig( container_url=args.container_url, object_id=args.object_id, tenant_id=args.tenant_id, token_start=args.token_start, token_expiry=args.token_expiry, start_time=args.start_time, end_time=args.end_time, signature=args.signature ) downloader = AzureModelDownloader(args.url, sas_config, args.output_dir, args.parallel) await downloader.run(html_content) if __name__ == "__main__": asyncio.run(main())