Forked from kohya-ss/forward_of_sdxl_original_unet.py
Created
February 21, 2024 13:02
-
-
Save AmesianX/133831676e24c8812e52da492cd4d074 to your computer and use it in GitHub Desktop.
Revisions
-
kohya-ss created this gist
Nov 14, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,68 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): # broadcast timesteps to batch dimension timesteps = timesteps.expand(x.shape[0]) hs = [] t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) t_emb = t_emb.to(x.dtype) emb = self.time_embed(t_emb) assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" # assert x.dtype == self.dtype emb = emb + self.label_emb(y) def call_module(module, h, emb, context): x = h for layer in module: # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) if isinstance(layer, ResnetBlock2D): x = layer(x, emb) elif isinstance(layer, Transformer2DModel): x = layer(x, context) else: x = layer(x) return x # h = x.type(self.dtype) h = x # downsample depth # 深くすると全体の構図は安定するが、キャラがゆがむ。浅くしすぎると細部が混沌とする ds_depth_1 = 3 # 2~4 くらいがよさそう ds_depth_2 = 3 # depth_1より+0~+2くらいがよさそう # downsample timestep # 大きくすると構図が乱れて、小さくするとディテールが乱れる ds_timestep_1 = 900 ds_timestep_2 = 650 # timestep_1より小さいこと。デッサンに影響する depth = 0 # current depth for module in self.input_blocks: h = call_module(module, h, emb, context) hs.append(h) # print(depth, h.shape, timesteps) if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or ( depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2 ): # bicubicでないとゆがむ、align_cornersはあまり影響しない模様 h = F.interpolate(h.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応 depth += 1 h = call_module(self.middle_block, h, emb, context) for module in self.output_blocks: depth -= 1 # print(depth, h.shape) if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or ( depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2 ): h = F.interpolate(h.float(), scale_factor=2.0, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応 h = torch.cat([h, hs.pop()], dim=1) h = call_module(module, h, emb, context) h = h.type(x.dtype) h = call_module(self.out, h, emb, context) return h