#!/usr/bin/env python3 """ Script to extract InternVL embeddings from MP4 video files. Usage: python extract_video_embeddings.py """ import argparse import sys import torch from transformers import AutoProcessor, AutoModelForImageTextToText import numpy as np import cv2 from torch.nn.functional import cosine_similarity # Global query text query = "man doing pushups" def extract_video_embeddings( video_path, model_checkpoint="OpenGVLab/InternVL3-8B-hf", num_frames=8 ): """ Extract embeddings from MP4 video using InternVL model. Args: video_path (str): Path to the MP4 video file model_checkpoint (str): HuggingFace model checkpoint num_frames (int): Number of frames to sample from video Returns: torch.Tensor: Video embeddings """ try: # Load processor and model processor = AutoProcessor.from_pretrained(model_checkpoint) model = AutoModelForImageTextToText.from_pretrained( model_checkpoint, torch_dtype=torch.float16, device_map="auto" ) # Load and process video frames # Read video frames cap = cv2.VideoCapture(video_path) frames = [] frame_count = 0 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Sample frames uniformly step = max(1, total_frames // num_frames) while cap.isOpened() and len(frames) < num_frames: ret, frame = cap.read() if not ret: break if frame_count % step == 0: # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame_rgb) frame_count += 1 cap.release() if not frames: raise ValueError("No frames extracted from video") # Process frames using apply_chat_template for proper formatting messages = [ { "role": "user", "content": [{"type": "image"}] * len(frames) + [{"type": "text", "text": ""}], } ] inputs = processor.apply_chat_template( messages, images=frames, return_tensors="pt", add_generation_prompt=False ) # Extract embeddings from model without generating text with torch.no_grad(): # inputs is a string from apply_chat_template, need to tokenize if isinstance(inputs, str): inputs = processor.tokenizer(inputs, return_tensors="pt") # Move inputs to same device as model device = next(model.parameters()).device inputs = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items() } outputs = model(**inputs, output_hidden_states=True) # Get last hidden state as embedding embeddings = outputs.hidden_states[-1].mean( dim=1 ) # Pool across sequence dimension return embeddings except Exception as e: print(f"Error processing video: {e}", file=sys.stderr) return None def get_text_embedding(text, processor, model): """Get text embedding from the model using multimodal processing.""" device = next(model.parameters()).device # Process text through the full multimodal pipeline messages = [ { "role": "user", "content": [{"type": "text", "text": text}], } ] inputs = processor.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=False ) print(f"Text input after apply_chat_template: {inputs}") with torch.no_grad(): # inputs is a string from apply_chat_template, need to tokenize if isinstance(inputs, str): inputs = processor.tokenizer(inputs, return_tensors="pt") print(f"Tokenized text input shape: {inputs['input_ids'].shape}") # Move inputs to same device as model inputs = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items() } outputs = model(**inputs, output_hidden_states=True) text_embedding = outputs.hidden_states[-1].mean(dim=1) return text_embedding def main(): parser = argparse.ArgumentParser( description="Extract InternVL embeddings from MP4 video" ) parser.add_argument("video_path", help="Path to the MP4 video file") parser.add_argument( "--model", default="OpenGVLab/InternVL3-8B-hf", help="HuggingFace model checkpoint", ) parser.add_argument( "--num-frames", type=int, default=8, help="Number of frames to sample from video", ) parser.add_argument( "--output-format", choices=["numpy", "tensor"], default="numpy", help="Output format for embeddings", ) args = parser.parse_args() # Extract embeddings embeddings = extract_video_embeddings(args.video_path, args.model, args.num_frames) if embeddings is None: sys.exit(1) # Load model and processor for text embedding processor = AutoProcessor.from_pretrained(args.model) model = AutoModelForImageTextToText.from_pretrained( args.model, torch_dtype=torch.float16, device_map="auto" ) # Get text embedding for query print(f"Processing query: '{query}'") text_embedding = get_text_embedding(query, processor, model) # Calculate cosine similarity similarity = cosine_similarity(embeddings, text_embedding, dim=-1) print(f"Query: {query}") print(f"Cosine similarity: {similarity.item():.4f}") # Debug: Print embedding statistics print(f"Video embedding norm: {torch.norm(embeddings):.4f}") print(f"Text embedding norm: {torch.norm(text_embedding):.4f}") print(f"Embeddings are same: {torch.allclose(embeddings, text_embedding, atol=1e-3)}") # Print embeddings if args.output_format == "numpy": embeddings_np = embeddings.cpu().numpy() print("Video embeddings shape:", embeddings_np.shape) print("Video embeddings:") print(embeddings_np) else: print("Video embeddings shape:", embeddings.shape) print("Video embeddings:") print(embeddings) if __name__ == "__main__": main()