Skip to content

Instantly share code, notes, and snippets.

@jeromeku
Forked from vwxyzjn/vllm_forloop.py
Created August 11, 2025 13:19
Show Gist options
  • Save jeromeku/b2723e2bf42f5aad0c93392edaabbbc5 to your computer and use it in GitHub Desktop.
Save jeromeku/b2723e2bf42f5aad0c93392edaabbbc5 to your computer and use it in GitHub Desktop.
import time
from vllm import LLM, SamplingParams
from vllm.inputs import PromptType
from vllm.outputs import PoolingRequestOutput, RequestOutput
from typing import Union, cast, Sequence
from multiprocessing import Queue, Event
import threading
class MyLLM(LLM):
def keep_running(
self,
*,
stop_event: Event,
output_queue: Queue,
):
while True:
outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
if stop_event.is_set():
break
if not self.llm_engine.has_unfinished_requests():
time.sleep(0.001)
continue
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if len(outputs) > 0:
output_queue.put(outputs)
def add_requests(self, prompts: list[str], sampling_params: SamplingParams):
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], prompts)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=sampling_params,
lora_request=None,
)
if __name__ == "__main__":
llm = MyLLM(model="HuggingFaceTB/SmolLM2-135M", enforce_eager=True)
input_queue = Queue()
output_queue = Queue()
stop_event = Event()
threading.Thread(
target=llm.keep_running,
kwargs={"stop_event": stop_event, "output_queue": output_queue},
daemon=True,
).start()
prompts = [
"What is the capital of France?",
"What is the capital of Germany?",
"What is the capital of Italy?",
"What is the capital of Spain?",
"What is the capital of Portugal?",
]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=100,
)
for prompt in prompts:
llm.add_requests([prompt], sampling_params)
print(f"len of output queue: {output_queue.qsize()}")
time.sleep(0.1)
received, total = 0, len(prompts)
while received < total:
outputs = output_queue.get()
for output in outputs:
truncated_text = output.outputs[0].text[:40].strip()
print(f"{received=} ==== {truncated_text}")
received += 1
stop_event.set()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment