From e40e60945a1f2ca3a678c711a22623be9ab0c76f Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 12 Nov 2024 09:57:40 -0800 Subject: [PATCH] feat: Move Source to ORM model (#1979) --- .../cda66b6cb0d6_move_sources_to_orm.py | 64 +++++++ letta/agent.py | 6 +- letta/client/client.py | 29 +-- letta/data_sources/connectors.py | 6 +- letta/llm_api/google_ai.py | 2 - letta/metadata.py | 95 +--------- letta/orm/__init__.py | 1 + letta/orm/organization.py | 4 +- letta/orm/source.py | 50 +++++ letta/providers.py | 1 - letta/schemas/source.py | 58 +++--- letta/server/rest_api/routers/v1/sources.py | 24 ++- letta/server/server.py | 83 +++----- letta/services/organization_manager.py | 25 ++- letta/services/source_manager.py | 100 ++++++++++ tests/test_client.py | 12 ++ tests/test_managers.py | 177 ++++++++++++++++-- tests/test_server.py | 8 +- 18 files changed, 509 insertions(+), 236 deletions(-) create mode 100644 alembic/versions/cda66b6cb0d6_move_sources_to_orm.py create mode 100644 letta/orm/source.py create mode 100644 letta/services/source_manager.py diff --git a/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py b/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py new file mode 100644 index 00000000..f46bef6b --- /dev/null +++ b/alembic/versions/cda66b6cb0d6_move_sources_to_orm.py @@ -0,0 +1,64 @@ +"""Move sources to orm + +Revision ID: cda66b6cb0d6 +Revises: b6d7ca024aa9 +Create Date: 2024-11-07 13:29:57.186107 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "cda66b6cb0d6" +down_revision: Union[str, None] = "b6d7ca024aa9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("sources", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True)) + op.add_column("sources", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False)) + op.add_column("sources", sa.Column("_created_by_id", sa.String(), nullable=True)) + op.add_column("sources", sa.Column("_last_updated_by_id", sa.String(), nullable=True)) + + # Data migration step: + op.add_column("sources", sa.Column("organization_id", sa.String(), nullable=True)) + # Populate `organization_id` based on `user_id` + # Use a raw SQL query to update the organization_id + op.execute( + """ + UPDATE sources + SET organization_id = users.organization_id + FROM users + WHERE sources.user_id = users.id + """ + ) + + # Set `organization_id` as non-nullable after population + op.alter_column("sources", "organization_id", nullable=False) + + op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False) + op.drop_index("sources_idx_user", table_name="sources") + op.create_foreign_key(None, "sources", "organizations", ["organization_id"], ["id"]) + op.drop_column("sources", "user_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("sources", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False)) + op.drop_constraint(None, "sources", type_="foreignkey") + op.create_index("sources_idx_user", "sources", ["user_id"], unique=False) + op.alter_column("sources", "embedding_config", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True) + op.drop_column("sources", "organization_id") + op.drop_column("sources", "_last_updated_by_id") + op.drop_column("sources", "_created_by_id") + op.drop_column("sources", "is_deleted") + op.drop_column("sources", "updated_at") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index e1939ffb..a3ddfb0f 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -46,6 +46,8 @@ from letta.schemas.passage import Passage from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics +from letta.services.source_manager import SourceManager +from letta.services.user_manager import UserManager from letta.system import ( get_heartbeat, get_initial_boot_messages, @@ -1311,7 +1313,7 @@ class Agent(BaseAgent): def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore): """Attach data with name `source_name` to the agent from source_connector.""" # TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory - + user = UserManager().get_user_by_id(self.agent_state.user_id) filters = {"user_id": self.agent_state.user_id, "source_id": source_id} size = source_connector.size(filters) page_size = 100 @@ -1339,7 +1341,7 @@ class Agent(BaseAgent): self.persistence_manager.archival_memory.storage.save() # attach to agent - source = ms.get_source(source_id=source_id) + source = SourceManager().get_source_by_id(source_id=source_id, actor=user) assert source is not None, f"Source {source_id} not found in metadata store" ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id) diff --git a/letta/client/client.py b/letta/client/client.py index c69c5478..519902ba 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -238,7 +238,7 @@ class AbstractClient(object): def delete_file_from_source(self, source_id: str, file_id: str) -> None: raise NotImplementedError - def create_source(self, name: str) -> Source: + def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source: raise NotImplementedError def delete_source(self, source_id: str): @@ -1188,7 +1188,7 @@ class RESTClient(AbstractClient): if response.status_code not in [200, 204]: raise ValueError(f"Failed to delete tool: {response.text}") - def create_source(self, name: str) -> Source: + def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source: """ Create a source @@ -1198,7 +1198,8 @@ class RESTClient(AbstractClient): Returns: source (Source): Created source """ - payload = {"name": name} + source_create = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config) + payload = source_create.model_dump() response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers) response_json = response.json() return Source(**response_json) @@ -1253,7 +1254,7 @@ class RESTClient(AbstractClient): Returns: source (Source): Updated source """ - request = SourceUpdate(id=source_id, name=name) + request = SourceUpdate(name=name) response = requests.patch(f"{self.base_url}/{self.api_prefix}/sources/{source_id}", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to update source: {response.text}") @@ -2453,7 +2454,7 @@ class LocalClient(AbstractClient): def list_active_jobs(self): return self.server.list_active_jobs(user_id=self.user_id) - def create_source(self, name: str) -> Source: + def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source: """ Create a source @@ -2463,8 +2464,10 @@ class LocalClient(AbstractClient): Returns: source (Source): Created source """ - request = SourceCreate(name=name) - return self.server.create_source(request=request, user_id=self.user_id) + source = Source( + name=name, embedding_config=embedding_config or self._default_embedding_config, organization_id=self.user.organization_id + ) + return self.server.source_manager.create_source(source=source, actor=self.user) def delete_source(self, source_id: str): """ @@ -2475,7 +2478,7 @@ class LocalClient(AbstractClient): """ # TODO: delete source data - self.server.delete_source(source_id=source_id, user_id=self.user_id) + self.server.delete_source(source_id=source_id, actor=self.user) def get_source(self, source_id: str) -> Source: """ @@ -2487,7 +2490,7 @@ class LocalClient(AbstractClient): Returns: source (Source): Source """ - return self.server.get_source(source_id=source_id, user_id=self.user_id) + return self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) def get_source_id(self, source_name: str) -> str: """ @@ -2499,7 +2502,7 @@ class LocalClient(AbstractClient): Returns: source_id (str): ID of the source """ - return self.server.get_source_id(source_name=source_name, user_id=self.user_id) + return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None): """ @@ -2532,7 +2535,7 @@ class LocalClient(AbstractClient): sources (List[Source]): List of sources """ - return self.server.list_all_sources(user_id=self.user_id) + return self.server.list_all_sources(actor=self.user) def list_attached_sources(self, agent_id: str) -> List[Source]: """ @@ -2572,8 +2575,8 @@ class LocalClient(AbstractClient): source (Source): Updated source """ # TODO should the arg here just be "source_update: Source"? - request = SourceUpdate(id=source_id, name=name) - return self.server.update_source(request=request, user_id=self.user_id) + request = SourceUpdate(name=name) + return self.server.source_manager.update_source(source_id=source_id, source_update=request, actor=self.user) # archival memory diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index f729c8ad..f9fb3d2a 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -47,7 +47,7 @@ def load_data( passage_store: StorageConnector, file_metadata_store: StorageConnector, ): - """Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id.""" + """Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id.""" embedding_config = source.embedding_config # embedding model @@ -88,7 +88,7 @@ def load_data( file_id=file_metadata.id, source_id=source.id, metadata_=passage_metadata, - user_id=source.user_id, + user_id=source.created_by_id, embedding_config=source.embedding_config, embedding=embedding, ) @@ -155,7 +155,7 @@ class DirectoryConnector(DataConnector): for metadata in extract_metadata_from_files(files): yield FileMetadata( - user_id=source.user_id, + user_id=source.created_by_id, source_id=source.id, file_name=metadata.get("file_name"), file_path=metadata.get("file_path"), diff --git a/letta/llm_api/google_ai.py b/letta/llm_api/google_ai.py index 5d4e1798..57071a23 100644 --- a/letta/llm_api/google_ai.py +++ b/letta/llm_api/google_ai.py @@ -95,10 +95,8 @@ def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = try: response = requests.get(url, headers=headers) - printd(f"response = {response}") response.raise_for_status() # Raises HTTPError for 4XX/5XX status response = response.json() # convert to dict from string - printd(f"response.json = {response}") # Grab the models out model_list = response["models"] diff --git a/letta/metadata.py b/letta/metadata.py index dddced11..9ffc81c6 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -29,7 +29,6 @@ from letta.schemas.job import Job from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction -from letta.schemas.source import Source from letta.schemas.tool_rule import ( BaseToolRule, InitToolRule, @@ -292,40 +291,6 @@ class AgentModel(Base): return agent_state -class SourceModel(Base): - """Defines data model for storing Passages (consisting of text, embedding)""" - - __tablename__ = "sources" - __table_args__ = {"extend_existing": True} - - # Assuming passage_id is the primary key - # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - id = Column(String, primary_key=True) - user_id = Column(String, nullable=False) - name = Column(String, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - embedding_config = Column(EmbeddingConfigColumn) - description = Column(String) - metadata_ = Column(JSON) - Index(__tablename__ + "_idx_user", user_id), - - # TODO: add num passages - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> Source: - return Source( - id=self.id, - user_id=self.user_id, - name=self.name, - created_at=self.created_at, - embedding_config=self.embedding_config, - description=self.description, - metadata_=self.metadata_, - ) - - class AgentSourceMappingModel(Base): """Stores mapping between agent -> source""" @@ -497,14 +462,6 @@ class MetadataStore: session.add(AgentModel(**fields)) session.commit() - @enforce_types - def create_source(self, source: Source): - with self.session_maker() as session: - if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0: - raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}") - session.add(SourceModel(**vars(source))) - session.commit() - @enforce_types def create_block(self, block: Block): with self.session_maker() as session: @@ -522,6 +479,7 @@ class MetadataStore: ): raise ValueError(f"Block with name {block.template_name} already exists") + session.add(BlockModel(**vars(block))) session.commit() @@ -536,12 +494,6 @@ class MetadataStore: session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields) session.commit() - @enforce_types - def update_source(self, source: Source): - with self.session_maker() as session: - session.query(SourceModel).filter(SourceModel.id == source.id).update(vars(source)) - session.commit() - @enforce_types def update_block(self, block: Block): with self.session_maker() as session: @@ -591,29 +543,12 @@ class MetadataStore: session.commit() - @enforce_types - def delete_source(self, source_id: str): - with self.session_maker() as session: - # delete from sources table - session.query(SourceModel).filter(SourceModel.id == source_id).delete() - - # delete any mappings - session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).delete() - - session.commit() - @enforce_types def list_agents(self, user_id: str) -> List[AgentState]: with self.session_maker() as session: results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all() return [r.to_record() for r in results] - @enforce_types - def list_sources(self, user_id: str) -> List[Source]: - with self.session_maker() as session: - results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all() - return [r.to_record() for r in results] - @enforce_types def get_agent( self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None @@ -630,21 +565,6 @@ class MetadataStore: assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result return results[0].to_record() - @enforce_types - def get_source( - self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None - ) -> Optional[Source]: - with self.session_maker() as session: - if source_id: - results = session.query(SourceModel).filter(SourceModel.id == source_id).all() - else: - assert user_id is not None and source_name is not None - results = session.query(SourceModel).filter(SourceModel.name == source_name).filter(SourceModel.user_id == user_id).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" - return results[0].to_record() - @enforce_types def get_block(self, block_id: str) -> Optional[Block]: with self.session_maker() as session: @@ -699,19 +619,10 @@ class MetadataStore: session.commit() @enforce_types - def list_attached_sources(self, agent_id: str) -> List[Source]: + def list_attached_source_ids(self, agent_id: str) -> List[str]: with self.session_maker() as session: results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all() - - sources = [] - # make sure source exists - for r in results: - source = self.get_source(source_id=r.source_id) - if source: - sources.append(source) - else: - printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.") - return sources + return [r.source_id for r in results] @enforce_types def list_attached_agents(self, source_id: str) -> List[str]: diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index c95da85b..b69737ac 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -1,4 +1,5 @@ from letta.orm.base import Base from letta.orm.organization import Organization +from letta.orm.source import Source from letta.orm.tool import Tool from letta.orm.user import User diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 4b641607..a6b05ee6 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -21,13 +21,13 @@ class Organization(SqlalchemyBase): id: Mapped[str] = mapped_column(String, primary_key=True) name: Mapped[str] = mapped_column(doc="The display name of the organization.") + # relationships users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") + sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan") - # TODO: Map these relationships later when we actually make these models # below is just a suggestion # agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") - # sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan") # tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") # documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/orm/source.py b/letta/orm/source.py new file mode 100644 index 00000000..e8a7ed47 --- /dev/null +++ b/letta/orm/source.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING, Optional + +from sqlalchemy import JSON, TypeDecorator +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.source import Source as PydanticSource + +if TYPE_CHECKING: + from letta.orm.organization import Organization + + +class EmbeddingConfigColumn(TypeDecorator): + """Custom type for storing EmbeddingConfig as JSON""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value: + # return vars(value) + if isinstance(value, EmbeddingConfig): + return value.model_dump() + return value + + def process_result_value(self, value, dialect): + if value: + return EmbeddingConfig(**value) + return value + + +class Source(SqlalchemyBase, OrganizationMixin): + """A source represents an embedded text passage""" + + __tablename__ = "sources" + __pydantic_model__ = PydanticSource + + name: Mapped[str] = mapped_column(doc="the name of the source, must be unique within the org", nullable=False) + description: Mapped[str] = mapped_column(nullable=True, doc="a human-readable description of the source") + embedding_config: Mapped[EmbeddingConfig] = mapped_column(EmbeddingConfigColumn, doc="Configuration settings for embedding.") + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the source.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") + # agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources") diff --git a/letta/providers.py b/letta/providers.py index 6fa98327..63bbe475 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -462,7 +462,6 @@ class VLLMChatCompletionsProvider(Provider): response = openai_get_model_list(self.base_url, api_key=None) configs = [] - print(response) for model in response["data"]: configs.append( LLMConfig( diff --git a/letta/schemas/source.py b/letta/schemas/source.py index 8f816ad7..0a458dfd 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -1,12 +1,10 @@ from datetime import datetime from typing import Optional -from fastapi import UploadFile -from pydantic import BaseModel, Field +from pydantic import Field from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_base import LettaBase -from letta.utils import get_utc_time class BaseSource(LettaBase): @@ -15,15 +13,6 @@ class BaseSource(LettaBase): """ __id_prefix__ = "source" - description: Optional[str] = Field(None, description="The description of the source.") - embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.") - # NOTE: .metadata is a reserved attribute on SQLModel - metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") - - -class SourceCreate(BaseSource): - name: str = Field(..., description="The name of the source.") - description: Optional[str] = Field(None, description="The description of the source.") class Source(BaseSource): @@ -34,7 +23,6 @@ class Source(BaseSource): id (str): The ID of the source name (str): The name of the source. embedding_config (EmbeddingConfig): The embedding configuration used by the source. - created_at (datetime): The creation date of the source. user_id (str): The ID of the user that created the source. metadata_ (dict): Metadata associated with the source. description (str): The description of the source. @@ -42,21 +30,39 @@ class Source(BaseSource): id: str = BaseSource.generate_id_field() name: str = Field(..., description="The name of the source.") + description: Optional[str] = Field(None, description="The description of the source.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.") - created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the source.") - user_id: str = Field(..., description="The ID of the user that created the source.") + organization_id: Optional[str] = Field(None, description="The ID of the organization that created the source.") + metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") + + # metadata fields + created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") + last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") + created_at: Optional[datetime] = Field(None, description="The timestamp when the source was created.") + updated_at: Optional[datetime] = Field(None, description="The timestamp when the source was last updated.") + + +class SourceCreate(BaseSource): + """ + Schema for creating a new Source. + """ + + # required + name: str = Field(..., description="The name of the source.") + # TODO: @matt, make this required after shub makes the FE changes + embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.") + + # optional + description: Optional[str] = Field(None, description="The description of the source.") + metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") class SourceUpdate(BaseSource): - id: str = Field(..., description="The ID of the source.") + """ + Schema for updating an existing Source. + """ + name: Optional[str] = Field(None, description="The name of the source.") - - -class UploadFileToSourceRequest(BaseModel): - file: UploadFile = Field(..., description="The file to upload.") - - -class UploadFileToSourceResponse(BaseModel): - source: Source = Field(..., description="The source the file was uploaded to.") - added_passages: int = Field(..., description="The number of passages added to the source.") - added_documents: int = Field(..., description="The number of files added to the source.") + description: Optional[str] = Field(None, description="The description of the source.") + metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.") + embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.") diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 388fa3e0..58047f12 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -36,7 +36,7 @@ def get_source( """ actor = server.get_user_or_default(user_id=user_id) - return server.get_source(source_id=source_id, user_id=actor.id) + return server.source_manager.get_source_by_id(source_id=source_id, actor=actor) @router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name") @@ -50,8 +50,8 @@ def get_source_id_by_name( """ actor = server.get_user_or_default(user_id=user_id) - source_id = server.get_source_id(source_name=source_name, user_id=actor.id) - return source_id + source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor) + return source.id @router.get("/", response_model=List[Source], operation_id="list_sources") @@ -64,12 +64,12 @@ def list_sources( """ actor = server.get_user_or_default(user_id=user_id) - return server.list_all_sources(user_id=actor.id) + return server.list_all_sources(actor=actor) @router.post("/", response_model=Source, operation_id="create_source") def create_source( - source: SourceCreate, + source_create: SourceCreate, server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): @@ -77,8 +77,9 @@ def create_source( Create a new data source. """ actor = server.get_user_or_default(user_id=user_id) + source = Source(**source_create.model_dump()) - return server.create_source(request=source, user_id=actor.id) + return server.source_manager.create_source(source=source, actor=actor) @router.patch("/{source_id}", response_model=Source, operation_id="update_source") @@ -92,10 +93,7 @@ def update_source( Update the name or documentation of an existing data source. """ actor = server.get_user_or_default(user_id=user_id) - - assert source.id == source_id, "Source ID in path must match ID in request body" - - return server.update_source(request=source, user_id=actor.id) + return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor) @router.delete("/{source_id}", response_model=None, operation_id="delete_source") @@ -109,7 +107,7 @@ def delete_source( """ actor = server.get_user_or_default(user_id=user_id) - server.delete_source(source_id=source_id, user_id=actor.id) + server.delete_source(source_id=source_id, actor=actor) @router.post("/{source_id}/attach", response_model=Source, operation_id="attach_agent_to_source") @@ -124,7 +122,7 @@ def attach_source_to_agent( """ actor = server.get_user_or_default(user_id=user_id) - source = server.ms.get_source(source_id=source_id, user_id=actor.id) + source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) assert source is not None, f"Source with id={source_id} not found." source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id) return source @@ -158,7 +156,7 @@ def upload_file_to_source( """ actor = server.get_user_or_default(user_id=user_id) - source = server.ms.get_source(source_id=source_id, user_id=actor.id) + source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) assert source is not None, f"Source with id={source_id} not found." bytes = file.file.read() diff --git a/letta/server/server.py b/letta/server/server.py index a5c86d38..023fffad 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -78,12 +78,13 @@ from letta.schemas.memory import ( from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage from letta.schemas.organization import Organization from letta.schemas.passage import Passage -from letta.schemas.source import Source, SourceCreate, SourceUpdate +from letta.schemas.source import Source from letta.schemas.tool import Tool, ToolCreate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.services.agents_tags_manager import AgentsTagsManager from letta.services.organization_manager import OrganizationManager +from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.utils import create_random_username, json_dumps, json_loads @@ -249,6 +250,7 @@ class SyncServer(Server): self.organization_manager = OrganizationManager() self.user_manager = UserManager() self.tool_manager = ToolManager() + self.source_manager = SourceManager() self.agents_tags_manager = AgentsTagsManager() # Make default user and org @@ -1560,44 +1562,12 @@ class SyncServer(Server): self.ms.delete_api_key(api_key=api_key) return api_key_obj - def create_source(self, request: SourceCreate, user_id: str) -> Source: # TODO: add other fields - """Create a new data source""" - source = Source( - name=request.name, - user_id=user_id, - embedding_config=self.list_embedding_models()[0], # TODO: require providing this - ) - self.ms.create_source(source) - assert self.ms.get_source(source_name=request.name, user_id=user_id) is not None, f"Failed to create source {request.name}" - return source - - def update_source(self, request: SourceUpdate, user_id: str) -> Source: - """Update an existing data source""" - if not request.id: - existing_source = self.ms.get_source(source_name=request.name, user_id=user_id) - else: - existing_source = self.ms.get_source(source_id=request.id) - if not existing_source: - raise ValueError("Source does not exist") - - # override updated fields - if request.name: - existing_source.name = request.name - if request.metadata_: - existing_source.metadata_ = request.metadata_ - if request.description: - existing_source.description = request.description - - self.ms.update_source(existing_source) - return existing_source - - def delete_source(self, source_id: str, user_id: str): + def delete_source(self, source_id: str, actor: User): """Delete a data source""" - source = self.ms.get_source(source_id=source_id, user_id=user_id) - self.ms.delete_source(source_id) + self.source_manager.delete_source(source_id=source_id, actor=actor) # delete data from passage store - passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) + passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=actor.id) passage_store.delete({"source_id": source_id}) # TODO: delete data from agent passage stores (?) @@ -1639,9 +1609,9 @@ class SyncServer(Server): # try: from letta.data_sources.connectors import DirectoryConnector - source = self.ms.get_source(source_id=source_id) + source = self.source_manager.get_source_by_id(source_id=source_id) connector = DirectoryConnector(input_files=[file_path]) - num_passages, num_documents = self.load_data(user_id=source.user_id, source_name=source.name, connector=connector) + num_passages, num_documents = self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector) # except Exception as e: # # job failed with error # error = str(e) @@ -1675,7 +1645,8 @@ class SyncServer(Server): # TODO: this should be implemented as a batch job or at least async, since it may take a long time # load data from a data source into the document store - source = self.ms.get_source(source_name=source_name, user_id=user_id) + user = self.user_manager.get_user_by_id(user_id=user_id) + source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) if source is None: raise ValueError(f"Data source {source_name} does not exist for user {user_id}") @@ -1696,9 +1667,13 @@ class SyncServer(Server): source_name: Optional[str] = None, ) -> Source: # attach a data source to an agent - data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name) - if data_source is None: - raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}") + user = self.user_manager.get_user_by_id(user_id=user_id) + if source_id: + data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=user) + elif source_name: + data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) + else: + raise ValueError(f"Need to provide at least source_id or source_name to find the source.") # get connection to data source storage source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) @@ -1719,12 +1694,14 @@ class SyncServer(Server): source_id: Optional[str] = None, source_name: Optional[str] = None, ) -> Source: - if not source_id: - assert source_name is not None, "source_name must be provided if source_id is not" - source = self.ms.get_source(source_name=source_name, user_id=user_id) - source_id = source.id + user = self.user_manager.get_user_by_id(user_id=user_id) + if source_id: + source = self.source_manager.get_source_by_id(source_id=source_id, actor=user) + elif source_name: + source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) else: - source = self.ms.get_source(source_id=source_id) + raise ValueError(f"Need to provide at least source_id or source_name to find the source.") + source_id = source.id # delete all Passage objects with source_id==source_id from agent's archival memory agent = self._get_or_load_agent(agent_id=agent_id) @@ -1739,7 +1716,9 @@ class SyncServer(Server): def list_attached_sources(self, agent_id: str) -> List[Source]: # list all attached sources to an agent - return self.ms.list_attached_sources(agent_id) + source_ids = self.ms.list_attached_source_ids(agent_id) + + return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids] def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]: # list all attached sources to an agent @@ -1749,17 +1728,17 @@ class SyncServer(Server): warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) return [] - def list_all_sources(self, user_id: str) -> List[Source]: + def list_all_sources(self, actor: User) -> List[Source]: """List all sources (w/ extra metadata) belonging to a user""" - sources = self.ms.list_sources(user_id=user_id) + sources = self.source_manager.list_sources(actor=actor) # Add extra metadata to the sources sources_with_metadata = [] for source in sources: # count number of passages - passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) + passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=actor.id) num_passages = passage_conn.size({"source_id": source.id}) # TODO: add when files table implemented @@ -1773,7 +1752,7 @@ class SyncServer(Server): attached_agents = [ { "id": str(a_id), - "name": self.ms.get_agent(user_id=user_id, agent_id=a_id).name, + "name": self.ms.get_agent(user_id=actor.id, agent_id=a_id).name, } for a_id in agent_ids ] diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index 1832c580..1b7f18b6 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -27,18 +27,26 @@ class OrganizationManager: return self.get_organization_by_id(self.DEFAULT_ORG_ID) @enforce_types - def get_organization_by_id(self, org_id: str) -> PydanticOrganization: + def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]: """Fetch an organization by ID.""" with self.session_maker() as session: try: organization = OrganizationModel.read(db_session=session, identifier=org_id) return organization.to_pydantic() except NoResultFound: - raise ValueError(f"Organization with id {org_id} not found.") + return None @enforce_types def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: """Create a new organization. If a name is provided, it is used, otherwise, a random one is generated.""" + org = self.get_organization_by_id(pydantic_org.id) + if org: + return org + else: + return self._create_organization(pydantic_org=pydantic_org) + + @enforce_types + def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: with self.session_maker() as session: org = OrganizationModel(**pydantic_org.model_dump()) org.create(session) @@ -47,16 +55,7 @@ class OrganizationManager: @enforce_types def create_default_organization(self) -> PydanticOrganization: """Create the default organization.""" - with self.session_maker() as session: - # Try to get it first - try: - org = OrganizationModel.read(db_session=session, identifier=self.DEFAULT_ORG_ID) - # If it doesn't exist, make it - except NoResultFound: - org = OrganizationModel(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID) - org.create(session) - - return org.to_pydantic() + return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)) @enforce_types def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization: @@ -73,7 +72,7 @@ class OrganizationManager: """Delete an organization by marking it as deleted.""" with self.session_maker() as session: organization = OrganizationModel.read(db_session=session, identifier=org_id) - organization.delete(session) + organization.hard_delete(session) @enforce_types def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py new file mode 100644 index 00000000..e09bddd9 --- /dev/null +++ b/letta/services/source_manager.py @@ -0,0 +1,100 @@ +from typing import List, Optional + +from letta.orm.errors import NoResultFound +from letta.orm.source import Source as SourceModel +from letta.schemas.source import Source as PydanticSource +from letta.schemas.source import SourceUpdate +from letta.schemas.user import User as PydanticUser +from letta.utils import enforce_types, printd + + +class SourceManager: + """Manager class to handle business logic related to Sources.""" + + def __init__(self): + from letta.server.server import db_context + + self.session_maker = db_context + + @enforce_types + def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource: + """Create a new source based on the PydanticSource schema.""" + # Try getting the source first by id + db_source = self.get_source_by_id(source.id, actor=actor) + if db_source: + return db_source + else: + with self.session_maker() as session: + # Provide default embedding config if not given + source.organization_id = actor.organization_id + source = SourceModel(**source.model_dump(exclude_none=True)) + source.create(session, actor=actor) + return source.to_pydantic() + + @enforce_types + def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource: + """Update a source by its ID with the given SourceUpdate object.""" + with self.session_maker() as session: + source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) + + # get update dictionary + update_data = source_update.model_dump(exclude_unset=True, exclude_none=True) + # Remove redundant update fields + update_data = {key: value for key, value in update_data.items() if getattr(source, key) != value} + + if update_data: + for key, value in update_data.items(): + setattr(source, key, value) + source.update(db_session=session, actor=actor) + else: + printd( + f"`update_source` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={source.name}, but found existing source with nothing to update." + ) + + return source.to_pydantic() + + @enforce_types + def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource: + """Delete a source by its ID.""" + with self.session_maker() as session: + source = SourceModel.read(db_session=session, identifier=source_id) + source.delete(db_session=session, actor=actor) + return source.to_pydantic() + + @enforce_types + def list_sources(self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticSource]: + """List all sources with optional pagination.""" + with self.session_maker() as session: + sources = SourceModel.list( + db_session=session, + cursor=cursor, + limit=limit, + organization_id=actor.organization_id, + ) + return [source.to_pydantic() for source in sources] + + # TODO: We make actor optional for now, but should most likely be enforced due to security reasons + @enforce_types + def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]: + """Retrieve a source by its ID.""" + with self.session_maker() as session: + try: + source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) + return source.to_pydantic() + except NoResultFound: + return None + + @enforce_types + def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]: + """Retrieve a source by its name.""" + with self.session_maker() as session: + sources = SourceModel.list( + db_session=session, + name=source_name, + organization_id=actor.organization_id, + limit=1, + ) + if not sources: + return None + else: + return sources[0].to_pydantic() diff --git a/tests/test_client.py b/tests/test_client.py index 3c1647b7..ffdd27bf 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,11 +6,13 @@ from typing import List, Union import pytest from dotenv import load_dotenv +from sqlalchemy import delete from letta import create_client from letta.agent import initialize_message_sequence from letta.client.client import LocalClient, RESTClient from letta.constants import DEFAULT_PRESET +from letta.orm import Source from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole, MessageStreamStatus @@ -83,6 +85,16 @@ def client(request): yield client +@pytest.fixture(autouse=True) +def clear_tables(): + """Fixture to clear the organization table before each test.""" + from letta.server.server import db_context + + with db_context() as session: + session.execute(delete(Source)) + session.commit() + + # Fixture for test agent @pytest.fixture(scope="module") def agent(client: Union[LocalClient, RESTClient]): diff --git a/tests/test_managers.py b/tests/test_managers.py index 64c6b3be..8436d7ba 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3,14 +3,14 @@ from sqlalchemy import delete import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.orm.organization import Organization -from letta.orm.tool import Tool -from letta.orm.user import User +from letta.orm import Organization, Source, Tool, User from letta.schemas.agent import CreateAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization as PydanticOrganization +from letta.schemas.source import Source as PydanticSource +from letta.schemas.source import SourceUpdate from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolUpdate from letta.services.organization_manager import OrganizationManager @@ -21,11 +21,23 @@ from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate from letta.server.server import SyncServer +DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_model="letta-free", + embedding_dim=1024, + embedding_chunk_size=300, + azure_endpoint=None, + azure_version=None, + azure_deployment=None, +) + @pytest.fixture(autouse=True) def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: + session.execute(delete(Source)) session.execute(delete(Tool)) # Clear all records from the Tool table session.execute(delete(User)) # Clear all records from the user table session.execute(delete(Organization)) # Clear all records from the organization table @@ -114,8 +126,6 @@ def tool_fixture(server: SyncServer, default_user, default_organization): description = "test_description" tags = ["test"] - org = server.organization_manager.create_default_organization() - user = server.user_manager.create_default_user() tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type) derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) @@ -123,10 +133,10 @@ def tool_fixture(server: SyncServer, default_user, default_organization): tool.json_schema = derived_json_schema tool.name = derived_name - tool = server.tool_manager.create_tool(tool, actor=user) + tool = server.tool_manager.create_tool(tool, actor=default_user) # Yield the created tool, organization, and user for use in tests - yield {"tool": tool, "organization": org, "user": user, "tool_create": tool} + yield {"tool": tool} @pytest.fixture(scope="module") @@ -240,16 +250,10 @@ def test_update_user(server: SyncServer): # ====================================================================================================================== def test_create_tool(server: SyncServer, tool_fixture, default_user, default_organization): tool = tool_fixture["tool"] - tool_create = tool_fixture["tool_create"] # Assertions to ensure the created tool matches the expected values assert tool.created_by_id == default_user.id assert tool.organization_id == default_organization.id - assert tool.description == tool_create.description - assert tool.tags == tool_create.tags - assert tool.source_code == tool_create.source_code - assert tool.source_type == tool_create.source_type - assert tool.json_schema == derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name) def test_get_tool_by_id(server: SyncServer, tool_fixture, default_user): @@ -327,7 +331,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t # Test begins tool = tool_fixture["tool"] - og_json_schema = tool_fixture["tool_create"].json_schema + og_json_schema = tool.json_schema source_code = parse_source_code(counter_tool) @@ -364,7 +368,7 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_ # Test begins tool = tool_fixture["tool"] - og_json_schema = tool_fixture["tool_create"].json_schema + og_json_schema = tool.json_schema source_code = parse_source_code(counter_tool) name = "counter_tool" @@ -415,6 +419,149 @@ def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user): assert len(tools) == 0 +# ====================================================================================================================== +# Source Manager Tests +# ====================================================================================================================== + + +def test_create_source(server: SyncServer, default_user): + """Test creating a new source.""" + source_pydantic = PydanticSource( + name="Test Source", + description="This is a test source.", + metadata_={"type": "test"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + + # Assertions to check the created source + assert source.name == source_pydantic.name + assert source.description == source_pydantic.description + assert source.metadata_ == source_pydantic.metadata_ + assert source.organization_id == default_user.organization_id + + +def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user): + """Test creating a new source.""" + name = "Test Source" + source_pydantic = PydanticSource( + name=name, + description="This is a test source.", + metadata_={"type": "medical"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source_pydantic = PydanticSource( + name=name, + description="This is a different test source.", + metadata_={"type": "legal"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + + assert source.name == same_source.name + assert source.id != same_source.id + + +def test_update_source(server: SyncServer, default_user): + """Test updating an existing source.""" + source_pydantic = PydanticSource(name="Original Source", description="Original description", embedding_config=DEFAULT_EMBEDDING_CONFIG) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + + # Update the source + update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata_={"type": "updated"}) + updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user) + + # Assertions to verify update + assert updated_source.name == update_data.name + assert updated_source.description == update_data.description + assert updated_source.metadata_ == update_data.metadata_ + + +def test_delete_source(server: SyncServer, default_user): + """Test deleting a source.""" + source_pydantic = PydanticSource( + name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + + # Delete the source + deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user) + + # Assertions to verify deletion + assert deleted_source.id == source.id + + # Verify that the source no longer appears in list_sources + sources = server.source_manager.list_sources(actor=default_user) + assert len(sources) == 0 + + +def test_list_sources(server: SyncServer, default_user): + """Test listing sources with pagination.""" + # Create multiple sources + server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) + server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) + + # List sources without pagination + sources = server.source_manager.list_sources(actor=default_user) + assert len(sources) == 2 + + # List sources with pagination + paginated_sources = server.source_manager.list_sources(actor=default_user, limit=1) + assert len(paginated_sources) == 1 + + # Ensure cursor-based pagination works + next_page = server.source_manager.list_sources(actor=default_user, cursor=paginated_sources[-1].id, limit=1) + assert len(next_page) == 1 + assert next_page[0].name != paginated_sources[0].name + + +def test_get_source_by_id(server: SyncServer, default_user): + """Test retrieving a source by ID.""" + source_pydantic = PydanticSource( + name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + + # Retrieve the source by ID + retrieved_source = server.source_manager.get_source_by_id(source_id=source.id, actor=default_user) + + # Assertions to verify the retrieved source matches the created one + assert retrieved_source.id == source.id + assert retrieved_source.name == source.name + assert retrieved_source.description == source.description + + +def test_get_source_by_name(server: SyncServer, default_user): + """Test retrieving a source by name.""" + source_pydantic = PydanticSource( + name="Unique Source", description="Test source for name retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + + # Retrieve the source by name + retrieved_source = server.source_manager.get_source_by_name(source_name=source.name, actor=default_user) + + # Assertions to verify the retrieved source matches the created one + assert retrieved_source.name == source.name + assert retrieved_source.description == source.description + + +def test_update_source_no_changes(server: SyncServer, default_user): + """Test update_source with no actual changes to verify logging and response.""" + source_pydantic = PydanticSource(name="No Change Source", description="No changes", embedding_config=DEFAULT_EMBEDDING_CONFIG) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + + # Attempt to update the source with identical data + update_data = SourceUpdate(name="No Change Source", description="No changes") + updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user) + + # Assertions to ensure the update returned the source but made no modifications + assert updated_source.id == source.id + assert updated_source.name == source.name + assert updated_source.description == source.description + + # ====================================================================================================================== # AgentsTagsManager Tests # ====================================================================================================================== diff --git a/tests/test_server.py b/tests/test_server.py index 1144cc18..bd68e45a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -8,6 +8,8 @@ import letta.utils as utils from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.enums import MessageRole +from .test_managers import DEFAULT_EMBEDDING_CONFIG + utils.DEBUG = True from letta.config import LettaConfig from letta.schemas.agent import CreateAgent @@ -24,7 +26,7 @@ from letta.schemas.letta_message import ( from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory from letta.schemas.message import Message -from letta.schemas.source import SourceCreate +from letta.schemas.source import Source from letta.server.server import SyncServer from .utils import DummyDataConnector @@ -117,7 +119,9 @@ def test_user_message_memory(server, user_id, agent_id): @pytest.mark.order(3) def test_load_data(server, user_id, agent_id): # create source - source = server.create_source(SourceCreate(name="test_source"), user_id=user_id) + source = server.source_manager.create_source( + Source(name="test_source", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=server.default_user + ) # load data archival_memories = [