Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save lukestanley/2577d0b8fcb02e678b202fe0fd924b15 to your computer and use it in GitHub Desktop.

Select an option

Save lukestanley/2577d0b8fcb02e678b202fe0fd924b15 to your computer and use it in GitHub Desktop.

Revisions

  1. lukestanley created this gist Jan 17, 2025.
    233 changes: 233 additions & 0 deletions llama_cpp_server_model_swapping_proxy_middleware.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,233 @@
    # Minimalist OpenAI API compatiable llama.cpp server dynamic models switching middleware / proxy server manages model loading and auto-shutdown..
    # Provides seamless model hot-swapping and idle shutdown while exposing llama.cpp's advanced features like speculative decoding."""
    import asyncio
    import json
    from datetime import datetime, timedelta
    import subprocess
    import aiohttp
    from aiohttp import web

    # Constants
    PROXY_PORT = 8000 # Intended as an external port
    SERVER_PORT = 8312 # Intended as an internal port
    IDLE_TIMEOUT_SECONDS = 60 * 100 # 100 minutes
    SERVER_READY_TIMEOUT_SECONDS = 90
    DEFAULT_MODEL = "/fast/Meta-Llama-3.1-8B-Instruct-IQ4_XS.gguf"
    DEFAULT_DRAFT_MODEL = "/fast/Llama-3.2-1B-Instruct-IQ4_XS.gguf"
    CHUNK_SIZE = 4096
    MAX_GPU_LAYERS = 99
    DEFAULT_CTX_SIZE = 4096
    DEFAULT_THREADS = 8

    LLAMA_SERVER_CMD = """
    /fast/llama_gpu/bin/llama-server
    --model {model}
    --model-draft {draft_model}
    --ctx-size {ctx_size}
    --threads {threads}
    --port {port}
    -fa
    -ngl {gpu_layers}
    --gpu-layers-draft {gpu_layers}
    """

    KEY_MODEL = "model"
    KEY_DRAFT_MODEL = "draft_model"
    KEY_PROCESS = "process"
    KEY_LAST_REQUEST = "last_request"
    # Shared State
    state = {
    KEY_PROCESS: None,
    KEY_LAST_REQUEST: None,
    KEY_MODEL: DEFAULT_MODEL,
    KEY_DRAFT_MODEL: DEFAULT_DRAFT_MODEL,
    }


    async def is_server_ready():
    """Check if the Llama server is ready."""
    try:
    async with aiohttp.ClientSession() as session:
    async with session.get(f"http://127.0.0.1:{SERVER_PORT}/health") as resp:
    return resp.status == 200
    except aiohttp.ClientConnectorError:
    return False


    async def wait_for_server_ready():
    """Wait for the Llama server to be ready."""
    start_time = datetime.now()
    while not await is_server_ready():
    if datetime.now() - start_time > timedelta(
    seconds=SERVER_READY_TIMEOUT_SECONDS
    ):
    raise Exception("Server failed to become ready")
    await asyncio.sleep(1)


    async def start_server(
    model=None, draft_model=None, ctx_size=DEFAULT_CTX_SIZE, threads=DEFAULT_THREADS
    ):
    """Start the Llama server."""
    model = model or state[KEY_MODEL]
    draft_model = draft_model or state[KEY_DRAFT_MODEL]

    if state[KEY_PROCESS] and (
    state[KEY_MODEL] != model or state[KEY_DRAFT_MODEL] != draft_model
    ):
    print("Stopping server to switch models...")
    await stop_server()

    if state[KEY_PROCESS] is None:
    cmd = LLAMA_SERVER_CMD.format(
    model=model,
    draft_model=draft_model,
    ctx_size=ctx_size,
    threads=threads,
    port=SERVER_PORT,
    gpu_layers=MAX_GPU_LAYERS,
    ).split()

    state[KEY_PROCESS] = await asyncio.create_subprocess_exec(
    *cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
    )
    asyncio.create_task(stream_subprocess_output(state[KEY_PROCESS]))

    state[KEY_MODEL] = model
    state[KEY_DRAFT_MODEL] = draft_model
    state[KEY_LAST_REQUEST] = datetime.now()

    print(
    f"Llama server started on port {SERVER_PORT} with models: {model}, {draft_model}"
    )
    await wait_for_server_ready()


    async def stop_server():
    """Stop the Llama server."""
    if state[KEY_PROCESS]:
    print("Terminating Llama server process...")
    state[KEY_PROCESS].terminate()
    await state[KEY_PROCESS].wait()
    print("Llama server process terminated.")
    state[KEY_PROCESS] = None


    async def stream_subprocess_output(process):
    """Stream subprocess output."""

    async def stream_pipe(pipe):
    while line := await pipe.readline():
    print(line.decode(), end="")

    await asyncio.gather(stream_pipe(process.stdout), stream_pipe(process.stderr))


    async def monitor_idle_timeout():
    """Monitor server idle timeout."""
    while True:
    await asyncio.sleep(5)
    if state[KEY_LAST_REQUEST] and datetime.now() - state[
    KEY_LAST_REQUEST
    ] > timedelta(seconds=IDLE_TIMEOUT_SECONDS):
    print("Stopping server due to inactivity...")
    await stop_server()
    state[KEY_LAST_REQUEST] = None


    def adapt_request_for_llama(headers, body):
    """Adapt request for Llama.cpp server."""
    try:
    request_json = json.loads(body)
    model = request_json.pop("model", None)
    draft_model = request_json.pop("draft_model", None)
    body = json.dumps(request_json).encode()
    headers["content-type"] = "application/json"
    headers["content-length"] = str(len(body))
    return headers, body, model, draft_model
    except json.JSONDecodeError:
    return headers, body, None, None


    async def proxy_request(request):
    """Handle incoming requests using aiohttp.web."""
    try:
    # Read request body
    body = await request.read()
    headers = dict(request.headers)

    # Adapt request for Llama
    headers, body, model, draft_model = adapt_request_for_llama(headers, body)

    # Start the Llama Server (if needed)
    await start_server(model, draft_model)
    state[KEY_LAST_REQUEST] = datetime.now()

    # Proxy to Llama.cpp Server
    async with aiohttp.ClientSession() as session:
    async with session.request(
    method=request.method,
    url=f"http://127.0.0.1:{SERVER_PORT}{request.path}",
    headers=headers,
    data=body
    ) as llama_response:
    # Stream response from Llama server
    response = web.StreamResponse(
    status=llama_response.status,
    headers=llama_response.headers
    )
    await response.prepare(request)

    async for chunk in llama_response.content.iter_any():
    await response.write(chunk)

    await response.write_eof()
    return response

    except Exception as e:
    print(f"An error occurred: {e}")
    return web.Response(status=500, text="Internal Server Error")

    async def main():
    """Main function."""
    await start_server()
    asyncio.create_task(monitor_idle_timeout())

    app = web.Application()
    app.router.add_route('*', '/{path:.*}', proxy_request)

    runner = web.AppRunner(app)
    await runner.setup()
    site = web.TCPSite(runner, '127.0.0.1', PROXY_PORT)

    print(f"Serving on http://127.0.0.1:{PROXY_PORT}")
    await site.start()

    try:
    await asyncio.Future() # run forever
    finally:
    await runner.cleanup()

    asyncio.run(main())

    """
    Why not just use the Llama server directly?
    -That's fine until you want to switch models, or save resources when not in use.
    -It's bleeding edge and supports the latest models with speculative decoding to speed up completions.
    -It also can support JSON schema enforcement and other great features.
    Ollama is a similar project but it is more complicated and lags behind llama.cpp.
    Python-llama-cpp is a Python wrapper around llama.cpp but it is not as feature rich as llama.cpp.
    This is a minimalist server that is easy to understand and modify.
    The standard OpenAI API compatibility is implemented by Llama.cpp Server.
    The OpenAI style API is even offered by Google's Gemini API.
    """

    """
    This works but how could it be more minimalist while preserving the primary functionality?
    httpx may help. FastAPI and Flask would not.
    This is middleware.
    FastAPI is a nice wrapper around Starlette, which is a nice wrapper around anyio IIRC.
    config managment is not a big chunk of the code. Celery would be overkill.
    We don't need logging libraries.
    I want to keep everything in one file.
    """