fix: March 11 fixes (#2479)

Co-authored-by: cthomas <caren@letta.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
Co-authored-by: Shubham Naik <shub@letta.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
Co-authored-by: Charles Packer <packercharles@gmail.com>
Co-authored-by: Shubham Naik <shubham.naik10@gmail.com>
Co-authored-by: mlong93 <35275280+mlong93@users.noreply.github.com>
Co-authored-by: Mindy Long <mindy@letta.com>
Co-authored-by: Stephan Fitzpatrick <stephan@knowsuchagency.com>
Co-authored-by: dboyliao <qmalliao@gmail.com>
Co-authored-by: Jyotirmaya Mahanta <jyotirmaya.mahanta@gmail.com>
Co-authored-by: Nicholas <102550462+ndisalvio3@users.noreply.github.com>
Co-authored-by: tarunkumark <tkksctwo@gmail.com>
Co-authored-by: Miao <one.lemorage@gmail.com>
Co-authored-by: Krishnakumar R (KK) <65895020+kk-src@users.noreply.github.com>
Co-authored-by: Will Sargent <will.sargent@gmail.com>
This commit is contained in:
Matthew Zhou
2025-03-11 14:50:17 -07:00
committed by GitHub
parent fb092d7fa9
commit 30f3d3d2c7
35 changed files with 1711 additions and 633 deletions

View File

@@ -0,0 +1,31 @@
"""add agent id to steps
Revision ID: d211df879a5f
Revises: 2f4ede6ae33b
Create Date: 2025-03-06 21:42:22.289345
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "d211df879a5f"
down_revision: Union[str, None] = "2f4ede6ae33b"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("steps", sa.Column("agent_id", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("steps", "agent_id")
# ### end Alembic commands ###

View File

@@ -1,4 +1,4 @@
__version__ = "0.6.37"
__version__ = "0.6.38"
# import clients
from letta.client.client import LocalClient, RESTClient, create_client

View File

@@ -29,6 +29,7 @@ from letta.helpers.json_helpers import json_dumps, json_loads
from letta.interface import AgentInterface
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.llm_api.llm_client import LLMClient
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
from letta.memory import summarize_messages
@@ -356,19 +357,38 @@ class Agent(BaseAgent):
for attempt in range(1, empty_response_retry_limit + 1):
try:
log_telemetry(self.logger, "_get_ai_reply create start")
response = create(
# New LLM client flow
llm_client = LLMClient.create(
agent_id=self.agent_state.id,
llm_config=self.agent_state.llm_config,
messages=message_sequence,
user_id=self.agent_state.created_by_id,
functions=allowed_functions,
# functions_python=self.functions_python, do we need this?
function_call=function_call,
first_message=first_message,
force_tool_call=force_tool_call,
stream=stream,
stream_interface=self.interface,
put_inner_thoughts_first=put_inner_thoughts_first,
actor_id=self.agent_state.created_by_id,
)
if llm_client and not stream:
response = llm_client.send_llm_request(
messages=message_sequence,
tools=allowed_functions,
tool_call=function_call,
stream=stream,
first_message=first_message,
force_tool_call=force_tool_call,
)
else:
# Fallback to existing flow
response = create(
llm_config=self.agent_state.llm_config,
messages=message_sequence,
user_id=self.agent_state.created_by_id,
functions=allowed_functions,
# functions_python=self.functions_python, do we need this?
function_call=function_call,
first_message=first_message,
force_tool_call=force_tool_call,
stream=stream,
stream_interface=self.interface,
put_inner_thoughts_first=put_inner_thoughts_first,
)
log_telemetry(self.logger, "_get_ai_reply create finish")
# These bottom two are retryable
@@ -632,7 +652,7 @@ class Agent(BaseAgent):
function_args,
function_response,
messages,
[tool_return] if tool_return else None,
[tool_return],
include_function_failed_message=True,
)
return messages, False, True # force a heartbeat to allow agent to handle error
@@ -659,7 +679,7 @@ class Agent(BaseAgent):
"content": function_response,
"tool_call_id": tool_call_id,
},
tool_returns=[tool_return] if tool_return else None,
tool_returns=[tool_return] if sandbox_run_result else None,
)
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
@@ -909,6 +929,7 @@ class Agent(BaseAgent):
# Log step - this must happen before messages are persisted
step = self.step_manager.log_step(
actor=self.user,
agent_id=self.agent_state.id,
provider_name=self.agent_state.llm_config.model_endpoint_type,
model=self.agent_state.llm_config.model,
model_endpoint=self.agent_state.llm_config.model_endpoint,
@@ -1174,6 +1195,7 @@ class Agent(BaseAgent):
memory_edit_timestamp=get_utc_time(),
previous_message_count=self.message_manager.size(actor=self.user, agent_id=self.agent_state.id),
archival_memory_size=self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id),
recent_passages=self.agent_manager.list_passages(actor=self.user, agent_id=self.agent_state.id, ascending=False, limit=10),
)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)

View File

@@ -4,7 +4,6 @@ import time
from typing import Callable, Dict, Generator, List, Optional, Union
import requests
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
import letta.utils
from letta.constants import ADMIN_PREFIX, BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, FUNCTION_RETURN_CHAR_LIMIT
@@ -29,7 +28,7 @@ from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, ChatMemory, CreateArchivalMemory, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
@@ -640,30 +639,6 @@ class RESTClient(AbstractClient):
# refresh and return agent
return self.get_agent(agent_state.id)
def update_message(
self,
agent_id: str,
message_id: str,
role: Optional[MessageRole] = None,
text: Optional[str] = None,
name: Optional[str] = None,
tool_calls: Optional[List[OpenAIToolCall]] = None,
tool_call_id: Optional[str] = None,
) -> Message:
request = MessageUpdate(
role=role,
content=text,
name=name,
tool_calls=tool_calls,
tool_call_id=tool_call_id,
)
response = requests.patch(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/{message_id}", json=request.model_dump(), headers=self.headers
)
if response.status_code != 200:
raise ValueError(f"Failed to update message: {response.text}")
return Message(**response.json())
def update_agent(
self,
agent_id: str,
@@ -2436,30 +2411,6 @@ class LocalClient(AbstractClient):
# TODO: get full agent state
return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user)
def update_message(
self,
agent_id: str,
message_id: str,
role: Optional[MessageRole] = None,
text: Optional[str] = None,
name: Optional[str] = None,
tool_calls: Optional[List[OpenAIToolCall]] = None,
tool_call_id: Optional[str] = None,
) -> Message:
message = self.server.update_agent_message(
agent_id=agent_id,
message_id=message_id,
request=MessageUpdate(
role=role,
content=text,
name=name,
tool_calls=tool_calls,
tool_call_id=tool_call_id,
),
actor=self.user,
)
return message
def update_agent(
self,
agent_id: str,

View File

@@ -50,7 +50,7 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "
# Base memory tools CAN be edited, and are added by default by the server
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
# Multi agent tools
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_tags", "send_message_to_agent_async"]
# Set of all built-in Letta tools
LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS)

View File

@@ -2,7 +2,7 @@ import asyncio
from typing import TYPE_CHECKING, List
from letta.functions.helpers import (
_send_message_to_agents_matching_all_tags_async,
_send_message_to_agents_matching_tags_async,
execute_send_message_to_agent,
fire_and_forget_send_to_agent,
)
@@ -70,18 +70,19 @@ def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str
return "Successfully sent message"
def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: List[str]) -> List[str]:
def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: List[str], match_some: List[str]) -> List[str]:
"""
Sends a message to all agents within the same organization that match all of the specified tags. Messages are dispatched in parallel for improved performance, with retries to handle transient issues and timeouts to ensure responsiveness. This function enforces a limit of 100 agents and does not support pagination (cursor-based queries). Each agent must match all specified tags (`match_all_tags=True`) to be included.
Sends a message to all agents within the same organization that match the specified tag criteria. Agents must possess *all* of the tags in `match_all` and *at least one* of the tags in `match_some` to receive the message.
Args:
message (str): The content of the message to be sent to each matching agent.
tags (List[str]): A list of tags that an agent must possess to receive the message.
match_all (List[str]): A list of tags that an agent must possess to receive the message.
match_some (List[str]): A list of tags where an agent must have at least one to qualify.
Returns:
List[str]: A list of responses from the agents that matched all tags. Each
response corresponds to a single agent. Agents that do not respond will not
have an entry in the returned list.
List[str]: A list of responses from the agents that matched the filtering criteria. Each
response corresponds to a single agent. Agents that do not respond will not have an entry
in the returned list.
"""
return asyncio.run(_send_message_to_agents_matching_all_tags_async(self, message, tags))
return asyncio.run(_send_message_to_agents_matching_tags_async(self, message, match_all, match_some))

View File

@@ -518,8 +518,16 @@ def fire_and_forget_send_to_agent(
run_in_background_thread(background_task())
async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]:
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async start", message=message, tags=tags)
async def _send_message_to_agents_matching_tags_async(
sender_agent: "Agent", message: str, match_all: List[str], match_some: List[str]
) -> List[str]:
log_telemetry(
sender_agent.logger,
"_send_message_to_agents_matching_tags_async start",
message=message,
match_all=match_all,
match_some=match_some,
)
server = get_letta_server()
augmented_message = (
@@ -529,9 +537,22 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
)
# Retrieve up to 100 matching agents
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents start", message=message, tags=tags)
matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100)
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents finish", message=message, tags=tags)
log_telemetry(
sender_agent.logger,
"_send_message_to_agents_matching_tags_async listing agents start",
message=message,
match_all=match_all,
match_some=match_some,
)
matching_agents = server.agent_manager.list_agents_matching_tags(actor=sender_agent.user, match_all=match_all, match_some=match_some)
log_telemetry(
sender_agent.logger,
"_send_message_to_agents_matching_tags_async listing agents finish",
message=message,
match_all=match_all,
match_some=match_some,
)
# Create a system message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
@@ -559,7 +580,13 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
else:
final.append(r)
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async finish", message=message, tags=tags)
log_telemetry(
sender_agent.logger,
"_send_message_to_agents_matching_tags_async finish",
message=message,
match_all=match_all,
match_some=match_some,
)
return final

View File

@@ -0,0 +1,332 @@
import uuid
from typing import List, Optional, Tuple
from letta.constants import NON_USER_MSG_PREFIX
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps
from letta.llm_api.helpers import make_post_request
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.local_llm.utils import count_tokens
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
from letta.settings import model_settings
from letta.utils import get_tool_call_id
class GoogleAIClient(LLMClientBase):
def request(self, request_data: dict) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
url, headers = self.get_gemini_endpoint_and_headers(generate_content=True)
return make_post_request(url, headers, request_data)
def build_request_data(
self,
messages: List[PydanticMessage],
tools: List[dict],
tool_call: Optional[str],
) -> dict:
"""
Constructs a request object in the expected data format for this client.
"""
if tools:
tools = [{"type": "function", "function": f} for f in tools]
tools = self.convert_tools_to_google_ai_format(
[Tool(**t) for t in tools],
)
contents = self.add_dummy_model_messages(
[m.to_google_ai_dict() for m in messages],
)
return {
"contents": contents,
"tools": tools,
"generation_config": {
"temperature": self.llm_config.temperature,
"max_output_tokens": self.llm_config.max_tokens,
},
}
def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage],
) -> ChatCompletionResponse:
"""
Converts custom response format from llm client into an OpenAI
ChatCompletionsResponse object.
Example Input:
{
"candidates": [
{
"content": {
"parts": [
{
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
}
]
}
}
],
"usageMetadata": {
"promptTokenCount": 9,
"candidatesTokenCount": 27,
"totalTokenCount": 36
}
}
"""
try:
choices = []
index = 0
for candidate in response_data["candidates"]:
content = candidate["content"]
role = content["role"]
assert role == "model", f"Unknown role in response: {role}"
parts = content["parts"]
# TODO support parts / multimodal
# TODO support parallel tool calling natively
# TODO Alternative here is to throw away everything else except for the first part
for response_message in parts:
# Convert the actual message style to OpenAI style
if "functionCall" in response_message and response_message["functionCall"] is not None:
function_call = response_message["functionCall"]
assert isinstance(function_call, dict), function_call
function_name = function_call["name"]
assert isinstance(function_name, str), function_name
function_args = function_call["args"]
assert isinstance(function_args, dict), function_args
# NOTE: this also involves stripping the inner monologue out of the function
if self.llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}"
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG)
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
else:
inner_thoughts = None
# Google AI API doesn't generate tool call IDs
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
tool_calls=[
ToolCall(
id=get_tool_call_id(),
type="function",
function=FunctionCall(
name=function_name,
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
),
)
],
)
else:
# Inner thoughts are the content by default
inner_thoughts = response_message["text"]
# Google AI API doesn't generate tool call IDs
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
)
# Google AI API uses different finish reason strings than OpenAI
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
finish_reason = candidate["finishReason"]
if finish_reason == "STOP":
openai_finish_reason = (
"function_call"
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
else "stop"
)
elif finish_reason == "MAX_TOKENS":
openai_finish_reason = "length"
elif finish_reason == "SAFETY":
openai_finish_reason = "content_filter"
elif finish_reason == "RECITATION":
openai_finish_reason = "content_filter"
else:
raise ValueError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
choices.append(
Choice(
finish_reason=openai_finish_reason,
index=index,
message=openai_response_message,
)
)
index += 1
# if len(choices) > 1:
# raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
# "usageMetadata": {
# "promptTokenCount": 9,
# "candidatesTokenCount": 27,
# "totalTokenCount": 36
# }
if "usageMetadata" in response_data:
usage = UsageStatistics(
prompt_tokens=response_data["usageMetadata"]["promptTokenCount"],
completion_tokens=response_data["usageMetadata"]["candidatesTokenCount"],
total_tokens=response_data["usageMetadata"]["totalTokenCount"],
)
else:
# Count it ourselves
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
total_tokens = prompt_tokens + completion_tokens
usage = UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
response_id = str(uuid.uuid4())
return ChatCompletionResponse(
id=response_id,
choices=choices,
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
created=get_utc_time(),
usage=usage,
)
except KeyError as e:
raise e
def get_gemini_endpoint_and_headers(
self,
key_in_header: bool = True,
generate_content: bool = False,
) -> Tuple[str, dict]:
"""
Dynamically generate the model endpoint and headers.
"""
url = f"{self.llm_config.model_endpoint}/v1beta/models"
# Add the model
url += f"/{self.llm_config.model}"
# Add extension for generating content if we're hitting the LM
if generate_content:
url += ":generateContent"
# Decide if api key should be in header or not
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
if key_in_header:
headers = {"Content-Type": "application/json", "x-goog-api-key": model_settings.gemini_api_key}
else:
url += f"?key={model_settings.gemini_api_key}"
headers = {"Content-Type": "application/json"}
return url, headers
def convert_tools_to_google_ai_format(self, tools: List[Tool]) -> List[dict]:
"""
OpenAI style:
"tools": [{
"type": "function",
"function": {
"name": "find_movies",
"description": "find ....",
"parameters": {
"type": "object",
"properties": {
PARAM: {
"type": PARAM_TYPE, # eg "string"
"description": PARAM_DESCRIPTION,
},
...
},
"required": List[str],
}
}
}
]
Google AI style:
"tools": [{
"functionDeclarations": [{
"name": "find_movies",
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
"parameters": {
"type": "OBJECT",
"properties": {
"location": {
"type": "STRING",
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
},
"description": {
"type": "STRING",
"description": "Any kind of description including category or genre, title words, attributes, etc."
}
},
"required": ["description"]
}
}, {
"name": "find_theaters",
...
"""
function_list = [
dict(
name=t.function.name,
description=t.function.description,
parameters=t.function.parameters, # TODO need to unpack
)
for t in tools
]
# Correct casing + add inner thoughts if needed
for func in function_list:
func["parameters"]["type"] = "OBJECT"
for param_name, param_fields in func["parameters"]["properties"].items():
param_fields["type"] = param_fields["type"].upper()
# Add inner thoughts
if self.llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = {
"type": "STRING",
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
}
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG)
return [{"functionDeclarations": function_list}]
def add_dummy_model_messages(self, messages: List[dict]) -> List[dict]:
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
so there is no natural follow-up 'model' role message.
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
with role == 'model' that is placed in-betweeen and function output
(role == 'tool') and user message (role == 'user').
"""
dummy_yield_message = {
"role": "model",
"parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}],
}
messages_with_padding = []
for i, message in enumerate(messages):
messages_with_padding.append(message)
# Check if the current message role is 'tool' and the next message role is 'user'
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
messages_with_padding.append(dummy_yield_message)
return messages_with_padding

View File

@@ -0,0 +1,214 @@
import uuid
from typing import List, Optional
from google import genai
from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ToolConfig
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps
from letta.llm_api.google_ai_client import GoogleAIClient
from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.local_llm.utils import count_tokens
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
from letta.settings import model_settings
from letta.utils import get_tool_call_id
class GoogleVertexClient(GoogleAIClient):
def request(self, request_data: dict) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
client = genai.Client(
vertexai=True,
project=model_settings.google_cloud_project,
location=model_settings.google_cloud_location,
http_options={"api_version": "v1"},
)
response = client.models.generate_content(
model=self.llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
return response.model_dump()
def build_request_data(
self,
messages: List[PydanticMessage],
tools: List[dict],
tool_call: Optional[str],
) -> dict:
"""
Constructs a request object in the expected data format for this client.
"""
request_data = super().build_request_data(messages, tools, tool_call)
request_data["config"] = request_data.pop("generation_config")
request_data["config"]["tools"] = request_data.pop("tools")
tool_config = ToolConfig(
function_calling_config=FunctionCallingConfig(
# ANY mode forces the model to predict only function calls
mode=FunctionCallingConfigMode.ANY,
)
)
request_data["config"]["tool_config"] = tool_config.model_dump()
return request_data
def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage],
) -> ChatCompletionResponse:
"""
Converts custom response format from llm client into an OpenAI
ChatCompletionsResponse object.
Example:
{
"candidates": [
{
"content": {
"parts": [
{
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
}
]
}
}
],
"usageMetadata": {
"promptTokenCount": 9,
"candidatesTokenCount": 27,
"totalTokenCount": 36
}
}
"""
response = GenerateContentResponse(**response_data)
try:
choices = []
index = 0
for candidate in response.candidates:
content = candidate.content
role = content.role
assert role == "model", f"Unknown role in response: {role}"
parts = content.parts
# TODO support parts / multimodal
# TODO support parallel tool calling natively
# TODO Alternative here is to throw away everything else except for the first part
for response_message in parts:
# Convert the actual message style to OpenAI style
if response_message.function_call:
function_call = response_message.function_call
function_name = function_call.name
function_args = function_call.args
assert isinstance(function_args, dict), function_args
# NOTE: this also involves stripping the inner monologue out of the function
if self.llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}"
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG)
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
else:
inner_thoughts = None
# Google AI API doesn't generate tool call IDs
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
tool_calls=[
ToolCall(
id=get_tool_call_id(),
type="function",
function=FunctionCall(
name=function_name,
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
),
)
],
)
else:
# Inner thoughts are the content by default
inner_thoughts = response_message.text
# Google AI API doesn't generate tool call IDs
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
)
# Google AI API uses different finish reason strings than OpenAI
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
finish_reason = candidate.finish_reason.value
if finish_reason == "STOP":
openai_finish_reason = (
"function_call"
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
else "stop"
)
elif finish_reason == "MAX_TOKENS":
openai_finish_reason = "length"
elif finish_reason == "SAFETY":
openai_finish_reason = "content_filter"
elif finish_reason == "RECITATION":
openai_finish_reason = "content_filter"
else:
raise ValueError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
choices.append(
Choice(
finish_reason=openai_finish_reason,
index=index,
message=openai_response_message,
)
)
index += 1
# if len(choices) > 1:
# raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
# "usageMetadata": {
# "promptTokenCount": 9,
# "candidatesTokenCount": 27,
# "totalTokenCount": 36
# }
if response.usage_metadata:
usage = UsageStatistics(
prompt_tokens=response.usage_metadata.prompt_token_count,
completion_tokens=response.usage_metadata.candidates_token_count,
total_tokens=response.usage_metadata.total_token_count,
)
else:
# Count it ourselves
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
total_tokens = prompt_tokens + completion_tokens
usage = UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
response_id = str(uuid.uuid4())
return ChatCompletionResponse(
id=response_id,
choices=choices,
model=self.llm_config.model, # NOTE: Google API doesn't pass back model in the response
created=get_utc_time(),
usage=usage,
)
except KeyError as e:
raise e

View File

@@ -0,0 +1,48 @@
from typing import Optional
from letta.llm_api.llm_client_base import LLMClientBase
from letta.schemas.llm_config import LLMConfig
class LLMClient:
"""Factory class for creating LLM clients based on the model endpoint type."""
@staticmethod
def create(
agent_id: str,
llm_config: LLMConfig,
put_inner_thoughts_first: bool = True,
actor_id: Optional[str] = None,
) -> Optional[LLMClientBase]:
"""
Create an LLM client based on the model endpoint type.
Args:
agent_id: Unique identifier for the agent
llm_config: Configuration for the LLM model
put_inner_thoughts_first: Whether to put inner thoughts first in the response
use_structured_output: Whether to use structured output
use_tool_naming: Whether to use tool naming
actor_id: Optional actor identifier
Returns:
An instance of LLMClientBase subclass
Raises:
ValueError: If the model endpoint type is not supported
"""
match llm_config.model_endpoint_type:
case "google_ai":
from letta.llm_api.google_ai_client import GoogleAIClient
return GoogleAIClient(
agent_id=agent_id, llm_config=llm_config, put_inner_thoughts_first=put_inner_thoughts_first, actor_id=actor_id
)
case "google_vertex":
from letta.llm_api.google_vertex_client import GoogleVertexClient
return GoogleVertexClient(
agent_id=agent_id, llm_config=llm_config, put_inner_thoughts_first=put_inner_thoughts_first, actor_id=actor_id
)
case _:
return None

