Last active
July 16, 2025 10:41
-
-
Save xenova/d48921875c8178de1dd72443cfb6f7c8 to your computer and use it in GitHub Desktop.
Revisions
-
xenova revised this gist
Feb 17, 2025 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -82,7 +82,7 @@ def forward(self, pixel_values): vision_model = VisionEncoder() w, h = 768, 768 x = torch.randn(2, 3, h, w, requires_grad=True) torch.onnx.export( vision_model, -
xenova created this gist
Feb 15, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,343 @@ # !pip install --upgrade onnx==1.17.0 onnxruntime==1.20.1 onnxslim==0.1.48 optimum==1.24.0 transformers==4.48.3 import torch import torch.nn as nn from transformers import AutoProcessor, AutoModelForCausalLM import os import onnxslim from optimum.onnx.graph_transformations import merge_decoders, check_and_save_model model = AutoModelForCausalLM.from_pretrained( "microsoft/Florence-2-base-ft", trust_remote_code=True ) output_dir = "converted" os.makedirs(output_dir, exist_ok=True) # 1. Export vision encoder class VisionEncoder(nn.Module): def __init__(self): super().__init__() self.vision_tower = model.vision_tower self.image_projection = model.image_projection self.image_proj_norm = model.image_proj_norm self.image_pos_embed = model.image_pos_embed self.visual_temporal_embed = model.visual_temporal_embed self.image_feature_source = model.image_feature_source def forward(self, pixel_values): if len(pixel_values.shape) == 4: batch_size, C, H, W = pixel_values.shape T = 1 x = self.vision_tower.forward_features_unpool(pixel_values) else: raise ValueError(f"invalid image shape {pixel_values.shape}") if self.image_pos_embed is not None: x = x.view(batch_size * T, -1, x.shape[-1]) num_tokens = x.shape[-2] h, w = (num_tokens**0.5).to(torch.int64), (num_tokens**0.5).to( torch.int64 ) assert h * w == num_tokens, "only support square feature maps for now" x = x.view(batch_size * T, h, w, x.shape[-1]) pos_embed = self.image_pos_embed(x) x = x + pos_embed x = x.view(batch_size, T * h * w, x.shape[-1]) if self.visual_temporal_embed is not None: visual_temporal_embed = self.visual_temporal_embed( x.view(batch_size, T, -1, x.shape[-1])[:, :, 0] ) x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view( 1, T, 1, x.shape[-1] ) x_feat_dict = {} spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] x_feat_dict["last_frame"] = x new_x = [] for _image_feature_source in self.image_feature_source: if _image_feature_source not in x_feat_dict: raise ValueError( "invalid image feature source: {}".format(_image_feature_source) ) new_x.append(x_feat_dict[_image_feature_source]) x = torch.cat(new_x, dim=1) x = x @ self.image_projection x = self.image_proj_norm(x) return x vision_model = VisionEncoder() w, h = 224, 224 x = torch.randn(2, 3, h, w, requires_grad=True) torch.onnx.export( vision_model, x, f"{output_dir}/vision_encoder.onnx", export_params=True, opset_version=13, do_constant_folding=True, input_names=["pixel_values"], output_names=["image_features"], dynamic_axes={ "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, "image_features": {0: "batch_size", 1: "sequence_length"}, }, ) # 2. Export input embedding layer x = torch.randint(0, 100, (2, 16)) torch.onnx.export( model.get_input_embeddings(), x, f"{output_dir}/embed_tokens.onnx", export_params=True, opset_version=14, do_constant_folding=True, input_names=["input_ids"], output_names=["inputs_embeds"], dynamic_axes={ "input_ids": {0: "batch_size", 1: "sequence_length"}, "inputs_embeds": {0: "batch_size", 1: "sequence_length"}, }, ) # 3. Export language model (encoder, decoder w/o past, decoder w/ past, and merged decoder) text_config = model.config.text_config num_attention_heads = text_config.decoder_attention_heads num_layers = text_config.decoder_layers hidden_size = text_config.d_model head_dim = hidden_size // num_attention_heads batch_size = 2 past_decoder_sequence_length = 6 decoder_sequence_length = 13 encoder_sequence_length = 3 encoder_inputs_embeds = torch.randn((batch_size, encoder_sequence_length, hidden_size)) encoder_attention_mask = torch.ones( (batch_size, encoder_sequence_length), dtype=torch.int64 ) decoder_inputs_embeds = torch.randn((batch_size, decoder_sequence_length, hidden_size)) dummy_past_key_values_kwargs = { f"past_key_values.{i}.{module}.{key}": torch.zeros( batch_size, num_attention_heads, past_decoder_sequence_length, head_dim, dtype=torch.float32, ) for i in range(num_layers) for module in ("decoder", "encoder") # (self, cross_attn) for key in ["key", "value"] } encoder_outputs = model.language_model.model.encoder( inputs_embeds=encoder_inputs_embeds, attention_mask=encoder_attention_mask, ) class Encoder(nn.Module): def __init__(self): super().__init__() self.encoder = model.language_model.model.encoder def forward(self, *args): encoder_inputs_embeds, encoder_attention_mask = args encoder_outputs = self.encoder( inputs_embeds=encoder_inputs_embeds, attention_mask=encoder_attention_mask, ) return encoder_outputs.last_hidden_state encoder_model = Encoder() torch.onnx.export( encoder_model, (encoder_inputs_embeds, encoder_attention_mask), f=f"{output_dir}/encoder_model.onnx", export_params=True, opset_version=14, do_constant_folding=True, input_names=["inputs_embeds", "attention_mask"], output_names=["last_hidden_state"], dynamic_axes={ "inputs_embeds": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, "last_hidden_state": {0: "batch_size", 1: "encoder_sequence_length"}, }, ) encoder_outputs = model.language_model.model.encoder.forward( inputs_embeds=encoder_inputs_embeds, attention_mask=encoder_attention_mask, ) pkv_input_names = list(dummy_past_key_values_kwargs.keys()) pkv_output_names = list( x.replace("past_key_values", "present") for x in dummy_past_key_values_kwargs.keys() ) class PatchedFlorence2DecoderWithoutPast(nn.Module): def __init__(self): super().__init__() self.language_model = model.language_model def forward(self, *args): encoder_attention_mask, encoder_hidden_states, inputs_embeds = args decoder_outputs = self.language_model.forward( encoder_outputs=encoder_outputs, decoder_inputs_embeds=inputs_embeds, ) flattened_outputs = { "logits": decoder_outputs.logits, } for i in range(num_layers): for j, v in enumerate( ("decoder.key", "decoder.value", "encoder.key", "encoder.value") ): flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[ i ][j] return flattened_outputs decoder_without_past = PatchedFlorence2DecoderWithoutPast() torch.onnx.export( decoder_without_past, args=( encoder_attention_mask, encoder_outputs.last_hidden_state, encoder_inputs_embeds, ), f=f"{output_dir}/decoder_model.onnx", export_params=True, opset_version=14, do_constant_folding=True, input_names=["encoder_attention_mask", "encoder_hidden_states", "inputs_embeds"], output_names=["logits"] + pkv_output_names, dynamic_axes={ "encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, "encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"}, "inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"}, "logits": {0: "batch_size", 1: "decoder_sequence_length"}, **{ k: { 0: "batch_size", 2: "past_decoder_sequence_length + decoder_sequence_length" if "decoder" in k else "encoder_sequence_length", } for k in pkv_output_names }, }, ) class PatchedFlorence2DecoderWithPast(nn.Module): def __init__(self): super().__init__() self.language_model = model.language_model def forward(self, *args): encoder_attention_mask, inputs_embeds, *past_key_values_args = args pkv_iter = iter(past_key_values_args) pkv_tuples = tuple( tuple(next(pkv_iter) for i in range(4)) for _ in range(num_layers) ) decoder_outputs = self.language_model.forward( # NOTE: encoder_outputs isn't defined here, but we will reuse k,v, cross attentions from pkv tuples encoder_outputs=[torch.zeros(0, past_key_values_args[0].shape[2], 0)], decoder_inputs_embeds=inputs_embeds, past_key_values=pkv_tuples, # No need to pass `decoder_attention_mask` ) flattened_outputs = { "logits": decoder_outputs.logits, } for i in range(num_layers): for j, v in enumerate( ("decoder.key", "decoder.value", "encoder.key", "encoder.value") ): if "encoder" in v: continue flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[ i ][j] return flattened_outputs decoder_with_past = PatchedFlorence2DecoderWithPast() torch.onnx.export( decoder_with_past, args=( encoder_attention_mask, encoder_inputs_embeds, *dummy_past_key_values_kwargs.values(), ), f=f"{output_dir}/decoder_with_past_model.onnx", export_params=True, opset_version=14, do_constant_folding=True, input_names=["encoder_attention_mask", "inputs_embeds"] + pkv_input_names, output_names=["logits"] + [x for x in pkv_output_names if "decoder" in x], dynamic_axes={ "encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, "encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"}, "inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"}, **{ k: { 0: "batch_size", 2: "past_decoder_sequence_length" if "decoder" in k else "encoder_sequence_length_out", } for k in pkv_input_names }, "logits": {0: "batch_size", 1: "decoder_sequence_length"}, **{ k: { 0: "batch_size", 2: "past_decoder_sequence_length + decoder_sequence_length", } for k in pkv_output_names if "decoder" in k }, }, ) # 4. Post-processing for f in os.listdir(output_dir): p = os.path.join(output_dir, f) onnxslim.slim(p, p) merged_decoder = merge_decoders( f"{output_dir}/decoder_model.onnx", f"{output_dir}/decoder_with_past_model.onnx", strict=False, ) check_and_save_model(merged_decoder, f"{output_dir}/decoder_model_merged.onnx")