Skip to content

Instantly share code, notes, and snippets.

@artsparkAI
Created April 11, 2023 12:09
Show Gist options
  • Save artsparkAI/3d16e4ce379353a289864feeef4a3eb6 to your computer and use it in GitHub Desktop.
Save artsparkAI/3d16e4ce379353a289864feeef4a3eb6 to your computer and use it in GitHub Desktop.

Revisions

  1. artsparkAI created this gist Apr 11, 2023.
    677 changes: 677 additions & 0 deletions controlnet.patch
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,677 @@
    From 8b4ef05aa89479616bbc28bb7e74ed9dc37a83c2 Mon Sep 17 00:00:00 2001
    From: artspark <[email protected]>
    Date: Tue, 11 Apr 2023 12:08:26 +0000
    Subject: [PATCH] Backend changes

    ---
    backend/ldm/generate.py | 15 +-
    backend/ldm/invoke/generator/base.py | 3 +-
    .../invoke/generator/diffusers_pipeline.py | 226 +++++++++++++++++-
    backend/ldm/invoke/generator/img2img.py | 27 ++-
    backend/ldm/invoke/generator/inpaint.py | 12 +
    backend/ldm/invoke/generator/txt2img.py | 3 +-
    backend/ldm/invoke/model_manager.py | 6 +-
    backend/ldm/models/diffusion/ddpm.py | 2 +-
    8 files changed, 268 insertions(+), 26 deletions(-)

    diff --git a/backend/ldm/generate.py b/backend/ldm/generate.py
    index a5649a92..d60046fd 100644
    --- a/backend/ldm/generate.py
    +++ b/backend/ldm/generate.py
    @@ -341,6 +341,7 @@ class Generate:
    infill_method = None,
    force_outpaint: bool = False,
    enable_image_debugging = False,
    + control = None,

    **args,
    ): # eat up additional cruft
    @@ -423,9 +424,9 @@ class Generate:

    assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
    assert threshold >= 0.0, '--threshold must be >=0.0'
    - assert (
    - 0.0 < strength < 1.0
    - ), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
    + #assert (
    + # 0.0 < strength < 1.0
    + #), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
    assert (
    0.0 <= variation_amount <= 1.0
    ), '-v --variation_amount must be in [0.0, 1.0]'
    @@ -517,7 +518,10 @@ class Generate:
    )

    # TODO: Hacky selection of operation to perform. Needs to be refactored.
    + print("control", control)
    generator = self.select_generator(init_image, mask_image, embiggen, hires_fix, force_outpaint)
    + #generator.model.control = control
    + #generator.model.control = 'canny'

    generator.set_variation(
    self.seed, variation_amount, with_variations
    @@ -534,6 +538,7 @@ class Generate:
    iterations=iterations,
    seed=self.seed,
    sampler=self.sampler,
    + control=control,
    steps=steps,
    cfg_scale=cfg_scale,
    conditioning=(uc, c, extra_conditioning_info),
    @@ -753,6 +758,7 @@ class Generate:
    ):
    inpainting_model_in_use = self.sampler.uses_inpainting_model()

    +
    if hires_fix:
    return self._make_txt2img2img()

    @@ -817,6 +823,9 @@ class Generate:
    def _make_base(self):
    return self._load_generator('','Generator')

    + def _make_controlnet(self):
    + return self._load_generator('.controlnet','ControlNet')
    +
    def _make_txt2img(self):
    return self._load_generator('.txt2img','Txt2Img')

    diff --git a/backend/ldm/invoke/generator/base.py b/backend/ldm/invoke/generator/base.py
    index 467cbe38..254e65bb 100644
    --- a/backend/ldm/invoke/generator/base.py
    +++ b/backend/ldm/invoke/generator/base.py
    @@ -61,7 +61,7 @@ class Generator:

    def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
    image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
    - safety_checker:dict=None, orig_prompt=None,
    + safety_checker:dict=None, orig_prompt=None, control=None,
    free_gpu_mem: bool=False,
    **kwargs):
    scope = nullcontext
    @@ -79,6 +79,7 @@ class Generator:
    threshold = threshold,
    perlin = perlin,
    attention_maps_callback = attention_maps_callback,
    + control = control,
    **kwargs
    )
    results = []
    diff --git a/backend/ldm/invoke/generator/diffusers_pipeline.py b/backend/ldm/invoke/generator/diffusers_pipeline.py
    index 69412057..82f87ed3 100644
    --- a/backend/ldm/invoke/generator/diffusers_pipeline.py
    +++ b/backend/ldm/invoke/generator/diffusers_pipeline.py
    @@ -6,7 +6,7 @@ import secrets
    import sys
    import warnings
    from dataclasses import dataclass, field
    -from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
    +from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any, Dict

    if sys.version_info < (3, 10):
    from typing_extensions import ParamSpec
    @@ -14,11 +14,13 @@ else:
    from typing import ParamSpec

    import PIL.Image
    +from PIL import Image
    import einops
    import torch
    import torchvision.transforms as T
    from diffusers.models import attention
    from diffusers.utils.import_utils import is_xformers_available
    +import numpy as np

    from ...models.diffusion import cross_attention_control
    from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
    @@ -28,10 +30,12 @@ from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToE
    # this is to make prompt2prompt and (future) attention maps work
    attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention

    -from diffusers.models import AutoencoderKL, UNet2DConditionModel
    +from diffusers.models import AutoencoderKL, UNet2DConditionModel, ControlNetModel
    from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
    from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
    from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
    +from diffusers import StableDiffusionControlNetPipeline, StableDiffusionInpaintPipeline
    +from diffusers.utils import load_image
    from .safety_checker import StableDiffusionSafetyChecker
    #from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
    from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
    @@ -44,6 +48,21 @@ from ldm.invoke.globals import Globals
    from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
    from ldm.modules.textual_inversion_manager import TextualInversionManager

    +from web_pdb import set_trace
    +
    +
    +import controlnet_hinter
    +#controlnet_hinter.scribble = controlnet_hinter.fake_scribble
    +CONTROLNETS = [
    + 'canny',
    + 'depth',
    + 'scribble',
    + 'hed',
    + 'mlsd',
    + 'normal',
    + 'openpose',
    + 'seg',
    +];



    @@ -233,14 +252,15 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
    r"""
    Output class for InvokeAI's Stable Diffusion pipeline.

    +
    Args:
    attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
    after generation completes. Optional.
    """
    attention_map_saver: Optional[AttentionMapSaver]

    -
    -class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    +class StableDiffusionGeneratorPipeline(StableDiffusionControlNetPipeline):
    +#class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    r"""
    Pipeline for text-to-image generation using Stable Diffusion.

    @@ -280,13 +300,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    text_encoder: CLIPTextModel,
    tokenizer: CLIPTokenizer,
    unet: UNet2DConditionModel,
    + controlnet: ControlNetModel,
    scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
    safety_checker: Optional[StableDiffusionSafetyChecker],
    feature_extractor: Optional[CLIPFeatureExtractor],
    requires_safety_checker: bool = False,
    precision: str = 'float32',
    ):
    - super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
    + #print(vae)
    + super().__init__(vae, text_encoder, tokenizer, unet, controlnet, scheduler,
    safety_checker, feature_extractor, requires_safety_checker)

    self.register_modules(
    @@ -297,6 +319,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    scheduler=scheduler,
    safety_checker=safety_checker,
    feature_extractor=feature_extractor,
    + controlnet=controlnet,
    + #**controlnets,
    )
    self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
    use_full_precision = (precision == 'float32' or precision == 'autocast')
    @@ -316,13 +340,90 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    self.enable_vae_slicing()
    self.enable_attention_slicing()

    + #self.controlnet_cond = None
    + #self.control = None
    + #self.controller = None
    + #self.controlnets = self.controlnet_models_dict()
    + self.control_forward = None
    +
    + def controlnet_model(self, controlnet_name: str) -> ControlNetModel:
    + #print(self.unet.device)
    + return ControlNetModel.from_pretrained(
    + #f'takuma104/control_sd15_{controlnet_name}',
    + f'/data/models/control_models/control_nedream_{controlnet_name}',
    + subfolder='controlnet',
    + torch_dtype=self.unet.dtype
    + )
    +
    + def controlnet_models_dict(self):
    + return {f"controlnet_{cn}": self.controlnet_model(cn) for cn in CONTROLNETS}
    +
    + def controlnet_models(self):
    + return [self.controlnet_model(cn) for cn in CONTROLNETS]
    +
    +
    + def control_to_model(self, control):
    + if 'scribble' in control:
    + key = 'scribble'
    + elif 'hough' in control:
    + key = 'mlsd'
    + elif 'segmentation' in control:
    + key = 'seg'
    + else:
    + key = control
    +
    + return self.controlnet_model(key)
    + #return self.controlnets[key]
    +
    +
    + def controller(self, control_dict, image):
    +
    +
    + controls = [c for c, v in control_dict.items() if v]
    + w, h = trim_to_multiple_of(*image.size)
    + ctrl = {}
    + for control in controls:
    + hinter = getattr(controlnet_hinter, f'hint_{control}')
    + control_image = hinter(image)
    + controlnet_cond = self.preprocess(control_image, w, h)
    + control_model = self.control_to_model(control)
    + ctrl[control] = controlnet_cond, control_model
    +
    +
    + def controller_forward(latents, t, encoder_hidden_states):
    + down_res, mid_res = None, None
    + for control in controls:
    + controlnet_cond, control_model = ctrl[control]
    + model = control_model.to(self._execution_device)
    +
    + dr, mr = model(
    + latents,
    + t,
    + encoder_hidden_states=encoder_hidden_states,
    + controlnet_cond=controlnet_cond.to(self._execution_device, dtype=self.controlnet.dtype),
    + return_dict=False,
    + )
    + down_res = dr if down_res is None else [sum(z) for z in zip(down_res, dr)]
    + mid_res = mr if mid_res is None else sum([mid_res, mr])
    +
    + unet_inputs = {
    + 'down_block_additional_residuals': down_res,
    + 'mid_block_additional_residual': mid_res,
    + }
    +
    + return unet_inputs
    +
    + return controller_forward
    +

    def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
    conditioning_data: ConditioningData,
    *,
    noise: torch.Tensor,
    + noise_func=None,
    callback: Callable[[PipelineIntermediateState], None]=None,
    - run_id=None) -> InvokeAIStableDiffusionPipelineOutput:
    + control=None,
    + run_id=None, init_image=None) -> InvokeAIStableDiffusionPipelineOutput:
    r"""
    Function invoked when calling the pipeline for generation.

    @@ -335,6 +436,25 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    :param callback:
    :param run_id:
    """
    +
    + if init_image and control:
    + # This is for the case where someone is in img2img mode but only wants the init image to be used
    + # as a control hint, not as a base image. In the UI this is represented as turning up "strength" to 1.0
    + self.control_forward = self.controller(control, init_image)
    +
    + if isinstance(init_image, PIL.Image.Image):
    + init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
    +
    + if init_image.dim() == 3:
    + init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
    +
    + device = self.unet.device
    + latents_dtype = self.unet.dtype
    + latents = torch.zeros_like(self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype))
    + noise = noise_func(latents)
    +
    +
    +
    result_latents, result_attention_map_saver = self.latents_from_embeddings(
    latents, num_inference_steps,
    conditioning_data,
    @@ -347,6 +467,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    with torch.inference_mode():
    image = self.decode_latents(result_latents)
    output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
    + self.control_forward = None
    return self.check_for_safety(output, dtype=conditioning_data.dtype)

    def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
    @@ -458,9 +579,29 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):

    return step_output

    + def preprocess(self, image, width, height):
    + #height = height or self.unet.config.sample_size * self.vae_scale_factor
    + #width = width or self.unet.config.sample_size * self.vae_scale_factor
    + if isinstance(image, torch.Tensor):
    + return image
    + elif isinstance(image, PIL.Image.Image):
    + image = [image]
    +
    + if isinstance(image[0], PIL.Image.Image):
    + image = [np.array(i.convert('RGB').resize((width, height), resample=PIL.Image.Resampling.LANCZOS))[None, :] for i in image]
    + image = np.concatenate(image, axis=0)
    + image = np.array(image).astype(np.float32) / 255.0
    + image = image[:, :, :, ::-1] # RGB -> BGR
    + image = image.transpose(0, 3, 1, 2)
    + image = torch.from_numpy(image.copy()) # copy: ::-1 workaround
    + elif isinstance(image[0], torch.Tensor):
    + image = torch.cat(image, dim=0)
    + return image
    +
    def _unet_forward(self, latents, t, text_embeddings):
    latents = latents.to(self.unet.device, dtype=self.unet.dtype)
    """predict the noise residual"""
    +
    if is_inpainting_model(self.unet) and latents.size(1) == 4:
    # Pad out normal non-inpainting inputs for an inpainting model.
    # FIXME: There are too many layers of functions and we have too many different ways of
    @@ -472,7 +613,35 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
    ).add_mask_channels(latents)

    - return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
    +
    + if self.control_forward:
    + unet_inputs = self.control_forward(latents, t, text_embeddings)
    + else:
    + unet_inputs = {}
    + #if self.controlnet_cond is not None:
    + # #controlnet_cond = self.preprocess(self.control_image).to(device=self._execution_device, dtype=self.controlnet.dtype)
    + # controlnet_cond = self.controlnet_cond.to(self._execution_device, dtype=self.controlnet.dtype)
    +
    + # down_res, mid_res = self.controller(
    + # latents,
    + # t,
    + # encoder_hidden_states=text_embeddings,
    + # controlnet_cond=controlnet_cond,
    + # return_dict=False,
    + # )
    +
    + # unet_inputs = {
    + # 'down_block_additional_residuals': down_res,
    + # 'mid_block_additional_residual': mid_res,
    + # }
    + #else:
    + # unet_inputs = {}
    +
    + return self.unet(
    + latents, t,
    + encoder_hidden_states=text_embeddings,
    + **unet_inputs,
    + ).sample

    def img2img_from_embeddings(self,
    init_image: Union[torch.FloatTensor, PIL.Image.Image],
    @@ -481,34 +650,57 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    conditioning_data: ConditioningData,
    *, callback: Callable[[PipelineIntermediateState], None] = None,
    run_id=None,
    - noise_func=None
    + noise_func=None,
    + control=None
    ) -> InvokeAIStableDiffusionPipelineOutput:
    +
    +
    + if control:
    + self.control_forward = self.controller(control, init_image)
    +
    +
    if isinstance(init_image, PIL.Image.Image):
    init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))

    if init_image.dim() == 3:
    init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')

    +
    +
    +
    # 6. Prepare latent variables
    device = self.unet.device
    latents_dtype = self.unet.dtype
    initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
    noise = noise_func(initial_latents)

    - return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
    + result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
    conditioning_data,
    strength,
    noise, run_id, callback)

    + self.control_forward = None
    + return result
    +
    + def get_timesteps(self, scheduler, num_inference_steps, strength, device):
    + # get the original timestep using init_timestep
    + init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
    +
    + t_start = max(num_inference_steps - init_timestep, 0)
    + timesteps = scheduler.timesteps[t_start:]
    +
    + return timesteps, num_inference_steps - t_start
    +
    def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps,
    conditioning_data: ConditioningData,
    strength,
    noise: torch.Tensor, run_id=None, callback=None
    ) -> InvokeAIStableDiffusionPipelineOutput:
    device = self.unet.device
    - img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
    + #img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
    + img2img_pipeline = StableDiffusionControlNetPipeline(**self.components)
    img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
    - timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
    + timesteps, _ = self.get_timesteps(img2img_pipeline.scheduler, num_inference_steps, strength, device=device)

    result_latents, result_attention_maps = self.latents_from_embeddings(
    initial_latents, num_inference_steps, conditioning_data,
    @@ -535,10 +727,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    *, callback: Callable[[PipelineIntermediateState], None] = None,
    run_id=None,
    noise_func=None,
    + control=None,
    + pil_init_image=None
    ) -> InvokeAIStableDiffusionPipelineOutput:
    device = self.unet.device
    latents_dtype = self.unet.dtype

    + if control:
    + self.control_forward = self.controller(control, pil_init_image)
    +
    +
    if isinstance(init_image, PIL.Image.Image):
    init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))

    @@ -547,9 +745,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    if init_image.dim() == 3:
    init_image = init_image.unsqueeze(0)

    - img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
    + #img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
    + img2img_pipeline = StableDiffusionControlNetPipeline(**self.components)
    img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
    - timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
    + timesteps, _ = self.get_timesteps(img2img_pipeline.scheduler, num_inference_steps, strength, device=device)

    assert img2img_pipeline.scheduler is self.scheduler

    @@ -588,6 +787,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    with torch.inference_mode():
    image = self.decode_latents(result_latents)
    output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
    + self.control_forward = None
    return self.check_for_safety(output, dtype=conditioning_data.dtype)

    def non_noised_latents_from_image(self, init_image, *, device, dtype):
    diff --git a/backend/ldm/invoke/generator/img2img.py b/backend/ldm/invoke/generator/img2img.py
    index fedf6d3a..ee6d1379 100644
    --- a/backend/ldm/invoke/generator/img2img.py
    +++ b/backend/ldm/invoke/generator/img2img.py
    @@ -8,6 +8,7 @@ from diffusers import logging
    from ldm.invoke.generator.base import Generator
    from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
    from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
    +from diffusers.utils import load_image


    class Img2Img(Generator):
    @@ -17,7 +18,7 @@ class Img2Img(Generator):

    def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
    conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
    - attention_maps_callback=None,
    + attention_maps_callback=None, control=None,
    **kwargs):
    """
    Returns a function returning an image derived from the prompt and the initial image
    @@ -28,6 +29,7 @@ class Img2Img(Generator):
    # noinspection PyTypeChecker
    pipeline: StableDiffusionGeneratorPipeline = self.model
    pipeline.scheduler = sampler
    + pipeline.control_forward = None

    uc, c, extra_conditioning_info = conditioning
    conditioning_data = (
    @@ -42,11 +44,24 @@ class Img2Img(Generator):
    # We're not at the moment because the pipeline automatically resizes init_image if
    # necessary, which the x_T input might not match.
    logging.set_verbosity_error() # quench safety check warnings
    - pipeline_output = pipeline.img2img_from_embeddings(
    - init_image, strength, steps, conditioning_data,
    - noise_func=self.get_noise_like,
    - callback=step_callback
    - )
    + if strength == 1.0:
    + pipeline_output = pipeline.image_from_embeddings(
    + latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
    + noise=x_T,
    + noise_func=self.get_noise_like,
    + num_inference_steps=steps,
    + conditioning_data=conditioning_data,
    + callback=step_callback,
    + init_image=init_image,
    + control=control,
    + )
    + else:
    + pipeline_output = pipeline.img2img_from_embeddings(
    + init_image, strength, steps, conditioning_data,
    + noise_func=self.get_noise_like,
    + callback=step_callback,
    + control=control
    + )
    if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
    attention_maps_callback(pipeline_output.attention_map_saver)
    return pipeline.numpy_to_pil(pipeline_output.images)[0]
    diff --git a/backend/ldm/invoke/generator/inpaint.py b/backend/ldm/invoke/generator/inpaint.py
    index dd43d3cd..f0db3e26 100644
    --- a/backend/ldm/invoke/generator/inpaint.py
    +++ b/backend/ldm/invoke/generator/inpaint.py
    @@ -11,6 +11,7 @@ import numpy as np
    import torch
    from PIL import Image, ImageFilter, ImageOps, ImageChops

    +from diffusers import StableDiffusionInpaintPipeline
    from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline, \
    ConditioningData
    from ldm.invoke.generator.img2img import Img2Img
    @@ -183,6 +184,7 @@ class Inpaint(Img2Img):
    inpaint_width=None,
    inpaint_height=None,
    attention_maps_callback=None,
    + control=None,
    **kwargs):
    """
    Returns a function returning an image derived from the prompt and
    @@ -244,6 +246,9 @@ class Inpaint(Img2Img):
    # noinspection PyTypeChecker
    pipeline: StableDiffusionGeneratorPipeline = self.model
    pipeline.scheduler = sampler
    + pipeline.control_forward = None
    +
    + #pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=pipeline.unet.dtype)

    # todo: support cross-attention control
    uc, c, _ = conditioning
    @@ -252,14 +257,20 @@ class Inpaint(Img2Img):


    def make_image(x_T):
    + #orig_unet = pipeline.unet
    + #pipeline.unet = pipe_inpaint.unet.to(pipeline.device)
    + #pipeline.unet.in_channels = 4
    +
    pipeline_output = pipeline.inpaint_from_embeddings(
    init_image=init_image,
    + pil_init_image=self.pil_image,
    mask=1 - mask, # expects white means "paint here."
    strength=strength,
    num_inference_steps=steps,
    conditioning_data=conditioning_data,
    noise_func=self.get_noise_like,
    callback=step_callback,
    + control=control,
    )

    if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
    @@ -289,6 +300,7 @@ class Inpaint(Img2Img):
    infill_method = infill_method,
    **kwargs)

    + #pipeline.unet = orig_unet
    return result

    return make_image
    diff --git a/backend/ldm/invoke/generator/txt2img.py b/backend/ldm/invoke/generator/txt2img.py
    index 38b4415c..d3f4045f 100644
    --- a/backend/ldm/invoke/generator/txt2img.py
    +++ b/backend/ldm/invoke/generator/txt2img.py
    @@ -16,7 +16,7 @@ class Txt2Img(Generator):
    @torch.no_grad()
    def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
    conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
    - attention_maps_callback=None,
    + attention_maps_callback=None, control=None,
    **kwargs):
    """
    Returns a function returning an image derived from the prompt and the initial image
    @@ -28,6 +28,7 @@ class Txt2Img(Generator):
    # noinspection PyTypeChecker
    pipeline: StableDiffusionGeneratorPipeline = self.model
    pipeline.scheduler = sampler
    + pipeline.control_forward = None

    uc, c, extra_conditioning_info = conditioning
    conditioning_data = (
    diff --git a/backend/ldm/invoke/model_manager.py b/backend/ldm/invoke/model_manager.py
    index 2f68df73..9d324e4a 100644
    --- a/backend/ldm/invoke/model_manager.py
    +++ b/backend/ldm/invoke/model_manager.py
    @@ -422,7 +422,7 @@ class ModelManager(object):

    return model, width, height, model_hash

    - def _load_diffusers_model(self, mconfig):
    + def _load_diffusers_model(self, mconfig, control=True):
    name_or_path = self.model_name_or_path(mconfig)
    using_fp16 = self.precision == 'float16'

    @@ -437,6 +437,10 @@ class ModelManager(object):
    safety_checker=None,
    local_files_only=not Globals.internet_available
    )
    + #if control:
    + # controlnets = StableDiffusionGeneratorPipeline.controlnet_models()
    + # pipeline_args.update(controlnets=controlnets)
    +
    if 'vae' in mconfig and mconfig['vae'] is not None:
    vae = self._load_vae(mconfig['vae'])
    pipeline_args.update(vae=vae)
    diff --git a/backend/ldm/models/diffusion/ddpm.py b/backend/ldm/models/diffusion/ddpm.py
    index 7c7ba9f5..422f61a7 100644
    --- a/backend/ldm/models/diffusion/ddpm.py
    +++ b/backend/ldm/models/diffusion/ddpm.py
    @@ -1013,7 +1013,7 @@ class LatentDiffusion(DDPM):
    xc = x
    if not self.cond_stage_trainable or force_c_encode:
    if isinstance(xc, dict) or isinstance(xc, list):
    - # import pudb; pudb.set_trace()
    + # import pudb; pudb.set_trace() #
    c = self.get_learned_conditioning(xc)
    else:
    c = self.get_learned_conditioning(xc.to(self.device))
    --
    2.39.1