Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created August 4, 2023 08:59
Show Gist options
  • Save sayakpaul/a57a86ee7419ac3e7a7879fd100e8d06 to your computer and use it in GitHub Desktop.
Save sayakpaul/a57a86ee7419ac3e7a7879fd100e8d06 to your computer and use it in GitHub Desktop.

Revisions

  1. sayakpaul created this gist Aug 4, 2023.
    78 changes: 78 additions & 0 deletions benchmark_distilled_sd.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,78 @@
    """
    Examples:
    (1) python benchmark_distilled_sd.py --pipeline_id CompVis/stable-diffusion-v1-4
    (2) python benchmark_distilled_sd.py --pipeline_id CompVis/stable-diffusion-v1-4 --vae_path sayakpaul/taesd-diffusers
    (3) python benchmark_distilled_sd.py --pipeline_id nota-ai/bk-sdm-small
    (4) python benchmark_distilled_sd.py --pipeline_id nota-ai/bk-sdm-small --vae_path sayakpaul/taesd-diffusers
    """

    import argparse
    import time

    import torch

    from diffusers import AutoencoderTiny, DiffusionPipeline

    NUM_ITERS_TO_RUN = 3
    NUM_INFERENCE_STEPS = 25
    NUM_IMAGES_PER_PROMPT = 4
    PROMPT = "a golden vase with different flowers"
    SEED = 0


    def load_pipeline(pipeline_id, vae_path=None):
    pipe = DiffusionPipeline.from_pretrained(pipeline_id, torch_dtype=torch.float16)
    pipe = pipe.to("cuda")

    if vae_path is not None:
    pipe.vae = AutoencoderTiny.from_pretrained(
    vae_path, torch_dtype=torch.float16
    ).to("cuda")

    return pipe


    def run_inference(args):
    torch.cuda.reset_peak_memory_stats()
    pipe = load_pipeline(args.pipeline_id, args.vae_path)

    start = time.time_ns()
    for _ in range(NUM_ITERS_TO_RUN):
    images = pipe(
    PROMPT,
    num_inference_steps=NUM_INFERENCE_STEPS,
    generator=torch.manual_seed(SEED),
    num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
    ).images
    end = time.time_ns()
    mem_bytes = torch.cuda.max_memory_allocated()
    mem_MB = int(mem_bytes / (10**6))

    total_time = f"{(end - start) / 1e6:.1f}"
    results = {
    "pipeline_id": args.pipeline_id,
    "total_time (ms)": total_time,
    "memory (mb)": mem_MB,
    }
    if args.vae_path is not None:
    results.update({"vae_path": args.vae_path})
    return results


    def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
    "--pipeline_id",
    type=str,
    default="CompVis/stable-diffusion-v1-4",
    required=True,
    )
    parser.add_argument("--vae_path", type=str, default=None)
    args = parser.parse_args()
    return args


    if __name__ == "__main__":
    args = parse_args()
    results = run_inference(args)
    print(results)