Created
July 28, 2025 14:17
-
-
Save richardliaw/cff64a6a5551edab2a131fc98acf19d7 to your computer and use it in GitHub Desktop.
Revisions
-
richardliaw created this gist
Jul 28, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,304 @@ #!/usr/bin/env python3 """ vLLM multi-node deployment script. Automatically handles Ray cluster setup and vLLM server launch. This script simplifies vLLM multi-node deployment by automatically handling the Ray cluster setup and vLLM server launch based on the current node's role, eliminating the need for multiple terminals and manual Ray cluster management. Usage Examples: Basic multi-node deployment: # Head node python3 run.py --head-ip 192.168.1.100 --tp 16 # Worker node python3 run.py --head-ip 192.168.1.100 --tp 16 With environment variables (network interface configuration): # Head node NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16 # Worker node NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16 With xpanes for parallel execution: xpanes -I {} -c "NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16" 192.168.1.100 192.168.1.101 With model and additional vLLM arguments: python3 run.py --head-ip 192.168.1.100 --tp 16 --model meta-llama/Llama-2-70b-hf --vllm-args --max-model-len 4096 --gpu-memory-utilization 0.9 Custom GPU count per node: GPUS_PER_NODE=4 python3 run.py --head-ip 192.168.1.100 --tp 16 Key Features: - Automatic role detection based on IP comparison - Environment variable support for network interface configuration - Ray cluster size validation before starting vLLM server - Proper cleanup and signal handling - Single command per node with consistent arguments """ import subprocess import socket import sys import time import argparse import os import signal import json from typing import Optional, List def get_local_ip() -> str: """Get the local IP address of this machine.""" try: # Connect to a remote address to determine local IP with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: s.connect(("8.8.8.8", 80)) return s.getsockname()[0] except Exception: return "127.0.0.1" def is_port_open(host: str, port: int, timeout: int = 5) -> bool: """Check if a port is open on a given host.""" try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(timeout) result = s.connect_ex((host, port)) return result == 0 except Exception: return False def wait_for_ray_cluster(head_ip: str, port: int = 6379, timeout: int = 60) -> bool: """Wait for Ray cluster to be ready.""" print(f"Waiting for Ray cluster at {head_ip}:{port}...") start_time = time.time() while time.time() - start_time < timeout: if is_port_open(head_ip, port): print("Ray cluster is ready!") return True time.sleep(2) print(f"Timeout: Ray cluster at {head_ip}:{port} not ready after {timeout}s") return False def run_command(cmd: List[str], wait: bool = True) -> subprocess.Popen: """Run a command using current environment variables.""" print(f"Running: {' '.join(cmd)}") if wait: result = subprocess.run(cmd) if result.returncode != 0: print(f"Command failed with return code {result.returncode}") sys.exit(1) return None else: return subprocess.Popen(cmd) def start_ray_head(ray_port: int = 6379) -> subprocess.Popen: """Start Ray head node.""" print("Starting Ray head node...") cmd = ["ray", "start", "--head", f"--port={ray_port}", "--block"] return run_command(cmd, wait=False) def start_ray_worker(head_ip: str, ray_port: int = 6379) -> subprocess.Popen: """Start Ray worker node.""" print(f"Starting Ray worker node, connecting to {head_ip}:{ray_port}...") cmd = ["ray", "start", f"--address={head_ip}:{ray_port}", "--block"] return run_command(cmd, wait=False) def get_ray_cluster_size() -> int: """Get the current Ray cluster size (number of nodes).""" try: # Use ray status to get cluster information result = subprocess.run(["ray", "status", "--format", "json"], capture_output=True, text=True, timeout=10) if result.returncode != 0: print("Warning: Could not get Ray cluster status") return 0 status_data = json.loads(result.stdout) # Count nodes in the cluster node_count = len(status_data.get("cluster_state", {}).get("node_table", {})) return node_count except (subprocess.TimeoutExpired, json.JSONDecodeError, KeyError) as e: print(f"Warning: Error getting Ray cluster size: {e}") return 0 except Exception as e: print(f"Warning: Unexpected error getting Ray cluster size: {e}") return 0 def wait_for_cluster_size(expected_nodes: int, timeout: int = 120) -> bool: """Wait for Ray cluster to reach expected size.""" print(f"Waiting for Ray cluster to reach {expected_nodes} nodes...") start_time = time.time() while time.time() - start_time < timeout: current_size = get_ray_cluster_size() print(f"Current cluster size: {current_size}/{expected_nodes} nodes") if current_size >= expected_nodes: print(f"Ray cluster ready with {current_size} nodes!") return True time.sleep(5) print(f"Timeout: Ray cluster did not reach {expected_nodes} nodes after {timeout}s") return False def start_vllm_server(tp: int, model: Optional[str] = None, additional_args: Optional[List[str]] = None): """Start vLLM server.""" print(f"Starting vLLM server with TP={tp}...") cmd = ["vllm", "serve"] if model: cmd.extend([model]) cmd.extend(["-tp", str(tp)]) if additional_args: cmd.extend(additional_args) run_command(cmd, wait=True) def cleanup_ray(): """Clean up Ray processes.""" print("Cleaning up Ray...") try: subprocess.run(["ray", "stop"], check=False) except Exception as e: print(f"Error during Ray cleanup: {e}") def signal_handler(signum, frame): """Handle interrupt signals.""" print("\nReceived interrupt signal, cleaning up...") cleanup_ray() sys.exit(0) def parse_env_vars(env_list: List[str]) -> dict: """Parse environment variables from command line arguments.""" env_vars = {} for env_var in env_list: if '=' in env_var: key, value = env_var.split('=', 1) env_vars[key] = value else: print(f"Warning: Invalid environment variable format: {env_var}") return env_vars def calculate_expected_nodes(tp: int) -> int: """Calculate expected number of nodes based on TP size and available GPUs per node.""" # Try to get GPU count per node from environment or assume 8 GPUs per node gpus_per_node = int(os.environ.get('GPUS_PER_NODE', '8')) expected_nodes = (tp + gpus_per_node - 1) // gpus_per_node # Ceiling division return max(1, expected_nodes) def main(): parser = argparse.ArgumentParser(description="vLLM multi-node deployment") parser.add_argument("--head-ip", required=True, help="IP address of the head node") parser.add_argument("--tp", type=int, required=True, help="Tensor parallel size") parser.add_argument("--ray-port", type=int, default=6379, help="Ray port (default: 6379)") parser.add_argument("--model", help="Model name/path for vLLM") parser.add_argument("--vllm-args", nargs=argparse.REMAINDER, help="Additional arguments to pass to vLLM serve") parser.add_argument("--wait-timeout", type=int, default=60, help="Timeout for waiting for Ray cluster (seconds)") parser.add_argument("--expected-nodes", type=int, help="Expected number of nodes (auto-calculated if not provided)") parser.add_argument("--skip-cluster-check", action="store_true", help="Skip Ray cluster size validation") args = parser.parse_args() # Set up signal handlers for cleanup signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) # Get local IP local_ip = get_local_ip() is_head_node = local_ip == args.head_ip print(f"Local IP: {local_ip}") print(f"Head IP: {args.head_ip}") print(f"Node role: {'HEAD' if is_head_node else 'WORKER'}") # Show environment variables that might affect Ray/NCCL env_vars_of_interest = ['NCCL_SOCKET_IFNAME', 'GLOO_SOCKET_IFNAME', 'CUDA_VISIBLE_DEVICES', 'GPUS_PER_NODE'] active_env_vars = {k: v for k, v in os.environ.items() if k in env_vars_of_interest} if active_env_vars: print(f"Relevant environment variables: {active_env_vars}") ray_process = None try: if is_head_node: # Start Ray head node ray_process = start_ray_head(args.ray_port) # Wait a bit for Ray to start time.sleep(5) # Check cluster size before starting vLLM if not args.skip_cluster_check: expected_nodes = args.expected_nodes or calculate_expected_nodes(args.tp) print(f"Expected nodes for TP={args.tp}: {expected_nodes}") if not wait_for_cluster_size(expected_nodes, args.wait_timeout): print("Warning: Cluster size check failed. Use --skip-cluster-check to bypass.") sys.exit(1) # Start vLLM server on head node start_vllm_server(args.tp, args.model, args.vllm_args) else: # Wait for head node to be ready if not wait_for_ray_cluster(args.head_ip, args.ray_port, args.wait_timeout): print("Failed to connect to Ray head node") sys.exit(1) # Start Ray worker node ray_process = start_ray_worker(args.head_ip, args.ray_port) # Keep worker running print("Worker node started. Press Ctrl+C to stop.") try: ray_process.wait() except KeyboardInterrupt: pass except KeyboardInterrupt: print("\nInterrupted by user") except Exception as e: print(f"Error: {e}") sys.exit(1) finally: # Cleanup if ray_process: ray_process.terminate() try: ray_process.wait(timeout=10) except subprocess.TimeoutExpired: ray_process.kill() cleanup_ray() if __name__ == "__main__": main()