| import json |
| import re |
| from typing import Dict, List, Sequence, Union |
| import partial_json_parser |
| from partial_json_parser.core.options import Allow |
|
|
| from vllm.entrypoints.openai.protocol import ( |
| ChatCompletionRequest, DeltaMessage, DeltaToolCall, |
| DeltaFunctionCall, ExtractedToolCallInformation, ToolCall, FunctionCall |
| ) |
| from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager |
| from vllm.utils import random_uuid |
| from vllm.logger import init_logger |
| from transformers import PreTrainedTokenizerBase |
| from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix, |
| is_complete_json, |
| partial_json_loads) |
|
|
| logger = init_logger(__name__) |
|
|
| @ToolParserManager.register_module("xlam") |
| class xLAMToolParser(ToolParser): |
| def __init__(self, tokenizer: PreTrainedTokenizerBase): |
| super().__init__(tokenizer) |
| |
| self.prev_tool_calls: List[Dict] = [] |
| self.current_tools_sent: List[bool] = [] |
| self.streamed_args: List[str] = [] |
| |
| |
| def extract_tool_calls( |
| self, |
| model_output: str, |
| request: ChatCompletionRequest |
| ) -> ExtractedToolCallInformation: |
| try: |
| |
| if not model_output.strip().startswith('['): |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output |
| ) |
|
|
| tool_calls_data = json.loads(model_output) |
| tool_calls: List[ToolCall] = [] |
| |
| for idx, call in enumerate(tool_calls_data): |
| tool_call = ToolCall( |
| id=f"call_{idx}_{random_uuid()}", |
| type="function", |
| function=FunctionCall( |
| name=call["name"], |
| arguments=json.dumps(call["arguments"]) |
| ) |
| ) |
| tool_calls.append(tool_call) |
|
|
| return ExtractedToolCallInformation( |
| tools_called=True, |
| tool_calls=tool_calls, |
| content=None |
| ) |
|
|
| except Exception: |
| logger.exception("Error extracting tool calls") |
| return ExtractedToolCallInformation( |
| tools_called=False, |
| tool_calls=[], |
| content=model_output |
| ) |
|
|
| def extract_tool_calls_streaming( |
| self, |
| previous_text: str, |
| current_text: str, |
| delta_text: str, |
| previous_token_ids: Sequence[int], |
| current_token_ids: Sequence[int], |
| delta_token_ids: Sequence[int], |
| request: ChatCompletionRequest, |
| ) -> Union[DeltaMessage, None]: |
| if not current_text.strip().startswith('['): |
| return DeltaMessage(content=delta_text) |
|
|
| flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR |
| |
| try: |
| tool_call_arr = [] |
| is_complete = [] |
| try: |
| |
| start_idx = 0 |
| while start_idx < len(current_text): |
| obj, end_idx = partial_json_loads(current_text[start_idx:], flags) |
| is_complete.append( |
| is_complete_json(current_text[start_idx:start_idx + end_idx]) |
| ) |
| start_idx += end_idx |
| tool_call_arr.append(obj) |
| except partial_json_parser.core.exceptions.MalformedJSON: |
| logger.debug('not enough tokens to parse into JSON yet') |
| return None |
|
|
| |
| current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ |
| if len(tool_call_arr) > 0 else {} |
|
|
| |
| if len(tool_call_arr) == 0: |
| return None |
|
|
| |
| elif (len(tool_call_arr) > 0 |
| and len(tool_call_arr) > self.current_tool_id + 1): |
| |
| |
| if self.current_tool_id >= 0: |
| cur_arguments = current_tool_call.get("arguments") |
| if cur_arguments: |
| cur_args_json = json.dumps(cur_arguments) |
| sent = len(self.streamed_args[self.current_tool_id]) |
| argument_diff = cur_args_json[sent:] |
|
|
| if argument_diff: |
| delta = DeltaMessage(tool_calls=[ |
| DeltaToolCall( |
| index=self.current_tool_id, |
| function=DeltaFunctionCall( |
| arguments=argument_diff |
| ).model_dump(exclude_none=True) |
| ) |
| ]) |
| self.streamed_args[self.current_tool_id] += argument_diff |
| return delta |
|
|
| |
| self.current_tool_id = len(tool_call_arr) - 1 |
| self.current_tools_sent.append(False) |
| self.streamed_args.append("") |
| logger.debug("starting new tool %d", self.current_tool_id) |
| return None |
|
|
| |
| elif not self.current_tools_sent[self.current_tool_id]: |
| function_name = current_tool_call.get("name") |
| if function_name: |
| delta = DeltaMessage(tool_calls=[ |
| DeltaToolCall( |
| index=self.current_tool_id, |
| type="function", |
| id=f"call_{self.current_tool_id}_{random_uuid()}", |
| function=DeltaFunctionCall( |
| name=function_name |
| ).model_dump(exclude_none=True) |
| ) |
| ]) |
| self.current_tools_sent[self.current_tool_id] = True |
| return delta |
| return None |
|
|
| |
| else: |
| cur_arguments = current_tool_call.get("arguments") |
| if cur_arguments: |
| sent = len(self.streamed_args[self.current_tool_id]) |
| cur_args_json = json.dumps(cur_arguments) |
| prev_arguments = self.prev_tool_calls[self.current_tool_id].get("arguments") |
|
|
| argument_diff = None |
| if is_complete[self.current_tool_id]: |
| argument_diff = cur_args_json[sent:] |
| elif prev_arguments: |
| prev_args_json = json.dumps(prev_arguments) |
| if cur_args_json != prev_args_json: |
| prefix = find_common_prefix(prev_args_json, cur_args_json) |
| argument_diff = prefix[sent:] |
|
|
| if argument_diff is not None: |
| delta = DeltaMessage(tool_calls=[ |
| DeltaToolCall( |
| index=self.current_tool_id, |
| function=DeltaFunctionCall( |
| arguments=argument_diff |
| ).model_dump(exclude_none=True) |
| ) |
| ]) |
| self.streamed_args[self.current_tool_id] += argument_diff |
| return delta |
|
|
| self.prev_tool_calls = tool_call_arr |
| return None |
|
|
| except Exception: |
| logger.exception("Error in streaming tool calls") |
| logger.debug("Skipping chunk due to streaming error") |
| return None |
|
|
|
|