feat: Adjust ephemeral memory agent to become persisted sleeptime agent (#1943)
This commit is contained in:
@@ -86,7 +86,7 @@ class Swarm:
|
||||
# grab responses
|
||||
messages = []
|
||||
for message in response.messages:
|
||||
messages += message.to_letta_message()
|
||||
messages += message.to_letta_messages()
|
||||
|
||||
# get new agent (see tool call)
|
||||
# print(messages)
|
||||
|
||||
@@ -15,7 +15,7 @@ def _create_letta_response(new_in_context_messages: list[Message], use_assistant
|
||||
"""
|
||||
response_messages = []
|
||||
for msg in new_in_context_messages:
|
||||
response_messages.extend(msg.to_letta_message(use_assistant_message=use_assistant_message))
|
||||
response_messages.extend(msg.to_letta_messages(use_assistant_message=use_assistant_message))
|
||||
return LettaResponse(messages=response_messages, usage=LettaUsageStatistics())
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
import openai
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.ephemeral_memory_agent import EphemeralMemoryAgent
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.tool_execution_helper import (
|
||||
@@ -81,26 +81,39 @@ class VoiceAgent(BaseAgent):
|
||||
# TODO: This is not guaranteed to exist!
|
||||
self.summary_block_label = "human"
|
||||
self.message_buffer_limit = message_buffer_limit
|
||||
self.summarizer = Summarizer(
|
||||
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
summarizer_agent=EphemeralMemoryAgent(
|
||||
agent_id=agent_id,
|
||||
openai_client=openai_client,
|
||||
message_manager=message_manager,
|
||||
agent_manager=agent_manager,
|
||||
actor=actor,
|
||||
block_manager=block_manager,
|
||||
target_block_label=self.summary_block_label,
|
||||
message_transcripts=[],
|
||||
),
|
||||
message_buffer_limit=message_buffer_limit,
|
||||
message_buffer_min=message_buffer_min,
|
||||
)
|
||||
self.message_buffer_min = message_buffer_min
|
||||
|
||||
# Cached archival memory/message size
|
||||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
|
||||
def init_summarizer(self, agent_state: AgentState) -> Summarizer:
|
||||
if not agent_state.multi_agent_group:
|
||||
raise ValueError("Low latency voice agent is not part of a multiagent group, missing sleeptime agent.")
|
||||
if len(agent_state.multi_agent_group.agent_ids) != 1:
|
||||
raise ValueError(
|
||||
f"None or multiple participant agents found in voice sleeptime group: {agent_state.multi_agent_group.agent_ids}"
|
||||
)
|
||||
voice_sleeptime_agent_id = agent_state.multi_agent_group.agent_ids[0]
|
||||
summarizer = Summarizer(
|
||||
mode=SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
summarizer_agent=VoiceSleeptimeAgent(
|
||||
agent_id=voice_sleeptime_agent_id,
|
||||
convo_agent_state=agent_state,
|
||||
openai_client=self.openai_client,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
actor=self.actor,
|
||||
block_manager=self.block_manager,
|
||||
target_block_label=self.summary_block_label,
|
||||
message_transcripts=[],
|
||||
),
|
||||
message_buffer_limit=self.message_buffer_limit,
|
||||
message_buffer_min=self.message_buffer_min,
|
||||
)
|
||||
|
||||
return summarizer
|
||||
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
raise NotImplementedError("VoiceAgent does not have a synchronous step implemented currently.")
|
||||
|
||||
@@ -114,6 +127,8 @@ class VoiceAgent(BaseAgent):
|
||||
user_query = input_messages[0].content[0].text
|
||||
|
||||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||||
summarizer = self.init_summarizer(agent_state=agent_state)
|
||||
|
||||
in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
in_context_messages[0].content[0].text = compile_system_message(
|
||||
@@ -155,7 +170,7 @@ class VoiceAgent(BaseAgent):
|
||||
break
|
||||
|
||||
# Rebuild context window if desired
|
||||
await self._rebuild_context_window(in_context_messages, letta_message_db_queue)
|
||||
await self._rebuild_context_window(summarizer, in_context_messages, letta_message_db_queue)
|
||||
|
||||
# TODO: This may be out of sync, if in between steps users add files
|
||||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
@@ -253,11 +268,13 @@ class VoiceAgent(BaseAgent):
|
||||
# If we got here, there's no tool call. If finish_reason_stop => done
|
||||
return not streaming_interface.finish_reason_stop
|
||||
|
||||
async def _rebuild_context_window(self, in_context_messages: List[Message], letta_message_db_queue: List[Message]) -> None:
|
||||
async def _rebuild_context_window(
|
||||
self, summarizer: Summarizer, in_context_messages: List[Message], letta_message_db_queue: List[Message]
|
||||
) -> None:
|
||||
new_letta_messages = self.message_manager.create_many_messages(letta_message_db_queue, actor=self.actor)
|
||||
|
||||
# TODO: Make this more general and configurable, less brittle
|
||||
new_in_context_messages, updated = self.summarizer.summarize(
|
||||
new_in_context_messages, updated = summarizer.summarize(
|
||||
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import AsyncGenerator, Dict, List, Tuple, Union
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import openai
|
||||
|
||||
@@ -11,7 +11,7 @@ from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool, UserMessage
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
@@ -19,9 +19,11 @@ from letta.server.rest_api.utils import convert_in_context_letta_messages_to_ope
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.system import package_function_response
|
||||
|
||||
|
||||
class EphemeralMemoryAgent(BaseAgent):
|
||||
# TODO: Move this to the new Letta Agent loop
|
||||
class VoiceSleeptimeAgent(BaseAgent):
|
||||
"""
|
||||
A stateless agent that helps with offline memory computations.
|
||||
"""
|
||||
@@ -29,6 +31,7 @@ class EphemeralMemoryAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
convo_agent_state: AgentState,
|
||||
openai_client: openai.AsyncClient,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
@@ -45,6 +48,7 @@ class EphemeralMemoryAgent(BaseAgent):
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
self.convo_agent_state = convo_agent_state
|
||||
self.block_manager = block_manager
|
||||
self.target_block_label = target_block_label
|
||||
self.message_transcripts = message_transcripts
|
||||
@@ -75,26 +79,50 @@ class EphemeralMemoryAgent(BaseAgent):
|
||||
if function_name == "store_memories":
|
||||
print("Called store_memories")
|
||||
print(function_args)
|
||||
for chunk_args in function_args.get("chunks"):
|
||||
self.store_memory(agent_state=agent_state, **chunk_args)
|
||||
result = "Successfully stored memories"
|
||||
chunks = function_args.get("chunks", [])
|
||||
results = [self.store_memory(agent_state=self.convo_agent_state, **chunk_args) for chunk_args in chunks]
|
||||
|
||||
aggregated_result = next((res for res, _ in results if res is not None), None)
|
||||
aggregated_success = all(success for _, success in results)
|
||||
|
||||
else:
|
||||
raise ValueError("Error: Unknown tool function '{function_name}'")
|
||||
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {"name": function_name, "arguments": tool_call.function.arguments},
|
||||
}
|
||||
],
|
||||
}
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {"name": function_name, "arguments": tool_call.function.arguments},
|
||||
}
|
||||
],
|
||||
}
|
||||
openai_messages.append(assistant_message)
|
||||
in_context_messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_id,
|
||||
openai_message_dict=assistant_message,
|
||||
model=agent_state.llm_config.model,
|
||||
name=function_name,
|
||||
)
|
||||
)
|
||||
tool_call_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": package_function_response(was_success=aggregated_success, response_string=str(aggregated_result)),
|
||||
}
|
||||
openai_messages.append(tool_call_message)
|
||||
in_context_messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_id,
|
||||
openai_message_dict=tool_call_message,
|
||||
model=agent_state.llm_config.model,
|
||||
name=function_name,
|
||||
tool_returns=[ToolReturn(status="success" if aggregated_success else "error")],
|
||||
)
|
||||
)
|
||||
openai_messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(result)})
|
||||
|
||||
# 2. Execute rethink block memory loop
|
||||
human_block_content = self.agent_manager.get_block_with_label(
|
||||
@@ -113,7 +141,7 @@ Please refine this block:
|
||||
- Organize related information together (e.g., preferences, background, ongoing goals).
|
||||
- Add any light, supportable inferences that deepen understanding—but do not invent unsupported details.
|
||||
|
||||
Use `rethink_user_memor(new_memory)` as many times as you need to iteratively improve the text. When it’s fully polished and complete, call `finish_rethinking_memory()`.
|
||||
Use `rethink_user_memory(new_memory)` as many times as you need to iteratively improve the text. When it’s fully polished and complete, call `finish_rethinking_memory()`.
|
||||
"""
|
||||
rethink_command = UserMessage(content=rethink_command)
|
||||
openai_messages.append(rethink_command.model_dump())
|
||||
@@ -128,35 +156,59 @@ Use `rethink_user_memor(new_memory)` as many times as you need to iteratively im
|
||||
function_name = tool_call.function.name
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
if function_name == "rethink_user_memor":
|
||||
print("Called rethink_user_memor")
|
||||
if function_name == "rethink_user_memory":
|
||||
print("Called rethink_user_memory")
|
||||
print(function_args)
|
||||
result = self.rethink_user_memory(agent_state=agent_state, **function_args)
|
||||
result, success = self.rethink_user_memory(agent_state=agent_state, **function_args)
|
||||
elif function_name == "finish_rethinking_memory":
|
||||
print("Called finish_rethinking_memory")
|
||||
result, success = None, True
|
||||
break
|
||||
else:
|
||||
result = f"Error: Unknown tool function '{function_name}'"
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {"name": function_name, "arguments": tool_call.function.arguments},
|
||||
}
|
||||
],
|
||||
}
|
||||
print(f"Error: Unknown tool function '{function_name}'")
|
||||
raise ValueError(f"Error: Unknown tool function '{function_name}'", False)
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {"name": function_name, "arguments": tool_call.function.arguments},
|
||||
}
|
||||
],
|
||||
}
|
||||
openai_messages.append(assistant_message)
|
||||
in_context_messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_id,
|
||||
openai_message_dict=assistant_message,
|
||||
model=agent_state.llm_config.model,
|
||||
name=function_name,
|
||||
)
|
||||
)
|
||||
tool_call_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": package_function_response(was_success=success, response_string=str(result)),
|
||||
}
|
||||
openai_messages.append(tool_call_message)
|
||||
in_context_messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_id,
|
||||
openai_message_dict=tool_call_message,
|
||||
model=agent_state.llm_config.model,
|
||||
name=function_name,
|
||||
tool_returns=[ToolReturn(status="success" if success else "error")],
|
||||
)
|
||||
)
|
||||
openai_messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(result)})
|
||||
|
||||
# Actually save the memory:
|
||||
target_block = agent_state.memory.get_block(self.target_block_label)
|
||||
self.block_manager.update_block(block_id=target_block.id, block_update=BlockUpdate(value=target_block.value), actor=self.actor)
|
||||
|
||||
return LettaResponse(messages=[], usage=LettaUsageStatistics())
|
||||
self.message_manager.create_many_messages(pydantic_msgs=in_context_messages, actor=self.actor)
|
||||
return LettaResponse(messages=[msg for m in in_context_messages for msg in m.to_letta_messages()], usage=LettaUsageStatistics())
|
||||
|
||||
def _format_messages_llm_friendly(self):
|
||||
messages = self.message_manager.list_messages_for_agent(agent_id=self.agent_id, actor=self.actor)
|
||||
@@ -166,7 +218,7 @@ Use `rethink_user_memor(new_memory)` as many times as you need to iteratively im
|
||||
|
||||
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState, tools: List[Tool]) -> ChatCompletionRequest:
|
||||
openai_request = ChatCompletionRequest(
|
||||
model="gpt-4o", # agent_state.llm_config.model, # TODO: Separate config for summarizer?
|
||||
model=agent_state.llm_config.model, # TODO: Separate config for summarizer?
|
||||
messages=openai_messages,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
@@ -261,14 +313,14 @@ Use `rethink_user_memor(new_memory)` as many times as you need to iteratively im
|
||||
|
||||
return tools
|
||||
|
||||
def rethink_user_memory(self, new_memory: str, agent_state: AgentState) -> str:
|
||||
def rethink_user_memory(self, new_memory: str, agent_state: AgentState) -> Tuple[Optional[str], bool]:
|
||||
if agent_state.memory.get_block(self.target_block_label) is None:
|
||||
agent_state.memory.create_block(label=self.target_block_label, value=new_memory)
|
||||
|
||||
agent_state.memory.update_block_value(label=self.target_block_label, value=new_memory)
|
||||
return "Successfully updated memory"
|
||||
return None, True
|
||||
|
||||
def store_memory(self, start_index: int, end_index: int, context: str, agent_state: AgentState) -> str:
|
||||
def store_memory(self, start_index: int, end_index: int, context: str, agent_state: AgentState) -> Tuple[Optional[str], bool]:
|
||||
"""
|
||||
Store a memory.
|
||||
"""
|
||||
@@ -283,9 +335,9 @@ Use `rethink_user_memor(new_memory)` as many times as you need to iteratively im
|
||||
)
|
||||
self.agent_manager.rebuild_system_prompt(agent_id=agent_state.id, actor=self.actor, force=True)
|
||||
|
||||
return "Sucessfully stored memory"
|
||||
return None, True
|
||||
except Exception as e:
|
||||
return f"Failed to store memory given start_index {start_index} and end_index {end_index}: {e}"
|
||||
return f"Failed to store memory given start_index {start_index} and end_index {end_index}: {e}", False
|
||||
|
||||
def serialize(self, messages: List[str], context: str) -> str:
|
||||
"""
|
||||
@@ -344,4 +396,4 @@ Use `rethink_user_memor(new_memory)` as many times as you need to iteratively im
|
||||
"""
|
||||
This agent is synchronous-only. If called in an async context, raise an error.
|
||||
"""
|
||||
raise NotImplementedError("EphemeralMemoryAgent does not support async step.")
|
||||
raise NotImplementedError("VoiceSleeptimeAgent does not support async step.")
|
||||
@@ -1031,7 +1031,7 @@ class RESTClient(AbstractClient):
|
||||
# messages = []
|
||||
# for m in response.messages:
|
||||
# assert isinstance(m, Message)
|
||||
# messages += m.to_letta_message()
|
||||
# messages += m.to_letta_messages()
|
||||
# response.messages = messages
|
||||
|
||||
return response
|
||||
@@ -2725,14 +2725,14 @@ class LocalClient(AbstractClient):
|
||||
# assert isinstance(m, Message), f"Expected Message object, got {type(m)}"
|
||||
# letta_messages = []
|
||||
# for m in messages:
|
||||
# letta_messages += m.to_letta_message()
|
||||
# letta_messages += m.to_letta_messages()
|
||||
# return LettaResponse(messages=letta_messages, usage=usage)
|
||||
|
||||
# format messages
|
||||
messages = self.interface.to_list()
|
||||
letta_messages = []
|
||||
for m in messages:
|
||||
letta_messages += m.to_letta_message()
|
||||
letta_messages += m.to_letta_messages()
|
||||
|
||||
return LettaResponse(messages=letta_messages, usage=usage)
|
||||
|
||||
|
||||
@@ -219,7 +219,7 @@ class Message(BaseMessage):
|
||||
return [
|
||||
msg
|
||||
for m in messages
|
||||
for msg in m.to_letta_message(
|
||||
for msg in m.to_letta_messages(
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
@@ -227,7 +227,7 @@ class Message(BaseMessage):
|
||||
)
|
||||
]
|
||||
|
||||
def to_letta_message(
|
||||
def to_letta_messages(
|
||||
self,
|
||||
use_assistant_message: bool = False,
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
@@ -447,7 +447,7 @@ class Message(BaseMessage):
|
||||
name: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
tool_returns: Optional[List[ToolReturn]] = None,
|
||||
):
|
||||
) -> Message:
|
||||
"""Convert a ChatCompletion message object into a Message object (synced to DB)"""
|
||||
if not created_at:
|
||||
# timestamp for creation
|
||||
|
||||
@@ -56,8 +56,8 @@ async def create_voice_chat_completions(
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
message_buffer_limit=40,
|
||||
message_buffer_min=15,
|
||||
message_buffer_limit=8,
|
||||
message_buffer_min=4,
|
||||
)
|
||||
|
||||
# Return the streaming generator
|
||||
|
||||
@@ -122,7 +122,7 @@ class MessageManager:
|
||||
message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor)
|
||||
|
||||
# convert back to LettaMessage
|
||||
for letta_msg in message.to_letta_message(use_assistant_message=True):
|
||||
for letta_msg in message.to_letta_messages(use_assistant_message=True):
|
||||
if letta_msg.message_type == letta_message_update.message_type:
|
||||
return letta_msg
|
||||
|
||||
@@ -160,7 +160,7 @@ class MessageManager:
|
||||
message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor)
|
||||
|
||||
# convert back to LettaMessage
|
||||
for letta_msg in message.to_letta_message(use_assistant_message=True):
|
||||
for letta_msg in message.to_letta_messages(use_assistant_message=True):
|
||||
if letta_msg.message_type == letta_message_update.message_type:
|
||||
return letta_msg
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import traceback
|
||||
from typing import List, Tuple
|
||||
|
||||
from letta.agents.ephemeral_memory_agent import EphemeralMemoryAgent
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
@@ -21,7 +21,7 @@ class Summarizer:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, mode: SummarizationMode, summarizer_agent: EphemeralMemoryAgent, message_buffer_limit: int = 10, message_buffer_min: int = 3
|
||||
self, mode: SummarizationMode, summarizer_agent: VoiceSleeptimeAgent, message_buffer_limit: int = 10, message_buffer_min: int = 3
|
||||
):
|
||||
self.mode = mode
|
||||
|
||||
@@ -109,17 +109,14 @@ class Summarizer:
|
||||
|
||||
evicted_messages_str = "\n".join(formatted_evicted_messages)
|
||||
in_context_messages_str = "\n".join(formatted_in_context_messages)
|
||||
summary_request_text = f"""You are a specialized memory recall agent assisting another AI agent by asynchronously reorganizing its memory storage. The LLM agent you are helping maintains a limited context window that retains only the most recent {self.message_buffer_min} messages from its conversations. The provided conversation history includes messages that are about to be evicted from its context window, as well as some additional recent messages for extra clarity and context.
|
||||
summary_request_text = f"""You’re a memory-recall helper for an AI that can only keep the last {self.message_buffer_min} messages. Scan the conversation history, focusing on messages about to drop out of that window, and write crisp notes that capture any important facts or insights about the human so they aren’t lost.
|
||||
|
||||
Your task is to carefully review the provided conversation history and proactively generate detailed, relevant memories about the human participant, specifically targeting information contained in messages that are about to be evicted from the context window. Your notes will help preserve critical insights, events, or facts that would otherwise be forgotten.
|
||||
|
||||
(Older) Evicted Messages:
|
||||
(Older) Evicted Messages:\n
|
||||
{evicted_messages_str}
|
||||
|
||||
(Newer) In-Context Messages:
|
||||
(Newer) In-Context Messages:\n
|
||||
{in_context_messages_str}
|
||||
"""
|
||||
|
||||
# Fire-and-forget the summarization task
|
||||
self.fire_and_forget(
|
||||
self.summarizer_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
@@ -9,20 +8,15 @@ from letta_client import Letta
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.agents.ephemeral_memory_agent import EphemeralMemoryAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# --- Server Management --- #
|
||||
|
||||
@@ -43,7 +37,7 @@ def server_url():
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5) # Allow server startup time
|
||||
wait_for_server(url) # Allow server startup time
|
||||
|
||||
return url
|
||||
|
||||
@@ -137,7 +131,7 @@ def _get_chat_request(message, stream=True):
|
||||
"""Returns a chat completion request with streaming enabled."""
|
||||
return ChatCompletionRequest(
|
||||
model="gpt-4o-mini",
|
||||
messages=[UserMessage(content=message)],
|
||||
messages=[OpenAIUserMessage(content=message)],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@@ -164,150 +158,9 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
# --- Test Cases --- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["How are you?"])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_latency(disable_e2b_api_key, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
async_client = AsyncOpenAI(base_url=f"http://localhost:8283/{endpoint}/{agent.id}", max_retries=0)
|
||||
import time
|
||||
|
||||
print(f"SENT OFF REQUEST {time.perf_counter()}")
|
||||
first = True
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
print(chunk)
|
||||
if first:
|
||||
print(f"FIRST RECEIVED FROM REQUEST{time.perf_counter()}")
|
||||
first = False
|
||||
continue
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_multiple_messages(disable_e2b_api_key, client, agent, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request("How are you?")
|
||||
async_client = AsyncOpenAI(base_url=f"http://localhost:8283/{endpoint}/{agent.id}", max_retries=0)
|
||||
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content)
|
||||
print("============================================")
|
||||
request = _get_chat_request("What are you up to?")
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ephemeral_memory_agent(disable_e2b_api_key, agent):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
async_client = AsyncOpenAI()
|
||||
message_transcripts = [
|
||||
"user: Hey, I’ve been thinking about planning a road trip up the California coast next month.",
|
||||
"assistant: That sounds amazing! Do you have any particular cities or sights in mind?",
|
||||
"user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops.",
|
||||
"assistant: Great choices. Would you like recommendations for top-rated coffee spots along the way?",
|
||||
"user: Yes, please. Also, I prefer independent cafés over chains, and I’m vegan.",
|
||||
"assistant: Noted—independent, vegan-friendly cafés. Anything else?",
|
||||
"user: I’d also like to listen to something upbeat, maybe a podcast or playlist suggestion.",
|
||||
"assistant: Sure—perhaps an indie rock playlist or a travel podcast like “Zero To Travel.”",
|
||||
"user: Perfect. By the way, my birthday is June 12th, so I’ll be turning 30 on the trip.",
|
||||
"assistant: Happy early birthday! Would you like gift ideas or celebration tips?",
|
||||
"user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.",
|
||||
"assistant: How about Vegan Treats in Santa Barbara? They’re highly rated.",
|
||||
"user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.",
|
||||
"user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total.",
|
||||
"assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.",
|
||||
"user: Yes, let’s do that.",
|
||||
"assistant: I’ll put together a day-by-day plan now.",
|
||||
]
|
||||
|
||||
memory_agent = EphemeralMemoryAgent(
|
||||
agent_id=agent.id,
|
||||
openai_client=async_client,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=AgentManager(),
|
||||
actor=UserManager().get_user_or_default(),
|
||||
block_manager=BlockManager(),
|
||||
target_block_label="human",
|
||||
message_transcripts=message_transcripts,
|
||||
)
|
||||
|
||||
summary_request_text = """
|
||||
Here is the conversation history. Lines marked (Older) are about to be evicted; lines marked (Newer) are still in context for clarity:
|
||||
|
||||
(Older)
|
||||
0. user: Hey, I’ve been thinking about planning a road trip up the California coast next month.
|
||||
1. assistant: That sounds amazing! Do you have any particular cities or sights in mind?
|
||||
2. user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops.
|
||||
3. assistant: Great choices. Would you like recommendations for top-rated coffee spots along the way?
|
||||
4. user: Yes, please. Also, I prefer independent cafés over chains, and I’m vegan.
|
||||
5. assistant: Noted—independent, vegan-friendly cafés. Anything else?
|
||||
6. user: I’d also like to listen to something upbeat, maybe a podcast or playlist suggestion.
|
||||
7. assistant: Sure—perhaps an indie rock playlist or a travel podcast like “Zero To Travel.”
|
||||
8. user: Perfect. By the way, my birthday is June 12th, so I’ll be turning 30 on the trip.
|
||||
9. assistant: Happy early birthday! Would you like gift ideas or celebration tips?
|
||||
10. user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.
|
||||
11. assistant: How about Vegan Treats in Santa Barbara? They’re highly rated.
|
||||
12. user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.
|
||||
|
||||
(Newer)
|
||||
13. user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total.
|
||||
14. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
|
||||
15. user: Yes, let’s do that.
|
||||
16. assistant: I’ll put together a day-by-day plan now.
|
||||
|
||||
Please segment the (Older) portion into coherent chunks and—using **only** the `store_memory` tool—output a JSON call that lists each chunk’s `start_index`, `end_index`, and a one-sentence `contextual_description`.
|
||||
"""
|
||||
|
||||
results = await memory_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=summary_request_text)])])
|
||||
print(results)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_voice_recall_memory(disable_e2b_api_key, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
# Insert some messages about my name
|
||||
client.agents.messages.create(
|
||||
agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role=MessageRole.user,
|
||||
content=[
|
||||
TextContent(text="My name is Matt, don't do anything with this information other than call send_message right after.")
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Wipe the in context messages
|
||||
actor = UserManager().get_default_user()
|
||||
AgentManager().set_in_context_messages(agent_id=agent.id, message_ids=[agent.message_ids[0]], actor=actor)
|
||||
|
||||
async_client = AsyncOpenAI(base_url=f"http://localhost:8283/{endpoint}/{agent.id}", max_retries=0)
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas.", "What's the weather in SF?"])
|
||||
@pytest.mark.parametrize("endpoint", ["openai/v1"]) # , "v1/voice-beta"])
|
||||
@pytest.mark.parametrize("endpoint", ["openai/v1"])
|
||||
async def test_chat_completions_streaming_openai_client(disable_e2b_api_key, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
@@ -79,9 +79,13 @@ def agent_state(client: Letta) -> AgentState:
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
|
||||
"""
|
||||
client.tools.upsert_base_tools()
|
||||
|
||||
send_message_tool = client.tools.list(name="send_message")[0]
|
||||
agent_state_instance = client.agents.create(
|
||||
name="supervisor",
|
||||
include_base_tools=True,
|
||||
include_base_tools=False,
|
||||
tool_ids=[send_message_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["supervisor"],
|
||||
|
||||
@@ -1,15 +1,87 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Provider, Step
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentType, CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from letta.schemas.group import ManagerType
|
||||
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage
|
||||
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.utils import get_persona_text
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
MESSAGE_TRANSCRIPTS = [
|
||||
"user: Hey, I’ve been thinking about planning a road trip up the California coast next month.",
|
||||
"assistant: That sounds amazing! Do you have any particular cities or sights in mind?",
|
||||
"user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops.",
|
||||
"assistant: Great choices. Would you like recommendations for top-rated coffee spots along the way?",
|
||||
"user: Yes, please. Also, I prefer independent cafés over chains, and I’m vegan.",
|
||||
"assistant: Noted—independent, vegan-friendly cafés. Anything else?",
|
||||
"user: I’d also like to listen to something upbeat, maybe a podcast or playlist suggestion.",
|
||||
"assistant: Sure—perhaps an indie rock playlist or a travel podcast like “Zero To Travel.”",
|
||||
"user: Perfect. By the way, my birthday is June 12th, so I’ll be turning 30 on the trip.",
|
||||
"assistant: Happy early birthday! Would you like gift ideas or celebration tips?",
|
||||
"user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.",
|
||||
"assistant: How about Vegan Treats in Santa Barbara? They’re highly rated.",
|
||||
"user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.",
|
||||
"user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total.",
|
||||
"assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.",
|
||||
"user: Yes, let’s do that.",
|
||||
"assistant: I’ll put together a day-by-day plan now.",
|
||||
]
|
||||
|
||||
SUMMARY_REQ_TEXT = """
|
||||
Here is the conversation history. Lines marked (Older) are about to be evicted; lines marked (Newer) are still in context for clarity:
|
||||
|
||||
(Older)
|
||||
0. user: Hey, I’ve been thinking about planning a road trip up the California coast next month.
|
||||
1. assistant: That sounds amazing! Do you have any particular cities or sights in mind?
|
||||
2. user: I definitely want to stop in Big Sur and maybe Santa Barbara. Also, I love craft coffee shops.
|
||||
3. assistant: Great choices. Would you like recommendations for top-rated coffee spots along the way?
|
||||
4. user: Yes, please. Also, I prefer independent cafés over chains, and I’m vegan.
|
||||
5. assistant: Noted—independent, vegan-friendly cafés. Anything else?
|
||||
6. user: I’d also like to listen to something upbeat, maybe a podcast or playlist suggestion.
|
||||
7. assistant: Sure—perhaps an indie rock playlist or a travel podcast like “Zero To Travel.”
|
||||
8. user: Perfect. By the way, my birthday is June 12th, so I’ll be turning 30 on the trip.
|
||||
9. assistant: Happy early birthday! Would you like gift ideas or celebration tips?
|
||||
10. user: Maybe just a recommendation for a nice vegan bakery to grab a birthday treat.
|
||||
11. assistant: How about Vegan Treats in Santa Barbara? They’re highly rated.
|
||||
12. user: Sounds good. Also, I work remotely as a UX designer, usually on a MacBook Pro.
|
||||
|
||||
(Newer)
|
||||
13. user: I want to make sure my itinerary isn’t too tight—aiming for 3–4 days total.
|
||||
14. assistant: Understood. I can draft a relaxed 4-day schedule with driving and stops.
|
||||
15. user: Yes, let’s do that.
|
||||
16. assistant: I’ll put together a day-by-day plan now.
|
||||
|
||||
Please segment the (Older) portion into coherent chunks and—using **only** the `store_memory` tool—output a JSON call that lists each chunk’s `start_index`, `end_index`, and a one-sentence `contextual_description`.
|
||||
"""
|
||||
|
||||
# --- Server Management --- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -23,6 +95,119 @@ def server():
|
||||
return server
|
||||
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(url) # Allow server startup time
|
||||
|
||||
return url
|
||||
|
||||
|
||||
# --- Client Setup --- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = Letta(base_url=server_url)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def roll_dice_tool(client):
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
return "Rolled a 10!"
|
||||
|
||||
tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def weather_tool(client):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
|
||||
Parameters:
|
||||
location (str): The location to get the weather for.
|
||||
|
||||
Returns:
|
||||
str: A formatted string describing the weather in the given location.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the request to fetch weather data fails.
|
||||
"""
|
||||
import requests
|
||||
|
||||
url = f"https://wttr.in/{location}?format=%C+%t"
|
||||
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
weather_data = response.text
|
||||
return f"The weather in {location} is {weather_data}."
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
tool = client.tools.upsert_from_function(func=get_weather)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def composio_gmail_get_profile_tool(default_user):
|
||||
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
|
||||
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def voice_agent(server, actor):
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
main_agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
agent_type=AgentType.voice_convo_agent,
|
||||
name="main_agent",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a personal assistant that helps users with requests.",
|
||||
),
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="My favorite plant is the fiddle leaf\nMy favorite color is lavender",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
enable_sleeptime=True,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return main_agent
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def org_id(server):
|
||||
org = server.organization_manager.create_default_organization()
|
||||
@@ -46,35 +231,147 @@ def actor(server, org_id):
|
||||
server.user_manager.delete_user_by_id(user.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_voice_convo_agent(server, actor):
|
||||
# 0. Refresh base tools
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
# --- Helper Functions --- #
|
||||
|
||||
# 1. Create sleeptime agent
|
||||
main_agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
agent_type=AgentType.voice_convo_agent,
|
||||
name="main_agent",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a personal assistant that helps users with requests.",
|
||||
),
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="My favorite plant is the fiddle leaf\nMy favorite color is lavender",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
enable_sleeptime=True,
|
||||
),
|
||||
actor=actor,
|
||||
|
||||
def _get_chat_request(message, stream=True):
|
||||
"""Returns a chat completion request with streaming enabled."""
|
||||
return ChatCompletionRequest(
|
||||
model="gpt-4o-mini",
|
||||
messages=[OpenAIUserMessage(content=message)],
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
assert main_agent.enable_sleeptime == True
|
||||
main_agent_tools = [tool.name for tool in main_agent.tools]
|
||||
|
||||
def _assert_valid_chunk(chunk, idx, chunks):
|
||||
"""Validates the structure of each streaming chunk."""
|
||||
if isinstance(chunk, ChatCompletionChunk):
|
||||
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
|
||||
|
||||
elif isinstance(chunk, LettaUsageStatistics):
|
||||
assert chunk.completion_tokens > 0, "Completion tokens must be > 0."
|
||||
assert chunk.prompt_tokens > 0, "Prompt tokens must be > 0."
|
||||
assert chunk.total_tokens > 0, "Total tokens must be > 0."
|
||||
assert chunk.step_count == 1, "Step count must be 1."
|
||||
|
||||
elif isinstance(chunk, MessageStreamStatus):
|
||||
assert chunk == MessageStreamStatus.done, "Stream should end with 'done' status."
|
||||
assert idx == len(chunks) - 1, "The last chunk must be 'done'."
|
||||
|
||||
else:
|
||||
pytest.fail(f"Unexpected chunk type: {chunk}")
|
||||
|
||||
|
||||
# --- Tests --- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_voice_recall_memory(disable_e2b_api_key, client, voice_agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
async_client = AsyncOpenAI(base_url=f"http://localhost:8283/{endpoint}/{voice_agent.id}", max_retries=0)
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_multiple_messages(disable_e2b_api_key, client, voice_agent, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request("How are you?")
|
||||
async_client = AsyncOpenAI(base_url=f"http://localhost:8283/{endpoint}/{voice_agent.id}", max_retries=0)
|
||||
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content)
|
||||
print("============================================")
|
||||
request = _get_chat_request("What are you up to?")
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
agent_manager = AgentManager()
|
||||
user_manager = UserManager()
|
||||
actor = user_manager.get_default_user()
|
||||
|
||||
request = CreateAgent(
|
||||
name=voice_agent.name + "-sleeptime",
|
||||
agent_type=AgentType.voice_sleeptime_agent,
|
||||
block_ids=[block.id for block in voice_agent.memory.blocks],
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="memory_persona",
|
||||
value=get_persona_text("voice_memory_persona"),
|
||||
),
|
||||
],
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
project_id=voice_agent.project_id,
|
||||
)
|
||||
sleeptime_agent = agent_manager.create_agent(request, actor=actor)
|
||||
|
||||
async_client = AsyncOpenAI()
|
||||
|
||||
memory_agent = VoiceSleeptimeAgent(
|
||||
agent_id=sleeptime_agent.id,
|
||||
convo_agent_state=sleeptime_agent, # In reality, this will be the main convo agent
|
||||
openai_client=async_client,
|
||||
message_manager=MessageManager(),
|
||||
agent_manager=agent_manager,
|
||||
actor=actor,
|
||||
block_manager=BlockManager(),
|
||||
target_block_label="human",
|
||||
message_transcripts=MESSAGE_TRANSCRIPTS,
|
||||
)
|
||||
|
||||
results = await memory_agent.step([MessageCreate(role=MessageRole.user, content=[TextContent(text=SUMMARY_REQ_TEXT)])])
|
||||
|
||||
messages = results.messages
|
||||
# --- Basic structural check ---
|
||||
assert isinstance(messages, list)
|
||||
assert len(messages) >= 5, "Expected at least 5 messages in the sequence"
|
||||
|
||||
# --- Message 0: initial UserMessage ---
|
||||
assert isinstance(messages[0], UserMessage), "First message should be a UserMessage"
|
||||
|
||||
# --- Message 1: store_memories ToolCall ---
|
||||
assert isinstance(messages[1], ToolCallMessage), "Second message should be ToolCallMessage"
|
||||
assert messages[1].name == "store_memories", "Expected store_memories tool call"
|
||||
|
||||
# --- Message 2: store_memories ToolReturn ---
|
||||
assert isinstance(messages[2], ToolReturnMessage), "Third message should be ToolReturnMessage"
|
||||
assert messages[2].name == "store_memories", "Expected store_memories tool return"
|
||||
assert messages[2].status == "success", "store_memories tool return should be successful"
|
||||
|
||||
# --- Message 3: rethink_user_memory ToolCall ---
|
||||
assert isinstance(messages[3], ToolCallMessage), "Fourth message should be ToolCallMessage"
|
||||
assert messages[3].name == "rethink_user_memory", "Expected rethink_user_memory tool call"
|
||||
|
||||
# --- Message 4: rethink_user_memory ToolReturn ---
|
||||
assert isinstance(messages[4], ToolReturnMessage), "Fifth message should be ToolReturnMessage"
|
||||
assert messages[4].name == "rethink_user_memory", "Expected rethink_user_memory tool return"
|
||||
assert messages[4].status == "success", "rethink_user_memory tool return should be successful"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_voice_convo_agent(voice_agent, server, actor):
|
||||
|
||||
assert voice_agent.enable_sleeptime == True
|
||||
main_agent_tools = [tool.name for tool in voice_agent.tools]
|
||||
assert len(main_agent_tools) == 2
|
||||
assert "send_message" in main_agent_tools
|
||||
assert "search_memory" in main_agent_tools
|
||||
@@ -84,7 +381,7 @@ async def test_init_voice_convo_agent(server, actor):
|
||||
|
||||
# 2. Check that a group was created
|
||||
group = server.group_manager.retrieve_group(
|
||||
group_id=main_agent.multi_agent_group.id,
|
||||
group_id=voice_agent.multi_agent_group.id,
|
||||
actor=actor,
|
||||
)
|
||||
assert group.manager_type == ManagerType.voice_sleeptime
|
||||
@@ -92,11 +389,11 @@ async def test_init_voice_convo_agent(server, actor):
|
||||
|
||||
# 3. Verify shared blocks
|
||||
sleeptime_agent_id = group.agent_ids[0]
|
||||
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
|
||||
shared_block = server.agent_manager.get_block_with_label(agent_id=voice_agent.id, block_label="human", actor=actor)
|
||||
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
|
||||
assert len(agents) == 2
|
||||
assert sleeptime_agent_id in [agent.id for agent in agents]
|
||||
assert main_agent.id in [agent.id for agent in agents]
|
||||
assert voice_agent.id in [agent.id for agent in agents]
|
||||
|
||||
# 4 Verify sleeptime agent tools
|
||||
sleeptime_agent = server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
|
||||
@@ -107,7 +404,7 @@ async def test_init_voice_convo_agent(server, actor):
|
||||
|
||||
# 5. Send a message as a sanity check
|
||||
response = await server.send_message_to_agent(
|
||||
agent_id=main_agent.id,
|
||||
agent_id=voice_agent.id,
|
||||
actor=actor,
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
@@ -124,7 +421,7 @@ async def test_init_voice_convo_agent(server, actor):
|
||||
assert AssistantMessage in message_types
|
||||
|
||||
# 6. Delete agent
|
||||
server.agent_manager.delete_agent(agent_id=main_agent.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_id=voice_agent.id, actor=actor)
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
server.group_manager.retrieve_group(group_id=group.id, actor=actor)
|
||||
|
||||
Reference in New Issue
Block a user