diff --git a/letta/serialize_schemas/__init__.py b/letta/serialize_schemas/__init__.py new file mode 100644 index 00000000..d0e09d6d --- /dev/null +++ b/letta/serialize_schemas/__init__.py @@ -0,0 +1 @@ +from letta.serialize_schemas.agent import SerializedAgentSchema diff --git a/letta/serialize_schemas/agent.py b/letta/serialize_schemas/agent.py new file mode 100644 index 00000000..036adf44 --- /dev/null +++ b/letta/serialize_schemas/agent.py @@ -0,0 +1,36 @@ +from marshmallow import fields + +from letta.orm import Agent +from letta.serialize_schemas.base import BaseSchema +from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField +from letta.serialize_schemas.message import SerializedMessageSchema + + +class SerializedAgentSchema(BaseSchema): + """ + Marshmallow schema for serializing/deserializing Agent objects. + Excludes relational fields. + """ + + llm_config = LLMConfigField() + embedding_config = EmbeddingConfigField() + tool_rules = ToolRulesField() + + messages = fields.List(fields.Nested(SerializedMessageSchema)) + + def __init__(self, *args, session=None, **kwargs): + super().__init__(*args, **kwargs) + if session: + self.session = session + + # propagate session to nested schemas + for field_name, field_obj in self.fields.items(): + if isinstance(field_obj, fields.List) and hasattr(field_obj.inner, "schema"): + field_obj.inner.schema.session = session + elif hasattr(field_obj, "schema"): + field_obj.schema.session = session + + class Meta(BaseSchema.Meta): + model = Agent + # TODO: Serialize these as well... + exclude = ("tools", "sources", "core_memory", "tags", "source_passages", "agent_passages", "organization") diff --git a/letta/serialize_schemas/base.py b/letta/serialize_schemas/base.py new file mode 100644 index 00000000..b64e76e2 --- /dev/null +++ b/letta/serialize_schemas/base.py @@ -0,0 +1,12 @@ +from marshmallow_sqlalchemy import SQLAlchemyAutoSchema + + +class BaseSchema(SQLAlchemyAutoSchema): + """ + Base schema for all SQLAlchemy models. + This ensures all schemas share the same session. + """ + + class Meta: + include_relationships = True + load_instance = True diff --git a/letta/serialize_schemas/custom_fields.py b/letta/serialize_schemas/custom_fields.py new file mode 100644 index 00000000..4478659e --- /dev/null +++ b/letta/serialize_schemas/custom_fields.py @@ -0,0 +1,69 @@ +from marshmallow import fields + +from letta.helpers.converters import ( + deserialize_embedding_config, + deserialize_llm_config, + deserialize_tool_calls, + deserialize_tool_rules, + serialize_embedding_config, + serialize_llm_config, + serialize_tool_calls, + serialize_tool_rules, +) + + +class PydanticField(fields.Field): + """Generic Marshmallow field for handling Pydantic models.""" + + def __init__(self, pydantic_class, **kwargs): + self.pydantic_class = pydantic_class + super().__init__(**kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + return value.model_dump() if value else None + + def _deserialize(self, value, attr, data, **kwargs): + return self.pydantic_class(**value) if value else None + + +class LLMConfigField(fields.Field): + """Marshmallow field for handling LLMConfig serialization.""" + + def _serialize(self, value, attr, obj, **kwargs): + return serialize_llm_config(value) + + def _deserialize(self, value, attr, data, **kwargs): + return deserialize_llm_config(value) + + +class EmbeddingConfigField(fields.Field): + """Marshmallow field for handling EmbeddingConfig serialization.""" + + def _serialize(self, value, attr, obj, **kwargs): + return serialize_embedding_config(value) + + def _deserialize(self, value, attr, data, **kwargs): + return deserialize_embedding_config(value) + + +class ToolRulesField(fields.List): + """Custom Marshmallow field to handle a list of ToolRules.""" + + def __init__(self, **kwargs): + super().__init__(fields.Dict(), **kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + return serialize_tool_rules(value) + + def _deserialize(self, value, attr, data, **kwargs): + return deserialize_tool_rules(value) + + +class ToolCallField(fields.Field): + """Marshmallow field for handling a list of OpenAI ToolCall objects.""" + + def _serialize(self, value, attr, obj, **kwargs): + return serialize_tool_calls(value) + + def _deserialize(self, value, attr, data, **kwargs): + return deserialize_tool_calls(value) diff --git a/letta/serialize_schemas/message.py b/letta/serialize_schemas/message.py new file mode 100644 index 00000000..58d055d6 --- /dev/null +++ b/letta/serialize_schemas/message.py @@ -0,0 +1,15 @@ +from letta.orm.message import Message +from letta.serialize_schemas.base import BaseSchema +from letta.serialize_schemas.custom_fields import ToolCallField + + +class SerializedMessageSchema(BaseSchema): + """ + Marshmallow schema for serializing/deserializing Message objects. + """ + + tool_calls = ToolCallField() + + class Meta(BaseSchema.Meta): + model = Message + exclude = ("step", "job_message") diff --git a/letta/server/db.py b/letta/server/db.py new file mode 100644 index 00000000..834837f1 --- /dev/null +++ b/letta/server/db.py @@ -0,0 +1,111 @@ +import os +from contextlib import contextmanager + +from rich.console import Console +from rich.panel import Panel +from rich.text import Text +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from letta.config import LettaConfig +from letta.log import get_logger +from letta.orm import Base + +# NOTE: hack to see if single session management works +from letta.settings import settings + +config = LettaConfig.load() + +logger = get_logger(__name__) + + +def print_sqlite_schema_error(): + """Print a formatted error message for SQLite schema issues""" + console = Console() + error_text = Text() + error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red") + error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white") + error_text.append("https://docs.letta.com/server/docker", style="blue underline") + error_text.append(") or use Postgres by setting ", style="white") + error_text.append("LETTA_PG_URI", style="yellow") + error_text.append(".\n\n", style="white") + error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white") + error_text.append("rm ~/.letta/sqlite.db", style="yellow") + error_text.append(" or downgrade to your previous version of Letta.", style="white") + + console.print(Panel(error_text, border_style="red")) + + +@contextmanager +def db_error_handler(): + """Context manager for handling database errors""" + try: + yield + except Exception as e: + # Handle other SQLAlchemy errors + print(e) + print_sqlite_schema_error() + # raise ValueError(f"SQLite DB error: {str(e)}") + exit(1) + + +if settings.letta_pg_uri_no_default: + print("Creating postgres engine") + config.recall_storage_type = "postgres" + config.recall_storage_uri = settings.letta_pg_uri_no_default + config.archival_storage_type = "postgres" + config.archival_storage_uri = settings.letta_pg_uri_no_default + + # create engine + engine = create_engine( + settings.letta_pg_uri, + pool_size=settings.pg_pool_size, + max_overflow=settings.pg_max_overflow, + pool_timeout=settings.pg_pool_timeout, + pool_recycle=settings.pg_pool_recycle, + echo=settings.pg_echo, + ) +else: + # TODO: don't rely on config storage + engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db") + logger.info("Creating sqlite engine " + engine_path) + + engine = create_engine(engine_path) + + # Store the original connect method + original_connect = engine.connect + + def wrapped_connect(*args, **kwargs): + with db_error_handler(): + # Get the connection + connection = original_connect(*args, **kwargs) + + # Store the original execution method + original_execute = connection.execute + + # Wrap the execute method of the connection + def wrapped_execute(*args, **kwargs): + with db_error_handler(): + return original_execute(*args, **kwargs) + + # Replace the connection's execute method + connection.execute = wrapped_execute + + return connection + + # Replace the engine's connect method + engine.connect = wrapped_connect + + Base.metadata.create_all(bind=engine) + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +db_context = contextmanager(get_db) diff --git a/letta/server/server.py b/letta/server/server.py index fabefc7e..b3590999 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -18,6 +18,7 @@ import letta.server.utils as server_utils import letta.system as system from letta.agent import Agent, save_agent from letta.chat_only_agent import ChatOnlyAgent +from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector, load_data from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads @@ -27,7 +28,6 @@ from letta.interface import AgentInterface # abstract from letta.interface import CLIInterface # for printing to terminal from letta.log import get_logger from letta.offline_memory_agent import OfflineMemoryAgent -from letta.orm import Base from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentState, AgentType, CreateAgent from letta.schemas.block import BlockUpdate @@ -82,8 +82,10 @@ from letta.services.step_manager import StepManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager +from letta.settings import model_settings, settings, tool_settings from letta.utils import get_friendly_error_msg +config = LettaConfig.load() logger = get_logger(__name__) @@ -145,118 +147,6 @@ class Server(object): raise NotImplementedError -from contextlib import contextmanager - -from rich.console import Console -from rich.panel import Panel -from rich.text import Text -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from letta.config import LettaConfig - -# NOTE: hack to see if single session management works -from letta.settings import model_settings, settings, tool_settings - -config = LettaConfig.load() - - -def print_sqlite_schema_error(): - """Print a formatted error message for SQLite schema issues""" - console = Console() - error_text = Text() - error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red") - error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white") - error_text.append("https://docs.letta.com/server/docker", style="blue underline") - error_text.append(") or use Postgres by setting ", style="white") - error_text.append("LETTA_PG_URI", style="yellow") - error_text.append(".\n\n", style="white") - error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white") - error_text.append("rm ~/.letta/sqlite.db", style="yellow") - error_text.append(" or downgrade to your previous version of Letta.", style="white") - - console.print(Panel(error_text, border_style="red")) - - -@contextmanager -def db_error_handler(): - """Context manager for handling database errors""" - try: - yield - except Exception as e: - # Handle other SQLAlchemy errors - print(e) - print_sqlite_schema_error() - # raise ValueError(f"SQLite DB error: {str(e)}") - exit(1) - - -if settings.letta_pg_uri_no_default: - print("Creating postgres engine") - config.recall_storage_type = "postgres" - config.recall_storage_uri = settings.letta_pg_uri_no_default - config.archival_storage_type = "postgres" - config.archival_storage_uri = settings.letta_pg_uri_no_default - - # create engine - engine = create_engine( - settings.letta_pg_uri, - pool_size=settings.pg_pool_size, - max_overflow=settings.pg_max_overflow, - pool_timeout=settings.pg_pool_timeout, - pool_recycle=settings.pg_pool_recycle, - echo=settings.pg_echo, - ) -else: - # TODO: don't rely on config storage - engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db") - logger.info("Creating sqlite engine " + engine_path) - - engine = create_engine(engine_path) - - # Store the original connect method - original_connect = engine.connect - - def wrapped_connect(*args, **kwargs): - with db_error_handler(): - # Get the connection - connection = original_connect(*args, **kwargs) - - # Store the original execution method - original_execute = connection.execute - - # Wrap the execute method of the connection - def wrapped_execute(*args, **kwargs): - with db_error_handler(): - return original_execute(*args, **kwargs) - - # Replace the connection's execute method - connection.execute = wrapped_execute - - return connection - - # Replace the engine's connect method - engine.connect = wrapped_connect - - Base.metadata.create_all(bind=engine) - -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -# Dependency -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - -from contextlib import contextmanager - -db_context = contextmanager(get_db) - - class SyncServer(Server): """Simple single-threaded / blocking server process""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 4b4d1bc3..e2360240 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -29,6 +29,7 @@ from letta.schemas.source import Source as PydanticSource from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool_rule import ToolRule as PydanticToolRule from letta.schemas.user import User as PydanticUser +from letta.serialize_schemas import SerializedAgentSchema from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import ( _process_relationship, @@ -53,7 +54,7 @@ class AgentManager: """Manager class to handle business logic related to Agents.""" def __init__(self): - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context self.block_manager = BlockManager() @@ -355,6 +356,24 @@ class AgentManager: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) agent.hard_delete(session) + @enforce_types + def serialize(self, agent_id: str, actor: PydanticUser) -> dict: + with self.session_maker() as session: + # Retrieve the agent + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + schema = SerializedAgentSchema(session=session) + return schema.dump(agent) + + @enforce_types + def deserialize(self, serialized_agent: dict, actor: PydanticUser) -> PydanticAgentState: + # TODO: Use actor to override fields + with self.session_maker() as session: + schema = SerializedAgentSchema(session=session) + agent = schema.load(serialized_agent, session=session) + agent.organization_id = actor.organization_id + agent = agent.create(session, actor=actor) + return agent.to_pydantic() + # ====================================================================================================================== # Per Agent Environment Variable Management # ====================================================================================================================== diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 41275e1e..7ae743a7 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -16,7 +16,7 @@ class BlockManager: def __init__(self): # Fetching the db_context similarly as in ToolManager - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 71bb5c0c..5f3f7fd0 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -29,7 +29,7 @@ class JobManager: def __init__(self): # Fetching the db_context similarly as in OrganizationManager - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 01eccb53..ed2881b3 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -16,7 +16,7 @@ class MessageManager: """Manager class to handle business logic related to Messages.""" def __init__(self): - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index b5b3ffd1..3f47f8a3 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -16,7 +16,7 @@ class OrganizationManager: # TODO: Please refactor this out # I am currently working on a ORM refactor and would like to make a more minimal set of changes # - Matt - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 7bcf1bc3..1758287b 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -16,7 +16,7 @@ class PassageManager: """Manager class to handle business logic related to Passages.""" def __init__(self): - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 01d7c701..20f7c2ad 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -10,7 +10,7 @@ from letta.utils import enforce_types class ProviderManager: def __init__(self): - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index 1fea2d2a..e4e01111 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -20,7 +20,7 @@ class SandboxConfigManager: """Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable.""" def __init__(self): - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 41e1bb8a..21a36ded 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -15,7 +15,7 @@ class SourceManager: """Manager class to handle business logic related to Sources.""" def __init__(self): - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index a316eda6..612c8bf2 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -17,7 +17,7 @@ from letta.utils import enforce_types class StepManager: def __init__(self): - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 211b3e33..aa85e17b 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -31,7 +31,7 @@ class ToolManager: def __init__(self): # Fetching the db_context similarly as in OrganizationManager - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index 939adcfe..9dbd15e4 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -17,7 +17,7 @@ class UserManager: def __init__(self): # Fetching the db_context similarly as in OrganizationManager - from letta.server.server import db_context + from letta.server.db import db_context self.session_maker = db_context diff --git a/poetry.lock b/poetry.lock index 92bd8290..55290e12 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3036,6 +3036,26 @@ dev = ["marshmallow[tests]", "pre-commit (>=3.5,<5.0)", "tox"] docs = ["autodocsumm (==0.2.14)", "furo (==2024.8.6)", "sphinx (==8.1.3)", "sphinx-copybutton (==0.5.2)", "sphinx-issues (==5.0.0)", "sphinxext-opengraph (==0.9.1)"] tests = ["pytest", "simplejson"] +[[package]] +name = "marshmallow-sqlalchemy" +version = "1.4.1" +description = "SQLAlchemy integration with the marshmallow (de)serialization library" +optional = false +python-versions = ">=3.9" +files = [ + {file = "marshmallow_sqlalchemy-1.4.1-py3-none-any.whl", hash = "sha256:9a3dd88a2b24f425fbffb3fea8aeb7f424a932fc97372a9f1338b7a379396191"}, + {file = "marshmallow_sqlalchemy-1.4.1.tar.gz", hash = "sha256:b4aa964356d00e178bdb8469a28daa9022b375ff4f5c04f8e2b9aafe1e65c529"}, +] + +[package.dependencies] +marshmallow = ">=3.18.0" +SQLAlchemy = ">=1.4.40,<3.0" + +[package.extras] +dev = ["marshmallow-sqlalchemy[tests]", "pre-commit (>=3.5,<5.0)", "tox"] +docs = ["furo (==2024.8.6)", "sphinx (==8.1.3)", "sphinx-copybutton (==0.5.2)", "sphinx-design (==0.6.1)", "sphinx-issues (==5.0.0)", "sphinxext-opengraph (==0.9.1)"] +tests = ["pytest (<9)", "pytest-lazy-fixtures"] + [[package]] name = "matplotlib-inline" version = "0.1.7" @@ -6550,4 +6570,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<3.14,>=3.10" -content-hash = "05633c1ae9cf8125ccb07f90bf8887072ee3a452854f35dec2421e574ee202f7" +content-hash = "37167dffe2006e220123bbc64f6bd91ab44e363d6950dd73c9eacc4d056daeb7" diff --git a/pyproject.toml b/pyproject.toml index 70d1bcc2..74de01e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ openai = "^1.60.0" google-genai = {version = "^1.1.0", optional = true} faker = "^36.1.0" colorama = "^0.4.6" +marshmallow-sqlalchemy = "^1.4.1" [tool.poetry.extras] diff --git a/tests/integration_test_multi_agent.py b/tests/integration_test_multi_agent.py index d0b0edb3..91df2e24 100644 --- a/tests/integration_test_multi_agent.py +++ b/tests/integration_test_multi_agent.py @@ -15,8 +15,8 @@ from tests.utils import wait_for_incoming_message @pytest.fixture(autouse=True) -def truncate_database(): - from letta.server.server import db_context +def clear_tables(): + from letta.server.db import db_context with db_context() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 9073e3e2..531d0c2e 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -35,7 +35,7 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user")) @pytest.fixture(autouse=True) def clear_tables(): """Fixture to clear the organization table before each test.""" - from letta.server.server import db_context + from letta.server.db import db_context with db_context() as session: session.execute(delete(SandboxEnvironmentVariable)) diff --git a/tests/manual_test_many_messages.py b/tests/manual_test_many_messages.py index 0eef1764..c47f32dc 100644 --- a/tests/manual_test_many_messages.py +++ b/tests/manual_test_many_messages.py @@ -21,7 +21,7 @@ from tests.integration_test_summarizer import LLM_CONFIG_DIR @pytest.fixture(autouse=True) def truncate_database(): - from letta.server.server import db_context + from letta.server.db import db_context with db_context() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py new file mode 100644 index 00000000..ade9f2f4 --- /dev/null +++ b/tests/test_agent_serialization.py @@ -0,0 +1,137 @@ +import json + +import pytest + +from letta import create_client +from letta.config import LettaConfig +from letta.orm import Base +from letta.schemas.agent import CreateAgent +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.server.server import SyncServer + + +def _clear_tables(): + from letta.server.db import db_context + + with db_context() as session: + for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues + session.execute(table.delete()) # Truncate table + session.commit() + + +@pytest.fixture(autouse=True) +def clear_tables(): + _clear_tables() + + +@pytest.fixture(scope="module") +def local_client(): + client = create_client() + client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) + client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) + yield client + + +@pytest.fixture(scope="module") +def server(): + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=False) + return server + + +@pytest.fixture +def default_organization(server: SyncServer): + """Fixture to create and return the default organization.""" + org = server.organization_manager.create_default_organization() + yield org + + +@pytest.fixture +def default_user(server: SyncServer, default_organization): + """Fixture to create and return the default user within the default organization.""" + user = server.user_manager.create_default_user(org_id=default_organization.id) + yield user + + +@pytest.fixture +def sarah_agent(server: SyncServer, default_user, default_organization): + """Fixture to create and return a sample agent within the default organization.""" + agent_state = server.agent_manager.create_agent( + agent_create=CreateAgent( + name="sarah_agent", + memory_blocks=[], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=default_user, + ) + yield agent_state + + +def test_agent_serialization(server, sarah_agent, default_user): + """Test serializing an Agent instance to JSON.""" + result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user) + + # Assert that the result is a dictionary (JSON object) + assert isinstance(result, dict), "Expected a dictionary result" + + # Assert that the 'id' field is present and matches the agent's ID + assert "id" in result, "Agent 'id' is missing in the serialized result" + assert result["id"] == sarah_agent.id, f"Expected agent 'id' to be {sarah_agent.id}, but got {result['id']}" + + # Assert that the 'llm_config' and 'embedding_config' fields exist + assert "llm_config" in result, "'llm_config' is missing in the serialized result" + assert "embedding_config" in result, "'embedding_config' is missing in the serialized result" + + # Assert that 'messages' is a list + assert isinstance(result.get("messages", []), list), "'messages' should be a list" + + # Assert that the 'tool_exec_environment_variables' field is a list (empty or populated) + assert isinstance(result.get("tool_exec_environment_variables", []), list), "'tool_exec_environment_variables' should be a list" + + # Assert that the 'agent_type' is a valid string + assert isinstance(result.get("agent_type"), str), "'agent_type' should be a string" + + # Assert that the 'tool_rules' field is a list (even if empty) + assert isinstance(result.get("tool_rules", []), list), "'tool_rules' should be a list" + + # Check that all necessary fields are present in the 'messages' section, focusing on core elements + if "messages" in result: + for message in result["messages"]: + assert "id" in message, "Message 'id' is missing" + assert "text" in message, "Message 'text' is missing" + assert "role" in message, "Message 'role' is missing" + assert "created_at" in message, "Message 'created_at' is missing" + assert "updated_at" in message, "Message 'updated_at' is missing" + + # Optionally check that 'created_at' and 'updated_at' are in ISO 8601 format + assert isinstance(result["created_at"], str), "Expected 'created_at' to be a string" + assert isinstance(result["updated_at"], str), "Expected 'updated_at' to be a string" + + # Optionally check for presence of any required metadata or ensure it is null if expected + assert "metadata_" in result, "'metadata_' field is missing" + assert result["metadata_"] is None, "'metadata_' should be null" + + # Assert that the agent name is as expected (if defined) + assert result.get("name") == sarah_agent.name, "Expected agent 'name' to not be None, but found something else" + + print(json.dumps(result, indent=4)) + + +def test_agent_deserialization_basic(local_client, server, sarah_agent, default_user): + """Test deserializing JSON into an Agent instance.""" + # Send a message first + sarah_agent = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user) + result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user) + + # Delete the agent + server.agent_manager.delete_agent(sarah_agent.id, actor=default_user) + + agent_state = server.agent_manager.deserialize(serialized_agent=result, actor=default_user) + + assert agent_state.name == sarah_agent.name + assert len(agent_state.message_ids) == len(sarah_agent.message_ids) diff --git a/tests/test_client.py b/tests/test_client.py index b727f77a..c53ac781 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -104,7 +104,7 @@ def search_agent_two(client: Union[LocalClient, RESTClient]): @pytest.fixture(autouse=True) def clear_tables(): """Clear the sandbox tables before each test.""" - from letta.server.server import db_context + from letta.server.db import db_context with db_context() as session: session.execute(delete(SandboxEnvironmentVariable)) diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 997766f9..00ee65ba 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -90,7 +90,7 @@ def client(request): @pytest.fixture(autouse=True) def clear_tables(): """Fixture to clear the organization table before each test.""" - from letta.server.server import db_context + from letta.server.db import db_context with db_context() as session: session.execute(delete(FileMetadata)) diff --git a/tests/test_managers.py b/tests/test_managers.py index cce3e449..8d078acf 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -5,35 +5,13 @@ from datetime import datetime, timedelta import pytest from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction -from sqlalchemy import delete from sqlalchemy.exc import IntegrityError from letta.config import LettaConfig from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS from letta.embeddings import embedding_model from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.orm import ( - Agent, - AgentPassage, - Block, - BlocksAgents, - FileMetadata, - Job, - JobMessage, - Message, - Organization, - Provider, - SandboxConfig, - SandboxEnvironmentVariable, - Source, - SourcePassage, - SourcesAgents, - Step, - Tool, - ToolsAgents, - User, -) -from letta.orm.agents_tags import AgentsTags +from letta.orm import Base from letta.orm.enums import JobType, ToolType from letta.orm.errors import NoResultFound, UniqueConstraintViolationError from letta.schemas.agent import CreateAgent, UpdateAgent @@ -81,30 +59,13 @@ USING_SQLITE = not bool(os.getenv("LETTA_PG_URI")) @pytest.fixture(autouse=True) -def clear_tables(server: SyncServer): - """Fixture to clear the organization table before each test.""" - with server.organization_manager.session_maker() as session: - session.execute(delete(Message)) - session.execute(delete(AgentPassage)) - session.execute(delete(SourcePassage)) - session.execute(delete(JobMessage)) # Clear JobMessage first - session.execute(delete(Job)) - session.execute(delete(ToolsAgents)) # Clear ToolsAgents first - session.execute(delete(BlocksAgents)) - session.execute(delete(SourcesAgents)) - session.execute(delete(AgentsTags)) - session.execute(delete(SandboxEnvironmentVariable)) - session.execute(delete(SandboxConfig)) - session.execute(delete(Block)) - session.execute(delete(FileMetadata)) - session.execute(delete(Source)) - session.execute(delete(Tool)) # Clear all records from the Tool table - session.execute(delete(Agent)) - session.execute(delete(User)) # Clear all records from the user table - session.execute(delete(Step)) - session.execute(delete(Provider)) - session.execute(delete(Organization)) # Clear all records from the organization table - session.commit() # Commit the deletion +def clear_tables(): + from letta.server.db import db_context + + with db_context() as session: + for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues + session.execute(table.delete()) # Truncate table + session.commit() @pytest.fixture