diff --git a/compose.tracing.yaml b/compose.tracing.yaml new file mode 100644 index 00000000..80d6a3c1 --- /dev/null +++ b/compose.tracing.yaml @@ -0,0 +1,18 @@ +services: + letta_server: + environment: + - OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317 + + otel-collector: + image: otel/opentelemetry-collector-contrib:0.92.0 + command: ["--config=/etc/otel-collector-config.yaml"] + volumes: + - ./otel-collector-config.yaml:/etc/otel-collector-config.yaml + environment: + - CLICKHOUSE_ENDPOINT=${CLICKHOUSE_ENDPOINT} + - CLICKHOUSE_DATABASE=${CLICKHOUSE_DATABASE} + - CLICKHOUSE_USER=${CLICKHOUSE_USER} + - CLICKHOUSE_PASSWORD=${CLICKHOUSE_PASSWORD} + ports: + - "4317:4317" + - "4318:4318" diff --git a/compose.yaml b/compose.yaml index 0ecdadb1..f6d13abc 100644 --- a/compose.yaml +++ b/compose.yaml @@ -49,9 +49,12 @@ services: - VLLM_API_BASE=${VLLM_API_BASE} - OPENLLM_AUTH_TYPE=${OPENLLM_AUTH_TYPE} - OPENLLM_API_KEY=${OPENLLM_API_KEY} - #volumes: - #- ./configs/server_config.yaml:/root/.letta/config # config file - #- ~/.letta/credentials:/root/.letta/credentials # credentials file + # volumes: + # - ./configs/server_config.yaml:/root/.letta/config # config file + # - ~/.letta/credentials:/root/.letta/credentials # credentials file + # Uncomment this line to mount a local directory for tool execution, and specify the mount path + # before running docker compose: `export LETTA_SANDBOX_MOUNT_PATH=$PWD/directory` + # - ${LETTA_SANDBOX_MOUNT_PATH:?}:/root/.letta/tool_execution_dir # mounted volume for tool execution letta_nginx: hostname: letta-nginx image: nginx:stable-alpine3.17-slim diff --git a/letta/__init__.py b/letta/__init__.py index e4e33377..e871e65e 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.35" +__version__ = "0.6.36" # import clients from letta.client.client import LocalClient, RESTClient, create_client diff --git a/letta/client/client.py b/letta/client/client.py index 58f680c9..e26ffa0a 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1,4 +1,5 @@ import logging +import sys import time from typing import Callable, Dict, Generator, List, Optional, Union @@ -40,6 +41,16 @@ from letta.schemas.tool_rule import BaseToolRule from letta.server.rest_api.interface import QueuingInterface from letta.utils import get_human_text, get_persona_text +# Print deprecation notice in yellow when module is imported +print( + "\n\n\033[93m" + + "DEPRECATION WARNING: This legacy Python client has been deprecated and will be removed in a future release.\n" + + "Please migrate to the new official python SDK by running: pip install letta-client\n" + + "For further documentation, visit: https://docs.letta.com/api-reference/overview#python-sdk" + + "\033[0m\n\n", + file=sys.stderr, +) + def create_client(base_url: Optional[str] = None, token: Optional[str] = None): if base_url is None: diff --git a/letta/constants.py b/letta/constants.py index 95db4282..e06984a3 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -2,7 +2,7 @@ import os from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta") -LETTA_DIR_TOOL_SANDBOX = os.path.join(LETTA_DIR, "tool_sandbox_dir") +LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir") ADMIN_PREFIX = "/v1/admin" API_PREFIX = "/v1" @@ -146,6 +146,9 @@ MESSAGE_SUMMARY_WARNING_STR = " ".join( # "Remember to pass request_heartbeat = true if you would like to send a message immediately after.", ] ) +DATA_SOURCE_ATTACH_ALERT = ( + "[ALERT] New data was just uploaded to archival memory. You can view this data by calling the archival_memory_search tool." +) # The ackknowledgement message used in the summarize sequence MESSAGE_SUMMARY_REQUEST_ACK = "Understood, I will respond with a summary of the message (and only the summary, nothing else) once I receive the conversation history. I'm ready." diff --git a/letta/log.py b/letta/log.py index 0d4ad8e1..8a2506ac 100644 --- a/letta/log.py +++ b/letta/log.py @@ -54,7 +54,7 @@ DEVELOPMENT_LOGGING = { "propagate": True, # Let logs bubble up to root }, "uvicorn": { - "level": "DEBUG", + "level": "INFO", "handlers": ["console"], "propagate": True, }, diff --git a/letta/schemas/letta_base.py b/letta/schemas/letta_base.py index d6850933..5d2a3da3 100644 --- a/letta/schemas/letta_base.py +++ b/letta/schemas/letta_base.py @@ -38,15 +38,15 @@ class LettaBase(BaseModel): description=cls._id_description(prefix), pattern=cls._id_regex_pattern(prefix), examples=[cls._id_example(prefix)], - default_factory=cls._generate_id, + default_factory=cls.generate_id, ) @classmethod - def _generate_id(cls, prefix: Optional[str] = None) -> str: + def generate_id(cls, prefix: Optional[str] = None) -> str: prefix = prefix or cls.__id_prefix__ return f"{prefix}-{uuid.uuid4()}" - # def _generate_id(self) -> str: + # def generate_id(self) -> str: # return f"{self.__id_prefix__}-{uuid.uuid4()}" @classmethod diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index b7917038..9084c729 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -27,7 +27,7 @@ class Provider(ProviderBase): def resolve_identifier(self): if not self.id: - self.id = ProviderBase._generate_id(prefix=ProviderBase.__id_prefix__) + self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) def list_llm_models(self) -> List[LLMConfig]: return [] diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py index 51f13919..80e93c11 100644 --- a/letta/schemas/sandbox_config.py +++ b/letta/schemas/sandbox_config.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field, model_validator +from letta.constants import LETTA_TOOL_EXECUTION_DIR from letta.schemas.agent import AgentState from letta.schemas.letta_base import LettaBase, OrmMetadataBase from letta.settings import tool_settings @@ -71,7 +72,7 @@ class LocalSandboxConfig(BaseModel): if tool_settings.local_sandbox_dir: data["sandbox_dir"] = tool_settings.local_sandbox_dir else: - data["sandbox_dir"] = "~/.letta" + data["sandbox_dir"] = LETTA_TOOL_EXECUTION_DIR return data diff --git a/letta/serialize_schemas/agent.py b/letta/serialize_schemas/agent.py index 036adf44..dfb3ff0e 100644 --- a/letta/serialize_schemas/agent.py +++ b/letta/serialize_schemas/agent.py @@ -1,9 +1,16 @@ -from marshmallow import fields +from typing import Dict + +from marshmallow import fields, post_dump from letta.orm import Agent +from letta.schemas.agent import AgentState as PydanticAgentState +from letta.schemas.user import User from letta.serialize_schemas.base import BaseSchema +from letta.serialize_schemas.block import SerializedBlockSchema from letta.serialize_schemas.custom_fields import EmbeddingConfigField, LLMConfigField, ToolRulesField from letta.serialize_schemas.message import SerializedMessageSchema +from letta.serialize_schemas.tool import SerializedToolSchema +from letta.server.db import SessionLocal class SerializedAgentSchema(BaseSchema): @@ -12,25 +19,51 @@ class SerializedAgentSchema(BaseSchema): Excludes relational fields. """ + __pydantic_model__ = PydanticAgentState + llm_config = LLMConfigField() embedding_config = EmbeddingConfigField() tool_rules = ToolRulesField() messages = fields.List(fields.Nested(SerializedMessageSchema)) + core_memory = fields.List(fields.Nested(SerializedBlockSchema)) + tools = fields.List(fields.Nested(SerializedToolSchema)) - def __init__(self, *args, session=None, **kwargs): - super().__init__(*args, **kwargs) - if session: - self.session = session + def __init__(self, *args, session: SessionLocal, actor: User, **kwargs): + super().__init__(*args, actor=actor, **kwargs) + self.session = session - # propagate session to nested schemas - for field_name, field_obj in self.fields.items(): - if isinstance(field_obj, fields.List) and hasattr(field_obj.inner, "schema"): - field_obj.inner.schema.session = session - elif hasattr(field_obj, "schema"): - field_obj.schema.session = session + # Propagate session and actor to nested schemas automatically + for field in self.fields.values(): + if isinstance(field, fields.List) and isinstance(field.inner, fields.Nested): + field.inner.schema.session = session + field.inner.schema.actor = actor + elif isinstance(field, fields.Nested): + field.schema.session = session + field.schema.actor = actor + + @post_dump + def sanitize_ids(self, data: Dict, **kwargs): + data = super().sanitize_ids(data, **kwargs) + + # Remap IDs of messages + # Need to do this in post, so we can correctly map the in-context message IDs + # TODO: Remap message_ids to reference objects, not just be a list + id_remapping = dict() + for message in data.get("messages"): + message_id = message.get("id") + if message_id not in id_remapping: + id_remapping[message_id] = SerializedMessageSchema.__pydantic_model__.generate_id() + message["id"] = id_remapping[message_id] + else: + raise ValueError(f"Duplicate message IDs in agent.messages: {message_id}") + + # Remap in context message ids + data["message_ids"] = [id_remapping[message_id] for message_id in data.get("message_ids")] + + return data class Meta(BaseSchema.Meta): model = Agent # TODO: Serialize these as well... - exclude = ("tools", "sources", "core_memory", "tags", "source_passages", "agent_passages", "organization") + exclude = BaseSchema.Meta.exclude + ("sources", "tags", "source_passages", "agent_passages") diff --git a/letta/serialize_schemas/base.py b/letta/serialize_schemas/base.py index b64e76e2..e4051408 100644 --- a/letta/serialize_schemas/base.py +++ b/letta/serialize_schemas/base.py @@ -1,4 +1,10 @@ +from typing import Dict, Optional + +from marshmallow import post_dump, pre_load from marshmallow_sqlalchemy import SQLAlchemyAutoSchema +from sqlalchemy.inspection import inspect + +from letta.schemas.user import User class BaseSchema(SQLAlchemyAutoSchema): @@ -7,6 +13,41 @@ class BaseSchema(SQLAlchemyAutoSchema): This ensures all schemas share the same session. """ + __pydantic_model__ = None + sensitive_ids = {"_created_by_id", "_last_updated_by_id"} + sensitive_relationships = {"organization"} + id_scramble_placeholder = "xxx" + + def __init__(self, *args, actor: Optional[User] = None, **kwargs): + super().__init__(*args, **kwargs) + self.actor = actor + + @post_dump + def sanitize_ids(self, data: Dict, **kwargs): + data["id"] = self.__pydantic_model__.generate_id() + + for sensitive_id in BaseSchema.sensitive_ids.union(BaseSchema.sensitive_relationships): + if sensitive_id in data: + data[sensitive_id] = BaseSchema.id_scramble_placeholder + + return data + + @pre_load + def regenerate_ids(self, data: Dict, **kwargs): + if self.Meta.model: + mapper = inspect(self.Meta.model) + for sensitive_id in BaseSchema.sensitive_ids: + if sensitive_id in mapper.columns: + data[sensitive_id] = self.actor.id + + for relationship in BaseSchema.sensitive_relationships: + if relationship in mapper.relationships: + data[relationship] = self.actor.organization_id + + return data + class Meta: + model = None include_relationships = True load_instance = True + exclude = () diff --git a/letta/serialize_schemas/block.py b/letta/serialize_schemas/block.py new file mode 100644 index 00000000..41139121 --- /dev/null +++ b/letta/serialize_schemas/block.py @@ -0,0 +1,15 @@ +from letta.orm.block import Block +from letta.schemas.block import Block as PydanticBlock +from letta.serialize_schemas.base import BaseSchema + + +class SerializedBlockSchema(BaseSchema): + """ + Marshmallow schema for serializing/deserializing Block objects. + """ + + __pydantic_model__ = PydanticBlock + + class Meta(BaseSchema.Meta): + model = Block + exclude = BaseSchema.Meta.exclude + ("agents",) diff --git a/letta/serialize_schemas/message.py b/letta/serialize_schemas/message.py index 58d055d6..f1300d24 100644 --- a/letta/serialize_schemas/message.py +++ b/letta/serialize_schemas/message.py @@ -1,4 +1,9 @@ +from typing import Dict + +from marshmallow import post_dump + from letta.orm.message import Message +from letta.schemas.message import Message as PydanticMessage from letta.serialize_schemas.base import BaseSchema from letta.serialize_schemas.custom_fields import ToolCallField @@ -8,8 +13,17 @@ class SerializedMessageSchema(BaseSchema): Marshmallow schema for serializing/deserializing Message objects. """ + __pydantic_model__ = PydanticMessage + tool_calls = ToolCallField() + @post_dump + def sanitize_ids(self, data: Dict, **kwargs): + # We don't want to remap here + # Because of the way that message_ids is just a JSON field on agents + # We need to wait for the agent dumps, and then keep track of all the message IDs we remapped + return data + class Meta(BaseSchema.Meta): model = Message - exclude = ("step", "job_message") + exclude = BaseSchema.Meta.exclude + ("step", "job_message", "agent") diff --git a/letta/serialize_schemas/tool.py b/letta/serialize_schemas/tool.py new file mode 100644 index 00000000..fe2debe8 --- /dev/null +++ b/letta/serialize_schemas/tool.py @@ -0,0 +1,15 @@ +from letta.orm import Tool +from letta.schemas.tool import Tool as PydanticTool +from letta.serialize_schemas.base import BaseSchema + + +class SerializedToolSchema(BaseSchema): + """ + Marshmallow schema for serializing/deserializing Tool objects. + """ + + __pydantic_model__ = PydanticTool + + class Meta(BaseSchema.Meta): + model = Tool + exclude = BaseSchema.Meta.exclude diff --git a/letta/server/rest_api/routers/v1/identities.py b/letta/server/rest_api/routers/v1/identities.py index a8a9ad27..b22d355a 100644 --- a/letta/server/rest_api/routers/v1/identities.py +++ b/letta/server/rest_api/routers/v1/identities.py @@ -42,6 +42,8 @@ def list_identities( ) except HTTPException: raise + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") return identities @@ -75,11 +77,11 @@ def create_identity( except UniqueConstraintViolationError: if identity.project_id: raise HTTPException( - status_code=400, + status_code=409, detail=f"An identity with identifier key {identity.identifier_key} already exists for project {identity.project_id}", ) else: - raise HTTPException(status_code=400, detail=f"An identity with identifier key {identity.identifier_key} already exists") + raise HTTPException(status_code=409, detail=f"An identity with identifier key {identity.identifier_key} already exists") except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") @@ -96,6 +98,8 @@ def upsert_identity( return server.identity_manager.upsert_identity(identity=identity, actor=actor) except HTTPException: raise + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") @@ -112,6 +116,8 @@ def modify_identity( return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor) except HTTPException: raise + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") @@ -125,5 +131,12 @@ def delete_identity( """ Delete an identity by its identifier key """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - server.identity_manager.delete_identity(identity_id=identity_id, actor=actor) + try: + actor = server.user_manager.get_user_or_default(user_id=actor_id) + server.identity_manager.delete_identity(identity_id=identity_id, actor=actor) + except HTTPException: + raise + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") diff --git a/letta/server/startup.sh b/letta/server/startup.sh index d4523cce..44b790f1 100755 --- a/letta/server/startup.sh +++ b/letta/server/startup.sh @@ -38,6 +38,15 @@ if ! alembic upgrade head; then fi echo "Database migration completed successfully." +# Set permissions for tool execution directory if configured +if [ -n "$LETTA_SANDBOX_MOUNT_PATH" ]; then + if ! chmod 777 "$LETTA_SANDBOX_MOUNT_PATH"; then + echo "ERROR: Failed to set permissions for tool execution directory at: $LETTA_SANDBOX_MOUNT_PATH" + echo "Please check that the directory exists and is accessible" + exit 1 + fi +fi + # If ADE is enabled, add the --ade flag to the command CMD="letta server --host $HOST --port $PORT" if [ "${SECURE:-false}" = "true" ]; then diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index d57ab21c..06e0ef22 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional import numpy as np from sqlalchemy import Select, and_, func, literal, or_, select, union_all -from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM, MULTI_AGENT_TOOLS +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DATA_SOURCE_ATTACH_ALERT, MAX_EMBEDDING_DIM, MULTI_AGENT_TOOLS from letta.embeddings import embedding_model from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger @@ -35,6 +35,7 @@ from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule from letta.schemas.tool_rule import ToolRule as PydanticToolRule from letta.schemas.user import User as PydanticUser from letta.serialize_schemas import SerializedAgentSchema +from letta.serialize_schemas.tool import SerializedToolSchema from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import ( _process_relationship, @@ -394,18 +395,28 @@ class AgentManager: with self.session_maker() as session: # Retrieve the agent agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - schema = SerializedAgentSchema(session=session) + schema = SerializedAgentSchema(session=session, actor=actor) return schema.dump(agent) @enforce_types - def deserialize(self, serialized_agent: dict, actor: PydanticUser) -> PydanticAgentState: - # TODO: Use actor to override fields + def deserialize(self, serialized_agent: dict, actor: PydanticUser, mark_as_copy: bool = True) -> PydanticAgentState: + tool_data_list = serialized_agent.pop("tools", []) + with self.session_maker() as session: - schema = SerializedAgentSchema(session=session) + schema = SerializedAgentSchema(session=session, actor=actor) agent = schema.load(serialized_agent, session=session) - agent.organization_id = actor.organization_id - agent = agent.create(session, actor=actor) - return agent.to_pydantic() + if mark_as_copy: + agent.name += "_copy" + agent.create(session, actor=actor) + pydantic_agent = agent.to_pydantic() + + # Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle + for tool_data in tool_data_list: + pydantic_tool = SerializedToolSchema(actor=actor).load(tool_data, transient=True).to_pydantic() + pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor) + pydantic_agent = self.attach_tool(agent_id=pydantic_agent.id, tool_id=pydantic_tool.id, actor=actor) + + return pydantic_agent # ====================================================================================================================== # Per Agent Environment Variable Management @@ -670,6 +681,7 @@ class AgentManager: ValueError: If either agent or source doesn't exist IntegrityError: If the source is already attached to the agent """ + with self.session_maker() as session: # Verify both agent and source exist and user has permission to access them agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) @@ -687,7 +699,27 @@ class AgentManager: # Commit the changes agent.update(session, actor=actor) - return agent.to_pydantic() + + # Add system messsage alert to agent + self.append_system_message( + agent_id=agent_id, + content=DATA_SOURCE_ATTACH_ALERT, + actor=actor, + ) + + return agent.to_pydantic() + + @enforce_types + def append_system_message(self, agent_id: str, content: str, actor: PydanticUser): + + # get the agent + agent = self.get_agent_by_id(agent_id=agent_id, actor=actor) + message = PydanticMessage.dict_to_message( + agent_id=agent.id, user_id=actor.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content} + ) + + # update agent in-context message IDs + self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor) @enforce_types def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 26f0bee5..26f7c27b 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -46,7 +46,7 @@ class MessageManager: # Sort results directly based on message_ids result_dict = {msg.id: msg.to_pydantic() for msg in results} - return [result_dict[msg_id] for msg_id in message_ids] + return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids])) @enforce_types def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage: diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index e4e01111..9feaf2a0 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional -from letta.constants import LETTA_DIR_TOOL_SANDBOX +from letta.constants import LETTA_TOOL_EXECUTION_DIR from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel @@ -35,7 +35,7 @@ class SandboxConfigManager: default_config = {} # Empty else: # TODO: May want to move this to environment variables v.s. persisting in database - default_local_sandbox_path = LETTA_DIR_TOOL_SANDBOX + default_local_sandbox_path = LETTA_TOOL_EXECUTION_DIR default_config = LocalSandboxConfig(sandbox_dir=default_local_sandbox_path).model_dump(exclude_none=True) sandbox_config = self.create_or_update_sandbox_config(SandboxConfigCreate(config=default_config), actor=actor) diff --git a/otel-collector-config.yaml b/otel-collector-config.yaml new file mode 100644 index 00000000..d13164ea --- /dev/null +++ b/otel-collector-config.yaml @@ -0,0 +1,32 @@ +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:4317 + http: + endpoint: 0.0.0.0:4318 + +processors: + batch: + timeout: 1s + send_batch_size: 1024 + +exporters: + clickhouse: + endpoint: ${CLICKHOUSE_ENDPOINT} + username: ${CLICKHOUSE_USER} + password: ${CLICKHOUSE_PASSWORD} + database: ${CLICKHOUSE_DATABASE} + timeout: 10s + retry_on_failure: + enabled: true + initial_interval: 5s + max_interval: 30s + max_elapsed_time: 300s + +service: + pipelines: + traces: + receivers: [otlp] + processors: [batch] + exporters: [clickhouse] diff --git a/poetry.lock b/poetry.lock index dbf1d59c..32158702 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "aiohappyeyeballs" -version = "2.4.6" +version = "2.4.8" description = "Happy Eyeballs for asyncio" optional = false python-versions = ">=3.9" files = [ - {file = "aiohappyeyeballs-2.4.6-py3-none-any.whl", hash = "sha256:147ec992cf873d74f5062644332c539fcd42956dc69453fe5204195e560517e1"}, - {file = "aiohappyeyeballs-2.4.6.tar.gz", hash = "sha256:9b05052f9042985d32ecbe4b59a77ae19c006a78f1344d7fdad69d28ded3d0b0"}, + {file = "aiohappyeyeballs-2.4.8-py3-none-any.whl", hash = "sha256:6cac4f5dd6e34a9644e69cf9021ef679e4394f54e58a183056d12009e42ea9e3"}, + {file = "aiohappyeyeballs-2.4.8.tar.gz", hash = "sha256:19728772cb12263077982d2f55453babd8bec6a052a926cd5c0c42796da8bf62"}, ] [[package]] @@ -130,22 +130,22 @@ frozenlist = ">=1.1.0" [[package]] name = "alembic" -version = "1.14.1" +version = "1.15.1" description = "A database migration tool for SQLAlchemy." optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "alembic-1.14.1-py3-none-any.whl", hash = "sha256:1acdd7a3a478e208b0503cd73614d5e4c6efafa4e73518bb60e4f2846a37b1c5"}, - {file = "alembic-1.14.1.tar.gz", hash = "sha256:496e888245a53adf1498fcab31713a469c65836f8de76e01399aa1c3e90dd213"}, + {file = "alembic-1.15.1-py3-none-any.whl", hash = "sha256:197de710da4b3e91cf66a826a5b31b5d59a127ab41bd0fc42863e2902ce2bbbe"}, + {file = "alembic-1.15.1.tar.gz", hash = "sha256:e1a1c738577bca1f27e68728c910cd389b9a92152ff91d902da649c192e30c49"}, ] [package.dependencies] Mako = "*" -SQLAlchemy = ">=1.3.0" -typing-extensions = ">=4" +SQLAlchemy = ">=1.4.0" +typing-extensions = ">=4.12" [package.extras] -tz = ["backports.zoneinfo", "tzdata"] +tz = ["tzdata"] [[package]] name = "annotated-types" @@ -447,17 +447,17 @@ files = [ [[package]] name = "boto3" -version = "1.37.5" +version = "1.37.6" description = "The AWS SDK for Python" optional = true python-versions = ">=3.8" files = [ - {file = "boto3-1.37.5-py3-none-any.whl", hash = "sha256:12166353519aca0cc8d9dcfbbb0d38f8915955a5912b8cb241b2b2314f0dbc14"}, - {file = "boto3-1.37.5.tar.gz", hash = "sha256:ae6e7048beeaa4478368e554a4b290e3928beb0ae8d8767d108d72381a81af30"}, + {file = "boto3-1.37.6-py3-none-any.whl", hash = "sha256:4c661389e68437a3fbc1f63decea24b88f7175e022c68622848d47fdf6e0144f"}, + {file = "boto3-1.37.6.tar.gz", hash = "sha256:e2f4a1edb7e6dbd541c2962117e1c6fea8d5a42788c441a958700a43a3ca7c47"}, ] [package.dependencies] -botocore = ">=1.37.5,<1.38.0" +botocore = ">=1.37.6,<1.38.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.11.0,<0.12.0" @@ -466,13 +466,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.37.5" +version = "1.37.6" description = "Low-level, data-driven core of boto 3." optional = true python-versions = ">=3.8" files = [ - {file = "botocore-1.37.5-py3-none-any.whl", hash = "sha256:e5cfbb8026d5b4fadd9b3a18b61d238a41a8b8f620ab75873dc1467d456150d6"}, - {file = "botocore-1.37.5.tar.gz", hash = "sha256:f8f526d33ae74d242c577e0440b57b9ec7d53edd41db211155ec8087fe7a5a21"}, + {file = "botocore-1.37.6-py3-none-any.whl", hash = "sha256:cd282fe9c8adbb55a08c7290982a98ac6cc4507fa1c493f48bc43fd6c8376a57"}, + {file = "botocore-1.37.6.tar.gz", hash = "sha256:2cb121a403cbec047d76e2401a402a6b2efd3309169037fbac588e8f7125aec4"}, ] [package.dependencies] @@ -2670,30 +2670,24 @@ test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout" [[package]] name = "langchain" -version = "0.3.19" +version = "0.3.20" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "langchain-0.3.19-py3-none-any.whl", hash = "sha256:1e16d97db9106640b7de4c69f8f5ed22eeda56b45b9241279e83f111640eff16"}, - {file = "langchain-0.3.19.tar.gz", hash = "sha256:b96f8a445f01d15d522129ffe77cc89c8468dbd65830d153a676de8f6b899e7b"}, + {file = "langchain-0.3.20-py3-none-any.whl", hash = "sha256:273287f8e61ffdf7e811cf8799e6a71e9381325b8625fd6618900faba79cfdd0"}, + {file = "langchain-0.3.20.tar.gz", hash = "sha256:edcc3241703e1f6557ef5a5c35cd56f9ccc25ff12e38b4829c66d94971737a93"}, ] [package.dependencies] -aiohttp = ">=3.8.3,<4.0.0" async-timeout = {version = ">=4.0.0,<5.0.0", markers = "python_version < \"3.11\""} -langchain-core = ">=0.3.35,<1.0.0" +langchain-core = ">=0.3.41,<1.0.0" langchain-text-splitters = ">=0.3.6,<1.0.0" langsmith = ">=0.1.17,<0.4" -numpy = [ - {version = ">=1.26.4,<2", markers = "python_version < \"3.12\""}, - {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, -] pydantic = ">=2.7.4,<3.0.0" PyYAML = ">=5.3" requests = ">=2,<3" SQLAlchemy = ">=1.4,<3" -tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" [package.extras] anthropic = ["langchain-anthropic"] @@ -2714,26 +2708,23 @@ xai = ["langchain-xai"] [[package]] name = "langchain-community" -version = "0.3.18" +version = "0.3.19" description = "Community contributed LangChain integrations." optional = true python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_community-0.3.18-py3-none-any.whl", hash = "sha256:0d4a70144a1750045c4f726f9a43379ed2484178f76e4b8295bcef3a7fdf41d5"}, - {file = "langchain_community-0.3.18.tar.gz", hash = "sha256:fa2889a8f0b2d22b5c306fd1b070c0970e1f11b604bf55fad2f4a1d0bf68a077"}, + {file = "langchain_community-0.3.19-py3-none-any.whl", hash = "sha256:268ce7b322c0d1961d7bab1a9419d6ff30c99ad09487dca48d47389b69875b16"}, + {file = "langchain_community-0.3.19.tar.gz", hash = "sha256:fc100b6d4d6523566a957cdc306b0500e4982d5b221b98f67432da18ba5b2bf5"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" dataclasses-json = ">=0.5.7,<0.7" httpx-sse = ">=0.4.0,<1.0.0" -langchain = ">=0.3.19,<1.0.0" -langchain-core = ">=0.3.37,<1.0.0" +langchain = ">=0.3.20,<1.0.0" +langchain-core = ">=0.3.41,<1.0.0" langsmith = ">=0.1.125,<0.4" -numpy = [ - {version = ">=1.26.4,<2", markers = "python_version < \"3.12\""}, - {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, -] +numpy = ">=1.26.2,<3" pydantic-settings = ">=2.4.0,<3.0.0" PyYAML = ">=5.3" requests = ">=2,<3" @@ -2742,13 +2733,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10" [[package]] name = "langchain-core" -version = "0.3.40" +version = "0.3.41" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_core-0.3.40-py3-none-any.whl", hash = "sha256:9f31358741f10a13db8531e8288b8a5ae91904018c5c2e6f739d6645a98fca03"}, - {file = "langchain_core-0.3.40.tar.gz", hash = "sha256:893a238b38491967c804662c1ec7c3e6ebaf223d1125331249c3cf3862ff2746"}, + {file = "langchain_core-0.3.41-py3-none-any.whl", hash = "sha256:1a27cca5333bae7597de4004fb634b5f3e71667a3da6493b94ce83bcf15a23bd"}, + {file = "langchain_core-0.3.41.tar.gz", hash = "sha256:d3ee9f3616ebbe7943470ade23d4a04e1729b1512c0ec55a4a07bd2ac64dedb4"}, ] [package.dependencies] @@ -3074,13 +3065,13 @@ llama-index-program-openai = ">=0.3.0,<0.4.0" [[package]] name = "llama-index-readers-file" -version = "0.4.5" +version = "0.4.6" description = "llama-index readers file integration" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "llama_index_readers_file-0.4.5-py3-none-any.whl", hash = "sha256:704ac6b549f0ec59c0bd796007fceced2fff89a44b03d7ee36bce2d26b39e526"}, - {file = "llama_index_readers_file-0.4.5.tar.gz", hash = "sha256:3ce5c8ad7f285bb7ff828c5b2e20088856ac65cf96640287eca770b69a21df88"}, + {file = "llama_index_readers_file-0.4.6-py3-none-any.whl", hash = "sha256:5b5589a528bd3bdf41798406ad0b3ad1a55f28085ff9078a00b61567ff29acba"}, + {file = "llama_index_readers_file-0.4.6.tar.gz", hash = "sha256:50119fdffb7f5aa4638dda2227c79ad6a5f326b9c55a7e46054df99f46a709e0"}, ] [package.dependencies] @@ -3655,13 +3646,13 @@ files = [ [[package]] name = "openai" -version = "1.65.2" +version = "1.65.3" description = "The official Python library for the openai API" optional = false python-versions = ">=3.8" files = [ - {file = "openai-1.65.2-py3-none-any.whl", hash = "sha256:27d9fe8de876e31394c2553c4e6226378b6ed85e480f586ccfe25b7193fb1750"}, - {file = "openai-1.65.2.tar.gz", hash = "sha256:729623efc3fd91c956f35dd387fa5c718edd528c4bed9f00b40ef290200fb2ce"}, + {file = "openai-1.65.3-py3-none-any.whl", hash = "sha256:a155fa5d60eccda516384d3d60d923e083909cc126f383fe4a350f79185c232a"}, + {file = "openai-1.65.3.tar.gz", hash = "sha256:9b7cd8f79140d03d77f4ed8aeec6009be5dcd79bbc02f03b0e8cd83356004f71"}, ] [package.dependencies] @@ -5696,20 +5687,20 @@ pyasn1 = ">=0.1.3" [[package]] name = "s3transfer" -version = "0.11.3" +version = "0.11.4" description = "An Amazon S3 Transfer Manager" optional = true python-versions = ">=3.8" files = [ - {file = "s3transfer-0.11.3-py3-none-any.whl", hash = "sha256:ca855bdeb885174b5ffa95b9913622459d4ad8e331fc98eb01e6d5eb6a30655d"}, - {file = "s3transfer-0.11.3.tar.gz", hash = "sha256:edae4977e3a122445660c7c114bba949f9d191bae3b34a096f18a1c8c354527a"}, + {file = "s3transfer-0.11.4-py3-none-any.whl", hash = "sha256:ac265fa68318763a03bf2dc4f39d5cbd6a9e178d81cc9483ad27da33637e320d"}, + {file = "s3transfer-0.11.4.tar.gz", hash = "sha256:559f161658e1cf0a911f45940552c696735f5c74e64362e515f333ebed87d679"}, ] [package.dependencies] -botocore = ">=1.36.0,<2.0a.0" +botocore = ">=1.37.4,<2.0a.0" [package.extras] -crt = ["botocore[crt] (>=1.36.0,<2.0a.0)"] +crt = ["botocore[crt] (>=1.37.4,<2.0a.0)"] [[package]] name = "scramp" diff --git a/pyproject.toml b/pyproject.toml index 72f2b4b7..29a79b34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.6.35" +version = "0.6.36" packages = [ {include = "letta"}, ] diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index ade9f2f4..2651c08c 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -1,15 +1,27 @@ +import difflib import json +from datetime import datetime, timezone +from typing import Any, Dict, List, Mapping import pytest +from rich.console import Console +from rich.syntax import Syntax from letta import create_client from letta.config import LettaConfig from letta.orm import Base -from letta.schemas.agent import CreateAgent +from letta.schemas.agent import AgentState, CreateAgent +from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import MessageRole from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate +from letta.schemas.organization import Organization +from letta.schemas.user import User from letta.server.server import SyncServer +console = Console() + def _clear_tables(): from letta.server.db import db_context @@ -58,80 +70,291 @@ def default_user(server: SyncServer, default_organization): @pytest.fixture -def sarah_agent(server: SyncServer, default_user, default_organization): +def other_organization(server: SyncServer): + """Fixture to create and return the default organization.""" + org = server.organization_manager.create_organization(pydantic_org=Organization(name="letta")) + yield org + + +@pytest.fixture +def other_user(server: SyncServer, other_organization): + """Fixture to create and return the default user within the default organization.""" + user = server.user_manager.create_user(pydantic_user=User(organization_id=other_organization.id, name="sarah")) + yield user + + +@pytest.fixture +def serialize_test_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + agent_name = f"serialize_test_agent_{timestamp}" + + server.tool_manager.upsert_base_tools(actor=default_user) + agent_state = server.agent_manager.create_agent( agent_create=CreateAgent( - name="sarah_agent", - memory_blocks=[], + name=agent_name, + memory_blocks=[ + CreateBlock( + value="Name: Caren", + label="human", + ), + ], llm_config=LLMConfig.default_config("gpt-4o-mini"), embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=True, ), actor=default_user, ) yield agent_state -def test_agent_serialization(server, sarah_agent, default_user): - """Test serializing an Agent instance to JSON.""" - result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user) - - # Assert that the result is a dictionary (JSON object) - assert isinstance(result, dict), "Expected a dictionary result" - - # Assert that the 'id' field is present and matches the agent's ID - assert "id" in result, "Agent 'id' is missing in the serialized result" - assert result["id"] == sarah_agent.id, f"Expected agent 'id' to be {sarah_agent.id}, but got {result['id']}" - - # Assert that the 'llm_config' and 'embedding_config' fields exist - assert "llm_config" in result, "'llm_config' is missing in the serialized result" - assert "embedding_config" in result, "'embedding_config' is missing in the serialized result" - - # Assert that 'messages' is a list - assert isinstance(result.get("messages", []), list), "'messages' should be a list" - - # Assert that the 'tool_exec_environment_variables' field is a list (empty or populated) - assert isinstance(result.get("tool_exec_environment_variables", []), list), "'tool_exec_environment_variables' should be a list" - - # Assert that the 'agent_type' is a valid string - assert isinstance(result.get("agent_type"), str), "'agent_type' should be a string" - - # Assert that the 'tool_rules' field is a list (even if empty) - assert isinstance(result.get("tool_rules", []), list), "'tool_rules' should be a list" - - # Check that all necessary fields are present in the 'messages' section, focusing on core elements - if "messages" in result: - for message in result["messages"]: - assert "id" in message, "Message 'id' is missing" - assert "text" in message, "Message 'text' is missing" - assert "role" in message, "Message 'role' is missing" - assert "created_at" in message, "Message 'created_at' is missing" - assert "updated_at" in message, "Message 'updated_at' is missing" - - # Optionally check that 'created_at' and 'updated_at' are in ISO 8601 format - assert isinstance(result["created_at"], str), "Expected 'created_at' to be a string" - assert isinstance(result["updated_at"], str), "Expected 'updated_at' to be a string" - - # Optionally check for presence of any required metadata or ensure it is null if expected - assert "metadata_" in result, "'metadata_' field is missing" - assert result["metadata_"] is None, "'metadata_' should be null" - - # Assert that the agent name is as expected (if defined) - assert result.get("name") == sarah_agent.name, "Expected agent 'name' to not be None, but found something else" - - print(json.dumps(result, indent=4)) +# Helper functions below -def test_agent_deserialization_basic(local_client, server, sarah_agent, default_user): +def dict_to_pretty_json(d: Dict[str, Any]) -> str: + """Convert a dictionary to a pretty JSON string with sorted keys, handling datetime objects.""" + return json.dumps(d, indent=2, sort_keys=True, default=_json_serializable) + + +def _json_serializable(obj: Any) -> Any: + """Convert non-serializable objects (like datetime) to a JSON-friendly format.""" + if isinstance(obj, datetime): + return obj.isoformat() # Convert datetime to ISO 8601 format + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + + +def print_dict_diff(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> None: + """Prints a detailed colorized diff between two dictionaries.""" + json1 = dict_to_pretty_json(dict1).splitlines() + json2 = dict_to_pretty_json(dict2).splitlines() + + diff = list(difflib.unified_diff(json1, json2, fromfile="Expected", tofile="Actual", lineterm="")) + + if diff: + console.print("\nšŸ” [bold red]Dictionary Diff:[/bold red]") + diff_text = "\n".join(diff) + syntax = Syntax(diff_text, "diff", theme="monokai", line_numbers=False) + console.print(syntax) + else: + console.print("\nāœ… [bold green]No differences found in dictionaries.[/bold green]") + + +def has_same_prefix(value1: Any, value2: Any) -> bool: + """Check if two string values have the same major prefix (before the second hyphen).""" + if not isinstance(value1, str) or not isinstance(value2, str): + return False + + prefix1 = value1.split("-")[0] + prefix2 = value2.split("-")[0] + + return prefix1 == prefix2 + + +def compare_lists(list1: List[Any], list2: List[Any]) -> bool: + """Compare lists while handling unordered dictionaries inside.""" + if len(list1) != len(list2): + return False + + if all(isinstance(item, Mapping) for item in list1) and all(isinstance(item, Mapping) for item in list2): + return all(any(_compare_agent_state_model_dump(i1, i2, log=False) for i2 in list2) for i1 in list1) + + return sorted(list1) == sorted(list2) + + +def strip_datetime_fields(d: Dict[str, Any]) -> Dict[str, Any]: + """Remove datetime fields from a dictionary before comparison.""" + return {k: v for k, v in d.items() if not isinstance(v, datetime)} + + +def _log_mismatch(key: str, expected: Any, actual: Any, log: bool) -> None: + """Log detailed information about a mismatch.""" + if log: + print(f"\nšŸ”“ Mismatch Found in Key: '{key}'") + print(f"Expected: {expected}") + print(f"Actual: {actual}") + + if isinstance(expected, str) and isinstance(actual, str): + print("\nšŸ” String Diff:") + diff = difflib.ndiff(expected.splitlines(), actual.splitlines()) + print("\n".join(diff)) + + +def _compare_agent_state_model_dump(d1: Dict[str, Any], d2: Dict[str, Any], log: bool = True) -> bool: + """ + Compare two dictionaries with special handling: + - Keys in `ignore_prefix_fields` should match only by prefix. + - 'message_ids' lists should match in length only. + - Datetime fields are ignored. + - Order-independent comparison for lists of dicts. + """ + ignore_prefix_fields = {"id", "last_updated_by_id", "organization_id", "created_by_id"} + + # Remove datetime fields upfront + d1 = strip_datetime_fields(d1) + d2 = strip_datetime_fields(d2) + + if d1.keys() != d2.keys(): + _log_mismatch("dict_keys", set(d1.keys()), set(d2.keys())) + return False + + for key, v1 in d1.items(): + v2 = d2[key] + + if key in ignore_prefix_fields: + if v1 and v2 and not has_same_prefix(v1, v2): + _log_mismatch(key, v1, v2, log) + return False + elif key == "message_ids": + if not isinstance(v1, list) or not isinstance(v2, list) or len(v1) != len(v2): + _log_mismatch(key, v1, v2, log) + return False + elif isinstance(v1, Dict) and isinstance(v2, Dict): + if not _compare_agent_state_model_dump(v1, v2): + _log_mismatch(key, v1, v2, log) + return False + elif isinstance(v1, list) and isinstance(v2, list): + if not compare_lists(v1, v2): + _log_mismatch(key, v1, v2, log) + return False + elif v1 != v2: + _log_mismatch(key, v1, v2, log) + return False + + return True + + +def compare_agent_state(original: AgentState, copy: AgentState, mark_as_copy: bool) -> bool: + """Wrapper function that provides a default set of ignored prefix fields.""" + if not mark_as_copy: + assert original.name == copy.name + + return _compare_agent_state_model_dump(original.model_dump(exclude="name"), copy.model_dump(exclude="name")) + + +# Sanity tests for our agent model_dump verifier helpers + + +def test_sanity_identical_dicts(): + d1 = {"name": "Alice", "age": 30, "details": {"city": "New York"}} + d2 = {"name": "Alice", "age": 30, "details": {"city": "New York"}} + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_different_dicts(): + d1 = {"name": "Alice", "age": 30} + d2 = {"name": "Bob", "age": 30} + assert not _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_ignored_id_fields(): + d1 = {"id": "user-abc123", "name": "Alice"} + d2 = {"id": "user-xyz789", "name": "Alice"} # Different ID, same prefix + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_different_id_prefix_fails(): + d1 = {"id": "user-abc123"} + d2 = {"id": "admin-xyz789"} # Different prefix + assert not _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_nested_dicts(): + d1 = {"user": {"id": "user-123", "name": "Alice"}} + d2 = {"user": {"id": "user-456", "name": "Alice"}} # ID changes, but prefix matches + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_list_handling(): + d1 = {"items": [1, 2, 3]} + d2 = {"items": [1, 2, 3]} + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_list_mismatch(): + d1 = {"items": [1, 2, 3]} + d2 = {"items": [1, 2, 4]} + assert not _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_message_ids_length_check(): + d1 = {"message_ids": ["msg-123", "msg-456", "msg-789"]} + d2 = {"message_ids": ["msg-abc", "msg-def", "msg-ghi"]} # Same length, different values + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_message_ids_different_length(): + d1 = {"message_ids": ["msg-123", "msg-456"]} + d2 = {"message_ids": ["msg-123"]} + assert not _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_datetime_fields(): + d1 = {"created_at": datetime(2025, 3, 4, 18, 25, 37, tzinfo=timezone.utc)} + d2 = {"created_at": datetime(2025, 3, 4, 18, 25, 37, tzinfo=timezone.utc)} + assert _compare_agent_state_model_dump(d1, d2) + + +def test_sanity_datetime_mismatch(): + d1 = {"created_at": datetime(2025, 3, 4, 18, 25, 37, tzinfo=timezone.utc)} + d2 = {"created_at": datetime(2025, 3, 4, 18, 25, 38, tzinfo=timezone.utc)} # One second difference + assert _compare_agent_state_model_dump(d1, d2) # Should ignore + + +# Agent serialize/deserialize tests + + +@pytest.mark.parametrize("mark_as_copy", [True, False]) +def test_mark_as_copy_simple(local_client, server, serialize_test_agent, default_user, other_user, mark_as_copy): """Test deserializing JSON into an Agent instance.""" - # Send a message first - sarah_agent = server.agent_manager.get_agent_by_id(agent_id=sarah_agent.id, actor=default_user) - result = server.agent_manager.serialize(agent_id=sarah_agent.id, actor=default_user) + result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) - # Delete the agent - server.agent_manager.delete_agent(sarah_agent.id, actor=default_user) + # Deserialize the agent + agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user, mark_as_copy=mark_as_copy) - agent_state = server.agent_manager.deserialize(serialized_agent=result, actor=default_user) + # Compare serialized representations to check for exact match + print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json())) + assert compare_agent_state(agent_copy, serialize_test_agent, mark_as_copy=mark_as_copy) - assert agent_state.name == sarah_agent.name - assert len(agent_state.message_ids) == len(sarah_agent.message_ids) + +def test_in_context_message_id_remapping(local_client, server, serialize_test_agent, default_user, other_user): + """Test deserializing JSON into an Agent instance.""" + result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) + + # Check remapping on message_ids and messages is consistent + assert sorted([m["id"] for m in result["messages"]]) == sorted(result["message_ids"]) + + # Deserialize the agent + agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user) + + # Make sure all the messages are able to be retrieved + in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_copy.id, actor=other_user) + assert len(in_context_messages) == len(result["message_ids"]) + assert sorted([m.id for m in in_context_messages]) == sorted(result["message_ids"]) + + +def test_agent_serialize_with_user_messages(local_client, server, serialize_test_agent, default_user, other_user): + """Test deserializing JSON into an Agent instance.""" + mark_as_copy = False + server.send_messages( + actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="hello")] + ) + result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) + + # Deserialize the agent + agent_copy = server.agent_manager.deserialize(serialized_agent=result, actor=other_user, mark_as_copy=mark_as_copy) + + # Get most recent original agent instance + serialize_test_agent = server.agent_manager.get_agent_by_id(agent_id=serialize_test_agent.id, actor=default_user) + + # Compare serialized representations to check for exact match + print_dict_diff(json.loads(serialize_test_agent.model_dump_json()), json.loads(agent_copy.model_dump_json())) + assert compare_agent_state(agent_copy, serialize_test_agent, mark_as_copy=mark_as_copy) + + # Make sure both agents can receive messages after + server.send_messages( + actor=default_user, agent_id=serialize_test_agent.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")] + ) + server.send_messages( + actor=other_user, agent_id=agent_copy.id, messages=[MessageCreate(role=MessageRole.user, content="and hello again")] + ) diff --git a/tests/test_managers.py b/tests/test_managers.py index 52206d72..334e72ed 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -8,7 +8,7 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open from sqlalchemy.exc import IntegrityError from letta.config import LettaConfig -from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MULTI_AGENT_TOOLS +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_TOOL_EXECUTION_DIR, MULTI_AGENT_TOOLS from letta.embeddings import embedding_model from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.orm import Base @@ -2340,7 +2340,7 @@ def test_create_local_sandbox_config_defaults(server: SyncServer, default_user): # Assertions assert created_config.type == SandboxType.LOCAL assert created_config.get_local_config() == sandbox_config_create.config - assert created_config.get_local_config().sandbox_dir in {"~/.letta", tool_settings.local_sandbox_dir} + assert created_config.get_local_config().sandbox_dir in {LETTA_TOOL_EXECUTION_DIR, tool_settings.local_sandbox_dir} assert created_config.organization_id == default_user.organization_id