feat: Serialize agent state simple fields and messages (#1012)

This commit is contained in:
Matthew Zhou
2025-02-18 11:01:10 -08:00
committed by GitHub
parent 3dc1767f46
commit b5e09536ae
28 changed files with 451 additions and 179 deletions

View File

@@ -0,0 +1 @@
from letta.serialize_schemas.agent import SerializedAgentSchema

View File

@@ -0,0 +1,36 @@
from marshmallow import fields
from letta.orm import Agent
from letta.serialize_schemas.base import BaseSchema
from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField
from letta.serialize_schemas.message import SerializedMessageSchema
class SerializedAgentSchema(BaseSchema):
"""
Marshmallow schema for serializing/deserializing Agent objects.
Excludes relational fields.
"""
llm_config = LLMConfigField()
embedding_config = EmbeddingConfigField()
tool_rules = ToolRulesField()
messages = fields.List(fields.Nested(SerializedMessageSchema))
def __init__(self, *args, session=None, **kwargs):
super().__init__(*args, **kwargs)
if session:
self.session = session
# propagate session to nested schemas
for field_name, field_obj in self.fields.items():
if isinstance(field_obj, fields.List) and hasattr(field_obj.inner, "schema"):
field_obj.inner.schema.session = session
elif hasattr(field_obj, "schema"):
field_obj.schema.session = session
class Meta(BaseSchema.Meta):
model = Agent
# TODO: Serialize these as well...
exclude = ("tools", "sources", "core_memory", "tags", "source_passages", "agent_passages", "organization")

View File

@@ -0,0 +1,12 @@
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema
class BaseSchema(SQLAlchemyAutoSchema):
"""
Base schema for all SQLAlchemy models.
This ensures all schemas share the same session.
"""
class Meta:
include_relationships = True
load_instance = True

View File

@@ -0,0 +1,69 @@
from marshmallow import fields
from letta.helpers.converters import (
deserialize_embedding_config,
deserialize_llm_config,
deserialize_tool_calls,
deserialize_tool_rules,
serialize_embedding_config,
serialize_llm_config,
serialize_tool_calls,
serialize_tool_rules,
)
class PydanticField(fields.Field):
"""Generic Marshmallow field for handling Pydantic models."""
def __init__(self, pydantic_class, **kwargs):
self.pydantic_class = pydantic_class
super().__init__(**kwargs)
def _serialize(self, value, attr, obj, **kwargs):
return value.model_dump() if value else None
def _deserialize(self, value, attr, data, **kwargs):
return self.pydantic_class(**value) if value else None
class LLMConfigField(fields.Field):
"""Marshmallow field for handling LLMConfig serialization."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_llm_config(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_llm_config(value)
class EmbeddingConfigField(fields.Field):
"""Marshmallow field for handling EmbeddingConfig serialization."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_embedding_config(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_embedding_config(value)
class ToolRulesField(fields.List):
"""Custom Marshmallow field to handle a list of ToolRules."""
def __init__(self, **kwargs):
super().__init__(fields.Dict(), **kwargs)
def _serialize(self, value, attr, obj, **kwargs):
return serialize_tool_rules(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_tool_rules(value)
class ToolCallField(fields.Field):
"""Marshmallow field for handling a list of OpenAI ToolCall objects."""
def _serialize(self, value, attr, obj, **kwargs):
return serialize_tool_calls(value)
def _deserialize(self, value, attr, data, **kwargs):
return deserialize_tool_calls(value)

View File

@@ -0,0 +1,15 @@
from letta.orm.message import Message
from letta.serialize_schemas.base import BaseSchema
from letta.serialize_schemas.custom_fields import ToolCallField
class SerializedMessageSchema(BaseSchema):
"""
Marshmallow schema for serializing/deserializing Message objects.
"""
tool_calls = ToolCallField()
class Meta(BaseSchema.Meta):
model = Message
exclude = ("step", "job_message")

111
letta/server/db.py Normal file
View File

@@ -0,0 +1,111 @@
import os
from contextlib import contextmanager
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from letta.config import LettaConfig
from letta.log import get_logger
from letta.orm import Base
# NOTE: hack to see if single session management works
from letta.settings import settings
config = LettaConfig.load()
logger = get_logger(__name__)
def print_sqlite_schema_error():
"""Print a formatted error message for SQLite schema issues"""
console = Console()
error_text = Text()
error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red")
error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white")
error_text.append("https://docs.letta.com/server/docker", style="blue underline")
error_text.append(") or use Postgres by setting ", style="white")
error_text.append("LETTA_PG_URI", style="yellow")
error_text.append(".\n\n", style="white")
error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white")
error_text.append("rm ~/.letta/sqlite.db", style="yellow")
error_text.append(" or downgrade to your previous version of Letta.", style="white")
console.print(Panel(error_text, border_style="red"))
@contextmanager
def db_error_handler():
"""Context manager for handling database errors"""
try:
yield
except Exception as e:
# Handle other SQLAlchemy errors
print(e)
print_sqlite_schema_error()
# raise ValueError(f"SQLite DB error: {str(e)}")
exit(1)
if settings.letta_pg_uri_no_default:
print("Creating postgres engine")
config.recall_storage_type = "postgres"
config.recall_storage_uri = settings.letta_pg_uri_no_default
config.archival_storage_type = "postgres"
config.archival_storage_uri = settings.letta_pg_uri_no_default
# create engine
engine = create_engine(
settings.letta_pg_uri,
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
)
else:
# TODO: don't rely on config storage
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
logger.info("Creating sqlite engine " + engine_path)
engine = create_engine(engine_path)
# Store the original connect method
original_connect = engine.connect
def wrapped_connect(*args, **kwargs):
with db_error_handler():
# Get the connection
connection = original_connect(*args, **kwargs)
# Store the original execution method
original_execute = connection.execute
# Wrap the execute method of the connection
def wrapped_execute(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)
# Replace the connection's execute method
connection.execute = wrapped_execute
return connection
# Replace the engine's connect method
engine.connect = wrapped_connect
Base.metadata.create_all(bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db_context = contextmanager(get_db)

View File

@@ -18,6 +18,7 @@ import letta.server.utils as server_utils
import letta.system as system
from letta.agent import Agent, save_agent
from letta.chat_only_agent import ChatOnlyAgent
from letta.config import LettaConfig
from letta.data_sources.connectors import DataConnector, load_data
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads
@@ -27,7 +28,6 @@ from letta.interface import AgentInterface # abstract
from letta.interface import CLIInterface # for printing to terminal
from letta.log import get_logger
from letta.offline_memory_agent import OfflineMemoryAgent
from letta.orm import Base
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, AgentType, CreateAgent
from letta.schemas.block import BlockUpdate
@@ -82,8 +82,10 @@ from letta.services.step_manager import StepManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings, settings, tool_settings
from letta.utils import get_friendly_error_msg
config = LettaConfig.load()
logger = get_logger(__name__)
@@ -145,118 +147,6 @@ class Server(object):
raise NotImplementedError
from contextlib import contextmanager
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from letta.config import LettaConfig
# NOTE: hack to see if single session management works
from letta.settings import model_settings, settings, tool_settings
config = LettaConfig.load()
def print_sqlite_schema_error():
"""Print a formatted error message for SQLite schema issues"""
console = Console()
error_text = Text()
error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red")
error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white")
error_text.append("https://docs.letta.com/server/docker", style="blue underline")
error_text.append(") or use Postgres by setting ", style="white")
error_text.append("LETTA_PG_URI", style="yellow")
error_text.append(".\n\n", style="white")
error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white")
error_text.append("rm ~/.letta/sqlite.db", style="yellow")
error_text.append(" or downgrade to your previous version of Letta.", style="white")
console.print(Panel(error_text, border_style="red"))
@contextmanager
def db_error_handler():
"""Context manager for handling database errors"""
try:
yield
except Exception as e:
# Handle other SQLAlchemy errors
print(e)
print_sqlite_schema_error()
# raise ValueError(f"SQLite DB error: {str(e)}")
exit(1)
if settings.letta_pg_uri_no_default:
print("Creating postgres engine")
config.recall_storage_type = "postgres"
config.recall_storage_uri = settings.letta_pg_uri_no_default
config.archival_storage_type = "postgres"
config.archival_storage_uri = settings.letta_pg_uri_no_default
# create engine
engine = create_engine(
settings.letta_pg_uri,
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
)
else:
# TODO: don't rely on config storage
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
logger.info("Creating sqlite engine " + engine_path)
engine = create_engine(engine_path)
# Store the original connect method
original_connect = engine.connect
def wrapped_connect(*args, **kwargs):
with db_error_handler():
# Get the connection
connection = original_connect(*args, **kwargs)
# Store the original execution method
original_execute = connection.execute
# Wrap the execute method of the connection
def wrapped_execute(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)
# Replace the connection's execute method
connection.execute = wrapped_execute
return connection
# Replace the engine's connect method
engine.connect = wrapped_connect
Base.metadata.create_all(bind=engine)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
from contextlib import contextmanager
db_context = contextmanager(get_db)
class SyncServer(Server):
"""Simple single-threaded / blocking server process"""

View File

@@ -29,6 +29,7 @@ from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser
from letta.serialize_schemas import SerializedAgentSchema
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
_process_relationship,
@@ -53,7 +54,7 @@ class AgentManager:
"""Manager class to handle business logic related to Agents."""
def __init__(self):
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context
self.block_manager = BlockManager()
@@ -355,6 +356,24 @@ class AgentManager:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
agent.hard_delete(session)
@enforce_types
def serialize(self, agent_id: str, actor: PydanticUser) -> dict:
with self.session_maker() as session:
# Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
schema = SerializedAgentSchema(session=session)
return schema.dump(agent)
@enforce_types
def deserialize(self, serialized_agent: dict, actor: PydanticUser) -> PydanticAgentState:
# TODO: Use actor to override fields
with self.session_maker() as session:
schema = SerializedAgentSchema(session=session)
agent = schema.load(serialized_agent, session=session)
agent.organization_id = actor.organization_id
agent = agent.create(session, actor=actor)
return agent.to_pydantic()
# ======================================================================================================================
# Per Agent Environment Variable Management
# ======================================================================================================================

View File

@@ -16,7 +16,7 @@ class BlockManager:
def __init__(self):
# Fetching the db_context similarly as in ToolManager
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -29,7 +29,7 @@ class JobManager:
def __init__(self):
# Fetching the db_context similarly as in OrganizationManager
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -16,7 +16,7 @@ class MessageManager:
"""Manager class to handle business logic related to Messages."""
def __init__(self):
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -16,7 +16,7 @@ class OrganizationManager:
# TODO: Please refactor this out
# I am currently working on a ORM refactor and would like to make a more minimal set of changes
# - Matt
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -16,7 +16,7 @@ class PassageManager:
"""Manager class to handle business logic related to Passages."""
def __init__(self):
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -10,7 +10,7 @@ from letta.utils import enforce_types
class ProviderManager:
def __init__(self):
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -20,7 +20,7 @@ class SandboxConfigManager:
"""Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable."""
def __init__(self):
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -15,7 +15,7 @@ class SourceManager:
"""Manager class to handle business logic related to Sources."""
def __init__(self):
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -17,7 +17,7 @@ from letta.utils import enforce_types
class StepManager:
def __init__(self):
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -31,7 +31,7 @@ class ToolManager:
def __init__(self):
# Fetching the db_context similarly as in OrganizationManager
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

View File

@@ -17,7 +17,7 @@ class UserManager:
def __init__(self):
# Fetching the db_context similarly as in OrganizationManager
from letta.server.server import db_context
from letta.server.db import db_context
self.session_maker = db_context

22
poetry.lock generated
View File

@@ -3036,6 +3036,26 @@ dev = ["marshmallow[tests]", "pre-commit (>=3.5,<5.0)", "tox"]
docs = ["autodocsumm (==0.2.14)", "furo (==2024.8.6)", "sphinx (==8.1.3)", "sphinx-copybutton (==0.5.2)", "sphinx-issues (==5.0.0)", "sphinxext-opengraph (==0.9.1)"]
tests = ["pytest", "simplejson"]
[[package]]
name = "marshmallow-sqlalchemy"
version = "1.4.1"
description = "SQLAlchemy integration with the marshmallow (de)serialization library"
optional = false
python-versions = ">=3.9"
files = [
{file = "marshmallow_sqlalchemy-1.4.1-py3-none-any.whl", hash = "sha256:9a3dd88a2b24f425fbffb3fea8aeb7f424a932fc97372a9f1338b7a379396191"},
{file = "marshmallow_sqlalchemy-1.4.1.tar.gz", hash = "sha256:b4aa964356d00e178bdb8469a28daa9022b375ff4f5c04f8e2b9aafe1e65c529"},
]
[package.dependencies]
marshmallow = ">=3.18.0"
SQLAlchemy = ">=1.4.40,<3.0"
[package.extras]
dev = ["marshmallow-sqlalchemy[tests]", "pre-commit (>=3.5,<5.0)", "tox"]
docs = ["furo (==2024.8.6)", "sphinx (==8.1.3)", "sphinx-copybutton (==0.5.2)", "sphinx-design (==0.6.1)", "sphinx-issues (==5.0.0)", "sphinxext-opengraph (==0.9.1)"]
tests = ["pytest (<9)", "pytest-lazy-fixtures"]
[[package]]
name = "matplotlib-inline"
version = "0.1.7"
@@ -6550,4 +6570,4 @@ tests = ["wikipedia"]
[metadata]
lock-version = "2.0"
python-versions = "<3.14,>=3.10"
content-hash = "05633c1ae9cf8125ccb07f90bf8887072ee3a452854f35dec2421e574ee202f7"
content-hash = "37167dffe2006e220123bbc64f6bd91ab44e363d6950dd73c9eacc4d056daeb7"

View File

@@ -81,6 +81,7 @@ openai = "^1.60.0"
google-genai = {version = "^1.1.0", optional = true}
faker = "^36.1.0"
colorama = "^0.4.6"
marshmallow-sqlalchemy = "^1.4.1"
[tool.poetry.extras]

View File

@@ -15,8 +15,8 @@ from tests.utils import wait_for_incoming_message
@pytest.fixture(autouse=True)
def truncate_database():
from letta.server.server import db_context
def clear_tables():
from letta.server.db import db_context
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues

View File

@@ -35,7 +35,7 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user"))
@pytest.fixture(autouse=True)
def clear_tables():
"""Fixture to clear the organization table before each test."""
from letta.server.server import db_context
from letta.server.db import db_context
with db_context() as session:
session.execute(delete(SandboxEnvironmentVariable))

View File

@@ -21,7 +21,7 @@ from tests.integration_test_summarizer import LLM_CONFIG_DIR
@pytest.fixture(autouse=True)
def truncate_database():
from letta.server.server import db_context
from letta.server.db import db_context
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues

View File

@@ -0,0 +1,137 @@
import json
import pytest
from letta import create_client
from letta.config import LettaConfig
from letta.orm import Base
from letta.schemas.agent import CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.server.server import SyncServer
def _clear_tables():
from letta.server.db import db_context
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
session.execute(table.delete()) # Truncate table
session.commit()
@pytest.fixture(autouse=True)
def clear_tables():
_clear_tables()
@pytest.fixture(scope="module")
def local_client():
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=False)
return server
@pytest.fixture
def default_organization(server: SyncServer):
"""Fixture to create and return the default organization."""
org = server.organization_manager.create_default_organization()
yield org
@pytest.fixture
def default_user(server: SyncServer, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_default_user(org_id=default_organization.id)
yield user
@pytest.fixture
def sarah_agent(server: SyncServer, default_user, default_organization):
"""Fixture to create and return a sample agent within the default organization."""
agent_state = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="sarah_agent",
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=default_user,
)
yield agent_state
def test_agent_serialization(server, sarah_agent, default_user):
"""Test serializing an Agent instance to JSON."""
result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user)
# Assert that the result is a dictionary (JSON object)
assert isinstance(result, dict), "Expected a dictionary result"
# Assert that the 'id' field is present and matches the agent's ID
assert "id" in result, "Agent 'id' is missing in the serialized result"
assert result["id"] == sarah_agent.id, f"Expected agent 'id' to be {sarah_agent.id}, but got {result['id']}"
# Assert that the 'llm_config' and 'embedding_config' fields exist
assert "llm_config" in result, "'llm_config' is missing in the serialized result"
assert "embedding_config" in result, "'embedding_config' is missing in the serialized result"
# Assert that 'messages' is a list
assert isinstance(result.get("messages", []), list), "'messages' should be a list"
# Assert that the 'tool_exec_environment_variables' field is a list (empty or populated)
assert isinstance(result.get("tool_exec_environment_variables", []), list), "'tool_exec_environment_variables' should be a list"
# Assert that the 'agent_type' is a valid string
assert isinstance(result.get("agent_type"), str), "'agent_type' should be a string"
# Assert that the 'tool_rules' field is a list (even if empty)
assert isinstance(result.get("tool_rules", []), list), "'tool_rules' should be a list"
# Check that all necessary fields are present in the 'messages' section, focusing on core elements
if "messages" in result:
for message in result["messages"]:
assert "id" in message, "Message 'id' is missing"
assert "text" in message, "Message 'text' is missing"
assert "role" in message, "Message 'role' is missing"
assert "created_at" in message, "Message 'created_at' is missing"
assert "updated_at" in message, "Message 'updated_at' is missing"
# Optionally check that 'created_at' and 'updated_at' are in ISO 8601 format
assert isinstance(result["created_at"], str), "Expected 'created_at' to be a string"
assert isinstance(result["updated_at"], str), "Expected 'updated_at' to be a string"
# Optionally check for presence of any required metadata or ensure it is null if expected
assert "metadata_" in result, "'metadata_' field is missing"
assert result["metadata_"] is None, "'metadata_' should be null"
# Assert that the agent name is as expected (if defined)
assert result.get("name") == sarah_agent.name, "Expected agent 'name' to not be None, but found something else"
print(json.dumps(result, indent=4))
def test_agent_deserialization_basic(local_client, server, sarah_agent, default_user):
"""Test deserializing JSON into an Agent instance."""
# Send a message first
sarah_agent = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user)
result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user)
# Delete the agent
server.agent_manager.delete_agent(sarah_agent.id, actor=default_user)
agent_state = server.agent_manager.deserialize(serialized_agent=result, actor=default_user)
assert agent_state.name == sarah_agent.name
assert len(agent_state.message_ids) == len(sarah_agent.message_ids)

