feat: add conversation_id parameter to context endpoint [LET-6989] (#8678)
* feat: add conversation_id parameter to context endpoint [LET-6989] Add optional conversation_id query parameter to retrieve_agent_context_window. When provided, the endpoint uses messages from the specific conversation instead of the agent's default message_ids. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * chore: regenerate SDK after context endpoint update [LET-6989] 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -452,12 +452,15 @@ async def retrieve_agent_context_window(
|
||||
agent_id: AgentId,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
conversation_id: Optional[str] = Query(
|
||||
None, description="Conversation ID to get context window for. If provided, uses messages from this conversation."
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieve the context window of a specific agent.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
return await server.agent_manager.get_context_window(agent_id=agent_id, actor=actor)
|
||||
return await server.agent_manager.get_context_window(agent_id=agent_id, actor=actor, conversation_id=conversation_id)
|
||||
|
||||
|
||||
class CreateAgentRequest(CreateAgent):
|
||||
@@ -2216,7 +2219,7 @@ async def capture_messages(
|
||||
messages_to_persist.append(
|
||||
Message(
|
||||
role=MessageRole.user,
|
||||
content=[(TextContent(text=message["content"]))],
|
||||
content=[TextContent(text=message["content"])],
|
||||
agent_id=agent_id,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
@@ -2228,7 +2231,7 @@ async def capture_messages(
|
||||
messages_to_persist.append(
|
||||
Message(
|
||||
role=MessageRole.assistant,
|
||||
content=[(TextContent(text=request.response_dict["content"]))],
|
||||
content=[TextContent(text=request.response_dict["content"])],
|
||||
agent_id=agent_id,
|
||||
model=request.model,
|
||||
tool_calls=None,
|
||||
|
||||
@@ -83,6 +83,7 @@ from letta.services.archive_manager import ArchiveManager
|
||||
from letta.services.block_manager import BlockManager, validate_block_limit_constraint
|
||||
from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator
|
||||
from letta.services.context_window_calculator.token_counter import create_token_counter
|
||||
from letta.services.conversation_manager import ConversationManager
|
||||
from letta.services.file_processor.chunker.line_chunker import LineChunker
|
||||
from letta.services.files_agents_manager import FileAgentManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
@@ -137,6 +138,7 @@ class AgentManager:
|
||||
self.identity_manager = IdentityManager()
|
||||
self.file_agent_manager = FileAgentManager()
|
||||
self.archive_manager = ArchiveManager()
|
||||
self.conversation_manager = ConversationManager()
|
||||
|
||||
@staticmethod
|
||||
def _should_exclude_model_from_base_tool_rules(model: str) -> bool:
|
||||
@@ -3388,7 +3390,7 @@ class AgentManager:
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@trace_method
|
||||
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
|
||||
async def get_context_window(self, agent_id: str, actor: PydanticUser, conversation_id: Optional[str] = None) -> ContextWindowOverview:
|
||||
agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async(
|
||||
agent_id=agent_id, actor=actor, force=True, dry_run=True
|
||||
)
|
||||
@@ -3402,6 +3404,16 @@ class AgentManager:
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
# If conversation_id is provided, get message_ids from the conversation
|
||||
# Skip the first message ID (system message) since it's passed separately
|
||||
message_ids = None
|
||||
if conversation_id is not None:
|
||||
conversation_message_ids = await self.conversation_manager.get_message_ids_for_conversation(
|
||||
conversation_id=conversation_id, actor=actor
|
||||
)
|
||||
# Skip the system message (first message) as it's handled separately
|
||||
message_ids = conversation_message_ids[1:] if conversation_message_ids else []
|
||||
|
||||
try:
|
||||
result = await calculator.calculate_context_window(
|
||||
agent_state=agent_state,
|
||||
@@ -3411,6 +3423,7 @@ class AgentManager:
|
||||
system_message_compiled=system_message,
|
||||
num_archival_memories=num_archival_memories,
|
||||
num_messages=num_messages,
|
||||
message_ids=message_ids,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -105,9 +105,17 @@ class ContextWindowCalculator:
|
||||
system_message_compiled: Message,
|
||||
num_archival_memories: int,
|
||||
num_messages: int,
|
||||
message_ids: Optional[List[str]] = None,
|
||||
) -> ContextWindowOverview:
|
||||
"""Calculate context window information using the provided token counter"""
|
||||
messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids[1:], actor=actor)
|
||||
"""Calculate context window information using the provided token counter
|
||||
|
||||
Args:
|
||||
message_ids: Optional list of message IDs to use instead of agent_state.message_ids.
|
||||
If provided, should NOT include the system message ID (index 0).
|
||||
"""
|
||||
# Use provided message_ids or fall back to agent_state.message_ids[1:]
|
||||
effective_message_ids = message_ids if message_ids is not None else agent_state.message_ids[1:]
|
||||
messages = await message_manager.get_messages_by_ids_async(message_ids=effective_message_ids, actor=actor)
|
||||
in_context_messages = [system_message_compiled] + messages
|
||||
|
||||
# Filter out None messages (can occur when system message is missing)
|
||||
|
||||
Reference in New Issue
Block a user