View File

@@ -0,0 +1,129 @@
from abc import abstractmethod
from typing import List, Optional, Union
from openai import AsyncStream, Stream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.tracing import log_event
class LLMClientBase:
"""
Abstract base class for LLM clients, formatting the request objects,
handling the downstream request and parsing into chat completions response format
"""
def __init__(
self,
agent_id: str,
llm_config: LLMConfig,
put_inner_thoughts_first: Optional[bool] = True,
use_structured_output: Optional[bool] = True,
use_tool_naming: bool = True,
actor_id: Optional[str] = None,
):
self.agent_id = agent_id
self.llm_config = llm_config
self.put_inner_thoughts_first = put_inner_thoughts_first
self.actor_id = actor_id
def send_llm_request(
self,
messages: List[Message],
tools: Optional[List[dict]] = None, # TODO: change to Tool object
tool_call: Optional[str] = None,
stream: bool = False,
first_message: bool = False,
force_tool_call: Optional[str] = None,
) -> Union[ChatCompletionResponse, Stream[ChatCompletionChunk]]:
"""
Issues a request to the downstream model endpoint and parses response.
If stream=True, returns a Stream[ChatCompletionChunk] that can be iterated over.
Otherwise returns a ChatCompletionResponse.
"""
request_data = self.build_request_data(messages, tools, tool_call)
log_event(name="llm_request_sent", attributes=request_data)
if stream:
return self.stream(request_data)
else:
response_data = self.request(request_data)
log_event(name="llm_response_received", attributes=response_data)
return self.convert_response_to_chat_completion(response_data, messages)
async def send_llm_request_async(
self,
messages: List[Message],
tools: Optional[List[dict]] = None, # TODO: change to Tool object
tool_call: Optional[str] = None,
stream: bool = False,
first_message: bool = False,
force_tool_call: Optional[str] = None,
) -> Union[ChatCompletionResponse, AsyncStream[ChatCompletionChunk]]:
"""
Issues a request to the downstream model endpoint.
If stream=True, returns an AsyncStream[ChatCompletionChunk] that can be async iterated over.
Otherwise returns a ChatCompletionResponse.
"""
request_data = self.build_request_data(messages, tools, tool_call)
log_event(name="llm_request_sent", attributes=request_data)
if stream:
return await self.stream_async(request_data)
else:
response_data = await self.request_async(request_data)
log_event(name="llm_response_received", attributes=response_data)
return self.convert_response_to_chat_completion(response_data, messages)
@abstractmethod
def build_request_data(
self,
messages: List[Message],
tools: List[dict],
tool_call: Optional[str],
) -> dict:
"""
Constructs a request object in the expected data format for this client.
"""
raise NotImplementedError
@abstractmethod
def request(self, request_data: dict) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
raise NotImplementedError
@abstractmethod
async def request_async(self, request_data: dict) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
raise NotImplementedError
@abstractmethod
def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[Message],
) -> ChatCompletionResponse:
"""
Converts custom response format from llm client into an OpenAI
ChatCompletionsResponse object.
"""
raise NotImplementedError
@abstractmethod
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
"""
Performs underlying streaming request to llm and returns raw response.
"""
raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}")
@abstractmethod
async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying streaming request to llm and returns raw response.
"""
raise NotImplementedError(f"Streaming is not supported for {self.llm_config.model_endpoint_type}")

View File

@@ -33,6 +33,7 @@ class Step(SqlalchemyBase):
job_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("jobs.id", ondelete="SET NULL"), nullable=True, doc="The unique identified of the job run that triggered this step"
)
agent_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.")

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field, model_validator
from pydantic import Field, model_validator
from typing_extensions import Self
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
@@ -37,7 +37,8 @@ class BaseBlock(LettaBase, validate_assignment=True):
@model_validator(mode="after")
def verify_char_limit(self) -> Self:
if self.value and len(self.value) > self.limit:
# self.limit can be None from
if self.limit is not None and self.value and len(self.value) > self.limit:
error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}."
raise ValueError(error_msg)
@@ -89,61 +90,16 @@ class Persona(Block):
label: str = "persona"
# class CreateBlock(BaseBlock):
# """Create a block"""
#
# is_template: bool = True
# label: str = Field(..., description="Label of the block.")
class BlockLabelUpdate(BaseModel):
"""Update the label of a block"""
current_label: str = Field(..., description="Current label of the block.")
new_label: str = Field(..., description="New label of the block.")
# class CreatePersona(CreateBlock):
# """Create a persona block"""
#
# label: str = "persona"
#
#
# class CreateHuman(CreateBlock):
# """Create a human block"""
#
# label: str = "human"
class BlockUpdate(BaseBlock):
"""Update a block"""
limit: Optional[int] = Field(CORE_MEMORY_BLOCK_CHAR_LIMIT, description="Character limit of the block.")
limit: Optional[int] = Field(None, description="Character limit of the block.")
value: Optional[str] = Field(None, description="Value of the block.")
class Config:
extra = "ignore" # Ignores extra fields
class BlockLimitUpdate(BaseModel):
"""Update the limit of a block"""
label: str = Field(..., description="Label of the block.")
limit: int = Field(..., description="New limit of the block.")
# class UpdatePersona(BlockUpdate):
# """Update a persona block"""
#
# label: str = "persona"
#
#
# class UpdateHuman(BlockUpdate):
# """Update a human block"""
#
# label: str = "human"
class CreateBlock(BaseBlock):
"""Create a block"""

View File

@@ -236,6 +236,32 @@ LettaMessageUnion = Annotated[
]
class UpdateSystemMessage(BaseModel):
content: Union[str, List[MessageContentUnion]]
message_type: Literal["system_message"] = "system_message"
class UpdateUserMessage(BaseModel):
content: Union[str, List[MessageContentUnion]]
message_type: Literal["user_message"] = "user_message"
class UpdateReasoningMessage(BaseModel):
reasoning: Union[str, List[MessageContentUnion]]
message_type: Literal["reasoning_message"] = "reasoning_message"
class UpdateAssistantMessage(BaseModel):
content: Union[str, List[MessageContentUnion]]
message_type: Literal["assistant_message"] = "assistant_message"
LettaMessageUpdateUnion = Annotated[
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
Field(discriminator="message_type"),
]
def create_letta_message_union_schema():
return {
"oneOf": [

View File

@@ -74,7 +74,7 @@ class MessageUpdate(BaseModel):
"""Request to update a message"""
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
content: Optional[Union[str, List[MessageContentUnion]]] = Field(..., description="The content of the message.")
content: Optional[Union[str, List[MessageContentUnion]]] = Field(None, description="The content of the message.")
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")

View File

@@ -18,6 +18,7 @@ class Step(StepBase):
job_id: Optional[str] = Field(
None, description="The unique identifier of the job that this step belongs to. Only included for async calls."
)
agent_id: Optional[str] = Field(None, description="The ID of the agent that performed the step.")
provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.")
model: Optional[str] = Field(None, description="The name of the model used for this step.")
model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.")

View File

@@ -70,4 +70,11 @@ class SerializedAgentSchema(BaseSchema):
class Meta(BaseSchema.Meta):
model = Agent
# TODO: Serialize these as well...
exclude = BaseSchema.Meta.exclude + ("sources", "source_passages", "agent_passages")
exclude = BaseSchema.Meta.exclude + (
"project_id",
"template_id",
"base_template_id",
"sources",
"source_passages",
"agent_passages",
)

View File

@@ -24,7 +24,7 @@ logger = get_logger(__name__)
@router.post(
"/chat/completions",
"/{agent_id}/chat/completions",
response_model=None,
operation_id="create_chat_completions",
responses={
@@ -37,6 +37,7 @@ logger = get_logger(__name__)
},
)
async def create_chat_completions(
agent_id: str,
completion_request: CompletionCreateParams = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
@@ -51,12 +52,6 @@ async def create_chat_completions(
actor = server.user_manager.get_user_or_default(user_id=user_id)
agent_id = str(completion_request.get("user", None))
if agent_id is None:
error_msg = "Must pass agent_id in the 'user' field"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
llm_config = letta_agent.agent_state.llm_config
if llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint:

View File

@@ -13,13 +13,12 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import Block, BlockUpdate, CreateBlock # , BlockLabelUpdate, BlockLimitUpdate
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.passage import Passage, PassageUpdate
from letta.schemas.run import Run
from letta.schemas.source import Source
@@ -119,6 +118,7 @@ async def upload_agent_serialized(
True,
description="If set to True, existing tools can get their source code overwritten by the uploaded tool definitions. Note that Letta core tools can never be updated externally.",
),
project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."),
):
"""
Upload a serialized agent JSON file and recreate the agent in the system.
@@ -129,7 +129,11 @@ async def upload_agent_serialized(
serialized_data = await file.read()
agent_json = json.loads(serialized_data)
new_agent = server.agent_manager.deserialize(
serialized_agent=agent_json, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools
serialized_agent=agent_json,
actor=actor,
append_copy_suffix=append_copy_suffix,
override_existing_tools=override_existing_tools,
project_id=project_id,
)
return new_agent
@@ -526,20 +530,20 @@ def list_messages(
)
@router.patch("/{agent_id}/messages/{message_id}", response_model=Message, operation_id="modify_message")
@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUpdateUnion, operation_id="modify_message")
def modify_message(
agent_id: str,
message_id: str,
request: MessageUpdate = Body(...),
request: LettaMessageUpdateUnion = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the details of a message associated with an agent.
"""
# TODO: Get rid of agent_id here, it's not really relevant
# TODO: support modifying tool calls/returns
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
@router.post(

View File

@@ -20,6 +20,7 @@ def list_steps(
start_date: Optional[str] = Query(None, description='Return steps after this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"),
agent_id: Optional[str] = Query(None, description="Filter by the ID of the agent that performed the step"),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
@@ -42,6 +43,7 @@ def list_steps(
limit=limit,
order=order,
model=model,
agent_id=agent_id,
)

View File

@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Optional
import httpx
import openai
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from fastapi import APIRouter, Body, Depends, Header
from fastapi.responses import StreamingResponse
from openai.types.chat.completion_create_params import CompletionCreateParams
@@ -22,7 +22,7 @@ logger = get_logger(__name__)
@router.post(
"/chat/completions",
"/{agent_id}/chat/completions",
response_model=None,
operation_id="create_voice_chat_completions",
responses={
@@ -35,16 +35,13 @@ logger = get_logger(__name__)
},
)
async def create_voice_chat_completions(
agent_id: str,
completion_request: CompletionCreateParams = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
agent_id = str(completion_request.get("user", None))
if agent_id is None:
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
# Also parse the user's new input
input_message = UserMessage(**get_messages_from_completion_request(completion_request)[-1])

View File

@@ -358,6 +358,49 @@ class AgentManager:
return [agent.to_pydantic() for agent in agents]
@enforce_types
def list_agents_matching_tags(
self,
actor: PydanticUser,
match_all: List[str],
match_some: List[str],
limit: Optional[int] = 50,
) -> List[PydanticAgentState]:
"""
Retrieves agents in the same organization that match all specified `match_all` tags
and at least one tag from `match_some`. The query is optimized for efficiency by
leveraging indexed filtering and aggregation.
Args:
actor (PydanticUser): The user requesting the agent list.
match_all (List[str]): Agents must have all these tags.
match_some (List[str]): Agents must have at least one of these tags.
limit (Optional[int]): Maximum number of agents to return.
Returns:
List[PydanticAgentState: The filtered list of matching agents.
"""
with self.session_maker() as session:
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
if match_all:
# Subquery to find agent IDs that contain all match_all tags
subquery = (
select(AgentsTags.agent_id)
.where(AgentsTags.tag.in_(match_all))
.group_by(AgentsTags.agent_id)
.having(func.count(AgentsTags.tag) == literal(len(match_all)))
)
query = query.where(AgentModel.id.in_(subquery))
if match_some:
# Ensures agents match at least one tag in match_some
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
query = query.group_by(AgentModel.id).limit(limit)
return list(session.execute(query).scalars())
@enforce_types
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
@@ -401,7 +444,12 @@ class AgentManager:
@enforce_types
def deserialize(
self, serialized_agent: dict, actor: PydanticUser, append_copy_suffix: bool = True, override_existing_tools: bool = True
self,
serialized_agent: dict,
actor: PydanticUser,
append_copy_suffix: bool = True,
override_existing_tools: bool = True,
project_id: Optional[str] = None,
) -> PydanticAgentState:
tool_data_list = serialized_agent.pop("tools", [])
@@ -410,7 +458,9 @@ class AgentManager:
agent = schema.load(serialized_agent, session=session)
if append_copy_suffix:
agent.name += "_copy"
agent.create(session, actor=actor)
if project_id:
agent.project_id = project_id
agent = agent.create(session, actor=actor)
pydantic_agent = agent.to_pydantic()
# Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle
@@ -548,6 +598,7 @@ class AgentManager:
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10),
)
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
@@ -718,7 +769,9 @@ class AgentManager:
# Commit the changes
agent.update(session, actor=actor)
# Add system messsage alert to agent
# Force rebuild of system prompt so that the agent is updated with passage count
# and recent passages and add system message alert to agent
self.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True)
self.append_system_message(
agent_id=agent_id,
content=DATA_SOURCE_ATTACH_ALERT,

View File

@@ -13,6 +13,7 @@ from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate, TextContent
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.tool_rule import ToolRule
from letta.schemas.user import User
from letta.system import get_initial_boot_messages, get_login_event
@@ -99,7 +100,10 @@ def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
# TODO: This code is kind of wonky and deserves a rewrite
def compile_memory_metadata_block(
memory_edit_timestamp: datetime.datetime, previous_message_count: int = 0, archival_memory_size: int = 0
memory_edit_timestamp: datetime.datetime,
previous_message_count: int = 0,
archival_memory_size: int = 0,
recent_passages: List[PydanticPassage] = None,
) -> str:
# Put the timestamp in the local timezone (mimicking get_local_time())
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
@@ -110,6 +114,11 @@ def compile_memory_metadata_block(
f"### Memory [last modified: {timestamp_str}]",
f"{previous_message_count} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{archival_memory_size} total memories you created are stored in archival memory (use functions to access them)",
(
f"Most recent archival passages {len(recent_passages)} recent passages: {[passage.text for passage in recent_passages]}"
if recent_passages is not None
else ""
),
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
]
)
@@ -146,6 +155,7 @@ def compile_system_message(
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int = 0,
recent_passages: Optional[List[PydanticPassage]] = None,
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
@@ -170,6 +180,7 @@ def compile_system_message(
memory_edit_timestamp=in_context_memory_last_edit,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
recent_passages=recent_passages,
)
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()

View File

@@ -78,7 +78,13 @@ class IdentityManager:
if existing_identity is None:
return self.create_identity(identity=identity, actor=actor)
else:
identity_update = IdentityUpdate(name=identity.name, identity_type=identity.identity_type, agent_ids=identity.agent_ids)
identity_update = IdentityUpdate(
name=identity.name,
identifier_key=identity.identifier_key,
identity_type=identity.identity_type,
agent_ids=identity.agent_ids,
properties=identity.properties,
)
return self._update_identity(
session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
)

View File

@@ -1,3 +1,4 @@
import json
from typing import List, Optional
from sqlalchemy import and_, or_
@@ -7,6 +8,7 @@ from letta.orm.agent import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import LettaMessageUpdateUnion
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageUpdate
from letta.schemas.user import User as PydanticUser
@@ -64,6 +66,44 @@ class MessageManager:
"""Create multiple messages."""
return [self.create_message(m, actor=actor) for m in pydantic_msgs]
@enforce_types
def update_message_by_letta_message(
self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser
) -> PydanticMessage:
"""
Updated the underlying messages table giving an update specified to the user-facing LettaMessage
"""
message = self.get_message_by_id(message_id=message_id, actor=actor)
if letta_message_update.message_type == "assistant_message":
# modify the tool call for send_message
# TODO: fix this if we add parallel tool calls
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
assert (
message.tool_calls[0].function.name == "send_message"
), f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
original_args = json.loads(message.tool_calls[0].function.arguments)
original_args["message"] = letta_message_update.content # override the assistant message
update_tool_call = message.tool_calls[0].__deepcopy__()
update_tool_call.function.arguments = json.dumps(original_args)
update_message = MessageUpdate(tool_calls=[update_tool_call])
elif letta_message_update.message_type == "reasoning_message":
update_message = MessageUpdate(content=letta_message_update.reasoning)
elif letta_message_update.message_type == "user_message" or letta_message_update.message_type == "system_message":
update_message = MessageUpdate(content=letta_message_update.content)
else:
raise ValueError(f"Unsupported message type for modification: {letta_message_update.message_type}")
message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor)
# convert back to LettaMessage
for letta_msg in message.to_letta_message(use_assistant_message=True):
if letta_msg.message_type == letta_message_update.message_type:
return letta_msg
# raise error if message type got modified
raise ValueError(f"Message type got modified: {letta_message_update.message_type}")
@enforce_types
def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
"""

View File

@@ -33,10 +33,15 @@ class StepManager:
limit: Optional[int] = 50,
order: Optional[str] = None,
model: Optional[str] = None,
agent_id: Optional[str] = None,
) -> List[PydanticStep]:
"""List all jobs with optional pagination and status filter."""
with self.session_maker() as session:
filter_kwargs = {"organization_id": actor.organization_id, "model": model}
filter_kwargs = {"organization_id": actor.organization_id}
if model:
filter_kwargs["model"] = model
if agent_id:
filter_kwargs["agent_id"] = agent_id
steps = StepModel.list(
db_session=session,
@@ -54,6 +59,7 @@ class StepManager:
def log_step(
self,
actor: PydanticUser,
agent_id: str,
provider_name: str,
model: str,
model_endpoint: Optional[str],
@@ -65,6 +71,7 @@ class StepManager:
step_data = {
"origin": None,
"organization_id": actor.organization_id,
"agent_id": agent_id,
"provider_id": provider_id,
"provider_name": provider_name,
"model": model,

704
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.6.37"
version = "0.6.38"
packages = [
{include = "letta"},
]
@@ -58,8 +58,8 @@ nltk = "^3.8.1"
jinja2 = "^3.1.5"
locust = {version = "^2.31.5", optional = true}
wikipedia = {version = "^1.4.0", optional = true}
composio-langchain = "^0.7.2"
composio-core = "^0.7.2"
composio-langchain = "^0.7.7"
composio-core = "^0.7.7"
alembic = "^1.13.3"
pyhumps = "^3.8.0"
psycopg2 = {version = "^2.9.10", optional = true}
@@ -98,9 +98,10 @@ qdrant = ["qdrant-client"]
cloud-tool-sandbox = ["e2b-code-interpreter"]
external-tools = ["docker", "langchain", "wikipedia", "langchain-community"]
tests = ["wikipedia"]
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
bedrock = ["boto3"]
google = ["google-genai"]
desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "datasets", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "datamodel-code-generator"]
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "datamodel-code-generator"]
[tool.poetry.group.dev.dependencies]
black = "^24.4.2"

View File

@@ -17,6 +17,7 @@ from letta.embeddings import embedding_model
from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError
from letta.helpers.json_helpers import json_dumps
from letta.llm_api.llm_api_tools import create
from letta.llm_api.llm_client import LLMClient
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
@@ -103,12 +104,23 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner
messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user)
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
response = create(
llm_client = LLMClient.create(
agent_id=agent_state.id,
llm_config=agent_state.llm_config,
user_id=str(uuid.UUID(int=1)), # dummy user_id
messages=messages,
functions=[t.json_schema for t in agent.agent_state.tools],
actor_id=str(uuid.UUID(int=1)),
)
if llm_client:
response = llm_client.send_llm_request(
messages=messages,
tools=[t.json_schema for t in agent.agent_state.tools],
)
else:
response = create(
llm_config=agent_state.llm_config,
user_id=str(uuid.UUID(int=1)), # dummy user_id
messages=messages,
functions=[t.json_schema for t in agent.agent_state.tools],
)
# Basic check
assert response is not None, response

View File

@@ -120,12 +120,11 @@ def agent(client, roll_dice_tool, weather_tool, composio_gmail_get_profile_tool)
# --- Helper Functions --- #
def _get_chat_request(agent_id, message, stream=True):
def _get_chat_request(message, stream=True):
"""Returns a chat completion request with streaming enabled."""
return ChatCompletionRequest(
model="gpt-4o-mini",
messages=[UserMessage(content=message)],
user=agent_id,
stream=stream,
)
@@ -157,9 +156,9 @@ def _assert_valid_chunk(chunk, idx, chunks):
@pytest.mark.parametrize("endpoint", ["v1/voice"])
async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint):
"""Tests chat completion streaming using the Async OpenAI client."""
request = _get_chat_request(agent.id, message)
request = _get_chat_request(message)
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}", max_retries=0)
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}/{agent.id}", max_retries=0)
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
async with stream:
async for chunk in stream:
@@ -171,9 +170,9 @@ async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint):
@pytest.mark.parametrize("endpoint", ["openai/v1", "v1/voice"])
async def test_chat_completions_streaming_openai_client(mock_e2b_api_key_none, client, agent, message, endpoint):
"""Tests chat completion streaming using the Async OpenAI client."""
request = _get_chat_request(agent.id, message)
request = _get_chat_request(message)
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}", max_retries=0)
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}/{agent.id}", max_retries=0)
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
received_chunks = 0

View File

@@ -127,54 +127,55 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_send_message_to_agents_with_tags_simple(client):
worker_tags = ["worker", "user-456"]
worker_tags_123 = ["worker", "user-123"]
worker_tags_456 = ["worker", "user-456"]
# Clean up first from possibly failed tests
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags, match_all_tags=True)
prev_worker_agents = client.server.agent_manager.list_agents(
client.user, tags=list(set(worker_tags_123 + worker_tags_456)), match_all_tags=True
)
for agent in prev_worker_agents:
client.delete_agent(agent.id)
secret_word = "banana"
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 non-matching worker agents (These should NOT get the message)
worker_agents = []
worker_tags = ["worker", "user-123"]
worker_agents_123 = []
for _ in range(3):
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_123)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents.append(worker_agent)
worker_agents_123.append(worker_agent)
# Create 3 worker agents that should get the message
worker_agents = []
worker_tags = ["worker", "user-456"]
worker_agents_456 = []
for _ in range(3):
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags)
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_456)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents.append(worker_agent)
worker_agents_456.append(worker_agent)
# Encourage the manager to send a message to the other agent_obj with the secret string
response = client.send_message(
agent_id=manager_agent.agent_state.id,
role="user",
message=f"Send a message to all agents with tags {worker_tags} informing them of the secret word: {secret_word}!",
message=f"Send a message to all agents with tags {worker_tags_456} informing them of the secret word: {secret_word}!",
)
for m in response.messages:
if isinstance(m, ToolReturnMessage):
tool_response = eval(json.loads(m.tool_return)["message"])
print(f"\n\nManager agent tool response: \n{tool_response}\n\n")
assert len(tool_response) == len(worker_agents)
assert len(tool_response) == len(worker_agents_456)
# We can break after this, the ToolReturnMessage after is not related
break
# Conversation search the worker agents
for agent in worker_agents:
for agent in worker_agents_456:
messages = client.get_messages(agent.agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
@@ -183,13 +184,22 @@ def test_send_message_to_agents_with_tags_simple(client):
assert secret_word in m.content
break
# Ensure it's NOT in the non matching worker agents
for agent in worker_agents_123:
messages = client.get_messages(agent.agent_state.id)
# Check for the presence of system message
for m in reversed(messages):
print(f"\n\n {agent.agent_state.id} -> {m.model_dump_json(indent=4)}")
if isinstance(m, SystemMessage):
assert secret_word not in m.content
# Test that the agent can still receive messages fine
response = client.send_message(agent_id=manager_agent.agent_state.id, role="user", message="So what did the other agents say?")
print("Manager agent followup message: \n\n" + "\n".join([str(m) for m in response.messages]))
# Clean up agents
client.delete_agent(manager_agent_state.id)
for agent in worker_agents:
for agent in worker_agents_456 + worker_agents_123:
client.delete_agent(agent.agent_state.id)
@@ -203,8 +213,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
client.delete_agent(agent.id)
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 worker agents
@@ -245,8 +255,8 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id])
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id])
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 2 worker agents
@@ -260,7 +270,7 @@ def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
worker_agents.append(worker_agent)
# Encourage the manager to send a message to the other agent_obj with the secret string
broadcast_message = f"Using your tool named `send_message_to_agents_matching_all_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!"
broadcast_message = f"Using your tool named `send_message_to_agents_matching_tags`, instruct all agents with tags {worker_tags} to `core_memory_append` the topic of the day: bananas!"
client.send_message(
agent_id=manager_agent.agent_state.id,
role="user",

View File

@@ -65,10 +65,8 @@ def test_multi_agent_large(client, roll_dice_tool, num_workers):
client.delete_agent(agent.id)
# Create "manager" agent
send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags")
manager_agent_state = client.create_agent(
name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id], tags=manager_tags
)
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id], tags=manager_tags)
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 worker agents

View File

@@ -229,7 +229,7 @@ def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log:
- Datetime fields are ignored.
- Order-independent comparison for lists of dicts.
"""
ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id"}
ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id", "agent_id", "project_id"}
# Remove datetime fields upfront
d1 = strip_datetime_fields(d1)
@@ -476,8 +476,9 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
# FastAPI endpoint tests
@pytest.mark.parametrize("append_copy_suffix", [True])
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix):
@pytest.mark.parametrize("append_copy_suffix", [True, False])
@pytest.mark.parametrize("project_id", ["project-12345", None])
def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent, default_user, other_user, append_copy_suffix, project_id):
"""
Test the full E2E serialization and deserialization flow using FastAPI endpoints.
"""
@@ -495,7 +496,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
upload_response = fastapi_client.post(
"/v1/agents/upload",
headers={"user_id": other_user.id},
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False},
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id},
files=files,
)
assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
@@ -504,7 +505,8 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
copied_agent = upload_response.json()
copied_agent_id = copied_agent["id"]
assert copied_agent_id != agent_id, "Copied agent should have a different ID"
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
if append_copy_suffix:
assert copied_agent["name"] == serialize_test_agent.name + "_copy", "Copied agent name should have '_copy' suffix"
# Step 3: Retrieve the copied agent
serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)

View File

@@ -26,7 +26,7 @@ from letta.schemas.letta_message import (
ToolReturnMessage,
UserMessage,
)
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.letta_response import LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.usage import LettaUsageStatistics
@@ -536,21 +536,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
client.delete_source(source.id)
def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can update the details of a message"""
# create a message
message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
print("Messages=", message_response)
assert isinstance(message_response, LettaResponse)
assert isinstance(message_response.messages[-1], AssistantMessage)
message = message_response.messages[-1]
new_text = "this is a secret message"
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
assert new_message.text == new_text
def test_organization(client: RESTClient):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")

