From 810628acd944998b1347178da0792fc9f8919ffa Mon Sep 17 00:00:00 2001 From: cthomas Date: Sat, 19 Jul 2025 09:28:54 -0700 Subject: [PATCH] feat: remove organization from pydantic message model (#3411) --- letta/orm/message.py | 2 +- letta/schemas/message.py | 1 - letta/serialize_schemas/marshmallow_message.py | 4 +--- letta/server/rest_api/utils.py | 7 ------- letta/services/helpers/agent_manager_helper.py | 4 ---- letta/services/message_manager.py | 4 ++-- tests/test_managers.py | 15 --------------- 7 files changed, 4 insertions(+), 33 deletions(-) diff --git a/letta/orm/message.py b/letta/orm/message.py index 100e8bff..5e96790d 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -62,7 +62,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): ) # Relationships - organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin") + organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="raise") step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin") # Job relationship diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 417540c8..ec690577 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -150,7 +150,6 @@ class Message(BaseMessage): """ id: str = BaseMessage.generate_id_field() - organization_id: Optional[str] = Field(default=None, description="The unique identifier of the organization.") agent_id: Optional[str] = Field(default=None, description="The unique identifier of the agent.") model: Optional[str] = Field(default=None, description="The model used to make the function call.") # Basic OpenAI-style fields diff --git a/letta/serialize_schemas/marshmallow_message.py b/letta/serialize_schemas/marshmallow_message.py index 1a8a8c22..75678bd7 100644 --- a/letta/serialize_schemas/marshmallow_message.py +++ b/letta/serialize_schemas/marshmallow_message.py @@ -23,7 +23,6 @@ class SerializedMessageSchema(BaseSchema): # agent dump will then get rid of message ids del data["_created_by_id"] del data["_last_updated_by_id"] - del data["organization"] return data @@ -33,10 +32,9 @@ class SerializedMessageSchema(BaseSchema): # Skip regenerating ID, as agent dump will do it data["_created_by_id"] = self.actor.id data["_last_updated_by_id"] = self.actor.id - data["organization"] = self.actor.organization_id return data class Meta(BaseSchema.Meta): model = Message - exclude = BaseSchema.Meta.exclude + ("step", "job_message", "otid", "is_deleted") + exclude = BaseSchema.Meta.exclude + ("step", "job_message", "otid", "is_deleted", "organization") diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 1d47647e..17077fe3 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -174,8 +174,6 @@ def create_input_messages(input_messages: List[MessageCreate], agent_id: str, ti """ messages = convert_message_creates_to_messages(input_messages, agent_id, timezone, wrap_user_message=False, wrap_system_message=False) - for message in messages: - message.organization_id = actor.organization_id return messages @@ -214,7 +212,6 @@ def create_letta_messages_from_llm_response( assistant_message = Message( role=MessageRole.assistant, content=reasoning_content if reasoning_content else [], - organization_id=actor.organization_id, agent_id=agent_id, model=model, tool_calls=[tool_call], @@ -231,7 +228,6 @@ def create_letta_messages_from_llm_response( tool_message = Message( role=MessageRole.tool, content=[TextContent(text=package_function_response(function_call_success, function_response, timezone))], - organization_id=actor.organization_id, agent_id=agent_id, model=model, tool_calls=[], @@ -284,7 +280,6 @@ def create_heartbeat_system_message( heartbeat_system_message = Message( role=MessageRole.user, content=[TextContent(text=get_heartbeat(timezone, text_content))], - organization_id=actor.organization_id, agent_id=agent_id, model=model, tool_calls=[], @@ -360,7 +355,6 @@ def convert_in_context_letta_messages_to_openai(in_context_messages: List[Messag id=msg.id, role=msg.role, content=[TextContent(text=extracted_text)], - organization_id=msg.organization_id, agent_id=msg.agent_id, model=msg.model, name=msg.name, @@ -389,7 +383,6 @@ def convert_in_context_letta_messages_to_openai(in_context_messages: List[Messag id=msg.id, role=msg.role, content=[TextContent(text=actual_user_text)], - organization_id=msg.organization_id, agent_id=msg.agent_id, model=msg.model, name=msg.name, diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 3d7a1ef6..8c70f725 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -382,7 +382,6 @@ def package_initial_message_sequence( role=message_create.role, content=[TextContent(text=packed_message)], name=message_create.name, - organization_id=actor.organization_id, agent_id=agent_id, model=model, ) @@ -397,7 +396,6 @@ def package_initial_message_sequence( role=message_create.role, content=[TextContent(text=packed_message)], name=message_create.name, - organization_id=actor.organization_id, agent_id=agent_id, model=model, ) @@ -418,7 +416,6 @@ def package_initial_message_sequence( role=MessageRole.assistant, content=None, name=message_create.name, - organization_id=actor.organization_id, agent_id=agent_id, model=model, tool_calls=[ @@ -438,7 +435,6 @@ def package_initial_message_sequence( role=MessageRole.tool, content=[TextContent(text=function_response)], name=message_create.name, - organization_id=actor.organization_id, agent_id=agent_id, model=model, tool_call_id=tool_call_id, diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 46e04d2c..9d1fa352 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -88,8 +88,8 @@ class MessageManager: """Create a new message.""" with db_registry.session() as session: # Set the organization id of the Pydantic message - pydantic_msg.organization_id = actor.organization_id msg_data = pydantic_msg.model_dump(to_orm=True) + msg_data["organization_id"] = actor.organization_id msg = MessageModel(**msg_data) msg.create(session, actor=actor) # Persist to database return msg.to_pydantic() @@ -99,8 +99,8 @@ class MessageManager: orm_messages = [] for pydantic_msg in pydantic_msgs: # Set the organization id of the Pydantic message - pydantic_msg.organization_id = actor.organization_id msg_data = pydantic_msg.model_dump(to_orm=True) + msg_data["organization_id"] = actor.organization_id orm_messages.append(MessageModel(**msg_data)) return orm_messages diff --git a/tests/test_managers.py b/tests/test_managers.py index b25e5c80..a34338f8 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -378,7 +378,6 @@ def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent): """Fixture to create a tool with default settings and clean up after the test.""" # Set up message message = PydanticMessage( - organization_id=default_user.organization_id, agent_id=sarah_agent.id, role="user", content=[TextContent(text="Hello, world!")], @@ -1815,7 +1814,6 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a msg1 = server.message_manager.create_message( PydanticMessage( agent_id=sarah_agent.id, - organization_id=default_user.organization_id, role="user", content=[TextContent(text="Hello, Sarah!")], ), @@ -1824,7 +1822,6 @@ async def test_reset_messages_with_existing_messages(server: SyncServer, sarah_a msg2 = server.message_manager.create_message( PydanticMessage( agent_id=sarah_agent.id, - organization_id=default_user.organization_id, role="assistant", content=[TextContent(text="Hello, user!")], ), @@ -1859,7 +1856,6 @@ async def test_reset_messages_idempotency(server: SyncServer, sarah_agent, defau server.message_manager.create_message( PydanticMessage( agent_id=sarah_agent.id, - organization_id=default_user.organization_id, role="user", content=[TextContent(text="Hello, Sarah!")], ), @@ -1889,7 +1885,6 @@ async def test_reset_messages_preserves_system_message_id(server: SyncServer, sa server.message_manager.create_message( PydanticMessage( agent_id=sarah_agent.id, - organization_id=default_user.organization_id, role="user", content=[TextContent(text="Hello!")], ), @@ -1923,7 +1918,6 @@ async def test_reset_messages_preserves_system_message_content(server: SyncServe server.message_manager.create_message( PydanticMessage( agent_id=sarah_agent.id, - organization_id=default_user.organization_id, role="user", content=[TextContent(text="Hello!")], ), @@ -3526,7 +3520,6 @@ def test_message_size(server: SyncServer, hello_world_message_fixture, default_u # Create additional test messages messages = [ PydanticMessage( - organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, content=[TextContent(text=f"Test message {i}")], @@ -3557,7 +3550,6 @@ def create_test_messages(server: SyncServer, base_message: PydanticMessage, defa """Helper function to create test messages for all tests""" messages = [ PydanticMessage( - organization_id=default_user.organization_id, agent_id=base_message.agent_id, role=base_message.role, content=[TextContent(text=f"Test message {i}")], @@ -6105,7 +6097,6 @@ def test_job_messages_pagination(server: SyncServer, default_run, default_user, message_ids = [] for i in range(5): message = PydanticMessage( - organization_id=default_user.organization_id, agent_id=sarah_agent.id, role=MessageRole.user, content=[TextContent(text=f"Test message {i}")], @@ -6222,7 +6213,6 @@ def test_job_messages_ordering(server: SyncServer, default_run, default_user, sa message = PydanticMessage( role=MessageRole.user, content=[TextContent(text="Test message")], - organization_id=default_user.organization_id, agent_id=sarah_agent.id, created_at=created_at, ) @@ -6291,19 +6281,16 @@ def test_job_messages_filter(server: SyncServer, default_run, default_user, sara PydanticMessage( role=MessageRole.user, content=[TextContent(text="Hello")], - organization_id=default_user.organization_id, agent_id=sarah_agent.id, ), PydanticMessage( role=MessageRole.assistant, content=[TextContent(text="Hi there!")], - organization_id=default_user.organization_id, agent_id=sarah_agent.id, ), PydanticMessage( role=MessageRole.assistant, content=[TextContent(text="Let me help you with that")], - organization_id=default_user.organization_id, agent_id=sarah_agent.id, tool_calls=[ OpenAIToolCall( @@ -6354,7 +6341,6 @@ def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_ # Add some messages messages = [ PydanticMessage( - organization_id=default_user.organization_id, agent_id=sarah_agent.id, role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')], @@ -6405,7 +6391,6 @@ def test_get_run_messages_with_assistant_message(server: SyncServer, default_use # Add some messages messages = [ PydanticMessage( - organization_id=default_user.organization_id, agent_id=sarah_agent.id, role=MessageRole.tool if i % 2 == 0 else MessageRole.assistant, content=[TextContent(text=f"Test message {i}" if i % 2 == 1 else '{"status": "OK"}')],