test: add more robust multi-agent testing (#1444)

This commit is contained in:
cthomas
2025-03-28 14:21:54 -07:00
committed by GitHub
parent 096e1aed5d
commit 1a5c08c62b
19 changed files with 568 additions and 211 deletions

View File

@@ -0,0 +1,31 @@
"""add ordered agent ids to groups
Revision ID: a66510f83fc2
Revises: bdddd421ec41
Create Date: 2025-03-27 11:11:51.709498
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "a66510f83fc2"
down_revision: Union[str, None] = "bdddd421ec41"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("groups", sa.Column("agent_ids", sa.JSON(), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("groups", "agent_ids")
# ### end Alembic commands ###

View File

@@ -220,6 +220,7 @@ class Agent(BaseAgent):
messages: List[Message],
tool_returns: Optional[List[ToolReturn]] = None,
include_function_failed_message: bool = False,
group_id: Optional[str] = None,
) -> List[Message]:
"""
Handle error from function call response
@@ -240,7 +241,9 @@ class Agent(BaseAgent):
"content": function_response,
"tool_call_id": tool_call_id,
},
name=self.agent_state.name,
tool_returns=tool_returns,
group_id=group_id,
)
messages.append(new_message)
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message)
@@ -329,6 +332,7 @@ class Agent(BaseAgent):
stream=stream,
stream_interface=self.interface,
put_inner_thoughts_first=put_inner_thoughts_first,
name=self.agent_state.name,
)
log_telemetry(self.logger, "_get_ai_reply create finish")
@@ -372,6 +376,7 @@ class Agent(BaseAgent):
# and now we want to use it in the creation of the Message object
# TODO figure out a cleaner way to do this
response_message_id: Optional[str] = None,
group_id: Optional[str] = None,
) -> Tuple[List[Message], bool, bool]:
"""Handles parsing and function execution"""
log_telemetry(self.logger, "_handle_ai_response start")
@@ -417,6 +422,8 @@ class Agent(BaseAgent):
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=response_message.model_dump(),
name=self.agent_state.name,
group_id=group_id,
)
) # extend conversation with assistant's reply
self.logger.debug(f"Function call message: {messages[-1]}")
@@ -449,7 +456,7 @@ class Agent(BaseAgent):
error_msg = f"No function named {function_name}"
function_response = "None" # more like "never ran?"
messages = self._handle_function_error_response(
error_msg, tool_call_id, function_name, function_args, function_response, messages
error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id
)
return messages, False, True # force a heartbeat to allow agent to handle error
@@ -464,7 +471,7 @@ class Agent(BaseAgent):
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}"
function_response = "None" # more like "never ran?"
messages = self._handle_function_error_response(
error_msg, tool_call_id, function_name, function_args, function_response, messages
error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id
)
return messages, False, True # force a heartbeat to allow agent to handle error
@@ -535,6 +542,7 @@ class Agent(BaseAgent):
function_response,
messages,
[tool_return],
group_id=group_id,
)
return messages, False, True # force a heartbeat to allow agent to handle error
@@ -571,6 +579,7 @@ class Agent(BaseAgent):
messages,
[ToolReturn(status="error", stderr=[error_msg_user])],
include_function_failed_message=True,
group_id=group_id,
)
return messages, False, True # force a heartbeat to allow agent to handle error
@@ -595,6 +604,7 @@ class Agent(BaseAgent):
messages,
[tool_return],
include_function_failed_message=True,
group_id=group_id,
)
return messages, False, True # force a heartbeat to allow agent to handle error
@@ -620,7 +630,9 @@ class Agent(BaseAgent):
"content": function_response,
"tool_call_id": tool_call_id,
},
name=self.agent_state.name,
tool_returns=[tool_return] if sandbox_run_result else None,
group_id=group_id,
)
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
@@ -636,6 +648,8 @@ class Agent(BaseAgent):
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=response_message.model_dump(),
name=self.agent_state.name,
group_id=group_id,
)
) # extend conversation with assistant's reply
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
@@ -799,7 +813,11 @@ class Agent(BaseAgent):
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
input_message_sequence = in_context_messages + messages
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
if (
len(input_message_sequence) > 1
and input_message_sequence[-1].role != "user"
and input_message_sequence[-1].group_id is None
):
self.logger.warning(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
# Step 2: send the conversation and available functions to the LLM
@@ -832,6 +850,7 @@ class Agent(BaseAgent):
# TODO this is kind of hacky, find a better way to handle this
# the only time we set up message creation ahead of time is when streaming is on
response_message_id=response.id if stream else None,
group_id=input_message_sequence[-1].group_id,
)
# Step 6: extend the message history

View File

@@ -16,7 +16,7 @@ class DynamicMultiAgent(Agent):
self,
interface: AgentInterface,
agent_state: AgentState,
user: User = None,
user: User,
# custom
group_id: str = "",
agent_ids: List[str] = [],
@@ -128,7 +128,7 @@ class DynamicMultiAgent(Agent):
)
for message in assistant_messages
]
message_index[agent_id] = len(chat_history) + len(new_messages)
message_index[speaker_id] = len(chat_history) + len(new_messages)
# sum usage
total_usage.prompt_tokens += usage_stats.prompt_tokens
@@ -251,10 +251,10 @@ class DynamicMultiAgent(Agent):
chat_history: List[Message],
agent_id_options: List[str],
) -> Message:
chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history]
text_chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history]
for message in new_messages:
chat_history.append(f"{message.name or 'user'}: {message.content}")
context_messages = "\n".join(chat_history)
text_chat_history.append(f"{message.name or 'user'}: {message.content}")
context_messages = "\n".join(text_chat_history)
message_text = (
"Choose the most suitable agent to reply to the latest message in the "

View File

@@ -859,6 +859,7 @@ def anthropic_chat_completions_process_stream(
create_message_id: bool = True,
create_message_datetime: bool = True,
betas: List[str] = ["tools-2024-04-04"],
name: Optional[str] = None,
) -> ChatCompletionResponse:
"""Process a streaming completion response from Anthropic, similar to OpenAI's streaming.
@@ -951,6 +952,7 @@ def anthropic_chat_completions_process_stream(
# if extended_thinking is on, then reasoning_content will be flowing as chunks
# TODO handle emitting redacted reasoning content (e.g. as concat?)
expect_reasoning_content=extended_thinking,
name=name,
)
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
stream_interface.process_refresh(chat_completion_response)

View File

@@ -140,6 +140,7 @@ def create(
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
model_settings: Optional[dict] = None, # TODO: eventually pass from server
put_inner_thoughts_first: bool = True,
name: Optional[str] = None,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from letta.utils import printd
@@ -206,6 +207,7 @@ def create(
api_key=api_key,
chat_completion_request=data,
stream_interface=stream_interface,
name=name,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
@@ -255,6 +257,7 @@ def create(
api_key=api_key,
chat_completion_request=data,
stream_interface=stream_interface,
name=name,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
@@ -359,6 +362,7 @@ def create(
stream_interface=stream_interface,
extended_thinking=llm_config.enable_reasoner,
max_reasoning_tokens=llm_config.max_reasoning_tokens,
name=name,
)
else:
@@ -531,6 +535,7 @@ def create(
api_key=model_settings.deepseek_api_key,
chat_completion_request=data,
stream_interface=stream_interface,
name=name,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False

View File

@@ -185,6 +185,7 @@ def openai_chat_completions_process_stream(
# however, we don't necessarily want to put these
# expect_reasoning_content: bool = False,
expect_reasoning_content: bool = True,
name: Optional[str] = None,
) -> ChatCompletionResponse:
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
@@ -272,6 +273,7 @@ def openai_chat_completions_process_stream(
message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id,
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
expect_reasoning_content=expect_reasoning_content,
name=name,
)
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
stream_interface.process_refresh(chat_completion_response)

View File

@@ -1,7 +1,7 @@
import uuid
from typing import List, Optional
from sqlalchemy import ForeignKey, String
from sqlalchemy import JSON, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
@@ -23,11 +23,8 @@ class Group(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
agent_ids: Mapped[List[str]] = mapped_column(JSON, nullable=False, doc="Ordered list of agent IDs in this group")
agents: Mapped[List["Agent"]] = relationship(
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
)
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group")
@property
def agent_ids(self) -> List[str]:
return [agent.id for agent in self.agents]

View File

@@ -14,7 +14,7 @@ class RoundRobinMultiAgent(Agent):
self,
interface: AgentInterface,
agent_state: AgentState,
user: User = None,
user: User,
# custom
group_id: str = "",
agent_ids: List[str] = [],
@@ -45,7 +45,7 @@ class RoundRobinMultiAgent(Agent):
for agent_id in self.agent_ids:
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
message_index = {}
message_index = {agent_id: 0 for agent_id in self.agent_ids}
chat_history: List[Message] = []
new_messages = messages
speaker_id = None
@@ -91,7 +91,7 @@ class RoundRobinMultiAgent(Agent):
MessageCreate(
role="system",
content=message.content,
name=participant_agent.agent_state.name,
name=message.name,
)
for message in assistant_messages
]
@@ -138,10 +138,21 @@ class RoundRobinMultiAgent(Agent):
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
persona_block = agent_state.memory.get_block(label="persona")
group_chat_participant_persona = (
"\n\n====Group Chat Contex===="
f"\nYou are speaking in a group chat with {len(self.agent_ids) - 1} other "
"agents and one user. Respond to new messages in the group chat when prompted. "
f"Description of the group: {self.description}"
f"%%% GROUP CHAT CONTEXT %%% "
f"You are speaking in a group chat with {len(self.agent_ids)} other participants. "
f"Group Description: {self.description} "
"INTERACTION GUIDELINES:\n"
"1. Be aware that others can see your messages - communicate as if in a real group conversation\n"
"2. Acknowledge and build upon others' contributions when relevant\n"
"3. Stay on topic while adding your unique perspective based on your role and personality\n"
"4. Be concise but engaging - give others space to contribute\n"
"5. Maintain your character's personality while being collaborative\n"
"6. Feel free to ask questions to other participants to encourage discussion\n"
"7. If someone addresses you directly, acknowledge their message\n"
"8. Share relevant experiences or knowledge that adds value to the conversation\n\n"
"Remember: This is a natural group conversation. Interact as you would in a real group setting, "
"staying true to your character while fostering meaningful dialogue. "
"%%% END GROUP CHAT CONTEXT %%%"
)
agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_participant_persona)
return Agent(

View File

@@ -62,4 +62,10 @@ ManagerConfigUnion = Annotated[
class GroupCreate(BaseModel):
agent_ids: List[str] = Field(..., description="")
description: str = Field(..., description="")
manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="")
class GroupUpdate(BaseModel):
agent_ids: Optional[List[str]] = Field(None, description="")
description: Optional[str] = Field(None, description="")
manager_config: Optional[ManagerConfigUnion] = Field(None, description="")

View File

@@ -226,6 +226,7 @@ class Message(BaseMessage):
id=self.id,
date=self.created_at,
reasoning=self.content[0].text,
name=self.name,
)
)
# Otherwise, we may have a list of multiple types
@@ -239,6 +240,7 @@ class Message(BaseMessage):
id=self.id,
date=self.created_at,
reasoning=content_part.text,
name=self.name,
)
)
elif isinstance(content_part, ReasoningContent):
@@ -250,6 +252,7 @@ class Message(BaseMessage):
reasoning=content_part.reasoning,
source="reasoner_model", # TODO do we want to tag like this?
signature=content_part.signature,
name=self.name,
)
)
elif isinstance(content_part, RedactedReasoningContent):
@@ -260,6 +263,7 @@ class Message(BaseMessage):
date=self.created_at,
state="redacted",
hidden_reasoning=content_part.data,
name=self.name,
)
)
else:
@@ -282,6 +286,7 @@ class Message(BaseMessage):
id=self.id,
date=self.created_at,
content=message_string,
name=self.name,
)
)
else:
@@ -294,6 +299,7 @@ class Message(BaseMessage):
arguments=tool_call.function.arguments,
tool_call_id=tool_call.id,
),
name=self.name,
)
)
elif self.role == MessageRole.tool:
@@ -334,6 +340,7 @@ class Message(BaseMessage):
tool_call_id=self.tool_call_id,
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
name=self.name,
)
)
elif self.role == MessageRole.user:
@@ -349,6 +356,7 @@ class Message(BaseMessage):
id=self.id,
date=self.created_at,
content=message_str or text_content,
name=self.name,
)
)
elif self.role == MessageRole.system:
@@ -363,6 +371,7 @@ class Message(BaseMessage):
id=self.id,
date=self.created_at,
content=text_content,
name=self.name,
)
)
else:
@@ -379,6 +388,8 @@ class Message(BaseMessage):
allow_functions_style: bool = False, # allow deprecated functions style?
created_at: Optional[datetime] = None,
id: Optional[str] = None,
name: Optional[str] = None,
group_id: Optional[str] = None,
tool_returns: Optional[List[ToolReturn]] = None,
):
"""Convert a ChatCompletion message object into a Message object (synced to DB)"""
@@ -426,12 +437,13 @@ class Message(BaseMessage):
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole.tool, # NOTE
content=content,
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
name=name,
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
created_at=created_at,
id=str(id),
tool_returns=tool_returns,
group_id=group_id,
)
else:
return Message(
@@ -440,11 +452,12 @@ class Message(BaseMessage):
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole.tool, # NOTE
content=content,
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
name=name,
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
created_at=created_at,
tool_returns=tool_returns,
group_id=group_id,
)
elif "function_call" in openai_message_dict and openai_message_dict["function_call"] is not None:
@@ -473,12 +486,13 @@ class Message(BaseMessage):
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]),
content=content,
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
name=name,
tool_calls=tool_calls,
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
created_at=created_at,
id=str(id),
tool_returns=tool_returns,
group_id=group_id,
)
else:
return Message(
@@ -492,6 +506,7 @@ class Message(BaseMessage):
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
created_at=created_at,
tool_returns=tool_returns,
group_id=group_id,
)
else:
@@ -520,12 +535,13 @@ class Message(BaseMessage):
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]),
content=content,
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
name=name,
tool_calls=tool_calls,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
created_at=created_at,
id=str(id),
tool_returns=tool_returns,
group_id=group_id,
)
else:
return Message(
@@ -534,11 +550,12 @@ class Message(BaseMessage):
# standard fields expected in an OpenAI ChatCompletion message object
role=MessageRole(openai_message_dict["role"]),
content=content,
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
name=name,
tool_calls=tool_calls,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
created_at=created_at,
tool_returns=tool_returns,
group_id=group_id,
)
def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict:
@@ -579,9 +596,6 @@ class Message(BaseMessage):
"content": text_content,
"role": self.role,
}
# Optional field, do not include if null
if self.name is not None:
openai_message["name"] = self.name
elif self.role == "user":
assert all([v is not None for v in [text_content, self.role]]), vars(self)
@@ -589,9 +603,6 @@ class Message(BaseMessage):
"content": text_content,
"role": self.role,
}
# Optional field, do not include if null
if self.name is not None:
openai_message["name"] = self.name
elif self.role == "assistant":
assert self.tool_calls is not None or text_content is not None
@@ -599,9 +610,7 @@ class Message(BaseMessage):
"content": None if put_inner_thoughts_in_kwargs else text_content,
"role": self.role,
}
# Optional fields, do not include if null
if self.name is not None:
openai_message["name"] = self.name
if self.tool_calls is not None:
if put_inner_thoughts_in_kwargs:
# put the inner thoughts inside the tool call before casting to a dict

View File

@@ -465,6 +465,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# if we expect `reasoning_content``, then that's what gets mapped to ReasoningMessage
# and `content` needs to be handled outside the interface
expect_reasoning_content: bool = False,
name: Optional[str] = None,
) -> Optional[Union[ReasoningMessage, ToolCallMessage, AssistantMessage]]:
"""
Example data from non-streaming response looks like:
@@ -497,6 +498,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
reasoning=message_delta.reasoning_content,
signature=message_delta.reasoning_content_signature,
source="reasoner_model" if message_delta.reasoning_content_signature else "non_reasoner_model",
name=name,
)
elif expect_reasoning_content and message_delta.redacted_reasoning_content is not None:
processed_chunk = HiddenReasoningMessage(
@@ -504,6 +506,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=message_date,
hidden_reasoning=message_delta.redacted_reasoning_content,
state="redacted",
name=name,
)
elif expect_reasoning_content and message_delta.content is not None:
# "ignore" content if we expect reasoning content
@@ -530,6 +533,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
arguments=json.dumps(json_reasoning_content.get("arguments")),
tool_call_id=None,
),
name=name,
)
except json.JSONDecodeError as e:
@@ -559,6 +563,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
id=message_id,
date=message_date,
reasoning=message_delta.content,
name=name,
)
# tool calls
@@ -607,7 +612,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# TODO: Assumes consistent state and that prev_content is subset of new_content
diff = new_content.replace(prev_content, "", 1)
self.current_json_parse_result = parsed_args
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff)
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name)
else:
return None
@@ -639,6 +644,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
arguments=tool_call_delta.get("arguments"),
tool_call_id=tool_call_delta.get("id"),
),
name=name,
)
elif self.inner_thoughts_in_kwargs and tool_call.function:
@@ -674,6 +680,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
id=message_id,
date=message_date,
reasoning=updates_inner_thoughts,
name=name,
)
# Additionally inner thoughts may stream back with a chunk of main JSON
# In that case, since we can only return a chunk at a time, we should buffer it
@@ -709,6 +716,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
arguments=None,
tool_call_id=self.function_id_buffer,
),
name=name,
)
# Record what the last function name we flushed was
@@ -765,6 +773,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
id=message_id,
date=message_date,
content=combined_chunk,
name=name,
)
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
@@ -789,7 +798,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# TODO: Assumes consistent state and that prev_content is subset of new_content
diff = new_content.replace(prev_content, "", 1)
self.current_json_parse_result = parsed_args
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff)
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name)
else:
return None
@@ -813,6 +822,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
name=name,
)
# clear buffer
self.function_args_buffer = None
@@ -827,6 +837,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
name=name,
)
self.function_id_buffer = None
@@ -955,6 +966,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
arguments=tool_call_delta.get("arguments"),
tool_call_id=tool_call_delta.get("id"),
),
name=name,
)
elif choice.finish_reason is not None:
@@ -1035,6 +1047,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
message_id: str,
message_date: datetime,
expect_reasoning_content: bool = False,
name: Optional[str] = None,
):
"""Process a streaming chunk from an OpenAI-compatible server.
@@ -1060,6 +1073,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
message_id=message_id,
message_date=message_date,
expect_reasoning_content=expect_reasoning_content,
name=name,
)
if processed_chunk is None:
@@ -1087,6 +1101,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
id=msg_obj.id,
date=msg_obj.created_at,
reasoning=msg,
name=msg_obj.name,
)
self._push_to_buffer(processed_chunk)
@@ -1097,6 +1112,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
id=msg_obj.id,
date=msg_obj.created_at,
reasoning=content.text,
name=msg_obj.name,
)
elif isinstance(content, ReasoningContent):
processed_chunk = ReasoningMessage(
@@ -1105,6 +1121,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
source="reasoner_model",
reasoning=content.reasoning,
signature=content.signature,
name=msg_obj.name,
)
elif isinstance(content, RedactedReasoningContent):
processed_chunk = HiddenReasoningMessage(
@@ -1112,6 +1129,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=msg_obj.created_at,
state="redacted",
hidden_reasoning=content.data,
name=msg_obj.name,
)
self._push_to_buffer(processed_chunk)
@@ -1172,6 +1190,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
id=msg_obj.id,
date=msg_obj.created_at,
content=func_args["message"],
name=msg_obj.name,
)
self._push_to_buffer(processed_chunk)
except Exception as e:
@@ -1194,6 +1213,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
id=msg_obj.id,
date=msg_obj.created_at,
content=func_args[self.assistant_message_tool_kwarg],
name=msg_obj.name,
)
# Store the ID of the tool call so allow skipping the corresponding response
self.prev_assistant_message_id = function_call.id
@@ -1206,6 +1226,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
arguments=function_call.function.arguments,
tool_call_id=function_call.id,
),
name=msg_obj.name,
)
# processed_chunk = {
@@ -1245,6 +1266,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
tool_call_id=msg_obj.tool_call_id,
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
name=msg_obj.name,
)
elif msg.startswith("Error: "):
@@ -1259,6 +1281,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
tool_call_id=msg_obj.tool_call_id,
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
name=msg_obj.name,
)
else:

View File

@@ -527,6 +527,7 @@ def list_messages(
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
group_id: Optional[str] = Query(None, description="Group ID to filter messages by."),
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
@@ -543,6 +544,7 @@ def list_messages(
after=after,
before=before,
limit=limit,
group_id=group_id,
reverse=True,
return_message_object=False,
use_assistant_message=use_assistant_message,

View File

@@ -1,11 +1,13 @@
from typing import Annotated, List, Optional
from fastapi import APIRouter, Body, Depends, Header, Query
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
from fastapi.responses import JSONResponse
from pydantic import Field
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.group import Group, GroupCreate, ManagerType
from letta.schemas.letta_message import LettaMessageUnion
from letta.orm.errors import NoResultFound
from letta.schemas.group import Group, GroupCreate, GroupUpdate, ManagerType
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.server.rest_api.utils import get_letta_server
@@ -14,21 +16,6 @@ from letta.server.server import SyncServer
router = APIRouter(prefix="/groups", tags=["groups"])
@router.post("/", response_model=Group, operation_id="create_group")
async def create_group(
server: SyncServer = Depends(get_letta_server),
request: GroupCreate = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a multi-agent group with a specified management pattern. When no
management config is specified, this endpoint will use round robin for
speaker selection.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.create_group(request, actor=actor)
@router.get("/", response_model=List[Group], operation_id="list_groups")
def list_groups(
server: "SyncServer" = Depends(get_letta_server),
@@ -53,6 +40,23 @@ def list_groups(
)
@router.get("/{group_id}", response_model=Group, operation_id="retrieve_group")
def retrieve_group(
group_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the group by id.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
return server.group_manager.retrieve_group(group_id=group_id, actor=actor)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/", response_model=Group, operation_id="create_group")
def create_group(
group: GroupCreate = Body(...),
@@ -70,9 +74,10 @@ def create_group(
raise HTTPException(status_code=500, detail=str(e))
@router.put("/", response_model=Group, operation_id="upsert_group")
def upsert_group(
group: GroupCreate = Body(...),
@router.put("/{group_id}", response_model=Group, operation_id="modify_group")
def modify_group(
group_id: str,
group: GroupUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
@@ -82,7 +87,7 @@ def upsert_group(
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.create_group(group, actor=actor)
return server.group_manager.modify_group(group_id=group_id, group_update=group, actor=actor)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -110,7 +115,7 @@ def delete_group(
operation_id="send_group_message",
)
async def send_group_message(
agent_id: str,
group_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -178,6 +183,22 @@ GroupMessagesResponse = Annotated[
]
@router.patch("/{group_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_group_message")
def modify_group_message(
group_id: str,
message_id: str,
request: LettaMessageUpdateUnion = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the details of a message associated with an agent.
"""
# TODO: support modifying tool calls/returns
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
@router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages")
def list_group_messages(
group_id: str,
@@ -194,40 +215,42 @@ def list_group_messages(
Retrieve message history for an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.group_manager.list_group_messages(
group_id=group_id,
before=before,
after=after,
limit=limit,
actor=actor,
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
)
group = server.group_manager.retrieve_group(group_id=group_id, actor=actor)
if group.manager_agent_id:
return server.get_agent_recall(
user_id=actor.id,
agent_id=group.manager_agent_id,
after=after,
before=before,
limit=limit,
group_id=group_id,
reverse=True,
return_message_object=False,
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
)
else:
return server.group_manager.list_group_messages(
group_id=group_id,
after=after,
before=before,
limit=limit,
actor=actor,
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
)
'''
@router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages")
def reset_group_messages(
group_id: str,
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Resets the messages for all agents that are part of the multi-agent group.
TODO: only delete group messages not all messages!
Delete the group messages for all agents that are part of the multi-agent group.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
group = server.group_manager.retrieve_group(group_id=group_id, actor=actor)
agent_ids = group.agent_ids
if group.manager_agent_id:
agent_ids.append(group.manager_agent_id)
for agent_id in agent_ids:
server.agent_manager.reset_messages(
agent_id=agent_id,
actor=actor,
add_default_initial_messages=add_default_initial_messages,
)
'''
server.group_manager.reset_messages(group_id=group_id, actor=actor)

View File

@@ -211,7 +211,6 @@ def create_tool_call_messages_from_openai_response(
tool_calls=[],
tool_call_id=tool_call_id,
created_at=get_utc_time(),
name=function_name,
)
messages.append(tool_message)

View File

@@ -367,6 +367,9 @@ class SyncServer(Server):
def load_multi_agent(
self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None
) -> Agent:
if len(group.agent_ids) == 0:
raise ValueError("Empty group: group must have at least one agent")
match group.manager_type:
case ManagerType.round_robin:
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor)
@@ -862,6 +865,7 @@ class SyncServer(Server):
after: Optional[str] = None,
before: Optional[str] = None,
limit: Optional[int] = 100,
group_id: Optional[str] = None,
reverse: Optional[bool] = False,
return_message_object: bool = True,
use_assistant_message: bool = True,
@@ -879,6 +883,7 @@ class SyncServer(Server):
before=before,
limit=limit,
ascending=not reverse,
group_id=group_id,
)
if not return_message_object:
@@ -1591,88 +1596,76 @@ class SyncServer(Server):
) -> Union[StreamingResponse, LettaResponse]:
include_final_message = True
if not stream_steps and stream_tokens:
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
raise ValueError("stream_steps must be 'true' if stream_tokens is 'true'")
try:
# fetch the group
group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
letta_multi_agent = self.load_multi_agent(group=group, actor=actor)
group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
letta_multi_agent = self.load_multi_agent(group=group, actor=actor)
llm_config = letta_multi_agent.agent_state.llm_config
supports_token_streaming = ["openai", "anthropic", "deepseek"]
if stream_tokens and (
llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint
):
warnings.warn(
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
)
stream_tokens = False
llm_config = letta_multi_agent.agent_state.llm_config
supports_token_streaming = ["openai", "anthropic", "deepseek"]
if stream_tokens and (
llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint
):
warnings.warn(
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
)
stream_tokens = False
# Create a new interface per request
letta_multi_agent.interface = StreamingServerInterface(
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
inner_thoughts_in_kwargs=(
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
# Create a new interface per request
letta_multi_agent.interface = StreamingServerInterface(
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
inner_thoughts_in_kwargs=(
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
),
)
streaming_interface = letta_multi_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
streaming_interface.streaming_mode = stream_tokens
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
if metadata and hasattr(streaming_interface, "metadata"):
streaming_interface.metadata = metadata
streaming_interface.stream_start()
task = asyncio.create_task(
asyncio.to_thread(
letta_multi_agent.step,
messages=messages,
chaining=self.chaining,
max_chaining_steps=self.max_chaining_steps,
)
)
if stream_steps:
# return a stream
return StreamingResponse(
sse_async_generator(
streaming_interface.get_generator(),
usage_task=task,
finish_message=include_final_message,
),
)
streaming_interface = letta_multi_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
streaming_interface.streaming_mode = stream_tokens
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
if metadata and hasattr(streaming_interface, "metadata"):
streaming_interface.metadata = metadata
streaming_interface.stream_start()
task = asyncio.create_task(
asyncio.to_thread(
letta_multi_agent.step,
messages=messages,
chaining=self.chaining,
max_chaining_steps=self.max_chaining_steps,
)
media_type="text/event-stream",
)
if stream_steps:
# return a stream
return StreamingResponse(
sse_async_generator(
streaming_interface.get_generator(),
usage_task=task,
finish_message=include_final_message,
),
media_type="text/event-stream",
)
else:
# buffer the stream, then return the list
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or isinstance(message, MessageStreamStatus)
), type(message)
generated_stream.append(message)
if message == MessageStreamStatus.done:
break
else:
# buffer the stream, then return the list
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, LettaMessage)
or isinstance(message, LegacyLettaMessage)
or isinstance(message, MessageStreamStatus)
), type(message)
generated_stream.append(message)
if message == MessageStreamStatus.done:
break
# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
usage = await task
# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
usage = await task
# By default the stream will be messages of type LettaMessage or LettaLegacyMessage
# If we want to convert these to Message, we can use the attached IDs
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
return LettaResponse(messages=filtered_stream, usage=usage)
except HTTPException:
raise
except Exception as e:
print(e)
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
# By default the stream will be messages of type LettaMessage or LettaLegacyMessage
# If we want to convert these to Message, we can use the attached IDs
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
return LettaResponse(messages=filtered_stream, usage=usage)

View File

@@ -7,7 +7,9 @@ from letta.orm.errors import NoResultFound
from letta.orm.group import Group as GroupModel
from letta.orm.message import Message as MessageModel
from letta.schemas.group import Group as PydanticGroup
from letta.schemas.group import GroupCreate, ManagerType
from letta.schemas.group import GroupCreate, GroupUpdate, ManagerType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
@@ -22,12 +24,12 @@ class GroupManager:
@enforce_types
def list_groups(
self,
actor: PydanticUser,
project_id: Optional[str] = None,
manager_type: Optional[ManagerType] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> list[PydanticGroup]:
with self.session_maker() as session:
filters = {"organization_id": actor.organization_id}
@@ -56,27 +58,66 @@ class GroupManager:
new_group = GroupModel()
new_group.organization_id = actor.organization_id
new_group.description = group.description
match group.manager_config.manager_type:
case ManagerType.round_robin:
new_group.manager_type = ManagerType.round_robin
new_group.max_turns = group.manager_config.max_turns
case ManagerType.dynamic:
new_group.manager_type = ManagerType.dynamic
new_group.manager_agent_id = group.manager_config.manager_agent_id
new_group.max_turns = group.manager_config.max_turns
new_group.termination_token = group.manager_config.termination_token
case ManagerType.supervisor:
new_group.manager_type = ManagerType.supervisor
new_group.manager_agent_id = group.manager_config.manager_agent_id
case _:
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False)
if group.manager_config is None:
new_group.manager_type = ManagerType.round_robin
else:
match group.manager_config.manager_type:
case ManagerType.round_robin:
new_group.manager_type = ManagerType.round_robin
new_group.max_turns = group.manager_config.max_turns
case ManagerType.dynamic:
new_group.manager_type = ManagerType.dynamic
new_group.manager_agent_id = group.manager_config.manager_agent_id
new_group.max_turns = group.manager_config.max_turns
new_group.termination_token = group.manager_config.termination_token
case ManagerType.supervisor:
new_group.manager_type = ManagerType.supervisor
new_group.manager_agent_id = group.manager_config.manager_agent_id
case _:
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
new_group.create(session, actor=actor)
return new_group.to_pydantic()
@enforce_types
def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
with self.session_maker() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
max_turns = None
termination_token = None
manager_agent_id = None
if group_update.manager_config:
if group_update.manager_config.manager_type != group.manager_type:
raise ValueError(f"Cannot change group pattern after creation")
match group_update.manager_config.manager_type:
case ManagerType.round_robin:
max_turns = group_update.manager_config.max_turns
case ManagerType.dynamic:
manager_agent_id = group_update.manager_config.manager_agent_id
max_turns = group_update.manager_config.max_turns
termination_token = group_update.manager_config.termination_token
case ManagerType.supervisor:
manager_agent_id = group_update.manager_config.manager_agent_id
case _:
raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}")
if max_turns:
group.max_turns = max_turns
if termination_token:
group.termination_token = termination_token
if manager_agent_id:
group.manager_agent_id = manager_agent_id
if group_update.description:
group.description = group_update.description
if group_update.agent_ids:
self._process_agent_relationship(
session=session, group=group, agent_ids=group_update.agent_ids, allow_partial=False, replace=True
)
group.update(session, actor=actor)
return group.to_pydantic()
@enforce_types
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
with self.session_maker() as session:
@@ -87,23 +128,19 @@ class GroupManager:
@enforce_types
def list_group_messages(
self,
actor: PydanticUser,
group_id: Optional[str] = None,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
actor: PydanticUser = None,
use_assistant_message: bool = True,
assistant_message_tool_name: str = "send_message",
assistant_message_tool_kwarg: str = "message",
) -> list[PydanticGroup]:
) -> list[LettaMessage]:
with self.session_maker() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
agent_id = group.manager_agent_id if group.manager_agent_id else group.agent_ids[0]
filters = {
"organization_id": actor.organization_id,
"group_id": group_id,
"agent_id": agent_id,
}
messages = MessageModel.list(
db_session=session,
@@ -114,21 +151,39 @@ class GroupManager:
)
messages = PydanticMessage.to_letta_messages_from_list(
messages=messages,
messages=[msg.to_pydantic() for msg in messages],
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
)
# TODO: filter messages to return a clean conversation history
return messages
@enforce_types
def reset_messages(self, group_id: str, actor: PydanticUser) -> None:
with self.session_maker() as session:
# Ensure group is loadable by user
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
# Delete all messages in the group
session.query(MessageModel).filter(
MessageModel.organization_id == actor.organization_id, MessageModel.group_id == group_id
).delete(synchronize_session=False)
session.commit()
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
current_relationship = getattr(group, "agents", [])
if not agent_ids:
if replace:
setattr(group, "agents", [])
setattr(group, "agent_ids", [])
return
if group.manager_type == ManagerType.dynamic and len(agent_ids) != len(set(agent_ids)):
raise ValueError("Duplicate agent ids found in list")
# Retrieve models for the provided IDs
found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all()
@@ -137,11 +192,14 @@ class GroupManager:
missing = set(agent_ids) - {item.id for item in found_items}
raise NoResultFound(f"Items not found in agents: {missing}")
if group.manager_type == ManagerType.dynamic:
names = [item.name for item in found_items]
if len(names) != len(set(names)):
raise ValueError("Duplicate agent names found in the provided agent IDs.")
if replace:
# Replace the relationship
setattr(group, "agents", found_items)
setattr(group, "agent_ids", agent_ids)
else:
# Extend the relationship (only add new items)
current_ids = {item.id for item in current_relationship}
new_items = [item for item in found_items if item.id not in current_ids]
current_relationship.extend(new_items)
raise ValueError("Extend relationship is not supported for groups.")

