Skip to content

Instantly share code, notes, and snippets.

@dchae
Last active July 7, 2025 03:12
Show Gist options
  • Select an option

  • Save dchae/690baa99dda0ebf9aac906eab49098c6 to your computer and use it in GitHub Desktop.

Select an option

Save dchae/690baa99dda0ebf9aac906eab49098c6 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Script to extract InternVL embeddings from MP4 video files.
Usage: python extract_video_embeddings.py <path_to_video.mp4>
"""
import argparse
import sys
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
import numpy as np
import cv2
from torch.nn.functional import cosine_similarity
# Global query text
query = "man doing pushups"
def extract_video_embeddings(
video_path, model_checkpoint="OpenGVLab/InternVL3-8B-hf", num_frames=8
):
"""
Extract embeddings from MP4 video using InternVL model.
Args:
video_path (str): Path to the MP4 video file
model_checkpoint (str): HuggingFace model checkpoint
num_frames (int): Number of frames to sample from video
Returns:
torch.Tensor: Video embeddings
"""
try:
# Load processor and model
processor = AutoProcessor.from_pretrained(model_checkpoint)
model = AutoModelForImageTextToText.from_pretrained(
model_checkpoint, torch_dtype=torch.float16, device_map="auto"
)
# Load and process video frames
# Read video frames
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Sample frames uniformly
step = max(1, total_frames // num_frames)
while cap.isOpened() and len(frames) < num_frames:
ret, frame = cap.read()
if not ret:
break
if frame_count % step == 0:
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame_rgb)
frame_count += 1
cap.release()
if not frames:
raise ValueError("No frames extracted from video")
# Process frames using apply_chat_template for proper formatting
messages = [
{
"role": "user",
"content": [{"type": "image"}] * len(frames)
+ [{"type": "text", "text": ""}],
}
]
inputs = processor.apply_chat_template(
messages, images=frames, return_tensors="pt", add_generation_prompt=False
)
# Extract embeddings from model without generating text
with torch.no_grad():
# inputs is a string from apply_chat_template, need to tokenize
if isinstance(inputs, str):
inputs = processor.tokenizer(inputs, return_tensors="pt")
# Move inputs to same device as model
device = next(model.parameters()).device
inputs = {
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()
}
outputs = model(**inputs, output_hidden_states=True)
# Get last hidden state as embedding
embeddings = outputs.hidden_states[-1].mean(
dim=1
) # Pool across sequence dimension
return embeddings
except Exception as e:
print(f"Error processing video: {e}", file=sys.stderr)
return None
def get_text_embedding(text, processor, model):
"""Get text embedding from the model using multimodal processing."""
device = next(model.parameters()).device
# Process text through the full multimodal pipeline
messages = [
{
"role": "user",
"content": [{"type": "text", "text": text}],
}
]
inputs = processor.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=False
)
print(f"Text input after apply_chat_template: {inputs}")
with torch.no_grad():
# inputs is a string from apply_chat_template, need to tokenize
if isinstance(inputs, str):
inputs = processor.tokenizer(inputs, return_tensors="pt")
print(f"Tokenized text input shape: {inputs['input_ids'].shape}")
# Move inputs to same device as model
inputs = {
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()
}
outputs = model(**inputs, output_hidden_states=True)
text_embedding = outputs.hidden_states[-1].mean(dim=1)
return text_embedding
def main():
parser = argparse.ArgumentParser(
description="Extract InternVL embeddings from MP4 video"
)
parser.add_argument("video_path", help="Path to the MP4 video file")
parser.add_argument(
"--model",
default="OpenGVLab/InternVL3-8B-hf",
help="HuggingFace model checkpoint",
)
parser.add_argument(
"--num-frames",
type=int,
default=8,
help="Number of frames to sample from video",
)
parser.add_argument(
"--output-format",
choices=["numpy", "tensor"],
default="numpy",
help="Output format for embeddings",
)
args = parser.parse_args()
# Extract embeddings
embeddings = extract_video_embeddings(args.video_path, args.model, args.num_frames)
if embeddings is None:
sys.exit(1)
# Load model and processor for text embedding
processor = AutoProcessor.from_pretrained(args.model)
model = AutoModelForImageTextToText.from_pretrained(
args.model, torch_dtype=torch.float16, device_map="auto"
)
# Get text embedding for query
print(f"Processing query: '{query}'")
text_embedding = get_text_embedding(query, processor, model)
# Calculate cosine similarity
similarity = cosine_similarity(embeddings, text_embedding, dim=-1)
print(f"Query: {query}")
print(f"Cosine similarity: {similarity.item():.4f}")
# Debug: Print embedding statistics
print(f"Video embedding norm: {torch.norm(embeddings):.4f}")
print(f"Text embedding norm: {torch.norm(text_embedding):.4f}")
print(f"Embeddings are same: {torch.allclose(embeddings, text_embedding, atol=1e-3)}")
# Print embeddings
if args.output_format == "numpy":
embeddings_np = embeddings.cpu().numpy()
print("Video embeddings shape:", embeddings_np.shape)
print("Video embeddings:")
print(embeddings_np)
else:
print("Video embeddings shape:", embeddings.shape)
print("Video embeddings:")
print(embeddings)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment