Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save AmesianX/133831676e24c8812e52da492cd4d074 to your computer and use it in GitHub Desktop.
Save AmesianX/133831676e24c8812e52da492cd4d074 to your computer and use it in GitHub Desktop.

Revisions

  1. @kohya-ss kohya-ss created this gist Nov 14, 2023.
    68 changes: 68 additions & 0 deletions forward_of_sdxl_original_unet.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,68 @@
    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
    # broadcast timesteps to batch dimension
    timesteps = timesteps.expand(x.shape[0])

    hs = []
    t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
    t_emb = t_emb.to(x.dtype)
    emb = self.time_embed(t_emb)

    assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
    assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
    # assert x.dtype == self.dtype
    emb = emb + self.label_emb(y)

    def call_module(module, h, emb, context):
    x = h
    for layer in module:
    # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
    if isinstance(layer, ResnetBlock2D):
    x = layer(x, emb)
    elif isinstance(layer, Transformer2DModel):
    x = layer(x, context)
    else:
    x = layer(x)
    return x

    # h = x.type(self.dtype)
    h = x

    # downsample depth
    # 深くすると全体の構図は安定するが、キャラがゆがむ。浅くしすぎると細部が混沌とする
    ds_depth_1 = 3 # 2~4 くらいがよさそう
    ds_depth_2 = 3 # depth_1より+0~+2くらいがよさそう
    # downsample timestep
    # 大きくすると構図が乱れて、小さくするとディテールが乱れる
    ds_timestep_1 = 900
    ds_timestep_2 = 650 # timestep_1より小さいこと。デッサンに影響する

    depth = 0 # current depth
    for module in self.input_blocks:
    h = call_module(module, h, emb, context)
    hs.append(h)

    # print(depth, h.shape, timesteps)
    if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or (
    depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2
    ):
    # bicubicでないとゆがむ、align_cornersはあまり影響しない模様
    h = F.interpolate(h.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応
    depth += 1

    h = call_module(self.middle_block, h, emb, context)

    for module in self.output_blocks:
    depth -= 1
    # print(depth, h.shape)
    if (depth == ds_depth_1 and timesteps[0] > ds_timestep_1) or (
    depth == ds_depth_2 and ds_timestep_1 > timesteps[0] and timesteps[0] > ds_timestep_2
    ):
    h = F.interpolate(h.float(), scale_factor=2.0, mode="bicubic", align_corners=False).to(h.dtype) # bfloat16対応

    h = torch.cat([h, hs.pop()], dim=1)
    h = call_module(module, h, emb, context)

    h = h.type(x.dtype)
    h = call_module(self.out, h, emb, context)

    return h