From e7a56328cf6f06eeb881ad75bb65850ef30db128 Mon Sep 17 00:00:00 2001 From: mlong93 <35275280+mlong93@users.noreply.github.com> Date: Thu, 12 Dec 2024 14:29:22 -0800 Subject: [PATCH] chore: raising errors for passages, adding new passage test (#2234) Co-authored-by: Mindy Long --- letta/orm/passage.py | 7 ++- letta/server/server.py | 11 +---- letta/services/passage_manager.py | 47 ++++++++----------- tests/test_managers.py | 5 +-- tests/test_server.py | 75 ++++++++++++++++--------------- 5 files changed, 66 insertions(+), 79 deletions(-) diff --git a/letta/orm/passage.py b/letta/orm/passage.py index 847c8ddd..bfa3e153 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -1,6 +1,6 @@ from datetime import datetime -from typing import List, Optional, TYPE_CHECKING -from sqlalchemy import Column, String, DateTime, Index, JSON, UniqueConstraint, ForeignKey +from typing import Optional, TYPE_CHECKING +from sqlalchemy import Column, String, DateTime, JSON, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.types import TypeDecorator, BINARY @@ -9,7 +9,7 @@ import base64 from letta.orm.source import EmbeddingConfigColumn from letta.orm.sqlalchemy_base import SqlalchemyBase -from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin +from letta.orm.mixins import FileMixin, OrganizationMixin from letta.schemas.passage import Passage as PydanticPassage from letta.config import LettaConfig @@ -19,7 +19,6 @@ from letta.settings import settings config = LettaConfig() if TYPE_CHECKING: - from letta.orm.file import File from letta.orm.organization import Organization class CommonVector(TypeDecorator): diff --git a/letta/server/server.py b/letta/server/server.py index 80a1335a..2765b269 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1345,17 +1345,10 @@ class SyncServer(Server): letta_agent = self.load_agent(agent_id=agent_id) # Insert into archival memory - passage_ids = self.passage_manager.insert_passage( - agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor, return_ids=True + return self.passage_manager.insert_passage( + agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor ) - # Update the agent - # TODO: should this update the system prompt? - save_agent(letta_agent, self.ms) - - # TODO: this is gross, fix - return [self.passage_manager.get_passage_by_id(passage_id=passage_id, actor=actor) for passage_id in passage_ids] - def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str): actor = self.user_manager.get_user_by_id(user_id=user_id) if actor is None: diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index c1933b39..ef93b732 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -26,11 +26,8 @@ class PassageManager: def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: """Fetch a passage by ID.""" with self.session_maker() as session: - try: - passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor) - return passage.to_pydantic() - except NoResultFound: - return None + passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor) + return passage.to_pydantic() @enforce_types def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: @@ -83,11 +80,6 @@ class PassageManager: actor=actor ) passages.append(passage) - - ids = [str(p.id) for p in passages] - - if return_ids: - return ids return passages @@ -101,26 +93,23 @@ class PassageManager: raise ValueError("Passage ID must be provided.") with self.session_maker() as session: - try: - # Fetch existing message from database - curr_passage = PassageModel.read( - db_session=session, - identifier=passage_id, - actor=actor, - ) - if not curr_passage: - raise ValueError(f"Passage with id {passage_id} does not exist.") + # Fetch existing message from database + curr_passage = PassageModel.read( + db_session=session, + identifier=passage_id, + actor=actor, + ) + if not curr_passage: + raise ValueError(f"Passage with id {passage_id} does not exist.") - # Update the database record with values from the provided record - update_data = passage.model_dump(exclude_unset=True, exclude_none=True) - for key, value in update_data.items(): - setattr(curr_passage, key, value) + # Update the database record with values from the provided record + update_data = passage.model_dump(exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(curr_passage, key, value) - # Commit changes - curr_passage.update(session, actor=actor) - return curr_passage.to_pydantic() - except NoResultFound: - return None + # Commit changes + curr_passage.update(session, actor=actor) + return curr_passage.to_pydantic() @enforce_types def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: @@ -145,6 +134,7 @@ class PassageManager: query_text : Optional[str] = None, start_date : Optional[datetime] = None, end_date : Optional[datetime] = None, + ascending : bool = True, source_id : Optional[str] = None, embed_query : bool = False, embedding_config: Optional[EmbeddingConfig] = None @@ -176,6 +166,7 @@ class PassageManager: start_date=start_date, end_date=end_date, limit=limit, + ascending=ascending, query_text=query_text if not embedded_text else None, query_embedding=embedded_text, **filters diff --git a/tests/test_managers.py b/tests/test_managers.py index f6e54366..745b17d7 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -9,7 +9,6 @@ from letta.embeddings import embedding_model import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.metadata import AgentModel -from letta.orm.sqlite_functions import verify_embedding_dimension, convert_array from letta.orm import ( Block, BlocksAgents, @@ -445,8 +444,8 @@ def test_passage_update(server: SyncServer, hello_world_passage_fixture, default def test_passage_delete(server: SyncServer, hello_world_passage_fixture, default_user): """Test deleting a passage""" server.passage_manager.delete_passage_by_id(hello_world_passage_fixture.id, actor=default_user) - retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user) - assert retrieved is None + with pytest.raises(NoResultFound): + server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user) def test_passage_size(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user): diff --git a/tests/test_server.py b/tests/test_server.py index ee745c5c..38404afc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -433,42 +433,47 @@ def test_get_recall_memory(server, org_id, user_id, agent_id): assert message_id in message_ids, f"{message_id} not in {message_ids}" -# TODO: Out-of-date test. pagination commands are off -# @pytest.mark.order(6) -# def test_get_archival_memory(server, user_id, agent_id): -# # test archival memory cursor pagination -# passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text") -# assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2" -# cursor1 = passages_1[-1].id -# passages_2 = server.get_agent_archival_cursor( -# user_id=user_id, -# agent_id=agent_id, -# reverse=False, -# after=cursor1, -# order_by="text", -# ) -# cursor2 = passages_2[-1].id -# passages_3 = server.get_agent_archival_cursor( -# user_id=user_id, -# agent_id=agent_id, -# reverse=False, -# before=cursor2, -# limit=1000, -# order_by="text", -# ) -# passages_3[-1].id -# # assert passages_1[0].text == "Cinderella wore a blue dress" -# assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test -# assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test +@pytest.mark.order(6) +def test_get_archival_memory(server, user_id, agent_id): + # test archival memory cursor pagination + user = server.user_manager.get_user_by_id(user_id=user_id) + + # List latest 2 passages + passages_1 = server.passage_manager.list_passages( + actor=user, agent_id=agent_id, ascending=False, limit=2, + ) + assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2" -# # test archival memory -# passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=1) -# assert len(passage_1) == 1 -# passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1, count=1000) -# assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test -# # test safe empty return -# passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000) -# assert len(passage_none) == 0 + # List next 3 passages (earliest 3) + cursor1 = passages_1[-1].id + passages_2 = server.passage_manager.list_passages( + actor=user, + agent_id=agent_id, + ascending=False, + cursor=cursor1, + ) + + # List all 5 + cursor2 = passages_1[0].created_at + passages_3 = server.passage_manager.list_passages( + actor=user, + agent_id=agent_id, + ascending=False, + end_date=cursor2, + limit=1000, + ) + # assert passages_1[0].text == "Cinderella wore a blue dress" + assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test + assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test + + # test archival memory + passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, limit=1) + assert len(passage_1) == 1 + passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passage_1[-1].id, limit=1000) + assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test + # test safe empty return + passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passages_1[0].id, limit=1000) + assert len(passage_none) == 0 def test_agent_rethink_rewrite_retry(server, user_id, agent_id):