Created
October 20, 2024 09:01
-
-
Save ariG23498/948c263116886b3aafae95e69ac3336a to your computer and use it in GitHub Desktop.
Revisions
-
ariG23498 created this gist
Oct 20, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,78 @@ # Taken from: https://gist.github.com/sayakpaul/23862a2e7f5ab73dfdcc513751289bea from diffusers import FluxPipeline, FluxTransformer2DModel from transformers import T5EncoderModel import torch import gc def flush(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def bytes_to_giga_bytes(bytes): return bytes / 1024 / 1024 / 1024 flush() ckpt_id = "black-forest-labs/FLUX.1-dev" ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg" prompt = "a cute dog in paris photoshoot" text_encoder_2_4bit = T5EncoderModel.from_pretrained( ckpt_4bit_id, subfolder="text_encoder_2", ) pipeline = FluxPipeline.from_pretrained( ckpt_id, text_encoder_2=text_encoder_2_4bit, transformer=None, vae=None, torch_dtype=torch.float16, ) pipeline.enable_model_cpu_offload() with torch.no_grad(): print("Encoding prompts.") prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( prompt=prompt, prompt_2=None, max_sequence_length=256 ) pipeline = pipeline.to("cpu") del pipeline flush() transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer") pipeline = FluxPipeline.from_pretrained( ckpt_id, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None, transformer=transformer_4bit, torch_dtype=torch.float16, ) pipeline.enable_model_cpu_offload() print("Running denoising.") height, width = 512, 768 images = pipeline( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=50, guidance_scale=5.5, height=height, width=width, output_type="pil", ).images images[0].save("output.png")