Skip to content

Instantly share code, notes, and snippets.

@davideuler
Forked from sayakpaul/run_flux_under_24gbs.py
Created August 18, 2024 15:37
Show Gist options
  • Save davideuler/14c7156653deb6e0765fc9a426db6bc2 to your computer and use it in GitHub Desktop.
Save davideuler/14c7156653deb6e0765fc9a426db6bc2 to your computer and use it in GitHub Desktop.

Revisions

  1. @sayakpaul sayakpaul revised this gist Aug 3, 2024. 1 changed file with 4 additions and 6 deletions.
    10 changes: 4 additions & 6 deletions run_flux_under_24gbs.py
    Original file line number Diff line number Diff line change
    @@ -22,13 +22,13 @@ def bytes_to_giga_bytes(bytes):
    prompt = "a photo of a dog with cat-like look"

    text_encoder = CLIPTextModel.from_pretrained(
    ckpt_id, revision="refs/pr/1", subfolder="text_encoder", torch_dtype=torch.bfloat16
    ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
    )
    text_encoder_2 = T5EncoderModel.from_pretrained(
    ckpt_id, revision="refs/pr/1", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
    ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
    )
    tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer", revision="refs/pr/1")
    tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2", revision="refs/pr/1")
    tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
    tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2")

    pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    @@ -38,7 +38,6 @@ def bytes_to_giga_bytes(bytes):
    tokenizer_2=tokenizer_2,
    transformer=None,
    vae=None,
    revision="refs/pr/1",
    ).to("cuda")

    with torch.no_grad():
    @@ -62,7 +61,6 @@ def bytes_to_giga_bytes(bytes):
    tokenizer=None,
    tokenizer_2=None,
    vae=None,
    revision="refs/pr/1",
    torch_dtype=torch.bfloat16,
    ).to("cuda")

  2. @sayakpaul sayakpaul created this gist Aug 2, 2024.
    102 changes: 102 additions & 0 deletions run_flux_under_24gbs.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,102 @@
    from diffusers import FluxPipeline, AutoencoderKL
    from diffusers.image_processor import VaeImageProcessor
    from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
    import torch
    import gc


    def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


    def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024


    flush()

    ckpt_id = "black-forest-labs/FLUX.1-schnell"
    prompt = "a photo of a dog with cat-like look"

    text_encoder = CLIPTextModel.from_pretrained(
    ckpt_id, revision="refs/pr/1", subfolder="text_encoder", torch_dtype=torch.bfloat16
    )
    text_encoder_2 = T5EncoderModel.from_pretrained(
    ckpt_id, revision="refs/pr/1", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
    )
    tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer", revision="refs/pr/1")
    tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2", revision="refs/pr/1")

    pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder=text_encoder,
    text_encoder_2=text_encoder_2,
    tokenizer=tokenizer,
    tokenizer_2=tokenizer_2,
    transformer=None,
    vae=None,
    revision="refs/pr/1",
    ).to("cuda")

    with torch.no_grad():
    print("Encoding prompts.")
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
    prompt=prompt, prompt_2=None, max_sequence_length=256
    )

    del text_encoder
    del text_encoder_2
    del tokenizer
    del tokenizer_2
    del pipeline

    flush()

    pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    vae=None,
    revision="refs/pr/1",
    torch_dtype=torch.bfloat16,
    ).to("cuda")

    print("Running denoising.")
    height, width = 768, 1360
    # No need to wrap it up under `torch.no_grad()` as pipeline call method
    # is already wrapped under that.
    latents = pipeline(
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    num_inference_steps=4,
    guidance_scale=0.0,
    height=height,
    width=width,
    output_type="latent",
    ).images
    print(f"{latents.shape=}")

    del pipeline.transformer
    del pipeline

    flush()

    vae = AutoencoderKL.from_pretrained(ckpt_id, revision="refs/pr/1", subfolder="vae", torch_dtype=torch.bfloat16).to(
    "cuda"
    )
    vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)

    with torch.no_grad():
    print("Running decoding.")
    latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor

    image = vae.decode(latents, return_dict=False)[0]
    image = image_processor.postprocess(image, output_type="pil")
    image[0].save("image.png")