fix: March 11 fixes (#2479)
Co-authored-by: cthomas <caren@letta.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: Kevin Lin <klin5061@gmail.com> Co-authored-by: Shubham Naik <shub@letta.com> Co-authored-by: Shubham Naik <shub@memgpt.ai> Co-authored-by: Charles Packer <packercharles@gmail.com> Co-authored-by: Shubham Naik <shubham.naik10@gmail.com> Co-authored-by: mlong93 <35275280+mlong93@users.noreply.github.com> Co-authored-by: Mindy Long <mindy@letta.com> Co-authored-by: Stephan Fitzpatrick <stephan@knowsuchagency.com> Co-authored-by: dboyliao <qmalliao@gmail.com> Co-authored-by: Jyotirmaya Mahanta <jyotirmaya.mahanta@gmail.com> Co-authored-by: Nicholas <102550462+ndisalvio3@users.noreply.github.com> Co-authored-by: tarunkumark <tkksctwo@gmail.com> Co-authored-by: Miao <one.lemorage@gmail.com> Co-authored-by: Krishnakumar R (KK) <65895020+kk-src@users.noreply.github.com> Co-authored-by: Will Sargent <will.sargent@gmail.com>
This commit is contained in:
31
alembic/versions/d211df879a5f_add_agent_id_to_steps.py
Normal file
31
alembic/versions/d211df879a5f_add_agent_id_to_steps.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""add agent id to steps
|
||||
|
||||
Revision ID: d211df879a5f
|
||||
Revises: 2f4ede6ae33b
|
||||
Create Date: 2025-03-06 21:42:22.289345
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d211df879a5f"
|
||||
down_revision: Union[str, None] = "2f4ede6ae33b"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("steps", sa.Column("agent_id", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("steps", "agent_id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.6.37"
|
||||
__version__ = "0.6.38"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
||||
@@ -29,6 +29,7 @@ from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
from letta.memory import summarize_messages
|
||||
@@ -356,19 +357,38 @@ class Agent(BaseAgent):
|
||||
for attempt in range(1, empty_response_retry_limit + 1):
|
||||
try:
|
||||
log_telemetry(self.logger, "_get_ai_reply create start")
|
||||
response = create(
|
||||
# New LLM client flow
|
||||
llm_client = LLMClient.create(
|
||||
agent_id=self.agent_state.id,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
messages=message_sequence,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
functions=allowed_functions,
|
||||
# functions_python=self.functions_python, do we need this?
|
||||
function_call=function_call,
|
||||
first_message=first_message,
|
||||
force_tool_call=force_tool_call,
|
||||
stream=stream,
|
||||
stream_interface=self.interface,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor_id=self.agent_state.created_by_id,
|
||||
)
|
||||
|
||||
if llm_client and not stream:
|
||||
response = llm_client.send_llm_request(
|
||||
messages=message_sequence,
|
||||
tools=allowed_functions,
|
||||
tool_call=function_call,
|
||||
stream=stream,
|
||||
first_message=first_message,
|
||||
force_tool_call=force_tool_call,
|
||||
)
|
||||
else:
|
||||
# Fallback to existing flow
|
||||
response = create(
|
||||
llm_config=self.agent_state.llm_config,
|
||||
messages=message_sequence,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
functions=allowed_functions,
|
||||
# functions_python=self.functions_python, do we need this?
|
||||
function_call=function_call,
|
||||
first_message=first_message,
|
||||
force_tool_call=force_tool_call,
|
||||
stream=stream,
|
||||
stream_interface=self.interface,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
log_telemetry(self.logger, "_get_ai_reply create finish")
|
||||
|
||||
# These bottom two are retryable
|
||||
@@ -632,7 +652,7 @@ class Agent(BaseAgent):
|
||||
function_args,
|
||||
function_response,
|
||||
messages,
|
||||
[tool_return] if tool_return else None,
|
||||
[tool_return],
|
||||
include_function_failed_message=True,
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
@@ -659,7 +679,7 @@ class Agent(BaseAgent):
|
||||
"content": function_response,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
tool_returns=[tool_return] if tool_return else None,
|
||||
tool_returns=[tool_return] if sandbox_run_result else None,
|
||||
)
|
||||
) # extend conversation with function response
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
|
||||
@@ -909,6 +929,7 @@ class Agent(BaseAgent):
|
||||
# Log step - this must happen before messages are persisted
|
||||
step = self.step_manager.log_step(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
provider_name=self.agent_state.llm_config.model_endpoint_type,
|
||||
model=self.agent_state.llm_config.model,
|
||||
model_endpoint=self.agent_state.llm_config.model_endpoint,
|
||||
@@ -1174,6 +1195,7 @@ class Agent(BaseAgent):
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
previous_message_count=self.message_manager.size(actor=self.user, agent_id=self.agent_state.id),
|
||||
archival_memory_size=self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id),
|
||||
recent_passages=self.agent_manager.list_passages(actor=self.user, agent_id=self.agent_state.id, ascending=False, limit=10),
|
||||
)
|
||||
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import time
|
||||
from typing import Callable, Dict, Generator, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
|
||||
import letta.utils
|
||||
from letta.constants import ADMIN_PREFIX, BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, FUNCTION_RETURN_CHAR_LIMIT
|
||||
@@ -29,7 +28,7 @@ from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, ChatMemory, CreateArchivalMemory, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
@@ -640,30 +639,6 @@ class RESTClient(AbstractClient):
|
||||
# refresh and return agent
|
||||
return self.get_agent(agent_state.id)
|
||||
|
||||
def update_message(
|
||||
self,
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
role: Optional[MessageRole] = None,
|
||||
text: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
tool_calls: Optional[List[OpenAIToolCall]] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> Message:
|
||||
request = MessageUpdate(
|
||||
role=role,
|
||||
content=text,
|
||||
name=name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update message: {response.text}")
|
||||
return Message(**response.json())
|
||||
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -2436,30 +2411,6 @@ class LocalClient(AbstractClient):
|
||||
# TODO: get full agent state
|
||||
return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user)
|
||||
|
||||
def update_message(
|
||||
self,
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
role: Optional[MessageRole] = None,
|
||||
text: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
tool_calls: Optional[List[OpenAIToolCall]] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> Message:
|
||||
message = self.server.update_agent_message(
|
||||
agent_id=agent_id,
|
||||
message_id=message_id,
|
||||
request=MessageUpdate(
|
||||
role=role,
|
||||
content=text,
|
||||
name=name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
return message
|
||||
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
||||
@@ -50,7 +50,7 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "
|
||||
# Base memory tools CAN be edited, and are added by default by the server
|
||||
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
# Multi agent tools
|
||||
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
|
||||
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"]
|
||||
# Set of all built-in Letta tools
|
||||
LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from letta.functions.helpers import (
|
||||
_send_message_to_agents_matching_all_tags_async,
|
||||
_send_message_to_agents_matching_tags_async,
|
||||
execute_send_message_to_agent,
|
||||
fire_and_forget_send_to_agent,
|
||||
)
|
||||
@@ -70,18 +70,19 @@ def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str
|
||||
return "Successfully sent message"
|
||||
|
||||
|
||||
def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]:
|
||||
def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: List[str], match_some: List[str]) -> List[str]:
|
||||
"""
|
||||
Sends a message to all agents within the same organization that match all of the specified tags. Messages are dispatched in parallel for improved performance, with retries to handle transient issues and timeouts to ensure responsiveness. This function enforces a limit of 100 agents and does not support pagination (cursor-based queries). Each agent must match all specified tags (`match_all_tags=True`) to be included.
|
||||
Sends a message to all agents within the same organization that match the specified tag criteria. Agents must possess *all* of the tags in `match_all` and *at least one* of the tags in `match_some` to receive the message.
|
||||
|
||||
Args:
|
||||
message (str): The content of the message to be sent to each matching agent.
|
||||
tags (List[str]): A list of tags that an agent must possess to receive the message.
|
||||
match_all (List[str]): A list of tags that an agent must possess to receive the message.
|
||||
match_some (List[str]): A list of tags where an agent must have at least one to qualify.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of responses from the agents that matched all tags. Each
|
||||
response corresponds to a single agent. Agents that do not respond will not
|
||||
have an entry in the returned list.
|
||||
List[str]: A list of responses from the agents that matched the filtering criteria. Each
|
||||
response corresponds to a single agent. Agents that do not respond will not have an entry
|
||||
in the returned list.
|
||||
"""
|
||||
|
||||
return asyncio.run(_send_message_to_agents_matching_all_tags_async(self, message, tags))
|
||||
return asyncio.run(_send_message_to_agents_matching_tags_async(self, message, match_all, match_some))
|
||||
|
||||
@@ -518,8 +518,16 @@ def fire_and_forget_send_to_agent(
|
||||
run_in_background_thread(background_task())
|
||||
|
||||
|
||||
async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]:
|
||||
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async start", message=message, tags=tags)
|
||||
async def _send_message_to_agents_matching_tags_async(
|
||||
sender_agent: "Agent", message: str, match_all: List[str], match_some: List[str]
|
||||
) -> List[str]:
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
"_send_message_to_agents_matching_tags_async start",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
server = get_letta_server()
|
||||
|
||||
augmented_message = (
|
||||
@@ -529,9 +537,22 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
|
||||
)
|
||||
|
||||
# Retrieve up to 100 matching agents
|
||||
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents start", message=message, tags=tags)
|
||||
matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100)
|
||||
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents finish", message=message, tags=tags)
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
"_send_message_to_agents_matching_tags_async listing agents start",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
matching_agents = server.agent_manager.list_agents_matching_tags(actor=sender_agent.user, match_all=match_all, match_some=match_some)
|
||||
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
"_send_message_to_agents_matching_tags_async listing agents finish",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
|
||||
# Create a system message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
|
||||
@@ -559,7 +580,13 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
|
||||
else:
|
||||
final.append(r)
|
||||
|
||||
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async finish", message=message, tags=tags)
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
"_send_message_to_agents_matching_tags_async finish",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
return final
|
||||
|
||||
|
||||
|
||||
332
letta/llm_api/google_ai_client.py
Normal file
332
letta/llm_api/google_ai_client.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.llm_api.helpers import make_post_request
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||
from letta.settings import model_settings
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
|
||||
class GoogleAIClient(LLMClientBase):
|
||||
|
||||
def request(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
url, headers = self.get_gemini_endpoint_and_headers(generate_content=True)
|
||||
return make_post_request(url, headers, request_data)
|
||||
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: List[dict],
|
||||
tool_call: Optional[str],
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
"""
|
||||
if tools:
|
||||
tools = [{"type": "function", "function": f} for f in tools]
|
||||
tools = self.convert_tools_to_google_ai_format(
|
||||
[Tool(**t) for t in tools],
|
||||
)
|
||||
contents = self.add_dummy_model_messages(
|
||||
[m.to_google_ai_dict() for m in messages],
|
||||
)
|
||||
|
||||
return {
|
||||
"contents": contents,
|
||||
"tools": tools,
|
||||
"generation_config": {
|
||||
"temperature": self.llm_config.temperature,
|
||||
"max_output_tokens": self.llm_config.max_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage],
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts custom response format from llm client into an OpenAI
|
||||
ChatCompletionsResponse object.
|
||||
|
||||
Example Input:
|
||||
{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 9,
|
||||
"candidatesTokenCount": 27,
|
||||
"totalTokenCount": 36
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
choices = []
|
||||
index = 0
|
||||
for candidate in response_data["candidates"]:
|
||||
content = candidate["content"]
|
||||
|
||||
role = content["role"]
|
||||
assert role == "model", f"Unknown role in response: {role}"
|
||||
|
||||
parts = content["parts"]
|
||||
# TODO support parts / multimodal
|
||||
# TODO support parallel tool calling natively
|
||||
# TODO Alternative here is to throw away everything else except for the first part
|
||||
for response_message in parts:
|
||||
# Convert the actual message style to OpenAI style
|
||||
if "functionCall" in response_message and response_message["functionCall"] is not None:
|
||||
function_call = response_message["functionCall"]
|
||||
assert isinstance(function_call, dict), function_call
|
||||
function_name = function_call["name"]
|
||||
assert isinstance(function_name, str), function_name
|
||||
function_args = function_call["args"]
|
||||
assert isinstance(function_args, dict), function_args
|
||||
|
||||
# NOTE: this also involves stripping the inner monologue out of the function
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
|
||||
assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG)
|
||||
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
||||
else:
|
||||
inner_thoughts = None
|
||||
|
||||
# Google AI API doesn't generate tool call IDs
|
||||
openai_response_message = Message(
|
||||
role="assistant", # NOTE: "model" -> "assistant"
|
||||
content=inner_thoughts,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=get_tool_call_id(),
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
# Inner thoughts are the content by default
|
||||
inner_thoughts = response_message["text"]
|
||||
|
||||
# Google AI API doesn't generate tool call IDs
|
||||
openai_response_message = Message(
|
||||
role="assistant", # NOTE: "model" -> "assistant"
|
||||
content=inner_thoughts,
|
||||
)
|
||||
|
||||
# Google AI API uses different finish reason strings than OpenAI
|
||||
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
|
||||
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
|
||||
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
||||
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
|
||||
finish_reason = candidate["finishReason"]
|
||||
if finish_reason == "STOP":
|
||||
openai_finish_reason = (
|
||||
"function_call"
|
||||
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
|
||||
else "stop"
|
||||
)
|
||||
elif finish_reason == "MAX_TOKENS":
|
||||
openai_finish_reason = "length"
|
||||
elif finish_reason == "SAFETY":
|
||||
openai_finish_reason = "content_filter"
|
||||
elif finish_reason == "RECITATION":
|
||||
openai_finish_reason = "content_filter"
|
||||
else:
|
||||
raise ValueError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
finish_reason=openai_finish_reason,
|
||||
index=index,
|
||||
message=openai_response_message,
|
||||
)
|
||||
)
|
||||
index += 1
|
||||
|
||||
# if len(choices) > 1:
|
||||
# raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
|
||||
|
||||
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
|
||||
# "usageMetadata": {
|
||||
# "promptTokenCount": 9,
|
||||
# "candidatesTokenCount": 27,
|
||||
# "totalTokenCount": 36
|
||||
# }
|
||||
if "usageMetadata" in response_data:
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=response_data["usageMetadata"]["promptTokenCount"],
|
||||
completion_tokens=response_data["usageMetadata"]["candidatesTokenCount"],
|
||||
total_tokens=response_data["usageMetadata"]["totalTokenCount"],
|
||||
)
|
||||
else:
|
||||
# Count it ourselves
|
||||
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
|
||||
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
|
||||
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
response_id = str(uuid.uuid4())
|
||||
return ChatCompletionResponse(
|
||||
id=response_id,
|
||||
choices=choices,
|
||||
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
||||
created=get_utc_time(),
|
||||
usage=usage,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise e
|
||||
|
||||
def get_gemini_endpoint_and_headers(
|
||||
self,
|
||||
key_in_header: bool = True,
|
||||
generate_content: bool = False,
|
||||
) -> Tuple[str, dict]:
|
||||
"""
|
||||
Dynamically generate the model endpoint and headers.
|
||||
"""
|
||||
|
||||
url = f"{self.llm_config.model_endpoint}/v1beta/models"
|
||||
|
||||
# Add the model
|
||||
url += f"/{self.llm_config.model}"
|
||||
|
||||
# Add extension for generating content if we're hitting the LM
|
||||
if generate_content:
|
||||
url += ":generateContent"
|
||||
|
||||
# Decide if api key should be in header or not
|
||||
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
|
||||
if key_in_header:
|
||||
headers = {"Content-Type": "application/json", "x-goog-api-key": model_settings.gemini_api_key}
|
||||
else:
|
||||
url += f"?key={model_settings.gemini_api_key}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
return url, headers
|
||||
|
||||
def convert_tools_to_google_ai_format(self, tools: List[Tool]) -> List[dict]:
|
||||
"""
|
||||
OpenAI style:
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "find_movies",
|
||||
"description": "find ....",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
PARAM: {
|
||||
"type": PARAM_TYPE, # eg "string"
|
||||
"description": PARAM_DESCRIPTION,
|
||||
},
|
||||
...
|
||||
},
|
||||
"required": List[str],
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Google AI style:
|
||||
"tools": [{
|
||||
"functionDeclarations": [{
|
||||
"name": "find_movies",
|
||||
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "STRING",
|
||||
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
|
||||
},
|
||||
"description": {
|
||||
"type": "STRING",
|
||||
"description": "Any kind of description including category or genre, title words, attributes, etc."
|
||||
}
|
||||
},
|
||||
"required": ["description"]
|
||||
}
|
||||
}, {
|
||||
"name": "find_theaters",
|
||||
...
|
||||
"""
|
||||
function_list = [
|
||||
dict(
|
||||
name=t.function.name,
|
||||
description=t.function.description,
|
||||
parameters=t.function.parameters, # TODO need to unpack
|
||||
)
|
||||
for t in tools
|
||||
]
|
||||
|
||||
# Correct casing + add inner thoughts if needed
|
||||
for func in function_list:
|
||||
func["parameters"]["type"] = "OBJECT"
|
||||
for param_name, param_fields in func["parameters"]["properties"].items():
|
||||
param_fields["type"] = param_fields["type"].upper()
|
||||
# Add inner thoughts
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
|
||||
func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = {
|
||||
"type": "STRING",
|
||||
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
|
||||
}
|
||||
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG)
|
||||
|
||||
return [{"functionDeclarations": function_list}]
|
||||
|
||||
def add_dummy_model_messages(self, messages: List[dict]) -> List[dict]:
|
||||
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
|
||||
|
||||
In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
|
||||
so there is no natural follow-up 'model' role message.
|
||||
|
||||
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
|
||||
with role == 'model' that is placed in-betweeen and function output
|
||||
(role == 'tool') and user message (role == 'user').
|
||||
"""
|
||||
dummy_yield_message = {
|
||||
"role": "model",
|
||||
"parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}],
|
||||
}
|
||||
messages_with_padding = []
|
||||
for i, message in enumerate(messages):
|
||||
messages_with_padding.append(message)
|
||||
# Check if the current message role is 'tool' and the next message role is 'user'
|
||||
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
|
||||
messages_with_padding.append(dummy_yield_message)
|
||||
|
||||
return messages_with_padding
|
||||
214
letta/llm_api/google_vertex_client.py
Normal file
214
letta/llm_api/google_vertex_client.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from google import genai
|
||||
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ToolConfig
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||
from letta.settings import model_settings
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
|
||||
class GoogleVertexClient(GoogleAIClient):
|
||||
|
||||
def request(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
client = genai.Client(
|
||||
vertexai=True,
|
||||
project=model_settings.google_cloud_project,
|
||||
location=model_settings.google_cloud_location,
|
||||
http_options={"api_version": "v1"},
|
||||
)
|
||||
response = client.models.generate_content(
|
||||
model=self.llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
config=request_data["config"],
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: List[dict],
|
||||
tool_call: Optional[str],
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
"""
|
||||
request_data = super().build_request_data(messages, tools, tool_call)
|
||||
request_data["config"] = request_data.pop("generation_config")
|
||||
request_data["config"]["tools"] = request_data.pop("tools")
|
||||
|
||||
tool_config = ToolConfig(
|
||||
function_calling_config=FunctionCallingConfig(
|
||||
# ANY mode forces the model to predict only function calls
|
||||
mode=FunctionCallingConfigMode.ANY,
|
||||
)
|
||||
)
|
||||
request_data["config"]["tool_config"] = tool_config.model_dump()
|
||||
|
||||
return request_data
|
||||
|
||||
def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage],
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts custom response format from llm client into an OpenAI
|
||||
ChatCompletionsResponse object.
|
||||
|
||||
Example:
|
||||
{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 9,
|
||||
"candidatesTokenCount": 27,
|
||||
"totalTokenCount": 36
|
||||
}
|
||||
}
|
||||
"""
|
||||
response = GenerateContentResponse(**response_data)
|
||||
try:
|
||||
choices = []
|
||||
index = 0
|
||||
for candidate in response.candidates:
|
||||
content = candidate.content
|
||||
|
||||
role = content.role
|
||||
assert role == "model", f"Unknown role in response: {role}"
|
||||
|
||||
parts = content.parts
|
||||
# TODO support parts / multimodal
|
||||
# TODO support parallel tool calling natively
|
||||
# TODO Alternative here is to throw away everything else except for the first part
|
||||
for response_message in parts:
|
||||
# Convert the actual message style to OpenAI style
|
||||
if response_message.function_call:
|
||||
function_call = response_message.function_call
|
||||
function_name = function_call.name
|
||||
function_args = function_call.args
|
||||
assert isinstance(function_args, dict), function_args
|
||||
|
||||
# NOTE: this also involves stripping the inner monologue out of the function
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
|
||||
assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG)
|
||||
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
||||
else:
|
||||
inner_thoughts = None
|
||||
|
||||
# Google AI API doesn't generate tool call IDs
|
||||
openai_response_message = Message(
|
||||
role="assistant", # NOTE: "model" -> "assistant"
|
||||
content=inner_thoughts,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=get_tool_call_id(),
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
# Inner thoughts are the content by default
|
||||
inner_thoughts = response_message.text
|
||||
|
||||
# Google AI API doesn't generate tool call IDs
|
||||
openai_response_message = Message(
|
||||
role="assistant", # NOTE: "model" -> "assistant"
|
||||
content=inner_thoughts,
|
||||
)
|
||||
|
||||
# Google AI API uses different finish reason strings than OpenAI
|
||||
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
|
||||
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
|
||||
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
||||
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
|
||||
finish_reason = candidate.finish_reason.value
|
||||
if finish_reason == "STOP":
|
||||
openai_finish_reason = (
|
||||
"function_call"
|
||||
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
|
||||
else "stop"
|
||||
)
|
||||
elif finish_reason == "MAX_TOKENS":
|
||||
openai_finish_reason = "length"
|
||||
elif finish_reason == "SAFETY":
|
||||
openai_finish_reason = "content_filter"
|
||||
elif finish_reason == "RECITATION":
|
||||
openai_finish_reason = "content_filter"
|
||||
else:
|
||||
raise ValueError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
finish_reason=openai_finish_reason,
|
||||
index=index,
|
||||
message=openai_response_message,
|
||||
)
|
||||
)
|
||||
index += 1
|
||||
|
||||
# if len(choices) > 1:
|
||||
# raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
|
||||
|
||||
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
|
||||
# "usageMetadata": {
|
||||
# "promptTokenCount": 9,
|
||||
# "candidatesTokenCount": 27,
|
||||
# "totalTokenCount": 36
|
||||
# }
|
||||
if response.usage_metadata:
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=response.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response.usage_metadata.candidates_token_count,
|
||||
total_tokens=response.usage_metadata.total_token_count,
|
||||
)
|
||||
else:
|
||||
# Count it ourselves
|
||||
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
|
||||
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
|
||||
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
response_id = str(uuid.uuid4())
|
||||
return ChatCompletionResponse(
|
||||
id=response_id,
|
||||
choices=choices,
|
||||
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
||||
created=get_utc_time(),
|
||||
usage=usage,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise e
|
||||
48
letta/llm_api/llm_client.py
Normal file
48
letta/llm_api/llm_client.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Optional
|
||||
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Factory class for creating LLM clients based on the model endpoint type."""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
agent_id: str,
|
||||
llm_config: LLMConfig,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
actor_id: Optional[str] = None,
|
||||
) -> Optional[LLMClientBase]:
|
||||
"""
|
||||
Create an LLM client based on the model endpoint type.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the agent
|
||||
llm_config: Configuration for the LLM model
|
||||
put_inner_thoughts_first: Whether to put inner thoughts first in the response
|
||||
use_structured_output: Whether to use structured output
|
||||
use_tool_naming: Whether to use tool naming
|
||||
actor_id: Optional actor identifier
|
||||
|
||||
Returns:
|
||||
An instance of LLMClientBase subclass
|
||||
|
||||
Raises:
|
||||
ValueError: If the model endpoint type is not supported
|
||||
"""
|
||||
match llm_config.model_endpoint_type:
|
||||
case "google_ai":
|
||||
from letta.llm_api.google_ai_client import GoogleAIClient
|
||||
|
||||
return GoogleAIClient(
|
||||
agent_id=agent_id, llm_config=llm_config, put_inner_thoughts_first=put_inner_thoughts_first, actor_id=actor_id
|
||||
)
|
||||
case "google_vertex":
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
|
||||
return GoogleVertexClient(
|
||||
agent_id=agent_id, llm_config=llm_config, put_inner_thoughts_first=put_inner_thoughts_first, actor_id=actor_id
|
||||
)
|
||||
case _:
|
||||
return None
|
||||
129
letta/llm_api/llm_client_base.py
Normal file
129
letta/llm_api/llm_client_base.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from openai import AsyncStream, Stream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.tracing import log_event
|
||||
|
||||
|
||||
class LLMClientBase:
|
||||
"""
|
||||
Abstract base class for LLM clients, formatting the request objects,
|
||||
handling the downstream request and parsing into chat completions response format
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
llm_config: LLMConfig,
|
||||
put_inner_thoughts_first: Optional[bool] = True,
|
||||
use_structured_output: Optional[bool] = True,
|
||||
use_tool_naming: bool = True,
|
||||
actor_id: Optional[str] = None,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.llm_config = llm_config
|
||||
self.put_inner_thoughts_first = put_inner_thoughts_first
|
||||
self.actor_id = actor_id
|
||||
|
||||
def send_llm_request(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
tool_call: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
first_message: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
|
||||
"""
|
||||
Issues a request to the downstream model endpoint and parses response.
|
||||
If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over.
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, tools, tool_call)
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
if stream:
|
||||
return self.stream(request_data)
|
||||
else:
|
||||
response_data = self.request(request_data)
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
return self.convert_response_to_chat_completion(response_data, messages)
|
||||
|
||||
async def send_llm_request_async(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: Optional[List[dict]] = None, # TODO: change to Tool object
|
||||
tool_call: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
first_message: bool = False,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]:
|
||||
"""
|
||||
Issues a request to the downstream model endpoint.
|
||||
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
|
||||
Otherwise returns a ChatCompletionResponse.
|
||||
"""
|
||||
request_data = self.build_request_data(messages, tools, tool_call)
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
if stream:
|
||||
return await self.stream_async(request_data)
|
||||
else:
|
||||
response_data = await self.request_async(request_data)
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
return self.convert_response_to_chat_completion(response_data, messages)
|
||||
|
||||
@abstractmethod
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: List[dict],
|
||||
tool_call: Optional[str],
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def request(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def request_async(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[Message],
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts custom response format from llm client into an OpenAI
|
||||
ChatCompletionsResponse object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to llm and returns raw response.
|
||||
"""
|
||||
raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}")
|
||||
|
||||
@abstractmethod
|
||||
async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to llm and returns raw response.
|
||||
"""
|
||||
raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}")
|
||||
@@ -33,6 +33,7 @@ class Step(SqlalchemyBase):
|
||||
job_id: Mapped[Optional[str]] = mapped_column(
|
||||
ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step"
|
||||
)
|
||||
agent_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
|
||||
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
|
||||
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
|
||||
model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
||||
@@ -37,7 +37,8 @@ class BaseBlock(LettaBase, validate_assignment=True):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def verify_char_limit(self) -> Self:
|
||||
if self.value and len(self.value) > self.limit:
|
||||
# self.limit can be None from
|
||||
if self.limit is not None and self.value and len(self.value) > self.limit:
|
||||
error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -89,61 +90,16 @@ class Persona(Block):
|
||||
label: str = "persona"
|
||||
|
||||
|
||||
# class CreateBlock(BaseBlock):
|
||||
# """Create a block"""
|
||||
#
|
||||
# is_template: bool = True
|
||||
# label: str = Field(..., description="Label of the block.")
|
||||
|
||||
|
||||
class BlockLabelUpdate(BaseModel):
|
||||
"""Update the label of a block"""
|
||||
|
||||
current_label: str = Field(..., description="Current label of the block.")
|
||||
new_label: str = Field(..., description="New label of the block.")
|
||||
|
||||
|
||||
# class CreatePersona(CreateBlock):
|
||||
# """Create a persona block"""
|
||||
#
|
||||
# label: str = "persona"
|
||||
#
|
||||
#
|
||||
# class CreateHuman(CreateBlock):
|
||||
# """Create a human block"""
|
||||
#
|
||||
# label: str = "human"
|
||||
|
||||
|
||||
class BlockUpdate(BaseBlock):
|
||||
"""Update a block"""
|
||||
|
||||
limit: Optional[int] = Field(CORE_MEMORY_BLOCK_CHAR_LIMIT, description="Character limit of the block.")
|
||||
limit: Optional[int] = Field(None, description="Character limit of the block.")
|
||||
value: Optional[str] = Field(None, description="Value of the block.")
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # Ignores extra fields
|
||||
|
||||
|
||||
class BlockLimitUpdate(BaseModel):
|
||||
"""Update the limit of a block"""
|
||||
|
||||
label: str = Field(..., description="Label of the block.")
|
||||
limit: int = Field(..., description="New limit of the block.")
|
||||
|
||||
|
||||
# class UpdatePersona(BlockUpdate):
|
||||
# """Update a persona block"""
|
||||
#
|
||||
# label: str = "persona"
|
||||
#
|
||||
#
|
||||
# class UpdateHuman(BlockUpdate):
|
||||
# """Update a human block"""
|
||||
#
|
||||
# label: str = "human"
|
||||
|
||||
|
||||
class CreateBlock(BaseBlock):
|
||||
"""Create a block"""
|
||||
|
||||
|
||||
@@ -236,6 +236,32 @@ LettaMessageUnion = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class UpdateSystemMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
|
||||
|
||||
class UpdateUserMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
|
||||
|
||||
class UpdateReasoningMessage(BaseModel):
|
||||
reasoning: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
|
||||
|
||||
class UpdateAssistantMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
|
||||
|
||||
LettaMessageUpdateUnion = Annotated[
|
||||
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
|
||||
Field(discriminator="message_type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_message_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
|
||||
@@ -74,7 +74,7 @@ class MessageUpdate(BaseModel):
|
||||
"""Request to update a message"""
|
||||
|
||||
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
||||
content: Optional[Union[str, List[MessageContentUnion]]] = Field(..., description="The content of the message.")
|
||||
content: Optional[Union[str, List[MessageContentUnion]]] = Field(None, description="The content of the message.")
|
||||
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
||||
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
||||
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
|
||||
@@ -18,6 +18,7 @@ class Step(StepBase):
|
||||
job_id: Optional[str] = Field(
|
||||
None, description="The unique identifier of the job that this step belongs to. Only included for async calls."
|
||||
)
|
||||
agent_id: Optional[str] = Field(None, description="The ID of the agent that performed the step.")
|
||||
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
|
||||
model: Optional[str] = Field(None, description="The name of the model used for this step.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.")
|
||||
|
||||
@@ -70,4 +70,11 @@ class SerializedAgentSchema(BaseSchema):
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
# TODO: Serialize these as well...
|
||||
exclude = BaseSchema.Meta.exclude + ("sources", "source_passages", "agent_passages")
|
||||
exclude = BaseSchema.Meta.exclude + (
|
||||
"project_id",
|
||||
"template_id",
|
||||
"base_template_id",
|
||||
"sources",
|
||||
"source_passages",
|
||||
"agent_passages",
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/completions",
|
||||
"/{agent_id}/chat/completions",
|
||||
response_model=None,
|
||||
operation_id="create_chat_completions",
|
||||
responses={
|
||||
@@ -37,6 +37,7 @@ logger = get_logger(__name__)
|
||||
},
|
||||
)
|
||||
async def create_chat_completions(
|
||||
agent_id: str,
|
||||
completion_request: CompletionCreateParams = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -51,12 +52,6 @@ async def create_chat_completions(
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent_id = str(completion_request.get("user", None))
|
||||
if agent_id is None:
|
||||
error_msg = "Must pass agent_id in the 'user' field"
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
llm_config = letta_agent.agent_state.llm_config
|
||||
if llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint:
|
||||
|
||||
@@ -13,13 +13,12 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block, BlockUpdate, CreateBlock # , BlockLabelUpdate, BlockLimitUpdate
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
|
||||
from letta.schemas.message import Message, MessageUpdate
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.source import Source
|
||||
@@ -119,6 +118,7 @@ async def upload_agent_serialized(
|
||||
True,
|
||||
description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.",
|
||||
),
|
||||
project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."),
|
||||
):
|
||||
"""
|
||||
Upload a serialized agent JSON file and recreate the agent in the system.
|
||||
@@ -129,7 +129,11 @@ async def upload_agent_serialized(
|
||||
serialized_data = await file.read()
|
||||
agent_json = json.loads(serialized_data)
|
||||
new_agent = server.agent_manager.deserialize(
|
||||
serialized_agent=agent_json, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools
|
||||
serialized_agent=agent_json,
|
||||
actor=actor,
|
||||
append_copy_suffix=append_copy_suffix,
|
||||
override_existing_tools=override_existing_tools,
|
||||
project_id=project_id,
|
||||
)
|
||||
return new_agent
|
||||
|
||||
@@ -526,20 +530,20 @@ def list_messages(
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/messages/{message_id}", response_model=Message, operation_id="modify_message")
|
||||
@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUpdateUnion, operation_id="modify_message")
|
||||
def modify_message(
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
request: MessageUpdate = Body(...),
|
||||
request: LettaMessageUpdateUnion = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
# TODO: Get rid of agent_id here, it's not really relevant
|
||||
# TODO: support modifying tool calls/returns
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
|
||||
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -20,6 +20,7 @@ def list_steps(
|
||||
start_date: Optional[str] = Query(None, description='Return steps after this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
|
||||
end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
|
||||
model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"),
|
||||
agent_id: Optional[str] = Query(None, description="Filter by the ID of the agent that performed the step"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
@@ -42,6 +43,7 @@ def list_steps(
|
||||
limit=limit,
|
||||
order=order,
|
||||
model=model,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
from fastapi import APIRouter, Body, Depends, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
|
||||
@@ -22,7 +22,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/completions",
|
||||
"/{agent_id}/chat/completions",
|
||||
response_model=None,
|
||||
operation_id="create_voice_chat_completions",
|
||||
responses={
|
||||
@@ -35,16 +35,13 @@ logger = get_logger(__name__)
|
||||
},
|
||||
)
|
||||
async def create_voice_chat_completions(
|
||||
agent_id: str,
|
||||
completion_request: CompletionCreateParams = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent_id = str(completion_request.get("user", None))
|
||||
if agent_id is None:
|
||||
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
|
||||
|
||||
# Also parse the user's new input
|
||||
input_message = UserMessage(**get_messages_from_completion_request(completion_request)[-1])
|
||||
|
||||
|
||||
@@ -358,6 +358,49 @@ class AgentManager:
|
||||
|
||||
return [agent.to_pydantic() for agent in agents]
|
||||
|
||||
@enforce_types
|
||||
def list_agents_matching_tags(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
match_all: List[str],
|
||||
match_some: List[str],
|
||||
limit: Optional[int] = 50,
|
||||
) -> List[PydanticAgentState]:
|
||||
"""
|
||||
Retrieves agents in the same organization that match all specified `match_all` tags
|
||||
and at least one tag from `match_some`. The query is optimized for efficiency by
|
||||
leveraging indexed filtering and aggregation.
|
||||
|
||||
Args:
|
||||
actor (PydanticUser): The user requesting the agent list.
|
||||
match_all (List[str]): Agents must have all these tags.
|
||||
match_some (List[str]): Agents must have at least one of these tags.
|
||||
limit (Optional[int]): Maximum number of agents to return.
|
||||
|
||||
Returns:
|
||||
List[PydanticAgentState: The filtered list of matching agents.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
|
||||
|
||||
if match_all:
|
||||
# Subquery to find agent IDs that contain all match_all tags
|
||||
subquery = (
|
||||
select(AgentsTags.agent_id)
|
||||
.where(AgentsTags.tag.in_(match_all))
|
||||
.group_by(AgentsTags.agent_id)
|
||||
.having(func.count(AgentsTags.tag) == literal(len(match_all)))
|
||||
)
|
||||
query = query.where(AgentModel.id.in_(subquery))
|
||||
|
||||
if match_some:
|
||||
# Ensures agents match at least one tag in match_some
|
||||
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
|
||||
|
||||
query = query.group_by(AgentModel.id).limit(limit)
|
||||
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
@@ -401,7 +444,12 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
def deserialize(
|
||||
self, serialized_agent: dict, actor: PydanticUser, append_copy_suffix: bool = True, override_existing_tools: bool = True
|
||||
self,
|
||||
serialized_agent: dict,
|
||||
actor: PydanticUser,
|
||||
append_copy_suffix: bool = True,
|
||||
override_existing_tools: bool = True,
|
||||
project_id: Optional[str] = None,
|
||||
) -> PydanticAgentState:
|
||||
tool_data_list = serialized_agent.pop("tools", [])
|
||||
|
||||
@@ -410,7 +458,9 @@ class AgentManager:
|
||||
agent = schema.load(serialized_agent, session=session)
|
||||
if append_copy_suffix:
|
||||
agent.name += "_copy"
|
||||
agent.create(session, actor=actor)
|
||||
if project_id:
|
||||
agent.project_id = project_id
|
||||
agent = agent.create(session, actor=actor)
|
||||
pydantic_agent = agent.to_pydantic()
|
||||
|
||||
# Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle
|
||||
@@ -548,6 +598,7 @@ class AgentManager:
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10),
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||||
@@ -718,7 +769,9 @@ class AgentManager:
|
||||
# Commit the changes
|
||||
agent.update(session, actor=actor)
|
||||
|
||||
# Add system messsage alert to agent
|
||||
# Force rebuild of system prompt so that the agent is updated with passage count
|
||||
# and recent passages and add system message alert to agent
|
||||
self.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True)
|
||||
self.append_system_message(
|
||||
agent_id=agent_id,
|
||||
content=DATA_SOURCE_ATTACH_ALERT,
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate, TextContent
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.schemas.user import User
|
||||
from letta.system import get_initial_boot_messages, get_login_event
|
||||
@@ -99,7 +100,10 @@ def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
|
||||
|
||||
# TODO: This code is kind of wonky and deserves a rewrite
|
||||
def compile_memory_metadata_block(
|
||||
memory_edit_timestamp: datetime.datetime, previous_message_count: int = 0, archival_memory_size: int = 0
|
||||
memory_edit_timestamp: datetime.datetime,
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
recent_passages: List[PydanticPassage] = None,
|
||||
) -> str:
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
|
||||
@@ -110,6 +114,11 @@ def compile_memory_metadata_block(
|
||||
f"### Memory [last modified: {timestamp_str}]",
|
||||
f"{previous_message_count} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{archival_memory_size} total memories you created are stored in archival memory (use functions to access them)",
|
||||
(
|
||||
f"Most recent archival passages {len(recent_passages)} recent passages: {[passage.text for passage in recent_passages]}"
|
||||
if recent_passages is not None
|
||||
else ""
|
||||
),
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
]
|
||||
)
|
||||
@@ -146,6 +155,7 @@ def compile_system_message(
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
recent_passages: Optional[List[PydanticPassage]] = None,
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
@@ -170,6 +180,7 @@ def compile_system_message(
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
recent_passages=recent_passages,
|
||||
)
|
||||
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
||||
|
||||
|
||||
@@ -78,7 +78,13 @@ class IdentityManager:
|
||||
if existing_identity is None:
|
||||
return self.create_identity(identity=identity, actor=actor)
|
||||
else:
|
||||
identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids)
|
||||
identity_update = IdentityUpdate(
|
||||
name=identity.name,
|
||||
identifier_key=identity.identifier_key,
|
||||
identity_type=identity.identity_type,
|
||||
agent_ids=identity.agent_ids,
|
||||
properties=identity.properties,
|
||||
)
|
||||
return self._update_identity(
|
||||
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import and_, or_
|
||||
@@ -7,6 +8,7 @@ from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -64,6 +66,44 @@ class MessageManager:
|
||||
"""Create multiple messages."""
|
||||
return [self.create_message(m, actor=actor) for m in pydantic_msgs]
|
||||
|
||||
@enforce_types
|
||||
def update_message_by_letta_message(
|
||||
self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser
|
||||
) -> PydanticMessage:
|
||||
"""
|
||||
Updated the underlying messages table giving an update specified to the user-facing LettaMessage
|
||||
"""
|
||||
message = self.get_message_by_id(message_id=message_id, actor=actor)
|
||||
if letta_message_update.message_type == "assistant_message":
|
||||
# modify the tool call for send_message
|
||||
# TODO: fix this if we add parallel tool calls
|
||||
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
||||
assert (
|
||||
message.tool_calls[0].function.name == "send_message"
|
||||
), f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||||
original_args = json.loads(message.tool_calls[0].function.arguments)
|
||||
original_args["message"] = letta_message_update.content # override the assistant message
|
||||
update_tool_call = message.tool_calls[0].__deepcopy__()
|
||||
update_tool_call.function.arguments = json.dumps(original_args)
|
||||
|
||||
update_message = MessageUpdate(tool_calls=[update_tool_call])
|
||||
elif letta_message_update.message_type == "reasoning_message":
|
||||
update_message = MessageUpdate(content=letta_message_update.reasoning)
|
||||
elif letta_message_update.message_type == "user_message" or letta_message_update.message_type == "system_message":
|
||||
update_message = MessageUpdate(content=letta_message_update.content)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type for modification: {letta_message_update.message_type}")
|
||||
|
||||
message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor)
|
||||
|
||||
# convert back to LettaMessage
|
||||
for letta_msg in message.to_letta_message(use_assistant_message=True):
|
||||
if letta_msg.message_type == letta_message_update.message_type:
|
||||
return letta_msg
|
||||
|
||||
# raise error if message type got modified
|
||||
raise ValueError(f"Message type got modified: {letta_message_update.message_type}")
|
||||
|
||||
@enforce_types
|
||||
def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
|
||||
"""
|
||||
|
||||
@@ -33,10 +33,15 @@ class StepManager:
|
||||
limit: Optional[int] = 50,
|
||||
order: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> List[PydanticStep]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
with self.session_maker() as session:
|
||||
filter_kwargs = {"organization_id": actor.organization_id, "model": model}
|
||||
filter_kwargs = {"organization_id": actor.organization_id}
|
||||
if model:
|
||||
filter_kwargs["model"] = model
|
||||
if agent_id:
|
||||
filter_kwargs["agent_id"] = agent_id
|
||||
|
||||
steps = StepModel.list(
|
||||
db_session=session,
|
||||
@@ -54,6 +59,7 @@ class StepManager:
|
||||
def log_step(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
provider_name: str,
|
||||
model: str,
|
||||
model_endpoint: Optional[str],
|
||||
@@ -65,6 +71,7 @@ class StepManager:
|
||||
step_data = {
|
||||
"origin": None,
|
||||
"organization_id": actor.organization_id,
|
||||
"agent_id": agent_id,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider_name,
|
||||
"model": model,
|
||||
|
||||
704
poetry.lock
generated
704
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.6.37"
|
||||
version = "0.6.38"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
@@ -58,8 +58,8 @@ nltk = "^3.8.1"
|
||||
jinja2 = "^3.1.5"
|
||||
locust = {version = "^2.31.5", optional = true}
|
||||
wikipedia = {version = "^1.4.0", optional = true}
|
||||
composio-langchain = "^0.7.2"
|
||||
composio-core = "^0.7.2"
|
||||
composio-langchain = "^0.7.7"
|
||||
composio-core = "^0.7.7"
|
||||
alembic = "^1.13.3"
|
||||
pyhumps = "^3.8.0"
|
||||
psycopg2 = {version = "^2.9.10", optional = true}
|
||||
@@ -98,9 +98,10 @@ qdrant = ["qdrant-client"]
|
||||
cloud-tool-sandbox = ["e2b-code-interpreter"]
|
||||
external-tools = ["docker", "langchain", "wikipedia", "langchain-community"]
|
||||
tests = ["wikipedia"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
|
||||
bedrock = ["boto3"]
|
||||
google = ["google-genai"]
|
||||
desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "datasets", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "datamodel-code-generator"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "datamodel-code-generator"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.4.2"
|
||||
|
||||
@@ -17,6 +17,7 @@ from letta.embeddings import embedding_model
|
||||
from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -103,12 +104,23 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
|
||||
messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user)
|
||||
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
|
||||
|
||||
response = create(
|
||||
llm_client = LLMClient.create(
|
||||
agent_id=agent_state.id,
|
||||
llm_config=agent_state.llm_config,
|
||||
user_id=str(uuid.UUID(int=1)), # dummy user_id
|
||||
messages=messages,
|
||||
functions=[t.json_schema for t in agent.agent_state.tools],
|
||||
actor_id=str(uuid.UUID(int=1)),
|
||||
)
|
||||
if llm_client:
|
||||
response = llm_client.send_llm_request(
|
||||
messages=messages,
|
||||
tools=[t.json_schema for t in agent.agent_state.tools],
|
||||
)
|
||||
else:
|
||||
response = create(
|
||||
llm_config=agent_state.llm_config,
|
||||
user_id=str(uuid.UUID(int=1)), # dummy user_id
|
||||
messages=messages,
|
||||
functions=[t.json_schema for t in agent.agent_state.tools],
|
||||
)
|
||||
|
||||
# Basic check
|
||||
assert response is not None, response
|
||||
|
||||
@@ -120,12 +120,11 @@ def agent(client, roll_dice_tool, weather_tool, composio_gmail_get_profile_tool)
|
||||
# --- Helper Functions --- #
|
||||
|
||||
|
||||
def _get_chat_request(agent_id, message, stream=True):
|
||||
def _get_chat_request(message, stream=True):
|
||||
"""Returns a chat completion request with streaming enabled."""
|
||||
return ChatCompletionRequest(
|
||||
model="gpt-4o-mini",
|
||||
messages=[UserMessage(content=message)],
|
||||
user=agent_id,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@@ -157,9 +156,9 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice"])
|
||||
async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(agent.id, message)
|
||||
request = _get_chat_request(message)
|
||||
|
||||
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}", max_retries=0)
|
||||
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}/{agent.id}", max_retries=0)
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
@@ -171,9 +170,9 @@ async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint):
|
||||
@pytest.mark.parametrize("endpoint", ["openai/v1", "v1/voice"])
|
||||
async def test_chat_completions_streaming_openai_client(mock_e2b_api_key_none, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(agent.id, message)
|
||||
request = _get_chat_request(message)
|
||||
|
||||
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}", max_retries=0)
|
||||
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}/{agent.id}", max_retries=0)
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
|
||||
received_chunks = 0
|
||||
|
||||
@@ -127,54 +127,55 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags_simple(client):
|
||||
worker_tags = ["worker", "user-456"]
|
||||
worker_tags_123 = ["worker", "user-123"]
|
||||
worker_tags_456 = ["worker", "user-456"]
|
||||
|
||||
# Clean up first from possibly failed tests
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
|
||||
prev_worker_agents = client.server.agent_manager.list_agents(
|
||||
client.user, tags=list(set(worker_tags_123 + worker_tags_456)), match_all_tags=True
|
||||
)
|
||||
for agent in prev_worker_agents:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
secret_word = "banana"
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 non-matching worker agents (These should NOT get the message)
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-123"]
|
||||
worker_agents_123 = []
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_123)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
worker_agents_123.append(worker_agent)
|
||||
|
||||
# Create 3 worker agents that should get the message
|
||||
worker_agents = []
|
||||
worker_tags = ["worker", "user-456"]
|
||||
worker_agents_456 = []
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_456)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents.append(worker_agent)
|
||||
worker_agents_456.append(worker_agent)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
response = client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
message=f"Send a message to all agents with tags {worker_tags} informing them of the secret word: {secret_word}!",
|
||||
message=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
|
||||
)
|
||||
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolReturnMessage):
|
||||
tool_response = eval(json.loads(m.tool_return)["message"])
|
||||
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
|
||||
assert len(tool_response) == len(worker_agents)
|
||||
assert len(tool_response) == len(worker_agents_456)
|
||||
|
||||
# We can break after this, the ToolReturnMessage after is not related
|
||||
break
|
||||
|
||||
# Conversation search the worker agents
|
||||
for agent in worker_agents:
|
||||
for agent in worker_agents_456:
|
||||
messages = client.get_messages(agent.agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
@@ -183,13 +184,22 @@ def test_send_message_to_agents_with_tags_simple(client):
|
||||
assert secret_word in m.content
|
||||
break
|
||||
|
||||
# Ensure it's NOT in the non matching worker agents
|
||||
for agent in worker_agents_123:
|
||||
messages = client.get_messages(agent.agent_state.id)
|
||||
# Check for the presence of system message
|
||||
for m in reversed(messages):
|
||||
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
|
||||
if isinstance(m, SystemMessage):
|
||||
assert secret_word not in m.content
|
||||
|
||||
# Test that the agent can still receive messages fine
|
||||
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
|
||||
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
|
||||
|
||||
# Clean up agents
|
||||
client.delete_agent(manager_agent_state.id)
|
||||
for agent in worker_agents:
|
||||
for agent in worker_agents_456 + worker_agents_123:
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
|
||||
@@ -203,8 +213,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 worker agents
|
||||
@@ -245,8 +255,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 2 worker agents
|
||||
@@ -260,7 +270,7 @@ def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
|
||||
worker_agents.append(worker_agent)
|
||||
|
||||
# Encourage the manager to send a message to the other agent_obj with the secret string
|
||||
broadcast_message = f"Using your tool named `send_message_to_agents_matching_all_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!"
|
||||
broadcast_message = f"Using your tool named `send_message_to_agents_matching_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!"
|
||||
client.send_message(
|
||||
agent_id=manager_agent.agent_state.id,
|
||||
role="user",
|
||||
|
||||
@@ -65,10 +65,8 @@ def test_multi_agent_large(client, roll_dice_tool, num_workers):
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
|
||||
manager_agent_state = client.create_agent(
|
||||
name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id], tags=manager_tags
|
||||
)
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id], tags=manager_tags)
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 worker agents
|
||||
|
||||
@@ -229,7 +229,7 @@ def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log:
|
||||
- Datetime fields are ignored.
|
||||
- Order-independent comparison for lists of dicts.
|
||||
"""
|
||||
ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id"}
|
||||
ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id", "project_id"}
|
||||
|
||||
# Remove datetime fields upfront
|
||||
d1 = strip_datetime_fields(d1)
|
||||
@@ -476,8 +476,9 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
|
||||
# FastAPI endpoint tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize("append_copy_suffix", [True])
|
||||
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix):
|
||||
@pytest.mark.parametrize("append_copy_suffix", [True, False])
|
||||
@pytest.mark.parametrize("project_id", ["project-12345", None])
|
||||
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id):
|
||||
"""
|
||||
Test the full E2E serialization and deserialization flow using FastAPI endpoints.
|
||||
"""
|
||||
@@ -495,7 +496,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
upload_response = fastapi_client.post(
|
||||
"/v1/agents/upload",
|
||||
headers={"user_id": other_user.id},
|
||||
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False},
|
||||
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id},
|
||||
files=files,
|
||||
)
|
||||
assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
|
||||
@@ -504,7 +505,8 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
copied_agent = upload_response.json()
|
||||
copied_agent_id = copied_agent["id"]
|
||||
assert copied_agent_id != agent_id, "Copied agent should have a different ID"
|
||||
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
|
||||
if append_copy_suffix:
|
||||
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
|
||||
|
||||
# Step 3: Retrieve the copied agent
|
||||
serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
@@ -26,7 +26,7 @@ from letta.schemas.letta_message import (
|
||||
ToolReturnMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.letta_response import LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
@@ -536,21 +536,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
client.delete_source(source.id)
|
||||
|
||||
|
||||
def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can update the details of a message"""
|
||||
|
||||
# create a message
|
||||
message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
print("Messages=", message_response)
|
||||
assert isinstance(message_response, LettaResponse)
|
||||
assert isinstance(message_response.messages[-1], AssistantMessage)
|
||||
message = message_response.messages[-1]
|
||||
|
||||
new_text = "this is a secret message"
|
||||
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
|
||||
assert new_message.text == new_text
|
||||
|
||||
|
||||
def test_organization(client: RESTClient):
|
||||
if isinstance(client, LocalClient):
|
||||
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
|
||||
|
||||
@@ -21,8 +21,10 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
@@ -40,6 +42,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.identity_manager import IdentityManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.settings import tool_settings
|
||||
from tests.helpers.utils import comprehensive_agent_checks
|
||||
@@ -472,6 +475,45 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
server.source_manager.delete_source(default_source.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_with_tags(server: SyncServer, default_user):
|
||||
"""Fixture to create agents with specific tags."""
|
||||
agent1 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="agent1",
|
||||
tags=["primary_agent", "benefit_1"],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent2 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="agent2",
|
||||
tags=["primary_agent", "benefit_2"],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent3 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="agent3",
|
||||
tags=["primary_agent", "benefit_1", "benefit_2"],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
return [agent1, agent2, agent3]
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -775,6 +817,45 @@ def test_list_attached_agents_nonexistent_source(server: SyncServer, default_use
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_list_agents_matching_all_tags(server: SyncServer, default_user, agent_with_tags):
|
||||
agents = server.agent_manager.list_agents_matching_tags(
|
||||
actor=default_user,
|
||||
match_all=["primary_agent", "benefit_1"],
|
||||
match_some=[],
|
||||
)
|
||||
assert len(agents) == 2 # agent1 and agent3 match
|
||||
assert {a.name for a in agents} == {"agent1", "agent3"}
|
||||
|
||||
|
||||
def test_list_agents_matching_some_tags(server: SyncServer, default_user, agent_with_tags):
|
||||
agents = server.agent_manager.list_agents_matching_tags(
|
||||
actor=default_user,
|
||||
match_all=["primary_agent"],
|
||||
match_some=["benefit_1", "benefit_2"],
|
||||
)
|
||||
assert len(agents) == 3 # All agents match
|
||||
assert {a.name for a in agents} == {"agent1", "agent2", "agent3"}
|
||||
|
||||
|
||||
def test_list_agents_matching_all_and_some_tags(server: SyncServer, default_user, agent_with_tags):
|
||||
agents = server.agent_manager.list_agents_matching_tags(
|
||||
actor=default_user,
|
||||
match_all=["primary_agent", "benefit_1"],
|
||||
match_some=["benefit_2", "nonexistent"],
|
||||
)
|
||||
assert len(agents) == 1 # Only agent3 matches
|
||||
assert agents[0].name == "agent3"
|
||||
|
||||
|
||||
def test_list_agents_matching_no_tags(server: SyncServer, default_user, agent_with_tags):
|
||||
agents = server.agent_manager.list_agents_matching_tags(
|
||||
actor=default_user,
|
||||
match_all=["primary_agent", "nonexistent_tag"],
|
||||
match_some=["benefit_1", "benefit_2"],
|
||||
)
|
||||
assert len(agents) == 0 # No agent should match
|
||||
|
||||
|
||||
def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
"""Test listing agents that have ALL specified tags."""
|
||||
# Create agents with multiple tags
|
||||
@@ -1073,6 +1154,73 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use
|
||||
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1
|
||||
|
||||
|
||||
def test_modify_letta_message(server: SyncServer, sarah_agent, default_user):
|
||||
"""
|
||||
Test updating a message.
|
||||
"""
|
||||
|
||||
messages = server.message_manager.list_messages_for_agent(agent_id=sarah_agent.id, actor=default_user)
|
||||
letta_messages = PydanticMessage.to_letta_messages_from_list(messages=messages)
|
||||
|
||||
system_message = [msg for msg in letta_messages if msg.message_type == "system_message"][0]
|
||||
assistant_message = [msg for msg in letta_messages if msg.message_type == "assistant_message"][0]
|
||||
user_message = [msg for msg in letta_messages if msg.message_type == "user_message"][0]
|
||||
reasoning_message = [msg for msg in letta_messages if msg.message_type == "reasoning_message"][0]
|
||||
|
||||
# user message
|
||||
update_user_message = UpdateUserMessage(content="Hello, Sarah!")
|
||||
original_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
|
||||
assert original_user_message.content[0].text != update_user_message.content
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=user_message.id, letta_message_update=update_user_message, actor=default_user
|
||||
)
|
||||
updated_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
|
||||
assert updated_user_message.content[0].text == update_user_message.content
|
||||
|
||||
# system message
|
||||
update_system_message = UpdateSystemMessage(content="You are a friendly assistant!")
|
||||
original_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
|
||||
assert original_system_message.content[0].text != update_system_message.content
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=system_message.id, letta_message_update=update_system_message, actor=default_user
|
||||
)
|
||||
updated_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
|
||||
assert updated_system_message.content[0].text == update_system_message.content
|
||||
|
||||
# reasoning message
|
||||
update_reasoning_message = UpdateReasoningMessage(reasoning="I am thinking")
|
||||
original_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
|
||||
assert original_reasoning_message.content[0].text != update_reasoning_message.reasoning
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=reasoning_message.id, letta_message_update=update_reasoning_message, actor=default_user
|
||||
)
|
||||
updated_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
|
||||
assert updated_reasoning_message.content[0].text == update_reasoning_message.reasoning
|
||||
|
||||
# assistant message
|
||||
def parse_send_message(tool_call):
|
||||
import json
|
||||
|
||||
function_call = tool_call.function
|
||||
arguments = json.loads(function_call.arguments)
|
||||
return arguments["message"]
|
||||
|
||||
update_assistant_message = UpdateAssistantMessage(content="I am an agent!")
|
||||
original_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
|
||||
print("ORIGINAL", original_assistant_message.tool_calls)
|
||||
print("MESSAGE", parse_send_message(original_assistant_message.tool_calls[0]))
|
||||
assert parse_send_message(original_assistant_message.tool_calls[0]) != update_assistant_message.content
|
||||
server.message_manager.update_message_by_letta_message(
|
||||
message_id=assistant_message.id, letta_message_update=update_assistant_message, actor=default_user
|
||||
)
|
||||
updated_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
|
||||
print("UPDATED", updated_assistant_message.tool_calls)
|
||||
print("MESSAGE", parse_send_message(updated_assistant_message.tool_calls[0]))
|
||||
assert parse_send_message(updated_assistant_message.tool_calls[0]) == update_assistant_message.content
|
||||
|
||||
# TODO: tool calls/responses
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Blocks Relationship
|
||||
# ======================================================================================================================
|
||||
@@ -2001,28 +2149,42 @@ def test_update_block(server: SyncServer, default_user):
|
||||
|
||||
|
||||
def test_update_block_limit(server: SyncServer, default_user):
|
||||
|
||||
block_manager = BlockManager()
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user)
|
||||
|
||||
limit = len("Updated Content") * 2000
|
||||
update_data = BlockUpdate(value="Updated Content" * 2000, description="Updated description", limit=limit)
|
||||
update_data = BlockUpdate(value="Updated Content" * 2000, description="Updated description")
|
||||
|
||||
# Check that a large block fails
|
||||
try:
|
||||
# Check that exceeding the block limit raises an exception
|
||||
with pytest.raises(ValueError):
|
||||
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
|
||||
assert False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Ensure the update works when within limits
|
||||
update_data = BlockUpdate(value="Updated Content" * 2000, description="Updated description", limit=limit)
|
||||
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
|
||||
# Retrieve the updated block
|
||||
|
||||
# Retrieve the updated block and validate the update
|
||||
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
|
||||
# Assertions to verify the update
|
||||
|
||||
assert updated_block.value == "Updated Content" * 2000
|
||||
assert updated_block.description == "Updated description"
|
||||
|
||||
|
||||
def test_update_block_limit_does_not_reset(server: SyncServer, default_user):
|
||||
block_manager = BlockManager()
|
||||
new_content = "Updated Content" * 2000
|
||||
limit = len(new_content)
|
||||
block = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content", limit=limit), actor=default_user)
|
||||
|
||||
# Ensure the update works
|
||||
update_data = BlockUpdate(value=new_content)
|
||||
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
|
||||
|
||||
# Retrieve the updated block and validate the update
|
||||
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
|
||||
assert updated_block.value == new_content
|
||||
|
||||
|
||||
def test_delete_block(server: SyncServer, default_user):
|
||||
block_manager = BlockManager()
|
||||
|
||||
@@ -2075,6 +2237,154 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
|
||||
assert charles_agent.id in agent_state_ids
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Identity Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_and_upsert_identity(server: SyncServer, default_user):
|
||||
identity_manager = IdentityManager()
|
||||
identity_create = IdentityCreate(
|
||||
identifier_key="1234",
|
||||
name="caren",
|
||||
identity_type=IdentityType.user,
|
||||
properties=[
|
||||
IdentityProperty(key="email", value="caren@letta.com", type=IdentityPropertyType.string),
|
||||
IdentityProperty(key="age", value=28, type=IdentityPropertyType.number),
|
||||
],
|
||||
)
|
||||
|
||||
identity = identity_manager.create_identity(identity_create, actor=default_user)
|
||||
|
||||
# Assertions to ensure the created identity matches the expected values
|
||||
assert identity.identifier_key == identity_create.identifier_key
|
||||
assert identity.name == identity_create.name
|
||||
assert identity.identity_type == identity_create.identity_type
|
||||
assert identity.properties == identity_create.properties
|
||||
assert identity.agent_ids == []
|
||||
assert identity.project_id == None
|
||||
|
||||
with pytest.raises(UniqueConstraintViolationError):
|
||||
identity_manager.create_identity(
|
||||
IdentityCreate(identifier_key="1234", name="sarah", identity_type=IdentityType.user),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))]
|
||||
|
||||
identity = identity_manager.upsert_identity(identity_create, actor=default_user)
|
||||
|
||||
identity = identity_manager.get_identity(identity_id=identity.id, actor=default_user)
|
||||
assert len(identity.properties) == 1
|
||||
assert identity.properties[0].key == "age"
|
||||
assert identity.properties[0].value == 29
|
||||
|
||||
identity_manager.delete_identity(identity.id, actor=default_user)
|
||||
|
||||
|
||||
def test_get_identities(server, default_user):
|
||||
identity_manager = IdentityManager()
|
||||
|
||||
# Create identities to retrieve later
|
||||
user = identity_manager.create_identity(
|
||||
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
|
||||
)
|
||||
org = identity_manager.create_identity(
|
||||
IdentityCreate(name="letta", identifier_key="0001", identity_type=IdentityType.org), actor=default_user
|
||||
)
|
||||
|
||||
# Retrieve identities by different filters
|
||||
all_identities = identity_manager.list_identities(actor=default_user)
|
||||
assert len(all_identities) == 2
|
||||
|
||||
user_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user)
|
||||
assert len(user_identities) == 1
|
||||
assert user_identities[0].name == user.name
|
||||
|
||||
org_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org)
|
||||
assert len(org_identities) == 1
|
||||
assert org_identities[0].name == org.name
|
||||
|
||||
identity_manager.delete_identity(user.id, actor=default_user)
|
||||
identity_manager.delete_identity(org.id, actor=default_user)
|
||||
|
||||
|
||||
def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
identity = server.identity_manager.create_identity(
|
||||
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
|
||||
)
|
||||
|
||||
# Update identity fields
|
||||
update_data = IdentityUpdate(
|
||||
agent_ids=[sarah_agent.id, charles_agent.id],
|
||||
properties=[IdentityProperty(key="email", value="caren@letta.com", type=IdentityPropertyType.string)],
|
||||
)
|
||||
server.identity_manager.update_identity(identity_id=identity.id, identity=update_data, actor=default_user)
|
||||
|
||||
# Retrieve the updated identity
|
||||
updated_identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
# Assertions to verify the update
|
||||
assert updated_identity.agent_ids.sort() == update_data.agent_ids.sort()
|
||||
assert updated_identity.properties == update_data.properties
|
||||
|
||||
agent_state = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert identity.id in agent_state.identity_ids
|
||||
agent_state = server.agent_manager.get_agent_by_id(agent_id=charles_agent.id, actor=default_user)
|
||||
assert identity.id in agent_state.identity_ids
|
||||
|
||||
server.identity_manager.delete_identity(identity.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, default_user):
|
||||
# Create an identity
|
||||
identity = server.identity_manager.create_identity(
|
||||
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
|
||||
)
|
||||
agent_state = server.agent_manager.update_agent(
|
||||
agent_id=sarah_agent.id, agent_update=UpdateAgent(identity_ids=[identity.id]), actor=default_user
|
||||
)
|
||||
|
||||
# Check that identity has been attached
|
||||
assert identity.id in agent_state.identity_ids
|
||||
|
||||
# Now attempt to delete the identity
|
||||
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
|
||||
|
||||
# Verify that the identity was deleted
|
||||
identities = server.identity_manager.list_identities(actor=default_user)
|
||||
assert len(identities) == 0
|
||||
|
||||
# Check that block has been detached too
|
||||
agent_state = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert not identity.id in agent_state.identity_ids
|
||||
|
||||
|
||||
def test_get_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
identity = server.identity_manager.create_identity(
|
||||
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, agent_ids=[sarah_agent.id, charles_agent.id]),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Get the agents for identity id
|
||||
agent_states = server.agent_manager.list_agents(identifier_id=identity.id, actor=default_user)
|
||||
assert len(agent_states) == 2
|
||||
|
||||
# Check both agents are in the list
|
||||
agent_state_ids = [a.id for a in agent_states]
|
||||
assert sarah_agent.id in agent_state_ids
|
||||
assert charles_agent.id in agent_state_ids
|
||||
|
||||
# Get the agents for identifier key
|
||||
agent_states = server.agent_manager.list_agents(identifier_keys=[identity.identifier_key], actor=default_user)
|
||||
assert len(agent_states) == 2
|
||||
|
||||
# Check both agents are in the list
|
||||
agent_state_ids = [a.id for a in agent_states]
|
||||
assert sarah_agent.id in agent_state_ids
|
||||
assert charles_agent.id in agent_state_ids
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# SourceManager Tests - Sources
|
||||
# ======================================================================================================================
|
||||
@@ -3095,13 +3405,14 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user):
|
||||
def test_job_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_job, default_user):
|
||||
"""Test adding and retrieving job usage statistics."""
|
||||
job_manager = server.job_manager
|
||||
step_manager = server.step_manager
|
||||
|
||||
# Add usage statistics
|
||||
step_manager.log_step(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
@@ -3145,13 +3456,14 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
|
||||
assert len(steps) == 0
|
||||
|
||||
|
||||
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
|
||||
def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_job, default_user):
|
||||
"""Test adding multiple usage statistics entries for a job."""
|
||||
job_manager = server.job_manager
|
||||
step_manager = server.step_manager
|
||||
|
||||
# Add first usage statistics entry
|
||||
step_manager.log_step(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
@@ -3167,6 +3479,7 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
|
||||
|
||||
# Add second usage statistics entry
|
||||
step_manager.log_step(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
@@ -3193,6 +3506,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
|
||||
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
|
||||
assert len(steps) == 2
|
||||
|
||||
# get agent steps
|
||||
steps = step_manager.list_steps(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(steps) == 2
|
||||
|
||||
|
||||
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
||||
"""Test getting usage statistics for a nonexistent job."""
|
||||
@@ -3202,12 +3519,13 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
||||
job_manager.get_job_usage(job_id="nonexistent_job", actor=default_user)
|
||||
|
||||
|
||||
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user):
|
||||
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test adding usage statistics for a nonexistent job."""
|
||||
step_manager = server.step_manager
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
step_manager.log_step(
|
||||
agent_id=sarah_agent.id,
|
||||
provider_name="openai",
|
||||
model="gpt-4",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
|
||||
Reference in New Issue
Block a user