Skip to content

Instantly share code, notes, and snippets.

@mikelgg93
Last active October 23, 2025 09:55
Show Gist options
  • Select an option

  • Save mikelgg93/cedc064a1b065bd1e9d8b9aa1f05e53d to your computer and use it in GitHub Desktop.

Select an option

Save mikelgg93/cedc064a1b065bd1e9d8b9aa1f05e53d to your computer and use it in GitHub Desktop.
Track hand pose and ball's position in realtime on Neon's scene camera using mediapipe and yolo11
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "click",
# "matplotlib",
# "numpy",
# "opencv-python",
# "pupil-labs-realtime-api",
# "rich",
# "pandas",
# "mediapipe",
# "scipy",
# "ultralytics",
# "pupil-labs-neon-recording",
# ]
# ///
import contextlib
import multiprocessing as mp
import queue
import threading
from collections import deque
from datetime import datetime
from typing import Any
import click
import cv2
import matplotlib.animation as animation
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import mediapipe as mpipe
import numpy as np
import pandas as pd
import torch
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from pupil_labs.realtime_api.simple import (
Device,
discover_one_device,
)
from rich.console import Console
from rich.live import Live
from rich.spinner import Spinner
from rich.table import Table
from scipy.spatial.distance import cdist
from ultralytics import RTDETR, YOLO
from ultralytics.utils.downloads import download
from upath import UPath
fig, (ax1, ax2) = plt.subplots(
1, 2, figsize=(12, 5), gridspec_kw={"width_ratios": [3, 2]}
)
(line_gaze_ball,) = ax1.plot(
[], [], lw=2, color="red", label="Gaze-to-Ball", animated=True
)
(line_gaze_lhand,) = ax1.plot(
[], [], lw=2, color="magenta", label="Gaze-to-Left-Hand", animated=True
)
(line_gaze_rhand,) = ax1.plot(
[], [], lw=2, color="cyan", label="Gaze-to-Right-Hand", animated=True
)
(line_lhand_ball,) = ax1.plot(
[],
[],
lw=1,
linestyle="--",
color="magenta",
label="Left-Hand-to-Ball",
animated=True,
)
(line_rhand_ball,) = ax1.plot(
[],
[],
lw=1,
linestyle="--",
color="cyan",
label="Right-Hand-to-Ball",
animated=True,
)
consuming_rate_text: plt.Text | None = ax1.text(
0.75,
0.95,
"",
transform=ax1.transAxes,
ha="left",
va="top",
fontsize=9,
color="gray",
animated=True,
)
ball_detected_text: plt.Text | None = ax1.text(
0.5,
0.95,
"",
transform=ax1.transAxes,
ha="center",
va="top",
fontsize=12,
color="green",
animated=True,
)
(gaze_canvas_dot,) = ax2.plot(
[], [], "o", color="red", markersize=15, label="Gaze", animated=True
)
ball_canvas_bbox = patches.Rectangle(
(0, 0), 1, 1, linewidth=2, edgecolor="b", facecolor="none", animated=True
)
ax2.add_patch(ball_canvas_bbox)
lhand_canvas_scatter = ax2.scatter(
[], [], c="magenta", s=50, label="Left Hand", animated=True
)
rhand_canvas_scatter = ax2.scatter(
[], [], c="cyan", s=50, label="Right Hand", animated=True
)
console = Console()
def initialize_detectors(
device: str, umodel: str
) -> tuple[vision.HandLandmarker | None, YOLO | RTDETR | None]:
"""Initializes MediaPipe Hand Landmarker and YOLO or RTDETR Object Detector."""
try:
console.log("Initializing MediaPipe Hand Landmarker...")
hand_model_path = UPath("hand_landmarker.task")
if not hand_model_path.exists():
console.log(f"Downloading Hand Landmarker model to {hand_model_path}...")
download(
"https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task",
dir=UPath("."),
)
hand_base_options = python.BaseOptions(
model_asset_path=str(hand_model_path),
delegate=mpipe.tasks.BaseOptions.Delegate.GPU,
)
hand_options = vision.HandLandmarkerOptions(
base_options=hand_base_options,
running_mode=mpipe.tasks.vision.RunningMode.VIDEO,
num_hands=2,
min_hand_detection_confidence=0.5,
min_hand_presence_confidence=0.5,
)
hand_detector = vision.HandLandmarker.create_from_options(hand_options)
console.log("[green]MediaPipe Hand Landmarker initialized.[/green]")
console.log(
f"Initializing Ultralytics {umodel} model on device: [bold]{device}[/bold]..."
)
if umodel.startswith("yolo"):
model = YOLO(umodel)
elif umodel.startswith("rtdet"):
model = RTDETR(umodel)
else:
raise ValueError(f"Unsupported model type: {umodel}") # noqa: TRY301
console.log(f"[green]Ultralytics {umodel} model initialized.[/green]")
except Exception as e:
console.log(f"[bold red]Error initializing detectors: {e}[/bold red]")
console.log(
"[bold yellow]Please check your internet connection and file permissions.[/bold yellow]"
)
return None, None
else:
console.log("[green]Detectors initialized successfully.[/green]")
return hand_detector, model
def annotate_frame(
frame: np.ndarray,
gaze: dict,
hand_result: vision.HandLandmarkerResult,
ball_center: tuple[int, int] | None,
ball_bbox_data: list | None,
hand_data: dict,
class_name: str,
) -> np.ndarray:
"""Draws all detections and overlays on the video frame."""
if hand_result and hand_result.hand_landmarks:
for hand_landmarks in hand_result.hand_landmarks:
hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
hand_landmarks_proto.landmark.extend([
landmark_pb2.NormalizedLandmark(
x=landmark.x, y=landmark.y, z=landmark.z
)
for landmark in hand_landmarks
])
solutions.drawing_utils.draw_landmarks(
frame,
hand_landmarks_proto,
solutions.hands.HAND_CONNECTIONS,
solutions.drawing_styles.get_default_hand_landmarks_style(),
solutions.drawing_styles.get_default_hand_connections_style(),
)
if ball_bbox_data is not None:
start_point = (int(ball_bbox_data[0]), int(ball_bbox_data[1]))
end_point = (int(ball_bbox_data[2]), int(ball_bbox_data[3]))
cv2.rectangle(frame, start_point, end_point, (0, 255, 0), 3)
cv2.putText(
frame,
class_name,
(start_point[0], start_point[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
2,
(0, 255, 0),
2,
)
gaze_px = (int(gaze["x"]), int(gaze["y"]))
cv2.circle(frame, gaze_px, 20, (0, 0, 255), 2)
if ball_center is not None:
cv2.circle(frame, ball_center, 10, (255, 165, 0), -1)
cv2.line(frame, gaze_px, ball_center, (0, 0, 255), 1)
for hand, data in hand_data.items():
if data["closest_point_to_gaze"]:
color = {"left": (255, 0, 255), "right": (255, 255, 0)}.get(
hand, (0, 255, 0)
)
cv2.line(frame, gaze_px, data["closest_point_to_gaze"], color, 2)
cv2.circle(frame, data["closest_point_to_gaze"], 8, color, -1)
if ball_center and data["closest_point_to_ball"]:
cv2.line(frame, data["closest_point_to_ball"], ball_center, color, 2)
return frame
def acquire_data(device: Device, acq_queue: mp.Queue, shutdown_event: mp.Event) -> None:
"""Thread target for acquiring matched gaze and scene video data."""
console.log("[green]Data acquisition thread started.[/green]")
while not shutdown_event.is_set():
try:
sample = device.receive_matched_scene_video_frame_and_gaze(
timeout_seconds=1
)
if sample:
item = {
"frame": sample.frame.bgr_pixels,
"gaze": {
"x": sample.gaze.x,
"y": sample.gaze.y,
"ts": sample.gaze.timestamp_unix_seconds,
},
}
acq_queue.put(item)
except queue.Full:
pass
except Exception as e:
console.log(f"[bold red]Acquisition error: {e}[/bold red]")
shutdown_event.set()
break
console.log("[yellow]Data acquisition thread finished.[/yellow]")
def process_and_log_data(
output_path: UPath,
save_data: bool,
max_history: int,
device_str: str,
conf_threshold: float,
umodel: str,
coco_class: int,
acquisition_queue: mp.Queue,
visualization_queue: mp.Queue,
video_display_queue: mp.Queue,
shutdown_event: mp.Event,
) -> None:
"""Process target for inference and data logging."""
hand_detector, ball_detector = initialize_detectors(device_str, umodel)
if not hand_detector or not ball_detector:
shutdown_event.set()
return
console.log("[green]Processing and logging process started.[/green]")
history = {
"ts": deque(),
"gaze_ball": deque(),
"gaze_lhand": deque(),
"gaze_rhand": deque(),
"lhand_ball": deque(),
"rhand_ball": deque(),
}
data_log = []
start_ts: float | None = None
frame_dims_sent = False
INFERENCE_WIDTH = 640 # Resize frames to this width for faster processing?
while not shutdown_event.is_set():
try:
sample = acquisition_queue.get(timeout=1)
original_frame, gaze = sample["frame"], sample["gaze"]
ts = gaze["ts"]
if start_ts is None:
start_ts = ts
original_height, original_width, _ = original_frame.shape
scale = INFERENCE_WIDTH / original_width
inference_height = int(original_height * scale)
resized_frame = cv2.resize(
original_frame, (INFERENCE_WIDTH, inference_height)
)
rgb_frame_resized = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGB)
rgba_frame_resized = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2RGBA)
mp_image = mpipe.Image(
image_format=mpipe.ImageFormat.SRGBA, data=rgba_frame_resized
)
hand_result = hand_detector.detect_for_video(mp_image, int(ts * 1000))
# COCO class 32 is 'sports ball'. Filter results for this class.
ball_result = ball_detector.predict(
rgb_frame_resized,
verbose=False,
device=device_str,
classes=[coco_class],
conf=conf_threshold,
)[0]
class_name = (
ball_result.names[coco_class]
if hasattr(ball_result, "names") and coco_class in ball_result.names
else "Unknown"
)
gaze_pos = np.array([gaze["x"], gaze["y"]])
ball_center_np, ball_bbox_data = None, None
if ball_result and len(ball_result.boxes) > 0:
best_ball_box_resized = ball_result.boxes.xyxy[0].cpu().numpy()
ball_bbox_data = best_ball_box_resized / scale
ball_center_np = np.array([
(ball_bbox_data[0] + ball_bbox_data[2]) / 2,
(ball_bbox_data[1] + ball_bbox_data[3]) / 2,
])
hand_data = {
"left": {
"dist_to_gaze": np.nan,
"dist_to_ball": np.nan,
"closest_point_to_gaze": None,
"closest_point_to_ball": None,
"landmarks": None,
},
"right": {
"dist_to_gaze": np.nan,
"dist_to_ball": np.nan,
"closest_point_to_gaze": None,
"closest_point_to_ball": None,
"landmarks": None,
},
}
if hand_result.hand_landmarks:
for i, hand_landmarks in enumerate(hand_result.hand_landmarks):
handedness = hand_result.handedness[i][0].category_name.lower()
landmark_coords = np.array([
(lm.x * original_width, lm.y * original_height)
for lm in hand_landmarks
])
hand_data[handedness]["landmarks"] = landmark_coords
distances_gaze = cdist(gaze_pos[None, :], landmark_coords)
min_dist_idx = np.argmin(distances_gaze)
hand_data[handedness]["dist_to_gaze"] = distances_gaze[
0, min_dist_idx
]
hand_data[handedness]["closest_point_to_gaze"] = tuple(
map(int, landmark_coords[min_dist_idx])
)
if ball_center_np is not None:
distances_ball = cdist(ball_center_np[None, :], landmark_coords)
min_ball_dist_idx = np.argmin(distances_ball)
hand_data[handedness]["dist_to_ball"] = distances_ball[
0, min_ball_dist_idx
]
hand_data[handedness]["closest_point_to_ball"] = tuple(
map(int, landmark_coords[min_ball_dist_idx])
)
history["ts"].append(ts - start_ts)
history["gaze_ball"].append(
np.linalg.norm(gaze_pos - ball_center_np)
if ball_center_np is not None
else np.nan
)
history["gaze_lhand"].append(hand_data["left"]["dist_to_gaze"])
history["gaze_rhand"].append(hand_data["right"]["dist_to_gaze"])
history["lhand_ball"].append(hand_data["left"]["dist_to_ball"])
history["rhand_ball"].append(hand_data["right"]["dist_to_ball"])
while history["ts"] and history["ts"][-1] - history["ts"][0] > max_history:
for key in history:
history[key].popleft()
if save_data:
log_entry = {
"timestamp": ts,
"gaze_x": gaze["x"],
"gaze_y": gaze["y"],
"ball_x": ball_center_np[0]
if ball_center_np is not None
else np.nan,
"ball_y": ball_center_np[1]
if ball_center_np is not None
else np.nan,
"dist_gaze_ball": history["gaze_ball"][-1],
**{
f"dist_gaze_{h}hand": hand_data[h]["dist_to_gaze"]
for h in ["left", "right"]
},
**{
f"dist_{h}hand_ball": hand_data[h]["dist_to_ball"]
for h in ["left", "right"]
},
}
data_log.append(log_entry)
computed_rate = (
1.0 / np.mean(np.diff(list(history["ts"])[-100:]))
if len(history["ts"]) > 1
else 0.0
)
ball_center_for_drawing = (
tuple(map(int, ball_center_np)) if ball_center_np is not None else None
)
viz_data = {
"history": {k: list(v) for k, v in history.items()},
"computed_rate": computed_rate,
"gaze_pos": gaze_pos,
"ball_bbox_data": ball_bbox_data,
"lhand_landmarks": hand_data["left"]["landmarks"],
"rhand_landmarks": hand_data["right"]["landmarks"],
"ball_detected": ball_center_np is not None,
}
if not frame_dims_sent:
viz_data["frame_dims"] = (original_width, original_height)
frame_dims_sent = True
with contextlib.suppress(queue.Full):
visualization_queue.put_nowait(viz_data)
annotated_bgr = annotate_frame(
original_frame.copy(),
gaze,
hand_result,
ball_center_for_drawing,
ball_bbox_data,
hand_data,
class_name=class_name,
)
with contextlib.suppress(queue.Full):
video_display_queue.put_nowait(annotated_bgr)
except queue.Empty:
continue
except Exception:
console.print_exception()
if save_data:
console.log("[cyan]Saving captured data to CSV...[/cyan]")
if data_log:
pd.DataFrame(data_log).to_csv(output_path, index=False, float_format="%.3f")
console.log(f"[green]Saved {len(data_log)} rows to {output_path}[/green]")
else:
console.log("[yellow]No data was logged.[/yellow]")
console.log("[yellow]Processing and logging process finished.[/yellow]")
def update_plot(
frame_idx: int,
max_history: int,
video_window_name: str,
visualization_queue: mp.Queue,
video_display_queue: mp.Queue,
shutdown_event: mp.Event,
) -> list[Any]:
"""Function called by FuncAnimation to update plots and video."""
try:
data = visualization_queue.get_nowait()
history = data["history"]
elapsed_time = history["ts"]
line_gaze_ball.set_data(elapsed_time, history["gaze_ball"])
line_gaze_lhand.set_data(elapsed_time, history["gaze_lhand"])
line_gaze_rhand.set_data(elapsed_time, history["gaze_rhand"])
line_lhand_ball.set_data(elapsed_time, history["lhand_ball"])
line_rhand_ball.set_data(elapsed_time, history["rhand_ball"])
consuming_rate_text.set_text(f"Proc Rate: {data['computed_rate']:.1f} Hz")
ball_detected_text.set_text("Object Detected" if data["ball_detected"] else "")
current_max_elapsed = elapsed_time[-1] if elapsed_time else max_history
ax1.set_xlim(
max(0, current_max_elapsed - max_history),
current_max_elapsed or max_history,
)
all_distances = np.concatenate([
d
for d in history.values()
if isinstance(d, list) and len(d) > 0 and isinstance(d[0], (int, float))
])
max_dist = (
np.nanmax(all_distances)
if len(all_distances) > 0 and np.any(np.isfinite(all_distances))
else 1000
)
ax1.set_ylim(0, max_dist * 1.1 if max_dist > 0 else 1000)
if "frame_dims" in data:
width, height = data["frame_dims"]
ax2.set_xlim(0, width)
ax2.set_ylim(height, 0)
if data["gaze_pos"] is not None:
gaze_canvas_dot.set_data([data["gaze_pos"][0]], [data["gaze_pos"][1]])
if data["ball_bbox_data"] is not None:
bbox = data["ball_bbox_data"]
ball_canvas_bbox.set_xy((bbox[0], bbox[1]))
ball_canvas_bbox.set_width(bbox[2] - bbox[0])
ball_canvas_bbox.set_height(bbox[3] - bbox[1])
ball_canvas_bbox.set_visible(True)
else:
ball_canvas_bbox.set_visible(False)
lhand_lm = data["lhand_landmarks"]
rhand_lm = data["rhand_landmarks"]
lhand_canvas_scatter.set_offsets(
lhand_lm if lhand_lm is not None else np.empty((0, 2))
)
rhand_canvas_scatter.set_offsets(
rhand_lm if rhand_lm is not None else np.empty((0, 2))
)
except queue.Empty:
pass
try:
video_frame = video_display_queue.get_nowait()
cv2.imshow(video_window_name, video_frame)
except queue.Empty:
pass
if cv2.waitKey(1) & 0xFF == 27 and not shutdown_event.is_set():
console.log("[bold magenta]Shutdown initiated by user...[/bold magenta]")
shutdown_event.set()
return [
line_gaze_ball,
line_gaze_lhand,
line_gaze_rhand,
line_lhand_ball,
line_rhand_ball,
consuming_rate_text,
ball_detected_text,
gaze_canvas_dot,
ball_canvas_bbox,
lhand_canvas_scatter,
rhand_canvas_scatter,
]
def on_close(event: Any | None, shutdown_event: mp.Event) -> None:
if not shutdown_event.is_set():
console.log("[bold magenta]Shutdown initiated by user...[/bold magenta]")
shutdown_event.set()
@click.command()
@click.option("--ip", default=None, help="IP address of Neon.")
@click.option("--port", default=8080, show_default=True, help="Port of Neon's device.")
@click.option(
"--output-dir",
default=str(UPath.home() / "gaze_ball_hand_recordings"),
type=click.Path(file_okay=False),
show_default=True,
)
@click.option(
"--max-history",
default=30,
show_default=True,
help="Maximum seconds to show in plot history.",
)
@click.option(
"--save", is_flag=True, default=False, help="Enable saving data to a CSV file."
)
@click.option(
"--device",
"device_str",
default="cpu",
type=click.Choice(["cpu", "mps", "cuda"]),
help="Processing device (e.g., 'mps' for M1/M2 Mac), cuda for NVIDIA GPUs.",
)
@click.option(
"--conf",
"conf_threshold",
default=0.25,
show_default=True,
help="Confidence threshold for detection.",
)
@click.option(
"--model",
"umodel",
default="yolo12n.pt",
help="Ultralytics model to use for detection. e.g. yolo12n.pt, yolo11n.pt, "
"or rtdetr-l.pt",
)
@click.option(
"--class",
"coco_class",
default=32,
type=int,
help="COCO class ID for the object to detect (default: 32 for sports ball)."
"Check them https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml",
)
def main(
ip: str | None,
port: int,
output_dir: str,
max_history: int,
save: bool,
device_str: str,
conf_threshold: float,
umodel: str,
coco_class: int,
) -> None:
"""Main entry point for the real-time Gaze, Ball, and Hand tracking tool."""
console.rule("[bold cyan]Real-Time Gaze, Ball & Hand Tracker[/bold cyan]")
settings_table = Table(title="Application Settings")
settings_table.add_column("Parameter", justify="right", style="cyan", no_wrap=True)
settings_table.add_column("Value", style="magenta")
settings_table.add_row("Device IP", ip or "Auto-discover")
settings_table.add_row("Device Port", str(port))
settings_table.add_row("Plot History (s)", str(max_history))
settings_table.add_row(
"Save Data", "[green]Enabled[/green]" if save else "[yellow]Disabled[/yellow]"
)
if save:
settings_table.add_row("Output Directory", str(output_dir))
settings_table.add_section()
settings_table.add_row("Processing Device", device_str.upper())
settings_table.add_row("Torch", torch.__version__)
settings_table.add_row("Model", umodel)
settings_table.add_row("Confidence", str(conf_threshold))
settings_table.add_row("Classes", str(coco_class))
settings_table.add_section()
settings_table.add_row("MediaPipe Model", "hand_landmarker.task")
settings_table.add_row("MediaPipe Hand Confidence", "0.5")
console.print(settings_table)
# --- Define Queues and Events in main scope ---
shutdown_event = mp.Event()
acquisition_queue = mp.Queue(maxsize=10)
visualization_queue = mp.Queue(maxsize=2)
video_display_queue = mp.Queue(maxsize=2)
device: Device | None = None
acquisition_thread: threading.Thread | None = None
processing_proc: mp.Process | None = None
video_window_name = "Scene Camera with Gaze and Detections"
try:
with Live(
Spinner("dots", text="Connecting to device..."),
console=console,
transient=True,
) as live:
if ip:
live.update(f"Attempting to connect to {ip}:{port}...")
try:
device = Device(address=ip, port=port)
except Exception as e:
live.stop()
console.log(f"[bold red]Failed to connect: {e}[/bold red]")
return
else:
live.update("Searching for Pupil Labs device...")
device = discover_one_device(max_search_duration_seconds=10)
if device is None:
live.stop()
console.log(
"[bold red]No device found or connection failed.[/bold red]"
)
return
output_path = UPath(output_dir)
csv_filename = UPath("")
if save:
output_path.mkdir(parents=True, exist_ok=True)
csv_filename = (
output_path
/ f"gaze_hand_ball_data_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.csv"
)
ax1.set_title("Real-Time Distances: Gaze, Hand, and Ball")
ax1.set_xlabel("Time (seconds)")
ax1.set_ylabel("Distance (pixels)")
ax1.legend(loc="upper left")
ax1.grid(True, linestyle="--", alpha=0.5)
ax1.set_xlim(max_history, 0)
ax1.set_ylim(0, 1250)
ax2.set_title("2D Scene Canvas")
ax2.set_xticks([])
ax2.set_yticks([])
ax2.set_aspect("equal", adjustable="box")
ax2.legend(loc="upper left")
plt.tight_layout(pad=2.0)
fig.canvas.mpl_connect(
"close_event", lambda event: on_close(event, shutdown_event)
)
cv2.namedWindow(video_window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(video_window_name, 800, 600)
# Pass queues and events as arguments
acquisition_thread = threading.Thread(
target=acquire_data, args=(device, acquisition_queue, shutdown_event)
)
processing_proc = mp.Process(
target=process_and_log_data,
args=(
csv_filename,
save,
max_history,
device_str,
conf_threshold,
umodel,
coco_class,
acquisition_queue,
visualization_queue,
video_display_queue,
shutdown_event,
),
)
console.log("[cyan]Starting background processes...[/cyan]")
acquisition_thread.start()
processing_proc.start()
console.log(
"[bold green]Setup complete. Displaying real-time plot and video.[/bold green]"
)
anim = animation.FuncAnimation(
fig,
update_plot,
fargs=(
max_history,
video_window_name,
visualization_queue,
video_display_queue,
shutdown_event,
),
interval=30,
blit=True,
)
plt.show()
except Exception:
console.print_exception(show_locals=False)
finally:
if not shutdown_event.is_set():
shutdown_event.set()
if acquisition_thread and acquisition_thread.is_alive():
acquisition_thread.join(timeout=2)
if processing_proc and processing_proc.is_alive():
processing_proc.join(timeout=2)
if device:
device.close()
cv2.destroyAllWindows()
console.rule("[bold cyan]Application Finished[/bold cyan]")
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment