feat: message orm migration (#2144)

Co-authored-by: Mindy Long <mindy@letta.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
mlong93
2024-12-06 11:50:15 -08:00
committed by GitHub
parent 370a0e68dd
commit 6c2c7231ab
45 changed files with 984 additions and 1265 deletions

View File

@@ -31,6 +31,11 @@ jobs:
# Check each changed Python file
while IFS= read -r file; do
if [ "$file" == "letta/main.py" ]; then
echo "Skipping $file for print statement checks."
continue
fi
if [ -f "$file" ]; then
echo "Checking $file for new print statements..."

View File

@@ -31,6 +31,7 @@ jobs:
- "test_utils.py"
- "test_tool_schema_parsing.py"
- "test_v1_routes.py"
- "test_offline_memory_agent.py"
services:
qdrant:
image: qdrant/qdrant
@@ -133,4 +134,4 @@ jobs:
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_v1_routes.py and not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests
poetry run pytest -s -vv -k "not test_offline_memory_agent.py and not test_v1_routes.py and not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests

View File

@@ -0,0 +1,63 @@
"""Migrate message to orm
Revision ID: 95badb46fdf9
Revises: 3c683a662c82
Create Date: 2024-12-05 14:02:04.163150
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "95badb46fdf9"
down_revision: Union[str, None] = "08b2f8225812"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("messages", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("messages", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("messages", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("messages", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.add_column("messages", sa.Column("organization_id", sa.String(), nullable=True))
# Populate `organization_id` based on `user_id`
# Use a raw SQL query to update the organization_id
op.execute(
"""
UPDATE messages
SET organization_id = users.organization_id
FROM users
WHERE messages.user_id = users.id
"""
)
op.alter_column("messages", "organization_id", nullable=False)
op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
op.drop_index("message_idx_user", table_name="messages")
op.create_foreign_key(None, "messages", "agents", ["agent_id"], ["id"])
op.create_foreign_key(None, "messages", "organizations", ["organization_id"], ["id"])
op.drop_column("messages", "user_id")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("messages", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, "messages", type_="foreignkey")
op.drop_constraint(None, "messages", type_="foreignkey")
op.create_index("message_idx_user", "messages", ["user_id", "agent_id"], unique=False)
op.alter_column("messages", "created_at", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True)
op.alter_column("messages", "tool_calls", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
op.drop_column("messages", "organization_id")
op.drop_column("messages", "_last_updated_by_id")
op.drop_column("messages", "_created_by_id")
op.drop_column("messages", "is_deleted")
op.drop_column("messages", "updated_at")
# ### end Alembic commands ###

View File

@@ -97,7 +97,7 @@ def upgrade() -> None:
sa.Column("text", sa.String(), nullable=True),
sa.Column("model", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("tool_calls", letta.metadata.ToolCallColumn(), nullable=True),
sa.Column("tool_calls", letta.orm.message.ToolCallColumn(), nullable=True),
sa.Column("tool_call_id", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),

View File

@@ -19,6 +19,7 @@ from letta.constants import (
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC,
O1_BASE_TOOLS,
REQ_HEARTBEAT_MESSAGE,
)
from letta.errors import LLMError
@@ -27,10 +28,9 @@ from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
from letta.memory import ArchivalMemory, EmbeddingArchivalMemory, summarize_messages
from letta.metadata import MetadataStore
from letta.orm import User
from letta.persistence_manager import LocalStateManager
from letta.schemas.agent import AgentState, AgentStepResponse
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
@@ -49,7 +49,9 @@ from letta.schemas.passage import Passage
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.user_manager import UserManager
@@ -80,9 +82,11 @@ from letta.utils import (
def compile_memory_metadata_block(
actor: PydanticUser,
agent_id: str,
memory_edit_timestamp: datetime.datetime,
archival_memory: Optional[ArchivalMemory] = None,
recall_memory: Optional[RecallMemory] = None,
message_manager: Optional[MessageManager] = None,
) -> str:
# Put the timestamp in the local timezone (mimicking get_local_time())
timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip()
@@ -91,7 +95,7 @@ def compile_memory_metadata_block(
memory_metadata_block = "\n".join(
[
f"### Memory [last modified: {timestamp_str}]",
f"{recall_memory.count() if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{archival_memory.count() if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
]
@@ -101,10 +105,12 @@ def compile_memory_metadata_block(
def compile_system_message(
system_prompt: str,
agent_id: str,
in_context_memory: Memory,
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
actor: PydanticUser,
archival_memory: Optional[ArchivalMemory] = None,
recall_memory: Optional[RecallMemory] = None,
message_manager: Optional[MessageManager] = None,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
@@ -129,9 +135,11 @@ def compile_system_message(
else:
# TODO should this all put into the memory.__repr__ function?
memory_metadata_string = compile_memory_metadata_block(
actor=actor,
agent_id=agent_id,
memory_edit_timestamp=in_context_memory_last_edit,
archival_memory=archival_memory,
recall_memory=recall_memory,
message_manager=message_manager,
)
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
@@ -164,9 +172,11 @@ def compile_system_message(
def initialize_message_sequence(
model: str,
system: str,
agent_id: str,
memory: Memory,
actor: PydanticUser,
archival_memory: Optional[ArchivalMemory] = None,
recall_memory: Optional[RecallMemory] = None,
message_manager: Optional[MessageManager] = None,
memory_edit_timestamp: Optional[datetime.datetime] = None,
include_initial_boot_message: bool = True,
) -> List[dict]:
@@ -177,11 +187,13 @@ def initialize_message_sequence(
# system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
# )
full_system_message = compile_system_message(
agent_id=agent_id,
system_prompt=system,
in_context_memory=memory,
in_context_memory_last_edit=memory_edit_timestamp,
actor=actor,
archival_memory=archival_memory,
recall_memory=recall_memory,
message_manager=message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
)
@@ -282,7 +294,8 @@ class Agent(BaseAgent):
self.interface = interface
# Create the persistence manager object based on the AgentState info
self.persistence_manager = LocalStateManager(agent_state=self.agent_state)
self.archival_memory = EmbeddingArchivalMemory(agent_state)
self.message_manager = MessageManager()
# State needed for heartbeat pausing
self.pause_heartbeats_start = None
@@ -309,9 +322,11 @@ class Agent(BaseAgent):
init_messages = initialize_message_sequence(
model=self.model,
system=self.agent_state.system,
agent_id=self.agent_state.id,
memory=self.agent_state.memory,
actor=self.user,
archival_memory=None,
recall_memory=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)
@@ -333,8 +348,10 @@ class Agent(BaseAgent):
model=self.model,
system=self.agent_state.system,
memory=self.agent_state.memory,
agent_id=self.agent_state.id,
actor=self.user,
archival_memory=None,
recall_memory=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)
@@ -356,7 +373,6 @@ class Agent(BaseAgent):
# Put the messages inside the message buffer
self.messages_total = 0
# self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])
self._append_to_messages(added_messages=init_messages_objs)
self._validate_message_buffer_is_utc()
@@ -413,7 +429,10 @@ class Agent(BaseAgent):
# TODO: need to have an AgentState object that actually has full access to the block data
# this is because the sandbox tools need to be able to access block.value to edit this data
try:
if function_name in BASE_TOOLS:
# TODO: This is NO BUENO
# TODO: Matching purely by names is extremely problematic, users can create tools with these names and run them in the agent loop
# TODO: We will have probably have to match the function strings exactly for safety
if function_name in BASE_TOOLS or function_name in O1_BASE_TOOLS:
# base tools are allowed to access the `Agent` object and run on the database
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = function_to_call(**function_args)
@@ -474,7 +493,7 @@ class Agent(BaseAgent):
# Pull the message objects from the database
message_objs = []
for msg_id in message_ids:
msg_obj = self.persistence_manager.recall_memory.storage.get(msg_id)
msg_obj = self.message_manager.get_message_by_id(msg_id, actor=self.user)
if msg_obj:
if isinstance(msg_obj, Message):
message_objs.append(msg_obj)
@@ -522,16 +541,13 @@ class Agent(BaseAgent):
def _trim_messages(self, num):
"""Trim messages from the front, not including the system message"""
self.persistence_manager.trim_messages(num)
new_messages = [self._messages[0]] + self._messages[num:]
self._messages = new_messages
def _prepend_to_messages(self, added_messages: List[Message]):
"""Wrapper around self.messages.prepend to allow additional calls to a state/persistence manager"""
assert all([isinstance(msg, Message) for msg in added_messages])
self.persistence_manager.prepend_to_messages(added_messages)
self.message_manager.create_many_messages(added_messages, actor=self.user)
new_messages = [self._messages[0]] + added_messages + self._messages[1:] # prepend (no system)
self._messages = new_messages
@@ -540,8 +556,7 @@ class Agent(BaseAgent):
def _append_to_messages(self, added_messages: List[Message]):
"""Wrapper around self.messages.append to allow additional calls to a state/persistence manager"""
assert all([isinstance(msg, Message) for msg in added_messages])
self.persistence_manager.append_to_messages(added_messages)
self.message_manager.create_many_messages(added_messages, actor=self.user)
# strip extra metadata if it exists
# for msg in added_messages:
@@ -885,7 +900,6 @@ class Agent(BaseAgent):
messages=next_input_message,
**kwargs,
)
step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
@@ -1247,7 +1261,7 @@ class Agent(BaseAgent):
assert new_system_message_obj.role == "system", new_system_message_obj
assert self._messages[0].role == "system", self._messages
self.persistence_manager.swap_system_message(new_system_message_obj)
self.message_manager.create_message(new_system_message_obj, actor=self.user)
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages
@@ -1280,11 +1294,13 @@ class Agent(BaseAgent):
# update memory (TODO: potentially update recall/archival stats seperately)
new_system_message_str = compile_system_message(
agent_id=self.agent_state.id,
system_prompt=self.agent_state.system,
in_context_memory=self.agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
archival_memory=self.persistence_manager.archival_memory,
recall_memory=self.persistence_manager.recall_memory,
actor=self.user,
archival_memory=self.archival_memory,
message_manager=self.message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
)
@@ -1370,20 +1386,20 @@ class Agent(BaseAgent):
# passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}")
# insert into agent archival memory
self.persistence_manager.archival_memory.storage.insert_many(passages)
self.archival_memory.storage.insert_many(passages)
all_passages += passages
assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"
# save destination storage
self.persistence_manager.archival_memory.storage.save()
self.archival_memory.storage.save()
# attach to agent
source = SourceManager().get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in metadata store"
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
total_agent_passages = self.persistence_manager.archival_memory.storage.size()
total_agent_passages = self.archival_memory.storage.size()
printd(
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
@@ -1392,7 +1408,7 @@ class Agent(BaseAgent):
def update_message(self, request: UpdateMessage) -> Message:
"""Update the details of a message associated with an agent"""
message = self.persistence_manager.recall_memory.storage.get(id=request.id)
message = self.message_manager.get_message_by_id(message_id=request.id, actor=self.user)
if message is None:
raise ValueError(f"Message with id {request.id} not found")
assert isinstance(message, Message), f"Message is not a Message object: {type(message)}"
@@ -1413,10 +1429,10 @@ class Agent(BaseAgent):
message.tool_call_id = request.tool_call_id
# Save the updated message
self.persistence_manager.recall_memory.storage.update(record=message)
self.message_manager.update_message_by_id(message_id=message.id, message=message, actor=self.user)
# Return the updated message
updated_message = self.persistence_manager.recall_memory.storage.get(id=message.id)
updated_message = self.message_manager.get_message_by_id(message_id=message.id, actor=self.user)
if updated_message is None:
raise ValueError(f"Error persisting message - message with id {request.id} not found")
return updated_message
@@ -1496,7 +1512,7 @@ class Agent(BaseAgent):
deleted_message = self._messages.pop()
# then also remove it from recall storage
try:
self.persistence_manager.recall_memory.storage.delete(filters={"id": deleted_message.id})
self.message_manager.delete_message_by_id(deleted_message.id, actor=self.user)
popped_messages.append(deleted_message)
except Exception as e:
warnings.warn(f"Error deleting message {deleted_message.id} from recall memory: {e}")
@@ -1522,7 +1538,6 @@ class Agent(BaseAgent):
def retry_message(self) -> List[Message]:
"""Retry / regenerate the last message"""
self.pop_until_user()
user_message = self.pop_message(count=1)[0]
assert user_message.text is not None, "User message text is None"
@@ -1569,12 +1584,14 @@ class Agent(BaseAgent):
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
)
num_archival_memory = self.persistence_manager.archival_memory.storage.size()
num_recall_memory = self.persistence_manager.recall_memory.storage.size()
num_archival_memory = self.archival_memory.storage.size()
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
external_memory_summary = compile_memory_metadata_block(
actor=self.user,
agent_id=self.agent_state.id,
memory_edit_timestamp=get_utc_time(), # dummy timestamp
archival_memory=self.persistence_manager.archival_memory,
recall_memory=self.persistence_manager.recall_memory,
archival_memory=self.archival_memory,
message_manager=self.message_manager,
)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
@@ -1600,7 +1617,7 @@ class Agent(BaseAgent):
# context window breakdown (in messages)
num_messages=len(self._messages),
num_archival_memory=num_archival_memory,
num_recall_memory=num_recall_memory,
num_recall_memory=message_manager_size,
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
# top-level information
context_window_size_max=self.agent_state.llm_config.context_window,

View File

@@ -27,13 +27,11 @@ from tqdm import tqdm
from letta.agent_store.storage import StorageConnector, TableType
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
from letta.metadata import EmbeddingConfigColumn
from letta.orm.base import Base
from letta.orm.file import FileMetadata as FileMetadataModel
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.message import Message
from letta.schemas.openai.chat_completions import ToolCall
from letta.schemas.passage import Passage
from letta.settings import settings
@@ -69,69 +67,6 @@ class CommonVector(TypeDecorator):
return np.frombuffer(value, dtype=np.float32)
class MessageModel(Base):
"""Defines data model for storing Message objects"""
__tablename__ = "messages"
__table_args__ = {"extend_existing": True}
# Assuming message_id is the primary key
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
# openai info
role = Column(String, nullable=False)
text = Column(String) # optional: can be null if function call
model = Column(String) # optional: can be null if LLM backend doesn't require specifying
name = Column(String) # optional: multi-agent only
# tool call request info
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
# TODO align with OpenAI spec of multiple tool calls
# tool_calls = Column(ToolCallColumn)
tool_calls = Column(ToolCallColumn)
# tool call response info
# if role == "tool", then this must be specified
# if role != "tool", this must be null
tool_call_id = Column(String)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
Index("message_idx_user", user_id, agent_id),
def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}')>"
def to_record(self):
# calls = (
# [ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
# if self.tool_calls
# else None
# )
# if calls:
# assert isinstance(calls[0], ToolCall)
if self.tool_calls and len(self.tool_calls) > 0:
assert isinstance(self.tool_calls[0], ToolCall), type(self.tool_calls[0])
for tool in self.tool_calls:
assert isinstance(tool, ToolCall), type(tool)
return Message(
user_id=self.user_id,
agent_id=self.agent_id,
role=self.role,
name=self.name,
text=self.text,
model=self.model,
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
tool_calls=self.tool_calls,
tool_call_id=self.tool_call_id,
created_at=self.created_at,
id=self.id,
)
class PassageModel(Base):
"""Defines data model for storing Passages (consisting of text, embedding)"""
@@ -367,11 +302,6 @@ class PostgresStorageConnector(SQLStorageConnector):
self.db_model = PassageModel
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
self.db_model = MessageModel
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specify recall_storage_uri in config {self.config.config_path}")
elif table_type == TableType.FILES:
self.uri = self.config.metadata_storage_uri
self.db_model = FileMetadataModel
@@ -490,12 +420,6 @@ class SQLLiteStorageConnector(SQLStorageConnector):
# get storage URI
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
raise ValueError(f"Table type {table_type} not implemented")
elif table_type == TableType.RECALL_MEMORY:
# TODO: eventually implement URI option
self.path = self.config.recall_storage_path
if self.path is None:
raise ValueError(f"Must specify recall_storage_path in config.")
self.db_model = MessageModel
elif table_type == TableType.FILES:
self.path = self.config.metadata_storage_path
if self.path is None:

View File

@@ -1,177 +0,0 @@
# type: ignore
import uuid
from datetime import datetime
from typing import Dict, Iterator, List, Optional
from lancedb.pydantic import LanceModel, Vector
from letta.agent_store.storage import StorageConnector, TableType
from letta.config import AgentConfig, LettaConfig
from letta.schemas.message import Message, Passage, Record
""" Initial implementation - not complete """
def get_db_model(table_name: str, table_type: TableType):
config = LettaConfig.load()
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
# create schema for archival memory
class PassageModel(LanceModel):
"""Defines data model for storing Passages (consisting of text, embedding)"""
id: uuid.UUID
user_id: str
text: str
file_id: str
agent_id: str
data_source: str
embedding: Vector(config.default_embedding_config.embedding_dim)
metadata_: Dict
def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Passage(
text=self.text,
embedding=self.embedding,
file_id=self.file_id,
user_id=self.user_id,
id=self.id,
data_source=self.data_source,
agent_id=self.agent_id,
metadata=self.metadata_,
)
return PassageModel
elif table_type == TableType.RECALL_MEMORY:
class MessageModel(LanceModel):
"""Defines data model for storing Message objects"""
__abstract__ = True # this line is necessary
# Assuming message_id is the primary key
id: uuid.UUID
user_id: str
agent_id: str
# openai info
role: str
name: str
text: str
model: str
user: str
# function info
function_name: str
function_args: str
function_response: str
embedding = Vector(config.default_embedding_config.embedding_dim)
# Add a datetime column, with default value as the current time
created_at = datetime
def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Message(
user_id=self.user_id,
agent_id=self.agent_id,
role=self.role,
name=self.name,
text=self.text,
model=self.model,
function_name=self.function_name,
function_args=self.function_args,
function_response=self.function_response,
embedding=self.embedding,
created_at=self.created_at,
id=self.id,
)
"""Create database model for table_name"""
return MessageModel
else:
raise ValueError(f"Table type {table_type} not implemented")
class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""
# TODO: this should probably eventually be moved into a parent DB class
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
# TODO
pass
def generate_where_filter(self, filters: Dict) -> str:
where_filters = []
for key, value in filters.items():
where_filters.append(f"{key}={value}")
return where_filters.join(" AND ")
@abstractmethod
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
# TODO
pass
@abstractmethod
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
# TODO
pass
@abstractmethod
def get(self, id: uuid.UUID) -> Optional[Record]:
# TODO
pass
@abstractmethod
def size(self, filters: Optional[Dict] = {}) -> int:
# TODO
pass
@abstractmethod
def insert(self, record: Record):
# TODO
pass
@abstractmethod
def insert_many(self, records: List[Record], show_progress=False):
# TODO
pass
@abstractmethod
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
# TODO
pass
@abstractmethod
def query_date(self, start_date, end_date):
# TODO
pass
@abstractmethod
def query_text(self, query):
# TODO
pass
@abstractmethod
def delete_table(self):
# TODO
pass
@abstractmethod
def delete(self, filters: Optional[Dict] = {}):
# TODO
pass
@abstractmethod
def save(self):
# TODO
pass

View File

@@ -110,11 +110,6 @@ class StorageConnector:
from letta.agent_store.qdrant import QdrantStorageConnector
return QdrantStorageConnector(table_type, config, user_id, agent_id)
# TODO: add back
# elif storage_type == "lancedb":
# from letta.agent_store.db import LanceDBConnector
# return LanceDBConnector(agent_config=agent_config, table_type=table_type)
elif storage_type == "sqlite":
from letta.agent_store.db import SQLLiteStorageConnector

View File

@@ -171,7 +171,6 @@ def run(
# printd("State path:", agent_config.save_state_dir())
# printd("Persistent manager path:", agent_config.save_persistence_manager_dir())
# printd("Index path:", agent_config.save_agent_index_dir())
# persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load
# TODO: load prior agent state
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)

View File

@@ -3054,16 +3054,13 @@ class LocalClient(AbstractClient):
# recall memory
def get_messages(
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000
) -> List[Message]:
def get_messages(self, agent_id: str, cursor: Optional[str] = None, limit: Optional[int] = 1000) -> List[Message]:
"""
Get messages from an agent with pagination.
Args:
agent_id (str): ID of the agent
before (str): Get messages before a certain time
after (str): Get messages after a certain time
cursor (str): Get messages after a certain time
limit (int): Limit number of messages
Returns:
@@ -3074,8 +3071,7 @@ class LocalClient(AbstractClient):
return self.server.get_agent_recall_cursor(
user_id=self.user_id,
agent_id=agent_id,
before=before,
after=after,
cursor=cursor,
limit=limit,
reverse=True,
)

View File

@@ -38,6 +38,7 @@ DEFAULT_PRESET = "memgpt_chat"
# Base tools that cannot be edited, as they access agent state directly
BASE_TOOLS = ["send_message", "conversation_search", "conversation_search_date", "archival_memory_insert", "archival_memory_search"]
O1_BASE_TOOLS = ["send_thinking_message", "send_final_message"]
# Base memory tools CAN be edited, and are added by default by the server
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from typing import Optional
from letta.agent import Agent
@@ -38,7 +39,7 @@ Returns:
"""
def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]:
def pause_heartbeats(self: "Agent", minutes: int) -> Optional[str]:
import datetime
from letta.constants import MAX_PAUSE_HEARTBEATS
@@ -80,7 +81,15 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results, total = self.persistence_manager.recall_memory.text_search(query, count=count, start=page * count)
# TODO: add paging by page number. currently cursor only works with strings.
# original: start=page * count
results = self.message_manager.list_user_messages_for_agent(
agent_id=self.agent_state.id,
actor=self.user,
query_text=query,
limit=count,
)
total = len(results)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -112,10 +121,29 @@ def conversation_search_date(self: "Agent", start_date: str, end_date: str, page
page = 0
try:
page = int(page)
if page < 0:
raise ValueError
except:
raise ValueError(f"'page' argument must be an integer")
# Convert date strings to datetime objects
try:
start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace(hour=0, minute=0, second=0, microsecond=0)
end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
raise ValueError("Dates must be in the format 'YYYY-MM-DD'")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results, total = self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page * count)
results = self.message_manager.list_user_messages_for_agent(
# TODO: add paging by page number. currently cursor only works with strings.
agent_id=self.agent_state.id,
actor=self.user,
start_date=start_datetime,
end_date=end_datetime,
limit=count,
# start_date=start_date, end_date=end_date, limit=count, start=page * count
)
total = len(results)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
@@ -136,7 +164,7 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.persistence_manager.archival_memory.insert(content)
self.archival_memory.insert(content)
return None
@@ -163,7 +191,7 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) -
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results, total = self.persistence_manager.archival_memory.search(query, count=count, start=page * count)
results, total = self.archival_memory.search(query, count=count, start=page * count)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."

View File

@@ -190,8 +190,8 @@ def run_agent_loop(
elif user_input.lower() == "/memory":
print(f"\nDumping memory contents:\n")
print(f"{letta_agent.agent_state.memory.compile()}")
print(f"{letta_agent.persistence_manager.archival_memory.compile()}")
print(f"{letta_agent.persistence_manager.recall_memory.compile()}")
print(f"{letta_agent.archival_memory.compile()}")
print(f"{letta_agent.recall_memory.compile()}")
continue
elif user_input.lower() == "/model":

View File

@@ -67,14 +67,12 @@ def summarize_messages(
+ message_sequence_to_summarize[cutoff:]
)
dummy_user_id = agent_state.user_id
agent_state.user_id
dummy_agent_id = agent_state.id
message_sequence = []
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt))
message_sequence.append(
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK)
)
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input))
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt))
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK))
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input))
# TODO: We need to eventually have a separate LLM config for the summarizer LLM
llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
@@ -252,82 +250,6 @@ class DummyRecallMemory(RecallMemory):
return matches, len(matches)
class BaseRecallMemory(RecallMemory):
"""Recall memory based on base functions implemented by storage connectors"""
def __init__(self, agent_state, restrict_search_to_summaries=False):
# If true, the pool of messages that can be queried are the automated summaries only
# (generated when the conversation window needs to be shortened)
self.restrict_search_to_summaries = restrict_search_to_summaries
from letta.agent_store.storage import StorageConnector
self.agent_state = agent_state
# create embedding model
self.embed_model = embedding_model(agent_state.embedding_config)
self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
# create storage backend
self.storage = StorageConnector.get_recall_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}
def get_all(self, start=0, count=None):
start = 0 if start is None else int(start)
count = 0 if count is None else int(count)
results = self.storage.get_all(start, count)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)
def text_search(self, query_string, count=None, start=None):
start = 0 if start is None else int(start)
count = 0 if count is None else int(count)
results = self.storage.query_text(query_string, count, start)
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)
def date_search(self, start_date, end_date, count=None, start=None):
start = 0 if start is None else int(start)
count = 0 if count is None else int(count)
results = self.storage.query_date(start_date, end_date, count, start)
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)
def compile(self) -> str:
total = self.storage.size()
system_count = self.storage.size(filters={"role": "system"})
user_count = self.storage.size(filters={"role": "user"})
assistant_count = self.storage.size(filters={"role": "assistant"})
function_count = self.storage.size(filters={"role": "function"})
other_count = total - (system_count + user_count + assistant_count + function_count)
memory_str = (
f"Statistics:"
+ f"\n{total} total messages"
+ f"\n{system_count} system"
+ f"\n{user_count} user"
+ f"\n{assistant_count} assistant"
+ f"\n{function_count} function"
+ f"\n{other_count} other"
)
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
def insert(self, message: Message):
self.storage.insert(message)
def insert_many(self, messages: List[Message]):
self.storage.insert_many(messages)
def save(self):
self.storage.save()
def __len__(self):
return self.storage.size()
def count(self) -> int:
return len(self)
class EmbeddingArchivalMemory(ArchivalMemory):
"""Archival memory with embedding based search"""

View File

@@ -14,7 +14,6 @@ from letta.schemas.api_key import APIKey
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.user import User
from letta.services.per_agent_lock_manager import PerAgentLockManager
@@ -66,40 +65,6 @@ class EmbeddingConfigColumn(TypeDecorator):
return value
class ToolCallColumn(TypeDecorator):
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
values = []
for v in value:
if isinstance(v, ToolCall):
values.append(v.model_dump())
else:
values.append(v)
return values
return value
def process_result_value(self, value, dialect):
if value:
tools = []
for tool_value in value:
if "function" in tool_value:
tool_call_function = ToolCallFunction(**tool_value["function"])
del tool_value["function"]
else:
tool_call_function = None
tools.append(ToolCall(function=tool_call_function, **tool_value))
return tools
return value
# TODO: eventually store providers?
# class Provider(Base):
# __tablename__ = "providers"

View File

@@ -20,7 +20,7 @@ def send_thinking_message(self: "Agent", message: str) -> Optional[str]:
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.interface.internal_monologue(message, msg_obj=self._messages[-1])
self.interface.internal_monologue(message)
return None
@@ -34,7 +34,7 @@ def send_final_message(self: "Agent", message: str) -> Optional[str]:
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.interface.internal_monologue(message, msg_obj=self._messages[-1])
self.interface.internal_monologue(message)
return None
@@ -62,10 +62,15 @@ class O1Agent(Agent):
"""Run Agent.inner_step in a loop, terminate when final thinking message is sent or max_thinking_steps is reached"""
# assert ms is not None, "MetadataStore is required"
next_input_message = messages if isinstance(messages, list) else [messages]
counter = 0
total_usage = UsageStatistics()
step_count = 0
while step_count < self.max_thinking_steps:
# This is hacky but we need to do this for now
for m in next_input_message:
m.id = m._generate_id()
kwargs["ms"] = ms
kwargs["first_message"] = False
step_response = self.inner_step(

View File

@@ -18,6 +18,7 @@ def trigger_rethink_memory(agent_state: "AgentState", message: Optional[str]) ->
"""
from letta import create_client
client = create_client()
agents = client.list_agents()
for agent in agents:
@@ -149,6 +150,11 @@ class OfflineMemoryAgent(Agent):
step_count = 0
while counter < self.max_memory_rethinks:
# This is hacky but we need to do this for now
# TODO: REMOVE THIS
for m in next_input_message:
m.id = m._generate_id()
kwargs["ms"] = ms
kwargs["first_message"] = False
step_response = self.inner_step(

View File

@@ -1,8 +1,10 @@
from letta.orm.agents_tags import AgentsTags
from letta.orm.base import Base
from letta.orm.block import Block
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.job import Job
from letta.orm.message import Message
from letta.orm.organization import Organization
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, List
from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

66
letta/orm/message.py Normal file
View File

@@ -0,0 +1,66 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import JSON, DateTime, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
class ToolCallColumn(TypeDecorator):
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
values = []
for v in value:
if isinstance(v, ToolCall):
values.append(v.model_dump())
else:
values.append(v)
return values
return value
def process_result_value(self, value, dialect):
if value:
tools = []
for tool_value in value:
if "function" in tool_value:
tool_call_function = ToolCallFunction(**tool_value["function"])
del tool_value["function"]
else:
tool_call_function = None
tools.append(ToolCall(function=tool_call_function, **tool_value))
return tools
return value
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
"""Defines data model for storing Message objects"""
__tablename__ = "messages"
__table_args__ = {"extend_existing": True}
__pydantic_model__ = PydanticMessage
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique message identifier")
role: Mapped[str] = mapped_column(doc="Message role (user/assistant/system/tool)")
text: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Message content")
model: Mapped[Optional[str]] = mapped_column(nullable=True, doc="LLM model used")
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
# Relationships
# TODO: Add in after Agent ORM is created
# agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")

View File

@@ -31,6 +31,22 @@ class UserMixin(Base):
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
class AgentMixin(Base):
"""Mixin for models that belong to an agent."""
__abstract__ = True
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"))
class FileMixin(Base):
"""Mixin for models that belong to a file."""
__abstract__ = True
file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id"))
class SourceMixin(Base):
"""Mixin for models (e.g. file) that belong to a source."""

View File

@@ -33,6 +33,7 @@ class Organization(SqlalchemyBase):
sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship(
"SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
)
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
# TODO: Map these relationships later when we actually make these models
# below is just a suggestion

View File

@@ -1,8 +1,10 @@
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, List, Literal, Optional, Type
from sqlalchemy import String, select
from sqlalchemy import String, func, select
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, Session, mapped_column
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
@@ -20,6 +22,11 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class AccessType(str, Enum):
ORGANIZATION = "organization"
USER = "user"
class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
__abstract__ = True
@@ -28,46 +35,68 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
id: Mapped[str] = mapped_column(String, primary_key=True)
@classmethod
def list(
cls, *, db_session: "Session", cursor: Optional[str] = None, limit: Optional[int] = 50, **kwargs
) -> List[Type["SqlalchemyBase"]]:
"""
List records with optional cursor (for pagination), limit, and automatic filtering.
def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
"""Get a record by ID.
Args:
db_session: The database session to use.
cursor: Optional ID to start pagination from.
limit: Maximum number of records to return.
**kwargs: Filters passed as equality conditions or iterable for IN filtering.
db_session: SQLAlchemy session
id: Record ID to retrieve
Returns:
A list of model instances matching the filters.
Optional[SqlalchemyBase]: The record if found, None otherwise
"""
logger.debug(f"Listing {cls.__name__} with filters {kwargs}")
try:
return db_session.query(cls).filter(cls.id == id).first()
except DBAPIError:
return None
@classmethod
def list(
cls,
*,
db_session: "Session",
cursor: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
**kwargs,
) -> List[Type["SqlalchemyBase"]]:
"""List records with advanced filtering and pagination options."""
if start_date and end_date and start_date > end_date:
raise ValueError("start_date must be earlier than or equal to end_date")
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
with db_session as session:
# Start with a base query
query = select(cls)
# Apply filtering logic
for key, value in kwargs.items():
column = getattr(cls, key)
if isinstance(value, (list, tuple, set)): # Check for iterables
if isinstance(value, (list, tuple, set)):
query = query.where(column.in_(value))
else: # Single value for equality filtering
else:
query = query.where(column == value)
# Apply cursor for pagination
# Date range filtering
if start_date:
query = query.filter(cls.created_at >= start_date)
if end_date:
query = query.filter(cls.created_at <= end_date)
# Cursor-based pagination
if cursor:
query = query.where(cls.id > cursor)
# Handle soft deletes if the class has the 'is_deleted' attribute
# Apply text search
if query_text:
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
# Handle ordering and soft deletes
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
# Add ordering and limit
query = query.order_by(cls.id).limit(limit)
# Execute the query and return results as model instances
return list(session.execute(query).scalars())
@classmethod
@@ -77,6 +106,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
identifier: Optional[str] = None,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
**kwargs,
) -> Type["SqlalchemyBase"]:
"""The primary accessor for an ORM record.
@@ -108,7 +138,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_conditions.append(", ".join(f"{key}='{value}'" for key, value in kwargs.items()))
if actor:
query = cls.apply_access_predicate(query, actor, access)
query = cls.apply_access_predicate(query, actor, access, access_type)
query_conditions.append(f"access level in {access} for actor='{actor}'")
if hasattr(cls, "is_deleted"):
@@ -170,12 +200,66 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
session.refresh(self)
return self
@classmethod
def size(
cls,
*,
db_session: "Session",
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
**kwargs,
) -> int:
"""
Get the count of rows that match the provided filters.
Args:
db_session: SQLAlchemy session
**kwargs: Filters to apply to the query (e.g., column_name=value)
Returns:
int: The count of rows that match the filters
Raises:
DBAPIError: If a database error occurs
"""
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
with db_session as session:
query = select(func.count()).select_from(cls)
if actor:
query = cls.apply_access_predicate(query, actor, access, access_type)
# Apply filtering logic based on kwargs
for key, value in kwargs.items():
if value:
column = getattr(cls, key, None)
if not column:
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
if isinstance(value, (list, tuple, set)): # Check for iterables
query = query.where(column.in_(value))
else: # Single value for equality filtering
query = query.where(column == value)
# Handle soft deletes if the class has the 'is_deleted' attribute
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
try:
count = session.execute(query).scalar()
return count if count else 0
except DBAPIError as e:
logger.exception(f"Failed to calculate size for {cls.__name__}")
raise e
@classmethod
def apply_access_predicate(
cls,
query: "Select",
actor: "User",
access: List[Literal["read", "write", "admin"]],
access_type: AccessType = AccessType.ORGANIZATION,
) -> "Select":
"""applies a WHERE clause restricting results to the given actor and access level
Args:
@@ -189,10 +273,18 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
the sqlalchemy select statement restricted to the given access.
"""
del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
org_id = getattr(actor, "organization_id", None)
if not org_id:
raise ValueError(f"object {actor} has no organization accessor")
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
if access_type == AccessType.ORGANIZATION:
org_id = getattr(actor, "organization_id", None)
if not org_id:
raise ValueError(f"object {actor} has no organization accessor")
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
elif access_type == AccessType.USER:
user_id = getattr(actor, "id", None)
if not user_id:
raise ValueError(f"object {actor} has no user accessor")
return query.where(cls.user_id == user_id, cls.is_deleted == False)
else:
raise ValueError(f"unknown access_type: {access_type}")
@classmethod
def _handle_dbapi_error(cls, e: DBAPIError):

View File

@@ -1,149 +0,0 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List
from letta.memory import BaseRecallMemory, EmbeddingArchivalMemory
from letta.schemas.agent import AgentState
from letta.schemas.memory import Memory
from letta.schemas.message import Message
from letta.utils import printd
def parse_formatted_time(formatted_time: str):
# parse times returned by letta.utils.get_formatted_time()
try:
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p %Z%z")
except:
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p")
class PersistenceManager(ABC):
@abstractmethod
def trim_messages(self, num):
pass
@abstractmethod
def prepend_to_messages(self, added_messages):
pass
@abstractmethod
def append_to_messages(self, added_messages):
pass
@abstractmethod
def swap_system_message(self, new_system_message):
pass
@abstractmethod
def update_memory(self, new_memory):
pass
class LocalStateManager(PersistenceManager):
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
recall_memory_cls = BaseRecallMemory
archival_memory_cls = EmbeddingArchivalMemory
def __init__(self, agent_state: AgentState):
# Memory held in-state useful for debugging stateful versions
self.memory = agent_state.memory
# self.messages = [] # current in-context messages
# self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB)
self.archival_memory = EmbeddingArchivalMemory(agent_state)
self.recall_memory = BaseRecallMemory(agent_state)
# self.agent_state = agent_state
def save(self):
"""Ensure storage connectors save data"""
self.archival_memory.save()
self.recall_memory.save()
'''
def json_to_message(self, message_json) -> Message:
"""Convert agent message JSON into Message object"""
# get message
if "message" in message_json:
message = message_json["message"]
else:
message = message_json
# get timestamp
if "timestamp" in message_json:
timestamp = parse_formatted_time(message_json["timestamp"])
else:
timestamp = get_local_time()
# TODO: change this when we fully migrate to tool calls API
if "function_call" in message:
tool_calls = [
ToolCall(
id=message["tool_call_id"],
tool_call_type="function",
function={
"name": message["function_call"]["name"],
"arguments": message["function_call"]["arguments"],
},
)
]
printd(f"Saving tool calls {[vars(tc) for tc in tool_calls]}")
else:
tool_calls = None
# if message["role"] == "function":
# message["role"] = "tool"
return Message(
user_id=self.agent_state.user_id,
agent_id=self.agent_state.id,
role=message["role"],
text=message["content"],
name=message["name"] if "name" in message else None,
model=self.agent_state.llm_config.model,
created_at=timestamp,
tool_calls=tool_calls,
tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None,
id=message["id"] if "id" in message else None,
)
'''
def trim_messages(self, num):
# printd(f"InMemoryStateManager.trim_messages")
# self.messages = [self.messages[0]] + self.messages[num:]
pass
def prepend_to_messages(self, added_messages: List[Message]):
# first tag with timestamps
# added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"{self.__class__.__name__}.prepend_to_message")
# self.messages = [self.messages[0]] + added_messages + self.messages[1:]
# add to recall memory
self.recall_memory.insert_many([m for m in added_messages])
def append_to_messages(self, added_messages: List[Message]):
# first tag with timestamps
# added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"{self.__class__.__name__}.append_to_messages")
# self.messages = self.messages + added_messages
# add to recall memory
self.recall_memory.insert_many([m for m in added_messages])
def swap_system_message(self, new_system_message: Message):
# first tag with timestamps
# new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
printd(f"{self.__class__.__name__}.swap_system_message")
# self.messages[0] = new_system_message
# add to recall memory
self.recall_memory.insert(new_system_message)
def update_memory(self, new_memory: Memory):
printd(f"{self.__class__.__name__}.update_memory")
assert isinstance(new_memory, Memory), type(new_memory)
self.memory = new_memory

View File

@@ -33,18 +33,19 @@ class LettaBase(BaseModel):
def generate_id_field(cls, prefix: Optional[str] = None) -> "Field":
prefix = prefix or cls.__id_prefix__
# TODO: generate ID from regex pattern?
def _generate_id() -> str:
return f"{prefix}-{uuid.uuid4()}"
return Field(
...,
description=cls._id_description(prefix),
pattern=cls._id_regex_pattern(prefix),
examples=[cls._id_example(prefix)],
default_factory=_generate_id,
default_factory=cls._generate_id,
)
@classmethod
def _generate_id(cls, prefix: Optional[str] = None) -> str:
prefix = prefix or cls.__id_prefix__
return f"{prefix}-{uuid.uuid4()}"
# def _generate_id(self) -> str:
# return f"{self.__id_prefix__}-{uuid.uuid4()}"
@@ -78,7 +79,7 @@ class LettaBase(BaseModel):
"""
_ = values # for SCA
if isinstance(v, UUID):
logger.warning(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
logger.debug(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
return f"{cls.__id_prefix__}-{v}"
return v

View File

@@ -105,7 +105,7 @@ class Message(BaseMessage):
id: str = BaseMessage.generate_id_field()
role: MessageRole = Field(..., description="The role of the participant.")
text: Optional[str] = Field(None, description="The text of the message.")
user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
model: Optional[str] = Field(None, description="The model used to make the function call.")
name: Optional[str] = Field(None, description="The name of the participant.")
@@ -281,7 +281,6 @@ class Message(BaseMessage):
)
if id is not None:
return Message(
user_id=user_id,
agent_id=agent_id,
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
@@ -295,7 +294,6 @@ class Message(BaseMessage):
)
else:
return Message(
user_id=user_id,
agent_id=agent_id,
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
@@ -328,7 +326,6 @@ class Message(BaseMessage):
if id is not None:
return Message(
user_id=user_id,
agent_id=agent_id,
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
@@ -342,7 +339,6 @@ class Message(BaseMessage):
)
else:
return Message(
user_id=user_id,
agent_id=agent_id,
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
@@ -375,7 +371,6 @@ class Message(BaseMessage):
# If we're going from tool-call style
if id is not None:
return Message(
user_id=user_id,
agent_id=agent_id,
model=model,
# standard fields expected in an OpenAI ChatCompletion message object
@@ -389,7 +384,6 @@ class Message(BaseMessage):
)
else:
return Message(
user_id=user_id,
agent_id=agent_id,
model=model,
# standard fields expected in an OpenAI ChatCompletion message object

View File

@@ -409,7 +409,7 @@ def get_agent_messages(
return server.get_agent_recall_cursor(
user_id=actor.id,
agent_id=agent_id,
before=before,
cursor=before,
limit=limit,
reverse=True,
return_message_object=msg_object,

View File

@@ -77,14 +77,15 @@ from letta.schemas.user import User
from letta.services.agents_tags_manager import AgentsTagsManager
from letta.services.block_manager import BlockManager
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.services.tools_agents_manager import ToolsAgentsManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.organization_manager import OrganizationManager
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.tools_agents_manager import ToolsAgentsManager
from letta.services.user_manager import UserManager
from letta.utils import create_random_username, get_utc_time, json_dumps, json_loads
@@ -260,6 +261,7 @@ class SyncServer(Server):
self.agents_tags_manager = AgentsTagsManager()
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
self.blocks_agents_manager = BlocksAgentsManager()
self.message_manager = MessageManager()
self.tools_agents_manager = ToolsAgentsManager()
self.job_manager = JobManager()
@@ -414,7 +416,7 @@ class SyncServer(Server):
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor)
elif agent_state.agent_type == AgentType.chat_only_agent:
agent = ChatOnlyAgent(agent_state=agent_state, interface=interface, user=actor)
else:
else:
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
# Rebuild the system prompt - may be linked to new blocks now
@@ -422,7 +424,7 @@ class SyncServer(Server):
# Persist to agent
save_agent(agent, self.ms)
return agent
return agent
def _step(
self,
@@ -518,12 +520,7 @@ class SyncServer(Server):
letta_agent.interface.print_messages_raw(letta_agent.messages)
elif command.lower() == "memory":
ret_str = (
f"\nDumping memory contents:\n"
+ f"\n{str(letta_agent.agent_state.memory)}"
+ f"\n{str(letta_agent.persistence_manager.archival_memory)}"
+ f"\n{str(letta_agent.persistence_manager.recall_memory)}"
)
ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.archival_memory)}"
return ret_str
elif command.lower() == "pop" or command.lower().startswith("pop "):
@@ -625,7 +622,6 @@ class SyncServer(Server):
# Convert to a Message object
if timestamp:
message = Message(
user_id=user_id,
agent_id=agent_id,
role="user",
text=packaged_user_message,
@@ -633,7 +629,6 @@ class SyncServer(Server):
)
else:
message = Message(
user_id=user_id,
agent_id=agent_id,
role="user",
text=packaged_user_message,
@@ -672,7 +667,6 @@ class SyncServer(Server):
if timestamp:
message = Message(
user_id=user_id,
agent_id=agent_id,
role="system",
text=packaged_system_message,
@@ -680,7 +674,6 @@ class SyncServer(Server):
)
else:
message = Message(
user_id=user_id,
agent_id=agent_id,
role="system",
text=packaged_system_message,
@@ -743,7 +736,6 @@ class SyncServer(Server):
# Create the Message object
message_objects.append(
Message(
user_id=user_id,
agent_id=agent_id,
role=message.role,
text=message.text,
@@ -876,7 +868,7 @@ class SyncServer(Server):
else:
raise ValueError(f"Invalid message role: {message.role}")
init_messages.append(Message(role=message.role, text=packed_message, user_id=user_id, agent_id=agent_state.id))
init_messages.append(Message(role=message.role, text=packed_message, agent_id=agent_state.id))
# init_messages = [Message.dict_to_message(user_id=user_id, agent_id=agent_state.id, openai_message_dict=message.model_dump()) for message in request.initial_message_sequence]
else:
init_messages = None
@@ -1160,11 +1152,11 @@ class SyncServer(Server):
def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary:
agent = self.load_agent(agent_id=agent_id)
return ArchivalMemorySummary(size=len(agent.persistence_manager.archival_memory))
return ArchivalMemorySummary(size=len(agent.archival_memory))
def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary:
agent = self.load_agent(agent_id=agent_id)
return RecallMemorySummary(size=len(agent.persistence_manager.recall_memory))
return RecallMemorySummary(size=len(agent.message_manager))
def get_in_context_message_ids(self, agent_id: str) -> List[str]:
"""Get the message ids of the in-context messages in the agent's memory"""
@@ -1182,7 +1174,7 @@ class SyncServer(Server):
"""Get a single message from the agent's memory"""
# Get the agent object (loaded in memory)
agent = self.load_agent(agent_id=agent_id)
message = agent.persistence_manager.recall_memory.storage.get(id=message_id)
message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user)
return message
def get_agent_messages(
@@ -1213,14 +1205,16 @@ class SyncServer(Server):
else:
# need to access persistence manager for additional messages
db_iterator = letta_agent.persistence_manager.recall_memory.storage.get_all_paginated(page_size=count, offset=start)
# get a single page of messages
# TODO: handle stop iteration
page = next(db_iterator, [])
# get messages using message manager
page = letta_agent.message_manager.list_user_messages_for_agent(
agent_id=agent_id,
actor=self.default_user,
cursor=start,
limit=count,
)
# return messages in reverse chronological order
messages = sorted(page, key=lambda x: x.created_at, reverse=True)
messages = page
assert all(isinstance(m, Message) for m in messages)
## Convert to json
@@ -1243,7 +1237,7 @@ class SyncServer(Server):
letta_agent = self.load_agent(agent_id=agent_id)
# iterate over records
db_iterator = letta_agent.persistence_manager.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
db_iterator = letta_agent.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
# get a single page of messages
page = next(db_iterator, [])
@@ -1268,7 +1262,7 @@ class SyncServer(Server):
letta_agent = self.load_agent(agent_id=agent_id)
# iterate over recorde
cursor, records = letta_agent.persistence_manager.archival_memory.storage.get_all_cursor(
cursor, records = letta_agent.archival_memory.storage.get_all_cursor(
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
)
return records
@@ -1283,14 +1277,14 @@ class SyncServer(Server):
letta_agent = self.load_agent(agent_id=agent_id)
# Insert into archival memory
passage_ids = letta_agent.persistence_manager.archival_memory.insert(memory_string=memory_contents, return_ids=True)
passage_ids = letta_agent.archival_memory.insert(memory_string=memory_contents, return_ids=True)
# Update the agent
# TODO: should this update the system prompt?
save_agent(letta_agent, self.ms)
# TODO: this is gross, fix
return [letta_agent.persistence_manager.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids]
return [letta_agent.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids]
def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str):
if self.user_manager.get_user_by_id(user_id=user_id) is None:
@@ -1305,7 +1299,7 @@ class SyncServer(Server):
# Delete by ID
# TODO check if it exists first, and throw error if not
letta_agent.persistence_manager.archival_memory.storage.delete({"id": memory_id})
letta_agent.archival_memory.storage.delete({"id": memory_id})
# TODO: return archival memory
@@ -1313,17 +1307,15 @@ class SyncServer(Server):
self,
user_id: str,
agent_id: str,
after: Optional[str] = None,
before: Optional[str] = None,
cursor: Optional[str] = None,
limit: Optional[int] = 100,
order_by: Optional[str] = "created_at",
order: Optional[str] = "asc",
reverse: Optional[bool] = False,
return_message_object: bool = True,
assistant_message_tool_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
) -> Union[List[Message], List[LettaMessage]]:
if self.user_manager.get_user_by_id(user_id=user_id) is None:
actor = self.user_manager.get_user_by_id(user_id=user_id)
if actor is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
@@ -1332,8 +1324,12 @@ class SyncServer(Server):
letta_agent = self.load_agent(agent_id=agent_id)
# iterate over records
cursor, records = letta_agent.persistence_manager.recall_memory.storage.get_all_cursor(
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
# TODO: Check "order_by", "order"
records = letta_agent.message_manager.list_messages_for_agent(
agent_id=agent_id,
actor=actor,
cursor=cursor,
limit=limit,
)
assert all(isinstance(m, Message) for m in records)
@@ -1353,7 +1349,7 @@ class SyncServer(Server):
records = records[::-1]
return records
def get_server_config(self, include_defaults: bool = False) -> dict:
"""Return the base config"""
@@ -1425,19 +1421,25 @@ class SyncServer(Server):
self.agents_tags_manager.delete_all_tags_from_agent(agent_id=agent_id, actor=actor)
self.blocks_agents_manager.remove_all_agent_blocks(agent_id=agent_id)
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# Verify that the agent exists and belongs to the org of the user
agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
if not agent_state:
if agent_state is None:
raise ValueError(f"Could not find agent_id={agent_id} under user_id={user_id}")
agent_state_user = self.user_manager.get_user_by_id(user_id=agent_state.user_id)
if agent_state_user.organization_id != actor.organization_id:
raise ValueError(
f"Could not authorize agent_id={agent_id} with user_id={user_id} because of differing organizations; agent_id was created in {agent_state_user.organization_id} while user belongs to {actor.organization_id}. How did they get the agent id?"
)
# TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM MODEL
messages = self.message_manager.list_messages_for_agent(agent_id=agent_state.id)
for message in messages:
self.message_manager.delete_message_by_id(message.id, actor=actor)
# TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM
try:
agent_state_user = self.user_manager.get_user_by_id(user_id=agent_state.user_id)
if agent_state_user.organization_id != actor.organization_id:
raise ValueError(
f"Could not authorize agent_id={agent_id} with user_id={user_id} because of differing organizations; agent_id was created in {agent_state_user.organization_id} while user belongs to {actor.organization_id}. How did they get the agent id?"
)
except NoResultFound:
logger.error(f"Agent with id {agent_state.id} has nonexistent user {agent_state.user_id}")
# First, if the agent is in the in-memory cache we should remove it
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
@@ -1582,7 +1584,7 @@ class SyncServer(Server):
# delete all Passage objects with source_id==source_id from agent's archival memory
agent = self.load_agent(agent_id=agent_id)
archival_memory = agent.persistence_manager.archival_memory
archival_memory = agent.archival_memory
archival_memory.storage.delete({"source_id": source_id})
# delete agent-source mapping
@@ -1661,7 +1663,7 @@ class SyncServer(Server):
"""Get a single message from the agent's memory"""
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id)
message = letta_agent.persistence_manager.recall_memory.storage.get(id=message_id)
message = letta_agent.message_manager.get_message_by_id(id=message_id)
save_agent(letta_agent, self.ms)
return message
@@ -1705,7 +1707,7 @@ class SyncServer(Server):
try:
return self.user_manager.get_user_by_id(user_id=user_id)
except ValueError:
except NoResultFound:
raise HTTPException(status_code=404, detail=f"User with id {user_id} not found")
def get_organization_or_default(self, org_id: Optional[str]) -> Organization:
@@ -1715,7 +1717,7 @@ class SyncServer(Server):
try:
return self.organization_manager.get_organization_by_id(org_id=org_id)
except ValueError:
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
def list_llm_models(self) -> List[LLMConfig]:

View File

@@ -102,7 +102,7 @@ class BlockManager:
return [block.to_pydantic() for block in blocks]
@enforce_types
def get_block_by_id(self, block_id, actor: PydanticUser) -> Optional[PydanticBlock]:
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
"""Retrieve a block by its name."""
with self.session_maker() as session:
try:

View File

@@ -0,0 +1,182 @@
from datetime import datetime
from typing import Dict, List, Optional
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.schemas.enums import MessageRole
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class MessageManager:
"""Manager class to handle business logic related to Messages."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
"""Fetch a message by ID."""
with self.session_maker() as session:
try:
message = MessageModel.read(db_session=session, identifier=message_id, actor=actor)
return message.to_pydantic()
except NoResultFound:
return None
@enforce_types
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
"""Create a new message."""
with self.session_maker() as session:
# Set the organization id of the Pydantic message
pydantic_msg.organization_id = actor.organization_id
msg_data = pydantic_msg.model_dump()
msg = MessageModel(**msg_data)
msg.create(session, actor=actor) # Persist to database
return msg.to_pydantic()
@enforce_types
def create_many_messages(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]:
"""Create multiple messages."""
return [self.create_message(m, actor=actor) for m in pydantic_msgs]
@enforce_types
def update_message_by_id(self, message_id: str, message: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
"""
Updates an existing record in the database with values from the provided record object.
"""
with self.session_maker() as session:
# Fetch existing message from database
msg = MessageModel.read(
db_session=session,
identifier=message_id,
actor=actor,
)
# Update the database record with values from the provided record
for column in MessageModel.__table__.columns:
column_name = column.name
if hasattr(message, column_name):
new_value = getattr(message, column_name)
setattr(msg, column_name, new_value)
# Commit changes
return msg.update(db_session=session, actor=actor).to_pydantic()
@enforce_types
def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool:
"""Delete a message."""
with self.session_maker() as session:
try:
msg = MessageModel.read(
db_session=session,
identifier=message_id,
actor=actor,
)
msg.hard_delete(session, actor=actor)
except NoResultFound:
raise ValueError(f"Message with id {message_id} not found.")
@enforce_types
def size(
self,
actor: PydanticUser,
role: Optional[MessageRole] = None,
agent_id: Optional[str] = None,
) -> int:
"""Get the total count of messages with optional filters.
Args:
actor: The user requesting the count
role: The role of the message
"""
with self.session_maker() as session:
return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id)
@enforce_types
def list_user_messages_for_agent(
self,
agent_id: str,
actor: Optional[PydanticUser] = None,
cursor: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
filters: Optional[Dict] = None,
query_text: Optional[str] = None,
) -> List[PydanticMessage]:
"""List user messages with flexible filtering and pagination options.
Args:
cursor: Cursor-based pagination - return records after this ID (exclusive)
start_date: Filter records created after this date
end_date: Filter records created before this date
limit: Maximum number of records to return
filters: Additional filters to apply
query_text: Optional text to search for in message content
Returns:
List[PydanticMessage] - List of messages matching the criteria
"""
message_filters = {"role": "user"}
if filters:
message_filters.update(filters)
return self.list_messages_for_agent(
agent_id=agent_id,
actor=actor,
cursor=cursor,
start_date=start_date,
end_date=end_date,
limit=limit,
filters=message_filters,
query_text=query_text,
)
@enforce_types
def list_messages_for_agent(
self,
agent_id: str,
actor: Optional[PydanticUser] = None,
cursor: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
filters: Optional[Dict] = None,
query_text: Optional[str] = None,
) -> List[PydanticMessage]:
"""List messages with flexible filtering and pagination options.
Args:
cursor: Cursor-based pagination - return records after this ID (exclusive)
start_date: Filter records created after this date
end_date: Filter records created before this date
limit: Maximum number of records to return
filters: Additional filters to apply
query_text: Optional text to search for in message content
Returns:
List[PydanticMessage] - List of messages matching the criteria
"""
with self.session_maker() as session:
# Start with base filters
message_filters = {"agent_id": agent_id}
if actor:
message_filters.update({"organization_id": actor.organization_id})
if filters:
message_filters.update(filters)
results = MessageModel.list(
db_session=session,
cursor=cursor,
start_date=start_date,
end_date=end_date,
limit=limit,
query_text=query_text,
**message_filters,
)
return [msg.to_pydantic() for msg in results]

View File

@@ -30,19 +30,16 @@ class OrganizationManager:
def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]:
"""Fetch an organization by ID."""
with self.session_maker() as session:
try:
organization = OrganizationModel.read(db_session=session, identifier=org_id)
return organization.to_pydantic()
except NoResultFound:
return None
organization = OrganizationModel.read(db_session=session, identifier=org_id)
return organization.to_pydantic()
@enforce_types
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
"""Create a new organization. If a name is provided, it is used, otherwise, a random one is generated."""
org = self.get_organization_by_id(pydantic_org.id)
if org:
"""Create a new organization."""
try:
org = self.get_organization_by_id(pydantic_org.id)
return org
else:
except NoResultFound:
return self._create_organization(pydantic_org=pydantic_org)
@enforce_types

View File

@@ -141,5 +141,5 @@ class SourceManager:
"""Delete a file by its ID."""
with self.session_maker() as session:
file = FileMetadataModel.read(db_session=session, identifier=file_id)
file.delete(db_session=session, actor=actor)
file.hard_delete(db_session=session, actor=actor)
return file.to_pydantic()

View File

@@ -122,7 +122,7 @@ class ToolManager:
tool.json_schema = new_schema
# Save the updated tool to the database
return tool.update(db_session=session, actor=actor)
return tool.update(db_session=session, actor=actor).to_pydantic()
@enforce_types
def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:

View File

@@ -71,7 +71,7 @@ class UserManager:
with self.session_maker() as session:
# Delete from user table
user = UserModel.read(db_session=session, identifier=user_id)
user.delete(session)
user.hard_delete(session)
# TODO: Integrate this via the ORM models for the Agent, Source, and AgentSourceMapping
# Cascade delete for related models: Agent, Source, AgentSourceMapping

View File

@@ -1,5 +1,24 @@
import logging
import pytest
from letta.settings import tool_settings
def pytest_configure(config):
logging.basicConfig(level=logging.DEBUG)
@pytest.fixture
def mock_e2b_api_key_none():
# Store the original value of e2b_api_key
original_api_key = tool_settings.e2b_api_key
# Set e2b_api_key to None
tool_settings.e2b_api_key = None
# Yield control to the test
yield
# Restore the original value of e2b_api_key
tool_settings.e2b_api_key = original_api_key

View File

@@ -58,21 +58,6 @@ def clear_tables():
Sandbox.connect(sandbox.sandbox_id).kill()
@pytest.fixture
def mock_e2b_api_key_none():
# Store the original value of e2b_api_key
original_api_key = tool_settings.e2b_api_key
# Set e2b_api_key to None
tool_settings.e2b_api_key = None
# Yield control to the test
yield
# Restore the original value of e2b_api_key
tool_settings.e2b_api_key = original_api_key
@pytest.fixture
def check_e2b_key_is_set():
original_api_key = tool_settings.e2b_api_key

View File

@@ -5,7 +5,6 @@ import pytest
from letta import create_client
from letta.schemas.letta_message import FunctionCallMessage
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.settings import tool_settings
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
assert_invoked_send_message_with_keyword,
@@ -20,21 +19,6 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
@pytest.fixture
def mock_e2b_api_key_none():
# Store the original value of e2b_api_key
original_api_key = tool_settings.e2b_api_key
# Set e2b_api_key to None
tool_settings.e2b_api_key = None
# Yield control to the test
yield
# Restore the original value of e2b_api_key
tool_settings.e2b_api_key = original_api_key
"""Contrived tools for this test case"""

View File

@@ -16,7 +16,6 @@ from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType
from letta.settings import tool_settings
from letta.utils import create_random_username
# Constants
@@ -40,7 +39,7 @@ def run_server():
@pytest.fixture(
params=[{"server": False}], # whether to use REST API server
params=[{"server": False}, {"server": True}], # whether to use REST API server
scope="module",
)
def client(request):
@@ -83,21 +82,6 @@ def clear_tables():
session.commit()
@pytest.fixture
def mock_e2b_api_key_none():
# Store the original value of e2b_api_key
original_api_key = tool_settings.e2b_api_key
# Set e2b_api_key to None
tool_settings.e2b_api_key = None
# Yield control to the test
yield
# Restore the original value of e2b_api_key
tool_settings.e2b_api_key = original_api_key
def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
"""
Test sandbox config and environment variable functions for both LocalClient and RESTClient.

View File

@@ -27,9 +27,12 @@ from letta.schemas.letta_message import (
)
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
from letta.settings import model_settings, tool_settings
from letta.services.user_manager import UserManager
from letta.settings import model_settings
from tests.helpers.client_helper import upload_file_using_client
# from tests.utils import create_config
@@ -54,21 +57,6 @@ def run_server():
start_server(debug=True)
@pytest.fixture
def mock_e2b_api_key_none():
# Store the original value of e2b_api_key
original_api_key = tool_settings.e2b_api_key
# Set e2b_api_key to None
tool_settings.e2b_api_key = None
# Yield control to the test
yield
# Restore the original value of e2b_api_key
tool_settings.e2b_api_key = original_api_key
# Fixture to create clients with different configurations
@pytest.fixture(
# params=[{"server": True}, {"server": False}], # whether to use REST API server
@@ -119,6 +107,22 @@ def agent(client: Union[LocalClient, RESTClient]):
client.delete_agent(agent_state.id)
@pytest.fixture
def default_organization():
"""Fixture to create and return the default organization."""
manager = OrganizationManager()
org = manager.create_default_organization()
yield org
@pytest.fixture
def default_user(default_organization):
"""Fixture to create and return the default user within the default organization."""
manager = UserManager()
user = manager.create_default_user(org_id=default_organization.id)
yield user
def test_agent(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState):
# test client.rename_agent
@@ -612,7 +616,7 @@ def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTCli
@pytest.fixture
def cleanup_agents():
def cleanup_agents(client):
created_agents = []
yield created_agents
# Cleanup will run even if test fails
@@ -624,7 +628,7 @@ def cleanup_agents():
# NOTE: we need to add this back once agents can also create blocks during agent creation
def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]):
def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str], default_user):
"""Test that we can set an initial message sequence
If we pass in None, we should get a "default" message sequence
@@ -638,11 +642,12 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
reference_init_messages = initialize_message_sequence(
model=agent.llm_config.model,
system=agent.system,
agent_id=agent.id,
memory=agent.memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
actor=default_user,
)
# system, login message, send_message test, send_message receipt
@@ -661,24 +666,8 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
# Test with empty sequence
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
cleanup_agents.append(empty_agent_state.id)
# NOTE: allowed to be None initially
# assert empty_agent_state.message_ids is not None
# assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}"
# Test with custom sequence
# custom_sequence = [
# Message(
# role=MessageRole.user,
# text="Hello, how are you?",
# user_id=agent.user_id,
# agent_id=agent.id,
# model=agent.llm_config.model,
# name=None,
# tool_calls=None,
# tool_call_id=None,
# ),
# ]
custom_sequence = [{"text": "Hello, how are you?", "role": "user"}]
custom_sequence = [Message(**{"text": "Hello, how are you?", "role": "user"})]
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
cleanup_agents.append(custom_agent_state.id)
assert custom_agent_state.message_ids is not None
@@ -687,7 +676,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
# assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
# shoule be contained in second message (after system message)
assert custom_sequence[0]["text"] in client.get_in_context_messages(custom_agent_state.id)[1].text
assert custom_sequence[0].text in client.get_in_context_messages(custom_agent_state.id)[1].text
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):

View File

@@ -8,7 +8,6 @@ from letta.schemas.agent import PersistedAgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
from letta.schemas.tool import ToolCreate
@pytest.fixture(scope="module")
@@ -293,6 +292,10 @@ def test_tools(client: LocalClient):
"""
print(msg)
# Clean all tools first
for tool in client.list_tools():
client.delete_tool(tool.id)
# create tool
tool = client.create_or_update_tool(func=print_tool, tags=["extras"])
@@ -330,49 +333,50 @@ def test_tools_from_composio_basic(client: LocalClient):
# The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable
def test_tools_from_langchain(client: LocalClient):
# create langchain tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
# Add the tool
tool = client.load_langchain_tool(
langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
)
# list tools
tools = client.list_tools()
assert tool.name in [t.name for t in tools]
# get tool
tool_id = client.get_tool_id(name=tool.name)
retrieved_tool = client.get_tool(tool_id)
source_code = retrieved_tool.source_code
# Parse the function and attempt to use it
local_scope = {}
exec(source_code, {}, local_scope)
func = local_scope[tool.name]
expected_content = "Albert Einstein"
assert expected_content in func(query="Albert Einstein")
def test_tool_creation_langchain_missing_imports(client: LocalClient):
# create langchain tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
# Translate to memGPT Tool
# Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
with pytest.raises(RuntimeError):
ToolCreate.from_langchain(langchain_tool)
# TODO: Langchain seems to have issues with Pydantic
# TODO: Langchain tools are breaking every two weeks bc of changes on their side
# def test_tools_from_langchain(client: LocalClient):
# # create langchain tool
# from langchain_community.tools import WikipediaQueryRun
# from langchain_community.utilities import WikipediaAPIWrapper
#
# langchain_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
#
# # Add the tool
# tool = client.load_langchain_tool(
# langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
# )
#
# # list tools
# tools = client.list_tools()
# assert tool.name in [t.name for t in tools]
#
# # get tool
# tool_id = client.get_tool_id(name=tool.name)
# retrieved_tool = client.get_tool(tool_id)
# source_code = retrieved_tool.source_code
#
# # Parse the function and attempt to use it
# local_scope = {}
# exec(source_code, {}, local_scope)
# func = local_scope[tool.name]
#
# expected_content = "Albert Einstein"
# assert expected_content in func(query="Albert Einstein")
#
#
# def test_tool_creation_langchain_missing_imports(client: LocalClient):
# # create langchain tool
# from langchain_community.tools import WikipediaQueryRun
# from langchain_community.utilities import WikipediaAPIWrapper
#
# api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
# langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
#
# # Translate to memGPT Tool
# # Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
# with pytest.raises(RuntimeError):
# ToolCreate.from_langchain(langchain_tool)
def test_shared_blocks_without_send_message(client: LocalClient):

View File

@@ -1,4 +1,5 @@
import os
from datetime import datetime, timedelta
import pytest
from sqlalchemy import delete
@@ -11,6 +12,7 @@ from letta.orm import (
BlocksAgents,
FileMetadata,
Job,
Message,
Organization,
SandboxConfig,
SandboxEnvironmentVariable,
@@ -29,11 +31,12 @@ from letta.schemas.agent import CreateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.sandbox_config import (
E2BSandboxConfig,
@@ -47,7 +50,7 @@ from letta.schemas.sandbox_config import (
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.schemas.tool import ToolUpdate
from letta.services.block_manager import BlockManager
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
@@ -77,6 +80,7 @@ using_sqlite = not bool(os.getenv("LETTA_PG_URI"))
def clear_tables(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(Message))
session.execute(delete(Job))
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
session.execute(delete(BlocksAgents))
@@ -191,6 +195,21 @@ def print_tool(server: SyncServer, default_user, default_organization):
yield tool
@pytest.fixture
def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent):
"""Fixture to create a tool with default settings and clean up after the test."""
# Set up message
message = PydanticMessage(
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
role="user",
text="Hello, world!",
)
msg = server.message_manager.create_message(message, actor=default_user)
yield msg
@pytest.fixture
def sandbox_config_fixture(server: SyncServer, default_user):
sandbox_config_create = SandboxConfigCreate(
@@ -274,6 +293,7 @@ def other_tool(server: SyncServer, default_user, default_organization):
# Yield the created tool
yield tool
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
@@ -548,6 +568,163 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
assert len(tools) == 0
# ======================================================================================================================
# Message Manager Tests
# ======================================================================================================================
def test_message_create(server: SyncServer, hello_world_message_fixture, default_user):
"""Test creating a message using hello_world_message_fixture fixture"""
assert hello_world_message_fixture.id is not None
assert hello_world_message_fixture.text == "Hello, world!"
assert hello_world_message_fixture.role == "user"
# Verify we can retrieve it
retrieved = server.message_manager.get_message_by_id(
hello_world_message_fixture.id,
actor=default_user,
)
assert retrieved is not None
assert retrieved.id == hello_world_message_fixture.id
assert retrieved.text == hello_world_message_fixture.text
assert retrieved.role == hello_world_message_fixture.role
def test_message_get_by_id(server: SyncServer, hello_world_message_fixture, default_user):
"""Test retrieving a message by ID"""
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
assert retrieved is not None
assert retrieved.id == hello_world_message_fixture.id
assert retrieved.text == hello_world_message_fixture.text
def test_message_update(server: SyncServer, hello_world_message_fixture, default_user):
"""Test updating a message"""
new_text = "Updated text"
hello_world_message_fixture.text = new_text
updated = server.message_manager.update_message_by_id(hello_world_message_fixture.id, hello_world_message_fixture, actor=default_user)
assert updated is not None
assert updated.text == new_text
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
assert retrieved.text == new_text
def test_message_delete(server: SyncServer, hello_world_message_fixture, default_user):
"""Test deleting a message"""
server.message_manager.delete_message_by_id(hello_world_message_fixture.id, actor=default_user)
retrieved = server.message_manager.get_message_by_id(hello_world_message_fixture.id, actor=default_user)
assert retrieved is None
def test_message_size(server: SyncServer, hello_world_message_fixture, default_user):
"""Test counting messages with filters"""
base_message = hello_world_message_fixture
# Create additional test messages
messages = [
PydanticMessage(
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
)
for i in range(4)
]
server.message_manager.create_many_messages(messages, actor=default_user)
# Test total count
total = server.message_manager.size(actor=default_user, role=MessageRole.user)
assert total == 6 # login message + base message + 4 test messages
# TODO: change login message to be a system not user message
# Test count with agent filter
agent_count = server.message_manager.size(actor=default_user, agent_id=base_message.agent_id, role=MessageRole.user)
assert agent_count == 6
# Test count with role filter
role_count = server.message_manager.size(actor=default_user, role=base_message.role)
assert role_count == 6
# Test count with non-existent filter
empty_count = server.message_manager.size(actor=default_user, agent_id="non-existent", role=MessageRole.user)
assert empty_count == 0
def create_test_messages(server: SyncServer, base_message: PydanticMessage, default_user) -> list[PydanticMessage]:
"""Helper function to create test messages for all tests"""
messages = [
PydanticMessage(
organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, text=f"Test message {i}"
)
for i in range(4)
]
server.message_manager.create_many_messages(messages, actor=default_user)
return messages
def test_message_listing_basic(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
"""Test basic message listing with limit"""
create_test_messages(server, hello_world_message_fixture, default_user)
results = server.message_manager.list_user_messages_for_agent(agent_id=sarah_agent.id, limit=3, actor=default_user)
assert len(results) == 3
def test_message_listing_cursor(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
"""Test cursor-based pagination functionality"""
create_test_messages(server, hello_world_message_fixture, default_user)
# Make sure there are 5 messages
assert server.message_manager.size(actor=default_user, role=MessageRole.user) == 6
# Get first page
first_page = server.message_manager.list_user_messages_for_agent(agent_id=sarah_agent.id, actor=default_user, limit=3)
assert len(first_page) == 3
last_id_on_first_page = first_page[-1].id
# Get second page
second_page = server.message_manager.list_user_messages_for_agent(
agent_id=sarah_agent.id, actor=default_user, cursor=last_id_on_first_page, limit=3
)
assert len(second_page) == 3 # Should have 2 remaining messages
assert all(r1.id != r2.id for r1 in first_page for r2 in second_page)
def test_message_listing_filtering(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
"""Test filtering messages by agent ID"""
create_test_messages(server, hello_world_message_fixture, default_user)
agent_results = server.message_manager.list_user_messages_for_agent(agent_id=sarah_agent.id, actor=default_user, limit=10)
assert len(agent_results) == 6 # login message + base message + 4 test messages
assert all(msg.agent_id == hello_world_message_fixture.agent_id for msg in agent_results)
def test_message_listing_text_search(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
"""Test searching messages by text content"""
create_test_messages(server, hello_world_message_fixture, default_user)
search_results = server.message_manager.list_user_messages_for_agent(
agent_id=sarah_agent.id, actor=default_user, query_text="Test message", limit=10
)
assert len(search_results) == 4
assert all("Test message" in msg.text for msg in search_results)
# Test no results
search_results = server.message_manager.list_user_messages_for_agent(
agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10
)
assert len(search_results) == 0
def test_message_listing_date_range_filtering(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent):
"""Test filtering messages by date range"""
create_test_messages(server, hello_world_message_fixture, default_user)
now = datetime.utcnow()
date_results = server.message_manager.list_user_messages_for_agent(
agent_id=sarah_agent.id, actor=default_user, start_date=now - timedelta(minutes=1), end_date=now + timedelta(minutes=1), limit=10
)
assert len(date_results) > 0
# ======================================================================================================================
# Block Manager Tests
# ======================================================================================================================
@@ -1211,9 +1388,7 @@ def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent,
# Change the tool name
new_name = "banana"
tool = server.tool_manager.update_tool_by_id(
tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user
)
tool = server.tool_manager.update_tool_by_id(tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user)
assert tool.name == new_name
# Get the association
@@ -1225,9 +1400,7 @@ def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent,
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
with pytest.raises(ForeignKeyConstraintViolationError):
server.tools_agents_manager.add_tool_to_agent(
agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name"
)
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name")
def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, print_tool, other_tool):
@@ -1240,9 +1413,7 @@ def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, pri
def test_remove_tool_with_name_from_agent(server, sarah_agent, default_user, print_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(
agent_id=sarah_agent.id, tool_name=print_tool.name
)
removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name)
assert removed_tool.tool_name == print_tool.name
assert removed_tool.tool_id == print_tool.id
@@ -1280,6 +1451,7 @@ def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user,
with pytest.raises(ForeignKeyConstraintViolationError):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, other_tool):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
@@ -1290,6 +1462,7 @@ def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, o
assert not retrieved_tool_ids
# ======================================================================================================================
# JobManager Tests
# ======================================================================================================================

View File

@@ -1,30 +1,40 @@
import json
import pytest
from letta import BasicBlockMemory
from letta import offline_memory_agent
from letta.client.client import Block, create_client
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.offline_memory_agent import (
rethink_memory,
finish_rethinking_memory,
rethink_memory_convo,
finish_rethinking_memory_convo,
rethink_memory,
rethink_memory_convo,
trigger_rethink_memory,
trigger_rethink_memory_convo,
)
from letta.prompts import gpt_system
from letta.schemas.agent import AgentType
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.tool_rule import TerminalToolRule
from letta.utils import get_human_text, get_persona_text
def test_ripple_edit():
@pytest.fixture(scope="module")
def client():
client = create_client()
assert client is not None
trigger_rethink_memory_tool = client.create_tool(trigger_rethink_memory)
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@pytest.fixture(autouse=True)
def clear_agents(client):
for agent in client.list_agents():
client.delete_agent(agent.id)
def test_ripple_edit(client, mock_e2b_api_key_none):
trigger_rethink_memory_tool = client.create_or_update_tool(trigger_rethink_memory)
conversation_human_block = Block(name="human", label="human", value=get_human_text(DEFAULT_HUMAN), limit=2000)
conversation_persona_block = Block(name="persona", label="persona", value=get_persona_text(DEFAULT_PERSONA), limit=2000)
@@ -61,12 +71,7 @@ def test_ripple_edit():
)
assert conversation_agent is not None
assert set(conversation_agent.memory.list_block_labels()) == set([
"persona",
"human",
"fact_block",
"rethink_memory_block",
])
assert set(conversation_agent.memory.list_block_labels()) == {"persona", "human", "fact_block", "rethink_memory_block"}
rethink_memory_tool = client.create_tool(rethink_memory)
finish_rethinking_memory_tool = client.create_tool(finish_rethinking_memory)
@@ -82,7 +87,7 @@ def test_ripple_edit():
include_base_tools=False,
)
assert offline_memory_agent is not None
assert set(offline_memory_agent.memory.list_block_labels())== set(["persona", "human", "fact_block", "rethink_memory_block"])
assert set(offline_memory_agent.memory.list_block_labels()) == {"persona", "human", "fact_block", "rethink_memory_block"}
response = client.user_message(
agent_id=conversation_agent.id, message="[trigger_rethink_memory]: Messi has now moved to playing for Inter Miami"
)
@@ -92,12 +97,14 @@ def test_ripple_edit():
conversation_agent = client.get_agent(agent_id=conversation_agent.id)
assert conversation_agent.memory.get_block("rethink_memory_block").value != "[empty]"
# Clean up agent
client.create_agent(conversation_agent.id)
client.delete_agent(offline_memory_agent.id)
def test_chat_only_agent():
client = create_client()
rethink_memory = client.create_tool(rethink_memory_convo)
finish_rethinking_memory = client.create_tool(finish_rethinking_memory_convo)
def test_chat_only_agent(client, mock_e2b_api_key_none):
rethink_memory = client.create_or_update_tool(rethink_memory_convo)
finish_rethinking_memory = client.create_or_update_tool(finish_rethinking_memory_convo)
conversation_human_block = Block(name="chat_agent_human", label="chat_agent_human", value=get_human_text(DEFAULT_HUMAN), limit=2000)
conversation_persona_block = Block(
@@ -114,10 +121,10 @@ def test_chat_only_agent():
tools=["send_message"],
memory=conversation_memory,
include_base_tools=False,
metadata = {"offline_memory_tools": [rethink_memory.name, finish_rethinking_memory.name]}
metadata={"offline_memory_tools": [rethink_memory.name, finish_rethinking_memory.name]},
)
assert chat_only_agent is not None
assert set(chat_only_agent.memory.list_block_labels()) == set(["chat_agent_persona", "chat_agent_human"])
assert set(chat_only_agent.memory.list_block_labels()) == {"chat_agent_persona", "chat_agent_human"}
for message in ["hello", "my name is not chad, my name is swoodily"]:
client.send_message(agent_id=chat_only_agent.id, message=message, role="user")
@@ -125,3 +132,6 @@ def test_chat_only_agent():
chat_only_agent = client.get_agent(agent_id=chat_only_agent.id)
assert chat_only_agent.memory.get_block("chat_agent_human").value != get_human_text(DEFAULT_HUMAN)
# Clean up agent
client.delete_agent(chat_only_agent.id)

View File

@@ -17,7 +17,6 @@ from letta.schemas.letta_message import (
SystemMessage,
UserMessage,
)
from letta.schemas.message import Message
from letta.schemas.user import User
from .test_managers import DEFAULT_EMBEDDING_CONFIG
@@ -91,7 +90,7 @@ def agent_id(server, user_id):
def test_error_on_nonexistent_agent(server, user_id, agent_id):
try:
fake_agent_id = uuid.uuid4()
fake_agent_id = str(uuid.uuid4())
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
raise Exception("user_message call should have failed")
except (KeyError, ValueError) as e:
@@ -388,7 +387,7 @@ def _test_get_messages_letta_format(
agent_id,
reverse=False,
):
"""Reverse is off by default, the GET goes in chronological order"""
"""Test mapping between messages and letta_messages with reverse=False."""
messages = server.get_agent_recall_cursor(
user_id=user_id,
@@ -397,7 +396,6 @@ def _test_get_messages_letta_format(
reverse=reverse,
return_message_object=True,
)
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
assert all(isinstance(m, Message) for m in messages)
letta_messages = server.get_agent_recall_cursor(
@@ -407,140 +405,96 @@ def _test_get_messages_letta_format(
reverse=reverse,
return_message_object=False,
)
# letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
assert all(isinstance(m, LettaMessage) for m in letta_messages)
# Loop through `messages` while also looping through `letta_messages`
# Each message in `messages` should have 1+ corresponding messages in `letta_messages`
# If role of message (in `messages`) is `assistant`,
# then there should be two messages in `letta_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
# If role of message (in `messages`) is `user`, then there should be one message in `letta_messages` which is type UserMessage.
# If role of message (in `messages`) is `system`, then there should be one message in `letta_messages` which is type SystemMessage.
# If role of message (in `messages`) is `tool`, then there should be one message in `letta_messages` which is type FunctionReturn.
print("MESSAGES (obj):")
for i, m in enumerate(messages):
# print(m)
print(f"{i}: {m.role}, {m.text[:50]}...")
# print(m.role)
print("MEMGPT_MESSAGES:")
for i, m in enumerate(letta_messages):
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
# Collect system messages and their texts
system_messages = [m for m in messages if m.role == MessageRole.system]
system_texts = [m.text for m in system_messages]
# If there are multiple system messages, print the diff
if len(system_messages) > 1:
print("Differences between system messages:")
for i in range(len(system_texts) - 1):
for j in range(i + 1, len(system_texts)):
import difflib
diff = difflib.unified_diff(
system_texts[i].splitlines(),
system_texts[j].splitlines(),
fromfile=f"System Message {i+1}",
tofile=f"System Message {j+1}",
lineterm="",
)
print("\n".join(diff))
else:
print("There is only one or no system message.")
print(f"Messages: {len(messages)}, LettaMessages: {len(letta_messages)}")
letta_message_index = 0
for i, message in enumerate(messages):
assert isinstance(message, Message)
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
# Defensive bounds check for letta_messages
if letta_message_index >= len(letta_messages):
print(f"Error: letta_message_index out of range. Expected more letta_messages for message {i}: {message.role}")
raise ValueError(f"Mismatch in letta_messages length. Index: {letta_message_index}, Length: {len(letta_messages)}")
print(f"Processing message {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
while letta_message_index < len(letta_messages):
letta_message = letta_messages[letta_message_index]
print(f"letta_message {letta_message_index}: {str(letta_message)[:50]}")
# Validate mappings for assistant role
if message.role == MessageRole.assistant:
print(f"i={i}, M=assistant, MM={type(letta_message)}")
print(f"Assistant Message at {i}: {type(letta_message)}")
# If reverse, function call will come first
if reverse:
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
# Reverse handling: FunctionCallMessages come first
if message.tool_calls:
for tool_call in message.tool_calls:
# Try to parse the tool call args
try:
json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
except json.JSONDecodeError:
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
assert isinstance(letta_message, FunctionCallMessage)
letta_message_index += 1
if letta_message_index >= len(letta_messages):
break
letta_message = letta_messages[letta_message_index]
if message.text is not None:
if message.text:
assert isinstance(letta_message, InternalMonologue)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
else:
if message.text is not None:
else: # Non-reverse handling
if message.text:
assert isinstance(letta_message, InternalMonologue)
letta_message_index += 1
if letta_message_index >= len(letta_messages):
break
letta_message = letta_messages[letta_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
if message.tool_calls:
for tool_call in message.tool_calls:
# Try to parse the tool call args
try:
json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
except json.JSONDecodeError:
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
assert isinstance(letta_message, FunctionCallMessage)
assert tool_call.function.name == letta_message.function_call.name
assert tool_call.function.arguments == letta_message.function_call.arguments
letta_message_index += 1
if letta_message_index >= len(letta_messages):
break
letta_message = letta_messages[letta_message_index]
elif message.role == MessageRole.user:
print(f"i={i}, M=user, MM={type(letta_message)}")
assert isinstance(letta_message, UserMessage)
assert message.text == letta_message.message
letta_message_index += 1
elif message.role == MessageRole.system:
print(f"i={i}, M=system, MM={type(letta_message)}")
assert isinstance(letta_message, SystemMessage)
assert message.text == letta_message.message
letta_message_index += 1
elif message.role == MessageRole.tool:
print(f"i={i}, M=tool, MM={type(letta_message)}")
assert isinstance(letta_message, FunctionReturn)
# Check the the value in `text` is the same
assert message.text == letta_message.function_return
letta_message_index += 1
else:
raise ValueError(f"Unexpected message role: {message.role}")
# Move to the next message in the original messages list
break
break # Exit the letta_messages loop after processing one mapping
if letta_message_index < len(letta_messages):
warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}")
def test_get_messages_letta_format(server, user_id, agent_id):
for reverse in [False, True]:
# for reverse in [False, True]:
for reverse in [False]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
@@ -586,7 +540,7 @@ def ingest(message: str):
'''
def test_tool_run(server, user_id, agent_id):
def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
"""Test that the server can run tools"""
result = server.run_tool_from_source(
@@ -672,7 +626,7 @@ def test_composio_client_simple(server):
assert len(actions) > 0
def test_memory_rebuild_count(server, user_id):
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
# create agent
@@ -712,7 +666,6 @@ def test_memory_rebuild_count(server, user_id):
return len(system_messages), letta_messages
try:
# At this stage, there should only be 1 system message inside of recall storage
num_system_messages, all_messages = count_system_messages_in_recall()
# assert num_system_messages == 1, (num_system_messages, all_messages)

View File

@@ -1,318 +0,0 @@
# TODO: add back post DB refactor
# import os
# import uuid
# from datetime import datetime, timedelta
#
# import pytest
# from sqlalchemy.ext.declarative import declarative_base
#
# from letta.agent_store.storage import StorageConnector, TableType
# from letta.config import LettaConfig
# from letta.constants import BASE_TOOLS, MAX_EMBEDDING_DIM
# from letta.credentials import LettaCredentials
# from letta.embeddings import embedding_model, query_embedding
# from letta.metadata import MetadataStore
# from letta.settings import settings
# from tests import TEST_MEMGPT_CONFIG
# from tests.utils import create_config, wipe_config
#
# from .utils import with_qdrant_storage
#
# from letta.schemas.agent import AgentState
# from letta.schemas.message import Message
# from letta.schemas.passage import Passage
# from letta.schemas.user import User
#
#
## Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
# texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
# start_date = datetime(2009, 10, 5, 18, 00)
# dates = [start_date, start_date - timedelta(weeks=1), start_date + timedelta(weeks=1)]
# roles = ["user", "assistant", "assistant"]
# agent_1_id = uuid.uuid4()
# agent_2_id = uuid.uuid4()
# agent_ids = [agent_1_id, agent_2_id, agent_1_id]
# ids = [uuid.uuid4(), uuid.uuid4(), uuid.uuid4()]
# user_id = uuid.uuid4()
#
#
## Data generation functions: Passages
# def generate_passages(embed_model):
# """Generate list of 3 Passage objects"""
# # embeddings: use openai if env is set, otherwise local
# passages = []
# for text, _, _, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
# embedding, embedding_model, embedding_dim = None, None, None
# if embed_model:
# embedding = embed_model.get_text_embedding(text)
# embedding_model = "gpt-4"
# embedding_dim = len(embedding)
# passages.append(
# Passage(
# user_id=user_id,
# text=text,
# agent_id=agent_id,
# embedding=embedding,
# data_source="test_source",
# id=id,
# embedding_dim=embedding_dim,
# embedding_model=embedding_model,
# )
# )
# return passages
#
#
## Data generation functions: Messages
# def generate_messages(embed_model):
# """Generate list of 3 Message objects"""
# messages = []
# for text, date, role, agent_id, id in zip(texts, dates, roles, agent_ids, ids):
# embedding, embedding_model, embedding_dim = None, None, None
# if embed_model:
# embedding = embed_model.get_text_embedding(text)
# embedding_model = "gpt-4"
# embedding_dim = len(embedding)
# messages.append(
# Message(
# user_id=user_id,
# text=text,
# agent_id=agent_id,
# role=role,
# created_at=date,
# id=id,
# model="gpt-4",
# embedding=embedding,
# embedding_model=embedding_model,
# embedding_dim=embedding_dim,
# )
# )
# print(messages[-1].text)
# return messages
#
#
# @pytest.fixture(autouse=True)
# def clear_dynamically_created_models():
# """Wipe globals for SQLAlchemy"""
# yield
# for key in list(globals().keys()):
# if key.endswith("Model"):
# del globals()[key]
#
#
# @pytest.fixture(autouse=True)
# def recreate_declarative_base():
# """Recreate the declarative base before each test"""
# global Base
# Base = declarative_base()
# yield
# Base.metadata.clear()
#
#
# @pytest.mark.parametrize("storage_connector", with_qdrant_storage(["postgres", "chroma", "sqlite", "milvus"]))
## @pytest.mark.parametrize("storage_connector", ["sqlite", "chroma"])
## @pytest.mark.parametrize("storage_connector", ["postgres"])
# @pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
# def test_storage(
# storage_connector,
# table_type,
# clear_dynamically_created_models,
# recreate_declarative_base,
# ):
# # setup letta config
# # TODO: set env for different config path
#
# # hacky way to cleanup globals that scruw up tests
# # for table_name in ['Message']:
# # if 'Message' in globals():
# # print("Removing messages", globals()['Message'])
# # del globals()['Message']
#
# wipe_config()
# if os.getenv("OPENAI_API_KEY"):
# create_config("openai")
# credentials = LettaCredentials(
# openai_key=os.getenv("OPENAI_API_KEY"),
# )
# else: # hosted
# create_config("letta_hosted")
# LettaCredentials()
#
# config = LettaConfig.load()
# TEST_MEMGPT_CONFIG.default_embedding_config = config.default_embedding_config
# TEST_MEMGPT_CONFIG.default_llm_config = config.default_llm_config
#
# if storage_connector == "postgres":
# TEST_MEMGPT_CONFIG.archival_storage_uri = settings.letta_pg_uri
# TEST_MEMGPT_CONFIG.recall_storage_uri = settings.letta_pg_uri
# TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
# TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
# if storage_connector == "lancedb":
# # TODO: complete lancedb implementation
# if not os.getenv("LANCEDB_TEST_URL"):
# print("Skipping test, missing LanceDB URI")
# return
# TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["LANCEDB_TEST_URL"]
# TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["LANCEDB_TEST_URL"]
# TEST_MEMGPT_CONFIG.archival_storage_type = "lancedb"
# TEST_MEMGPT_CONFIG.recall_storage_type = "lancedb"
# if storage_connector == "chroma":
# if table_type == TableType.RECALL_MEMORY:
# print("Skipping test, chroma only supported for archival memory")
# return
# TEST_MEMGPT_CONFIG.archival_storage_type = "chroma"
# TEST_MEMGPT_CONFIG.archival_storage_path = "./test_chroma"
# if storage_connector == "sqlite":
# if table_type == TableType.ARCHIVAL_MEMORY:
# print("Skipping test, sqlite only supported for recall memory")
# return
# TEST_MEMGPT_CONFIG.recall_storage_type = "sqlite"
# if storage_connector == "qdrant":
# if table_type == TableType.RECALL_MEMORY:
# print("Skipping test, Qdrant only supports archival memory")
# return
# TEST_MEMGPT_CONFIG.archival_storage_type = "qdrant"
# TEST_MEMGPT_CONFIG.archival_storage_uri = "localhost:6333"
# if storage_connector == "milvus":
# if table_type == TableType.RECALL_MEMORY:
# print("Skipping test, Milvus only supports archival memory")
# return
# TEST_MEMGPT_CONFIG.archival_storage_type = "milvus"
# TEST_MEMGPT_CONFIG.archival_storage_uri = "./milvus.db"
# # get embedding model
# embedding_config = TEST_MEMGPT_CONFIG.default_embedding_config
# embed_model = embedding_model(TEST_MEMGPT_CONFIG.default_embedding_config)
#
# # create user
# ms = MetadataStore(TEST_MEMGPT_CONFIG)
# ms.delete_user(user_id)
# user = User(id=user_id)
# agent = AgentState(
# user_id=user_id,
# name="agent_1",
# id=agent_1_id,
# llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
# embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
# system="",
# tools=BASE_TOOLS,
# state={
# "persona": "",
# "human": "",
# "messages": None,
# },
# )
# ms.create_user(user)
# ms.create_agent(agent)
#
# # create storage connector
# conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
# # conn.client.delete_collection(conn.collection.name) # clear out data
# conn.delete_table()
# conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
#
# # generate data
# if table_type == TableType.ARCHIVAL_MEMORY:
# records = generate_passages(embed_model)
# elif table_type == TableType.RECALL_MEMORY:
# records = generate_messages(embed_model)
# else:
# raise NotImplementedError(f"Table type {table_type} not implemented")
#
# # check record dimentions
# print("TABLE TYPE", table_type, type(records[0]), len(records[0].embedding))
# if embed_model:
# assert len(records[0].embedding) == MAX_EMBEDDING_DIM, f"Expected {MAX_EMBEDDING_DIM}, got {len(records[0].embedding)}"
# assert (
# records[0].embedding_dim == embedding_config.embedding_dim
# ), f"Expected {embedding_config.embedding_dim}, got {records[0].embedding_dim}"
#
# # test: insert
# conn.insert(records[0])
# assert conn.size() == 1, f"Expected 1 record, got {conn.size()}: {conn.get_all()}"
#
# # test: insert_many
# conn.insert_many(records[1:])
# assert (
# conn.size() == 2
# ), f"Expected 2 records, got {conn.size()}: {conn.get_all()}" # expect 2, since storage connector filters for agent1
#
# # test: update
# # NOTE: only testing with messages
# if table_type == TableType.RECALL_MEMORY:
# TEST_STRING = "hello world"
#
# updated_record = records[1]
# updated_record.text = TEST_STRING
#
# current_record = conn.get(id=updated_record.id)
# assert current_record is not None, f"Couldn't find {updated_record.id}"
# assert current_record.text != TEST_STRING, (current_record.text, TEST_STRING)
#
# conn.update(updated_record)
# new_record = conn.get(id=updated_record.id)
# assert new_record is not None, f"Couldn't find {updated_record.id}"
# assert new_record.text == TEST_STRING, (new_record.text, TEST_STRING)
#
# # test: list_loaded_data
# # TODO: add back
# # if table_type == TableType.ARCHIVAL_MEMORY:
# # sources = StorageConnector.list_loaded_data(storage_type=storage_connector)
# # assert len(sources) == 1, f"Expected 1 source, got {len(sources)}"
# # assert sources[0] == "test_source", f"Expected 'test_source', got {sources[0]}"
#
# # test: get_all_paginated
# paginated_total = 0
# for page in conn.get_all_paginated(page_size=1):
# paginated_total += len(page)
# assert paginated_total == 2, f"Expected 2 records, got {paginated_total}"
#
# # test: get_all
# all_records = conn.get_all()
# assert len(all_records) == 2, f"Expected 2 records, got {len(all_records)}"
# all_records = conn.get_all(limit=1)
# assert len(all_records) == 1, f"Expected 1 records, got {len(all_records)}"
#
# # test: get
# print("GET ID", ids[0], records)
# res = conn.get(id=ids[0])
# assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}"
#
# # test: size
# assert conn.size() == 2, f"Expected 2 records, got {conn.size()}"
# assert conn.size(filters={"agent_id": agent.id}) == 2, f"Expected 2 records, got {conn.size(filters={'agent_id', agent.id})}"
# if table_type == TableType.RECALL_MEMORY:
# assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}"
#
# # test: query (vector)
# if table_type == TableType.ARCHIVAL_MEMORY:
# query = "why was she crying"
# query_vec = query_embedding(embed_model, query)
# res = conn.query(None, query_vec, top_k=2)
# assert len(res) == 2, f"Expected 2 results, got {len(res)}"
# print("Archival memory results", res)
# assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
#
# # test optional query functions for recall memory
# if table_type == TableType.RECALL_MEMORY:
# # test: query_text
# query = "CindereLLa"
# res = conn.query_text(query)
# assert len(res) == 1, f"Expected 1 result, got {len(res)}"
# assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}"
#
# # test: query_date (recall memory only)
# print("Testing recall memory date search")
# start_date = datetime(2009, 10, 5, 18, 00)
# start_date = start_date - timedelta(days=1)
# end_date = start_date + timedelta(days=1)
# res = conn.query_date(start_date=start_date, end_date=end_date)
# print("DATE", res)
# assert len(res) == 1, f"Expected 1 result, got {len(res)}: {res}"
#
# # test: delete
# conn.delete({"id": ids[0]})
# assert conn.size() == 1, f"Expected 2 records, got {conn.size()}"
#
# # cleanup
# ms.delete_user(user_id)
#

View File

@@ -1,14 +1,11 @@
import uuid
from typing import List
import pytest
from letta import create_client
from letta.client.client import LocalClient
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.settings import tool_settings
from .utils import wipe_config
@@ -21,21 +18,6 @@ agent_obj = None
# TODO: these tests should add function calls into the summarized message sequence:W
@pytest.fixture
def mock_e2b_api_key_none():
# Store the original value of e2b_api_key
original_api_key = tool_settings.e2b_api_key
# Set e2b_api_key to None
tool_settings.e2b_api_key = None
# Yield control to the test
yield
# Restore the original value of e2b_api_key
tool_settings.e2b_api_key = original_api_key
def create_test_agent():
"""Create a test agent that we can call functions on"""
wipe_config()