fix: Remove in-memory _messages field on Agent (#2295)

This commit is contained in:
Matthew Zhou
2024-12-20 15:52:04 -08:00
committed by GitHub
parent e9239cf1bf
commit 5bb4888cea
18 changed files with 650 additions and 1164 deletions

View File

@@ -33,7 +33,6 @@ jobs:
- "test_memory.py"
- "test_utils.py"
- "test_stream_buffer_readers.py"
- "test_summarize.py"
services:
qdrant:
image: qdrant/qdrant

View File

@@ -1,25 +1,22 @@
import datetime
import inspect
import json
import time
import traceback
import warnings
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from letta.constants import (
BASE_TOOLS,
CLI_WARNING_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
IN_CONTEXT_MEMORY_KEYWORD,
LLM_MAX_TOKENS,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC,
O1_BASE_TOOLS,
REQ_HEARTBEAT_MESSAGE,
STRUCTURED_OUTPUT_MODELS,
)
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
@@ -34,7 +31,7 @@ from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.memory import ContextWindowOverview, Memory
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import (
Tool as ChatCompletionRequestTool,
)
@@ -49,6 +46,10 @@ from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
check_supports_structured_output,
compile_memory_metadata_block,
)
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
@@ -56,8 +57,6 @@ from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import (
get_heartbeat,
get_initial_boot_messages,
get_login_event,
get_token_limit_warning,
package_function_response,
package_summarize_message,
@@ -66,166 +65,20 @@ from letta.system import (
from letta.utils import (
count_tokens,
get_friendly_error_msg,
get_local_time,
get_tool_call_id,
get_utc_time,
is_utc_datetime,
json_dumps,
json_loads,
parse_json,
printd,
united_diff,
validate_function_response,
verify_first_message_correctness,
)
def compile_memory_metadata_block(
actor: PydanticUser,
agent_id: str,
memory_edit_timestamp: datetime.datetime,
agent_manager: Optional[AgentManager] = None,
message_manager: Optional[MessageManager] = None,
) -> str:
# Put the timestamp in the local timezone (mimicking get_local_time())
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
memory_metadata_block = "\n".join(
[
f"### Memory [last modified: {timestamp_str}]",
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{agent_manager.passage_size(actor=actor, agent_id=agent_id) if agent_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
]
)
return memory_metadata_block
def compile_system_message(
system_prompt: str,
agent_id: str,
in_context_memory: Memory,
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
actor: PydanticUser,
agent_manager: Optional[AgentManager] = None,
message_manager: Optional[MessageManager] = None,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
The base system message may be templated, in which case we need to render the variables.
The following are reserved variables:
- CORE_MEMORY: the in-context memory of the LLM
"""
if user_defined_variables is not None:
# TODO eventually support the user defining their own variables to inject
raise NotImplementedError
else:
variables = {}
# Add the protected memory variable
if IN_CONTEXT_MEMORY_KEYWORD in variables:
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
else:
# TODO should this all put into the memory.__repr__ function?
memory_metadata_string = compile_memory_metadata_block(
actor=actor,
agent_id=agent_id,
memory_edit_timestamp=in_context_memory_last_edit,
agent_manager=agent_manager,
message_manager=message_manager,
)
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
# Add to the variables list to inject
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
if template_format == "f-string":
# Catch the special case where the system prompt is unformatted
if append_icm_if_missing:
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
if memory_variable_string not in system_prompt:
# In this case, append it to the end to make sure memory is still injected
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
system_prompt += "\n" + memory_variable_string
# render the variables using the built-in templater
try:
formatted_prompt = system_prompt.format_map(variables)
except Exception as e:
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
else:
# TODO support for mustache and jinja2
raise NotImplementedError(template_format)
return formatted_prompt
def initialize_message_sequence(
model: str,
system: str,
agent_id: str,
memory: Memory,
actor: PydanticUser,
agent_manager: Optional[AgentManager] = None,
message_manager: Optional[MessageManager] = None,
memory_edit_timestamp: Optional[datetime.datetime] = None,
include_initial_boot_message: bool = True,
) -> List[dict]:
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
# full_system_message = construct_system_with_memory(
# system, memory, memory_edit_timestamp, agent_manager=agent_manager, recall_memory=recall_memory
# )
full_system_message = compile_system_message(
agent_id=agent_id,
system_prompt=system,
in_context_memory=memory,
in_context_memory_last_edit=memory_edit_timestamp,
actor=actor,
agent_manager=agent_manager,
message_manager=message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
)
first_user_message = get_login_event() # event letting Letta know the user just logged in
if include_initial_boot_message:
if model is not None and "gpt-3.5" in model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
messages = (
[
{"role": "system", "content": full_system_message},
]
+ initial_boot_messages
+ [
{"role": "user", "content": first_user_message},
]
)
else:
messages = [
{"role": "system", "content": full_system_message},
{"role": "user", "content": first_user_message},
]
return messages
class BaseAgent(ABC):
"""
Abstract class for all agents.
Only two interfaces are required: step and update_state.
Only one interface is required: step.
"""
@abstractmethod
@@ -238,10 +91,6 @@ class BaseAgent(ABC):
"""
raise NotImplementedError
@abstractmethod
def update_state(self) -> AgentState:
raise NotImplementedError
class Agent(BaseAgent):
def __init__(
@@ -250,9 +99,7 @@ class Agent(BaseAgent):
agent_state: AgentState, # in-memory representation of the agent state (read from multiple tables)
user: User,
# extras
messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config?
initial_message_sequence: Optional[List[Message]] = None,
):
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
# Hold a copy of the state that was used to init the agent
@@ -276,7 +123,7 @@ class Agent(BaseAgent):
# gpt-4, gpt-3.5-turbo, ...
self.model = self.agent_state.llm_config.model
self.check_tool_rules()
self.supports_structured_output = check_supports_structured_output(model=self.model, tool_rules=agent_state.tool_rules)
# state managers
self.block_manager = BlockManager()
@@ -304,99 +151,14 @@ class Agent(BaseAgent):
# When the summarizer is run, set this back to False (to reset)
self.agent_alerted_about_memory_pressure = False
self._messages: List[Message] = []
# Once the memory object is initialized, use it to "bake" the system message
if self.agent_state.message_ids is not None:
self.set_message_buffer(message_ids=self.agent_state.message_ids)
else:
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
assert self.agent_state.id is not None and self.agent_state.created_by_id is not None
# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence(
model=self.model,
system=self.agent_state.system,
agent_id=self.agent_state.id,
memory=self.agent_state.memory,
actor=self.user,
agent_manager=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)
if initial_message_sequence is not None:
# We always need the system prompt up front
system_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=init_messages[0],
)
# Don't use anything else in the pregen sequence, instead use the provided sequence
init_messages = [system_message_obj] + initial_message_sequence
else:
# Basic "more human than human" initial message sequence
init_messages = initialize_message_sequence(
model=self.model,
system=self.agent_state.system,
memory=self.agent_state.memory,
agent_id=self.agent_state.id,
actor=self.user,
agent_manager=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)
# Cast to Message objects
init_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id, user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=msg
)
for msg in init_messages
]
# Cast the messages to actual Message objects to be synced to the DB
init_messages_objs = []
for msg in init_messages:
init_messages_objs.append(msg)
for msg in init_messages_objs:
assert isinstance(msg, Message), f"Message object is not of type Message: {type(msg)}"
assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages)
# Put the messages inside the message buffer
self.messages_total = 0
self._append_to_messages(added_messages=init_messages_objs)
self._validate_message_buffer_is_utc()
# Load last function response from message history
self.last_function_response = self.load_last_function_response()
# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
self.messages_total_init = len(self._messages) - 1
printd(f"Agent initialized, self.messages_total={self.messages_total}")
# Create the agent in the DB
self.update_state()
def check_tool_rules(self):
if self.model not in STRUCTURED_OUTPUT_MODELS:
if len(self.tool_rules_solver.init_tool_rules) > 1:
raise ValueError(
"Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule."
)
self.supports_structured_output = False
else:
self.supports_structured_output = True
def load_last_function_response(self):
"""Load the last function response from message history"""
for i in range(len(self._messages) - 1, -1, -1):
msg = self._messages[i]
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
for i in range(len(in_context_messages) - 1, -1, -1):
msg = in_context_messages[i]
if msg.role == MessageRole.tool and msg.text:
try:
response_json = json.loads(msg.text)
@@ -435,7 +197,7 @@ class Agent(BaseAgent):
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
# rebuild memory - this records the last edited timestamp of the memory
# TODO: pass in update timestamp from block edit time
self.rebuild_system_prompt()
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
return True
return False
@@ -487,109 +249,6 @@ class Agent(BaseAgent):
return function_response
@property
def messages(self) -> List[dict]:
"""Getter method that converts the internal Message list into OpenAI-style dicts"""
return [msg.to_openai_dict() for msg in self._messages]
@messages.setter
def messages(self, value):
raise Exception("Modifying message list directly not allowed")
def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
"""Load a list of messages from recall storage"""
# Pull the message objects from the database
message_objs = []
for msg_id in message_ids:
msg_obj = self.message_manager.get_message_by_id(msg_id, actor=self.user)
if msg_obj:
if isinstance(msg_obj, Message):
message_objs.append(msg_obj)
else:
printd(f"Warning - message ID {msg_id} is not a Message object")
warnings.warn(f"Warning - message ID {msg_id} is not a Message object")
else:
printd(f"Warning - message ID {msg_id} not found in recall storage")
warnings.warn(f"Warning - message ID {msg_id} not found in recall storage")
return message_objs
def _validate_message_buffer_is_utc(self):
"""Iterate over the message buffer and force all messages to be UTC stamped"""
for m in self._messages:
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if m.created_at:
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
def set_message_buffer(self, message_ids: List[str], force_utc: bool = True):
"""Set the messages in the buffer to the message IDs list"""
message_objs = self._load_messages_from_recall(message_ids=message_ids)
# set the objects in the buffer
self._messages = message_objs
# bugfix for old agents that may not have had UTC specified in their timestamps
if force_utc:
self._validate_message_buffer_is_utc()
# also sync the message IDs attribute
self.agent_state.message_ids = message_ids
def refresh_message_buffer(self):
"""Refresh the message buffer from the database"""
messages_to_sync = self.agent_state.message_ids
assert messages_to_sync and all([isinstance(msg_id, str) for msg_id in messages_to_sync])
self.set_message_buffer(message_ids=messages_to_sync)
def _trim_messages(self, num):
"""Trim messages from the front, not including the system message"""
new_messages = [self._messages[0]] + self._messages[num:]
self._messages = new_messages
def _prepend_to_messages(self, added_messages: List[Message]):
"""Wrapper around self.messages.prepend to allow additional calls to a state/persistence manager"""
assert all([isinstance(msg, Message) for msg in added_messages])
self.message_manager.create_many_messages(added_messages, actor=self.user)
new_messages = [self._messages[0]] + added_messages + self._messages[1:] # prepend (no system)
self._messages = new_messages
self.messages_total += len(added_messages) # still should increment the message counter (summaries are additions too)
def _append_to_messages(self, added_messages: List[Message]):
"""Wrapper around self.messages.append to allow additional calls to a state/persistence manager"""
assert all([isinstance(msg, Message) for msg in added_messages])
self.message_manager.create_many_messages(added_messages, actor=self.user)
# strip extra metadata if it exists
# for msg in added_messages:
# msg.pop("api_response", None)
# msg.pop("api_args", None)
new_messages = self._messages + added_messages # append
self._messages = new_messages
self.messages_total += len(added_messages)
def append_to_messages(self, added_messages: List[dict]):
"""An external-facing message append, where dict-like messages are first converted to Message objects"""
added_messages_objs = [
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=msg,
)
for msg in added_messages
]
self._append_to_messages(added_messages_objs)
def _get_ai_reply(
self,
message_sequence: List[Message],
@@ -898,7 +557,7 @@ class Agent(BaseAgent):
# rebuild memory
# TODO: @charles please check this
self.rebuild_system_prompt()
self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user)
# Update ToolRulesSolver state with last called function
self.tool_rules_solver.update_tool_usage(function_name)
@@ -930,6 +589,7 @@ class Agent(BaseAgent):
messages=next_input_message,
**kwargs,
)
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
@@ -1021,33 +681,19 @@ class Agent(BaseAgent):
if not all(isinstance(m, Message) for m in messages):
raise ValueError(f"messages should be a Message or a list of Message, got {type(messages)}")
input_message_sequence = self._messages + messages
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
input_message_sequence = in_context_messages + messages
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
# Step 2: send the conversation and available functions to the LLM
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
while True:
response = self._get_ai_reply(
message_sequence=input_message_sequence, first_message=True, stream=stream # passed through to the prompt formatter
)
if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
counter += 1
if counter > first_message_retry_limit:
raise Exception(f"Hit first message retry limit ({first_message_retry_limit})")
else:
response = self._get_ai_reply(
message_sequence=input_message_sequence,
first_message=first_message,
stream=stream,
step_count=step_count,
)
response = self._get_ai_reply(
message_sequence=input_message_sequence,
first_message=first_message,
stream=stream,
step_count=step_count,
)
# Step 3: check if LLM wanted to call a function
# (if yes) Step 4: call the function
@@ -1095,10 +741,9 @@ class Agent(BaseAgent):
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
)
self._append_to_messages(all_new_messages)
# update state after each step
self.update_state()
self.agent_state = self.agent_manager.append_to_in_context_messages(
all_new_messages, agent_id=self.agent_state.id, actor=self.user
)
return AgentStepResponse(
messages=all_new_messages,
@@ -1113,7 +758,9 @@ class Agent(BaseAgent):
# If we got a context alert, try trimming the messages length, then try again
if is_context_overflow_error(e):
printd(f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages")
printd(
f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages"
)
# A separate API call to run a summarizer
self.summarize_messages_inplace()
@@ -1165,15 +812,19 @@ class Agent(BaseAgent):
return self.inner_step(messages=[user_message], **kwargs)
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
if in_context_messages_openai[0]["role"] != "system":
raise RuntimeError(f"in_context_messages_openai[0] should be system (instead got {in_context_messages_openai[0]})")
# Start at index 1 (past the system message),
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
# Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling
token_counts = [count_tokens(str(msg)) for msg in self.messages]
token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai]
message_buffer_token_count = sum(token_counts[1:]) # no system message
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
candidate_messages_to_summarize = self.messages[1:]
candidate_messages_to_summarize = in_context_messages_openai[1:]
token_counts = token_counts[1:]
if preserve_last_N_messages:
@@ -1193,7 +844,7 @@ class Agent(BaseAgent):
"Not enough messages to compress for summarization",
details={
"num_candidate_messages": len(candidate_messages_to_summarize),
"num_total_messages": len(self.messages),
"num_total_messages": len(in_context_messages_openai),
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
},
)
@@ -1212,9 +863,9 @@ class Agent(BaseAgent):
# Try to make an assistant message come after the cutoff
try:
printd(f"Selected cutoff {cutoff} was a 'user', shifting one...")
if self.messages[cutoff]["role"] == "user":
if in_context_messages_openai[cutoff]["role"] == "user":
new_cutoff = cutoff + 1
if self.messages[new_cutoff]["role"] == "user":
if in_context_messages_openai[new_cutoff]["role"] == "user":
printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...")
cutoff = new_cutoff
except IndexError:
@@ -1222,23 +873,23 @@ class Agent(BaseAgent):
# Make sure the cutoff isn't on a 'tool' or 'function'
if disallow_tool_as_first:
while self.messages[cutoff]["role"] in ["tool", "function"] and cutoff < len(self.messages):
while in_context_messages_openai[cutoff]["role"] in ["tool", "function"] and cutoff < len(in_context_messages_openai):
printd(f"Selected cutoff {cutoff} was a 'tool', shifting one...")
cutoff += 1
message_sequence_to_summarize = self._messages[1:cutoff] # do NOT get rid of the system message
message_sequence_to_summarize = in_context_messages[1:cutoff] # do NOT get rid of the system message
if len(message_sequence_to_summarize) <= 1:
# This prevents a potential infinite loop of summarizing the same message over and over
raise ContextWindowExceededError(
"Not enough messages to compress for summarization after determining cutoff",
details={
"num_candidate_messages": len(message_sequence_to_summarize),
"num_total_messages": len(self.messages),
"num_total_messages": len(in_context_messages_openai),
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
},
)
else:
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self._messages)}")
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(in_context_messages)}")
# We can't do summarize logic properly if context_window is undefined
if self.agent_state.llm_config.context_window is None:
@@ -1253,118 +904,33 @@ class Agent(BaseAgent):
printd(f"Got summary: {summary}")
# Metadata that's useful for the agent to see
all_time_message_count = self.messages_total
remaining_message_count = len(self.messages[cutoff:])
all_time_message_count = self.message_manager.size(agent_id=self.agent_state.id, actor=self.user)
remaining_message_count = len(in_context_messages_openai[cutoff:])
hidden_message_count = all_time_message_count - remaining_message_count
summary_message_count = len(message_sequence_to_summarize)
summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count)
printd(f"Packaged into message: {summary_message}")
prior_len = len(self.messages)
self._trim_messages(cutoff)
prior_len = len(in_context_messages_openai)
self.agent_state = self.agent_manager.trim_older_in_context_messages(cutoff, agent_id=self.agent_state.id, actor=self.user)
packed_summary_message = {"role": "user", "content": summary_message}
self._prepend_to_messages(
[
self.agent_state = self.agent_manager.prepend_to_in_context_messages(
messages=[
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=packed_summary_message,
)
]
],
agent_id=self.agent_state.id,
actor=self.user,
)
# reset alert
self.agent_alerted_about_memory_pressure = False
printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}")
def _swap_system_message_in_buffer(self, new_system_message: str):
"""Update the system message (NOT prompt) of the Agent (requires updating the internal buffer)"""
assert isinstance(new_system_message, str)
new_system_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={"role": "system", "content": new_system_message},
)
assert new_system_message_obj.role == "system", new_system_message_obj
assert self._messages[0].role == "system", self._messages
self.message_manager.create_message(new_system_message_obj, actor=self.user)
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages
def rebuild_system_prompt(self, force=False, update_timestamp=True):
"""Rebuilds the system message with the latest memory object and any shared memory block updates
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
"""
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
# note: we only update the system prompt if the core memory is changed
# this means that the archival/recall memory statistics may be someout out of date
curr_memory_str = self.agent_state.memory.compile()
if curr_memory_str in curr_system_message["content"] and not force:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
printd(f"Memory hasn't changed, skipping system prompt rebuild")
return
# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False
if update_timestamp:
memory_edit_timestamp = get_utc_time()
else:
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
memory_edit_timestamp = self._messages[0].created_at
# update memory (TODO: potentially update recall/archival stats separately)
new_system_message_str = compile_system_message(
agent_id=self.agent_state.id,
system_prompt=self.agent_state.system,
in_context_memory=self.agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
actor=self.user,
agent_manager=self.agent_manager,
message_manager=self.message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
)
new_system_message = {
"role": "system",
"content": new_system_message_str,
}
diff = united_diff(curr_system_message["content"], new_system_message["content"])
if len(diff) > 0: # there was a diff
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Swap the system message out (only if there is a diff)
self._swap_system_message_in_buffer(new_system_message=new_system_message_str)
assert self.messages[0]["content"] == new_system_message["content"], (
self.messages[0]["content"],
new_system_message["content"],
)
def update_system_prompt(self, new_system_prompt: str):
"""Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)"""
assert isinstance(new_system_prompt, str)
if new_system_prompt == self.agent_state.system:
return
self.agent_state.system = new_system_prompt
# updating the system prompt requires rebuilding the memory block inside the compiled system message
self.rebuild_system_prompt(force=True, update_timestamp=False)
# make sure to persist the change
_ = self.update_state()
printd(f"Ran summarizer, messages length {prior_len} -> {len(in_context_messages_openai)}")
def add_function(self, function_name: str) -> str:
# TODO: refactor
@@ -1374,20 +940,6 @@ class Agent(BaseAgent):
# TODO: refactor
raise NotImplementedError
def update_state(self) -> AgentState:
# TODO: this should be removed and self._messages should be moved into self.agent_state.in_context_messages
message_ids = [msg.id for msg in self._messages]
# Assert that these are all strings
if any(not isinstance(m_id, str) for m_id in message_ids):
warnings.warn(f"Non-string message IDs found in agent state: {message_ids}")
message_ids = [m_id for m_id in message_ids if isinstance(m_id, str)]
# override any fields that may have been updated
self.agent_state.message_ids = message_ids
return self.agent_state
def migrate_embedding(self, embedding_config: EmbeddingConfig):
"""Migrate the agent to a new embedding"""
# TODO: archival memory
@@ -1421,123 +973,6 @@ class Agent(BaseAgent):
f"Attached data source {source.name} to agent {self.agent_state.name}.",
)
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
"""Update the details of a message associated with an agent"""
# Save the updated message
updated_message = self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=self.user)
return updated_message
# TODO(sarah): should we be creating a new message here, or just editing a message?
def rethink_message(self, new_thought: str) -> Message:
"""Rethink / update the last message"""
for x in range(len(self.messages) - 1, 0, -1):
msg_obj = self._messages[x]
if msg_obj.role == MessageRole.assistant:
updated_message = self.update_message(
message_id=msg_obj.id,
request=MessageUpdate(
text=new_thought,
),
)
self.refresh_message_buffer()
return updated_message
raise ValueError(f"No assistant message found to update")
# TODO(sarah): should we be creating a new message here, or just editing a message?
def rewrite_message(self, new_text: str) -> Message:
"""Rewrite / update the send_message text on the last message"""
# Walk backwards through the messages until we find an assistant message
for x in range(len(self._messages) - 1, 0, -1):
if self._messages[x].role == MessageRole.assistant:
# Get the current message content
message_obj = self._messages[x]
# The rewrite target is the output of send_message
if message_obj.tool_calls is not None and len(message_obj.tool_calls) > 0:
# Check that we hit an assistant send_message call
name_string = message_obj.tool_calls[0].function.name
if name_string is None or name_string != "send_message":
raise ValueError("Assistant missing send_message function call")
args_string = message_obj.tool_calls[0].function.arguments
if args_string is None:
raise ValueError("Assistant missing send_message function arguments")
args_json = json_loads(args_string)
if "message" not in args_json:
raise ValueError("Assistant missing send_message message argument")
# Once we found our target, rewrite it
args_json["message"] = new_text
new_args_string = json_dumps(args_json)
message_obj.tool_calls[0].function.arguments = new_args_string
# Write the update to the DB
updated_message = self.update_message(
message_id=message_obj.id,
request=MessageUpdate(
tool_calls=message_obj.tool_calls,
),
)
self.refresh_message_buffer()
return updated_message
raise ValueError("No assistant message found to update")
def pop_message(self, count: int = 1) -> List[Message]:
"""Pop the last N messages from the agent's memory"""
n_messages = len(self._messages)
popped_messages = []
MIN_MESSAGES = 2
if n_messages <= MIN_MESSAGES:
raise ValueError(f"Agent only has {n_messages} messages in stack, none left to pop")
elif n_messages - count < MIN_MESSAGES:
raise ValueError(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}")
else:
# print(f"Popping last {count} messages from stack")
for _ in range(min(count, len(self._messages))):
# remove the message from the internal state of the agent
deleted_message = self._messages.pop()
# then also remove it from recall storage
try:
self.message_manager.delete_message_by_id(deleted_message.id, actor=self.user)
popped_messages.append(deleted_message)
except Exception as e:
warnings.warn(f"Error deleting message {deleted_message.id} from recall memory: {e}")
self._messages.append(deleted_message)
break
return popped_messages
def pop_until_user(self) -> List[Message]:
"""Pop all messages until the last user message"""
if MessageRole.user not in [msg.role for msg in self._messages]:
raise ValueError("No user message found in buffer")
popped_messages = []
while len(self._messages) > 0:
if self._messages[-1].role == MessageRole.user:
# we want to pop up to the last user message
return popped_messages
else:
popped_messages.append(self.pop_message(count=1))
raise ValueError("No user message found in buffer")
def retry_message(self) -> List[Message]:
"""Retry / regenerate the last message"""
self.pop_until_user()
user_message = self.pop_message(count=1)[0]
assert user_message.text is not None, "User message text is None"
step_response = self.step_user_message(user_message_str=user_message.text)
messages = step_response.messages
assert messages is not None
assert all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects"
return messages
def get_context_window(self) -> ContextWindowOverview:
"""Get the context window of the agent"""
@@ -1546,24 +981,28 @@ class Agent(BaseAgent):
core_memory = self.agent_state.memory.compile()
num_tokens_core_memory = count_tokens(core_memory)
# Grab the in-context messages
# conversion of messages to OpenAI dict format, which is passed to the token counter
messages_openai_format = self.messages
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
# Check if there's a summary message in the message queue
if (
len(self._messages) > 1
and self._messages[1].role == MessageRole.user
and isinstance(self._messages[1].text, str)
len(in_context_messages) > 1
and in_context_messages[1].role == MessageRole.user
and isinstance(in_context_messages[1].text, str)
# TODO remove hardcoding
and "The following is a summary of the previous " in self._messages[1].text
and "The following is a summary of the previous " in in_context_messages[1].text
):
# Summary message exists
assert self._messages[1].text is not None
summary_memory = self._messages[1].text
num_tokens_summary_memory = count_tokens(self._messages[1].text)
assert in_context_messages[1].text is not None
summary_memory = in_context_messages[1].text
num_tokens_summary_memory = count_tokens(in_context_messages[1].text)
# with a summary message, the real messages start at index 2
num_tokens_messages = (
num_tokens_from_messages(messages=messages_openai_format[2:], model=self.model) if len(messages_openai_format) > 2 else 0
num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model)
if len(in_context_messages_openai) > 2
else 0
)
else:
@@ -1571,17 +1010,17 @@ class Agent(BaseAgent):
num_tokens_summary_memory = 0
# with no summary message, the real messages start at index 1
num_tokens_messages = (
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model)
if len(in_context_messages_openai) > 1
else 0
)
agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id)
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
external_memory_summary = compile_memory_metadata_block(
actor=self.user,
agent_id=self.agent_state.id,
memory_edit_timestamp=get_utc_time(), # dummy timestamp
agent_manager=self.agent_manager,
message_manager=self.message_manager,
memory_edit_timestamp=get_utc_time(),
previous_message_count=self.message_manager.size(actor=self.user, agent_id=self.agent_state.id),
archival_memory_size=self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id),
)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
@@ -1606,7 +1045,7 @@ class Agent(BaseAgent):
return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(self._messages),
num_messages=len(in_context_messages),
num_archival_memory=agent_manager_passage_size,
num_recall_memory=message_manager_size,
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
@@ -1621,7 +1060,7 @@ class Agent(BaseAgent):
num_tokens_summary_memory=num_tokens_summary_memory,
summary_memory=summary_memory,
num_tokens_messages=num_tokens_messages,
messages=self._messages,
messages=in_context_messages,
# related to functions
num_tokens_functions_definitions=num_tokens_available_functions_definitions,
functions_definitions=available_functions_definitions,
@@ -1635,7 +1074,6 @@ class Agent(BaseAgent):
def save_agent(agent: Agent):
"""Save agent to metadata store"""
agent.update_state()
agent_state = agent.agent_state
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"

