@@ -1,26 +1,29 @@
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
# *Only* converts the UNet, VAE, and Text Encoder.
# Does not convert optimizer state or any other thing.
import argparse
import os
import os .path as osp
import torch
#=================#
# =================#
# UNet Conversion #
#=================#
# =================#
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
(' time_embed.0.weight' , ' time_embedding.linear_1.weight' ),
(' time_embed.0.bias' , ' time_embedding.linear_1.bias' ),
(' time_embed.2.weight' , ' time_embedding.linear_2.weight' ),
(' time_embed.2.bias' , ' time_embedding.linear_2.bias' ),
(' input_blocks.0.0.weight' , ' conv_in.weight' ),
(' input_blocks.0.0.bias' , ' conv_in.bias' ),
(' out.0.weight' , ' conv_norm_out.weight' ),
(' out.0.bias' , ' conv_norm_out.bias' ),
(' out.2.weight' , ' conv_out.weight' ),
(' out.2.bias' , ' conv_out.bias' )
(" time_embed.0.weight" , " time_embedding.linear_1.weight" ),
(" time_embed.0.bias" , " time_embedding.linear_1.bias" ),
(" time_embed.2.weight" , " time_embedding.linear_2.weight" ),
(" time_embed.2.bias" , " time_embedding.linear_2.bias" ),
(" input_blocks.0.0.weight" , " conv_in.weight" ),
(" input_blocks.0.0.bias" , " conv_in.bias" ),
(" out.0.weight" , " conv_norm_out.weight" ),
(" out.0.bias" , " conv_norm_out.bias" ),
(" out.2.weight" , " conv_out.weight" ),
(" out.2.bias" , " conv_out.bias" ),
]
unet_conversion_map_resnet = [
@@ -30,7 +33,7 @@
("out_layers.0" , "norm2" ),
("out_layers.3" , "conv2" ),
("emb_layers.1" , "time_emb_proj" ),
("skip_connection" , "conv_shortcut" )
("skip_connection" , "conv_shortcut" ),
]
unet_conversion_map_layer = []
@@ -41,150 +44,156 @@
for j in range (2 ):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f' down_blocks.{ i } .resnets.{ j } .'
sd_down_res_prefix = f' input_blocks.{ 3 * i + j + 1 } .0.'
hf_down_res_prefix = f" down_blocks.{ i } .resnets.{ j } ."
sd_down_res_prefix = f" input_blocks.{ 3 * i + j + 1 } .0."
unet_conversion_map_layer .append ((sd_down_res_prefix , hf_down_res_prefix ))
if i < 3 :
# no attention layers in down_blocks.3
hf_down_atn_prefix = f' down_blocks.{ i } .attentions.{ j } .'
sd_down_atn_prefix = f' input_blocks.{ 3 * i + j + 1 } .1.'
hf_down_atn_prefix = f" down_blocks.{ i } .attentions.{ j } ."
sd_down_atn_prefix = f" input_blocks.{ 3 * i + j + 1 } .1."
unet_conversion_map_layer .append ((sd_down_atn_prefix , hf_down_atn_prefix ))
for j in range (3 ):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f' up_blocks.{ i } .resnets.{ j } .'
sd_up_res_prefix = f' output_blocks.{ 3 * i + j } .0.'
hf_up_res_prefix = f" up_blocks.{ i } .resnets.{ j } ."
sd_up_res_prefix = f" output_blocks.{ 3 * i + j } .0."
unet_conversion_map_layer .append ((sd_up_res_prefix , hf_up_res_prefix ))
if i > 0 :
# no attention layers in up_blocks.0
hf_up_atn_prefix = f' up_blocks.{ i } .attentions.{ j } .'
sd_up_atn_prefix = f' output_blocks.{ 3 * i + j } .1.'
hf_up_atn_prefix = f" up_blocks.{ i } .attentions.{ j } ."
sd_up_atn_prefix = f" output_blocks.{ 3 * i + j } .1."
unet_conversion_map_layer .append ((sd_up_atn_prefix , hf_up_atn_prefix ))
if i < 3 :
# no downsample in down_blocks.3
hf_downsample_prefix = f' down_blocks.{ i } .downsamplers.0.conv.'
sd_downsample_prefix = f' input_blocks.{ 3 * (i + 1 )} .0.op.'
hf_downsample_prefix = f" down_blocks.{ i } .downsamplers.0.conv."
sd_downsample_prefix = f" input_blocks.{ 3 * (i + 1 )} .0.op."
unet_conversion_map_layer .append ((sd_downsample_prefix , hf_downsample_prefix ))
# no upsample in up_blocks.3
hf_upsample_prefix = f' up_blocks.{ i } .upsamplers.0.'
sd_upsample_prefix = f' output_blocks.{ 3 * i + 2 } .{ 1 if i == 0 else 2 } .'
hf_upsample_prefix = f" up_blocks.{ i } .upsamplers.0."
sd_upsample_prefix = f" output_blocks.{ 3 * i + 2 } .{ 1 if i == 0 else 2 } ."
unet_conversion_map_layer .append ((sd_upsample_prefix , hf_upsample_prefix ))
hf_mid_atn_prefix = ' mid_block.attentions.0.'
sd_mid_atn_prefix = ' middle_block.1.'
hf_mid_atn_prefix = " mid_block.attentions.0."
sd_mid_atn_prefix = " middle_block.1."
unet_conversion_map_layer .append ((sd_mid_atn_prefix , hf_mid_atn_prefix ))
for j in range (2 ):
hf_mid_res_prefix = f' mid_block.resnets.{ j } .'
sd_mid_res_prefix = f' middle_block.{ 2 * j } .'
hf_mid_res_prefix = f" mid_block.resnets.{ j } ."
sd_mid_res_prefix = f" middle_block.{ 2 * j } ."
unet_conversion_map_layer .append ((sd_mid_res_prefix , hf_mid_res_prefix ))
def convert_unet_state_dict (unet_state_dict ):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k :k for k in unet_state_dict .keys ()}
mapping = {k : k for k in unet_state_dict .keys ()}
for sd_name , hf_name in unet_conversion_map :
mapping [hf_name ] = sd_name
for k ,v in mapping .items ():
if ' resnets' in k :
for k , v in mapping .items ():
if " resnets" in k :
for sd_part , hf_part in unet_conversion_map_resnet :
v = v .replace (hf_part , sd_part )
mapping [k ] = v
for k ,v in mapping .items ():
for k , v in mapping .items ():
for sd_part , hf_part in unet_conversion_map_layer :
v = v .replace (hf_part , sd_part )
mapping [k ] = v
new_state_dict = {v :unet_state_dict [k ] for k ,v in mapping .items ()}
new_state_dict = {v : unet_state_dict [k ] for k , v in mapping .items ()}
return new_state_dict
#================#
# ================#
# VAE Conversion #
#================#
# ================#
vae_conversion_map = [
# (stable-diffusion, HF Diffusers)
(' nin_shortcut' , ' conv_shortcut' ),
(' norm_out' , ' conv_norm_out' ),
(' mid.attn_1.' , ' mid_block.attentions.0.' )
(" nin_shortcut" , " conv_shortcut" ),
(" norm_out" , " conv_norm_out" ),
(" mid.attn_1." , " mid_block.attentions.0." ),
]
for i in range (4 ):
# down_blocks have two resnets
for j in range (2 ):
hf_down_prefix = f' encoder.down_blocks.{ i } .resnets.{ j } .'
sd_down_prefix = f' encoder.down.{ i } .block.{ j } .'
hf_down_prefix = f" encoder.down_blocks.{ i } .resnets.{ j } ."
sd_down_prefix = f" encoder.down.{ i } .block.{ j } ."
vae_conversion_map .append ((sd_down_prefix , hf_down_prefix ))
if i < 3 :
hf_downsample_prefix = f' down_blocks.{ i } .downsamplers.0.'
sd_downsample_prefix = f' down.{ i } .downsample.'
hf_downsample_prefix = f" down_blocks.{ i } .downsamplers.0."
sd_downsample_prefix = f" down.{ i } .downsample."
vae_conversion_map .append ((sd_downsample_prefix , hf_downsample_prefix ))
hf_upsample_prefix = f' up_blocks.{ i } .upsamplers.0.'
sd_upsample_prefix = f' up.{ 3 - i } .upsample.'
hf_upsample_prefix = f" up_blocks.{ i } .upsamplers.0."
sd_upsample_prefix = f" up.{ 3 - i } .upsample."
vae_conversion_map .append ((sd_upsample_prefix , hf_upsample_prefix ))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range (3 ):
hf_up_prefix = f' decoder.up_blocks.{ i } .resnets.{ j } .'
sd_up_prefix = f' decoder.up.{ 3 - i } .block.{ j } .'
hf_up_prefix = f" decoder.up_blocks.{ i } .resnets.{ j } ."
sd_up_prefix = f" decoder.up.{ 3 - i } .block.{ j } ."
vae_conversion_map .append ((sd_up_prefix , hf_up_prefix ))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range (2 ):
hf_mid_res_prefix = f' mid_block.resnets.{ i } .'
sd_mid_res_prefix = f' mid.block_{ i + 1 } .'
hf_mid_res_prefix = f" mid_block.resnets.{ i } ."
sd_mid_res_prefix = f" mid.block_{ i + 1 } ."
vae_conversion_map .append ((sd_mid_res_prefix , hf_mid_res_prefix ))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
(' norm.' , ' group_norm.' ),
('q.' , ' query.' ),
('k.' , ' key.' ),
('v.' , ' value.' ),
(' proj_out.' , ' proj_attn.' )
(" norm." , " group_norm." ),
("q." , " query." ),
("k." , " key." ),
("v." , " value." ),
(" proj_out." , " proj_attn." ),
]
def reshape_weight_for_sd (w ):
# convert HF linear weights to SD conv2d weights
return w .reshape (* w .shape , 1 , 1 )
def convert_vae_state_dict (vae_state_dict ):
mapping = {k :k for k in vae_state_dict .keys ()}
for k ,v in mapping .items ():
mapping = {k : k for k in vae_state_dict .keys ()}
for k , v in mapping .items ():
for sd_part , hf_part in vae_conversion_map :
v = v .replace (hf_part , sd_part )
mapping [k ] = v
for k ,v in mapping .items ():
if ' attentions' in k :
for k , v in mapping .items ():
if " attentions" in k :
for sd_part , hf_part in vae_conversion_map_attn :
v = v .replace (hf_part , sd_part )
mapping [k ] = v
new_state_dict = {v :vae_state_dict [k ] for k ,v in mapping .items ()}
weights_to_convert = ['q' , 'k' , 'v' , ' proj_out' ]
for k ,v in new_state_dict .items ():
new_state_dict = {v : vae_state_dict [k ] for k , v in mapping .items ()}
weights_to_convert = ["q" , "k" , "v" , " proj_out" ]
for k , v in new_state_dict .items ():
for weight_name in weights_to_convert :
if f' mid.attn_1.{ weight_name } .weight' in k :
print (f' Reshaping { k } for SD format' )
if f" mid.attn_1.{ weight_name } .weight" in k :
print (f" Reshaping { k } for SD format" )
new_state_dict [k ] = reshape_weight_for_sd (v )
return new_state_dict
#=========================#
# =========================#
# Text Encoder Conversion #
#=========================#
# =========================#
# pretty much a no-op
def convert_text_enc_state_dict (text_enc_dict ):
return text_enc_dict
if __name__ == "__main__" :
parser = argparse .ArgumentParser ()
@@ -193,32 +202,30 @@ def convert_text_enc_state_dict(text_enc_dict):
args = parser .parse_args ()
assert args .model_path is not None , \
"Must provide a model path!"
assert args .model_path is not None , "Must provide a model path!"
assert args .checkpoint_path is not None , \
"Must provide a checkpoint path!"
assert args .checkpoint_path is not None , "Must provide a checkpoint path!"
unet_path = osp .join (args .model_path , ' unet' , ' diffusion_pytorch_model.bin' )
vae_path = osp .join (args .model_path , ' vae' , ' diffusion_pytorch_model.bin' )
text_enc_path = osp .join (args .model_path , ' text_encoder' , ' pytorch_model.bin' )
unet_path = osp .join (args .model_path , " unet" , " diffusion_pytorch_model.bin" )
vae_path = osp .join (args .model_path , " vae" , " diffusion_pytorch_model.bin" )
text_enc_path = osp .join (args .model_path , " text_encoder" , " pytorch_model.bin" )
# Convert the UNet model
unet_state_dict = torch .load (unet_path )
unet_state_dict = torch .load (unet_path , map_location = 'cpu' )
unet_state_dict = convert_unet_state_dict (unet_state_dict )
unet_state_dict = {"model.diffusion_model." + k : v for k ,v in unet_state_dict .items ()}
unet_state_dict = {"model.diffusion_model." + k : v for k , v in unet_state_dict .items ()}
# Convert the VAE model
vae_state_dict = torch .load (vae_path )
vae_state_dict = torch .load (vae_path , map_location = 'cpu' )
vae_state_dict = convert_vae_state_dict (vae_state_dict )
vae_state_dict = {"first_stage_model." + k : v for k ,v in vae_state_dict .items ()}
vae_state_dict = {"first_stage_model." + k : v for k , v in vae_state_dict .items ()}
# Convert the text encoder model
text_enc_dict = torch .load (text_enc_path )
text_enc_dict = torch .load (text_enc_path , map_location = 'cpu' )
text_enc_dict = convert_text_enc_state_dict (text_enc_dict )
text_enc_dict = {"cond_stage_model.transformer." + k : v for k ,v in text_enc_dict .items ()}
text_enc_dict = {"cond_stage_model.transformer." + k : v for k , v in text_enc_dict .items ()}
# Put together new checkpoint
state_dict = {** unet_state_dict , ** vae_state_dict , ** text_enc_dict }
state_dict = {"state_dict" : state_dict }
torch .save (state_dict , args .checkpoint_path )
torch .save (state_dict , args .checkpoint_path )