Skip to content

Instantly share code, notes, and snippets.

@xenova
Last active July 16, 2025 10:41
Show Gist options
  • Save xenova/d48921875c8178de1dd72443cfb6f7c8 to your computer and use it in GitHub Desktop.
Save xenova/d48921875c8178de1dd72443cfb6f7c8 to your computer and use it in GitHub Desktop.

Revisions

  1. xenova revised this gist Feb 17, 2025. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion export-florence-to-onnx.py
    Original file line number Diff line number Diff line change
    @@ -82,7 +82,7 @@ def forward(self, pixel_values):


    vision_model = VisionEncoder()
    w, h = 224, 224
    w, h = 768, 768
    x = torch.randn(2, 3, h, w, requires_grad=True)
    torch.onnx.export(
    vision_model,
  2. xenova created this gist Feb 15, 2025.
    343 changes: 343 additions & 0 deletions export-florence-to-onnx.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,343 @@
    # !pip install --upgrade onnx==1.17.0 onnxruntime==1.20.1 onnxslim==0.1.48 optimum==1.24.0 transformers==4.48.3

    import torch
    import torch.nn as nn
    from transformers import AutoProcessor, AutoModelForCausalLM
    import os
    import onnxslim
    from optimum.onnx.graph_transformations import merge_decoders, check_and_save_model

    model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-base-ft", trust_remote_code=True
    )

    output_dir = "converted"
    os.makedirs(output_dir, exist_ok=True)


    # 1. Export vision encoder
    class VisionEncoder(nn.Module):
    def __init__(self):
    super().__init__()
    self.vision_tower = model.vision_tower
    self.image_projection = model.image_projection
    self.image_proj_norm = model.image_proj_norm
    self.image_pos_embed = model.image_pos_embed
    self.visual_temporal_embed = model.visual_temporal_embed
    self.image_feature_source = model.image_feature_source

    def forward(self, pixel_values):
    if len(pixel_values.shape) == 4:
    batch_size, C, H, W = pixel_values.shape
    T = 1
    x = self.vision_tower.forward_features_unpool(pixel_values)
    else:
    raise ValueError(f"invalid image shape {pixel_values.shape}")

    if self.image_pos_embed is not None:
    x = x.view(batch_size * T, -1, x.shape[-1])
    num_tokens = x.shape[-2]
    h, w = (num_tokens**0.5).to(torch.int64), (num_tokens**0.5).to(
    torch.int64
    )
    assert h * w == num_tokens, "only support square feature maps for now"
    x = x.view(batch_size * T, h, w, x.shape[-1])
    pos_embed = self.image_pos_embed(x)
    x = x + pos_embed
    x = x.view(batch_size, T * h * w, x.shape[-1])

    if self.visual_temporal_embed is not None:
    visual_temporal_embed = self.visual_temporal_embed(
    x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]
    )
    x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(
    1, T, 1, x.shape[-1]
    )

    x_feat_dict = {}

    spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
    x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x

    temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
    x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x

    x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
    x_feat_dict["last_frame"] = x

    new_x = []
    for _image_feature_source in self.image_feature_source:
    if _image_feature_source not in x_feat_dict:
    raise ValueError(
    "invalid image feature source: {}".format(_image_feature_source)
    )
    new_x.append(x_feat_dict[_image_feature_source])

    x = torch.cat(new_x, dim=1)

    x = x @ self.image_projection
    x = self.image_proj_norm(x)

    return x


    vision_model = VisionEncoder()
    w, h = 224, 224
    x = torch.randn(2, 3, h, w, requires_grad=True)
    torch.onnx.export(
    vision_model,
    x,
    f"{output_dir}/vision_encoder.onnx",
    export_params=True,
    opset_version=13,
    do_constant_folding=True,
    input_names=["pixel_values"],
    output_names=["image_features"],
    dynamic_axes={
    "pixel_values": {0: "batch_size", 2: "height", 3: "width"},
    "image_features": {0: "batch_size", 1: "sequence_length"},
    },
    )


    # 2. Export input embedding layer
    x = torch.randint(0, 100, (2, 16))
    torch.onnx.export(
    model.get_input_embeddings(),
    x,
    f"{output_dir}/embed_tokens.onnx",
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=["input_ids"],
    output_names=["inputs_embeds"],
    dynamic_axes={
    "input_ids": {0: "batch_size", 1: "sequence_length"},
    "inputs_embeds": {0: "batch_size", 1: "sequence_length"},
    },
    )


    # 3. Export language model (encoder, decoder w/o past, decoder w/ past, and merged decoder)
    text_config = model.config.text_config
    num_attention_heads = text_config.decoder_attention_heads
    num_layers = text_config.decoder_layers
    hidden_size = text_config.d_model
    head_dim = hidden_size // num_attention_heads

    batch_size = 2
    past_decoder_sequence_length = 6
    decoder_sequence_length = 13
    encoder_sequence_length = 3

    encoder_inputs_embeds = torch.randn((batch_size, encoder_sequence_length, hidden_size))
    encoder_attention_mask = torch.ones(
    (batch_size, encoder_sequence_length), dtype=torch.int64
    )
    decoder_inputs_embeds = torch.randn((batch_size, decoder_sequence_length, hidden_size))
    dummy_past_key_values_kwargs = {
    f"past_key_values.{i}.{module}.{key}": torch.zeros(
    batch_size,
    num_attention_heads,
    past_decoder_sequence_length,
    head_dim,
    dtype=torch.float32,
    )
    for i in range(num_layers)
    for module in ("decoder", "encoder") # (self, cross_attn)
    for key in ["key", "value"]
    }
    encoder_outputs = model.language_model.model.encoder(
    inputs_embeds=encoder_inputs_embeds,
    attention_mask=encoder_attention_mask,
    )


    class Encoder(nn.Module):
    def __init__(self):
    super().__init__()
    self.encoder = model.language_model.model.encoder

    def forward(self, *args):
    encoder_inputs_embeds, encoder_attention_mask = args
    encoder_outputs = self.encoder(
    inputs_embeds=encoder_inputs_embeds,
    attention_mask=encoder_attention_mask,
    )
    return encoder_outputs.last_hidden_state


    encoder_model = Encoder()
    torch.onnx.export(
    encoder_model,
    (encoder_inputs_embeds, encoder_attention_mask),
    f=f"{output_dir}/encoder_model.onnx",
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=["inputs_embeds", "attention_mask"],
    output_names=["last_hidden_state"],
    dynamic_axes={
    "inputs_embeds": {0: "batch_size", 1: "encoder_sequence_length"},
    "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"},
    "last_hidden_state": {0: "batch_size", 1: "encoder_sequence_length"},
    },
    )

    encoder_outputs = model.language_model.model.encoder.forward(
    inputs_embeds=encoder_inputs_embeds,
    attention_mask=encoder_attention_mask,
    )

    pkv_input_names = list(dummy_past_key_values_kwargs.keys())
    pkv_output_names = list(
    x.replace("past_key_values", "present") for x in dummy_past_key_values_kwargs.keys()
    )


    class PatchedFlorence2DecoderWithoutPast(nn.Module):
    def __init__(self):
    super().__init__()
    self.language_model = model.language_model

    def forward(self, *args):
    encoder_attention_mask, encoder_hidden_states, inputs_embeds = args

    decoder_outputs = self.language_model.forward(
    encoder_outputs=encoder_outputs,
    decoder_inputs_embeds=inputs_embeds,
    )

    flattened_outputs = {
    "logits": decoder_outputs.logits,
    }
    for i in range(num_layers):
    for j, v in enumerate(
    ("decoder.key", "decoder.value", "encoder.key", "encoder.value")
    ):
    flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[
    i
    ][j]

    return flattened_outputs


    decoder_without_past = PatchedFlorence2DecoderWithoutPast()
    torch.onnx.export(
    decoder_without_past,
    args=(
    encoder_attention_mask,
    encoder_outputs.last_hidden_state,
    encoder_inputs_embeds,
    ),
    f=f"{output_dir}/decoder_model.onnx",
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=["encoder_attention_mask", "encoder_hidden_states", "inputs_embeds"],
    output_names=["logits"] + pkv_output_names,
    dynamic_axes={
    "encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"},
    "encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"},
    "inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"},
    "logits": {0: "batch_size", 1: "decoder_sequence_length"},
    **{
    k: {
    0: "batch_size",
    2: "past_decoder_sequence_length + decoder_sequence_length"
    if "decoder" in k
    else "encoder_sequence_length",
    }
    for k in pkv_output_names
    },
    },
    )


    class PatchedFlorence2DecoderWithPast(nn.Module):
    def __init__(self):
    super().__init__()
    self.language_model = model.language_model

    def forward(self, *args):
    encoder_attention_mask, inputs_embeds, *past_key_values_args = args

    pkv_iter = iter(past_key_values_args)
    pkv_tuples = tuple(
    tuple(next(pkv_iter) for i in range(4)) for _ in range(num_layers)
    )

    decoder_outputs = self.language_model.forward(
    # NOTE: encoder_outputs isn't defined here, but we will reuse k,v, cross attentions from pkv tuples
    encoder_outputs=[torch.zeros(0, past_key_values_args[0].shape[2], 0)],
    decoder_inputs_embeds=inputs_embeds,
    past_key_values=pkv_tuples,
    # No need to pass `decoder_attention_mask`
    )

    flattened_outputs = {
    "logits": decoder_outputs.logits,
    }
    for i in range(num_layers):
    for j, v in enumerate(
    ("decoder.key", "decoder.value", "encoder.key", "encoder.value")
    ):
    if "encoder" in v:
    continue
    flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[
    i
    ][j]

    return flattened_outputs


    decoder_with_past = PatchedFlorence2DecoderWithPast()
    torch.onnx.export(
    decoder_with_past,
    args=(
    encoder_attention_mask,
    encoder_inputs_embeds,
    *dummy_past_key_values_kwargs.values(),
    ),
    f=f"{output_dir}/decoder_with_past_model.onnx",
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=["encoder_attention_mask", "inputs_embeds"] + pkv_input_names,
    output_names=["logits"] + [x for x in pkv_output_names if "decoder" in x],
    dynamic_axes={
    "encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"},
    "encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"},
    "inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"},
    **{
    k: {
    0: "batch_size",
    2: "past_decoder_sequence_length"
    if "decoder" in k
    else "encoder_sequence_length_out",
    }
    for k in pkv_input_names
    },
    "logits": {0: "batch_size", 1: "decoder_sequence_length"},
    **{
    k: {
    0: "batch_size",
    2: "past_decoder_sequence_length + decoder_sequence_length",
    }
    for k in pkv_output_names
    if "decoder" in k
    },
    },
    )

    # 4. Post-processing
    for f in os.listdir(output_dir):
    p = os.path.join(output_dir, f)
    onnxslim.slim(p, p)

    merged_decoder = merge_decoders(
    f"{output_dir}/decoder_model.onnx",
    f"{output_dir}/decoder_with_past_model.onnx",
    strict=False,
    )
    check_and_save_model(merged_decoder, f"{output_dir}/decoder_model_merged.onnx")