From 63a63e5183eec21f8962aea3f2bf276dc35bf39f Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 13 Dec 2024 15:03:48 -0800 Subject: [PATCH] Fix: Move orm metadata out of write-only Agent pydantic objects (#2249) --- letta/orm/source.py | 3 ++- letta/schemas/agent.py | 26 +++++++++++++------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/letta/orm/source.py b/letta/orm/source.py index e849cddb..b933f95f 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List, Optional from sqlalchemy import JSON, TypeDecorator from sqlalchemy.orm import Mapped, mapped_column, relationship +from letta.orm import FileMetadata from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.embedding_config import EmbeddingConfig @@ -47,5 +48,5 @@ class Source(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") - files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan") + files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan") agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources") diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 994233ab..ea3afd28 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -16,14 +16,6 @@ from letta.schemas.tool_rule import ToolRule from letta.utils import create_random_username -class BaseAgent(OrmMetadataBase, validate_assignment=True): - __id_prefix__ = "agent" - description: Optional[str] = Field(None, description="The description of the agent.") - - # metadata - metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") - - class AgentType(str, Enum): """ Enum to represent the type of agent. @@ -36,7 +28,7 @@ class AgentType(str, Enum): chat_only_agent = "chat_only_agent" -class AgentState(BaseAgent): +class AgentState(OrmMetadataBase, validate_assignment=True): """ Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent. @@ -53,8 +45,10 @@ class AgentState(BaseAgent): """ + __id_prefix__ = "agent" + # NOTE: this is what is returned to the client and also what is used to initialize `Agent` - id: str = BaseAgent.generate_id_field() + id: str = Field(..., description="The id of the agent. Assigned by the database.") name: str = Field(..., description="The name of the agent.") # tool rules tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.") @@ -76,14 +70,16 @@ class AgentState(BaseAgent): # Field in this object can be theoretically edited by tools, and will be persisted by the ORM organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the agent.") + description: Optional[str] = Field(None, description="The description of the agent.") + metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") + memory: Memory = Field(..., description="The in-context memory of the agent.") tools: List[Tool] = Field(..., description="The tools used by the agent.") sources: List[Source] = Field(..., description="The sources used by the agent.") tags: List[str] = Field(..., description="The tags associated with the agent.") - # TODO: add in context message objects -class CreateAgent(BaseAgent): # +class CreateAgent(BaseModel, validate_assignment=True): # # all optional as server can generate defaults name: str = Field(default_factory=lambda: create_random_username(), description="The name of the agent.") @@ -109,6 +105,8 @@ class CreateAgent(BaseAgent): # None, description="The initial set of messages to put in the agent's in-context memory." ) include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.") + description: Optional[str] = Field(None, description="The description of the agent.") + metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") @field_validator("name") @classmethod @@ -136,7 +134,7 @@ class CreateAgent(BaseAgent): # return name -class UpdateAgent(BaseAgent): +class UpdateAgent(BaseModel): name: Optional[str] = Field(None, description="The name of the agent.") tool_ids: Optional[List[str]] = Field(None, description="The ids of the tools used by the agent.") source_ids: Optional[List[str]] = Field(None, description="The ids of the sources used by the agent.") @@ -147,6 +145,8 @@ class UpdateAgent(BaseAgent): llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.") message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.") + description: Optional[str] = Field(None, description="The description of the agent.") + metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") class Config: extra = "ignore" # Ignores extra fields