fix: Remove in-memory _messages field on Agent (#2295)
This commit is contained in:
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@@ -33,7 +33,6 @@ jobs:
|
||||
- "test_memory.py"
|
||||
- "test_utils.py"
|
||||
- "test_stream_buffer_readers.py"
|
||||
- "test_summarize.py"
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
|
||||
708
letta/agent.py
708
letta/agent.py
@@ -1,25 +1,22 @@
|
||||
import datetime
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from letta.constants import (
|
||||
BASE_TOOLS,
|
||||
CLI_WARNING_PREFIX,
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
||||
IN_CONTEXT_MEMORY_KEYWORD,
|
||||
LLM_MAX_TOKENS,
|
||||
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
MESSAGE_SUMMARY_WARNING_FRAC,
|
||||
O1_BASE_TOOLS,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
STRUCTURED_OUTPUT_MODELS,
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
@@ -34,7 +31,7 @@ from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, MessageUpdate
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
Tool as ChatCompletionRequestTool,
|
||||
)
|
||||
@@ -49,6 +46,10 @@ from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
check_supports_structured_output,
|
||||
compile_memory_metadata_block,
|
||||
)
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
@@ -56,8 +57,6 @@ from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from letta.system import (
|
||||
get_heartbeat,
|
||||
get_initial_boot_messages,
|
||||
get_login_event,
|
||||
get_token_limit_warning,
|
||||
package_function_response,
|
||||
package_summarize_message,
|
||||
@@ -66,166 +65,20 @@ from letta.system import (
|
||||
from letta.utils import (
|
||||
count_tokens,
|
||||
get_friendly_error_msg,
|
||||
get_local_time,
|
||||
get_tool_call_id,
|
||||
get_utc_time,
|
||||
is_utc_datetime,
|
||||
json_dumps,
|
||||
json_loads,
|
||||
parse_json,
|
||||
printd,
|
||||
united_diff,
|
||||
validate_function_response,
|
||||
verify_first_message_correctness,
|
||||
)
|
||||
|
||||
|
||||
def compile_memory_metadata_block(
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
memory_edit_timestamp: datetime.datetime,
|
||||
agent_manager: Optional[AgentManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
) -> str:
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
|
||||
|
||||
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
|
||||
memory_metadata_block = "\n".join(
|
||||
[
|
||||
f"### Memory [last modified: {timestamp_str}]",
|
||||
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{agent_manager.passage_size(actor=actor, agent_id=agent_id) if agent_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
]
|
||||
)
|
||||
return memory_metadata_block
|
||||
|
||||
|
||||
def compile_system_message(
|
||||
system_prompt: str,
|
||||
agent_id: str,
|
||||
in_context_memory: Memory,
|
||||
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
||||
actor: PydanticUser,
|
||||
agent_manager: Optional[AgentManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
The base system message may be templated, in which case we need to render the variables.
|
||||
|
||||
The following are reserved variables:
|
||||
- CORE_MEMORY: the in-context memory of the LLM
|
||||
"""
|
||||
|
||||
if user_defined_variables is not None:
|
||||
# TODO eventually support the user defining their own variables to inject
|
||||
raise NotImplementedError
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
# Add the protected memory variable
|
||||
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
||||
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
||||
else:
|
||||
# TODO should this all put into the memory.__repr__ function?
|
||||
memory_metadata_string = compile_memory_metadata_block(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
agent_manager=agent_manager,
|
||||
message_manager=message_manager,
|
||||
)
|
||||
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
||||
|
||||
# Add to the variables list to inject
|
||||
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
||||
|
||||
if template_format == "f-string":
|
||||
|
||||
# Catch the special case where the system prompt is unformatted
|
||||
if append_icm_if_missing:
|
||||
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
||||
if memory_variable_string not in system_prompt:
|
||||
# In this case, append it to the end to make sure memory is still injected
|
||||
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
||||
system_prompt += "\n" + memory_variable_string
|
||||
|
||||
# render the variables using the built-in templater
|
||||
try:
|
||||
formatted_prompt = system_prompt.format_map(variables)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
||||
|
||||
else:
|
||||
# TODO support for mustache and jinja2
|
||||
raise NotImplementedError(template_format)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
def initialize_message_sequence(
|
||||
model: str,
|
||||
system: str,
|
||||
agent_id: str,
|
||||
memory: Memory,
|
||||
actor: PydanticUser,
|
||||
agent_manager: Optional[AgentManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
||||
include_initial_boot_message: bool = True,
|
||||
) -> List[dict]:
|
||||
if memory_edit_timestamp is None:
|
||||
memory_edit_timestamp = get_local_time()
|
||||
|
||||
# full_system_message = construct_system_with_memory(
|
||||
# system, memory, memory_edit_timestamp, agent_manager=agent_manager, recall_memory=recall_memory
|
||||
# )
|
||||
full_system_message = compile_system_message(
|
||||
agent_id=agent_id,
|
||||
system_prompt=system,
|
||||
in_context_memory=memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
actor=actor,
|
||||
agent_manager=agent_manager,
|
||||
message_manager=message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
)
|
||||
first_user_message = get_login_event() # event letting Letta know the user just logged in
|
||||
|
||||
if include_initial_boot_message:
|
||||
if model is not None and "gpt-3.5" in model:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
|
||||
else:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
|
||||
messages = (
|
||||
[
|
||||
{"role": "system", "content": full_system_message},
|
||||
]
|
||||
+ initial_boot_messages
|
||||
+ [
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": full_system_message},
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""
|
||||
Abstract class for all agents.
|
||||
Only two interfaces are required: step and update_state.
|
||||
Only one interface is required: step.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -238,10 +91,6 @@ class BaseAgent(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(self) -> AgentState:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Agent(BaseAgent):
|
||||
def __init__(
|
||||
@@ -250,9 +99,7 @@ class Agent(BaseAgent):
|
||||
agent_state: AgentState, # in-memory representation of the agent state (read from multiple tables)
|
||||
user: User,
|
||||
# extras
|
||||
messages_total: Optional[int] = None, # TODO remove?
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
initial_message_sequence: Optional[List[Message]] = None,
|
||||
):
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
|
||||
# Hold a copy of the state that was used to init the agent
|
||||
@@ -276,7 +123,7 @@ class Agent(BaseAgent):
|
||||
|
||||
# gpt-4, gpt-3.5-turbo, ...
|
||||
self.model = self.agent_state.llm_config.model
|
||||
self.check_tool_rules()
|
||||
self.supports_structured_output = check_supports_structured_output(model=self.model, tool_rules=agent_state.tool_rules)
|
||||
|
||||
# state managers
|
||||
self.block_manager = BlockManager()
|
||||
@@ -304,99 +151,14 @@ class Agent(BaseAgent):
|
||||
# When the summarizer is run, set this back to False (to reset)
|
||||
self.agent_alerted_about_memory_pressure = False
|
||||
|
||||
self._messages: List[Message] = []
|
||||
|
||||
# Once the memory object is initialized, use it to "bake" the system message
|
||||
if self.agent_state.message_ids is not None:
|
||||
self.set_message_buffer(message_ids=self.agent_state.message_ids)
|
||||
|
||||
else:
|
||||
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
|
||||
assert self.agent_state.id is not None and self.agent_state.created_by_id is not None
|
||||
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
init_messages = initialize_message_sequence(
|
||||
model=self.model,
|
||||
system=self.agent_state.system,
|
||||
agent_id=self.agent_state.id,
|
||||
memory=self.agent_state.memory,
|
||||
actor=self.user,
|
||||
agent_manager=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
)
|
||||
|
||||
if initial_message_sequence is not None:
|
||||
# We always need the system prompt up front
|
||||
system_message_obj = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=init_messages[0],
|
||||
)
|
||||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||||
init_messages = [system_message_obj] + initial_message_sequence
|
||||
|
||||
else:
|
||||
# Basic "more human than human" initial message sequence
|
||||
init_messages = initialize_message_sequence(
|
||||
model=self.model,
|
||||
system=self.agent_state.system,
|
||||
memory=self.agent_state.memory,
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
agent_manager=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
)
|
||||
# Cast to Message objects
|
||||
init_messages = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id, user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=msg
|
||||
)
|
||||
for msg in init_messages
|
||||
]
|
||||
|
||||
# Cast the messages to actual Message objects to be synced to the DB
|
||||
init_messages_objs = []
|
||||
for msg in init_messages:
|
||||
init_messages_objs.append(msg)
|
||||
for msg in init_messages_objs:
|
||||
assert isinstance(msg, Message), f"Message object is not of type Message: {type(msg)}"
|
||||
assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages)
|
||||
|
||||
# Put the messages inside the message buffer
|
||||
self.messages_total = 0
|
||||
self._append_to_messages(added_messages=init_messages_objs)
|
||||
self._validate_message_buffer_is_utc()
|
||||
|
||||
# Load last function response from message history
|
||||
self.last_function_response = self.load_last_function_response()
|
||||
|
||||
# Keep track of the total number of messages throughout all time
|
||||
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
|
||||
self.messages_total_init = len(self._messages) - 1
|
||||
printd(f"Agent initialized, self.messages_total={self.messages_total}")
|
||||
|
||||
# Create the agent in the DB
|
||||
self.update_state()
|
||||
|
||||
def check_tool_rules(self):
|
||||
if self.model not in STRUCTURED_OUTPUT_MODELS:
|
||||
if len(self.tool_rules_solver.init_tool_rules) > 1:
|
||||
raise ValueError(
|
||||
"Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule."
|
||||
)
|
||||
self.supports_structured_output = False
|
||||
else:
|
||||
self.supports_structured_output = True
|
||||
|
||||
def load_last_function_response(self):
|
||||
"""Load the last function response from message history"""
|
||||
for i in range(len(self._messages) - 1, -1, -1):
|
||||
msg = self._messages[i]
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
for i in range(len(in_context_messages) - 1, -1, -1):
|
||||
msg = in_context_messages[i]
|
||||
if msg.role == MessageRole.tool and msg.text:
|
||||
try:
|
||||
response_json = json.loads(msg.text)
|
||||
@@ -435,7 +197,7 @@ class Agent(BaseAgent):
|
||||
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
|
||||
# rebuild memory - this records the last edited timestamp of the memory
|
||||
# TODO: pass in update timestamp from block edit time
|
||||
self.rebuild_system_prompt()
|
||||
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
|
||||
|
||||
return True
|
||||
return False
|
||||
@@ -487,109 +249,6 @@ class Agent(BaseAgent):
|
||||
|
||||
return function_response
|
||||
|
||||
@property
|
||||
def messages(self) -> List[dict]:
|
||||
"""Getter method that converts the internal Message list into OpenAI-style dicts"""
|
||||
return [msg.to_openai_dict() for msg in self._messages]
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value):
|
||||
raise Exception("Modifying message list directly not allowed")
|
||||
|
||||
def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
|
||||
"""Load a list of messages from recall storage"""
|
||||
|
||||
# Pull the message objects from the database
|
||||
message_objs = []
|
||||
for msg_id in message_ids:
|
||||
msg_obj = self.message_manager.get_message_by_id(msg_id, actor=self.user)
|
||||
if msg_obj:
|
||||
if isinstance(msg_obj, Message):
|
||||
message_objs.append(msg_obj)
|
||||
else:
|
||||
printd(f"Warning - message ID {msg_id} is not a Message object")
|
||||
warnings.warn(f"Warning - message ID {msg_id} is not a Message object")
|
||||
else:
|
||||
printd(f"Warning - message ID {msg_id} not found in recall storage")
|
||||
warnings.warn(f"Warning - message ID {msg_id} not found in recall storage")
|
||||
|
||||
return message_objs
|
||||
|
||||
def _validate_message_buffer_is_utc(self):
|
||||
"""Iterate over the message buffer and force all messages to be UTC stamped"""
|
||||
|
||||
for m in self._messages:
|
||||
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
|
||||
# TODO eventually do casting via an edit_message function
|
||||
if m.created_at:
|
||||
if not is_utc_datetime(m.created_at):
|
||||
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
|
||||
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
def set_message_buffer(self, message_ids: List[str], force_utc: bool = True):
|
||||
"""Set the messages in the buffer to the message IDs list"""
|
||||
|
||||
message_objs = self._load_messages_from_recall(message_ids=message_ids)
|
||||
|
||||
# set the objects in the buffer
|
||||
self._messages = message_objs
|
||||
|
||||
# bugfix for old agents that may not have had UTC specified in their timestamps
|
||||
if force_utc:
|
||||
self._validate_message_buffer_is_utc()
|
||||
|
||||
# also sync the message IDs attribute
|
||||
self.agent_state.message_ids = message_ids
|
||||
|
||||
def refresh_message_buffer(self):
|
||||
"""Refresh the message buffer from the database"""
|
||||
|
||||
messages_to_sync = self.agent_state.message_ids
|
||||
assert messages_to_sync and all([isinstance(msg_id, str) for msg_id in messages_to_sync])
|
||||
|
||||
self.set_message_buffer(message_ids=messages_to_sync)
|
||||
|
||||
def _trim_messages(self, num):
|
||||
"""Trim messages from the front, not including the system message"""
|
||||
new_messages = [self._messages[0]] + self._messages[num:]
|
||||
self._messages = new_messages
|
||||
|
||||
def _prepend_to_messages(self, added_messages: List[Message]):
|
||||
"""Wrapper around self.messages.prepend to allow additional calls to a state/persistence manager"""
|
||||
assert all([isinstance(msg, Message) for msg in added_messages])
|
||||
self.message_manager.create_many_messages(added_messages, actor=self.user)
|
||||
|
||||
new_messages = [self._messages[0]] + added_messages + self._messages[1:] # prepend (no system)
|
||||
self._messages = new_messages
|
||||
self.messages_total += len(added_messages) # still should increment the message counter (summaries are additions too)
|
||||
|
||||
def _append_to_messages(self, added_messages: List[Message]):
|
||||
"""Wrapper around self.messages.append to allow additional calls to a state/persistence manager"""
|
||||
assert all([isinstance(msg, Message) for msg in added_messages])
|
||||
self.message_manager.create_many_messages(added_messages, actor=self.user)
|
||||
|
||||
# strip extra metadata if it exists
|
||||
# for msg in added_messages:
|
||||
# msg.pop("api_response", None)
|
||||
# msg.pop("api_args", None)
|
||||
new_messages = self._messages + added_messages # append
|
||||
|
||||
self._messages = new_messages
|
||||
self.messages_total += len(added_messages)
|
||||
|
||||
def append_to_messages(self, added_messages: List[dict]):
|
||||
"""An external-facing message append, where dict-like messages are first converted to Message objects"""
|
||||
added_messages_objs = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=msg,
|
||||
)
|
||||
for msg in added_messages
|
||||
]
|
||||
self._append_to_messages(added_messages_objs)
|
||||
|
||||
def _get_ai_reply(
|
||||
self,
|
||||
message_sequence: List[Message],
|
||||
@@ -898,7 +557,7 @@ class Agent(BaseAgent):
|
||||
|
||||
# rebuild memory
|
||||
# TODO: @charles please check this
|
||||
self.rebuild_system_prompt()
|
||||
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
|
||||
|
||||
# Update ToolRulesSolver state with last called function
|
||||
self.tool_rules_solver.update_tool_usage(function_name)
|
||||
@@ -930,6 +589,7 @@ class Agent(BaseAgent):
|
||||
messages=next_input_message,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
heartbeat_request = step_response.heartbeat_request
|
||||
function_failed = step_response.function_failed
|
||||
token_warning = step_response.in_context_memory_warning
|
||||
@@ -1021,33 +681,19 @@ class Agent(BaseAgent):
|
||||
if not all(isinstance(m, Message) for m in messages):
|
||||
raise ValueError(f"messages should be a Message or a list of Message, got {type(messages)}")
|
||||
|
||||
input_message_sequence = self._messages + messages
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
input_message_sequence = in_context_messages + messages
|
||||
|
||||
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
|
||||
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
|
||||
|
||||
# Step 2: send the conversation and available functions to the LLM
|
||||
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
|
||||
printd(f"This is the first message. Running extra verifier on AI response.")
|
||||
counter = 0
|
||||
while True:
|
||||
response = self._get_ai_reply(
|
||||
message_sequence=input_message_sequence, first_message=True, stream=stream # passed through to the prompt formatter
|
||||
)
|
||||
if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
|
||||
break
|
||||
|
||||
counter += 1
|
||||
if counter > first_message_retry_limit:
|
||||
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
|
||||
|
||||
else:
|
||||
response = self._get_ai_reply(
|
||||
message_sequence=input_message_sequence,
|
||||
first_message=first_message,
|
||||
stream=stream,
|
||||
step_count=step_count,
|
||||
)
|
||||
response = self._get_ai_reply(
|
||||
message_sequence=input_message_sequence,
|
||||
first_message=first_message,
|
||||
stream=stream,
|
||||
step_count=step_count,
|
||||
)
|
||||
|
||||
# Step 3: check if LLM wanted to call a function
|
||||
# (if yes) Step 4: call the function
|
||||
@@ -1095,10 +741,9 @@ class Agent(BaseAgent):
|
||||
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
|
||||
)
|
||||
|
||||
self._append_to_messages(all_new_messages)
|
||||
|
||||
# update state after each step
|
||||
self.update_state()
|
||||
self.agent_state = self.agent_manager.append_to_in_context_messages(
|
||||
all_new_messages, agent_id=self.agent_state.id, actor=self.user
|
||||
)
|
||||
|
||||
return AgentStepResponse(
|
||||
messages=all_new_messages,
|
||||
@@ -1113,7 +758,9 @@ class Agent(BaseAgent):
|
||||
|
||||
# If we got a context alert, try trimming the messages length, then try again
|
||||
if is_context_overflow_error(e):
|
||||
printd(f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages")
|
||||
printd(
|
||||
f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages"
|
||||
)
|
||||
# A separate API call to run a summarizer
|
||||
self.summarize_messages_inplace()
|
||||
|
||||
@@ -1165,15 +812,19 @@ class Agent(BaseAgent):
|
||||
return self.inner_step(messages=[user_message], **kwargs)
|
||||
|
||||
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True):
|
||||
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
|
||||
if in_context_messages_openai[0]["role"] != "system":
|
||||
raise RuntimeError(f"in_context_messages_openai[0] should be system (instead got {in_context_messages_openai[0]})")
|
||||
|
||||
# Start at index 1 (past the system message),
|
||||
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
|
||||
# Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling
|
||||
token_counts = [count_tokens(str(msg)) for msg in self.messages]
|
||||
token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai]
|
||||
message_buffer_token_count = sum(token_counts[1:]) # no system message
|
||||
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
|
||||
candidate_messages_to_summarize = self.messages[1:]
|
||||
candidate_messages_to_summarize = in_context_messages_openai[1:]
|
||||
token_counts = token_counts[1:]
|
||||
|
||||
if preserve_last_N_messages:
|
||||
@@ -1193,7 +844,7 @@ class Agent(BaseAgent):
|
||||
"Not enough messages to compress for summarization",
|
||||
details={
|
||||
"num_candidate_messages": len(candidate_messages_to_summarize),
|
||||
"num_total_messages": len(self.messages),
|
||||
"num_total_messages": len(in_context_messages_openai),
|
||||
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
},
|
||||
)
|
||||
@@ -1212,9 +863,9 @@ class Agent(BaseAgent):
|
||||
# Try to make an assistant message come after the cutoff
|
||||
try:
|
||||
printd(f"Selected cutoff {cutoff} was a 'user', shifting one...")
|
||||
if self.messages[cutoff]["role"] == "user":
|
||||
if in_context_messages_openai[cutoff]["role"] == "user":
|
||||
new_cutoff = cutoff + 1
|
||||
if self.messages[new_cutoff]["role"] == "user":
|
||||
if in_context_messages_openai[new_cutoff]["role"] == "user":
|
||||
printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...")
|
||||
cutoff = new_cutoff
|
||||
except IndexError:
|
||||
@@ -1222,23 +873,23 @@ class Agent(BaseAgent):
|
||||
|
||||
# Make sure the cutoff isn't on a 'tool' or 'function'
|
||||
if disallow_tool_as_first:
|
||||
while self.messages[cutoff]["role"] in ["tool", "function"] and cutoff < len(self.messages):
|
||||
while in_context_messages_openai[cutoff]["role"] in ["tool", "function"] and cutoff < len(in_context_messages_openai):
|
||||
printd(f"Selected cutoff {cutoff} was a 'tool', shifting one...")
|
||||
cutoff += 1
|
||||
|
||||
message_sequence_to_summarize = self._messages[1:cutoff] # do NOT get rid of the system message
|
||||
message_sequence_to_summarize = in_context_messages[1:cutoff] # do NOT get rid of the system message
|
||||
if len(message_sequence_to_summarize) <= 1:
|
||||
# This prevents a potential infinite loop of summarizing the same message over and over
|
||||
raise ContextWindowExceededError(
|
||||
"Not enough messages to compress for summarization after determining cutoff",
|
||||
details={
|
||||
"num_candidate_messages": len(message_sequence_to_summarize),
|
||||
"num_total_messages": len(self.messages),
|
||||
"num_total_messages": len(in_context_messages_openai),
|
||||
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
},
|
||||
)
|
||||
else:
|
||||
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self._messages)}")
|
||||
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(in_context_messages)}")
|
||||
|
||||
# We can't do summarize logic properly if context_window is undefined
|
||||
if self.agent_state.llm_config.context_window is None:
|
||||
@@ -1253,118 +904,33 @@ class Agent(BaseAgent):
|
||||
printd(f"Got summary: {summary}")
|
||||
|
||||
# Metadata that's useful for the agent to see
|
||||
all_time_message_count = self.messages_total
|
||||
remaining_message_count = len(self.messages[cutoff:])
|
||||
all_time_message_count = self.message_manager.size(agent_id=self.agent_state.id, actor=self.user)
|
||||
remaining_message_count = len(in_context_messages_openai[cutoff:])
|
||||
hidden_message_count = all_time_message_count - remaining_message_count
|
||||
summary_message_count = len(message_sequence_to_summarize)
|
||||
summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count)
|
||||
printd(f"Packaged into message: {summary_message}")
|
||||
|
||||
prior_len = len(self.messages)
|
||||
self._trim_messages(cutoff)
|
||||
prior_len = len(in_context_messages_openai)
|
||||
self.agent_state = self.agent_manager.trim_older_in_context_messages(cutoff, agent_id=self.agent_state.id, actor=self.user)
|
||||
packed_summary_message = {"role": "user", "content": summary_message}
|
||||
self._prepend_to_messages(
|
||||
[
|
||||
self.agent_state = self.agent_manager.prepend_to_in_context_messages(
|
||||
messages=[
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=packed_summary_message,
|
||||
)
|
||||
]
|
||||
],
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
# reset alert
|
||||
self.agent_alerted_about_memory_pressure = False
|
||||
|
||||
printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}")
|
||||
|
||||
def _swap_system_message_in_buffer(self, new_system_message: str):
|
||||
"""Update the system message (NOT prompt) of the Agent (requires updating the internal buffer)"""
|
||||
assert isinstance(new_system_message, str)
|
||||
new_system_message_obj = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={"role": "system", "content": new_system_message},
|
||||
)
|
||||
|
||||
assert new_system_message_obj.role == "system", new_system_message_obj
|
||||
assert self._messages[0].role == "system", self._messages
|
||||
|
||||
self.message_manager.create_message(new_system_message_obj, actor=self.user)
|
||||
|
||||
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
|
||||
self._messages = new_messages
|
||||
|
||||
def rebuild_system_prompt(self, force=False, update_timestamp=True):
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||||
|
||||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||||
|
||||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||||
"""
|
||||
|
||||
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
|
||||
|
||||
# note: we only update the system prompt if the core memory is changed
|
||||
# this means that the archival/recall memory statistics may be someout out of date
|
||||
curr_memory_str = self.agent_state.memory.compile()
|
||||
if curr_memory_str in curr_system_message["content"] and not force:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
printd(f"Memory hasn't changed, skipping system prompt rebuild")
|
||||
return
|
||||
|
||||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||||
# For example, if we're doing a system prompt swap, this should probably be False
|
||||
if update_timestamp:
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
else:
|
||||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||||
memory_edit_timestamp = self._messages[0].created_at
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
new_system_message_str = compile_system_message(
|
||||
agent_id=self.agent_state.id,
|
||||
system_prompt=self.agent_state.system,
|
||||
in_context_memory=self.agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
actor=self.user,
|
||||
agent_manager=self.agent_manager,
|
||||
message_manager=self.message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
)
|
||||
new_system_message = {
|
||||
"role": "system",
|
||||
"content": new_system_message_str,
|
||||
}
|
||||
|
||||
diff = united_diff(curr_system_message["content"], new_system_message["content"])
|
||||
if len(diff) > 0: # there was a diff
|
||||
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# Swap the system message out (only if there is a diff)
|
||||
self._swap_system_message_in_buffer(new_system_message=new_system_message_str)
|
||||
assert self.messages[0]["content"] == new_system_message["content"], (
|
||||
self.messages[0]["content"],
|
||||
new_system_message["content"],
|
||||
)
|
||||
|
||||
def update_system_prompt(self, new_system_prompt: str):
|
||||
"""Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)"""
|
||||
assert isinstance(new_system_prompt, str)
|
||||
|
||||
if new_system_prompt == self.agent_state.system:
|
||||
return
|
||||
|
||||
self.agent_state.system = new_system_prompt
|
||||
|
||||
# updating the system prompt requires rebuilding the memory block inside the compiled system message
|
||||
self.rebuild_system_prompt(force=True, update_timestamp=False)
|
||||
|
||||
# make sure to persist the change
|
||||
_ = self.update_state()
|
||||
printd(f"Ran summarizer, messages length {prior_len} -> {len(in_context_messages_openai)}")
|
||||
|
||||
def add_function(self, function_name: str) -> str:
|
||||
# TODO: refactor
|
||||
@@ -1374,20 +940,6 @@ class Agent(BaseAgent):
|
||||
# TODO: refactor
|
||||
raise NotImplementedError
|
||||
|
||||
def update_state(self) -> AgentState:
|
||||
# TODO: this should be removed and self._messages should be moved into self.agent_state.in_context_messages
|
||||
message_ids = [msg.id for msg in self._messages]
|
||||
|
||||
# Assert that these are all strings
|
||||
if any(not isinstance(m_id, str) for m_id in message_ids):
|
||||
warnings.warn(f"Non-string message IDs found in agent state: {message_ids}")
|
||||
message_ids = [m_id for m_id in message_ids if isinstance(m_id, str)]
|
||||
|
||||
# override any fields that may have been updated
|
||||
self.agent_state.message_ids = message_ids
|
||||
|
||||
return self.agent_state
|
||||
|
||||
def migrate_embedding(self, embedding_config: EmbeddingConfig):
|
||||
"""Migrate the agent to a new embedding"""
|
||||
# TODO: archival memory
|
||||
@@ -1421,123 +973,6 @@ class Agent(BaseAgent):
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}.",
|
||||
)
|
||||
|
||||
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
|
||||
"""Update the details of a message associated with an agent"""
|
||||
# Save the updated message
|
||||
updated_message = self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=self.user)
|
||||
return updated_message
|
||||
|
||||
# TODO(sarah): should we be creating a new message here, or just editing a message?
|
||||
def rethink_message(self, new_thought: str) -> Message:
|
||||
"""Rethink / update the last message"""
|
||||
for x in range(len(self.messages) - 1, 0, -1):
|
||||
msg_obj = self._messages[x]
|
||||
if msg_obj.role == MessageRole.assistant:
|
||||
updated_message = self.update_message(
|
||||
message_id=msg_obj.id,
|
||||
request=MessageUpdate(
|
||||
text=new_thought,
|
||||
),
|
||||
)
|
||||
self.refresh_message_buffer()
|
||||
return updated_message
|
||||
raise ValueError(f"No assistant message found to update")
|
||||
|
||||
# TODO(sarah): should we be creating a new message here, or just editing a message?
|
||||
def rewrite_message(self, new_text: str) -> Message:
|
||||
"""Rewrite / update the send_message text on the last message"""
|
||||
|
||||
# Walk backwards through the messages until we find an assistant message
|
||||
for x in range(len(self._messages) - 1, 0, -1):
|
||||
if self._messages[x].role == MessageRole.assistant:
|
||||
# Get the current message content
|
||||
message_obj = self._messages[x]
|
||||
|
||||
# The rewrite target is the output of send_message
|
||||
if message_obj.tool_calls is not None and len(message_obj.tool_calls) > 0:
|
||||
|
||||
# Check that we hit an assistant send_message call
|
||||
name_string = message_obj.tool_calls[0].function.name
|
||||
if name_string is None or name_string != "send_message":
|
||||
raise ValueError("Assistant missing send_message function call")
|
||||
|
||||
args_string = message_obj.tool_calls[0].function.arguments
|
||||
if args_string is None:
|
||||
raise ValueError("Assistant missing send_message function arguments")
|
||||
|
||||
args_json = json_loads(args_string)
|
||||
if "message" not in args_json:
|
||||
raise ValueError("Assistant missing send_message message argument")
|
||||
|
||||
# Once we found our target, rewrite it
|
||||
args_json["message"] = new_text
|
||||
new_args_string = json_dumps(args_json)
|
||||
message_obj.tool_calls[0].function.arguments = new_args_string
|
||||
|
||||
# Write the update to the DB
|
||||
updated_message = self.update_message(
|
||||
message_id=message_obj.id,
|
||||
request=MessageUpdate(
|
||||
tool_calls=message_obj.tool_calls,
|
||||
),
|
||||
)
|
||||
self.refresh_message_buffer()
|
||||
return updated_message
|
||||
|
||||
raise ValueError("No assistant message found to update")
|
||||
|
||||
def pop_message(self, count: int = 1) -> List[Message]:
|
||||
"""Pop the last N messages from the agent's memory"""
|
||||
n_messages = len(self._messages)
|
||||
popped_messages = []
|
||||
MIN_MESSAGES = 2
|
||||
if n_messages <= MIN_MESSAGES:
|
||||
raise ValueError(f"Agent only has {n_messages} messages in stack, none left to pop")
|
||||
elif n_messages - count < MIN_MESSAGES:
|
||||
raise ValueError(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}")
|
||||
else:
|
||||
# print(f"Popping last {count} messages from stack")
|
||||
for _ in range(min(count, len(self._messages))):
|
||||
# remove the message from the internal state of the agent
|
||||
deleted_message = self._messages.pop()
|
||||
# then also remove it from recall storage
|
||||
try:
|
||||
self.message_manager.delete_message_by_id(deleted_message.id, actor=self.user)
|
||||
popped_messages.append(deleted_message)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Error deleting message {deleted_message.id} from recall memory: {e}")
|
||||
self._messages.append(deleted_message)
|
||||
break
|
||||
|
||||
return popped_messages
|
||||
|
||||
def pop_until_user(self) -> List[Message]:
|
||||
"""Pop all messages until the last user message"""
|
||||
if MessageRole.user not in [msg.role for msg in self._messages]:
|
||||
raise ValueError("No user message found in buffer")
|
||||
|
||||
popped_messages = []
|
||||
while len(self._messages) > 0:
|
||||
if self._messages[-1].role == MessageRole.user:
|
||||
# we want to pop up to the last user message
|
||||
return popped_messages
|
||||
else:
|
||||
popped_messages.append(self.pop_message(count=1))
|
||||
|
||||
raise ValueError("No user message found in buffer")
|
||||
|
||||
def retry_message(self) -> List[Message]:
|
||||
"""Retry / regenerate the last message"""
|
||||
self.pop_until_user()
|
||||
user_message = self.pop_message(count=1)[0]
|
||||
assert user_message.text is not None, "User message text is None"
|
||||
step_response = self.step_user_message(user_message_str=user_message.text)
|
||||
messages = step_response.messages
|
||||
|
||||
assert messages is not None
|
||||
assert all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects"
|
||||
return messages
|
||||
|
||||
def get_context_window(self) -> ContextWindowOverview:
|
||||
"""Get the context window of the agent"""
|
||||
|
||||
@@ -1546,24 +981,28 @@ class Agent(BaseAgent):
|
||||
core_memory = self.agent_state.memory.compile()
|
||||
num_tokens_core_memory = count_tokens(core_memory)
|
||||
|
||||
# Grab the in-context messages
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
messages_openai_format = self.messages
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
|
||||
# Check if there's a summary message in the message queue
|
||||
if (
|
||||
len(self._messages) > 1
|
||||
and self._messages[1].role == MessageRole.user
|
||||
and isinstance(self._messages[1].text, str)
|
||||
len(in_context_messages) > 1
|
||||
and in_context_messages[1].role == MessageRole.user
|
||||
and isinstance(in_context_messages[1].text, str)
|
||||
# TODO remove hardcoding
|
||||
and "The following is a summary of the previous " in self._messages[1].text
|
||||
and "The following is a summary of the previous " in in_context_messages[1].text
|
||||
):
|
||||
# Summary message exists
|
||||
assert self._messages[1].text is not None
|
||||
summary_memory = self._messages[1].text
|
||||
num_tokens_summary_memory = count_tokens(self._messages[1].text)
|
||||
assert in_context_messages[1].text is not None
|
||||
summary_memory = in_context_messages[1].text
|
||||
num_tokens_summary_memory = count_tokens(in_context_messages[1].text)
|
||||
# with a summary message, the real messages start at index 2
|
||||
num_tokens_messages = (
|
||||
num_tokens_from_messages(messages=messages_openai_format[2:], model=self.model) if len(messages_openai_format) > 2 else 0
|
||||
num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model)
|
||||
if len(in_context_messages_openai) > 2
|
||||
else 0
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -1571,17 +1010,17 @@ class Agent(BaseAgent):
|
||||
num_tokens_summary_memory = 0
|
||||
# with no summary message, the real messages start at index 1
|
||||
num_tokens_messages = (
|
||||
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
|
||||
num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model)
|
||||
if len(in_context_messages_openai) > 1
|
||||
else 0
|
||||
)
|
||||
|
||||
agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id)
|
||||
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||
external_memory_summary = compile_memory_metadata_block(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
memory_edit_timestamp=get_utc_time(), # dummy timestamp
|
||||
agent_manager=self.agent_manager,
|
||||
message_manager=self.message_manager,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
previous_message_count=self.message_manager.size(actor=self.user, agent_id=self.agent_state.id),
|
||||
archival_memory_size=self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id),
|
||||
)
|
||||
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
||||
|
||||
@@ -1606,7 +1045,7 @@ class Agent(BaseAgent):
|
||||
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(self._messages),
|
||||
num_messages=len(in_context_messages),
|
||||
num_archival_memory=agent_manager_passage_size,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
@@ -1621,7 +1060,7 @@ class Agent(BaseAgent):
|
||||
num_tokens_summary_memory=num_tokens_summary_memory,
|
||||
summary_memory=summary_memory,
|
||||
num_tokens_messages=num_tokens_messages,
|
||||
messages=self._messages,
|
||||
messages=in_context_messages,
|
||||
# related to functions
|
||||
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
|
||||
functions_definitions=available_functions_definitions,
|
||||
@@ -1635,7 +1074,6 @@ class Agent(BaseAgent):
|
||||
|
||||
def save_agent(agent: Agent):
|
||||
"""Save agent to metadata store"""
|
||||
agent.update_state()
|
||||
agent_state = agent.agent_state
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
|
||||
|
||||
|
||||
@@ -63,8 +63,8 @@ class ChatOnlyAgent(Agent):
|
||||
conversation_persona_block_new = Block(
|
||||
name="chat_agent_persona_new", label="chat_agent_persona_new", value=conversation_persona_block.value, limit=2000
|
||||
)
|
||||
|
||||
recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit :]
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
recent_convo = "".join([str(message) for message in in_context_messages[3:]])[-self.recent_convo_limit :]
|
||||
conversation_messages_block = Block(
|
||||
name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit
|
||||
)
|
||||
|
||||
@@ -2234,7 +2234,7 @@ class LocalClient(AbstractClient):
|
||||
"""
|
||||
# TODO: add the abilitty to reset linked block_ids
|
||||
self.interface.clear()
|
||||
agent_state = self.server.update_agent(
|
||||
agent_state = self.server.agent_manager.update_agent(
|
||||
agent_id,
|
||||
UpdateAgent(
|
||||
name=name,
|
||||
@@ -2262,7 +2262,7 @@ class LocalClient(AbstractClient):
|
||||
List[Tool]: A list of Tool objs
|
||||
"""
|
||||
self.interface.clear()
|
||||
return self.server.get_tools_from_agent(agent_id=agent_id, user_id=self.user_id)
|
||||
return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user).tools
|
||||
|
||||
def add_tool_to_agent(self, agent_id: str, tool_id: str):
|
||||
"""
|
||||
@@ -2276,7 +2276,7 @@ class LocalClient(AbstractClient):
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
self.interface.clear()
|
||||
agent_state = self.server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
|
||||
agent_state = self.server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
|
||||
return agent_state
|
||||
|
||||
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
|
||||
@@ -2291,7 +2291,7 @@ class LocalClient(AbstractClient):
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
self.interface.clear()
|
||||
agent_state = self.server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
|
||||
agent_state = self.server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
|
||||
return agent_state
|
||||
|
||||
def rename_agent(self, agent_id: str, new_name: str):
|
||||
@@ -2426,7 +2426,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
messages (List[Message]): List of in-context messages
|
||||
"""
|
||||
return self.server.get_in_context_messages(agent_id=agent_id, actor=self.user)
|
||||
return self.server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=self.user)
|
||||
|
||||
# agent interactions
|
||||
|
||||
@@ -3075,7 +3075,7 @@ class LocalClient(AbstractClient):
|
||||
agent_id (str): ID of the agent
|
||||
memory_id (str): ID of the memory
|
||||
"""
|
||||
self.server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=self.user)
|
||||
self.server.delete_archival_memory(memory_id=memory_id, actor=self.user)
|
||||
|
||||
def get_archival_memory(
|
||||
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000
|
||||
|
||||
@@ -194,46 +194,6 @@ def run_agent_loop(
|
||||
print(f"Current model: {letta_agent.agent_state.llm_config.model}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "):
|
||||
# Check if there's an additional argument that's an integer
|
||||
command = user_input.strip().split()
|
||||
pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3
|
||||
try:
|
||||
popped_messages = letta_agent.pop_message(count=pop_amount)
|
||||
except ValueError as e:
|
||||
print(f"Error popping messages: {e}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/retry":
|
||||
print(f"Retrying for another answer...")
|
||||
try:
|
||||
letta_agent.retry_message()
|
||||
except Exception as e:
|
||||
print(f"Error retrying message: {e}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/rethink" or user_input.lower().startswith("/rethink "):
|
||||
if len(user_input) < len("/rethink "):
|
||||
print("Missing text after the command")
|
||||
continue
|
||||
try:
|
||||
letta_agent.rethink_message(new_thought=user_input[len("/rethink ") :].strip())
|
||||
except Exception as e:
|
||||
print(f"Error rethinking message: {e}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/rewrite" or user_input.lower().startswith("/rewrite "):
|
||||
if len(user_input) < len("/rewrite "):
|
||||
print("Missing text after the command")
|
||||
continue
|
||||
|
||||
text = user_input[len("/rewrite ") :].strip()
|
||||
try:
|
||||
letta_agent.rewrite_message(new_text=text)
|
||||
except Exception as e:
|
||||
print(f"Error rewriting message: {e}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/summarize":
|
||||
try:
|
||||
letta_agent.summarize_messages_inplace()
|
||||
@@ -319,42 +279,6 @@ def run_agent_loop(
|
||||
questionary.print(cmd, "bold")
|
||||
questionary.print(f" {desc}")
|
||||
continue
|
||||
|
||||
elif user_input.lower().startswith("/systemswap"):
|
||||
if len(user_input) < len("/systemswap "):
|
||||
print("Missing new system prompt after the command")
|
||||
continue
|
||||
old_system_prompt = letta_agent.system
|
||||
new_system_prompt = user_input[len("/systemswap ") :].strip()
|
||||
|
||||
# Show warning and prompts to user
|
||||
typer.secho(
|
||||
"\nWARNING: You are about to change the system prompt.",
|
||||
# fg=typer.colors.BRIGHT_YELLOW,
|
||||
bold=True,
|
||||
)
|
||||
typer.secho(
|
||||
f"\nOld system prompt:\n{old_system_prompt}",
|
||||
fg=typer.colors.RED,
|
||||
bold=True,
|
||||
)
|
||||
typer.secho(
|
||||
f"\nNew system prompt:\n{new_system_prompt}",
|
||||
fg=typer.colors.GREEN,
|
||||
bold=True,
|
||||
)
|
||||
|
||||
# Ask for confirmation
|
||||
confirm = questionary.confirm("Do you want to proceed with the swap?").ask()
|
||||
|
||||
if confirm:
|
||||
letta_agent.update_system_prompt(new_system_prompt=new_system_prompt)
|
||||
print("System prompt updated successfully.")
|
||||
else:
|
||||
print("System prompt swap cancelled.")
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
print(f"Unrecognized command: {user_input}")
|
||||
continue
|
||||
|
||||
@@ -129,9 +129,8 @@ class OfflineMemoryAgent(Agent):
|
||||
# extras
|
||||
first_message_verify_mono: bool = False,
|
||||
max_memory_rethinks: int = 10,
|
||||
initial_message_sequence: Optional[List[Message]] = None,
|
||||
):
|
||||
super().__init__(interface, agent_state, user, initial_message_sequence=initial_message_sequence)
|
||||
super().__init__(interface, agent_state, user)
|
||||
self.first_message_verify_mono = first_message_verify_mono
|
||||
self.max_memory_rethinks = max_memory_rethinks
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ def get_agent_context_window(
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
|
||||
return server.get_agent_context_window(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
class CreateAgentRequest(CreateAgent):
|
||||
@@ -138,7 +138,7 @@ def update_agent(
|
||||
):
|
||||
"""Update an exsiting agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.update_agent(agent_id, update_agent, actor=actor)
|
||||
return server.agent_manager.update_agent(agent_id=agent_id, agent_update=update_agent, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent")
|
||||
@@ -149,7 +149,7 @@ def get_tools_from_agent(
|
||||
):
|
||||
"""Get tools from an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id)
|
||||
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).tools
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent")
|
||||
@@ -161,7 +161,7 @@ def add_tool_to_agent(
|
||||
):
|
||||
"""Add tools to an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
||||
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, user_id=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent")
|
||||
@@ -173,7 +173,7 @@ def remove_tool_from_agent(
|
||||
):
|
||||
"""Add tools to an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
||||
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent")
|
||||
@@ -232,7 +232,7 @@ def get_agent_in_context_messages(
|
||||
Retrieve the messages in the context of a specific agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.get_in_context_messages(agent_id=agent_id, actor=actor)
|
||||
return server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
# TODO: remove? can also get with agent blocks
|
||||
@@ -429,7 +429,7 @@ def delete_agent_archival_memory(
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=actor)
|
||||
server.delete_archival_memory(memory_id=memory_id, actor=actor)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
||||
|
||||
|
||||
@@ -479,8 +479,9 @@ def update_message(
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
# TODO: Get rid of agent_id here, it's not really relevant
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request, actor=actor)
|
||||
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
||||
@@ -40,14 +40,14 @@ from letta.providers import (
|
||||
VLLMChatCompletionsProvider,
|
||||
VLLMCompletionsProvider,
|
||||
)
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
from letta.schemas.letta_message import ToolReturnMessage, LettaMessage
|
||||
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import (
|
||||
ArchivalMemorySummary,
|
||||
@@ -376,25 +376,6 @@ class SyncServer(Server):
|
||||
)
|
||||
)
|
||||
|
||||
def initialize_agent(self, agent_id, actor, interface: Union[AgentInterface, None] = None, initial_message_sequence=None) -> Agent:
|
||||
"""Initialize an agent from the database"""
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
|
||||
interface = interface or self.default_interface_factory()
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
agent = Agent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence)
|
||||
elif agent_state.agent_type == AgentType.offline_memory_agent:
|
||||
agent = OfflineMemoryAgent(
|
||||
agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence
|
||||
)
|
||||
else:
|
||||
assert initial_message_sequence is None, f"Initial message sequence is not supported for O1Agents"
|
||||
agent = O1Agent(agent_state=agent_state, interface=interface, user=actor)
|
||||
|
||||
# Persist to agent
|
||||
save_agent(agent)
|
||||
return agent
|
||||
|
||||
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
|
||||
"""Updated method to load agents from persisted storage"""
|
||||
agent_lock = self.per_agent_lock_manager.get_lock(agent_id)
|
||||
@@ -413,11 +394,6 @@ class SyncServer(Server):
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
|
||||
|
||||
# Rebuild the system prompt - may be linked to new blocks now
|
||||
agent.rebuild_system_prompt()
|
||||
|
||||
# Persist to agent
|
||||
save_agent(agent)
|
||||
return agent
|
||||
|
||||
def _step(
|
||||
@@ -456,7 +432,7 @@ class SyncServer(Server):
|
||||
)
|
||||
|
||||
# save agent after step
|
||||
save_agent(letta_agent)
|
||||
# save_agent(letta_agent)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in server._step: {e}")
|
||||
@@ -790,129 +766,23 @@ class SyncServer(Server):
|
||||
|
||||
"""Create a new agent using a config"""
|
||||
# Invoke manager
|
||||
agent_state = self.agent_manager.create_agent(
|
||||
return self.agent_manager.create_agent(
|
||||
agent_create=request,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# create the agent object
|
||||
if request.initial_message_sequence is not None:
|
||||
# init_messages = [Message(user_id=user_id, agent_id=agent_state.id, role=message.role, text=message.text) for message in request.initial_message_sequence]
|
||||
init_messages = []
|
||||
for message in request.initial_message_sequence:
|
||||
|
||||
if message.role == MessageRole.user:
|
||||
packed_message = system.package_user_message(
|
||||
user_message=message.text,
|
||||
)
|
||||
elif message.role == MessageRole.system:
|
||||
packed_message = system.package_system_message(
|
||||
system_message=message.text,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
init_messages.append(Message(role=message.role, text=packed_message, agent_id=agent_state.id))
|
||||
# init_messages = [Message.dict_to_message(user_id=user_id, agent_id=agent_state.id, openai_message_dict=message.model_dump()) for message in request.initial_message_sequence]
|
||||
else:
|
||||
init_messages = None
|
||||
|
||||
# initialize the agent (generates initial message list with system prompt)
|
||||
if interface is None:
|
||||
interface = self.default_interface_factory()
|
||||
self.initialize_agent(agent_id=agent_state.id, interface=interface, initial_message_sequence=init_messages, actor=actor)
|
||||
|
||||
in_memory_agent_state = self.agent_manager.get_agent_by_id(agent_state.id, actor=actor)
|
||||
return in_memory_agent_state
|
||||
|
||||
# TODO: This is not good!
|
||||
# TODO: Ideally, this should ALL be handled by the ORM
|
||||
# TODO: The main blocker here IS the _message updates
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
request: UpdateAgent,
|
||||
actor: User,
|
||||
) -> AgentState:
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
# Update agent state in the db first
|
||||
agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# TODO: Everything below needs to get removed, no updating anything in memory
|
||||
# update the system prompt
|
||||
if request.system:
|
||||
letta_agent.update_system_prompt(request.system)
|
||||
|
||||
# update in-context messages
|
||||
if request.message_ids:
|
||||
# This means the user is trying to change what messages are in the message buffer
|
||||
# Internally this requires (1) pulling from recall,
|
||||
# then (2) setting the attributes ._messages and .state.message_ids
|
||||
letta_agent.set_message_buffer(message_ids=request.message_ids)
|
||||
|
||||
letta_agent.update_state()
|
||||
|
||||
return agent_state
|
||||
|
||||
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
|
||||
"""Get tools from an existing agent"""
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return letta_agent.agent_state.tools
|
||||
|
||||
def add_tool_to_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Add tools from an existing agent"""
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
return agent_state
|
||||
|
||||
def remove_tool_from_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Remove tools from an existing agent"""
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
return agent_state
|
||||
|
||||
# convert name->id
|
||||
|
||||
# TODO: These can be moved to agent_manager
|
||||
def get_agent_memory(self, agent_id: str, actor: User) -> Memory:
|
||||
"""Return the memory of an agent (core memory)"""
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return agent.agent_state.memory
|
||||
return self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).memory
|
||||
|
||||
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
|
||||
|
||||
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return RecallMemorySummary(size=len(agent.message_manager))
|
||||
|
||||
def get_in_context_messages(self, agent_id: str, actor: User) -> List[Message]:
|
||||
"""Get the in-context messages in the agent's memory"""
|
||||
# Get the agent object (loaded in memory)
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return agent._messages
|
||||
return RecallMemorySummary(size=self.message_manager.size(actor=actor, agent_id=agent_id))
|
||||
|
||||
def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[Passage]:
|
||||
"""Paginated query of all messages in agent archival memory"""
|
||||
@@ -947,24 +817,17 @@ class SyncServer(Server):
|
||||
|
||||
def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
# Insert into archival memory
|
||||
passages = self.passage_manager.insert_passage(
|
||||
agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor
|
||||
)
|
||||
|
||||
save_agent(letta_agent)
|
||||
# TODO: @mindy look at moving this to agent_manager to avoid above extra call
|
||||
passages = self.passage_manager.insert_passage(agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor)
|
||||
|
||||
return passages
|
||||
|
||||
def delete_archival_memory(self, agent_id: str, memory_id: str, actor: User):
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Delete by ID
|
||||
def delete_archival_memory(self, memory_id: str, actor: User):
|
||||
# TODO check if it exists first, and throw error if not
|
||||
letta_agent.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
|
||||
# TODO: @mindy make this return the deleted passage instead
|
||||
self.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
|
||||
|
||||
# TODO: return archival memory
|
||||
|
||||
@@ -1042,9 +905,8 @@ class SyncServer(Server):
|
||||
# update the block
|
||||
self.block_manager.update_block(block_id=block.id, block_update=BlockUpdate(value=value), actor=actor)
|
||||
|
||||
# load agent
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return letta_agent.agent_state.memory
|
||||
# rebuild system prompt for agent, potentially changed
|
||||
return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory
|
||||
|
||||
def delete_source(self, source_id: str, actor: User):
|
||||
"""Delete a data source"""
|
||||
@@ -1214,36 +1076,11 @@ class SyncServer(Server):
|
||||
|
||||
return success
|
||||
|
||||
def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate, actor: User) -> Message:
|
||||
def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message:
|
||||
"""Update the details of a message associated with an agent"""
|
||||
|
||||
# Get the current message
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
response = letta_agent.update_message(message_id=message_id, request=request)
|
||||
save_agent(letta_agent)
|
||||
return response
|
||||
|
||||
def rewrite_agent_message(self, agent_id: str, new_text: str, actor: User) -> Message:
|
||||
|
||||
# Get the current message
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
response = letta_agent.rewrite_message(new_text=new_text)
|
||||
save_agent(letta_agent)
|
||||
return response
|
||||
|
||||
def rethink_agent_message(self, agent_id: str, new_thought: str, actor: User) -> Message:
|
||||
# Get the current message
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
response = letta_agent.rethink_message(new_thought=new_thought)
|
||||
save_agent(letta_agent)
|
||||
return response
|
||||
|
||||
def retry_agent_message(self, agent_id: str, actor: User) -> List[Message]:
|
||||
# Get the current message
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
response = letta_agent.retry_message()
|
||||
save_agent(letta_agent)
|
||||
return response
|
||||
return self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
|
||||
|
||||
def get_organization_or_default(self, org_id: Optional[str]) -> Organization:
|
||||
"""Get the organization object for org_id if it exists, otherwise return the default organization object"""
|
||||
@@ -1331,15 +1168,7 @@ class SyncServer(Server):
|
||||
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
|
||||
"""Add a new embedding model"""
|
||||
|
||||
def get_agent_context_window(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
) -> ContextWindowOverview:
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Get the current message
|
||||
def get_agent_context_window(self, agent_id: str, actor: User) -> ContextWindowOverview:
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return letta_agent.get_context_window()
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
@@ -28,12 +29,17 @@ from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_process_relationship,
|
||||
_process_tags,
|
||||
check_supports_structured_output,
|
||||
compile_system_message,
|
||||
derive_system_message,
|
||||
initialize_message_sequence,
|
||||
package_initial_message_sequence,
|
||||
)
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.utils import enforce_types, get_utc_time, united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -49,6 +55,7 @@ class AgentManager:
|
||||
self.block_manager = BlockManager()
|
||||
self.tool_manager = ToolManager()
|
||||
self.source_manager = SourceManager()
|
||||
self.message_manager = MessageManager()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Basic CRUD operations
|
||||
@@ -64,6 +71,10 @@ class AgentManager:
|
||||
if not agent_create.llm_config or not agent_create.embedding_config:
|
||||
raise ValueError("llm_config and embedding_config are required")
|
||||
|
||||
# Check tool rules are valid
|
||||
if agent_create.tool_rules:
|
||||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
|
||||
|
||||
# create blocks (note: cannot be linked into the agent_id is created)
|
||||
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
|
||||
for create_block in agent_create.memory_blocks:
|
||||
@@ -88,7 +99,8 @@ class AgentManager:
|
||||
# Remove duplicates
|
||||
tool_ids = list(set(tool_ids))
|
||||
|
||||
return self._create_agent(
|
||||
# Create the agent
|
||||
agent_state = self._create_agent(
|
||||
name=agent_create.name,
|
||||
system=system,
|
||||
agent_type=agent_create.agent_type,
|
||||
@@ -104,6 +116,35 @@ class AgentManager:
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# TODO: See if we can merge this into the above SQL create call for performance reasons
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
init_messages = initialize_message_sequence(
|
||||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||||
)
|
||||
|
||||
if agent_create.initial_message_sequence is not None:
|
||||
# We always need the system prompt up front
|
||||
system_message_obj = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_state.id,
|
||||
user_id=agent_state.created_by_id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict=init_messages[0],
|
||||
)
|
||||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||||
init_messages = [system_message_obj]
|
||||
init_messages.extend(
|
||||
package_initial_message_sequence(agent_state.id, agent_create.initial_message_sequence, agent_state.llm_config.model, actor)
|
||||
)
|
||||
else:
|
||||
init_messages = [
|
||||
PydanticMessage.dict_to_message(
|
||||
agent_id=agent_state.id, user_id=agent_state.created_by_id, model=agent_state.llm_config.model, openai_message_dict=msg
|
||||
)
|
||||
for msg in init_messages
|
||||
]
|
||||
|
||||
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def _create_agent(
|
||||
self,
|
||||
@@ -149,6 +190,16 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor)
|
||||
|
||||
# Rebuild the system prompt if it's different
|
||||
if agent_update.system and agent_update.system != agent_state.system:
|
||||
agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False)
|
||||
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
def _update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Update an existing agent.
|
||||
|
||||
@@ -247,6 +298,105 @@ class AgentManager:
|
||||
agent.hard_delete(session)
|
||||
return agent_state
|
||||
|
||||
# ======================================================================================================================
|
||||
# In Context Messages Management
|
||||
# ======================================================================================================================
|
||||
# TODO: There are several assumptions here that are not explicitly checked
|
||||
# TODO: 1) These message ids are valid
|
||||
# TODO: 2) These messages are ordered from oldest to newest
|
||||
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
|
||||
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
|
||||
@enforce_types
|
||||
def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||||
|
||||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||||
|
||||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||||
"""
|
||||
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
|
||||
curr_system_message = self.get_system_message(
|
||||
agent_id=agent_id, actor=actor
|
||||
) # this is the system + memory bank, not just the system prompt
|
||||
curr_system_message_openai = curr_system_message.to_openai_dict()
|
||||
|
||||
# note: we only update the system prompt if the core memory is changed
|
||||
# this means that the archival/recall memory statistics may be someout out of date
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
if curr_memory_str in curr_system_message_openai["content"] and not force:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.info(
|
||||
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return agent_state
|
||||
|
||||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||||
# For example, if we're doing a system prompt swap, this should probably be False
|
||||
if update_timestamp:
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
else:
|
||||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||||
memory_edit_timestamp = curr_system_message.created_at
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||||
if len(diff) > 0: # there was a diff
|
||||
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# Swap the system message out (only if there is a diff)
|
||||
message = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||||
)
|
||||
message = self.message_manager.create_message(message, actor=actor)
|
||||
message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system)
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||||
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
|
||||
message_ids += [m.id for m in messages]
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
# ======================================================================================================================
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
from typing import List, Optional
|
||||
import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from letta import system
|
||||
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.prompts import gpt_system
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.schemas.user import User
|
||||
from letta.system import get_initial_boot_messages, get_login_event
|
||||
from letta.utils import get_local_time
|
||||
|
||||
|
||||
# Static methods
|
||||
@@ -88,3 +99,162 @@ def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
|
||||
raise ValueError(f"Invalid agent type: {agent_type}")
|
||||
|
||||
return system
|
||||
|
||||
|
||||
# TODO: This code is kind of wonky and deserves a rewrite
|
||||
def compile_memory_metadata_block(
|
||||
memory_edit_timestamp: datetime.datetime, previous_message_count: int = 0, archival_memory_size: int = 0
|
||||
) -> str:
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
|
||||
|
||||
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
|
||||
memory_metadata_block = "\n".join(
|
||||
[
|
||||
f"### Memory [last modified: {timestamp_str}]",
|
||||
f"{previous_message_count} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{archival_memory_size} total memories you created are stored in archival memory (use functions to access them)",
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
]
|
||||
)
|
||||
return memory_metadata_block
|
||||
|
||||
|
||||
def compile_system_message(
|
||||
system_prompt: str,
|
||||
in_context_memory: Memory,
|
||||
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
The base system message may be templated, in which case we need to render the variables.
|
||||
|
||||
The following are reserved variables:
|
||||
- CORE_MEMORY: the in-context memory of the LLM
|
||||
"""
|
||||
|
||||
if user_defined_variables is not None:
|
||||
# TODO eventually support the user defining their own variables to inject
|
||||
raise NotImplementedError
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
# Add the protected memory variable
|
||||
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
||||
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
||||
else:
|
||||
# TODO should this all put into the memory.__repr__ function?
|
||||
memory_metadata_string = compile_memory_metadata_block(
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
)
|
||||
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
||||
|
||||
# Add to the variables list to inject
|
||||
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
||||
|
||||
if template_format == "f-string":
|
||||
|
||||
# Catch the special case where the system prompt is unformatted
|
||||
if append_icm_if_missing:
|
||||
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
||||
if memory_variable_string not in system_prompt:
|
||||
# In this case, append it to the end to make sure memory is still injected
|
||||
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
||||
system_prompt += "\n" + memory_variable_string
|
||||
|
||||
# render the variables using the built-in templater
|
||||
try:
|
||||
formatted_prompt = system_prompt.format_map(variables)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
||||
|
||||
else:
|
||||
# TODO support for mustache and jinja2
|
||||
raise NotImplementedError(template_format)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
def initialize_message_sequence(
|
||||
agent_state: AgentState,
|
||||
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
||||
include_initial_boot_message: bool = True,
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
) -> List[dict]:
|
||||
if memory_edit_timestamp is None:
|
||||
memory_edit_timestamp = get_local_time()
|
||||
|
||||
full_system_message = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
)
|
||||
first_user_message = get_login_event() # event letting Letta know the user just logged in
|
||||
|
||||
if include_initial_boot_message:
|
||||
if agent_state.llm_config.model is not None and "gpt-3.5" in agent_state.llm_config.model:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
|
||||
else:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
|
||||
messages = (
|
||||
[
|
||||
{"role": "system", "content": full_system_message},
|
||||
]
|
||||
+ initial_boot_messages
|
||||
+ [
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": full_system_message},
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def package_initial_message_sequence(
|
||||
agent_id: str, initial_message_sequence: List[MessageCreate], model: str, actor: User
|
||||
) -> List[Message]:
|
||||
# create the agent object
|
||||
init_messages = []
|
||||
for message_create in initial_message_sequence:
|
||||
|
||||
if message_create.role == MessageRole.user:
|
||||
packed_message = system.package_user_message(
|
||||
user_message=message_create.text,
|
||||
)
|
||||
elif message_create.role == MessageRole.system:
|
||||
packed_message = system.package_system_message(
|
||||
system_message=message_create.text,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message_create.role}")
|
||||
|
||||
init_messages.append(
|
||||
Message(role=message_create.role, text=packed_message, organization_id=actor.organization_id, agent_id=agent_id, model=model)
|
||||
)
|
||||
return init_messages
|
||||
|
||||
|
||||
def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) -> bool:
|
||||
if model not in STRUCTURED_OUTPUT_MODELS:
|
||||
if len(ToolRulesSolver(tool_rules=tool_rules).init_tool_rules) > 1:
|
||||
raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.")
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@@ -28,6 +28,21 @@ class MessageManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
|
||||
"""Fetch messages by ID and return them in the requested order."""
|
||||
with self.session_maker() as session:
|
||||
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id)
|
||||
|
||||
if len(results) != len(message_ids):
|
||||
raise NoResultFound(
|
||||
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
|
||||
)
|
||||
|
||||
# Sort results directly based on message_ids
|
||||
result_dict = {msg.id: msg.to_pydantic() for msg in results}
|
||||
return [result_dict[msg_id] for msg_id in message_ids]
|
||||
|
||||
@enforce_types
|
||||
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
|
||||
"""Create a new message."""
|
||||
|
||||
@@ -15,20 +15,16 @@ from letta.config import LettaConfig
|
||||
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.errors import (
|
||||
InvalidToolCallError,
|
||||
InvalidInnerMonologueError,
|
||||
MissingToolCallError,
|
||||
InvalidToolCallError,
|
||||
MissingInnerMonologueError,
|
||||
MissingToolCallError,
|
||||
)
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message import (
|
||||
ToolCallMessage,
|
||||
ReasoningMessage,
|
||||
LettaMessage,
|
||||
)
|
||||
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, ToolCallMessage
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
@@ -78,7 +74,13 @@ def setup_agent(
|
||||
|
||||
memory = ChatMemory(human=memory_human_str, persona=memory_persona_str)
|
||||
agent_state = client.create_agent(
|
||||
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules, include_base_tools=include_base_tools,
|
||||
name=agent_uuid,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
memory=memory,
|
||||
tool_ids=tool_ids,
|
||||
tool_rules=tool_rules,
|
||||
include_base_tools=include_base_tools,
|
||||
)
|
||||
|
||||
return agent_state
|
||||
@@ -105,12 +107,13 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet
|
||||
agent_state = setup_agent(client, filename)
|
||||
|
||||
full_agent_state = client.get_agent(agent_state.id)
|
||||
messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user)
|
||||
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
|
||||
|
||||
response = create(
|
||||
llm_config=agent_state.llm_config,
|
||||
user_id=str(uuid.UUID(int=1)), # dummy user_id
|
||||
messages=agent._messages,
|
||||
messages=messages,
|
||||
functions=[t.json_schema for t in agent.agent_state.tools],
|
||||
)
|
||||
|
||||
@@ -412,9 +415,7 @@ def assert_invoked_function_call(messages: Sequence[LettaMessage], function_name
|
||||
# Found it, do nothing
|
||||
return
|
||||
|
||||
raise MissingToolCallError(
|
||||
messages=messages, explanation=f"No messages were found invoking function call with name: {function_name}"
|
||||
)
|
||||
raise MissingToolCallError(messages=messages, explanation=f"No messages were found invoking function call with name: {function_name}")
|
||||
|
||||
|
||||
def assert_inner_monologue_is_present_and_valid(messages: List[LettaMessage]) -> None:
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.agent import Agent
|
||||
from letta.client.client import LocalClient
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
|
||||
from tests.helpers.utils import cleanup
|
||||
@@ -16,6 +19,110 @@ from tests.helpers.utils import cleanup
|
||||
LLM_CONFIG_DIR = "tests/configs/llm_model_configs"
|
||||
SUMMARY_KEY_PHRASE = "The following is a summary"
|
||||
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
|
||||
# TODO: these tests should include looping through LLM providers, since behavior may vary across providers
|
||||
# TODO: these tests should add function calls into the summarized message sequence:W
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
client = create_client()
|
||||
# client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_state(client):
|
||||
# Generate uuid for agent name for this example
|
||||
agent_state = client.create_agent(name=test_agent_name)
|
||||
yield agent_state
|
||||
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
|
||||
def test_summarize_messages_inplace(client, agent_state, mock_e2b_api_key_none):
|
||||
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
|
||||
# First send a few messages (5)
|
||||
response = client.user_message(
|
||||
agent_id=agent_state.id,
|
||||
message="Hey, how's it going? What do you think about this whole shindig",
|
||||
).messages
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
response = client.user_message(
|
||||
agent_id=agent_state.id,
|
||||
message="Any thoughts on the meaning of life?",
|
||||
).messages
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?").messages
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
response = client.user_message(
|
||||
agent_id=agent_state.id,
|
||||
message="Would you be surprised to learn that you're actually conversing with an AI right now?",
|
||||
).messages
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
# reload agent object
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
|
||||
agent_obj.summarize_messages_inplace()
|
||||
|
||||
|
||||
def test_auto_summarize(client, mock_e2b_api_key_none):
|
||||
"""Test that the summarizer triggers by itself"""
|
||||
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
|
||||
small_context_llm_config.context_window = 4000
|
||||
|
||||
small_agent_state = client.create_agent(
|
||||
name="small_context_agent",
|
||||
llm_config=small_context_llm_config,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
def summarize_message_exists(messages: List[Message]) -> bool:
|
||||
for message in messages:
|
||||
if message.text and "The following is a summary of the previous" in message.text:
|
||||
print(f"Summarize message found after {message_count} messages: \n {message.text}")
|
||||
return True
|
||||
return False
|
||||
|
||||
MAX_ATTEMPTS = 10
|
||||
message_count = 0
|
||||
while True:
|
||||
|
||||
# send a message
|
||||
response = client.user_message(
|
||||
agent_id=small_agent_state.id,
|
||||
message="What is the meaning of life?",
|
||||
)
|
||||
message_count += 1
|
||||
|
||||
print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------")
|
||||
|
||||
# check if the summarize message is inside the messages
|
||||
assert isinstance(client, LocalClient), "Test only works with LocalClient"
|
||||
in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=small_agent_state.id, actor=client.user)
|
||||
print("SUMMARY", summarize_message_exists(in_context_messages))
|
||||
if summarize_message_exists(in_context_messages):
|
||||
break
|
||||
|
||||
if message_count > MAX_ATTEMPTS:
|
||||
raise Exception(f"Summarize message not found after {message_count} messages")
|
||||
|
||||
finally:
|
||||
client.delete_agent(small_agent_state.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_filename",
|
||||
@@ -69,4 +176,5 @@ def test_summarizer(config_filename):
|
||||
|
||||
# Invoke a summarize
|
||||
letta_agent.summarize_messages_inplace(preserve_last_N_messages=False)
|
||||
assert SUMMARY_KEY_PHRASE in letta_agent.messages[1]["content"], f"Test failed for config: {config_filename}"
|
||||
in_context_messages = client.get_in_context_messages(agent_state.id)
|
||||
assert SUMMARY_KEY_PHRASE in in_context_messages[1].text, f"Test failed for config: {config_filename}"
|
||||
|
||||
@@ -18,20 +18,22 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
LettaMessage,
|
||||
ReasoningMessage,
|
||||
SystemMessage,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
ReasoningMessage,
|
||||
LettaMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.helpers.agent_manager_helper import initialize_message_sequence
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings
|
||||
from letta.utils import get_utc_time
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
|
||||
# from tests.utils import create_config
|
||||
@@ -602,18 +604,11 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
||||
If we pass in a non-empty list, we should get that sequence
|
||||
If we pass in an empty list, we should get an empty sequence
|
||||
"""
|
||||
from letta.agent import initialize_message_sequence
|
||||
from letta.utils import get_utc_time
|
||||
|
||||
# The reference initial message sequence:
|
||||
reference_init_messages = initialize_message_sequence(
|
||||
model=agent.llm_config.model,
|
||||
system=agent.system,
|
||||
agent_id=agent.id,
|
||||
memory=agent.memory,
|
||||
agent_state=agent,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# system, login message, send_message test, send_message receipt
|
||||
|
||||
@@ -259,10 +259,10 @@ def test_recall_memory(client: LocalClient, agent: AgentState):
|
||||
assert exists
|
||||
|
||||
# get in-context messages
|
||||
messages = client.get_in_context_messages(agent.id)
|
||||
in_context_messages = client.get_in_context_messages(agent.id)
|
||||
exists = False
|
||||
for m in messages:
|
||||
if message_str in str(m):
|
||||
for m in in_context_messages:
|
||||
if message_str in m.text:
|
||||
exists = True
|
||||
assert exists
|
||||
|
||||
|
||||
@@ -370,8 +370,8 @@ def other_tool(server: SyncServer, default_user, default_organization):
|
||||
@pytest.fixture
|
||||
def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create and return a sample agent within the default organization."""
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="sarah_agent",
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
@@ -385,8 +385,8 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
@pytest.fixture
|
||||
def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create and return a sample agent within the default organization."""
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="charles_agent",
|
||||
memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
@@ -503,6 +503,54 @@ def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixt
|
||||
assert len(list_agents) == 0
|
||||
|
||||
|
||||
def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block):
|
||||
memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]
|
||||
create_agent_request = CreateAgent(
|
||||
system="test system",
|
||||
memory_blocks=memory_blocks,
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
block_ids=[default_block.id],
|
||||
tags=["a", "b"],
|
||||
description="test_description",
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
actor=default_user,
|
||||
)
|
||||
assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 2
|
||||
init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user)
|
||||
# Check that the system appears in the first initial message
|
||||
assert create_agent_request.system in init_messages[0].text
|
||||
assert create_agent_request.memory_blocks[0].value in init_messages[0].text
|
||||
# Check that the second message is the passed in initial message seq
|
||||
assert create_agent_request.initial_message_sequence[0].role == init_messages[1].role
|
||||
assert create_agent_request.initial_message_sequence[0].text in init_messages[1].text
|
||||
|
||||
|
||||
def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block):
|
||||
memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]
|
||||
create_agent_request = CreateAgent(
|
||||
system="test system",
|
||||
memory_blocks=memory_blocks,
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
block_ids=[default_block.id],
|
||||
tags=["a", "b"],
|
||||
description="test_description",
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
actor=default_user,
|
||||
)
|
||||
assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 4
|
||||
init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user)
|
||||
# Check that the system appears in the first initial message
|
||||
assert create_agent_request.system in init_messages[0].text
|
||||
assert create_agent_request.memory_blocks[0].value in init_messages[0].text
|
||||
|
||||
|
||||
def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user):
|
||||
agent, _ = comprehensive_test_agent_fixture
|
||||
update_agent_request = UpdateAgent(
|
||||
@@ -794,8 +842,8 @@ def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent,
|
||||
def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization):
|
||||
"""Test pagination when listing agents by tags."""
|
||||
# Create first agent
|
||||
agent1 = server.create_agent(
|
||||
request=CreateAgent(
|
||||
agent1 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="agent1",
|
||||
tags=["pagination_test", "tag1"],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
@@ -809,8 +857,8 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
time.sleep(CREATE_DELAY_SQLITE) # Ensure distinct created_at timestamps
|
||||
|
||||
# Create second agent
|
||||
agent2 = server.create_agent(
|
||||
request=CreateAgent(
|
||||
agent2 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="agent2",
|
||||
tags=["pagination_test", "tag2"],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
@@ -1565,6 +1613,15 @@ def create_test_messages(server: SyncServer, base_message: PydanticMessage, defa
|
||||
return messages
|
||||
|
||||
|
||||
def test_get_messages_by_ids(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
|
||||
"""Test basic message listing with limit"""
|
||||
messages = create_test_messages(server, hello_world_message_fixture, default_user)
|
||||
message_ids = [m.id for m in messages]
|
||||
|
||||
results = server.message_manager.get_messages_by_ids(message_ids=message_ids, actor=default_user)
|
||||
assert sorted(message_ids) == sorted([r.id for r in results])
|
||||
|
||||
|
||||
def test_message_listing_basic(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
|
||||
"""Test basic message listing with limit"""
|
||||
create_test_messages(server, hello_world_message_fixture, default_user)
|
||||
|
||||
@@ -10,11 +10,11 @@ from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import (
|
||||
LettaMessage,
|
||||
ReasoningMessage,
|
||||
SystemMessage,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
ReasoningMessage,
|
||||
LettaMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.user import User
|
||||
@@ -507,76 +507,9 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
assert len(passage_none) == 0
|
||||
|
||||
|
||||
def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
|
||||
"""Test the /rethink, /rewrite, and /retry commands in the CLI
|
||||
|
||||
- "rethink" replaces the inner thoughts of the last assistant message
|
||||
- "rewrite" replaces the text of the last assistant message
|
||||
- "retry" retries the last assistant message
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
# Send an initial message
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
|
||||
# Grab the raw Agent object
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
assert letta_agent._messages[-1].role == MessageRole.tool
|
||||
assert letta_agent._messages[-2].role == MessageRole.assistant
|
||||
last_agent_message = letta_agent._messages[-2]
|
||||
|
||||
# Try "rethink"
|
||||
new_thought = "I am thinking about the meaning of life, the universe, and everything. Bananas?"
|
||||
assert last_agent_message.text is not None and last_agent_message.text != new_thought
|
||||
server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought, actor=actor)
|
||||
|
||||
# Grab the agent object again (make sure it's live)
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
assert letta_agent._messages[-1].role == MessageRole.tool
|
||||
assert letta_agent._messages[-2].role == MessageRole.assistant
|
||||
last_agent_message = letta_agent._messages[-2]
|
||||
assert last_agent_message.text == new_thought
|
||||
|
||||
# Try "rewrite"
|
||||
assert last_agent_message.tool_calls is not None
|
||||
assert last_agent_message.tool_calls[0].function.name == "send_message"
|
||||
assert last_agent_message.tool_calls[0].function.arguments is not None
|
||||
args_json = json.loads(last_agent_message.tool_calls[0].function.arguments)
|
||||
assert "message" in args_json and args_json["message"] is not None and args_json["message"] != ""
|
||||
|
||||
new_text = "Why hello there my good friend! Is 42 what you're looking for? Bananas?"
|
||||
server.rewrite_agent_message(agent_id=agent_id, new_text=new_text, actor=actor)
|
||||
|
||||
# Grab the agent object again (make sure it's live)
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
assert letta_agent._messages[-1].role == MessageRole.tool
|
||||
assert letta_agent._messages[-2].role == MessageRole.assistant
|
||||
last_agent_message = letta_agent._messages[-2]
|
||||
args_json = json.loads(last_agent_message.tool_calls[0].function.arguments)
|
||||
assert "message" in args_json and args_json["message"] is not None and args_json["message"] == new_text
|
||||
|
||||
# Try retry
|
||||
server.retry_agent_message(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Grab the agent object again (make sure it's live)
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
assert letta_agent._messages[-1].role == MessageRole.tool
|
||||
assert letta_agent._messages[-2].role == MessageRole.assistant
|
||||
last_agent_message = letta_agent._messages[-2]
|
||||
|
||||
# Make sure the inner thoughts changed
|
||||
assert last_agent_message.text is not None and last_agent_message.text != new_thought
|
||||
|
||||
# Make sure the message changed
|
||||
args_json = json.loads(last_agent_message.tool_calls[0].function.arguments)
|
||||
print(args_json)
|
||||
assert "message" in args_json and args_json["message"] is not None and args_json["message"] != new_text
|
||||
|
||||
|
||||
def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str):
|
||||
"""Test that the context window overview fetch works"""
|
||||
|
||||
overview = server.get_agent_context_window(user_id=user_id, agent_id=agent_id)
|
||||
overview = server.get_agent_context_window(agent_id=agent_id, actor=server.user_manager.get_user_or_default(user_id))
|
||||
assert overview is not None
|
||||
|
||||
# Run some basic checks
|
||||
@@ -1142,10 +1075,10 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
|
||||
|
||||
# Add all the base tools
|
||||
request.tool_ids = [b.id for b in base_tools]
|
||||
agent_state = server.update_agent(agent_state.id, request=request, actor=actor)
|
||||
agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
|
||||
assert len(agent_state.tools) == len(base_tools)
|
||||
|
||||
# Remove one base tool
|
||||
request.tool_ids = [b.id for b in base_tools[:-2]]
|
||||
agent_state = server.update_agent(agent_state.id, request=request, actor=actor)
|
||||
agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
|
||||
assert len(agent_state.tools) == len(base_tools) - 2
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
|
||||
from .utils import wipe_config
|
||||
|
||||
# test_agent_id = "test_agent"
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
client = None
|
||||
agent_obj = None
|
||||
|
||||
# TODO: these tests should include looping through LLM providers, since behavior may vary across providers
|
||||
# TODO: these tests should add function calls into the summarized message sequence:W
|
||||
|
||||
|
||||
def create_test_agent():
|
||||
"""Create a test agent that we can call functions on"""
|
||||
wipe_config()
|
||||
|
||||
global client
|
||||
client = create_client()
|
||||
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
agent_state = client.create_agent(
|
||||
name=test_agent_name,
|
||||
)
|
||||
|
||||
global agent_obj
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
|
||||
|
||||
def test_summarize_messages_inplace(mock_e2b_api_key_none):
|
||||
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
|
||||
global client
|
||||
global agent_obj
|
||||
|
||||
if agent_obj is None:
|
||||
create_test_agent()
|
||||
|
||||
assert agent_obj is not None, "Run create_agent test first"
|
||||
assert client is not None, "Run create_agent test first"
|
||||
|
||||
# First send a few messages (5)
|
||||
response = client.user_message(
|
||||
agent_id=agent_obj.agent_state.id,
|
||||
message="Hey, how's it going? What do you think about this whole shindig",
|
||||
).messages
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
response = client.user_message(
|
||||
agent_id=agent_obj.agent_state.id,
|
||||
message="Any thoughts on the meaning of life?",
|
||||
).messages
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
response = client.user_message(agent_id=agent_obj.agent_state.id, message="Does the number 42 ring a bell?").messages
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
response = client.user_message(
|
||||
agent_id=agent_obj.agent_state.id,
|
||||
message="Would you be surprised to learn that you're actually conversing with an AI right now?",
|
||||
).messages
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
# reload agent object
|
||||
agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id, actor=client.user)
|
||||
|
||||
agent_obj.summarize_messages_inplace()
|
||||
print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}")
|
||||
# response = client.run_command(agent_id=agent_obj.agent_state.id, command="summarize")
|
||||
|
||||
|
||||
def test_auto_summarize(mock_e2b_api_key_none):
|
||||
"""Test that the summarizer triggers by itself"""
|
||||
client = create_client()
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
small_context_llm_config = LLMConfig.default_config("gpt-4")
|
||||
# default system prompt + funcs lead to ~2300 tokens, after one message it's at 2523 tokens
|
||||
SMALL_CONTEXT_WINDOW = 4000
|
||||
small_context_llm_config.context_window = SMALL_CONTEXT_WINDOW
|
||||
|
||||
agent_state = client.create_agent(
|
||||
name="small_context_agent",
|
||||
llm_config=small_context_llm_config,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
def summarize_message_exists(messages: List[Message]) -> bool:
|
||||
for message in messages:
|
||||
if message.text and "The following is a summary of the previous" in message.text:
|
||||
print(f"Summarize message found after {message_count} messages: \n {message.text}")
|
||||
return True
|
||||
return False
|
||||
|
||||
MAX_ATTEMPTS = 5
|
||||
message_count = 0
|
||||
while True:
|
||||
|
||||
# send a message
|
||||
response = client.user_message(
|
||||
agent_id=agent_state.id,
|
||||
message="What is the meaning of life?",
|
||||
)
|
||||
message_count += 1
|
||||
|
||||
print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------")
|
||||
|
||||
# check if the summarize message is inside the messages
|
||||
assert isinstance(client, LocalClient), "Test only works with LocalClient"
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
print("SUMMARY", summarize_message_exists(agent_obj._messages))
|
||||
if summarize_message_exists(agent_obj._messages):
|
||||
break
|
||||
|
||||
if message_count > MAX_ATTEMPTS:
|
||||
raise Exception(f"Summarize message not found after {message_count} messages")
|
||||
|
||||
finally:
|
||||
client.delete_agent(agent_state.id)
|
||||
Reference in New Issue
Block a user