View File

@@ -104,7 +104,7 @@ def search_agent_two(client: Union[LocalClient, RESTClient]):
@pytest.fixture(autouse=True)
def clear_tables():
"""Clear the sandbox tables before each test."""
from letta.server.server import db_context
from letta.server.db import db_context
with db_context() as session:
session.execute(delete(SandboxEnvironmentVariable))

View File

@@ -90,7 +90,7 @@ def client(request):
@pytest.fixture(autouse=True)
def clear_tables():
"""Fixture to clear the organization table before each test."""
from letta.server.server import db_context
from letta.server.db import db_context
with db_context() as session:
session.execute(delete(FileMetadata))

View File

@@ -5,35 +5,13 @@ from datetime import datetime, timedelta
import pytest
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy import delete
from sqlalchemy.exc import IntegrityError
from letta.config import LettaConfig
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS
from letta.embeddings import embedding_model
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.orm import (
Agent,
AgentPassage,
Block,
BlocksAgents,
FileMetadata,
Job,
JobMessage,
Message,
Organization,
Provider,
SandboxConfig,
SandboxEnvironmentVariable,
Source,
SourcePassage,
SourcesAgents,
Step,
Tool,
ToolsAgents,
User,
)
from letta.orm.agents_tags import AgentsTags
from letta.orm import Base
from letta.orm.enums import JobType, ToolType
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
from letta.schemas.agent import CreateAgent, UpdateAgent
@@ -81,30 +59,13 @@ USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
@pytest.fixture(autouse=True)
def clear_tables(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(Message))
session.execute(delete(AgentPassage))
session.execute(delete(SourcePassage))
session.execute(delete(JobMessage)) # Clear JobMessage first
session.execute(delete(Job))
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
session.execute(delete(BlocksAgents))
session.execute(delete(SourcesAgents))
session.execute(delete(AgentsTags))
session.execute(delete(SandboxEnvironmentVariable))
session.execute(delete(SandboxConfig))
session.execute(delete(Block))
session.execute(delete(FileMetadata))
session.execute(delete(Source))
session.execute(delete(Tool)) # Clear all records from the Tool table
session.execute(delete(Agent))
session.execute(delete(User)) # Clear all records from the user table
session.execute(delete(Step))
session.execute(delete(Provider))
session.execute(delete(Organization)) # Clear all records from the organization table
session.commit() # Commit the deletion
def clear_tables():
from letta.server.db import db_context
with db_context() as session:
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
session.execute(table.delete()) # Truncate table
session.commit()
@pytest.fixture