Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Created July 28, 2025 14:17
Show Gist options
  • Save richardliaw/cff64a6a5551edab2a131fc98acf19d7 to your computer and use it in GitHub Desktop.
Save richardliaw/cff64a6a5551edab2a131fc98acf19d7 to your computer and use it in GitHub Desktop.

Revisions

  1. richardliaw created this gist Jul 28, 2025.
    304 changes: 304 additions & 0 deletions vllm_multinode_script.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,304 @@
    #!/usr/bin/env python3
    """
    vLLM multi-node deployment script.
    Automatically handles Ray cluster setup and vLLM server launch.
    This script simplifies vLLM multi-node deployment by automatically handling the Ray cluster
    setup and vLLM server launch based on the current node's role, eliminating the need for
    multiple terminals and manual Ray cluster management.
    Usage Examples:
    Basic multi-node deployment:
    # Head node
    python3 run.py --head-ip 192.168.1.100 --tp 16
    # Worker node
    python3 run.py --head-ip 192.168.1.100 --tp 16
    With environment variables (network interface configuration):
    # Head node
    NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16
    # Worker node
    NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16
    With xpanes for parallel execution:
    xpanes -I {} -c "NCCL_SOCKET_IFNAME=bond0 GLOO_SOCKET_IFNAME=bond0 python3 run.py --head-ip 192.168.1.100 --tp 16" 192.168.1.100 192.168.1.101
    With model and additional vLLM arguments:
    python3 run.py --head-ip 192.168.1.100 --tp 16 --model meta-llama/Llama-2-70b-hf --vllm-args --max-model-len 4096 --gpu-memory-utilization 0.9
    Custom GPU count per node:
    GPUS_PER_NODE=4 python3 run.py --head-ip 192.168.1.100 --tp 16
    Key Features:
    - Automatic role detection based on IP comparison
    - Environment variable support for network interface configuration
    - Ray cluster size validation before starting vLLM server
    - Proper cleanup and signal handling
    - Single command per node with consistent arguments
    """

    import subprocess
    import socket
    import sys
    import time
    import argparse
    import os
    import signal
    import json
    from typing import Optional, List


    def get_local_ip() -> str:
    """Get the local IP address of this machine."""
    try:
    # Connect to a remote address to determine local IP
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
    s.connect(("8.8.8.8", 80))
    return s.getsockname()[0]
    except Exception:
    return "127.0.0.1"


    def is_port_open(host: str, port: int, timeout: int = 5) -> bool:
    """Check if a port is open on a given host."""
    try:
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.settimeout(timeout)
    result = s.connect_ex((host, port))
    return result == 0
    except Exception:
    return False


    def wait_for_ray_cluster(head_ip: str, port: int = 6379, timeout: int = 60) -> bool:
    """Wait for Ray cluster to be ready."""
    print(f"Waiting for Ray cluster at {head_ip}:{port}...")
    start_time = time.time()

    while time.time() - start_time < timeout:
    if is_port_open(head_ip, port):
    print("Ray cluster is ready!")
    return True
    time.sleep(2)

    print(f"Timeout: Ray cluster at {head_ip}:{port} not ready after {timeout}s")
    return False


    def run_command(cmd: List[str], wait: bool = True) -> subprocess.Popen:
    """Run a command using current environment variables."""
    print(f"Running: {' '.join(cmd)}")

    if wait:
    result = subprocess.run(cmd)
    if result.returncode != 0:
    print(f"Command failed with return code {result.returncode}")
    sys.exit(1)
    return None
    else:
    return subprocess.Popen(cmd)


    def start_ray_head(ray_port: int = 6379) -> subprocess.Popen:
    """Start Ray head node."""
    print("Starting Ray head node...")
    cmd = ["ray", "start", "--head", f"--port={ray_port}", "--block"]
    return run_command(cmd, wait=False)


    def start_ray_worker(head_ip: str, ray_port: int = 6379) -> subprocess.Popen:
    """Start Ray worker node."""
    print(f"Starting Ray worker node, connecting to {head_ip}:{ray_port}...")
    cmd = ["ray", "start", f"--address={head_ip}:{ray_port}", "--block"]
    return run_command(cmd, wait=False)


    def get_ray_cluster_size() -> int:
    """Get the current Ray cluster size (number of nodes)."""
    try:
    # Use ray status to get cluster information
    result = subprocess.run(["ray", "status", "--format", "json"],
    capture_output=True, text=True, timeout=10)
    if result.returncode != 0:
    print("Warning: Could not get Ray cluster status")
    return 0

    status_data = json.loads(result.stdout)
    # Count nodes in the cluster
    node_count = len(status_data.get("cluster_state", {}).get("node_table", {}))
    return node_count

    except (subprocess.TimeoutExpired, json.JSONDecodeError, KeyError) as e:
    print(f"Warning: Error getting Ray cluster size: {e}")
    return 0
    except Exception as e:
    print(f"Warning: Unexpected error getting Ray cluster size: {e}")
    return 0


    def wait_for_cluster_size(expected_nodes: int, timeout: int = 120) -> bool:
    """Wait for Ray cluster to reach expected size."""
    print(f"Waiting for Ray cluster to reach {expected_nodes} nodes...")
    start_time = time.time()

    while time.time() - start_time < timeout:
    current_size = get_ray_cluster_size()
    print(f"Current cluster size: {current_size}/{expected_nodes} nodes")

    if current_size >= expected_nodes:
    print(f"Ray cluster ready with {current_size} nodes!")
    return True

    time.sleep(5)

    print(f"Timeout: Ray cluster did not reach {expected_nodes} nodes after {timeout}s")
    return False
    def start_vllm_server(tp: int, model: Optional[str] = None, additional_args: Optional[List[str]] = None):
    """Start vLLM server."""
    print(f"Starting vLLM server with TP={tp}...")
    cmd = ["vllm", "serve"]

    if model:
    cmd.extend([model])

    cmd.extend(["-tp", str(tp)])

    if additional_args:
    cmd.extend(additional_args)

    run_command(cmd, wait=True)


    def cleanup_ray():
    """Clean up Ray processes."""
    print("Cleaning up Ray...")
    try:
    subprocess.run(["ray", "stop"], check=False)
    except Exception as e:
    print(f"Error during Ray cleanup: {e}")


    def signal_handler(signum, frame):
    """Handle interrupt signals."""
    print("\nReceived interrupt signal, cleaning up...")
    cleanup_ray()
    sys.exit(0)


    def parse_env_vars(env_list: List[str]) -> dict:
    """Parse environment variables from command line arguments."""
    env_vars = {}
    for env_var in env_list:
    if '=' in env_var:
    key, value = env_var.split('=', 1)
    env_vars[key] = value
    else:
    print(f"Warning: Invalid environment variable format: {env_var}")
    return env_vars


    def calculate_expected_nodes(tp: int) -> int:
    """Calculate expected number of nodes based on TP size and available GPUs per node."""
    # Try to get GPU count per node from environment or assume 8 GPUs per node
    gpus_per_node = int(os.environ.get('GPUS_PER_NODE', '8'))
    expected_nodes = (tp + gpus_per_node - 1) // gpus_per_node # Ceiling division
    return max(1, expected_nodes)


    def main():
    parser = argparse.ArgumentParser(description="vLLM multi-node deployment")
    parser.add_argument("--head-ip", required=True, help="IP address of the head node")
    parser.add_argument("--tp", type=int, required=True, help="Tensor parallel size")
    parser.add_argument("--ray-port", type=int, default=6379, help="Ray port (default: 6379)")
    parser.add_argument("--model", help="Model name/path for vLLM")
    parser.add_argument("--vllm-args", nargs=argparse.REMAINDER,
    help="Additional arguments to pass to vLLM serve")
    parser.add_argument("--wait-timeout", type=int, default=60,
    help="Timeout for waiting for Ray cluster (seconds)")
    parser.add_argument("--expected-nodes", type=int,
    help="Expected number of nodes (auto-calculated if not provided)")
    parser.add_argument("--skip-cluster-check", action="store_true",
    help="Skip Ray cluster size validation")

    args = parser.parse_args()

    # Set up signal handlers for cleanup
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    # Get local IP
    local_ip = get_local_ip()
    is_head_node = local_ip == args.head_ip

    print(f"Local IP: {local_ip}")
    print(f"Head IP: {args.head_ip}")
    print(f"Node role: {'HEAD' if is_head_node else 'WORKER'}")

    # Show environment variables that might affect Ray/NCCL
    env_vars_of_interest = ['NCCL_SOCKET_IFNAME', 'GLOO_SOCKET_IFNAME', 'CUDA_VISIBLE_DEVICES', 'GPUS_PER_NODE']
    active_env_vars = {k: v for k, v in os.environ.items() if k in env_vars_of_interest}
    if active_env_vars:
    print(f"Relevant environment variables: {active_env_vars}")

    ray_process = None

    try:
    if is_head_node:
    # Start Ray head node
    ray_process = start_ray_head(args.ray_port)

    # Wait a bit for Ray to start
    time.sleep(5)

    # Check cluster size before starting vLLM
    if not args.skip_cluster_check:
    expected_nodes = args.expected_nodes or calculate_expected_nodes(args.tp)
    print(f"Expected nodes for TP={args.tp}: {expected_nodes}")

    if not wait_for_cluster_size(expected_nodes, args.wait_timeout):
    print("Warning: Cluster size check failed. Use --skip-cluster-check to bypass.")
    sys.exit(1)

    # Start vLLM server on head node
    start_vllm_server(args.tp, args.model, args.vllm_args)

    else:
    # Wait for head node to be ready
    if not wait_for_ray_cluster(args.head_ip, args.ray_port, args.wait_timeout):
    print("Failed to connect to Ray head node")
    sys.exit(1)

    # Start Ray worker node
    ray_process = start_ray_worker(args.head_ip, args.ray_port)

    # Keep worker running
    print("Worker node started. Press Ctrl+C to stop.")
    try:
    ray_process.wait()
    except KeyboardInterrupt:
    pass

    except KeyboardInterrupt:
    print("\nInterrupted by user")

    except Exception as e:
    print(f"Error: {e}")
    sys.exit(1)

    finally:
    # Cleanup
    if ray_process:
    ray_process.terminate()
    try:
    ray_process.wait(timeout=10)
    except subprocess.TimeoutExpired:
    ray_process.kill()

    cleanup_ray()


    if __name__ == "__main__":
    main()