From 669e8c79af559ab28666562add5caf7e8042780d Mon Sep 17 00:00:00 2001 From: mlong93 <35275280+mlong93@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:35:22 -0800 Subject: [PATCH] feat: add anthropic streaming (#716) Co-authored-by: Mindy Long Co-authored-by: Charles Packer --- letta/agent.py | 1 + letta/llm_api/anthropic.py | 492 ++++++++++++++++++++++++++++- letta/llm_api/llm_api_tools.py | 41 ++- letta/server/rest_api/interface.py | 18 +- letta/server/server.py | 8 +- letta/streaming_utils.py | 6 +- tests/test_client_legacy.py | 32 +- 7 files changed, 558 insertions(+), 40 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 4fa9f761..9ff0f437 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -731,6 +731,7 @@ class Agent(BaseAgent): # (if yes) Step 4: call the function # (if yes) Step 5: send the info on the function call and function response to LLM response_message = response.choices[0].message + response_message.model_copy() # TODO why are we copying here? all_response_messages, heartbeat_request, function_failed = self._handle_ai_response( response_message, diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index b562d466..2c35cfdc 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -1,21 +1,41 @@ import json import re -from typing import List, Optional, Tuple, Union +import time +from typing import Generator, List, Optional, Tuple, Union import anthropic from anthropic import PermissionDeniedError +from anthropic.types.beta import ( + BetaRawContentBlockDeltaEvent, + BetaRawContentBlockStartEvent, + BetaRawContentBlockStopEvent, + BetaRawMessageDeltaEvent, + BetaRawMessageStartEvent, + BetaRawMessageStopEvent, + BetaTextBlock, + BetaToolUseBlock, +) from letta.errors import BedrockError, BedrockPermissionError from letta.llm_api.aws_bedrock import get_bedrock_client -from letta.schemas.message import Message +from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages +from letta.schemas.message import Message as _Message +from letta.schemas.message import MessageRole as _MessageRole from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool -from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall from letta.schemas.openai.chat_completion_response import ( - Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype + ChatCompletionChunkResponse, + ChatCompletionResponse, + Choice, + ChunkChoice, + FunctionCall, + FunctionCallDelta, ) -from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics +from letta.schemas.openai.chat_completion_response import Message +from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage +from letta.schemas.openai.chat_completion_response import MessageDelta, ToolCall, ToolCallDelta, UsageStatistics from letta.services.provider_manager import ProviderManager from letta.settings import model_settings +from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface from letta.utils import get_utc_time, smart_urljoin BASE_URL = "https://api.anthropic.com/v1" @@ -200,6 +220,28 @@ def strip_xml_tags(string: str, tag: Optional[str]) -> str: return re.sub(tag_pattern, "", string) +def strip_xml_tags_streaming(string: str, tag: Optional[str]) -> str: + if tag is None: + return string + + # Handle common partial tag cases + parts_to_remove = [ + "<", # Leftover start bracket + f"<{tag}", # Opening tag start + f"", # Closing tag end + f"{tag}>", # Opening tag end + f"/{tag}", # Partial closing tag without > + ">", # Leftover end bracket + ] + + result = string + for part in parts_to_remove: + result = result.replace(part, "") + + return result + + def convert_anthropic_response_to_chatcompletion( response: anthropic.types.Message, inner_thoughts_xml_tag: Optional[str] = None, @@ -307,6 +349,166 @@ def convert_anthropic_response_to_chatcompletion( ) +def convert_anthropic_stream_event_to_chatcompletion( + event: Union[ + BetaRawMessageStartEvent, + BetaRawContentBlockStartEvent, + BetaRawContentBlockDeltaEvent, + BetaRawContentBlockStopEvent, + BetaRawMessageDeltaEvent, + BetaRawMessageStopEvent, + ], + message_id: str, + model: str, + inner_thoughts_xml_tag: Optional[str] = "thinking", +) -> ChatCompletionChunkResponse: + """Convert Anthropic stream events to OpenAI ChatCompletionResponse format. + + Args: + event: The event to convert + message_id: The ID of the message. Anthropic does not return this on every event, so we need to keep track of it + model: The model used. Anthropic does not return this on every event, so we need to keep track of it + + Example response from OpenAI: + + 'id': 'MESSAGE_ID', + 'choices': [ + { + 'finish_reason': None, + 'index': 0, + 'delta': { + 'content': None, + 'tool_calls': [ + { + 'index': 0, + 'id': None, + 'type': 'function', + 'function': { + 'name': None, + 'arguments': '_th' + } + } + ], + 'function_call': None + }, + 'logprobs': None + } + ], + 'created': datetime.datetime(2025, 1, 24, 0, 18, 55, tzinfo=TzInfo(UTC)), + 'model': 'gpt-4o-mini-2024-07-18', + 'system_fingerprint': 'fp_bd83329f63', + 'object': 'chat.completion.chunk' + } + """ + # Get finish reason + finish_reason = None + if isinstance(event, BetaRawMessageDeltaEvent): + """ + BetaRawMessageDeltaEvent( + delta=Delta( + stop_reason='tool_use', + stop_sequence=None + ), + type='message_delta', + usage=BetaMessageDeltaUsage(output_tokens=45) + ) + """ + finish_reason = remap_finish_reason(event.delta.stop_reason) + + # Get content and tool calls + content = None + tool_calls = None + if isinstance(event, BetaRawContentBlockDeltaEvent): + """ + BetaRawContentBlockDeltaEvent( + delta=BetaInputJSONDelta( + partial_json='lo', + type='input_json_delta' + ), + index=0, + type='content_block_delta' + ) + + OR + + BetaRawContentBlockDeltaEvent( + delta=BetaTextDelta( + text='👋 ', + type='text_delta' + ), + index=0, + type='content_block_delta' + ) + + """ + if event.delta.type == "text_delta": + content = strip_xml_tags_streaming(string=event.delta.text, tag=inner_thoughts_xml_tag) + + elif event.delta.type == "input_json_delta": + tool_calls = [ + ToolCallDelta( + index=0, + function=FunctionCallDelta( + name=None, + arguments=event.delta.partial_json, + ), + ) + ] + elif isinstance(event, BetaRawContentBlockStartEvent): + """ + BetaRawContentBlockStartEvent( + content_block=BetaToolUseBlock( + id='toolu_01LmpZhRhR3WdrRdUrfkKfFw', + input={}, + name='get_weather', + type='tool_use' + ), + index=0, + type='content_block_start' + ) + + OR + + BetaRawContentBlockStartEvent( + content_block=BetaTextBlock( + text='', + type='text' + ), + index=0, + type='content_block_start' + ) + """ + if isinstance(event.content_block, BetaToolUseBlock): + tool_calls = [ + ToolCallDelta( + index=0, + id=event.content_block.id, + function=FunctionCallDelta( + name=event.content_block.name, + arguments="", + ), + ) + ] + elif isinstance(event.content_block, BetaTextBlock): + content = event.content_block.text + + # Initialize base response + choice = ChunkChoice( + index=0, + finish_reason=finish_reason, + delta=MessageDelta( + content=content, + tool_calls=tool_calls, + ), + ) + return ChatCompletionChunkResponse( + id=message_id, + choices=[choice], + created=get_utc_time(), + model=model, + ) + + def _prepare_anthropic_request( data: ChatCompletionRequest, inner_thoughts_xml_tag: Optional[str] = "thinking", @@ -345,7 +547,7 @@ def _prepare_anthropic_request( message["content"] = None # Convert to Anthropic format - msg_objs = [Message.dict_to_message(user_id=None, agent_id=None, openai_message_dict=m) for m in data["messages"]] + msg_objs = [_Message.dict_to_message(user_id=None, agent_id=None, openai_message_dict=m) for m in data["messages"]] data["messages"] = [m.to_anthropic_dict(inner_thoughts_xml_tag=inner_thoughts_xml_tag) for m in msg_objs] # Ensure first message is user @@ -359,7 +561,7 @@ def _prepare_anthropic_request( assert "max_tokens" in data, data # Remove OpenAI-specific fields - for field in ["frequency_penalty", "logprobs", "n", "top_p", "presence_penalty", "user"]: + for field in ["frequency_penalty", "logprobs", "n", "top_p", "presence_penalty", "user", "stream"]: data.pop(field, None) return data @@ -427,3 +629,279 @@ def anthropic_bedrock_chat_completions_request( raise BedrockPermissionError(f"User does not have access to the Bedrock model with the specified ID. {data['model']}") except Exception as e: raise BedrockError(f"Bedrock error: {e}") + + +def anthropic_chat_completions_request_stream( + data: ChatCompletionRequest, + inner_thoughts_xml_tag: Optional[str] = "thinking", + betas: List[str] = ["tools-2024-04-04"], +) -> Generator[ChatCompletionChunkResponse, None, None]: + """Stream chat completions from Anthropic API. + + Similar to OpenAI's streaming, but using Anthropic's native streaming support. + See: https://docs.anthropic.com/claude/reference/messages-streaming + """ + data = _prepare_anthropic_request(data, inner_thoughts_xml_tag) + + anthropic_override_key = ProviderManager().get_anthropic_override_key() + if anthropic_override_key: + anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) + elif model_settings.anthropic_api_key: + anthropic_client = anthropic.Anthropic() + + with anthropic_client.beta.messages.stream( + **data, + betas=betas, + ) as stream: + # Stream: https://github.com/anthropics/anthropic-sdk-python/blob/d212ec9f6d5e956f13bc0ddc3d86b5888a954383/src/anthropic/lib/streaming/_beta_messages.py#L22 + message_id = None + model = None + + for chunk in stream._raw_stream: + time.sleep(0.01) # Anthropic is really fast, faster than frontend can upload. + if isinstance(chunk, BetaRawMessageStartEvent): + """ + BetaRawMessageStartEvent( + message=BetaMessage( + id='MESSAGE ID HERE', + content=[], + model='claude-3-5-sonnet-20241022', + role='assistant', + stop_reason=None, + stop_sequence=None, + type='message', + usage=BetaUsage( + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + input_tokens=30, + output_tokens=4 + ) + ), + type='message_start' + ), + """ + message_id = chunk.message.id + model = chunk.message.model + yield convert_anthropic_stream_event_to_chatcompletion(chunk, message_id, model, inner_thoughts_xml_tag) + + +def anthropic_chat_completions_process_stream( + chat_completion_request: ChatCompletionRequest, + stream_interface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None, + inner_thoughts_xml_tag: Optional[str] = "thinking", + create_message_id: bool = True, + create_message_datetime: bool = True, + betas: List[str] = ["tools-2024-04-04"], +) -> ChatCompletionResponse: + """Process a streaming completion response from Anthropic, similar to OpenAI's streaming. + + Args: + api_key: The Anthropic API key + chat_completion_request: The chat completion request + stream_interface: Interface for handling streaming chunks + inner_thoughts_xml_tag: Tag for inner thoughts in the response + create_message_id: Whether to create a message ID + create_message_datetime: Whether to create message datetime + betas: Beta features to enable + + Returns: + The final ChatCompletionResponse + """ + assert chat_completion_request.stream == True + assert stream_interface is not None, "Required" + + # Count prompt tokens - we'll get completion tokens from the final response + chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages] + prompt_tokens = num_tokens_from_messages( + messages=chat_history, + model=chat_completion_request.model, + ) + + # Add tokens for tools if present + if chat_completion_request.tools is not None: + assert chat_completion_request.functions is None + prompt_tokens += num_tokens_from_functions( + functions=[t.function.model_dump() for t in chat_completion_request.tools], + model=chat_completion_request.model, + ) + elif chat_completion_request.functions is not None: + assert chat_completion_request.tools is None + prompt_tokens += num_tokens_from_functions( + functions=[f.model_dump() for f in chat_completion_request.functions], + model=chat_completion_request.model, + ) + + # Create a dummy message for ID/datetime if needed + dummy_message = _Message( + role=_MessageRole.assistant, + text="", + agent_id="", + model="", + name=None, + tool_calls=None, + tool_call_id=None, + ) + + TEMP_STREAM_RESPONSE_ID = "temp_id" + TEMP_STREAM_FINISH_REASON = "temp_null" + TEMP_STREAM_TOOL_CALL_ID = "temp_id" + chat_completion_response = ChatCompletionResponse( + id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID, + choices=[], + created=dummy_message.created_at, + model=chat_completion_request.model, + usage=UsageStatistics( + completion_tokens=0, + prompt_tokens=prompt_tokens, + total_tokens=prompt_tokens, + ), + ) + + if stream_interface: + stream_interface.stream_start() + + n_chunks = 0 + try: + for chunk_idx, chat_completion_chunk in enumerate( + anthropic_chat_completions_request_stream( + data=chat_completion_request, + inner_thoughts_xml_tag=inner_thoughts_xml_tag, + betas=betas, + ) + ): + assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) + + if stream_interface: + if isinstance(stream_interface, AgentChunkStreamingInterface): + stream_interface.process_chunk( + chat_completion_chunk, + message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id, + message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created, + ) + elif isinstance(stream_interface, AgentRefreshStreamingInterface): + stream_interface.process_refresh(chat_completion_response) + else: + raise TypeError(stream_interface) + + if chunk_idx == 0: + # initialize the choice objects which we will increment with the deltas + num_choices = len(chat_completion_chunk.choices) + assert num_choices > 0 + chat_completion_response.choices = [ + Choice( + finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten + index=i, + message=Message( + role="assistant", + ), + ) + for i in range(len(chat_completion_chunk.choices)) + ] + + # add the choice delta + assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk + for chunk_choice in chat_completion_chunk.choices: + if chunk_choice.finish_reason is not None: + chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason + + if chunk_choice.logprobs is not None: + chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs + + accum_message = chat_completion_response.choices[chunk_choice.index].message + message_delta = chunk_choice.delta + + if message_delta.content is not None: + content_delta = message_delta.content + if accum_message.content is None: + accum_message.content = content_delta + else: + accum_message.content += content_delta + + # TODO(charles) make sure this works for parallel tool calling? + if message_delta.tool_calls is not None: + tool_calls_delta = message_delta.tool_calls + + # If this is the first tool call showing up in a chunk, initialize the list with it + if accum_message.tool_calls is None: + accum_message.tool_calls = [ + ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments="")) + for _ in range(len(tool_calls_delta)) + ] + + # There may be many tool calls in a tool calls delta (e.g. parallel tool calls) + for tool_call_delta in tool_calls_delta: + if tool_call_delta.id is not None: + # TODO assert that we're not overwriting? + # TODO += instead of =? + if tool_call_delta.index not in range(len(accum_message.tool_calls)): + warnings.warn( + f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}" + ) + # force index 0 + # accum_message.tool_calls[0].id = tool_call_delta.id + else: + accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id + if tool_call_delta.function is not None: + if tool_call_delta.function.name is not None: + # TODO assert that we're not overwriting? + # TODO += instead of =? + if tool_call_delta.index not in range(len(accum_message.tool_calls)): + warnings.warn( + f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}" + ) + # force index 0 + # accum_message.tool_calls[0].function.name = tool_call_delta.function.name + else: + accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name + if tool_call_delta.function.arguments is not None: + if tool_call_delta.index not in range(len(accum_message.tool_calls)): + warnings.warn( + f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}" + ) + # force index 0 + # accum_message.tool_calls[0].function.arguments += tool_call_delta.function.arguments + else: + accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments + + if message_delta.function_call is not None: + raise NotImplementedError(f"Old function_call style not support with stream=True") + + # overwrite response fields based on latest chunk + if not create_message_id: + chat_completion_response.id = chat_completion_chunk.id + if not create_message_datetime: + chat_completion_response.created = chat_completion_chunk.created + chat_completion_response.model = chat_completion_chunk.model + chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint + + # increment chunk counter + n_chunks += 1 + + except Exception as e: + if stream_interface: + stream_interface.stream_end() + print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}") + raise e + finally: + if stream_interface: + stream_interface.stream_end() + + # make sure we didn't leave temp stuff in + assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices]) + assert all( + [ + all([tc.id != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True + for c in chat_completion_response.choices + ] + ) + if not create_message_id: + assert chat_completion_response.id != dummy_message.id + + # compute token usage before returning + # TODO try actually computing the #tokens instead of assuming the chunks is the same + chat_completion_response.usage.completion_tokens = n_chunks + chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks + + assert len(chat_completion_response.choices) > 0, chat_completion_response + + return chat_completion_response diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index dc43f6a6..c6e8d63a 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -6,7 +6,11 @@ import requests from letta.constants import CLI_WARNING_PREFIX from letta.errors import LettaConfigurationError, RateLimitExceededError -from letta.llm_api.anthropic import anthropic_bedrock_chat_completions_request, anthropic_chat_completions_request +from letta.llm_api.anthropic import ( + anthropic_bedrock_chat_completions_request, + anthropic_chat_completions_process_stream, + anthropic_chat_completions_request, +) from letta.llm_api.aws_bedrock import has_valid_aws_credentials from letta.llm_api.azure_openai import azure_openai_chat_completions_request from letta.llm_api.google_ai import convert_tools_to_google_ai_format, google_ai_chat_completions_request @@ -243,27 +247,38 @@ def create( ) elif llm_config.model_endpoint_type == "anthropic": - if stream: - raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Anthropic API requests") + # Force tool calling tool_call = None if force_tool_call is not None: tool_call = {"type": "function", "function": {"name": force_tool_call}} assert functions is not None + chat_completion_request = ChatCompletionRequest( + model=llm_config.model, + messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], + tools=([{"type": "function", "function": f} for f in functions] if functions else None), + tool_choice=tool_call, + max_tokens=1024, # TODO make dynamic + temperature=llm_config.temperature, + stream=stream, + ) + + # Handle streaming + if stream: # Client requested token streaming + assert isinstance(stream_interface, (AgentChunkStreamingInterface, AgentRefreshStreamingInterface)), type(stream_interface) + + response = anthropic_chat_completions_process_stream( + chat_completion_request=chat_completion_request, + stream_interface=stream_interface, + ) + return response + + # Client did not request token streaming (expect a blocking backend response) return anthropic_chat_completions_request( - data=ChatCompletionRequest( - model=llm_config.model, - messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], - tools=[{"type": "function", "function": f} for f in functions] if functions else None, - tool_choice=tool_call, - # user=str(user_id), - # NOTE: max_tokens is required for Anthropic API - max_tokens=1024, # TODO make dynamic - temperature=llm_config.temperature, - ), + data=chat_completion_request, ) # elif llm_config.model_endpoint_type == "cohere": diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 227d8827..ded9d749 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -424,6 +424,16 @@ class StreamingServerInterface(AgentChunkStreamingInterface): choice = chunk.choices[0] message_delta = choice.delta + if ( + message_delta.content is None + and message_delta.tool_calls is None + and message_delta.function_call is None + and choice.finish_reason is None + and chunk.model.startswith("claude-") + ): + # First chunk of Anthropic is empty + return None + # inner thoughts if message_delta.content is not None: processed_chunk = ReasoningMessage( @@ -515,7 +525,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.function_id_buffer += tool_call.id if tool_call.function.arguments: - updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) + if chunk.model.startswith("claude-"): + updates_main_json = tool_call.function.arguments + updates_inner_thoughts = "" + else: # OpenAI + updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) # If we have inner thoughts, we should output them as a chunk if updates_inner_thoughts: @@ -585,7 +599,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface): ): # do an additional parse on the updates_main_json if self.function_args_buffer: - updates_main_json = self.function_args_buffer + updates_main_json self.function_args_buffer = None @@ -875,7 +888,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface): raise NotImplementedError("OpenAI proxy streaming temporarily disabled") else: processed_chunk = self._process_chunk_to_letta_style(chunk=chunk, message_id=message_id, message_date=message_date) - if processed_chunk is None: return diff --git a/letta/server/server.py b/letta/server/server.py index 21143a03..1ef4c407 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1277,12 +1277,14 @@ class SyncServer(Server): # This will be attached to the POST SSE request used under-the-hood letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - # Disable token streaming if not OpenAI + # Disable token streaming if not OpenAI or Anthropic # TODO: cleanup this logic llm_config = letta_agent.agent_state.llm_config - if stream_tokens and (llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint): + if stream_tokens and ( + llm_config.model_endpoint_type not in ["openai", "anthropic"] or "inference.memgpt.ai" in llm_config.model_endpoint + ): warnings.warn( - "Token streaming is only supported for models with type 'openai' or `inference.memgpt.ai` in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." + "Token streaming is only supported for models with type 'openai', 'anthropic', or `inference.memgpt.ai` in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." ) stream_tokens = False diff --git a/letta/streaming_utils.py b/letta/streaming_utils.py index 650e6643..485c2a7a 100644 --- a/letta/streaming_utils.py +++ b/letta/streaming_utils.py @@ -209,6 +209,11 @@ class JSONInnerThoughtsExtractor: return updates_main_json, updates_inner_thoughts + # def process_anthropic_fragment(self, fragment) -> Tuple[str, str]: + # # Add to buffer + # self.main_buffer += fragment + # return fragment, "" + @property def main_json(self): return self.main_buffer @@ -233,7 +238,6 @@ class FunctionArgumentsStreamHandler: def process_json_chunk(self, chunk: str) -> Optional[str]: """Process a chunk from the function arguments and return the plaintext version""" - # Use strip to handle only leading and trailing whitespace in control structures if self.accumulating: clean_chunk = chunk.strip() diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index ddaedfd7..d1784da7 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -224,12 +224,29 @@ def test_core_memory(mock_e2b_api_key_none, client: Union[LocalClient, RESTClien assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}" -@pytest.mark.parametrize("stream_tokens", [True, False]) -def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent: AgentState, stream_tokens): +@pytest.mark.parametrize( + "stream_tokens,model", + [ + (True, "gpt-4o-mini"), + (True, "claude-3-sonnet-20240229"), + (False, "gpt-4o-mini"), + (False, "claude-3-sonnet-20240229"), + ], +) +def test_streaming_send_message( + mock_e2b_api_key_none, + client: RESTClient, + agent: AgentState, + stream_tokens: bool, + model: str, +): if isinstance(client, LocalClient): pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming") assert isinstance(client, RESTClient), client + # Update agent's model + agent.llm_config.model = model + # First, try streaming just steps # Next, try streaming both steps and tokens @@ -249,11 +266,8 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent send_message_ran = False # 3. Check that we get all the start/stop/end tokens we want # This includes all of the MessageStreamStatus enums - # done_gen = False - # done_step = False done = False - # print(response) assert response, "Sending message failed" for chunk in response: assert isinstance(chunk, LettaStreamingResponse) @@ -268,12 +282,6 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent if chunk == MessageStreamStatus.done: assert not done, "Message stream already done" done = True - # elif chunk == MessageStreamStatus.done_step: - # assert not done_step, "Message stream already done step" - # done_step = True - # elif chunk == MessageStreamStatus.done_generation: - # assert not done_gen, "Message stream already done generation" - # done_gen = True if isinstance(chunk, LettaUsageStatistics): # Some rough metrics for a reasonable usage pattern assert chunk.step_count == 1 @@ -286,8 +294,6 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent assert inner_thoughts_exist, "No inner thoughts found" assert send_message_ran, "send_message function call not found" assert done, "Message stream not done" - # assert done_step, "Message stream not done step" - # assert done_gen, "Message stream not done generation" def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState):