feat: add otid to new agent loop (#1635)

This commit is contained in:
cthomas
2025-04-09 16:50:41 -07:00
committed by GitHub
parent 74e299a05f
commit 29fcccb3a4
10 changed files with 135 additions and 97 deletions

View File

@@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Optional, Union
from typing import Any, AsyncGenerator, List, Optional, Union
import openai
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, UserMessage
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import MessageCreate
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.message_manager import MessageManager
@@ -33,7 +35,7 @@ class BaseAgent(ABC):
self.actor = actor
@abstractmethod
async def step(self, input_message: UserMessage, max_steps: int = 10) -> LettaResponse:
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
"""
Main execution loop for the agent.
"""
@@ -41,15 +43,24 @@ class BaseAgent(ABC):
@abstractmethod
async def step_stream(
self, input_message: UserMessage, max_steps: int = 10
self, input_messages: List[MessageCreate], max_steps: int = 10
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
"""
Main streaming execution loop for the agent.
"""
raise NotImplementedError
def pre_process_input_message(self, input_message: UserMessage) -> Any:
def pre_process_input_message(self, input_messages: List[MessageCreate]) -> Any:
"""
Pre-process function to run on the input_message.
"""
return input_message.model_dump()
def get_content(message: MessageCreate) -> str:
if isinstance(message.content, str):
return message.content
elif message.content and len(message.content) == 1 and isinstance(message.content[0], TextContent):
return message.content[0].text
else:
return ""
return [{"role": input_message.role, "content": get_content(input_message)} for input_message in input_messages]

View File

@@ -5,9 +5,8 @@ import openai
from letta.agents.base_agent import BaseAgent
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import UserMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
@@ -37,15 +36,15 @@ class EphemeralAgent(BaseAgent):
actor=actor,
)
async def step(self, input_message: UserMessage) -> List[Message]:
async def step(self, input_messages: List[MessageCreate]) -> List[Message]:
"""
Synchronous method that takes a user's input text and returns a summary from OpenAI.
Returns a list of ephemeral Message objects containing both the user text and the assistant summary.
"""
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
input_message = self.pre_process_input_message(input_message=input_message)
request = self._build_openai_request([input_message], agent_state)
openai_messages = self.pre_process_input_message(input_messages=input_messages)
request = self._build_openai_request(openai_messages, agent_state)
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
@@ -66,7 +65,7 @@ class EphemeralAgent(BaseAgent):
)
return openai_request
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
async def step_stream(self, input_messages: List[MessageCreate]) -> AsyncGenerator[str, None]:
"""
This agent is synchronous-only. If called in an async context, raise an error.
"""

View File

@@ -7,9 +7,8 @@ from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import UserMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
@@ -38,15 +37,15 @@ class EphemeralMemoryAgent(BaseAgent):
actor=actor,
)
async def step(self, input_message: UserMessage) -> List[Message]:
async def step(self, input_messages: List[MessageCreate]) -> List[Message]:
"""
Synchronous method that takes a user's input text and returns a summary from OpenAI.
Returns a list of ephemeral Message objects containing both the user text and the assistant summary.
"""
agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor)
input_message = self.pre_process_input_message(input_message=input_message)
request = self._build_openai_request([input_message], agent_state)
openai_messages = self.pre_process_input_message(input_messages=input_messages)
request = self._build_openai_request(openai_messages, agent_state)
chat_completion = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True))
@@ -57,7 +56,8 @@ class EphemeralMemoryAgent(BaseAgent):
)
]
def pre_process_input_message(self, input_message: UserMessage) -> Dict:
def pre_process_input_message(self, input_messages: List[MessageCreate]) -> List[Dict]:
input_message = input_messages[0]
input_prompt_augmented = f"""
You are a memory recall agent whose job is to comb through a large set of messages and write relevant memories in relation to a user query.
Your response will directly populate a "memory block" called "human" that describes the user, that will be used to answer more questions in the future.
@@ -78,9 +78,7 @@ class EphemeralMemoryAgent(BaseAgent):
Your response:
"""
input_message.content = input_prompt_augmented
# print(input_prompt_augmented)
return input_message.model_dump()
return [{"role": "user", "content": input_prompt_augmented}]
def _format_messages_llm_friendly(self):
messages = self.message_manager.list_messages_for_agent(agent_id=self.agent_id, actor=self.actor)
@@ -107,7 +105,7 @@ class EphemeralMemoryAgent(BaseAgent):
return [Tool(type="function", function=enable_strict_mode(t.json_schema)) for t in tools]
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
async def step_stream(self, input_messages: List[MessageCreate]) -> AsyncGenerator[str, None]:
"""
This agent is synchronous-only. If called in an async context, raise an error.
"""

View File

@@ -1,11 +1,11 @@
from typing import Dict, List, Tuple
from typing import List, Tuple
from letta.schemas.agent import AgentState
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message
from letta.schemas.message import Message, MessageCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_user_message
from letta.server.rest_api.utils import create_input_messages
from letta.services.message_manager import MessageManager
@@ -20,13 +20,13 @@ def _create_letta_response(new_in_context_messages: list[Message], use_assistant
def _prepare_in_context_messages(
input_message: Dict, agent_state: AgentState, message_manager: MessageManager, actor: User
input_messages: List[MessageCreate], agent_state: AgentState, message_manager: MessageManager, actor: User
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.
Args:
input_message (Dict): The new user input message to process.
input_messages (List[MessageCreate]): The new user input messages to process.
agent_state (AgentState): The current state of the agent, including message buffer config.
message_manager (MessageManager): The manager used to retrieve and create messages.
actor (User): The user performing the action, used for access control and attribution.
@@ -46,7 +46,7 @@ def _prepare_in_context_messages(
# Create a new user message from the input and store it
new_in_context_messages = message_manager.create_many_messages(
[create_user_message(input_message=input_message, agent_id=agent_state.id, actor=actor)], actor=actor
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=actor), actor=actor
)
return current_in_context_messages, new_in_context_messages

View File

@@ -18,12 +18,11 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.letta_message import AssistantMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.openai.chat_completion_request import UserMessage
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import ToolCall
from letta.schemas.user import User
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
@@ -60,11 +59,10 @@ class LettaAgent(BaseAgent):
self.use_assistant_message = use_assistant_message
@trace_method
async def step(self, input_message: UserMessage, max_steps: int = 10) -> LettaResponse:
input_message = self.pre_process_input_message(input_message)
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_message, agent_state, self.message_manager, self.actor
input_messages, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
@@ -96,16 +94,15 @@ class LettaAgent(BaseAgent):
@trace_method
async def step_stream(
self, input_message: UserMessage, max_steps: int = 10, use_assistant_message: bool = False
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = False
) -> AsyncGenerator[str, None]:
"""
Main streaming loop that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well.
"""
input_message = self.pre_process_input_message(input_message)
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_message, agent_state, self.message_manager, self.actor
input_messages, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
@@ -362,7 +359,9 @@ class LettaAgent(BaseAgent):
f"{message}"
)
letta_response = await letta_agent.step(UserMessage(content=augmented_message))
letta_response = await letta_agent.step(
[MessageCreate(role=MessageRole.system, content=[TextContent(text=augmented_message)])]
)
messages = letta_response.messages
send_message_content = [message.content for message in messages if isinstance(message, AssistantMessage)]

View File

@@ -19,8 +19,9 @@ from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.block import BlockUpdate
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_request import (
AssistantMessage,
ChatCompletionRequest,
@@ -34,8 +35,8 @@ from letta.schemas.user import User
from letta.server.rest_api.utils import (
convert_letta_messages_to_openai,
create_assistant_messages_from_openai_response,
create_input_messages,
create_letta_messages_from_llm_response,
create_user_message,
)
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
@@ -93,19 +94,18 @@ class VoiceAgent(BaseAgent):
agent_id=agent_id, openai_client=openai_client, message_manager=message_manager, agent_manager=agent_manager, actor=actor
)
async def step(self, input_message: UserMessage, max_steps: int = 10) -> LettaResponse:
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
raise NotImplementedError("LowLatencyAgent does not have a synchronous step implemented currently.")
async def step_stream(self, input_message: UserMessage, max_steps: int = 10) -> AsyncGenerator[str, None]:
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
"""
Main streaming loop that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well.
"""
input_message = self.pre_process_input_message(input_message)
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
letta_message_db_queue = [create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)]
in_memory_message_history = [input_message]
letta_message_db_queue = [create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=self.actor)]
in_memory_message_history = self.pre_process_input_message(input_messages)
# TODO: Define max steps here
for _ in range(max_steps):
@@ -372,7 +372,7 @@ class VoiceAgent(BaseAgent):
return f"Failed to call tool. Error: {e}", False
async def _recall_memory(self, query, agent_state: AgentState) -> None:
results = await self.offline_memory_agent.step(UserMessage(content=query))
results = await self.offline_memory_agent.step([MessageCreate(role="user", content=[TextContent(text=query)])])
target_block = next(b for b in agent_state.memory.blocks if b.label == self.summary_block_label)
self.block_manager.update_block(
block_id=target_block.id, block_update=BlockUpdate(value=results[0].content[0].text), actor=self.actor

View File

@@ -22,7 +22,6 @@ from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ContextWindowOverview, CreateArchivalMemory, Memory
from letta.schemas.message import MessageCreate
from letta.schemas.openai.chat_completion_request import UserMessage
from letta.schemas.passage import Passage, PassageUpdate
from letta.schemas.run import Run
from letta.schemas.source import Source
@@ -610,9 +609,7 @@ async def send_message(
actor=actor,
)
messages = request.messages
content = messages[0].content[0].text if messages and not isinstance(messages[0].content, str) else messages[0].content
result = await experimental_agent.step(UserMessage(content=content), max_steps=10)
result = await experimental_agent.step(request.messages, max_steps=10)
else:
result = await server.send_message_to_agent(
agent_id=agent_id,
@@ -672,10 +669,8 @@ async def send_message_streaming(
actor=actor,
)
messages = request.messages
content = messages[0].content[0].text if messages and not isinstance(messages[0].content, str) else messages[0].content
result = StreamingResponse(
experimental_agent.step_stream(UserMessage(content=content), max_steps=10, use_assistant_message=request.use_assistant_message),
experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message),
media_type="text/event-stream",
)
else:

View File

@@ -19,7 +19,7 @@ from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.message import Message
from letta.schemas.message import Message, MessageCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.interface import StreamingServerInterface
@@ -140,31 +140,29 @@ def log_error_to_sentry(e):
sentry_sdk.capture_exception(e)
def create_user_message(input_message: dict, agent_id: str, actor: User) -> Message:
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]:
"""
Converts a user input message into the internal structured format.
"""
# Generate timestamp in the correct format
# Skip pytz for performance reasons
now = get_utc_time().isoformat()
new_messages = []
for input_message in input_messages:
# Construct the Message object
new_message = Message(
id=f"message-{uuid.uuid4()}",
role=input_message.role,
content=input_message.content,
name=input_message.name,
otid=input_message.otid,
organization_id=actor.organization_id,
agent_id=agent_id,
model=None,
tool_calls=None,
tool_call_id=None,
created_at=get_utc_time(),
)
new_messages.append(new_message)
# Format message as structured JSON
structured_message = {"type": "user_message", "message": input_message["content"], "time": now}
# Construct the Message object
user_message = Message(
id=f"message-{uuid.uuid4()}",
role=MessageRole.user,
content=[TextContent(text=json.dumps(structured_message, indent=2))], # Store structured JSON
organization_id=actor.organization_id,
agent_id=agent_id,
model=None,
tool_calls=None,
tool_call_id=None,
created_at=get_utc_time(),
)
return user_message
return new_messages
def create_letta_messages_from_llm_response(

View File

@@ -4,8 +4,8 @@ from typing import List, Tuple
from letta.agents.base_agent import BaseAgent
from letta.schemas.enums import MessageRole
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import UserMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.services.summarizer.enums import SummarizationMode
@@ -95,8 +95,15 @@ class Summarizer:
"It should be in note-taking format in natural English. You are to return the new, updated memory only."
)
messages = await self.summarizer_agent.step(UserMessage(content=summary_request_text))
current_summary = "\n".join([m.content[0].text for m in messages])
response = await self.summarizer_agent.step(
input_messages=[
MessageCreate(
role=MessageRole.user,
content=[TextContent(text=summary_request_text)],
),
],
)
current_summary = "\n".join([m.content[0].text for m in response.messages if m.message_type == "assistant_message"])
current_summary = f"{self.summary_prefix}{current_summary}"
return updated_in_context_messages, current_summary, True

View File

@@ -7,14 +7,15 @@ import httpx
import openai
import pytest
from dotenv import load_dotenv
from letta_client import CreateBlock, Letta
from letta_client import CreateBlock, Letta, MessageCreate, TextContent
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.agents.letta_agent import LettaAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message_content import TextContent as LettaTextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_request import UserMessage
from letta.schemas.message import MessageCreate as LettaMessageCreate
from letta.schemas.tool import ToolCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.services.agent_manager import AgentManager
@@ -248,7 +249,7 @@ async def test_new_agent_loop(disable_e2b_api_key, openai_client, agent_state, m
actor=actor,
)
response = await agent.step(UserMessage(content=message))
response = await agent.step([LettaMessageCreate(role="user", content=[LettaTextContent(text=message)])])
@pytest.mark.asyncio
@@ -265,7 +266,7 @@ async def test_rethink_tool(disable_e2b_api_key, openai_client, agent_state, mes
)
assert "chicken" not in AgentManager().get_agent_by_id(agent_state.id, actor).memory.get_block("human").value
response = await agent.step(UserMessage(content=message))
response = await agent.step([LettaMessageCreate(role="user", content=[LettaTextContent(text=message)])])
assert "chicken" in AgentManager().get_agent_by_id(agent_state.id, actor).memory.get_block("human").value
@@ -275,9 +276,16 @@ async def test_multi_agent_broadcast(disable_e2b_api_key, client, openai_client,
stale_agents = AgentManager().list_agents(actor=actor, limit=300)
for agent in stale_agents:
client.delete_agent(agent_id=agent.id)
AgentManager().delete_agent(agent_id=agent.id, actor=actor)
manager_agent_state = client.create_agent(name=f"manager", include_base_tools=True, include_multi_agent_tools=True, tags=["manager"])
manager_agent_state = client.agents.create(
name=f"manager",
include_base_tools=True,
include_multi_agent_tools=True,
tags=["manager"],
model="openai/gpt-4o",
embedding="letta/letta-free",
)
manager_agent = LettaAgent(
agent_id=manager_agent_state.id,
message_manager=MessageManager(),
@@ -290,12 +298,31 @@ async def test_multi_agent_broadcast(disable_e2b_api_key, client, openai_client,
tag = "subagent"
workers = []
for idx in range(30):
workers.append(client.create_agent(name=f"worker_{idx}", include_base_tools=True, tags=[tag], tool_ids=[weather_tool.id]))
workers.append(
client.agents.create(
name=f"worker_{idx}",
include_base_tools=True,
tags=[tag],
tool_ids=[weather_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
),
)
response = await manager_agent.step(
UserMessage(
content="Use the `send_message_to_agents_matching_tags` tool to send a message to agents with tag 'subagent' asking them to check the weather in Seattle.",
)
[
LettaMessageCreate(
role="user",
content=[
LettaTextContent(
text=(
"Use the `send_message_to_agents_matching_tags` tool to send a message to agents with "
"tag 'subagent' asking them to check the weather in Seattle."
)
),
],
),
]
)
@@ -334,10 +361,14 @@ def test_multi_agent_broadcast_client(client: Letta, weather_tool):
response = client.agents.messages.create(
agent_id=supervisor.id,
messages=[
{
"role": "user",
"content": "Use the `send_message_to_agents_matching_tags` tool to send a message to agents with tag 'worker' asking them to check the weather in Seattle.",
}
MessageCreate(
role="user",
content=[
TextContent(
text="Use the `send_message_to_agents_matching_tags` tool to send a message to agents with tag 'worker' asking them to check the weather in Seattle."
)
],
)
],
)
end = time.perf_counter()
@@ -456,10 +487,10 @@ def test_anthropic_streaming(client: Letta):
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
{
"role": "user",
"content": "Use core memory append to append `banana` to the persona core memory.",
}
MessageCreate(
role="user",
content=[TextContent(text="Use the core memory append tool to append `banana` to the persona core memory.")],
),
],
stream_tokens=True,
)