feat: message orm migration (#2144)
Co-authored-by: Mindy Long <mindy@letta.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
5
.github/workflows/check_for_new_prints.yml
vendored
5
.github/workflows/check_for_new_prints.yml
vendored
@@ -31,6 +31,11 @@ jobs:
|
||||
|
||||
# Check each changed Python file
|
||||
while IFS= read -r file; do
|
||||
if [ "$file" == "letta/main.py" ]; then
|
||||
echo "Skipping $file for print statement checks."
|
||||
continue
|
||||
fi
|
||||
|
||||
if [ -f "$file" ]; then
|
||||
echo "Checking $file for new print statements..."
|
||||
|
||||
|
||||
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -31,6 +31,7 @@ jobs:
|
||||
- "test_utils.py"
|
||||
- "test_tool_schema_parsing.py"
|
||||
- "test_v1_routes.py"
|
||||
- "test_offline_memory_agent.py"
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
@@ -133,4 +134,4 @@ jobs:
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
run: |
|
||||
poetry run pytest -s -vv -k "not test_v1_routes.py and not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests
|
||||
poetry run pytest -s -vv -k "not test_offline_memory_agent.py and not test_v1_routes.py and not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests
|
||||
|
||||
63
alembic/versions/95badb46fdf9_migrate_message_to_orm.py
Normal file
63
alembic/versions/95badb46fdf9_migrate_message_to_orm.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Migrate message to orm
|
||||
|
||||
Revision ID: 95badb46fdf9
|
||||
Revises: 3c683a662c82
|
||||
Create Date: 2024-12-05 14:02:04.163150
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "95badb46fdf9"
|
||||
down_revision: Union[str, None] = "08b2f8225812"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("messages", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("messages", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
op.add_column("messages", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||
op.add_column("messages", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||
op.add_column("messages", sa.Column("organization_id", sa.String(), nullable=True))
|
||||
# Populate `organization_id` based on `user_id`
|
||||
# Use a raw SQL query to update the organization_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE messages
|
||||
SET organization_id = users.organization_id
|
||||
FROM users
|
||||
WHERE messages.user_id = users.id
|
||||
"""
|
||||
)
|
||||
op.alter_column("messages", "organization_id", nullable=False)
|
||||
op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
||||
op.drop_index("message_idx_user", table_name="messages")
|
||||
op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"])
|
||||
op.create_foreign_key(None, "messages", "organizations", ["organization_id"], ["id"])
|
||||
op.drop_column("messages", "user_id")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("messages", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.drop_constraint(None, "messages", type_="foreignkey")
|
||||
op.drop_constraint(None, "messages", type_="foreignkey")
|
||||
op.create_index("message_idx_user", "messages", ["user_id", "agent_id"], unique=False)
|
||||
op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True)
|
||||
op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
op.drop_column("messages", "organization_id")
|
||||
op.drop_column("messages", "_last_updated_by_id")
|
||||
op.drop_column("messages", "_created_by_id")
|
||||
op.drop_column("messages", "is_deleted")
|
||||
op.drop_column("messages", "updated_at")
|
||||
# ### end Alembic commands ###
|
||||
@@ -97,7 +97,7 @@ def upgrade() -> None:
|
||||
sa.Column("text", sa.String(), nullable=True),
|
||||
sa.Column("model", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("tool_calls", letta.metadata.ToolCallColumn(), nullable=True),
|
||||
sa.Column("tool_calls", letta.orm.message.ToolCallColumn(), nullable=True),
|
||||
sa.Column("tool_call_id", sa.String(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta.constants import (
|
||||
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
MESSAGE_SUMMARY_WARNING_FRAC,
|
||||
O1_BASE_TOOLS,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
)
|
||||
from letta.errors import LLMError
|
||||
@@ -27,10 +28,9 @@ from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
|
||||
from letta.memory import ArchivalMemory, EmbeddingArchivalMemory, summarize_messages
|
||||
from letta.metadata import MetadataStore
|
||||
from letta.orm import User
|
||||
from letta.persistence_manager import LocalStateManager
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -49,7 +49,9 @@ from letta.schemas.passage import Passage
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.user_manager import UserManager
|
||||
@@ -80,9 +82,11 @@ from letta.utils import (
|
||||
|
||||
|
||||
def compile_memory_metadata_block(
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
memory_edit_timestamp: datetime.datetime,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
recall_memory: Optional[RecallMemory] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
) -> str:
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
|
||||
@@ -91,7 +95,7 @@ def compile_memory_metadata_block(
|
||||
memory_metadata_block = "\n".join(
|
||||
[
|
||||
f"### Memory [last modified: {timestamp_str}]",
|
||||
f"{recall_memory.count() if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{archival_memory.count() if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
]
|
||||
@@ -101,10 +105,12 @@ def compile_memory_metadata_block(
|
||||
|
||||
def compile_system_message(
|
||||
system_prompt: str,
|
||||
agent_id: str,
|
||||
in_context_memory: Memory,
|
||||
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
||||
actor: PydanticUser,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
recall_memory: Optional[RecallMemory] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
@@ -129,9 +135,11 @@ def compile_system_message(
|
||||
else:
|
||||
# TODO should this all put into the memory.__repr__ function?
|
||||
memory_metadata_string = compile_memory_metadata_block(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
archival_memory=archival_memory,
|
||||
recall_memory=recall_memory,
|
||||
message_manager=message_manager,
|
||||
)
|
||||
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
||||
|
||||
@@ -164,9 +172,11 @@ def compile_system_message(
|
||||
def initialize_message_sequence(
|
||||
model: str,
|
||||
system: str,
|
||||
agent_id: str,
|
||||
memory: Memory,
|
||||
actor: PydanticUser,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
recall_memory: Optional[RecallMemory] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
||||
include_initial_boot_message: bool = True,
|
||||
) -> List[dict]:
|
||||
@@ -177,11 +187,13 @@ def initialize_message_sequence(
|
||||
# system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
|
||||
# )
|
||||
full_system_message = compile_system_message(
|
||||
agent_id=agent_id,
|
||||
system_prompt=system,
|
||||
in_context_memory=memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
actor=actor,
|
||||
archival_memory=archival_memory,
|
||||
recall_memory=recall_memory,
|
||||
message_manager=message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
)
|
||||
@@ -282,7 +294,8 @@ class Agent(BaseAgent):
|
||||
self.interface = interface
|
||||
|
||||
# Create the persistence manager object based on the AgentState info
|
||||
self.persistence_manager = LocalStateManager(agent_state=self.agent_state)
|
||||
self.archival_memory = EmbeddingArchivalMemory(agent_state)
|
||||
self.message_manager = MessageManager()
|
||||
|
||||
# State needed for heartbeat pausing
|
||||
self.pause_heartbeats_start = None
|
||||
@@ -309,9 +322,11 @@ class Agent(BaseAgent):
|
||||
init_messages = initialize_message_sequence(
|
||||
model=self.model,
|
||||
system=self.agent_state.system,
|
||||
agent_id=self.agent_state.id,
|
||||
memory=self.agent_state.memory,
|
||||
actor=self.user,
|
||||
archival_memory=None,
|
||||
recall_memory=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
)
|
||||
@@ -333,8 +348,10 @@ class Agent(BaseAgent):
|
||||
model=self.model,
|
||||
system=self.agent_state.system,
|
||||
memory=self.agent_state.memory,
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
archival_memory=None,
|
||||
recall_memory=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
)
|
||||
@@ -356,7 +373,6 @@ class Agent(BaseAgent):
|
||||
|
||||
# Put the messages inside the message buffer
|
||||
self.messages_total = 0
|
||||
# self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])
|
||||
self._append_to_messages(added_messages=init_messages_objs)
|
||||
self._validate_message_buffer_is_utc()
|
||||
|
||||
@@ -413,7 +429,10 @@ class Agent(BaseAgent):
|
||||
# TODO: need to have an AgentState object that actually has full access to the block data
|
||||
# this is because the sandbox tools need to be able to access block.value to edit this data
|
||||
try:
|
||||
if function_name in BASE_TOOLS:
|
||||
# TODO: This is NO BUENO
|
||||
# TODO: Matching purely by names is extremely problematic, users can create tools with these names and run them in the agent loop
|
||||
# TODO: We will have probably have to match the function strings exactly for safety
|
||||
if function_name in BASE_TOOLS or function_name in O1_BASE_TOOLS:
|
||||
# base tools are allowed to access the `Agent` object and run on the database
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response = function_to_call(**function_args)
|
||||
@@ -474,7 +493,7 @@ class Agent(BaseAgent):
|
||||
# Pull the message objects from the database
|
||||
message_objs = []
|
||||
for msg_id in message_ids:
|
||||
msg_obj = self.persistence_manager.recall_memory.storage.get(msg_id)
|
||||
msg_obj = self.message_manager.get_message_by_id(msg_id, actor=self.user)
|
||||
if msg_obj:
|
||||
if isinstance(msg_obj, Message):
|
||||
message_objs.append(msg_obj)
|
||||
@@ -522,16 +541,13 @@ class Agent(BaseAgent):
|
||||
|
||||
def _trim_messages(self, num):
|
||||
"""Trim messages from the front, not including the system message"""
|
||||
self.persistence_manager.trim_messages(num)
|
||||
|
||||
new_messages = [self._messages[0]] + self._messages[num:]
|
||||
self._messages = new_messages
|
||||
|
||||
def _prepend_to_messages(self, added_messages: List[Message]):
|
||||
"""Wrapper around self.messages.prepend to allow additional calls to a state/persistence manager"""
|
||||
assert all([isinstance(msg, Message) for msg in added_messages])
|
||||
|
||||
self.persistence_manager.prepend_to_messages(added_messages)
|
||||
self.message_manager.create_many_messages(added_messages, actor=self.user)
|
||||
|
||||
new_messages = [self._messages[0]] + added_messages + self._messages[1:] # prepend (no system)
|
||||
self._messages = new_messages
|
||||
@@ -540,8 +556,7 @@ class Agent(BaseAgent):
|
||||
def _append_to_messages(self, added_messages: List[Message]):
|
||||
"""Wrapper around self.messages.append to allow additional calls to a state/persistence manager"""
|
||||
assert all([isinstance(msg, Message) for msg in added_messages])
|
||||
|
||||
self.persistence_manager.append_to_messages(added_messages)
|
||||
self.message_manager.create_many_messages(added_messages, actor=self.user)
|
||||
|
||||
# strip extra metadata if it exists
|
||||
# for msg in added_messages:
|
||||
@@ -885,7 +900,6 @@ class Agent(BaseAgent):
|
||||
messages=next_input_message,
|
||||
**kwargs,
|
||||
)
|
||||
step_response.messages
|
||||
heartbeat_request = step_response.heartbeat_request
|
||||
function_failed = step_response.function_failed
|
||||
token_warning = step_response.in_context_memory_warning
|
||||
@@ -1247,7 +1261,7 @@ class Agent(BaseAgent):
|
||||
assert new_system_message_obj.role == "system", new_system_message_obj
|
||||
assert self._messages[0].role == "system", self._messages
|
||||
|
||||
self.persistence_manager.swap_system_message(new_system_message_obj)
|
||||
self.message_manager.create_message(new_system_message_obj, actor=self.user)
|
||||
|
||||
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
|
||||
self._messages = new_messages
|
||||
@@ -1280,11 +1294,13 @@ class Agent(BaseAgent):
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats seperately)
|
||||
new_system_message_str = compile_system_message(
|
||||
agent_id=self.agent_state.id,
|
||||
system_prompt=self.agent_state.system,
|
||||
in_context_memory=self.agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
archival_memory=self.persistence_manager.archival_memory,
|
||||
recall_memory=self.persistence_manager.recall_memory,
|
||||
actor=self.user,
|
||||
archival_memory=self.archival_memory,
|
||||
message_manager=self.message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
)
|
||||
@@ -1370,20 +1386,20 @@ class Agent(BaseAgent):
|
||||
# passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}")
|
||||
|
||||
# insert into agent archival memory
|
||||
self.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
self.archival_memory.storage.insert_many(passages)
|
||||
all_passages += passages
|
||||
|
||||
assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"
|
||||
|
||||
# save destination storage
|
||||
self.persistence_manager.archival_memory.storage.save()
|
||||
self.archival_memory.storage.save()
|
||||
|
||||
# attach to agent
|
||||
source = SourceManager().get_source_by_id(source_id=source_id, actor=user)
|
||||
assert source is not None, f"Source {source_id} not found in metadata store"
|
||||
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
|
||||
|
||||
total_agent_passages = self.persistence_manager.archival_memory.storage.size()
|
||||
total_agent_passages = self.archival_memory.storage.size()
|
||||
|
||||
printd(
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
|
||||
@@ -1392,7 +1408,7 @@ class Agent(BaseAgent):
|
||||
def update_message(self, request: UpdateMessage) -> Message:
|
||||
"""Update the details of a message associated with an agent"""
|
||||
|
||||
message = self.persistence_manager.recall_memory.storage.get(id=request.id)
|
||||
message = self.message_manager.get_message_by_id(message_id=request.id, actor=self.user)
|
||||
if message is None:
|
||||
raise ValueError(f"Message with id {request.id} not found")
|
||||
assert isinstance(message, Message), f"Message is not a Message object: {type(message)}"
|
||||
@@ -1413,10 +1429,10 @@ class Agent(BaseAgent):
|
||||
message.tool_call_id = request.tool_call_id
|
||||
|
||||
# Save the updated message
|
||||
self.persistence_manager.recall_memory.storage.update(record=message)
|
||||
self.message_manager.update_message_by_id(message_id=message.id, message=message, actor=self.user)
|
||||
|
||||
# Return the updated message
|
||||
updated_message = self.persistence_manager.recall_memory.storage.get(id=message.id)
|
||||
updated_message = self.message_manager.get_message_by_id(message_id=message.id, actor=self.user)
|
||||
if updated_message is None:
|
||||
raise ValueError(f"Error persisting message - message with id {request.id} not found")
|
||||
return updated_message
|
||||
@@ -1496,7 +1512,7 @@ class Agent(BaseAgent):
|
||||
deleted_message = self._messages.pop()
|
||||
# then also remove it from recall storage
|
||||
try:
|
||||
self.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id})
|
||||
self.message_manager.delete_message_by_id(deleted_message.id, actor=self.user)
|
||||
popped_messages.append(deleted_message)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Error deleting message {deleted_message.id} from recall memory: {e}")
|
||||
@@ -1522,7 +1538,6 @@ class Agent(BaseAgent):
|
||||
|
||||
def retry_message(self) -> List[Message]:
|
||||
"""Retry / regenerate the last message"""
|
||||
|
||||
self.pop_until_user()
|
||||
user_message = self.pop_message(count=1)[0]
|
||||
assert user_message.text is not None, "User message text is None"
|
||||
@@ -1569,12 +1584,14 @@ class Agent(BaseAgent):
|
||||
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
|
||||
)
|
||||
|
||||
num_archival_memory = self.persistence_manager.archival_memory.storage.size()
|
||||
num_recall_memory = self.persistence_manager.recall_memory.storage.size()
|
||||
num_archival_memory = self.archival_memory.storage.size()
|
||||
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||
external_memory_summary = compile_memory_metadata_block(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
memory_edit_timestamp=get_utc_time(), # dummy timestamp
|
||||
archival_memory=self.persistence_manager.archival_memory,
|
||||
recall_memory=self.persistence_manager.recall_memory,
|
||||
archival_memory=self.archival_memory,
|
||||
message_manager=self.message_manager,
|
||||
)
|
||||
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
||||
|
||||
@@ -1600,7 +1617,7 @@ class Agent(BaseAgent):
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(self._messages),
|
||||
num_archival_memory=num_archival_memory,
|
||||
num_recall_memory=num_recall_memory,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
# top-level information
|
||||
context_window_size_max=self.agent_state.llm_config.context_window,
|
||||
|
||||
@@ -27,13 +27,11 @@ from tqdm import tqdm
|
||||
from letta.agent_store.storage import StorageConnector, TableType
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
|
||||
from letta.metadata import EmbeddingConfigColumn
|
||||
from letta.orm.base import Base
|
||||
from letta.orm.file import FileMetadata as FileMetadataModel
|
||||
|
||||
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.settings import settings
|
||||
|
||||
@@ -69,69 +67,6 @@ class CommonVector(TypeDecorator):
|
||||
return np.frombuffer(value, dtype=np.float32)
|
||||
|
||||
|
||||
class MessageModel(Base):
|
||||
"""Defines data model for storing Message objects"""
|
||||
|
||||
__tablename__ = "messages"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
# Assuming message_id is the primary key
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
|
||||
# openai info
|
||||
role = Column(String, nullable=False)
|
||||
text = Column(String) # optional: can be null if function call
|
||||
model = Column(String) # optional: can be null if LLM backend doesn't require specifying
|
||||
name = Column(String) # optional: multi-agent only
|
||||
|
||||
# tool call request info
|
||||
# if role == "assistant", this MAY be specified
|
||||
# if role != "assistant", this must be null
|
||||
# TODO align with OpenAI spec of multiple tool calls
|
||||
# tool_calls = Column(ToolCallColumn)
|
||||
tool_calls = Column(ToolCallColumn)
|
||||
|
||||
# tool call response info
|
||||
# if role == "tool", then this must be specified
|
||||
# if role != "tool", this must be null
|
||||
tool_call_id = Column(String)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
Index("message_idx_user", user_id, agent_id),
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Message(message_id='{self.id}', text='{self.text}')>"
|
||||
|
||||
def to_record(self):
|
||||
# calls = (
|
||||
# [ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
|
||||
# if self.tool_calls
|
||||
# else None
|
||||
# )
|
||||
# if calls:
|
||||
# assert isinstance(calls[0], ToolCall)
|
||||
if self.tool_calls and len(self.tool_calls) > 0:
|
||||
assert isinstance(self.tool_calls[0], ToolCall), type(self.tool_calls[0])
|
||||
for tool in self.tool_calls:
|
||||
assert isinstance(tool, ToolCall), type(tool)
|
||||
return Message(
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
role=self.role,
|
||||
name=self.name,
|
||||
text=self.text,
|
||||
model=self.model,
|
||||
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
|
||||
tool_calls=self.tool_calls,
|
||||
tool_call_id=self.tool_call_id,
|
||||
created_at=self.created_at,
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
|
||||
class PassageModel(Base):
|
||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||
|
||||
@@ -367,11 +302,6 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
self.db_model = PassageModel
|
||||
if self.config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.uri = self.config.recall_storage_uri
|
||||
self.db_model = MessageModel
|
||||
if self.config.recall_storage_uri is None:
|
||||
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
|
||||
elif table_type == TableType.FILES:
|
||||
self.uri = self.config.metadata_storage_uri
|
||||
self.db_model = FileMetadataModel
|
||||
@@ -490,12 +420,6 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
# get storage URI
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
# TODO: eventually implement URI option
|
||||
self.path = self.config.recall_storage_path
|
||||
if self.path is None:
|
||||
raise ValueError(f"Must specify recall_storage_path in config.")
|
||||
self.db_model = MessageModel
|
||||
elif table_type == TableType.FILES:
|
||||
self.path = self.config.metadata_storage_path
|
||||
if self.path is None:
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
# type: ignore
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
from letta.agent_store.storage import StorageConnector, TableType
|
||||
from letta.config import AgentConfig, LettaConfig
|
||||
from letta.schemas.message import Message, Passage, Record
|
||||
|
||||
""" Initial implementation - not complete """
|
||||
|
||||
|
||||
def get_db_model(table_name: str, table_type: TableType):
|
||||
config = LettaConfig.load()
|
||||
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
# create schema for archival memory
|
||||
class PassageModel(LanceModel):
|
||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||
|
||||
id: uuid.UUID
|
||||
user_id: str
|
||||
text: str
|
||||
file_id: str
|
||||
agent_id: str
|
||||
data_source: str
|
||||
embedding: Vector(config.default_embedding_config.embedding_dim)
|
||||
metadata_: Dict
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
return Passage(
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
file_id=self.file_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
data_source=self.data_source,
|
||||
agent_id=self.agent_id,
|
||||
metadata=self.metadata_,
|
||||
)
|
||||
|
||||
return PassageModel
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
|
||||
class MessageModel(LanceModel):
|
||||
"""Defines data model for storing Message objects"""
|
||||
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming message_id is the primary key
|
||||
id: uuid.UUID
|
||||
user_id: str
|
||||
agent_id: str
|
||||
|
||||
# openai info
|
||||
role: str
|
||||
name: str
|
||||
text: str
|
||||
model: str
|
||||
user: str
|
||||
|
||||
# function info
|
||||
function_name: str
|
||||
function_args: str
|
||||
function_response: str
|
||||
|
||||
embedding = Vector(config.default_embedding_config.embedding_dim)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = datetime
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
return Message(
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
role=self.role,
|
||||
name=self.name,
|
||||
text=self.text,
|
||||
model=self.model,
|
||||
function_name=self.function_name,
|
||||
function_args=self.function_args,
|
||||
function_response=self.function_response,
|
||||
embedding=self.embedding,
|
||||
created_at=self.created_at,
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
"""Create database model for table_name"""
|
||||
return MessageModel
|
||||
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
|
||||
class LanceDBConnector(StorageConnector):
|
||||
"""Storage via LanceDB"""
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
|
||||
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def generate_where_filter(self, filters: Dict) -> str:
|
||||
where_filters = []
|
||||
for key, value in filters.items():
|
||||
where_filters.append(f"{key}={value}")
|
||||
return where_filters.join(" AND ")
|
||||
|
||||
@abstractmethod
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: uuid.UUID) -> Optional[Record]:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, record: Record):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert_many(self, records: List[Record], show_progress=False):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_date(self, start_date, end_date):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_text(self, query):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_table(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
# TODO
|
||||
pass
|
||||
@@ -110,11 +110,6 @@ class StorageConnector:
|
||||
from letta.agent_store.qdrant import QdrantStorageConnector
|
||||
|
||||
return QdrantStorageConnector(table_type, config, user_id, agent_id)
|
||||
# TODO: add back
|
||||
# elif storage_type == "lancedb":
|
||||
# from letta.agent_store.db import LanceDBConnector
|
||||
|
||||
# return LanceDBConnector(agent_config=agent_config, table_type=table_type)
|
||||
|
||||
elif storage_type == "sqlite":
|
||||
from letta.agent_store.db import SQLLiteStorageConnector
|
||||
|
||||
@@ -171,7 +171,6 @@ def run(
|
||||
# printd("State path:", agent_config.save_state_dir())
|
||||
# printd("Persistent manager path:", agent_config.save_persistence_manager_dir())
|
||||
# printd("Index path:", agent_config.save_agent_index_dir())
|
||||
# persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load
|
||||
# TODO: load prior agent state
|
||||
|
||||
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
|
||||
|
||||
@@ -3054,16 +3054,13 @@ class LocalClient(AbstractClient):
|
||||
|
||||
# recall memory
|
||||
|
||||
def get_messages(
|
||||
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000
|
||||
) -> List[Message]:
|
||||
def get_messages(self, agent_id: str, cursor: Optional[str] = None, limit: Optional[int] = 1000) -> List[Message]:
|
||||
"""
|
||||
Get messages from an agent with pagination.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
before (str): Get messages before a certain time
|
||||
after (str): Get messages after a certain time
|
||||
cursor (str): Get messages after a certain time
|
||||
limit (int): Limit number of messages
|
||||
|
||||
Returns:
|
||||
@@ -3074,8 +3071,7 @@ class LocalClient(AbstractClient):
|
||||
return self.server.get_agent_recall_cursor(
|
||||
user_id=self.user_id,
|
||||
agent_id=agent_id,
|
||||
before=before,
|
||||
after=after,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ DEFAULT_PRESET = "memgpt_chat"
|
||||
|
||||
# Base tools that cannot be edited, as they access agent state directly
|
||||
BASE_TOOLS = ["send_message", "conversation_search", "conversation_search_date", "archival_memory_insert", "archival_memory_search"]
|
||||
O1_BASE_TOOLS = ["send_thinking_message", "send_final_message"]
|
||||
# Base memory tools CAN be edited, and are added by default by the server
|
||||
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from letta.agent import Agent
|
||||
@@ -38,7 +39,7 @@ Returns:
|
||||
"""
|
||||
|
||||
|
||||
def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]:
|
||||
def pause_heartbeats(self: "Agent", minutes: int) -> Optional[str]:
|
||||
import datetime
|
||||
|
||||
from letta.constants import MAX_PAUSE_HEARTBEATS
|
||||
@@ -80,7 +81,15 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
|
||||
except:
|
||||
raise ValueError(f"'page' argument must be an integer")
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
results, total = self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
|
||||
# TODO: add paging by page number. currently cursor only works with strings.
|
||||
# original: start=page * count
|
||||
results = self.message_manager.list_user_messages_for_agent(
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
query_text=query,
|
||||
limit=count,
|
||||
)
|
||||
total = len(results)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
@@ -112,10 +121,29 @@ def conversation_search_date(self: "Agent", start_date: str, end_date: str, page
|
||||
page = 0
|
||||
try:
|
||||
page = int(page)
|
||||
if page < 0:
|
||||
raise ValueError
|
||||
except:
|
||||
raise ValueError(f"'page' argument must be an integer")
|
||||
|
||||
# Convert date strings to datetime objects
|
||||
try:
|
||||
start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
except ValueError:
|
||||
raise ValueError("Dates must be in the format 'YYYY-MM-DD'")
|
||||
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
results, total = self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
|
||||
results = self.message_manager.list_user_messages_for_agent(
|
||||
# TODO: add paging by page number. currently cursor only works with strings.
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
start_date=start_datetime,
|
||||
end_date=end_datetime,
|
||||
limit=count,
|
||||
# start_date=start_date, end_date=end_date, limit=count, start=page * count
|
||||
)
|
||||
total = len(results)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
@@ -136,7 +164,7 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.persistence_manager.archival_memory.insert(content)
|
||||
self.archival_memory.insert(content)
|
||||
return None
|
||||
|
||||
|
||||
@@ -163,7 +191,7 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) -
|
||||
except:
|
||||
raise ValueError(f"'page' argument must be an integer")
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
results, total = self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
|
||||
results, total = self.archival_memory.search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
|
||||
@@ -190,8 +190,8 @@ def run_agent_loop(
|
||||
elif user_input.lower() == "/memory":
|
||||
print(f"\nDumping memory contents:\n")
|
||||
print(f"{letta_agent.agent_state.memory.compile()}")
|
||||
print(f"{letta_agent.persistence_manager.archival_memory.compile()}")
|
||||
print(f"{letta_agent.persistence_manager.recall_memory.compile()}")
|
||||
print(f"{letta_agent.archival_memory.compile()}")
|
||||
print(f"{letta_agent.recall_memory.compile()}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/model":
|
||||
|
||||
@@ -67,14 +67,12 @@ def summarize_messages(
|
||||
+ message_sequence_to_summarize[cutoff:]
|
||||
)
|
||||
|
||||
dummy_user_id = agent_state.user_id
|
||||
agent_state.user_id
|
||||
dummy_agent_id = agent_state.id
|
||||
message_sequence = []
|
||||
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt))
|
||||
message_sequence.append(
|
||||
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK)
|
||||
)
|
||||
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input))
|
||||
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt))
|
||||
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK))
|
||||
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input))
|
||||
|
||||
# TODO: We need to eventually have a separate LLM config for the summarizer LLM
|
||||
llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
|
||||
@@ -252,82 +250,6 @@ class DummyRecallMemory(RecallMemory):
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class BaseRecallMemory(RecallMemory):
|
||||
"""Recall memory based on base functions implemented by storage connectors"""
|
||||
|
||||
def __init__(self, agent_state, restrict_search_to_summaries=False):
|
||||
# If true, the pool of messages that can be queried are the automated summaries only
|
||||
# (generated when the conversation window needs to be shortened)
|
||||
self.restrict_search_to_summaries = restrict_search_to_summaries
|
||||
from letta.agent_store.storage import StorageConnector
|
||||
|
||||
self.agent_state = agent_state
|
||||
|
||||
# create embedding model
|
||||
self.embed_model = embedding_model(agent_state.embedding_config)
|
||||
self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
||||
|
||||
# create storage backend
|
||||
self.storage = StorageConnector.get_recall_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
|
||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||
self.cache = {}
|
||||
|
||||
def get_all(self, start=0, count=None):
|
||||
start = 0 if start is None else int(start)
|
||||
count = 0 if count is None else int(count)
|
||||
results = self.storage.get_all(start, count)
|
||||
results_json = [message.to_openai_dict() for message in results]
|
||||
return results_json, len(results)
|
||||
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
start = 0 if start is None else int(start)
|
||||
count = 0 if count is None else int(count)
|
||||
results = self.storage.query_text(query_string, count, start)
|
||||
results_json = [message.to_openai_dict_search_results() for message in results]
|
||||
return results_json, len(results)
|
||||
|
||||
def date_search(self, start_date, end_date, count=None, start=None):
|
||||
start = 0 if start is None else int(start)
|
||||
count = 0 if count is None else int(count)
|
||||
results = self.storage.query_date(start_date, end_date, count, start)
|
||||
results_json = [message.to_openai_dict_search_results() for message in results]
|
||||
return results_json, len(results)
|
||||
|
||||
def compile(self) -> str:
|
||||
total = self.storage.size()
|
||||
system_count = self.storage.size(filters={"role": "system"})
|
||||
user_count = self.storage.size(filters={"role": "user"})
|
||||
assistant_count = self.storage.size(filters={"role": "assistant"})
|
||||
function_count = self.storage.size(filters={"role": "function"})
|
||||
other_count = total - (system_count + user_count + assistant_count + function_count)
|
||||
|
||||
memory_str = (
|
||||
f"Statistics:"
|
||||
+ f"\n{total} total messages"
|
||||
+ f"\n{system_count} system"
|
||||
+ f"\n{user_count} user"
|
||||
+ f"\n{assistant_count} assistant"
|
||||
+ f"\n{function_count} function"
|
||||
+ f"\n{other_count} other"
|
||||
)
|
||||
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
|
||||
|
||||
def insert(self, message: Message):
|
||||
self.storage.insert(message)
|
||||
|
||||
def insert_many(self, messages: List[Message]):
|
||||
self.storage.insert_many(messages)
|
||||
|
||||
def save(self):
|
||||
self.storage.save()
|
||||
|
||||
def __len__(self):
|
||||
return self.storage.size()
|
||||
|
||||
def count(self) -> int:
|
||||
return len(self)
|
||||
|
||||
|
||||
class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
"""Archival memory with embedding based search"""
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from letta.schemas.api_key import APIKey
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.user import User
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
@@ -66,40 +65,6 @@ class EmbeddingConfigColumn(TypeDecorator):
|
||||
return value
|
||||
|
||||
|
||||
class ToolCallColumn(TypeDecorator):
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
values = []
|
||||
for v in value:
|
||||
if isinstance(v, ToolCall):
|
||||
values.append(v.model_dump())
|
||||
else:
|
||||
values.append(v)
|
||||
return values
|
||||
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
tools = []
|
||||
for tool_value in value:
|
||||
if "function" in tool_value:
|
||||
tool_call_function = ToolCallFunction(**tool_value["function"])
|
||||
del tool_value["function"]
|
||||
else:
|
||||
tool_call_function = None
|
||||
tools.append(ToolCall(function=tool_call_function, **tool_value))
|
||||
return tools
|
||||
return value
|
||||
|
||||
|
||||
# TODO: eventually store providers?
|
||||
# class Provider(Base):
|
||||
# __tablename__ = "providers"
|
||||
|
||||
@@ -20,7 +20,7 @@ def send_thinking_message(self: "Agent", message: str) -> Optional[str]:
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.interface.internal_monologue(message, msg_obj=self._messages[-1])
|
||||
self.interface.internal_monologue(message)
|
||||
return None
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def send_final_message(self: "Agent", message: str) -> Optional[str]:
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.interface.internal_monologue(message, msg_obj=self._messages[-1])
|
||||
self.interface.internal_monologue(message)
|
||||
return None
|
||||
|
||||
|
||||
@@ -62,10 +62,15 @@ class O1Agent(Agent):
|
||||
"""Run Agent.inner_step in a loop, terminate when final thinking message is sent or max_thinking_steps is reached"""
|
||||
# assert ms is not None, "MetadataStore is required"
|
||||
next_input_message = messages if isinstance(messages, list) else [messages]
|
||||
|
||||
counter = 0
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
while step_count < self.max_thinking_steps:
|
||||
# This is hacky but we need to do this for now
|
||||
for m in next_input_message:
|
||||
m.id = m._generate_id()
|
||||
|
||||
kwargs["ms"] = ms
|
||||
kwargs["first_message"] = False
|
||||
step_response = self.inner_step(
|
||||
|
||||
@@ -18,6 +18,7 @@ def trigger_rethink_memory(agent_state: "AgentState", message: Optional[str]) ->
|
||||
|
||||
"""
|
||||
from letta import create_client
|
||||
|
||||
client = create_client()
|
||||
agents = client.list_agents()
|
||||
for agent in agents:
|
||||
@@ -149,6 +150,11 @@ class OfflineMemoryAgent(Agent):
|
||||
step_count = 0
|
||||
|
||||
while counter < self.max_memory_rethinks:
|
||||
# This is hacky but we need to do this for now
|
||||
# TODO: REMOVE THIS
|
||||
for m in next_input_message:
|
||||
m.id = m._generate_id()
|
||||
|
||||
kwargs["ms"] = ms
|
||||
kwargs["first_message"] = False
|
||||
step_response = self.inner_step(
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
|
||||
from sqlalchemy import Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
66
letta/orm/message.py
Normal file
66
letta/orm/message.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import JSON, DateTime, TypeDecorator
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
|
||||
|
||||
class ToolCallColumn(TypeDecorator):
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
values = []
|
||||
for v in value:
|
||||
if isinstance(v, ToolCall):
|
||||
values.append(v.model_dump())
|
||||
else:
|
||||
values.append(v)
|
||||
return values
|
||||
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
tools = []
|
||||
for tool_value in value:
|
||||
if "function" in tool_value:
|
||||
tool_call_function = ToolCallFunction(**tool_value["function"])
|
||||
del tool_value["function"]
|
||||
else:
|
||||
tool_call_function = None
|
||||
tools.append(ToolCall(function=tool_call_function, **tool_value))
|
||||
return tools
|
||||
return value
|
||||
|
||||
|
||||
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
"""Defines data model for storing Message objects"""
|
||||
|
||||
__tablename__ = "messages"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
__pydantic_model__ = PydanticMessage
|
||||
|
||||
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique message identifier")
|
||||
role: Mapped[str] = mapped_column(doc="Message role (user/assistant/system/tool)")
|
||||
text: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Message content")
|
||||
model: Mapped[Optional[str]] = mapped_column(nullable=True, doc="LLM model used")
|
||||
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
|
||||
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
|
||||
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
# TODO: Add in after Agent ORM is created
|
||||
# agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
|
||||
@@ -31,6 +31,22 @@ class UserMixin(Base):
|
||||
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
||||
|
||||
|
||||
class AgentMixin(Base):
|
||||
"""Mixin for models that belong to an agent."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"))
|
||||
|
||||
|
||||
class FileMixin(Base):
|
||||
"""Mixin for models that belong to a file."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id"))
|
||||
|
||||
|
||||
class SourceMixin(Base):
|
||||
"""Mixin for models (e.g. file) that belong to a source."""
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ class Organization(SqlalchemyBase):
|
||||
sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship(
|
||||
"SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
# TODO: Map these relationships later when we actually make these models
|
||||
# below is just a suggestion
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type
|
||||
|
||||
from sqlalchemy import String, select
|
||||
from sqlalchemy import String, func, select
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
@@ -20,6 +22,11 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AccessType(str, Enum):
|
||||
ORGANIZATION = "organization"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
__abstract__ = True
|
||||
|
||||
@@ -28,46 +35,68 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls, *, db_session: "Session", cursor: Optional[str] = None, limit: Optional[int] = 50, **kwargs
|
||||
) -> List[Type["SqlalchemyBase"]]:
|
||||
"""
|
||||
List records with optional cursor (for pagination), limit, and automatic filtering.
|
||||
def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
|
||||
"""Get a record by ID.
|
||||
|
||||
Args:
|
||||
db_session: The database session to use.
|
||||
cursor: Optional ID to start pagination from.
|
||||
limit: Maximum number of records to return.
|
||||
**kwargs: Filters passed as equality conditions or iterable for IN filtering.
|
||||
db_session: SQLAlchemy session
|
||||
id: Record ID to retrieve
|
||||
|
||||
Returns:
|
||||
A list of model instances matching the filters.
|
||||
Optional[SqlalchemyBase]: The record if found, None otherwise
|
||||
"""
|
||||
logger.debug(f"Listing {cls.__name__} with filters {kwargs}")
|
||||
try:
|
||||
return db_session.query(cls).filter(cls.id == id).first()
|
||||
except DBAPIError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
*,
|
||||
db_session: "Session",
|
||||
cursor: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[Type["SqlalchemyBase"]]:
|
||||
"""List records with advanced filtering and pagination options."""
|
||||
if start_date and end_date and start_date > end_date:
|
||||
raise ValueError("start_date must be earlier than or equal to end_date")
|
||||
|
||||
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
|
||||
with db_session as session:
|
||||
# Start with a base query
|
||||
query = select(cls)
|
||||
|
||||
# Apply filtering logic
|
||||
for key, value in kwargs.items():
|
||||
column = getattr(cls, key)
|
||||
if isinstance(value, (list, tuple, set)): # Check for iterables
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
query = query.where(column.in_(value))
|
||||
else: # Single value for equality filtering
|
||||
else:
|
||||
query = query.where(column == value)
|
||||
|
||||
# Apply cursor for pagination
|
||||
# Date range filtering
|
||||
if start_date:
|
||||
query = query.filter(cls.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at <= end_date)
|
||||
|
||||
# Cursor-based pagination
|
||||
if cursor:
|
||||
query = query.where(cls.id > cursor)
|
||||
|
||||
# Handle soft deletes if the class has the 'is_deleted' attribute
|
||||
# Apply text search
|
||||
if query_text:
|
||||
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
||||
|
||||
# Handle ordering and soft deletes
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
|
||||
# Add ordering and limit
|
||||
query = query.order_by(cls.id).limit(limit)
|
||||
|
||||
# Execute the query and return results as model instances
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
@classmethod
|
||||
@@ -77,6 +106,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
identifier: Optional[str] = None,
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
**kwargs,
|
||||
) -> Type["SqlalchemyBase"]:
|
||||
"""The primary accessor for an ORM record.
|
||||
@@ -108,7 +138,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_conditions.append(", ".join(f"{key}='{value}'" for key, value in kwargs.items()))
|
||||
|
||||
if actor:
|
||||
query = cls.apply_access_predicate(query, actor, access)
|
||||
query = cls.apply_access_predicate(query, actor, access, access_type)
|
||||
query_conditions.append(f"access level in {access} for actor='{actor}'")
|
||||
|
||||
if hasattr(cls, "is_deleted"):
|
||||
@@ -170,12 +200,66 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def size(
|
||||
cls,
|
||||
*,
|
||||
db_session: "Session",
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""
|
||||
Get the count of rows that match the provided filters.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
**kwargs: Filters to apply to the query (e.g., column_name=value)
|
||||
|
||||
Returns:
|
||||
int: The count of rows that match the filters
|
||||
|
||||
Raises:
|
||||
DBAPIError: If a database error occurs
|
||||
"""
|
||||
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
||||
|
||||
with db_session as session:
|
||||
query = select(func.count()).select_from(cls)
|
||||
|
||||
if actor:
|
||||
query = cls.apply_access_predicate(query, actor, access, access_type)
|
||||
|
||||
# Apply filtering logic based on kwargs
|
||||
for key, value in kwargs.items():
|
||||
if value:
|
||||
column = getattr(cls, key, None)
|
||||
if not column:
|
||||
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
|
||||
if isinstance(value, (list, tuple, set)): # Check for iterables
|
||||
query = query.where(column.in_(value))
|
||||
else: # Single value for equality filtering
|
||||
query = query.where(column == value)
|
||||
|
||||
# Handle soft deletes if the class has the 'is_deleted' attribute
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
|
||||
try:
|
||||
count = session.execute(query).scalar()
|
||||
return count if count else 0
|
||||
except DBAPIError as e:
|
||||
logger.exception(f"Failed to calculate size for {cls.__name__}")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def apply_access_predicate(
|
||||
cls,
|
||||
query: "Select",
|
||||
actor: "User",
|
||||
access: List[Literal["read", "write", "admin"]],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
) -> "Select":
|
||||
"""applies a WHERE clause restricting results to the given actor and access level
|
||||
Args:
|
||||
@@ -189,10 +273,18 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
the sqlalchemy select statement restricted to the given access.
|
||||
"""
|
||||
del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
|
||||
org_id = getattr(actor, "organization_id", None)
|
||||
if not org_id:
|
||||
raise ValueError(f"object {actor} has no organization accessor")
|
||||
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
|
||||
if access_type == AccessType.ORGANIZATION:
|
||||
org_id = getattr(actor, "organization_id", None)
|
||||
if not org_id:
|
||||
raise ValueError(f"object {actor} has no organization accessor")
|
||||
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
|
||||
elif access_type == AccessType.USER:
|
||||
user_id = getattr(actor, "id", None)
|
||||
if not user_id:
|
||||
raise ValueError(f"object {actor} has no user accessor")
|
||||
return query.where(cls.user_id == user_id, cls.is_deleted == False)
|
||||
else:
|
||||
raise ValueError(f"unknown access_type: {access_type}")
|
||||
|
||||
@classmethod
|
||||
def _handle_dbapi_error(cls, e: DBAPIError):
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from letta.memory import BaseRecallMemory, EmbeddingArchivalMemory
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message
|
||||
from letta.utils import printd
|
||||
|
||||
|
||||
def parse_formatted_time(formatted_time: str):
|
||||
# parse times returned by letta.utils.get_formatted_time()
|
||||
try:
|
||||
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
except:
|
||||
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p")
|
||||
|
||||
|
||||
class PersistenceManager(ABC):
|
||||
@abstractmethod
|
||||
def trim_messages(self, num):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepend_to_messages(self, added_messages):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_to_messages(self, added_messages):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_system_message(self, new_system_message):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_memory(self, new_memory):
|
||||
pass
|
||||
|
||||
|
||||
class LocalStateManager(PersistenceManager):
|
||||
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
||||
|
||||
recall_memory_cls = BaseRecallMemory
|
||||
archival_memory_cls = EmbeddingArchivalMemory
|
||||
|
||||
def __init__(self, agent_state: AgentState):
|
||||
# Memory held in-state useful for debugging stateful versions
|
||||
self.memory = agent_state.memory
|
||||
# self.messages = [] # current in-context messages
|
||||
# self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB)
|
||||
self.archival_memory = EmbeddingArchivalMemory(agent_state)
|
||||
self.recall_memory = BaseRecallMemory(agent_state)
|
||||
# self.agent_state = agent_state
|
||||
|
||||
def save(self):
|
||||
"""Ensure storage connectors save data"""
|
||||
self.archival_memory.save()
|
||||
self.recall_memory.save()
|
||||
|
||||
'''
|
||||
def json_to_message(self, message_json) -> Message:
|
||||
"""Convert agent message JSON into Message object"""
|
||||
|
||||
# get message
|
||||
if "message" in message_json:
|
||||
message = message_json["message"]
|
||||
else:
|
||||
message = message_json
|
||||
|
||||
# get timestamp
|
||||
if "timestamp" in message_json:
|
||||
timestamp = parse_formatted_time(message_json["timestamp"])
|
||||
else:
|
||||
timestamp = get_local_time()
|
||||
|
||||
# TODO: change this when we fully migrate to tool calls API
|
||||
if "function_call" in message:
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=message["tool_call_id"],
|
||||
tool_call_type="function",
|
||||
function={
|
||||
"name": message["function_call"]["name"],
|
||||
"arguments": message["function_call"]["arguments"],
|
||||
},
|
||||
)
|
||||
]
|
||||
printd(f"Saving tool calls {[vars(tc) for tc in tool_calls]}")
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
# if message["role"] == "function":
|
||||
# message["role"] = "tool"
|
||||
|
||||
return Message(
|
||||
user_id=self.agent_state.user_id,
|
||||
agent_id=self.agent_state.id,
|
||||
role=message["role"],
|
||||
text=message["content"],
|
||||
name=message["name"] if "name" in message else None,
|
||||
model=self.agent_state.llm_config.model,
|
||||
created_at=timestamp,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None,
|
||||
id=message["id"] if "id" in message else None,
|
||||
)
|
||||
'''
|
||||
|
||||
def trim_messages(self, num):
|
||||
# printd(f"InMemoryStateManager.trim_messages")
|
||||
# self.messages = [self.messages[0]] + self.messages[num:]
|
||||
pass
|
||||
|
||||
def prepend_to_messages(self, added_messages: List[Message]):
|
||||
# first tag with timestamps
|
||||
# added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"{self.__class__.__name__}.prepend_to_message")
|
||||
# self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
||||
|
||||
# add to recall memory
|
||||
self.recall_memory.insert_many([m for m in added_messages])
|
||||
|
||||
def append_to_messages(self, added_messages: List[Message]):
|
||||
# first tag with timestamps
|
||||
# added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"{self.__class__.__name__}.append_to_messages")
|
||||
# self.messages = self.messages + added_messages
|
||||
|
||||
# add to recall memory
|
||||
self.recall_memory.insert_many([m for m in added_messages])
|
||||
|
||||
def swap_system_message(self, new_system_message: Message):
|
||||
# first tag with timestamps
|
||||
# new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
|
||||
|
||||
printd(f"{self.__class__.__name__}.swap_system_message")
|
||||
# self.messages[0] = new_system_message
|
||||
|
||||
# add to recall memory
|
||||
self.recall_memory.insert(new_system_message)
|
||||
|
||||
def update_memory(self, new_memory: Memory):
|
||||
printd(f"{self.__class__.__name__}.update_memory")
|
||||
assert isinstance(new_memory, Memory), type(new_memory)
|
||||
self.memory = new_memory
|
||||
@@ -33,18 +33,19 @@ class LettaBase(BaseModel):
|
||||
def generate_id_field(cls, prefix: Optional[str] = None) -> "Field":
|
||||
prefix = prefix or cls.__id_prefix__
|
||||
|
||||
# TODO: generate ID from regex pattern?
|
||||
def _generate_id() -> str:
|
||||
return f"{prefix}-{uuid.uuid4()}"
|
||||
|
||||
return Field(
|
||||
...,
|
||||
description=cls._id_description(prefix),
|
||||
pattern=cls._id_regex_pattern(prefix),
|
||||
examples=[cls._id_example(prefix)],
|
||||
default_factory=_generate_id,
|
||||
default_factory=cls._generate_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _generate_id(cls, prefix: Optional[str] = None) -> str:
|
||||
prefix = prefix or cls.__id_prefix__
|
||||
return f"{prefix}-{uuid.uuid4()}"
|
||||
|
||||
# def _generate_id(self) -> str:
|
||||
# return f"{self.__id_prefix__}-{uuid.uuid4()}"
|
||||
|
||||
@@ -78,7 +79,7 @@ class LettaBase(BaseModel):
|
||||
"""
|
||||
_ = values # for SCA
|
||||
if isinstance(v, UUID):
|
||||
logger.warning(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
|
||||
logger.debug(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
|
||||
return f"{cls.__id_prefix__}-{v}"
|
||||
return v
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ class Message(BaseMessage):
|
||||
id: str = BaseMessage.generate_id_field()
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
text: Optional[str] = Field(None, description="The text of the message.")
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
model: Optional[str] = Field(None, description="The model used to make the function call.")
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
@@ -281,7 +281,6 @@ class Message(BaseMessage):
|
||||
)
|
||||
if id is not None:
|
||||
return Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
@@ -295,7 +294,6 @@ class Message(BaseMessage):
|
||||
)
|
||||
else:
|
||||
return Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
@@ -328,7 +326,6 @@ class Message(BaseMessage):
|
||||
|
||||
if id is not None:
|
||||
return Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
@@ -342,7 +339,6 @@ class Message(BaseMessage):
|
||||
)
|
||||
else:
|
||||
return Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
@@ -375,7 +371,6 @@ class Message(BaseMessage):
|
||||
# If we're going from tool-call style
|
||||
if id is not None:
|
||||
return Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
@@ -389,7 +384,6 @@ class Message(BaseMessage):
|
||||
)
|
||||
else:
|
||||
return Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
|
||||
@@ -409,7 +409,7 @@ def get_agent_messages(
|
||||
return server.get_agent_recall_cursor(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
before=before,
|
||||
cursor=before,
|
||||
limit=limit,
|
||||
reverse=True,
|
||||
return_message_object=msg_object,
|
||||
|
||||
@@ -77,14 +77,15 @@ from letta.schemas.user import User
|
||||
from letta.services.agents_tags_manager import AgentsTagsManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
||||
from letta.services.tools_agents_manager import ToolsAgentsManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.tools_agents_manager import ToolsAgentsManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.utils import create_random_username, get_utc_time, json_dumps, json_loads
|
||||
|
||||
@@ -260,6 +261,7 @@ class SyncServer(Server):
|
||||
self.agents_tags_manager = AgentsTagsManager()
|
||||
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
||||
self.blocks_agents_manager = BlocksAgentsManager()
|
||||
self.message_manager = MessageManager()
|
||||
self.tools_agents_manager = ToolsAgentsManager()
|
||||
self.job_manager = JobManager()
|
||||
|
||||
@@ -414,7 +416,7 @@ class SyncServer(Server):
|
||||
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor)
|
||||
elif agent_state.agent_type == AgentType.chat_only_agent:
|
||||
agent = ChatOnlyAgent(agent_state=agent_state, interface=interface, user=actor)
|
||||
else:
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
|
||||
|
||||
# Rebuild the system prompt - may be linked to new blocks now
|
||||
@@ -422,7 +424,7 @@ class SyncServer(Server):
|
||||
|
||||
# Persist to agent
|
||||
save_agent(agent, self.ms)
|
||||
return agent
|
||||
return agent
|
||||
|
||||
def _step(
|
||||
self,
|
||||
@@ -518,12 +520,7 @@ class SyncServer(Server):
|
||||
letta_agent.interface.print_messages_raw(letta_agent.messages)
|
||||
|
||||
elif command.lower() == "memory":
|
||||
ret_str = (
|
||||
f"\nDumping memory contents:\n"
|
||||
+ f"\n{str(letta_agent.agent_state.memory)}"
|
||||
+ f"\n{str(letta_agent.persistence_manager.archival_memory)}"
|
||||
+ f"\n{str(letta_agent.persistence_manager.recall_memory)}"
|
||||
)
|
||||
ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.archival_memory)}"
|
||||
return ret_str
|
||||
|
||||
elif command.lower() == "pop" or command.lower().startswith("pop "):
|
||||
@@ -625,7 +622,6 @@ class SyncServer(Server):
|
||||
# Convert to a Message object
|
||||
if timestamp:
|
||||
message = Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role="user",
|
||||
text=packaged_user_message,
|
||||
@@ -633,7 +629,6 @@ class SyncServer(Server):
|
||||
)
|
||||
else:
|
||||
message = Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role="user",
|
||||
text=packaged_user_message,
|
||||
@@ -672,7 +667,6 @@ class SyncServer(Server):
|
||||
|
||||
if timestamp:
|
||||
message = Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role="system",
|
||||
text=packaged_system_message,
|
||||
@@ -680,7 +674,6 @@ class SyncServer(Server):
|
||||
)
|
||||
else:
|
||||
message = Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role="system",
|
||||
text=packaged_system_message,
|
||||
@@ -743,7 +736,6 @@ class SyncServer(Server):
|
||||
# Create the Message object
|
||||
message_objects.append(
|
||||
Message(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
text=message.text,
|
||||
@@ -876,7 +868,7 @@ class SyncServer(Server):
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
init_messages.append(Message(role=message.role, text=packed_message, user_id=user_id, agent_id=agent_state.id))
|
||||
init_messages.append(Message(role=message.role, text=packed_message, agent_id=agent_state.id))
|
||||
# init_messages = [Message.dict_to_message(user_id=user_id, agent_id=agent_state.id, openai_message_dict=message.model_dump()) for message in request.initial_message_sequence]
|
||||
else:
|
||||
init_messages = None
|
||||
@@ -1160,11 +1152,11 @@ class SyncServer(Server):
|
||||
|
||||
def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
return ArchivalMemorySummary(size=len(agent.persistence_manager.archival_memory))
|
||||
return ArchivalMemorySummary(size=len(agent.archival_memory))
|
||||
|
||||
def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
return RecallMemorySummary(size=len(agent.persistence_manager.recall_memory))
|
||||
return RecallMemorySummary(size=len(agent.message_manager))
|
||||
|
||||
def get_in_context_message_ids(self, agent_id: str) -> List[str]:
|
||||
"""Get the message ids of the in-context messages in the agent's memory"""
|
||||
@@ -1182,7 +1174,7 @@ class SyncServer(Server):
|
||||
"""Get a single message from the agent's memory"""
|
||||
# Get the agent object (loaded in memory)
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
message = agent.persistence_manager.recall_memory.storage.get(id=message_id)
|
||||
message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user)
|
||||
return message
|
||||
|
||||
def get_agent_messages(
|
||||
@@ -1213,14 +1205,16 @@ class SyncServer(Server):
|
||||
|
||||
else:
|
||||
# need to access persistence manager for additional messages
|
||||
db_iterator = letta_agent.persistence_manager.recall_memory.storage.get_all_paginated(page_size=count, offset=start)
|
||||
|
||||
# get a single page of messages
|
||||
# TODO: handle stop iteration
|
||||
page = next(db_iterator, [])
|
||||
# get messages using message manager
|
||||
page = letta_agent.message_manager.list_user_messages_for_agent(
|
||||
agent_id=agent_id,
|
||||
actor=self.default_user,
|
||||
cursor=start,
|
||||
limit=count,
|
||||
)
|
||||
|
||||
# return messages in reverse chronological order
|
||||
messages = sorted(page, key=lambda x: x.created_at, reverse=True)
|
||||
messages = page
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
## Convert to json
|
||||
@@ -1243,7 +1237,7 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# iterate over records
|
||||
db_iterator = letta_agent.persistence_manager.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
|
||||
db_iterator = letta_agent.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
|
||||
|
||||
# get a single page of messages
|
||||
page = next(db_iterator, [])
|
||||
@@ -1268,7 +1262,7 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# iterate over recorde
|
||||
cursor, records = letta_agent.persistence_manager.archival_memory.storage.get_all_cursor(
|
||||
cursor, records = letta_agent.archival_memory.storage.get_all_cursor(
|
||||
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
||||
)
|
||||
return records
|
||||
@@ -1283,14 +1277,14 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# Insert into archival memory
|
||||
passage_ids = letta_agent.persistence_manager.archival_memory.insert(memory_string=memory_contents, return_ids=True)
|
||||
passage_ids = letta_agent.archival_memory.insert(memory_string=memory_contents, return_ids=True)
|
||||
|
||||
# Update the agent
|
||||
# TODO: should this update the system prompt?
|
||||
save_agent(letta_agent, self.ms)
|
||||
|
||||
# TODO: this is gross, fix
|
||||
return [letta_agent.persistence_manager.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids]
|
||||
return [letta_agent.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids]
|
||||
|
||||
def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str):
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
@@ -1305,7 +1299,7 @@ class SyncServer(Server):
|
||||
|
||||
# Delete by ID
|
||||
# TODO check if it exists first, and throw error if not
|
||||
letta_agent.persistence_manager.archival_memory.storage.delete({"id": memory_id})
|
||||
letta_agent.archival_memory.storage.delete({"id": memory_id})
|
||||
|
||||
# TODO: return archival memory
|
||||
|
||||
@@ -1313,17 +1307,15 @@ class SyncServer(Server):
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
order_by: Optional[str] = "created_at",
|
||||
order: Optional[str] = "asc",
|
||||
reverse: Optional[bool] = False,
|
||||
return_message_object: bool = True,
|
||||
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
actor = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
if actor is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
@@ -1332,8 +1324,12 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# iterate over records
|
||||
cursor, records = letta_agent.persistence_manager.recall_memory.storage.get_all_cursor(
|
||||
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
||||
# TODO: Check "order_by", "order"
|
||||
records = letta_agent.message_manager.list_messages_for_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
assert all(isinstance(m, Message) for m in records)
|
||||
@@ -1353,7 +1349,7 @@ class SyncServer(Server):
|
||||
records = records[::-1]
|
||||
|
||||
return records
|
||||
|
||||
|
||||
def get_server_config(self, include_defaults: bool = False) -> dict:
|
||||
"""Return the base config"""
|
||||
|
||||
@@ -1425,19 +1421,25 @@ class SyncServer(Server):
|
||||
self.agents_tags_manager.delete_all_tags_from_agent(agent_id=agent_id, actor=actor)
|
||||
self.blocks_agents_manager.remove_all_agent_blocks(agent_id=agent_id)
|
||||
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
|
||||
# Verify that the agent exists and belongs to the org of the user
|
||||
agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
|
||||
if not agent_state:
|
||||
if agent_state is None:
|
||||
raise ValueError(f"Could not find agent_id={agent_id} under user_id={user_id}")
|
||||
|
||||
agent_state_user = self.user_manager.get_user_by_id(user_id=agent_state.user_id)
|
||||
if agent_state_user.organization_id != actor.organization_id:
|
||||
raise ValueError(
|
||||
f"Could not authorize agent_id={agent_id} with user_id={user_id} because of differing organizations; agent_id was created in {agent_state_user.organization_id} while user belongs to {actor.organization_id}. How did they get the agent id?"
|
||||
)
|
||||
# TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM MODEL
|
||||
messages = self.message_manager.list_messages_for_agent(agent_id=agent_state.id)
|
||||
for message in messages:
|
||||
self.message_manager.delete_message_by_id(message.id, actor=actor)
|
||||
|
||||
# TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM
|
||||
try:
|
||||
agent_state_user = self.user_manager.get_user_by_id(user_id=agent_state.user_id)
|
||||
if agent_state_user.organization_id != actor.organization_id:
|
||||
raise ValueError(
|
||||
f"Could not authorize agent_id={agent_id} with user_id={user_id} because of differing organizations; agent_id was created in {agent_state_user.organization_id} while user belongs to {actor.organization_id}. How did they get the agent id?"
|
||||
)
|
||||
except NoResultFound:
|
||||
logger.error(f"Agent with id {agent_state.id} has nonexistent user {agent_state.user_id}")
|
||||
|
||||
# First, if the agent is in the in-memory cache we should remove it
|
||||
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
|
||||
@@ -1582,7 +1584,7 @@ class SyncServer(Server):
|
||||
|
||||
# delete all Passage objects with source_id==source_id from agent's archival memory
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
archival_memory = agent.persistence_manager.archival_memory
|
||||
archival_memory = agent.archival_memory
|
||||
archival_memory.storage.delete({"source_id": source_id})
|
||||
|
||||
# delete agent-source mapping
|
||||
@@ -1661,7 +1663,7 @@ class SyncServer(Server):
|
||||
"""Get a single message from the agent's memory"""
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
message = letta_agent.persistence_manager.recall_memory.storage.get(id=message_id)
|
||||
message = letta_agent.message_manager.get_message_by_id(id=message_id)
|
||||
save_agent(letta_agent, self.ms)
|
||||
return message
|
||||
|
||||
@@ -1705,7 +1707,7 @@ class SyncServer(Server):
|
||||
|
||||
try:
|
||||
return self.user_manager.get_user_by_id(user_id=user_id)
|
||||
except ValueError:
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"User with id {user_id} not found")
|
||||
|
||||
def get_organization_or_default(self, org_id: Optional[str]) -> Organization:
|
||||
@@ -1715,7 +1717,7 @@ class SyncServer(Server):
|
||||
|
||||
try:
|
||||
return self.organization_manager.get_organization_by_id(org_id=org_id)
|
||||
except ValueError:
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
|
||||
@@ -102,7 +102,7 @@ class BlockManager:
|
||||
return [block.to_pydantic() for block in blocks]
|
||||
|
||||
@enforce_types
|
||||
def get_block_by_id(self, block_id, actor: PydanticUser) -> Optional[PydanticBlock]:
|
||||
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its name."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
|
||||
182
letta/services/message_manager.py
Normal file
182
letta/services/message_manager.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class MessageManager:
|
||||
"""Manager class to handle business logic related to Messages."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
|
||||
"""Fetch a message by ID."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
message = MessageModel.read(db_session=session, identifier=message_id, actor=actor)
|
||||
return message.to_pydantic()
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
|
||||
"""Create a new message."""
|
||||
with self.session_maker() as session:
|
||||
# Set the organization id of the Pydantic message
|
||||
pydantic_msg.organization_id = actor.organization_id
|
||||
msg_data = pydantic_msg.model_dump()
|
||||
msg = MessageModel(**msg_data)
|
||||
msg.create(session, actor=actor) # Persist to database
|
||||
return msg.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_many_messages(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]:
|
||||
"""Create multiple messages."""
|
||||
return [self.create_message(m, actor=actor) for m in pydantic_msgs]
|
||||
|
||||
@enforce_types
|
||||
def update_message_by_id(self, message_id: str, message: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
|
||||
"""
|
||||
Updates an existing record in the database with values from the provided record object.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Fetch existing message from database
|
||||
msg = MessageModel.read(
|
||||
db_session=session,
|
||||
identifier=message_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
for column in MessageModel.__table__.columns:
|
||||
column_name = column.name
|
||||
if hasattr(message, column_name):
|
||||
new_value = getattr(message, column_name)
|
||||
setattr(msg, column_name, new_value)
|
||||
|
||||
# Commit changes
|
||||
return msg.update(db_session=session, actor=actor).to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete a message."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
msg = MessageModel.read(
|
||||
db_session=session,
|
||||
identifier=message_id,
|
||||
actor=actor,
|
||||
)
|
||||
msg.hard_delete(session, actor=actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Message with id {message_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
role: Optional[MessageRole] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get the total count of messages with optional filters.
|
||||
|
||||
Args:
|
||||
actor: The user requesting the count
|
||||
role: The role of the message
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id)
|
||||
|
||||
@enforce_types
|
||||
def list_user_messages_for_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
cursor: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
filters: Optional[Dict] = None,
|
||||
query_text: Optional[str] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""List user messages with flexible filtering and pagination options.
|
||||
|
||||
Args:
|
||||
cursor: Cursor-based pagination - return records after this ID (exclusive)
|
||||
start_date: Filter records created after this date
|
||||
end_date: Filter records created before this date
|
||||
limit: Maximum number of records to return
|
||||
filters: Additional filters to apply
|
||||
query_text: Optional text to search for in message content
|
||||
|
||||
Returns:
|
||||
List[PydanticMessage] - List of messages matching the criteria
|
||||
"""
|
||||
message_filters = {"role": "user"}
|
||||
if filters:
|
||||
message_filters.update(filters)
|
||||
|
||||
return self.list_messages_for_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
cursor=cursor,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
filters=message_filters,
|
||||
query_text=query_text,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
def list_messages_for_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
cursor: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
filters: Optional[Dict] = None,
|
||||
query_text: Optional[str] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""List messages with flexible filtering and pagination options.
|
||||
|
||||
Args:
|
||||
cursor: Cursor-based pagination - return records after this ID (exclusive)
|
||||
start_date: Filter records created after this date
|
||||
end_date: Filter records created before this date
|
||||
limit: Maximum number of records to return
|
||||
filters: Additional filters to apply
|
||||
query_text: Optional text to search for in message content
|
||||
|
||||
Returns:
|
||||
List[PydanticMessage] - List of messages matching the criteria
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Start with base filters
|
||||
message_filters = {"agent_id": agent_id}
|
||||
if actor:
|
||||
message_filters.update({"organization_id": actor.organization_id})
|
||||
if filters:
|
||||
message_filters.update(filters)
|
||||
|
||||
results = MessageModel.list(
|
||||
db_session=session,
|
||||
cursor=cursor,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
query_text=query_text,
|
||||
**message_filters,
|
||||
)
|
||||
|
||||
return [msg.to_pydantic() for msg in results]
|
||||
@@ -30,19 +30,16 @@ class OrganizationManager:
|
||||
def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]:
|
||||
"""Fetch an organization by ID."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
organization = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
return organization.to_pydantic()
|
||||
except NoResultFound:
|
||||
return None
|
||||
organization = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
return organization.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
||||
"""Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
|
||||
org = self.get_organization_by_id(pydantic_org.id)
|
||||
if org:
|
||||
"""Create a new organization."""
|
||||
try:
|
||||
org = self.get_organization_by_id(pydantic_org.id)
|
||||
return org
|
||||
else:
|
||||
except NoResultFound:
|
||||
return self._create_organization(pydantic_org=pydantic_org)
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -141,5 +141,5 @@ class SourceManager:
|
||||
"""Delete a file by its ID."""
|
||||
with self.session_maker() as session:
|
||||
file = FileMetadataModel.read(db_session=session, identifier=file_id)
|
||||
file.delete(db_session=session, actor=actor)
|
||||
file.hard_delete(db_session=session, actor=actor)
|
||||
return file.to_pydantic()
|
||||
|
||||
@@ -122,7 +122,7 @@ class ToolManager:
|
||||
tool.json_schema = new_schema
|
||||
|
||||
# Save the updated tool to the database
|
||||
return tool.update(db_session=session, actor=actor)
|
||||
return tool.update(db_session=session, actor=actor).to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:
|
||||
|
||||
@@ -71,7 +71,7 @@ class UserManager:
|
||||
with self.session_maker() as session:
|
||||
# Delete from user table
|
||||
user = UserModel.read(db_session=session, identifier=user_id)
|
||||
user.delete(session)
|
||||
user.hard_delete(session)
|
||||
|
||||
# TODO: Integrate this via the ORM models for the Agent, Source, and AgentSourceMapping
|
||||
# Cascade delete for related models: Agent, Source, AgentSourceMapping
|
||||
|
||||
@@ -1,5 +1,24 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_e2b_api_key_none():
|
||||
# Store the original value of e2b_api_key
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
# Set e2b_api_key to None
|
||||
tool_settings.e2b_api_key = None
|
||||
|
||||
# Yield control to the test
|
||||
yield
|
||||
|
||||
# Restore the original value of e2b_api_key
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
@@ -58,21 +58,6 @@ def clear_tables():
|
||||
Sandbox.connect(sandbox.sandbox_id).kill()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_e2b_api_key_none():
|
||||
# Store the original value of e2b_api_key
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
# Set e2b_api_key to None
|
||||
tool_settings.e2b_api_key = None
|
||||
|
||||
# Yield control to the test
|
||||
yield
|
||||
|
||||
# Restore the original value of e2b_api_key
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_e2b_key_is_set():
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
@@ -5,7 +5,6 @@ import pytest
|
||||
from letta import create_client
|
||||
from letta.schemas.letta_message import FunctionCallMessage
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.settings import tool_settings
|
||||
from tests.helpers.endpoints_helper import (
|
||||
assert_invoked_function_call,
|
||||
assert_invoked_send_message_with_keyword,
|
||||
@@ -20,21 +19,6 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_e2b_api_key_none():
|
||||
# Store the original value of e2b_api_key
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
# Set e2b_api_key to None
|
||||
tool_settings.e2b_api_key = None
|
||||
|
||||
# Yield control to the test
|
||||
yield
|
||||
|
||||
# Restore the original value of e2b_api_key
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
"""Contrived tools for this test case"""
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType
|
||||
from letta.settings import tool_settings
|
||||
from letta.utils import create_random_username
|
||||
|
||||
# Constants
|
||||
@@ -40,7 +39,7 @@ def run_server():
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[{"server": False}], # whether to use REST API server
|
||||
params=[{"server": False}, {"server": True}], # whether to use REST API server
|
||||
scope="module",
|
||||
)
|
||||
def client(request):
|
||||
@@ -83,21 +82,6 @@ def clear_tables():
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_e2b_api_key_none():
|
||||
# Store the original value of e2b_api_key
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
# Set e2b_api_key to None
|
||||
tool_settings.e2b_api_key = None
|
||||
|
||||
# Yield control to the test
|
||||
yield
|
||||
|
||||
# Restore the original value of e2b_api_key
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
|
||||
"""
|
||||
Test sandbox config and environment variable functions for both LocalClient and RESTClient.
|
||||
|
||||
@@ -27,9 +27,12 @@ from letta.schemas.letta_message import (
|
||||
)
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import model_settings, tool_settings
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
|
||||
# from tests.utils import create_config
|
||||
@@ -54,21 +57,6 @@ def run_server():
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_e2b_api_key_none():
|
||||
# Store the original value of e2b_api_key
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
# Set e2b_api_key to None
|
||||
tool_settings.e2b_api_key = None
|
||||
|
||||
# Yield control to the test
|
||||
yield
|
||||
|
||||
# Restore the original value of e2b_api_key
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
# Fixture to create clients with different configurations
|
||||
@pytest.fixture(
|
||||
# params=[{"server": True}, {"server": False}], # whether to use REST API server
|
||||
@@ -119,6 +107,22 @@ def agent(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_organization():
|
||||
"""Fixture to create and return the default organization."""
|
||||
manager = OrganizationManager()
|
||||
org = manager.create_default_organization()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
manager = UserManager()
|
||||
user = manager.create_default_user(org_id=default_organization.id)
|
||||
yield user
|
||||
|
||||
|
||||
def test_agent(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
# test client.rename_agent
|
||||
@@ -612,7 +616,7 @@ def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTCli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_agents():
|
||||
def cleanup_agents(client):
|
||||
created_agents = []
|
||||
yield created_agents
|
||||
# Cleanup will run even if test fails
|
||||
@@ -624,7 +628,7 @@ def cleanup_agents():
|
||||
|
||||
|
||||
# NOTE: we need to add this back once agents can also create blocks during agent creation
|
||||
def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]):
|
||||
def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str], default_user):
|
||||
"""Test that we can set an initial message sequence
|
||||
|
||||
If we pass in None, we should get a "default" message sequence
|
||||
@@ -638,11 +642,12 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
||||
reference_init_messages = initialize_message_sequence(
|
||||
model=agent.llm_config.model,
|
||||
system=agent.system,
|
||||
agent_id=agent.id,
|
||||
memory=agent.memory,
|
||||
archival_memory=None,
|
||||
recall_memory=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# system, login message, send_message test, send_message receipt
|
||||
@@ -661,24 +666,8 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
||||
# Test with empty sequence
|
||||
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
|
||||
cleanup_agents.append(empty_agent_state.id)
|
||||
# NOTE: allowed to be None initially
|
||||
# assert empty_agent_state.message_ids is not None
|
||||
# assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}"
|
||||
|
||||
# Test with custom sequence
|
||||
# custom_sequence = [
|
||||
# Message(
|
||||
# role=MessageRole.user,
|
||||
# text="Hello, how are you?",
|
||||
# user_id=agent.user_id,
|
||||
# agent_id=agent.id,
|
||||
# model=agent.llm_config.model,
|
||||
# name=None,
|
||||
# tool_calls=None,
|
||||
# tool_call_id=None,
|
||||
# ),
|
||||
# ]
|
||||
custom_sequence = [{"text": "Hello, how are you?", "role": "user"}]
|
||||
custom_sequence = [Message(**{"text": "Hello, how are you?", "role": "user"})]
|
||||
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
|
||||
cleanup_agents.append(custom_agent_state.id)
|
||||
assert custom_agent_state.message_ids is not None
|
||||
@@ -687,7 +676,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
||||
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
|
||||
# assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
|
||||
# shoule be contained in second message (after system message)
|
||||
assert custom_sequence[0]["text"] in client.get_in_context_messages(custom_agent_state.id)[1].text
|
||||
assert custom_sequence[0].text in client.get_in_context_messages(custom_agent_state.id)[1].text
|
||||
|
||||
|
||||
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
@@ -8,7 +8,6 @@ from letta.schemas.agent import PersistedAgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
|
||||
from letta.schemas.tool import ToolCreate
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -293,6 +292,10 @@ def test_tools(client: LocalClient):
|
||||
"""
|
||||
print(msg)
|
||||
|
||||
# Clean all tools first
|
||||
for tool in client.list_tools():
|
||||
client.delete_tool(tool.id)
|
||||
|
||||
# create tool
|
||||
tool = client.create_or_update_tool(func=print_tool, tags=["extras"])
|
||||
|
||||
@@ -330,49 +333,50 @@ def test_tools_from_composio_basic(client: LocalClient):
|
||||
# The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable
|
||||
|
||||
|
||||
def test_tools_from_langchain(client: LocalClient):
|
||||
# create langchain tool
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
||||
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
|
||||
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
|
||||
# Add the tool
|
||||
tool = client.load_langchain_tool(
|
||||
langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
)
|
||||
|
||||
# list tools
|
||||
tools = client.list_tools()
|
||||
assert tool.name in [t.name for t in tools]
|
||||
|
||||
# get tool
|
||||
tool_id = client.get_tool_id(name=tool.name)
|
||||
retrieved_tool = client.get_tool(tool_id)
|
||||
source_code = retrieved_tool.source_code
|
||||
|
||||
# Parse the function and attempt to use it
|
||||
local_scope = {}
|
||||
exec(source_code, {}, local_scope)
|
||||
func = local_scope[tool.name]
|
||||
|
||||
expected_content = "Albert Einstein"
|
||||
assert expected_content in func(query="Albert Einstein")
|
||||
|
||||
|
||||
def test_tool_creation_langchain_missing_imports(client: LocalClient):
|
||||
# create langchain tool
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
||||
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
|
||||
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
|
||||
# Translate to memGPT Tool
|
||||
# Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
with pytest.raises(RuntimeError):
|
||||
ToolCreate.from_langchain(langchain_tool)
|
||||
# TODO: Langchain seems to have issues with Pydantic
|
||||
# TODO: Langchain tools are breaking every two weeks bc of changes on their side
|
||||
# def test_tools_from_langchain(client: LocalClient):
|
||||
# # create langchain tool
|
||||
# from langchain_community.tools import WikipediaQueryRun
|
||||
# from langchain_community.utilities import WikipediaAPIWrapper
|
||||
#
|
||||
# langchain_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
||||
#
|
||||
# # Add the tool
|
||||
# tool = client.load_langchain_tool(
|
||||
# langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
# )
|
||||
#
|
||||
# # list tools
|
||||
# tools = client.list_tools()
|
||||
# assert tool.name in [t.name for t in tools]
|
||||
#
|
||||
# # get tool
|
||||
# tool_id = client.get_tool_id(name=tool.name)
|
||||
# retrieved_tool = client.get_tool(tool_id)
|
||||
# source_code = retrieved_tool.source_code
|
||||
#
|
||||
# # Parse the function and attempt to use it
|
||||
# local_scope = {}
|
||||
# exec(source_code, {}, local_scope)
|
||||
# func = local_scope[tool.name]
|
||||
#
|
||||
# expected_content = "Albert Einstein"
|
||||
# assert expected_content in func(query="Albert Einstein")
|
||||
#
|
||||
#
|
||||
# def test_tool_creation_langchain_missing_imports(client: LocalClient):
|
||||
# # create langchain tool
|
||||
# from langchain_community.tools import WikipediaQueryRun
|
||||
# from langchain_community.utilities import WikipediaAPIWrapper
|
||||
#
|
||||
# api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
|
||||
# langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
||||
#
|
||||
# # Translate to memGPT Tool
|
||||
# # Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
|
||||
# with pytest.raises(RuntimeError):
|
||||
# ToolCreate.from_langchain(langchain_tool)
|
||||
|
||||
|
||||
def test_shared_blocks_without_send_message(client: LocalClient):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
@@ -11,6 +12,7 @@ from letta.orm import (
|
||||
BlocksAgents,
|
||||
FileMetadata,
|
||||
Job,
|
||||
Message,
|
||||
Organization,
|
||||
SandboxConfig,
|
||||
SandboxEnvironmentVariable,
|
||||
@@ -29,11 +31,12 @@ from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.sandbox_config import (
|
||||
E2BSandboxConfig,
|
||||
@@ -47,7 +50,7 @@ from letta.schemas.sandbox_config import (
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source import SourceUpdate
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.schemas.tool import ToolUpdate
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
@@ -77,6 +80,7 @@ using_sqlite = not bool(os.getenv("LETTA_PG_URI"))
|
||||
def clear_tables(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Job))
|
||||
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
||||
session.execute(delete(BlocksAgents))
|
||||
@@ -191,6 +195,21 @@ def print_tool(server: SyncServer, default_user, default_organization):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
# Set up message
|
||||
message = PydanticMessage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role="user",
|
||||
text="Hello, world!",
|
||||
)
|
||||
|
||||
msg = server.message_manager.create_message(message, actor=default_user)
|
||||
yield msg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sandbox_config_fixture(server: SyncServer, default_user):
|
||||
sandbox_config_create = SandboxConfigCreate(
|
||||
@@ -274,6 +293,7 @@ def other_tool(server: SyncServer, default_user, default_organization):
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
@@ -548,6 +568,163 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Message Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_message_create(server: SyncServer, hello_world_message_fixture, default_user):
|
||||
"""Test creating a message using hello_world_message_fixture fixture"""
|
||||
assert hello_world_message_fixture.id is not None
|
||||
assert hello_world_message_fixture.text == "Hello, world!"
|
||||
assert hello_world_message_fixture.role == "user"
|
||||
|
||||
# Verify we can retrieve it
|
||||
retrieved = server.message_manager.get_message_by_id(
|
||||
hello_world_message_fixture.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == hello_world_message_fixture.id
|
||||
assert retrieved.text == hello_world_message_fixture.text
|
||||
assert retrieved.role == hello_world_message_fixture.role
|
||||
|
||||
|
||||
def test_message_get_by_id(server: SyncServer, hello_world_message_fixture, default_user):
|
||||
"""Test retrieving a message by ID"""
|
||||
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == hello_world_message_fixture.id
|
||||
assert retrieved.text == hello_world_message_fixture.text
|
||||
|
||||
|
||||
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user):
|
||||
"""Test updating a message"""
|
||||
new_text = "Updated text"
|
||||
hello_world_message_fixture.text = new_text
|
||||
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, hello_world_message_fixture, actor=default_user)
|
||||
assert updated is not None
|
||||
assert updated.text == new_text
|
||||
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
|
||||
assert retrieved.text == new_text
|
||||
|
||||
|
||||
def test_message_delete(server: SyncServer, hello_world_message_fixture, default_user):
|
||||
"""Test deleting a message"""
|
||||
server.message_manager.delete_message_by_id(hello_world_message_fixture.id, actor=default_user)
|
||||
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
|
||||
assert retrieved is None
|
||||
|
||||
|
||||
def test_message_size(server: SyncServer, hello_world_message_fixture, default_user):
|
||||
"""Test counting messages with filters"""
|
||||
base_message = hello_world_message_fixture
|
||||
|
||||
# Create additional test messages
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
server.message_manager.create_many_messages(messages, actor=default_user)
|
||||
|
||||
# Test total count
|
||||
total = server.message_manager.size(actor=default_user, role=MessageRole.user)
|
||||
assert total == 6 # login message + base message + 4 test messages
|
||||
# TODO: change login message to be a system not user message
|
||||
|
||||
# Test count with agent filter
|
||||
agent_count = server.message_manager.size(actor=default_user, agent_id=base_message.agent_id, role=MessageRole.user)
|
||||
assert agent_count == 6
|
||||
|
||||
# Test count with role filter
|
||||
role_count = server.message_manager.size(actor=default_user, role=base_message.role)
|
||||
assert role_count == 6
|
||||
|
||||
# Test count with non-existent filter
|
||||
empty_count = server.message_manager.size(actor=default_user, agent_id="non-existent", role=MessageRole.user)
|
||||
assert empty_count == 0
|
||||
|
||||
|
||||
def create_test_messages(server: SyncServer, base_message: PydanticMessage, default_user) -> list[PydanticMessage]:
|
||||
"""Helper function to create test messages for all tests"""
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
server.message_manager.create_many_messages(messages, actor=default_user)
|
||||
return messages
|
||||
|
||||
|
||||
def test_message_listing_basic(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
|
||||
"""Test basic message listing with limit"""
|
||||
create_test_messages(server, hello_world_message_fixture, default_user)
|
||||
|
||||
results = server.message_manager.list_user_messages_for_agent(agent_id=sarah_agent.id, limit=3, actor=default_user)
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
def test_message_listing_cursor(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
|
||||
"""Test cursor-based pagination functionality"""
|
||||
create_test_messages(server, hello_world_message_fixture, default_user)
|
||||
|
||||
# Make sure there are 5 messages
|
||||
assert server.message_manager.size(actor=default_user, role=MessageRole.user) == 6
|
||||
|
||||
# Get first page
|
||||
first_page = server.message_manager.list_user_messages_for_agent(agent_id=sarah_agent.id, actor=default_user, limit=3)
|
||||
assert len(first_page) == 3
|
||||
|
||||
last_id_on_first_page = first_page[-1].id
|
||||
|
||||
# Get second page
|
||||
second_page = server.message_manager.list_user_messages_for_agent(
|
||||
agent_id=sarah_agent.id, actor=default_user, cursor=last_id_on_first_page, limit=3
|
||||
)
|
||||
assert len(second_page) == 3 # Should have 2 remaining messages
|
||||
assert all(r1.id != r2.id for r1 in first_page for r2 in second_page)
|
||||
|
||||
|
||||
def test_message_listing_filtering(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
|
||||
"""Test filtering messages by agent ID"""
|
||||
create_test_messages(server, hello_world_message_fixture, default_user)
|
||||
|
||||
agent_results = server.message_manager.list_user_messages_for_agent(agent_id=sarah_agent.id, actor=default_user, limit=10)
|
||||
assert len(agent_results) == 6 # login message + base message + 4 test messages
|
||||
assert all(msg.agent_id == hello_world_message_fixture.agent_id for msg in agent_results)
|
||||
|
||||
|
||||
def test_message_listing_text_search(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
|
||||
"""Test searching messages by text content"""
|
||||
create_test_messages(server, hello_world_message_fixture, default_user)
|
||||
|
||||
search_results = server.message_manager.list_user_messages_for_agent(
|
||||
agent_id=sarah_agent.id, actor=default_user, query_text="Test message", limit=10
|
||||
)
|
||||
assert len(search_results) == 4
|
||||
assert all("Test message" in msg.text for msg in search_results)
|
||||
|
||||
# Test no results
|
||||
search_results = server.message_manager.list_user_messages_for_agent(
|
||||
agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10
|
||||
)
|
||||
assert len(search_results) == 0
|
||||
|
||||
|
||||
def test_message_listing_date_range_filtering(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
|
||||
"""Test filtering messages by date range"""
|
||||
create_test_messages(server, hello_world_message_fixture, default_user)
|
||||
now = datetime.utcnow()
|
||||
|
||||
date_results = server.message_manager.list_user_messages_for_agent(
|
||||
agent_id=sarah_agent.id, actor=default_user, start_date=now - timedelta(minutes=1), end_date=now + timedelta(minutes=1), limit=10
|
||||
)
|
||||
assert len(date_results) > 0
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block Manager Tests
|
||||
# ======================================================================================================================
|
||||
@@ -1211,9 +1388,7 @@ def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent,
|
||||
|
||||
# Change the tool name
|
||||
new_name = "banana"
|
||||
tool = server.tool_manager.update_tool_by_id(
|
||||
tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user
|
||||
)
|
||||
tool = server.tool_manager.update_tool_by_id(tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user)
|
||||
assert tool.name == new_name
|
||||
|
||||
# Get the association
|
||||
@@ -1225,9 +1400,7 @@ def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent,
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.tools_agents_manager.add_tool_to_agent(
|
||||
agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name"
|
||||
)
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name")
|
||||
|
||||
|
||||
def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, print_tool, other_tool):
|
||||
@@ -1240,9 +1413,7 @@ def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, pri
|
||||
def test_remove_tool_with_name_from_agent(server, sarah_agent, default_user, print_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(
|
||||
agent_id=sarah_agent.id, tool_name=print_tool.name
|
||||
)
|
||||
removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name)
|
||||
|
||||
assert removed_tool.tool_name == print_tool.name
|
||||
assert removed_tool.tool_id == print_tool.id
|
||||
@@ -1280,6 +1451,7 @@ def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user,
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
|
||||
def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, other_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
|
||||
@@ -1290,6 +1462,7 @@ def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, o
|
||||
|
||||
assert not retrieved_tool_ids
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# JobManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -1,30 +1,40 @@
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from letta import BasicBlockMemory
|
||||
from letta import offline_memory_agent
|
||||
from letta.client.client import Block, create_client
|
||||
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.offline_memory_agent import (
|
||||
rethink_memory,
|
||||
finish_rethinking_memory,
|
||||
rethink_memory_convo,
|
||||
finish_rethinking_memory_convo,
|
||||
rethink_memory,
|
||||
rethink_memory_convo,
|
||||
trigger_rethink_memory,
|
||||
trigger_rethink_memory_convo,
|
||||
)
|
||||
from letta.prompts import gpt_system
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.utils import get_human_text, get_persona_text
|
||||
|
||||
|
||||
def test_ripple_edit():
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
client = create_client()
|
||||
assert client is not None
|
||||
trigger_rethink_memory_tool = client.create_tool(trigger_rethink_memory)
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_agents(client):
|
||||
for agent in client.list_agents():
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
|
||||
def test_ripple_edit(client, mock_e2b_api_key_none):
|
||||
trigger_rethink_memory_tool = client.create_or_update_tool(trigger_rethink_memory)
|
||||
|
||||
conversation_human_block = Block(name="human", label="human", value=get_human_text(DEFAULT_HUMAN), limit=2000)
|
||||
conversation_persona_block = Block(name="persona", label="persona", value=get_persona_text(DEFAULT_PERSONA), limit=2000)
|
||||
@@ -61,12 +71,7 @@ def test_ripple_edit():
|
||||
)
|
||||
assert conversation_agent is not None
|
||||
|
||||
assert set(conversation_agent.memory.list_block_labels()) == set([
|
||||
"persona",
|
||||
"human",
|
||||
"fact_block",
|
||||
"rethink_memory_block",
|
||||
])
|
||||
assert set(conversation_agent.memory.list_block_labels()) == {"persona", "human", "fact_block", "rethink_memory_block"}
|
||||
|
||||
rethink_memory_tool = client.create_tool(rethink_memory)
|
||||
finish_rethinking_memory_tool = client.create_tool(finish_rethinking_memory)
|
||||
@@ -82,7 +87,7 @@ def test_ripple_edit():
|
||||
include_base_tools=False,
|
||||
)
|
||||
assert offline_memory_agent is not None
|
||||
assert set(offline_memory_agent.memory.list_block_labels())== set(["persona", "human", "fact_block", "rethink_memory_block"])
|
||||
assert set(offline_memory_agent.memory.list_block_labels()) == {"persona", "human", "fact_block", "rethink_memory_block"}
|
||||
response = client.user_message(
|
||||
agent_id=conversation_agent.id, message="[trigger_rethink_memory]: Messi has now moved to playing for Inter Miami"
|
||||
)
|
||||
@@ -92,12 +97,14 @@ def test_ripple_edit():
|
||||
conversation_agent = client.get_agent(agent_id=conversation_agent.id)
|
||||
assert conversation_agent.memory.get_block("rethink_memory_block").value != "[empty]"
|
||||
|
||||
# Clean up agent
|
||||
client.create_agent(conversation_agent.id)
|
||||
client.delete_agent(offline_memory_agent.id)
|
||||
|
||||
def test_chat_only_agent():
|
||||
client = create_client()
|
||||
|
||||
rethink_memory = client.create_tool(rethink_memory_convo)
|
||||
finish_rethinking_memory = client.create_tool(finish_rethinking_memory_convo)
|
||||
def test_chat_only_agent(client, mock_e2b_api_key_none):
|
||||
rethink_memory = client.create_or_update_tool(rethink_memory_convo)
|
||||
finish_rethinking_memory = client.create_or_update_tool(finish_rethinking_memory_convo)
|
||||
|
||||
conversation_human_block = Block(name="chat_agent_human", label="chat_agent_human", value=get_human_text(DEFAULT_HUMAN), limit=2000)
|
||||
conversation_persona_block = Block(
|
||||
@@ -114,10 +121,10 @@ def test_chat_only_agent():
|
||||
tools=["send_message"],
|
||||
memory=conversation_memory,
|
||||
include_base_tools=False,
|
||||
metadata = {"offline_memory_tools": [rethink_memory.name, finish_rethinking_memory.name]}
|
||||
metadata={"offline_memory_tools": [rethink_memory.name, finish_rethinking_memory.name]},
|
||||
)
|
||||
assert chat_only_agent is not None
|
||||
assert set(chat_only_agent.memory.list_block_labels()) == set(["chat_agent_persona", "chat_agent_human"])
|
||||
assert set(chat_only_agent.memory.list_block_labels()) == {"chat_agent_persona", "chat_agent_human"}
|
||||
|
||||
for message in ["hello", "my name is not chad, my name is swoodily"]:
|
||||
client.send_message(agent_id=chat_only_agent.id, message=message, role="user")
|
||||
@@ -125,3 +132,6 @@ def test_chat_only_agent():
|
||||
|
||||
chat_only_agent = client.get_agent(agent_id=chat_only_agent.id)
|
||||
assert chat_only_agent.memory.get_block("chat_agent_human").value != get_human_text(DEFAULT_HUMAN)
|
||||
|
||||
# Clean up agent
|
||||
client.delete_agent(chat_only_agent.id)
|
||||
|
||||
@@ -17,7 +17,6 @@ from letta.schemas.letta_message import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.user import User
|
||||
|
||||
from .test_managers import DEFAULT_EMBEDDING_CONFIG
|
||||
@@ -91,7 +90,7 @@ def agent_id(server, user_id):
|
||||
|
||||
def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||
try:
|
||||
fake_agent_id = uuid.uuid4()
|
||||
fake_agent_id = str(uuid.uuid4())
|
||||
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
|
||||
raise Exception("user_message call should have failed")
|
||||
except (KeyError, ValueError) as e:
|
||||
@@ -388,7 +387,7 @@ def _test_get_messages_letta_format(
|
||||
agent_id,
|
||||
reverse=False,
|
||||
):
|
||||
"""Reverse is off by default, the GET goes in chronological order"""
|
||||
"""Test mapping between messages and letta_messages with reverse=False."""
|
||||
|
||||
messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
@@ -397,7 +396,6 @@ def _test_get_messages_letta_format(
|
||||
reverse=reverse,
|
||||
return_message_object=True,
|
||||
)
|
||||
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
letta_messages = server.get_agent_recall_cursor(
|
||||
@@ -407,140 +405,96 @@ def _test_get_messages_letta_format(
|
||||
reverse=reverse,
|
||||
return_message_object=False,
|
||||
)
|
||||
# letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
|
||||
assert all(isinstance(m, LettaMessage) for m in letta_messages)
|
||||
|
||||
# Loop through `messages` while also looping through `letta_messages`
|
||||
# Each message in `messages` should have 1+ corresponding messages in `letta_messages`
|
||||
# If role of message (in `messages`) is `assistant`,
|
||||
# then there should be two messages in `letta_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
|
||||
# If role of message (in `messages`) is `user`, then there should be one message in `letta_messages` which is type UserMessage.
|
||||
# If role of message (in `messages`) is `system`, then there should be one message in `letta_messages` which is type SystemMessage.
|
||||
# If role of message (in `messages`) is `tool`, then there should be one message in `letta_messages` which is type FunctionReturn.
|
||||
|
||||
print("MESSAGES (obj):")
|
||||
for i, m in enumerate(messages):
|
||||
# print(m)
|
||||
print(f"{i}: {m.role}, {m.text[:50]}...")
|
||||
# print(m.role)
|
||||
|
||||
print("MEMGPT_MESSAGES:")
|
||||
for i, m in enumerate(letta_messages):
|
||||
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
|
||||
|
||||
# Collect system messages and their texts
|
||||
system_messages = [m for m in messages if m.role == MessageRole.system]
|
||||
system_texts = [m.text for m in system_messages]
|
||||
|
||||
# If there are multiple system messages, print the diff
|
||||
if len(system_messages) > 1:
|
||||
print("Differences between system messages:")
|
||||
for i in range(len(system_texts) - 1):
|
||||
for j in range(i + 1, len(system_texts)):
|
||||
import difflib
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
system_texts[i].splitlines(),
|
||||
system_texts[j].splitlines(),
|
||||
fromfile=f"System Message {i+1}",
|
||||
tofile=f"System Message {j+1}",
|
||||
lineterm="",
|
||||
)
|
||||
print("\n".join(diff))
|
||||
else:
|
||||
print("There is only one or no system message.")
|
||||
print(f"Messages: {len(messages)}, LettaMessages: {len(letta_messages)}")
|
||||
|
||||
letta_message_index = 0
|
||||
for i, message in enumerate(messages):
|
||||
assert isinstance(message, Message)
|
||||
|
||||
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
|
||||
# Defensive bounds check for letta_messages
|
||||
if letta_message_index >= len(letta_messages):
|
||||
print(f"Error: letta_message_index out of range. Expected more letta_messages for message {i}: {message.role}")
|
||||
raise ValueError(f"Mismatch in letta_messages length. Index: {letta_message_index}, Length: {len(letta_messages)}")
|
||||
|
||||
print(f"Processing message {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
|
||||
while letta_message_index < len(letta_messages):
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
print(f"letta_message {letta_message_index}: {str(letta_message)[:50]}")
|
||||
|
||||
# Validate mappings for assistant role
|
||||
if message.role == MessageRole.assistant:
|
||||
print(f"i={i}, M=assistant, MM={type(letta_message)}")
|
||||
print(f"Assistant Message at {i}: {type(letta_message)}")
|
||||
|
||||
# If reverse, function call will come first
|
||||
if reverse:
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
# Reverse handling: FunctionCallMessages come first
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
# Try to parse the tool call args
|
||||
try:
|
||||
json.loads(tool_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
|
||||
assert isinstance(letta_message, FunctionCallMessage)
|
||||
letta_message_index += 1
|
||||
if letta_message_index >= len(letta_messages):
|
||||
break
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
if message.text is not None:
|
||||
if message.text:
|
||||
assert isinstance(letta_message, InternalMonologue)
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
else:
|
||||
|
||||
if message.text is not None:
|
||||
else: # Non-reverse handling
|
||||
if message.text:
|
||||
assert isinstance(letta_message, InternalMonologue)
|
||||
letta_message_index += 1
|
||||
if letta_message_index >= len(letta_messages):
|
||||
break
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
# Try to parse the tool call args
|
||||
try:
|
||||
json.loads(tool_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
|
||||
assert isinstance(letta_message, FunctionCallMessage)
|
||||
assert tool_call.function.name == letta_message.function_call.name
|
||||
assert tool_call.function.arguments == letta_message.function_call.arguments
|
||||
letta_message_index += 1
|
||||
if letta_message_index >= len(letta_messages):
|
||||
break
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
elif message.role == MessageRole.user:
|
||||
print(f"i={i}, M=user, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, UserMessage)
|
||||
assert message.text == letta_message.message
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.system:
|
||||
print(f"i={i}, M=system, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, SystemMessage)
|
||||
assert message.text == letta_message.message
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.tool:
|
||||
print(f"i={i}, M=tool, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, FunctionReturn)
|
||||
# Check the the value in `text` is the same
|
||||
assert message.text == letta_message.function_return
|
||||
letta_message_index += 1
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected message role: {message.role}")
|
||||
|
||||
# Move to the next message in the original messages list
|
||||
break
|
||||
break # Exit the letta_messages loop after processing one mapping
|
||||
|
||||
if letta_message_index < len(letta_messages):
|
||||
warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}")
|
||||
|
||||
|
||||
def test_get_messages_letta_format(server, user_id, agent_id):
|
||||
for reverse in [False, True]:
|
||||
# for reverse in [False, True]:
|
||||
for reverse in [False]:
|
||||
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
|
||||
|
||||
|
||||
@@ -586,7 +540,7 @@ def ingest(message: str):
|
||||
'''
|
||||
|
||||
|
||||
def test_tool_run(server, user_id, agent_id):
|
||||
def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
|
||||
"""Test that the server can run tools"""
|
||||
|
||||
result = server.run_tool_from_source(
|
||||
@@ -672,7 +626,7 @@ def test_composio_client_simple(server):
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
def test_memory_rebuild_count(server, user_id):
|
||||
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
|
||||
# create agent
|
||||
@@ -712,7 +666,6 @@ def test_memory_rebuild_count(server, user_id):
|
||||
return len(system_messages), letta_messages
|
||||
|
||||
try:
|
||||
|
||||
# At this stage, there should only be 1 system message inside of recall storage
|
||||
num_system_messages, all_messages = count_system_messages_in_recall()
|
||||
# assert num_system_messages == 1, (num_system_messages, all_messages)
|
||||
|
||||
@@ -1,318 +0,0 @@
|
||||
# TODO: add back post DB refactor
|
||||
|
||||
# import os
|
||||
# import uuid
|
||||
# from datetime import datetime, timedelta
|
||||
#
|
||||
# import pytest
|
||||
# from sqlalchemy.ext.declarative import declarative_base
|
||||
#
|
||||
# from letta.agent_store.storage import StorageConnector, TableType
|
||||
# from letta.config import LettaConfig
|
||||
# from letta.constants import BASE_TOOLS, MAX_EMBEDDING_DIM
|
||||
# from letta.credentials import LettaCredentials
|
||||
# from letta.embeddings import embedding_model, query_embedding
|
||||
# from letta.metadata import MetadataStore
|
||||
# from letta.settings import settings
|
||||
# from tests import TEST_MEMGPT_CONFIG
|
||||
# from tests.utils import create_config, wipe_config
|
||||
#
|
||||
# from .utils import with_qdrant_storage
|
||||
#
|
||||
# from letta.schemas.agent import AgentState
|
||||
# from letta.schemas.message import Message
|
||||
# from letta.schemas.passage import Passage
|
||||
# from letta.schemas.user import User
|
||||
#
|
||||
#
|
||||
## Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
|
||||
# texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
# start_date = datetime(2009, 10, 5, 18, 00)
|
||||
# dates = [start_date, start_date - timedelta(weeks=1), start_date + timedelta(weeks=1)]
|
||||
# roles = ["user", "assistant", "assistant"]
|
||||
# agent_1_id = uuid.uuid4()
|
||||
# agent_2_id = uuid.uuid4()
|
||||
# agent_ids = [agent_1_id, agent_2_id, agent_1_id]
|
||||
# ids = [uuid.uuid4(), uuid.uuid4(), uuid.uuid4()]
|
||||
# user_id = uuid.uuid4()
|
||||
#
|
||||
#
|
||||
## Data generation functions: Passages
|
||||
# def generate_passages(embed_model):
|
||||
# """Generate list of 3 Passage objects"""
|
||||
# # embeddings: use openai if env is set, otherwise local
|
||||
# passages = []
|
||||
# for text, _, _, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
|
||||
# embedding, embedding_model, embedding_dim = None, None, None
|
||||
# if embed_model:
|
||||
# embedding = embed_model.get_text_embedding(text)
|
||||
# embedding_model = "gpt-4"
|
||||
# embedding_dim = len(embedding)
|
||||
# passages.append(
|
||||
# Passage(
|
||||
# user_id=user_id,
|
||||
# text=text,
|
||||
# agent_id=agent_id,
|
||||
# embedding=embedding,
|
||||
# data_source="test_source",
|
||||
# id=id,
|
||||
# embedding_dim=embedding_dim,
|
||||
# embedding_model=embedding_model,
|
||||
# )
|
||||
# )
|
||||
# return passages
|
||||
#
|
||||
#
|
||||
## Data generation functions: Messages
|
||||
# def generate_messages(embed_model):
|
||||
# """Generate list of 3 Message objects"""
|
||||
# messages = []
|
||||
# for text, date, role, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
|
||||
# embedding, embedding_model, embedding_dim = None, None, None
|
||||
# if embed_model:
|
||||
# embedding = embed_model.get_text_embedding(text)
|
||||
# embedding_model = "gpt-4"
|
||||
# embedding_dim = len(embedding)
|
||||
# messages.append(
|
||||
# Message(
|
||||
# user_id=user_id,
|
||||
# text=text,
|
||||
# agent_id=agent_id,
|
||||
# role=role,
|
||||
# created_at=date,
|
||||
# id=id,
|
||||
# model="gpt-4",
|
||||
# embedding=embedding,
|
||||
# embedding_model=embedding_model,
|
||||
# embedding_dim=embedding_dim,
|
||||
# )
|
||||
# )
|
||||
# print(messages[-1].text)
|
||||
# return messages
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def clear_dynamically_created_models():
|
||||
# """Wipe globals for SQLAlchemy"""
|
||||
# yield
|
||||
# for key in list(globals().keys()):
|
||||
# if key.endswith("Model"):
|
||||
# del globals()[key]
|
||||
#
|
||||
#
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def recreate_declarative_base():
|
||||
# """Recreate the declarative base before each test"""
|
||||
# global Base
|
||||
# Base = declarative_base()
|
||||
# yield
|
||||
# Base.metadata.clear()
|
||||
#
|
||||
#
|
||||
# @pytest.mark.parametrize("storage_connector", with_qdrant_storage(["postgres", "chroma", "sqlite", "milvus"]))
|
||||
## @pytest.mark.parametrize("storage_connector", ["sqlite", "chroma"])
|
||||
## @pytest.mark.parametrize("storage_connector", ["postgres"])
|
||||
# @pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
|
||||
# def test_storage(
|
||||
# storage_connector,
|
||||
# table_type,
|
||||
# clear_dynamically_created_models,
|
||||
# recreate_declarative_base,
|
||||
# ):
|
||||
# # setup letta config
|
||||
# # TODO: set env for different config path
|
||||
#
|
||||
# # hacky way to cleanup globals that scruw up tests
|
||||
# # for table_name in ['Message']:
|
||||
# # if 'Message' in globals():
|
||||
# # print("Removing messages", globals()['Message'])
|
||||
# # del globals()['Message']
|
||||
#
|
||||
# wipe_config()
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
# create_config("openai")
|
||||
# credentials = LettaCredentials(
|
||||
# openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
# )
|
||||
# else: # hosted
|
||||
# create_config("letta_hosted")
|
||||
# LettaCredentials()
|
||||
#
|
||||
# config = LettaConfig.load()
|
||||
# TEST_MEMGPT_CONFIG.default_embedding_config = config.default_embedding_config
|
||||
# TEST_MEMGPT_CONFIG.default_llm_config = config.default_llm_config
|
||||
#
|
||||
# if storage_connector == "postgres":
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_uri = settings.letta_pg_uri
|
||||
# TEST_MEMGPT_CONFIG.recall_storage_uri = settings.letta_pg_uri
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
|
||||
# TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
|
||||
# if storage_connector == "lancedb":
|
||||
# # TODO: complete lancedb implementation
|
||||
# if not os.getenv("LANCEDB_TEST_URL"):
|
||||
# print("Skipping test, missing LanceDB URI")
|
||||
# return
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["LANCEDB_TEST_URL"]
|
||||
# TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["LANCEDB_TEST_URL"]
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_type = "lancedb"
|
||||
# TEST_MEMGPT_CONFIG.recall_storage_type = "lancedb"
|
||||
# if storage_connector == "chroma":
|
||||
# if table_type == TableType.RECALL_MEMORY:
|
||||
# print("Skipping test, chroma only supported for archival memory")
|
||||
# return
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_type = "chroma"
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_path = "./test_chroma"
|
||||
# if storage_connector == "sqlite":
|
||||
# if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
# print("Skipping test, sqlite only supported for recall memory")
|
||||
# return
|
||||
# TEST_MEMGPT_CONFIG.recall_storage_type = "sqlite"
|
||||
# if storage_connector == "qdrant":
|
||||
# if table_type == TableType.RECALL_MEMORY:
|
||||
# print("Skipping test, Qdrant only supports archival memory")
|
||||
# return
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_type = "qdrant"
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_uri = "localhost:6333"
|
||||
# if storage_connector == "milvus":
|
||||
# if table_type == TableType.RECALL_MEMORY:
|
||||
# print("Skipping test, Milvus only supports archival memory")
|
||||
# return
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_type = "milvus"
|
||||
# TEST_MEMGPT_CONFIG.archival_storage_uri = "./milvus.db"
|
||||
# # get embedding model
|
||||
# embedding_config = TEST_MEMGPT_CONFIG.default_embedding_config
|
||||
# embed_model = embedding_model(TEST_MEMGPT_CONFIG.default_embedding_config)
|
||||
#
|
||||
# # create user
|
||||
# ms = MetadataStore(TEST_MEMGPT_CONFIG)
|
||||
# ms.delete_user(user_id)
|
||||
# user = User(id=user_id)
|
||||
# agent = AgentState(
|
||||
# user_id=user_id,
|
||||
# name="agent_1",
|
||||
# id=agent_1_id,
|
||||
# llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
# embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
# system="",
|
||||
# tools=BASE_TOOLS,
|
||||
# state={
|
||||
# "persona": "",
|
||||
# "human": "",
|
||||
# "messages": None,
|
||||
# },
|
||||
# )
|
||||
# ms.create_user(user)
|
||||
# ms.create_agent(agent)
|
||||
#
|
||||
# # create storage connector
|
||||
# conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
|
||||
# # conn.client.delete_collection(conn.collection.name) # clear out data
|
||||
# conn.delete_table()
|
||||
# conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
|
||||
#
|
||||
# # generate data
|
||||
# if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
# records = generate_passages(embed_model)
|
||||
# elif table_type == TableType.RECALL_MEMORY:
|
||||
# records = generate_messages(embed_model)
|
||||
# else:
|
||||
# raise NotImplementedError(f"Table type {table_type} not implemented")
|
||||
#
|
||||
# # check record dimentions
|
||||
# print("TABLE TYPE", table_type, type(records[0]), len(records[0].embedding))
|
||||
# if embed_model:
|
||||
# assert len(records[0].embedding) == MAX_EMBEDDING_DIM, f"Expected {MAX_EMBEDDING_DIM}, got {len(records[0].embedding)}"
|
||||
# assert (
|
||||
# records[0].embedding_dim == embedding_config.embedding_dim
|
||||
# ), f"Expected {embedding_config.embedding_dim}, got {records[0].embedding_dim}"
|
||||
#
|
||||
# # test: insert
|
||||
# conn.insert(records[0])
|
||||
# assert conn.size() == 1, f"Expected 1 record, got {conn.size()}: {conn.get_all()}"
|
||||
#
|
||||
# # test: insert_many
|
||||
# conn.insert_many(records[1:])
|
||||
# assert (
|
||||
# conn.size() == 2
|
||||
# ), f"Expected 2 records, got {conn.size()}: {conn.get_all()}" # expect 2, since storage connector filters for agent1
|
||||
#
|
||||
# # test: update
|
||||
# # NOTE: only testing with messages
|
||||
# if table_type == TableType.RECALL_MEMORY:
|
||||
# TEST_STRING = "hello world"
|
||||
#
|
||||
# updated_record = records[1]
|
||||
# updated_record.text = TEST_STRING
|
||||
#
|
||||
# current_record = conn.get(id=updated_record.id)
|
||||
# assert current_record is not None, f"Couldn't find {updated_record.id}"
|
||||
# assert current_record.text != TEST_STRING, (current_record.text, TEST_STRING)
|
||||
#
|
||||
# conn.update(updated_record)
|
||||
# new_record = conn.get(id=updated_record.id)
|
||||
# assert new_record is not None, f"Couldn't find {updated_record.id}"
|
||||
# assert new_record.text == TEST_STRING, (new_record.text, TEST_STRING)
|
||||
#
|
||||
# # test: list_loaded_data
|
||||
# # TODO: add back
|
||||
# # if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
# # sources = StorageConnector.list_loaded_data(storage_type=storage_connector)
|
||||
# # assert len(sources) == 1, f"Expected 1 source, got {len(sources)}"
|
||||
# # assert sources[0] == "test_source", f"Expected 'test_source', got {sources[0]}"
|
||||
#
|
||||
# # test: get_all_paginated
|
||||
# paginated_total = 0
|
||||
# for page in conn.get_all_paginated(page_size=1):
|
||||
# paginated_total += len(page)
|
||||
# assert paginated_total == 2, f"Expected 2 records, got {paginated_total}"
|
||||
#
|
||||
# # test: get_all
|
||||
# all_records = conn.get_all()
|
||||
# assert len(all_records) == 2, f"Expected 2 records, got {len(all_records)}"
|
||||
# all_records = conn.get_all(limit=1)
|
||||
# assert len(all_records) == 1, f"Expected 1 records, got {len(all_records)}"
|
||||
#
|
||||
# # test: get
|
||||
# print("GET ID", ids[0], records)
|
||||
# res = conn.get(id=ids[0])
|
||||
# assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}"
|
||||
#
|
||||
# # test: size
|
||||
# assert conn.size() == 2, f"Expected 2 records, got {conn.size()}"
|
||||
# assert conn.size(filters={"agent_id": agent.id}) == 2, f"Expected 2 records, got {conn.size(filters={'agent_id', agent.id})}"
|
||||
# if table_type == TableType.RECALL_MEMORY:
|
||||
# assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}"
|
||||
#
|
||||
# # test: query (vector)
|
||||
# if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
# query = "why was she crying"
|
||||
# query_vec = query_embedding(embed_model, query)
|
||||
# res = conn.query(None, query_vec, top_k=2)
|
||||
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
# print("Archival memory results", res)
|
||||
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
#
|
||||
# # test optional query functions for recall memory
|
||||
# if table_type == TableType.RECALL_MEMORY:
|
||||
# # test: query_text
|
||||
# query = "CindereLLa"
|
||||
# res = conn.query_text(query)
|
||||
# assert len(res) == 1, f"Expected 1 result, got {len(res)}"
|
||||
# assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}"
|
||||
#
|
||||
# # test: query_date (recall memory only)
|
||||
# print("Testing recall memory date search")
|
||||
# start_date = datetime(2009, 10, 5, 18, 00)
|
||||
# start_date = start_date - timedelta(days=1)
|
||||
# end_date = start_date + timedelta(days=1)
|
||||
# res = conn.query_date(start_date=start_date, end_date=end_date)
|
||||
# print("DATE", res)
|
||||
# assert len(res) == 1, f"Expected 1 result, got {len(res)}: {res}"
|
||||
#
|
||||
# # test: delete
|
||||
# conn.delete({"id": ids[0]})
|
||||
# assert conn.size() == 1, f"Expected 2 records, got {conn.size()}"
|
||||
#
|
||||
# # cleanup
|
||||
# ms.delete_user(user_id)
|
||||
#
|
||||
@@ -1,14 +1,11 @@
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.settings import tool_settings
|
||||
|
||||
from .utils import wipe_config
|
||||
|
||||
@@ -21,21 +18,6 @@ agent_obj = None
|
||||
# TODO: these tests should add function calls into the summarized message sequence:W
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_e2b_api_key_none():
|
||||
# Store the original value of e2b_api_key
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
# Set e2b_api_key to None
|
||||
tool_settings.e2b_api_key = None
|
||||
|
||||
# Yield control to the test
|
||||
yield
|
||||
|
||||
# Restore the original value of e2b_api_key
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
def create_test_agent():
|
||||
"""Create a test agent that we can call functions on"""
|
||||
wipe_config()
|
||||
|
||||
Reference in New Issue
Block a user