""" Proof-of-concept for NAT traversal and low-latency communication over QUIC between two Modal containers. In theory this could be used to establish a low-latency p2p connection between a service running outside Modal and a Modal GPU container, e.g. for real-time inference on a video stream. Please let us know if you try it! Usage: > modal run modal_quic_hole_punch.py """ from typing import Literal import modal import time app = modal.App("quic-hole-punch") image = ( modal.Image.debian_slim() .pip_install("fastapi", "aioquic", "aiohttp", "six") .pip_install("pynat") ) @app.function(image=image, max_containers=1) @modal.asgi_app() def rendezvous(): """Rendezvous server that hands each peer the other's public tuple.""" from typing import Dict, Optional, Tuple from fastapi import FastAPI from pydantic import BaseModel class RegisterRequest(BaseModel): peer_id: Literal["A", "B"] ip: str port: int api = FastAPI() peers: Dict[str, Tuple[str, int]] = {} @api.post("/register") async def register(req: RegisterRequest): peers[req.peer_id] = (req.ip, req.port) other = "A" if req.peer_id == "B" else "B" info: Optional[Tuple[str, int]] = peers.get(other) return {"peer": info} # Null until the second peer registers return api async def get_ext_addr(sock): from pynat import get_stun_response response = get_stun_response(sock, ("stun.ekiga.net", 3478)) return response["ext_ip"], response["ext_port"] def create_cert(key): """Create a self-signed certificate for the given key.""" import datetime from cryptography import x509 from cryptography.hazmat.primitives import hashes from cryptography.x509.oid import NameOID return ( x509.CertificateBuilder() .subject_name( x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "modal-quic-demo")]) ) .issuer_name( x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "modal-quic-demo")]) ) .public_key(key.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(datetime.datetime.utcnow()) .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=1)) .sign(key, hashes.SHA256()) ) N_PINGS = 5 @app.function(image=image, region="jp") # Run in 🇯🇵. async def punch_and_quic(my_id: str, rendezvous_url: str, local_port: int = 5555): import asyncio import socket import ssl import aiohttp from aioquic.asyncio import connect, serve from aioquic.quic.configuration import QuicConfiguration from cryptography.hazmat.primitives.asymmetric import ec # 1. Discover public mapping via STUN. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind(("0.0.0.0", local_port)) sock.setblocking(False) pub_ip, pub_port = await get_ext_addr(sock) print(f"[{my_id}] Pub IP: {pub_ip}, Pub Port: {pub_port}") # 2. Register & wait for the peer's tuple. async with aiohttp.ClientSession() as s: while True: resp = await s.post( f"{rendezvous_url}/register", json={"peer_id": my_id, "ip": pub_ip, "port": pub_port}, ) if peer := (await resp.json()).get("peer"): peer_ip, peer_port = peer break await asyncio.sleep(1) print(f"[{my_id}] Punching {pub_ip}:{pub_port} -> {peer_ip}:{peer_port}") for _ in range(50): # 5s total. sock.sendto(b"punch", (peer_ip, peer_port)) try: await asyncio.wait_for(asyncio.get_event_loop().sock_recv(sock, 16), 0.1) break except asyncio.TimeoutError: continue else: raise RuntimeError("Hole punching failed – no response from peer") print(f"[{my_id}] Punched {pub_ip}:{pub_port} -> {peer_ip}:{peer_port}") sock.close() # Close socket. Mapping should stay alive. is_client = my_id == "B" cfg = QuicConfiguration( is_client=is_client, alpn_protocols=["hq-29"], verify_mode=ssl.CERT_NONE ) if not is_client: cfg.private_key = ec.generate_private_key(ec.SECP256R1()) cfg.certificate = create_cert(cfg.private_key) async def echo(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): for i in range(N_PINGS): data = await reader.read(100) if not data: break assert data == b"ping" writer.write(b"pong") await writer.drain() writer.close() def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): asyncio.create_task(echo(reader, writer)) await serve( host="0.0.0.0", port=local_port, # Use the punched port. configuration=cfg, stream_handler=handler, ) await asyncio.sleep(1) else: async with connect( peer_ip, peer_port, configuration=cfg, local_port=local_port, ) as quic: reader, writer = await quic.create_stream() total_latency = 0 for i in range(N_PINGS): start_time = time.monotonic() writer.write(b"ping") await writer.drain() print(f"[{my_id}] Sent ping {i + 1}") response = await reader.read(100) assert response == b"pong" end_time = time.monotonic() rtt = end_time - start_time total_latency += rtt print(f"[{my_id}] Received pong {i + 1}") print(f"[{my_id}] Round-trip time: {rtt * 1000:.2f}ms") await asyncio.sleep(0.1) writer.close() print(f"[{my_id}] Average rtt: {(total_latency / N_PINGS) * 1000:.2f}ms") @app.local_entrypoint() def main(): a = punch_and_quic.spawn(my_id="A", rendezvous_url=rendezvous.web_url) b = punch_and_quic.spawn(my_id="B", rendezvous_url=rendezvous.web_url) modal.FunctionCall.gather(a, b)