Skip to content

Instantly share code, notes, and snippets.

@hitrust
Forked from karpathy/stablediffusionwalk.py
Created August 17, 2022 02:24
Show Gist options
  • Save hitrust/9c057322d1be77a631b84e4c55ebcad7 to your computer and use it in GitHub Desktop.
Save hitrust/9c057322d1be77a631b84e4c55ebcad7 to your computer and use it in GitHub Desktop.

Revisions

  1. @karpathy karpathy revised this gist Aug 16, 2022. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion stablediffusionwalk.py
    Original file line number Diff line number Diff line change
    @@ -115,7 +115,8 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):

    for i, t in enumerate(np.linspace(0, 1, 200)):
    init = slerp(float(t), init1, init2)
    image = diffuse(text_embeddings, init, guidance_scale=10.0)
    with autocast("cuda"):
    image = diffuse(text_embeddings, init, guidance_scale=10.0)
    im = Image.fromarray((image[0] * 255).astype(np.uint8))
    im.save('/home/ubuntu/out/frame%06d.jpg' % n)
    print('dreaming... ', n)
  2. @karpathy karpathy revised this gist Aug 16, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion stablediffusionwalk.py
    Original file line number Diff line number Diff line change
    @@ -117,7 +117,7 @@ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
    init = slerp(float(t), init1, init2)
    image = diffuse(text_embeddings, init, guidance_scale=10.0)
    im = Image.fromarray((image[0] * 255).astype(np.uint8))
    im.save('/home/ubuntu/out/frame%04d.jpg' % n)
    im.save('/home/ubuntu/out/frame%06d.jpg' % n)
    print('dreaming... ', n)
    n += 1

  3. @karpathy karpathy revised this gist Aug 16, 2022. 1 changed file with 2 additions and 0 deletions.
    2 changes: 2 additions & 0 deletions stablediffusionwalk.py
    Original file line number Diff line number Diff line change
    @@ -10,6 +10,8 @@
    THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE
    nice slerp def from @xsteenbrugge ty
    you have to have access to stablediffusion checkpoints from https://huggingface.co/CompVis
    and install all the other dependencies (e.g. diffusers library)
    """

    from diffusers import StableDiffusionPipeline
  4. @karpathy karpathy created this gist Aug 16, 2022.
    123 changes: 123 additions & 0 deletions stablediffusionwalk.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,123 @@
    """
    draws many samples from a diffusion model by slerp'ing around
    the noise space, and dumps frames to a directory. You can then
    stitch up the frames with e.g.:
    $ ffmpeg -r 10 -f image2 -s 512x512 -i out/frame%04d.jpg -vcodec libx264 -crf 10 -pix_fmt yuv420p test.mp4
    THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE
    THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE
    THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE
    nice slerp def from @xsteenbrugge ty
    """

    from diffusers import StableDiffusionPipeline
    from time import time
    from PIL import Image
    from einops import rearrange
    import numpy as np
    import torch
    from torch import autocast
    from torchvision.utils import make_grid

    torch.manual_seed(42)

    pipe = StableDiffusionPipeline.from_pretrained("/home/ubuntu/stable-diffusion-v1-3-diffusers", use_auth_token=True)

    torch_device = 'cuda:3'
    pipe.unet.to(torch_device)
    pipe.vae.to(torch_device)
    pipe.text_encoder.to(torch_device)
    print('w00t')

    batch_size = 1
    height = 512
    width = 512

    prompt = ["ultrarealistic steam punk neural network machine in the shape of a brain, placed on a pedestal, covered with neurons made of gears. dramatic lighting. #unrealengine"] * 1
    text_input = pipe.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
    text_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0]


    @torch.no_grad()
    def diffuse(text_embeddings, init, guidance_scale = 7.5):
    # text_embeddings are n,t,d

    max_length = text_embeddings.shape[1]
    uncond_input = pipe.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
    uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    latents = init.clone()

    num_inference_steps = 50
    pipe.scheduler.set_timesteps(num_inference_steps)

    for t in pipe.scheduler.timesteps:

    # predict the noise residual
    latent_model_input = torch.cat([latents] * 2) # for cfg
    noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

    # perform guidance
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    # compute the previous noisy sample x_t -> x_t-1
    latents = pipe.scheduler.step(noise_pred, t, latents)["prev_sample"]

    # post-process
    latents = 1 / 0.18215 * latents
    image = pipe.vae.decode(latents)
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()

    return image


    def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):

    if not isinstance(v0, np.ndarray):
    inputs_are_torch = True
    input_device = v0.device
    v0 = v0.cpu().numpy()
    v1 = v1.cpu().numpy()

    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
    if np.abs(dot) > DOT_THRESHOLD:
    v2 = (1 - t) * v0 + t * v1
    else:
    theta_0 = np.arccos(dot)
    sin_theta_0 = np.sin(theta_0)
    theta_t = theta_0 * t
    sin_theta_t = np.sin(theta_t)
    s0 = np.sin(theta_0 - theta_t) / sin_theta_0
    s1 = sin_theta_t / sin_theta_0
    v2 = s0 * v0 + s1 * v1

    if inputs_are_torch:
    v2 = torch.from_numpy(v2).to(input_device)

    return v2

    # DREAM

    # sample start
    init1 = torch.randn((batch_size, pipe.unet.in_channels, height // 8, width // 8)).to(torch_device)
    n = 0
    while True:

    # sample destination
    init2 = torch.randn((batch_size, pipe.unet.in_channels, height // 8, width // 8)).to(torch_device)

    for i, t in enumerate(np.linspace(0, 1, 200)):
    init = slerp(float(t), init1, init2)
    image = diffuse(text_embeddings, init, guidance_scale=10.0)
    im = Image.fromarray((image[0] * 255).astype(np.uint8))
    im.save('/home/ubuntu/out/frame%04d.jpg' % n)
    print('dreaming... ', n)
    n += 1

    init1 = init2