fix: updated send_message_async request params, get_run_messages returns LettaMessage (#638)
Co-authored-by: Mindy Long <mindy@letta.com> Co-authored-by: cthomas <caren@letta.com> Co-authored-by: Shubham Naik <shubham.naik10@gmail.com> Co-authored-by: Shubham Naik <shub@memgpt.ai>
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
"""adding request_config to Job table
|
||||
|
||||
Revision ID: f595e0e8013e
|
||||
Revises: 7f652fdd3dba
|
||||
Create Date: 2025-01-14 14:34:34.203363
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f595e0e8013e"
|
||||
down_revision: Union[str, None] = "7f652fdd3dba"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("jobs", sa.Column("request_config", sa.JSON, nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("jobs", "request_config")
|
||||
# ### end Alembic commands ###
|
||||
@@ -6,7 +6,7 @@ from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from tests.helpers.endpoints_helper import assert_invoked_send_message_with_keyword, setup_agent
|
||||
from tests.helpers.utils import cleanup
|
||||
from tests.test_model_letta_perfomance import llm_config_dir
|
||||
from tests.test_model_letta_performance import llm_config_dir
|
||||
|
||||
"""
|
||||
This example shows how you can constrain tool calls in your agent.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, Generator, List, Optional, Union
|
||||
|
||||
import requests
|
||||
@@ -23,6 +22,7 @@ from letta.schemas.environment_variables import (
|
||||
)
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -1999,46 +1999,27 @@ class RESTClient(AbstractClient):
|
||||
self,
|
||||
run_id: str,
|
||||
cursor: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 100,
|
||||
query_text: Optional[str] = None,
|
||||
ascending: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
role: Optional[MessageRole] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
) -> List[LettaMessageUnion]:
|
||||
"""
|
||||
Get messages associated with a job with filtering options.
|
||||
|
||||
Args:
|
||||
job_id: ID of the job
|
||||
cursor: Cursor for pagination
|
||||
start_date: Filter messages after this date
|
||||
end_date: Filter messages before this date
|
||||
limit: Maximum number of messages to return
|
||||
query_text: Search text in message content
|
||||
ascending: Sort order by creation time
|
||||
tags: Filter by message tags
|
||||
match_all_tags: If true, match all tags. If false, match any tag
|
||||
role: Filter by message role (user/assistant/system/tool)
|
||||
tool_name: Filter by tool call name
|
||||
|
||||
Returns:
|
||||
List of messages matching the filter criteria
|
||||
"""
|
||||
params = {
|
||||
"cursor": cursor,
|
||||
"start_date": start_date.isoformat() if start_date else None,
|
||||
"end_date": end_date.isoformat() if end_date else None,
|
||||
"limit": limit,
|
||||
"query_text": query_text,
|
||||
"ascending": ascending,
|
||||
"tags": tags,
|
||||
"match_all_tags": match_all_tags,
|
||||
"role": role,
|
||||
"tool_name": tool_name,
|
||||
}
|
||||
# Remove None values
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
@@ -2046,7 +2027,7 @@ class RESTClient(AbstractClient):
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/runs/{run_id}/messages", params=params)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get run messages: {response.text}")
|
||||
return [Message(**message) for message in response.json()]
|
||||
return [LettaMessage(**message) for message in response.json()]
|
||||
|
||||
def get_run_usage(
|
||||
self,
|
||||
@@ -3621,48 +3602,30 @@ class LocalClient(AbstractClient):
|
||||
self,
|
||||
run_id: str,
|
||||
cursor: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 100,
|
||||
query_text: Optional[str] = None,
|
||||
ascending: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
role: Optional[MessageRole] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
) -> List[LettaMessageUnion]:
|
||||
"""
|
||||
Get messages associated with a job with filtering options.
|
||||
|
||||
Args:
|
||||
run_id: ID of the run
|
||||
cursor: Cursor for pagination
|
||||
start_date: Filter messages after this date
|
||||
end_date: Filter messages before this date
|
||||
limit: Maximum number of messages to return
|
||||
query_text: Search text in message content
|
||||
ascending: Sort order by creation time
|
||||
tags: Filter by message tags
|
||||
match_all_tags: If true, match all tags. If false, match any tag
|
||||
role: Filter by message role (user/assistant/system/tool)
|
||||
tool_name: Filter by tool call name
|
||||
|
||||
Returns:
|
||||
List of messages matching the filter criteria
|
||||
"""
|
||||
params = {
|
||||
"cursor": cursor,
|
||||
"start_date": start_date.isoformat() if start_date else None,
|
||||
"end_date": end_date.isoformat() if end_date else None,
|
||||
"limit": limit,
|
||||
"query_text": query_text,
|
||||
"ascending": ascending,
|
||||
"tags": tags,
|
||||
"match_all_tags": match_all_tags,
|
||||
"role": role,
|
||||
"tool_name": tool_name,
|
||||
}
|
||||
return self.server.job_manager.get_job_messages(job_id=run_id, actor=self.user, job_type=JobType.RUN, **params)
|
||||
return self.server.job_manager.get_run_messages_cursor(run_id=run_id, actor=self.user, **params)
|
||||
|
||||
def get_run_usage(
|
||||
self,
|
||||
|
||||
@@ -9,6 +9,7 @@ from letta.orm.mixins import UserMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.letta_request import LettaRequestConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.job_messages import JobMessage
|
||||
@@ -33,6 +34,9 @@ class Job(SqlalchemyBase, UserMixin):
|
||||
default=JobType.JOB,
|
||||
doc="The type of job. This affects whether or not we generate json_schema and source_code on the fly.",
|
||||
)
|
||||
request_config: Mapped[Optional[LettaRequestConfig]] = mapped_column(
|
||||
JSON, nullable=True, doc="The request configuration for the job, stored as JSON."
|
||||
)
|
||||
|
||||
# relationships
|
||||
user: Mapped["User"] = relationship("User", back_populates="jobs")
|
||||
|
||||
@@ -6,11 +6,8 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.message import MessageCreate
|
||||
|
||||
|
||||
class LettaRequest(BaseModel):
|
||||
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
||||
|
||||
class LettaRequestConfig(BaseModel):
|
||||
# Flags to support the use of AssistantMessage message types
|
||||
|
||||
use_assistant_message: bool = Field(
|
||||
default=True,
|
||||
description="Whether the server should parse specific tool call arguments (default `send_message`) as `AssistantMessage` objects.",
|
||||
@@ -25,6 +22,11 @@ class LettaRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class LettaRequest(BaseModel):
|
||||
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
||||
config: LettaRequestConfig = Field(default=LettaRequestConfig(), description="Configuration options for the LettaRequest.")
|
||||
|
||||
|
||||
class LettaStreamingRequest(LettaRequest):
|
||||
stream_tokens: bool = Field(
|
||||
default=False,
|
||||
|
||||
@@ -149,9 +149,9 @@ class Message(BaseMessage):
|
||||
# We need to unpack the actual message contents from the function call
|
||||
try:
|
||||
func_args = json.loads(tool_call.function.arguments)
|
||||
message_string = func_args[DEFAULT_MESSAGE_TOOL_KWARG]
|
||||
message_string = func_args[assistant_message_tool_kwarg]
|
||||
except KeyError:
|
||||
raise ValueError(f"Function call {tool_call.function.name} missing {DEFAULT_MESSAGE_TOOL_KWARG} argument")
|
||||
raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument")
|
||||
messages.append(
|
||||
AssistantMessage(
|
||||
id=self.id,
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import Field
|
||||
|
||||
from letta.orm.enums import JobType
|
||||
from letta.schemas.job import Job, JobBase
|
||||
from letta.schemas.letta_request import LettaRequestConfig
|
||||
|
||||
|
||||
class RunBase(JobBase):
|
||||
@@ -28,6 +29,7 @@ class Run(RunBase):
|
||||
|
||||
id: str = RunBase.generate_id_field()
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the run.")
|
||||
request_config: Optional[LettaRequestConfig] = Field(None, description="The request configuration for the run.")
|
||||
|
||||
@classmethod
|
||||
def from_job(cls, job: Job) -> "Run":
|
||||
|
||||
@@ -502,9 +502,9 @@ async def send_message(
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
use_assistant_message=request.config.use_assistant_message,
|
||||
assistant_message_tool_name=request.config.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.config.assistant_message_tool_kwarg,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -542,9 +542,9 @@ async def send_message_streaming(
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
use_assistant_message=request.config.use_assistant_message,
|
||||
assistant_message_tool_name=request.config.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.config.assistant_message_tool_kwarg,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -622,6 +622,7 @@ async def send_message_async(
|
||||
"job_type": "send_message_async",
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
request_config=request.config,
|
||||
)
|
||||
run = server.job_manager.create_job(pydantic_job=run, actor=actor)
|
||||
|
||||
@@ -633,9 +634,9 @@ async def send_message_async(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
messages=request.messages,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
use_assistant_message=request.config.use_assistant_message,
|
||||
assistant_message_tool_name=request.config.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.config.assistant_message_tool_kwarg,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
@@ -6,7 +5,7 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
from letta.orm.enums import JobType
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.run import Run
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
@@ -61,21 +60,15 @@ def get_run(
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
|
||||
@router.get("/{run_id}/messages", response_model=List[Message], operation_id="get_run_messages")
|
||||
def get_run_messages(
|
||||
@router.get("/{run_id}/messages", response_model=List[LettaMessageUnion], operation_id="get_run_messages")
|
||||
async def get_run_messages(
|
||||
run_id: str,
|
||||
cursor: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter messages after this date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter messages before this date"),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
query_text: Optional[str] = Query(None, description="Search text in message content"),
|
||||
ascending: bool = Query(True, description="Sort order by creation time"),
|
||||
tags: Optional[List[str]] = Query(None, description="Filter by message tags"),
|
||||
match_all_tags: bool = Query(False, description="If true, match all tags. If false, match any tag"),
|
||||
role: Optional[MessageRole] = Query(None, description="Filter by message role"),
|
||||
tool_name: Optional[str] = Query(None, description="Filter by tool call name"),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
cursor: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
ascending: bool = Query(True, description="Sort order by creation time"),
|
||||
role: Optional[MessageRole] = Query(None, description="Filter by role"),
|
||||
):
|
||||
"""
|
||||
Get messages associated with a run with filtering options.
|
||||
@@ -83,33 +76,25 @@ def get_run_messages(
|
||||
Args:
|
||||
run_id: ID of the run
|
||||
cursor: Cursor for pagination
|
||||
start_date: Filter messages after this date
|
||||
end_date: Filter messages before this date
|
||||
limit: Maximum number of messages to return
|
||||
query_text: Search text in message content
|
||||
ascending: Sort order by creation time
|
||||
tags: Filter by message tags
|
||||
match_all_tags: If true, match all tags. If false, match any tag
|
||||
role: Filter by message role (user/assistant/system/tool)
|
||||
tool_name: Filter by tool call name
|
||||
role: Filter by role (user/assistant/system/tool)
|
||||
return_message_object: Whether to return Message objects or LettaMessage objects
|
||||
user_id: ID of the user making the request
|
||||
|
||||
Returns:
|
||||
A list of messages associated with the run. Default is List[LettaMessage].
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
try:
|
||||
messages = server.job_manager.get_job_messages(
|
||||
job_id=run_id,
|
||||
messages = server.job_manager.get_run_messages_cursor(
|
||||
run_id=run_id,
|
||||
actor=actor,
|
||||
cursor=cursor,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
query_text=query_text,
|
||||
cursor=cursor,
|
||||
ascending=ascending,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
return messages
|
||||
except NoResultFound as e:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -14,6 +13,8 @@ from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_request import LettaRequestConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
@@ -108,15 +109,9 @@ class JobManager:
|
||||
job_id: str,
|
||||
actor: PydanticUser,
|
||||
cursor: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 100,
|
||||
query_text: Optional[str] = None,
|
||||
ascending: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
role: Optional[MessageRole] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
ascending: bool = True,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Get all messages associated with a job.
|
||||
@@ -127,7 +122,7 @@ class JobManager:
|
||||
cursor: Cursor for pagination
|
||||
limit: Maximum number of messages to return
|
||||
role: Optional filter for message role
|
||||
tool_name: Optional filter for tool call name
|
||||
ascending: Optional flag to sort in ascending order
|
||||
|
||||
Returns:
|
||||
List of messages associated with the job
|
||||
@@ -145,24 +140,15 @@ class JobManager:
|
||||
messages = MessageModel.list(
|
||||
db_session=session,
|
||||
cursor=cursor,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
query_text=query_text,
|
||||
ascending=ascending,
|
||||
limit=limit,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
actor=actor,
|
||||
join_model=JobMessage,
|
||||
join_conditions=[MessageModel.id == JobMessage.message_id, JobMessage.job_id == job_id],
|
||||
**filters,
|
||||
)
|
||||
|
||||
# Filter by tool name if specified
|
||||
if tool_name is not None:
|
||||
messages = [msg for msg in messages if msg.tool_calls and any(call.function.name == tool_name for call in msg.tool_calls)]
|
||||
|
||||
return [message.to_pydantic() for message in messages]
|
||||
return [message.to_pydantic() for message in messages]
|
||||
|
||||
@enforce_types
|
||||
def add_message_to_job(self, job_id: str, message_id: str, actor: PydanticUser) -> None:
|
||||
@@ -268,6 +254,58 @@ class JobManager:
|
||||
session.add(usage_stats)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def get_run_messages_cursor(
|
||||
self,
|
||||
run_id: str,
|
||||
actor: PydanticUser,
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
role: Optional[MessageRole] = None,
|
||||
ascending: bool = True,
|
||||
) -> List[LettaMessage]:
|
||||
"""
|
||||
Get messages associated with a job using cursor-based pagination.
|
||||
This is a wrapper around get_job_messages that provides cursor-based pagination.
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job to get messages for
|
||||
actor: The user making the request
|
||||
cursor: Message ID to get messages after or before
|
||||
limit: Maximum number of messages to return
|
||||
ascending: Whether to return messages in ascending order
|
||||
role: Optional role filter
|
||||
|
||||
Returns:
|
||||
List of LettaMessages associated with the job
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
messages = self.get_job_messages(
|
||||
job_id=run_id,
|
||||
actor=actor,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
role=role,
|
||||
ascending=ascending,
|
||||
)
|
||||
|
||||
request_config = self._get_run_request_config(run_id)
|
||||
|
||||
# Convert messages to LettaMessages
|
||||
messages = [
|
||||
msg
|
||||
for m in messages
|
||||
for msg in m.to_letta_message(
|
||||
assistant_message=request_config["use_assistant_message"],
|
||||
assistant_message_tool_name=request_config["assistant_message_tool_name"],
|
||||
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
|
||||
)
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def _verify_job_access(
|
||||
self,
|
||||
session: Session,
|
||||
@@ -295,3 +333,18 @@ class JobManager:
|
||||
if not job:
|
||||
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
|
||||
return job
|
||||
|
||||
def _get_run_request_config(self, run_id: str) -> LettaRequestConfig:
|
||||
"""
|
||||
Get the request config for a job.
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job to get messages for
|
||||
|
||||
Returns:
|
||||
The request config for the job
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
job = session.query(JobModel).filter(JobModel.id == run_id).first()
|
||||
request_config = job.request_config or LettaRequestConfig()
|
||||
return request_config
|
||||
|
||||
@@ -571,8 +571,6 @@ def test_send_message_async(client: Union[LocalClient, RESTClient], agent: Agent
|
||||
assert len(assistant_messages) > 0
|
||||
tool_messages = client.get_run_messages(run_id=run.id, role=MessageRole.tool)
|
||||
assert len(tool_messages) > 0
|
||||
specific_tool_messages = client.get_run_messages(run_id=run.id, tool_name="send_message")
|
||||
assert len(specific_tool_messages) > 0
|
||||
|
||||
# Get and verify usage statistics
|
||||
usage = client.get_run_usage(run_id=run.id)[0]
|
||||
|
||||
@@ -42,6 +42,7 @@ from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_request import LettaRequestConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
@@ -2769,27 +2770,59 @@ def test_job_messages_filter(server: SyncServer, default_run, default_user, sara
|
||||
assert len(user_messages) == 1
|
||||
assert user_messages[0].role == MessageRole.user
|
||||
|
||||
# Test filtering by tool name
|
||||
tool_messages = server.job_manager.get_job_messages(job_id=default_run.id, actor=default_user, tool_name="test_tool")
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0].tool_calls[0].function.name == "test_tool"
|
||||
|
||||
# Test filtering by role and tool name
|
||||
assistant_tool_messages = server.job_manager.get_job_messages(
|
||||
job_id=default_run.id,
|
||||
actor=default_user,
|
||||
role=MessageRole.assistant,
|
||||
tool_name="test_tool",
|
||||
)
|
||||
assert len(assistant_tool_messages) == 1
|
||||
assert assistant_tool_messages[0].role == MessageRole.assistant
|
||||
assert assistant_tool_messages[0].tool_calls[0].function.name == "test_tool"
|
||||
|
||||
# Test limit
|
||||
limited_messages = server.job_manager.get_job_messages(job_id=default_run.id, actor=default_user, limit=2)
|
||||
assert len(limited_messages) == 2
|
||||
|
||||
|
||||
def test_get_run_messages_cursor(server: SyncServer, default_user: PydanticUser, sarah_agent):
|
||||
"""Test getting messages for a run with request config."""
|
||||
# Create a run with custom request config
|
||||
run = server.job_manager.create_job(
|
||||
pydantic_job=PydanticRun(
|
||||
user_id=default_user.id,
|
||||
status=JobStatus.created,
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=False, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg"
|
||||
),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Add some messages
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user if i % 2 == 0 else MessageRole.assistant,
|
||||
text=f"Test message {i}",
|
||||
tool_calls=(
|
||||
[{"id": f"call_{i}", "function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'}}] if i % 2 == 1 else None
|
||||
),
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
for msg in messages:
|
||||
created_msg = server.message_manager.create_message(msg, actor=default_user)
|
||||
server.job_manager.add_message_to_job(job_id=run.id, message_id=created_msg.id, actor=default_user)
|
||||
|
||||
# Get messages and verify they're converted correctly
|
||||
result = server.job_manager.get_run_messages_cursor(run_id=run.id, actor=default_user)
|
||||
|
||||
# Verify correct number of messages. Assistant messages should be parsed
|
||||
assert len(result) == 6
|
||||
|
||||
# Verify assistant messages are parsed according to request config
|
||||
tool_call_messages = [msg for msg in result if msg.message_type == "tool_call_message"]
|
||||
reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"]
|
||||
assert len(tool_call_messages) == 2
|
||||
assert len(reasoning_messages) == 2
|
||||
for msg in tool_call_messages:
|
||||
assert msg.tool_call is not None
|
||||
assert msg.tool_call.name == "custom_tool"
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# JobManager Tests - Usage Statistics
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -6,6 +6,7 @@ from composio.client.collections import ActionModel, ActionParametersModel, Acti
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.message import UserMessage
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.server.rest_api.app import app
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
@@ -339,62 +340,44 @@ def test_get_run_messages(client, mock_sync_server):
|
||||
"""Test getting messages for a run."""
|
||||
# Create properly formatted mock messages
|
||||
current_time = datetime.utcnow()
|
||||
messages_data = [
|
||||
{
|
||||
"id": f"message-{i:08x}", # Matches pattern '^message-[a-fA-F0-9]{8}'
|
||||
"text": f"Test message {i}",
|
||||
"role": "user",
|
||||
"organization_id": "org-123",
|
||||
"agent_id": "agent-123",
|
||||
"model": "gpt-4",
|
||||
"name": "test-user",
|
||||
"tool_calls": [],
|
||||
"tool_call_id": None,
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"created_by_id": "user-123",
|
||||
"last_updated_by_id": "user-123",
|
||||
}
|
||||
mock_messages = [
|
||||
UserMessage(
|
||||
id=f"message-{i:08x}",
|
||||
date=current_time,
|
||||
message=f"Test message {i}",
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
mock_messages = []
|
||||
for msg_data in messages_data:
|
||||
mock_msg = Mock()
|
||||
for key, value in msg_data.items():
|
||||
setattr(mock_msg, key, value)
|
||||
mock_messages.append(mock_msg)
|
||||
|
||||
# Configure mock server responses
|
||||
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
||||
mock_sync_server.job_manager.get_job_messages.return_value = mock_messages
|
||||
mock_sync_server.job_manager.get_run_messages_cursor.return_value = mock_messages
|
||||
|
||||
# Test successful retrieval
|
||||
response = client.get(
|
||||
"/v1/runs/run-12345678/messages",
|
||||
headers={"user_id": "user-123"},
|
||||
params={"limit": 10, "cursor": messages_data[1]["id"], "role": "user"},
|
||||
params={
|
||||
"limit": 10,
|
||||
"cursor": mock_messages[0].id,
|
||||
"role": "user",
|
||||
"ascending": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
assert response.json()[0]["id"] == messages_data[0]["id"]
|
||||
assert response.json()[1]["id"] == messages_data[1]["id"]
|
||||
assert response.json()[0]["id"] == mock_messages[0].id
|
||||
assert response.json()[1]["id"] == mock_messages[1].id
|
||||
|
||||
# Verify mock calls
|
||||
mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123")
|
||||
mock_sync_server.job_manager.get_job_messages.assert_called_once_with(
|
||||
job_id="run-12345678",
|
||||
mock_sync_server.job_manager.get_run_messages_cursor.assert_called_once_with(
|
||||
run_id="run-12345678",
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
||||
limit=10,
|
||||
cursor=messages_data[1]["id"],
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
query_text=None,
|
||||
cursor=mock_messages[0].id,
|
||||
ascending=True,
|
||||
tags=None,
|
||||
match_all_tags=False,
|
||||
role="user",
|
||||
tool_name=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -403,7 +386,7 @@ def test_get_run_messages_not_found(client, mock_sync_server):
|
||||
# Configure mock responses
|
||||
error_message = "Run 'run-nonexistent' not found"
|
||||
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
||||
mock_sync_server.job_manager.get_job_messages.side_effect = NoResultFound(error_message)
|
||||
mock_sync_server.job_manager.get_run_messages_cursor.side_effect = NoResultFound(error_message)
|
||||
|
||||
response = client.get("/v1/runs/run-nonexistent/messages", headers={"user_id": "user-123"})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user