View File

@@ -264,6 +264,7 @@ class MessageManager:
roles: Optional[Sequence[MessageRole]] = None,
limit: Optional[int] = 50,
ascending: bool = True,
group_id: Optional[str] = None,
) -> List[PydanticMessage]:
"""
Most performant query to list messages for an agent by directly querying the Message table.
@@ -296,6 +297,10 @@ class MessageManager:
# Build a query that directly filters the Message table by agent_id.
query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id)
# If group_id is provided, filter messages by group_id.
if group_id:
query = query.filter(MessageModel.group_id == group_id)
# If query_text is provided, filter messages using subquery.
if query_text:
content_element = func.json_array_elements(MessageModel.content).alias("content_element")

View File

@@ -22,7 +22,7 @@ class SupervisorMultiAgent(Agent):
self,
interface: AgentInterface,
agent_state: AgentState,
user: User = None,
user: User,
# custom
group_id: str = "",
agent_ids: List[str] = [],
@@ -65,6 +65,7 @@ class SupervisorMultiAgent(Agent):
self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
# override tool rules
old_tool_rules = self.agent_state.tool_rules
self.agent_state.tool_rules = [
InitToolRule(
tool_name="send_message_to_all_agents_in_group",
@@ -106,6 +107,7 @@ class SupervisorMultiAgent(Agent):
raise e
finally:
self.interface.step_yield()
self.agent_state.tool_rules = old_tool_rules
self.interface.step_complete()

View File

@@ -5,7 +5,7 @@ from letta.config import LettaConfig
from letta.orm import Provider, Step
from letta.schemas.agent import CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.group import DynamicManager, GroupCreate, SupervisorManager
from letta.schemas.group import DynamicManager, GroupCreate, GroupUpdate, ManagerType, RoundRobinManager, SupervisorManager
from letta.schemas.message import MessageCreate
from letta.server.server import SyncServer
@@ -45,7 +45,7 @@ def actor(server, org_id):
@pytest.fixture(scope="module")
def participant_agent_ids(server, actor):
def participant_agents(server, actor):
agent_fred = server.create_agent(
request=CreateAgent(
name="fred",
@@ -102,7 +102,7 @@ def participant_agent_ids(server, actor):
),
actor=actor,
)
yield [agent_fred.id, agent_velma.id, agent_daphne.id, agent_shaggy.id]
yield [agent_fred, agent_velma, agent_daphne, agent_shaggy]
# cleanup
server.agent_manager.delete_agent(agent_fred.id, actor=actor)
@@ -112,7 +112,7 @@ def participant_agent_ids(server, actor):
@pytest.fixture(scope="module")
def manager_agent_id(server, actor):
def manager_agent(server, actor):
agent_scooby = server.create_agent(
request=CreateAgent(
name="scooby",
@@ -131,22 +131,84 @@ def manager_agent_id(server, actor):
),
actor=actor,
)
yield agent_scooby.id
yield agent_scooby
# cleanup
server.agent_manager.delete_agent(agent_scooby.id, actor=actor)
@pytest.mark.asyncio
async def test_round_robin(server, actor, participant_agent_ids):
async def test_empty_group(server, actor):
group = server.group_manager.create_group(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
agent_ids=participant_agent_ids,
agent_ids=[],
),
actor=actor,
)
with pytest.raises(ValueError, match="Empty group"):
await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
MessageCreate(
role="user",
content="what is everyone up to for the holidays?",
),
],
stream_steps=False,
stream_tokens=False,
)
server.group_manager.delete_group(group_id=group.id, actor=actor)
@pytest.mark.asyncio
async def test_modify_group_pattern(server, actor, participant_agents, manager_agent):
group = server.group_manager.create_group(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
agent_ids=[agent.id for agent in participant_agents],
),
actor=actor,
)
with pytest.raises(ValueError, match="Cannot change group pattern"):
server.group_manager.modify_group(
group_id=group.id,
group_update=GroupUpdate(
manager_config=DynamicManager(
manager_type=ManagerType.dynamic,
manager_agent_id=manager_agent.id,
),
),
actor=actor,
)
server.group_manager.delete_group(group_id=group.id, actor=actor)
@pytest.mark.asyncio
async def test_round_robin(server, actor, participant_agents):
description = (
"This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries."
)
group = server.group_manager.create_group(
group=GroupCreate(
description=description,
agent_ids=[agent.id for agent in participant_agents],
),
actor=actor,
)
# verify group creation
assert group.manager_type == ManagerType.round_robin
assert group.description == description
assert group.agent_ids == [agent.id for agent in participant_agents]
assert group.max_turns == None
assert group.manager_agent_id is None
assert group.termination_token is None
try:
server.group_manager.reset_messages(group_id=group.id, actor=actor)
response = await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
@@ -159,15 +221,85 @@ async def test_round_robin(server, actor, participant_agent_ids):
stream_steps=False,
stream_tokens=False,
)
assert response.usage.step_count == len(participant_agent_ids)
assert response.usage.step_count == len(group.agent_ids)
assert len(response.messages) == response.usage.step_count * 2
for i, message in enumerate(response.messages):
assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message"
assert message.name == participant_agents[i // 2].name
for agent_id in group.agent_ids:
agent_messages = server.get_agent_recall(
user_id=actor.id,
agent_id=agent_id,
group_id=group.id,
reverse=True,
return_message_object=False,
)
assert len(agent_messages) == len(group.agent_ids) + 2 # add one for user message, one for reasoning message
# TODO: filter this to return a clean conversation history
messages = server.group_manager.list_group_messages(
group_id=group.id,
actor=actor,
)
assert len(messages) == (len(group.agent_ids) + 2) * len(group.agent_ids)
max_turns = 3
group = server.group_manager.modify_group(
group_id=group.id,
group_update=GroupUpdate(
agent_ids=[agent.id for agent in participant_agents][::-1],
manager_config=RoundRobinManager(
max_turns=max_turns,
),
),
actor=actor,
)
assert group.manager_type == ManagerType.round_robin
assert group.description == description
assert group.agent_ids == [agent.id for agent in participant_agents][::-1]
assert group.max_turns == max_turns
assert group.manager_agent_id is None
assert group.termination_token is None
server.group_manager.reset_messages(group_id=group.id, actor=actor)
response = await server.send_group_message_to_agent(
group_id=group.id,
actor=actor,
messages=[
MessageCreate(
role="user",
content="what is everyone up to for the holidays?",
),
],
stream_steps=False,
stream_tokens=False,
)
assert response.usage.step_count == max_turns
assert len(response.messages) == max_turns * 2
for i, message in enumerate(response.messages):
assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message"
assert message.name == participant_agents[::-1][i // 2].name
for i in range(len(group.agent_ids)):
agent_messages = server.get_agent_recall(
user_id=actor.id,
agent_id=group.agent_ids[i],
group_id=group.id,
reverse=True,
return_message_object=False,
)
expected_message_count = max_turns + 1 if i >= max_turns else max_turns + 2
assert len(agent_messages) == expected_message_count
finally:
server.group_manager.delete_group(group_id=group.id, actor=actor)
@pytest.mark.asyncio
async def test_supervisor(server, actor, participant_agent_ids):
async def test_supervisor(server, actor, participant_agents):
agent_scrappy = server.create_agent(
request=CreateAgent(
name="shaggy",
@@ -186,10 +318,11 @@ async def test_supervisor(server, actor, participant_agent_ids):
),
actor=actor,
)
group = server.group_manager.create_group(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
agent_ids=participant_agent_ids,
agent_ids=[agent.id for agent in participant_agents],
manager_config=SupervisorManager(
manager_agent_id=agent_scrappy.id,
),
@@ -219,7 +352,7 @@ async def test_supervisor(server, actor, participant_agent_ids):
and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group"
)
assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len(
participant_agent_ids
participant_agents
)
assert response.messages[3].message_type == "reasoning_message"
assert response.messages[4].message_type == "assistant_message"
@@ -230,13 +363,50 @@ async def test_supervisor(server, actor, participant_agent_ids):
@pytest.mark.asyncio
async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_agent_ids):
async def test_dynamic_group_chat(server, actor, manager_agent, participant_agents):
description = (
"This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries."
)
# error on duplicate agent in participant list
with pytest.raises(ValueError, match="Duplicate agent ids"):
server.group_manager.create_group(
group=GroupCreate(
description=description,
agent_ids=[agent.id for agent in participant_agents] + [participant_agents[0].id],
manager_config=DynamicManager(
manager_agent_id=manager_agent.id,
),
),
actor=actor,
)
# error on duplicate agent names
duplicate_agent_shaggy = server.create_agent(
request=CreateAgent(
name="shaggy",
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
with pytest.raises(ValueError, match="Duplicate agent names"):
server.group_manager.create_group(
group=GroupCreate(
description=description,
agent_ids=[agent.id for agent in participant_agents] + [duplicate_agent_shaggy.id],
manager_config=DynamicManager(
manager_agent_id=manager_agent.id,
),
),
actor=actor,
)
server.agent_manager.delete_agent(duplicate_agent_shaggy.id, actor=actor)
group = server.group_manager.create_group(
group=GroupCreate(
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
agent_ids=participant_agent_ids,
description=description,
agent_ids=[agent.id for agent in participant_agents],
manager_config=DynamicManager(
manager_agent_id=manager_agent_id,
manager_agent_id=manager_agent.id,
),
),
actor=actor,
@@ -251,7 +421,7 @@ async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_a
stream_steps=False,
stream_tokens=False,
)
assert response.usage.step_count == len(participant_agent_ids) * 2
assert response.usage.step_count == len(participant_agents) * 2
assert len(response.messages) == response.usage.step_count * 2
finally: