feat: add anthropic streaming (#716)
Co-authored-by: Mindy Long <mindy@letta.com> Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
@@ -731,6 +731,7 @@ class Agent(BaseAgent):
|
||||
# (if yes) Step 4: call the function
|
||||
# (if yes) Step 5: send the info on the function call and function response to LLM
|
||||
response_message = response.choices[0].message
|
||||
|
||||
response_message.model_copy() # TODO why are we copying here?
|
||||
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(
|
||||
response_message,
|
||||
|
||||
@@ -1,21 +1,41 @@
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import time
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
import anthropic
|
||||
from anthropic import PermissionDeniedError
|
||||
from anthropic.types.beta import (
|
||||
BetaRawContentBlockDeltaEvent,
|
||||
BetaRawContentBlockStartEvent,
|
||||
BetaRawContentBlockStopEvent,
|
||||
BetaRawMessageDeltaEvent,
|
||||
BetaRawMessageStartEvent,
|
||||
BetaRawMessageStopEvent,
|
||||
BetaTextBlock,
|
||||
BetaToolUseBlock,
|
||||
)
|
||||
|
||||
from letta.errors import BedrockError, BedrockPermissionError
|
||||
from letta.llm_api.aws_bedrock import get_bedrock_client
|
||||
from letta.schemas.message import Message
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.schemas.message import Message as _Message
|
||||
from letta.schemas.message import MessageRole as _MessageRole
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
||||
ChatCompletionChunkResponse,
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
ChunkChoice,
|
||||
FunctionCall,
|
||||
FunctionCallDelta,
|
||||
)
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.schemas.openai.chat_completion_response import Message
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
|
||||
from letta.schemas.openai.chat_completion_response import MessageDelta, ToolCall, ToolCallDelta, UsageStatistics
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.settings import model_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.utils import get_utc_time, smart_urljoin
|
||||
|
||||
BASE_URL = "https://api.anthropic.com/v1"
|
||||
@@ -200,6 +220,28 @@ def strip_xml_tags(string: str, tag: Optional[str]) -> str:
|
||||
return re.sub(tag_pattern, "", string)
|
||||
|
||||
|
||||
def strip_xml_tags_streaming(string: str, tag: Optional[str]) -> str:
|
||||
if tag is None:
|
||||
return string
|
||||
|
||||
# Handle common partial tag cases
|
||||
parts_to_remove = [
|
||||
"<", # Leftover start bracket
|
||||
f"<{tag}", # Opening tag start
|
||||
f"</{tag}", # Closing tag start
|
||||
f"/{tag}>", # Closing tag end
|
||||
f"{tag}>", # Opening tag end
|
||||
f"/{tag}", # Partial closing tag without >
|
||||
">", # Leftover end bracket
|
||||
]
|
||||
|
||||
result = string
|
||||
for part in parts_to_remove:
|
||||
result = result.replace(part, "")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def convert_anthropic_response_to_chatcompletion(
|
||||
response: anthropic.types.Message,
|
||||
inner_thoughts_xml_tag: Optional[str] = None,
|
||||
@@ -307,6 +349,166 @@ def convert_anthropic_response_to_chatcompletion(
|
||||
)
|
||||
|
||||
|
||||
def convert_anthropic_stream_event_to_chatcompletion(
|
||||
event: Union[
|
||||
BetaRawMessageStartEvent,
|
||||
BetaRawContentBlockStartEvent,
|
||||
BetaRawContentBlockDeltaEvent,
|
||||
BetaRawContentBlockStopEvent,
|
||||
BetaRawMessageDeltaEvent,
|
||||
BetaRawMessageStopEvent,
|
||||
],
|
||||
message_id: str,
|
||||
model: str,
|
||||
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
||||
) -> ChatCompletionChunkResponse:
|
||||
"""Convert Anthropic stream events to OpenAI ChatCompletionResponse format.
|
||||
|
||||
Args:
|
||||
event: The event to convert
|
||||
message_id: The ID of the message. Anthropic does not return this on every event, so we need to keep track of it
|
||||
model: The model used. Anthropic does not return this on every event, so we need to keep track of it
|
||||
|
||||
Example response from OpenAI:
|
||||
|
||||
'id': 'MESSAGE_ID',
|
||||
'choices': [
|
||||
{
|
||||
'finish_reason': None,
|
||||
'index': 0,
|
||||
'delta': {
|
||||
'content': None,
|
||||
'tool_calls': [
|
||||
{
|
||||
'index': 0,
|
||||
'id': None,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': None,
|
||||
'arguments': '_th'
|
||||
}
|
||||
}
|
||||
],
|
||||
'function_call': None
|
||||
},
|
||||
'logprobs': None
|
||||
}
|
||||
],
|
||||
'created': datetime.datetime(2025, 1, 24, 0, 18, 55, tzinfo=TzInfo(UTC)),
|
||||
'model': 'gpt-4o-mini-2024-07-18',
|
||||
'system_fingerprint': 'fp_bd83329f63',
|
||||
'object': 'chat.completion.chunk'
|
||||
}
|
||||
"""
|
||||
# Get finish reason
|
||||
finish_reason = None
|
||||
if isinstance(event, BetaRawMessageDeltaEvent):
|
||||
"""
|
||||
BetaRawMessageDeltaEvent(
|
||||
delta=Delta(
|
||||
stop_reason='tool_use',
|
||||
stop_sequence=None
|
||||
),
|
||||
type='message_delta',
|
||||
usage=BetaMessageDeltaUsage(output_tokens=45)
|
||||
)
|
||||
"""
|
||||
finish_reason = remap_finish_reason(event.delta.stop_reason)
|
||||
|
||||
# Get content and tool calls
|
||||
content = None
|
||||
tool_calls = None
|
||||
if isinstance(event, BetaRawContentBlockDeltaEvent):
|
||||
"""
|
||||
BetaRawContentBlockDeltaEvent(
|
||||
delta=BetaInputJSONDelta(
|
||||
partial_json='lo',
|
||||
type='input_json_delta'
|
||||
),
|
||||
index=0,
|
||||
type='content_block_delta'
|
||||
)
|
||||
|
||||
OR
|
||||
|
||||
BetaRawContentBlockDeltaEvent(
|
||||
delta=BetaTextDelta(
|
||||
text='👋 ',
|
||||
type='text_delta'
|
||||
),
|
||||
index=0,
|
||||
type='content_block_delta'
|
||||
)
|
||||
|
||||
"""
|
||||
if event.delta.type == "text_delta":
|
||||
content = strip_xml_tags_streaming(string=event.delta.text, tag=inner_thoughts_xml_tag)
|
||||
|
||||
elif event.delta.type == "input_json_delta":
|
||||
tool_calls = [
|
||||
ToolCallDelta(
|
||||
index=0,
|
||||
function=FunctionCallDelta(
|
||||
name=None,
|
||||
arguments=event.delta.partial_json,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif isinstance(event, BetaRawContentBlockStartEvent):
|
||||
"""
|
||||
BetaRawContentBlockStartEvent(
|
||||
content_block=BetaToolUseBlock(
|
||||
id='toolu_01LmpZhRhR3WdrRdUrfkKfFw',
|
||||
input={},
|
||||
name='get_weather',
|
||||
type='tool_use'
|
||||
),
|
||||
index=0,
|
||||
type='content_block_start'
|
||||
)
|
||||
|
||||
OR
|
||||
|
||||
BetaRawContentBlockStartEvent(
|
||||
content_block=BetaTextBlock(
|
||||
text='',
|
||||
type='text'
|
||||
),
|
||||
index=0,
|
||||
type='content_block_start'
|
||||
)
|
||||
"""
|
||||
if isinstance(event.content_block, BetaToolUseBlock):
|
||||
tool_calls = [
|
||||
ToolCallDelta(
|
||||
index=0,
|
||||
id=event.content_block.id,
|
||||
function=FunctionCallDelta(
|
||||
name=event.content_block.name,
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
]
|
||||
elif isinstance(event.content_block, BetaTextBlock):
|
||||
content = event.content_block.text
|
||||
|
||||
# Initialize base response
|
||||
choice = ChunkChoice(
|
||||
index=0,
|
||||
finish_reason=finish_reason,
|
||||
delta=MessageDelta(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
)
|
||||
return ChatCompletionChunkResponse(
|
||||
id=message_id,
|
||||
choices=[choice],
|
||||
created=get_utc_time(),
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_anthropic_request(
|
||||
data: ChatCompletionRequest,
|
||||
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
||||
@@ -345,7 +547,7 @@ def _prepare_anthropic_request(
|
||||
message["content"] = None
|
||||
|
||||
# Convert to Anthropic format
|
||||
msg_objs = [Message.dict_to_message(user_id=None, agent_id=None, openai_message_dict=m) for m in data["messages"]]
|
||||
msg_objs = [_Message.dict_to_message(user_id=None, agent_id=None, openai_message_dict=m) for m in data["messages"]]
|
||||
data["messages"] = [m.to_anthropic_dict(inner_thoughts_xml_tag=inner_thoughts_xml_tag) for m in msg_objs]
|
||||
|
||||
# Ensure first message is user
|
||||
@@ -359,7 +561,7 @@ def _prepare_anthropic_request(
|
||||
assert "max_tokens" in data, data
|
||||
|
||||
# Remove OpenAI-specific fields
|
||||
for field in ["frequency_penalty", "logprobs", "n", "top_p", "presence_penalty", "user"]:
|
||||
for field in ["frequency_penalty", "logprobs", "n", "top_p", "presence_penalty", "user", "stream"]:
|
||||
data.pop(field, None)
|
||||
|
||||
return data
|
||||
@@ -427,3 +629,279 @@ def anthropic_bedrock_chat_completions_request(
|
||||
raise BedrockPermissionError(f"User does not have access to the Bedrock model with the specified ID. {data['model']}")
|
||||
except Exception as e:
|
||||
raise BedrockError(f"Bedrock error: {e}")
|
||||
|
||||
|
||||
def anthropic_chat_completions_request_stream(
|
||||
data: ChatCompletionRequest,
|
||||
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
) -> Generator[ChatCompletionChunkResponse, None, None]:
|
||||
"""Stream chat completions from Anthropic API.
|
||||
|
||||
Similar to OpenAI's streaming, but using Anthropic's native streaming support.
|
||||
See: https://docs.anthropic.com/claude/reference/messages-streaming
|
||||
"""
|
||||
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag)
|
||||
|
||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||
if anthropic_override_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
|
||||
with anthropic_client.beta.messages.stream(
|
||||
**data,
|
||||
betas=betas,
|
||||
) as stream:
|
||||
# Stream: https://github.com/anthropics/anthropic-sdk-python/blob/d212ec9f6d5e956f13bc0ddc3d86b5888a954383/src/anthropic/lib/streaming/_beta_messages.py#L22
|
||||
message_id = None
|
||||
model = None
|
||||
|
||||
for chunk in stream._raw_stream:
|
||||
time.sleep(0.01) # Anthropic is really fast, faster than frontend can upload.
|
||||
if isinstance(chunk, BetaRawMessageStartEvent):
|
||||
"""
|
||||
BetaRawMessageStartEvent(
|
||||
message=BetaMessage(
|
||||
id='MESSAGE ID HERE',
|
||||
content=[],
|
||||
model='claude-3-5-sonnet-20241022',
|
||||
role='assistant',
|
||||
stop_reason=None,
|
||||
stop_sequence=None,
|
||||
type='message',
|
||||
usage=BetaUsage(
|
||||
cache_creation_input_tokens=0,
|
||||
cache_read_input_tokens=0,
|
||||
input_tokens=30,
|
||||
output_tokens=4
|
||||
)
|
||||
),
|
||||
type='message_start'
|
||||
),
|
||||
"""
|
||||
message_id = chunk.message.id
|
||||
model = chunk.message.model
|
||||
yield convert_anthropic_stream_event_to_chatcompletion(chunk, message_id, model, inner_thoughts_xml_tag)
|
||||
|
||||
|
||||
def anthropic_chat_completions_process_stream(
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
stream_interface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None,
|
||||
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
||||
create_message_id: bool = True,
|
||||
create_message_datetime: bool = True,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
) -> ChatCompletionResponse:
|
||||
"""Process a streaming completion response from Anthropic, similar to OpenAI's streaming.
|
||||
|
||||
Args:
|
||||
api_key: The Anthropic API key
|
||||
chat_completion_request: The chat completion request
|
||||
stream_interface: Interface for handling streaming chunks
|
||||
inner_thoughts_xml_tag: Tag for inner thoughts in the response
|
||||
create_message_id: Whether to create a message ID
|
||||
create_message_datetime: Whether to create message datetime
|
||||
betas: Beta features to enable
|
||||
|
||||
Returns:
|
||||
The final ChatCompletionResponse
|
||||
"""
|
||||
assert chat_completion_request.stream == True
|
||||
assert stream_interface is not None, "Required"
|
||||
|
||||
# Count prompt tokens - we'll get completion tokens from the final response
|
||||
chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages]
|
||||
prompt_tokens = num_tokens_from_messages(
|
||||
messages=chat_history,
|
||||
model=chat_completion_request.model,
|
||||
)
|
||||
|
||||
# Add tokens for tools if present
|
||||
if chat_completion_request.tools is not None:
|
||||
assert chat_completion_request.functions is None
|
||||
prompt_tokens += num_tokens_from_functions(
|
||||
functions=[t.function.model_dump() for t in chat_completion_request.tools],
|
||||
model=chat_completion_request.model,
|
||||
)
|
||||
elif chat_completion_request.functions is not None:
|
||||
assert chat_completion_request.tools is None
|
||||
prompt_tokens += num_tokens_from_functions(
|
||||
functions=[f.model_dump() for f in chat_completion_request.functions],
|
||||
model=chat_completion_request.model,
|
||||
)
|
||||
|
||||
# Create a dummy message for ID/datetime if needed
|
||||
dummy_message = _Message(
|
||||
role=_MessageRole.assistant,
|
||||
text="",
|
||||
agent_id="",
|
||||
model="",
|
||||
name=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
)
|
||||
|
||||
TEMP_STREAM_RESPONSE_ID = "temp_id"
|
||||
TEMP_STREAM_FINISH_REASON = "temp_null"
|
||||
TEMP_STREAM_TOOL_CALL_ID = "temp_id"
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID,
|
||||
choices=[],
|
||||
created=dummy_message.created_at,
|
||||
model=chat_completion_request.model,
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=0,
|
||||
prompt_tokens=prompt_tokens,
|
||||
total_tokens=prompt_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
if stream_interface:
|
||||
stream_interface.stream_start()
|
||||
|
||||
n_chunks = 0
|
||||
try:
|
||||
for chunk_idx, chat_completion_chunk in enumerate(
|
||||
anthropic_chat_completions_request_stream(
|
||||
data=chat_completion_request,
|
||||
inner_thoughts_xml_tag=inner_thoughts_xml_tag,
|
||||
betas=betas,
|
||||
)
|
||||
):
|
||||
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
|
||||
|
||||
if stream_interface:
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.process_chunk(
|
||||
chat_completion_chunk,
|
||||
message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id,
|
||||
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
|
||||
)
|
||||
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
|
||||
stream_interface.process_refresh(chat_completion_response)
|
||||
else:
|
||||
raise TypeError(stream_interface)
|
||||
|
||||
if chunk_idx == 0:
|
||||
# initialize the choice objects which we will increment with the deltas
|
||||
num_choices = len(chat_completion_chunk.choices)
|
||||
assert num_choices > 0
|
||||
chat_completion_response.choices = [
|
||||
Choice(
|
||||
finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten
|
||||
index=i,
|
||||
message=Message(
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
for i in range(len(chat_completion_chunk.choices))
|
||||
]
|
||||
|
||||
# add the choice delta
|
||||
assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk
|
||||
for chunk_choice in chat_completion_chunk.choices:
|
||||
if chunk_choice.finish_reason is not None:
|
||||
chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason
|
||||
|
||||
if chunk_choice.logprobs is not None:
|
||||
chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs
|
||||
|
||||
accum_message = chat_completion_response.choices[chunk_choice.index].message
|
||||
message_delta = chunk_choice.delta
|
||||
|
||||
if message_delta.content is not None:
|
||||
content_delta = message_delta.content
|
||||
if accum_message.content is None:
|
||||
accum_message.content = content_delta
|
||||
else:
|
||||
accum_message.content += content_delta
|
||||
|
||||
# TODO(charles) make sure this works for parallel tool calling?
|
||||
if message_delta.tool_calls is not None:
|
||||
tool_calls_delta = message_delta.tool_calls
|
||||
|
||||
# If this is the first tool call showing up in a chunk, initialize the list with it
|
||||
if accum_message.tool_calls is None:
|
||||
accum_message.tool_calls = [
|
||||
ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments=""))
|
||||
for _ in range(len(tool_calls_delta))
|
||||
]
|
||||
|
||||
# There may be many tool calls in a tool calls delta (e.g. parallel tool calls)
|
||||
for tool_call_delta in tool_calls_delta:
|
||||
if tool_call_delta.id is not None:
|
||||
# TODO assert that we're not overwriting?
|
||||
# TODO += instead of =?
|
||||
if tool_call_delta.index not in range(len(accum_message.tool_calls)):
|
||||
warnings.warn(
|
||||
f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}"
|
||||
)
|
||||
# force index 0
|
||||
# accum_message.tool_calls[0].id = tool_call_delta.id
|
||||
else:
|
||||
accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id
|
||||
if tool_call_delta.function is not None:
|
||||
if tool_call_delta.function.name is not None:
|
||||
# TODO assert that we're not overwriting?
|
||||
# TODO += instead of =?
|
||||
if tool_call_delta.index not in range(len(accum_message.tool_calls)):
|
||||
warnings.warn(
|
||||
f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}"
|
||||
)
|
||||
# force index 0
|
||||
# accum_message.tool_calls[0].function.name = tool_call_delta.function.name
|
||||
else:
|
||||
accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name
|
||||
if tool_call_delta.function.arguments is not None:
|
||||
if tool_call_delta.index not in range(len(accum_message.tool_calls)):
|
||||
warnings.warn(
|
||||
f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}"
|
||||
)
|
||||
# force index 0
|
||||
# accum_message.tool_calls[0].function.arguments += tool_call_delta.function.arguments
|
||||
else:
|
||||
accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments
|
||||
|
||||
if message_delta.function_call is not None:
|
||||
raise NotImplementedError(f"Old function_call style not support with stream=True")
|
||||
|
||||
# overwrite response fields based on latest chunk
|
||||
if not create_message_id:
|
||||
chat_completion_response.id = chat_completion_chunk.id
|
||||
if not create_message_datetime:
|
||||
chat_completion_response.created = chat_completion_chunk.created
|
||||
chat_completion_response.model = chat_completion_chunk.model
|
||||
chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint
|
||||
|
||||
# increment chunk counter
|
||||
n_chunks += 1
|
||||
|
||||
except Exception as e:
|
||||
if stream_interface:
|
||||
stream_interface.stream_end()
|
||||
print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
|
||||
raise e
|
||||
finally:
|
||||
if stream_interface:
|
||||
stream_interface.stream_end()
|
||||
|
||||
# make sure we didn't leave temp stuff in
|
||||
assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices])
|
||||
assert all(
|
||||
[
|
||||
all([tc.id != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
|
||||
for c in chat_completion_response.choices
|
||||
]
|
||||
)
|
||||
if not create_message_id:
|
||||
assert chat_completion_response.id != dummy_message.id
|
||||
|
||||
# compute token usage before returning
|
||||
# TODO try actually computing the #tokens instead of assuming the chunks is the same
|
||||
chat_completion_response.usage.completion_tokens = n_chunks
|
||||
chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks
|
||||
|
||||
assert len(chat_completion_response.choices) > 0, chat_completion_response
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
@@ -6,7 +6,11 @@ import requests
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.errors import LettaConfigurationError, RateLimitExceededError
|
||||
from letta.llm_api.anthropic import anthropic_bedrock_chat_completions_request, anthropic_chat_completions_request
|
||||
from letta.llm_api.anthropic import (
|
||||
anthropic_bedrock_chat_completions_request,
|
||||
anthropic_chat_completions_process_stream,
|
||||
anthropic_chat_completions_request,
|
||||
)
|
||||
from letta.llm_api.aws_bedrock import has_valid_aws_credentials
|
||||
from letta.llm_api.azure_openai import azure_openai_chat_completions_request
|
||||
from letta.llm_api.google_ai import convert_tools_to_google_ai_format, google_ai_chat_completions_request
|
||||
@@ -243,27 +247,38 @@ def create(
|
||||
)
|
||||
|
||||
elif llm_config.model_endpoint_type == "anthropic":
|
||||
if stream:
|
||||
raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}")
|
||||
if not use_tool_naming:
|
||||
raise NotImplementedError("Only tool calling supported on Anthropic API requests")
|
||||
|
||||
# Force tool calling
|
||||
tool_call = None
|
||||
if force_tool_call is not None:
|
||||
tool_call = {"type": "function", "function": {"name": force_tool_call}}
|
||||
assert functions is not None
|
||||
|
||||
chat_completion_request = ChatCompletionRequest(
|
||||
model=llm_config.model,
|
||||
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
||||
tools=([{"type": "function", "function": f} for f in functions] if functions else None),
|
||||
tool_choice=tool_call,
|
||||
max_tokens=1024, # TODO make dynamic
|
||||
temperature=llm_config.temperature,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# Handle streaming
|
||||
if stream: # Client requested token streaming
|
||||
assert isinstance(stream_interface, (AgentChunkStreamingInterface, AgentRefreshStreamingInterface)), type(stream_interface)
|
||||
|
||||
response = anthropic_chat_completions_process_stream(
|
||||
chat_completion_request=chat_completion_request,
|
||||
stream_interface=stream_interface,
|
||||
)
|
||||
return response
|
||||
|
||||
# Client did not request token streaming (expect a blocking backend response)
|
||||
return anthropic_chat_completions_request(
|
||||
data=ChatCompletionRequest(
|
||||
model=llm_config.model,
|
||||
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
||||
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
|
||||
tool_choice=tool_call,
|
||||
# user=str(user_id),
|
||||
# NOTE: max_tokens is required for Anthropic API
|
||||
max_tokens=1024, # TODO make dynamic
|
||||
temperature=llm_config.temperature,
|
||||
),
|
||||
data=chat_completion_request,
|
||||
)
|
||||
|
||||
# elif llm_config.model_endpoint_type == "cohere":
|
||||
|
||||
@@ -424,6 +424,16 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
choice = chunk.choices[0]
|
||||
message_delta = choice.delta
|
||||
|
||||
if (
|
||||
message_delta.content is None
|
||||
and message_delta.tool_calls is None
|
||||
and message_delta.function_call is None
|
||||
and choice.finish_reason is None
|
||||
and chunk.model.startswith("claude-")
|
||||
):
|
||||
# First chunk of Anthropic is empty
|
||||
return None
|
||||
|
||||
# inner thoughts
|
||||
if message_delta.content is not None:
|
||||
processed_chunk = ReasoningMessage(
|
||||
@@ -515,7 +525,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
self.function_id_buffer += tool_call.id
|
||||
|
||||
if tool_call.function.arguments:
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
if chunk.model.startswith("claude-"):
|
||||
updates_main_json = tool_call.function.arguments
|
||||
updates_inner_thoughts = ""
|
||||
else: # OpenAI
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
|
||||
# If we have inner thoughts, we should output them as a chunk
|
||||
if updates_inner_thoughts:
|
||||
@@ -585,7 +599,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
):
|
||||
# do an additional parse on the updates_main_json
|
||||
if self.function_args_buffer:
|
||||
|
||||
updates_main_json = self.function_args_buffer + updates_main_json
|
||||
self.function_args_buffer = None
|
||||
|
||||
@@ -875,7 +888,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
raise NotImplementedError("OpenAI proxy streaming temporarily disabled")
|
||||
else:
|
||||
processed_chunk = self._process_chunk_to_letta_style(chunk=chunk, message_id=message_id, message_date=message_date)
|
||||
|
||||
if processed_chunk is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -1277,12 +1277,14 @@ class SyncServer(Server):
|
||||
# This will be attached to the POST SSE request used under-the-hood
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Disable token streaming if not OpenAI
|
||||
# Disable token streaming if not OpenAI or Anthropic
|
||||
# TODO: cleanup this logic
|
||||
llm_config = letta_agent.agent_state.llm_config
|
||||
if stream_tokens and (llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint):
|
||||
if stream_tokens and (
|
||||
llm_config.model_endpoint_type not in ["openai", "anthropic"] or "inference.memgpt.ai" in llm_config.model_endpoint
|
||||
):
|
||||
warnings.warn(
|
||||
"Token streaming is only supported for models with type 'openai' or `inference.memgpt.ai` in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
"Token streaming is only supported for models with type 'openai', 'anthropic', or `inference.memgpt.ai` in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
)
|
||||
stream_tokens = False
|
||||
|
||||
|
||||
@@ -209,6 +209,11 @@ class JSONInnerThoughtsExtractor:
|
||||
|
||||
return updates_main_json, updates_inner_thoughts
|
||||
|
||||
# def process_anthropic_fragment(self, fragment) -> Tuple[str, str]:
|
||||
# # Add to buffer
|
||||
# self.main_buffer += fragment
|
||||
# return fragment, ""
|
||||
|
||||
@property
|
||||
def main_json(self):
|
||||
return self.main_buffer
|
||||
@@ -233,7 +238,6 @@ class FunctionArgumentsStreamHandler:
|
||||
|
||||
def process_json_chunk(self, chunk: str) -> Optional[str]:
|
||||
"""Process a chunk from the function arguments and return the plaintext version"""
|
||||
|
||||
# Use strip to handle only leading and trailing whitespace in control structures
|
||||
if self.accumulating:
|
||||
clean_chunk = chunk.strip()
|
||||
|
||||
@@ -224,12 +224,29 @@ def test_core_memory(mock_e2b_api_key_none, client: Union[LocalClient, RESTClien
|
||||
assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream_tokens", [True, False])
|
||||
def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent: AgentState, stream_tokens):
|
||||
@pytest.mark.parametrize(
|
||||
"stream_tokens,model",
|
||||
[
|
||||
(True, "gpt-4o-mini"),
|
||||
(True, "claude-3-sonnet-20240229"),
|
||||
(False, "gpt-4o-mini"),
|
||||
(False, "claude-3-sonnet-20240229"),
|
||||
],
|
||||
)
|
||||
def test_streaming_send_message(
|
||||
mock_e2b_api_key_none,
|
||||
client: RESTClient,
|
||||
agent: AgentState,
|
||||
stream_tokens: bool,
|
||||
model: str,
|
||||
):
|
||||
if isinstance(client, LocalClient):
|
||||
pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming")
|
||||
assert isinstance(client, RESTClient), client
|
||||
|
||||
# Update agent's model
|
||||
agent.llm_config.model = model
|
||||
|
||||
# First, try streaming just steps
|
||||
|
||||
# Next, try streaming both steps and tokens
|
||||
@@ -249,11 +266,8 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
|
||||
send_message_ran = False
|
||||
# 3. Check that we get all the start/stop/end tokens we want
|
||||
# This includes all of the MessageStreamStatus enums
|
||||
# done_gen = False
|
||||
# done_step = False
|
||||
done = False
|
||||
|
||||
# print(response)
|
||||
assert response, "Sending message failed"
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LettaStreamingResponse)
|
||||
@@ -268,12 +282,6 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
|
||||
if chunk == MessageStreamStatus.done:
|
||||
assert not done, "Message stream already done"
|
||||
done = True
|
||||
# elif chunk == MessageStreamStatus.done_step:
|
||||
# assert not done_step, "Message stream already done step"
|
||||
# done_step = True
|
||||
# elif chunk == MessageStreamStatus.done_generation:
|
||||
# assert not done_gen, "Message stream already done generation"
|
||||
# done_gen = True
|
||||
if isinstance(chunk, LettaUsageStatistics):
|
||||
# Some rough metrics for a reasonable usage pattern
|
||||
assert chunk.step_count == 1
|
||||
@@ -286,8 +294,6 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
|
||||
assert inner_thoughts_exist, "No inner thoughts found"
|
||||
assert send_message_ran, "send_message function call not found"
|
||||
assert done, "Message stream not done"
|
||||
# assert done_step, "Message stream not done step"
|
||||
# assert done_gen, "Message stream not done generation"
|
||||
|
||||
|
||||
def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
Reference in New Issue
Block a user