feat: Serialize agent state simple fields and messages (#1012)
This commit is contained in:
1
letta/serialize_schemas/__init__.py
Normal file
1
letta/serialize_schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from letta.serialize_schemas.agent import SerializedAgentSchema
|
||||
36
letta/serialize_schemas/agent.py
Normal file
36
letta/serialize_schemas/agent.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from marshmallow import fields
|
||||
|
||||
from letta.orm import Agent
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
|
||||
from letta.serialize_schemas.message import SerializedMessageSchema
|
||||
|
||||
|
||||
class SerializedAgentSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Agent objects.
|
||||
Excludes relational fields.
|
||||
"""
|
||||
|
||||
llm_config = LLMConfigField()
|
||||
embedding_config = EmbeddingConfigField()
|
||||
tool_rules = ToolRulesField()
|
||||
|
||||
messages = fields.List(fields.Nested(SerializedMessageSchema))
|
||||
|
||||
def __init__(self, *args, session=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if session:
|
||||
self.session = session
|
||||
|
||||
# propagate session to nested schemas
|
||||
for field_name, field_obj in self.fields.items():
|
||||
if isinstance(field_obj, fields.List) and hasattr(field_obj.inner, "schema"):
|
||||
field_obj.inner.schema.session = session
|
||||
elif hasattr(field_obj, "schema"):
|
||||
field_obj.schema.session = session
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
# TODO: Serialize these as well...
|
||||
exclude = ("tools", "sources", "core_memory", "tags", "source_passages", "agent_passages", "organization")
|
||||
12
letta/serialize_schemas/base.py
Normal file
12
letta/serialize_schemas/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
|
||||
|
||||
|
||||
class BaseSchema(SQLAlchemyAutoSchema):
|
||||
"""
|
||||
Base schema for all SQLAlchemy models.
|
||||
This ensures all schemas share the same session.
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
include_relationships = True
|
||||
load_instance = True
|
||||
69
letta/serialize_schemas/custom_fields.py
Normal file
69
letta/serialize_schemas/custom_fields.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from marshmallow import fields
|
||||
|
||||
from letta.helpers.converters import (
|
||||
deserialize_embedding_config,
|
||||
deserialize_llm_config,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_rules,
|
||||
serialize_embedding_config,
|
||||
serialize_llm_config,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_rules,
|
||||
)
|
||||
|
||||
|
||||
class PydanticField(fields.Field):
|
||||
"""Generic Marshmallow field for handling Pydantic models."""
|
||||
|
||||
def __init__(self, pydantic_class, **kwargs):
|
||||
self.pydantic_class = pydantic_class
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return value.model_dump() if value else None
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return self.pydantic_class(**value) if value else None
|
||||
|
||||
|
||||
class LLMConfigField(fields.Field):
|
||||
"""Marshmallow field for handling LLMConfig serialization."""
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_llm_config(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_llm_config(value)
|
||||
|
||||
|
||||
class EmbeddingConfigField(fields.Field):
|
||||
"""Marshmallow field for handling EmbeddingConfig serialization."""
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_embedding_config(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_embedding_config(value)
|
||||
|
||||
|
||||
class ToolRulesField(fields.List):
|
||||
"""Custom Marshmallow field to handle a list of ToolRules."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(fields.Dict(), **kwargs)
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_tool_rules(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_tool_rules(value)
|
||||
|
||||
|
||||
class ToolCallField(fields.Field):
|
||||
"""Marshmallow field for handling a list of OpenAI ToolCall objects."""
|
||||
|
||||
def _serialize(self, value, attr, obj, **kwargs):
|
||||
return serialize_tool_calls(value)
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
return deserialize_tool_calls(value)
|
||||
15
letta/serialize_schemas/message.py
Normal file
15
letta/serialize_schemas/message.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from letta.orm.message import Message
|
||||
from letta.serialize_schemas.base import BaseSchema
|
||||
from letta.serialize_schemas.custom_fields import ToolCallField
|
||||
|
||||
|
||||
class SerializedMessageSchema(BaseSchema):
|
||||
"""
|
||||
Marshmallow schema for serializing/deserializing Message objects.
|
||||
"""
|
||||
|
||||
tool_calls = ToolCallField()
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Message
|
||||
exclude = ("step", "job_message")
|
||||
111
letta/server/db.py
Normal file
111
letta/server/db.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Base
|
||||
|
||||
# NOTE: hack to see if single session management works
|
||||
from letta.settings import settings
|
||||
|
||||
config = LettaConfig.load()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def print_sqlite_schema_error():
|
||||
"""Print a formatted error message for SQLite schema issues"""
|
||||
console = Console()
|
||||
error_text = Text()
|
||||
error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red")
|
||||
error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white")
|
||||
error_text.append("https://docs.letta.com/server/docker", style="blue underline")
|
||||
error_text.append(") or use Postgres by setting ", style="white")
|
||||
error_text.append("LETTA_PG_URI", style="yellow")
|
||||
error_text.append(".\n\n", style="white")
|
||||
error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white")
|
||||
error_text.append("rm ~/.letta/sqlite.db", style="yellow")
|
||||
error_text.append(" or downgrade to your previous version of Letta.", style="white")
|
||||
|
||||
console.print(Panel(error_text, border_style="red"))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_error_handler():
|
||||
"""Context manager for handling database errors"""
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
# Handle other SQLAlchemy errors
|
||||
print(e)
|
||||
print_sqlite_schema_error()
|
||||
# raise ValueError(f"SQLite DB error: {str(e)}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
print("Creating postgres engine")
|
||||
config.recall_storage_type = "postgres"
|
||||
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
config.archival_storage_type = "postgres"
|
||||
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
||||
|
||||
# create engine
|
||||
engine = create_engine(
|
||||
settings.letta_pg_uri,
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
)
|
||||
else:
|
||||
# TODO: don't rely on config storage
|
||||
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
|
||||
logger.info("Creating sqlite engine " + engine_path)
|
||||
|
||||
engine = create_engine(engine_path)
|
||||
|
||||
# Store the original connect method
|
||||
original_connect = engine.connect
|
||||
|
||||
def wrapped_connect(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
# Get the connection
|
||||
connection = original_connect(*args, **kwargs)
|
||||
|
||||
# Store the original execution method
|
||||
original_execute = connection.execute
|
||||
|
||||
# Wrap the execute method of the connection
|
||||
def wrapped_execute(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the connection's execute method
|
||||
connection.execute = wrapped_execute
|
||||
|
||||
return connection
|
||||
|
||||
# Replace the engine's connect method
|
||||
engine.connect = wrapped_connect
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db_context = contextmanager(get_db)
|
||||
@@ -18,6 +18,7 @@ import letta.server.utils as server_utils
|
||||
import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.chat_only_agent import ChatOnlyAgent
|
||||
from letta.config import LettaConfig
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
@@ -27,7 +28,6 @@ from letta.interface import AgentInterface # abstract
|
||||
from letta.interface import CLIInterface # for printing to terminal
|
||||
from letta.log import get_logger
|
||||
from letta.offline_memory_agent import OfflineMemoryAgent
|
||||
from letta.orm import Base
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
@@ -82,8 +82,10 @@ from letta.services.step_manager import StepManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings, settings, tool_settings
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
config = LettaConfig.load()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -145,118 +147,6 @@ class Server(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from letta.config import LettaConfig
|
||||
|
||||
# NOTE: hack to see if single session management works
|
||||
from letta.settings import model_settings, settings, tool_settings
|
||||
|
||||
config = LettaConfig.load()
|
||||
|
||||
|
||||
def print_sqlite_schema_error():
|
||||
"""Print a formatted error message for SQLite schema issues"""
|
||||
console = Console()
|
||||
error_text = Text()
|
||||
error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red")
|
||||
error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white")
|
||||
error_text.append("https://docs.letta.com/server/docker", style="blue underline")
|
||||
error_text.append(") or use Postgres by setting ", style="white")
|
||||
error_text.append("LETTA_PG_URI", style="yellow")
|
||||
error_text.append(".\n\n", style="white")
|
||||
error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white")
|
||||
error_text.append("rm ~/.letta/sqlite.db", style="yellow")
|
||||
error_text.append(" or downgrade to your previous version of Letta.", style="white")
|
||||
|
||||
console.print(Panel(error_text, border_style="red"))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def db_error_handler():
|
||||
"""Context manager for handling database errors"""
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
# Handle other SQLAlchemy errors
|
||||
print(e)
|
||||
print_sqlite_schema_error()
|
||||
# raise ValueError(f"SQLite DB error: {str(e)}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
print("Creating postgres engine")
|
||||
config.recall_storage_type = "postgres"
|
||||
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
config.archival_storage_type = "postgres"
|
||||
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
||||
|
||||
# create engine
|
||||
engine = create_engine(
|
||||
settings.letta_pg_uri,
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
)
|
||||
else:
|
||||
# TODO: don't rely on config storage
|
||||
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
|
||||
logger.info("Creating sqlite engine " + engine_path)
|
||||
|
||||
engine = create_engine(engine_path)
|
||||
|
||||
# Store the original connect method
|
||||
original_connect = engine.connect
|
||||
|
||||
def wrapped_connect(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
# Get the connection
|
||||
connection = original_connect(*args, **kwargs)
|
||||
|
||||
# Store the original execution method
|
||||
original_execute = connection.execute
|
||||
|
||||
# Wrap the execute method of the connection
|
||||
def wrapped_execute(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the connection's execute method
|
||||
connection.execute = wrapped_execute
|
||||
|
||||
return connection
|
||||
|
||||
# Replace the engine's connect method
|
||||
engine.connect = wrapped_connect
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
db_context = contextmanager(get_db)
|
||||
|
||||
|
||||
class SyncServer(Server):
|
||||
"""Simple single-threaded / blocking server process"""
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.serialize_schemas import SerializedAgentSchema
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_process_relationship,
|
||||
@@ -53,7 +54,7 @@ class AgentManager:
|
||||
"""Manager class to handle business logic related to Agents."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
self.block_manager = BlockManager()
|
||||
@@ -355,6 +356,24 @@ class AgentManager:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
agent.hard_delete(session)
|
||||
|
||||
@enforce_types
|
||||
def serialize(self, agent_id: str, actor: PydanticUser) -> dict:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
schema = SerializedAgentSchema(session=session)
|
||||
return schema.dump(agent)
|
||||
|
||||
@enforce_types
|
||||
def deserialize(self, serialized_agent: dict, actor: PydanticUser) -> PydanticAgentState:
|
||||
# TODO: Use actor to override fields
|
||||
with self.session_maker() as session:
|
||||
schema = SerializedAgentSchema(session=session)
|
||||
agent = schema.load(serialized_agent, session=session)
|
||||
agent.organization_id = actor.organization_id
|
||||
agent = agent.create(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Per Agent Environment Variable Management
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -16,7 +16,7 @@ class BlockManager:
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in ToolManager
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class JobManager:
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in OrganizationManager
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class MessageManager:
|
||||
"""Manager class to handle business logic related to Messages."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class OrganizationManager:
|
||||
# TODO: Please refactor this out
|
||||
# I am currently working on a ORM refactor and would like to make a more minimal set of changes
|
||||
# - Matt
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class PassageManager:
|
||||
"""Manager class to handle business logic related to Passages."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from letta.utils import enforce_types
|
||||
class ProviderManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class SandboxConfigManager:
|
||||
"""Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class SourceManager:
|
||||
"""Manager class to handle business logic related to Sources."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from letta.utils import enforce_types
|
||||
class StepManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class ToolManager:
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in OrganizationManager
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class UserManager:
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in OrganizationManager
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
|
||||
22
poetry.lock
generated
22
poetry.lock
generated
@@ -3036,6 +3036,26 @@ dev = ["marshmallow[tests]", "pre-commit (>=3.5,<5.0)", "tox"]
|
||||
docs = ["autodocsumm (==0.2.14)", "furo (==2024.8.6)", "sphinx (==8.1.3)", "sphinx-copybutton (==0.5.2)", "sphinx-issues (==5.0.0)", "sphinxext-opengraph (==0.9.1)"]
|
||||
tests = ["pytest", "simplejson"]
|
||||
|
||||
[[package]]
|
||||
name = "marshmallow-sqlalchemy"
|
||||
version = "1.4.1"
|
||||
description = "SQLAlchemy integration with the marshmallow (de)serialization library"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "marshmallow_sqlalchemy-1.4.1-py3-none-any.whl", hash = "sha256:9a3dd88a2b24f425fbffb3fea8aeb7f424a932fc97372a9f1338b7a379396191"},
|
||||
{file = "marshmallow_sqlalchemy-1.4.1.tar.gz", hash = "sha256:b4aa964356d00e178bdb8469a28daa9022b375ff4f5c04f8e2b9aafe1e65c529"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
marshmallow = ">=3.18.0"
|
||||
SQLAlchemy = ">=1.4.40,<3.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["marshmallow-sqlalchemy[tests]", "pre-commit (>=3.5,<5.0)", "tox"]
|
||||
docs = ["furo (==2024.8.6)", "sphinx (==8.1.3)", "sphinx-copybutton (==0.5.2)", "sphinx-design (==0.6.1)", "sphinx-issues (==5.0.0)", "sphinxext-opengraph (==0.9.1)"]
|
||||
tests = ["pytest (<9)", "pytest-lazy-fixtures"]
|
||||
|
||||
[[package]]
|
||||
name = "matplotlib-inline"
|
||||
version = "0.1.7"
|
||||
@@ -6550,4 +6570,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "05633c1ae9cf8125ccb07f90bf8887072ee3a452854f35dec2421e574ee202f7"
|
||||
content-hash = "37167dffe2006e220123bbc64f6bd91ab44e363d6950dd73c9eacc4d056daeb7"
|
||||
|
||||
@@ -81,6 +81,7 @@ openai = "^1.60.0"
|
||||
google-genai = {version = "^1.1.0", optional = true}
|
||||
faker = "^36.1.0"
|
||||
colorama = "^0.4.6"
|
||||
marshmallow-sqlalchemy = "^1.4.1"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
@@ -15,8 +15,8 @@ from tests.utils import wait_for_incoming_message
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def truncate_database():
|
||||
from letta.server.server import db_context
|
||||
def clear_tables():
|
||||
from letta.server.db import db_context
|
||||
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
|
||||
@@ -35,7 +35,7 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user"))
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables():
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
with db_context() as session:
|
||||
session.execute(delete(SandboxEnvironmentVariable))
|
||||
|
||||
@@ -21,7 +21,7 @@ from tests.integration_test_summarizer import LLM_CONFIG_DIR
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def truncate_database():
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
|
||||
137
tests/test_agent_serialization.py
Normal file
137
tests/test_agent_serialization.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
def _clear_tables():
|
||||
from letta.server.db import db_context
|
||||
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables():
|
||||
_clear_tables()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def local_client():
|
||||
client = create_client()
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_organization(server: SyncServer):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = server.organization_manager.create_default_organization()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(server: SyncServer, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_default_user(org_id=default_organization.id)
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create and return a sample agent within the default organization."""
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="sarah_agent",
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
yield agent_state
|
||||
|
||||
|
||||
def test_agent_serialization(server, sarah_agent, default_user):
|
||||
"""Test serializing an Agent instance to JSON."""
|
||||
result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user)
|
||||
|
||||
# Assert that the result is a dictionary (JSON object)
|
||||
assert isinstance(result, dict), "Expected a dictionary result"
|
||||
|
||||
# Assert that the 'id' field is present and matches the agent's ID
|
||||
assert "id" in result, "Agent 'id' is missing in the serialized result"
|
||||
assert result["id"] == sarah_agent.id, f"Expected agent 'id' to be {sarah_agent.id}, but got {result['id']}"
|
||||
|
||||
# Assert that the 'llm_config' and 'embedding_config' fields exist
|
||||
assert "llm_config" in result, "'llm_config' is missing in the serialized result"
|
||||
assert "embedding_config" in result, "'embedding_config' is missing in the serialized result"
|
||||
|
||||
# Assert that 'messages' is a list
|
||||
assert isinstance(result.get("messages", []), list), "'messages' should be a list"
|
||||
|
||||
# Assert that the 'tool_exec_environment_variables' field is a list (empty or populated)
|
||||
assert isinstance(result.get("tool_exec_environment_variables", []), list), "'tool_exec_environment_variables' should be a list"
|
||||
|
||||
# Assert that the 'agent_type' is a valid string
|
||||
assert isinstance(result.get("agent_type"), str), "'agent_type' should be a string"
|
||||
|
||||
# Assert that the 'tool_rules' field is a list (even if empty)
|
||||
assert isinstance(result.get("tool_rules", []), list), "'tool_rules' should be a list"
|
||||
|
||||
# Check that all necessary fields are present in the 'messages' section, focusing on core elements
|
||||
if "messages" in result:
|
||||
for message in result["messages"]:
|
||||
assert "id" in message, "Message 'id' is missing"
|
||||
assert "text" in message, "Message 'text' is missing"
|
||||
assert "role" in message, "Message 'role' is missing"
|
||||
assert "created_at" in message, "Message 'created_at' is missing"
|
||||
assert "updated_at" in message, "Message 'updated_at' is missing"
|
||||
|
||||
# Optionally check that 'created_at' and 'updated_at' are in ISO 8601 format
|
||||
assert isinstance(result["created_at"], str), "Expected 'created_at' to be a string"
|
||||
assert isinstance(result["updated_at"], str), "Expected 'updated_at' to be a string"
|
||||
|
||||
# Optionally check for presence of any required metadata or ensure it is null if expected
|
||||
assert "metadata_" in result, "'metadata_' field is missing"
|
||||
assert result["metadata_"] is None, "'metadata_' should be null"
|
||||
|
||||
# Assert that the agent name is as expected (if defined)
|
||||
assert result.get("name") == sarah_agent.name, "Expected agent 'name' to not be None, but found something else"
|
||||
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
|
||||
def test_agent_deserialization_basic(local_client, server, sarah_agent, default_user):
|
||||
"""Test deserializing JSON into an Agent instance."""
|
||||
# Send a message first
|
||||
sarah_agent = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user)
|
||||
result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user)
|
||||
|
||||
# Delete the agent
|
||||
server.agent_manager.delete_agent(sarah_agent.id, actor=default_user)
|
||||
|
||||
agent_state = server.agent_manager.deserialize(serialized_agent=result, actor=default_user)
|
||||
|
||||
assert agent_state.name == sarah_agent.name
|
||||
assert len(agent_state.message_ids) == len(sarah_agent.message_ids)
|
||||
@@ -104,7 +104,7 @@ def search_agent_two(client: Union[LocalClient, RESTClient]):
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables():
|
||||
"""Clear the sandbox tables before each test."""
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
with db_context() as session:
|
||||
session.execute(delete(SandboxEnvironmentVariable))
|
||||
|
||||
@@ -90,7 +90,7 @@ def client(request):
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables():
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
from letta.server.server import db_context
|
||||
from letta.server.db import db_context
|
||||
|
||||
with db_context() as session:
|
||||
session.execute(delete(FileMetadata))
|
||||
|
||||
@@ -5,35 +5,13 @@ from datetime import datetime, timedelta
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.orm import (
|
||||
Agent,
|
||||
AgentPassage,
|
||||
Block,
|
||||
BlocksAgents,
|
||||
FileMetadata,
|
||||
Job,
|
||||
JobMessage,
|
||||
Message,
|
||||
Organization,
|
||||
Provider,
|
||||
SandboxConfig,
|
||||
SandboxEnvironmentVariable,
|
||||
Source,
|
||||
SourcePassage,
|
||||
SourcesAgents,
|
||||
Step,
|
||||
Tool,
|
||||
ToolsAgents,
|
||||
User,
|
||||
)
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm import Base
|
||||
from letta.orm.enums import JobType, ToolType
|
||||
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
@@ -81,30 +59,13 @@ USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(AgentPassage))
|
||||
session.execute(delete(SourcePassage))
|
||||
session.execute(delete(JobMessage)) # Clear JobMessage first
|
||||
session.execute(delete(Job))
|
||||
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
||||
session.execute(delete(BlocksAgents))
|
||||
session.execute(delete(SourcesAgents))
|
||||
session.execute(delete(AgentsTags))
|
||||
session.execute(delete(SandboxEnvironmentVariable))
|
||||
session.execute(delete(SandboxConfig))
|
||||
session.execute(delete(Block))
|
||||
session.execute(delete(FileMetadata))
|
||||
session.execute(delete(Source))
|
||||
session.execute(delete(Tool)) # Clear all records from the Tool table
|
||||
session.execute(delete(Agent))
|
||||
session.execute(delete(User)) # Clear all records from the user table
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.execute(delete(Organization)) # Clear all records from the organization table
|
||||
session.commit() # Commit the deletion
|
||||
def clear_tables():
|
||||
from letta.server.db import db_context
|
||||
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user