Bases: BaseToolCallDetectionStrategy
A strategy for detecting tool calls using vendor-provided metadata in SSE chunks.
This strategy processes tool calls by accumulating function names and arguments
across multiple Server-Sent Events (SSE) chunks. It relies on vendor-specific
metadata in the chunks to identify tool calls and their completion status.
Attributes:
Example
detector = VendorToolCallDetectionStrategy()
async for chunk in stream:
result = await detector.detect_chunk(chunk, context)
if result.state == DetectionState.COMPLETE_MATCH:
tool_calls = result.tool_calls
# Process the complete tool calls
Source code in src/llm/tool_detection/vendor_detection_strategy.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177 | class VendorToolCallDetectionStrategy(BaseToolCallDetectionStrategy):
"""A strategy for detecting tool calls using vendor-provided metadata in SSE chunks.
This strategy processes tool calls by accumulating function names and arguments
across multiple Server-Sent Events (SSE) chunks. It relies on vendor-specific
metadata in the chunks to identify tool calls and their completion status.
Attributes:
found_complete_call (bool): Flag indicating if a complete tool call was found.
collected_tool_calls (List[ToolCall]): List of fully collected tool calls.
partial_name (str): Buffer for accumulating function name.
partial_args (str): Buffer for accumulating function arguments.
Example:
```python
detector = VendorToolCallDetectionStrategy()
async for chunk in stream:
result = await detector.detect_chunk(chunk, context)
if result.state == DetectionState.COMPLETE_MATCH:
tool_calls = result.tool_calls
# Process the complete tool calls
```
"""
def __init__(self):
"""Initialize the vendor tool call detection strategy."""
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.debug("Initializing VendorToolCallDetectionStrategy")
self.found_complete_call = None
self.collected_tool_calls = None
self.partial_name = None
self.partial_args = None
self.reset()
def reset(self) -> None:
"""Reset all stored tool call information to initial state.
This method clears all accumulated data and resets flags, preparing the
detector for processing a new stream of chunks.
"""
self.logger.debug("Resetting detector state")
self.partial_name = None
self.partial_args = "" # Accumulates streamed arguments
self.collected_tool_calls = []
self.found_complete_call = False
async def detect_chunk(
self,
sse_chunk: SSEChunk,
context: StreamContext
) -> DetectionResult:
"""Process an SSE chunk for tool call detection.
Analyzes the chunk for tool call information using vendor-provided metadata.
Accumulates partial tool calls across multiple chunks until a complete call
is detected.
Args:
sse_chunk (SSEChunk): The chunk of streaming content to process.
context (StreamContext): Context information for the current stream.
Returns:
DetectionResult: Result of processing the chunk, including detection state
and any content or tool calls found.
Note:
This method maintains state between calls detect_chunk() to properly handle tool calls
that span multiple chunks. It relies on the finish_reason field to
determine when a tool call is complete.
"""
tool_call_data = None
if not sse_chunk.choices:
return DetectionResult(state=DetectionState.NO_MATCH)
delta = sse_chunk.choices[0].delta
finish_reason = sse_chunk.choices[0].finish_reason
text_content = delta.content if delta.content else None
# Check if tool call data is present in this chunk
tool_calls = delta.tool_calls if delta.tool_calls else None
if tool_calls:
tool_call_data = tool_calls[0] # Assuming index 0 for simplicity
function_name = tool_call_data.function.name if tool_call_data.function else None
arguments = tool_call_data.function.arguments if tool_call_data.function else None
# If this chunk contains a function name, store it
if function_name:
self.partial_name = function_name
# If arguments are being streamed, accumulate them
if arguments:
self.partial_args += arguments
# If finish_reason indicates the tool call is complete, finalize it
if finish_reason in ["tool_calls", "tool_use"]:
if self.partial_name:
try:
parsed_args = json.loads(self.partial_args) if self.partial_args else {}
except json.JSONDecodeError:
self.logger.warning("Failed to parse arguments as JSON: %s", self.partial_args[:50])
parsed_args = {"_malformed": self.partial_args}
tool_call = ToolCall(
id=(tool_call_data and tool_call_data.id) or "call_generated",
function=FunctionDetail(
name=self.partial_name,
arguments=str(parsed_args)
)
)
self.collected_tool_calls.append(tool_call)
self.found_complete_call = True
return DetectionResult(
state=DetectionState.COMPLETE_MATCH,
tool_calls=[tool_call],
content=text_content
)
# If we're still collecting tool call data, return PARTIAL_MATCH
if self.partial_name or self.partial_args:
return DetectionResult(
state=DetectionState.PARTIAL_MATCH,
content=text_content
)
# Otherwise, just return NO_MATCH and pass the text through
return DetectionResult(
state=DetectionState.NO_MATCH,
content=text_content
)
async def finalize_detection(self, context: StreamContext) -> DetectionResult:
"""Finalize the detection process and handle any accumulated tool calls.
This method is called at the end of the SSE stream to process any remaining
tool call data and return final results.
Args:
context (StreamContext): Context information for the current stream.
Returns:
DetectionResult: Final result of the detection process, including any
complete tool calls or remaining content.
Note:
This method handles cleanup of partial tool calls that were never
completed due to stream termination.
"""
self.logger.debug("Finalizing detection")
if self.found_complete_call:
self.logger.debug("Returning %d collected tool calls", len(self.collected_tool_calls))
return DetectionResult(
state=DetectionState.COMPLETE_MATCH,
tool_calls=self.collected_tool_calls
)
if self.partial_name or self.partial_args:
self.logger.debug("Incomplete tool call data at stream end")
self.logger.debug(f"Name: {self.partial_name}, Args: {self.partial_args}")
return DetectionResult(state=DetectionState.NO_MATCH)
self.logger.debug("No tool calls to finalize")
return DetectionResult(state=DetectionState.NO_MATCH)
|
Initialize the vendor tool call detection strategy.
Source code in src/llm/tool_detection/vendor_detection_strategy.py
37
38
39
40
41
42
43
44
45 | def __init__(self):
"""Initialize the vendor tool call detection strategy."""
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.debug("Initializing VendorToolCallDetectionStrategy")
self.found_complete_call = None
self.collected_tool_calls = None
self.partial_name = None
self.partial_args = None
self.reset()
|
Process an SSE chunk for tool call detection.
Analyzes the chunk for tool call information using vendor-provided metadata.
Accumulates partial tool calls across multiple chunks until a complete call
is detected.
Parameters:
Returns:
Note
This method maintains state between calls detect_chunk() to properly handle tool calls
that span multiple chunks. It relies on the finish_reason field to
determine when a tool call is complete.
Source code in src/llm/tool_detection/vendor_detection_strategy.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144 | async def detect_chunk(
self,
sse_chunk: SSEChunk,
context: StreamContext
) -> DetectionResult:
"""Process an SSE chunk for tool call detection.
Analyzes the chunk for tool call information using vendor-provided metadata.
Accumulates partial tool calls across multiple chunks until a complete call
is detected.
Args:
sse_chunk (SSEChunk): The chunk of streaming content to process.
context (StreamContext): Context information for the current stream.
Returns:
DetectionResult: Result of processing the chunk, including detection state
and any content or tool calls found.
Note:
This method maintains state between calls detect_chunk() to properly handle tool calls
that span multiple chunks. It relies on the finish_reason field to
determine when a tool call is complete.
"""
tool_call_data = None
if not sse_chunk.choices:
return DetectionResult(state=DetectionState.NO_MATCH)
delta = sse_chunk.choices[0].delta
finish_reason = sse_chunk.choices[0].finish_reason
text_content = delta.content if delta.content else None
# Check if tool call data is present in this chunk
tool_calls = delta.tool_calls if delta.tool_calls else None
if tool_calls:
tool_call_data = tool_calls[0] # Assuming index 0 for simplicity
function_name = tool_call_data.function.name if tool_call_data.function else None
arguments = tool_call_data.function.arguments if tool_call_data.function else None
# If this chunk contains a function name, store it
if function_name:
self.partial_name = function_name
# If arguments are being streamed, accumulate them
if arguments:
self.partial_args += arguments
# If finish_reason indicates the tool call is complete, finalize it
if finish_reason in ["tool_calls", "tool_use"]:
if self.partial_name:
try:
parsed_args = json.loads(self.partial_args) if self.partial_args else {}
except json.JSONDecodeError:
self.logger.warning("Failed to parse arguments as JSON: %s", self.partial_args[:50])
parsed_args = {"_malformed": self.partial_args}
tool_call = ToolCall(
id=(tool_call_data and tool_call_data.id) or "call_generated",
function=FunctionDetail(
name=self.partial_name,
arguments=str(parsed_args)
)
)
self.collected_tool_calls.append(tool_call)
self.found_complete_call = True
return DetectionResult(
state=DetectionState.COMPLETE_MATCH,
tool_calls=[tool_call],
content=text_content
)
# If we're still collecting tool call data, return PARTIAL_MATCH
if self.partial_name or self.partial_args:
return DetectionResult(
state=DetectionState.PARTIAL_MATCH,
content=text_content
)
# Otherwise, just return NO_MATCH and pass the text through
return DetectionResult(
state=DetectionState.NO_MATCH,
content=text_content
)
|
Finalize the detection process and handle any accumulated tool calls.
This method is called at the end of the SSE stream to process any remaining
tool call data and return final results.
Parameters:
Returns:
Note
This method handles cleanup of partial tool calls that were never
completed due to stream termination.
Source code in src/llm/tool_detection/vendor_detection_strategy.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177 | async def finalize_detection(self, context: StreamContext) -> DetectionResult:
"""Finalize the detection process and handle any accumulated tool calls.
This method is called at the end of the SSE stream to process any remaining
tool call data and return final results.
Args:
context (StreamContext): Context information for the current stream.
Returns:
DetectionResult: Final result of the detection process, including any
complete tool calls or remaining content.
Note:
This method handles cleanup of partial tool calls that were never
completed due to stream termination.
"""
self.logger.debug("Finalizing detection")
if self.found_complete_call:
self.logger.debug("Returning %d collected tool calls", len(self.collected_tool_calls))
return DetectionResult(
state=DetectionState.COMPLETE_MATCH,
tool_calls=self.collected_tool_calls
)
if self.partial_name or self.partial_args:
self.logger.debug("Incomplete tool call data at stream end")
self.logger.debug(f"Name: {self.partial_name}, Args: {self.partial_args}")
return DetectionResult(state=DetectionState.NO_MATCH)
self.logger.debug("No tool calls to finalize")
return DetectionResult(state=DetectionState.NO_MATCH)
|
Reset all stored tool call information to initial state.
This method clears all accumulated data and resets flags, preparing the
detector for processing a new stream of chunks.
Source code in src/llm/tool_detection/vendor_detection_strategy.py
47
48
49
50
51
52
53
54
55
56
57 | def reset(self) -> None:
"""Reset all stored tool call information to initial state.
This method clears all accumulated data and resets flags, preparing the
detector for processing a new stream of chunks.
"""
self.logger.debug("Resetting detector state")
self.partial_name = None
self.partial_args = "" # Accumulates streamed arguments
self.collected_tool_calls = []
self.found_complete_call = False
|