Last active
July 7, 2025 03:12
-
-
Save dchae/690baa99dda0ebf9aac906eab49098c6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Script to extract InternVL embeddings from MP4 video files. | |
| Usage: python extract_video_embeddings.py <path_to_video.mp4> | |
| """ | |
| import argparse | |
| import sys | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| import numpy as np | |
| import cv2 | |
| from torch.nn.functional import cosine_similarity | |
| # Global query text | |
| query = "man doing pushups" | |
| def extract_video_embeddings( | |
| video_path, model_checkpoint="OpenGVLab/InternVL3-8B-hf", num_frames=8 | |
| ): | |
| """ | |
| Extract embeddings from MP4 video using InternVL model. | |
| Args: | |
| video_path (str): Path to the MP4 video file | |
| model_checkpoint (str): HuggingFace model checkpoint | |
| num_frames (int): Number of frames to sample from video | |
| Returns: | |
| torch.Tensor: Video embeddings | |
| """ | |
| try: | |
| # Load processor and model | |
| processor = AutoProcessor.from_pretrained(model_checkpoint) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_checkpoint, torch_dtype=torch.float16, device_map="auto" | |
| ) | |
| # Load and process video frames | |
| # Read video frames | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| frame_count = 0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Sample frames uniformly | |
| step = max(1, total_frames // num_frames) | |
| while cap.isOpened() and len(frames) < num_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % step == 0: | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(frame_rgb) | |
| frame_count += 1 | |
| cap.release() | |
| if not frames: | |
| raise ValueError("No frames extracted from video") | |
| # Process frames using apply_chat_template for proper formatting | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "image"}] * len(frames) | |
| + [{"type": "text", "text": ""}], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, images=frames, return_tensors="pt", add_generation_prompt=False | |
| ) | |
| # Extract embeddings from model without generating text | |
| with torch.no_grad(): | |
| # inputs is a string from apply_chat_template, need to tokenize | |
| if isinstance(inputs, str): | |
| inputs = processor.tokenizer(inputs, return_tensors="pt") | |
| # Move inputs to same device as model | |
| device = next(model.parameters()).device | |
| inputs = { | |
| k: v.to(device) if isinstance(v, torch.Tensor) else v | |
| for k, v in inputs.items() | |
| } | |
| outputs = model(**inputs, output_hidden_states=True) | |
| # Get last hidden state as embedding | |
| embeddings = outputs.hidden_states[-1].mean( | |
| dim=1 | |
| ) # Pool across sequence dimension | |
| return embeddings | |
| except Exception as e: | |
| print(f"Error processing video: {e}", file=sys.stderr) | |
| return None | |
| def get_text_embedding(text, processor, model): | |
| """Get text embedding from the model using multimodal processing.""" | |
| device = next(model.parameters()).device | |
| # Process text through the full multimodal pipeline | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": text}], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, return_tensors="pt", add_generation_prompt=False | |
| ) | |
| print(f"Text input after apply_chat_template: {inputs}") | |
| with torch.no_grad(): | |
| # inputs is a string from apply_chat_template, need to tokenize | |
| if isinstance(inputs, str): | |
| inputs = processor.tokenizer(inputs, return_tensors="pt") | |
| print(f"Tokenized text input shape: {inputs['input_ids'].shape}") | |
| # Move inputs to same device as model | |
| inputs = { | |
| k: v.to(device) if isinstance(v, torch.Tensor) else v | |
| for k, v in inputs.items() | |
| } | |
| outputs = model(**inputs, output_hidden_states=True) | |
| text_embedding = outputs.hidden_states[-1].mean(dim=1) | |
| return text_embedding | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Extract InternVL embeddings from MP4 video" | |
| ) | |
| parser.add_argument("video_path", help="Path to the MP4 video file") | |
| parser.add_argument( | |
| "--model", | |
| default="OpenGVLab/InternVL3-8B-hf", | |
| help="HuggingFace model checkpoint", | |
| ) | |
| parser.add_argument( | |
| "--num-frames", | |
| type=int, | |
| default=8, | |
| help="Number of frames to sample from video", | |
| ) | |
| parser.add_argument( | |
| "--output-format", | |
| choices=["numpy", "tensor"], | |
| default="numpy", | |
| help="Output format for embeddings", | |
| ) | |
| args = parser.parse_args() | |
| # Extract embeddings | |
| embeddings = extract_video_embeddings(args.video_path, args.model, args.num_frames) | |
| if embeddings is None: | |
| sys.exit(1) | |
| # Load model and processor for text embedding | |
| processor = AutoProcessor.from_pretrained(args.model) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| args.model, torch_dtype=torch.float16, device_map="auto" | |
| ) | |
| # Get text embedding for query | |
| print(f"Processing query: '{query}'") | |
| text_embedding = get_text_embedding(query, processor, model) | |
| # Calculate cosine similarity | |
| similarity = cosine_similarity(embeddings, text_embedding, dim=-1) | |
| print(f"Query: {query}") | |
| print(f"Cosine similarity: {similarity.item():.4f}") | |
| # Debug: Print embedding statistics | |
| print(f"Video embedding norm: {torch.norm(embeddings):.4f}") | |
| print(f"Text embedding norm: {torch.norm(text_embedding):.4f}") | |
| print(f"Embeddings are same: {torch.allclose(embeddings, text_embedding, atol=1e-3)}") | |
| # Print embeddings | |
| if args.output_format == "numpy": | |
| embeddings_np = embeddings.cpu().numpy() | |
| print("Video embeddings shape:", embeddings_np.shape) | |
| print("Video embeddings:") | |
| print(embeddings_np) | |
| else: | |
| print("Video embeddings shape:", embeddings.shape) | |
| print("Video embeddings:") | |
| print(embeddings) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment