Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active August 26, 2025 15:32
Show Gist options
  • Save blepping/fbb92a23bc9697976cc0555a0af3d9af to your computer and use it in GitHub Desktop.
Save blepping/fbb92a23bc9697976cc0555a0af3d9af to your computer and use it in GitHub Desktop.

Revisions

  1. blepping revised this gist Dec 7, 2024. 1 changed file with 4 additions and 54 deletions.
    58 changes: 4 additions & 54 deletions comfyui_sageattention.py
    Original file line number Diff line number Diff line change
    @@ -1,54 +1,4 @@
    # Install https://github.com/thu-ml/SageAttention - it requires the latest as of 20241110.
    # Put this file in custom_nodes/ - it will add a SageAttention node.
    # NOTES:
    # 1. Not an actual model patch, to enable or disable you must ensure the node runs.
    # 2. No point in using this for SD15 as none of the dimensions are compatible (it will just delegate).
    # 3. It will delegate to the default attention whenever the head dimensions are incompatible or k, v dimensions don't match q.
    import torch
    from comfy.ldm.modules import attention as comfy_attention

    from sageattention import sageattn

    orig_attention = comfy_attention.optimized_attention

    class SageAttentionNode:
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "go"
    CATEGORY = "model_patches"

    @classmethod
    def INPUT_TYPES(cls):
    return {
    "required": {
    "model": ("MODEL",),
    "enabled": ("BOOLEAN", {"default": True},),
    "smooth_k": ("BOOLEAN", {"default": True},),
    }
    }

    @classmethod
    def go(cls, *, model: object, enabled: bool, smooth_k: bool):
    def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
    if skip_reshape:
    b, _, _, dim_head = q.shape
    else:
    b, _, dim_head = q.shape
    dim_head //= heads
    if dim_head not in (64, 96, 128):
    return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
    if not skip_reshape:
    q, k, v = map(
    lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
    (q, k, v),
    )
    return (
    sageattn(q, k, v, is_causal=False, attn_mask=mask, dropout_p=0.0, smooth_k=smooth_k)
    .transpose(1, 2)
    .reshape(b, -1, heads * dim_head)
    )
    comfy_attention.optimized_attention = orig_attention if not enabled else attention_sage
    return (model,)

    NODE_CLASS_MAPPINGS = {
    "SageAttention": SageAttentionNode,
    }
    # NO LONGER MAINTAINED
    # There is an improved version in my ComfyUI-bleh node pack:
    # https://github.com/blepping/ComfyUI-bleh#blehsageattentionsampler
    # If you really want to use the gist, see the previous revision.
  2. blepping revised this gist Nov 11, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion comfyui_sageattention.py
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    # Install https://github.com/thu-ml/SageAttention
    # Install https://github.com/thu-ml/SageAttention - it requires the latest as of 20241110.
    # Put this file in custom_nodes/ - it will add a SageAttention node.
    # NOTES:
    # 1. Not an actual model patch, to enable or disable you must ensure the node runs.
  3. blepping revised this gist Nov 11, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion comfyui_sageattention.py
    Original file line number Diff line number Diff line change
    @@ -34,7 +34,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
    else:
    b, _, dim_head = q.shape
    dim_head //= heads
    if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape):
    if dim_head not in (64, 96, 128):
    return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
    if not skip_reshape:
    q, k, v = map(
  4. blepping revised this gist Oct 9, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion comfyui_sageattention.py
    Original file line number Diff line number Diff line change
    @@ -34,7 +34,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
    else:
    b, _, dim_head = q.shape
    dim_head //= heads
    if dim_head not in (64, 96, 120) or not (k.shape == q.shape and v.shape == q.shape):
    if dim_head not in (64, 96, 128) or not (k.shape == q.shape and v.shape == q.shape):
    return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
    if not skip_reshape:
    q, k, v = map(
  5. blepping revised this gist Oct 9, 2024. 1 changed file with 14 additions and 12 deletions.
    26 changes: 14 additions & 12 deletions comfyui_sageattention.py
    Original file line number Diff line number Diff line change
    @@ -1,15 +1,14 @@
    # Install https://github.com/thu-ml/SageAttention
    # Put this file in custom_nodes/ - it will add a SageAttention node.
    # NOTE: Not an actual model patch, to enable or disable you must ensure the node runs.
    # NOTE: Also it doesn't work for SD15 and NaNs for SDXL currently.
    try:
    from sageattention import sageattn
    except ImportError:
    from sageattention import attention as sageattn


    # NOTES:
    # 1. Not an actual model patch, to enable or disable you must ensure the node runs.
    # 2. No point in using this for SD15 as none of the dimensions are compatible (it will just delegate).
    # 3. It will delegate to the default attention whenever the head dimensions are incompatible or k, v dimensions don't match q.
    import torch
    from comfy.ldm.modules import attention as comfy_attention

    from sageattention import sageattn

    orig_attention = comfy_attention.optimized_attention

    class SageAttentionNode:
    @@ -35,15 +34,18 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
    else:
    b, _, dim_head = q.shape
    dim_head //= heads
    if dim_head not in (64, 96, 120) or not (k.shape == q.shape and v.shape == q.shape):
    return orig_attention(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
    if not skip_reshape:
    q, k, v = map(
    lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
    (q, k, v),
    )
    out = sageattn(q, k, v, is_causal=False, smooth_k=smooth_k)
    out = (
    out.transpose(1, 2).reshape(b, -1, heads * dim_head)
    return (
    sageattn(q, k, v, is_causal=False, attn_mask=mask, dropout_p=0.0, smooth_k=smooth_k)
    .transpose(1, 2)
    .reshape(b, -1, heads * dim_head)
    )
    return out
    comfy_attention.optimized_attention = orig_attention if not enabled else attention_sage
    return (model,)

  6. blepping revised this gist Oct 5, 2024. 1 changed file with 6 additions and 1 deletion.
    7 changes: 6 additions & 1 deletion comfyui_sageattention.py
    Original file line number Diff line number Diff line change
    @@ -2,7 +2,12 @@
    # Put this file in custom_nodes/ - it will add a SageAttention node.
    # NOTE: Not an actual model patch, to enable or disable you must ensure the node runs.
    # NOTE: Also it doesn't work for SD15 and NaNs for SDXL currently.
    from sageattention import attention as sageattn
    try:
    from sageattention import sageattn
    except ImportError:
    from sageattention import attention as sageattn


    from comfy.ldm.modules import attention as comfy_attention

    orig_attention = comfy_attention.optimized_attention
  7. blepping created this gist Oct 5, 2024.
    47 changes: 47 additions & 0 deletions comfyui_sageattention.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,47 @@
    # Install https://github.com/thu-ml/SageAttention
    # Put this file in custom_nodes/ - it will add a SageAttention node.
    # NOTE: Not an actual model patch, to enable or disable you must ensure the node runs.
    # NOTE: Also it doesn't work for SD15 and NaNs for SDXL currently.
    from sageattention import attention as sageattn
    from comfy.ldm.modules import attention as comfy_attention

    orig_attention = comfy_attention.optimized_attention

    class SageAttentionNode:
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "go"
    CATEGORY = "model_patches"

    @classmethod
    def INPUT_TYPES(cls):
    return {
    "required": {
    "model": ("MODEL",),
    "enabled": ("BOOLEAN", {"default": True},),
    "smooth_k": ("BOOLEAN", {"default": True},),
    }
    }

    @classmethod
    def go(cls, *, model: object, enabled: bool, smooth_k: bool):
    def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
    if skip_reshape:
    b, _, _, dim_head = q.shape
    else:
    b, _, dim_head = q.shape
    dim_head //= heads
    q, k, v = map(
    lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
    (q, k, v),
    )
    out = sageattn(q, k, v, is_causal=False, smooth_k=smooth_k)
    out = (
    out.transpose(1, 2).reshape(b, -1, heads * dim_head)
    )
    return out
    comfy_attention.optimized_attention = orig_attention if not enabled else attention_sage
    return (model,)

    NODE_CLASS_MAPPINGS = {
    "SageAttention": SageAttentionNode,
    }