feat: Make pydantic serialized agent object (#1278)
Co-authored-by: Caren Thomas <caren@letta.com>
This commit is contained in:
@@ -1 +1 @@
|
||||
from letta.serialize_schemas.agent import SerializedAgentSchema
|
||||
from letta.serialize_schemas.marshmallow_agent import MarshmallowAgentSchema
|
||||
|
||||
@@ -6,17 +6,17 @@ import letta
|
||||
from letta.orm import Agent
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.block import SerializedBlockSchema
|
||||
from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
|
||||
from letta.serialize_schemas.message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.tag import SerializedAgentTagSchema
|
||||
from letta.serialize_schemas.tool import SerializedToolSchema
|
||||
from letta.serialize_schemas.marshmallow_agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_block import SerializedBlockSchema
|
||||
from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
|
||||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema
|
||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||
from letta.server.db import SessionLocal
|
||||
|
||||
|
||||
class SerializedAgentSchema(BaseSchema):
|
||||
class MarshmallowAgentSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Agent objects.
|
||||
Excludes relational fields.
|
||||
@@ -98,7 +98,6 @@ class SerializedAgentSchema(BaseSchema):
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
# TODO: Serialize these as well...
|
||||
exclude = BaseSchema.Meta.exclude + (
|
||||
"project_id",
|
||||
"template_id",
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedAgentEnvironmentVariableSchema(BaseSchema):
|
||||
@@ -1,6 +1,6 @@
|
||||
from letta.orm.block import Block
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedBlockSchema(BaseSchema):
|
||||
@@ -4,8 +4,8 @@ from marshmallow import post_dump, pre_load
|
||||
|
||||
from letta.orm.message import Message
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.custom_fields import ToolCallField
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_custom_fields import ToolCallField
|
||||
|
||||
|
||||
class SerializedMessageSchema(BaseSchema):
|
||||
@@ -3,7 +3,7 @@ from typing import Dict
|
||||
from marshmallow import fields, post_dump, pre_load
|
||||
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedAgentTagSchema(BaseSchema):
|
||||
@@ -1,6 +1,6 @@
|
||||
from letta.orm import Tool
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
||||
|
||||
|
||||
class SerializedToolSchema(BaseSchema):
|
||||
110
letta/serialize_schemas/pydantic_agent_schema.py
Normal file
110
letta/serialize_schemas/pydantic_agent_schema.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
class CoreMemoryBlockSchema(BaseModel):
|
||||
created_at: str
|
||||
description: Optional[str]
|
||||
identities: List[Any]
|
||||
is_deleted: bool
|
||||
is_template: bool
|
||||
label: str
|
||||
limit: int
|
||||
metadata_: Dict[str, Any] = Field(default_factory=dict)
|
||||
template_name: Optional[str]
|
||||
updated_at: str
|
||||
value: str
|
||||
|
||||
|
||||
class MessageSchema(BaseModel):
|
||||
created_at: str
|
||||
group_id: Optional[str]
|
||||
in_context: bool
|
||||
model: Optional[str]
|
||||
name: Optional[str]
|
||||
role: str
|
||||
text: str
|
||||
tool_call_id: Optional[str]
|
||||
tool_calls: List[Any]
|
||||
tool_returns: List[Any]
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TagSchema(BaseModel):
|
||||
tag: str
|
||||
|
||||
|
||||
class ToolEnvVarSchema(BaseModel):
|
||||
created_at: str
|
||||
description: Optional[str]
|
||||
is_deleted: bool
|
||||
key: str
|
||||
updated_at: str
|
||||
value: str
|
||||
|
||||
|
||||
class ToolRuleSchema(BaseModel):
|
||||
tool_name: str
|
||||
type: str
|
||||
|
||||
|
||||
class ParameterProperties(BaseModel):
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ParametersSchema(BaseModel):
|
||||
type: Optional[str] = "object"
|
||||
properties: Dict[str, ParameterProperties]
|
||||
required: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolJSONSchema(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: ParametersSchema # <— nested strong typing
|
||||
type: Optional[str] = None # top-level 'type' if it exists
|
||||
required: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolSchema(BaseModel):
|
||||
args_json_schema: Optional[Any]
|
||||
created_at: str
|
||||
description: str
|
||||
is_deleted: bool
|
||||
json_schema: ToolJSONSchema
|
||||
name: str
|
||||
return_char_limit: int
|
||||
source_code: Optional[str]
|
||||
source_type: str
|
||||
tags: List[str]
|
||||
tool_type: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class AgentSchema(BaseModel):
|
||||
agent_type: str
|
||||
core_memory: List[CoreMemoryBlockSchema]
|
||||
created_at: str
|
||||
description: str
|
||||
embedding_config: EmbeddingConfig
|
||||
groups: List[Any]
|
||||
identities: List[Any]
|
||||
is_deleted: bool
|
||||
llm_config: LLMConfig
|
||||
message_buffer_autoclear: bool
|
||||
messages: List[MessageSchema]
|
||||
metadata_: Dict
|
||||
multi_agent_group: Optional[Any]
|
||||
name: str
|
||||
system: str
|
||||
tags: List[TagSchema]
|
||||
tool_exec_environment_variables: List[ToolEnvVarSchema]
|
||||
tool_rules: List[ToolRuleSchema]
|
||||
tools: List[ToolSchema]
|
||||
updated_at: str
|
||||
version: str
|
||||
Reference in New Issue
Block a user