Skip to content

Instantly share code, notes, and snippets.

@unbracketed
Created February 4, 2025 19:32
Show Gist options
  • Save unbracketed/a898d409b76f51505f3da99db013eb35 to your computer and use it in GitHub Desktop.
Save unbracketed/a898d409b76f51505f3da99db013eb35 to your computer and use it in GitHub Desktop.

Revisions

  1. unbracketed created this gist Feb 4, 2025.
    149 changes: 149 additions & 0 deletions image-to-text-model-comparison.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,149 @@
    # /// script
    # requires-python = ">=3.12"
    # dependencies = [
    # "click",
    # "torch",
    # "transformers",
    # "Pillow",
    # "rich",
    # "pandas",
    # ]
    # ///

    import click
    import time
    from pathlib import Path
    from typing import List, Dict
    import torch
    from transformers import pipeline
    from rich.console import Console
    from rich.table import Table
    from rich import print as rprint
    import pandas as pd
    import warnings
    warnings.filterwarnings('ignore')

    # List of models to test
    MODELS = [
    "ydshieh/vit-gpt2-coco-en",
    "Salesforce/blip-image-captioning-large",
    "microsoft/git-base-coco",
    ]

    def process_image(image_path: Path, models: List[str]) -> Dict:
    """Process a single image through multiple models and return results."""
    results = {
    'image': image_path.name,
    'captions': {},
    'times': {}
    }

    for model_name in models:
    try:
    # Initialize model
    captioner = pipeline(model=model_name, device=0 if torch.cuda.is_available() else -1)

    # Generate caption
    caption_start = time.time()
    caption = captioner(str(image_path))
    end_time = time.time()

    # Store results
    results['captions'][model_name] = caption[0]['generated_text'] if isinstance(caption, list) else caption
    results['times'][model_name] = round(end_time - caption_start, 2)

    except Exception as e:
    results['captions'][model_name] = f"Error: {str(e)}"
    results['times'][model_name] = -1

    return results

    def create_summary_tables(all_results: List[Dict]) -> tuple:
    """Create summary tables for captions and timing."""
    # Prepare data for pandas
    caption_data = []
    timing_data = []

    for result in all_results:
    caption_row = {'Image': result['image']}
    timing_row = {'Image': result['image']}

    caption_row.update(result['captions'])
    timing_row.update(result['times'])

    caption_data.append(caption_row)
    timing_data.append(timing_row)

    caption_df = pd.DataFrame(caption_data)
    timing_df = pd.DataFrame(timing_data)

    return caption_df, timing_df

    def display_results(caption_df: pd.DataFrame, timing_df: pd.DataFrame):
    """Display results using rich tables."""
    console = Console()

    # Caption Results Table
    caption_table = Table(title="Generated Captions", show_header=True, header_style="bold magenta")
    for col in caption_df.columns:
    caption_table.add_column(col, style="cyan", no_wrap=True)

    for _, row in caption_df.iterrows():
    caption_table.add_row(*[str(x) for x in row])

    # Timing Results Table
    timing_table = Table(title="Processing Times (seconds)", show_header=True, header_style="bold magenta")
    for col in timing_df.columns:
    timing_table.add_column(col, style="cyan", justify="right")

    for _, row in timing_df.iterrows():
    timing_table.add_row(*[str(x) for x in row])

    # Display tables
    console.print("\n")
    console.print(caption_table)
    console.print("\n")
    console.print(timing_table)

    @click.command()
    @click.argument('image_dir', type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
    @click.option('--models', '-m', multiple=True, help='Specific models to use (default: use all built-in models)')
    @click.option('--output', '-o', type=click.Path(path_type=Path), help='Save results to CSV files')
    def main(image_dir: Path, models: tuple, output: Path):
    """
    Process images in a directory through multiple image captioning models.
    Compares generated captions and processing times.
    """
    # Use specified models or default list
    model_list = list(models) if models else MODELS

    rprint(f"[bold green]Processing images from: {image_dir}[/bold green]")
    rprint(f"[bold blue]Using models: {', '.join(model_list)}[/bold blue]\n")

    # Get all image files
    image_files = list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.png'))
    if not image_files:
    rprint("[bold red]No image files (jpg/png) found in directory![/bold red]")
    return

    # Process all images
    all_results = []
    with click.progressbar(image_files, label='Processing images') as images:
    for img_path in images:
    results = process_image(img_path, model_list)
    all_results.append(results)

    # Create and display summary tables
    caption_df, timing_df = create_summary_tables(all_results)
    display_results(caption_df, timing_df)

    # Save results if output path specified
    if output:
    output.parent.mkdir(parents=True, exist_ok=True)
    caption_df.to_csv(output.with_suffix('.captions.csv'), index=False)
    timing_df.to_csv(output.with_suffix('.timing.csv'), index=False)
    rprint(f"\n[bold green]Results saved to:{output.with_suffix('.captions.csv')}"
    f" and {output.with_suffix('.timing.csv')}[/bold green]")

    if __name__ == '__main__':
    main()