From e4700fdfe1636a6be0c1eefe85a5cecf3c962862 Mon Sep 17 00:00:00 2001 From: mlong93 <35275280+mlong93@users.noreply.github.com> Date: Wed, 15 Jan 2025 14:47:20 -0800 Subject: [PATCH] fix: updated `send_message_async` request params, `get_run_messages` returns `LettaMessage` (#638) Co-authored-by: Mindy Long Co-authored-by: cthomas Co-authored-by: Shubham Naik Co-authored-by: Shubham Naik --- ...013e_adding_request_config_to_job_table.py | 31 +++++++ examples/tool_rule_usage.py | 2 +- letta/client/client.py | 47 +--------- letta/orm/job.py | 4 + letta/schemas/letta_request.py | 10 +- letta/schemas/message.py | 4 +- letta/schemas/run.py | 2 + letta/server/rest_api/routers/v1/agents.py | 19 ++-- letta/server/rest_api/routers/v1/runs.py | 47 ++++------ letta/services/job_manager.py | 91 +++++++++++++++---- tests/test_client.py | 2 - tests/test_managers.py | 65 +++++++++---- ...nce.py => test_model_letta_performance.py} | 0 tests/test_v1_routes.py | 57 ++++-------- 14 files changed, 218 insertions(+), 163 deletions(-) create mode 100644 alembic/versions/f595e0e8013e_adding_request_config_to_job_table.py rename tests/{test_model_letta_perfomance.py => test_model_letta_performance.py} (100%) diff --git a/alembic/versions/f595e0e8013e_adding_request_config_to_job_table.py b/alembic/versions/f595e0e8013e_adding_request_config_to_job_table.py new file mode 100644 index 00000000..d53a30a2 --- /dev/null +++ b/alembic/versions/f595e0e8013e_adding_request_config_to_job_table.py @@ -0,0 +1,31 @@ +"""adding request_config to Job table + +Revision ID: f595e0e8013e +Revises: 7f652fdd3dba +Create Date: 2025-01-14 14:34:34.203363 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f595e0e8013e" +down_revision: Union[str, None] = "7f652fdd3dba" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("jobs", sa.Column("request_config", sa.JSON, nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("jobs", "request_config") + # ### end Alembic commands ### diff --git a/examples/tool_rule_usage.py b/examples/tool_rule_usage.py index 54e051e2..8ec061d0 100644 --- a/examples/tool_rule_usage.py +++ b/examples/tool_rule_usage.py @@ -6,7 +6,7 @@ from letta.schemas.letta_message import ToolCallMessage from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from tests.helpers.endpoints_helper import assert_invoked_send_message_with_keyword, setup_agent from tests.helpers.utils import cleanup -from tests.test_model_letta_perfomance import llm_config_dir +from tests.test_model_letta_performance import llm_config_dir """ This example shows how you can constrain tool calls in your agent. diff --git a/letta/client/client.py b/letta/client/client.py index ec62bcf9..c4e1497f 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1,6 +1,5 @@ import logging import time -from datetime import datetime from typing import Callable, Dict, Generator, List, Optional, Union import requests @@ -23,6 +22,7 @@ from letta.schemas.environment_variables import ( ) from letta.schemas.file import FileMetadata from letta.schemas.job import Job +from letta.schemas.letta_message import LettaMessage, LettaMessageUnion from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig @@ -1999,46 +1999,27 @@ class RESTClient(AbstractClient): self, run_id: str, cursor: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, limit: Optional[int] = 100, - query_text: Optional[str] = None, ascending: bool = True, - tags: Optional[List[str]] = None, - match_all_tags: bool = False, role: Optional[MessageRole] = None, - tool_name: Optional[str] = None, - ) -> List[Message]: + ) -> List[LettaMessageUnion]: """ Get messages associated with a job with filtering options. Args: job_id: ID of the job cursor: Cursor for pagination - start_date: Filter messages after this date - end_date: Filter messages before this date limit: Maximum number of messages to return - query_text: Search text in message content ascending: Sort order by creation time - tags: Filter by message tags - match_all_tags: If true, match all tags. If false, match any tag role: Filter by message role (user/assistant/system/tool) - tool_name: Filter by tool call name - Returns: List of messages matching the filter criteria """ params = { "cursor": cursor, - "start_date": start_date.isoformat() if start_date else None, - "end_date": end_date.isoformat() if end_date else None, "limit": limit, - "query_text": query_text, "ascending": ascending, - "tags": tags, - "match_all_tags": match_all_tags, "role": role, - "tool_name": tool_name, } # Remove None values params = {k: v for k, v in params.items() if v is not None} @@ -2046,7 +2027,7 @@ class RESTClient(AbstractClient): response = requests.get(f"{self.base_url}/{self.api_prefix}/runs/{run_id}/messages", params=params) if response.status_code != 200: raise ValueError(f"Failed to get run messages: {response.text}") - return [Message(**message) for message in response.json()] + return [LettaMessage(**message) for message in response.json()] def get_run_usage( self, @@ -3621,48 +3602,30 @@ class LocalClient(AbstractClient): self, run_id: str, cursor: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, limit: Optional[int] = 100, - query_text: Optional[str] = None, ascending: bool = True, - tags: Optional[List[str]] = None, - match_all_tags: bool = False, role: Optional[MessageRole] = None, - tool_name: Optional[str] = None, - ) -> List[Message]: + ) -> List[LettaMessageUnion]: """ Get messages associated with a job with filtering options. Args: run_id: ID of the run cursor: Cursor for pagination - start_date: Filter messages after this date - end_date: Filter messages before this date limit: Maximum number of messages to return - query_text: Search text in message content ascending: Sort order by creation time - tags: Filter by message tags - match_all_tags: If true, match all tags. If false, match any tag role: Filter by message role (user/assistant/system/tool) - tool_name: Filter by tool call name Returns: List of messages matching the filter criteria """ params = { "cursor": cursor, - "start_date": start_date.isoformat() if start_date else None, - "end_date": end_date.isoformat() if end_date else None, "limit": limit, - "query_text": query_text, "ascending": ascending, - "tags": tags, - "match_all_tags": match_all_tags, "role": role, - "tool_name": tool_name, } - return self.server.job_manager.get_job_messages(job_id=run_id, actor=self.user, job_type=JobType.RUN, **params) + return self.server.job_manager.get_run_messages_cursor(run_id=run_id, actor=self.user, **params) def get_run_usage( self, diff --git a/letta/orm/job.py b/letta/orm/job.py index de4663e6..95e67006 100644 --- a/letta/orm/job.py +++ b/letta/orm/job.py @@ -9,6 +9,7 @@ from letta.orm.mixins import UserMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.enums import JobStatus from letta.schemas.job import Job as PydanticJob +from letta.schemas.letta_request import LettaRequestConfig if TYPE_CHECKING: from letta.orm.job_messages import JobMessage @@ -33,6 +34,9 @@ class Job(SqlalchemyBase, UserMixin): default=JobType.JOB, doc="The type of job. This affects whether or not we generate json_schema and source_code on the fly.", ) + request_config: Mapped[Optional[LettaRequestConfig]] = mapped_column( + JSON, nullable=True, doc="The request configuration for the job, stored as JSON." + ) # relationships user: Mapped["User"] = relationship("User", back_populates="jobs") diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index f1f8f450..663dba14 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -6,11 +6,8 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.message import MessageCreate -class LettaRequest(BaseModel): - messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.") - +class LettaRequestConfig(BaseModel): # Flags to support the use of AssistantMessage message types - use_assistant_message: bool = Field( default=True, description="Whether the server should parse specific tool call arguments (default `send_message`) as `AssistantMessage` objects.", @@ -25,6 +22,11 @@ class LettaRequest(BaseModel): ) +class LettaRequest(BaseModel): + messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.") + config: LettaRequestConfig = Field(default=LettaRequestConfig(), description="Configuration options for the LettaRequest.") + + class LettaStreamingRequest(LettaRequest): stream_tokens: bool = Field( default=False, diff --git a/letta/schemas/message.py b/letta/schemas/message.py index ea04ec10..df09aa25 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -149,9 +149,9 @@ class Message(BaseMessage): # We need to unpack the actual message contents from the function call try: func_args = json.loads(tool_call.function.arguments) - message_string = func_args[DEFAULT_MESSAGE_TOOL_KWARG] + message_string = func_args[assistant_message_tool_kwarg] except KeyError: - raise ValueError(f"Function call {tool_call.function.name} missing {DEFAULT_MESSAGE_TOOL_KWARG} argument") + raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument") messages.append( AssistantMessage( id=self.id, diff --git a/letta/schemas/run.py b/letta/schemas/run.py index 9d348147..b455a211 100644 --- a/letta/schemas/run.py +++ b/letta/schemas/run.py @@ -4,6 +4,7 @@ from pydantic import Field from letta.orm.enums import JobType from letta.schemas.job import Job, JobBase +from letta.schemas.letta_request import LettaRequestConfig class RunBase(JobBase): @@ -28,6 +29,7 @@ class Run(RunBase): id: str = RunBase.generate_id_field() user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the run.") + request_config: Optional[LettaRequestConfig] = Field(None, description="The request configuration for the run.") @classmethod def from_job(cls, job: Job) -> "Run": diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 390b0eee..53b0c290 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -502,9 +502,9 @@ async def send_message( stream_steps=False, stream_tokens=False, # Support for AssistantMessage - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + use_assistant_message=request.config.use_assistant_message, + assistant_message_tool_name=request.config.assistant_message_tool_name, + assistant_message_tool_kwarg=request.config.assistant_message_tool_kwarg, ) return result @@ -542,9 +542,9 @@ async def send_message_streaming( stream_steps=True, stream_tokens=request.stream_tokens, # Support for AssistantMessage - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + use_assistant_message=request.config.use_assistant_message, + assistant_message_tool_name=request.config.assistant_message_tool_name, + assistant_message_tool_kwarg=request.config.assistant_message_tool_kwarg, ) return result @@ -622,6 +622,7 @@ async def send_message_async( "job_type": "send_message_async", "agent_id": agent_id, }, + request_config=request.config, ) run = server.job_manager.create_job(pydantic_job=run, actor=actor) @@ -633,9 +634,9 @@ async def send_message_async( actor=actor, agent_id=agent_id, messages=request.messages, - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + use_assistant_message=request.config.use_assistant_message, + assistant_message_tool_name=request.config.assistant_message_tool_name, + assistant_message_tool_kwarg=request.config.assistant_message_tool_kwarg, ) return run diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index a659b193..34cbb889 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -1,4 +1,3 @@ -from datetime import datetime from typing import List, Optional from fastapi import APIRouter, Depends, Header, HTTPException, Query @@ -6,7 +5,7 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query from letta.orm.enums import JobType from letta.orm.errors import NoResultFound from letta.schemas.enums import JobStatus, MessageRole -from letta.schemas.message import Message +from letta.schemas.letta_message import LettaMessageUnion from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.run import Run from letta.server.rest_api.utils import get_letta_server @@ -61,21 +60,15 @@ def get_run( raise HTTPException(status_code=404, detail="Run not found") -@router.get("/{run_id}/messages", response_model=List[Message], operation_id="get_run_messages") -def get_run_messages( +@router.get("/{run_id}/messages", response_model=List[LettaMessageUnion], operation_id="get_run_messages") +async def get_run_messages( run_id: str, - cursor: Optional[str] = Query(None, description="Cursor for pagination"), - start_date: Optional[datetime] = Query(None, description="Filter messages after this date"), - end_date: Optional[datetime] = Query(None, description="Filter messages before this date"), - limit: Optional[int] = Query(100, description="Maximum number of messages to return"), - query_text: Optional[str] = Query(None, description="Search text in message content"), - ascending: bool = Query(True, description="Sort order by creation time"), - tags: Optional[List[str]] = Query(None, description="Filter by message tags"), - match_all_tags: bool = Query(False, description="If true, match all tags. If false, match any tag"), - role: Optional[MessageRole] = Query(None, description="Filter by message role"), - tool_name: Optional[str] = Query(None, description="Filter by tool call name"), - user_id: Optional[str] = Header(None, alias="user_id"), server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), + cursor: Optional[str] = Query(None, description="Cursor for pagination"), + limit: Optional[int] = Query(100, description="Maximum number of messages to return"), + ascending: bool = Query(True, description="Sort order by creation time"), + role: Optional[MessageRole] = Query(None, description="Filter by role"), ): """ Get messages associated with a run with filtering options. @@ -83,33 +76,25 @@ def get_run_messages( Args: run_id: ID of the run cursor: Cursor for pagination - start_date: Filter messages after this date - end_date: Filter messages before this date limit: Maximum number of messages to return - query_text: Search text in message content ascending: Sort order by creation time - tags: Filter by message tags - match_all_tags: If true, match all tags. If false, match any tag - role: Filter by message role (user/assistant/system/tool) - tool_name: Filter by tool call name + role: Filter by role (user/assistant/system/tool) + return_message_object: Whether to return Message objects or LettaMessage objects user_id: ID of the user making the request + + Returns: + A list of messages associated with the run. Default is List[LettaMessage]. """ actor = server.user_manager.get_user_or_default(user_id=user_id) try: - messages = server.job_manager.get_job_messages( - job_id=run_id, + messages = server.job_manager.get_run_messages_cursor( + run_id=run_id, actor=actor, - cursor=cursor, - start_date=start_date, - end_date=end_date, limit=limit, - query_text=query_text, + cursor=cursor, ascending=ascending, - tags=tags, - match_all_tags=match_all_tags, role=role, - tool_name=tool_name, ) return messages except NoResultFound as e: diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index e5753730..b8ea803b 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -1,4 +1,3 @@ -from datetime import datetime from typing import List, Literal, Optional, Union from sqlalchemy import select @@ -14,6 +13,8 @@ from letta.orm.sqlalchemy_base import AccessType from letta.schemas.enums import JobStatus, MessageRole from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate +from letta.schemas.letta_message import LettaMessage +from letta.schemas.letta_request import LettaRequestConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.run import Run as PydanticRun from letta.schemas.usage import LettaUsageStatistics @@ -108,15 +109,9 @@ class JobManager: job_id: str, actor: PydanticUser, cursor: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, limit: Optional[int] = 100, - query_text: Optional[str] = None, - ascending: bool = True, - tags: Optional[List[str]] = None, - match_all_tags: bool = False, role: Optional[MessageRole] = None, - tool_name: Optional[str] = None, + ascending: bool = True, ) -> List[PydanticMessage]: """ Get all messages associated with a job. @@ -127,7 +122,7 @@ class JobManager: cursor: Cursor for pagination limit: Maximum number of messages to return role: Optional filter for message role - tool_name: Optional filter for tool call name + ascending: Optional flag to sort in ascending order Returns: List of messages associated with the job @@ -145,24 +140,15 @@ class JobManager: messages = MessageModel.list( db_session=session, cursor=cursor, - start_date=start_date, - end_date=end_date, - query_text=query_text, ascending=ascending, limit=limit, - tags=tags, - match_all_tags=match_all_tags, actor=actor, join_model=JobMessage, join_conditions=[MessageModel.id == JobMessage.message_id, JobMessage.job_id == job_id], **filters, ) - # Filter by tool name if specified - if tool_name is not None: - messages = [msg for msg in messages if msg.tool_calls and any(call.function.name == tool_name for call in msg.tool_calls)] - - return [message.to_pydantic() for message in messages] + return [message.to_pydantic() for message in messages] @enforce_types def add_message_to_job(self, job_id: str, message_id: str, actor: PydanticUser) -> None: @@ -268,6 +254,58 @@ class JobManager: session.add(usage_stats) session.commit() + @enforce_types + def get_run_messages_cursor( + self, + run_id: str, + actor: PydanticUser, + cursor: Optional[str] = None, + limit: Optional[int] = 100, + role: Optional[MessageRole] = None, + ascending: bool = True, + ) -> List[LettaMessage]: + """ + Get messages associated with a job using cursor-based pagination. + This is a wrapper around get_job_messages that provides cursor-based pagination. + + Args: + job_id: The ID of the job to get messages for + actor: The user making the request + cursor: Message ID to get messages after or before + limit: Maximum number of messages to return + ascending: Whether to return messages in ascending order + role: Optional role filter + + Returns: + List of LettaMessages associated with the job + + Raises: + NoResultFound: If the job does not exist or user does not have access + """ + messages = self.get_job_messages( + job_id=run_id, + actor=actor, + cursor=cursor, + limit=limit, + role=role, + ascending=ascending, + ) + + request_config = self._get_run_request_config(run_id) + + # Convert messages to LettaMessages + messages = [ + msg + for m in messages + for msg in m.to_letta_message( + assistant_message=request_config["use_assistant_message"], + assistant_message_tool_name=request_config["assistant_message_tool_name"], + assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"], + ) + ] + + return messages + def _verify_job_access( self, session: Session, @@ -295,3 +333,18 @@ class JobManager: if not job: raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") return job + + def _get_run_request_config(self, run_id: str) -> LettaRequestConfig: + """ + Get the request config for a job. + + Args: + job_id: The ID of the job to get messages for + + Returns: + The request config for the job + """ + with self.session_maker() as session: + job = session.query(JobModel).filter(JobModel.id == run_id).first() + request_config = job.request_config or LettaRequestConfig() + return request_config diff --git a/tests/test_client.py b/tests/test_client.py index 8aeef23a..7ecc89e4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -571,8 +571,6 @@ def test_send_message_async(client: Union[LocalClient, RESTClient], agent: Agent assert len(assistant_messages) > 0 tool_messages = client.get_run_messages(run_id=run.id, role=MessageRole.tool) assert len(tool_messages) > 0 - specific_tool_messages = client.get_run_messages(run_id=run.id, tool_name="send_message") - assert len(specific_tool_messages) > 0 # Get and verify usage statistics usage = client.get_run_usage(run_id=run.id)[0] diff --git a/tests/test_managers.py b/tests/test_managers.py index 08e2b1e3..5d9f9aa2 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -42,6 +42,7 @@ from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate +from letta.schemas.letta_request import LettaRequestConfig from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate @@ -2769,27 +2770,59 @@ def test_job_messages_filter(server: SyncServer, default_run, default_user, sara assert len(user_messages) == 1 assert user_messages[0].role == MessageRole.user - # Test filtering by tool name - tool_messages = server.job_manager.get_job_messages(job_id=default_run.id, actor=default_user, tool_name="test_tool") - assert len(tool_messages) == 1 - assert tool_messages[0].tool_calls[0].function.name == "test_tool" - - # Test filtering by role and tool name - assistant_tool_messages = server.job_manager.get_job_messages( - job_id=default_run.id, - actor=default_user, - role=MessageRole.assistant, - tool_name="test_tool", - ) - assert len(assistant_tool_messages) == 1 - assert assistant_tool_messages[0].role == MessageRole.assistant - assert assistant_tool_messages[0].tool_calls[0].function.name == "test_tool" - # Test limit limited_messages = server.job_manager.get_job_messages(job_id=default_run.id, actor=default_user, limit=2) assert len(limited_messages) == 2 +def test_get_run_messages_cursor(server: SyncServer, default_user: PydanticUser, sarah_agent): + """Test getting messages for a run with request config.""" + # Create a run with custom request config + run = server.job_manager.create_job( + pydantic_job=PydanticRun( + user_id=default_user.id, + status=JobStatus.created, + request_config=LettaRequestConfig( + use_assistant_message=False, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg" + ), + ), + actor=default_user, + ) + + # Add some messages + messages = [ + PydanticMessage( + organization_id=default_user.organization_id, + agent_id=sarah_agent.id, + role=MessageRole.user if i % 2 == 0 else MessageRole.assistant, + text=f"Test message {i}", + tool_calls=( + [{"id": f"call_{i}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] if i % 2 == 1 else None + ), + ) + for i in range(4) + ] + + for msg in messages: + created_msg = server.message_manager.create_message(msg, actor=default_user) + server.job_manager.add_message_to_job(job_id=run.id, message_id=created_msg.id, actor=default_user) + + # Get messages and verify they're converted correctly + result = server.job_manager.get_run_messages_cursor(run_id=run.id, actor=default_user) + + # Verify correct number of messages. Assistant messages should be parsed + assert len(result) == 6 + + # Verify assistant messages are parsed according to request config + tool_call_messages = [msg for msg in result if msg.message_type == "tool_call_message"] + reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"] + assert len(tool_call_messages) == 2 + assert len(reasoning_messages) == 2 + for msg in tool_call_messages: + assert msg.tool_call is not None + assert msg.tool_call.name == "custom_tool" + + # ====================================================================================================================== # JobManager Tests - Usage Statistics # ====================================================================================================================== diff --git a/tests/test_model_letta_perfomance.py b/tests/test_model_letta_performance.py similarity index 100% rename from tests/test_model_letta_perfomance.py rename to tests/test_model_letta_performance.py diff --git a/tests/test_v1_routes.py b/tests/test_v1_routes.py index a7e39c7b..8394e61e 100644 --- a/tests/test_v1_routes.py +++ b/tests/test_v1_routes.py @@ -6,6 +6,7 @@ from composio.client.collections import ActionModel, ActionParametersModel, Acti from fastapi.testclient import TestClient from letta.orm.errors import NoResultFound +from letta.schemas.message import UserMessage from letta.schemas.tool import ToolCreate, ToolUpdate from letta.server.rest_api.app import app from letta.server.rest_api.utils import get_letta_server @@ -339,62 +340,44 @@ def test_get_run_messages(client, mock_sync_server): """Test getting messages for a run.""" # Create properly formatted mock messages current_time = datetime.utcnow() - messages_data = [ - { - "id": f"message-{i:08x}", # Matches pattern '^message-[a-fA-F0-9]{8}' - "text": f"Test message {i}", - "role": "user", - "organization_id": "org-123", - "agent_id": "agent-123", - "model": "gpt-4", - "name": "test-user", - "tool_calls": [], - "tool_call_id": None, - "created_at": current_time, - "updated_at": current_time, - "created_by_id": "user-123", - "last_updated_by_id": "user-123", - } + mock_messages = [ + UserMessage( + id=f"message-{i:08x}", + date=current_time, + message=f"Test message {i}", + ) for i in range(2) ] - mock_messages = [] - for msg_data in messages_data: - mock_msg = Mock() - for key, value in msg_data.items(): - setattr(mock_msg, key, value) - mock_messages.append(mock_msg) - # Configure mock server responses mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") - mock_sync_server.job_manager.get_job_messages.return_value = mock_messages + mock_sync_server.job_manager.get_run_messages_cursor.return_value = mock_messages # Test successful retrieval response = client.get( "/v1/runs/run-12345678/messages", headers={"user_id": "user-123"}, - params={"limit": 10, "cursor": messages_data[1]["id"], "role": "user"}, + params={ + "limit": 10, + "cursor": mock_messages[0].id, + "role": "user", + "ascending": True, + }, ) assert response.status_code == 200 assert len(response.json()) == 2 - assert response.json()[0]["id"] == messages_data[0]["id"] - assert response.json()[1]["id"] == messages_data[1]["id"] + assert response.json()[0]["id"] == mock_messages[0].id + assert response.json()[1]["id"] == mock_messages[1].id # Verify mock calls mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123") - mock_sync_server.job_manager.get_job_messages.assert_called_once_with( - job_id="run-12345678", + mock_sync_server.job_manager.get_run_messages_cursor.assert_called_once_with( + run_id="run-12345678", actor=mock_sync_server.user_manager.get_user_or_default.return_value, limit=10, - cursor=messages_data[1]["id"], - start_date=None, - end_date=None, - query_text=None, + cursor=mock_messages[0].id, ascending=True, - tags=None, - match_all_tags=False, role="user", - tool_name=None, ) @@ -403,7 +386,7 @@ def test_get_run_messages_not_found(client, mock_sync_server): # Configure mock responses error_message = "Run 'run-nonexistent' not found" mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") - mock_sync_server.job_manager.get_job_messages.side_effect = NoResultFound(error_message) + mock_sync_server.job_manager.get_run_messages_cursor.side_effect = NoResultFound(error_message) response = client.get("/v1/runs/run-nonexistent/messages", headers={"user_id": "user-123"})