View File

@@ -63,8 +63,8 @@ class ChatOnlyAgent(Agent):
conversation_persona_block_new = Block(
name="chat_agent_persona_new", label="chat_agent_persona_new", value=conversation_persona_block.value, limit=2000
)
recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit :]
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
recent_convo = "".join([str(message) for message in in_context_messages[3:]])[-self.recent_convo_limit :]
conversation_messages_block = Block(
name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit
)

View File

@@ -2234,7 +2234,7 @@ class LocalClient(AbstractClient):
"""
# TODO: add the abilitty to reset linked block_ids
self.interface.clear()
agent_state = self.server.update_agent(
agent_state = self.server.agent_manager.update_agent(
agent_id,
UpdateAgent(
name=name,
@@ -2262,7 +2262,7 @@ class LocalClient(AbstractClient):
List[Tool]: A list of Tool objs
"""
self.interface.clear()
return self.server.get_tools_from_agent(agent_id=agent_id, user_id=self.user_id)
return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user).tools
def add_tool_to_agent(self, agent_id: str, tool_id: str):
"""
@@ -2276,7 +2276,7 @@ class LocalClient(AbstractClient):
agent_state (AgentState): State of the updated agent
"""
self.interface.clear()
agent_state = self.server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
agent_state = self.server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
return agent_state
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
@@ -2291,7 +2291,7 @@ class LocalClient(AbstractClient):
agent_state (AgentState): State of the updated agent
"""
self.interface.clear()
agent_state = self.server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
agent_state = self.server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
return agent_state
def rename_agent(self, agent_id: str, new_name: str):
@@ -2426,7 +2426,7 @@ class LocalClient(AbstractClient):
Returns:
messages (List[Message]): List of in-context messages
"""
return self.server.get_in_context_messages(agent_id=agent_id, actor=self.user)
return self.server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=self.user)
# agent interactions
@@ -3075,7 +3075,7 @@ class LocalClient(AbstractClient):
agent_id (str): ID of the agent
memory_id (str): ID of the memory
"""
self.server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=self.user)
self.server.delete_archival_memory(memory_id=memory_id, actor=self.user)
def get_archival_memory(
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000

View File

@@ -194,46 +194,6 @@ def run_agent_loop(
print(f"Current model: {letta_agent.agent_state.llm_config.model}")
continue
elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "):
# Check if there's an additional argument that's an integer
command = user_input.strip().split()
pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3
try:
popped_messages = letta_agent.pop_message(count=pop_amount)
except ValueError as e:
print(f"Error popping messages: {e}")
continue
elif user_input.lower() == "/retry":
print(f"Retrying for another answer...")
try:
letta_agent.retry_message()
except Exception as e:
print(f"Error retrying message: {e}")
continue
elif user_input.lower() == "/rethink" or user_input.lower().startswith("/rethink "):
if len(user_input) < len("/rethink "):
print("Missing text after the command")
continue
try:
letta_agent.rethink_message(new_thought=user_input[len("/rethink ") :].strip())
except Exception as e:
print(f"Error rethinking message: {e}")
continue
elif user_input.lower() == "/rewrite" or user_input.lower().startswith("/rewrite "):
if len(user_input) < len("/rewrite "):
print("Missing text after the command")
continue
text = user_input[len("/rewrite ") :].strip()
try:
letta_agent.rewrite_message(new_text=text)
except Exception as e:
print(f"Error rewriting message: {e}")
continue
elif user_input.lower() == "/summarize":
try:
letta_agent.summarize_messages_inplace()
@@ -319,42 +279,6 @@ def run_agent_loop(
questionary.print(cmd, "bold")
questionary.print(f" {desc}")
continue
elif user_input.lower().startswith("/systemswap"):
if len(user_input) < len("/systemswap "):
print("Missing new system prompt after the command")
continue
old_system_prompt = letta_agent.system
new_system_prompt = user_input[len("/systemswap ") :].strip()
# Show warning and prompts to user
typer.secho(
"\nWARNING: You are about to change the system prompt.",
# fg=typer.colors.BRIGHT_YELLOW,
bold=True,
)
typer.secho(
f"\nOld system prompt:\n{old_system_prompt}",
fg=typer.colors.RED,
bold=True,
)
typer.secho(
f"\nNew system prompt:\n{new_system_prompt}",
fg=typer.colors.GREEN,
bold=True,
)
# Ask for confirmation
confirm = questionary.confirm("Do you want to proceed with the swap?").ask()
if confirm:
letta_agent.update_system_prompt(new_system_prompt=new_system_prompt)
print("System prompt updated successfully.")
else:
print("System prompt swap cancelled.")
continue
else:
print(f"Unrecognized command: {user_input}")
continue

View File

@@ -129,9 +129,8 @@ class OfflineMemoryAgent(Agent):
# extras
first_message_verify_mono: bool = False,
max_memory_rethinks: int = 10,
initial_message_sequence: Optional[List[Message]] = None,
):
super().__init__(interface, agent_state, user, initial_message_sequence=initial_message_sequence)
super().__init__(interface, agent_state, user)
self.first_message_verify_mono = first_message_verify_mono
self.max_memory_rethinks = max_memory_rethinks

View File

@@ -104,7 +104,7 @@ def get_agent_context_window(
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
return server.get_agent_context_window(agent_id=agent_id, actor=actor)
class CreateAgentRequest(CreateAgent):
@@ -138,7 +138,7 @@ def update_agent(
):
"""Update an exsiting agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.update_agent(agent_id, update_agent, actor=actor)
return server.agent_manager.update_agent(agent_id=agent_id, agent_update=update_agent, actor=actor)
@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent")
@@ -149,7 +149,7 @@ def get_tools_from_agent(
):
"""Get tools from an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id)
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).tools
@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent")
@@ -161,7 +161,7 @@ def add_tool_to_agent(
):
"""Add tools to an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, user_id=actor)
@router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent")
@@ -173,7 +173,7 @@ def remove_tool_from_agent(
):
"""Add tools to an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
@router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent")
@@ -232,7 +232,7 @@ def get_agent_in_context_messages(
Retrieve the messages in the context of a specific agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_in_context_messages(agent_id=agent_id, actor=actor)
return server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor)
# TODO: remove? can also get with agent blocks
@@ -429,7 +429,7 @@ def delete_agent_archival_memory(
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=actor)
server.delete_archival_memory(memory_id=memory_id, actor=actor)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
@@ -479,8 +479,9 @@ def update_message(
"""
Update the details of a message associated with an agent.
"""
# TODO: Get rid of agent_id here, it's not really relevant
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request, actor=actor)
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
@router.post(

View File

@@ -40,14 +40,14 @@ from letta.providers import (
VLLMChatCompletionsProvider,
VLLMCompletionsProvider,
)
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
from letta.schemas.agent import AgentState, AgentType, CreateAgent
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job, JobUpdate
from letta.schemas.letta_message import ToolReturnMessage, LettaMessage
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import (
ArchivalMemorySummary,
@@ -376,25 +376,6 @@ class SyncServer(Server):
)
)
def initialize_agent(self, agent_id, actor, interface: Union[AgentInterface, None] = None, initial_message_sequence=None) -> Agent:
"""Initialize an agent from the database"""
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
interface = interface or self.default_interface_factory()
if agent_state.agent_type == AgentType.memgpt_agent:
agent = Agent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence)
elif agent_state.agent_type == AgentType.offline_memory_agent:
agent = OfflineMemoryAgent(
agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence
)
else:
assert initial_message_sequence is None, f"Initial message sequence is not supported for O1Agents"
agent = O1Agent(agent_state=agent_state, interface=interface, user=actor)
# Persist to agent
save_agent(agent)
return agent
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
"""Updated method to load agents from persisted storage"""
agent_lock = self.per_agent_lock_manager.get_lock(agent_id)
@@ -413,11 +394,6 @@ class SyncServer(Server):
else:
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
# Rebuild the system prompt - may be linked to new blocks now
agent.rebuild_system_prompt()
# Persist to agent
save_agent(agent)
return agent
def _step(
@@ -456,7 +432,7 @@ class SyncServer(Server):
)
# save agent after step
save_agent(letta_agent)
# save_agent(letta_agent)
except Exception as e:
logger.error(f"Error in server._step: {e}")
@@ -790,129 +766,23 @@ class SyncServer(Server):
"""Create a new agent using a config"""
# Invoke manager
agent_state = self.agent_manager.create_agent(
return self.agent_manager.create_agent(
agent_create=request,
actor=actor,
)
# create the agent object
if request.initial_message_sequence is not None:
# init_messages = [Message(user_id=user_id, agent_id=agent_state.id, role=message.role, text=message.text) for message in request.initial_message_sequence]
init_messages = []
for message in request.initial_message_sequence:
if message.role == MessageRole.user:
packed_message = system.package_user_message(
user_message=message.text,
)
elif message.role == MessageRole.system:
packed_message = system.package_system_message(
system_message=message.text,
)
else:
raise ValueError(f"Invalid message role: {message.role}")
init_messages.append(Message(role=message.role, text=packed_message, agent_id=agent_state.id))
# init_messages = [Message.dict_to_message(user_id=user_id, agent_id=agent_state.id, openai_message_dict=message.model_dump()) for message in request.initial_message_sequence]
else:
init_messages = None
# initialize the agent (generates initial message list with system prompt)
if interface is None:
interface = self.default_interface_factory()
self.initialize_agent(agent_id=agent_state.id, interface=interface, initial_message_sequence=init_messages, actor=actor)
in_memory_agent_state = self.agent_manager.get_agent_by_id(agent_state.id, actor=actor)
return in_memory_agent_state
# TODO: This is not good!
# TODO: Ideally, this should ALL be handled by the ORM
# TODO: The main blocker here IS the _message updates
def update_agent(
self,
agent_id: str,
request: UpdateAgent,
actor: User,
) -> AgentState:
"""Update the agents core memory block, return the new state"""
# Update agent state in the db first
agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
# TODO: Everything below needs to get removed, no updating anything in memory
# update the system prompt
if request.system:
letta_agent.update_system_prompt(request.system)
# update in-context messages
if request.message_ids:
# This means the user is trying to change what messages are in the message buffer
# Internally this requires (1) pulling from recall,
# then (2) setting the attributes ._messages and .state.message_ids
letta_agent.set_message_buffer(message_ids=request.message_ids)
letta_agent.update_state()
return agent_state
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
"""Get tools from an existing agent"""
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return letta_agent.agent_state.tools
def add_tool_to_agent(
self,
agent_id: str,
tool_id: str,
user_id: str,
):
"""Add tools from an existing agent"""
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
return agent_state
def remove_tool_from_agent(
self,
agent_id: str,
tool_id: str,
user_id: str,
):
"""Remove tools from an existing agent"""
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
return agent_state
# convert name->id
# TODO: These can be moved to agent_manager
def get_agent_memory(self, agent_id: str, actor: User) -> Memory:
"""Return the memory of an agent (core memory)"""
agent = self.load_agent(agent_id=agent_id, actor=actor)
return agent.agent_state.memory
return self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).memory
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
agent = self.load_agent(agent_id=agent_id, actor=actor)
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
agent = self.load_agent(agent_id=agent_id, actor=actor)
return RecallMemorySummary(size=len(agent.message_manager))
def get_in_context_messages(self, agent_id: str, actor: User) -> List[Message]:
"""Get the in-context messages in the agent's memory"""
# Get the agent object (loaded in memory)
agent = self.load_agent(agent_id=agent_id, actor=actor)
return agent._messages
return RecallMemorySummary(size=self.message_manager.size(actor=actor, agent_id=agent_id))
def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[Passage]:
"""Paginated query of all messages in agent archival memory"""
@@ -947,24 +817,17 @@ class SyncServer(Server):
def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
# Insert into archival memory
passages = self.passage_manager.insert_passage(
agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor
)
save_agent(letta_agent)
# TODO: @mindy look at moving this to agent_manager to avoid above extra call
passages = self.passage_manager.insert_passage(agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor)
return passages
def delete_archival_memory(self, agent_id: str, memory_id: str, actor: User):
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
# Delete by ID
def delete_archival_memory(self, memory_id: str, actor: User):
# TODO check if it exists first, and throw error if not
letta_agent.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
# TODO: @mindy make this return the deleted passage instead
self.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
# TODO: return archival memory
@@ -1042,9 +905,8 @@ class SyncServer(Server):
# update the block
self.block_manager.update_block(block_id=block.id, block_update=BlockUpdate(value=value), actor=actor)
# load agent
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return letta_agent.agent_state.memory
# rebuild system prompt for agent, potentially changed
return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory
def delete_source(self, source_id: str, actor: User):
"""Delete a data source"""
@@ -1214,36 +1076,11 @@ class SyncServer(Server):
return success
def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate, actor: User) -> Message:
def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message:
"""Update the details of a message associated with an agent"""
# Get the current message
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
response = letta_agent.update_message(message_id=message_id, request=request)
save_agent(letta_agent)
return response
def rewrite_agent_message(self, agent_id: str, new_text: str, actor: User) -> Message:
# Get the current message
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
response = letta_agent.rewrite_message(new_text=new_text)
save_agent(letta_agent)
return response
def rethink_agent_message(self, agent_id: str, new_thought: str, actor: User) -> Message:
# Get the current message
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
response = letta_agent.rethink_message(new_thought=new_thought)
save_agent(letta_agent)
return response
def retry_agent_message(self, agent_id: str, actor: User) -> List[Message]:
# Get the current message
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
response = letta_agent.retry_message()
save_agent(letta_agent)
return response
return self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
def get_organization_or_default(self, org_id: Optional[str]) -> Organization:
"""Get the organization object for org_id if it exists, otherwise return the default organization object"""
@@ -1331,15 +1168,7 @@ class SyncServer(Server):
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
"""Add a new embedding model"""
def get_agent_context_window(
self,
user_id: str,
agent_id: str,
) -> ContextWindowOverview:
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
# Get the current message
def get_agent_context_window(self, agent_id: str, actor: User) -> ContextWindowOverview:
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return letta_agent.get_context_window()

View File

@@ -20,6 +20,7 @@ from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
@@ -28,12 +29,17 @@ from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
_process_relationship,
_process_tags,
check_supports_structured_output,
compile_system_message,
derive_system_message,
initialize_message_sequence,
package_initial_message_sequence,
)
from letta.services.message_manager import MessageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.settings import settings
from letta.utils import enforce_types
from letta.utils import enforce_types, get_utc_time, united_diff
logger = get_logger(__name__)
@@ -49,6 +55,7 @@ class AgentManager:
self.block_manager = BlockManager()
self.tool_manager = ToolManager()
self.source_manager = SourceManager()
self.message_manager = MessageManager()
# ======================================================================================================================
# Basic CRUD operations
@@ -64,6 +71,10 @@ class AgentManager:
if not agent_create.llm_config or not agent_create.embedding_config:
raise ValueError("llm_config and embedding_config are required")
# Check tool rules are valid
if agent_create.tool_rules:
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
# create blocks (note: cannot be linked into the agent_id is created)
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
for create_block in agent_create.memory_blocks:
@@ -88,7 +99,8 @@ class AgentManager:
# Remove duplicates
tool_ids = list(set(tool_ids))
return self._create_agent(
# Create the agent
agent_state = self._create_agent(
name=agent_create.name,
system=system,
agent_type=agent_create.agent_type,
@@ -104,6 +116,35 @@ class AgentManager:
actor=actor,
)
# TODO: See if we can merge this into the above SQL create call for performance reasons
# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence(
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
)
if agent_create.initial_message_sequence is not None:
# We always need the system prompt up front
system_message_obj = PydanticMessage.dict_to_message(
agent_id=agent_state.id,
user_id=agent_state.created_by_id,
model=agent_state.llm_config.model,
openai_message_dict=init_messages[0],
)
# Don't use anything else in the pregen sequence, instead use the provided sequence
init_messages = [system_message_obj]
init_messages.extend(
package_initial_message_sequence(agent_state.id, agent_create.initial_message_sequence, agent_state.llm_config.model, actor)
)
else:
init_messages = [
PydanticMessage.dict_to_message(
agent_id=agent_state.id, user_id=agent_state.created_by_id, model=agent_state.llm_config.model, openai_message_dict=msg
)
for msg in init_messages
]
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
@enforce_types
def _create_agent(
self,
@@ -149,6 +190,16 @@ class AgentManager:
@enforce_types
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor)
# Rebuild the system prompt if it's different
if agent_update.system and agent_update.system != agent_state.system:
agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False)
return agent_state
@enforce_types
def _update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
"""
Update an existing agent.
@@ -247,6 +298,105 @@ class AgentManager:
agent.hard_delete(session)
return agent_state
# ======================================================================================================================
# In Context Messages Management
# ======================================================================================================================
# TODO: There are several assumptions here that are not explicitly checked
# TODO: 1) These message ids are valid
# TODO: 2) These messages are ordered from oldest to newest
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
@enforce_types
def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
@enforce_types
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
@enforce_types
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
"""Rebuilds the system message with the latest memory object and any shared memory block updates
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
"""
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
curr_system_message = self.get_system_message(
agent_id=agent_id, actor=actor
) # this is the system + memory bank, not just the system prompt
curr_system_message_openai = curr_system_message.to_openai_dict()
# note: we only update the system prompt if the core memory is changed
# this means that the archival/recall memory statistics may be someout out of date
curr_memory_str = agent_state.memory.compile()
if curr_memory_str in curr_system_message_openai["content"] and not force:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.info(
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
)
return agent_state
# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False
if update_timestamp:
memory_edit_timestamp = get_utc_time()
else:
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
memory_edit_timestamp = curr_system_message.created_at
# update memory (TODO: potentially update recall/archival stats separately)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
)
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
if len(diff) > 0: # there was a diff
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Swap the system message out (only if there is a diff)
message = PydanticMessage.dict_to_message(
agent_id=agent_id,
user_id=actor.id,
model=agent_state.llm_config.model,
openai_message_dict={"role": "system", "content": new_system_message_str},
)
message = self.message_manager.create_message(message, actor=actor)
message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system)
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
else:
return agent_state
@enforce_types
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
@enforce_types
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
@enforce_types
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
@enforce_types
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
messages = self.message_manager.create_many_messages(messages, actor=actor)
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
message_ids += [m.id for m in messages]
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
# ======================================================================================================================
# Source Management
# ======================================================================================================================

View File

@@ -1,10 +1,21 @@
from typing import List, Optional
import datetime
from typing import List, Literal, Optional
from letta import system
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS
from letta.helpers import ToolRulesSolver
from letta.orm.agent import Agent as AgentModel
from letta.orm.agents_tags import AgentsTags
from letta.orm.errors import NoResultFound
from letta.prompts import gpt_system
from letta.schemas.agent import AgentType
from letta.schemas.agent import AgentState, AgentType
from letta.schemas.enums import MessageRole
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate
from letta.schemas.tool_rule import ToolRule
from letta.schemas.user import User
from letta.system import get_initial_boot_messages, get_login_event
from letta.utils import get_local_time
# Static methods
@@ -88,3 +99,162 @@ def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
raise ValueError(f"Invalid agent type: {agent_type}")
return system
# TODO: This code is kind of wonky and deserves a rewrite
def compile_memory_metadata_block(
memory_edit_timestamp: datetime.datetime, previous_message_count: int = 0, archival_memory_size: int = 0
) -> str:
# Put the timestamp in the local timezone (mimicking get_local_time())
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
memory_metadata_block = "\n".join(
[
f"### Memory [last modified: {timestamp_str}]",
f"{previous_message_count} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{archival_memory_size} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
]
)
return memory_metadata_block
def compile_system_message(
system_prompt: str,
in_context_memory: Memory,
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int = 0,
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
The base system message may be templated, in which case we need to render the variables.
The following are reserved variables:
- CORE_MEMORY: the in-context memory of the LLM
"""
if user_defined_variables is not None:
# TODO eventually support the user defining their own variables to inject
raise NotImplementedError
else:
variables = {}
# Add the protected memory variable
if IN_CONTEXT_MEMORY_KEYWORD in variables:
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
else:
# TODO should this all put into the memory.__repr__ function?
memory_metadata_string = compile_memory_metadata_block(
memory_edit_timestamp=in_context_memory_last_edit,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
)
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
# Add to the variables list to inject
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
if template_format == "f-string":
# Catch the special case where the system prompt is unformatted
if append_icm_if_missing:
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
if memory_variable_string not in system_prompt:
# In this case, append it to the end to make sure memory is still injected
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
system_prompt += "\n" + memory_variable_string
# render the variables using the built-in templater
try:
formatted_prompt = system_prompt.format_map(variables)
except Exception as e:
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
else:
# TODO support for mustache and jinja2
raise NotImplementedError(template_format)
return formatted_prompt
def initialize_message_sequence(
agent_state: AgentState,
memory_edit_timestamp: Optional[datetime.datetime] = None,
include_initial_boot_message: bool = True,
previous_message_count: int = 0,
archival_memory_size: int = 0,
) -> List[dict]:
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
full_system_message = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
user_defined_variables=None,
append_icm_if_missing=True,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
)
first_user_message = get_login_event() # event letting Letta know the user just logged in
if include_initial_boot_message:
if agent_state.llm_config.model is not None and "gpt-3.5" in agent_state.llm_config.model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
messages = (
[
{"role": "system", "content": full_system_message},
]
+ initial_boot_messages
+ [
{"role": "user", "content": first_user_message},
]
)
else:
messages = [
{"role": "system", "content": full_system_message},
{"role": "user", "content": first_user_message},
]
return messages
def package_initial_message_sequence(
agent_id: str, initial_message_sequence: List[MessageCreate], model: str, actor: User
) -> List[Message]:
# create the agent object
init_messages = []
for message_create in initial_message_sequence:
if message_create.role == MessageRole.user:
packed_message = system.package_user_message(
user_message=message_create.text,
)
elif message_create.role == MessageRole.system:
packed_message = system.package_system_message(
system_message=message_create.text,
)
else:
raise ValueError(f"Invalid message role: {message_create.role}")
init_messages.append(
Message(role=message_create.role, text=packed_message, organization_id=actor.organization_id, agent_id=agent_id, model=model)
)
return init_messages
def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) -> bool:
if model not in STRUCTURED_OUTPUT_MODELS:
if len(ToolRulesSolver(tool_rules=tool_rules).init_tool_rules) > 1:
raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.")
return False
else:
return True

View File

@@ -28,6 +28,21 @@ class MessageManager:
except NoResultFound:
return None
@enforce_types
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
"""Fetch messages by ID and return them in the requested order."""
with self.session_maker() as session:
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id)
if len(results) != len(message_ids):
raise NoResultFound(
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
)
# Sort results directly based on message_ids
result_dict = {msg.id: msg.to_pydantic() for msg in results}
return [result_dict[msg_id] for msg_id in message_ids]
@enforce_types
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
"""Create a new message."""

View File

@@ -15,20 +15,16 @@ from letta.config import LettaConfig
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.embeddings import embedding_model
from letta.errors import (
InvalidToolCallError,
InvalidInnerMonologueError,
MissingToolCallError,
InvalidToolCallError,
MissingInnerMonologueError,
MissingToolCallError,
)
from letta.llm_api.llm_api_tools import create
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import (
ToolCallMessage,
ReasoningMessage,
LettaMessage,
)
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, ToolCallMessage
from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
@@ -78,7 +74,13 @@ def setup_agent(
memory = ChatMemory(human=memory_human_str, persona=memory_persona_str)
agent_state = client.create_agent(
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules, include_base_tools=include_base_tools,
name=agent_uuid,
llm_config=llm_config,
embedding_config=embedding_config,
memory=memory,
tool_ids=tool_ids,
tool_rules=tool_rules,
include_base_tools=include_base_tools,
)
return agent_state
@@ -105,12 +107,13 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet
agent_state = setup_agent(client, filename)
full_agent_state = client.get_agent(agent_state.id)
messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user)
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
response = create(
llm_config=agent_state.llm_config,
user_id=str(uuid.UUID(int=1)), # dummy user_id
messages=agent._messages,
messages=messages,
functions=[t.json_schema for t in agent.agent_state.tools],
)
@@ -412,9 +415,7 @@ def assert_invoked_function_call(messages: Sequence[LettaMessage], function_name
# Found it, do nothing
return
raise MissingToolCallError(
messages=messages, explanation=f"No messages were found invoking function call with name: {function_name}"
)
raise MissingToolCallError(messages=messages, explanation=f"No messages were found invoking function call with name: {function_name}")
def assert_inner_monologue_is_present_and_valid(messages: List[LettaMessage]) -> None:

View File

@@ -1,13 +1,16 @@
import json
import os
import uuid
from typing import List
import pytest
from letta import create_client
from letta.agent import Agent
from letta.client.client import LocalClient
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.streaming_interface import StreamingRefreshCLIInterface
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
from tests.helpers.utils import cleanup
@@ -16,6 +19,110 @@ from tests.helpers.utils import cleanup
LLM_CONFIG_DIR = "tests/configs/llm_model_configs"
SUMMARY_KEY_PHRASE = "The following is a summary"
test_agent_name = f"test_client_{str(uuid.uuid4())}"
# TODO: these tests should include looping through LLM providers, since behavior may vary across providers
# TODO: these tests should add function calls into the summarized message sequence:W
@pytest.fixture(scope="module")
def client():
client = create_client()
# client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@pytest.fixture(scope="module")
def agent_state(client):
# Generate uuid for agent name for this example
agent_state = client.create_agent(name=test_agent_name)
yield agent_state
client.delete_agent(agent_state.id)
def test_summarize_messages_inplace(client, agent_state, mock_e2b_api_key_none):
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
# First send a few messages (5)
response = client.user_message(
agent_id=agent_state.id,
message="Hey, how's it going? What do you think about this whole shindig",
).messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(
agent_id=agent_state.id,
message="Any thoughts on the meaning of life?",
).messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?").messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(
agent_id=agent_state.id,
message="Would you be surprised to learn that you're actually conversing with an AI right now?",
).messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
# reload agent object
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
agent_obj.summarize_messages_inplace()
def test_auto_summarize(client, mock_e2b_api_key_none):
"""Test that the summarizer triggers by itself"""
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
small_context_llm_config.context_window = 4000
small_agent_state = client.create_agent(
name="small_context_agent",
llm_config=small_context_llm_config,
)
try:
def summarize_message_exists(messages: List[Message]) -> bool:
for message in messages:
if message.text and "The following is a summary of the previous" in message.text:
print(f"Summarize message found after {message_count} messages: \n {message.text}")
return True
return False
MAX_ATTEMPTS = 10
message_count = 0
while True:
# send a message
response = client.user_message(
agent_id=small_agent_state.id,
message="What is the meaning of life?",
)
message_count += 1
print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------")
# check if the summarize message is inside the messages
assert isinstance(client, LocalClient), "Test only works with LocalClient"
in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=small_agent_state.id, actor=client.user)
print("SUMMARY", summarize_message_exists(in_context_messages))
if summarize_message_exists(in_context_messages):
break
if message_count > MAX_ATTEMPTS:
raise Exception(f"Summarize message not found after {message_count} messages")
finally:
client.delete_agent(small_agent_state.id)
@pytest.mark.parametrize(
"config_filename",
@@ -69,4 +176,5 @@ def test_summarizer(config_filename):
# Invoke a summarize
letta_agent.summarize_messages_inplace(preserve_last_N_messages=False)
assert SUMMARY_KEY_PHRASE in letta_agent.messages[1]["content"], f"Test failed for config: {config_filename}"
in_context_messages = client.get_in_context_messages(agent_state.id)
assert SUMMARY_KEY_PHRASE in in_context_messages[1].text, f"Test failed for config: {config_filename}"

View File

@@ -18,20 +18,22 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
LettaMessage,
ReasoningMessage,
SystemMessage,
ToolCallMessage,
ToolReturnMessage,
ReasoningMessage,
LettaMessage,
SystemMessage,
UserMessage,
)
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.services.helpers.agent_manager_helper import initialize_message_sequence
from letta.services.organization_manager import OrganizationManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings
from letta.utils import get_utc_time
from tests.helpers.client_helper import upload_file_using_client
# from tests.utils import create_config
@@ -602,18 +604,11 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
If we pass in a non-empty list, we should get that sequence
If we pass in an empty list, we should get an empty sequence
"""
from letta.agent import initialize_message_sequence
from letta.utils import get_utc_time
# The reference initial message sequence:
reference_init_messages = initialize_message_sequence(
model=agent.llm_config.model,
system=agent.system,
agent_id=agent.id,
memory=agent.memory,
agent_state=agent,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
actor=default_user,
)
# system, login message, send_message test, send_message receipt

View File

@@ -259,10 +259,10 @@ def test_recall_memory(client: LocalClient, agent: AgentState):
assert exists
# get in-context messages
messages = client.get_in_context_messages(agent.id)
in_context_messages = client.get_in_context_messages(agent.id)
exists = False
for m in messages:
if message_str in str(m):
for m in in_context_messages:
if message_str in m.text:
exists = True
assert exists

View File

@@ -370,8 +370,8 @@ def other_tool(server: SyncServer, default_user, default_organization):
@pytest.fixture
def sarah_agent(server: SyncServer, default_user, default_organization):
"""Fixture to create and return a sample agent within the default organization."""
agent_state = server.create_agent(
request=CreateAgent(
agent_state = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="sarah_agent",
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
@@ -385,8 +385,8 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
@pytest.fixture
def charles_agent(server: SyncServer, default_user, default_organization):
"""Fixture to create and return a sample agent within the default organization."""
agent_state = server.create_agent(
request=CreateAgent(
agent_state = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="charles_agent",
memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")],
llm_config=LLMConfig.default_config("gpt-4"),
@@ -503,6 +503,54 @@ def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixt
assert len(list_agents) == 0
def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block):
memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]
create_agent_request = CreateAgent(
system="test system",
memory_blocks=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
block_ids=[default_block.id],
tags=["a", "b"],
description="test_description",
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
)
agent_state = server.agent_manager.create_agent(
create_agent_request,
actor=default_user,
)
assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 2
init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user)
# Check that the system appears in the first initial message
assert create_agent_request.system in init_messages[0].text
assert create_agent_request.memory_blocks[0].value in init_messages[0].text
# Check that the second message is the passed in initial message seq
assert create_agent_request.initial_message_sequence[0].role == init_messages[1].role
assert create_agent_request.initial_message_sequence[0].text in init_messages[1].text
def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block):
memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")]
create_agent_request = CreateAgent(
system="test system",
memory_blocks=memory_blocks,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
block_ids=[default_block.id],
tags=["a", "b"],
description="test_description",
)
agent_state = server.agent_manager.create_agent(
create_agent_request,
actor=default_user,
)
assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 4
init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user)
# Check that the system appears in the first initial message
assert create_agent_request.system in init_messages[0].text
assert create_agent_request.memory_blocks[0].value in init_messages[0].text
def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user):
agent, _ = comprehensive_test_agent_fixture
update_agent_request = UpdateAgent(
@@ -794,8 +842,8 @@ def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent,
def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization):
"""Test pagination when listing agents by tags."""
# Create first agent
agent1 = server.create_agent(
request=CreateAgent(
agent1 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="agent1",
tags=["pagination_test", "tag1"],
llm_config=LLMConfig.default_config("gpt-4"),
@@ -809,8 +857,8 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
time.sleep(CREATE_DELAY_SQLITE) # Ensure distinct created_at timestamps
# Create second agent
agent2 = server.create_agent(
request=CreateAgent(
agent2 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="agent2",
tags=["pagination_test", "tag2"],
llm_config=LLMConfig.default_config("gpt-4"),
@@ -1565,6 +1613,15 @@ def create_test_messages(server: SyncServer, base_message: PydanticMessage, defa
return messages
def test_get_messages_by_ids(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
"""Test basic message listing with limit"""
messages = create_test_messages(server, hello_world_message_fixture, default_user)
message_ids = [m.id for m in messages]
results = server.message_manager.get_messages_by_ids(message_ids=message_ids, actor=default_user)
assert sorted(message_ids) == sorted([r.id for r in results])
def test_message_listing_basic(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
"""Test basic message listing with limit"""
create_test_messages(server, hello_world_message_fixture, default_user)

View File

@@ -10,11 +10,11 @@ from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import (
LettaMessage,
ReasoningMessage,
SystemMessage,
ToolCallMessage,
ToolReturnMessage,
ReasoningMessage,
LettaMessage,
SystemMessage,
UserMessage,
)
from letta.schemas.user import User
@@ -507,76 +507,9 @@ def test_get_archival_memory(server, user_id, agent_id):
assert len(passage_none) == 0
def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
"""Test the /rethink, /rewrite, and /retry commands in the CLI
- "rethink" replaces the inner thoughts of the last assistant message
- "rewrite" replaces the text of the last assistant message
- "retry" retries the last assistant message
"""
actor = server.user_manager.get_user_or_default(user_id)
# Send an initial message
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# Grab the raw Agent object
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
assert letta_agent._messages[-1].role == MessageRole.tool
assert letta_agent._messages[-2].role == MessageRole.assistant
last_agent_message = letta_agent._messages[-2]
# Try "rethink"
new_thought = "I am thinking about the meaning of life, the universe, and everything. Bananas?"
assert last_agent_message.text is not None and last_agent_message.text != new_thought
server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought, actor=actor)
# Grab the agent object again (make sure it's live)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
assert letta_agent._messages[-1].role == MessageRole.tool
assert letta_agent._messages[-2].role == MessageRole.assistant
last_agent_message = letta_agent._messages[-2]
assert last_agent_message.text == new_thought
# Try "rewrite"
assert last_agent_message.tool_calls is not None
assert last_agent_message.tool_calls[0].function.name == "send_message"
assert last_agent_message.tool_calls[0].function.arguments is not None
args_json = json.loads(last_agent_message.tool_calls[0].function.arguments)
assert "message" in args_json and args_json["message"] is not None and args_json["message"] != ""
new_text = "Why hello there my good friend! Is 42 what you're looking for? Bananas?"
server.rewrite_agent_message(agent_id=agent_id, new_text=new_text, actor=actor)
# Grab the agent object again (make sure it's live)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
assert letta_agent._messages[-1].role == MessageRole.tool
assert letta_agent._messages[-2].role == MessageRole.assistant
last_agent_message = letta_agent._messages[-2]
args_json = json.loads(last_agent_message.tool_calls[0].function.arguments)
assert "message" in args_json and args_json["message"] is not None and args_json["message"] == new_text
# Try retry
server.retry_agent_message(agent_id=agent_id, actor=actor)
# Grab the agent object again (make sure it's live)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
assert letta_agent._messages[-1].role == MessageRole.tool
assert letta_agent._messages[-2].role == MessageRole.assistant
last_agent_message = letta_agent._messages[-2]
# Make sure the inner thoughts changed
assert last_agent_message.text is not None and last_agent_message.text != new_thought
# Make sure the message changed
args_json = json.loads(last_agent_message.tool_calls[0].function.arguments)
print(args_json)
assert "message" in args_json and args_json["message"] is not None and args_json["message"] != new_text
def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str):
"""Test that the context window overview fetch works"""
overview = server.get_agent_context_window(user_id=user_id, agent_id=agent_id)
overview = server.get_agent_context_window(agent_id=agent_id, actor=server.user_manager.get_user_or_default(user_id))
assert overview is not None
# Run some basic checks
@@ -1142,10 +1075,10 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
# Add all the base tools
request.tool_ids = [b.id for b in base_tools]
agent_state = server.update_agent(agent_state.id, request=request, actor=actor)
agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
assert len(agent_state.tools) == len(base_tools)
# Remove one base tool
request.tool_ids = [b.id for b in base_tools[:-2]]
agent_state = server.update_agent(agent_state.id, request=request, actor=actor)
agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor)
assert len(agent_state.tools) == len(base_tools) - 2

View File

@@ -1,133 +0,0 @@
import uuid
from typing import List
from letta import create_client
from letta.client.client import LocalClient
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from .utils import wipe_config
# test_agent_id = "test_agent"
test_agent_name = f"test_client_{str(uuid.uuid4())}"
client = None
agent_obj = None
# TODO: these tests should include looping through LLM providers, since behavior may vary across providers
# TODO: these tests should add function calls into the summarized message sequence:W
def create_test_agent():
"""Create a test agent that we can call functions on"""
wipe_config()
global client
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
agent_state = client.create_agent(
name=test_agent_name,
)
global agent_obj
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
def test_summarize_messages_inplace(mock_e2b_api_key_none):
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
global client
global agent_obj
if agent_obj is None:
create_test_agent()
assert agent_obj is not None, "Run create_agent test first"
assert client is not None, "Run create_agent test first"
# First send a few messages (5)
response = client.user_message(
agent_id=agent_obj.agent_state.id,
message="Hey, how's it going? What do you think about this whole shindig",
).messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(
agent_id=agent_obj.agent_state.id,
message="Any thoughts on the meaning of life?",
).messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(agent_id=agent_obj.agent_state.id, message="Does the number 42 ring a bell?").messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(
agent_id=agent_obj.agent_state.id,
message="Would you be surprised to learn that you're actually conversing with an AI right now?",
).messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
# reload agent object
agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id, actor=client.user)
agent_obj.summarize_messages_inplace()
print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}")
# response = client.run_command(agent_id=agent_obj.agent_state.id, command="summarize")
def test_auto_summarize(mock_e2b_api_key_none):
"""Test that the summarizer triggers by itself"""
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
small_context_llm_config = LLMConfig.default_config("gpt-4")
# default system prompt + funcs lead to ~2300 tokens, after one message it's at 2523 tokens
SMALL_CONTEXT_WINDOW = 4000
small_context_llm_config.context_window = SMALL_CONTEXT_WINDOW
agent_state = client.create_agent(
name="small_context_agent",
llm_config=small_context_llm_config,
)
try:
def summarize_message_exists(messages: List[Message]) -> bool:
for message in messages:
if message.text and "The following is a summary of the previous" in message.text:
print(f"Summarize message found after {message_count} messages: \n {message.text}")
return True
return False
MAX_ATTEMPTS = 5
message_count = 0
while True:
# send a message
response = client.user_message(
agent_id=agent_state.id,
message="What is the meaning of life?",
)
message_count += 1
print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------")
# check if the summarize message is inside the messages
assert isinstance(client, LocalClient), "Test only works with LocalClient"
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
print("SUMMARY", summarize_message_exists(agent_obj._messages))
if summarize_message_exists(agent_obj._messages):
break
if message_count > MAX_ATTEMPTS:
raise Exception(f"Summarize message not found after {message_count} messages")
finally:
client.delete_agent(agent_state.id)