Created
February 4, 2025 19:32
-
-
Save unbracketed/a898d409b76f51505f3da99db013eb35 to your computer and use it in GitHub Desktop.
Revisions
-
unbracketed created this gist
Feb 4, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,149 @@ # /// script # requires-python = ">=3.12" # dependencies = [ # "click", # "torch", # "transformers", # "Pillow", # "rich", # "pandas", # ] # /// import click import time from pathlib import Path from typing import List, Dict import torch from transformers import pipeline from rich.console import Console from rich.table import Table from rich import print as rprint import pandas as pd import warnings warnings.filterwarnings('ignore') # List of models to test MODELS = [ "ydshieh/vit-gpt2-coco-en", "Salesforce/blip-image-captioning-large", "microsoft/git-base-coco", ] def process_image(image_path: Path, models: List[str]) -> Dict: """Process a single image through multiple models and return results.""" results = { 'image': image_path.name, 'captions': {}, 'times': {} } for model_name in models: try: # Initialize model captioner = pipeline(model=model_name, device=0 if torch.cuda.is_available() else -1) # Generate caption caption_start = time.time() caption = captioner(str(image_path)) end_time = time.time() # Store results results['captions'][model_name] = caption[0]['generated_text'] if isinstance(caption, list) else caption results['times'][model_name] = round(end_time - caption_start, 2) except Exception as e: results['captions'][model_name] = f"Error: {str(e)}" results['times'][model_name] = -1 return results def create_summary_tables(all_results: List[Dict]) -> tuple: """Create summary tables for captions and timing.""" # Prepare data for pandas caption_data = [] timing_data = [] for result in all_results: caption_row = {'Image': result['image']} timing_row = {'Image': result['image']} caption_row.update(result['captions']) timing_row.update(result['times']) caption_data.append(caption_row) timing_data.append(timing_row) caption_df = pd.DataFrame(caption_data) timing_df = pd.DataFrame(timing_data) return caption_df, timing_df def display_results(caption_df: pd.DataFrame, timing_df: pd.DataFrame): """Display results using rich tables.""" console = Console() # Caption Results Table caption_table = Table(title="Generated Captions", show_header=True, header_style="bold magenta") for col in caption_df.columns: caption_table.add_column(col, style="cyan", no_wrap=True) for _, row in caption_df.iterrows(): caption_table.add_row(*[str(x) for x in row]) # Timing Results Table timing_table = Table(title="Processing Times (seconds)", show_header=True, header_style="bold magenta") for col in timing_df.columns: timing_table.add_column(col, style="cyan", justify="right") for _, row in timing_df.iterrows(): timing_table.add_row(*[str(x) for x in row]) # Display tables console.print("\n") console.print(caption_table) console.print("\n") console.print(timing_table) @click.command() @click.argument('image_dir', type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path)) @click.option('--models', '-m', multiple=True, help='Specific models to use (default: use all built-in models)') @click.option('--output', '-o', type=click.Path(path_type=Path), help='Save results to CSV files') def main(image_dir: Path, models: tuple, output: Path): """ Process images in a directory through multiple image captioning models. Compares generated captions and processing times. """ # Use specified models or default list model_list = list(models) if models else MODELS rprint(f"[bold green]Processing images from: {image_dir}[/bold green]") rprint(f"[bold blue]Using models: {', '.join(model_list)}[/bold blue]\n") # Get all image files image_files = list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.png')) if not image_files: rprint("[bold red]No image files (jpg/png) found in directory![/bold red]") return # Process all images all_results = [] with click.progressbar(image_files, label='Processing images') as images: for img_path in images: results = process_image(img_path, model_list) all_results.append(results) # Create and display summary tables caption_df, timing_df = create_summary_tables(all_results) display_results(caption_df, timing_df) # Save results if output path specified if output: output.parent.mkdir(parents=True, exist_ok=True) caption_df.to_csv(output.with_suffix('.captions.csv'), index=False) timing_df.to_csv(output.with_suffix('.timing.csv'), index=False) rprint(f"\n[bold green]Results saved to:{output.with_suffix('.captions.csv')}" f" and {output.with_suffix('.timing.csv')}[/bold green]") if __name__ == '__main__': main()