feat: fix archival stats and uvicorn env vars (#2486)
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
"""add content parts to message
|
||||
|
||||
Revision ID: 2cceb07c2384
|
||||
Revises: 77de976590ae
|
||||
Create Date: 2025-03-13 14:30:53.177061
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from letta.orm.custom_columns import MessageContentColumn
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2cceb07c2384"
|
||||
down_revision: Union[str, None] = "77de976590ae"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("messages", sa.Column("content", MessageContentColumn(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("messages", "content")
|
||||
# ### end Alembic commands ###
|
||||
62
alembic/versions/77de976590ae_add_groups_for_multi_agent.py
Normal file
62
alembic/versions/77de976590ae_add_groups_for_multi_agent.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""add groups for multi agent
|
||||
|
||||
Revision ID: 77de976590ae
|
||||
Revises: 167491cfb7a8
|
||||
Create Date: 2025-03-12 14:01:58.034385
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "77de976590ae"
|
||||
down_revision: Union[str, None] = "167491cfb7a8"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"groups",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=False),
|
||||
sa.Column("manager_type", sa.String(), nullable=False),
|
||||
sa.Column("manager_agent_id", sa.String(), nullable=True),
|
||||
sa.Column("termination_token", sa.String(), nullable=True),
|
||||
sa.Column("max_turns", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["manager_agent_id"], ["agents.id"], ondelete="RESTRICT"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"groups_agents",
|
||||
sa.Column("group_id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("group_id", "agent_id"),
|
||||
)
|
||||
op.add_column("messages", sa.Column("group_id", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("messages", "group_id")
|
||||
op.drop_table("groups_agents")
|
||||
op.drop_table("groups")
|
||||
# ### end Alembic commands ###
|
||||
@@ -39,7 +39,8 @@ from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageContentType, MessageRole
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
@@ -95,6 +96,7 @@ class Agent(BaseAgent):
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
# MCP sessions, state held in-memory in the server
|
||||
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
|
||||
save_last_response: bool = False,
|
||||
):
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
|
||||
# Hold a copy of the state that was used to init the agent
|
||||
@@ -149,6 +151,10 @@ class Agent(BaseAgent):
|
||||
# Load last function response from message history
|
||||
self.last_function_response = self.load_last_function_response()
|
||||
|
||||
# Save last responses in memory
|
||||
self.save_last_response = save_last_response
|
||||
self.last_response_messages = []
|
||||
|
||||
# Logger that the Agent specifically can use, will also report the agent_state ID with the logs
|
||||
self.logger = get_logger(agent_state.id)
|
||||
|
||||
@@ -160,7 +166,7 @@ class Agent(BaseAgent):
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
for i in range(len(in_context_messages) - 1, -1, -1):
|
||||
msg = in_context_messages[i]
|
||||
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and msg.content[0].type == MessageContentType.text:
|
||||
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
|
||||
text_content = msg.content[0].text
|
||||
try:
|
||||
response_json = json.loads(text_content)
|
||||
@@ -926,6 +932,9 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
all_new_messages = all_response_messages
|
||||
|
||||
if self.save_last_response:
|
||||
self.last_response_messages = all_response_messages
|
||||
|
||||
# Check the memory pressure and potentially issue a memory pressure warning
|
||||
current_total_tokens = response.usage.total_tokens
|
||||
active_memory_warning = False
|
||||
@@ -1052,6 +1061,7 @@ class Agent(BaseAgent):
|
||||
|
||||
else:
|
||||
logger.error(f"step() failed with an unrecognized exception: '{str(e)}'")
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse:
|
||||
@@ -1201,7 +1211,7 @@ class Agent(BaseAgent):
|
||||
and in_context_messages[1].role == MessageRole.user
|
||||
and in_context_messages[1].content
|
||||
and len(in_context_messages[1].content) == 1
|
||||
and in_context_messages[1].content[0].type == MessageContentType.text
|
||||
and isinstance(in_context_messages[1].content[0], TextContent)
|
||||
# TODO remove hardcoding
|
||||
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
|
||||
):
|
||||
|
||||
@@ -5,7 +5,8 @@ import openai
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import TextContent, UserMessage
|
||||
from letta.schemas.letta_message import UserMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.user import User
|
||||
|
||||
@@ -40,6 +40,7 @@ from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
from letta.utils import united_diff
|
||||
@@ -75,6 +76,7 @@ class LowLatencyAgent(BaseAgent):
|
||||
# TODO: Make this more general, factorable
|
||||
# Summarizer settings
|
||||
self.block_manager = block_manager
|
||||
self.passage_manager = PassageManager() # TODO: pass this in
|
||||
# TODO: This is not guaranteed to exist!
|
||||
self.summary_block_label = "human"
|
||||
self.summarizer = Summarizer(
|
||||
@@ -246,10 +248,16 @@ class LowLatencyAgent(BaseAgent):
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
num_messages = self.message_manager.size(actor=actor, agent_id=agent_id)
|
||||
num_archival_memories = self.passage_manager.size(actor=actor, agent_id=agent_id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
|
||||
274
letta/dynamic_multi_agent.py
Normal file
274
letta/dynamic_multi_agent.py
Normal file
@@ -0,0 +1,274 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.tool_manager import ToolManager
|
||||
|
||||
|
||||
class DynamicMultiAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User = None,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
description: str = "",
|
||||
max_turns: Optional[int] = None,
|
||||
termination_token: str = "DONE!",
|
||||
):
|
||||
super().__init__(interface, agent_state, user)
|
||||
self.group_id = group_id
|
||||
self.agent_ids = agent_ids
|
||||
self.description = description
|
||||
self.max_turns = max_turns or len(agent_ids)
|
||||
self.termination_token = termination_token
|
||||
|
||||
self.tool_manager = ToolManager()
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
agents = {}
|
||||
message_index = {self.agent_state.id: 0}
|
||||
agents[self.agent_state.id] = self.load_manager_agent()
|
||||
for agent_id in self.agent_ids:
|
||||
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
|
||||
message_index[agent_id] = 0
|
||||
|
||||
chat_history: List[Message] = []
|
||||
new_messages = messages
|
||||
speaker_id = None
|
||||
try:
|
||||
for _ in range(self.max_turns):
|
||||
agent_id_options = [agent_id for agent_id in self.agent_ids if agent_id != speaker_id]
|
||||
manager_message = self.ask_manager_to_choose_participant_message(new_messages, chat_history, agent_id_options)
|
||||
manager_agent = agents[self.agent_state.id]
|
||||
usage_stats = manager_agent.step(
|
||||
messages=[manager_message],
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages)
|
||||
assistant_message = [response for response in responses if response.message_type == "assistant_message"][0]
|
||||
for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]:
|
||||
if name.lower() in assistant_message.content.lower():
|
||||
speaker_id = agent_id
|
||||
|
||||
# sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
total_usage.completion_tokens += usage_stats.completion_tokens
|
||||
total_usage.total_tokens += usage_stats.total_tokens
|
||||
step_count += 1
|
||||
|
||||
# initialize input messages
|
||||
for message in chat_history[message_index[speaker_id] :]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = speaker_id
|
||||
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=speaker_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
|
||||
# load agent and perform step
|
||||
participant_agent = agents[speaker_id]
|
||||
usage_stats = participant_agent.step(
|
||||
messages=chat_history[message_index[speaker_id] :],
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
# parse new messages for next step
|
||||
responses = Message.to_letta_messages_from_list(
|
||||
participant_agent.last_response_messages,
|
||||
)
|
||||
|
||||
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
|
||||
new_messages = [
|
||||
MessageCreate(
|
||||
role="system",
|
||||
content=message.content,
|
||||
name=participant_agent.agent_state.name,
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
message_index[agent_id] = len(chat_history) + len(new_messages)
|
||||
|
||||
# sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
total_usage.completion_tokens += usage_stats.completion_tokens
|
||||
total_usage.total_tokens += usage_stats.total_tokens
|
||||
step_count += 1
|
||||
|
||||
# check for termination token
|
||||
if any(self.termination_token in message.content for message in new_messages):
|
||||
break
|
||||
|
||||
# persist remaining chat history
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
for agent_id, index in message_index.items():
|
||||
if agent_id == speaker_id:
|
||||
continue
|
||||
for message in chat_history[index:]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = agent_id
|
||||
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
self.interface.step_yield()
|
||||
|
||||
self.interface.step_complete()
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
|
||||
def load_manager_agent(self) -> Agent:
|
||||
for participant_agent_id in self.agent_ids:
|
||||
participant_agent_state = self.agent_manager.get_agent_by_id(agent_id=participant_agent_id, actor=self.user)
|
||||
participant_persona_block = participant_agent_state.memory.get_block(label="persona")
|
||||
new_block = self.block_manager.create_or_update_block(
|
||||
block=Block(
|
||||
label=participant_agent_id,
|
||||
value=participant_persona_block.value,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
self.agent_state = self.agent_manager.update_block_with_label(
|
||||
agent_id=self.agent_state.id,
|
||||
block_label=participant_agent_id,
|
||||
new_block_id=new_block.id,
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
persona_block = self.agent_state.memory.get_block(label="persona")
|
||||
group_chat_manager_persona = (
|
||||
f"You are overseeing a group chat with {len(self.agent_ids) - 1} agents and "
|
||||
f"one user. Description of the group: {self.description}\n"
|
||||
"On each turn, you will be provided with the chat history and latest message. "
|
||||
"Your task is to decide which participant should speak next in the chat based "
|
||||
"on the chat history. Each agent has a memory block labeled with their ID which "
|
||||
"holds info about them, and you should use this context to inform your decision."
|
||||
)
|
||||
self.agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_manager_persona)
|
||||
return Agent(
|
||||
agent_state=self.agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
save_last_response=True,
|
||||
)
|
||||
|
||||
def load_participant_agent(self, agent_id: str) -> Agent:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
|
||||
persona_block = agent_state.memory.get_block(label="persona")
|
||||
group_chat_participant_persona = (
|
||||
f"You are a participant in a group chat with {len(self.agent_ids) - 1} other "
|
||||
"agents and one user. Respond to new messages in the group chat when prompted. "
|
||||
f"Description of the group: {self.description}. About you: "
|
||||
)
|
||||
agent_state.memory.update_block_value(label="persona", value=group_chat_participant_persona + persona_block.value)
|
||||
return Agent(
|
||||
agent_state=agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
save_last_response=True,
|
||||
)
|
||||
|
||||
'''
|
||||
def attach_choose_next_participant_tool(self) -> AgentState:
|
||||
def choose_next_participant(next_speaker_agent_id: str) -> str:
|
||||
"""
|
||||
Returns ID of the agent in the group chat that should reply to the latest message in the conversation. The agent ID will always be in the format: `agent-{UUID}`.
|
||||
Args:
|
||||
next_speaker_agent_id (str): The ID of the agent that is most suitable to be the next speaker.
|
||||
Returns:
|
||||
str: The ID of the agent that should be the next speaker.
|
||||
"""
|
||||
return next_speaker_agent_id
|
||||
source_code = parse_source_code(choose_next_participant)
|
||||
tool = self.tool_manager.create_or_update_tool(
|
||||
Tool(
|
||||
source_type="python",
|
||||
source_code=source_code,
|
||||
name="choose_next_participant",
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
return self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=tool.id, actor=self.user)
|
||||
'''
|
||||
|
||||
def ask_manager_to_choose_participant_message(
|
||||
self,
|
||||
new_messages: List[MessageCreate],
|
||||
chat_history: List[Message],
|
||||
agent_id_options: List[str],
|
||||
) -> Message:
|
||||
chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history]
|
||||
for message in new_messages:
|
||||
chat_history.append(f"{message.name or 'user'}: {message.content}")
|
||||
context_messages = "\n".join(chat_history)
|
||||
|
||||
message_text = (
|
||||
"Choose the most suitable agent to reply to the latest message in the "
|
||||
f"group chat from the following options: {agent_id_options}. Do not "
|
||||
"respond to the messages yourself, your task is only to decide the "
|
||||
f"next speaker, not to participate. \nChat history:\n{context_messages}"
|
||||
)
|
||||
return Message(
|
||||
agent_id=self.agent_state.id,
|
||||
role="user",
|
||||
content=[TextContent(text=message_text)],
|
||||
name=None,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
@@ -77,6 +77,7 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
||||
text=content,
|
||||
actor=self.user,
|
||||
)
|
||||
self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user, force=True)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ import requests
|
||||
from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
def message_chatgpt(self, message: str):
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List
|
||||
|
||||
from letta.functions.helpers import (
|
||||
_send_message_to_agents_matching_tags_async,
|
||||
_send_message_to_all_agents_in_group_async,
|
||||
execute_send_message_to_agent,
|
||||
fire_and_forget_send_to_agent,
|
||||
)
|
||||
@@ -86,3 +87,19 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
|
||||
"""
|
||||
|
||||
return asyncio.run(_send_message_to_agents_matching_tags_async(self, message, match_all, match_some))
|
||||
|
||||
|
||||
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]:
|
||||
"""
|
||||
Sends a message to all agents within the same multi-agent group.
|
||||
|
||||
Args:
|
||||
message (str): The content of the message to be sent to each matching agent.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of responses from the agents that matched the filtering criteria. Each
|
||||
response corresponds to a single agent. Agents that do not respond will not have an entry
|
||||
in the returned list.
|
||||
"""
|
||||
|
||||
return asyncio.run(_send_message_to_all_agents_in_group_async(self, message))
|
||||
|
||||
@@ -604,6 +604,47 @@ async def _send_message_to_agents_matching_tags_async(
|
||||
return final
|
||||
|
||||
|
||||
async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", message: str) -> List[str]:
|
||||
server = get_letta_server()
|
||||
|
||||
augmented_message = (
|
||||
f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
worker_agents_ids = sender_agent.agent_state.multi_agent_group.agent_ids
|
||||
worker_agents = [server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=sender_agent.user) for agent_id in worker_agents_ids]
|
||||
|
||||
# Create a system message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
|
||||
|
||||
# Possibly limit concurrency to avoid meltdown:
|
||||
sem = asyncio.Semaphore(settings.multi_agent_concurrent_sends)
|
||||
|
||||
async def _send_single(agent_state):
|
||||
async with sem:
|
||||
return await async_send_message_with_retries(
|
||||
server=server,
|
||||
sender_agent=sender_agent,
|
||||
target_agent_id=agent_state.id,
|
||||
messages=messages,
|
||||
max_retries=3,
|
||||
timeout=settings.multi_agent_send_message_timeout,
|
||||
)
|
||||
|
||||
tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in worker_agents]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
final = []
|
||||
for r in results:
|
||||
if isinstance(r, Exception):
|
||||
final.append(str(r))
|
||||
else:
|
||||
final.append(r)
|
||||
|
||||
return final
|
||||
|
||||
|
||||
def generate_model_from_args_json_schema(schema: Dict[str, Any]) -> Type[BaseModel]:
|
||||
"""Creates a Pydantic model from a JSON schema.
|
||||
|
||||
|
||||
@@ -8,6 +8,16 @@ from sqlalchemy import Dialect
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.letta_message_content import (
|
||||
MessageContent,
|
||||
MessageContentType,
|
||||
OmittedReasoningContent,
|
||||
ReasoningContent,
|
||||
RedactedReasoningContent,
|
||||
TextContent,
|
||||
ToolCallContent,
|
||||
ToolReturnContent,
|
||||
)
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import ToolReturn
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule
|
||||
@@ -80,10 +90,13 @@ def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, Term
|
||||
rule_type = ToolRuleType(data.get("type"))
|
||||
|
||||
if rule_type == ToolRuleType.run_first or rule_type == ToolRuleType.InitToolRule:
|
||||
data["type"] = ToolRuleType.run_first
|
||||
return InitToolRule(**data)
|
||||
elif rule_type == ToolRuleType.exit_loop or rule_type == ToolRuleType.TerminalToolRule:
|
||||
data["type"] = ToolRuleType.exit_loop
|
||||
return TerminalToolRule(**data)
|
||||
elif rule_type == ToolRuleType.constrain_child_tools or rule_type == ToolRuleType.ToolRule:
|
||||
data["type"] = ToolRuleType.constrain_child_tools
|
||||
return ChildToolRule(**data)
|
||||
elif rule_type == ToolRuleType.conditional:
|
||||
return ConditionalToolRule(**data)
|
||||
@@ -163,6 +176,60 @@ def deserialize_tool_returns(data: Optional[List[Dict]]) -> List[ToolReturn]:
|
||||
return tool_returns
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# MessageContent Serialization
|
||||
# ----------------------------
|
||||
|
||||
|
||||
def serialize_message_content(message_content: Optional[List[Union[MessageContent, dict]]]) -> List[Dict]:
|
||||
"""Convert a list of MessageContent objects into JSON-serializable format."""
|
||||
if not message_content:
|
||||
return []
|
||||
|
||||
serialized_message_content = []
|
||||
for content in message_content:
|
||||
if isinstance(content, MessageContent):
|
||||
serialized_message_content.append(content.model_dump())
|
||||
elif isinstance(content, dict):
|
||||
serialized_message_content.append(content) # Already a dictionary, leave it as-is
|
||||
else:
|
||||
raise TypeError(f"Unexpected message content type: {type(content)}")
|
||||
|
||||
return serialized_message_content
|
||||
|
||||
|
||||
def deserialize_message_content(data: Optional[List[Dict]]) -> List[MessageContent]:
|
||||
"""Convert a JSON list back into MessageContent objects."""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
message_content = []
|
||||
for item in data:
|
||||
if not item:
|
||||
continue
|
||||
|
||||
content_type = item.get("type")
|
||||
if content_type == MessageContentType.text:
|
||||
content = TextContent(**item)
|
||||
elif content_type == MessageContentType.tool_call:
|
||||
content = ToolCallContent(**item)
|
||||
elif content_type == MessageContentType.tool_return:
|
||||
content = ToolReturnContent(**item)
|
||||
elif content_type == MessageContentType.reasoning:
|
||||
content = ReasoningContent(**item)
|
||||
elif content_type == MessageContentType.redacted_reasoning:
|
||||
content = RedactedReasoningContent(**item)
|
||||
elif content_type == MessageContentType.omitted_reasoning:
|
||||
content = OmittedReasoningContent(**item)
|
||||
else:
|
||||
# Skip invalid content
|
||||
continue
|
||||
|
||||
message_content.append(content)
|
||||
|
||||
return message_content
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Vector Serialization
|
||||
# --------------------------
|
||||
|
||||
@@ -11,6 +11,9 @@ from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# see: https://modelcontextprotocol.io/quickstart/user
|
||||
MCP_CONFIG_TOPLEVEL_KEY = "mcpServers"
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""A simple wrapper around MCP's tool definition (to avoid conflict with our own)"""
|
||||
@@ -18,7 +21,7 @@ class MCPTool(Tool):
|
||||
|
||||
class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
LOCAL = "local"
|
||||
STDIO = "stdio"
|
||||
|
||||
|
||||
class BaseServerConfig(BaseModel):
|
||||
@@ -30,11 +33,29 @@ class SSEServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.SSE
|
||||
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "sse",
|
||||
"url": self.server_url,
|
||||
}
|
||||
return values
|
||||
|
||||
class LocalServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.LOCAL
|
||||
|
||||
class StdioServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.STDIO
|
||||
command: str = Field(..., description="The command to run (MCP 'local' client will run this command)")
|
||||
args: List[str] = Field(..., description="The arguments to pass to the command")
|
||||
env: Optional[dict[str, str]] = Field(None, description="Environment variables to set")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "stdio",
|
||||
"command": self.command,
|
||||
"args": self.args,
|
||||
}
|
||||
if self.env is not None:
|
||||
values["env"] = self.env
|
||||
return values
|
||||
|
||||
|
||||
class BaseMCPClient:
|
||||
@@ -83,8 +104,8 @@ class BaseMCPClient:
|
||||
logger.info("Cleaned up MCP clients on shutdown.")
|
||||
|
||||
|
||||
class LocalMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: LocalServerConfig):
|
||||
class StdioMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: StdioServerConfig):
|
||||
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
|
||||
stdio_cm = stdio_client(server_params)
|
||||
stdio_transport = self.loop.run_until_complete(stdio_cm.__aenter__())
|
||||
|
||||
@@ -221,7 +221,7 @@ def openai_chat_completions_process_stream(
|
||||
# TODO(sarah): add message ID generation function
|
||||
dummy_message = _Message(
|
||||
role=_MessageRole.assistant,
|
||||
text="",
|
||||
content=[],
|
||||
agent_id="",
|
||||
model="",
|
||||
name=None,
|
||||
|
||||
@@ -5,8 +5,9 @@ from letta.llm_api.llm_api_tools import create
|
||||
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.utils import count_tokens, printd
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.groups_agents import GroupsAgents
|
||||
from letta.orm.identities_agents import IdentitiesAgents
|
||||
from letta.orm.identities_blocks import IdentitiesBlocks
|
||||
from letta.orm.identity import Identity
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Set
|
||||
|
||||
from sqlalchemy import JSON, Boolean, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
@@ -128,37 +128,86 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
back_populates="agents",
|
||||
passive_deletes=True,
|
||||
)
|
||||
groups: Mapped[List["Group"]] = relationship(
|
||||
"Group",
|
||||
secondary="groups_agents",
|
||||
lazy="selectin",
|
||||
back_populates="agents",
|
||||
passive_deletes=True,
|
||||
)
|
||||
multi_agent_group: Mapped["Group"] = relationship(
|
||||
"Group",
|
||||
lazy="joined",
|
||||
viewonly=True,
|
||||
back_populates="manager_agent",
|
||||
)
|
||||
|
||||
def to_pydantic(self) -> PydanticAgentState:
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
# add default rule for having send_message be a terminal tool
|
||||
tool_rules = self.tool_rules
|
||||
def to_pydantic(self, include_relationships: Optional[Set[str]] = None) -> PydanticAgentState:
|
||||
"""
|
||||
Converts the SQLAlchemy Agent model into its Pydantic counterpart.
|
||||
|
||||
The following base fields are always included:
|
||||
- id, agent_type, name, description, system, message_ids, metadata_,
|
||||
llm_config, embedding_config, project_id, template_id, base_template_id,
|
||||
tool_rules, message_buffer_autoclear, tags
|
||||
|
||||
Everything else (e.g., tools, sources, memory, etc.) is optional and only
|
||||
included if specified in `include_fields`.
|
||||
|
||||
Args:
|
||||
include_relationships (Optional[Set[str]]):
|
||||
A set of additional field names to include in the output. If None or empty,
|
||||
no extra fields are loaded beyond the base fields.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The Pydantic representation of the agent.
|
||||
"""
|
||||
# Base fields: always included
|
||||
state = {
|
||||
"id": self.id,
|
||||
"organization_id": self.organization_id,
|
||||
"agent_type": self.agent_type,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"message_ids": self.message_ids,
|
||||
"tools": self.tools,
|
||||
"sources": [source.to_pydantic() for source in self.sources],
|
||||
"tags": [t.tag for t in self.tags],
|
||||
"tool_rules": tool_rules,
|
||||
"system": self.system,
|
||||
"agent_type": self.agent_type,
|
||||
"message_ids": self.message_ids,
|
||||
"metadata": self.metadata_, # Exposed as 'metadata' to Pydantic
|
||||
"llm_config": self.llm_config,
|
||||
"embedding_config": self.embedding_config,
|
||||
"metadata": self.metadata_,
|
||||
"memory": Memory(blocks=[b.to_pydantic() for b in self.core_memory]),
|
||||
"project_id": self.project_id,
|
||||
"template_id": self.template_id,
|
||||
"base_template_id": self.base_template_id,
|
||||
"tool_rules": self.tool_rules,
|
||||
"message_buffer_autoclear": self.message_buffer_autoclear,
|
||||
"created_by_id": self.created_by_id,
|
||||
"last_updated_by_id": self.last_updated_by_id,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
"tool_exec_environment_variables": self.tool_exec_environment_variables,
|
||||
"project_id": self.project_id,
|
||||
"template_id": self.template_id,
|
||||
"base_template_id": self.base_template_id,
|
||||
"identity_ids": [identity.id for identity in self.identities],
|
||||
"message_buffer_autoclear": self.message_buffer_autoclear,
|
||||
# optional field defaults
|
||||
"tags": [],
|
||||
"tools": [],
|
||||
"sources": [],
|
||||
"memory": Memory(blocks=[]),
|
||||
"identity_ids": [],
|
||||
"multi_agent_group": None,
|
||||
"tool_exec_environment_variables": [],
|
||||
}
|
||||
|
||||
# Optional fields: only included if requested
|
||||
optional_fields = {
|
||||
"tags": lambda: [t.tag for t in self.tags],
|
||||
"tools": lambda: self.tools,
|
||||
"sources": lambda: [s.to_pydantic() for s in self.sources],
|
||||
"memory": lambda: Memory(blocks=[b.to_pydantic() for b in self.core_memory]),
|
||||
"identity_ids": lambda: [i.id for i in self.identities],
|
||||
"multi_agent_group": lambda: self.multi_agent_group,
|
||||
"tool_exec_environment_variables": lambda: self.tool_exec_environment_variables,
|
||||
}
|
||||
|
||||
include_relationships = set(optional_fields.keys() if include_relationships is None else include_relationships)
|
||||
|
||||
for field_name in include_relationships:
|
||||
resolver = optional_fields.get(field_name)
|
||||
if resolver:
|
||||
state[field_name] = resolver()
|
||||
|
||||
return self.__pydantic_model__(**state)
|
||||
|
||||
@@ -4,12 +4,14 @@ from sqlalchemy.types import BINARY, TypeDecorator
|
||||
from letta.helpers.converters import (
|
||||
deserialize_embedding_config,
|
||||
deserialize_llm_config,
|
||||
deserialize_message_content,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_returns,
|
||||
deserialize_tool_rules,
|
||||
deserialize_vector,
|
||||
serialize_embedding_config,
|
||||
serialize_llm_config,
|
||||
serialize_message_content,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_returns,
|
||||
serialize_tool_rules,
|
||||
@@ -82,6 +84,19 @@ class ToolReturnColumn(TypeDecorator):
|
||||
return deserialize_tool_returns(value)
|
||||
|
||||
|
||||
class MessageContentColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing the content parts of a message as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_message_content(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_message_content(value)
|
||||
|
||||
|
||||
class CommonVector(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing vectors in SQLite."""
|
||||
|
||||
|
||||
33
letta/orm/group.py
Normal file
33
letta/orm/group.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.group import Group as PydanticGroup
|
||||
|
||||
|
||||
class Group(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
__tablename__ = "groups"
|
||||
__pydantic_model__ = PydanticGroup
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"group-{uuid.uuid4()}")
|
||||
description: Mapped[str] = mapped_column(nullable=False, doc="")
|
||||
manager_type: Mapped[str] = mapped_column(nullable=False, doc="")
|
||||
manager_agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id", ondelete="RESTRICT"), nullable=True, doc="")
|
||||
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
|
||||
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
|
||||
)
|
||||
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group")
|
||||
|
||||
@property
|
||||
def agent_ids(self) -> List[str]:
|
||||
return [agent.id for agent in self.agents]
|
||||
13
letta/orm/groups_agents.py
Normal file
13
letta/orm/groups_agents.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class GroupsAgents(Base):
|
||||
"""Agents may have one or many groups associated with them."""
|
||||
|
||||
__tablename__ = "groups_agents"
|
||||
|
||||
group_id: Mapped[str] = mapped_column(String, ForeignKey("groups.id", ondelete="CASCADE"), primary_key=True)
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
|
||||
@@ -4,11 +4,12 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
|
||||
from sqlalchemy import ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import ToolCallColumn, ToolReturnColumn
|
||||
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.letta_message_content import MessageContent
|
||||
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import TextContent as PydanticTextContent
|
||||
from letta.schemas.message import ToolReturn
|
||||
|
||||
|
||||
@@ -25,6 +26,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique message identifier")
|
||||
role: Mapped[str] = mapped_column(doc="Message role (user/assistant/system/tool)")
|
||||
text: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Message content")
|
||||
content: Mapped[List[MessageContent]] = mapped_column(MessageContentColumn, nullable=True, doc="Message content parts")
|
||||
model: Mapped[Optional[str]] = mapped_column(nullable=True, doc="LLM model used")
|
||||
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
|
||||
tool_calls: Mapped[List[OpenAIToolCall]] = mapped_column(ToolCallColumn, doc="Tool call information")
|
||||
@@ -36,6 +38,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
tool_returns: Mapped[List[ToolReturn]] = mapped_column(
|
||||
ToolReturnColumn, nullable=True, doc="Tool execution return information for prior tool calls"
|
||||
)
|
||||
group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in")
|
||||
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||
@@ -53,8 +56,8 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
return self.job_message.job if self.job_message else None
|
||||
|
||||
def to_pydantic(self) -> PydanticMessage:
|
||||
"""custom pydantic conversion for message content mapping"""
|
||||
"""Custom pydantic conversion to handle data using legacy text field"""
|
||||
model = self.__pydantic_model__.model_validate(self)
|
||||
if self.text:
|
||||
if self.text and not model.content:
|
||||
model.content = [PydanticTextContent(text=self.text)]
|
||||
return model
|
||||
|
||||
@@ -49,6 +49,7 @@ class Organization(SqlalchemyBase):
|
||||
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan")
|
||||
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
|
||||
identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan")
|
||||
groups: Mapped[List["Group"]] = relationship("Group", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@property
|
||||
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
||||
|
||||
@@ -139,11 +139,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
else:
|
||||
# Match ANY tag - use join and filter
|
||||
query = (
|
||||
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).group_by(cls.id)
|
||||
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
|
||||
) # Deduplicate results
|
||||
|
||||
# Group by primary key and all necessary columns to avoid JSON comparison
|
||||
query = query.group_by(cls.id)
|
||||
# select distinct primary key
|
||||
query = query.distinct(cls.id).order_by(cls.id)
|
||||
|
||||
if identifier_keys and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
||||
|
||||
152
letta/round_robin_multi_agent.py
Normal file
152
letta/round_robin_multi_agent.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
|
||||
class RoundRobinMultiAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User = None,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
description: str = "",
|
||||
max_turns: Optional[int] = None,
|
||||
):
|
||||
super().__init__(interface, agent_state, user)
|
||||
self.group_id = group_id
|
||||
self.agent_ids = agent_ids
|
||||
self.description = description
|
||||
self.max_turns = max_turns or len(agent_ids)
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
agents = {}
|
||||
for agent_id in self.agent_ids:
|
||||
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
|
||||
|
||||
message_index = {}
|
||||
chat_history: List[Message] = []
|
||||
new_messages = messages
|
||||
speaker_id = None
|
||||
try:
|
||||
for i in range(self.max_turns):
|
||||
speaker_id = self.agent_ids[i % len(self.agent_ids)]
|
||||
# initialize input messages
|
||||
start_index = message_index[speaker_id] if speaker_id in message_index else 0
|
||||
for message in chat_history[start_index:]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = speaker_id
|
||||
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=speaker_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
|
||||
# load agent and perform step
|
||||
participant_agent = agents[speaker_id]
|
||||
usage_stats = participant_agent.step(
|
||||
messages=chat_history[start_index:],
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
# parse new messages for next step
|
||||
responses = Message.to_letta_messages_from_list(participant_agent.last_response_messages)
|
||||
assistant_messages = [response for response in responses if response.message_type == "assistant_message"]
|
||||
new_messages = [
|
||||
MessageCreate(
|
||||
role="system",
|
||||
content=message.content,
|
||||
name=participant_agent.agent_state.name,
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
message_index[speaker_id] = len(chat_history) + len(new_messages)
|
||||
|
||||
# sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
total_usage.completion_tokens += usage_stats.completion_tokens
|
||||
total_usage.total_tokens += usage_stats.total_tokens
|
||||
step_count += 1
|
||||
|
||||
# persist remaining chat history
|
||||
for message in new_messages:
|
||||
chat_history.append(
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
)
|
||||
for agent_id, index in message_index.items():
|
||||
if agent_id == speaker_id:
|
||||
continue
|
||||
for message in chat_history[index:]:
|
||||
message.id = Message.generate_id()
|
||||
message.agent_id = agent_id
|
||||
self.message_manager.create_many_messages(chat_history[index:], actor=self.user)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
self.interface.step_yield()
|
||||
|
||||
self.interface.step_complete()
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
|
||||
def load_participant_agent(self, agent_id: str) -> Agent:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
|
||||
persona_block = agent_state.memory.get_block(label="persona")
|
||||
group_chat_participant_persona = (
|
||||
"\n\n====Group Chat Contex===="
|
||||
f"\nYou are speaking in a group chat with {len(self.agent_ids) - 1} other "
|
||||
"agents and one user. Respond to new messages in the group chat when prompted. "
|
||||
f"Description of the group: {self.description}"
|
||||
)
|
||||
agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_participant_persona)
|
||||
return Agent(
|
||||
agent_state=agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
save_last_response=True,
|
||||
)
|
||||
@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
@@ -90,6 +91,8 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
||||
)
|
||||
|
||||
multi_agent_group: Optional[Group] = Field(None, description="The multi-agent group that this agent manages")
|
||||
|
||||
def get_agent_env_vars_as_dict(self) -> Dict[str, str]:
|
||||
# Get environment variables for this agent specifically
|
||||
per_agent_env_vars = {}
|
||||
|
||||
@@ -9,10 +9,6 @@ class MessageRole(str, Enum):
|
||||
system = "system"
|
||||
|
||||
|
||||
class MessageContentType(str, Enum):
|
||||
text = "text"
|
||||
|
||||
|
||||
class OptionState(str, Enum):
|
||||
"""Useful for kwargs that are bool + default option"""
|
||||
|
||||
|
||||
65
letta/schemas/group.py
Normal file
65
letta/schemas/group.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class ManagerType(str, Enum):
|
||||
round_robin = "round_robin"
|
||||
supervisor = "supervisor"
|
||||
dynamic = "dynamic"
|
||||
swarm = "swarm"
|
||||
|
||||
|
||||
class GroupBase(LettaBase):
|
||||
__id_prefix__ = "group"
|
||||
|
||||
|
||||
class Group(GroupBase):
|
||||
id: str = Field(..., description="The id of the group. Assigned by the database.")
|
||||
manager_type: ManagerType = Field(..., description="")
|
||||
agent_ids: List[str] = Field(..., description="")
|
||||
description: str = Field(..., description="")
|
||||
# Pattern fields
|
||||
manager_agent_id: Optional[str] = Field(None, description="")
|
||||
termination_token: Optional[str] = Field(None, description="")
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
class ManagerConfig(BaseModel):
|
||||
manager_type: ManagerType = Field(..., description="")
|
||||
|
||||
|
||||
class RoundRobinManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.round_robin] = Field(ManagerType.round_robin, description="")
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
class SupervisorManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.supervisor] = Field(ManagerType.supervisor, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
|
||||
|
||||
class DynamicManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.dynamic] = Field(ManagerType.dynamic, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
termination_token: Optional[str] = Field("DONE!", description="")
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
# class SwarmGroup(ManagerConfig):
|
||||
# manager_type: Literal[ManagerType.swarm] = Field(ManagerType.swarm, description="")
|
||||
|
||||
|
||||
ManagerConfigUnion = Annotated[
|
||||
Union[RoundRobinManager, SupervisorManager, DynamicManager],
|
||||
Field(discriminator="manager_type"),
|
||||
]
|
||||
|
||||
|
||||
class GroupCreate(BaseModel):
|
||||
agent_ids: List[str] = Field(..., description="")
|
||||
description: str = Field(..., description="")
|
||||
manager_config: Optional[ManagerConfigUnion] = Field(None, description="")
|
||||
@@ -4,109 +4,132 @@ from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
from letta.schemas.enums import MessageContentType
|
||||
from letta.schemas.letta_message_content import (
|
||||
LettaAssistantMessageContentUnion,
|
||||
LettaUserMessageContentUnion,
|
||||
get_letta_assistant_message_content_union_str_json_schema,
|
||||
get_letta_user_message_content_union_str_json_schema,
|
||||
)
|
||||
|
||||
# Letta API style responses (intended to be easier to use vs getting true Message types)
|
||||
# ---------------------------
|
||||
# Letta API Messaging Schemas
|
||||
# ---------------------------
|
||||
|
||||
|
||||
class LettaMessage(BaseModel):
|
||||
"""
|
||||
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, tool calls, and tool returns in a simplified format that does not include additional information other than the content and timestamp.
|
||||
Base class for simplified Letta message response type. This is intended to be used for developers
|
||||
who want the internal monologue, tool calls, and tool returns in a simplified format that does not
|
||||
include additional information other than the content and timestamp.
|
||||
|
||||
Attributes:
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
"""
|
||||
|
||||
# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
||||
# see `message_type` attribute
|
||||
|
||||
id: str
|
||||
date: datetime
|
||||
name: Optional[str] = None
|
||||
|
||||
@field_serializer("date")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
"""
|
||||
Remove microseconds since it seems like we're inconsistent with getting them
|
||||
TODO: figure out why we don't always get microseconds (get_utc_time() does)
|
||||
"""
|
||||
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
# Remove microseconds since it seems like we're inconsistent with getting them
|
||||
# TODO figure out why we don't always get microseconds (get_utc_time() does)
|
||||
return dt.isoformat(timespec="seconds")
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: MessageContentType = Field(..., description="The type of the message.")
|
||||
|
||||
|
||||
class TextContent(MessageContent):
|
||||
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
|
||||
text: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
MessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class SystemMessage(LettaMessage):
|
||||
"""
|
||||
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
|
||||
|
||||
Attributes:
|
||||
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
content (str): The message content sent by the system
|
||||
"""
|
||||
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
content: str = Field(..., description="The message content sent by the system")
|
||||
|
||||
|
||||
class UserMessage(LettaMessage):
|
||||
"""
|
||||
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
|
||||
|
||||
Attributes:
|
||||
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
content (Union[str, List[LettaUserMessageContentUnion]]): The message content sent by the user (can be a string or an array of multi-modal content parts)
|
||||
"""
|
||||
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
|
||||
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
class ReasoningMessage(LettaMessage):
|
||||
"""
|
||||
Representation of an agent's internal reasoning.
|
||||
|
||||
Attributes:
|
||||
reasoning (str): The internal reasoning of the agent
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
source (Literal["reasoner_model", "non_reasoner_model"]): Whether the reasoning
|
||||
content was generated natively by a reasoner model or derived via prompting
|
||||
reasoning (str): The internal reasoning of the agent
|
||||
"""
|
||||
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
source: Literal["reasoner_model", "non_reasoner_model"] = "non_reasoner_model"
|
||||
reasoning: str
|
||||
|
||||
|
||||
class HiddenReasoningMessage(LettaMessage):
|
||||
"""
|
||||
Representation of an agent's internal reasoning where reasoning content
|
||||
has been hidden from the response.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
state (Literal["redacted", "omitted"]): Whether the reasoning
|
||||
content was redacted by the provider or simply omitted by the API
|
||||
reasoning (str): The internal reasoning of the agent
|
||||
"""
|
||||
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
state: Literal["redacted", "omitted"]
|
||||
reasoning: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class ToolCallDelta(BaseModel):
|
||||
|
||||
name: Optional[str]
|
||||
arguments: Optional[str]
|
||||
tool_call_id: Optional[str]
|
||||
|
||||
# NOTE: this is a workaround to exclude None values from the JSON dump,
|
||||
# since the OpenAI style of returning chunks doesn't include keys with null values
|
||||
def model_dump(self, *args, **kwargs):
|
||||
"""
|
||||
This is a workaround to exclude None values from the JSON dump since the
|
||||
OpenAI style of returning chunks doesn't include keys with null values.
|
||||
"""
|
||||
kwargs["exclude_none"] = True
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
@@ -118,17 +141,20 @@ class ToolCallMessage(LettaMessage):
|
||||
"""
|
||||
A message representing a request to call a tool (generated by the LLM to trigger tool execution).
|
||||
|
||||
Attributes:
|
||||
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
|
||||
"""
|
||||
|
||||
message_type: Literal["tool_call_message"] = "tool_call_message"
|
||||
tool_call: Union[ToolCall, ToolCallDelta]
|
||||
|
||||
# NOTE: this is required for the ToolCallDelta exclude_none to work correctly
|
||||
def model_dump(self, *args, **kwargs):
|
||||
"""
|
||||
Handling for the ToolCallDelta exclude_none to work correctly
|
||||
"""
|
||||
kwargs["exclude_none"] = True
|
||||
data = super().model_dump(*args, **kwargs)
|
||||
if isinstance(data["tool_call"], dict):
|
||||
@@ -141,12 +167,14 @@ class ToolCallMessage(LettaMessage):
|
||||
ToolCall: lambda v: v.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
# NOTE: this is required to cast dicts into ToolCallMessage objects
|
||||
# Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None
|
||||
# (instead of properly casting to ToolCallDelta instead of ToolCall)
|
||||
@field_validator("tool_call", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_call(cls, v):
|
||||
"""
|
||||
Casts dicts into ToolCallMessage objects. Without this extra validator, Pydantic will throw
|
||||
an error if 'name' or 'arguments' are None instead of properly casting to ToolCallDelta
|
||||
instead of ToolCall.
|
||||
"""
|
||||
if isinstance(v, dict):
|
||||
if "name" in v and "arguments" in v and "tool_call_id" in v:
|
||||
return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
|
||||
@@ -161,11 +189,12 @@ class ToolReturnMessage(LettaMessage):
|
||||
"""
|
||||
A message representing the return value of a tool call (generated by Letta executing the requested tool).
|
||||
|
||||
Attributes:
|
||||
tool_return (str): The return value of the tool
|
||||
status (Literal["success", "error"]): The status of the tool call
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
tool_return (str): The return value of the tool
|
||||
status (Literal["success", "error"]): The status of the tool call
|
||||
tool_call_id (str): A unique identifier for the tool call that generated this message
|
||||
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
|
||||
stderr (Optional[List(str)]): Captured stderr from the tool invocation
|
||||
@@ -179,89 +208,32 @@ class ToolReturnMessage(LettaMessage):
|
||||
stderr: Optional[List[str]] = None
|
||||
|
||||
|
||||
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
|
||||
|
||||
|
||||
class AssistantMessage(LettaMessage):
|
||||
"""
|
||||
A message sent by the LLM in response to user input. Used in the LLM context.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
|
||||
"""
|
||||
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
|
||||
|
||||
class LegacyFunctionCallMessage(LettaMessage):
|
||||
function_call: str
|
||||
|
||||
|
||||
class LegacyFunctionReturn(LettaMessage):
|
||||
"""
|
||||
A message representing the return value of a function call (generated by Letta executing the requested function).
|
||||
|
||||
Attributes:
|
||||
function_return (str): The return value of the function
|
||||
status (Literal["success", "error"]): The status of the function call
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
function_call_id (str): A unique identifier for the function call that generated this message
|
||||
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
|
||||
stderr (Optional[List(str)]): Captured stderr from the function invocation
|
||||
"""
|
||||
|
||||
message_type: Literal["function_return"] = "function_return"
|
||||
function_return: str
|
||||
status: Literal["success", "error"]
|
||||
function_call_id: str
|
||||
stdout: Optional[List[str]] = None
|
||||
stderr: Optional[List[str]] = None
|
||||
|
||||
|
||||
class LegacyInternalMonologue(LettaMessage):
|
||||
"""
|
||||
Representation of an agent's internal monologue.
|
||||
|
||||
Attributes:
|
||||
internal_monologue (str): The internal monologue of the agent
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
"""
|
||||
|
||||
message_type: Literal["internal_monologue"] = "internal_monologue"
|
||||
internal_monologue: str
|
||||
|
||||
|
||||
LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
|
||||
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the agent (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
||||
LettaMessageUnion = Annotated[
|
||||
Union[SystemMessage, UserMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage],
|
||||
Field(discriminator="message_type"),
|
||||
]
|
||||
|
||||
|
||||
class UpdateSystemMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
|
||||
|
||||
class UpdateUserMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
|
||||
|
||||
class UpdateReasoningMessage(BaseModel):
|
||||
reasoning: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
|
||||
|
||||
class UpdateAssistantMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
|
||||
|
||||
LettaMessageUpdateUnion = Annotated[
|
||||
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
|
||||
Field(discriminator="message_type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_message_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
@@ -284,3 +256,92 @@ def create_letta_message_union_schema():
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Message Update API Schemas
|
||||
# --------------------------
|
||||
|
||||
|
||||
class UpdateSystemMessage(BaseModel):
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
content: str = Field(
|
||||
..., description="The message content sent by the system (can be a string or an array of multi-modal content parts)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateUserMessage(BaseModel):
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
content: Union[str, List[LettaUserMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the user (can be a string or an array of multi-modal content parts)",
|
||||
json_schema_extra=get_letta_user_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
class UpdateReasoningMessage(BaseModel):
|
||||
reasoning: str
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
|
||||
|
||||
class UpdateAssistantMessage(BaseModel):
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
content: Union[str, List[LettaAssistantMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the assistant (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_assistant_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
LettaMessageUpdateUnion = Annotated[
|
||||
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
|
||||
Field(discriminator="message_type"),
|
||||
]
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Deprecated Message Schemas
|
||||
# --------------------------
|
||||
|
||||
|
||||
class LegacyFunctionCallMessage(LettaMessage):
|
||||
function_call: str
|
||||
|
||||
|
||||
class LegacyFunctionReturn(LettaMessage):
|
||||
"""
|
||||
A message representing the return value of a function call (generated by Letta executing the requested function).
|
||||
|
||||
Args:
|
||||
function_return (str): The return value of the function
|
||||
status (Literal["success", "error"]): The status of the function call
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
function_call_id (str): A unique identifier for the function call that generated this message
|
||||
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
|
||||
stderr (Optional[List(str)]): Captured stderr from the function invocation
|
||||
"""
|
||||
|
||||
message_type: Literal["function_return"] = "function_return"
|
||||
function_return: str
|
||||
status: Literal["success", "error"]
|
||||
function_call_id: str
|
||||
stdout: Optional[List[str]] = None
|
||||
stderr: Optional[List[str]] = None
|
||||
|
||||
|
||||
class LegacyInternalMonologue(LettaMessage):
|
||||
"""
|
||||
Representation of an agent's internal monologue.
|
||||
|
||||
Args:
|
||||
internal_monologue (str): The internal monologue of the agent
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
"""
|
||||
|
||||
message_type: Literal["internal_monologue"] = "internal_monologue"
|
||||
internal_monologue: str
|
||||
|
||||
|
||||
LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
|
||||
|
||||
192
letta/schemas/letta_message_content.py
Normal file
192
letta/schemas/letta_message_content.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageContentType(str, Enum):
|
||||
text = "text"
|
||||
tool_call = "tool_call"
|
||||
tool_return = "tool_return"
|
||||
reasoning = "reasoning"
|
||||
redacted_reasoning = "redacted_reasoning"
|
||||
omitted_reasoning = "omitted_reasoning"
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: MessageContentType = Field(..., description="The type of the message.")
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# User Content Types
|
||||
# -------------------------------
|
||||
|
||||
|
||||
class TextContent(MessageContent):
|
||||
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
|
||||
text: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
LettaUserMessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_user_message_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/TextContent",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_letta_user_message_content_union_str_json_schema():
|
||||
return {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/LettaUserMessageContentUnion",
|
||||
},
|
||||
},
|
||||
{"type": "string"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Assistant Content Types
|
||||
# -------------------------------
|
||||
|
||||
|
||||
LettaAssistantMessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_assistant_message_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/TextContent",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_letta_assistant_message_content_union_str_json_schema():
|
||||
return {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/LettaAssistantMessageContentUnion",
|
||||
},
|
||||
},
|
||||
{"type": "string"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Intermediate Step Content Types
|
||||
# -------------------------------
|
||||
|
||||
|
||||
class ToolCallContent(MessageContent):
|
||||
type: Literal[MessageContentType.tool_call] = Field(
|
||||
MessageContentType.tool_call, description="Indicates this content represents a tool call event."
|
||||
)
|
||||
id: str = Field(..., description="A unique identifier for this specific tool call instance.")
|
||||
name: str = Field(..., description="The name of the tool being called.")
|
||||
input: dict = Field(
|
||||
..., description="The parameters being passed to the tool, structured as a dictionary of parameter names to values."
|
||||
)
|
||||
|
||||
|
||||
class ToolReturnContent(MessageContent):
|
||||
type: Literal[MessageContentType.tool_return] = Field(
|
||||
MessageContentType.tool_return, description="Indicates this content represents a tool return event."
|
||||
)
|
||||
tool_call_id: str = Field(..., description="References the ID of the ToolCallContent that initiated this tool call.")
|
||||
content: str = Field(..., description="The content returned by the tool execution.")
|
||||
is_error: bool = Field(..., description="Indicates whether the tool execution resulted in an error.")
|
||||
|
||||
|
||||
class ReasoningContent(MessageContent):
|
||||
type: Literal[MessageContentType.reasoning] = Field(
|
||||
MessageContentType.reasoning, description="Indicates this is a reasoning/intermediate step."
|
||||
)
|
||||
is_native: bool = Field(..., description="Whether the reasoning content was generated by a reasoner model that processed this step.")
|
||||
reasoning: str = Field(..., description="The intermediate reasoning or thought process content.")
|
||||
signature: Optional[str] = Field(None, description="A unique identifier for this reasoning step.")
|
||||
|
||||
|
||||
class RedactedReasoningContent(MessageContent):
|
||||
type: Literal[MessageContentType.redacted_reasoning] = Field(
|
||||
MessageContentType.redacted_reasoning, description="Indicates this is a redacted thinking step."
|
||||
)
|
||||
data: str = Field(..., description="The redacted or filtered intermediate reasoning content.")
|
||||
|
||||
|
||||
class OmittedReasoningContent(MessageContent):
|
||||
type: Literal[MessageContentType.omitted_reasoning] = Field(
|
||||
MessageContentType.omitted_reasoning, description="Indicates this is an omitted reasoning step."
|
||||
)
|
||||
tokens: int = Field(..., description="The reasoning token count for intermediate reasoning content.")
|
||||
|
||||
|
||||
LettaMessageContentUnion = Annotated[
|
||||
Union[TextContent, ToolCallContent, ToolReturnContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_message_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
{"$ref": "#/components/schemas/ToolCallContent"},
|
||||
{"$ref": "#/components/schemas/ToolReturnContent"},
|
||||
{"$ref": "#/components/schemas/ReasoningContent"},
|
||||
{"$ref": "#/components/schemas/RedactedReasoningContent"},
|
||||
{"$ref": "#/components/schemas/OmittedReasoningContent"},
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/TextContent",
|
||||
"tool_call": "#/components/schemas/ToolCallContent",
|
||||
"tool_return": "#/components/schemas/ToolCallContent",
|
||||
"reasoning": "#/components/schemas/ReasoningContent",
|
||||
"redacted_reasoning": "#/components/schemas/RedactedReasoningContent",
|
||||
"omitted_reasoning": "#/components/schemas/OmittedReasoningContent",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_letta_message_content_union_str_json_schema():
|
||||
return {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/LettaMessageContentUnion",
|
||||
},
|
||||
},
|
||||
{"type": "string"},
|
||||
],
|
||||
}
|
||||
@@ -9,26 +9,25 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
|
||||
from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.enums import MessageContentType, MessageRole
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
LettaMessage,
|
||||
MessageContentUnion,
|
||||
ReasoningMessage,
|
||||
SystemMessage,
|
||||
TextContent,
|
||||
ToolCall,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.letta_message_content import LettaMessageContentUnion, TextContent, get_letta_message_content_union_str_json_schema
|
||||
from letta.system import unpack_message
|
||||
|
||||
|
||||
@@ -66,15 +65,30 @@ class MessageCreate(BaseModel):
|
||||
MessageRole.user,
|
||||
MessageRole.system,
|
||||
] = Field(..., description="The role of the participant.")
|
||||
content: Union[str, List[MessageContentUnion]] = Field(..., description="The content of the message.")
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The content of the message.",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
data = super().model_dump(**kwargs)
|
||||
if to_orm and "content" in data:
|
||||
if isinstance(data["content"], str):
|
||||
data["content"] = [TextContent(text=data["content"])]
|
||||
return data
|
||||
|
||||
|
||||
class MessageUpdate(BaseModel):
|
||||
"""Request to update a message"""
|
||||
|
||||
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
||||
content: Optional[Union[str, List[MessageContentUnion]]] = Field(None, description="The content of the message.")
|
||||
content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(
|
||||
None,
|
||||
description="The content of the message.",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
||||
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
||||
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
@@ -90,12 +104,7 @@ class MessageUpdate(BaseModel):
|
||||
data = super().model_dump(**kwargs)
|
||||
if to_orm and "content" in data:
|
||||
if isinstance(data["content"], str):
|
||||
data["text"] = data["content"]
|
||||
else:
|
||||
for content in data["content"]:
|
||||
if content["type"] == "text":
|
||||
data["text"] = content["text"]
|
||||
del data["content"]
|
||||
data["content"] = [TextContent(text=data["content"])]
|
||||
return data
|
||||
|
||||
|
||||
@@ -119,7 +128,7 @@ class Message(BaseMessage):
|
||||
|
||||
id: str = BaseMessage.generate_id_field()
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
content: Optional[List[MessageContentUnion]] = Field(None, description="The content of the message.")
|
||||
content: Optional[List[LettaMessageContentUnion]] = Field(None, description="The content of the message.")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
model: Optional[str] = Field(None, description="The model used to make the function call.")
|
||||
@@ -129,6 +138,7 @@ class Message(BaseMessage):
|
||||
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
|
||||
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
|
||||
tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls")
|
||||
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
|
||||
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
@@ -140,24 +150,6 @@ class Message(BaseMessage):
|
||||
assert v in roles, f"Role must be one of {roles}"
|
||||
return v
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def convert_from_orm(cls, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if isinstance(data, dict):
|
||||
if "text" in data and "content" not in data:
|
||||
data["content"] = [TextContent(text=data["text"])]
|
||||
del data["text"]
|
||||
return data
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
data = super().model_dump(**kwargs)
|
||||
if to_orm:
|
||||
for content in data["content"]:
|
||||
if content["type"] == "text":
|
||||
data["text"] = content["text"]
|
||||
del data["content"]
|
||||
return data
|
||||
|
||||
def to_json(self):
|
||||
json_message = vars(self)
|
||||
if json_message["tool_calls"] is not None:
|
||||
@@ -214,7 +206,7 @@ class Message(BaseMessage):
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> List[LettaMessage]:
|
||||
"""Convert message object (in DB format) to the style used by the original Letta API"""
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
@@ -485,7 +477,7 @@ class Message(BaseMessage):
|
||||
"""Go from Message class to ChatCompletion message object"""
|
||||
|
||||
# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
@@ -560,7 +552,7 @@ class Message(BaseMessage):
|
||||
Args:
|
||||
inner_thoughts_xml_tag (str): The XML tag to wrap around inner thoughts
|
||||
"""
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
@@ -655,7 +647,7 @@ class Message(BaseMessage):
|
||||
# type Content: https://ai.google.dev/api/rest/v1/Content / https://ai.google.dev/api/rest/v1beta/Content
|
||||
# parts[]: Part
|
||||
# role: str ('user' or 'model')
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
@@ -781,7 +773,7 @@ class Message(BaseMessage):
|
||||
|
||||
# TODO: update this prompt style once guidance from Cohere on
|
||||
# embedded function calls in multi-turn conversation become more clear
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
|
||||
@@ -1 +1 @@
|
||||
from letta.serialize_schemas.agent import SerializedAgentSchema
|
||||
from letta.serialize_schemas.marshmallow_agent import MarshmallowAgentSchema
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import fields, post_dump
|
||||
|
||||
from letta.orm import Agent
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.block import SerializedBlockSchema
|
||||
from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
|
||||
from letta.serialize_schemas.message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.tag import SerializedAgentTagSchema
|
||||
from letta.serialize_schemas.tool import SerializedToolSchema
|
||||
from letta.server.db import SessionLocal
|
||||
|
||||
|
||||
class SerializedAgentSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Agent objects.
|
||||
Excludes relational fields.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = PydanticAgentState
|
||||
|
||||
llm_config = LLMConfigField()
|
||||
embedding_config = EmbeddingConfigField()
|
||||
tool_rules = ToolRulesField()
|
||||
|
||||
messages = fields.List(fields.Nested(SerializedMessageSchema))
|
||||
core_memory = fields.List(fields.Nested(SerializedBlockSchema))
|
||||
tools = fields.List(fields.Nested(SerializedToolSchema))
|
||||
tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema))
|
||||
tags = fields.List(fields.Nested(SerializedAgentTagSchema))
|
||||
|
||||
def __init__(self, *args, session: SessionLocal, actor: User, **kwargs):
|
||||
super().__init__(*args, actor=actor, **kwargs)
|
||||
self.session = session
|
||||
|
||||
# Propagate session and actor to nested schemas automatically
|
||||
for field in self.fields.values():
|
||||
if isinstance(field, fields.List) and isinstance(field.inner, fields.Nested):
|
||||
field.inner.schema.session = session
|
||||
field.inner.schema.actor = actor
|
||||
elif isinstance(field, fields.Nested):
|
||||
field.schema.session = session
|
||||
field.schema.actor = actor
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
data = super().sanitize_ids(data, **kwargs)
|
||||
|
||||
# Remap IDs of messages
|
||||
# Need to do this in post, so we can correctly map the in-context message IDs
|
||||
# TODO: Remap message_ids to reference objects, not just be a list
|
||||
id_remapping = dict()
|
||||
for message in data.get("messages"):
|
||||
message_id = message.get("id")
|
||||
if message_id not in id_remapping:
|
||||
id_remapping[message_id] = SerializedMessageSchema.__pydantic_model__.generate_id()
|
||||
message["id"] = id_remapping[message_id]
|
||||
else:
|
||||
raise ValueError(f"Duplicate message IDs in agent.messages: {message_id}")
|
||||
|
||||
# Remap in context message ids
|
||||
data["message_ids"] = [id_remapping[message_id] for message_id in data.get("message_ids")]
|
||||
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
# TODO: Serialize these as well...
|
||||
exclude = BaseSchema.Meta.exclude + (
|
||||
"project_id",
|
||||
"template_id",
|
||||
"base_template_id",
|
||||
"sources",
|
||||
"source_passages",
|
||||
"agent_passages",
|
||||
)
|
||||
@@ -1,64 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from marshmallow import post_dump, pre_load
|
||||
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
|
||||
from sqlalchemy.inspection import inspect
|
||||
|
||||
from letta.schemas.user import User
|
||||
|
||||
|
||||
class BaseSchema(SQLAlchemyAutoSchema):
|
||||
"""
|
||||
Base schema for all SQLAlchemy models.
|
||||
This ensures all schemas share the same session.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = None
|
||||
sensitive_ids = {"_created_by_id", "_last_updated_by_id"}
|
||||
sensitive_relationships = {"organization"}
|
||||
id_scramble_placeholder = "xxx"
|
||||
|
||||
def __init__(self, *args, actor: Optional[User] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.actor = actor
|
||||
|
||||
def generate_id(self) -> Optional[str]:
|
||||
if self.__pydantic_model__:
|
||||
return self.__pydantic_model__.generate_id()
|
||||
|
||||
return None
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
if self.Meta.model:
|
||||
mapper = inspect(self.Meta.model)
|
||||
if "id" in mapper.columns:
|
||||
generated_id = self.generate_id()
|
||||
if generated_id:
|
||||
data["id"] = generated_id
|
||||
|
||||
for sensitive_id in BaseSchema.sensitive_ids.union(BaseSchema.sensitive_relationships):
|
||||
if sensitive_id in data:
|
||||
data[sensitive_id] = BaseSchema.id_scramble_placeholder
|
||||
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
if self.Meta.model:
|
||||
mapper = inspect(self.Meta.model)
|
||||
for sensitive_id in BaseSchema.sensitive_ids:
|
||||
if sensitive_id in mapper.columns:
|
||||
data[sensitive_id] = self.actor.id
|
||||
|
||||
for relationship in BaseSchema.sensitive_relationships:
|
||||
if relationship in mapper.relationships:
|
||||
data[relationship] = self.actor.organization_id
|
||||
|
||||
return data
|
||||
|
||||
class Meta:
|
||||
model = None
|
||||
include_relationships = True
|
||||
load_instance = True
|
||||
exclude = ()
|
||||
108
letta/serialize_schemas/marshmallow_agent.py
Normal file
108
letta/serialize_schemas/marshmallow_agent.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import fields, post_dump, pre_load
|
||||
|
||||
import letta
|
||||
from letta.orm import Agent
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.marshmallow_agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_block import SerializedBlockSchema
|
||||
from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
|
||||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema
|
||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||
from letta.server.db import SessionLocal
|
||||
|
||||
|
||||
class MarshmallowAgentSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Agent objects.
|
||||
Excludes relational fields.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = PydanticAgentState
|
||||
|
||||
FIELD_VERSION = "version"
|
||||
FIELD_MESSAGES = "messages"
|
||||
FIELD_MESSAGE_IDS = "message_ids"
|
||||
FIELD_IN_CONTEXT = "in_context"
|
||||
FIELD_ID = "id"
|
||||
|
||||
llm_config = LLMConfigField()
|
||||
embedding_config = EmbeddingConfigField()
|
||||
tool_rules = ToolRulesField()
|
||||
|
||||
messages = fields.List(fields.Nested(SerializedMessageSchema))
|
||||
core_memory = fields.List(fields.Nested(SerializedBlockSchema))
|
||||
tools = fields.List(fields.Nested(SerializedToolSchema))
|
||||
tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema))
|
||||
tags = fields.List(fields.Nested(SerializedAgentTagSchema))
|
||||
|
||||
def __init__(self, *args, session: SessionLocal, actor: User, **kwargs):
|
||||
super().__init__(*args, actor=actor, **kwargs)
|
||||
self.session = session
|
||||
|
||||
# Propagate session and actor to nested schemas automatically
|
||||
for field in self.fields.values():
|
||||
if isinstance(field, fields.List) and isinstance(field.inner, fields.Nested):
|
||||
field.inner.schema.session = session
|
||||
field.inner.schema.actor = actor
|
||||
elif isinstance(field, fields.Nested):
|
||||
field.schema.session = session
|
||||
field.schema.actor = actor
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
"""
|
||||
- Removes `message_ids`
|
||||
- Adds versioning
|
||||
- Marks messages as in-context
|
||||
- Removes individual message `id` fields
|
||||
"""
|
||||
data = super().sanitize_ids(data, **kwargs)
|
||||
data[self.FIELD_VERSION] = letta.__version__
|
||||
|
||||
message_ids = set(data.pop(self.FIELD_MESSAGE_IDS, [])) # Store and remove message_ids
|
||||
|
||||
for message in data.get(self.FIELD_MESSAGES, []):
|
||||
message[self.FIELD_IN_CONTEXT] = message[self.FIELD_ID] in message_ids # Mark messages as in-context
|
||||
message.pop(self.FIELD_ID, None) # Remove the id field
|
||||
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def check_version(self, data, **kwargs):
|
||||
"""Check version and remove it from the schema"""
|
||||
version = data[self.FIELD_VERSION]
|
||||
if version != letta.__version__:
|
||||
print(f"Version mismatch: expected {letta.__version__}, got {version}")
|
||||
del data[self.FIELD_VERSION]
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def remap_in_context_messages(self, data, **kwargs):
|
||||
"""
|
||||
Restores `message_ids` by collecting message IDs where `in_context` is True,
|
||||
generates new IDs for all messages, and removes `in_context` from all messages.
|
||||
"""
|
||||
message_ids = []
|
||||
for msg in data.get(self.FIELD_MESSAGES, []):
|
||||
msg[self.FIELD_ID] = SerializedMessageSchema.generate_id() # Generate new ID
|
||||
if msg.pop(self.FIELD_IN_CONTEXT, False): # If it was in-context, track its new ID
|
||||
message_ids.append(msg[self.FIELD_ID])
|
||||
|
||||
data[self.FIELD_MESSAGE_IDS] = message_ids
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
exclude = BaseSchema.Meta.exclude + (
|
||||
"project_id",
|
||||
"template_id",
|
||||
"base_template_id",
|
||||
"sources",
|
||||
"source_passages",
|
||||
"agent_passages",
|
||||
)
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedAgentEnvironmentVariableSchema(BaseSchema):
|
||||
52
letta/serialize_schemas/marshmallow_base.py
Normal file
52
letta/serialize_schemas/marshmallow_base.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from marshmallow import post_dump, pre_load
|
||||
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
|
||||
|
||||
from letta.schemas.user import User
|
||||
|
||||
|
||||
class BaseSchema(SQLAlchemyAutoSchema):
|
||||
"""
|
||||
Base schema for all SQLAlchemy models.
|
||||
This ensures all schemas share the same session.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = None
|
||||
|
||||
def __init__(self, *args, actor: Optional[User] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.actor = actor
|
||||
|
||||
@classmethod
|
||||
def generate_id(cls) -> Optional[str]:
|
||||
if cls.__pydantic_model__:
|
||||
return cls.__pydantic_model__.generate_id()
|
||||
|
||||
return None
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
# delete id
|
||||
del data["id"]
|
||||
del data["_created_by_id"]
|
||||
del data["_last_updated_by_id"]
|
||||
del data["organization"]
|
||||
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
if self.Meta.model:
|
||||
data["id"] = self.generate_id()
|
||||
data["_created_by_id"] = self.actor.id
|
||||
data["_last_updated_by_id"] = self.actor.id
|
||||
data["organization"] = self.actor.organization_id
|
||||
|
||||
return data
|
||||
|
||||
class Meta:
|
||||
model = None
|
||||
include_relationships = True
|
||||
load_instance = True
|
||||
exclude = ()
|
||||
@@ -1,6 +1,6 @@
|
||||
from letta.orm.block import Block
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedBlockSchema(BaseSchema):
|
||||
@@ -3,10 +3,12 @@ from marshmallow import fields
|
||||
from letta.helpers.converters import (
|
||||
deserialize_embedding_config,
|
||||
deserialize_llm_config,
|
||||
deserialize_message_content,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_rules,
|
||||
serialize_embedding_config,
|
||||
serialize_llm_config,
|
||||
serialize_message_content,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_rules,
|
||||
)
|
||||
@@ -67,3 +69,13 @@ class ToolCallField(fields.Field):
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_tool_calls(value)
|
||||
|
||||
|
||||
class MessageContentField(fields.Field):
|
||||
"""Marshmallow field for handling a list of Message Content Part objects."""
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_message_content(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_message_content(value)
|
||||
42
letta/serialize_schemas/marshmallow_message.py
Normal file
42
letta/serialize_schemas/marshmallow_message.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import post_dump, pre_load
|
||||
|
||||
from letta.orm.message import Message
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_custom_fields import ToolCallField
|
||||
|
||||
|
||||
class SerializedMessageSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Message objects.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = PydanticMessage
|
||||
|
||||
tool_calls = ToolCallField()
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
# keep id for remapping later on agent dump
|
||||
# agent dump will then get rid of message ids
|
||||
del data["_created_by_id"]
|
||||
del data["_last_updated_by_id"]
|
||||
del data["organization"]
|
||||
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
if self.Meta.model:
|
||||
# Skip regenerating ID, as agent dump will do it
|
||||
data["_created_by_id"] = self.actor.id
|
||||
data["_last_updated_by_id"] = self.actor.id
|
||||
data["organization"] = self.actor.organization_id
|
||||
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Message
|
||||
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "agent", "otid", "is_deleted")
|
||||
@@ -1,7 +1,9 @@
|
||||
from marshmallow import fields
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import fields, post_dump, pre_load
|
||||
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedAgentTagSchema(BaseSchema):
|
||||
@@ -13,6 +15,14 @@ class SerializedAgentTagSchema(BaseSchema):
|
||||
|
||||
tag = fields.String(required=True)
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
return data
|
||||
|
||||
@pre_load
|
||||
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = AgentsTags
|
||||
exclude = BaseSchema.Meta.exclude + ("agent",)
|
||||
@@ -1,6 +1,6 @@
|
||||
from letta.orm import Tool
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedToolSchema(BaseSchema):
|
||||
@@ -1,29 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import post_dump
|
||||
|
||||
from letta.orm.message import Message
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.custom_fields import ToolCallField
|
||||
|
||||
|
||||
class SerializedMessageSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Message objects.
|
||||
"""
|
||||
|
||||
__pydantic_model__ = PydanticMessage
|
||||
|
||||
tool_calls = ToolCallField()
|
||||
|
||||
@post_dump
|
||||
def sanitize_ids(self, data: Dict, **kwargs):
|
||||
# We don't want to remap here
|
||||
# Because of the way that message_ids is just a JSON field on agents
|
||||
# We need to wait for the agent dumps, and then keep track of all the message IDs we remapped
|
||||
return data
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Message
|
||||
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "agent")
|
||||
111
letta/serialize_schemas/pydantic_agent_schema.py
Normal file
111
letta/serialize_schemas/pydantic_agent_schema.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
class CoreMemoryBlockSchema(BaseModel):
|
||||
created_at: str
|
||||
description: Optional[str]
|
||||
identities: List[Any]
|
||||
is_deleted: bool
|
||||
is_template: bool
|
||||
label: str
|
||||
limit: int
|
||||
metadata_: Dict[str, Any] = Field(default_factory=dict)
|
||||
template_name: Optional[str]
|
||||
updated_at: str
|
||||
value: str
|
||||
|
||||
|
||||
class MessageSchema(BaseModel):
|
||||
created_at: str
|
||||
group_id: Optional[str]
|
||||
in_context: bool
|
||||
model: Optional[str]
|
||||
name: Optional[str]
|
||||
role: str
|
||||
content: List[TextContent] # TODO: Expand to more in the future
|
||||
tool_call_id: Optional[str]
|
||||
tool_calls: List[Any]
|
||||
tool_returns: List[Any]
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TagSchema(BaseModel):
|
||||
tag: str
|
||||
|
||||
|
||||
class ToolEnvVarSchema(BaseModel):
|
||||
created_at: str
|
||||
description: Optional[str]
|
||||
is_deleted: bool
|
||||
key: str
|
||||
updated_at: str
|
||||
value: str
|
||||
|
||||
|
||||
class ToolRuleSchema(BaseModel):
|
||||
tool_name: str
|
||||
type: str
|
||||
|
||||
|
||||
class ParameterProperties(BaseModel):
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ParametersSchema(BaseModel):
|
||||
type: Optional[str] = "object"
|
||||
properties: Dict[str, ParameterProperties]
|
||||
required: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolJSONSchema(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: ParametersSchema # <— nested strong typing
|
||||
type: Optional[str] = None # top-level 'type' if it exists
|
||||
required: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolSchema(BaseModel):
|
||||
args_json_schema: Optional[Any]
|
||||
created_at: str
|
||||
description: str
|
||||
is_deleted: bool
|
||||
json_schema: ToolJSONSchema
|
||||
name: str
|
||||
return_char_limit: int
|
||||
source_code: Optional[str]
|
||||
source_type: str
|
||||
tags: List[str]
|
||||
tool_type: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class AgentSchema(BaseModel):
|
||||
agent_type: str
|
||||
core_memory: List[CoreMemoryBlockSchema]
|
||||
created_at: str
|
||||
description: str
|
||||
embedding_config: EmbeddingConfig
|
||||
groups: List[Any]
|
||||
identities: List[Any]
|
||||
is_deleted: bool
|
||||
llm_config: LLMConfig
|
||||
message_buffer_autoclear: bool
|
||||
messages: List[MessageSchema]
|
||||
metadata_: Dict
|
||||
multi_agent_group: Optional[Any]
|
||||
name: str
|
||||
system: str
|
||||
tags: List[TagSchema]
|
||||
tool_exec_environment_variables: List[ToolEnvVarSchema]
|
||||
tool_rules: List[ToolRuleSchema]
|
||||
tools: List[ToolSchema]
|
||||
updated_at: str
|
||||
version: str
|
||||
@@ -17,6 +17,11 @@ from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaU
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import create_letta_message_union_schema
|
||||
from letta.schemas.letta_message_content import (
|
||||
create_letta_assistant_message_content_union_schema,
|
||||
create_letta_message_content_union_schema,
|
||||
create_letta_user_message_content_union_schema,
|
||||
)
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
|
||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||
@@ -68,6 +73,10 @@ def generate_openapi_schema(app: FastAPI):
|
||||
letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")}
|
||||
letta_docs["info"]["title"] = "Letta API"
|
||||
letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema()
|
||||
|
||||
for name, docs in [
|
||||
(
|
||||
"letta",
|
||||
@@ -320,6 +329,9 @@ def start_server(
|
||||
app,
|
||||
host=host or "localhost",
|
||||
port=port or REST_DEFAULT_PORT,
|
||||
workers=settings.uvicorn_workers,
|
||||
reload=settings.uvicorn_reload,
|
||||
timeout_keep_alive=settings.uvicorn_timeout_keep_alive,
|
||||
ssl_keyfile="certs/localhost-key.pem",
|
||||
ssl_certfile="certs/localhost.pem",
|
||||
)
|
||||
@@ -336,4 +348,7 @@ def start_server(
|
||||
app,
|
||||
host=host or "localhost",
|
||||
port=port or REST_DEFAULT_PORT,
|
||||
workers=settings.uvicorn_workers,
|
||||
reload=settings.uvicorn_reload,
|
||||
timeout_keep_alive=settings.uvicorn_timeout_keep_alive,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from letta.server.rest_api.routers.v1.agents import router as agents_router
|
||||
from letta.server.rest_api.routers.v1.blocks import router as blocks_router
|
||||
from letta.server.rest_api.routers.v1.groups import router as groups_router
|
||||
from letta.server.rest_api.routers.v1.health import router as health_router
|
||||
from letta.server.rest_api.routers.v1.identities import router as identities_router
|
||||
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
|
||||
@@ -17,6 +18,7 @@ ROUTERS = [
|
||||
tools_router,
|
||||
sources_router,
|
||||
agents_router,
|
||||
groups_router,
|
||||
identities_router,
|
||||
llm_router,
|
||||
blocks_router,
|
||||
|
||||
@@ -24,6 +24,7 @@ from letta.schemas.run import Run
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@@ -35,81 +36,82 @@ router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This should be paginated
|
||||
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
|
||||
def list_agents(
|
||||
name: Optional[str] = Query(None, description="Name of the agent"),
|
||||
tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"),
|
||||
match_all_tags: bool = Query(
|
||||
False,
|
||||
description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed in tags.",
|
||||
description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed-in tags.",
|
||||
),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(None, description="Limit for pagination"),
|
||||
limit: Optional[int] = Query(50, description="Limit for pagination"),
|
||||
query_text: Optional[str] = Query(None, description="Search agents by name"),
|
||||
project_id: Optional[str] = Query(None, description="Search agents by project id"),
|
||||
template_id: Optional[str] = Query(None, description="Search agents by template id"),
|
||||
base_template_id: Optional[str] = Query(None, description="Search agents by base template id"),
|
||||
identity_id: Optional[str] = Query(None, description="Search agents by identifier id"),
|
||||
project_id: Optional[str] = Query(None, description="Search agents by project ID"),
|
||||
template_id: Optional[str] = Query(None, description="Search agents by template ID"),
|
||||
base_template_id: Optional[str] = Query(None, description="Search agents by base template ID"),
|
||||
identity_id: Optional[str] = Query(None, description="Search agents by identity ID"),
|
||||
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
|
||||
include_relationships: Optional[List[str]] = Query(
|
||||
None,
|
||||
description=(
|
||||
"Specify which relational fields (e.g., 'tools', 'sources', 'memory') to include in the response. "
|
||||
"If not provided, all relationships are loaded by default. "
|
||||
"Using this can optimize performance by reducing unnecessary joins."
|
||||
),
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
|
||||
This endpoint retrieves a list of all agents and their configurations
|
||||
associated with the specified user ID.
|
||||
"""
|
||||
|
||||
# Retrieve the actor (user) details
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Use dictionary comprehension to build kwargs dynamically
|
||||
kwargs = {
|
||||
key: value
|
||||
for key, value in {
|
||||
"name": name,
|
||||
"project_id": project_id,
|
||||
"template_id": template_id,
|
||||
"base_template_id": base_template_id,
|
||||
}.items()
|
||||
if value is not None
|
||||
}
|
||||
|
||||
# Call list_agents with the dynamic kwargs
|
||||
agents = server.agent_manager.list_agents(
|
||||
# Call list_agents directly without unnecessary dict handling
|
||||
return server.agent_manager.list_agents(
|
||||
actor=actor,
|
||||
name=name,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
query_text=query_text,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
identifier_keys=identifier_keys,
|
||||
project_id=project_id,
|
||||
template_id=template_id,
|
||||
base_template_id=base_template_id,
|
||||
identity_id=identity_id,
|
||||
**kwargs,
|
||||
identifier_keys=identifier_keys,
|
||||
include_relationships=include_relationships,
|
||||
)
|
||||
return agents
|
||||
|
||||
|
||||
@router.get("/{agent_id}/download", operation_id="download_agent_serialized")
|
||||
def download_agent_serialized(
|
||||
@router.get("/{agent_id}/export", operation_id="export_agent_serialized", response_model=AgentSchema)
|
||||
def export_agent_serialized(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
) -> AgentSchema:
|
||||
"""
|
||||
Download the serialized JSON representation of an agent.
|
||||
Export the serialized JSON representation of an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
serialized_agent = server.agent_manager.serialize(agent_id=agent_id, actor=actor)
|
||||
return JSONResponse(content=serialized_agent, media_type="application/json")
|
||||
return server.agent_manager.serialize(agent_id=agent_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
||||
|
||||
|
||||
@router.post("/upload", response_model=AgentState, operation_id="upload_agent_serialized")
|
||||
async def upload_agent_serialized(
|
||||
@router.post("/import", response_model=AgentState, operation_id="import_agent_serialized")
|
||||
async def import_agent_serialized(
|
||||
file: UploadFile = File(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -121,15 +123,19 @@ async def upload_agent_serialized(
|
||||
project_id: Optional[str] = Query(None, description="The project ID to associate the uploaded agent with."),
|
||||
):
|
||||
"""
|
||||
Upload a serialized agent JSON file and recreate the agent in the system.
|
||||
Import a serialized agent file and recreate the agent in the system.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
serialized_data = await file.read()
|
||||
agent_json = json.loads(serialized_data)
|
||||
|
||||
# Validate the JSON against AgentSchema before passing it to deserialize
|
||||
agent_schema = AgentSchema.model_validate(agent_json)
|
||||
|
||||
new_agent = server.agent_manager.deserialize(
|
||||
serialized_agent=agent_json,
|
||||
serialized_agent=agent_schema, # Ensure we're passing a validated AgentSchema
|
||||
actor=actor,
|
||||
append_copy_suffix=append_copy_suffix,
|
||||
override_existing_tools=override_existing_tools,
|
||||
@@ -141,7 +147,7 @@ async def upload_agent_serialized(
|
||||
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
|
||||
|
||||
except ValidationError as e:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {str(e)}")
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent schema: {e.errors()}")
|
||||
|
||||
except IntegrityError as e:
|
||||
raise HTTPException(status_code=409, detail=f"Database integrity error: {str(e)}")
|
||||
@@ -149,9 +155,9 @@ async def upload_agent_serialized(
|
||||
except OperationalError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Database connection error. Please try again later: {str(e)}")
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail="An unexpected error occurred while uploading the agent.")
|
||||
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while uploading the agent: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{agent_id}/context", response_model=ContextWindowOverview, operation_id="retrieve_agent_context_window")
|
||||
@@ -530,7 +536,7 @@ def list_messages(
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUpdateUnion, operation_id="modify_message")
|
||||
@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_message")
|
||||
def modify_message(
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
|
||||
233
letta/server/rest_api/routers/v1/groups.py
Normal file
233
letta/server/rest_api/routers/v1/groups.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.group import Group, GroupCreate, ManagerType
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/groups", tags=["groups"])
|
||||
|
||||
|
||||
@router.post("/", response_model=Group, operation_id="create_group")
|
||||
async def create_group(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: GroupCreate = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a multi-agent group with a specified management pattern. When no
|
||||
management config is specified, this endpoint will use round robin for
|
||||
speaker selection.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.create_group(request, actor=actor)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Group], operation_id="list_groups")
|
||||
def list_groups(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
manager_type: Optional[ManagerType] = Query(None, description="Search groups by manager type"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(None, description="Limit for pagination"),
|
||||
project_id: Optional[str] = Query(None, description="Search groups by project id"),
|
||||
):
|
||||
"""
|
||||
Fetch all multi-agent groups matching query.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.list_groups(
|
||||
project_id=project_id,
|
||||
manager_type=manager_type,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=Group, operation_id="create_group")
|
||||
def create_group(
|
||||
group: GroupCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
"""
|
||||
Create a new multi-agent group with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.create_group(group, actor=actor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/", response_model=Group, operation_id="upsert_group")
|
||||
def upsert_group(
|
||||
group: GroupCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
"""
|
||||
Create a new multi-agent group with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.create_group(group, actor=actor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{group_id}", response_model=None, operation_id="delete_group")
|
||||
def delete_group(
|
||||
group_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a multi-agent group.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
server.group_manager.delete_group(group_id=group_id, actor=actor)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Group id={group_id} successfully deleted"})
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Group id={group_id} not found for user_id={actor.id}.")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{group_id}/messages",
|
||||
response_model=LettaResponse,
|
||||
operation_id="send_group_message",
|
||||
)
|
||||
async def send_group_message(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the group's response.
|
||||
This endpoint accepts a message from a user and processes it through through agents in the group based on the specified pattern
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{group_id}/messages/stream",
|
||||
response_model=None,
|
||||
operation_id="send_group_message_streaming",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def send_group_message_streaming(
|
||||
group_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the group's responses.
|
||||
This endpoint accepts a message from a user and processes it through agents in the group based on the specified pattern.
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
GroupMessagesResponse = Annotated[
|
||||
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages")
|
||||
def list_group_messages(
|
||||
group_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
|
||||
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
|
||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
|
||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.group_manager.list_group_messages(
|
||||
group_id=group_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
@router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages")
|
||||
def reset_group_messages(
|
||||
group_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Resets the messages for all agents that are part of the multi-agent group.
|
||||
TODO: only delete group messages not all messages!
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
group = server.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
agent_ids = group.agent_ids
|
||||
if group.manager_agent_id:
|
||||
agent_ids.append(group.manager_agent_id)
|
||||
for agent_id in agent_ids:
|
||||
server.agent_manager.reset_messages(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
add_default_initial_messages=add_default_initial_messages,
|
||||
)
|
||||
'''
|
||||
@@ -13,7 +13,7 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.mcp_helpers import LocalServerConfig, MCPTool, SSEServerConfig
|
||||
from letta.helpers.mcp_helpers import MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
@@ -333,7 +333,7 @@ def add_composio_tool(
|
||||
|
||||
|
||||
# Specific routes for MCP
|
||||
@router.get("/mcp/servers", response_model=dict[str, Union[SSEServerConfig, LocalServerConfig]], operation_id="list_mcp_servers")
|
||||
@router.get("/mcp/servers", response_model=dict[str, Union[SSEServerConfig, StdioServerConfig]], operation_id="list_mcp_servers")
|
||||
def list_mcp_servers(server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id")):
|
||||
"""
|
||||
Get a list of all configured MCP servers
|
||||
@@ -376,7 +376,7 @@ def add_mcp_tool(
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Add a new MCP tool by server + tool name
|
||||
Register a new MCP tool as a Letta server by MCP server + tool name
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
@@ -399,3 +399,31 @@ def add_mcp_tool(
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, actor=actor)
|
||||
|
||||
|
||||
@router.put("/mcp/servers", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="add_mcp_server")
|
||||
def add_mcp_server_to_config(
|
||||
request: Union[StdioServerConfig, SSEServerConfig] = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Add a new MCP server to the Letta MCP server config
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/mcp/servers/{mcp_server_name}", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="delete_mcp_server"
|
||||
)
|
||||
def delete_mcp_server_from_config(
|
||||
mcp_server_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Add a new MCP server to the Letta MCP server config
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.delete_mcp_server_from_config(server_name=mcp_server_name)
|
||||
|
||||
@@ -18,7 +18,7 @@ from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
|
||||
@@ -19,16 +19,18 @@ import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.config import LettaConfig
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.dynamic_multi_agent import DynamicMultiAgent
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.mcp_helpers import (
|
||||
MCP_CONFIG_TOPLEVEL_KEY,
|
||||
BaseMCPClient,
|
||||
LocalMCPClient,
|
||||
LocalServerConfig,
|
||||
MCPServerType,
|
||||
MCPTool,
|
||||
SSEMCPClient,
|
||||
SSEServerConfig,
|
||||
StdioMCPClient,
|
||||
StdioServerConfig,
|
||||
)
|
||||
|
||||
# TODO use custom interface
|
||||
@@ -37,6 +39,7 @@ from letta.interface import CLIInterface # for printing to terminal
|
||||
from letta.log import get_logger
|
||||
from letta.offline_memory_agent import OfflineMemoryAgent
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.round_robin_multi_agent import RoundRobinMultiAgent
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -44,12 +47,14 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.group import Group, ManagerType
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate, TextContent
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
from letta.schemas.providers import (
|
||||
@@ -80,6 +85,7 @@ from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.utils import sse_async_generator
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.group_manager import GroupManager
|
||||
from letta.services.identity_manager import IdentityManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
@@ -94,6 +100,7 @@ from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings, settings, tool_settings
|
||||
from letta.supervisor_multi_agent import SupervisorMultiAgent
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
@@ -207,6 +214,7 @@ class SyncServer(Server):
|
||||
self.provider_manager = ProviderManager()
|
||||
self.step_manager = StepManager()
|
||||
self.identity_manager = IdentityManager()
|
||||
self.group_manager = GroupManager()
|
||||
|
||||
# Managers that interface with parallelism
|
||||
self.per_agent_lock_manager = PerAgentLockManager()
|
||||
@@ -331,8 +339,8 @@ class SyncServer(Server):
|
||||
for server_name, server_config in mcp_server_configs.items():
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
self.mcp_clients[server_name] = SSEMCPClient()
|
||||
elif server_config.type == MCPServerType.LOCAL:
|
||||
self.mcp_clients[server_name] = LocalMCPClient()
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
self.mcp_clients[server_name] = StdioMCPClient()
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP server config: {server_config}")
|
||||
try:
|
||||
@@ -353,6 +361,8 @@ class SyncServer(Server):
|
||||
agent_lock = self.per_agent_lock_manager.get_lock(agent_id)
|
||||
with agent_lock:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
if agent_state.multi_agent_group:
|
||||
return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state)
|
||||
|
||||
interface = interface or self.default_interface_factory()
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
@@ -364,6 +374,46 @@ class SyncServer(Server):
|
||||
|
||||
return agent
|
||||
|
||||
def load_multi_agent(
|
||||
self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None
|
||||
) -> Agent:
|
||||
match group.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor)
|
||||
return RoundRobinMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
max_turns=group.max_turns,
|
||||
)
|
||||
case ManagerType.dynamic:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
|
||||
return DynamicMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
max_turns=group.max_turns,
|
||||
termination_token=group.termination_token,
|
||||
)
|
||||
case ManagerType.supervisor:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
|
||||
return SupervisorMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Type {group.manager_type} is not supported.")
|
||||
|
||||
def _step(
|
||||
self,
|
||||
actor: User,
|
||||
@@ -690,7 +740,7 @@ class SyncServer(Server):
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)],
|
||||
content=[TextContent(text=message.content)] if message.content else [],
|
||||
name=message.name,
|
||||
# assigned later?
|
||||
model=None,
|
||||
@@ -800,6 +850,9 @@ class SyncServer(Server):
|
||||
# TODO: @mindy look at moving this to agent_manager to avoid above extra call
|
||||
passages = self.passage_manager.insert_passage(agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor)
|
||||
|
||||
# rebuild agent system prompt - force since no archival change
|
||||
self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True)
|
||||
|
||||
return passages
|
||||
|
||||
def modify_archival_memory(self, agent_id: str, memory_id: str, passage: PassageUpdate, actor: User) -> List[Passage]:
|
||||
@@ -809,10 +862,14 @@ class SyncServer(Server):
|
||||
|
||||
def delete_archival_memory(self, memory_id: str, actor: User):
|
||||
# TODO check if it exists first, and throw error if not
|
||||
# TODO: @mindy make this return the deleted passage instead
|
||||
# TODO: need to also rebuild the prompt here
|
||||
passage = self.passage_manager.get_passage_by_id(passage_id=memory_id, actor=actor)
|
||||
|
||||
# delete the passage
|
||||
self.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
|
||||
|
||||
# TODO: return archival memory
|
||||
# rebuild system prompt and force
|
||||
self.agent_manager.rebuild_system_prompt(agent_id=passage.agent_id, actor=actor, force=True)
|
||||
|
||||
def get_agent_recall(
|
||||
self,
|
||||
@@ -931,6 +988,9 @@ class SyncServer(Server):
|
||||
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
||||
assert new_passage_size >= curr_passage_size # in case empty files are added
|
||||
|
||||
# rebuild system prompt and force
|
||||
self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True)
|
||||
|
||||
return job
|
||||
|
||||
def load_data(
|
||||
@@ -1209,7 +1269,7 @@ class SyncServer(Server):
|
||||
|
||||
# MCP wrappers
|
||||
# TODO support both command + SSE servers (via config)
|
||||
def get_mcp_servers(self) -> dict[str, Union[SSEServerConfig, LocalServerConfig]]:
|
||||
def get_mcp_servers(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
||||
"""List the MCP servers in the config (doesn't test that they are actually working)"""
|
||||
mcp_server_list = {}
|
||||
|
||||
@@ -1227,8 +1287,8 @@ class SyncServer(Server):
|
||||
# Proper formatting is "mcpServers" key at the top level,
|
||||
# then a dict with the MCP server name as the key,
|
||||
# with the value being the schema from StdioServerParameters
|
||||
if "mcpServers" in mcp_config:
|
||||
for server_name, server_params_raw in mcp_config["mcpServers"].items():
|
||||
if MCP_CONFIG_TOPLEVEL_KEY in mcp_config:
|
||||
for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items():
|
||||
|
||||
# No support for duplicate server names
|
||||
if server_name in mcp_server_list:
|
||||
@@ -1249,7 +1309,7 @@ class SyncServer(Server):
|
||||
else:
|
||||
# Attempt to parse the server params as a StdioServerParameters
|
||||
try:
|
||||
server_params = LocalServerConfig(
|
||||
server_params = StdioServerConfig(
|
||||
server_name=server_name,
|
||||
command=server_params_raw["command"],
|
||||
args=server_params_raw.get("args", []),
|
||||
@@ -1269,6 +1329,98 @@ class SyncServer(Server):
|
||||
|
||||
return self.mcp_clients[mcp_server_name].list_tools()
|
||||
|
||||
def add_mcp_server_to_config(
|
||||
self, server_config: Union[SSEServerConfig, StdioServerConfig], allow_upsert: bool = True
|
||||
) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
||||
"""Add a new server config to the MCP config file"""
|
||||
|
||||
# If the config file doesn't exist, throw an error.
|
||||
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
|
||||
if not os.path.exists(mcp_config_path):
|
||||
raise FileNotFoundError(f"MCP config file not found: {mcp_config_path}")
|
||||
|
||||
# If the file does exist, attempt to parse it get calling get_mcp_servers
|
||||
try:
|
||||
current_mcp_servers = self.get_mcp_servers()
|
||||
except Exception as e:
|
||||
# Raise an error telling the user to fix the config file
|
||||
logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}")
|
||||
raise ValueError(f"Failed to parse MCP config file {mcp_config_path}")
|
||||
|
||||
# Check if the server name is already in the config
|
||||
if server_config.server_name in current_mcp_servers and not allow_upsert:
|
||||
raise ValueError(f"Server name {server_config.server_name} is already in the config file")
|
||||
|
||||
# Attempt to initialize the connection to the server
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
new_mcp_client = SSEMCPClient()
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
new_mcp_client = StdioMCPClient()
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP server config: {server_config}")
|
||||
try:
|
||||
new_mcp_client.connect_to_server(server_config)
|
||||
except:
|
||||
logger.exception(f"Failed to connect to MCP server: {server_config.server_name}")
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {server_config.server_name}")
|
||||
# Print out the tools that are connected
|
||||
logger.info(f"Attempting to fetch tools from MCP server: {server_config.server_name}")
|
||||
new_mcp_tools = new_mcp_client.list_tools()
|
||||
logger.info(f"MCP tools connected: {', '.join([t.name for t in new_mcp_tools])}")
|
||||
logger.debug(f"MCP tools: {', '.join([str(t) for t in new_mcp_tools])}")
|
||||
|
||||
# Now that we've confirmed the config is working, let's add it to the client list
|
||||
self.mcp_clients[server_config.server_name] = new_mcp_client
|
||||
|
||||
# Add to the server file
|
||||
current_mcp_servers[server_config.server_name] = server_config
|
||||
|
||||
# Write out the file, and make sure to in include the top-level mcpConfig
|
||||
try:
|
||||
new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}}
|
||||
with open(mcp_config_path, "w") as f:
|
||||
json.dump(new_mcp_file, f, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}")
|
||||
raise ValueError(f"Failed to write MCP config file {mcp_config_path}")
|
||||
|
||||
return list(current_mcp_servers.values())
|
||||
|
||||
def delete_mcp_server_from_config(self, server_name: str) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
||||
"""Delete a server config from the MCP config file"""
|
||||
|
||||
# If the config file doesn't exist, throw an error.
|
||||
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
|
||||
if not os.path.exists(mcp_config_path):
|
||||
raise FileNotFoundError(f"MCP config file not found: {mcp_config_path}")
|
||||
|
||||
# If the file does exist, attempt to parse it get calling get_mcp_servers
|
||||
try:
|
||||
current_mcp_servers = self.get_mcp_servers()
|
||||
except Exception as e:
|
||||
# Raise an error telling the user to fix the config file
|
||||
logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}")
|
||||
raise ValueError(f"Failed to parse MCP config file {mcp_config_path}")
|
||||
|
||||
# Check if the server name is already in the config
|
||||
# If it's not, throw an error
|
||||
if server_name not in current_mcp_servers:
|
||||
raise ValueError(f"Server name {server_name} not found in MCP config file")
|
||||
|
||||
# Remove from the server file
|
||||
del current_mcp_servers[server_name]
|
||||
|
||||
# Write out the file, and make sure to in include the top-level mcpConfig
|
||||
try:
|
||||
new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}}
|
||||
with open(mcp_config_path, "w") as f:
|
||||
json.dump(new_mcp_file, f, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}")
|
||||
raise ValueError(f"Failed to write MCP config file {mcp_config_path}")
|
||||
|
||||
return list(current_mcp_servers.values())
|
||||
|
||||
@trace_method
|
||||
async def send_message_to_agent(
|
||||
self,
|
||||
@@ -1403,3 +1555,106 @@ class SyncServer(Server):
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@trace_method
|
||||
async def send_group_message_to_agent(
|
||||
self,
|
||||
group_id: str,
|
||||
actor: User,
|
||||
messages: Union[List[Message], List[MessageCreate]],
|
||||
stream_steps: bool,
|
||||
stream_tokens: bool,
|
||||
chat_completion_mode: bool = False,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Union[StreamingResponse, LettaResponse]:
|
||||
include_final_message = True
|
||||
if not stream_steps and stream_tokens:
|
||||
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
|
||||
|
||||
try:
|
||||
# fetch the group
|
||||
group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
letta_multi_agent = self.load_multi_agent(group=group, actor=actor)
|
||||
|
||||
llm_config = letta_multi_agent.agent_state.llm_config
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"]
|
||||
if stream_tokens and (
|
||||
llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint
|
||||
):
|
||||
warnings.warn(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
)
|
||||
stream_tokens = False
|
||||
|
||||
# Create a new interface per request
|
||||
letta_multi_agent.interface = StreamingServerInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
inner_thoughts_in_kwargs=(
|
||||
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
|
||||
),
|
||||
)
|
||||
streaming_interface = letta_multi_agent.interface
|
||||
if not isinstance(streaming_interface, StreamingServerInterface):
|
||||
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
|
||||
streaming_interface.streaming_mode = stream_tokens
|
||||
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
|
||||
if metadata and hasattr(streaming_interface, "metadata"):
|
||||
streaming_interface.metadata = metadata
|
||||
|
||||
streaming_interface.stream_start()
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
letta_multi_agent.step,
|
||||
messages=messages,
|
||||
chaining=self.chaining,
|
||||
max_chaining_steps=self.max_chaining_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
# return a stream
|
||||
return StreamingResponse(
|
||||
sse_async_generator(
|
||||
streaming_interface.get_generator(),
|
||||
usage_task=task,
|
||||
finish_message=include_final_message,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
else:
|
||||
# buffer the stream, then return the list
|
||||
generated_stream = []
|
||||
async for message in streaming_interface.get_generator():
|
||||
assert (
|
||||
isinstance(message, LettaMessage)
|
||||
or isinstance(message, LegacyLettaMessage)
|
||||
or isinstance(message, MessageStreamStatus)
|
||||
), type(message)
|
||||
generated_stream.append(message)
|
||||
if message == MessageStreamStatus.done:
|
||||
break
|
||||
|
||||
# Get rid of the stream status messages
|
||||
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
|
||||
usage = await task
|
||||
|
||||
# By default the stream will be messages of type LettaMessage or LettaLegacyMessage
|
||||
# If we want to convert these to Message, we can use the attached IDs
|
||||
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
|
||||
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
|
||||
return LettaResponse(messages=filtered_stream, usage=usage)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@@ -18,6 +18,7 @@ from letta.orm import Tool as ToolModel
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||
@@ -35,10 +36,15 @@ from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule
|
||||
from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.serialize_schemas import SerializedAgentSchema
|
||||
from letta.serialize_schemas.tool import SerializedToolSchema
|
||||
from letta.serialize_schemas import MarshmallowAgentSchema
|
||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_apply_filters,
|
||||
_apply_identity_filters,
|
||||
_apply_pagination,
|
||||
_apply_tag_filter,
|
||||
_process_relationship,
|
||||
_process_tags,
|
||||
check_supports_structured_output,
|
||||
@@ -49,6 +55,7 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
)
|
||||
from letta.services.identity_manager import IdentityManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings
|
||||
@@ -70,6 +77,7 @@ class AgentManager:
|
||||
self.tool_manager = ToolManager()
|
||||
self.source_manager = SourceManager()
|
||||
self.message_manager = MessageManager()
|
||||
self.passage_manager = PassageManager()
|
||||
self.identity_manager = IdentityManager()
|
||||
|
||||
# ======================================================================================================================
|
||||
@@ -326,39 +334,60 @@ class AgentManager:
|
||||
# Convert to PydanticAgentState and return
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
# TODO: Make this general and think about how to roll this into sqlalchemybase
|
||||
def list_agents(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
query_text: Optional[str] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
base_template_id: Optional[str] = None,
|
||||
identity_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
include_relationships: Optional[List[str]] = None,
|
||||
) -> List[PydanticAgentState]:
|
||||
"""
|
||||
List agents that have the specified tags.
|
||||
Retrieves agents with optimized filtering and optional field selection.
|
||||
|
||||
Args:
|
||||
actor: The User requesting the list
|
||||
name (Optional[str]): Filter by agent name.
|
||||
tags (Optional[List[str]]): Filter agents by tags.
|
||||
match_all_tags (bool): If True, only return agents that match ALL given tags.
|
||||
before (Optional[str]): Cursor for pagination.
|
||||
after (Optional[str]): Cursor for pagination.
|
||||
limit (Optional[int]): Maximum number of agents to return.
|
||||
query_text (Optional[str]): Search agents by name.
|
||||
project_id (Optional[str]): Filter by project ID.
|
||||
template_id (Optional[str]): Filter by template ID.
|
||||
base_template_id (Optional[str]): Filter by base template ID.
|
||||
identity_id (Optional[str]): Filter by identifier ID.
|
||||
identifier_keys (Optional[List[str]]): Search agents by identifier keys.
|
||||
include_relationships (Optional[List[str]]): List of fields to load for performance optimization.
|
||||
|
||||
Returns:
|
||||
List[PydanticAgentState]: The filtered list of matching agents.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
agents = AgentModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
organization_id=actor.organization_id if actor else None,
|
||||
query_text=query_text,
|
||||
identifier_keys=identifier_keys,
|
||||
identity_id=identity_id,
|
||||
**kwargs,
|
||||
)
|
||||
query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id)
|
||||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
|
||||
return [agent.to_pydantic() for agent in agents]
|
||||
# Apply filters
|
||||
query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id)
|
||||
query = _apply_identity_filters(query, identity_id, identifier_keys)
|
||||
query = _apply_tag_filter(query, tags, match_all_tags)
|
||||
query = _apply_pagination(query, before, after, session)
|
||||
|
||||
query = query.limit(limit)
|
||||
|
||||
agents = session.execute(query).scalars().all()
|
||||
return [agent.to_pydantic(include_relationships=include_relationships) for agent in agents]
|
||||
|
||||
@enforce_types
|
||||
def list_agents_matching_tags(
|
||||
@@ -399,7 +428,7 @@ class AgentManager:
|
||||
# Ensures agents match at least one tag in match_some
|
||||
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
|
||||
|
||||
query = query.group_by(AgentModel.id).limit(limit)
|
||||
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
|
||||
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
@@ -434,29 +463,32 @@ class AgentManager:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
# TODO check if it is managing a group
|
||||
agent.hard_delete(session)
|
||||
|
||||
@enforce_types
|
||||
def serialize(self, agent_id: str, actor: PydanticUser) -> dict:
|
||||
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
schema = SerializedAgentSchema(session=session, actor=actor)
|
||||
return schema.dump(agent)
|
||||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||||
data = schema.dump(agent)
|
||||
return AgentSchema(**data)
|
||||
|
||||
@enforce_types
|
||||
def deserialize(
|
||||
self,
|
||||
serialized_agent: dict,
|
||||
serialized_agent: AgentSchema,
|
||||
actor: PydanticUser,
|
||||
append_copy_suffix: bool = True,
|
||||
override_existing_tools: bool = True,
|
||||
project_id: Optional[str] = None,
|
||||
) -> PydanticAgentState:
|
||||
serialized_agent = serialized_agent.model_dump()
|
||||
tool_data_list = serialized_agent.pop("tools", [])
|
||||
|
||||
with self.session_maker() as session:
|
||||
schema = SerializedAgentSchema(session=session, actor=actor)
|
||||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||||
agent = schema.load(serialized_agent, session=session)
|
||||
if append_copy_suffix:
|
||||
agent.name += "_copy"
|
||||
@@ -595,12 +627,17 @@ class AgentManager:
|
||||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||||
memory_edit_timestamp = curr_system_message.created_at
|
||||
|
||||
num_messages = self.message_manager.size(actor=actor, agent_id=agent_id)
|
||||
num_archival_memories = self.passage_manager.size(actor=actor, agent_id=agent_id)
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10),
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||||
|
||||
147
letta/services/group_manager.py
Normal file
147
letta/services/group_manager.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.group import Group as GroupModel
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.schemas.group import Group as PydanticGroup
|
||||
from letta.schemas.group import GroupCreate, ManagerType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class GroupManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def list_groups(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
manager_type: Optional[ManagerType] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
) -> list[PydanticGroup]:
|
||||
with self.session_maker() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
if manager_type:
|
||||
filters["manager_type"] = manager_type
|
||||
groups = GroupModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
**filters,
|
||||
)
|
||||
return [group.to_pydantic() for group in groups]
|
||||
|
||||
@enforce_types
|
||||
def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
|
||||
with self.session_maker() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
return group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup:
|
||||
with self.session_maker() as session:
|
||||
new_group = GroupModel()
|
||||
new_group.organization_id = actor.organization_id
|
||||
new_group.description = group.description
|
||||
self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False)
|
||||
if group.manager_config is None:
|
||||
new_group.manager_type = ManagerType.round_robin
|
||||
else:
|
||||
match group.manager_config.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
new_group.manager_type = ManagerType.round_robin
|
||||
new_group.max_turns = group.manager_config.max_turns
|
||||
case ManagerType.dynamic:
|
||||
new_group.manager_type = ManagerType.dynamic
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
new_group.max_turns = group.manager_config.max_turns
|
||||
new_group.termination_token = group.manager_config.termination_token
|
||||
case ManagerType.supervisor:
|
||||
new_group.manager_type = ManagerType.supervisor
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
|
||||
new_group.create(session, actor=actor)
|
||||
return new_group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
group.hard_delete(session)
|
||||
|
||||
@enforce_types
|
||||
def list_group_messages(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = "send_message",
|
||||
assistant_message_tool_kwarg: str = "message",
|
||||
) -> list[PydanticGroup]:
|
||||
with self.session_maker() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
agent_id = group.manager_agent_id if group.manager_agent_id else group.agent_ids[0]
|
||||
|
||||
filters = {
|
||||
"organization_id": actor.organization_id,
|
||||
"group_id": group_id,
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
messages = MessageModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
**filters,
|
||||
)
|
||||
|
||||
messages = PydanticMessage.to_letta_messages_from_list(
|
||||
messages=messages,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
|
||||
current_relationship = getattr(group, "agents", [])
|
||||
if not agent_ids:
|
||||
if replace:
|
||||
setattr(group, "agents", [])
|
||||
return
|
||||
|
||||
# Retrieve models for the provided IDs
|
||||
found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all()
|
||||
|
||||
# Validate all items are found if allow_partial is False
|
||||
if not allow_partial and len(found_items) != len(agent_ids):
|
||||
missing = set(agent_ids) - {item.id for item in found_items}
|
||||
raise NoResultFound(f"Items not found in agents: {missing}")
|
||||
|
||||
if replace:
|
||||
# Replace the relationship
|
||||
setattr(group, "agents", found_items)
|
||||
else:
|
||||
# Extend the relationship (only add new items)
|
||||
current_ids = {item.id for item in current_relationship}
|
||||
new_items = [item for item in found_items if item.id not in current_ids]
|
||||
current_relationship.extend(new_items)
|
||||
@@ -1,6 +1,8 @@
|
||||
import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
|
||||
from letta import system
|
||||
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS
|
||||
from letta.helpers import ToolRulesSolver
|
||||
@@ -8,11 +10,13 @@ from letta.helpers.datetime_helpers import get_local_time
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.identity import Identity
|
||||
from letta.prompts import gpt_system
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate, TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.schemas.user import User
|
||||
@@ -293,3 +297,149 @@ def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) ->
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _apply_pagination(query, before: Optional[str], after: Optional[str], session) -> any:
|
||||
"""
|
||||
Apply cursor-based pagination filters using the agent's created_at timestamp with id as a tie-breaker.
|
||||
|
||||
Instead of relying on the UUID ordering, this function uses the agent's creation time
|
||||
(and id for tie-breaking) to paginate the results. It performs a minimal lookup to fetch
|
||||
only the created_at and id for the agent corresponding to the provided cursor.
|
||||
|
||||
Args:
|
||||
query: The SQLAlchemy query object to modify.
|
||||
before (Optional[str]): Cursor (agent id) to return agents created before this agent.
|
||||
after (Optional[str]): Cursor (agent id) to return agents created after this agent.
|
||||
session: The active database session used to execute the minimal lookup.
|
||||
|
||||
Returns:
|
||||
The modified query with pagination filters applied and ordered by created_at and id.
|
||||
"""
|
||||
if after:
|
||||
# Retrieve only the created_at and id for the agent corresponding to the 'after' cursor.
|
||||
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after)).first()
|
||||
if result:
|
||||
after_created_at, after_id = result
|
||||
# Filter: include agents created after the reference, or at the same time but with a greater id.
|
||||
query = query.where(
|
||||
or_(
|
||||
AgentModel.created_at > after_created_at,
|
||||
and_(
|
||||
AgentModel.created_at == after_created_at,
|
||||
AgentModel.id > after_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
if before:
|
||||
# Retrieve only the created_at and id for the agent corresponding to the 'before' cursor.
|
||||
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before)).first()
|
||||
if result:
|
||||
before_created_at, before_id = result
|
||||
# Filter: include agents created before the reference, or at the same time but with a smaller id.
|
||||
query = query.where(
|
||||
or_(
|
||||
AgentModel.created_at < before_created_at,
|
||||
and_(
|
||||
AgentModel.created_at == before_created_at,
|
||||
AgentModel.id < before_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
# Enforce a deterministic ordering: first by created_at, then by id.
|
||||
query = query.order_by(AgentModel.created_at.asc(), AgentModel.id.asc())
|
||||
return query
|
||||
|
||||
|
||||
def _apply_tag_filter(query, tags: Optional[List[str]], match_all_tags: bool):
|
||||
"""
|
||||
Apply tag-based filtering to the agent query.
|
||||
|
||||
This helper function creates a subquery that groups agent IDs by their tags.
|
||||
If `match_all_tags` is True, it filters agents that have all of the specified tags.
|
||||
Otherwise, it filters agents that have any of the tags.
|
||||
|
||||
Args:
|
||||
query: The SQLAlchemy query object to be modified.
|
||||
tags (Optional[List[str]]): A list of tags to filter agents.
|
||||
match_all_tags (bool): If True, only return agents that match all provided tags.
|
||||
|
||||
Returns:
|
||||
The modified query with tag filters applied.
|
||||
"""
|
||||
if tags:
|
||||
# Build a subquery to select agent IDs that have the specified tags.
|
||||
subquery = select(AgentsTags.agent_id).where(AgentsTags.tag.in_(tags)).group_by(AgentsTags.agent_id)
|
||||
# If all tags must match, add a HAVING clause to ensure the count of tags equals the number provided.
|
||||
if match_all_tags:
|
||||
subquery = subquery.having(func.count(AgentsTags.tag) == literal(len(tags)))
|
||||
# Filter the main query to include only agents present in the subquery.
|
||||
query = query.where(AgentModel.id.in_(subquery))
|
||||
return query
|
||||
|
||||
|
||||
def _apply_identity_filters(query, identity_id: Optional[str], identifier_keys: Optional[List[str]]):
|
||||
"""
|
||||
Apply identity-related filters to the agent query.
|
||||
|
||||
This helper function joins the identities relationship and filters the agents based on
|
||||
a specific identity ID and/or a list of identifier keys.
|
||||
|
||||
Args:
|
||||
query: The SQLAlchemy query object to be modified.
|
||||
identity_id (Optional[str]): The identity ID to filter by.
|
||||
identifier_keys (Optional[List[str]]): A list of identifier keys to filter agents.
|
||||
|
||||
Returns:
|
||||
The modified query with identity filters applied.
|
||||
"""
|
||||
# Join the identities relationship and filter by a specific identity ID.
|
||||
if identity_id:
|
||||
query = query.join(AgentModel.identities).where(Identity.id == identity_id)
|
||||
# Join the identities relationship and filter by a set of identifier keys.
|
||||
if identifier_keys:
|
||||
query = query.join(AgentModel.identities).where(Identity.identifier_key.in_(identifier_keys))
|
||||
return query
|
||||
|
||||
|
||||
def _apply_filters(
|
||||
query,
|
||||
name: Optional[str],
|
||||
query_text: Optional[str],
|
||||
project_id: Optional[str],
|
||||
template_id: Optional[str],
|
||||
base_template_id: Optional[str],
|
||||
):
|
||||
"""
|
||||
Apply basic filtering criteria to the agent query.
|
||||
|
||||
This helper function adds WHERE clauses based on provided parameters such as
|
||||
exact name, partial name match (using ILIKE), project ID, template ID, and base template ID.
|
||||
|
||||
Args:
|
||||
query: The SQLAlchemy query object to be modified.
|
||||
name (Optional[str]): Exact name to filter by.
|
||||
query_text (Optional[str]): Partial text to search in the agent's name (case-insensitive).
|
||||
project_id (Optional[str]): Filter for agents belonging to a specific project.
|
||||
template_id (Optional[str]): Filter for agents using a specific template.
|
||||
base_template_id (Optional[str]): Filter for agents using a specific base template.
|
||||
|
||||
Returns:
|
||||
The modified query with the applied filters.
|
||||
"""
|
||||
# Filter by exact agent name if provided.
|
||||
if name:
|
||||
query = query.where(AgentModel.name == name)
|
||||
# Apply a case-insensitive partial match for the agent's name.
|
||||
if query_text:
|
||||
query = query.where(AgentModel.name.ilike(f"%{query_text}%"))
|
||||
# Filter agents by project ID.
|
||||
if project_id:
|
||||
query = query.where(AgentModel.project_id == project_id)
|
||||
# Filter agents by template ID.
|
||||
if template_id:
|
||||
query = query.where(AgentModel.template_id == template_id)
|
||||
# Filter agents by base template ID.
|
||||
if base_template_id:
|
||||
query = query.where(AgentModel.base_template_id == base_template_id)
|
||||
return query
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy import and_, exists, func, or_, select, text
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
@@ -233,9 +233,17 @@ class MessageManager:
|
||||
# Build a query that directly filters the Message table by agent_id.
|
||||
query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id)
|
||||
|
||||
# If query_text is provided, filter messages by partial match on text.
|
||||
# If query_text is provided, filter messages using subquery.
|
||||
if query_text:
|
||||
query = query.filter(MessageModel.text.ilike(f"%{query_text}%"))
|
||||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||||
query = query.filter(
|
||||
exists(
|
||||
select(1)
|
||||
.select_from(content_element)
|
||||
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
|
||||
.params(query_text=f"%{query_text}%")
|
||||
)
|
||||
)
|
||||
|
||||
# If role is provided, filter messages by role.
|
||||
if role:
|
||||
|
||||
@@ -203,3 +203,18 @@ class PassageManager:
|
||||
for passage in passages:
|
||||
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
||||
return True
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get the total count of messages with optional filters.
|
||||
|
||||
Args:
|
||||
actor: The user requesting the count
|
||||
agent_id: The agent ID of the messages
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
|
||||
|
||||
@@ -174,6 +174,11 @@ class Settings(BaseSettings):
|
||||
# telemetry logging
|
||||
verbose_telemetry_logging: bool = False
|
||||
|
||||
# uvicorn settings
|
||||
uvicorn_workers: int = 1
|
||||
uvicorn_reload: bool = False
|
||||
uvicorn_timeout_keep_alive: int = 5
|
||||
|
||||
@property
|
||||
def letta_pg_uri(self) -> str:
|
||||
if self.pg_uri:
|
||||
|
||||
103
letta/supervisor_multi_agent.py
Normal file
103
letta/supervisor_multi_agent.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL
|
||||
from letta.functions.function_sets.multi_agent import send_message_to_all_agents_in_group
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from tests.helpers.utils import create_tool_from_func
|
||||
|
||||
|
||||
class SupervisorMultiAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User = None,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
description: str = "",
|
||||
):
|
||||
super().__init__(interface, agent_state, user)
|
||||
self.group_id = group_id
|
||||
self.agent_ids = agent_ids
|
||||
self.description = description
|
||||
self.agent_manager = AgentManager()
|
||||
self.tool_manager = ToolManager()
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
# add multi agent tool
|
||||
if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None:
|
||||
multi_agent_tool = create_tool_from_func(send_message_to_all_agents_in_group)
|
||||
multi_agent_tool.tool_type = ToolType.LETTA_MULTI_AGENT_CORE
|
||||
multi_agent_tool = self.tool_manager.create_or_update_tool(
|
||||
pydantic_tool=multi_agent_tool,
|
||||
actor=self.user,
|
||||
)
|
||||
self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
|
||||
|
||||
# override tool rules
|
||||
self.agent_state.tool_rules = [
|
||||
InitToolRule(
|
||||
tool_name="send_message_to_all_agents_in_group",
|
||||
),
|
||||
TerminalToolRule(
|
||||
tool_name=assistant_message_tool_name,
|
||||
),
|
||||
ChildToolRule(
|
||||
tool_name="send_message_to_all_agents_in_group",
|
||||
children=[assistant_message_tool_name],
|
||||
),
|
||||
]
|
||||
|
||||
supervisor_messages = [
|
||||
Message(
|
||||
agent_id=self.agent_state.id,
|
||||
role="user",
|
||||
content=[TextContent(text=message.content)],
|
||||
name=None,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
try:
|
||||
supervisor_agent = Agent(agent_state=self.agent_state, interface=self.interface, user=self.user)
|
||||
usage_stats = supervisor_agent.step(
|
||||
messages=supervisor_messages,
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
self.interface.step_yield()
|
||||
|
||||
self.interface.step_complete()
|
||||
|
||||
return usage_stats
|
||||
361
poetry.lock
generated
361
poetry.lock
generated
@@ -268,13 +268,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
version = "25.2.0"
|
||||
version = "25.3.0"
|
||||
description = "Classes Without Boilerplate"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "attrs-25.2.0-py3-none-any.whl", hash = "sha256:611344ff0a5fed735d86d7784610c84f8126b95e549bcad9ff61b4242f2d386b"},
|
||||
{file = "attrs-25.2.0.tar.gz", hash = "sha256:18a06db706db43ac232cce80443fcd9f2500702059ecf53489e3c5a3f417acaf"},
|
||||
{file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"},
|
||||
{file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -447,17 +447,17 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.37.11"
|
||||
version = "1.37.12"
|
||||
description = "The AWS SDK for Python"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "boto3-1.37.11-py3-none-any.whl", hash = "sha256:da6c22fc8a7e9bca5d7fc465a877ac3d45b6b086d776bd1a6c55bdde60523741"},
|
||||
{file = "boto3-1.37.11.tar.gz", hash = "sha256:8eec08363ef5db05c2fbf58e89f0c0de6276cda2fdce01e76b3b5f423cd5c0f4"},
|
||||
{file = "boto3-1.37.12-py3-none-any.whl", hash = "sha256:516feaa0d2afaeda1515216fd09291368a1215754bbccb0f28414c0a91a830a2"},
|
||||
{file = "boto3-1.37.12.tar.gz", hash = "sha256:9412d404f103ad6d14f033eb29cd5e0cdca2b9b08cbfa9d4dabd1d7be2de2625"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
botocore = ">=1.37.11,<1.38.0"
|
||||
botocore = ">=1.37.12,<1.38.0"
|
||||
jmespath = ">=0.7.1,<2.0.0"
|
||||
s3transfer = ">=0.11.0,<0.12.0"
|
||||
|
||||
@@ -466,13 +466,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.37.11"
|
||||
version = "1.37.12"
|
||||
description = "Low-level, data-driven core of boto 3."
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "botocore-1.37.11-py3-none-any.whl", hash = "sha256:02505309b1235f9f15a6da79103ca224b3f3dc5f6a62f8630fbb2c6ed05e2da8"},
|
||||
{file = "botocore-1.37.11.tar.gz", hash = "sha256:72eb3a9a58b064be26ba154e5e56373633b58f951941c340ace0d379590d98b5"},
|
||||
{file = "botocore-1.37.12-py3-none-any.whl", hash = "sha256:ba1948c883bbabe20d95ff62c3e36954c9269686f7db9361857835677ca3e676"},
|
||||
{file = "botocore-1.37.12.tar.gz", hash = "sha256:ae2d5328ce6ad02eb615270507235a6e90fd3eeed615a6c0732b5a68b12f2017"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1035,8 +1035,8 @@ isort = ">=4.3.21,<6.0"
|
||||
jinja2 = ">=2.10.1,<4.0"
|
||||
packaging = "*"
|
||||
pydantic = [
|
||||
{version = ">=1.9.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.9.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.10.0,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.12\" and python_version < \"4.0\""},
|
||||
]
|
||||
pyyaml = ">=6.0.1"
|
||||
@@ -1048,50 +1048,6 @@ graphql = ["graphql-core (>=3.2.3,<4.0.0)"]
|
||||
http = ["httpx"]
|
||||
validation = ["openapi-spec-validator (>=0.2.8,<0.7.0)", "prance (>=0.18.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "2.21.0"
|
||||
description = "HuggingFace community-driven open-source library of datasets"
|
||||
optional = true
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "datasets-2.21.0-py3-none-any.whl", hash = "sha256:25e4e097110ce28824b746a107727ada94024cba11db8bc588d468414692b65a"},
|
||||
{file = "datasets-2.21.0.tar.gz", hash = "sha256:998f85a8460f1bd982e5bd058f8a0808eef424249e3df1e8cdd594ccd0dc8ba2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = "*"
|
||||
dill = ">=0.3.0,<0.3.9"
|
||||
filelock = "*"
|
||||
fsspec = {version = ">=2023.1.0,<=2024.6.1", extras = ["http"]}
|
||||
huggingface-hub = ">=0.21.2"
|
||||
multiprocess = "*"
|
||||
numpy = ">=1.17"
|
||||
packaging = "*"
|
||||
pandas = "*"
|
||||
pyarrow = ">=15.0.0"
|
||||
pyyaml = ">=5.1"
|
||||
requests = ">=2.32.2"
|
||||
tqdm = ">=4.66.3"
|
||||
xxhash = "*"
|
||||
|
||||
[package.extras]
|
||||
apache-beam = ["apache-beam (>=2.26.0)"]
|
||||
audio = ["librosa", "soundfile (>=0.12.1)", "soxr (>=0.4.0)"]
|
||||
benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
|
||||
dev = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tensorflow (>=2.16.0)", "tensorflow (>=2.6.0)", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "transformers (>=4.42.0)", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
|
||||
jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
|
||||
metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk (<3.8.2)", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"]
|
||||
quality = ["ruff (>=0.3.0)"]
|
||||
s3 = ["s3fs"]
|
||||
tensorflow = ["tensorflow (>=2.6.0)"]
|
||||
tensorflow-gpu = ["tensorflow (>=2.6.0)"]
|
||||
tests = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tensorflow (>=2.16.0)", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers (>=4.42.0)", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
tests-numpy2 = ["Pillow (>=9.4.0)", "absl-py", "decorator", "elasticsearch (<8.0.0)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tiktoken", "torch (>=2.0.0)", "typing-extensions (>=4.6.1)", "zstandard"]
|
||||
torch = ["torch"]
|
||||
vision = ["Pillow (>=9.4.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
version = "1.8.13"
|
||||
@@ -1165,21 +1121,6 @@ wrapt = ">=1.10,<2"
|
||||
[package.extras]
|
||||
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
version = "0.3.8"
|
||||
description = "serialize all of Python"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
|
||||
{file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
graph = ["objgraph (>=1.7.2)"]
|
||||
profile = ["gprof2dot (>=2022.7.29)"]
|
||||
|
||||
[[package]]
|
||||
name = "dirtyjson"
|
||||
version = "1.0.8"
|
||||
@@ -1298,13 +1239,13 @@ typing-extensions = ">=4.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "e2b-code-interpreter"
|
||||
version = "1.0.5"
|
||||
version = "1.1.0"
|
||||
description = "E2B Code Interpreter - Stateful code execution"
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.8"
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "e2b_code_interpreter-1.0.5-py3-none-any.whl", hash = "sha256:4c7814e9eabba58097bf5e4019d327b3a82fab0813eafca4311b29ca6ea0639d"},
|
||||
{file = "e2b_code_interpreter-1.0.5.tar.gz", hash = "sha256:e7f70b039e6a70f8e592f90f806d696dc1056919414daabeb89e86c9b650a987"},
|
||||
{file = "e2b_code_interpreter-1.1.0-py3-none-any.whl", hash = "sha256:292f8ddbb820475d5ffb1f3f2e67a42001a921d1c8fef40bd97a7f16f13adc64"},
|
||||
{file = "e2b_code_interpreter-1.1.0.tar.gz", hash = "sha256:4554eb002f9489965c2e7dd7fc967e62128db69b18dbb64975d4abbc0572e3ed"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1571,18 +1512,15 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "fsspec"
|
||||
version = "2024.6.1"
|
||||
version = "2025.3.0"
|
||||
description = "File-system specification"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "fsspec-2024.6.1-py3-none-any.whl", hash = "sha256:3cb443f8bcd2efb31295a5b9fdb02aee81d8452c80d28f97a6d0959e6cee101e"},
|
||||
{file = "fsspec-2024.6.1.tar.gz", hash = "sha256:fad7d7e209dd4c1208e3bbfda706620e0da5142bebbd9c384afb95b07e798e49"},
|
||||
{file = "fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3"},
|
||||
{file = "fsspec-2025.3.0.tar.gz", hash = "sha256:a935fd1ea872591f2b5148907d103488fc523295e6c64b835cfad8c3eca44972"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""}
|
||||
|
||||
[package.extras]
|
||||
abfs = ["adlfs"]
|
||||
adl = ["adlfs"]
|
||||
@@ -1607,7 +1545,7 @@ sftp = ["paramiko"]
|
||||
smb = ["smbprotocol"]
|
||||
ssh = ["paramiko"]
|
||||
test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"]
|
||||
test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
|
||||
test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
|
||||
test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
|
||||
tqdm = ["tqdm"]
|
||||
|
||||
@@ -2160,40 +2098,6 @@ files = [
|
||||
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.29.3"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = true
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.29.3-py3-none-any.whl", hash = "sha256:0b25710932ac649c08cdbefa6c6ccb8e88eef82927cacdb048efb726429453aa"},
|
||||
{file = "huggingface_hub-0.29.3.tar.gz", hash = "sha256:64519a25716e0ba382ba2d3fb3ca082e7c7eb4a2fc634d200e8380006e0760e5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
fsspec = ">=2023.5.0"
|
||||
packaging = ">=20.9"
|
||||
pyyaml = ">=5.1"
|
||||
requests = "*"
|
||||
tqdm = ">=4.42.1"
|
||||
typing-extensions = ">=3.7.4.3"
|
||||
|
||||
[package.extras]
|
||||
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
cli = ["InquirerPy (==0.3.4)"]
|
||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
|
||||
hf-transfer = ["hf-transfer (>=0.1.4)"]
|
||||
inference = ["aiohttp"]
|
||||
quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"]
|
||||
tensorflow = ["graphviz", "pydot", "tensorflow"]
|
||||
tensorflow-testing = ["keras (<3.0)", "tensorflow"]
|
||||
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
||||
torch = ["safetensors[torch]", "torch"]
|
||||
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "hyperframe"
|
||||
version = "6.1.0"
|
||||
@@ -2822,13 +2726,13 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "letta-client"
|
||||
version = "0.1.66"
|
||||
version = "0.1.68"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "letta_client-0.1.66-py3-none-any.whl", hash = "sha256:9f18f0161f5eec83ad4c7f02fd91dea31e97e3b688c29ae6116df1b252b892c6"},
|
||||
{file = "letta_client-0.1.66.tar.gz", hash = "sha256:589f2fc88776e60bbeeecf14de9dd9a938216b2225c1984f486aec9015b41896"},
|
||||
{file = "letta_client-0.1.68-py3-none-any.whl", hash = "sha256:2b027f79281560abc88a7033b8ff4b3ecdd3be7ba2fa4f17f844ec0ba8d7dfe1"},
|
||||
{file = "letta_client-0.1.68.tar.gz", hash = "sha256:c956498c6e0d726ec3f205a1dbaa0552d7945147a30d45ca7ea8ee7b77ac81aa"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3305,13 +3209,13 @@ traitlets = "*"
|
||||
|
||||
[[package]]
|
||||
name = "mcp"
|
||||
version = "1.3.0"
|
||||
version = "1.4.0"
|
||||
description = "Model Context Protocol SDK"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
files = [
|
||||
{file = "mcp-1.3.0-py3-none-any.whl", hash = "sha256:2829d67ce339a249f803f22eba5e90385eafcac45c94b00cab6cef7e8f217211"},
|
||||
{file = "mcp-1.3.0.tar.gz", hash = "sha256:f409ae4482ce9d53e7ac03f3f7808bcab735bdfc0fba937453782efb43882d45"},
|
||||
{file = "mcp-1.4.0-py3-none-any.whl", hash = "sha256:d2760e1ea7635b1e70da516698620a016cde214976416dd894f228600b08984c"},
|
||||
{file = "mcp-1.4.0.tar.gz", hash = "sha256:5b750b14ca178eeb7b2addbd94adb21785d7b4de5d5f3577ae193d787869e2dd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3327,6 +3231,7 @@ uvicorn = ">=0.23.1"
|
||||
[package.extras]
|
||||
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.12.4)"]
|
||||
rich = ["rich (>=13.9.4)"]
|
||||
ws = ["websockets (>=15.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
@@ -3516,30 +3421,6 @@ files = [
|
||||
[package.dependencies]
|
||||
typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "multiprocess"
|
||||
version = "0.70.16"
|
||||
description = "better multiprocessing and multithreading in Python"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"},
|
||||
{file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"},
|
||||
{file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"},
|
||||
{file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"},
|
||||
{file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"},
|
||||
{file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"},
|
||||
{file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"},
|
||||
{file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"},
|
||||
{file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"},
|
||||
{file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"},
|
||||
{file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"},
|
||||
{file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
dill = ">=0.3.8"
|
||||
|
||||
[[package]]
|
||||
name = "mypy-extensions"
|
||||
version = "1.0.0"
|
||||
@@ -4016,8 +3897,8 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
@@ -4612,60 +4493,6 @@ files = [
|
||||
[package.extras]
|
||||
tests = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "19.0.1"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:fc28912a2dc924dddc2087679cc8b7263accc71b9ff025a1362b004711661a69"},
|
||||
{file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fca15aabbe9b8355800d923cc2e82c8ef514af321e18b437c3d782aa884eaeec"},
|
||||
{file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad76aef7f5f7e4a757fddcdcf010a8290958f09e3470ea458c80d26f4316ae89"},
|
||||
{file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d03c9d6f2a3dffbd62671ca070f13fc527bb1867b4ec2b98c7eeed381d4f389a"},
|
||||
{file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:65cf9feebab489b19cdfcfe4aa82f62147218558d8d3f0fc1e9dea0ab8e7905a"},
|
||||
{file = "pyarrow-19.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:41f9706fbe505e0abc10e84bf3a906a1338905cbbcf1177b71486b03e6ea6608"},
|
||||
{file = "pyarrow-19.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:c6cb2335a411b713fdf1e82a752162f72d4a7b5dbc588e32aa18383318b05866"},
|
||||
{file = "pyarrow-19.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc55d71898ea30dc95900297d191377caba257612f384207fe9f8293b5850f90"},
|
||||
{file = "pyarrow-19.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:7a544ec12de66769612b2d6988c36adc96fb9767ecc8ee0a4d270b10b1c51e00"},
|
||||
{file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0148bb4fc158bfbc3d6dfe5001d93ebeed253793fff4435167f6ce1dc4bddeae"},
|
||||
{file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f24faab6ed18f216a37870d8c5623f9c044566d75ec586ef884e13a02a9d62c5"},
|
||||
{file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:4982f8e2b7afd6dae8608d70ba5bd91699077323f812a0448d8b7abdff6cb5d3"},
|
||||
{file = "pyarrow-19.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:49a3aecb62c1be1d822f8bf629226d4a96418228a42f5b40835c1f10d42e4db6"},
|
||||
{file = "pyarrow-19.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:008a4009efdb4ea3d2e18f05cd31f9d43c388aad29c636112c2966605ba33466"},
|
||||
{file = "pyarrow-19.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:80b2ad2b193e7d19e81008a96e313fbd53157945c7be9ac65f44f8937a55427b"},
|
||||
{file = "pyarrow-19.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:ee8dec072569f43835932a3b10c55973593abc00936c202707a4ad06af7cb294"},
|
||||
{file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d5d1ec7ec5324b98887bdc006f4d2ce534e10e60f7ad995e7875ffa0ff9cb14"},
|
||||
{file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3ad4c0eb4e2a9aeb990af6c09e6fa0b195c8c0e7b272ecc8d4d2b6574809d34"},
|
||||
{file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:d383591f3dcbe545f6cc62daaef9c7cdfe0dff0fb9e1c8121101cabe9098cfa6"},
|
||||
{file = "pyarrow-19.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b4c4156a625f1e35d6c0b2132635a237708944eb41df5fbe7d50f20d20c17832"},
|
||||
{file = "pyarrow-19.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:5bd1618ae5e5476b7654c7b55a6364ae87686d4724538c24185bbb2952679960"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e45274b20e524ae5c39d7fc1ca2aa923aab494776d2d4b316b49ec7572ca324c"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:d9dedeaf19097a143ed6da37f04f4051aba353c95ef507764d344229b2b740ae"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ebfb5171bb5f4a52319344ebbbecc731af3f021e49318c74f33d520d31ae0c4"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a21d39fbdb948857f67eacb5bbaaf36802de044ec36fbef7a1c8f0dd3a4ab2"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:99bc1bec6d234359743b01e70d4310d0ab240c3d6b0da7e2a93663b0158616f6"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:1b93ef2c93e77c442c979b0d596af45e4665d8b96da598db145b0fec014b9136"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:d9d46e06846a41ba906ab25302cf0fd522f81aa2a85a71021826f34639ad31ef"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:c0fe3dbbf054a00d1f162fda94ce236a899ca01123a798c561ba307ca38af5f0"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:96606c3ba57944d128e8a8399da4812f56c7f61de8c647e3470b417f795d0ef9"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f04d49a6b64cf24719c080b3c2029a3a5b16417fd5fd7c4041f94233af732f3"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a9137cf7e1640dce4c190551ee69d478f7121b5c6f323553b319cac936395f6"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:7c1bca1897c28013db5e4c83944a2ab53231f541b9e0c3f4791206d0c0de389a"},
|
||||
{file = "pyarrow-19.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:58d9397b2e273ef76264b45531e9d552d8ec8a6688b7390b5be44c02a37aade8"},
|
||||
{file = "pyarrow-19.0.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b9766a47a9cb56fefe95cb27f535038b5a195707a08bf61b180e642324963b46"},
|
||||
{file = "pyarrow-19.0.1-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:6c5941c1aac89a6c2f2b16cd64fe76bcdb94b2b1e99ca6459de4e6f07638d755"},
|
||||
{file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd44d66093a239358d07c42a91eebf5015aa54fccba959db899f932218ac9cc8"},
|
||||
{file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:335d170e050bcc7da867a1ed8ffb8b44c57aaa6e0843b156a501298657b1e972"},
|
||||
{file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:1c7556165bd38cf0cd992df2636f8bcdd2d4b26916c6b7e646101aff3c16f76f"},
|
||||
{file = "pyarrow-19.0.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:699799f9c80bebcf1da0983ba86d7f289c5a2a5c04b945e2f2bcf7e874a91911"},
|
||||
{file = "pyarrow-19.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:8464c9fbe6d94a7fe1599e7e8965f350fd233532868232ab2596a71586c5a429"},
|
||||
{file = "pyarrow-19.0.1.tar.gz", hash = "sha256:3bf266b485df66a400f282ac0b6d1b500b9d2ae73314a153dbe97d6d5cc8a99e"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.1"
|
||||
@@ -6600,138 +6427,6 @@ files = [
|
||||
{file = "wrapt-1.17.2.tar.gz", hash = "sha256:41388e9d4d1522446fe79d3213196bd9e3b301a336965b9e27ca2788ebd122f3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xxhash"
|
||||
version = "3.5.0"
|
||||
description = "Python binding for xxHash"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "xxhash-3.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ece616532c499ee9afbb83078b1b952beffef121d989841f7f4b3dc5ac0fd212"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3171f693dbc2cef6477054a665dc255d996646b4023fe56cb4db80e26f4cc520"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c5d3e570ef46adaf93fc81b44aca6002b5a4d8ca11bd0580c07eac537f36680"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7cb29a034301e2982df8b1fe6328a84f4b676106a13e9135a0d7e0c3e9f806da"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0d307d27099bb0cbeea7260eb39ed4fdb99c5542e21e94bb6fd29e49c57a23"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0342aafd421795d740e514bc9858ebddfc705a75a8c5046ac56d85fe97bf196"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3dbbd9892c5ebffeca1ed620cf0ade13eb55a0d8c84e0751a6653adc6ac40d0c"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4cc2d67fdb4d057730c75a64c5923abfa17775ae234a71b0200346bfb0a7f482"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:ec28adb204b759306a3d64358a5e5c07d7b1dd0ccbce04aa76cb9377b7b70296"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1328f6d8cca2b86acb14104e381225a3d7b42c92c4b86ceae814e5c400dbb415"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8d47ebd9f5d9607fd039c1fbf4994e3b071ea23eff42f4ecef246ab2b7334198"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b96d559e0fcddd3343c510a0fe2b127fbff16bf346dd76280b82292567523442"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-win32.whl", hash = "sha256:61c722ed8d49ac9bc26c7071eeaa1f6ff24053d553146d5df031802deffd03da"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:9bed5144c6923cc902cd14bb8963f2d5e034def4486ab0bbe1f58f03f042f9a9"},
|
||||
{file = "xxhash-3.5.0-cp310-cp310-win_arm64.whl", hash = "sha256:893074d651cf25c1cc14e3bea4fceefd67f2921b1bb8e40fcfeba56820de80c6"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02c2e816896dc6f85922ced60097bcf6f008dedfc5073dcba32f9c8dd786f3c1"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6027dcd885e21581e46d3c7f682cfb2b870942feeed58a21c29583512c3f09f8"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1308fa542bbdbf2fa85e9e66b1077eea3a88bef38ee8a06270b4298a7a62a166"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c28b2fdcee797e1c1961cd3bcd3d545cab22ad202c846235197935e1df2f8ef7"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:924361811732ddad75ff23e90efd9ccfda4f664132feecb90895bade6a1b4623"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89997aa1c4b6a5b1e5b588979d1da048a3c6f15e55c11d117a56b75c84531f5a"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:685c4f4e8c59837de103344eb1c8a3851f670309eb5c361f746805c5471b8c88"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbd2ecfbfee70bc1a4acb7461fa6af7748ec2ab08ac0fa298f281c51518f982c"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:25b5a51dc3dfb20a10833c8eee25903fd2e14059e9afcd329c9da20609a307b2"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a8fb786fb754ef6ff8c120cb96629fb518f8eb5a61a16aac3a979a9dbd40a084"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a905ad00ad1e1c34fe4e9d7c1d949ab09c6fa90c919860c1534ff479f40fd12d"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:963be41bcd49f53af6d795f65c0da9b4cc518c0dd9c47145c98f61cb464f4839"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-win32.whl", hash = "sha256:109b436096d0a2dd039c355fa3414160ec4d843dfecc64a14077332a00aeb7da"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b702f806693201ad6c0a05ddbbe4c8f359626d0b3305f766077d51388a6bac58"},
|
||||
{file = "xxhash-3.5.0-cp311-cp311-win_arm64.whl", hash = "sha256:c4dcb4120d0cc3cc448624147dba64e9021b278c63e34a38789b688fd0da9bf3"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:14470ace8bd3b5d51318782cd94e6f94431974f16cb3b8dc15d52f3b69df8e00"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:59aa1203de1cb96dbeab595ded0ad0c0056bb2245ae11fac11c0ceea861382b9"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08424f6648526076e28fae6ea2806c0a7d504b9ef05ae61d196d571e5c879c84"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61a1ff00674879725b194695e17f23d3248998b843eb5e933007ca743310f793"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2f2c61bee5844d41c3eb015ac652a0229e901074951ae48581d58bfb2ba01be"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d32a592cac88d18cc09a89172e1c32d7f2a6e516c3dfde1b9adb90ab5df54a6"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:70dabf941dede727cca579e8c205e61121afc9b28516752fd65724be1355cc90"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e5d0ddaca65ecca9c10dcf01730165fd858533d0be84c75c327487c37a906a27"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e5b5e16c5a480fe5f59f56c30abdeba09ffd75da8d13f6b9b6fd224d0b4d0a2"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149b7914451eb154b3dfaa721315117ea1dac2cc55a01bfbd4df7c68c5dd683d"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:eade977f5c96c677035ff39c56ac74d851b1cca7d607ab3d8f23c6b859379cab"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa9f547bd98f5553d03160967866a71056a60960be00356a15ecc44efb40ba8e"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-win32.whl", hash = "sha256:f7b58d1fd3551b8c80a971199543379be1cee3d0d409e1f6d8b01c1a2eebf1f8"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:fa0cafd3a2af231b4e113fba24a65d7922af91aeb23774a8b78228e6cd785e3e"},
|
||||
{file = "xxhash-3.5.0-cp312-cp312-win_arm64.whl", hash = "sha256:586886c7e89cb9828bcd8a5686b12e161368e0064d040e225e72607b43858ba2"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:37889a0d13b0b7d739cfc128b1c902f04e32de17b33d74b637ad42f1c55101f6"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:97a662338797c660178e682f3bc180277b9569a59abfb5925e8620fba00b9fc5"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f85e0108d51092bdda90672476c7d909c04ada6923c14ff9d913c4f7dc8a3bc"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2fd827b0ba763ac919440042302315c564fdb797294d86e8cdd4578e3bc7f3"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82085c2abec437abebf457c1d12fccb30cc8b3774a0814872511f0f0562c768c"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07fda5de378626e502b42b311b049848c2ef38784d0d67b6f30bb5008642f8eb"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c279f0d2b34ef15f922b77966640ade58b4ccdfef1c4d94b20f2a364617a493f"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:89e66ceed67b213dec5a773e2f7a9e8c58f64daeb38c7859d8815d2c89f39ad7"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bcd51708a633410737111e998ceb3b45d3dbc98c0931f743d9bb0a209033a326"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3ff2c0a34eae7df88c868be53a8dd56fbdf592109e21d4bfa092a27b0bf4a7bf"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4e28503dccc7d32e0b9817aa0cbfc1f45f563b2c995b7a66c4c8a0d232e840c7"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a6c50017518329ed65a9e4829154626f008916d36295b6a3ba336e2458824c8c"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-win32.whl", hash = "sha256:53a068fe70301ec30d868ece566ac90d873e3bb059cf83c32e76012c889b8637"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:80babcc30e7a1a484eab952d76a4f4673ff601f54d5142c26826502740e70b43"},
|
||||
{file = "xxhash-3.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:4811336f1ce11cac89dcbd18f3a25c527c16311709a89313c3acaf771def2d4b"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6e5f70f6dca1d3b09bccb7daf4e087075ff776e3da9ac870f86ca316736bb4aa"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e76e83efc7b443052dd1e585a76201e40b3411fe3da7af4fe434ec51b2f163b"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:33eac61d0796ca0591f94548dcfe37bb193671e0c9bcf065789b5792f2eda644"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ec70a89be933ea49222fafc3999987d7899fc676f688dd12252509434636622"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd86b8e7f703ec6ff4f351cfdb9f428955859537125904aa8c963604f2e9d3e7"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0adfbd36003d9f86c8c97110039f7539b379f28656a04097e7434d3eaf9aa131"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:63107013578c8a730419adc05608756c3fa640bdc6abe806c3123a49fb829f43"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:683b94dbd1ca67557850b86423318a2e323511648f9f3f7b1840408a02b9a48c"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5d2a01dcce81789cf4b12d478b5464632204f4c834dc2d064902ee27d2d1f0ee"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:a9d360a792cbcce2fe7b66b8d51274ec297c53cbc423401480e53b26161a290d"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:f0b48edbebea1b7421a9c687c304f7b44d0677c46498a046079d445454504737"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-win32.whl", hash = "sha256:7ccb800c9418e438b44b060a32adeb8393764da7441eb52aa2aa195448935306"},
|
||||
{file = "xxhash-3.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c3bc7bf8cb8806f8d1c9bf149c18708cb1c406520097d6b0a73977460ea03602"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:74752ecaa544657d88b1d1c94ae68031e364a4d47005a90288f3bab3da3c970f"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:dee1316133c9b463aa81aca676bc506d3f80d8f65aeb0bba2b78d0b30c51d7bd"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:602d339548d35a8579c6b013339fb34aee2df9b4e105f985443d2860e4d7ffaa"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:695735deeddfb35da1677dbc16a083445360e37ff46d8ac5c6fcd64917ff9ade"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1030a39ba01b0c519b1a82f80e8802630d16ab95dc3f2b2386a0b5c8ed5cbb10"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5bc08f33c4966f4eb6590d6ff3ceae76151ad744576b5fc6c4ba8edd459fdec"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:160e0c19ee500482ddfb5d5570a0415f565d8ae2b3fd69c5dcfce8a58107b1c3"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f1abffa122452481a61c3551ab3c89d72238e279e517705b8b03847b1d93d738"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:d5e9db7ef3ecbfc0b4733579cea45713a76852b002cf605420b12ef3ef1ec148"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:23241ff6423378a731d84864bf923a41649dc67b144debd1077f02e6249a0d54"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:82b833d5563fefd6fceafb1aed2f3f3ebe19f84760fdd289f8b926731c2e6e91"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0a80ad0ffd78bef9509eee27b4a29e56f5414b87fb01a888353e3d5bda7038bd"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-win32.whl", hash = "sha256:50ac2184ffb1b999e11e27c7e3e70cc1139047e7ebc1aa95ed12f4269abe98d4"},
|
||||
{file = "xxhash-3.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:392f52ebbb932db566973693de48f15ce787cabd15cf6334e855ed22ea0be5b3"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bfc8cdd7f33d57f0468b0614ae634cc38ab9202c6957a60e31d285a71ebe0301"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e0c48b6300cd0b0106bf49169c3e0536408dfbeb1ccb53180068a18b03c662ab"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe1a92cfbaa0a1253e339ccec42dbe6db262615e52df591b68726ab10338003f"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:33513d6cc3ed3b559134fb307aae9bdd94d7e7c02907b37896a6c45ff9ce51bd"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eefc37f6138f522e771ac6db71a6d4838ec7933939676f3753eafd7d3f4c40bc"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a606c8070ada8aa2a88e181773fa1ef17ba65ce5dd168b9d08038e2a61b33754"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:42eca420c8fa072cc1dd62597635d140e78e384a79bb4944f825fbef8bfeeef6"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:604253b2143e13218ff1ef0b59ce67f18b8bd1c4205d2ffda22b09b426386898"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:6e93a5ad22f434d7876665444a97e713a8f60b5b1a3521e8df11b98309bff833"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:7a46e1d6d2817ba8024de44c4fd79913a90e5f7265434cef97026215b7d30df6"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:30eb2efe6503c379b7ab99c81ba4a779748e3830241f032ab46bd182bf5873af"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c8aa771ff2c13dd9cda8166d685d7333d389fae30a4d2bb39d63ab5775de8606"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-win32.whl", hash = "sha256:5ed9ebc46f24cf91034544b26b131241b699edbfc99ec5e7f8f3d02d6eb7fba4"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:220f3f896c6b8d0316f63f16c077d52c412619e475f9372333474ee15133a558"},
|
||||
{file = "xxhash-3.5.0-cp39-cp39-win_arm64.whl", hash = "sha256:a7b1d8315d9b5e9f89eb2933b73afae6ec9597a258d52190944437158b49d38e"},
|
||||
{file = "xxhash-3.5.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:2014c5b3ff15e64feecb6b713af12093f75b7926049e26a580e94dcad3c73d8c"},
|
||||
{file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab81ef75003eda96239a23eda4e4543cedc22e34c373edcaf744e721a163986"},
|
||||
{file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e2febf914ace002132aa09169cc572e0d8959d0f305f93d5828c4836f9bc5a6"},
|
||||
{file = "xxhash-3.5.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5d3a10609c51da2a1c0ea0293fc3968ca0a18bd73838455b5bca3069d7f8e32b"},
|
||||
{file = "xxhash-3.5.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5a74f23335b9689b66eb6dbe2a931a88fcd7a4c2cc4b1cb0edba8ce381c7a1da"},
|
||||
{file = "xxhash-3.5.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2b4154c00eb22e4d543f472cfca430e7962a0f1d0f3778334f2e08a7ba59363c"},
|
||||
{file = "xxhash-3.5.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d30bbc1644f726b825b3278764240f449d75f1a8bdda892e641d4a688b1494ae"},
|
||||
{file = "xxhash-3.5.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fa0b72f2423e2aa53077e54a61c28e181d23effeaafd73fcb9c494e60930c8e"},
|
||||
{file = "xxhash-3.5.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:13de2b76c1835399b2e419a296d5b38dc4855385d9e96916299170085ef72f57"},
|
||||
{file = "xxhash-3.5.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:0691bfcc4f9c656bcb96cc5db94b4d75980b9d5589f2e59de790091028580837"},
|
||||
{file = "xxhash-3.5.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:297595fe6138d4da2c8ce9e72a04d73e58725bb60f3a19048bc96ab2ff31c692"},
|
||||
{file = "xxhash-3.5.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc1276d369452040cbb943300dc8abeedab14245ea44056a2943183822513a18"},
|
||||
{file = "xxhash-3.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2061188a1ba352fc699c82bff722f4baacb4b4b8b2f0c745d2001e56d0dfb514"},
|
||||
{file = "xxhash-3.5.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:38c384c434021e4f62b8d9ba0bc9467e14d394893077e2c66d826243025e1f81"},
|
||||
{file = "xxhash-3.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e6a4dd644d72ab316b580a1c120b375890e4c52ec392d4aef3c63361ec4d77d1"},
|
||||
{file = "xxhash-3.5.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:531af8845aaadcadf951b7e0c1345c6b9c68a990eeb74ff9acd8501a0ad6a1c9"},
|
||||
{file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ce379bcaa9fcc00f19affa7773084dd09f5b59947b3fb47a1ceb0179f91aaa1"},
|
||||
{file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd1b2281d01723f076df3c8188f43f2472248a6b63118b036e641243656b1b0f"},
|
||||
{file = "xxhash-3.5.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c770750cc80e8694492244bca7251385188bc5597b6a39d98a9f30e8da984e0"},
|
||||
{file = "xxhash-3.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b150b8467852e1bd844387459aa6fbe11d7f38b56e901f9f3b3e6aba0d660240"},
|
||||
{file = "xxhash-3.5.0.tar.gz", hash = "sha256:84f2caddf951c9cbf8dc2e22a89d4ccf5d86391ac6418fe81e3c67d0cf60b45f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yarl"
|
||||
version = "1.18.3"
|
||||
@@ -7047,4 +6742,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "21a4534904aa25a7879ba34eff7d3a10d2601ac9af821649217b60df0d8a404d"
|
||||
content-hash = "a41c6ec00f4b96db9d586b7142436d5f7cd1733cab5c9eaf734d1866782e2f94"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.6.39"
|
||||
version = "0.6.40"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
@@ -22,7 +22,6 @@ pytz = "^2023.3.post1"
|
||||
tqdm = "^4.66.1"
|
||||
black = {extras = ["jupyter"], version = "^24.2.0"}
|
||||
setuptools = "^70"
|
||||
datasets = { version = "^2.14.6", optional = true}
|
||||
prettytable = "^3.9.0"
|
||||
pgvector = { version = "^0.2.3", optional = true }
|
||||
pre-commit = {version = "^3.5.0", optional = true }
|
||||
|
||||
@@ -13,8 +13,9 @@ from letta.errors import ContextWindowExceededError
|
||||
from letta.llm_api.helpers import calculate_summarizer_cutoff
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
|
||||
|
||||
@@ -21,6 +21,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.app import app
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@@ -369,12 +370,12 @@ def test_deserialize_override_existing_tools(
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Extract tools before upload
|
||||
tool_data_list = result.get("tools", [])
|
||||
tool_names = {tool["name"]: tool for tool in tool_data_list}
|
||||
tool_data_list = result.tools
|
||||
tool_names = {tool.name: tool for tool in tool_data_list}
|
||||
|
||||
# Rewrite all the tool source code to the print_tool source code
|
||||
for tool in result["tools"]:
|
||||
tool["source_code"] = print_tool.source_code
|
||||
for tool in result.tools:
|
||||
tool.source_code = print_tool.source_code
|
||||
|
||||
# Deserialize the agent with different override settings
|
||||
server.agent_manager.deserialize(
|
||||
@@ -394,22 +395,6 @@ def test_deserialize_override_existing_tools(
|
||||
assert existing_tool.source_code == weather_tool.source_code, f"Tool {tool_name} should NOT be overridden"
|
||||
|
||||
|
||||
def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Check remapping on message_ids and messages is consistent
|
||||
assert sorted([m["id"] for m in result["messages"]]) == sorted(result["message_ids"])
|
||||
|
||||
# Deserialize the agent
|
||||
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user)
|
||||
|
||||
# Make sure all the messages are able to be retrieved
|
||||
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user)
|
||||
assert len(in_context_messages) == len(result["message_ids"])
|
||||
assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"])
|
||||
|
||||
|
||||
def test_agent_serialize_with_user_messages(local_client, server, serialize_test_agent, default_user, other_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
append_copy_suffix = False
|
||||
@@ -473,6 +458,18 @@ def test_agent_serialize_tool_calls(mock_e2b_api_key_none, local_client, server,
|
||||
assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0
|
||||
|
||||
|
||||
def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
|
||||
# Deserialize the agent
|
||||
agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user)
|
||||
|
||||
# Make sure all the messages are able to be retrieved
|
||||
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user)
|
||||
assert len(in_context_messages) == len(serialize_test_agent.message_ids)
|
||||
|
||||
|
||||
# FastAPI endpoint tests
|
||||
|
||||
|
||||
@@ -485,16 +482,18 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
agent_id = serialize_test_agent.id
|
||||
|
||||
# Step 1: Download the serialized agent
|
||||
response = fastapi_client.get(f"/v1/agents/{agent_id}/download", headers={"user_id": default_user.id})
|
||||
response = fastapi_client.get(f"/v1/agents/{agent_id}/export", headers={"user_id": default_user.id})
|
||||
assert response.status_code == 200, f"Download failed: {response.text}"
|
||||
|
||||
agent_json = response.json()
|
||||
# Ensure response matches expected schema
|
||||
agent_schema = AgentSchema.model_validate(response.json()) # Validate as Pydantic model
|
||||
agent_json = agent_schema.model_dump(mode="json") # Convert back to serializable JSON
|
||||
|
||||
# Step 2: Upload the serialized agent as a copy
|
||||
agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
|
||||
files = {"file": ("agent.json", agent_bytes, "application/json")}
|
||||
upload_response = fastapi_client.post(
|
||||
"/v1/agents/upload",
|
||||
"/v1/agents/import",
|
||||
headers={"user_id": other_user.id},
|
||||
params={"append_copy_suffix": append_copy_suffix, "override_existing_tools": False, "project_id": project_id},
|
||||
files=files,
|
||||
@@ -511,6 +510,7 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
# Step 3: Retrieve the copied agent
|
||||
serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user)
|
||||
agent_copy = server.agent_manager.get_agent_by_id(agent_id=copied_agent_id, actor=other_user)
|
||||
|
||||
print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json()))
|
||||
assert compare_agent_state(agent_copy, serialize_test_agent, append_copy_suffix=append_copy_suffix)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPro
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
@@ -272,7 +273,7 @@ def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
text="Hello, world!",
|
||||
content=[TextContent(text="Hello, world!")],
|
||||
)
|
||||
|
||||
msg = server.message_manager.create_message(message, actor=default_user)
|
||||
@@ -614,6 +615,104 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
|
||||
assert updated_agent.updated_at > last_updated_timestamp
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Listing
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_list_agents_select_fields_empty(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
# Create an agent using the comprehensive fixture.
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# List agents using an empty list for select_fields.
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=[])
|
||||
# Assert that the agent is returned and basic fields are present.
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
assert agent.id is not None
|
||||
assert agent.name is not None
|
||||
|
||||
# Assert no relationships were loaded
|
||||
assert len(agent.tools) == 0
|
||||
assert len(agent.tags) == 0
|
||||
|
||||
|
||||
def test_list_agents_select_fields_none(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
# Create an agent using the comprehensive fixture.
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# List agents using an empty list for select_fields.
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=None)
|
||||
# Assert that the agent is returned and basic fields are present.
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
assert agent.id is not None
|
||||
assert agent.name is not None
|
||||
|
||||
# Assert no relationships were loaded
|
||||
assert len(agent.tools) > 0
|
||||
assert len(agent.tags) > 0
|
||||
|
||||
|
||||
def test_list_agents_select_fields_specific(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Choose a subset of valid relationship fields.
|
||||
valid_fields = ["tools", "tags"]
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=valid_fields)
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
# Depending on your to_pydantic() implementation,
|
||||
# verify that the fields exist in the returned pydantic model.
|
||||
# (Note: These assertions may require that your CreateAgent fixture sets up these relationships.)
|
||||
assert agent.tools
|
||||
assert sorted(agent.tags) == ["a", "b"]
|
||||
assert not agent.memory.blocks
|
||||
|
||||
|
||||
def test_list_agents_select_fields_invalid(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Provide field names that are not recognized.
|
||||
invalid_fields = ["foobar", "nonexistent_field"]
|
||||
# The expectation is that these fields are simply ignored.
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=invalid_fields)
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
# Verify that standard fields are still present.c
|
||||
assert agent.id is not None
|
||||
assert agent.name is not None
|
||||
|
||||
|
||||
def test_list_agents_select_fields_duplicates(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Provide duplicate valid field names.
|
||||
duplicate_fields = ["tools", "tools", "tags", "tags"]
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=duplicate_fields)
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
# Verify that the agent pydantic representation includes the relationships.
|
||||
# Even if duplicates were provided, the query should not break.
|
||||
assert isinstance(agent.tools, list)
|
||||
assert isinstance(agent.tags, list)
|
||||
|
||||
|
||||
def test_list_agents_select_fields_mixed(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
|
||||
# Mix valid fields with an invalid one.
|
||||
mixed_fields = ["tools", "invalid_field"]
|
||||
agents = server.agent_manager.list_agents(actor=default_user, include_relationships=mixed_fields)
|
||||
assert len(agents) >= 1
|
||||
agent = agents[0]
|
||||
# Valid fields should be loaded and accessible.
|
||||
assert agent.tools
|
||||
# Since "invalid_field" is not recognized, it should have no adverse effect.
|
||||
# You might optionally check that no extra attribute is created on the pydantic model.
|
||||
assert not hasattr(agent, "invalid_field")
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Tools Relationship
|
||||
# ======================================================================================================================
|
||||
@@ -1098,7 +1197,7 @@ def test_reset_messages_with_existing_messages(server: SyncServer, sarah_agent,
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
role="user",
|
||||
text="Hello, Sarah!",
|
||||
content=[TextContent(text="Hello, Sarah!")],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1107,7 +1206,7 @@ def test_reset_messages_with_existing_messages(server: SyncServer, sarah_agent,
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
role="assistant",
|
||||
text="Hello, user!",
|
||||
content=[TextContent(text="Hello, user!")],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1138,7 +1237,7 @@ def test_reset_messages_idempotency(server: SyncServer, sarah_agent, default_use
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
role="user",
|
||||
text="Hello, Sarah!",
|
||||
content=[TextContent(text="Hello, Sarah!")],
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1964,7 +2063,10 @@ def test_message_size(server: SyncServer, hello_world_message_fixture, default_u
|
||||
# Create additional test messages
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=base_message.agent_id,
|
||||
role=base_message.role,
|
||||
content=[TextContent(text=f"Test message {i}")],
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
@@ -1992,7 +2094,10 @@ def create_test_messages(server: SyncServer, base_message: PydanticMessage, defa
|
||||
"""Helper function to create test messages for all tests"""
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=base_message.agent_id,
|
||||
role=base_message.role,
|
||||
content=[TextContent(text=f"Test message {i}")],
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
@@ -3172,7 +3277,7 @@ def test_job_messages_pagination(server: SyncServer, default_run, default_user,
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
text=f"Test message {i}",
|
||||
content=[TextContent(text=f"Test message {i}")],
|
||||
)
|
||||
msg = server.message_manager.create_message(message, actor=default_user)
|
||||
message_ids.append(msg.id)
|
||||
@@ -3285,7 +3390,7 @@ def test_job_messages_ordering(server: SyncServer, default_run, default_user, sa
|
||||
for i, created_at in enumerate(message_times):
|
||||
message = PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
text="Test message",
|
||||
content=[TextContent(text="Test message")],
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
created_at=created_at,
|
||||
@@ -3354,19 +3459,19 @@ def test_job_messages_filter(server: SyncServer, default_run, default_user, sara
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
text="Hello",
|
||||
content=[TextContent(text="Hello")],
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
text="Hi there!",
|
||||
content=[TextContent(text="Hi there!")],
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
),
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
text="Let me help you with that",
|
||||
content=[TextContent(text="Let me help you with that")],
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
tool_calls=[
|
||||
@@ -3421,7 +3526,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
|
||||
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
|
||||
content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')],
|
||||
tool_calls=(
|
||||
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
|
||||
if i % 2 == 1
|
||||
@@ -3472,7 +3577,7 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant,
|
||||
text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}',
|
||||
content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')],
|
||||
tool_calls=(
|
||||
[{"type": "function", "id": f"call_{i//2}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}]
|
||||
if i % 2 == 1
|
||||
|
||||
258
tests/test_multi_agent.py
Normal file
258
tests/test_multi_agent.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Provider, Step
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.group import DynamicManager, GroupCreate, SupervisorManager
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def org_id(server):
|
||||
org = server.organization_manager.create_default_organization()
|
||||
|
||||
yield org.id
|
||||
|
||||
# cleanup
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
server.organization_manager.delete_organization_by_id(org.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def actor(server, org_id):
|
||||
user = server.user_manager.create_default_user()
|
||||
yield user
|
||||
|
||||
# cleanup
|
||||
server.user_manager.delete_user_by_id(user.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def participant_agent_ids(server, actor):
|
||||
agent_fred = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="fred",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="Your name is fred and you like to ski and have been wanting to go on a ski trip soon. You are speaking in a group chat with other agent pals where you participate in friendly banter.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agent_velma = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="velma",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="Your name is velma and you like tropical locations. You are speaking in a group chat with other agent friends and you love to include everyone.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agent_daphne = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="daphne",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="Your name is daphne and you love traveling abroad. You are speaking in a group chat with other agent friends and you love to keep in touch with them.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
agent_shaggy = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="shaggy",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="Your name is shaggy and your best friend is your dog, scooby. You are speaking in a group chat with other agent friends and you like to solve mysteries with them.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
yield [agent_fred.id, agent_velma.id, agent_daphne.id, agent_shaggy.id]
|
||||
|
||||
# cleanup
|
||||
server.agent_manager.delete_agent(agent_fred.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_velma.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_daphne.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_shaggy.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def manager_agent_id(server, actor):
|
||||
agent_scooby = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="scooby",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a puppy operations agent for Letta and you help run multi-agent group chats. Your job is to get to know the agents in your group and pick who is best suited to speak next in the conversation.",
|
||||
),
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
yield agent_scooby.id
|
||||
|
||||
# cleanup
|
||||
server.agent_manager.delete_agent(agent_scooby.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin(server, actor, participant_agent_ids):
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=participant_agent_ids,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
try:
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="what is everyone up to for the holidays?",
|
||||
),
|
||||
],
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
assert response.usage.step_count == len(participant_agent_ids)
|
||||
assert len(response.messages) == response.usage.step_count * 2
|
||||
|
||||
finally:
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor(server, actor, participant_agent_ids):
|
||||
agent_scrappy = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="shaggy",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a puppy operations agent for Letta and you help run multi-agent group chats. Your role is to supervise the group, sending messages and aggregating the responses.",
|
||||
),
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=participant_agent_ids,
|
||||
manager_config=SupervisorManager(
|
||||
manager_agent_id=agent_scrappy.id,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
try:
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="ask everyone what they like to do for fun and then come up with an activity for everyone to do together.",
|
||||
),
|
||||
],
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
assert response.usage.step_count == 2
|
||||
assert len(response.messages) == 5
|
||||
|
||||
# verify tool call
|
||||
assert response.messages[0].message_type == "reasoning_message"
|
||||
assert (
|
||||
response.messages[1].message_type == "tool_call_message"
|
||||
and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group"
|
||||
)
|
||||
assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len(
|
||||
participant_agent_ids
|
||||
)
|
||||
assert response.messages[3].message_type == "reasoning_message"
|
||||
assert response.messages[4].message_type == "assistant_message"
|
||||
|
||||
finally:
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_id=agent_scrappy.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_agent_ids):
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=participant_agent_ids,
|
||||
manager_config=DynamicManager(
|
||||
manager_agent_id=manager_agent_id,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
try:
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
MessageCreate(role="user", content="what is everyone up to for the holidays?"),
|
||||
],
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
assert response.usage.step_count == len(participant_agent_ids) * 2
|
||||
assert len(response.messages) == response.usage.step_count * 2
|
||||
|
||||
finally:
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
|
||||
Reference in New Issue
Block a user