import time from vllm import LLM, SamplingParams from vllm.inputs import PromptType from vllm.outputs import PoolingRequestOutput, RequestOutput from typing import Union, cast, Sequence from multiprocessing import Queue, Event import threading class MyLLM(LLM): def keep_running( self, *, stop_event: Event, output_queue: Queue, ): while True: outputs: list[Union[RequestOutput, PoolingRequestOutput]] = [] if stop_event.is_set(): break if not self.llm_engine.has_unfinished_requests(): time.sleep(0.001) continue step_outputs = self.llm_engine.step() for output in step_outputs: if output.finished: outputs.append(output) if len(outputs) > 0: output_queue.put(outputs) def add_requests(self, prompts: list[str], sampling_params: SamplingParams): parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], prompts) self._validate_and_add_requests( prompts=parsed_prompts, params=sampling_params, lora_request=None, ) if __name__ == "__main__": llm = MyLLM(model="HuggingFaceTB/SmolLM2-135M", enforce_eager=True) input_queue = Queue() output_queue = Queue() stop_event = Event() threading.Thread( target=llm.keep_running, kwargs={"stop_event": stop_event, "output_queue": output_queue}, daemon=True, ).start() prompts = [ "What is the capital of France?", "What is the capital of Germany?", "What is the capital of Italy?", "What is the capital of Spain?", "What is the capital of Portugal?", ] sampling_params = SamplingParams( temperature=0.0, max_tokens=100, ) for prompt in prompts: llm.add_requests([prompt], sampling_params) print(f"len of output queue: {output_queue.qsize()}") time.sleep(0.1) received, total = 0, len(prompts) while received < total: outputs = output_queue.get() for output in outputs: truncated_text = output.outputs[0].text[:40].strip() print(f"{received=} ==== {truncated_text}") received += 1 stop_event.set()