Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Created October 20, 2024 09:01
Show Gist options
  • Save ariG23498/948c263116886b3aafae95e69ac3336a to your computer and use it in GitHub Desktop.
Save ariG23498/948c263116886b3aafae95e69ac3336a to your computer and use it in GitHub Desktop.

Revisions

  1. ariG23498 created this gist Oct 20, 2024.
    78 changes: 78 additions & 0 deletions flux-dev-under-8gbs.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,78 @@
    # Taken from: https://gist.github.com/sayakpaul/23862a2e7f5ab73dfdcc513751289bea

    from diffusers import FluxPipeline, FluxTransformer2DModel
    from transformers import T5EncoderModel
    import torch
    import gc


    def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


    def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024


    flush()

    ckpt_id = "black-forest-labs/FLUX.1-dev"
    ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg"
    prompt = "a cute dog in paris photoshoot"


    text_encoder_2_4bit = T5EncoderModel.from_pretrained(
    ckpt_4bit_id,
    subfolder="text_encoder_2",
    )

    pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder_2=text_encoder_2_4bit,
    transformer=None,
    vae=None,
    torch_dtype=torch.float16,
    )
    pipeline.enable_model_cpu_offload()


    with torch.no_grad():
    print("Encoding prompts.")
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
    prompt=prompt, prompt_2=None, max_sequence_length=256
    )


    pipeline = pipeline.to("cpu")
    del pipeline

    flush()


    transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer")
    pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    transformer=transformer_4bit,
    torch_dtype=torch.float16,
    )
    pipeline.enable_model_cpu_offload()

    print("Running denoising.")
    height, width = 512, 768
    images = pipeline(
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    num_inference_steps=50,
    guidance_scale=5.5,
    height=height,
    width=width,
    output_type="pil",
    ).images
    images[0].save("output.png")