# !pip install --upgrade onnx==1.17.0 onnxruntime==1.20.1 onnxslim==0.1.48 optimum==1.24.0 transformers==4.48.3 import torch import torch.nn as nn from transformers import AutoProcessor, AutoModelForCausalLM import os import onnxslim from optimum.onnx.graph_transformations import merge_decoders, check_and_save_model model = AutoModelForCausalLM.from_pretrained( "microsoft/Florence-2-base-ft", trust_remote_code=True ) output_dir = "converted" os.makedirs(output_dir, exist_ok=True) # 1. Export vision encoder class VisionEncoder(nn.Module): def __init__(self): super().__init__() self.vision_tower = model.vision_tower self.image_projection = model.image_projection self.image_proj_norm = model.image_proj_norm self.image_pos_embed = model.image_pos_embed self.visual_temporal_embed = model.visual_temporal_embed self.image_feature_source = model.image_feature_source def forward(self, pixel_values): if len(pixel_values.shape) == 4: batch_size, C, H, W = pixel_values.shape T = 1 x = self.vision_tower.forward_features_unpool(pixel_values) else: raise ValueError(f"invalid image shape {pixel_values.shape}") if self.image_pos_embed is not None: x = x.view(batch_size * T, -1, x.shape[-1]) num_tokens = x.shape[-2] h, w = (num_tokens**0.5).to(torch.int64), (num_tokens**0.5).to( torch.int64 ) assert h * w == num_tokens, "only support square feature maps for now" x = x.view(batch_size * T, h, w, x.shape[-1]) pos_embed = self.image_pos_embed(x) x = x + pos_embed x = x.view(batch_size, T * h * w, x.shape[-1]) if self.visual_temporal_embed is not None: visual_temporal_embed = self.visual_temporal_embed( x.view(batch_size, T, -1, x.shape[-1])[:, :, 0] ) x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view( 1, T, 1, x.shape[-1] ) x_feat_dict = {} spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] x_feat_dict["last_frame"] = x new_x = [] for _image_feature_source in self.image_feature_source: if _image_feature_source not in x_feat_dict: raise ValueError( "invalid image feature source: {}".format(_image_feature_source) ) new_x.append(x_feat_dict[_image_feature_source]) x = torch.cat(new_x, dim=1) x = x @ self.image_projection x = self.image_proj_norm(x) return x vision_model = VisionEncoder() w, h = 768, 768 x = torch.randn(2, 3, h, w, requires_grad=True) torch.onnx.export( vision_model, x, f"{output_dir}/vision_encoder.onnx", export_params=True, opset_version=13, do_constant_folding=True, input_names=["pixel_values"], output_names=["image_features"], dynamic_axes={ "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, "image_features": {0: "batch_size", 1: "sequence_length"}, }, ) # 2. Export input embedding layer x = torch.randint(0, 100, (2, 16)) torch.onnx.export( model.get_input_embeddings(), x, f"{output_dir}/embed_tokens.onnx", export_params=True, opset_version=14, do_constant_folding=True, input_names=["input_ids"], output_names=["inputs_embeds"], dynamic_axes={ "input_ids": {0: "batch_size", 1: "sequence_length"}, "inputs_embeds": {0: "batch_size", 1: "sequence_length"}, }, ) # 3. Export language model (encoder, decoder w/o past, decoder w/ past, and merged decoder) text_config = model.config.text_config num_attention_heads = text_config.decoder_attention_heads num_layers = text_config.decoder_layers hidden_size = text_config.d_model head_dim = hidden_size // num_attention_heads batch_size = 2 past_decoder_sequence_length = 6 decoder_sequence_length = 13 encoder_sequence_length = 3 encoder_inputs_embeds = torch.randn((batch_size, encoder_sequence_length, hidden_size)) encoder_attention_mask = torch.ones( (batch_size, encoder_sequence_length), dtype=torch.int64 ) decoder_inputs_embeds = torch.randn((batch_size, decoder_sequence_length, hidden_size)) dummy_past_key_values_kwargs = { f"past_key_values.{i}.{module}.{key}": torch.zeros( batch_size, num_attention_heads, past_decoder_sequence_length, head_dim, dtype=torch.float32, ) for i in range(num_layers) for module in ("decoder", "encoder") # (self, cross_attn) for key in ["key", "value"] } encoder_outputs = model.language_model.model.encoder( inputs_embeds=encoder_inputs_embeds, attention_mask=encoder_attention_mask, ) class Encoder(nn.Module): def __init__(self): super().__init__() self.encoder = model.language_model.model.encoder def forward(self, *args): encoder_inputs_embeds, encoder_attention_mask = args encoder_outputs = self.encoder( inputs_embeds=encoder_inputs_embeds, attention_mask=encoder_attention_mask, ) return encoder_outputs.last_hidden_state encoder_model = Encoder() torch.onnx.export( encoder_model, (encoder_inputs_embeds, encoder_attention_mask), f=f"{output_dir}/encoder_model.onnx", export_params=True, opset_version=14, do_constant_folding=True, input_names=["inputs_embeds", "attention_mask"], output_names=["last_hidden_state"], dynamic_axes={ "inputs_embeds": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, "last_hidden_state": {0: "batch_size", 1: "encoder_sequence_length"}, }, ) encoder_outputs = model.language_model.model.encoder.forward( inputs_embeds=encoder_inputs_embeds, attention_mask=encoder_attention_mask, ) pkv_input_names = list(dummy_past_key_values_kwargs.keys()) pkv_output_names = list( x.replace("past_key_values", "present") for x in dummy_past_key_values_kwargs.keys() ) class PatchedFlorence2DecoderWithoutPast(nn.Module): def __init__(self): super().__init__() self.language_model = model.language_model def forward(self, *args): encoder_attention_mask, encoder_hidden_states, inputs_embeds = args decoder_outputs = self.language_model.forward( encoder_outputs=encoder_outputs, decoder_inputs_embeds=inputs_embeds, ) flattened_outputs = { "logits": decoder_outputs.logits, } for i in range(num_layers): for j, v in enumerate( ("decoder.key", "decoder.value", "encoder.key", "encoder.value") ): flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[ i ][j] return flattened_outputs decoder_without_past = PatchedFlorence2DecoderWithoutPast() torch.onnx.export( decoder_without_past, args=( encoder_attention_mask, encoder_outputs.last_hidden_state, encoder_inputs_embeds, ), f=f"{output_dir}/decoder_model.onnx", export_params=True, opset_version=14, do_constant_folding=True, input_names=["encoder_attention_mask", "encoder_hidden_states", "inputs_embeds"], output_names=["logits"] + pkv_output_names, dynamic_axes={ "encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, "encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"}, "inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"}, "logits": {0: "batch_size", 1: "decoder_sequence_length"}, **{ k: { 0: "batch_size", 2: "past_decoder_sequence_length + decoder_sequence_length" if "decoder" in k else "encoder_sequence_length", } for k in pkv_output_names }, }, ) class PatchedFlorence2DecoderWithPast(nn.Module): def __init__(self): super().__init__() self.language_model = model.language_model def forward(self, *args): encoder_attention_mask, inputs_embeds, *past_key_values_args = args pkv_iter = iter(past_key_values_args) pkv_tuples = tuple( tuple(next(pkv_iter) for i in range(4)) for _ in range(num_layers) ) decoder_outputs = self.language_model.forward( # NOTE: encoder_outputs isn't defined here, but we will reuse k,v, cross attentions from pkv tuples encoder_outputs=[torch.zeros(0, past_key_values_args[0].shape[2], 0)], decoder_inputs_embeds=inputs_embeds, past_key_values=pkv_tuples, # No need to pass `decoder_attention_mask` ) flattened_outputs = { "logits": decoder_outputs.logits, } for i in range(num_layers): for j, v in enumerate( ("decoder.key", "decoder.value", "encoder.key", "encoder.value") ): if "encoder" in v: continue flattened_outputs[f"present.{i}.{v}"] = decoder_outputs.past_key_values[ i ][j] return flattened_outputs decoder_with_past = PatchedFlorence2DecoderWithPast() torch.onnx.export( decoder_with_past, args=( encoder_attention_mask, encoder_inputs_embeds, *dummy_past_key_values_kwargs.values(), ), f=f"{output_dir}/decoder_with_past_model.onnx", export_params=True, opset_version=14, do_constant_folding=True, input_names=["encoder_attention_mask", "inputs_embeds"] + pkv_input_names, output_names=["logits"] + [x for x in pkv_output_names if "decoder" in x], dynamic_axes={ "encoder_attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, "encoder_hidden_states": {0: "batch_size", 1: "encoder_sequence_length"}, "inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length"}, **{ k: { 0: "batch_size", 2: "past_decoder_sequence_length" if "decoder" in k else "encoder_sequence_length_out", } for k in pkv_input_names }, "logits": {0: "batch_size", 1: "decoder_sequence_length"}, **{ k: { 0: "batch_size", 2: "past_decoder_sequence_length + decoder_sequence_length", } for k in pkv_output_names if "decoder" in k }, }, ) # 4. Post-processing for f in os.listdir(output_dir): p = os.path.join(output_dir, f) onnxslim.slim(p, p) merged_decoder = merge_decoders( f"{output_dir}/decoder_model.onnx", f"{output_dir}/decoder_with_past_model.onnx", strict=False, ) check_and_save_model(merged_decoder, f"{output_dir}/decoder_model_merged.onnx")