test: add more robust multi-agent testing (#1444)
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
"""add ordered agent ids to groups
|
||||
|
||||
Revision ID: a66510f83fc2
|
||||
Revises: bdddd421ec41
|
||||
Create Date: 2025-03-27 11:11:51.709498
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a66510f83fc2"
|
||||
down_revision: Union[str, None] = "bdddd421ec41"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("groups", sa.Column("agent_ids", sa.JSON(), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("groups", "agent_ids")
|
||||
# ### end Alembic commands ###
|
||||
@@ -220,6 +220,7 @@ class Agent(BaseAgent):
|
||||
messages: List[Message],
|
||||
tool_returns: Optional[List[ToolReturn]] = None,
|
||||
include_function_failed_message: bool = False,
|
||||
group_id: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Handle error from function call response
|
||||
@@ -240,7 +241,9 @@ class Agent(BaseAgent):
|
||||
"content": function_response,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
name=self.agent_state.name,
|
||||
tool_returns=tool_returns,
|
||||
group_id=group_id,
|
||||
)
|
||||
messages.append(new_message)
|
||||
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message)
|
||||
@@ -329,6 +332,7 @@ class Agent(BaseAgent):
|
||||
stream=stream,
|
||||
stream_interface=self.interface,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
name=self.agent_state.name,
|
||||
)
|
||||
log_telemetry(self.logger, "_get_ai_reply create finish")
|
||||
|
||||
@@ -372,6 +376,7 @@ class Agent(BaseAgent):
|
||||
# and now we want to use it in the creation of the Message object
|
||||
# TODO figure out a cleaner way to do this
|
||||
response_message_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
) -> Tuple[List[Message], bool, bool]:
|
||||
"""Handles parsing and function execution"""
|
||||
log_telemetry(self.logger, "_handle_ai_response start")
|
||||
@@ -417,6 +422,8 @@ class Agent(BaseAgent):
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=response_message.model_dump(),
|
||||
name=self.agent_state.name,
|
||||
group_id=group_id,
|
||||
)
|
||||
) # extend conversation with assistant's reply
|
||||
self.logger.debug(f"Function call message: {messages[-1]}")
|
||||
@@ -449,7 +456,7 @@ class Agent(BaseAgent):
|
||||
error_msg = f"No function named {function_name}"
|
||||
function_response = "None" # more like "never ran?"
|
||||
messages = self._handle_function_error_response(
|
||||
error_msg, tool_call_id, function_name, function_args, function_response, messages
|
||||
error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
@@ -464,7 +471,7 @@ class Agent(BaseAgent):
|
||||
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}"
|
||||
function_response = "None" # more like "never ran?"
|
||||
messages = self._handle_function_error_response(
|
||||
error_msg, tool_call_id, function_name, function_args, function_response, messages
|
||||
error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
@@ -535,6 +542,7 @@ class Agent(BaseAgent):
|
||||
function_response,
|
||||
messages,
|
||||
[tool_return],
|
||||
group_id=group_id,
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
@@ -571,6 +579,7 @@ class Agent(BaseAgent):
|
||||
messages,
|
||||
[ToolReturn(status="error", stderr=[error_msg_user])],
|
||||
include_function_failed_message=True,
|
||||
group_id=group_id,
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
@@ -595,6 +604,7 @@ class Agent(BaseAgent):
|
||||
messages,
|
||||
[tool_return],
|
||||
include_function_failed_message=True,
|
||||
group_id=group_id,
|
||||
)
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
@@ -620,7 +630,9 @@ class Agent(BaseAgent):
|
||||
"content": function_response,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
name=self.agent_state.name,
|
||||
tool_returns=[tool_return] if sandbox_run_result else None,
|
||||
group_id=group_id,
|
||||
)
|
||||
) # extend conversation with function response
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
|
||||
@@ -636,6 +648,8 @@ class Agent(BaseAgent):
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=response_message.model_dump(),
|
||||
name=self.agent_state.name,
|
||||
group_id=group_id,
|
||||
)
|
||||
) # extend conversation with assistant's reply
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
|
||||
@@ -799,7 +813,11 @@ class Agent(BaseAgent):
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
input_message_sequence = in_context_messages + messages
|
||||
|
||||
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
|
||||
if (
|
||||
len(input_message_sequence) > 1
|
||||
and input_message_sequence[-1].role != "user"
|
||||
and input_message_sequence[-1].group_id is None
|
||||
):
|
||||
self.logger.warning(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
|
||||
|
||||
# Step 2: send the conversation and available functions to the LLM
|
||||
@@ -832,6 +850,7 @@ class Agent(BaseAgent):
|
||||
# TODO this is kind of hacky, find a better way to handle this
|
||||
# the only time we set up message creation ahead of time is when streaming is on
|
||||
response_message_id=response.id if stream else None,
|
||||
group_id=input_message_sequence[-1].group_id,
|
||||
)
|
||||
|
||||
# Step 6: extend the message history
|
||||
|
||||
@@ -16,7 +16,7 @@ class DynamicMultiAgent(Agent):
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User = None,
|
||||
user: User,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
@@ -128,7 +128,7 @@ class DynamicMultiAgent(Agent):
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
message_index[agent_id] = len(chat_history) + len(new_messages)
|
||||
message_index[speaker_id] = len(chat_history) + len(new_messages)
|
||||
|
||||
# sum usage
|
||||
total_usage.prompt_tokens += usage_stats.prompt_tokens
|
||||
@@ -251,10 +251,10 @@ class DynamicMultiAgent(Agent):
|
||||
chat_history: List[Message],
|
||||
agent_id_options: List[str],
|
||||
) -> Message:
|
||||
chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history]
|
||||
text_chat_history = [f"{message.name or 'user'}: {message.content[0].text}" for message in chat_history]
|
||||
for message in new_messages:
|
||||
chat_history.append(f"{message.name or 'user'}: {message.content}")
|
||||
context_messages = "\n".join(chat_history)
|
||||
text_chat_history.append(f"{message.name or 'user'}: {message.content}")
|
||||
context_messages = "\n".join(text_chat_history)
|
||||
|
||||
message_text = (
|
||||
"Choose the most suitable agent to reply to the latest message in the "
|
||||
|
||||
@@ -859,6 +859,7 @@ def anthropic_chat_completions_process_stream(
|
||||
create_message_id: bool = True,
|
||||
create_message_datetime: bool = True,
|
||||
betas: List[str] = ["tools-2024-04-04"],
|
||||
name: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Process a streaming completion response from Anthropic, similar to OpenAI's streaming.
|
||||
|
||||
@@ -951,6 +952,7 @@ def anthropic_chat_completions_process_stream(
|
||||
# if extended_thinking is on, then reasoning_content will be flowing as chunks
|
||||
# TODO handle emitting redacted reasoning content (e.g. as concat?)
|
||||
expect_reasoning_content=extended_thinking,
|
||||
name=name,
|
||||
)
|
||||
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
|
||||
stream_interface.process_refresh(chat_completion_response)
|
||||
|
||||
@@ -140,6 +140,7 @@ def create(
|
||||
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
|
||||
model_settings: Optional[dict] = None, # TODO: eventually pass from server
|
||||
put_inner_thoughts_first: bool = True,
|
||||
name: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Return response to chat completion with backoff"""
|
||||
from letta.utils import printd
|
||||
@@ -206,6 +207,7 @@ def create(
|
||||
api_key=api_key,
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
name=name,
|
||||
)
|
||||
else: # Client did not request token streaming (expect a blocking backend response)
|
||||
data.stream = False
|
||||
@@ -255,6 +257,7 @@ def create(
|
||||
api_key=api_key,
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
name=name,
|
||||
)
|
||||
else: # Client did not request token streaming (expect a blocking backend response)
|
||||
data.stream = False
|
||||
@@ -359,6 +362,7 @@ def create(
|
||||
stream_interface=stream_interface,
|
||||
extended_thinking=llm_config.enable_reasoner,
|
||||
max_reasoning_tokens=llm_config.max_reasoning_tokens,
|
||||
name=name,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -531,6 +535,7 @@ def create(
|
||||
api_key=model_settings.deepseek_api_key,
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
name=name,
|
||||
)
|
||||
else: # Client did not request token streaming (expect a blocking backend response)
|
||||
data.stream = False
|
||||
|
||||
@@ -185,6 +185,7 @@ def openai_chat_completions_process_stream(
|
||||
# however, we don't necessarily want to put these
|
||||
# expect_reasoning_content: bool = False,
|
||||
expect_reasoning_content: bool = True,
|
||||
name: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
|
||||
|
||||
@@ -272,6 +273,7 @@ def openai_chat_completions_process_stream(
|
||||
message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id,
|
||||
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
|
||||
expect_reasoning_content=expect_reasoning_content,
|
||||
name=name,
|
||||
)
|
||||
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
|
||||
stream_interface.process_refresh(chat_completion_response)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy import JSON, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
@@ -23,11 +23,8 @@ class Group(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
|
||||
agent_ids: Mapped[List[str]] = mapped_column(JSON, nullable=False, doc="Ordered list of agent IDs in this group")
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
|
||||
)
|
||||
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group")
|
||||
|
||||
@property
|
||||
def agent_ids(self) -> List[str]:
|
||||
return [agent.id for agent in self.agents]
|
||||
|
||||
@@ -14,7 +14,7 @@ class RoundRobinMultiAgent(Agent):
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User = None,
|
||||
user: User,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
@@ -45,7 +45,7 @@ class RoundRobinMultiAgent(Agent):
|
||||
for agent_id in self.agent_ids:
|
||||
agents[agent_id] = self.load_participant_agent(agent_id=agent_id)
|
||||
|
||||
message_index = {}
|
||||
message_index = {agent_id: 0 for agent_id in self.agent_ids}
|
||||
chat_history: List[Message] = []
|
||||
new_messages = messages
|
||||
speaker_id = None
|
||||
@@ -91,7 +91,7 @@ class RoundRobinMultiAgent(Agent):
|
||||
MessageCreate(
|
||||
role="system",
|
||||
content=message.content,
|
||||
name=participant_agent.agent_state.name,
|
||||
name=message.name,
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
@@ -138,10 +138,21 @@ class RoundRobinMultiAgent(Agent):
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
|
||||
persona_block = agent_state.memory.get_block(label="persona")
|
||||
group_chat_participant_persona = (
|
||||
"\n\n====Group Chat Contex===="
|
||||
f"\nYou are speaking in a group chat with {len(self.agent_ids) - 1} other "
|
||||
"agents and one user. Respond to new messages in the group chat when prompted. "
|
||||
f"Description of the group: {self.description}"
|
||||
f"%%% GROUP CHAT CONTEXT %%% "
|
||||
f"You are speaking in a group chat with {len(self.agent_ids)} other participants. "
|
||||
f"Group Description: {self.description} "
|
||||
"INTERACTION GUIDELINES:\n"
|
||||
"1. Be aware that others can see your messages - communicate as if in a real group conversation\n"
|
||||
"2. Acknowledge and build upon others' contributions when relevant\n"
|
||||
"3. Stay on topic while adding your unique perspective based on your role and personality\n"
|
||||
"4. Be concise but engaging - give others space to contribute\n"
|
||||
"5. Maintain your character's personality while being collaborative\n"
|
||||
"6. Feel free to ask questions to other participants to encourage discussion\n"
|
||||
"7. If someone addresses you directly, acknowledge their message\n"
|
||||
"8. Share relevant experiences or knowledge that adds value to the conversation\n\n"
|
||||
"Remember: This is a natural group conversation. Interact as you would in a real group setting, "
|
||||
"staying true to your character while fostering meaningful dialogue. "
|
||||
"%%% END GROUP CHAT CONTEXT %%%"
|
||||
)
|
||||
agent_state.memory.update_block_value(label="persona", value=persona_block.value + group_chat_participant_persona)
|
||||
return Agent(
|
||||
|
||||
@@ -62,4 +62,10 @@ ManagerConfigUnion = Annotated[
|
||||
class GroupCreate(BaseModel):
|
||||
agent_ids: List[str] = Field(..., description="")
|
||||
description: str = Field(..., description="")
|
||||
manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="")
|
||||
|
||||
|
||||
class GroupUpdate(BaseModel):
|
||||
agent_ids: Optional[List[str]] = Field(None, description="")
|
||||
description: Optional[str] = Field(None, description="")
|
||||
manager_config: Optional[ManagerConfigUnion] = Field(None, description="")
|
||||
|
||||
@@ -226,6 +226,7 @@ class Message(BaseMessage):
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
reasoning=self.content[0].text,
|
||||
name=self.name,
|
||||
)
|
||||
)
|
||||
# Otherwise, we may have a list of multiple types
|
||||
@@ -239,6 +240,7 @@ class Message(BaseMessage):
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
reasoning=content_part.text,
|
||||
name=self.name,
|
||||
)
|
||||
)
|
||||
elif isinstance(content_part, ReasoningContent):
|
||||
@@ -250,6 +252,7 @@ class Message(BaseMessage):
|
||||
reasoning=content_part.reasoning,
|
||||
source="reasoner_model", # TODO do we want to tag like this?
|
||||
signature=content_part.signature,
|
||||
name=self.name,
|
||||
)
|
||||
)
|
||||
elif isinstance(content_part, RedactedReasoningContent):
|
||||
@@ -260,6 +263,7 @@ class Message(BaseMessage):
|
||||
date=self.created_at,
|
||||
state="redacted",
|
||||
hidden_reasoning=content_part.data,
|
||||
name=self.name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -282,6 +286,7 @@ class Message(BaseMessage):
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
content=message_string,
|
||||
name=self.name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -294,6 +299,7 @@ class Message(BaseMessage):
|
||||
arguments=tool_call.function.arguments,
|
||||
tool_call_id=tool_call.id,
|
||||
),
|
||||
name=self.name,
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.tool:
|
||||
@@ -334,6 +340,7 @@ class Message(BaseMessage):
|
||||
tool_call_id=self.tool_call_id,
|
||||
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
|
||||
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
|
||||
name=self.name,
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.user:
|
||||
@@ -349,6 +356,7 @@ class Message(BaseMessage):
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
content=message_str or text_content,
|
||||
name=self.name,
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.system:
|
||||
@@ -363,6 +371,7 @@ class Message(BaseMessage):
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
content=text_content,
|
||||
name=self.name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -379,6 +388,8 @@ class Message(BaseMessage):
|
||||
allow_functions_style: bool = False, # allow deprecated functions style?
|
||||
created_at: Optional[datetime] = None,
|
||||
id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
tool_returns: Optional[List[ToolReturn]] = None,
|
||||
):
|
||||
"""Convert a ChatCompletion message object into a Message object (synced to DB)"""
|
||||
@@ -426,12 +437,13 @@ class Message(BaseMessage):
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole.tool, # NOTE
|
||||
content=content,
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
name=name,
|
||||
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
created_at=created_at,
|
||||
id=str(id),
|
||||
tool_returns=tool_returns,
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
return Message(
|
||||
@@ -440,11 +452,12 @@ class Message(BaseMessage):
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole.tool, # NOTE
|
||||
content=content,
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
name=name,
|
||||
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
created_at=created_at,
|
||||
tool_returns=tool_returns,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
elif "function_call" in openai_message_dict and openai_message_dict["function_call"] is not None:
|
||||
@@ -473,12 +486,13 @@ class Message(BaseMessage):
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole(openai_message_dict["role"]),
|
||||
content=content,
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
name=name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
|
||||
created_at=created_at,
|
||||
id=str(id),
|
||||
tool_returns=tool_returns,
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
return Message(
|
||||
@@ -492,6 +506,7 @@ class Message(BaseMessage):
|
||||
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
|
||||
created_at=created_at,
|
||||
tool_returns=tool_returns,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -520,12 +535,13 @@ class Message(BaseMessage):
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole(openai_message_dict["role"]),
|
||||
content=content,
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
name=name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
created_at=created_at,
|
||||
id=str(id),
|
||||
tool_returns=tool_returns,
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
return Message(
|
||||
@@ -534,11 +550,12 @@ class Message(BaseMessage):
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole(openai_message_dict["role"]),
|
||||
content=content,
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
name=name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
created_at=created_at,
|
||||
tool_returns=tool_returns,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict:
|
||||
@@ -579,9 +596,6 @@ class Message(BaseMessage):
|
||||
"content": text_content,
|
||||
"role": self.role,
|
||||
}
|
||||
# Optional field, do not include if null
|
||||
if self.name is not None:
|
||||
openai_message["name"] = self.name
|
||||
|
||||
elif self.role == "user":
|
||||
assert all([v is not None for v in [text_content, self.role]]), vars(self)
|
||||
@@ -589,9 +603,6 @@ class Message(BaseMessage):
|
||||
"content": text_content,
|
||||
"role": self.role,
|
||||
}
|
||||
# Optional field, do not include if null
|
||||
if self.name is not None:
|
||||
openai_message["name"] = self.name
|
||||
|
||||
elif self.role == "assistant":
|
||||
assert self.tool_calls is not None or text_content is not None
|
||||
@@ -599,9 +610,7 @@ class Message(BaseMessage):
|
||||
"content": None if put_inner_thoughts_in_kwargs else text_content,
|
||||
"role": self.role,
|
||||
}
|
||||
# Optional fields, do not include if null
|
||||
if self.name is not None:
|
||||
openai_message["name"] = self.name
|
||||
|
||||
if self.tool_calls is not None:
|
||||
if put_inner_thoughts_in_kwargs:
|
||||
# put the inner thoughts inside the tool call before casting to a dict
|
||||
|
||||
@@ -465,6 +465,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# if we expect `reasoning_content``, then that's what gets mapped to ReasoningMessage
|
||||
# and `content` needs to be handled outside the interface
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
) -> Optional[Union[ReasoningMessage, ToolCallMessage, AssistantMessage]]:
|
||||
"""
|
||||
Example data from non-streaming response looks like:
|
||||
@@ -497,6 +498,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
reasoning=message_delta.reasoning_content,
|
||||
signature=message_delta.reasoning_content_signature,
|
||||
source="reasoner_model" if message_delta.reasoning_content_signature else "non_reasoner_model",
|
||||
name=name,
|
||||
)
|
||||
elif expect_reasoning_content and message_delta.redacted_reasoning_content is not None:
|
||||
processed_chunk = HiddenReasoningMessage(
|
||||
@@ -504,6 +506,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=message_date,
|
||||
hidden_reasoning=message_delta.redacted_reasoning_content,
|
||||
state="redacted",
|
||||
name=name,
|
||||
)
|
||||
elif expect_reasoning_content and message_delta.content is not None:
|
||||
# "ignore" content if we expect reasoning content
|
||||
@@ -530,6 +533,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
arguments=json.dumps(json_reasoning_content.get("arguments")),
|
||||
tool_call_id=None,
|
||||
),
|
||||
name=name,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -559,6 +563,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
reasoning=message_delta.content,
|
||||
name=name,
|
||||
)
|
||||
|
||||
# tool calls
|
||||
@@ -607,7 +612,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff)
|
||||
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -639,6 +644,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
arguments=tool_call_delta.get("arguments"),
|
||||
tool_call_id=tool_call_delta.get("id"),
|
||||
),
|
||||
name=name,
|
||||
)
|
||||
|
||||
elif self.inner_thoughts_in_kwargs and tool_call.function:
|
||||
@@ -674,6 +680,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
reasoning=updates_inner_thoughts,
|
||||
name=name,
|
||||
)
|
||||
# Additionally inner thoughts may stream back with a chunk of main JSON
|
||||
# In that case, since we can only return a chunk at a time, we should buffer it
|
||||
@@ -709,6 +716,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
arguments=None,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
name=name,
|
||||
)
|
||||
|
||||
# Record what the last function name we flushed was
|
||||
@@ -765,6 +773,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
content=combined_chunk,
|
||||
name=name,
|
||||
)
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
@@ -789,7 +798,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff)
|
||||
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -813,6 +822,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
name=name,
|
||||
)
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
@@ -827,6 +837,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
name=name,
|
||||
)
|
||||
self.function_id_buffer = None
|
||||
|
||||
@@ -955,6 +966,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
arguments=tool_call_delta.get("arguments"),
|
||||
tool_call_id=tool_call_delta.get("id"),
|
||||
),
|
||||
name=name,
|
||||
)
|
||||
|
||||
elif choice.finish_reason is not None:
|
||||
@@ -1035,6 +1047,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
message_id: str,
|
||||
message_date: datetime,
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server.
|
||||
|
||||
@@ -1060,6 +1073,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
message_id=message_id,
|
||||
message_date=message_date,
|
||||
expect_reasoning_content=expect_reasoning_content,
|
||||
name=name,
|
||||
)
|
||||
|
||||
if processed_chunk is None:
|
||||
@@ -1087,6 +1101,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
reasoning=msg,
|
||||
name=msg_obj.name,
|
||||
)
|
||||
|
||||
self._push_to_buffer(processed_chunk)
|
||||
@@ -1097,6 +1112,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
reasoning=content.text,
|
||||
name=msg_obj.name,
|
||||
)
|
||||
elif isinstance(content, ReasoningContent):
|
||||
processed_chunk = ReasoningMessage(
|
||||
@@ -1105,6 +1121,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
source="reasoner_model",
|
||||
reasoning=content.reasoning,
|
||||
signature=content.signature,
|
||||
name=msg_obj.name,
|
||||
)
|
||||
elif isinstance(content, RedactedReasoningContent):
|
||||
processed_chunk = HiddenReasoningMessage(
|
||||
@@ -1112,6 +1129,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=msg_obj.created_at,
|
||||
state="redacted",
|
||||
hidden_reasoning=content.data,
|
||||
name=msg_obj.name,
|
||||
)
|
||||
|
||||
self._push_to_buffer(processed_chunk)
|
||||
@@ -1172,6 +1190,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
content=func_args["message"],
|
||||
name=msg_obj.name,
|
||||
)
|
||||
self._push_to_buffer(processed_chunk)
|
||||
except Exception as e:
|
||||
@@ -1194,6 +1213,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
content=func_args[self.assistant_message_tool_kwarg],
|
||||
name=msg_obj.name,
|
||||
)
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
self.prev_assistant_message_id = function_call.id
|
||||
@@ -1206,6 +1226,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
arguments=function_call.function.arguments,
|
||||
tool_call_id=function_call.id,
|
||||
),
|
||||
name=msg_obj.name,
|
||||
)
|
||||
|
||||
# processed_chunk = {
|
||||
@@ -1245,6 +1266,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=msg_obj.tool_call_id,
|
||||
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
|
||||
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
|
||||
name=msg_obj.name,
|
||||
)
|
||||
|
||||
elif msg.startswith("Error: "):
|
||||
@@ -1259,6 +1281,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=msg_obj.tool_call_id,
|
||||
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
|
||||
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
|
||||
name=msg_obj.name,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -527,6 +527,7 @@ def list_messages(
|
||||
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
|
||||
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
||||
group_id: Optional[str] = Query(None, description="Group ID to filter messages by."),
|
||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
|
||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
|
||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
|
||||
@@ -543,6 +544,7 @@ def list_messages(
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
group_id=group_id,
|
||||
reverse=True,
|
||||
return_message_object=False,
|
||||
use_assistant_message=use_assistant_message,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, Query
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.group import Group, GroupCreate, ManagerType
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.group import Group, GroupCreate, GroupUpdate, ManagerType
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
@@ -14,21 +16,6 @@ from letta.server.server import SyncServer
|
||||
router = APIRouter(prefix="/groups", tags=["groups"])
|
||||
|
||||
|
||||
@router.post("/", response_model=Group, operation_id="create_group")
|
||||
async def create_group(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: GroupCreate = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a multi-agent group with a specified management pattern. When no
|
||||
management config is specified, this endpoint will use round robin for
|
||||
speaker selection.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.create_group(request, actor=actor)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Group], operation_id="list_groups")
|
||||
def list_groups(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -53,6 +40,23 @@ def list_groups(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}", response_model=Group, operation_id="retrieve_group")
|
||||
def retrieve_group(
|
||||
group_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the group by id.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/", response_model=Group, operation_id="create_group")
|
||||
def create_group(
|
||||
group: GroupCreate = Body(...),
|
||||
@@ -70,9 +74,10 @@ def create_group(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/", response_model=Group, operation_id="upsert_group")
|
||||
def upsert_group(
|
||||
group: GroupCreate = Body(...),
|
||||
@router.put("/{group_id}", response_model=Group, operation_id="modify_group")
|
||||
def modify_group(
|
||||
group_id: str,
|
||||
group: GroupUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
@@ -82,7 +87,7 @@ def upsert_group(
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.group_manager.create_group(group, actor=actor)
|
||||
return server.group_manager.modify_group(group_id=group_id, group_update=group, actor=actor)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -110,7 +115,7 @@ def delete_group(
|
||||
operation_id="send_group_message",
|
||||
)
|
||||
async def send_group_message(
|
||||
agent_id: str,
|
||||
group_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@@ -178,6 +183,22 @@ GroupMessagesResponse = Annotated[
|
||||
]
|
||||
|
||||
|
||||
@router.patch("/{group_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_group_message")
|
||||
def modify_group_message(
|
||||
group_id: str,
|
||||
message_id: str,
|
||||
request: LettaMessageUpdateUnion = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
# TODO: support modifying tool calls/returns
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=GroupMessagesResponse, operation_id="list_group_messages")
|
||||
def list_group_messages(
|
||||
group_id: str,
|
||||
@@ -194,40 +215,42 @@ def list_group_messages(
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.group_manager.list_group_messages(
|
||||
group_id=group_id,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
group = server.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
if group.manager_agent_id:
|
||||
return server.get_agent_recall(
|
||||
user_id=actor.id,
|
||||
agent_id=group.manager_agent_id,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
group_id=group_id,
|
||||
reverse=True,
|
||||
return_message_object=False,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
else:
|
||||
return server.group_manager.list_group_messages(
|
||||
group_id=group_id,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
@router.patch("/{group_id}/reset-messages", response_model=None, operation_id="reset_group_messages")
|
||||
def reset_group_messages(
|
||||
group_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Resets the messages for all agents that are part of the multi-agent group.
|
||||
TODO: only delete group messages not all messages!
|
||||
Delete the group messages for all agents that are part of the multi-agent group.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
group = server.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
agent_ids = group.agent_ids
|
||||
if group.manager_agent_id:
|
||||
agent_ids.append(group.manager_agent_id)
|
||||
for agent_id in agent_ids:
|
||||
server.agent_manager.reset_messages(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
add_default_initial_messages=add_default_initial_messages,
|
||||
)
|
||||
'''
|
||||
server.group_manager.reset_messages(group_id=group_id, actor=actor)
|
||||
|
||||
@@ -211,7 +211,6 @@ def create_tool_call_messages_from_openai_response(
|
||||
tool_calls=[],
|
||||
tool_call_id=tool_call_id,
|
||||
created_at=get_utc_time(),
|
||||
name=function_name,
|
||||
)
|
||||
messages.append(tool_message)
|
||||
|
||||
|
||||
@@ -367,6 +367,9 @@ class SyncServer(Server):
|
||||
def load_multi_agent(
|
||||
self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None
|
||||
) -> Agent:
|
||||
if len(group.agent_ids) == 0:
|
||||
raise ValueError("Empty group: group must have at least one agent")
|
||||
|
||||
match group.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor)
|
||||
@@ -862,6 +865,7 @@ class SyncServer(Server):
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
group_id: Optional[str] = None,
|
||||
reverse: Optional[bool] = False,
|
||||
return_message_object: bool = True,
|
||||
use_assistant_message: bool = True,
|
||||
@@ -879,6 +883,7 @@ class SyncServer(Server):
|
||||
before=before,
|
||||
limit=limit,
|
||||
ascending=not reverse,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
if not return_message_object:
|
||||
@@ -1591,88 +1596,76 @@ class SyncServer(Server):
|
||||
) -> Union[StreamingResponse, LettaResponse]:
|
||||
include_final_message = True
|
||||
if not stream_steps and stream_tokens:
|
||||
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
|
||||
raise ValueError("stream_steps must be 'true' if stream_tokens is 'true'")
|
||||
|
||||
try:
|
||||
# fetch the group
|
||||
group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
letta_multi_agent = self.load_multi_agent(group=group, actor=actor)
|
||||
group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
letta_multi_agent = self.load_multi_agent(group=group, actor=actor)
|
||||
|
||||
llm_config = letta_multi_agent.agent_state.llm_config
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"]
|
||||
if stream_tokens and (
|
||||
llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint
|
||||
):
|
||||
warnings.warn(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
)
|
||||
stream_tokens = False
|
||||
llm_config = letta_multi_agent.agent_state.llm_config
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"]
|
||||
if stream_tokens and (
|
||||
llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint
|
||||
):
|
||||
warnings.warn(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
)
|
||||
stream_tokens = False
|
||||
|
||||
# Create a new interface per request
|
||||
letta_multi_agent.interface = StreamingServerInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
inner_thoughts_in_kwargs=(
|
||||
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
|
||||
# Create a new interface per request
|
||||
letta_multi_agent.interface = StreamingServerInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
inner_thoughts_in_kwargs=(
|
||||
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
|
||||
),
|
||||
)
|
||||
streaming_interface = letta_multi_agent.interface
|
||||
if not isinstance(streaming_interface, StreamingServerInterface):
|
||||
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
|
||||
streaming_interface.streaming_mode = stream_tokens
|
||||
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
|
||||
if metadata and hasattr(streaming_interface, "metadata"):
|
||||
streaming_interface.metadata = metadata
|
||||
|
||||
streaming_interface.stream_start()
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
letta_multi_agent.step,
|
||||
messages=messages,
|
||||
chaining=self.chaining,
|
||||
max_chaining_steps=self.max_chaining_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
# return a stream
|
||||
return StreamingResponse(
|
||||
sse_async_generator(
|
||||
streaming_interface.get_generator(),
|
||||
usage_task=task,
|
||||
finish_message=include_final_message,
|
||||
),
|
||||
)
|
||||
streaming_interface = letta_multi_agent.interface
|
||||
if not isinstance(streaming_interface, StreamingServerInterface):
|
||||
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
|
||||
streaming_interface.streaming_mode = stream_tokens
|
||||
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
|
||||
if metadata and hasattr(streaming_interface, "metadata"):
|
||||
streaming_interface.metadata = metadata
|
||||
|
||||
streaming_interface.stream_start()
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
letta_multi_agent.step,
|
||||
messages=messages,
|
||||
chaining=self.chaining,
|
||||
max_chaining_steps=self.max_chaining_steps,
|
||||
)
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
# return a stream
|
||||
return StreamingResponse(
|
||||
sse_async_generator(
|
||||
streaming_interface.get_generator(),
|
||||
usage_task=task,
|
||||
finish_message=include_final_message,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
# buffer the stream, then return the list
|
||||
generated_stream = []
|
||||
async for message in streaming_interface.get_generator():
|
||||
assert (
|
||||
isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or isinstance(message, MessageStreamStatus)
|
||||
), type(message)
|
||||
generated_stream.append(message)
|
||||
if message == MessageStreamStatus.done:
|
||||
break
|
||||
|
||||
else:
|
||||
# buffer the stream, then return the list
|
||||
generated_stream = []
|
||||
async for message in streaming_interface.get_generator():
|
||||
assert (
|
||||
isinstance(message, LettaMessage)
|
||||
or isinstance(message, LegacyLettaMessage)
|
||||
or isinstance(message, MessageStreamStatus)
|
||||
), type(message)
|
||||
generated_stream.append(message)
|
||||
if message == MessageStreamStatus.done:
|
||||
break
|
||||
# Get rid of the stream status messages
|
||||
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
|
||||
usage = await task
|
||||
|
||||
# Get rid of the stream status messages
|
||||
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
|
||||
usage = await task
|
||||
|
||||
# By default the stream will be messages of type LettaMessage or LettaLegacyMessage
|
||||
# If we want to convert these to Message, we can use the attached IDs
|
||||
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
|
||||
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
|
||||
return LettaResponse(messages=filtered_stream, usage=usage)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
# By default the stream will be messages of type LettaMessage or LettaLegacyMessage
|
||||
# If we want to convert these to Message, we can use the attached IDs
|
||||
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
|
||||
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
|
||||
return LettaResponse(messages=filtered_stream, usage=usage)
|
||||
|
||||
@@ -7,7 +7,9 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.group import Group as GroupModel
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.schemas.group import Group as PydanticGroup
|
||||
from letta.schemas.group import GroupCreate, ManagerType
|
||||
from letta.schemas.group import GroupCreate, GroupUpdate, ManagerType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@@ -22,12 +24,12 @@ class GroupManager:
|
||||
@enforce_types
|
||||
def list_groups(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
project_id: Optional[str] = None,
|
||||
manager_type: Optional[ManagerType] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
) -> list[PydanticGroup]:
|
||||
with self.session_maker() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
@@ -56,27 +58,66 @@ class GroupManager:
|
||||
new_group = GroupModel()
|
||||
new_group.organization_id = actor.organization_id
|
||||
new_group.description = group.description
|
||||
|
||||
match group.manager_config.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
new_group.manager_type = ManagerType.round_robin
|
||||
new_group.max_turns = group.manager_config.max_turns
|
||||
case ManagerType.dynamic:
|
||||
new_group.manager_type = ManagerType.dynamic
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
new_group.max_turns = group.manager_config.max_turns
|
||||
new_group.termination_token = group.manager_config.termination_token
|
||||
case ManagerType.supervisor:
|
||||
new_group.manager_type = ManagerType.supervisor
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
|
||||
|
||||
self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False)
|
||||
if group.manager_config is None:
|
||||
new_group.manager_type = ManagerType.round_robin
|
||||
else:
|
||||
match group.manager_config.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
new_group.manager_type = ManagerType.round_robin
|
||||
new_group.max_turns = group.manager_config.max_turns
|
||||
case ManagerType.dynamic:
|
||||
new_group.manager_type = ManagerType.dynamic
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
new_group.max_turns = group.manager_config.max_turns
|
||||
new_group.termination_token = group.manager_config.termination_token
|
||||
case ManagerType.supervisor:
|
||||
new_group.manager_type = ManagerType.supervisor
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
|
||||
|
||||
new_group.create(session, actor=actor)
|
||||
return new_group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
|
||||
with self.session_maker() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
max_turns = None
|
||||
termination_token = None
|
||||
manager_agent_id = None
|
||||
if group_update.manager_config:
|
||||
if group_update.manager_config.manager_type != group.manager_type:
|
||||
raise ValueError(f"Cannot change group pattern after creation")
|
||||
match group_update.manager_config.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
max_turns = group_update.manager_config.max_turns
|
||||
case ManagerType.dynamic:
|
||||
manager_agent_id = group_update.manager_config.manager_agent_id
|
||||
max_turns = group_update.manager_config.max_turns
|
||||
termination_token = group_update.manager_config.termination_token
|
||||
case ManagerType.supervisor:
|
||||
manager_agent_id = group_update.manager_config.manager_agent_id
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}")
|
||||
|
||||
if max_turns:
|
||||
group.max_turns = max_turns
|
||||
if termination_token:
|
||||
group.termination_token = termination_token
|
||||
if manager_agent_id:
|
||||
group.manager_agent_id = manager_agent_id
|
||||
if group_update.description:
|
||||
group.description = group_update.description
|
||||
if group_update.agent_ids:
|
||||
self._process_agent_relationship(
|
||||
session=session, group=group, agent_ids=group_update.agent_ids, allow_partial=False, replace=True
|
||||
)
|
||||
|
||||
group.update(session, actor=actor)
|
||||
return group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
|
||||
with self.session_maker() as session:
|
||||
@@ -87,23 +128,19 @@ class GroupManager:
|
||||
@enforce_types
|
||||
def list_group_messages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
group_id: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = "send_message",
|
||||
assistant_message_tool_kwarg: str = "message",
|
||||
) -> list[PydanticGroup]:
|
||||
) -> list[LettaMessage]:
|
||||
with self.session_maker() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
agent_id = group.manager_agent_id if group.manager_agent_id else group.agent_ids[0]
|
||||
|
||||
filters = {
|
||||
"organization_id": actor.organization_id,
|
||||
"group_id": group_id,
|
||||
"agent_id": agent_id,
|
||||
}
|
||||
messages = MessageModel.list(
|
||||
db_session=session,
|
||||
@@ -114,21 +151,39 @@ class GroupManager:
|
||||
)
|
||||
|
||||
messages = PydanticMessage.to_letta_messages_from_list(
|
||||
messages=messages,
|
||||
messages=[msg.to_pydantic() for msg in messages],
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
)
|
||||
|
||||
# TODO: filter messages to return a clean conversation history
|
||||
|
||||
return messages
|
||||
|
||||
@enforce_types
|
||||
def reset_messages(self, group_id: str, actor: PydanticUser) -> None:
|
||||
with self.session_maker() as session:
|
||||
# Ensure group is loadable by user
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
# Delete all messages in the group
|
||||
session.query(MessageModel).filter(
|
||||
MessageModel.organization_id == actor.organization_id, MessageModel.group_id == group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
session.commit()
|
||||
|
||||
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
|
||||
current_relationship = getattr(group, "agents", [])
|
||||
if not agent_ids:
|
||||
if replace:
|
||||
setattr(group, "agents", [])
|
||||
setattr(group, "agent_ids", [])
|
||||
return
|
||||
|
||||
if group.manager_type == ManagerType.dynamic and len(agent_ids) != len(set(agent_ids)):
|
||||
raise ValueError("Duplicate agent ids found in list")
|
||||
|
||||
# Retrieve models for the provided IDs
|
||||
found_items = session.query(AgentModel).filter(AgentModel.id.in_(agent_ids)).all()
|
||||
|
||||
@@ -137,11 +192,14 @@ class GroupManager:
|
||||
missing = set(agent_ids) - {item.id for item in found_items}
|
||||
raise NoResultFound(f"Items not found in agents: {missing}")
|
||||
|
||||
if group.manager_type == ManagerType.dynamic:
|
||||
names = [item.name for item in found_items]
|
||||
if len(names) != len(set(names)):
|
||||
raise ValueError("Duplicate agent names found in the provided agent IDs.")
|
||||
|
||||
if replace:
|
||||
# Replace the relationship
|
||||
setattr(group, "agents", found_items)
|
||||
setattr(group, "agent_ids", agent_ids)
|
||||
else:
|
||||
# Extend the relationship (only add new items)
|
||||
current_ids = {item.id for item in current_relationship}
|
||||
new_items = [item for item in found_items if item.id not in current_ids]
|
||||
current_relationship.extend(new_items)
|
||||
raise ValueError("Extend relationship is not supported for groups.")
|
||||
|
||||
@@ -264,6 +264,7 @@ class MessageManager:
|
||||
roles: Optional[Sequence[MessageRole]] = None,
|
||||
limit: Optional[int] = 50,
|
||||
ascending: bool = True,
|
||||
group_id: Optional[str] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Most performant query to list messages for an agent by directly querying the Message table.
|
||||
@@ -296,6 +297,10 @@ class MessageManager:
|
||||
# Build a query that directly filters the Message table by agent_id.
|
||||
query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id)
|
||||
|
||||
# If group_id is provided, filter messages by group_id.
|
||||
if group_id:
|
||||
query = query.filter(MessageModel.group_id == group_id)
|
||||
|
||||
# If query_text is provided, filter messages using subquery.
|
||||
if query_text:
|
||||
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
||||
|
||||
@@ -22,7 +22,7 @@ class SupervisorMultiAgent(Agent):
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User = None,
|
||||
user: User,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
@@ -65,6 +65,7 @@ class SupervisorMultiAgent(Agent):
|
||||
self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user)
|
||||
|
||||
# override tool rules
|
||||
old_tool_rules = self.agent_state.tool_rules
|
||||
self.agent_state.tool_rules = [
|
||||
InitToolRule(
|
||||
tool_name="send_message_to_all_agents_in_group",
|
||||
@@ -106,6 +107,7 @@ class SupervisorMultiAgent(Agent):
|
||||
raise e
|
||||
finally:
|
||||
self.interface.step_yield()
|
||||
self.agent_state.tool_rules = old_tool_rules
|
||||
|
||||
self.interface.step_complete()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from letta.config import LettaConfig
|
||||
from letta.orm import Provider, Step
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.group import DynamicManager, GroupCreate, SupervisorManager
|
||||
from letta.schemas.group import DynamicManager, GroupCreate, GroupUpdate, ManagerType, RoundRobinManager, SupervisorManager
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@@ -45,7 +45,7 @@ def actor(server, org_id):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def participant_agent_ids(server, actor):
|
||||
def participant_agents(server, actor):
|
||||
agent_fred = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="fred",
|
||||
@@ -102,7 +102,7 @@ def participant_agent_ids(server, actor):
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
yield [agent_fred.id, agent_velma.id, agent_daphne.id, agent_shaggy.id]
|
||||
yield [agent_fred, agent_velma, agent_daphne, agent_shaggy]
|
||||
|
||||
# cleanup
|
||||
server.agent_manager.delete_agent(agent_fred.id, actor=actor)
|
||||
@@ -112,7 +112,7 @@ def participant_agent_ids(server, actor):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def manager_agent_id(server, actor):
|
||||
def manager_agent(server, actor):
|
||||
agent_scooby = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="scooby",
|
||||
@@ -131,22 +131,84 @@ def manager_agent_id(server, actor):
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
yield agent_scooby.id
|
||||
yield agent_scooby
|
||||
|
||||
# cleanup
|
||||
server.agent_manager.delete_agent(agent_scooby.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin(server, actor, participant_agent_ids):
|
||||
async def test_empty_group(server, actor):
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=participant_agent_ids,
|
||||
agent_ids=[],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
with pytest.raises(ValueError, match="Empty group"):
|
||||
await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="what is everyone up to for the holidays?",
|
||||
),
|
||||
],
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_group_pattern(server, actor, participant_agents, manager_agent):
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=[agent.id for agent in participant_agents],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
with pytest.raises(ValueError, match="Cannot change group pattern"):
|
||||
server.group_manager.modify_group(
|
||||
group_id=group.id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=DynamicManager(
|
||||
manager_type=ManagerType.dynamic,
|
||||
manager_agent_id=manager_agent.id,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin(server, actor, participant_agents):
|
||||
description = (
|
||||
"This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries."
|
||||
)
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description=description,
|
||||
agent_ids=[agent.id for agent in participant_agents],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# verify group creation
|
||||
assert group.manager_type == ManagerType.round_robin
|
||||
assert group.description == description
|
||||
assert group.agent_ids == [agent.id for agent in participant_agents]
|
||||
assert group.max_turns == None
|
||||
assert group.manager_agent_id is None
|
||||
assert group.termination_token is None
|
||||
|
||||
try:
|
||||
server.group_manager.reset_messages(group_id=group.id, actor=actor)
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
@@ -159,15 +221,85 @@ async def test_round_robin(server, actor, participant_agent_ids):
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
assert response.usage.step_count == len(participant_agent_ids)
|
||||
assert response.usage.step_count == len(group.agent_ids)
|
||||
assert len(response.messages) == response.usage.step_count * 2
|
||||
for i, message in enumerate(response.messages):
|
||||
assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message"
|
||||
assert message.name == participant_agents[i // 2].name
|
||||
|
||||
for agent_id in group.agent_ids:
|
||||
agent_messages = server.get_agent_recall(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
group_id=group.id,
|
||||
reverse=True,
|
||||
return_message_object=False,
|
||||
)
|
||||
assert len(agent_messages) == len(group.agent_ids) + 2 # add one for user message, one for reasoning message
|
||||
|
||||
# TODO: filter this to return a clean conversation history
|
||||
messages = server.group_manager.list_group_messages(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
)
|
||||
assert len(messages) == (len(group.agent_ids) + 2) * len(group.agent_ids)
|
||||
|
||||
max_turns = 3
|
||||
group = server.group_manager.modify_group(
|
||||
group_id=group.id,
|
||||
group_update=GroupUpdate(
|
||||
agent_ids=[agent.id for agent in participant_agents][::-1],
|
||||
manager_config=RoundRobinManager(
|
||||
max_turns=max_turns,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
assert group.manager_type == ManagerType.round_robin
|
||||
assert group.description == description
|
||||
assert group.agent_ids == [agent.id for agent in participant_agents][::-1]
|
||||
assert group.max_turns == max_turns
|
||||
assert group.manager_agent_id is None
|
||||
assert group.termination_token is None
|
||||
|
||||
server.group_manager.reset_messages(group_id=group.id, actor=actor)
|
||||
|
||||
response = await server.send_group_message_to_agent(
|
||||
group_id=group.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="what is everyone up to for the holidays?",
|
||||
),
|
||||
],
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
assert response.usage.step_count == max_turns
|
||||
assert len(response.messages) == max_turns * 2
|
||||
|
||||
for i, message in enumerate(response.messages):
|
||||
assert message.message_type == "reasoning_message" if i % 2 == 0 else "assistant_message"
|
||||
assert message.name == participant_agents[::-1][i // 2].name
|
||||
|
||||
for i in range(len(group.agent_ids)):
|
||||
agent_messages = server.get_agent_recall(
|
||||
user_id=actor.id,
|
||||
agent_id=group.agent_ids[i],
|
||||
group_id=group.id,
|
||||
reverse=True,
|
||||
return_message_object=False,
|
||||
)
|
||||
expected_message_count = max_turns + 1 if i >= max_turns else max_turns + 2
|
||||
assert len(agent_messages) == expected_message_count
|
||||
|
||||
finally:
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_supervisor(server, actor, participant_agent_ids):
|
||||
async def test_supervisor(server, actor, participant_agents):
|
||||
agent_scrappy = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="shaggy",
|
||||
@@ -186,10 +318,11 @@ async def test_supervisor(server, actor, participant_agent_ids):
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=participant_agent_ids,
|
||||
agent_ids=[agent.id for agent in participant_agents],
|
||||
manager_config=SupervisorManager(
|
||||
manager_agent_id=agent_scrappy.id,
|
||||
),
|
||||
@@ -219,7 +352,7 @@ async def test_supervisor(server, actor, participant_agent_ids):
|
||||
and response.messages[1].tool_call.name == "send_message_to_all_agents_in_group"
|
||||
)
|
||||
assert response.messages[2].message_type == "tool_return_message" and len(eval(response.messages[2].tool_return)) == len(
|
||||
participant_agent_ids
|
||||
participant_agents
|
||||
)
|
||||
assert response.messages[3].message_type == "reasoning_message"
|
||||
assert response.messages[4].message_type == "assistant_message"
|
||||
@@ -230,13 +363,50 @@ async def test_supervisor(server, actor, participant_agent_ids):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_agent_ids):
|
||||
async def test_dynamic_group_chat(server, actor, manager_agent, participant_agents):
|
||||
description = (
|
||||
"This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries."
|
||||
)
|
||||
# error on duplicate agent in participant list
|
||||
with pytest.raises(ValueError, match="Duplicate agent ids"):
|
||||
server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description=description,
|
||||
agent_ids=[agent.id for agent in participant_agents] + [participant_agents[0].id],
|
||||
manager_config=DynamicManager(
|
||||
manager_agent_id=manager_agent.id,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
# error on duplicate agent names
|
||||
duplicate_agent_shaggy = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="shaggy",
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
with pytest.raises(ValueError, match="Duplicate agent names"):
|
||||
server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description=description,
|
||||
agent_ids=[agent.id for agent in participant_agents] + [duplicate_agent_shaggy.id],
|
||||
manager_config=DynamicManager(
|
||||
manager_agent_id=manager_agent.id,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
server.agent_manager.delete_agent(duplicate_agent_shaggy.id, actor=actor)
|
||||
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="This is a group chat between best friends all like to hang out together. In their free time they like to solve mysteries.",
|
||||
agent_ids=participant_agent_ids,
|
||||
description=description,
|
||||
agent_ids=[agent.id for agent in participant_agents],
|
||||
manager_config=DynamicManager(
|
||||
manager_agent_id=manager_agent_id,
|
||||
manager_agent_id=manager_agent.id,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
@@ -251,7 +421,7 @@ async def test_dynamic_group_chat(server, actor, manager_agent_id, participant_a
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
assert response.usage.step_count == len(participant_agent_ids) * 2
|
||||
assert response.usage.step_count == len(participant_agents) * 2
|
||||
assert len(response.messages) == response.usage.step_count * 2
|
||||
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user