import asyncio import json import os import re import typing import uuid from contextlib import contextmanager from typing import Any import numpy as np import torch import torch.distributed from langchain_core.messages import AIMessage, ToolCall from langchain_core.tools import BaseTool from langgraph.prebuilt import ToolNode from omegaconf import DictConfig from tensordict import TensorDict from verl import DataProto from verl.third_party.vllm import vllm_version from verl.utils.torch_functional import pad_2d_list_to_length, pad_sequence_to_length from verl.workers.rollout.base import BaseRollout from vllm import LLM, SamplingParams from vllm.distributed import parallel_state as vllm_ps T = typing.TypeVar("T") def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[int]: non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids def _repeat_interleave(value: torch.Tensor | np.ndarray, repeats: int) -> torch.Tensor | list[Any]: if isinstance(value, torch.Tensor): return value.repeat_interleave(repeats, dim=0) else: return np.repeat(value, repeats, axis=0) def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64): """ end of sentence token can be int or list: 1 or [1, 2] e.g. response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0], [78, 0, 76, 2, 1, 0, 0], [23, 98, 1, 0, 0, 0, 0], [33, 3, 98, 45, 1, 0, 0]]) #eos_token=1 response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0]]) #eos_token=[1,2] response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0]]) """ eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int() return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype) def deserialize_tool_call(string: str) -> ToolCall: message = json.loads(string) assert isinstance(message, dict) request = { 'name': message['name'], 'args': message.get('arguments', {}), 'id': str(uuid.uuid4()), 'type': 'tool_call', } return request def validate_tool_calls(string: str) -> bool: balance = 0 for match in re.finditer(r'', string): match = match.group() if match == '': if balance > 0: return False balance += 1 else: if balance < 1: return False balance -= 1 return balance == 0 def search_tool_calls(string: str) -> list[str]: if not validate_tool_calls(string): return [] try: pattern = r'((?:(?!).)*)' matches = re.finditer(pattern, string, re.DOTALL) return [match.group(1).strip() for match in matches] except Exception: return [] async def execute_tool_calls(tool_runner: ToolNode, B_tool_calls: list[list[str]]) -> list[list[str]]: """executes a batch of tool calls in parallel using the tool runner.""" B_tool_responses = [[""] * len(_) for _ in B_tool_calls] scheduling = [] tool_calls = [] for i, strings in enumerate(B_tool_calls): for j, string in enumerate(strings): try: tool_calls.append(deserialize_tool_call(string)) scheduling.append((i, j)) except Exception as e: B_tool_responses[i][j] = json.dumps({"status": "error", "content": "tool call must be a JSON object with 'name' and (optional) 'arguments' fields"}) message = AIMessage(content="", tool_calls=tool_calls) tool_responses = await tool_runner.ainvoke([message]) for (i, j), tool_message in zip(scheduling, tool_responses): status, content = tool_message.status, tool_message.content if status == "error": content = content.replace("Error: ", "") content = content.strip() B_tool_responses[i][j] = json.dumps({"status": status, "content": content}) return B_tool_responses def run_coroutine_sync(coroutine: typing.Awaitable[T]) -> T: try: eventloop = asyncio.get_running_loop() except RuntimeError: return asyncio.run(coroutine) else: if eventloop.is_running(): future = asyncio.run_coroutine_threadsafe(coroutine, eventloop) return future.result() else: return eventloop.run_until_complete(coroutine) class vLLMRollout(BaseRollout): def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): """A vLLM rollout. It requires the module is supported by the vllm. Args: module: module here follows huggingface APIs config: DictConfig tokenizer: the task/model tokenizer model_hf_config: the huggingface config to initiallize the generating model in vllm **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group """ super().__init__() self.config = config assert not (not config.enforce_eager and config.free_cache_engine), \ "disable CUDA graph (enforce_eager = False) if free cache engine" tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1) assert tensor_parallel_size <= torch.distributed.get_world_size(), \ "tensor parallel size should be less than or equal to the world size" max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192) if kwargs.get('train_tp', None) is not None: # deployed with megatron os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0' os.environ['MEGATRON_IMPORT_TIMERS'] = '0' if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'): train_tp = kwargs.get('train_tp', None) num_tp_per_train_tp = train_tp // tensor_parallel_size vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp) else: vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \ "model context length should be greater than total sequence length" max_model_len = self.config.max_model_len if self.config.max_model_len \ else config.prompt_length + config.response_length max_model_len = int(max_model_len) if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: raise ValueError('Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ please increase max_num_batched_tokens or disable chunked prefill') trust_remote_code = kwargs.get('trust_remote_code', False) load_format = 'dummy' if config.load_format.startswith('dummy') else config.load_format self.inference_engine = LLM( model=model_path, enable_sleep_mode=True, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend="external_launcher", dtype=config.dtype, enforce_eager=config.enforce_eager, gpu_memory_utilization=config.gpu_memory_utilization, disable_custom_all_reduce=True, disable_mm_preprocessor_cache=True, skip_tokenizer_init=False, max_model_len=max_model_len, load_format=load_format, disable_log_stats=config.disable_log_stats, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=config.enable_chunked_prefill, enable_prefix_caching=True, trust_remote_code=trust_remote_code, seed=int(os.getenv("RANK", "0")) // tensor_parallel_size, ) # Offload vllm model to reduce peak memory usage self.inference_engine.sleep(level=1) kwargs = dict( n=1, logprobs=0, # can be set to 0 and let actor to recompute max_tokens=config.response_length, ) # # we may detokenize the result all together later if vllm_version != '0.3.1': kwargs['detokenize'] = False # supporting adding any sampling params from the config file for k in config.keys(): if hasattr(SamplingParams(), str(k)): kwargs[k] = config.get(k) print(f"kwargs: {kwargs}") self.sampling_params = SamplingParams(**kwargs) self.pad_token_id = tokenizer.pad_token_id @contextmanager def update_sampling_params(self, **kwargs): # update sampling params old_sampling_params_args = {} if kwargs: for key, value in kwargs.items(): if hasattr(self.sampling_params, key): old_value = getattr(self.sampling_params, key) old_sampling_params_args[key] = old_value setattr(self.sampling_params, key, value) yield # roll back to previous sampling params # if len(old_sampling_params_args): for key, value in old_sampling_params_args.items(): setattr(self.sampling_params, key, value) @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # rebuild vllm cache engine if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: self.inference_engine.init_cache_engine() idx = prompts.batch['input_ids'] # (bs, prompt_length) # left-padded attention_mask attention_mask = prompts.batch['attention_mask'] position_ids = prompts.batch['position_ids'] # used to construct attention_mask eos_token_id = prompts.meta_info['eos_token_id'] batch_size = idx.size(0) non_tensor_batch = prompts.non_tensor_batch if 'raw_prompt_ids' not in non_tensor_batch: non_tensor_batch['raw_prompt_ids'] = np.array( [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) if batch_size != len(non_tensor_batch['raw_prompt_ids']): raise RuntimeError('vllm sharding manager is not work properly.') if 'multi_modal_data' in non_tensor_batch: vllm_inputs = [] for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'), non_tensor_batch.pop('multi_modal_data')): vllm_inputs.append({'prompt_token_ids': raw_prompt_ids, 'multi_modal_data': multi_modal_data}) else: vllm_inputs = [{ 'prompt_token_ids': raw_prompt_ids } for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')] # ensure the type of `prompt_token_ids` passed to vllm is list[int] # https://github.com/volcengine/verl/pull/772 for input_data in vllm_inputs: if isinstance(input_data['prompt_token_ids'], np.ndarray): input_data['prompt_token_ids'] = input_data['prompt_token_ids'].tolist() elif not isinstance(input_data['prompt_token_ids'], list): raise TypeError( f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") do_sample = prompts.meta_info.get('do_sample', True) is_validate = prompts.meta_info.get('validate', False) if not do_sample: kwargs = { 'best_of': 1, 'top_p': 1.0, 'top_k': -1, 'min_p': 0.0, 'temperature': 0, 'n': 1 # if greedy, only 1 response } elif is_validate: # TODO: try ** kwargs = { 'top_k': self.config.val_kwargs.top_k, 'top_p': self.config.val_kwargs.top_p, 'temperature': self.config.val_kwargs.temperature, 'n': 1, # if validate, already repeat in ray_trainer } # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): outputs = self.inference_engine.generate( prompts=vllm_inputs, # because we have already convert it to prompt token id sampling_params=self.sampling_params, use_tqdm=False) # TODO(sgm): disable logprob when recompute_log_prob is enable # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) response = [] for output in outputs: for sample_id in range(len(output.outputs)): response.append(output.outputs[sample_id].token_ids) response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device) if self.sampling_params.n > 1 and do_sample: idx = _repeat_interleave(idx, self.sampling_params.n) attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n) position_ids = _repeat_interleave(position_ids, self.sampling_params.n) batch_size = batch_size * self.sampling_params.n if 'multi_modal_inputs' in non_tensor_batch.keys(): non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'], self.sampling_params.n) seq = torch.cat([idx, response], dim=-1) response_length = response.size(1) delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1) if position_ids.dim() == 3: # qwen2vl mrope delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) # TODO(sgm): fix position_ids on right_pad # prompt: left pad + response: right pad # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict( { 'prompts': idx, 'responses': response, 'input_ids': seq, # here input_ids become the whole sentences # 'old_log_probs': log_probs, # we will recompute old log prob with actor 'attention_mask': attention_mask, 'position_ids': position_ids }, batch_size=batch_size) # free vllm cache engine if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: self.inference_engine.free_cache_engine() return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) class vLLMRolloutWithTool(vLLMRollout): def __init__( self, model_path: str, config: DictConfig, tokenizer, model_hf_config, toolkit: list[BaseTool | typing.Callable[..., Any]], **kwargs ): super().__init__(model_path, config, tokenizer, model_hf_config, **kwargs) self.tokenizer = tokenizer self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() self.gen_str = "\n<|im_start|>assistant\n" self.gen_ids = self.tokenizer.encode(self.gen_str) self.tool_runner = ToolNode(toolkit) @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # rebuild vllm cache engine if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: self.inference_engine.init_cache_engine() ori_input_ids = prompts.batch['input_ids'] # (bs, prompt_length) # left-padded attention_mask attention_mask = prompts.batch['attention_mask'] position_ids = prompts.batch['position_ids'] # # used to construct attention_mask # eos_token_id = prompts.meta_info['eos_token_id'] batch_size = ori_input_ids.size(0) idx_list = [] # parse idx from torch.Tensor to list[list[str]] for i in range(batch_size): idx_list.append(_pre_process_inputs(self.pad_token_id, ori_input_ids[i])) do_sample = prompts.meta_info.get('do_sample', True) is_validate = prompts.meta_info.get('validate', False) if not do_sample: kwargs = { 'best_of': 1, 'top_p': 1.0, 'top_k': -1, 'min_p': 0.0, 'temperature': 0, 'n': 1 # if greedy, only 1 response } elif is_validate: # TODO: try ** kwargs = { 'top_k': self.config.val_kwargs.top_k, 'top_p': self.config.val_kwargs.top_p, 'temperature': self.config.val_kwargs.temperature, 'n': 1, # if validate, already repeat in ray_trainer } with self.update_sampling_params(**kwargs): # prepare n copies for each input curr_inputs = [] for input_ids in idx_list: for _ in range(self.sampling_params.n): curr_inputs.append(input_ids.copy()) init_inputs = [ids.copy() for ids in curr_inputs] # track the status of each input curr_max_tokens = [self.sampling_params.max_tokens] * len(curr_inputs) active_indices = list(range(len(curr_inputs))) # collect the result mask of each rollout, 1 for non-result, 0 for tool call result or pad result_mask_list = [[] for _ in range(len(curr_inputs))] # generate until all inputs are completed for step in range(self.config.max_turns): if len(active_indices) == 0: break # only process the active inputs active_inputs = [curr_inputs[i] for i in active_indices] active_max_tokens = [curr_max_tokens[i] for i in active_indices] with self.update_sampling_params( n=1, max_tokens=min(512, max(active_max_tokens)), stop_token_ids=[151644], top_p=0.99, ): # 512 at most, and add <|im_start|> as stop for corner case vllm_inputs = [{ 'prompt_token_ids': raw_prompt_ids } for raw_prompt_ids in active_inputs] outputs = self.inference_engine.generate( prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=False ) # collect all tool calls tool_calls_list: list[list[str]] = [] call_indices: list[int] = [] # process each output new_active_indices = [] for i, idx in enumerate(active_indices): output_ids = outputs[i].outputs[0].token_ids finish_reason = outputs[i].outputs[0].finish_reason stop_reason = outputs[i].outputs[0].stop_reason if finish_reason == 'stop' and (stop_reason == None or stop_reason == self.tokenizer.pad_token_id): curr_inputs[idx] += output_ids result_mask_list[idx] += [1] * len(output_ids) output_str = self.tokenizer.decode(output_ids) tool_calls = search_tool_calls(output_str) if tool_calls: tool_calls_list.append(tool_calls) call_indices.append(idx) new_active_indices.append(idx) else: pass # no tool calls elif finish_reason == 'length': # output over max tokens curr_inputs[idx] += output_ids result_mask_list[idx] += [1] * len(output_ids) elif finish_reason == 'stop' and stop_reason == 151644: # 151644 is the id of <|im_start|>, is a illigal stop, we stop here curr_inputs[idx] += output_ids result_mask_list[idx] += [1] * len(output_ids) else: raise ValueError(f"unknown stop reason. finish_reason: {finish_reason}, stop_reason: {stop_reason}") # batch process tool calls if tool_calls_list: # Only tp_rank 0 executes the tools if self.tp_rank == 0: tool_responses_list = run_coroutine_sync( execute_tool_calls(self.tool_runner, tool_calls_list) ) # Prepare data for broadcasting broadcast_data = { 'tool_calls_list': tool_calls_list, 'call_indices': call_indices, 'tool_responses_list': tool_responses_list } else: broadcast_data = None broadcast_data = vllm_ps._TP.broadcast_object(broadcast_data, src=0) # All ranks process the broadcasted data if broadcast_data is not None: tool_calls_list = broadcast_data['tool_calls_list'] call_indices = broadcast_data['call_indices'] tool_responses_list = broadcast_data['tool_responses_list'] for idx, tool_calls, tool_responses in zip(call_indices, tool_calls_list, tool_responses_list): tool_response_str = '' for call, response in zip(tool_calls, tool_responses): tool_response_str += f"{call}\n{response}\n\n" tool_response_str = "\n<|im_start|>user\n" + tool_response_str + "<|im_end|>" output_ids = self.tokenizer.encode(tool_response_str) curr_inputs[idx] += output_ids result_mask_list[idx] += [0] * len(output_ids) curr_inputs[idx] += self.gen_ids result_mask_list[idx] += [0] * len(self.gen_ids) # check if need to truncate, if yes, truncate, and remove from active; if no, update curr_max_tokens length_checked_active_indices = [] for idx in active_indices: assert len(curr_inputs[idx]) - len(init_inputs[idx]) == len(result_mask_list[idx]), f"curr_inputs: {len(curr_inputs[idx])}, init_inputs: {len(init_inputs[idx])}, result_mask_list: {len(result_mask_list[idx])}" if len(curr_inputs[idx]) - len(init_inputs[idx]) >= self.config.response_length: curr_inputs[idx] = init_inputs[idx] \ + curr_inputs[idx][len(init_inputs[idx]):len(init_inputs[idx])+self.config.response_length] result_mask_list[idx] = result_mask_list[idx][:self.config.response_length] else: curr_max_tokens[idx] = self.config.response_length - len(curr_inputs[idx]) + len(init_inputs[idx]) if idx in new_active_indices: length_checked_active_indices.append(idx) active_indices = length_checked_active_indices output_ids_list = [] # collect the all rollouts for i, input_ids in enumerate(idx_list): for j in range(self.sampling_params.n): idx = i * self.sampling_params.n + j input_len = len(input_ids) output_ids_list.append(curr_inputs[idx][input_len:]) response_attention_mask_list = [] response_list = [] result_mask_list_padded = [] for output_ids, result_mask in zip(output_ids_list, result_mask_list): assert len(output_ids) == len(result_mask), f"output_ids: {len(output_ids)}, result_mask: {len(result_mask)}" # to tensor response = torch.tensor(output_ids, device=ori_input_ids.device) result_mask = torch.tensor(result_mask, device=ori_input_ids.device) # response attention mask, 1 for valid, 0 for invalid response_attention_mask = torch.ones_like(response, dtype=torch.int64) response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) response_attention_mask_list.append(response_attention_mask) # response, pad to response_length response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) response_list.append(response) # result mask, 1 for non-result, 0 for result or pad result_mask = pad_sequence_to_length(result_mask, self.config.response_length, 0) result_mask_list_padded.append(result_mask) response_attention_mask = torch.stack(response_attention_mask_list, dim=0) response = torch.stack(response_list, dim=0) result_mask = torch.stack(result_mask_list_padded, dim=0) if self.config.n > 1 and do_sample: ori_input_ids = ori_input_ids.repeat_interleave(self.config.n, dim=0) attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0) position_ids = position_ids.repeat_interleave(self.config.n, dim=0) batch_size = batch_size * self.config.n seq = torch.cat([ori_input_ids, response], dim=-1) response_length = response.size(1) delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) # TODO(sgm): fix position_ids on right_pad # prompt: left pad + response: right pad # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) # concat attenion_mask for input and response attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # result mask: result part is 0, other part is 1 loss_mask = result_mask * response_attention_mask # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict({ 'prompts': ori_input_ids, 'responses': response, 'input_ids': seq, # here input_ids become the whole sentences 'attention_mask': attention_mask, 'loss_mask': loss_mask, 'position_ids': position_ids }, batch_size=batch_size) # free vllm cache engine if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: self.inference_engine.free_cache_engine() return DataProto(batch=batch)