View File

@@ -21,8 +21,10 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate, MessageUpdate
@@ -40,6 +42,7 @@ from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.server import SyncServer
from letta.services.block_manager import BlockManager
from letta.services.identity_manager import IdentityManager
from letta.services.organization_manager import OrganizationManager
from letta.settings import tool_settings
from tests.helpers.utils import comprehensive_agent_checks
@@ -472,6 +475,45 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
server.source_manager.delete_source(default_source.id, actor=actor)
@pytest.fixture
def agent_with_tags(server: SyncServer, default_user):
"""Fixture to create agents with specific tags."""
agent1 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="agent1",
tags=["primary_agent", "benefit_1"],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
),
actor=default_user,
)
agent2 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="agent2",
tags=["primary_agent", "benefit_2"],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
),
actor=default_user,
)
agent3 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="agent3",
tags=["primary_agent", "benefit_1", "benefit_2"],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
memory_blocks=[],
),
actor=default_user,
)
return [agent1, agent2, agent3]
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@@ -775,6 +817,45 @@ def test_list_attached_agents_nonexistent_source(server: SyncServer, default_use
# ======================================================================================================================
def test_list_agents_matching_all_tags(server: SyncServer, default_user, agent_with_tags):
agents = server.agent_manager.list_agents_matching_tags(
actor=default_user,
match_all=["primary_agent", "benefit_1"],
match_some=[],
)
assert len(agents) == 2 # agent1 and agent3 match
assert {a.name for a in agents} == {"agent1", "agent3"}
def test_list_agents_matching_some_tags(server: SyncServer, default_user, agent_with_tags):
agents = server.agent_manager.list_agents_matching_tags(
actor=default_user,
match_all=["primary_agent"],
match_some=["benefit_1", "benefit_2"],
)
assert len(agents) == 3 # All agents match
assert {a.name for a in agents} == {"agent1", "agent2", "agent3"}
def test_list_agents_matching_all_and_some_tags(server: SyncServer, default_user, agent_with_tags):
agents = server.agent_manager.list_agents_matching_tags(
actor=default_user,
match_all=["primary_agent", "benefit_1"],
match_some=["benefit_2", "nonexistent"],
)
assert len(agents) == 1 # Only agent3 matches
assert agents[0].name == "agent3"
def test_list_agents_matching_no_tags(server: SyncServer, default_user, agent_with_tags):
agents = server.agent_manager.list_agents_matching_tags(
actor=default_user,
match_all=["primary_agent", "nonexistent_tag"],
match_some=["benefit_1", "benefit_2"],
)
assert len(agents) == 0 # No agent should match
def test_list_agents_by_tags_match_all(server: SyncServer, sarah_agent, charles_agent, default_user):
"""Test listing agents that have ALL specified tags."""
# Create agents with multiple tags
@@ -1073,6 +1154,73 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use
assert server.message_manager.size(agent_id=sarah_agent.id, actor=default_user) == 1
def test_modify_letta_message(server: SyncServer, sarah_agent, default_user):
"""
Test updating a message.
"""
messages = server.message_manager.list_messages_for_agent(agent_id=sarah_agent.id, actor=default_user)
letta_messages = PydanticMessage.to_letta_messages_from_list(messages=messages)
system_message = [msg for msg in letta_messages if msg.message_type == "system_message"][0]
assistant_message = [msg for msg in letta_messages if msg.message_type == "assistant_message"][0]
user_message = [msg for msg in letta_messages if msg.message_type == "user_message"][0]
reasoning_message = [msg for msg in letta_messages if msg.message_type == "reasoning_message"][0]
# user message
update_user_message = UpdateUserMessage(content="Hello, Sarah!")
original_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
assert original_user_message.content[0].text != update_user_message.content
server.message_manager.update_message_by_letta_message(
message_id=user_message.id, letta_message_update=update_user_message, actor=default_user
)
updated_user_message = server.message_manager.get_message_by_id(message_id=user_message.id, actor=default_user)
assert updated_user_message.content[0].text == update_user_message.content
# system message
update_system_message = UpdateSystemMessage(content="You are a friendly assistant!")
original_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
assert original_system_message.content[0].text != update_system_message.content
server.message_manager.update_message_by_letta_message(
message_id=system_message.id, letta_message_update=update_system_message, actor=default_user
)
updated_system_message = server.message_manager.get_message_by_id(message_id=system_message.id, actor=default_user)
assert updated_system_message.content[0].text == update_system_message.content
# reasoning message
update_reasoning_message = UpdateReasoningMessage(reasoning="I am thinking")
original_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
assert original_reasoning_message.content[0].text != update_reasoning_message.reasoning
server.message_manager.update_message_by_letta_message(
message_id=reasoning_message.id, letta_message_update=update_reasoning_message, actor=default_user
)
updated_reasoning_message = server.message_manager.get_message_by_id(message_id=reasoning_message.id, actor=default_user)
assert updated_reasoning_message.content[0].text == update_reasoning_message.reasoning
# assistant message
def parse_send_message(tool_call):
import json
function_call = tool_call.function
arguments = json.loads(function_call.arguments)
return arguments["message"]
update_assistant_message = UpdateAssistantMessage(content="I am an agent!")
original_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
print("ORIGINAL", original_assistant_message.tool_calls)
print("MESSAGE", parse_send_message(original_assistant_message.tool_calls[0]))
assert parse_send_message(original_assistant_message.tool_calls[0]) != update_assistant_message.content
server.message_manager.update_message_by_letta_message(
message_id=assistant_message.id, letta_message_update=update_assistant_message, actor=default_user
)
updated_assistant_message = server.message_manager.get_message_by_id(message_id=assistant_message.id, actor=default_user)
print("UPDATED", updated_assistant_message.tool_calls)
print("MESSAGE", parse_send_message(updated_assistant_message.tool_calls[0]))
assert parse_send_message(updated_assistant_message.tool_calls[0]) == update_assistant_message.content
# TODO: tool calls/responses
# ======================================================================================================================
# AgentManager Tests - Blocks Relationship
# ======================================================================================================================
@@ -2001,28 +2149,42 @@ def test_update_block(server: SyncServer, default_user):
def test_update_block_limit(server: SyncServer, default_user):
block_manager = BlockManager()
block = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content"), actor=default_user)
limit = len("Updated Content") * 2000
update_data = BlockUpdate(value="Updated Content" * 2000, description="Updated description", limit=limit)
update_data = BlockUpdate(value="Updated Content" * 2000, description="Updated description")
# Check that a large block fails
try:
# Check that exceeding the block limit raises an exception
with pytest.raises(ValueError):
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
assert False
except Exception:
pass
# Ensure the update works when within limits
update_data = BlockUpdate(value="Updated Content" * 2000, description="Updated description", limit=limit)
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
# Retrieve the updated block
# Retrieve the updated block and validate the update
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
# Assertions to verify the update
assert updated_block.value == "Updated Content" * 2000
assert updated_block.description == "Updated description"
def test_update_block_limit_does_not_reset(server: SyncServer, default_user):
block_manager = BlockManager()
new_content = "Updated Content" * 2000
limit = len(new_content)
block = block_manager.create_or_update_block(PydanticBlock(label="persona", value="Original Content", limit=limit), actor=default_user)
# Ensure the update works
update_data = BlockUpdate(value=new_content)
block_manager.update_block(block_id=block.id, block_update=update_data, actor=default_user)
# Retrieve the updated block and validate the update
updated_block = block_manager.get_blocks(actor=default_user, id=block.id)[0]
assert updated_block.value == new_content
def test_delete_block(server: SyncServer, default_user):
block_manager = BlockManager()
@@ -2075,6 +2237,154 @@ def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, de
assert charles_agent.id in agent_state_ids
# ======================================================================================================================
# Identity Manager Tests
# ======================================================================================================================
def test_create_and_upsert_identity(server: SyncServer, default_user):
identity_manager = IdentityManager()
identity_create = IdentityCreate(
identifier_key="1234",
name="caren",
identity_type=IdentityType.user,
properties=[
IdentityProperty(key="email", value="caren@letta.com", type=IdentityPropertyType.string),
IdentityProperty(key="age", value=28, type=IdentityPropertyType.number),
],
)
identity = identity_manager.create_identity(identity_create, actor=default_user)
# Assertions to ensure the created identity matches the expected values
assert identity.identifier_key == identity_create.identifier_key
assert identity.name == identity_create.name
assert identity.identity_type == identity_create.identity_type
assert identity.properties == identity_create.properties
assert identity.agent_ids == []
assert identity.project_id == None
with pytest.raises(UniqueConstraintViolationError):
identity_manager.create_identity(
IdentityCreate(identifier_key="1234", name="sarah", identity_type=IdentityType.user),
actor=default_user,
)
identity_create.properties = [(IdentityProperty(key="age", value=29, type=IdentityPropertyType.number))]
identity = identity_manager.upsert_identity(identity_create, actor=default_user)
identity = identity_manager.get_identity(identity_id=identity.id, actor=default_user)
assert len(identity.properties) == 1
assert identity.properties[0].key == "age"
assert identity.properties[0].value == 29
identity_manager.delete_identity(identity.id, actor=default_user)
def test_get_identities(server, default_user):
identity_manager = IdentityManager()
# Create identities to retrieve later
user = identity_manager.create_identity(
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
)
org = identity_manager.create_identity(
IdentityCreate(name="letta", identifier_key="0001", identity_type=IdentityType.org), actor=default_user
)
# Retrieve identities by different filters
all_identities = identity_manager.list_identities(actor=default_user)
assert len(all_identities) == 2
user_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.user)
assert len(user_identities) == 1
assert user_identities[0].name == user.name
org_identities = identity_manager.list_identities(actor=default_user, identity_type=IdentityType.org)
assert len(org_identities) == 1
assert org_identities[0].name == org.name
identity_manager.delete_identity(user.id, actor=default_user)
identity_manager.delete_identity(org.id, actor=default_user)
def test_update_identity(server: SyncServer, sarah_agent, charles_agent, default_user):
identity = server.identity_manager.create_identity(
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
)
# Update identity fields
update_data = IdentityUpdate(
agent_ids=[sarah_agent.id, charles_agent.id],
properties=[IdentityProperty(key="email", value="caren@letta.com", type=IdentityPropertyType.string)],
)
server.identity_manager.update_identity(identity_id=identity.id, identity=update_data, actor=default_user)
# Retrieve the updated identity
updated_identity = server.identity_manager.get_identity(identity_id=identity.id, actor=default_user)
# Assertions to verify the update
assert updated_identity.agent_ids.sort() == update_data.agent_ids.sort()
assert updated_identity.properties == update_data.properties
agent_state = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user)
assert identity.id in agent_state.identity_ids
agent_state = server.agent_manager.get_agent_by_id(agent_id=charles_agent.id, actor=default_user)
assert identity.id in agent_state.identity_ids
server.identity_manager.delete_identity(identity.id, actor=default_user)
def test_attach_detach_identity_from_agent(server: SyncServer, sarah_agent, default_user):
# Create an identity
identity = server.identity_manager.create_identity(
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user), actor=default_user
)
agent_state = server.agent_manager.update_agent(
agent_id=sarah_agent.id, agent_update=UpdateAgent(identity_ids=[identity.id]), actor=default_user
)
# Check that identity has been attached
assert identity.id in agent_state.identity_ids
# Now attempt to delete the identity
server.identity_manager.delete_identity(identity_id=identity.id, actor=default_user)
# Verify that the identity was deleted
identities = server.identity_manager.list_identities(actor=default_user)
assert len(identities) == 0
# Check that block has been detached too
agent_state = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user)
assert not identity.id in agent_state.identity_ids
def test_get_agents_for_identities(server: SyncServer, sarah_agent, charles_agent, default_user):
identity = server.identity_manager.create_identity(
IdentityCreate(name="caren", identifier_key="1234", identity_type=IdentityType.user, agent_ids=[sarah_agent.id, charles_agent.id]),
actor=default_user,
)
# Get the agents for identity id
agent_states = server.agent_manager.list_agents(identifier_id=identity.id, actor=default_user)
assert len(agent_states) == 2
# Check both agents are in the list
agent_state_ids = [a.id for a in agent_states]
assert sarah_agent.id in agent_state_ids
assert charles_agent.id in agent_state_ids
# Get the agents for identifier key
agent_states = server.agent_manager.list_agents(identifier_keys=[identity.identifier_key], actor=default_user)
assert len(agent_states) == 2
# Check both agents are in the list
agent_state_ids = [a.id for a in agent_states]
assert sarah_agent.id in agent_state_ids
assert charles_agent.id in agent_state_ids
# ======================================================================================================================
# SourceManager Tests - Sources
# ======================================================================================================================
@@ -3095,13 +3405,14 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
# ======================================================================================================================
def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_user):
def test_job_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_job, default_user):
"""Test adding and retrieving job usage statistics."""
job_manager = server.job_manager
step_manager = server.step_manager
# Add usage statistics
step_manager.log_step(
agent_id=sarah_agent.id,
provider_name="openai",
model="gpt-4",
model_endpoint="https://api.openai.com/v1",
@@ -3145,13 +3456,14 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
assert len(steps) == 0
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
def test_job_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_job, default_user):
"""Test adding multiple usage statistics entries for a job."""
job_manager = server.job_manager
step_manager = server.step_manager
# Add first usage statistics entry
step_manager.log_step(
agent_id=sarah_agent.id,
provider_name="openai",
model="gpt-4",
model_endpoint="https://api.openai.com/v1",
@@ -3167,6 +3479,7 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
# Add second usage statistics entry
step_manager.log_step(
agent_id=sarah_agent.id,
provider_name="openai",
model="gpt-4",
model_endpoint="https://api.openai.com/v1",
@@ -3193,6 +3506,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
assert len(steps) == 2
# get agent steps
steps = step_manager.list_steps(agent_id=sarah_agent.id, actor=default_user)
assert len(steps) == 2
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
"""Test getting usage statistics for a nonexistent job."""
@@ -3202,12 +3519,13 @@ def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
job_manager.get_job_usage(job_id="nonexistent_job", actor=default_user)
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user):
def test_job_usage_stats_add_nonexistent_job(server: SyncServer, sarah_agent, default_user):
"""Test adding usage statistics for a nonexistent job."""
step_manager = server.step_manager
with pytest.raises(NoResultFound):
step_manager.log_step(
agent_id=sarah_agent.id,
provider_name="openai",
model="gpt-4",
model_endpoint="https://api.openai.com/v1",