Last active
May 30, 2025 13:51
-
-
Save DiTo97/05ff268bb640f9d87b3f53d61a302136 to your computer and use it in GitHub Desktop.
Revisions
-
DiTo97 revised this gist
May 30, 2025 . No changes.There are no files selected for viewing
-
DiTo97 revised this gist
May 21, 2025 . 1 changed file with 0 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,3 @@ import asyncio import json import os -
DiTo97 revised this gist
May 20, 2025 . 1 changed file with 100 additions and 100 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -29,6 +29,19 @@ T = typing.TypeVar("T") def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[int]: non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids def _repeat_interleave(value: torch.Tensor | np.ndarray, repeats: int) -> torch.Tensor | list[Any]: if isinstance(value, torch.Tensor): return value.repeat_interleave(repeats, dim=0) else: return np.repeat(value, repeats, axis=0) def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64): """ end of sentence token can be int or list: 1 or [1, 2] @@ -52,17 +65,96 @@ def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype) def deserialize_tool_call(string: str) -> ToolCall: message = json.loads(string) assert isinstance(message, dict) request = { 'name': message['name'], 'args': message.get('arguments', {}), 'id': str(uuid.uuid4()), 'type': 'tool_call', } return request def validate_tool_calls(string: str) -> bool: balance = 0 for match in re.finditer(r'</?tool_call>', string): match = match.group() if match == '<tool_call>': if balance > 0: return False balance += 1 else: if balance < 1: return False balance -= 1 return balance == 0 def search_tool_calls(string: str) -> list[str]: if not validate_tool_calls(string): return [] try: pattern = r'<tool_call>((?:(?!</tool_call>).)*)</tool_call>' matches = re.finditer(pattern, string, re.DOTALL) return [match.group(1).strip() for match in matches] except Exception: return [] async def execute_tool_calls(tool_runner: ToolNode, B_tool_calls: list[list[str]]) -> list[list[str]]: """executes a batch of tool calls in parallel using the tool runner.""" B_tool_responses = [[""] * len(_) for _ in B_tool_calls] scheduling = [] tool_calls = [] for i, strings in enumerate(B_tool_calls): for j, string in enumerate(strings): try: tool_calls.append(deserialize_tool_call(string)) scheduling.append((i, j)) except Exception as e: B_tool_responses[i][j] = json.dumps({"status": "error", "content": "tool call must be a JSON object with 'name' and (optional) 'arguments' fields"}) message = AIMessage(content="", tool_calls=tool_calls) tool_responses = await tool_runner.ainvoke([message]) for (i, j), tool_message in zip(scheduling, tool_responses): status, content = tool_message.status, tool_message.content if status == "error": content = content.replace("Error: ", "") content = content.strip() B_tool_responses[i][j] = json.dumps({"status": status, "content": content}) return B_tool_responses def run_coroutine_sync(coroutine: typing.Awaitable[T]) -> T: try: eventloop = asyncio.get_running_loop() except RuntimeError: return asyncio.run(coroutine) else: if eventloop.is_running(): future = asyncio.run_coroutine_threadsafe(coroutine, eventloop) return future.result() else: return eventloop.run_until_complete(coroutine) class vLLMRollout(BaseRollout): @@ -300,98 +392,6 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) class vLLMRolloutWithTool(vLLMRollout): def __init__( self, -
DiTo97 revised this gist
May 20, 2025 . 2 changed files with 101 additions and 173 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,17 +1,21 @@ """ https://github.com/Agent-RL/ReCall/blob/3d976d26ade4950bc491335bb80da1659424b3cb/src/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py """ import asyncio import json import os import re import typing import uuid from contextlib import contextmanager from typing import Any import numpy as np import torch import torch.distributed from langchain_core.messages import AIMessage, ToolCall from langchain_core.tools import BaseTool from langgraph.prebuilt import ToolNode from omegaconf import DictConfig from tensordict import TensorDict from verl import DataProto @@ -22,6 +26,9 @@ from vllm.distributed import parallel_state as vllm_ps T = typing.TypeVar("T") def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64): """ end of sentence token can be int or list: 1 or [1, 2] @@ -293,113 +300,117 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) def deserialize_tool_call(string: str) -> ToolCall: message = json.loads(string) assert isinstance(message, dict) request = { 'name': message['name'], 'args': message.get('arguments', {}), 'id': str(uuid.uuid4()), 'type': 'tool_call', } return request def validate_tool_calls(string: str) -> bool: balance = 0 for match in re.finditer(r'</?tool_call>', string): match = match.group() if match == '<tool_call>': if balance > 0: return False balance += 1 else: if balance < 1: return False balance -= 1 return balance == 0 def search_tool_calls(string: str) -> list[str]: if not validate_tool_calls(string): return [] try: pattern = r'<tool_call>((?:(?!</tool_call>).)*)</tool_call>' matches = re.finditer(pattern, string, re.DOTALL) return [match.group(1).strip() for match in matches] except Exception: return [] async def execute_tool_calls(tool_runner: ToolNode, B_tool_calls: list[list[str]]) -> list[list[str]]: """executes a batch of tool calls in parallel using the tool runner.""" B_tool_responses = [[""] * len(_) for _ in B_tool_calls] scheduling = [] tool_calls = [] for i, strings in enumerate(B_tool_calls): for j, string in enumerate(strings): try: tool_calls.append(deserialize_tool_call(string)) scheduling.append((i, j)) except Exception as e: B_tool_responses[i][j] = json.dumps({"status": "error", "content": "tool call must be a JSON object with 'name' and (optional) 'arguments' fields"}) message = AIMessage(content="", tool_calls=tool_calls) tool_responses = await tool_runner.ainvoke([message]) for (i, j), tool_message in zip(scheduling, tool_responses): status, content = tool_message.status, tool_message.content if status == "error": content = content.replace("Error: ", "") content = content.strip() B_tool_responses[i][j] = json.dumps({"status": status, "content": content}) return B_tool_responses def run_coroutine_sync(coroutine: typing.Awaitable[T]) -> T: try: eventloop = asyncio.get_running_loop() except RuntimeError: return asyncio.run(coroutine) else: if eventloop.is_running(): future = asyncio.run_coroutine_threadsafe(coroutine, eventloop) return future.result() else: return eventloop.run_until_complete(coroutine) class vLLMRolloutWithTool(vLLMRollout): def __init__( self, model_path: str, config: DictConfig, tokenizer, model_hf_config, toolkit: list[BaseTool | typing.Callable[..., Any]], **kwargs ): super().__init__(model_path, config, tokenizer, model_hf_config, **kwargs) self.tokenizer = tokenizer self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() self.gen_str = "\n<|im_start|>assistant\n<think>" self.gen_ids = self.tokenizer.encode(self.gen_str) self.tool_runner = ToolNode(toolkit) @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # rebuild vllm cache engine @@ -449,14 +460,6 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: curr_inputs.append(input_ids.copy()) init_inputs = [ids.copy() for ids in curr_inputs] # track the status of each input curr_max_tokens = [self.sampling_params.max_tokens] * len(curr_inputs) active_indices = list(range(len(curr_inputs))) @@ -504,7 +507,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: result_mask_list[idx] += [1] * len(output_ids) output_str = self.tokenizer.decode(output_ids) tool_calls = search_tool_calls(output_str) if tool_calls: tool_calls_list.append(tool_calls) call_indices.append(idx) @@ -525,8 +528,9 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: if tool_calls_list: # Only tp_rank 0 executes the tools if self.tp_rank == 0: tool_responses_list = run_coroutine_sync( execute_tool_calls(self.tool_runner, tool_calls_list) ) # Prepare data for broadcasting broadcast_data = { This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,76 +0,0 @@ -
DiTo97 revised this gist
May 19, 2025 . 2 changed files with 139 additions and 122 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,47 +1,64 @@ """ https://github.com/Agent-RL/ReCall/blob/3d976d26ade4950bc491335bb80da1659424b3cb/src/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py """ import json import os import re from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from typing import Any import numpy as np import requests import torch import torch.distributed from omegaconf import DictConfig from tensordict import TensorDict from verl import DataProto from verl.third_party.vllm import vllm_version from verl.utils.torch_functional import pad_2d_list_to_length, pad_sequence_to_length from verl.workers.rollout.base import BaseRollout from vllm import LLM, SamplingParams from vllm.distributed import parallel_state as vllm_ps def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64): """ end of sentence token can be int or list: 1 or [1, 2] e.g. response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0], [78, 0, 76, 2, 1, 0, 0], [23, 98, 1, 0, 0, 0, 0], [33, 3, 98, 45, 1, 0, 0]]) #eos_token=1 response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0]]) #eos_token=[1,2] response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0]]) """ eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int() return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype) def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[int]: non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids def _repeat_interleave(value: torch.Tensor | np.ndarray, repeats: int) -> torch.Tensor | list[Any]: if isinstance(value, torch.Tensor): return value.repeat_interleave(repeats, dim=0) else: return np.repeat(value, repeats, axis=0) class vLLMRollout(BaseRollout): def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): """A vLLM rollout. It requires the module is supported by the vllm. @@ -275,114 +292,113 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) def format_tool_call(tool_call_str: str): """Convert JSON function call description to Python executable code string.""" try: call_json = json.loads(tool_call_str) func_name = call_json['name'] arguments = call_json.get('arguments', {}) args_str = ', '.join(f"{k}={repr(v)}" for k, v in arguments.items()) return f"{func_name}({args_str})" except Exception as e: return f"Parse tool call failed: {e}" def validate_tool_calls(self, output_str): start_tags = re.findall(r'<tool_call>', output_str) end_tags = re.findall(r'</tool_call>', output_str) if len(start_tags) != len(end_tags): return False start_positions = [m.start() for m in re.finditer(r'<tool_call>', output_str)] end_positions = [m.start() for m in re.finditer(r'</tool_call>', output_str)] for start, end in zip(start_positions, end_positions): if start >= end: return False return True def extract_tool_calls(self, output_str): if not validate_tool_calls(output_str): return [] try: pattern = r'<tool_call>((?:(?!</tool_call>).)*)</tool_call>' matches = re.finditer(pattern, output_str, re.DOTALL) return [match.group(1).strip() for match in matches] except Exception as e: return [] def batch_execute(sandbox_url: str, env_list: list[str], tool_calls_list: list[list[str]]): def exe_tool_call(env, call): url = f'{sandbox_url}/execute' call_str = format_tool_call(call) if call_str.startswith("Parse tool call failed"): return call_str try: data = { 'env': env, 'call': call_str } response = requests.post(url, json=data, timeout=10) if response.status_code != 200: return f"error: {response.status_code}" response = response.json() ret_str = '' if response['result']: ret_str += f'result: \n{response["result"]}\n' if response['output']: ret_str += f'output: \n{response["output"]}\n' if response['error']: ret_str += f'error: \n{response["error"]}\n' return ret_str.strip() except requests.exceptions.Timeout: return "error: execution timed out" except Exception as e: return str(e) # flatten all tasks all_tasks = [] task_indices = [] for env_idx, (env, tool_calls) in enumerate(zip(env_list, tool_calls_list)): for call_idx, tool_call in enumerate(tool_calls): all_tasks.append((env, tool_call)) task_indices.append((env_idx, call_idx)) # parallel execute all tasks all_results = [None] * len(all_tasks) with ThreadPoolExecutor(max_workers=8) as executor: future_to_index = {executor.submit(exe_tool_call, env, call): i for i, (env, call) in enumerate(all_tasks)} for future in as_completed(future_to_index): index = future_to_index[future] all_results[index] = future.result() # reorganize results to original structure results_list = [[None for _ in range(len(tool_calls_list[i]))] for i, _ in enumerate(env_list)] for (env_idx, call_idx), result in zip(task_indices, all_results): results_list[env_idx][call_idx] = result return results_list class vLLMRolloutWithTool(vLLMRollout): def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): super().__init__(model_path, config, tokenizer, model_hf_config, **kwargs) self.tokenizer = tokenizer self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() self.gen_str = "\n<|im_start|>assistant\n<think>" self.gen_ids = self.tokenizer.encode(self.gen_str) @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: @@ -395,13 +411,13 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: attention_mask = prompts.batch['attention_mask'] position_ids = prompts.batch['position_ids'] # # used to construct attention_mask # eos_token_id = prompts.meta_info['eos_token_id'] batch_size = ori_input_ids.size(0) idx_list = [] # parse idx from torch.Tensor to list[list[str]] for i in range(batch_size): idx_list.append(_pre_process_inputs(self.pad_token_id, ori_input_ids[i])) @@ -473,8 +489,8 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: ) # collect all tool calls tool_calls_list: list[list[str]] = [] call_indices: list[int] = [] # process each output new_active_indices = [] @@ -488,7 +504,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: result_mask_list[idx] += [1] * len(output_ids) output_str = self.tokenizer.decode(output_ids) tool_calls: list[str] = extract_tool_calls(output_str) if tool_calls: tool_calls_list.append(tool_calls) call_indices.append(idx) @@ -510,7 +526,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # Only tp_rank 0 executes the tools if self.tp_rank == 0: active_env_list = [env_list[i] for i in call_indices] tool_responses_list = batch_execute(self.config.sandbox_url, active_env_list, tool_calls_list) # Prepare data for broadcasting broadcast_data = { This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1 +1,2 @@ langgraph>0.4,<1 verl[vllm]==0.3.0.post1 -
DiTo97 revised this gist
May 15, 2025 . 1 changed file with 76 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,76 @@ import asyncio import json from typing import Any from langchain.tools import BaseTool from langgraph.prebuilt import ToolNode from pydantic import BaseModel, ValidationError class ToolCallResponse(BaseModel): id: str name: str content: Any class ToolCallRequest(BaseModel): id: str name: str arguments: dict[str, Any] = {} class ToolkitRunner: """A runner for efficient tool call execution with a toolkit.""" def __init__(self, toolkit: list[BaseTool]): self.runner = ToolNode(toolkit) def available(self, request: ToolCallRequest) -> bool: return request.name in self.toolkit.tools_by_name async def execute(self, requests: list[str]) -> list[ToolResponse | None]: """executes tool call requests (JSON strings) asynchronously in preserving order.""" responses = [None] * len(requests) mappings = {} payloads = [] for i, string in enumerate(requests): try: payload = json.loads(string) payload = ToolCallRequest.parse_obj(payload) except ( json.JSONDecodeError, ValidationError ): continue if self.available(payload): mappings[request.id] = i payloads.append({ "id": payload.id, "name": payload.name, "args": payload.arguments, "type": "tool_call" }) else: responses[i] = ToolCallResponse( id=payload.id, name=payload.name, content=f"tool '{payload.name}' not found in toolkit" ) if payloads: messages = await self.runner.ainvoke(payloads) for message in messages: i = mappings.get(message.tool_call_id) if not i: continue responses[i] = ToolCallResponse( id=message.tool_call_id, name=message.name, content=message.content ) return responses -
DiTo97 revised this gist
May 15, 2025 . No changes.There are no files selected for viewing
-
DiTo97 revised this gist
May 14, 2025 . 1 changed file with 0 additions and 166 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,166 +0,0 @@ -
DiTo97 created this gist
May 14, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,626 @@ """ https://github.com/Agent-RL/ReCall/blob/3d976d26ade4950bc491335bb80da1659424b3cb/src/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py """ import os import numpy as np from typing import List from contextlib import contextmanager from omegaconf import DictConfig import torch import torch.distributed from tensordict import TensorDict from torch import nn from typing import Any, Union from verl import DataProto from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length from verl.workers.rollout.base import BaseRollout from vllm.distributed import parallel_state as vllm_ps from vllm import LLM, SamplingParams from verl.third_party.vllm import vllm_version # TODO # 1. support pp in vllm # 2. passing tokenizer is not necessary? no encoding/decoding is happending here # 3. simplify init logics # NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]: # remove the left padding in the prompt token_id # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]: if isinstance(value, torch.Tensor): return value.repeat_interleave(repeats, dim=0) else: return np.repeat(value, repeats, axis=0) class vLLMRollout(BaseRollout): def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): """A vLLM rollout. It requires the module is supported by the vllm. Args: module: module here follows huggingface APIs config: DictConfig tokenizer: the task/model tokenizer model_hf_config: the huggingface config to initiallize the generating model in vllm **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group """ super().__init__() self.config = config assert not (not config.enforce_eager and config.free_cache_engine), \ "disable CUDA graph (enforce_eager = False) if free cache engine" tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1) assert tensor_parallel_size <= torch.distributed.get_world_size(), \ "tensor parallel size should be less than or equal to the world size" max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192) if kwargs.get('train_tp', None) is not None: # deployed with megatron os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0' os.environ['MEGATRON_IMPORT_TIMERS'] = '0' if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'): train_tp = kwargs.get('train_tp', None) num_tp_per_train_tp = train_tp // tensor_parallel_size vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp) else: vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \ "model context length should be greater than total sequence length" max_model_len = self.config.max_model_len if self.config.max_model_len \ else config.prompt_length + config.response_length max_model_len = int(max_model_len) if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: raise ValueError('Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ please increase max_num_batched_tokens or disable chunked prefill') trust_remote_code = kwargs.get('trust_remote_code', False) load_format = 'dummy' if config.load_format.startswith('dummy') else config.load_format self.inference_engine = LLM( model=model_path, enable_sleep_mode=True, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend="external_launcher", dtype=config.dtype, enforce_eager=config.enforce_eager, gpu_memory_utilization=config.gpu_memory_utilization, disable_custom_all_reduce=True, disable_mm_preprocessor_cache=True, skip_tokenizer_init=False, max_model_len=max_model_len, load_format=load_format, disable_log_stats=config.disable_log_stats, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=config.enable_chunked_prefill, enable_prefix_caching=True, trust_remote_code=trust_remote_code, seed=int(os.getenv("RANK", "0")) // tensor_parallel_size, ) # Offload vllm model to reduce peak memory usage self.inference_engine.sleep(level=1) kwargs = dict( n=1, logprobs=0, # can be set to 0 and let actor to recompute max_tokens=config.response_length, ) # # we may detokenize the result all together later if vllm_version != '0.3.1': kwargs['detokenize'] = False # supporting adding any sampling params from the config file for k in config.keys(): if hasattr(SamplingParams(), str(k)): kwargs[k] = config.get(k) print(f"kwargs: {kwargs}") self.sampling_params = SamplingParams(**kwargs) self.pad_token_id = tokenizer.pad_token_id @contextmanager def update_sampling_params(self, **kwargs): # update sampling params old_sampling_params_args = {} if kwargs: for key, value in kwargs.items(): if hasattr(self.sampling_params, key): old_value = getattr(self.sampling_params, key) old_sampling_params_args[key] = old_value setattr(self.sampling_params, key, value) yield # roll back to previous sampling params # if len(old_sampling_params_args): for key, value in old_sampling_params_args.items(): setattr(self.sampling_params, key, value) @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # rebuild vllm cache engine if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: self.inference_engine.init_cache_engine() idx = prompts.batch['input_ids'] # (bs, prompt_length) # left-padded attention_mask attention_mask = prompts.batch['attention_mask'] position_ids = prompts.batch['position_ids'] # used to construct attention_mask eos_token_id = prompts.meta_info['eos_token_id'] batch_size = idx.size(0) non_tensor_batch = prompts.non_tensor_batch if 'raw_prompt_ids' not in non_tensor_batch: non_tensor_batch['raw_prompt_ids'] = np.array( [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) if batch_size != len(non_tensor_batch['raw_prompt_ids']): raise RuntimeError('vllm sharding manager is not work properly.') if 'multi_modal_data' in non_tensor_batch: vllm_inputs = [] for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'), non_tensor_batch.pop('multi_modal_data')): vllm_inputs.append({'prompt_token_ids': raw_prompt_ids, 'multi_modal_data': multi_modal_data}) else: vllm_inputs = [{ 'prompt_token_ids': raw_prompt_ids } for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')] # ensure the type of `prompt_token_ids` passed to vllm is list[int] # https://github.com/volcengine/verl/pull/772 for input_data in vllm_inputs: if isinstance(input_data['prompt_token_ids'], np.ndarray): input_data['prompt_token_ids'] = input_data['prompt_token_ids'].tolist() elif not isinstance(input_data['prompt_token_ids'], list): raise TypeError( f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") do_sample = prompts.meta_info.get('do_sample', True) is_validate = prompts.meta_info.get('validate', False) if not do_sample: kwargs = { 'best_of': 1, 'top_p': 1.0, 'top_k': -1, 'min_p': 0.0, 'temperature': 0, 'n': 1 # if greedy, only 1 response } elif is_validate: # TODO: try ** kwargs = { 'top_k': self.config.val_kwargs.top_k, 'top_p': self.config.val_kwargs.top_p, 'temperature': self.config.val_kwargs.temperature, 'n': 1, # if validate, already repeat in ray_trainer } # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): outputs = self.inference_engine.generate( prompts=vllm_inputs, # because we have already convert it to prompt token id sampling_params=self.sampling_params, use_tqdm=False) # TODO(sgm): disable logprob when recompute_log_prob is enable # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) response = [] for output in outputs: for sample_id in range(len(output.outputs)): response.append(output.outputs[sample_id].token_ids) response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device) if self.sampling_params.n > 1 and do_sample: idx = _repeat_interleave(idx, self.sampling_params.n) attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n) position_ids = _repeat_interleave(position_ids, self.sampling_params.n) batch_size = batch_size * self.sampling_params.n if 'multi_modal_inputs' in non_tensor_batch.keys(): non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'], self.sampling_params.n) seq = torch.cat([idx, response], dim=-1) response_length = response.size(1) delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1) if position_ids.dim() == 3: # qwen2vl mrope delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) # TODO(sgm): fix position_ids on right_pad # prompt: left pad + response: right pad # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict( { 'prompts': idx, 'responses': response, 'input_ids': seq, # here input_ids become the whole sentences # 'old_log_probs': log_probs, # we will recompute old log prob with actor 'attention_mask': attention_mask, 'position_ids': position_ids }, batch_size=batch_size) # free vllm cache engine if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: self.inference_engine.free_cache_engine() return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) import re import json import requests from concurrent.futures import ThreadPoolExecutor, as_completed from verl.utils.torch_functional import pad_sequence_to_length class vLLMRolloutWithTool(vLLMRollout): def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): super().__init__(model_path, config, tokenizer, model_hf_config, **kwargs) self.tokenizer = tokenizer self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() self.gen_str = "\n<|im_start|>assistant\n<think>" self.gen_ids = self.tokenizer.encode(self.gen_str) def format_tool_call(self, tool_call_str: str): """Convert JSON function call description to Python executable code string.""" try: call_json = json.loads(tool_call_str) func_name = call_json['name'] arguments = call_json.get('arguments', {}) args_str = ', '.join(f"{k}={repr(v)}" for k, v in arguments.items()) return f"{func_name}({args_str})" except Exception as e: return f"Parse tool call failed: {e}" def validate_tool_calls(self, output_str): start_tags = re.findall(r'<tool_call>', output_str) end_tags = re.findall(r'</tool_call>', output_str) if len(start_tags) != len(end_tags): return False start_positions = [m.start() for m in re.finditer(r'<tool_call>', output_str)] end_positions = [m.start() for m in re.finditer(r'</tool_call>', output_str)] for start, end in zip(start_positions, end_positions): if start >= end: return False return True def extract_tool_calls(self, output_str): if not self.validate_tool_calls(output_str): return [] try: pattern = r'<tool_call>((?:(?!</tool_call>).)*)</tool_call>' matches = re.finditer(pattern, output_str, re.DOTALL) return [match.group(1).strip() for match in matches] except Exception as e: return [] def batch_execute(self, env_list: List[str], tool_calls_list: List[List[str]]): def exe_tool_call(env, call): url = f'{self.config.sandbox_url}/execute' call_str = self.format_tool_call(call) if call_str.startswith("Parse tool call failed"): return call_str try: data = { 'env': env, 'call': call_str } response = requests.post(url, json=data, timeout=10) if response.status_code != 200: return f"error: {response.status_code}" response = response.json() ret_str = '' if response['result']: ret_str += f'result: \n{response["result"]}\n' if response['output']: ret_str += f'output: \n{response["output"]}\n' if response['error']: ret_str += f'error: \n{response["error"]}\n' return ret_str.strip() except requests.exceptions.Timeout: return "error: execution timed out" except Exception as e: return str(e) # flatten all tasks all_tasks = [] task_indices = [] for env_idx, (env, tool_calls) in enumerate(zip(env_list, tool_calls_list)): for call_idx, tool_call in enumerate(tool_calls): all_tasks.append((env, tool_call)) task_indices.append((env_idx, call_idx)) # parallel execute all tasks all_results = [None] * len(all_tasks) with ThreadPoolExecutor(max_workers=8) as executor: future_to_index = {executor.submit(exe_tool_call, env, call): i for i, (env, call) in enumerate(all_tasks)} for future in as_completed(future_to_index): index = future_to_index[future] all_results[index] = future.result() # reorganize results to original structure results_list = [[None for _ in range(len(tool_calls_list[i]))] for i, _ in enumerate(env_list)] for (env_idx, call_idx), result in zip(task_indices, all_results): results_list[env_idx][call_idx] = result return results_list @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # rebuild vllm cache engine if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: self.inference_engine.init_cache_engine() ori_input_ids = prompts.batch['input_ids'] # (bs, prompt_length) # left-padded attention_mask attention_mask = prompts.batch['attention_mask'] position_ids = prompts.batch['position_ids'] # used to construct attention_mask eos_token_id = prompts.meta_info['eos_token_id'] batch_size = ori_input_ids.size(0) idx_list = [] # parse idx from torch.Tensor to List[List[str]] for i in range(batch_size): idx_list.append(_pre_process_inputs(self.pad_token_id, ori_input_ids[i])) do_sample = prompts.meta_info.get('do_sample', True) is_validate = prompts.meta_info.get('validate', False) if not do_sample: kwargs = { 'best_of': 1, 'top_p': 1.0, 'top_k': -1, 'min_p': 0.0, 'temperature': 0, 'n': 1 # if greedy, only 1 response } elif is_validate: # TODO: try ** kwargs = { 'top_k': self.config.val_kwargs.top_k, 'top_p': self.config.val_kwargs.top_p, 'temperature': self.config.val_kwargs.temperature, 'n': 1, # if validate, already repeat in ray_trainer } with self.update_sampling_params(**kwargs): # prepare n copies for each input curr_inputs = [] for input_ids in idx_list: for _ in range(self.sampling_params.n): curr_inputs.append(input_ids.copy()) init_inputs = [ids.copy() for ids in curr_inputs] # if there are envs, prepare n copies for each env env_list = None if 'env' in prompts.non_tensor_batch: env_list = [] for env in prompts.non_tensor_batch['env']: for _ in range(self.sampling_params.n): env_list.append(env) # track the status of each input curr_max_tokens = [self.sampling_params.max_tokens] * len(curr_inputs) active_indices = list(range(len(curr_inputs))) # collect the result mask of each rollout, 1 for non-result, 0 for tool call result or pad result_mask_list = [[] for _ in range(len(curr_inputs))] # generate until all inputs are completed for step in range(self.config.max_turns): if len(active_indices) == 0: break # only process the active inputs active_inputs = [curr_inputs[i] for i in active_indices] active_max_tokens = [curr_max_tokens[i] for i in active_indices] with self.update_sampling_params( n=1, max_tokens=min(512, max(active_max_tokens)), stop_token_ids=[151644], top_p=0.99, ): # 512 at most, and add <|im_start|> as stop for corner case vllm_inputs = [{ 'prompt_token_ids': raw_prompt_ids } for raw_prompt_ids in active_inputs] outputs = self.inference_engine.generate( prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=False ) # collect all tool calls tool_calls_list: List[List[str]] = [] call_indices: List[int] = [] # process each output new_active_indices = [] for i, idx in enumerate(active_indices): output_ids = outputs[i].outputs[0].token_ids finish_reason = outputs[i].outputs[0].finish_reason stop_reason = outputs[i].outputs[0].stop_reason if finish_reason == 'stop' and (stop_reason == None or stop_reason == self.tokenizer.pad_token_id): curr_inputs[idx] += output_ids result_mask_list[idx] += [1] * len(output_ids) output_str = self.tokenizer.decode(output_ids) tool_calls: List[str] = self.extract_tool_calls(output_str) if tool_calls: tool_calls_list.append(tool_calls) call_indices.append(idx) new_active_indices.append(idx) else: pass # no tool calls elif finish_reason == 'length': # output over max tokens curr_inputs[idx] += output_ids result_mask_list[idx] += [1] * len(output_ids) elif finish_reason == 'stop' and stop_reason == 151644: # 151644 is the id of <|im_start|>, is a illigal stop, we stop here curr_inputs[idx] += output_ids result_mask_list[idx] += [1] * len(output_ids) else: raise ValueError(f"unknown stop reason. finish_reason: {finish_reason}, stop_reason: {stop_reason}") # batch process tool calls if tool_calls_list: # Only tp_rank 0 executes the tools if self.tp_rank == 0: active_env_list = [env_list[i] for i in call_indices] tool_responses_list = self.batch_execute(active_env_list, tool_calls_list) # Prepare data for broadcasting broadcast_data = { 'tool_calls_list': tool_calls_list, 'call_indices': call_indices, 'tool_responses_list': tool_responses_list } else: broadcast_data = None broadcast_data = vllm_ps._TP.broadcast_object(broadcast_data, src=0) # All ranks process the broadcasted data if broadcast_data is not None: tool_calls_list = broadcast_data['tool_calls_list'] call_indices = broadcast_data['call_indices'] tool_responses_list = broadcast_data['tool_responses_list'] for idx, tool_calls, tool_responses in zip(call_indices, tool_calls_list, tool_responses_list): tool_response_str = '' for call, response in zip(tool_calls, tool_responses): tool_response_str += f"<tool_response>{call}\n{response}\n</tool_response>\n" tool_response_str = "\n<|im_start|>user\n" + tool_response_str + "<|im_end|>" output_ids = self.tokenizer.encode(tool_response_str) curr_inputs[idx] += output_ids result_mask_list[idx] += [0] * len(output_ids) curr_inputs[idx] += self.gen_ids result_mask_list[idx] += [0] * len(self.gen_ids) # check if need to truncate, if yes, truncate, and remove from active; if no, update curr_max_tokens length_checked_active_indices = [] for idx in active_indices: assert len(curr_inputs[idx]) - len(init_inputs[idx]) == len(result_mask_list[idx]), f"curr_inputs: {len(curr_inputs[idx])}, init_inputs: {len(init_inputs[idx])}, result_mask_list: {len(result_mask_list[idx])}" if len(curr_inputs[idx]) - len(init_inputs[idx]) >= self.config.response_length: curr_inputs[idx] = init_inputs[idx] \ + curr_inputs[idx][len(init_inputs[idx]):len(init_inputs[idx])+self.config.response_length] result_mask_list[idx] = result_mask_list[idx][:self.config.response_length] else: curr_max_tokens[idx] = self.config.response_length - len(curr_inputs[idx]) + len(init_inputs[idx]) if idx in new_active_indices: length_checked_active_indices.append(idx) active_indices = length_checked_active_indices output_ids_list = [] # collect the all rollouts for i, input_ids in enumerate(idx_list): for j in range(self.sampling_params.n): idx = i * self.sampling_params.n + j input_len = len(input_ids) output_ids_list.append(curr_inputs[idx][input_len:]) response_attention_mask_list = [] response_list = [] result_mask_list_padded = [] for output_ids, result_mask in zip(output_ids_list, result_mask_list): assert len(output_ids) == len(result_mask), f"output_ids: {len(output_ids)}, result_mask: {len(result_mask)}" # to tensor response = torch.tensor(output_ids, device=ori_input_ids.device) result_mask = torch.tensor(result_mask, device=ori_input_ids.device) # response attention mask, 1 for valid, 0 for invalid response_attention_mask = torch.ones_like(response, dtype=torch.int64) response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) response_attention_mask_list.append(response_attention_mask) # response, pad to response_length response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) response_list.append(response) # result mask, 1 for non-result, 0 for result or pad result_mask = pad_sequence_to_length(result_mask, self.config.response_length, 0) result_mask_list_padded.append(result_mask) response_attention_mask = torch.stack(response_attention_mask_list, dim=0) response = torch.stack(response_list, dim=0) result_mask = torch.stack(result_mask_list_padded, dim=0) if self.config.n > 1 and do_sample: ori_input_ids = ori_input_ids.repeat_interleave(self.config.n, dim=0) attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0) position_ids = position_ids.repeat_interleave(self.config.n, dim=0) batch_size = batch_size * self.config.n seq = torch.cat([ori_input_ids, response], dim=-1) response_length = response.size(1) delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) # TODO(sgm): fix position_ids on right_pad # prompt: left pad + response: right pad # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) # concat attenion_mask for input and response attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # result mask: result part is 0, other part is 1 loss_mask = result_mask * response_attention_mask # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict({ 'prompts': ori_input_ids, 'responses': response, 'input_ids': seq, # here input_ids become the whole sentences 'attention_mask': attention_mask, 'loss_mask': loss_mask, 'position_ids': position_ids }, batch_size=batch_size) # free vllm cache engine if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: self.inference_engine.free_cache_engine() return DataProto(batch=batch) This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,166 @@ """ https://github.com/Agent-RL/ReCall/blob/3d976d26ade4950bc491335bb80da1659424b3cb/src/re_call/inference/re_call.py """ import re import json import requests import time from typing import List from functools import wraps def retry(max: int=10, sleep: int=1): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): for i in range(max): try: return func(*args, **kwargs) except Exception as e: print(f"[retry] try {i} times") if i == max - 1: raise Exception("Retry {} failed after {} times".format(func.__name__, max)) elif sleep: time.sleep(sleep) return wrapper return decorator class ReCall(): system_prompt = """In this environment you have access to a set of tools you can use to assist with the user query. \ You may perform multiple rounds of function calls. \ In each round, you can call one or more functions. \ Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n``` In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \ The reasoning process and function calling are enclosed within <think> </think> and <tool_call> </tool_call> tags. \ The results of the function calls will be given back to you after execution, \ and you can continue to call functions until you get the final answer for the user's question. \ Finally, if you have got the answer, enclose it within \\boxed{{}} with latex format and do not continue to call functions, \ i.e., <think> Based on the response from the function call, I get the weather information. </think> The weather in Beijing on 2025-04-01 is \\[ \\boxed{{20C}} \\]. For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags: <tool_call> {{"name": <function-name>, "arguments": <args-json-object>}} </tool_call>""" def __init__(self, model_url, executor_url): self.model_url = model_url self.executor_url = executor_url def init_prompt(self, func_schemas, question): system_prompt = f"<|im_start|>system\n{self.system_prompt.format(func_schemas=func_schemas)}<|im_end|>" user_prompt = f"<|im_start|>user\n{question}<|im_end|>" assistant_prefix = f"<|im_start|>assistant\n<think>" return system_prompt + "\n" + user_prompt + "\n" + assistant_prefix def cat_assistant_response(self, curr_prompt, assistant_response): return curr_prompt + assistant_response + "<|im_end|>" def cat_tool_results(self, curr_prompt, tool_calls, results): tool_response_str = "" for tool_call, result in zip(tool_calls, results): tool_response_str += f"<tool_response>{tool_call}\n{result}\n</tool_response>\n" tool_response_str = f"<|im_start|>user\n{tool_response_str}<|im_end|>" assistant_prefix = f"<|im_start|>assistant\n<think>" return curr_prompt + "\n" + tool_response_str + "\n" + assistant_prefix def format_tool_call(self, tool_call_str: str): """Convert JSON function call description to Python executable code string.""" try: call_json = json.loads(tool_call_str) func_name = call_json['name'] arguments = call_json.get('arguments', {}) args_str = ', '.join(f"{k}={repr(v)}" for k, v in arguments.items()) return f"{func_name}({args_str})" except Exception as e: return f"Parse tool call failed: {e}" def execute_tool_calls(self, env: str, tool_calls: List[str]) -> List[str]: def exe_tool_call(env, call): url = self.executor_url + '/execute' call_str = self.format_tool_call(call) if call_str.startswith("error: parse tool call failed"): return call_str try: data = { 'env': env, 'call': call_str } response = requests.post(url, json=data, timeout=3) if response.status_code != 200: return f"error: {response.status_code}" response = response.json() ret_str = '' if response['result']: ret_str += f'result: \n{response["result"]}\n' if response['output']: ret_str += f'output: \n{response["output"]}\n' if response['error']: ret_str += f'error: \n{response["error"]}\n' return ret_str.strip() except requests.exceptions.Timeout: return "error: execution timed out" except Exception as e: return str(e) results = [] for tool_call in tool_calls: result = exe_tool_call(env, tool_call) results.append(result) return results def validate_tool_calls(self, output_str): start_tags = re.findall(r'<tool_call>', output_str) end_tags = re.findall(r'</tool_call>', output_str) if len(start_tags) != len(end_tags): return False start_positions = [m.start() for m in re.finditer(r'<tool_call>', output_str)] end_positions = [m.start() for m in re.finditer(r'</tool_call>', output_str)] for start, end in zip(start_positions, end_positions): if start >= end: return False return True def extract_tool_calls(self, output_str): if not self.validate_tool_calls(output_str): return [] try: pattern = r'<tool_call>((?:(?!</tool_call>).)*)</tool_call>' matches = re.finditer(pattern, output_str, re.DOTALL) return [match.group(1).strip() for match in matches] except Exception as e: return [] @retry(max=5, sleep=1) def run(self, env, func_schemas, question): curr_prompt = self.init_prompt(func_schemas, question) for _ in range(5): response = requests.post( f'{self.model_url}/generate', json={ "text": curr_prompt, "sampling_params": { "temperature": 0.0, "max_new_tokens": 512 } } ).json() curr_prompt = self.cat_assistant_response(curr_prompt, response['text']) tool_calls: List[str] = self.extract_tool_calls(response['text']) if len(tool_calls) == 0: break results: List[str] = self.execute_tool_calls(env, tool_calls) curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results) return curr_prompt This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1 @@ langgraph>0.4,<1