Skip to content

Instantly share code, notes, and snippets.

@vdt
Forked from sayakpaul/bulk_log_results_wandb.ipynb
Created August 11, 2023 10:17
Show Gist options
  • Save vdt/5fa2d8c689fe9c3eecefd4781959c22e to your computer and use it in GitHub Desktop.
Save vdt/5fa2d8c689fe9c3eecefd4781959c22e to your computer and use it in GitHub Desktop.

Revisions

  1. @sayakpaul sayakpaul created this gist Aug 11, 2023.
    238 changes: 238 additions & 0 deletions bulk_log_results_wandb.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,238 @@
    {
    "cells": [
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "bae975ab-1261-4399-87e6-5d71d457e601",
    "metadata": {},
    "outputs": [],
    "source": [
    "import PIL\n",
    "import requests\n",
    "import torch\n",
    "from diffusers import StableDiffusionXLInstructPix2PixPipeline"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "37c55618-8606-4742-9f2c-d7c333e879ca",
    "metadata": {},
    "outputs": [],
    "source": [
    "MODEL_ID = \"sayakpaul/sdxl-instructpix2pix\"\n",
    "SEED = 0\n",
    "pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(\n",
    " MODEL_ID, torch_dtype=torch.float16\n",
    ").to(\"cuda\")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "9caa8902-77a0-4b02-921f-c8d17ddcfb4d",
    "metadata": {},
    "outputs": [],
    "source": [
    "!wget -q https://huggingface.co/spaces/timbrooks/instruct-pix2pix/resolve/main/imgs/example.jpg\n",
    "!wget -q https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "2f3ca300-6960-4cfa-a808-aac24df90bc0",
    "metadata": {},
    "outputs": [],
    "source": [
    "import hashlib\n",
    "\n",
    "\n",
    "def infer(\n",
    " prompt: str,\n",
    " image: PIL.Image.Image,\n",
    " guidance_scale=5,\n",
    " image_guidance_scale=2,\n",
    " num_inference_steps=20,\n",
    "):\n",
    " \"\"\"Performs inference with the pipeline.\"\"\"\n",
    " hash_image = hashlib.sha1(image.tobytes()).hexdigest()\n",
    " filename = f\"{str(hash_image)}_gs@{guidance_scale}_igs@{image_guidance_scale}_steps@{num_inference_steps}.png\"\n",
    " edited_image = pipe(\n",
    " prompt=prompt,\n",
    " image=image,\n",
    " guidance_scale=guidance_scale,\n",
    " image_guidance_scale=image_guidance_scale,\n",
    " num_inference_steps=num_inference_steps,\n",
    " generator=torch.manual_seed(SEED),\n",
    " ).images[0]\n",
    " edited_image.save(filename)\n",
    " return hash_image"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "db779600-d9e8-42f6-9fca-57c409f44ef4",
    "metadata": {},
    "outputs": [],
    "source": [
    "from typing import List\n",
    "\n",
    "def run_bulk_experiments(\n",
    " image: PIL.Image.Image,\n",
    " edit_prompt: str,\n",
    " guidance_scales: List[float],\n",
    " image_guidance_scales: List[float],\n",
    " steps: List[int],\n",
    "):\n",
    " \"\"\"Runs bulk experiments with the pipeline.\"\"\"\n",
    " for gs in guidance_scales:\n",
    " for igs in image_guidance_scales:\n",
    " for steps_ in steps:\n",
    " hash_image = infer(\n",
    " edit_prompt,\n",
    " image,\n",
    " guidance_scale=gs,\n",
    " image_guidance_scale=igs,\n",
    " num_inference_steps=steps_,\n",
    " )\n",
    " return hash_image"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "bcc68fa3-88f5-469d-b093-3ed2ccb781c1",
    "metadata": {},
    "outputs": [],
    "source": [
    "import wandb\n",
    "import glob\n",
    "\n",
    "\n",
    "def log_to_wandb(initial_image_path: str, edit_prompt: str, image_hex: str):\n",
    " \"\"\"Bulk logs results to wandb.\"\"\"\n",
    " wandb.init(\n",
    " project=\"instructpix2pix-sdxl-results\",\n",
    " config={\"model_id\": MODEL_ID, \"seed\": SEED},\n",
    " )\n",
    " table = wandb.Table(\n",
    " columns=[\n",
    " \"Initial Image\",\n",
    " \"Prompt\",\n",
    " \"Edited Image\",\n",
    " \"Guidance Scale\",\n",
    " \"Image Guidance Scale\",\n",
    " \"Number of Steps\",\n",
    " ]\n",
    " )\n",
    "\n",
    " edited_images = sorted(glob.glob(f\"{image_hex}_*.png\"))\n",
    " for edited_image in edited_images:\n",
    " gs = float(edited_image.split(\"_\")[1].split(\"@\")[-1])\n",
    " igs = float(edited_image.split(\"_\")[2].split(\"@\")[-1])\n",
    " steps = int(edited_image.split(\"_\")[3].split(\"@\")[-1].split(\".\")[0])\n",
    " table.add_data(\n",
    " wandb.Image(initial_image_path),\n",
    " edit_prompt,\n",
    " wandb.Image(edited_image),\n",
    " gs,\n",
    " igs,\n",
    " steps,\n",
    " )\n",
    " wandb.log({\"results\": table})\n",
    " wandb.finish()"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "0dca3588-4b71-4d38-b40c-27f0aca2c2ea",
    "metadata": {},
    "outputs": [],
    "source": [
    "cyborg = PIL.Image.open(\"example.jpg\")\n",
    "mountain = PIL.Image.open(\"mountain.png\")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "2f44f8bd-7386-4675-8e98-2a682ddbacac",
    "metadata": {},
    "outputs": [],
    "source": [
    "prompt = \"Turn him into a cyborg!\"\n",
    "guidance_scales = [5, 7, 7.5]\n",
    "image_guidance_scales = [1, 1.5, 2]\n",
    "steps = [20, 25, 40, 50]\n",
    "\n",
    "hash_image = run_bulk_experiments(\n",
    " image=cyborg,\n",
    " edit_prompt=prompt,\n",
    " guidance_scales=guidance_scales,\n",
    " image_guidance_scales=image_guidance_scales,\n",
    " steps=steps,\n",
    ")\n",
    "hash_image"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "7838ed52-3690-46b4-9136-e8fbf26381a8",
    "metadata": {},
    "outputs": [],
    "source": [
    "log_to_wandb(initial_image_path=\"example.jpg\", edit_prompt=prompt, image_hex=hash_image)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "42fad38a-1a0f-46f0-aacc-212fe893c9c8",
    "metadata": {},
    "outputs": [],
    "source": [
    "prompt = \"make the mountains snowy\"\n",
    "guidance_scales = [5, 7, 7.5]\n",
    "image_guidance_scales = [1, 1.5, 2]\n",
    "steps = [20, 25, 40, 50]\n",
    "\n",
    "hash_image = run_bulk_experiments(\n",
    " image=mountain,\n",
    " edit_prompt=prompt,\n",
    " guidance_scales=guidance_scales,\n",
    " image_guidance_scales=image_guidance_scales,\n",
    " steps=steps,\n",
    ")\n",
    "\n",
    "log_to_wandb(\n",
    " initial_image_path=\"mountain.png\", edit_prompt=prompt, image_hex=hash_image\n",
    ")"
    ]
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.8.2"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 5
    }