Skip to content

Instantly share code, notes, and snippets.

@Leo-Lee15
Forked from vwxyzjn/vllm_forloop.py
Created August 11, 2025 04:05
Show Gist options
  • Save Leo-Lee15/62a72a9c3acc59b7e426bb195a8d4925 to your computer and use it in GitHub Desktop.
Save Leo-Lee15/62a72a9c3acc59b7e426bb195a8d4925 to your computer and use it in GitHub Desktop.
import time
from vllm import LLM, SamplingParams
from vllm.inputs import PromptType
from vllm.outputs import PoolingRequestOutput, RequestOutput
from typing import Union, cast, Sequence
from multiprocessing import Queue, Event
import threading
class MyLLM(LLM):
def keep_running(
self,
*,
stop_event: Event,
output_queue: Queue,
):
while True:
outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
if stop_event.is_set():
break
if not self.llm_engine.has_unfinished_requests():
time.sleep(0.001)
continue
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if len(outputs) > 0:
output_queue.put(outputs)
def add_requests(self, prompts: list[str], sampling_params: SamplingParams):
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], prompts)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=sampling_params,
lora_request=None,
)
if __name__ == "__main__":
llm = MyLLM(model="HuggingFaceTB/SmolLM2-135M", enforce_eager=True)
input_queue = Queue()
output_queue = Queue()
stop_event = Event()
threading.Thread(
target=llm.keep_running,
kwargs={"stop_event": stop_event, "output_queue": output_queue},
daemon=True,
).start()
prompts = [
"What is the capital of France?",
"What is the capital of Germany?",
"What is the capital of Italy?",
"What is the capital of Spain?",
"What is the capital of Portugal?",
]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=100,
)
for prompt in prompts:
llm.add_requests([prompt], sampling_params)
print(f"len of output queue: {output_queue.qsize()}")
time.sleep(0.1)
received, total = 0, len(prompts)
while received < total:
outputs = output_queue.get()
for output in outputs:
truncated_text = output.outputs[0].text[:40].strip()
print(f"{received=} ==== {truncated_text}")
received += 1
stop_event.set()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment