Fix: Move orm metadata out of write-only Agent pydantic objects (#2249)

This commit is contained in:
Matthew Zhou
2024-12-13 15:03:48 -08:00
committed by GitHub
parent 7908b8a15f
commit 63a63e5183
2 changed files with 15 additions and 14 deletions

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm import FileMetadata
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.embedding_config import EmbeddingConfig
@@ -47,5 +48,5 @@ class Source(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")

View File

@@ -16,14 +16,6 @@ from letta.schemas.tool_rule import ToolRule
from letta.utils import create_random_username
class BaseAgent(OrmMetadataBase, validate_assignment=True):
__id_prefix__ = "agent"
description: Optional[str] = Field(None, description="The description of the agent.")
# metadata
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
class AgentType(str, Enum):
"""
Enum to represent the type of agent.
@@ -36,7 +28,7 @@ class AgentType(str, Enum):
chat_only_agent = "chat_only_agent"
class AgentState(BaseAgent):
class AgentState(OrmMetadataBase, validate_assignment=True):
"""
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
@@ -53,8 +45,10 @@ class AgentState(BaseAgent):
"""
__id_prefix__ = "agent"
# NOTE: this is what is returned to the client and also what is used to initialize `Agent`
id: str = BaseAgent.generate_id_field()
id: str = Field(..., description="The id of the agent. Assigned by the database.")
name: str = Field(..., description="The name of the agent.")
# tool rules
tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.")
@@ -76,14 +70,16 @@ class AgentState(BaseAgent):
# Field in this object can be theoretically edited by tools, and will be persisted by the ORM
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the agent.")
description: Optional[str] = Field(None, description="The description of the agent.")
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
memory: Memory = Field(..., description="The in-context memory of the agent.")
tools: List[Tool] = Field(..., description="The tools used by the agent.")
sources: List[Source] = Field(..., description="The sources used by the agent.")
tags: List[str] = Field(..., description="The tags associated with the agent.")
# TODO: add in context message objects
class CreateAgent(BaseAgent): #
class CreateAgent(BaseModel, validate_assignment=True): #
# all optional as server can generate defaults
name: str = Field(default_factory=lambda: create_random_username(), description="The name of the agent.")
@@ -109,6 +105,8 @@ class CreateAgent(BaseAgent): #
None, description="The initial set of messages to put in the agent's in-context memory."
)
include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.")
description: Optional[str] = Field(None, description="The description of the agent.")
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
@field_validator("name")
@classmethod
@@ -136,7 +134,7 @@ class CreateAgent(BaseAgent): #
return name
class UpdateAgent(BaseAgent):
class UpdateAgent(BaseModel):
name: Optional[str] = Field(None, description="The name of the agent.")
tool_ids: Optional[List[str]] = Field(None, description="The ids of the tools used by the agent.")
source_ids: Optional[List[str]] = Field(None, description="The ids of the sources used by the agent.")
@@ -147,6 +145,8 @@ class UpdateAgent(BaseAgent):
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
description: Optional[str] = Field(None, description="The description of the agent.")
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
class Config:
extra = "ignore" # Ignores extra fields