diff --git a/fern/openapi.json b/fern/openapi.json index e876c8f9..6e6f53ff 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -16459,17 +16459,6 @@ "title": "AssistantMessage", "description": "A message sent by the LLM in response to user input. Used in the LLM context.\n\nArgs:\n id (str): The ID of the message\n date (datetime): The date the message was created in ISO format\n name (Optional[str]): The name of the sender of the message\n content (Union[str, List[LettaAssistantMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)" }, - "Audio": { - "properties": { - "id": { - "type": "string", - "title": "Id" - } - }, - "type": "object", - "required": ["id"], - "title": "Audio" - }, "AuthRequest": { "properties": { "password": { @@ -17459,346 +17448,14 @@ "type": "object", "title": "CancelAgentRunRequest" }, - "ChatCompletionAllowedToolChoiceParam": { - "properties": { - "allowed_tools": { - "$ref": "#/components/schemas/ChatCompletionAllowedToolsParam" - }, - "type": { - "type": "string", - "const": "allowed_tools", - "title": "Type" - } - }, - "type": "object", - "required": ["allowed_tools", "type"], - "title": "ChatCompletionAllowedToolChoiceParam" - }, - "ChatCompletionAllowedToolsParam": { - "properties": { - "mode": { - "type": "string", - "enum": ["auto", "required"], - "title": "Mode" - }, - "tools": { - "items": { - "additionalProperties": true, - "type": "object" - }, - "type": "array", - "title": "Tools" - } - }, - "type": "object", - "required": ["mode", "tools"], - "title": "ChatCompletionAllowedToolsParam" - }, - "ChatCompletionAssistantMessageParam": { - "properties": { - "role": { - "type": "string", - "const": "assistant", - "title": "Role" - }, - "audio": { - "anyOf": [ - { - "$ref": "#/components/schemas/Audio" - }, - { - "type": "null" - } - ] - }, - "content": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionContentPartTextParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionContentPartRefusalParam" - } - ] - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Content" - }, - "function_call": { - "anyOf": [ - { - "$ref": "#/components/schemas/FunctionCall" - }, - { - "type": "null" - } - ] - }, - "name": { - "type": "string", - "title": "Name" - }, - "refusal": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Refusal" - }, - "tool_calls": { - "items": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionMessageFunctionToolCallParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionMessageCustomToolCallParam" - } - ] - }, - "type": "array", - "title": "Tool Calls" - } - }, - "type": "object", - "required": ["role"], - "title": "ChatCompletionAssistantMessageParam" - }, - "ChatCompletionAudioParam": { - "properties": { - "format": { - "type": "string", - "enum": ["wav", "aac", "mp3", "flac", "opus", "pcm16"], - "title": "Format" - }, - "voice": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "string", - "enum": [ - "alloy", - "ash", - "ballad", - "coral", - "echo", - "sage", - "shimmer", - "verse" - ] - } - ], - "title": "Voice" - } - }, - "type": "object", - "required": ["format", "voice"], - "title": "ChatCompletionAudioParam" - }, - "ChatCompletionContentPartImageParam": { - "properties": { - "image_url": { - "$ref": "#/components/schemas/ImageURL" - }, - "type": { - "type": "string", - "const": "image_url", - "title": "Type" - } - }, - "type": "object", - "required": ["image_url", "type"], - "title": "ChatCompletionContentPartImageParam" - }, - "ChatCompletionContentPartInputAudioParam": { - "properties": { - "input_audio": { - "$ref": "#/components/schemas/InputAudio" - }, - "type": { - "type": "string", - "const": "input_audio", - "title": "Type" - } - }, - "type": "object", - "required": ["input_audio", "type"], - "title": "ChatCompletionContentPartInputAudioParam" - }, - "ChatCompletionContentPartRefusalParam": { - "properties": { - "refusal": { - "type": "string", - "title": "Refusal" - }, - "type": { - "type": "string", - "const": "refusal", - "title": "Type" - } - }, - "type": "object", - "required": ["refusal", "type"], - "title": "ChatCompletionContentPartRefusalParam" - }, - "ChatCompletionContentPartTextParam": { - "properties": { - "text": { - "type": "string", - "title": "Text" - }, - "type": { - "type": "string", - "const": "text", - "title": "Type" - } - }, - "type": "object", - "required": ["text", "type"], - "title": "ChatCompletionContentPartTextParam" - }, - "ChatCompletionCustomToolParam": { - "properties": { - "custom": { - "$ref": "#/components/schemas/openai__types__chat__chat_completion_custom_tool_param__Custom" - }, - "type": { - "type": "string", - "const": "custom", - "title": "Type" - } - }, - "type": "object", - "required": ["custom", "type"], - "title": "ChatCompletionCustomToolParam" - }, - "ChatCompletionDeveloperMessageParam": { - "properties": { - "content": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "$ref": "#/components/schemas/ChatCompletionContentPartTextParam" - }, - "type": "array" - } - ], - "title": "Content" - }, - "role": { - "type": "string", - "const": "developer", - "title": "Role" - }, - "name": { - "type": "string", - "title": "Name" - } - }, - "type": "object", - "required": ["content", "role"], - "title": "ChatCompletionDeveloperMessageParam" - }, - "ChatCompletionFunctionCallOptionParam": { - "properties": { - "name": { - "type": "string", - "title": "Name" - } - }, - "type": "object", - "required": ["name"], - "title": "ChatCompletionFunctionCallOptionParam" - }, - "ChatCompletionFunctionMessageParam": { - "properties": { - "content": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Content" - }, - "name": { - "type": "string", - "title": "Name" - }, - "role": { - "type": "string", - "const": "function", - "title": "Role" - } - }, - "type": "object", - "required": ["content", "name", "role"], - "title": "ChatCompletionFunctionMessageParam" - }, - "ChatCompletionFunctionToolParam": { - "properties": { - "function": { - "$ref": "#/components/schemas/FunctionDefinition-Input" - }, - "type": { - "type": "string", - "const": "function", - "title": "Type" - } - }, - "type": "object", - "required": ["function", "type"], - "title": "ChatCompletionFunctionToolParam" - }, - "ChatCompletionMessageCustomToolCallParam": { - "properties": { - "id": { - "type": "string", - "title": "Id" - }, - "custom": { - "$ref": "#/components/schemas/openai__types__chat__chat_completion_message_custom_tool_call_param__Custom" - }, - "type": { - "type": "string", - "const": "custom", - "title": "Type" - } - }, - "type": "object", - "required": ["id", "custom", "type"], - "title": "ChatCompletionMessageCustomToolCallParam" - }, - "ChatCompletionMessageFunctionToolCall-Input": { + "ChatCompletionMessageFunctionToolCall": { "properties": { "id": { "type": "string", "title": "Id" }, "function": { - "$ref": "#/components/schemas/openai__types__chat__chat_completion_message_function_tool_call__Function" + "$ref": "#/components/schemas/Function" }, "type": { "type": "string", @@ -17811,218 +17468,6 @@ "required": ["id", "function", "type"], "title": "ChatCompletionMessageFunctionToolCall" }, - "ChatCompletionMessageFunctionToolCall-Output": { - "properties": { - "id": { - "type": "string", - "title": "Id" - }, - "function": { - "$ref": "#/components/schemas/Function-Output" - }, - "type": { - "type": "string", - "const": "function", - "title": "Type" - } - }, - "additionalProperties": true, - "type": "object", - "required": ["id", "function", "type"], - "title": "ChatCompletionMessageFunctionToolCall" - }, - "ChatCompletionMessageFunctionToolCallParam": { - "properties": { - "id": { - "type": "string", - "title": "Id" - }, - "function": { - "$ref": "#/components/schemas/openai__types__chat__chat_completion_message_function_tool_call_param__Function" - }, - "type": { - "type": "string", - "const": "function", - "title": "Type" - } - }, - "type": "object", - "required": ["id", "function", "type"], - "title": "ChatCompletionMessageFunctionToolCallParam" - }, - "ChatCompletionNamedToolChoiceCustomParam": { - "properties": { - "custom": { - "$ref": "#/components/schemas/openai__types__chat__chat_completion_named_tool_choice_custom_param__Custom" - }, - "type": { - "type": "string", - "const": "custom", - "title": "Type" - } - }, - "type": "object", - "required": ["custom", "type"], - "title": "ChatCompletionNamedToolChoiceCustomParam" - }, - "ChatCompletionNamedToolChoiceParam": { - "properties": { - "function": { - "$ref": "#/components/schemas/openai__types__chat__chat_completion_named_tool_choice_param__Function" - }, - "type": { - "type": "string", - "const": "function", - "title": "Type" - } - }, - "type": "object", - "required": ["function", "type"], - "title": "ChatCompletionNamedToolChoiceParam" - }, - "ChatCompletionPredictionContentParam": { - "properties": { - "content": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "$ref": "#/components/schemas/ChatCompletionContentPartTextParam" - }, - "type": "array" - } - ], - "title": "Content" - }, - "type": { - "type": "string", - "const": "content", - "title": "Type" - } - }, - "type": "object", - "required": ["content", "type"], - "title": "ChatCompletionPredictionContentParam" - }, - "ChatCompletionStreamOptionsParam": { - "properties": { - "include_obfuscation": { - "type": "boolean", - "title": "Include Obfuscation" - }, - "include_usage": { - "type": "boolean", - "title": "Include Usage" - } - }, - "type": "object", - "title": "ChatCompletionStreamOptionsParam" - }, - "ChatCompletionSystemMessageParam": { - "properties": { - "content": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "$ref": "#/components/schemas/ChatCompletionContentPartTextParam" - }, - "type": "array" - } - ], - "title": "Content" - }, - "role": { - "type": "string", - "const": "system", - "title": "Role" - }, - "name": { - "type": "string", - "title": "Name" - } - }, - "type": "object", - "required": ["content", "role"], - "title": "ChatCompletionSystemMessageParam" - }, - "ChatCompletionToolMessageParam": { - "properties": { - "content": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "$ref": "#/components/schemas/ChatCompletionContentPartTextParam" - }, - "type": "array" - } - ], - "title": "Content" - }, - "role": { - "type": "string", - "const": "tool", - "title": "Role" - }, - "tool_call_id": { - "type": "string", - "title": "Tool Call Id" - } - }, - "type": "object", - "required": ["content", "role", "tool_call_id"], - "title": "ChatCompletionToolMessageParam" - }, - "ChatCompletionUserMessageParam": { - "properties": { - "content": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionContentPartTextParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionContentPartImageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionContentPartInputAudioParam" - }, - { - "$ref": "#/components/schemas/File" - } - ] - }, - "type": "array" - } - ], - "title": "Content" - }, - "role": { - "type": "string", - "const": "user", - "title": "Role" - }, - "name": { - "type": "string", - "title": "Name" - } - }, - "type": "object", - "required": ["content", "role"], - "title": "ChatCompletionUserMessageParam" - }, "ChildToolRule": { "properties": { "tool_name": { @@ -18110,901 +17555,6 @@ "required": ["code"], "title": "CodeInput" }, - "CompletionCreateParamsNonStreaming": { - "properties": { - "messages": { - "items": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionDeveloperMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionSystemMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionUserMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionAssistantMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionToolMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionFunctionMessageParam" - } - ] - }, - "type": "array", - "title": "Messages" - }, - "model": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "string", - "enum": [ - "gpt-5", - "gpt-5-mini", - "gpt-5-nano", - "gpt-5-2025-08-07", - "gpt-5-mini-2025-08-07", - "gpt-5-nano-2025-08-07", - "gpt-5-chat-latest", - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - "gpt-4.1-2025-04-14", - "gpt-4.1-mini-2025-04-14", - "gpt-4.1-nano-2025-04-14", - "o4-mini", - "o4-mini-2025-04-16", - "o3", - "o3-2025-04-16", - "o3-mini", - "o3-mini-2025-01-31", - "o1", - "o1-2024-12-17", - "o1-preview", - "o1-preview-2024-09-12", - "o1-mini", - "o1-mini-2024-09-12", - "gpt-4o", - "gpt-4o-2024-11-20", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "gpt-4o-audio-preview", - "gpt-4o-audio-preview-2024-10-01", - "gpt-4o-audio-preview-2024-12-17", - "gpt-4o-audio-preview-2025-06-03", - "gpt-4o-mini-audio-preview", - "gpt-4o-mini-audio-preview-2024-12-17", - "gpt-4o-search-preview", - "gpt-4o-mini-search-preview", - "gpt-4o-search-preview-2025-03-11", - "gpt-4o-mini-search-preview-2025-03-11", - "chatgpt-4o-latest", - "codex-mini-latest", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-0125-preview", - "gpt-4-turbo-preview", - "gpt-4-1106-preview", - "gpt-4-vision-preview", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-32k", - "gpt-4-32k-0314", - "gpt-4-32k-0613", - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-16k-0613" - ] - } - ], - "title": "Model" - }, - "audio": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionAudioParam" - }, - { - "type": "null" - } - ] - }, - "frequency_penalty": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "null" - } - ], - "title": "Frequency Penalty" - }, - "function_call": { - "anyOf": [ - { - "type": "string", - "enum": ["none", "auto"] - }, - { - "$ref": "#/components/schemas/ChatCompletionFunctionCallOptionParam" - } - ], - "title": "Function Call" - }, - "functions": { - "items": { - "$ref": "#/components/schemas/openai__types__chat__completion_create_params__Function" - }, - "type": "array", - "title": "Functions" - }, - "logit_bias": { - "anyOf": [ - { - "additionalProperties": { - "type": "integer" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Logit Bias" - }, - "logprobs": { - "anyOf": [ - { - "type": "boolean" - }, - { - "type": "null" - } - ], - "title": "Logprobs" - }, - "max_completion_tokens": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "Max Completion Tokens" - }, - "max_tokens": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "Max Tokens" - }, - "metadata": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Metadata" - }, - "modalities": { - "anyOf": [ - { - "items": { - "type": "string", - "enum": ["text", "audio"] - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Modalities" - }, - "n": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "N" - }, - "parallel_tool_calls": { - "type": "boolean", - "title": "Parallel Tool Calls" - }, - "prediction": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionPredictionContentParam" - }, - { - "type": "null" - } - ] - }, - "presence_penalty": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "null" - } - ], - "title": "Presence Penalty" - }, - "prompt_cache_key": { - "type": "string", - "title": "Prompt Cache Key" - }, - "reasoning_effort": { - "anyOf": [ - { - "type": "string", - "enum": ["minimal", "low", "medium", "high"] - }, - { - "type": "null" - } - ], - "title": "Reasoning Effort" - }, - "response_format": { - "anyOf": [ - { - "$ref": "#/components/schemas/ResponseFormatText" - }, - { - "$ref": "#/components/schemas/ResponseFormatJSONSchema" - }, - { - "$ref": "#/components/schemas/ResponseFormatJSONObject" - } - ], - "title": "Response Format" - }, - "safety_identifier": { - "type": "string", - "title": "Safety Identifier" - }, - "seed": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "Seed" - }, - "service_tier": { - "anyOf": [ - { - "type": "string", - "enum": ["auto", "default", "flex", "scale", "priority"] - }, - { - "type": "null" - } - ], - "title": "Service Tier" - }, - "stop": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Stop" - }, - "store": { - "anyOf": [ - { - "type": "boolean" - }, - { - "type": "null" - } - ], - "title": "Store" - }, - "stream_options": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionStreamOptionsParam" - }, - { - "type": "null" - } - ] - }, - "temperature": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "null" - } - ], - "title": "Temperature" - }, - "tool_choice": { - "anyOf": [ - { - "type": "string", - "enum": ["none", "auto", "required"] - }, - { - "$ref": "#/components/schemas/ChatCompletionAllowedToolChoiceParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionNamedToolChoiceParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionNamedToolChoiceCustomParam" - } - ], - "title": "Tool Choice" - }, - "tools": { - "items": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionFunctionToolParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionCustomToolParam" - } - ] - }, - "type": "array", - "title": "Tools" - }, - "top_logprobs": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "Top Logprobs" - }, - "top_p": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "null" - } - ], - "title": "Top P" - }, - "user": { - "type": "string", - "title": "User" - }, - "verbosity": { - "anyOf": [ - { - "type": "string", - "enum": ["low", "medium", "high"] - }, - { - "type": "null" - } - ], - "title": "Verbosity" - }, - "web_search_options": { - "$ref": "#/components/schemas/WebSearchOptions" - }, - "stream": { - "anyOf": [ - { - "type": "boolean", - "const": false - }, - { - "type": "null" - } - ], - "title": "Stream" - } - }, - "type": "object", - "required": ["messages", "model"], - "title": "CompletionCreateParamsNonStreaming" - }, - "CompletionCreateParamsStreaming": { - "properties": { - "messages": { - "items": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionDeveloperMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionSystemMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionUserMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionAssistantMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionToolMessageParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionFunctionMessageParam" - } - ] - }, - "type": "array", - "title": "Messages" - }, - "model": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "string", - "enum": [ - "gpt-5", - "gpt-5-mini", - "gpt-5-nano", - "gpt-5-2025-08-07", - "gpt-5-mini-2025-08-07", - "gpt-5-nano-2025-08-07", - "gpt-5-chat-latest", - "gpt-4.1", - "gpt-4.1-mini", - "gpt-4.1-nano", - "gpt-4.1-2025-04-14", - "gpt-4.1-mini-2025-04-14", - "gpt-4.1-nano-2025-04-14", - "o4-mini", - "o4-mini-2025-04-16", - "o3", - "o3-2025-04-16", - "o3-mini", - "o3-mini-2025-01-31", - "o1", - "o1-2024-12-17", - "o1-preview", - "o1-preview-2024-09-12", - "o1-mini", - "o1-mini-2024-09-12", - "gpt-4o", - "gpt-4o-2024-11-20", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "gpt-4o-audio-preview", - "gpt-4o-audio-preview-2024-10-01", - "gpt-4o-audio-preview-2024-12-17", - "gpt-4o-audio-preview-2025-06-03", - "gpt-4o-mini-audio-preview", - "gpt-4o-mini-audio-preview-2024-12-17", - "gpt-4o-search-preview", - "gpt-4o-mini-search-preview", - "gpt-4o-search-preview-2025-03-11", - "gpt-4o-mini-search-preview-2025-03-11", - "chatgpt-4o-latest", - "codex-mini-latest", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-0125-preview", - "gpt-4-turbo-preview", - "gpt-4-1106-preview", - "gpt-4-vision-preview", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-32k", - "gpt-4-32k-0314", - "gpt-4-32k-0613", - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0301", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-16k-0613" - ] - } - ], - "title": "Model" - }, - "audio": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionAudioParam" - }, - { - "type": "null" - } - ] - }, - "frequency_penalty": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "null" - } - ], - "title": "Frequency Penalty" - }, - "function_call": { - "anyOf": [ - { - "type": "string", - "enum": ["none", "auto"] - }, - { - "$ref": "#/components/schemas/ChatCompletionFunctionCallOptionParam" - } - ], - "title": "Function Call" - }, - "functions": { - "items": { - "$ref": "#/components/schemas/openai__types__chat__completion_create_params__Function" - }, - "type": "array", - "title": "Functions" - }, - "logit_bias": { - "anyOf": [ - { - "additionalProperties": { - "type": "integer" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Logit Bias" - }, - "logprobs": { - "anyOf": [ - { - "type": "boolean" - }, - { - "type": "null" - } - ], - "title": "Logprobs" - }, - "max_completion_tokens": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "Max Completion Tokens" - }, - "max_tokens": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "Max Tokens" - }, - "metadata": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Metadata" - }, - "modalities": { - "anyOf": [ - { - "items": { - "type": "string", - "enum": ["text", "audio"] - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Modalities" - }, - "n": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "N" - }, - "parallel_tool_calls": { - "type": "boolean", - "title": "Parallel Tool Calls" - }, - "prediction": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionPredictionContentParam" - }, - { - "type": "null" - } - ] - }, - "presence_penalty": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "null" - } - ], - "title": "Presence Penalty" - }, - "prompt_cache_key": { - "type": "string", - "title": "Prompt Cache Key" - }, - "reasoning_effort": { - "anyOf": [ - { - "type": "string", - "enum": ["minimal", "low", "medium", "high"] - }, - { - "type": "null" - } - ], - "title": "Reasoning Effort" - }, - "response_format": { - "anyOf": [ - { - "$ref": "#/components/schemas/ResponseFormatText" - }, - { - "$ref": "#/components/schemas/ResponseFormatJSONSchema" - }, - { - "$ref": "#/components/schemas/ResponseFormatJSONObject" - } - ], - "title": "Response Format" - }, - "safety_identifier": { - "type": "string", - "title": "Safety Identifier" - }, - "seed": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "Seed" - }, - "service_tier": { - "anyOf": [ - { - "type": "string", - "enum": ["auto", "default", "flex", "scale", "priority"] - }, - { - "type": "null" - } - ], - "title": "Service Tier" - }, - "stop": { - "anyOf": [ - { - "type": "string" - }, - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Stop" - }, - "store": { - "anyOf": [ - { - "type": "boolean" - }, - { - "type": "null" - } - ], - "title": "Store" - }, - "stream_options": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionStreamOptionsParam" - }, - { - "type": "null" - } - ] - }, - "temperature": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "null" - } - ], - "title": "Temperature" - }, - "tool_choice": { - "anyOf": [ - { - "type": "string", - "enum": ["none", "auto", "required"] - }, - { - "$ref": "#/components/schemas/ChatCompletionAllowedToolChoiceParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionNamedToolChoiceParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionNamedToolChoiceCustomParam" - } - ], - "title": "Tool Choice" - }, - "tools": { - "items": { - "anyOf": [ - { - "$ref": "#/components/schemas/ChatCompletionFunctionToolParam" - }, - { - "$ref": "#/components/schemas/ChatCompletionCustomToolParam" - } - ] - }, - "type": "array", - "title": "Tools" - }, - "top_logprobs": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ], - "title": "Top Logprobs" - }, - "top_p": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "null" - } - ], - "title": "Top P" - }, - "user": { - "type": "string", - "title": "User" - }, - "verbosity": { - "anyOf": [ - { - "type": "string", - "enum": ["low", "medium", "high"] - }, - { - "type": "null" - } - ], - "title": "Verbosity" - }, - "web_search_options": { - "$ref": "#/components/schemas/WebSearchOptions" - }, - "stream": { - "type": "boolean", - "const": true, - "title": "Stream" - } - }, - "type": "object", - "required": ["messages", "model", "stream"], - "title": "CompletionCreateParamsStreaming" - }, "ConditionalToolRule": { "properties": { "tool_name": { @@ -20146,49 +18696,6 @@ "title": "CreateBlock", "description": "Create a block" }, - "CustomFormatGrammar": { - "properties": { - "grammar": { - "$ref": "#/components/schemas/CustomFormatGrammarGrammar" - }, - "type": { - "type": "string", - "const": "grammar", - "title": "Type" - } - }, - "type": "object", - "required": ["grammar", "type"], - "title": "CustomFormatGrammar" - }, - "CustomFormatGrammarGrammar": { - "properties": { - "definition": { - "type": "string", - "title": "Definition" - }, - "syntax": { - "type": "string", - "enum": ["lark", "regex"], - "title": "Syntax" - } - }, - "type": "object", - "required": ["definition", "syntax"], - "title": "CustomFormatGrammarGrammar" - }, - "CustomFormatText": { - "properties": { - "type": { - "type": "string", - "const": "text", - "title": "Type" - } - }, - "type": "object", - "required": ["type"], - "title": "CustomFormatText" - }, "DeleteDeploymentResponse": { "properties": { "deleted_blocks": { @@ -20555,21 +19062,6 @@ "enum": ["positive", "negative"], "title": "FeedbackType" }, - "File": { - "properties": { - "file": { - "$ref": "#/components/schemas/FileFile" - }, - "type": { - "type": "string", - "const": "file", - "title": "Type" - } - }, - "type": "object", - "required": ["file", "type"], - "title": "File" - }, "FileAgentSchema": { "properties": { "agent_id": { @@ -20871,24 +19363,6 @@ "required": ["value", "file_id", "source_id", "is_open"], "title": "FileBlock" }, - "FileFile": { - "properties": { - "file_data": { - "type": "string", - "title": "File Data" - }, - "file_id": { - "type": "string", - "title": "File Id" - }, - "filename": { - "type": "string", - "title": "Filename" - } - }, - "type": "object", - "title": "FileFile" - }, "FileMetadata": { "properties": { "source_id": { @@ -21377,7 +19851,7 @@ "title": "Folder", "description": "Representation of a folder, which is a collection of files and passages.\n\nParameters:\n id (str): The ID of the folder\n name (str): The name of the folder.\n embedding_config (EmbeddingConfig): The embedding configuration used by the folder.\n user_id (str): The ID of the user that created the folder.\n metadata (dict): Metadata associated with the folder.\n description (str): The description of the folder." }, - "Function-Output": { + "Function": { "properties": { "arguments": { "type": "string", @@ -21393,53 +19867,7 @@ "required": ["arguments", "name"], "title": "Function" }, - "FunctionCall": { - "properties": { - "arguments": { - "type": "string", - "title": "Arguments" - }, - "name": { - "type": "string", - "title": "Name" - } - }, - "type": "object", - "required": ["arguments", "name"], - "title": "FunctionCall" - }, - "FunctionDefinition-Input": { - "properties": { - "name": { - "type": "string", - "title": "Name" - }, - "description": { - "type": "string", - "title": "Description" - }, - "parameters": { - "additionalProperties": true, - "type": "object", - "title": "Parameters" - }, - "strict": { - "anyOf": [ - { - "type": "boolean" - }, - { - "type": "null" - } - ], - "title": "Strict" - } - }, - "type": "object", - "required": ["name"], - "title": "FunctionDefinition" - }, - "FunctionDefinition-Output": { + "FunctionDefinition": { "properties": { "name": { "type": "string", @@ -21488,7 +19916,7 @@ "FunctionTool": { "properties": { "function": { - "$ref": "#/components/schemas/FunctionDefinition-Output" + "$ref": "#/components/schemas/FunctionDefinition" }, "type": { "type": "string", @@ -22596,22 +21024,6 @@ "required": ["source"], "title": "ImageContent" }, - "ImageURL": { - "properties": { - "url": { - "type": "string", - "title": "Url" - }, - "detail": { - "type": "string", - "enum": ["auto", "low", "high"], - "title": "Detail" - } - }, - "type": "object", - "required": ["url"], - "title": "ImageURL" - }, "ImportedAgentsResponse": { "properties": { "agent_ids": { @@ -22660,22 +21072,6 @@ "title": "InitToolRule", "description": "Represents the initial tool rule configuration." }, - "InputAudio": { - "properties": { - "data": { - "type": "string", - "title": "Data" - }, - "format": { - "type": "string", - "enum": ["wav", "mp3"], - "title": "Format" - } - }, - "type": "object", - "required": ["data", "format"], - "title": "InputAudio" - }, "InternalTemplateAgentCreate": { "properties": { "name": { @@ -23512,37 +21908,6 @@ "title": "InternalTemplateGroupCreate", "description": "Used for Letta Cloud" }, - "JSONSchema": { - "properties": { - "name": { - "type": "string", - "title": "Name" - }, - "description": { - "type": "string", - "title": "Description" - }, - "schema": { - "additionalProperties": true, - "type": "object", - "title": "Schema" - }, - "strict": { - "anyOf": [ - { - "type": "boolean" - }, - { - "type": "null" - } - ], - "title": "Strict" - } - }, - "type": "object", - "required": ["name"], - "title": "JSONSchema" - }, "Job": { "properties": { "created_by_id": { @@ -25036,7 +23401,7 @@ "anyOf": [ { "items": { - "$ref": "#/components/schemas/ChatCompletionMessageFunctionToolCall-Output" + "$ref": "#/components/schemas/ChatCompletionMessageFunctionToolCall" }, "type": "array" }, @@ -26692,45 +25057,6 @@ "title": "RequiresApprovalToolRule", "description": "Represents a tool rule configuration which requires approval before the tool can be invoked." }, - "ResponseFormatJSONObject": { - "properties": { - "type": { - "type": "string", - "const": "json_object", - "title": "Type" - } - }, - "type": "object", - "required": ["type"], - "title": "ResponseFormatJSONObject" - }, - "ResponseFormatJSONSchema": { - "properties": { - "json_schema": { - "$ref": "#/components/schemas/JSONSchema" - }, - "type": { - "type": "string", - "const": "json_schema", - "title": "Type" - } - }, - "type": "object", - "required": ["json_schema", "type"], - "title": "ResponseFormatJSONSchema" - }, - "ResponseFormatText": { - "properties": { - "type": { - "type": "string", - "const": "text", - "title": "Type" - } - }, - "type": "object", - "required": ["type"], - "title": "ResponseFormatText" - }, "RetrieveStreamRequest": { "properties": { "starting_after": { @@ -30870,64 +29196,6 @@ "type": "object", "title": "VoiceSleeptimeManagerUpdate" }, - "WebSearchOptions": { - "properties": { - "search_context_size": { - "type": "string", - "enum": ["low", "medium", "high"], - "title": "Search Context Size" - }, - "user_location": { - "anyOf": [ - { - "$ref": "#/components/schemas/WebSearchOptionsUserLocation" - }, - { - "type": "null" - } - ] - } - }, - "type": "object", - "title": "WebSearchOptions" - }, - "WebSearchOptionsUserLocation": { - "properties": { - "approximate": { - "$ref": "#/components/schemas/WebSearchOptionsUserLocationApproximate" - }, - "type": { - "type": "string", - "const": "approximate", - "title": "Type" - } - }, - "type": "object", - "required": ["approximate", "type"], - "title": "WebSearchOptionsUserLocation" - }, - "WebSearchOptionsUserLocationApproximate": { - "properties": { - "city": { - "type": "string", - "title": "City" - }, - "country": { - "type": "string", - "title": "Country" - }, - "region": { - "type": "string", - "title": "Region" - }, - "timezone": { - "type": "string", - "title": "Timezone" - } - }, - "type": "object", - "title": "WebSearchOptionsUserLocationApproximate" - }, "letta__schemas__agent_file__AgentSchema": { "properties": { "name": { @@ -31682,7 +29950,7 @@ "anyOf": [ { "items": { - "$ref": "#/components/schemas/ChatCompletionMessageFunctionToolCall-Input" + "$ref": "#/components/schemas/ChatCompletionMessageFunctionToolCall" }, "type": "array" }, @@ -32253,120 +30521,6 @@ ], "title": "ToolSchema" }, - "openai__types__chat__chat_completion_custom_tool_param__Custom": { - "properties": { - "name": { - "type": "string", - "title": "Name" - }, - "description": { - "type": "string", - "title": "Description" - }, - "format": { - "anyOf": [ - { - "$ref": "#/components/schemas/CustomFormatText" - }, - { - "$ref": "#/components/schemas/CustomFormatGrammar" - } - ], - "title": "Format" - } - }, - "type": "object", - "required": ["name"], - "title": "Custom" - }, - "openai__types__chat__chat_completion_message_custom_tool_call_param__Custom": { - "properties": { - "input": { - "type": "string", - "title": "Input" - }, - "name": { - "type": "string", - "title": "Name" - } - }, - "type": "object", - "required": ["input", "name"], - "title": "Custom" - }, - "openai__types__chat__chat_completion_message_function_tool_call__Function": { - "properties": { - "arguments": { - "type": "string", - "title": "Arguments" - }, - "name": { - "type": "string", - "title": "Name" - } - }, - "additionalProperties": true, - "type": "object", - "required": ["arguments", "name"], - "title": "Function" - }, - "openai__types__chat__chat_completion_message_function_tool_call_param__Function": { - "properties": { - "arguments": { - "type": "string", - "title": "Arguments" - }, - "name": { - "type": "string", - "title": "Name" - } - }, - "type": "object", - "required": ["arguments", "name"], - "title": "Function" - }, - "openai__types__chat__chat_completion_named_tool_choice_custom_param__Custom": { - "properties": { - "name": { - "type": "string", - "title": "Name" - } - }, - "type": "object", - "required": ["name"], - "title": "Custom" - }, - "openai__types__chat__chat_completion_named_tool_choice_param__Function": { - "properties": { - "name": { - "type": "string", - "title": "Name" - } - }, - "type": "object", - "required": ["name"], - "title": "Function" - }, - "openai__types__chat__completion_create_params__Function": { - "properties": { - "name": { - "type": "string", - "title": "Name" - }, - "description": { - "type": "string", - "title": "Description" - }, - "parameters": { - "additionalProperties": true, - "type": "object", - "title": "Parameters" - } - }, - "type": "object", - "required": ["name"], - "title": "Function" - }, "LettaMessageUnion": { "oneOf": [ { diff --git a/letta/agent.py b/letta/agent.py deleted file mode 100644 index 0a039428..00000000 --- a/letta/agent.py +++ /dev/null @@ -1,1758 +0,0 @@ -import asyncio -import json -import time -import traceback -import warnings -from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Union - -from openai.types.beta.function_tool import FunctionTool as OpenAITool - -from letta.agents.helpers import generate_step_id -from letta.constants import ( - CLI_WARNING_PREFIX, - COMPOSIO_ENTITY_ENV_VAR_KEY, - ERROR_MESSAGE_PREFIX, - FIRST_MESSAGE_ATTEMPTS, - FUNC_FAILED_HEARTBEAT_MESSAGE, - LETTA_CORE_TOOL_MODULE_NAME, - LETTA_MULTI_AGENT_TOOL_MODULE_NAME, - LLM_MAX_TOKENS, - READ_ONLY_BLOCK_EDIT_ERROR, - REQ_HEARTBEAT_MESSAGE, - SEND_MESSAGE_TOOL_NAME, -) -from letta.errors import ContextWindowExceededError -from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source -from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name -from letta.functions.functions import get_function_from_module -from letta.helpers import ToolRulesSolver -from letta.helpers.composio_helpers import get_composio_api_key -from letta.helpers.datetime_helpers import get_utc_time -from letta.helpers.json_helpers import json_dumps, json_loads -from letta.helpers.message_helper import convert_message_creates_to_messages -from letta.interface import AgentInterface -from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error -from letta.llm_api.llm_api_tools import create -from letta.llm_api.llm_client import LLMClient -from letta.local_llm.constants import INNER_THOUGHTS_KWARG -from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages -from letta.log import get_logger -from letta.memory import summarize_messages -from letta.orm import User -from letta.otel.tracing import log_event, trace_method -from letta.prompts.prompt_generator import PromptGenerator -from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent -from letta.schemas.block import BlockUpdate -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole, ProviderType, StepStatus, ToolType -from letta.schemas.letta_message_content import ImageContent, TextContent -from letta.schemas.memory import ContextWindowOverview, Memory -from letta.schemas.message import Message, MessageCreate, ToolReturn -from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Message as ChatCompletionMessage, UsageStatistics -from letta.schemas.response_format import ResponseFormatType -from letta.schemas.tool import Tool -from letta.schemas.tool_execution_result import ToolExecutionResult -from letta.schemas.tool_rule import TerminalToolRule -from letta.schemas.usage import LettaUsageStatistics -from letta.services.agent_manager import AgentManager -from letta.services.block_manager import BlockManager -from letta.services.helpers.agent_manager_helper import check_supports_structured_output -from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema -from letta.services.job_manager import JobManager -from letta.services.mcp.base_client import AsyncBaseMCPClient -from letta.services.message_manager import MessageManager -from letta.services.passage_manager import PassageManager -from letta.services.provider_manager import ProviderManager -from letta.services.step_manager import StepManager -from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager -from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox -from letta.services.tool_manager import ToolManager -from letta.settings import model_settings, settings, summarizer_settings -from letta.streaming_interface import StreamingRefreshCLIInterface -from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message -from letta.utils import count_tokens, get_friendly_error_msg, get_tool_call_id, log_telemetry, parse_json, validate_function_response - -logger = get_logger(__name__) - - -class BaseAgent(ABC): - """ - Abstract class for all agents. - Only one interface is required: step. - """ - - @abstractmethod - def step( - self, - input_messages: List[MessageCreate], - ) -> LettaUsageStatistics: - """ - Top-level event message handler for the agent. - """ - raise NotImplementedError - - -class Agent(BaseAgent): - def __init__( - self, - interface: Optional[Union[AgentInterface, StreamingRefreshCLIInterface]], - agent_state: AgentState, # in-memory representation of the agent state (read from multiple tables) - user: User, - # extras - first_message_verify_mono: bool = True, # TODO move to config? - # MCP sessions, state held in-memory in the server - mcp_clients: Optional[Dict[str, AsyncBaseMCPClient]] = None, - save_last_response: bool = False, - ): - assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}" - # Hold a copy of the state that was used to init the agent - self.agent_state = agent_state - assert isinstance(self.agent_state.memory, Memory), f"Memory object is not of type Memory: {type(self.agent_state.memory)}" - - self.user = user - - # initialize a tool rules solver - self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules) - - # gpt-4, gpt-3.5-turbo, ... - self.model = self.agent_state.llm_config.model - self.supports_structured_output = check_supports_structured_output(model=self.model, tool_rules=agent_state.tool_rules) - - # if there are tool rules, print out a warning - if not self.supports_structured_output and agent_state.tool_rules: - for rule in agent_state.tool_rules: - if not isinstance(rule, TerminalToolRule): - warnings.warn("Tool rules only work reliably for model backends that support structured outputs (e.g. OpenAI gpt-4o).") - break - - # state managers - self.block_manager = BlockManager() - - # Interface must implement: - # - internal_monologue - # - assistant_message - # - function_message - # ... - # Different interfaces can handle events differently - # e.g., print in CLI vs send a discord message with a discord bot - self.interface = interface - - # Create the persistence manager object based on the AgentState info - self.message_manager = MessageManager() - self.passage_manager = PassageManager() - self.provider_manager = ProviderManager() - self.agent_manager = AgentManager() - self.job_manager = JobManager() - self.step_manager = StepManager() - self.telemetry_manager = TelemetryManager() if settings.llm_api_logging else NoopTelemetryManager() - - # State needed for heartbeat pausing - - self.first_message_verify_mono = first_message_verify_mono - - # Controls if the convo memory pressure warning is triggered - # When an alert is sent in the message queue, set this to True (to avoid repeat alerts) - # When the summarizer is run, set this back to False (to reset) - self.agent_alerted_about_memory_pressure = False - - # Load last function response from message history - self.last_function_response = self.load_last_function_response() - - # Save last responses in memory - self.save_last_response = save_last_response - self.last_response_messages = [] - - # Logger that the Agent specifically can use, will also report the agent_state ID with the logs - self.logger = get_logger(agent_state.id) - - # MCPClient, state/sessions managed by the server - # TODO: This is temporary, as a bridge - self.mcp_clients = None - # TODO: no longer supported - # if mcp_clients: - # self.mcp_clients = {client_id: client.to_sync_client() for client_id, client in mcp_clients.items()} - - def load_last_function_response(self): - """Load the last function response from message history""" - in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) - for i in range(len(in_context_messages) - 1, -1, -1): - msg = in_context_messages[i] - if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent): - text_content = msg.content[0].text - try: - response_json = json.loads(text_content) - if response_json.get("message"): - return response_json["message"] - except (json.JSONDecodeError, KeyError): - raise ValueError(f"Invalid JSON format in message: {text_content}") - return None - - def ensure_read_only_block_not_modified(self, new_memory: Memory) -> None: - """ - Throw an error if a read-only block has been modified - """ - for label in self.agent_state.memory.list_block_labels(): - if self.agent_state.memory.get_block(label).read_only: - if new_memory.get_block(label).value != self.agent_state.memory.get_block(label).value: - raise ValueError(READ_ONLY_BLOCK_EDIT_ERROR) - - def update_memory_if_changed(self, new_memory: Memory) -> bool: - """ - Update internal memory object and system prompt if there have been modifications. - - Args: - new_memory (Memory): the new memory object to compare to the current memory object - - Returns: - modified (bool): whether the memory was updated - """ - system_message = self.message_manager.get_message_by_id(message_id=self.agent_state.message_ids[0], actor=self.user) - if new_memory.compile() not in system_message.content[0].text: - # update the blocks (LRW) in the DB - for label in self.agent_state.memory.list_block_labels(): - updated_value = new_memory.get_block(label).value - if updated_value != self.agent_state.memory.get_block(label).value: - # update the block if it's changed - block_id = self.agent_state.memory.get_block(label).id - self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user) - - # refresh memory from DB (using block ids) - self.agent_state.memory = Memory( - blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()], - file_blocks=self.agent_state.memory.file_blocks, - agent_type=self.agent_state.agent_type, - ) - - # NOTE: don't do this since re-buildin the memory is handled at the start of the step - # rebuild memory - this records the last edited timestamp of the memory - # TODO: pass in update timestamp from block edit time - self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user) - - return True - - return False - - def _handle_function_error_response( - self, - error_msg: str, - tool_call_id: str, - function_name: str, - function_args: dict, - function_response: str, - messages: List[Message], - tool_returns: Optional[List[ToolReturn]] = None, - include_function_failed_message: bool = False, - group_id: Optional[str] = None, - ) -> List[Message]: - """ - Handle error from function call response - """ - # Update tool rules - self.last_function_response = function_response - self.tool_rules_solver.register_tool_call(function_name) - - # Extend conversation with function response - function_response = package_function_response(False, error_msg, self.agent_state.timezone) - new_message = Message( - agent_id=self.agent_state.id, - # Base info OpenAI-style - model=self.model, - role="tool", - name=function_name, # NOTE: when role is 'tool', the 'name' is the function name, not agent name - content=[TextContent(text=function_response)], - tool_call_id=tool_call_id, - # Letta extras - tool_returns=tool_returns, - group_id=group_id, - ) - messages.append(new_message) - self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message, chunk_index=0) - if include_function_failed_message: - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message) - - # Return updated messages - return messages - - def _runtime_override_tool_json_schema( - self, - functions_list: List[Dict | None], - ) -> List[Dict | None]: - """Override the tool JSON schema at runtime for a particular tool if conditions are met.""" - - # Currently just injects `send_message` with a `response_format` if provided to the agent. - if self.agent_state.response_format and self.agent_state.response_format.type != ResponseFormatType.text: - for func in functions_list: - if func["name"] == SEND_MESSAGE_TOOL_NAME: - if self.agent_state.response_format.type == ResponseFormatType.json_schema: - func["parameters"]["properties"]["message"] = self.agent_state.response_format.json_schema["schema"] - if self.agent_state.response_format.type == ResponseFormatType.json_object: - func["parameters"]["properties"]["message"] = { - "type": "object", - "description": "Message contents. All unicode (including emojis) are supported.", - "additionalProperties": True, - "properties": {}, - } - break - return functions_list - - @trace_method - def _get_ai_reply( - self, - message_sequence: List[Message], - function_call: Optional[str] = None, - first_message: bool = False, - stream: bool = False, # TODO move to config? - empty_response_retry_limit: int = 3, - backoff_factor: float = 0.5, # delay multiplier for exponential backoff - max_delay: float = 10.0, # max delay between retries - step_count: Optional[int] = None, - last_function_failed: bool = False, - put_inner_thoughts_first: bool = True, - step_id: Optional[str] = None, - ) -> ChatCompletionResponse | None: - """Get response from LLM API with robust retry mechanism.""" - log_telemetry(self.logger, "_get_ai_reply start") - available_tools = set([t.name for t in self.agent_state.tools]) - agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] - - # Get allowed tools or allow all if none are allowed - allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names( - available_tools=available_tools, last_function_response=self.last_function_response - ) or list(available_tools) - - # Don't allow a tool to be called if it failed last time - if last_function_failed and self.tool_rules_solver.tool_call_history: - allowed_tool_names = [f for f in allowed_tool_names if f != self.tool_rules_solver.tool_call_history[-1]] - if not allowed_tool_names: - return None - - allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names] - # Extract terminal tool names from tool rules - terminal_tool_names = {rule.tool_name for rule in self.tool_rules_solver.terminal_tool_rules} - allowed_functions = runtime_override_tool_json_schema( - tool_list=allowed_functions, - response_format=self.agent_state.response_format, - request_heartbeat=True, - terminal_tools=terminal_tool_names, - ) - - # For the first message, force the initial tool if one is specified - force_tool_call = None - if ( - step_count is not None - and step_count == 0 - and not self.supports_structured_output - and len(self.tool_rules_solver.init_tool_rules) > 0 - ): - # TODO: This just seems wrong? What if there are more than 1 init tool rules? - force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name - # Force a tool call if exactly one tool is specified - elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1: - force_tool_call = allowed_tool_names[0] - - for attempt in range(1, empty_response_retry_limit + 1): - try: - log_telemetry(self.logger, "_get_ai_reply create start") - # New LLM client flow - llm_client = LLMClient.create( - provider_type=self.agent_state.llm_config.model_endpoint_type, - put_inner_thoughts_first=put_inner_thoughts_first, - actor=self.user, - ) - - if llm_client and not stream: - response = llm_client.send_llm_request( - messages=message_sequence, - llm_config=self.agent_state.llm_config, - tools=allowed_functions, - force_tool_call=force_tool_call, - telemetry_manager=self.telemetry_manager, - step_id=step_id, - ) - else: - # Fallback to existing flow - for message in message_sequence: - if isinstance(message.content, list): - - def get_fallback_text_content(content): - if isinstance(content, ImageContent): - return TextContent(text="[Image Here]") - return content - - message.content = [get_fallback_text_content(content) for content in message.content] - - response = create( - llm_config=self.agent_state.llm_config, - messages=message_sequence, - user_id=self.agent_state.created_by_id, - functions=allowed_functions, - # functions_python=self.functions_python, do we need this? - function_call=function_call, - first_message=first_message, - force_tool_call=force_tool_call, - stream=stream, - stream_interface=self.interface, - put_inner_thoughts_first=put_inner_thoughts_first, - name=self.agent_state.name, - telemetry_manager=self.telemetry_manager, - step_id=step_id, - actor=self.user, - ) - log_telemetry(self.logger, "_get_ai_reply create finish") - - # These bottom two are retryable - if len(response.choices) == 0 or response.choices[0] is None: - raise ValueError(f"API call returned an empty message: {response}") - - if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]: - if response.choices[0].finish_reason == "length": - # This is not retryable, hence RuntimeError v.s. ValueError - raise RuntimeError("Finish reason was length (maximum context length)") - else: - raise ValueError(f"Bad finish reason from API: {response.choices[0].finish_reason}") - log_telemetry(self.logger, "_handle_ai_response finish") - - except ValueError as ve: - if attempt >= empty_response_retry_limit: - warnings.warn(f"Retry limit reached. Final error: {ve}") - log_telemetry(self.logger, "_handle_ai_response finish ValueError") - raise Exception(f"Retries exhausted and no valid response received. Final error: {ve}") - else: - delay = min(backoff_factor * (2 ** (attempt - 1)), max_delay) - warnings.warn(f"Attempt {attempt} failed: {ve}. Retrying in {delay} seconds...") - time.sleep(delay) - continue - - except Exception as e: - # For non-retryable errors, exit immediately - log_telemetry(self.logger, "_handle_ai_response finish generic Exception") - raise e - - # check if we are going over the context window: this allows for articifial constraints - if response.usage.total_tokens > self.agent_state.llm_config.context_window: - # trigger summarization - log_telemetry(self.logger, "_get_ai_reply summarize_messages_inplace") - self.summarize_messages_inplace() - - # return the response - return response - - log_telemetry(self.logger, "_handle_ai_response finish catch-all exception") - raise Exception("Retries exhausted and no valid response received.") - - @trace_method - def _handle_ai_response( - self, - response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function? - override_tool_call_id: bool = False, - # If we are streaming, we needed to create a Message ID ahead of time, - # and now we want to use it in the creation of the Message object - # TODO figure out a cleaner way to do this - response_message_id: Optional[str] = None, - group_id: Optional[str] = None, - ) -> Tuple[List[Message], bool, bool]: - """Handles parsing and function execution""" - log_telemetry(self.logger, "_handle_ai_response start") - # Hacky failsafe for now to make sure we didn't implement the streaming Message ID creation incorrectly - if response_message_id is not None: - assert response_message_id.startswith("message-"), response_message_id - - messages = [] # append these to the history when done - function_name = None - function_args = {} - chunk_index = 0 - - # Step 2: check if LLM wanted to call a function - if response_message.function_call or (response_message.tool_calls is not None and len(response_message.tool_calls) > 0): - if response_message.function_call: - raise DeprecationWarning(response_message) - if response_message.tool_calls is not None and len(response_message.tool_calls) > 1: - # raise NotImplementedError(f">1 tool call not supported") - # TODO eventually support sequential tool calling - self.logger.warning(f">1 tool call not supported, using index=0 only\n{response_message.tool_calls}") - response_message.tool_calls = [response_message.tool_calls[0]] - assert response_message.tool_calls is not None and len(response_message.tool_calls) > 0 - - # generate UUID for tool call - if override_tool_call_id or response_message.function_call: - warnings.warn("Overriding the tool call can result in inconsistent tool call IDs during streaming") - tool_call_id = get_tool_call_id() # needs to be a string for JSON - response_message.tool_calls[0].id = tool_call_id - else: - tool_call_id = response_message.tool_calls[0].id - assert tool_call_id is not None # should be defined - - # only necessary to add the tool_call_id to a function call (antipattern) - # response_message_dict = response_message.model_dump() - # response_message_dict["tool_call_id"] = tool_call_id - - # role: assistant (requesting tool call, set tool call ID) - messages.append( - # NOTE: we're recreating the message here - # TODO should probably just overwrite the fields? - Message.dict_to_message( - id=response_message_id, - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict=response_message.model_dump(), - name=self.agent_state.name, - group_id=group_id, - ) - ) # extend conversation with assistant's reply - self.logger.debug(f"Function call message: {messages[-1]}") - - nonnull_content = False - if response_message.content or response_message.reasoning_content or response_message.redacted_reasoning_content: - # The content if then internal monologue, not chat - self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index) - chunk_index += 1 - # Flag to avoid printing a duplicate if inner thoughts get popped from the function call - nonnull_content = True - - # Step 3: call the function - # Note: the JSON response may not always be valid; be sure to handle errors - function_call = ( - response_message.function_call if response_message.function_call is not None else response_message.tool_calls[0].function - ) - function_name = function_call.name - self.logger.info(f"Request to call function {function_name} with tool_call_id: {tool_call_id}") - - # Failure case 1: function name is wrong (not in agent_state.tools) - target_letta_tool = None - for t in self.agent_state.tools: - if t.name == function_name: - # This force refreshes the target_letta_tool from the database - # We only do this on name match to confirm that the agent state contains a specific tool with the right name - target_letta_tool = ToolManager().get_tool_by_name(tool_name=function_name, actor=self.user) - break - - if not target_letta_tool: - error_msg = f"No function named {function_name}" - function_response = "None" # more like "never ran?" - messages = self._handle_function_error_response( - error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id - ) - return messages, False, True # force a heartbeat to allow agent to handle error - - # Failure case 2: function name is OK, but function args are bad JSON - try: - raw_function_args = function_call.arguments - function_args = parse_json(raw_function_args) - if not isinstance(function_args, dict): - raise ValueError(f"Function arguments are not a dictionary: {function_args} (raw={raw_function_args})") - except Exception as e: - print(e) - error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}" - function_response = "None" # more like "never ran?" - messages = self._handle_function_error_response( - error_msg, tool_call_id, function_name, function_args, function_response, messages, group_id=group_id - ) - return messages, False, True # force a heartbeat to allow agent to handle error - - # Check if inner thoughts is in the function call arguments (possible apparently if you are using Azure) - if INNER_THOUGHTS_KWARG in function_args: - response_message.content = function_args.pop(INNER_THOUGHTS_KWARG) - # The content if then internal monologue, not chat - if response_message.content and not nonnull_content: - self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index) - chunk_index += 1 - - # (Still parsing function args) - # Handle requests for immediate heartbeat - heartbeat_request = function_args.pop("request_heartbeat", None) - - # Edge case: heartbeat_request is returned as a stringified boolean, we will attempt to parse: - if isinstance(heartbeat_request, str) and heartbeat_request.lower().strip() == "true": - heartbeat_request = True - - if heartbeat_request is None: - heartbeat_request = False - - if not isinstance(heartbeat_request, bool): - self.logger.warning( - f"{CLI_WARNING_PREFIX}'request_heartbeat' arg parsed was not a bool or None, type={type(heartbeat_request)}, value={heartbeat_request}" - ) - heartbeat_request = False - - # Failure case 3: function failed during execution - # NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message - # this is because the function/tool role message is only created once the function/tool has executed/returned - - # handle cases where we return a json message - if "message" in function_args: - function_args["message"] = str(function_args.get("message", "")) - self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=chunk_index) - chunk_index = 0 # reset chunk index after assistant message - try: - # handle tool execution (sandbox) and state updates - log_telemetry( - self.logger, "_handle_ai_response execute tool start", function_name=function_name, function_args=function_args - ) - log_event( - "tool_call_initiated", - attributes={ - "function_name": function_name, - "target_letta_tool": target_letta_tool.model_dump(), - **{f"function_args.{k}": v for k, v in function_args.items()}, - }, - ) - - tool_execution_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) - function_response = tool_execution_result.func_return - - log_event( - "tool_call_ended", - attributes={ - "function_response": function_response, - "tool_execution_result": tool_execution_result.model_dump(), - }, - ) - log_telemetry( - self.logger, "_handle_ai_response execute tool finish", function_name=function_name, function_args=function_args - ) - - if tool_execution_result and tool_execution_result.status == "error": - tool_return = ToolReturn( - status=tool_execution_result.status, stdout=tool_execution_result.stdout, stderr=tool_execution_result.stderr - ) - messages = self._handle_function_error_response( - function_response, - tool_call_id, - function_name, - function_args, - function_response, - messages, - [tool_return], - group_id=group_id, - ) - return messages, False, True # force a heartbeat to allow agent to handle error - - # handle trunction - if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: - # with certain functions we rely on the paging mechanism to handle overflow - truncate = False - else: - # but by default, we add a truncation safeguard to prevent bad functions from - # overflow the agent context window - truncate = True - - # get the function response limit - return_char_limit = target_letta_tool.return_char_limit - function_response_string = validate_function_response( - function_response, return_char_limit=return_char_limit, truncate=truncate - ) - function_args.pop("self", None) - function_response = package_function_response(True, function_response_string, self.agent_state.timezone) - function_failed = False - except Exception as e: - function_args.pop("self", None) - # error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}" - # Less detailed - don't provide full args, idea is that it should be in recent context so no need (just adds noise) - error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)) - error_msg_user = f"{error_msg}\n{traceback.format_exc()}" - self.logger.error(error_msg_user) - messages = self._handle_function_error_response( - error_msg, - tool_call_id, - function_name, - function_args, - function_response, - messages, - [ToolReturn(status="error", stderr=[error_msg_user])], - include_function_failed_message=True, - group_id=group_id, - ) - return messages, False, True # force a heartbeat to allow agent to handle error - - # Step 4: check if function response is an error - if function_response_string.startswith(ERROR_MESSAGE_PREFIX): - error_msg = function_response_string - tool_return = ToolReturn( - status=tool_execution_result.status, - stdout=tool_execution_result.stdout, - stderr=tool_execution_result.stderr, - ) - messages = self._handle_function_error_response( - error_msg, - tool_call_id, - function_name, - function_args, - function_response, - messages, - [tool_return], - include_function_failed_message=True, - group_id=group_id, - ) - return messages, False, True # force a heartbeat to allow agent to handle error - - # If no failures happened along the way: ... - # Step 5: send the info on the function call and function response to GPT - tool_return = ToolReturn( - status=tool_execution_result.status, - stdout=tool_execution_result.stdout, - stderr=tool_execution_result.stderr, - ) - messages.append( - Message( - agent_id=self.agent_state.id, - # Base info OpenAI-style - model=self.model, - role="tool", - name=function_name, # NOTE: when role is 'tool', the 'name' is the function name, not agent name - content=[TextContent(text=function_response)], - tool_call_id=tool_call_id, - # Letta extras - tool_returns=[tool_return], - group_id=group_id, - ) - ) # extend conversation with function response - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1], chunk_index=chunk_index) - self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1], chunk_index=chunk_index) - chunk_index += 1 - self.last_function_response = function_response - - else: - # Standard non-function reply - messages.append( - Message.dict_to_message( - id=response_message_id, - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict=response_message.model_dump(), - name=self.agent_state.name, - group_id=group_id, - ) - ) # extend conversation with assistant's reply - self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=chunk_index) - chunk_index += 1 - heartbeat_request = False - function_failed = False - - # rebuild memory - # TODO: @charles please check this - self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user) - - # Update ToolRulesSolver state with last called function - self.tool_rules_solver.register_tool_call(function_name) - # Update heartbeat request according to provided tool rules - if self.tool_rules_solver.has_children_tools(function_name): - heartbeat_request = True - elif self.tool_rules_solver.is_terminal_tool(function_name): - heartbeat_request = False - - # if continue tool rule, then must request a heartbeat - # TODO: dont even include heartbeats in the args - if self.tool_rules_solver.is_continue_tool(function_name): - heartbeat_request = True - - log_telemetry(self.logger, "_handle_ai_response finish") - return messages, heartbeat_request, function_failed - - @trace_method - def step( - self, - input_messages: List[MessageCreate], - # additional args - chaining: bool = True, - max_chaining_steps: Optional[int] = None, - put_inner_thoughts_first: bool = True, - **kwargs, - ) -> LettaUsageStatistics: - """Run Agent.step in a loop, handling chaining via heartbeat requests and function failures""" - # Defensively clear the tool rules solver history - # Usually this would be extraneous as Agent loop is re-loaded on every message send - # But just to be safe - self.tool_rules_solver.clear_tool_history() - - # Convert MessageCreate objects to Message objects - next_input_messages = convert_message_creates_to_messages(input_messages, self.agent_state.id, self.agent_state.timezone) - counter = 0 - total_usage = UsageStatistics() - step_count = 0 - function_failed = False - steps_messages = [] - while True: - kwargs["first_message"] = False - kwargs["step_count"] = step_count - kwargs["last_function_failed"] = function_failed - step_response = self.inner_step( - messages=next_input_messages, - put_inner_thoughts_first=put_inner_thoughts_first, - **kwargs, - ) - - heartbeat_request = step_response.heartbeat_request - function_failed = step_response.function_failed - token_warning = step_response.in_context_memory_warning - usage = step_response.usage - steps_messages.append(step_response.messages) - - step_count += 1 - total_usage += usage - counter += 1 - self.interface.step_complete() - - # logger.debug("Saving agent state") - # save updated state - save_agent(self) - - # Chain stops - if not chaining: - self.logger.info("No chaining, stopping after one step") - break - elif max_chaining_steps is not None and counter > max_chaining_steps: - self.logger.info(f"Hit max chaining steps, stopping after {counter} steps") - break - # Chain handlers - elif token_warning and summarizer_settings.send_memory_warning_message: - assert self.agent_state.created_by_id is not None - next_input_messages = [ - Message.dict_to_message( - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict={ - "role": "user", # TODO: change to system? - "content": get_token_limit_warning(), - }, - ), - ] - continue # always chain - elif function_failed: - assert self.agent_state.created_by_id is not None - next_input_messages = [ - Message.dict_to_message( - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict={ - "role": "user", # TODO: change to system? - "content": get_heartbeat(self.agent_state.timezone, FUNC_FAILED_HEARTBEAT_MESSAGE), - }, - ) - ] - continue # always chain - elif heartbeat_request: - assert self.agent_state.created_by_id is not None - next_input_messages = [ - Message.dict_to_message( - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict={ - "role": "user", # TODO: change to system? - "content": get_heartbeat(self.agent_state.timezone, REQ_HEARTBEAT_MESSAGE), - }, - ) - ] - continue # always chain - # Letta no-op / yield - else: - break - - if self.agent_state.message_buffer_autoclear: - self.logger.info("Autoclearing message buffer") - self.agent_state = self.agent_manager.trim_all_in_context_messages_except_system(self.agent_state.id, actor=self.user) - - return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count, steps_messages=steps_messages) - - def inner_step( - self, - messages: List[Message], - first_message: bool = False, - first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, - skip_verify: bool = False, - stream: bool = False, # TODO move to config? - step_count: Optional[int] = None, - metadata: Optional[dict] = None, - summarize_attempt_count: int = 0, - last_function_failed: bool = False, - put_inner_thoughts_first: bool = True, - ) -> AgentStepResponse: - """Runs a single step in the agent loop (generates at most one LLM call)""" - try: - # Extract job_id from metadata if present - job_id = metadata.get("job_id") if metadata else None - - # Declare step_id for the given step to be used as the step is processing. - step_id = generate_step_id() - - # Step 0: update core memory - # only pulling latest block data if shared memory is being used - current_persisted_memory = Memory( - blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()], - file_blocks=self.agent_state.memory.file_blocks, - agent_type=self.agent_state.agent_type, - ) # read blocks from DB - self.update_memory_if_changed(current_persisted_memory) - - # Step 1: add user message - if not all(isinstance(m, Message) for m in messages): - raise ValueError(f"messages should be a list of Message, got {[type(m) for m in messages]}") - - in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) - input_message_sequence = in_context_messages + messages - - if ( - len(input_message_sequence) > 1 - and input_message_sequence[-1].role != "user" - and input_message_sequence[-1].group_id is None - ): - self.logger.warning(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue") - - # Step 2: send the conversation and available functions to the LLM - response = self._get_ai_reply( - message_sequence=input_message_sequence, - first_message=first_message, - stream=stream, - step_count=step_count, - last_function_failed=last_function_failed, - put_inner_thoughts_first=put_inner_thoughts_first, - step_id=step_id, - ) - if not response: - # EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early - return AgentStepResponse( - messages=input_message_sequence, - heartbeat_request=False, - function_failed=False, # NOTE: this is different from other function fails. We force to return early - in_context_memory_warning=False, - usage=UsageStatistics(), - ) - - # Step 3: check if LLM wanted to call a function - # (if yes) Step 4: call the function - # (if yes) Step 5: send the info on the function call and function response to LLM - response_message = response.choices[0].message - - response_message.model_copy() # TODO why are we copying here? - all_response_messages, heartbeat_request, function_failed = self._handle_ai_response( - response_message, - # TODO this is kind of hacky, find a better way to handle this - # the only time we set up message creation ahead of time is when streaming is on - response_message_id=response.id if stream else None, - group_id=input_message_sequence[-1].group_id, - ) - - # Step 6: extend the message history - if len(messages) > 0: - all_new_messages = messages + all_response_messages - else: - all_new_messages = all_response_messages - - if self.save_last_response: - self.last_response_messages = all_response_messages - - # Check the memory pressure and potentially issue a memory pressure warning - current_total_tokens = response.usage.total_tokens - active_memory_warning = False - - # We can't do summarize logic properly if context_window is undefined - if self.agent_state.llm_config.context_window is None: - # Fallback if for some reason context_window is missing, just set to the default - print(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}") - print(f"{self.agent_state}") - self.agent_state.llm_config.context_window = ( - LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"] - ) - - if current_total_tokens > summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window): - logger.warning( - f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}" - ) - - log_event( - name="memory_pressure_warning", - attributes={ - "current_total_tokens": current_total_tokens, - "context_window_limit": self.agent_state.llm_config.context_window, - }, - ) - # Only deliver the alert if we haven't already (this period) - if not self.agent_alerted_about_memory_pressure: - active_memory_warning = True - self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this - - else: - logger.info( - f"last response total_tokens ({current_total_tokens}) < {summarizer_settings.memory_warning_threshold * int(self.agent_state.llm_config.context_window)}" - ) - - # Log step - this must happen before messages are persisted - step = self.step_manager.log_step( - actor=self.user, - agent_id=self.agent_state.id, - provider_name=self.agent_state.llm_config.model_endpoint_type, - provider_category=self.agent_state.llm_config.provider_category or "base", - model=self.agent_state.llm_config.model, - model_endpoint=self.agent_state.llm_config.model_endpoint, - context_window_limit=self.agent_state.llm_config.context_window, - usage=response.usage, - provider_id=self.provider_manager.get_provider_id_from_name( - self.agent_state.llm_config.provider_name, - actor=self.user, - ), - job_id=job_id, - step_id=step_id, - project_id=self.agent_state.project_id, - status=StepStatus.SUCCESS, # Set to SUCCESS since we're logging after successful completion - ) - for message in all_new_messages: - message.step_id = step.id - - # Persisting into Messages - self.agent_state = self.agent_manager.append_to_in_context_messages( - all_new_messages, agent_id=self.agent_state.id, actor=self.user - ) - if job_id: - for message in all_new_messages: - if message.role != "user": - self.job_manager.add_message_to_job( - job_id=job_id, - message_id=message.id, - actor=self.user, - ) - - return AgentStepResponse( - messages=all_new_messages, - heartbeat_request=heartbeat_request, - function_failed=function_failed, - in_context_memory_warning=active_memory_warning, - usage=response.usage, - ) - - except Exception as e: - logger.error(f"step() failed\nmessages = {messages}\nerror = {e}") - - # If we got a context alert, try trimming the messages length, then try again - if is_context_overflow_error(e): - in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) - - # TODO: this is a patch to resolve immediate issues, should be removed once the summarizer is fixes - if self.agent_state.message_buffer_autoclear: - # no calling the summarizer in this case - logger.error( - f"step() failed with an exception that looks like a context window overflow, but message buffer is set to autoclear, so skipping: '{str(e)}'" - ) - raise e - - if summarize_attempt_count <= summarizer_settings.max_summarizer_retries: - logger.warning( - f"context window exceeded with limit {self.agent_state.llm_config.context_window}, attempting to summarize ({summarize_attempt_count}/{summarizer_settings.max_summarizer_retries}" - ) - # A separate API call to run a summarizer - self.summarize_messages_inplace() - - # Try step again - return self.inner_step( - messages=messages, - first_message=first_message, - first_message_retry_limit=first_message_retry_limit, - skip_verify=skip_verify, - stream=stream, - metadata=metadata, - summarize_attempt_count=summarize_attempt_count + 1, - ) - else: - err_msg = f"Ran summarizer {summarize_attempt_count - 1} times for agent id={self.agent_state.id}, but messages are still overflowing the context window." - token_counts = (get_token_counts_for_messages(in_context_messages),) - logger.error(err_msg) - logger.error(f"num_in_context_messages: {len(self.agent_state.message_ids)}") - logger.error(f"token_counts: {token_counts}") - raise ContextWindowExceededError( - err_msg, - details={ - "num_in_context_messages": len(self.agent_state.message_ids), - "in_context_messages_text": [m.content for m in in_context_messages], - "token_counts": token_counts, - }, - ) - - else: - logger.error(f"step() failed with an unrecognized exception: '{str(e)}'") - traceback.print_exc() - raise e - - def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepResponse: - """Takes a basic user message string, turns it into a stringified JSON with extra metadata, then sends it to the agent - - Example: - -> user_message_str = 'hi' - -> {'message': 'hi', 'type': 'user_message', ...} - -> json.dumps(...) - -> agent.step(messages=[Message(role='user', text=...)]) - """ - # Wrap with metadata, dumps to JSON - assert user_message_str and isinstance(user_message_str, str), ( - f"user_message_str should be a non-empty string, got {type(user_message_str)}" - ) - user_message_json_str = package_user_message(user_message_str, self.agent_state.timezone) - - # Validate JSON via save/load - user_message = validate_json(user_message_json_str) - cleaned_user_message_text, name = strip_name_field_from_user_message(user_message) - - # Turn into a dict - openai_message_dict = {"role": "user", "content": cleaned_user_message_text, "name": name} - - # Create the associated Message object (in the database) - assert self.agent_state.created_by_id is not None, "User ID is not set" - user_message = Message.dict_to_message( - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict=openai_message_dict, - # created_at=timestamp, - ) - - return self.inner_step(messages=[user_message], **kwargs) - - def summarize_messages_inplace(self): - in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) - in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) - in_context_messages_openai_no_system = in_context_messages_openai[1:] - token_counts = get_token_counts_for_messages(in_context_messages) - logger.info(f"System message token count={token_counts[0]}") - logger.info(f"token_counts_no_system={token_counts[1:]}") - - if in_context_messages_openai[0]["role"] != "system": - raise RuntimeError(f"in_context_messages_openai[0] should be system (instead got {in_context_messages_openai[0]})") - - # If at this point there's nothing to summarize, throw an error - if len(in_context_messages_openai_no_system) == 0: - raise ContextWindowExceededError( - "Not enough messages to compress for summarization", - details={ - "num_candidate_messages": len(in_context_messages_openai_no_system), - "num_total_messages": len(in_context_messages_openai), - }, - ) - - cutoff = calculate_summarizer_cutoff(in_context_messages=in_context_messages, token_counts=token_counts, logger=logger) - message_sequence_to_summarize = in_context_messages[1:cutoff] # do NOT get rid of the system message - logger.info(f"Attempting to summarize {len(message_sequence_to_summarize)} messages of {len(in_context_messages)}") - - # We can't do summarize logic properly if context_window is undefined - if self.agent_state.llm_config.context_window is None: - # Fallback if for some reason context_window is missing, just set to the default - logger.warning(f"{CLI_WARNING_PREFIX}could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}") - self.agent_state.llm_config.context_window = ( - LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"] - ) - - summary = summarize_messages( - agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize, actor=self.user - ) - logger.info(f"Got summary: {summary}") - - # Metadata that's useful for the agent to see - all_time_message_count = self.message_manager.size(agent_id=self.agent_state.id, actor=self.user) - remaining_message_count = 1 + len(in_context_messages) - cutoff # System + remaining - hidden_message_count = all_time_message_count - remaining_message_count - summary_message_count = len(message_sequence_to_summarize) - summary_message = package_summarize_message( - summary, summary_message_count, hidden_message_count, all_time_message_count, self.agent_state.timezone - ) - logger.info(f"Packaged into message: {summary_message}") - - prior_len = len(in_context_messages_openai) - self.agent_state = self.agent_manager.trim_older_in_context_messages(num=cutoff, agent_id=self.agent_state.id, actor=self.user) - packed_summary_message = {"role": "user", "content": summary_message} - # Prepend the summary - self.agent_state = self.agent_manager.prepend_to_in_context_messages( - messages=[ - Message.dict_to_message( - agent_id=self.agent_state.id, - model=self.model, - openai_message_dict=packed_summary_message, - ) - ], - agent_id=self.agent_state.id, - actor=self.user, - ) - - # reset alert - self.agent_alerted_about_memory_pressure = False - curr_in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) - - current_token_count = sum(get_token_counts_for_messages(curr_in_context_messages)) - logger.info(f"Ran summarizer, messages length {prior_len} -> {len(curr_in_context_messages)}") - logger.info(f"Summarizer brought down total token count from {sum(token_counts)} -> {current_token_count}") - log_event( - name="summarization", - attributes={ - "prior_length": prior_len, - "current_length": len(curr_in_context_messages), - "prior_token_count": sum(token_counts), - "current_token_count": current_token_count, - "context_window_limit": self.agent_state.llm_config.context_window, - }, - ) - - def add_function(self, function_name: str) -> str: - # TODO: refactor - raise NotImplementedError - - def remove_function(self, function_name: str) -> str: - # TODO: refactor - raise NotImplementedError - - def migrate_embedding(self, embedding_config: EmbeddingConfig): - """Migrate the agent to a new embedding""" - # TODO: archival memory - - # TODO: recall memory - raise NotImplementedError() - - def get_context_window(self) -> ContextWindowOverview: - """Get the context window of the agent""" - - system_prompt = self.agent_state.system # TODO is this the current system or the initial system? - num_tokens_system = count_tokens(system_prompt) - core_memory = self.agent_state.memory.compile() - num_tokens_core_memory = count_tokens(core_memory) - - # Grab the in-context messages - # conversion of messages to OpenAI dict format, which is passed to the token counter - in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) - in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) - - # Check if there's a summary message in the message queue - if ( - len(in_context_messages) > 1 - and in_context_messages[1].role == MessageRole.user - and in_context_messages[1].content - and len(in_context_messages[1].content) == 1 - and isinstance(in_context_messages[1].content[0], TextContent) - # TODO remove hardcoding - and "The following is a summary of the previous " in in_context_messages[1].content[0].text - ): - # Summary message exists - text_content = in_context_messages[1].content[0].text - assert text_content is not None - summary_memory = text_content - num_tokens_summary_memory = count_tokens(text_content) - # with a summary message, the real messages start at index 2 - num_tokens_messages = ( - num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model) - if len(in_context_messages_openai) > 2 - else 0 - ) - - else: - summary_memory = None - num_tokens_summary_memory = 0 - # with no summary message, the real messages start at index 1 - num_tokens_messages = ( - num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model) - if len(in_context_messages_openai) > 1 - else 0 - ) - - agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id) - message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id) - external_memory_summary = PromptGenerator.compile_memory_metadata_block( - memory_edit_timestamp=get_utc_time(), - timezone=self.agent_state.timezone, - previous_message_count=self.message_manager.size(actor=self.user, agent_id=self.agent_state.id), - archival_memory_size=self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id), - ) - num_tokens_external_memory_summary = count_tokens(external_memory_summary) - - # tokens taken up by function definitions - agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] - if agent_state_tool_jsons: - available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons] - num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model) - else: - available_functions_definitions = [] - num_tokens_available_functions_definitions = 0 - - num_tokens_used_total = ( - num_tokens_system # system prompt - + num_tokens_available_functions_definitions # function definitions - + num_tokens_core_memory # core memory - + num_tokens_external_memory_summary # metadata (statistics) about recall/archival - + num_tokens_summary_memory # summary of ongoing conversation - + num_tokens_messages # tokens taken by messages - ) - assert isinstance(num_tokens_used_total, int) - - return ContextWindowOverview( - # context window breakdown (in messages) - num_messages=len(in_context_messages), - num_archival_memory=agent_manager_passage_size, - num_recall_memory=message_manager_size, - num_tokens_external_memory_summary=num_tokens_external_memory_summary, - external_memory_summary=external_memory_summary, - # top-level information - context_window_size_max=self.agent_state.llm_config.context_window, - context_window_size_current=num_tokens_used_total, - # context window breakdown (in tokens) - num_tokens_system=num_tokens_system, - system_prompt=system_prompt, - num_tokens_core_memory=num_tokens_core_memory, - core_memory=core_memory, - num_tokens_summary_memory=num_tokens_summary_memory, - summary_memory=summary_memory, - num_tokens_messages=num_tokens_messages, - messages=in_context_messages, - # related to functions - num_tokens_functions_definitions=num_tokens_available_functions_definitions, - functions_definitions=available_functions_definitions, - ) - - async def get_context_window_async(self) -> ContextWindowOverview: - if settings.environment == "PRODUCTION" and model_settings.anthropic_api_key: - return await self.get_context_window_from_anthropic_async() - return await self.get_context_window_from_tiktoken_async() - - async def get_context_window_from_tiktoken_async(self) -> ContextWindowOverview: - """Get the context window of the agent""" - # Grab the in-context messages - in_context_messages = await self.message_manager.get_messages_by_ids_async( - message_ids=self.agent_state.message_ids, actor=self.user - ) - - # conversion of messages to OpenAI dict format, which is passed to the token counter - in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) - - # Extract system, memory and external summary - if ( - len(in_context_messages) > 0 - and in_context_messages[0].role == MessageRole.system - and in_context_messages[0].content - and len(in_context_messages[0].content) == 1 - and isinstance(in_context_messages[0].content[0], TextContent) - ): - system_message = in_context_messages[0].content[0].text - - external_memory_marker_pos = system_message.find("###") - core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) - if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: - system_prompt = system_message[:external_memory_marker_pos].strip() - external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() - core_memory = system_message[core_memory_marker_pos:].strip() - else: - # if no markers found, put everything in system message - self.logger.info("No markers found in system message, core_memory and external_memory_summary will not be loaded") - system_prompt = system_message - external_memory_summary = "" - core_memory = "" - else: - # if no system message, fall back on agent's system prompt - self.logger.info("No system message found in history, core_memory and external_memory_summary will not be loaded") - system_prompt = self.agent_state.system - external_memory_summary = "" - core_memory = "" - - num_tokens_system = count_tokens(system_prompt) - num_tokens_core_memory = count_tokens(core_memory) - num_tokens_external_memory_summary = count_tokens(external_memory_summary) - - # Check if there's a summary message in the message queue - if ( - len(in_context_messages) > 1 - and in_context_messages[1].role == MessageRole.user - and in_context_messages[1].content - and len(in_context_messages[1].content) == 1 - and isinstance(in_context_messages[1].content[0], TextContent) - # TODO remove hardcoding - and "The following is a summary of the previous " in in_context_messages[1].content[0].text - ): - # Summary message exists - text_content = in_context_messages[1].content[0].text - assert text_content is not None - summary_memory = text_content - num_tokens_summary_memory = count_tokens(text_content) - # with a summary message, the real messages start at index 2 - num_tokens_messages = ( - num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model) - if len(in_context_messages_openai) > 2 - else 0 - ) - - else: - summary_memory = None - num_tokens_summary_memory = 0 - # with no summary message, the real messages start at index 1 - num_tokens_messages = ( - num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model) - if len(in_context_messages_openai) > 1 - else 0 - ) - - # tokens taken up by function definitions - agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] - if agent_state_tool_jsons: - available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons] - num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model) - else: - available_functions_definitions = [] - num_tokens_available_functions_definitions = 0 - - num_tokens_used_total = ( - num_tokens_system # system prompt - + num_tokens_available_functions_definitions # function definitions - + num_tokens_core_memory # core memory - + num_tokens_external_memory_summary # metadata (statistics) about recall/archival - + num_tokens_summary_memory # summary of ongoing conversation - + num_tokens_messages # tokens taken by messages - ) - assert isinstance(num_tokens_used_total, int) - - passage_manager_size = await self.passage_manager.agent_passage_size_async( - agent_id=self.agent_state.id, - actor=self.user, - ) - message_manager_size = await self.message_manager.size_async( - agent_id=self.agent_state.id, - actor=self.user, - ) - - return ContextWindowOverview( - # context window breakdown (in messages) - num_messages=len(in_context_messages), - num_archival_memory=passage_manager_size, - num_recall_memory=message_manager_size, - num_tokens_external_memory_summary=num_tokens_external_memory_summary, - external_memory_summary=external_memory_summary, - # top-level information - context_window_size_max=self.agent_state.llm_config.context_window, - context_window_size_current=num_tokens_used_total, - # context window breakdown (in tokens) - num_tokens_system=num_tokens_system, - system_prompt=system_prompt, - num_tokens_core_memory=num_tokens_core_memory, - core_memory=core_memory, - num_tokens_summary_memory=num_tokens_summary_memory, - summary_memory=summary_memory, - num_tokens_messages=num_tokens_messages, - messages=in_context_messages, - # related to functions - num_tokens_functions_definitions=num_tokens_available_functions_definitions, - functions_definitions=available_functions_definitions, - ) - - async def get_context_window_from_anthropic_async(self) -> ContextWindowOverview: - """Get the context window of the agent""" - anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=self.user) - model = self.agent_state.llm_config.model if self.agent_state.llm_config.model_endpoint_type == "anthropic" else None - - # Grab the in-context messages - in_context_messages = await self.message_manager.get_messages_by_ids_async( - message_ids=self.agent_state.message_ids, actor=self.user - ) - - # conversion of messages to anthropic dict format, which is passed to the token counter - in_context_messages_anthropic = Message.to_anthropic_dicts_from_list(in_context_messages) - - # Extract system, memory and external summary - if ( - len(in_context_messages) > 0 - and in_context_messages[0].role == MessageRole.system - and in_context_messages[0].content - and len(in_context_messages[0].content) == 1 - and isinstance(in_context_messages[0].content[0], TextContent) - ): - system_message = in_context_messages[0].content[0].text - - external_memory_marker_pos = system_message.find("###") - core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) - if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: - system_prompt = system_message[:external_memory_marker_pos].strip() - external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() - core_memory = system_message[core_memory_marker_pos:].strip() - else: - # if no markers found, put everything in system message - self.logger.info("No markers found in system message, core_memory and external_memory_summary will not be loaded") - system_prompt = system_message - external_memory_summary = "" - core_memory = "" - else: - # if no system message, fall back on agent's system prompt - self.logger.info("No system message found in history, core_memory and external_memory_summary will not be loaded") - system_prompt = self.agent_state.system - external_memory_summary = "" - core_memory = "" - - num_tokens_system_coroutine = anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": system_prompt}]) - num_tokens_core_memory_coroutine = ( - anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": core_memory}]) - if core_memory - else asyncio.sleep(0, result=0) - ) - num_tokens_external_memory_summary_coroutine = ( - anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": external_memory_summary}]) - if external_memory_summary - else asyncio.sleep(0, result=0) - ) - - # Check if there's a summary message in the message queue - if ( - len(in_context_messages) > 1 - and in_context_messages[1].role == MessageRole.user - and in_context_messages[1].content - and len(in_context_messages[1].content) == 1 - and isinstance(in_context_messages[1].content[0], TextContent) - # TODO remove hardcoding - and "The following is a summary of the previous " in in_context_messages[1].content[0].text - ): - # Summary message exists - text_content = in_context_messages[1].content[0].text - assert text_content is not None - summary_memory = text_content - num_tokens_summary_memory_coroutine = anthropic_client.count_tokens( - model=model, messages=[{"role": "user", "content": summary_memory}] - ) - # with a summary message, the real messages start at index 2 - num_tokens_messages_coroutine = ( - anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[2:]) - if len(in_context_messages_anthropic) > 2 - else asyncio.sleep(0, result=0) - ) - - else: - summary_memory = None - num_tokens_summary_memory_coroutine = asyncio.sleep(0, result=0) - # with no summary message, the real messages start at index 1 - num_tokens_messages_coroutine = ( - anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[1:]) - if len(in_context_messages_anthropic) > 1 - else asyncio.sleep(0, result=0) - ) - - # tokens taken up by function definitions - if self.agent_state.tools and len(self.agent_state.tools) > 0: - available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in self.agent_state.tools] - num_tokens_available_functions_definitions_coroutine = anthropic_client.count_tokens( - model=model, - tools=available_functions_definitions, - ) - else: - available_functions_definitions = [] - num_tokens_available_functions_definitions_coroutine = asyncio.sleep(0, result=0) - - ( - num_tokens_system, - num_tokens_core_memory, - num_tokens_external_memory_summary, - num_tokens_summary_memory, - num_tokens_messages, - num_tokens_available_functions_definitions, - ) = await asyncio.gather( - num_tokens_system_coroutine, - num_tokens_core_memory_coroutine, - num_tokens_external_memory_summary_coroutine, - num_tokens_summary_memory_coroutine, - num_tokens_messages_coroutine, - num_tokens_available_functions_definitions_coroutine, - ) - - num_tokens_used_total = ( - num_tokens_system # system prompt - + num_tokens_available_functions_definitions # function definitions - + num_tokens_core_memory # core memory - + num_tokens_external_memory_summary # metadata (statistics) about recall/archival - + num_tokens_summary_memory # summary of ongoing conversation - + num_tokens_messages # tokens taken by messages - ) - assert isinstance(num_tokens_used_total, int) - - passage_manager_size = await self.passage_manager.agent_passage_size_async( - agent_id=self.agent_state.id, - actor=self.user, - ) - message_manager_size = await self.message_manager.size_async( - agent_id=self.agent_state.id, - actor=self.user, - ) - - return ContextWindowOverview( - # context window breakdown (in messages) - num_messages=len(in_context_messages), - num_archival_memory=passage_manager_size, - num_recall_memory=message_manager_size, - num_tokens_external_memory_summary=num_tokens_external_memory_summary, - external_memory_summary=external_memory_summary, - # top-level information - context_window_size_max=self.agent_state.llm_config.context_window, - context_window_size_current=num_tokens_used_total, - # context window breakdown (in tokens) - num_tokens_system=num_tokens_system, - system_prompt=system_prompt, - num_tokens_core_memory=num_tokens_core_memory, - core_memory=core_memory, - num_tokens_summary_memory=num_tokens_summary_memory, - summary_memory=summary_memory, - num_tokens_messages=num_tokens_messages, - messages=in_context_messages, - # related to functions - num_tokens_functions_definitions=num_tokens_available_functions_definitions, - functions_definitions=available_functions_definitions, - ) - - def count_tokens(self) -> int: - """Count the tokens in the current context window""" - context_window_breakdown = self.get_context_window() - return context_window_breakdown.context_window_size_current - - # TODO: Refactor into separate class v.s. large if/elses here - def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool) -> ToolExecutionResult: - """ - Execute tool modifications and persist the state of the agent. - Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data - """ - # TODO: add agent manager here - orig_memory_str = self.agent_state.memory.compile() - - # TODO: need to have an AgentState object that actually has full access to the block data - # this is because the sandbox tools need to be able to access block.value to edit this data - try: - if target_letta_tool.tool_type == ToolType.LETTA_CORE: - # base tools are allowed to access the `Agent` object and run on the database - callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) - function_args["self"] = self # need to attach self to arg since it's dynamically linked - function_response = callable_func(**function_args) - elif target_letta_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE: - callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name) - function_args["self"] = self # need to attach self to arg since it's dynamically linked - function_response = callable_func(**function_args) - elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE or target_letta_tool.tool_type == ToolType.LETTA_SLEEPTIME_CORE: - callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name) - agent_state_copy = self.agent_state.__deepcopy__() - function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked - function_response = callable_func(**function_args) - self.ensure_read_only_block_not_modified( - new_memory=agent_state_copy.memory - ) # memory editing tools cannot edit read-only blocks - self.update_memory_if_changed(agent_state_copy.memory) - elif target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO: - action_name = generate_composio_action_from_func_name(target_letta_tool.name) - # Get entity ID from the agent_state - entity_id = None - for env_var in self.agent_state.secrets: - if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY: - entity_id = env_var.value - # Get composio_api_key - composio_api_key = get_composio_api_key(actor=self.user, logger=self.logger) - function_response = execute_composio_action( - action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id - ) - elif target_letta_tool.tool_type == ToolType.EXTERNAL_MCP: - # Get the server name from the tool tag - # TODO make a property instead? - server_name = target_letta_tool.tags[0].split(":")[1] - - # Get the MCPClient from the server's handle - # TODO these don't get raised properly - if not self.mcp_clients: - raise ValueError("No MCP client available to use") - if server_name not in self.mcp_clients: - raise ValueError(f"Unknown MCP server name: {server_name}") - mcp_client = self.mcp_clients[server_name] - - # Check that tool exists - available_tools = mcp_client.list_tools() - available_tool_names = [t.name for t in available_tools] - if function_name not in available_tool_names: - raise ValueError( - f"{function_name} is not available in MCP server {server_name}. Please check your `~/.letta/mcp_config.json` file." - ) - - function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args) - return ToolExecutionResult( - status="error" if is_error else "success", - func_return=function_response, - ) - else: - try: - # Parse the source code to extract function annotations - annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name) - # Coerce the function arguments to the correct types based on the annotations - function_args = coerce_dict_args_by_annotations(function_args, annotations) - except ValueError as e: - self.logger.debug(f"Error coercing function arguments: {e}") - - # execute tool in a sandbox - # TODO: allow agent_state to specify which sandbox to execute tools in - # TODO: This is only temporary, can remove after we publish a pip package with this object - agent_state_copy = self.agent_state.__deepcopy__() - agent_state_copy.tools = [] - agent_state_copy.tool_rules = [] - - tool_execution_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run( - agent_state=agent_state_copy - ) - assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" - if tool_execution_result.agent_state is not None: - self.update_memory_if_changed(tool_execution_result.agent_state.memory) - return tool_execution_result - except Exception as e: - # Need to catch error here, or else trunction wont happen - # TODO: modify to function execution error - function_response = get_friendly_error_msg( - function_name=function_name, exception_name=type(e).__name__, exception_message=str(e) - ) - return ToolExecutionResult( - status="error", - func_return=function_response, - stderr=[traceback.format_exc()], - ) - - return ToolExecutionResult( - status="success", - func_return=function_response, - ) - - -def save_agent(agent: Agent): - """Save agent to metadata store""" - agent_state = agent.agent_state - assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}" - - # TODO: move this to agent manager - # TODO: Completely strip out metadata - # convert to persisted model - agent_manager = AgentManager() - update_agent = UpdateAgent( - name=agent_state.name, - tool_ids=[t.id for t in agent_state.tools], - source_ids=[s.id for s in agent_state.sources], - block_ids=[b.id for b in agent_state.memory.blocks], - tags=agent_state.tags, - system=agent_state.system, - tool_rules=agent_state.tool_rules, - llm_config=agent_state.llm_config, - embedding_config=agent_state.embedding_config, - message_ids=agent_state.message_ids, - description=agent_state.description, - metadata=agent_state.metadata, - # TODO: Add this back in later - # tool_exec_environment_variables=agent_state.get_agent_env_vars_as_dict(), - ) - agent_manager.update_agent(agent_id=agent_state.id, agent_update=update_agent, actor=agent.user) - - -def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]: - """If 'name' exists in the JSON string, remove it and return the cleaned text + name value""" - try: - user_message_json = dict(json_loads(user_message_text)) - # Special handling for AutoGen messages with 'name' field - # Treat 'name' as a special field - # If it exists in the input message, elevate it to the 'message' level - name = user_message_json.pop("name", None) - clean_message = json_dumps(user_message_json) - return clean_message, name - - except Exception as e: - print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}") - raise e - - -def validate_json(user_message_text: str) -> str: - """Make sure that the user input message is valid JSON""" - try: - user_message_json = dict(json_loads(user_message_text)) - user_message_json_val = json_dumps(user_message_json) - return user_message_json_val - except Exception as e: - print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}") - raise e diff --git a/letta/cli/cli_load.py b/letta/cli/cli_load.py deleted file mode 100644 index a50c525e..00000000 --- a/letta/cli/cli_load.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -This file contains functions for loading data into Letta's archival storage. - -Data can be loaded with the following command, once a load function is defined: -``` -letta load --name [ADDITIONAL ARGS] -``` - -""" - -import typer - -app = typer.Typer() - - -default_extensions = "txt,md,pdf" diff --git a/letta/client/__init__.py b/letta/client/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/letta/client/streaming.py b/letta/client/streaming.py deleted file mode 100644 index 9154051a..00000000 --- a/letta/client/streaming.py +++ /dev/null @@ -1,95 +0,0 @@ -import json -from typing import Generator, Union, get_args - -import httpx -from httpx_sse import SSEError, connect_sse -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk - -from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING -from letta.errors import LLMError -from letta.log import get_logger -from letta.schemas.enums import MessageStreamStatus -from letta.schemas.letta_message import AssistantMessage, HiddenReasoningMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage -from letta.schemas.letta_response import LettaStreamingResponse -from letta.schemas.usage import LettaUsageStatistics - -logger = get_logger(__name__) - - -def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStreamingResponse, ChatCompletionChunk], None, None]: - """ - Sends an SSE POST request and yields parsed response chunks. - """ - # TODO: Please note his is a very generous timeout for e2b reasons - with httpx.Client(timeout=httpx.Timeout(5 * 60.0, read=5 * 60.0)) as client: - with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: - # Check for immediate HTTP errors before processing the SSE stream - if not event_source.response.is_success: - response_bytes = event_source.response.read() - logger.warning(f"SSE request error: {vars(event_source.response)}") - logger.warning(response_bytes.decode("utf-8")) - - try: - response_dict = json.loads(response_bytes.decode("utf-8")) - error_message = response_dict.get("error", {}).get("message", "") - - if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message: - logger.error(error_message) - raise LLMError(error_message) - except LLMError: - raise - except Exception: - logger.error("Failed to parse SSE message, raising HTTP error") - event_source.response.raise_for_status() - - try: - for sse in event_source.iter_sse(): - if sse.data in {status.value for status in MessageStreamStatus}: - yield MessageStreamStatus(sse.data) - if sse.data == MessageStreamStatus.done.value: - # We received the [DONE], so stop reading the stream. - break - else: - chunk_data = json.loads(sse.data) - - if "reasoning" in chunk_data: - yield ReasoningMessage(**chunk_data) - elif chunk_data.get("message_type") == "assistant_message": - yield AssistantMessage(**chunk_data) - elif "hidden_reasoning" in chunk_data: - yield HiddenReasoningMessage(**chunk_data) - elif "tool_call" in chunk_data: - yield ToolCallMessage(**chunk_data) - elif "tool_return" in chunk_data: - yield ToolReturnMessage(**chunk_data) - elif "step_count" in chunk_data: - yield LettaUsageStatistics(**chunk_data) - elif chunk_data.get("object") == get_args(ChatCompletionChunk.__annotations__["object"])[0]: - yield ChatCompletionChunk(**chunk_data) - else: - raise ValueError(f"Unknown message type in chunk_data: {chunk_data}") - - except SSEError as e: - logger.error(f"SSE stream error: {e}") - - if "application/json" in str(e): - response = client.post(url=url, json=data, headers=headers) - - if response.headers.get("Content-Type", "").startswith("application/json"): - error_details = response.json() - logger.error(f"POST Error: {error_details}") - else: - logger.error("Failed to retrieve JSON error message via retry.") - - raise e - - except Exception as e: - logger.error(f"Unexpected exception: {e}") - - if event_source.response.request: - logger.error(f"HTTP Request: {vars(event_source.response.request)}") - if event_source.response: - logger.error(f"HTTP Status: {event_source.response.status_code}") - logger.error(f"HTTP Headers: {event_source.response.headers}") - - raise e diff --git a/letta/client/utils.py b/letta/client/utils.py deleted file mode 100644 index f823ee87..00000000 --- a/letta/client/utils.py +++ /dev/null @@ -1,78 +0,0 @@ -import re -from datetime import datetime -from typing import Optional - -from IPython.display import HTML, display -from sqlalchemy.testing.plugin.plugin_base import warnings - -from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL - - -def pprint(messages): - """Utility function for pretty-printing the output of client.send_message in notebooks""" - - css_styles = """ - - """ - - html_content = css_styles + "
" - for message in messages: - date_str = message["date"] - date_formatted = datetime.fromisoformat(date_str.replace("Z", "+00:00")).strftime("%Y-%m-%d %H:%M:%S") - - if "function_return" in message: - return_string = message["function_return"] - return_status = message["status"] - html_content += f"

🛠️ [{date_formatted}] Function Return ({return_status}):

" - html_content += f"

{return_string}

" - elif "internal_monologue" in message: - html_content += f"

{INNER_THOUGHTS_CLI_SYMBOL} [{date_formatted}] Internal Monologue:

" - html_content += f"

{message['internal_monologue']}

" - elif "function_call" in message: - html_content += f"

🛠️ [[{date_formatted}] Function Call:

" - html_content += f"

{message['function_call']}

" - elif "assistant_message" in message: - html_content += f"

{ASSISTANT_MESSAGE_CLI_SYMBOL} [{date_formatted}] Assistant Message:

" - html_content += f"

{message['assistant_message']}

" - html_content += "
" - html_content += "
" - - display(HTML(html_content)) - - -def derive_function_name_regex(function_string: str) -> Optional[str]: - # Regular expression to match the function name - match = re.search(r"def\s+([a-zA-Z_]\w*)\s*\(", function_string) - - if match: - function_name = match.group(1) - return function_name - else: - warnings.warn("No function name found.") - return None diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index 623663fb..345d42dd 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -1,6 +1,5 @@ from typing import List, Literal, Optional -from letta.agent import Agent from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING diff --git a/letta/functions/function_sets/multi_agent.py b/letta/functions/function_sets/multi_agent.py index fced0832..2bfdff03 100644 --- a/letta/functions/function_sets/multi_agent.py +++ b/letta/functions/function_sets/multi_agent.py @@ -14,9 +14,6 @@ from letta.schemas.message import MessageCreate from letta.server.rest_api.dependencies import get_letta_server from letta.settings import settings -if TYPE_CHECKING: - from letta.agent import Agent - def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_agent_id: str) -> str: """ diff --git a/letta/groups/dynamic_multi_agent.py b/letta/groups/dynamic_multi_agent.py index 500d923d..8ec119c6 100644 --- a/letta/groups/dynamic_multi_agent.py +++ b/letta/groups/dynamic_multi_agent.py @@ -1,8 +1,9 @@ from typing import List, Optional -from letta.agent import Agent, AgentState +from letta.agents.base_agent import BaseAgent from letta.interface import AgentInterface from letta.orm import User +from letta.schemas.agent import AgentState from letta.schemas.block import Block from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message, MessageCreate @@ -11,7 +12,7 @@ from letta.schemas.usage import LettaUsageStatistics from letta.services.tool_manager import ToolManager -class DynamicMultiAgent(Agent): +class DynamicMultiAgent(BaseAgent): def __init__( self, interface: AgentInterface, diff --git a/letta/groups/helpers.py b/letta/groups/helpers.py index 69507c0f..5192fb36 100644 --- a/letta/groups/helpers.py +++ b/letta/groups/helpers.py @@ -1,7 +1,6 @@ import json from typing import Dict, Optional, Union -from letta.agent import Agent from letta.interface import AgentInterface from letta.orm.group import Group from letta.orm.user import User @@ -18,7 +17,7 @@ def load_multi_agent( actor: User, interface: Union[AgentInterface, None] = None, mcp_clients: Optional[Dict[str, AsyncBaseMCPClient]] = None, -) -> Agent: +) -> "Agent": if len(group.agent_ids) == 0: raise ValueError("Empty group: group must have at least one agent") diff --git a/letta/groups/round_robin_multi_agent.py b/letta/groups/round_robin_multi_agent.py index 9c7b319d..06a5fdcd 100644 --- a/letta/groups/round_robin_multi_agent.py +++ b/letta/groups/round_robin_multi_agent.py @@ -1,15 +1,16 @@ from typing import List, Optional -from letta.agent import Agent, AgentState +from letta.agents.base_agent import BaseAgent from letta.interface import AgentInterface from letta.orm import User +from letta.schemas.agent import AgentState from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.usage import LettaUsageStatistics -class RoundRobinMultiAgent(Agent): +class RoundRobinMultiAgent(BaseAgent): def __init__( self, interface: AgentInterface, diff --git a/letta/groups/sleeptime_multi_agent.py b/letta/groups/sleeptime_multi_agent.py index b207219c..afc49fac 100644 --- a/letta/groups/sleeptime_multi_agent.py +++ b/letta/groups/sleeptime_multi_agent.py @@ -3,10 +3,11 @@ import threading from datetime import datetime, timezone from typing import List, Optional -from letta.agent import Agent, AgentState +from letta.agents.base_agent import BaseAgent from letta.groups.helpers import stringify_message from letta.interface import AgentInterface from letta.orm import User +from letta.schemas.agent import AgentState from letta.schemas.enums import JobStatus from letta.schemas.job import JobUpdate from letta.schemas.letta_message_content import TextContent @@ -19,7 +20,7 @@ from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager -class SleeptimeMultiAgent(Agent): +class SleeptimeMultiAgent(BaseAgent): def __init__( self, interface: AgentInterface, diff --git a/letta/groups/supervisor_multi_agent.py b/letta/groups/supervisor_multi_agent.py index 35b5bf98..1a87aa67 100644 --- a/letta/groups/supervisor_multi_agent.py +++ b/letta/groups/supervisor_multi_agent.py @@ -1,12 +1,13 @@ from typing import List, Optional -from letta.agent import Agent, AgentState +from letta.agents.base_agent import BaseAgent from letta.constants import DEFAULT_MESSAGE_TOOL from letta.functions.function_sets.multi_agent import send_message_to_all_agents_in_group from letta.functions.functions import parse_source_code from letta.functions.schema_generator import generate_schema from letta.interface import AgentInterface from letta.orm import User +from letta.schemas.agent import AgentState from letta.schemas.enums import ToolType from letta.schemas.letta_message_content import TextContent from letta.schemas.message import MessageCreate @@ -17,7 +18,7 @@ from letta.services.agent_manager import AgentManager from letta.services.tool_manager import ToolManager -class SupervisorMultiAgent(Agent): +class SupervisorMultiAgent(BaseAgent): def __init__( self, interface: AgentInterface, @@ -35,82 +36,85 @@ class SupervisorMultiAgent(Agent): self.agent_manager = AgentManager() self.tool_manager = ToolManager() - def step( - self, - input_messages: List[MessageCreate], - chaining: bool = True, - max_chaining_steps: Optional[int] = None, - put_inner_thoughts_first: bool = True, - assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, - **kwargs, - ) -> LettaUsageStatistics: - # Load settings - token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False - metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None - # Prepare supervisor agent - if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None: - multi_agent_tool = Tool( - name=send_message_to_all_agents_in_group.__name__, - description="", - source_type="python", - tags=[], - source_code=parse_source_code(send_message_to_all_agents_in_group), - json_schema=generate_schema(send_message_to_all_agents_in_group, None), - ) - multi_agent_tool.tool_type = ToolType.LETTA_MULTI_AGENT_CORE - multi_agent_tool = self.tool_manager.create_or_update_tool( - pydantic_tool=multi_agent_tool, - actor=self.user, - ) - self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user) - - old_tool_rules = self.agent_state.tool_rules - self.agent_state.tool_rules = [ - InitToolRule( - tool_name="send_message_to_all_agents_in_group", - ), - TerminalToolRule( - tool_name=assistant_message_tool_name, - ), - ChildToolRule( - tool_name="send_message_to_all_agents_in_group", - children=[assistant_message_tool_name], - ), - ] - - # Prepare new messages - new_messages = [] - for message in input_messages: - if isinstance(message.content, str): - message.content = [TextContent(text=message.content)] - message.group_id = self.group_id - new_messages.append(message) - - try: - # Load supervisor agent - supervisor_agent = Agent( - agent_state=self.agent_state, - interface=self.interface, - user=self.user, - ) - - # Perform supervisor step - usage_stats = supervisor_agent.step( - input_messages=new_messages, - chaining=chaining, - max_chaining_steps=max_chaining_steps, - stream=token_streaming, - skip_verify=True, - metadata=metadata, - put_inner_thoughts_first=put_inner_thoughts_first, - ) - except Exception as e: - raise e - finally: - self.interface.step_yield() - self.agent_state.tool_rules = old_tool_rules - - self.interface.step_complete() - - return usage_stats +# +# def step( +# self, +# input_messages: List[MessageCreate], +# chaining: bool = True, +# max_chaining_steps: Optional[int] = None, +# put_inner_thoughts_first: bool = True, +# assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, +# **kwargs, +# ) -> LettaUsageStatistics: +# # Load settings +# token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False +# metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None +# +# # Prepare supervisor agent +# if self.tool_manager.get_tool_by_name(tool_name="send_message_to_all_agents_in_group", actor=self.user) is None: +# multi_agent_tool = Tool( +# name=send_message_to_all_agents_in_group.__name__, +# description="", +# source_type="python", +# tags=[], +# source_code=parse_source_code(send_message_to_all_agents_in_group), +# json_schema=generate_schema(send_message_to_all_agents_in_group, None), +# ) +# multi_agent_tool.tool_type = ToolType.LETTA_MULTI_AGENT_CORE +# multi_agent_tool = self.tool_manager.create_or_update_tool( +# pydantic_tool=multi_agent_tool, +# actor=self.user, +# ) +# self.agent_state = self.agent_manager.attach_tool(agent_id=self.agent_state.id, tool_id=multi_agent_tool.id, actor=self.user) +# +# old_tool_rules = self.agent_state.tool_rules +# self.agent_state.tool_rules = [ +# InitToolRule( +# tool_name="send_message_to_all_agents_in_group", +# ), +# TerminalToolRule( +# tool_name=assistant_message_tool_name, +# ), +# ChildToolRule( +# tool_name="send_message_to_all_agents_in_group", +# children=[assistant_message_tool_name], +# ), +# ] +# +# # Prepare new messages +# new_messages = [] +# for message in input_messages: +# if isinstance(message.content, str): +# message.content = [TextContent(text=message.content)] +# message.group_id = self.group_id +# new_messages.append(message) +# +# try: +# # Load supervisor agent +# supervisor_agent = Agent( +# agent_state=self.agent_state, +# interface=self.interface, +# user=self.user, +# ) +# +# # Perform supervisor step +# usage_stats = supervisor_agent.step( +# input_messages=new_messages, +# chaining=chaining, +# max_chaining_steps=max_chaining_steps, +# stream=token_streaming, +# skip_verify=True, +# metadata=metadata, +# put_inner_thoughts_first=put_inner_thoughts_first, +# ) +# except Exception as e: +# raise e +# finally: +# self.interface.step_yield() +# self.agent_state.tool_rules = old_tool_rules +# +# self.interface.step_complete() +# +# return usage_stats +# diff --git a/letta/main.py b/letta/main.py index a64b3637..9fd6c794 100644 --- a/letta/main.py +++ b/letta/main.py @@ -3,12 +3,19 @@ import os import typer from letta.cli.cli import server -from letta.cli.cli_load import app as load_app # disable composio print on exit os.environ["COMPOSIO_DISABLE_VERSION_CHECK"] = "true" app = typer.Typer(pretty_exceptions_enable=False) + +# Register server as both the default command and as a subcommand app.command(name="server")(server) -app.add_typer(load_app, name="load") + +# Also make server the default when no command is specified +@app.callback(invoke_without_command=True) +def main(ctx: typer.Context): + if ctx.invoked_subcommand is None: + # If no subcommand is specified, run the server + server() diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index c047839f..009df1e1 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -66,101 +66,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): id: Mapped[str] = mapped_column(String, primary_key=True) - @classmethod - @handle_db_timeout - def list( - cls, - *, - db_session: "Session", - before: Optional[str] = None, - after: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - limit: Optional[int] = 50, - query_text: Optional[str] = None, - query_embedding: Optional[List[float]] = None, - ascending: bool = True, - actor: Optional["User"] = None, - access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], - access_type: AccessType = AccessType.ORGANIZATION, - join_model: Optional[Base] = None, - join_conditions: Optional[Union[Tuple, List]] = None, - identifier_keys: Optional[List[str]] = None, - identity_id: Optional[str] = None, - **kwargs, - ) -> List["SqlalchemyBase"]: - """ - List records with before/after pagination, ordering by created_at. - Can use both before and after to fetch a window of records. - - Args: - db_session: SQLAlchemy session - before: ID of item to paginate before (upper bound) - after: ID of item to paginate after (lower bound) - start_date: Filter items after this date - end_date: Filter items before this date - limit: Maximum number of items to return - query_text: Text to search for - query_embedding: Vector to search for similar embeddings - ascending: Sort direction - **kwargs: Additional filters to apply - """ - if start_date and end_date and start_date > end_date: - raise ValueError("start_date must be earlier than or equal to end_date") - - logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}") - - with db_session as session: - # Get the reference objects for pagination - before_obj = None - after_obj = None - - if before: - before_obj = session.get(cls, before) - if not before_obj: - raise NoResultFound(f"No {cls.__name__} found with id {before}") - - if after: - after_obj = session.get(cls, after) - if not after_obj: - raise NoResultFound(f"No {cls.__name__} found with id {after}") - - # Validate that before comes after the after object if both are provided - if before_obj and after_obj and before_obj.created_at < after_obj.created_at: - raise ValueError("'before' reference must be later than 'after' reference") - - query = cls._list_preprocess( - before_obj=before_obj, - after_obj=after_obj, - start_date=start_date, - end_date=end_date, - limit=limit, - query_text=query_text, - query_embedding=query_embedding, - ascending=ascending, - actor=actor, - access=access, - access_type=access_type, - join_model=join_model, - join_conditions=join_conditions, - identifier_keys=identifier_keys, - identity_id=identity_id, - **kwargs, - ) - - # Execute the query - results = session.execute(query) - - results = list(results.scalars()) - results = cls._list_postprocess( - before=before, - after=after, - limit=limit, - results=results, - ) - - return results - @classmethod @handle_db_timeout async def list_async( @@ -446,45 +351,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): results = results[start:end] return results - @classmethod - @handle_db_timeout - def read( - cls, - db_session: "Session", - identifier: Optional[str] = None, - actor: Optional["User"] = None, - access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], - access_type: AccessType = AccessType.ORGANIZATION, - check_is_deleted: bool = False, - **kwargs, - ) -> "SqlalchemyBase": - """The primary accessor for an ORM record. - Args: - db_session: the database session to use when retrieving the record - identifier: the identifier of the record to read, can be the id string or the UUID object for backwards compatibility - actor: if specified, results will be scoped only to records the user is able to access - access: if actor is specified, records will be filtered to the minimum permission level for the actor - kwargs: additional arguments to pass to the read, used for more complex objects - Returns: - The matching object - Raises: - NoResultFound: if the object is not found - """ - # this is ok because read_multiple will check if the - identifiers = [] if identifier is None else [identifier] - found = cls.read_multiple(db_session, identifiers, actor, access, access_type, check_is_deleted, **kwargs) - if len(found) == 0: - # for backwards compatibility. - conditions = [] - if identifier: - conditions.append(f"id={identifier}") - if actor: - conditions.append(f"access level in {access} for {actor}") - if check_is_deleted and hasattr(cls, "is_deleted"): - conditions.append("is_deleted=False") - raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}") - return found[0] - @classmethod @handle_db_timeout async def read_async( @@ -521,36 +387,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): raise NoResultFound(f"{cls.__name__} not found with {', '.join(query_conditions if query_conditions else ['no conditions'])}") return item - @classmethod - @handle_db_timeout - def read_multiple( - cls, - db_session: "Session", - identifiers: List[str] = [], - actor: Optional["User"] = None, - access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], - access_type: AccessType = AccessType.ORGANIZATION, - check_is_deleted: bool = False, - **kwargs, - ) -> List["SqlalchemyBase"]: - """The primary accessor for ORM record(s) - Args: - db_session: the database session to use when retrieving the record - identifiers: a list of identifiers of the records to read, can be the id string or the UUID object for backwards compatibility - actor: if specified, results will be scoped only to records the user is able to access - access: if actor is specified, records will be filtered to the minimum permission level for the actor - kwargs: additional arguments to pass to the read, used for more complex objects - Returns: - The matching object - Raises: - NoResultFound: if the object is not found - """ - query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, check_is_deleted, **kwargs) - if query is None: - return [] - results = db_session.execute(query).scalars().all() - return cls._read_multiple_postprocess(results, identifiers, query_conditions) - @classmethod @handle_db_timeout async def read_multiple_async( @@ -637,23 +473,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): logger.debug(f"{cls.__name__} not found with {conditions_str}") return [] - @handle_db_timeout - def create(self, db_session: "Session", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase": - logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}") - - if actor: - self._set_created_and_updated_by_fields(actor.id) - try: - db_session.add(self) - if no_commit: - db_session.flush() # no commit, just flush to get PK - else: - db_session.commit() - db_session.refresh(self) - return self - except (DBAPIError, IntegrityError) as e: - self._handle_dbapi_error(e) - @handle_db_timeout async def create_async( self, @@ -680,47 +499,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): except (DBAPIError, IntegrityError) as e: self._handle_dbapi_error(e) - @classmethod - @handle_db_timeout - def batch_create(cls, items: List["SqlalchemyBase"], db_session: "Session", actor: Optional["User"] = None) -> List["SqlalchemyBase"]: - """ - Create multiple records in a single transaction for better performance. - Args: - items: List of model instances to create - db_session: SQLAlchemy session - actor: Optional user performing the action - Returns: - List of created model instances - """ - logger.debug(f"Batch creating {len(items)} {cls.__name__} items with actor={actor}") - if not items: - return [] - - # Set created/updated by fields if actor is provided - if actor: - for item in items: - item._set_created_and_updated_by_fields(actor.id) - - try: - with db_session as session: - session.add_all(items) - session.flush() # Flush to generate IDs but don't commit yet - - # Collect IDs to fetch the complete objects after commit - item_ids = [item.id for item in items] - - session.commit() - - # Re-query the objects to get them with relationships loaded - query = select(cls).where(cls.id.in_(item_ids)) - if hasattr(cls, "created_at"): - query = query.order_by(cls.created_at) - - return list(session.execute(query).scalars()) - - except (DBAPIError, IntegrityError) as e: - cls._handle_dbapi_error(e) - @classmethod @handle_db_timeout async def batch_create_async( @@ -774,16 +552,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): except (DBAPIError, IntegrityError) as e: cls._handle_dbapi_error(e) - @handle_db_timeout - def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": - logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}") - - if actor: - self._set_created_and_updated_by_fields(actor.id) - - self.is_deleted = True - return self.update(db_session) - @handle_db_timeout async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase": """Soft delete a record asynchronously (mark as deleted).""" @@ -795,22 +563,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): self.is_deleted = True return await self.update_async(db_session) - @handle_db_timeout - def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None: - """Permanently removes the record from the database.""" - logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}") - - with db_session as session: - try: - session.delete(self) - session.commit() - except Exception as e: - session.rollback() - logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}") - raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}") - else: - logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted") - @handle_db_timeout async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None: """Permanently removes the record from the database asynchronously.""" @@ -853,22 +605,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): logger.exception(f"Failed to hard delete {cls.__name__} with identifiers {identifiers}") raise ValueError(f"Failed to hard delete {cls.__name__} with identifiers {identifiers}: {e}") - @handle_db_timeout - def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase": - logger.debug(...) - if actor: - self._set_created_and_updated_by_fields(actor.id) - self.set_updated_at() - - # remove the context manager: - db_session.add(self) - if no_commit: - db_session.flush() # no commit, just flush to get PK - else: - db_session.commit() - db_session.refresh(self) - return self - @handle_db_timeout async def update_async( self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False, no_refresh: bool = False @@ -925,48 +661,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): return query - @classmethod - @handle_db_timeout - def size( - cls, - *, - db_session: "Session", - actor: Optional["User"] = None, - access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], - access_type: AccessType = AccessType.ORGANIZATION, - check_is_deleted: bool = False, - **kwargs, - ) -> int: - """ - Get the count of rows that match the provided filters. - - Args: - db_session: SQLAlchemy session - **kwargs: Filters to apply to the query (e.g., column_name=value) - - Returns: - int: The count of rows that match the filters - - Raises: - DBAPIError: If a database error occurs - """ - with db_session as session: - query = cls._size_preprocess( - db_session=session, - actor=actor, - access=access, - access_type=access_type, - check_is_deleted=check_is_deleted, - **kwargs, - ) - - try: - count = session.execute(query).scalar() - return count if count else 0 - except DBAPIError as e: - logger.exception(f"Failed to calculate size for {cls.__name__}") - raise e - @classmethod @handle_db_timeout async def size_async( diff --git a/letta/server/db.py b/letta/server/db.py index 5414e5a4..8a98b139 100644 --- a/letta/server/db.py +++ b/letta/server/db.py @@ -108,9 +108,9 @@ class DatabaseRegistry: """ def __init__(self): - self._engines: dict[str, Engine] = {} + # self._engines: dict[str, Engine] = {} self._async_engines: dict[str, AsyncEngine] = {} - self._session_factories: dict[str, sessionmaker] = {} + # self._session_factories: dict[str, sessionmaker] = {} self._async_session_factories: dict[str, async_sessionmaker] = {} self._initialized: dict[str, bool] = {"sync": False, "async": False} self._lock = threading.Lock() @@ -124,51 +124,51 @@ class DatabaseRegistry: self.logger.info("Database throttling is disabled") self._db_semaphore = None - def initialize_sync(self, force: bool = False) -> None: - """Initialize the synchronous database engine if not already initialized.""" - with self._lock: - if self._initialized.get("sync") and not force: - return + # def initialize_sync(self, force: bool = False) -> None: + # """Initialize the synchronous database engine if not already initialized.""" + # with self._lock: + # if self._initialized.get("sync") and not force: + # return - # Postgres engine - if settings.database_engine is DatabaseChoice.POSTGRES: - self.logger.info("Creating postgres engine") - self.config.recall_storage_type = "postgres" - self.config.recall_storage_uri = settings.letta_pg_uri_no_default - self.config.archival_storage_type = "postgres" - self.config.archival_storage_uri = settings.letta_pg_uri_no_default + # # Postgres engine + # if settings.database_engine is DatabaseChoice.POSTGRES: + # self.logger.info("Creating postgres engine") + # self.config.recall_storage_type = "postgres" + # self.config.recall_storage_uri = settings.letta_pg_uri_no_default + # self.config.archival_storage_type = "postgres" + # self.config.archival_storage_uri = settings.letta_pg_uri_no_default - engine = create_engine(settings.letta_pg_uri, **self._build_sqlalchemy_engine_args(is_async=False)) + # engine = create_engine(settings.letta_pg_uri, **self._build_sqlalchemy_engine_args(is_async=False)) - self._engines["default"] = engine - # SQLite engine - else: - from letta.orm import Base + # self._engines["default"] = engine + # # SQLite engine + # else: + # from letta.orm import Base - # TODO: don't rely on config storage - engine_path = "sqlite:///" + os.path.join(self.config.recall_storage_path, "sqlite.db") - self.logger.info("Creating sqlite engine " + engine_path) + # # TODO: don't rely on config storage + # engine_path = "sqlite:///" + os.path.join(self.config.recall_storage_path, "sqlite.db") + # self.logger.info("Creating sqlite engine " + engine_path) - engine = create_engine(engine_path) + # engine = create_engine(engine_path) - # Wrap the engine with error handling - self._wrap_sqlite_engine(engine) + # # Wrap the engine with error handling + # self._wrap_sqlite_engine(engine) - Base.metadata.create_all(bind=engine) - self._engines["default"] = engine + # Base.metadata.create_all(bind=engine) + # self._engines["default"] = engine - # Set up connection monitoring - if settings.sqlalchemy_tracing and settings.database_engine is DatabaseChoice.POSTGRES: - event.listen(engine, "connect", on_connect) - event.listen(engine, "close", on_close) - event.listen(engine, "checkout", on_checkout) - event.listen(engine, "checkin", on_checkin) + # # Set up connection monitoring + # if settings.sqlalchemy_tracing and settings.database_engine is DatabaseChoice.POSTGRES: + # event.listen(engine, "connect", on_connect) + # event.listen(engine, "close", on_close) + # event.listen(engine, "checkout", on_checkout) + # event.listen(engine, "checkin", on_checkin) - self._setup_pool_monitoring(engine, "default") + # self._setup_pool_monitoring(engine, "default") - # Create session factory - self._session_factories["default"] = sessionmaker(autocommit=False, autoflush=False, bind=self._engines["default"]) - self._initialized["sync"] = True + # # Create session factory + # self._session_factories["default"] = sessionmaker(autocommit=False, autoflush=False, bind=self._engines["default"]) + # self._initialized["sync"] = True def initialize_async(self, force: bool = False) -> None: """Initialize the asynchronous database engine if not already initialized.""" @@ -315,65 +315,65 @@ class DatabaseRegistry: except Exception as e: self.logger.warning(f"Failed to setup pool monitoring for {engine_name}: {e}") - def get_engine(self, name: str = "default") -> Engine: - """Get a database engine by name.""" - self.initialize_sync() - return self._engines.get(name) + # def get_engine(self, name: str = "default") -> Engine: + # """Get a database engine by name.""" + # self.initialize_sync() + # return self._engines.get(name) def get_async_engine(self, name: str = "default") -> Engine: """Get a database engine by name.""" self.initialize_async() return self._async_engines.get(name) - def get_session_factory(self, name: str = "default") -> sessionmaker: - """Get a session factory by name.""" - self.initialize_sync() - return self._session_factories.get(name) + # def get_session_factory(self, name: str = "default") -> sessionmaker: + # """Get a session factory by name.""" + # self.initialize_sync() + # return self._session_factories.get(name) def get_async_session_factory(self, name: str = "default") -> async_sessionmaker: """Get an async session factory by name.""" self.initialize_async() return self._async_session_factories.get(name) - @trace_method - @contextmanager - def session(self, name: str = "default") -> Generator[Any, None, None]: - """Context manager for database sessions.""" - caller_info = "unknown caller" - try: - import inspect + # @trace_method + # @contextmanager + # def session(self, name: str = "default") -> Generator[Any, None, None]: + # """Context manager for database sessions.""" + # caller_info = "unknown caller" + # try: + # import inspect - frame = inspect.currentframe() - stack = inspect.getouterframes(frame) + # frame = inspect.currentframe() + # stack = inspect.getouterframes(frame) - for i, frame_info in enumerate(stack): - module = inspect.getmodule(frame_info.frame) - module_name = module.__name__ if module else "unknown" + # for i, frame_info in enumerate(stack): + # module = inspect.getmodule(frame_info.frame) + # module_name = module.__name__ if module else "unknown" - if module_name != "contextlib" and "db.py" not in frame_info.filename: - caller_module = module_name - caller_function = frame_info.function - caller_lineno = frame_info.lineno - caller_file = frame_info.filename.split("/")[-1] + # if module_name != "contextlib" and "db.py" not in frame_info.filename: + # caller_module = module_name + # caller_function = frame_info.function + # caller_lineno = frame_info.lineno + # caller_file = frame_info.filename.split("/")[-1] - caller_info = f"{caller_module}.{caller_function}:{caller_lineno} ({caller_file})" - break - except: - pass - finally: - del frame + # caller_info = f"{caller_module}.{caller_function}:{caller_lineno} ({caller_file})" + # break + # except: + # pass + # finally: + # del frame - self.session_caller_trace(caller_info) + # self.session_caller_trace(caller_info) - session_factory = self.get_session_factory(name) - if not session_factory: - raise ValueError(f"No session factory found for '{name}'") + # session_factory = self.get_session_factory(name) + # if not session_factory: + # raise ValueError(f"No session factory found for '{name}'") - session = session_factory() - try: - yield session - finally: - session.close() + # session = session_factory() + # try: + # yield session + # finally: + # session.close() @trace_method @asynccontextmanager @@ -416,10 +416,10 @@ def get_db_registry() -> DatabaseRegistry: return db_registry -def get_db(): - """Get a database session.""" - with db_registry.session() as session: - yield session +# def get_db(): +# """Get a database session.""" +# with db_registry.session() as session: +# yield session async def get_db_async(): @@ -430,4 +430,4 @@ async def get_db_async(): # Prefer calling db_registry.session() or db_registry.async_session() directly # This is for backwards compatibility -db_context = contextmanager(get_db) +# db_context = contextmanager(get_db) diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 25f6a886..f354b614 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -44,7 +44,6 @@ from letta.server.db import db_registry from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right? from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.middleware import CheckPasswordMiddleware, ProfilerContextMiddleware -from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes from letta.server.rest_api.routers.v1.organizations import router as organizations_router from letta.server.rest_api.routers.v1.users import router as users_router # TODO: decide on admin @@ -124,7 +123,6 @@ async def lifespan(app_: FastAPI): logger.info(f"[Worker {worker_id}] Starting lifespan initialization") logger.info(f"[Worker {worker_id}] Initializing database connections") - db_registry.initialize_sync() db_registry.initialize_async() logger.info(f"[Worker {worker_id}] Database connections initialized") @@ -140,6 +138,7 @@ async def lifespan(app_: FastAPI): logger.info(f"[Worker {worker_id}] Starting scheduler with leader election") global server + await server.init_async() try: await start_scheduler_with_leader_election(server) logger.info(f"[Worker {worker_id}] Scheduler initialization completed") @@ -400,9 +399,6 @@ def create_application() -> "FastAPI": app.include_router(users_router, prefix=ADMIN_PREFIX) app.include_router(organizations_router, prefix=ADMIN_PREFIX) - # openai - app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX) - # /api/auth endpoints app.include_router(setup_auth_router(server, interface, random_password), prefix=API_PREFIX) diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py deleted file mode 100644 index 92cbf04e..00000000 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ /dev/null @@ -1,132 +0,0 @@ -import asyncio -from typing import TYPE_CHECKING, List, Optional, Union - -from fastapi import APIRouter, Body, Depends, Header, HTTPException -from fastapi.responses import StreamingResponse -from openai.types.chat.completion_create_params import CompletionCreateParams - -from letta.agent import Agent -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT -from letta.log import get_logger -from letta.schemas.message import Message, MessageCreate -from letta.schemas.user import User -from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface -from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server - -# TODO this belongs in a controller! -from letta.server.rest_api.utils import get_user_message_from_chat_completions_request, sse_async_generator -from letta.utils import safe_create_task - -if TYPE_CHECKING: - from letta.server.server import SyncServer - -router = APIRouter(prefix="/v1", tags=["chat_completions"]) - -logger = get_logger(__name__) - - -@router.post( - "/{agent_id}/chat/completions", - response_model=None, - operation_id="create_chat_completions", - responses={ - 200: { - "description": "Successful response", - "content": {"text/event-stream": {}}, - } - }, -) -async def create_chat_completions( - agent_id: str, - completion_request: CompletionCreateParams = Body(...), - server: "SyncServer" = Depends(get_letta_server), - headers: HeaderParams = Depends(get_headers), -): - # Validate and process fields - if not completion_request["stream"]: - raise HTTPException(status_code=400, detail="Must be streaming request: `stream` was set to `False` in the request.") - - actor = server.user_manager.get_user_or_default(user_id=headers.actor_id) - - letta_agent = server.load_agent(agent_id=agent_id, actor=actor) - llm_config = letta_agent.agent_state.llm_config - if llm_config.model_endpoint_type != "openai" or llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: - error_msg = f"You can only use models with type 'openai' for chat completions. This agent {agent_id} has llm_config: \n{llm_config.model_dump_json(indent=4)}" - logger.error(error_msg) - raise HTTPException(status_code=400, detail=error_msg) - - model = completion_request.get("model") - if model != llm_config.model: - warning_msg = f"The requested model {model} is different from the model specified in this agent's ({agent_id}) llm_config: \n{llm_config.model_dump_json(indent=4)}" - logger.warning(f"Defaulting to {llm_config.model}...") - logger.warning(warning_msg) - - return await send_message_to_agent_chat_completions( - server=server, - letta_agent=letta_agent, - actor=actor, - messages=get_user_message_from_chat_completions_request(completion_request), - ) - - -async def send_message_to_agent_chat_completions( - server: "SyncServer", - letta_agent: Agent, - actor: User, - messages: Union[List[Message], List[MessageCreate]], - assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, - assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, -) -> StreamingResponse: - """Split off into a separate function so that it can be imported in the /chat/completion proxy.""" - # For streaming response - try: - # TODO: cleanup this logic - llm_config = letta_agent.agent_state.llm_config - - # Create a new interface per request - letta_agent.interface = ChatCompletionsStreamingInterface() - streaming_interface = letta_agent.interface - if not isinstance(streaming_interface, ChatCompletionsStreamingInterface): - raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") - - # Allow AssistantMessage is desired by client - streaming_interface.assistant_message_tool_name = assistant_message_tool_name - streaming_interface.assistant_message_tool_kwarg = assistant_message_tool_kwarg - - # Related to JSON buffer reader - streaming_interface.inner_thoughts_in_kwargs = ( - llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False - ) - - # Offload the synchronous message_func to a separate thread - streaming_interface.stream_start() - safe_create_task( - asyncio.to_thread( - server.send_messages, - actor=actor, - agent_id=letta_agent.agent_state.id, - input_messages=messages, - interface=streaming_interface, - put_inner_thoughts_first=False, - ), - label="openai_send_messages", - ) - - # return a stream - return StreamingResponse( - sse_async_generator( - streaming_interface.get_generator(), - usage_task=None, - finish_message=True, - ), - media_type="text/event-stream", - ) - - except HTTPException: - raise - except Exception as e: - print(e) - import traceback - - traceback.print_exc() - raise HTTPException(status_code=500, detail=f"{e}") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 55188113..3ae2c1f6 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -190,7 +190,7 @@ async def export_agent( - Legacy format (use_legacy_format=true): Single agent with inline tools/blocks - New format (default): Multi-entity format with separate agents, tools, blocks, files, etc. """ - actor = server.user_manager.get_user_or_default(user_id=headers.actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) if use_legacy_format: # Use the legacy serialization method @@ -347,7 +347,7 @@ async def import_agent( Import a serialized agent file and recreate the agent(s) in the system. Returns the IDs of all imported agents. """ - actor = server.user_manager.get_user_or_default(user_id=headers.actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) try: serialized_data = file.file.read() @@ -1109,7 +1109,7 @@ async def list_messages( @router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_message") -def modify_message( +async def modify_message( agent_id: str, message_id: str, request: LettaMessageUpdateUnion = Body(...), @@ -1120,8 +1120,12 @@ def modify_message( Update the details of a message associated with an agent. """ # TODO: support modifying tool calls/returns - actor = server.user_manager.get_user_or_default(user_id=headers.actor_id) - return server.message_manager.update_message_by_letta_message(message_id=message_id, letta_message_update=request, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + + # TODO: implement + return await server.message_manager.update_message_by_letta_message_async( + message_id=message_id, letta_message_update=request, actor=actor + ) # noinspection PyInconsistentReturns diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 85a51b8f..a2a2464d 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -178,7 +178,7 @@ async def list_run_messages( actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) try: - messages = server.job_manager.get_run_messages( + messages = await server.job_manager.get_run_messages( run_id=run_id, actor=actor, limit=limit, @@ -244,7 +244,7 @@ async def list_run_steps( actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) try: - steps = server.job_manager.get_job_steps( + steps = await server.job_manager.get_job_steps( job_id=run_id, actor=actor, limit=limit, @@ -315,7 +315,7 @@ async def retrieve_stream( ): actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) try: - job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor) + job = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor) except NoResultFound: raise HTTPException(status_code=404, detail="Run not found") diff --git a/letta/server/server.py b/letta/server/server.py index fbee101d..e7634329 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -18,7 +18,6 @@ from fastapi.responses import StreamingResponse import letta.constants as constants import letta.server.utils as server_utils import letta.system as system -from letta.agent import Agent, save_agent from letta.config import LettaConfig from letta.constants import LETTA_TOOL_EXECUTION_DIR from letta.data_sources.connectors import DataConnector, load_data @@ -246,25 +245,41 @@ class SyncServer(Server): limits = httpx.Limits(max_connections=100, max_keepalive_connections=80, keepalive_expiry=300) self.httpx_client = httpx.AsyncClient(timeout=timeout, follow_redirects=True, limits=limits) + # For MCP + # TODO: remove this + """Initialize the MCP clients (there may be multiple)""" + self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {} + + # TODO: Remove these in memory caches + self._llm_config_cache = {} + self._embedding_config_cache = {} + + # TODO: Replace this with the Anthropic client we have in house + self.anthropic_async_client = AsyncAnthropic() + + async def init_async(self, init_with_default_org_and_user: bool = True): # Make default user and org if init_with_default_org_and_user: - self.default_org = self.organization_manager.create_default_organization() - self.default_user = self.user_manager.create_default_user() - self.tool_manager.upsert_base_tools(actor=self.default_user) + self.default_org = await self.organization_manager.create_default_organization_async() + self.default_user = await self.user_manager.create_default_actor_async() + print(f"Default user: {self.default_user} and org: {self.default_org}") + await self.tool_manager.upsert_base_tools_async(actor=self.default_user) # Add composio keys to the tool sandbox env vars of the org if tool_settings.composio_api_key: manager = SandboxConfigManager() - sandbox_config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.default_user) + sandbox_config = await manager.get_or_create_default_sandbox_config_async( + sandbox_type=SandboxType.LOCAL, actor=self.default_user + ) - manager.create_sandbox_env_var( + await manager.create_sandbox_env_var_async( SandboxEnvironmentVariableCreate(key="COMPOSIO_API_KEY", value=tool_settings.composio_api_key), sandbox_config_id=sandbox_config.id, actor=self.default_user, ) # For OSS users, create a local sandbox config - oss_default_user = self.user_manager.get_default_user() + oss_default_user = await self.user_manager.get_default_actor_async() use_venv = False if not tool_settings.tool_exec_venv_name else True venv_name = tool_settings.tool_exec_venv_name or "venv" tool_dir = tool_settings.tool_exec_dir or LETTA_TOOL_EXECUTION_DIR @@ -287,10 +302,10 @@ class SyncServer(Server): sandbox_config_create = SandboxConfigCreate( config=LocalSandboxConfig(sandbox_dir=tool_settings.tool_exec_dir, use_venv=use_venv, venv_name=venv_name) ) - sandbox_config = self.sandbox_config_manager.create_or_update_sandbox_config( + sandbox_config = await self.sandbox_config_manager.create_or_update_sandbox_config_async( sandbox_config_create=sandbox_config_create, actor=oss_default_user ) - logger.info(f"Successfully created default local sandbox config:\n{sandbox_config.get_local_config().model_dump()}") + logger.debug(f"Successfully created default local sandbox config:\n{sandbox_config.get_local_config().model_dump()}") if use_venv and tool_settings.tool_exec_autoreload_venv: prepare_local_sandbox( @@ -399,18 +414,6 @@ class SyncServer(Server): if model_settings.xai_api_key: self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key)) - # For MCP - # TODO: remove this - """Initialize the MCP clients (there may be multiple)""" - self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {} - - # TODO: Remove these in memory caches - self._llm_config_cache = {} - self._embedding_config_cache = {} - - # TODO: Replace this with the Anthropic client we have in house - self.anthropic_async_client = AsyncAnthropic() - async def init_mcp_clients(self): # TODO: remove this mcp_server_configs = self.get_mcp_servers() @@ -436,329 +439,6 @@ class SyncServer(Server): logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}") logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}") - def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: - """Updated method to load agents from persisted storage""" - agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) - # TODO: Think about how to integrate voice sleeptime into sleeptime - # TODO: Voice sleeptime agents turn into normal agents when being messaged - if agent_state.multi_agent_group and agent_state.multi_agent_group.manager_type != ManagerType.voice_sleeptime: - return load_multi_agent( - group=agent_state.multi_agent_group, agent_state=agent_state, actor=actor, interface=interface, mcp_clients=self.mcp_clients - ) - - interface = interface or self.default_interface_factory() - return Agent(agent_state=agent_state, interface=interface, user=actor, mcp_clients=self.mcp_clients) - - def _step( - self, - actor: User, - agent_id: str, - input_messages: List[MessageCreate], - interface: Union[AgentInterface, None] = None, # needed to getting responses - put_inner_thoughts_first: bool = True, - # timestamp: Optional[datetime], - ) -> LettaUsageStatistics: - """Send the input message through the agent""" - # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - logger.debug(f"Got input messages: {input_messages}") - letta_agent = None - try: - letta_agent = self.load_agent(agent_id=agent_id, interface=interface, actor=actor) - if letta_agent is None: - raise KeyError(f"Agent (user={actor.id}, agent={agent_id}) is not loaded") - - # Determine whether or not to token stream based on the capability of the interface - token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False - - logger.debug("Starting agent step") - if interface: - metadata = interface.metadata if hasattr(interface, "metadata") else None - else: - metadata = None - - usage_stats = letta_agent.step( - input_messages=input_messages, - chaining=self.chaining, - max_chaining_steps=self.max_chaining_steps, - stream=token_streaming, - skip_verify=True, - metadata=metadata, - put_inner_thoughts_first=put_inner_thoughts_first, - ) - - except Exception as e: - logger.error(f"Error in server._step: {e}") - print(traceback.print_exc()) - raise - finally: - logger.debug("Calling step_yield()") - if letta_agent: - letta_agent.interface.step_yield() - - return usage_stats - - def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: - """Process a CLI command""" - # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - actor = self.user_manager.get_user_or_default(user_id=user_id) - - logger.debug(f"Got command: {command}") - - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - usage = None - - if command.lower() == "exit": - # exit not supported on server.py - raise ValueError(command) - - elif command.lower() == "save" or command.lower() == "savechat": - save_agent(letta_agent) - - elif command.lower() == "attach": - # Different from CLI, we extract the data source name from the command - command = command.strip().split() - try: - data_source = int(command[1]) - except: - raise ValueError(command) - - # attach data to agent from source - letta_agent.attach_source( - user=self.user_manager.get_user_by_id(user_id=user_id), - source_id=data_source, - source_manager=self.source_manager, - agent_manager=self.agent_manager, - ) - - elif command.lower() == "dump" or command.lower().startswith("dump "): - # Check if there's an additional argument that's an integer - command = command.strip().split() - amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 - if amount == 0: - letta_agent.interface.print_messages(letta_agent.messages, dump=True) - else: - letta_agent.interface.print_messages(letta_agent.messages[-min(amount, len(letta_agent.messages)) :], dump=True) - - elif command.lower() == "dumpraw": - letta_agent.interface.print_messages_raw(letta_agent.messages) - - elif command.lower() == "memory": - ret_str = "\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.passage_manager)}" - return ret_str - - elif command.lower() == "pop" or command.lower().startswith("pop "): - # Check if there's an additional argument that's an integer - command = command.strip().split() - pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3 - n_messages = len(letta_agent.messages) - MIN_MESSAGES = 2 - if n_messages <= MIN_MESSAGES: - logger.debug(f"Agent only has {n_messages} messages in stack, none left to pop") - elif n_messages - pop_amount < MIN_MESSAGES: - logger.debug(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}") - else: - logger.debug(f"Popping last {pop_amount} messages from stack") - for _ in range(min(pop_amount, len(letta_agent.messages))): - letta_agent.messages.pop() - - elif command.lower() == "retry": - # TODO this needs to also modify the persistence manager - logger.debug("Retrying for another answer") - while len(letta_agent.messages) > 0: - if letta_agent.messages[-1].get("role") == "user": - # we want to pop up to the last user message and send it again - letta_agent.messages[-1].get("content") - letta_agent.messages.pop() - break - letta_agent.messages.pop() - - elif command.lower() == "rethink" or command.lower().startswith("rethink "): - # TODO this needs to also modify the persistence manager - if len(command) < len("rethink "): - logger.warning("Missing text after the command") - else: - for x in range(len(letta_agent.messages) - 1, 0, -1): - if letta_agent.messages[x].get("role") == "assistant": - text = command[len("rethink ") :].strip() - letta_agent.messages[x].update({"content": text}) - break - - elif command.lower() == "rewrite" or command.lower().startswith("rewrite "): - # TODO this needs to also modify the persistence manager - if len(command) < len("rewrite "): - logger.warning("Missing text after the command") - else: - for x in range(len(letta_agent.messages) - 1, 0, -1): - if letta_agent.messages[x].get("role") == "assistant": - text = command[len("rewrite ") :].strip() - args = json_loads(letta_agent.messages[x].get("function_call").get("arguments")) - args["message"] = text - letta_agent.messages[x].get("function_call").update({"arguments": json_dumps(args)}) - break - - # No skip options - elif command.lower() == "wipe": - # exit not supported on server.py - raise ValueError(command) - - elif command.lower() == "heartbeat": - input_message = system.get_heartbeat() - usage = self._step(actor=actor, agent_id=agent_id, input_message=input_message) - - elif command.lower() == "memorywarning": - input_message = system.get_token_limit_warning() - usage = self._step(actor=actor, agent_id=agent_id, input_message=input_message) - - if not usage: - usage = LettaUsageStatistics() - - return usage - - def user_message( - self, - user_id: str, - agent_id: str, - message: Union[str, Message], - timestamp: Optional[datetime] = None, - ) -> LettaUsageStatistics: - """Process an incoming user message and feed it through the Letta agent""" - try: - actor = self.user_manager.get_user_by_id(user_id=user_id) - except NoResultFound: - raise ValueError(f"User user_id={user_id} does not exist") - - try: - agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) - except NoResultFound: - raise ValueError(f"Agent agent_id={agent_id} does not exist") - - # Basic input sanitization - if isinstance(message, str): - if len(message) == 0: - raise ValueError(f"Invalid input: '{message}'") - - # If the input begins with a command prefix, reject - elif message.startswith("/"): - raise ValueError(f"Invalid input: '{message}'") - - packaged_user_message = system.package_user_message( - user_message=message, - timezone=agent.timezone, - ) - - # NOTE: eventually deprecate and only allow passing Message types - message = MessageCreate( - agent_id=agent_id, - role="user", - content=[TextContent(text=packaged_user_message)], - ) - - # Run the agent state forward - usage = self._step(actor=actor, agent_id=agent_id, input_messages=[message]) - return usage - - def system_message( - self, - user_id: str, - agent_id: str, - message: Union[str, Message], - timestamp: Optional[datetime] = None, - ) -> LettaUsageStatistics: - """Process an incoming system message and feed it through the Letta agent""" - try: - actor = self.user_manager.get_user_by_id(user_id=user_id) - except NoResultFound: - raise ValueError(f"User user_id={user_id} does not exist") - - try: - agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) - except NoResultFound: - raise ValueError(f"Agent agent_id={agent_id} does not exist") - - # Basic input sanitization - if isinstance(message, str): - if len(message) == 0: - raise ValueError(f"Invalid input: '{message}'") - - # If the input begins with a command prefix, reject - elif message.startswith("/"): - raise ValueError(f"Invalid input: '{message}'") - - packaged_system_message = system.package_system_message(system_message=message) - - # NOTE: eventually deprecate and only allow passing Message types - # Convert to a Message object - - if timestamp: - message = Message( - agent_id=agent_id, - role="system", - content=[TextContent(text=packaged_system_message)], - created_at=timestamp, - ) - else: - message = Message( - agent_id=agent_id, - role="system", - content=[TextContent(text=packaged_system_message)], - ) - - if isinstance(message, Message): - # Can't have a null text field - message_text = message.content[0].text - if message_text is None or len(message_text) == 0: - raise ValueError(f"Invalid input: '{message_text}'") - # If the input begins with a command prefix, reject - elif message_text.startswith("/"): - raise ValueError(f"Invalid input: '{message_text}'") - - else: - raise TypeError(f"Invalid input: '{message}' - type {type(message)}") - - if timestamp: - # Override the timestamp with what the caller provided - message.created_at = timestamp - - # Run the agent state forward - return self._step(actor=actor, agent_id=agent_id, input_messages=message) - - # TODO: Deprecate this - def send_messages( - self, - actor: User, - agent_id: str, - input_messages: List[MessageCreate], - wrap_user_message: bool = True, - wrap_system_message: bool = True, - interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed for responses - metadata: Optional[dict] = None, # Pass through metadata to interface - put_inner_thoughts_first: bool = True, - ) -> LettaUsageStatistics: - """Send a list of messages to the agent.""" - - # Store metadata in interface if provided - if metadata and hasattr(interface, "metadata"): - interface.metadata = metadata - - # Run the agent state forward - return self._step( - actor=actor, - agent_id=agent_id, - input_messages=input_messages, - interface=interface, - put_inner_thoughts_first=put_inner_thoughts_first, - ) - - # @LockingServer.agent_lock_decorator - def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: - """Run a command on the agent""" - # If the input begins with a command prefix, attempt to process it as a command - if command.startswith("/"): - if len(command) > 1: - command = command[1:] # strip the prefix - return self._command(user_id=user_id, agent_id=agent_id, command=command) - @trace_method def get_cached_llm_config(self, actor: User, **kwargs): key = make_key(**kwargs) @@ -788,54 +468,6 @@ class SyncServer(Server): self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs) return self._embedding_config_cache[key] - @trace_method - def create_agent( - self, - request: CreateAgent, - actor: User, - interface: AgentInterface | None = None, - ) -> AgentState: - warnings.warn("This method is deprecated, use create_agent_async where possible.", DeprecationWarning, stacklevel=2) - if request.llm_config is None: - if request.model is None: - raise ValueError("Must specify either model or llm_config in request") - config_params = { - "handle": request.model, - "context_window_limit": request.context_window_limit, - "max_tokens": request.max_tokens, - "max_reasoning_tokens": request.max_reasoning_tokens, - "enable_reasoner": request.enable_reasoner, - } - log_event(name="start get_cached_llm_config", attributes=config_params) - request.llm_config = self.get_cached_llm_config(actor=actor, **config_params) - log_event(name="end get_cached_llm_config", attributes=config_params) - - if request.embedding_config is None: - if request.embedding is None: - raise ValueError("Must specify either embedding or embedding_config in request") - embedding_config_params = { - "handle": request.embedding, - "embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE, - } - log_event(name="start get_cached_embedding_config", attributes=embedding_config_params) - request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params) - log_event(name="end get_cached_embedding_config", attributes=embedding_config_params) - - log_event(name="start create_agent db") - main_agent = self.agent_manager.create_agent( - agent_create=request, - actor=actor, - ) - log_event(name="end create_agent db") - - if request.enable_sleeptime: - if request.agent_type == AgentType.voice_convo_agent: - main_agent = self.create_voice_sleeptime_agent(main_agent=main_agent, actor=actor) - else: - main_agent = self.create_sleeptime_agent(main_agent=main_agent, actor=actor) - - return main_agent - @trace_method async def create_agent_async( self, @@ -903,32 +535,6 @@ class SyncServer(Server): return main_agent - def update_agent( - self, - agent_id: str, - request: UpdateAgent, - actor: User, - ) -> AgentState: - if request.model is not None: - request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor) - - if request.embedding is not None: - request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor) - - if request.enable_sleeptime: - agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) - if agent.multi_agent_group is None: - if agent.agent_type == AgentType.voice_convo_agent: - self.create_voice_sleeptime_agent(main_agent=agent, actor=actor) - else: - self.create_sleeptime_agent(main_agent=agent, actor=actor) - - return self.agent_manager.update_agent( - agent_id=agent_id, - agent_update=request, - actor=actor, - ) - async def update_agent_async( self, agent_id: str, @@ -955,38 +561,6 @@ class SyncServer(Server): actor=actor, ) - def create_sleeptime_agent(self, main_agent: AgentState, actor: User) -> AgentState: - request = CreateAgent( - name=main_agent.name + "-sleeptime", - agent_type=AgentType.sleeptime_agent, - block_ids=[block.id for block in main_agent.memory.blocks], - memory_blocks=[ - CreateBlock( - label="memory_persona", - value=get_persona_text("sleeptime_memory_persona"), - ), - ], - llm_config=main_agent.llm_config, - embedding_config=main_agent.embedding_config, - project_id=main_agent.project_id, - ) - sleeptime_agent = self.agent_manager.create_agent( - agent_create=request, - actor=actor, - ) - self.group_manager.create_group( - group=GroupCreate( - description="", - agent_ids=[sleeptime_agent.id], - manager_config=SleeptimeManager( - manager_agent_id=main_agent.id, - sleeptime_agent_frequency=5, - ), - ), - actor=actor, - ) - return self.agent_manager.get_agent_by_id(agent_id=main_agent.id, actor=actor) - async def create_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState: request = CreateAgent( name=main_agent.name + "-sleeptime", @@ -1019,40 +593,6 @@ class SyncServer(Server): ) return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor) - def create_voice_sleeptime_agent(self, main_agent: AgentState, actor: User) -> AgentState: - # TODO: Inject system - request = CreateAgent( - name=main_agent.name + "-sleeptime", - agent_type=AgentType.voice_sleeptime_agent, - block_ids=[block.id for block in main_agent.memory.blocks], - memory_blocks=[ - CreateBlock( - label="memory_persona", - value=get_persona_text("voice_memory_persona"), - ), - ], - llm_config=LLMConfig.default_config("gpt-4.1"), - embedding_config=main_agent.embedding_config, - project_id=main_agent.project_id, - ) - voice_sleeptime_agent = self.agent_manager.create_agent( - agent_create=request, - actor=actor, - ) - self.group_manager.create_group( - group=GroupCreate( - description="Low latency voice chat with async memory management.", - agent_ids=[voice_sleeptime_agent.id], - manager_config=VoiceSleeptimeManager( - manager_agent_id=main_agent.id, - max_message_buffer_length=constants.DEFAULT_MAX_MESSAGE_BUFFER_LENGTH, - min_message_buffer_length=constants.DEFAULT_MIN_MESSAGE_BUFFER_LENGTH, - ), - ), - actor=actor, - ) - return self.agent_manager.get_agent_by_id(agent_id=main_agent.id, actor=actor) - async def create_voice_sleeptime_agent_async(self, main_agent: AgentState, actor: User) -> AgentState: # TODO: Inject system request = CreateAgent( @@ -1087,24 +627,11 @@ class SyncServer(Server): ) return await self.agent_manager.get_agent_by_id_async(agent_id=main_agent.id, actor=actor) - # convert name->id - - # TODO: These can be moved to agent_manager - def get_agent_memory(self, agent_id: str, actor: User) -> Memory: - """Return the memory of an agent (core memory)""" - return self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).memory - async def get_agent_memory_async(self, agent_id: str, actor: User) -> Memory: """Return the memory of an agent (core memory)""" agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) return agent.memory - def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary: - return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id)) - - def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary: - return RecallMemorySummary(size=self.message_manager.size(actor=actor, agent_id=agent_id)) - async def get_agent_archival_async( self, agent_id: str, @@ -1429,72 +956,6 @@ class SyncServer(Server): passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor) return passage_count, document_count - def list_all_sources(self, actor: User) -> List[Source]: - # TODO: legacy: remove - """List all sources (w/ extra metadata) belonging to a user""" - - sources = self.source_manager.list_sources(actor=actor) - - # Add extra metadata to the sources - sources_with_metadata = [] - for source in sources: - # count number of passages - num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id) - - # TODO: add when files table implemented - ## count number of files - # document_conn = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id) - # num_documents = document_conn.size({"data_source": source.name}) - num_documents = 0 - - agents = self.source_manager.list_attached_agents(source_id=source.id, actor=actor) - # add the agent name information - attached_agents = [{"id": agent.id, "name": agent.name} for agent in agents] - - # Overwrite metadata field, should be empty anyways - source.metadata = dict( - num_documents=num_documents, - num_passages=num_passages, - attached_agents=attached_agents, - ) - - sources_with_metadata.append(source) - - return sources_with_metadata - - def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message: - """Update the details of a message associated with an agent""" - - # Get the current message - return self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor) - - def list_llm_models( - self, - actor: User, - provider_category: Optional[List[ProviderCategory]] = None, - provider_name: Optional[str] = None, - provider_type: Optional[ProviderType] = None, - ) -> List[LLMConfig]: - """List available models""" - llm_models = [] - for provider in self.get_enabled_providers( - provider_category=provider_category, - provider_name=provider_name, - provider_type=provider_type, - actor=actor, - ): - try: - llm_models.extend(provider.list_llm_models()) - except Exception as e: - import traceback - - traceback.print_exc() - warnings.warn(f"An error occurred while listing LLM models for provider {provider}: {e}") - - llm_models.extend(self.get_local_llm_configs()) - - return llm_models - @trace_method async def list_llm_models_async( self, @@ -1548,16 +1009,6 @@ class SyncServer(Server): return unique_models - def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]: - """List available embedding models""" - embedding_models = [] - for provider in self.get_enabled_providers(actor): - try: - embedding_models.extend(provider.list_embedding_models()) - except Exception as e: - warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") - return embedding_models - async def list_embedding_models_async(self, actor: User) -> List[EmbeddingConfig]: """Asynchronously list available embedding models with maximum concurrency""" import asyncio @@ -1587,35 +1038,6 @@ class SyncServer(Server): return embedding_models - def get_enabled_providers( - self, - actor: User, - provider_category: Optional[List[ProviderCategory]] = None, - provider_name: Optional[str] = None, - provider_type: Optional[ProviderType] = None, - ) -> List[Provider]: - providers = [] - if not provider_category or ProviderCategory.base in provider_category: - providers_from_env = [p for p in self._enabled_providers] - providers.extend(providers_from_env) - - if not provider_category or ProviderCategory.byok in provider_category: - providers_from_db = self.provider_manager.list_providers( - name=provider_name, - provider_type=provider_type, - actor=actor, - ) - providers_from_db = [p.cast_to_subtype() for p in providers_from_db] - providers.extend(providers_from_db) - - if provider_name is not None: - providers = [p for p in providers if p.name == provider_name] - - if provider_type is not None: - providers = [p for p in providers if p.provider_type == provider_type] - - return providers - async def get_enabled_providers_async( self, actor: User, @@ -1645,60 +1067,6 @@ class SyncServer(Server): return providers - @trace_method - def get_llm_config_from_handle( - self, - actor: User, - handle: str, - context_window_limit: Optional[int] = None, - max_tokens: Optional[int] = None, - max_reasoning_tokens: Optional[int] = None, - enable_reasoner: Optional[bool] = None, - ) -> LLMConfig: - try: - provider_name, model_name = handle.split("/", 1) - provider = self.get_provider_from_name(provider_name, actor) - - llm_configs = [config for config in provider.list_llm_models() if config.handle == handle] - if not llm_configs: - llm_configs = [config for config in provider.list_llm_models() if config.model == model_name] - if not llm_configs: - available_handles = [config.handle for config in provider.list_llm_models()] - raise HandleNotFoundError(handle, available_handles) - except ValueError as e: - llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle] - if not llm_configs: - llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name] - if not llm_configs: - raise e - - if len(llm_configs) == 1: - llm_config = llm_configs[0] - elif len(llm_configs) > 1: - raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}") - else: - llm_config = llm_configs[0] - - if context_window_limit is not None: - if context_window_limit > llm_config.context_window: - raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})") - llm_config.context_window = context_window_limit - else: - llm_config.context_window = min(llm_config.context_window, model_settings.global_max_context_window_limit) - - if max_tokens is not None: - llm_config.max_tokens = max_tokens - if max_reasoning_tokens is not None: - if not max_tokens or max_reasoning_tokens > max_tokens: - raise ValueError(f"Max reasoning tokens ({max_reasoning_tokens}) must be less than max tokens ({max_tokens})") - llm_config.max_reasoning_tokens = max_reasoning_tokens - if enable_reasoner is not None: - llm_config.enable_reasoner = enable_reasoner - if enable_reasoner and llm_config.model_endpoint_type == "anthropic": - llm_config.put_inner_thoughts_in_kwargs = False - - return llm_config - @trace_method async def get_llm_config_from_handle_async( self, @@ -1754,35 +1122,6 @@ class SyncServer(Server): return llm_config - @trace_method - def get_embedding_config_from_handle( - self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE - ) -> EmbeddingConfig: - try: - provider_name, model_name = handle.split("/", 1) - provider = self.get_provider_from_name(provider_name, actor) - - embedding_configs = [config for config in provider.list_embedding_models() if config.handle == handle] - if not embedding_configs: - raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}") - except ValueError as e: - # search local configs - embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle] - if not embedding_configs: - raise e - - if len(embedding_configs) == 1: - embedding_config = embedding_configs[0] - elif len(embedding_configs) > 1: - raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}") - else: - embedding_config = embedding_configs[0] - - if embedding_chunk_size: - embedding_config.embedding_chunk_size = embedding_chunk_size - - return embedding_config - @trace_method async def get_embedding_config_from_handle_async( self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE @@ -1813,19 +1152,6 @@ class SyncServer(Server): return embedding_config - def get_provider_from_name(self, provider_name: str, actor: User) -> Provider: - providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name] - if not providers: - raise ValueError( - f"Provider {provider_name} is not supported (supported providers: {', '.join([provider.name for provider in self._enabled_providers])})" - ) - elif len(providers) > 1: - raise ValueError(f"Multiple providers with name {provider_name} supported") - else: - provider = providers[0] - - return provider - async def get_provider_from_name_async(self, provider_name: str, actor: User) -> Provider: all_providers = await self.get_enabled_providers_async(actor) providers = [provider for provider in all_providers if provider.name == provider_name] @@ -1842,40 +1168,42 @@ class SyncServer(Server): def get_local_llm_configs(self): llm_models = [] - try: - llm_configs_dir = os.path.expanduser("~/.letta/llm_configs") - if os.path.exists(llm_configs_dir): - for filename in os.listdir(llm_configs_dir): - if filename.endswith(".json"): - filepath = os.path.join(llm_configs_dir, filename) - try: - with open(filepath, "r") as f: - config_data = json.load(f) - llm_config = LLMConfig(**config_data) - llm_models.append(llm_config) - except (json.JSONDecodeError, ValueError) as e: - warnings.warn(f"Error parsing LLM config file {filename}: {e}") - except Exception as e: - warnings.warn(f"Error reading LLM configs directory: {e}") + # NOTE: deprecated + # try: + # llm_configs_dir = os.path.expanduser("~/.letta/llm_configs") + # if os.path.exists(llm_configs_dir): + # for filename in os.listdir(llm_configs_dir): + # if filename.endswith(".json"): + # filepath = os.path.join(llm_configs_dir, filename) + # try: + # with open(filepath, "r") as f: + # config_data = json.load(f) + # llm_config = LLMConfig(**config_data) + # llm_models.append(llm_config) + # except (json.JSONDecodeError, ValueError) as e: + # warnings.warn(f"Error parsing LLM config file {filename}: {e}") + # except Exception as e: + # warnings.warn(f"Error reading LLM configs directory: {e}") return llm_models def get_local_embedding_configs(self): embedding_models = [] - try: - embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs") - if os.path.exists(embedding_configs_dir): - for filename in os.listdir(embedding_configs_dir): - if filename.endswith(".json"): - filepath = os.path.join(embedding_configs_dir, filename) - try: - with open(filepath, "r") as f: - config_data = json.load(f) - embedding_config = EmbeddingConfig(**config_data) - embedding_models.append(embedding_config) - except (json.JSONDecodeError, ValueError) as e: - warnings.warn(f"Error parsing embedding config file {filename}: {e}") - except Exception as e: - warnings.warn(f"Error reading embedding configs directory: {e}") + # NOTE: deprecated + # try: + # embedding_configs_dir = os.path.expanduser("~/.letta/embedding_configs") + # if os.path.exists(embedding_configs_dir): + # for filename in os.listdir(embedding_configs_dir): + # if filename.endswith(".json"): + # filepath = os.path.join(embedding_configs_dir, filename) + # try: + # with open(filepath, "r") as f: + # config_data = json.load(f) + # embedding_config = EmbeddingConfig(**config_data) + # embedding_models.append(embedding_config) + # except (json.JSONDecodeError, ValueError) as e: + # warnings.warn(f"Error parsing embedding config file {filename}: {e}") + # except Exception as e: + # warnings.warn(f"Error reading embedding configs directory: {e}") return embedding_models def add_llm_model(self, request: LLMConfig) -> LLMConfig: diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 7369be75..4c40131e 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -301,185 +301,6 @@ class AgentManager: # ====================================================================================================================== # Basic CRUD operations # ====================================================================================================================== - @trace_method - def create_agent(self, agent_create: CreateAgent, actor: PydanticUser, _test_only_force_id: Optional[str] = None) -> PydanticAgentState: - # validate required configs - if not agent_create.llm_config or not agent_create.embedding_config: - raise ValueError("llm_config and embedding_config are required") - - # blocks - block_ids = list(agent_create.block_ids or []) - if agent_create.memory_blocks: - pydantic_blocks = [PydanticBlock(**b.model_dump(to_orm=True)) for b in agent_create.memory_blocks] - created_blocks = self.block_manager.batch_create_blocks( - pydantic_blocks, - actor=actor, - ) - block_ids.extend([blk.id for blk in created_blocks]) - - # tools - tool_names = set(agent_create.tools or []) - if agent_create.include_base_tools: - if agent_create.agent_type == AgentType.voice_sleeptime_agent: - tool_names |= set(BASE_VOICE_SLEEPTIME_TOOLS) - elif agent_create.agent_type == AgentType.voice_convo_agent: - tool_names |= set(BASE_VOICE_SLEEPTIME_CHAT_TOOLS) - elif agent_create.agent_type == AgentType.sleeptime_agent: - tool_names |= set(BASE_SLEEPTIME_TOOLS) - elif agent_create.enable_sleeptime: - tool_names |= set(BASE_SLEEPTIME_CHAT_TOOLS) - elif agent_create.agent_type == AgentType.memgpt_v2_agent: - tool_names |= calculate_base_tools(is_v2=True) - elif agent_create.agent_type == AgentType.react_agent: - pass # no default tools - elif agent_create.agent_type == AgentType.workflow_agent: - pass # no default tools - else: - tool_names |= calculate_base_tools(is_v2=False) - if agent_create.include_multi_agent_tools: - tool_names |= calculate_multi_agent_tools() - - supplied_ids = set(agent_create.tool_ids or []) - - source_ids = agent_create.source_ids or [] - identity_ids = agent_create.identity_ids or [] - tag_values = agent_create.tags or [] - - with db_registry.session() as session: - with session.begin(): - name_to_id, id_to_name = self._resolve_tools( - session, - tool_names, - supplied_ids, - actor.organization_id, - ) - - tool_ids = set(name_to_id.values()) | set(id_to_name.keys()) - tool_names = set(name_to_id.keys()) # now canonical - - tool_rules = list(agent_create.tool_rules or []) - - # Override include_base_tool_rules to False if model matches exclusion keywords and include_base_tool_rules is not explicitly set to True - if ( - ( - self._should_exclude_model_from_base_tool_rules(agent_create.llm_config.model) - and agent_create.include_base_tool_rules is None - ) - and agent_create.agent_type != AgentType.sleeptime_agent - ) or agent_create.include_base_tool_rules is False: - agent_create.include_base_tool_rules = False - logger.info(f"Overriding include_base_tool_rules to False for model: {agent_create.llm_config.model}") - else: - agent_create.include_base_tool_rules = True - - should_add_base_tool_rules = agent_create.include_base_tool_rules - if should_add_base_tool_rules: - for tn in tool_names: - if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}: - tool_rules.append(TerminalToolRule(tool_name=tn)) - elif tn in (BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_MEMORY_TOOLS_V2 + BASE_SLEEPTIME_TOOLS): - tool_rules.append(ContinueToolRule(tool_name=tn)) - - if tool_rules: - check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=tool_rules) - - new_agent = AgentModel( - name=agent_create.name, - system=derive_system_message( - agent_type=agent_create.agent_type, - enable_sleeptime=agent_create.enable_sleeptime, - system=agent_create.system, - ), - hidden=agent_create.hidden, - agent_type=agent_create.agent_type, - llm_config=agent_create.llm_config, - embedding_config=agent_create.embedding_config, - organization_id=actor.organization_id, - description=agent_create.description, - metadata_=agent_create.metadata, - tool_rules=tool_rules, - project_id=agent_create.project_id, - template_id=agent_create.template_id, - base_template_id=agent_create.base_template_id, - message_buffer_autoclear=agent_create.message_buffer_autoclear, - enable_sleeptime=agent_create.enable_sleeptime, - response_format=agent_create.response_format, - created_by_id=actor.id, - last_updated_by_id=actor.id, - timezone=agent_create.timezone, - max_files_open=agent_create.max_files_open, - per_file_view_window_char_limit=agent_create.per_file_view_window_char_limit, - ) - - # Set template fields for InternalTemplateAgentCreate (similar to group creation) - if isinstance(agent_create, InternalTemplateAgentCreate): - new_agent.base_template_id = agent_create.base_template_id - new_agent.template_id = agent_create.template_id - new_agent.deployment_id = agent_create.deployment_id - new_agent.entity_id = agent_create.entity_id - - if _test_only_force_id: - new_agent.id = _test_only_force_id - - session.add(new_agent) - session.flush() - aid = new_agent.id - - # Note: These methods may need async versions if they perform database operations - self._bulk_insert_pivot( - session, - ToolsAgents.__table__, - [{"agent_id": aid, "tool_id": tid} for tid in tool_ids], - ) - - if block_ids: - result = session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(block_ids))) - rows = [{"agent_id": aid, "block_id": bid, "block_label": lbl} for bid, lbl in result.all()] - self._bulk_insert_pivot(session, BlocksAgents.__table__, rows) - - self._bulk_insert_pivot( - session, - SourcesAgents.__table__, - [{"agent_id": aid, "source_id": sid} for sid in source_ids], - ) - self._bulk_insert_pivot( - session, - AgentsTags.__table__, - [{"agent_id": aid, "tag": tag} for tag in tag_values], - ) - self._bulk_insert_pivot( - session, - IdentitiesAgents.__table__, - [{"agent_id": aid, "identity_id": iid} for iid in identity_ids], - ) - - agent_secrets = agent_create.secrets or agent_create.tool_exec_environment_variables - if agent_secrets: - env_rows = [ - { - "agent_id": aid, - "key": key, - "value": val, - "organization_id": actor.organization_id, - } - for key, val in agent_secrets.items() - ] - session.execute(insert(AgentEnvironmentVariable).values(env_rows)) - - # initial message sequence - init_messages = self._generate_initial_message_sequence( - actor, - agent_state=new_agent.to_pydantic(include_relationships={"memory"}), - supplied_initial_message_sequence=agent_create.initial_message_sequence, - ) - new_agent.message_ids = [msg.id for msg in init_messages] - - session.refresh(new_agent) - - # Using the synchronous version since we don't have an async version yet - # If you implement an async version of create_many_messages, you can switch to that - self.message_manager.create_many_messages(pydantic_msgs=init_messages, actor=actor) - return new_agent.to_pydantic() @trace_method async def create_agent_async( @@ -783,14 +604,6 @@ class AgentManager: return init_messages - @enforce_types - @trace_method - def append_initial_message_sequence_to_in_context_messages( - self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None - ) -> PydanticAgentState: - init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence) - return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor) - @enforce_types @trace_method async def append_initial_message_sequence_to_in_context_messages_async( @@ -799,130 +612,6 @@ class AgentManager: init_messages = await self._generate_initial_message_sequence_async(actor, agent_state, initial_message_sequence) return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor) - @enforce_types - @trace_method - def update_agent( - self, - agent_id: str, - agent_update: UpdateAgent, - actor: PydanticUser, - ) -> PydanticAgentState: - new_tools = set(agent_update.tool_ids or []) - new_sources = set(agent_update.source_ids or []) - new_blocks = set(agent_update.block_ids or []) - new_idents = set(agent_update.identity_ids or []) - new_tags = set(agent_update.tags or []) - - with db_registry.session() as session, session.begin(): - agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - agent.updated_at = datetime.now(timezone.utc) - agent.last_updated_by_id = actor.id - - scalar_updates = { - "name": agent_update.name, - "system": agent_update.system, - "llm_config": agent_update.llm_config, - "embedding_config": agent_update.embedding_config, - "message_ids": agent_update.message_ids, - "tool_rules": agent_update.tool_rules, - "description": agent_update.description, - "project_id": agent_update.project_id, - "template_id": agent_update.template_id, - "base_template_id": agent_update.base_template_id, - "message_buffer_autoclear": agent_update.message_buffer_autoclear, - "enable_sleeptime": agent_update.enable_sleeptime, - "response_format": agent_update.response_format, - "last_run_completion": agent_update.last_run_completion, - "last_run_duration_ms": agent_update.last_run_duration_ms, - "max_files_open": agent_update.max_files_open, - "per_file_view_window_char_limit": agent_update.per_file_view_window_char_limit, - "timezone": agent_update.timezone, - } - for col, val in scalar_updates.items(): - if val is not None: - setattr(agent, col, val) - - if agent_update.metadata is not None: - agent.metadata_ = agent_update.metadata - - aid = agent.id - - if agent_update.tool_ids is not None: - self._replace_pivot_rows( - session, - ToolsAgents.__table__, - aid, - [{"agent_id": aid, "tool_id": tid} for tid in new_tools], - ) - session.expire(agent, ["tools"]) - - if agent_update.source_ids is not None: - self._replace_pivot_rows( - session, - SourcesAgents.__table__, - aid, - [{"agent_id": aid, "source_id": sid} for sid in new_sources], - ) - session.expire(agent, ["sources"]) - - if agent_update.block_ids is not None: - rows = [] - if new_blocks: - label_map = { - bid: lbl - for bid, lbl in session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(new_blocks))) - } - rows = [{"agent_id": aid, "block_id": bid, "block_label": label_map[bid]} for bid in new_blocks] - - self._replace_pivot_rows(session, BlocksAgents.__table__, aid, rows) - session.expire(agent, ["core_memory"]) - - if agent_update.identity_ids is not None: - self._replace_pivot_rows( - session, - IdentitiesAgents.__table__, - aid, - [{"agent_id": aid, "identity_id": iid} for iid in new_idents], - ) - session.expire(agent, ["identities"]) - - if agent_update.tags is not None: - self._replace_pivot_rows( - session, - AgentsTags.__table__, - aid, - [{"agent_id": aid, "tag": tag} for tag in new_tags], - ) - session.expire(agent, ["tags"]) - - agent_secrets = agent_update.secrets or agent_update.tool_exec_environment_variables - if agent_secrets is not None: - session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid)) - env_rows = [ - { - "agent_id": aid, - "key": k, - "value": v, - "organization_id": agent.organization_id, - } - for k, v in agent_secrets.items() - ] - if env_rows: - self._bulk_insert_pivot(session, AgentEnvironmentVariable.__table__, env_rows) - session.expire(agent, ["tool_exec_environment_variables"]) - - if agent_update.enable_sleeptime and agent_update.system is None: - agent.system = derive_system_message( - agent_type=agent.agent_type, - enable_sleeptime=agent_update.enable_sleeptime, - system=agent.system, - ) - - session.flush() - session.refresh(agent) - - return agent.to_pydantic() - @enforce_types @trace_method async def update_agent_async( @@ -1073,67 +762,6 @@ class AgentManager: await agent.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) await session.commit() - # TODO: Make this general and think about how to roll this into sqlalchemybase - @trace_method - def list_agents( - self, - actor: PydanticUser, - name: Optional[str] = None, - tags: Optional[List[str]] = None, - match_all_tags: bool = False, - before: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = 50, - query_text: Optional[str] = None, - project_id: Optional[str] = None, - template_id: Optional[str] = None, - base_template_id: Optional[str] = None, - identity_id: Optional[str] = None, - identifier_keys: Optional[List[str]] = None, - include_relationships: Optional[List[str]] = None, - ascending: bool = True, - sort_by: Optional[str] = "created_at", - ) -> List[PydanticAgentState]: - """ - Retrieves agents with optimized filtering and optional field selection. - - Args: - actor: The User requesting the list - name (Optional[str]): Filter by agent name. - tags (Optional[List[str]]): Filter agents by tags. - match_all_tags (bool): If True, only return agents that match ALL given tags. - before (Optional[str]): Cursor for pagination. - after (Optional[str]): Cursor for pagination. - limit (Optional[int]): Maximum number of agents to return. - query_text (Optional[str]): Search agents by name. - project_id (Optional[str]): Filter by project ID. - template_id (Optional[str]): Filter by template ID. - base_template_id (Optional[str]): Filter by base template ID. - identity_id (Optional[str]): Filter by identifier ID. - identifier_keys (Optional[List[str]]): Search agents by identifier keys. - include_relationships (Optional[List[str]]): List of fields to load for performance optimization. - ascending - - Returns: - List[PydanticAgentState]: The filtered list of matching agents. - """ - with db_registry.session() as session: - query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id) - query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) - - # Apply filters - query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id) - query = _apply_identity_filters(query, identity_id, identifier_keys) - query = _apply_tag_filter(query, tags, match_all_tags) - query = _apply_pagination(query, before, after, session, ascending=ascending, sort_by=sort_by) - - if limit: - query = query.limit(limit) - - result = session.execute(query) - agents = result.scalars().all() - return [agent.to_pydantic(include_relationships=include_relationships) for agent in agents] - @trace_method async def list_agents_async( self, @@ -1201,50 +829,6 @@ class AgentManager: agents = result.scalars().all() return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) - @enforce_types - @trace_method - def list_agents_matching_tags( - self, - actor: PydanticUser, - match_all: List[str], - match_some: List[str], - limit: Optional[int] = 50, - ) -> List[PydanticAgentState]: - """ - Retrieves agents in the same organization that match all specified `match_all` tags - and at least one tag from `match_some`. The query is optimized for efficiency by - leveraging indexed filtering and aggregation. - - Args: - actor (PydanticUser): The user requesting the agent list. - match_all (List[str]): Agents must have all these tags. - match_some (List[str]): Agents must have at least one of these tags. - limit (Optional[int]): Maximum number of agents to return. - - Returns: - List[PydanticAgentState: The filtered list of matching agents. - """ - with db_registry.session() as session: - query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id) - - if match_all: - # Subquery to find agent IDs that contain all match_all tags - subquery = ( - select(AgentsTags.agent_id) - .where(AgentsTags.tag.in_(match_all)) - .group_by(AgentsTags.agent_id) - .having(func.count(AgentsTags.tag) == literal(len(match_all))) - ) - query = query.where(AgentModel.id.in_(subquery)) - - if match_some: - # Ensures agents match at least one tag in match_some - query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some)) - - query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit) - - return list(session.execute(query).scalars()) - @enforce_types @trace_method async def list_agents_matching_tags_async( @@ -1289,17 +873,6 @@ class AgentManager: result = await session.execute(query) return await asyncio.gather(*[agent.to_pydantic_async() for agent in result.scalars()]) - @trace_method - def size( - self, - actor: PydanticUser, - ) -> int: - """ - Get the total count of agents for the given user. - """ - with db_registry.session() as session: - return AgentModel.size(db_session=session, actor=actor) - @trace_method async def size_async( self, @@ -1311,14 +884,6 @@ class AgentManager: async with db_registry.async_session() as session: return await AgentModel.size_async(db_session=session, actor=actor) - @enforce_types - @trace_method - def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: - """Fetch an agent by its ID.""" - with db_registry.session() as session: - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - return agent.to_pydantic() - @enforce_types @trace_method async def get_agent_by_id_async( @@ -1374,14 +939,6 @@ class AgentManager: logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}") raise - @enforce_types - @trace_method - def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState: - """Fetch an agent by its ID.""" - with db_registry.session() as session: - agent = AgentModel.read(db_session=session, name=agent_name, actor=actor) - return agent.to_pydantic() - @enforce_types @trace_method async def get_agent_archive_ids_async(self, agent_id: str, actor: PydanticUser) -> List[str]: @@ -1395,54 +952,6 @@ class AgentManager: archive_ids = [row[0] for row in result.fetchall()] return archive_ids - @enforce_types - @trace_method - def delete_agent(self, agent_id: str, actor: PydanticUser) -> None: - """ - Deletes an agent and its associated relationships. - Ensures proper permission checks and cascades where applicable. - - Args: - agent_id: ID of the agent to be deleted. - actor: User performing the action. - - Raises: - NoResultFound: If agent doesn't exist - """ - with db_registry.session() as session: - # Retrieve the agent - logger.debug(f"Hard deleting Agent with ID: {agent_id} with actor={actor}") - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - agents_to_delete = [agent] - sleeptime_group_to_delete = None - - # Delete sleeptime agent and group (TODO this is flimsy pls fix) - if agent.multi_agent_group: - participant_agent_ids = agent.multi_agent_group.agent_ids - if agent.multi_agent_group.manager_type in {ManagerType.sleeptime, ManagerType.voice_sleeptime} and participant_agent_ids: - for participant_agent_id in participant_agent_ids: - try: - sleeptime_agent = AgentModel.read(db_session=session, identifier=participant_agent_id, actor=actor) - agents_to_delete.append(sleeptime_agent) - except NoResultFound: - pass # agent already deleted - sleeptime_agent_group = GroupModel.read(db_session=session, identifier=agent.multi_agent_group.id, actor=actor) - sleeptime_group_to_delete = sleeptime_agent_group - - try: - if sleeptime_group_to_delete is not None: - session.delete(sleeptime_group_to_delete) - session.commit() - for agent in agents_to_delete: - session.delete(agent) - session.commit() - except Exception as e: - session.rollback() - logger.exception(f"Failed to hard delete Agent with ID {agent_id}") - raise ValueError(f"Failed to hard delete Agent with ID {agent_id}: {e}") - else: - logger.debug(f"Agent with ID {agent_id} successfully hard deleted") - @enforce_types @trace_method async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None: @@ -1493,168 +1002,9 @@ class AgentManager: else: logger.debug(f"Agent with ID {agent_id} successfully hard deleted") - @enforce_types - @trace_method - def serialize(self, agent_id: str, actor: PydanticUser, max_steps: Optional[int] = None) -> AgentSchema: - with db_registry.session() as session: - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - schema = MarshmallowAgentSchema(session=session, actor=actor, max_steps=max_steps) - data = schema.dump(agent) - return AgentSchema(**data) - - @enforce_types - @trace_method - def deserialize( - self, - serialized_agent: AgentSchema, - actor: PydanticUser, - append_copy_suffix: bool = True, - override_existing_tools: bool = True, - project_id: Optional[str] = None, - strip_messages: Optional[bool] = False, - env_vars: Optional[dict[str, Any]] = None, - ) -> PydanticAgentState: - serialized_agent_dict = serialized_agent.model_dump() - tool_data_list = serialized_agent_dict.pop("tools", []) - messages = serialized_agent_dict.pop(MarshmallowAgentSchema.FIELD_MESSAGES, []) - - for msg in messages: - msg[MarshmallowAgentSchema.FIELD_ID] = SerializedMessageSchema.generate_id() # Generate new ID - - message_ids = [] - in_context_message_indices = serialized_agent_dict.pop(MarshmallowAgentSchema.FIELD_IN_CONTEXT_INDICES) - for idx in in_context_message_indices: - message_ids.append(messages[idx][MarshmallowAgentSchema.FIELD_ID]) - - serialized_agent_dict[MarshmallowAgentSchema.FIELD_MESSAGE_IDS] = message_ids - - with db_registry.session() as session: - schema = MarshmallowAgentSchema(session=session, actor=actor) - agent = schema.load(serialized_agent_dict, session=session) - - agent.organization_id = actor.organization_id - for block in agent.core_memory: - block.organization_id = actor.organization_id - if append_copy_suffix: - agent.name += "_copy" - if project_id: - agent.project_id = project_id - - if strip_messages: - # we want to strip all but the first (system) message - agent.message_ids = [agent.message_ids[0]] - - if env_vars: - for var in agent.tool_exec_environment_variables: - var.value = env_vars.get(var.key, "") - for var in agent.secrets: - var.value = env_vars.get(var.key, "") - - agent = agent.create(session, actor=actor) - - pydantic_agent = agent.to_pydantic() - - pyd_msgs = [] - message_schema = SerializedMessageSchema(session=session, actor=actor) - - for serialized_message in messages: - pydantic_message = message_schema.load(serialized_message, session=session).to_pydantic() - pydantic_message.agent_id = agent.id - pyd_msgs.append(pydantic_message) - self.message_manager.create_many_messages(pyd_msgs, actor=actor) - - # Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle - for tool_data in tool_data_list: - pydantic_tool = SerializedToolSchema(actor=actor).load(tool_data, transient=True).to_pydantic() - - existing_pydantic_tool = self.tool_manager.get_tool_by_name(pydantic_tool.name, actor=actor) - if existing_pydantic_tool and ( - existing_pydantic_tool.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MULTI_AGENT_CORE, ToolType.LETTA_MEMORY_CORE} - or not override_existing_tools - ): - pydantic_tool = existing_pydantic_tool - else: - pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor, bypass_name_check=True) - - pydantic_agent = self.attach_tool(agent_id=pydantic_agent.id, tool_id=pydantic_tool.id, actor=actor) - - return pydantic_agent - # ====================================================================================================================== # Per Agent Environment Variable Management # ====================================================================================================================== - @enforce_types - @trace_method - def _set_environment_variables( - self, - agent_id: str, - env_vars: Dict[str, str], - actor: PydanticUser, - ) -> PydanticAgentState: - """ - Adds or replaces the environment variables for the specified agent. - - Args: - agent_id: The agent id. - env_vars: A dictionary of environment variable key-value pairs. - actor: The user performing the action. - - Returns: - PydanticAgentState: The updated agent as a Pydantic model. - """ - with db_registry.session() as session: - # Retrieve the agent - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - - # Fetch existing environment variables as a dictionary - existing_vars = {var.key: var for var in agent.tool_exec_environment_variables} - - # Update or create environment variables - updated_vars = [] - for key, value in env_vars.items(): - if key in existing_vars: - # Update existing variable - existing_vars[key].value = value - updated_vars.append(existing_vars[key]) - else: - # Create new variable - updated_vars.append( - AgentEnvironmentVariableModel( - key=key, - value=value, - agent_id=agent_id, - organization_id=actor.organization_id, - created_by_id=actor.id, - last_updated_by_id=actor.id, - ) - ) - - # Remove stale variables - stale_keys = set(existing_vars) - set(env_vars) - agent.tool_exec_environment_variables = [var for var in updated_vars if var.key not in stale_keys] - agent.secrets = [var for var in updated_vars if var.key not in stale_keys] - - # Update the agent in the database - agent.update(session, actor=actor) - - # Return the updated agent state - return agent.to_pydantic() - - @enforce_types - @trace_method - def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]: - with db_registry.session() as session: - query = ( - select(GroupModel) - .join(GroupsAgents, GroupModel.id == GroupsAgents.group_id) - .where(GroupsAgents.agent_id == agent_id, GroupModel.organization_id == actor.organization_id) - ) - - if manager_type: - query = query.where(GroupModel.manager_type == manager_type) - - result = session.execute(query) - return [group.to_pydantic() for group in result.scalars()] # ====================================================================================================================== # In Context Messages Management @@ -2240,22 +1590,6 @@ class AgentManager: # ====================================================================================================================== # Block management # ====================================================================================================================== - @enforce_types - @trace_method - def get_block_with_label( - self, - agent_id: str, - block_label: str, - actor: PydanticUser, - ) -> PydanticBlock: - """Gets a block attached to an agent by its label.""" - with db_registry.session() as session: - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - for block in agent.core_memory: - if block.label == block_label: - return block.to_pydantic() - raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") - @enforce_types @trace_method async def get_block_with_label_async( @@ -2300,61 +1634,6 @@ class AgentManager: await block.update_async(session, actor=actor) return block.to_pydantic() - @enforce_types - @trace_method - def update_block_with_label( - self, - agent_id: str, - block_label: str, - new_block_id: str, - actor: PydanticUser, - ) -> PydanticAgentState: - """Updates which block is assigned to a specific label for an agent.""" - with db_registry.session() as session: - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor) - - if new_block.label != block_label: - raise ValueError(f"New block label '{new_block.label}' doesn't match required label '{block_label}'") - - # Remove old block with this label if it exists - agent.core_memory = [b for b in agent.core_memory if b.label != block_label] - - # Add new block - agent.core_memory.append(new_block) - agent.update(session, actor=actor) - return agent.to_pydantic() - - @enforce_types - @trace_method - def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState: - """Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group.""" - with db_registry.session() as session: - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) - - # Attach block to the main agent - agent.core_memory.append(block) - agent.update(session, actor=actor, no_commit=True) - - # If agent is part of a sleeptime group, attach block to the sleeptime_agent - if agent.multi_agent_group and agent.multi_agent_group.manager_type == ManagerType.sleeptime: - group = agent.multi_agent_group - # Find the sleeptime_agent in the group - for other_agent_id in group.agent_ids or []: - if other_agent_id != agent_id: - try: - other_agent = AgentModel.read(db_session=session, identifier=other_agent_id, actor=actor) - if other_agent.agent_type == AgentType.sleeptime_agent and block not in other_agent.core_memory: - other_agent.core_memory.append(block) - other_agent.update(session, actor=actor, no_commit=True) - except NoResultFound: - # Agent might not exist anymore, skip - continue - session.commit() - - return agent.to_pydantic() - @enforce_types @trace_method async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState: @@ -2391,27 +1670,6 @@ class AgentManager: return await agent.to_pydantic_async() - @enforce_types - @trace_method - def detach_block( - self, - agent_id: str, - block_id: str, - actor: PydanticUser, - ) -> PydanticAgentState: - """Detaches a block from an agent.""" - with db_registry.session() as session: - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - original_length = len(agent.core_memory) - - agent.core_memory = [b for b in agent.core_memory if b.id != block_id] - - if len(agent.core_memory) == original_length: - raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'") - - agent.update(session, actor=actor) - return agent.to_pydantic() - @enforce_types @trace_method async def detach_block_async( @@ -2433,27 +1691,6 @@ class AgentManager: await agent.update_async(session, actor=actor) return await agent.to_pydantic_async() - @enforce_types - @trace_method - def detach_block_with_label( - self, - agent_id: str, - block_label: str, - actor: PydanticUser, - ) -> PydanticAgentState: - """Detaches a block with the specified label from an agent.""" - with db_registry.session() as session: - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - original_length = len(agent.core_memory) - - agent.core_memory = [b for b in agent.core_memory if b.label != block_label] - - if len(agent.core_memory) == original_length: - raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}' with actor id: '{actor.id}'") - - agent.update(session, actor=actor) - return agent.to_pydantic() - # ====================================================================================================================== # Passage Management # ====================================================================================================================== @@ -2985,41 +2222,6 @@ class AgentManager: # Tool Management # ====================================================================================================================== @enforce_types - @trace_method - def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: - """ - Attaches a tool to an agent. - - Args: - agent_id: ID of the agent to attach the tool to. - tool_id: ID of the tool to attach. - actor: User performing the action. - - Raises: - NoResultFound: If the agent or tool is not found. - - Returns: - PydanticAgentState: The updated agent state. - """ - with db_registry.session() as session: - # Verify the agent exists and user has permission to access it - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - - # Use the _process_relationship helper to attach the tool - _process_relationship( - session=session, - agent=agent, - relationship_name="tools", - model_class=ToolModel, - item_ids=[tool_id], - allow_partial=False, # Ensure the tool exists - replace=False, # Extend the existing tools - ) - - # Commit and refresh the agent - agent.update(session, actor=actor) - return agent.to_pydantic() - @enforce_types @trace_method async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None: @@ -3253,40 +2455,6 @@ class AgentManager: return PydanticAgentState(**agent_state_dict) - @enforce_types - @trace_method - def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: - """ - Detaches a tool from an agent. - - Args: - agent_id: ID of the agent to detach the tool from. - tool_id: ID of the tool to detach. - actor: User performing the action. - - Raises: - NoResultFound: If the agent or tool is not found. - - Returns: - PydanticAgentState: The updated agent state. - """ - with db_registry.session() as session: - # Verify the agent exists and user has permission to access it - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - - # Filter out the tool to be detached - remaining_tools = [tool for tool in agent.tools if tool.id != tool_id] - - if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship - logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}") - - # Update the tools relationship - agent.tools = remaining_tools - - # Commit and refresh the agent - agent.update(session, actor=actor) - return agent.to_pydantic() - @enforce_types @trace_method async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None: @@ -3379,23 +2547,6 @@ class AgentManager: session.add(agent) await session.commit() - @enforce_types - @trace_method - def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]: - """ - List all tools attached to an agent. - - Args: - agent_id: ID of the agent to list tools for. - actor: User performing the action. - - Returns: - List[PydanticTool]: List of tools attached to the agent. - """ - with db_registry.session() as session: - agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - return [tool.to_pydantic() for tool in agent.tools] - @enforce_types @trace_method async def list_attached_tools_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]: @@ -3505,45 +2656,6 @@ class AgentManager: # ====================================================================================================================== # Tag Management # ====================================================================================================================== - @enforce_types - @trace_method - def list_tags( - self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None - ) -> List[str]: - """ - Get all tags a user has created, ordered alphabetically. - - Args: - actor: User performing the action. - after: Cursor for forward pagination. - limit: Maximum number of tags to return. - query_text: Query text to filter tags by. - - Returns: - List[str]: List of all tags. - """ - with db_registry.session() as session: - query = ( - session.query(AgentsTags.tag) - .join(AgentModel, AgentModel.id == AgentsTags.agent_id) - .filter(AgentModel.organization_id == actor.organization_id) - .distinct() - ) - - if query_text: - if settings.database_engine is DatabaseChoice.POSTGRES: - # PostgreSQL: Use ILIKE for case-insensitive search - query = query.filter(AgentsTags.tag.ilike(f"%{query_text}%")) - else: - # SQLite: Use LIKE with LOWER for case-insensitive search - query = query.filter(func.lower(AgentsTags.tag).like(func.lower(f"%{query_text}%"))) - - if after: - query = query.filter(AgentsTags.tag > after) - - query = query.order_by(AgentsTags.tag).limit(limit) - results = [tag[0] for tag in query.all()] - return results @enforce_types @trace_method diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index 0e54e7cd..203a1ffb 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -19,32 +19,6 @@ logger = get_logger(__name__) class ArchiveManager: """Manager class to handle business logic related to Archives.""" - @enforce_types - @trace_method - def create_archive( - self, - name: str, - description: Optional[str] = None, - actor: PydanticUser = None, - ) -> PydanticArchive: - """Create a new archive.""" - try: - with db_registry.session() as session: - # determine vector db provider based on settings - vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE - - archive = ArchiveModel( - name=name, - description=description, - organization_id=actor.organization_id, - vector_db_provider=vector_db_provider, - ) - archive.create(session, actor=actor) - return archive.to_pydantic() - except Exception as e: - logger.exception(f"Failed to create archive {name}. error={e}") - raise - @enforce_types @trace_method async def create_archive_async( @@ -160,36 +134,6 @@ class ArchiveManager: ) return [a.to_pydantic() for a in archives] - @enforce_types - @trace_method - def attach_agent_to_archive( - self, - agent_id: str, - archive_id: str, - is_owner: bool, - actor: PydanticUser, - ) -> None: - """Attach an agent to an archive.""" - with db_registry.session() as session: - # Check if already attached - existing = session.query(ArchivesAgents).filter_by(agent_id=agent_id, archive_id=archive_id).first() - - if existing: - # Update ownership if needed - if existing.is_owner != is_owner: - existing.is_owner = is_owner - session.commit() - return - - # Create new relationship - archives_agents = ArchivesAgents( - agent_id=agent_id, - archive_id=archive_id, - is_owner=is_owner, - ) - session.add(archives_agents) - session.commit() - @enforce_types @trace_method async def attach_agent_to_archive_async( @@ -345,50 +289,6 @@ class ArchiveManager: # this shouldn't happen, but if it does, re-raise raise - @enforce_types - @trace_method - def get_or_create_default_archive_for_agent( - self, - agent_id: str, - agent_name: Optional[str] = None, - actor: PydanticUser = None, - ) -> PydanticArchive: - """Get the agent's default archive, creating one if it doesn't exist.""" - with db_registry.session() as session: - # First check if agent has any archives - query = select(ArchivesAgents.archive_id).where(ArchivesAgents.agent_id == agent_id) - result = session.execute(query) - archive_ids = [row[0] for row in result.fetchall()] - - if archive_ids: - # TODO: Remove this check once we support multiple archives per agent - if len(archive_ids) > 1: - raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported") - # Get the archive - archive = ArchiveModel.read(db_session=session, identifier=archive_ids[0], actor=actor) - return archive.to_pydantic() - - # Create a default archive for this agent - archive_name = f"{agent_name or f'Agent {agent_id}'}'s Archive" - - # Create the archive - archive_model = ArchiveModel( - name=archive_name, - description="Default archive created automatically", - organization_id=actor.organization_id, - ) - archive_model.create(session, actor=actor) - - # Attach the agent to the archive as owner - self.attach_agent_to_archive( - agent_id=agent_id, - archive_id=archive_model.id, - is_owner=True, - actor=actor, - ) - - return archive_model.to_pydantic() - @enforce_types @trace_method async def get_agents_for_archive_async( diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 0e0b4447..3717d1cd 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -26,21 +26,6 @@ logger = get_logger(__name__) class BlockManager: """Manager class to handle business logic related to Blocks.""" - @enforce_types - @trace_method - def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: - """Create a new block based on the Block schema.""" - db_block = self.get_block_by_id(block.id, actor) - if db_block: - update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True)) - self.update_block(block.id, update_data, actor) - else: - with db_registry.session() as session: - data = block.model_dump(to_orm=True, exclude_none=True) - block = BlockModel(**data, organization_id=actor.organization_id) - block.create(session, actor=actor) - return block.to_pydantic() - @enforce_types @trace_method async def create_or_update_block_async(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: @@ -58,30 +43,6 @@ class BlockManager: await session.commit() return pydantic_block - @enforce_types - @trace_method - def batch_create_blocks(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]: - """ - Batch-create multiple Blocks in one transaction for better performance. - Args: - blocks: List of PydanticBlock schemas to create - actor: The user performing the operation - Returns: - List of created PydanticBlock instances (with IDs, timestamps, etc.) - """ - if not blocks: - return [] - - with db_registry.session() as session: - block_models = [ - BlockModel(**block.model_dump(to_orm=True, exclude_none=True), organization_id=actor.organization_id) for block in blocks - ] - - created_models = BlockModel.batch_create(items=block_models, db_session=session, actor=actor) - - # Convert back to Pydantic - return [m.to_pydantic() for m in created_models] - @enforce_types @trace_method async def batch_create_blocks_async(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]: @@ -107,22 +68,6 @@ class BlockManager: await session.commit() return result - @enforce_types - @trace_method - def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: - """Update a block by its ID with the given BlockUpdate object.""" - # Safety check for block - - with db_registry.session() as session: - block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) - update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) - - for key, value in update_data.items(): - setattr(block, key, value) - - block.update(db_session=session, actor=actor) - return block.to_pydantic() - @enforce_types @trace_method async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: @@ -141,19 +86,6 @@ class BlockManager: await session.commit() return pydantic_block - @enforce_types - @trace_method - def delete_block(self, block_id: str, actor: PydanticUser) -> None: - """Delete a block by its ID.""" - with db_registry.session() as session: - # First, delete all references in blocks_agents table - session.execute(delete(BlocksAgents).where(BlocksAgents.block_id == block_id)) - session.flush() - - # Then delete the block itself - block = BlockModel.read(db_session=session, identifier=block_id) - block.hard_delete(db_session=session, actor=actor) - @enforce_types @trace_method async def delete_block_async(self, block_id: str, actor: PydanticUser) -> None: @@ -352,17 +284,6 @@ class BlockManager: return [block.to_pydantic() for block in blocks] - @enforce_types - @trace_method - def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: - """Retrieve a block by its name.""" - with db_registry.session() as session: - try: - block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) - return block.to_pydantic() - except NoResultFound: - return None - @enforce_types @trace_method async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: @@ -524,72 +445,6 @@ class BlockManager: # Block History Functions - @enforce_types - @trace_method - def checkpoint_block( - self, - block_id: str, - actor: PydanticUser, - agent_id: Optional[str] = None, - use_preloaded_block: Optional[BlockModel] = None, # For concurrency tests - ) -> PydanticBlock: - """ - Create a new checkpoint for the given Block by copying its - current state into BlockHistory, using SQLAlchemy's built-in - version_id_col for concurrency checks. - - - If the block was undone to an earlier checkpoint, we remove - any "future" checkpoints beyond the current state to keep a - strictly linear history. - - A single commit at the end ensures atomicity. - """ - with db_registry.session() as session: - # 1) Load the Block - if use_preloaded_block is not None: - block = session.merge(use_preloaded_block) - else: - block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) - - # 2) Identify the block's current checkpoint (if any) - current_entry = None - if block.current_history_entry_id: - current_entry = session.get(BlockHistory, block.current_history_entry_id) - - # The current sequence, or 0 if no checkpoints exist - current_seq = current_entry.sequence_number if current_entry else 0 - - # 3) Truncate any future checkpoints - # If we are at seq=2, but there's a seq=3 or higher from a prior "redo chain", - # remove those, so we maintain a strictly linear undo/redo stack. - session.query(BlockHistory).filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq).delete() - - # 4) Determine the next sequence number - next_seq = current_seq + 1 - - # 5) Create a new BlockHistory row reflecting the block's current state - history_entry = BlockHistory( - organization_id=actor.organization_id, - block_id=block.id, - sequence_number=next_seq, - description=block.description, - label=block.label, - value=block.value, - limit=block.limit, - metadata_=block.metadata_, - actor_type=ActorType.LETTA_AGENT if agent_id else ActorType.LETTA_USER, - actor_id=agent_id if agent_id else actor.id, - ) - history_entry.create(session, actor=actor, no_commit=True) - - # 6) Update the block’s pointer to the new checkpoint - block.current_history_entry_id = history_entry.id - - # 7) Flush changes, then commit once - block = block.update(db_session=session, actor=actor, no_commit=True) - session.commit() - - return block.to_pydantic() - @enforce_types def _move_block_to_sequence(self, session: Session, block: BlockModel, target_seq: int, actor: PydanticUser) -> BlockModel: """ @@ -628,88 +483,6 @@ class BlockManager: updated_block = block.update(db_session=session, actor=actor, no_commit=True) return updated_block - @enforce_types - @trace_method - def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock: - """ - Move the block to the immediately previous checkpoint in BlockHistory. - If older sequences have been pruned, we jump to the largest sequence - number that is still < current_seq. - """ - with db_registry.session() as session: - # 1) Load the current block - block = ( - session.merge(use_preloaded_block) - if use_preloaded_block - else BlockModel.read(db_session=session, identifier=block_id, actor=actor) - ) - - if not block.current_history_entry_id: - raise ValueError(f"Block {block_id} has no history entry - cannot undo.") - - current_entry = session.get(BlockHistory, block.current_history_entry_id) - if not current_entry: - raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}") - - current_seq = current_entry.sequence_number - - # 2) Find the largest sequence < current_seq - previous_entry = ( - session.query(BlockHistory) - .filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number < current_seq) - .order_by(BlockHistory.sequence_number.desc()) - .first() - ) - if not previous_entry: - # No earlier checkpoint available - raise ValueError(f"Block {block_id} is already at the earliest checkpoint (seq={current_seq}). Cannot undo further.") - - # 3) Move to that sequence - block = self._move_block_to_sequence(session, block, previous_entry.sequence_number, actor) - - # 4) Commit - session.commit() - return block.to_pydantic() - - @enforce_types - @trace_method - def redo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock: - """ - Move the block to the next checkpoint if it exists. - If some middle checkpoints have been pruned, we jump to the smallest - sequence > current_seq that remains. - """ - with db_registry.session() as session: - block = ( - session.merge(use_preloaded_block) - if use_preloaded_block - else BlockModel.read(db_session=session, identifier=block_id, actor=actor) - ) - - if not block.current_history_entry_id: - raise ValueError(f"Block {block_id} has no history entry - cannot redo.") - - current_entry = session.get(BlockHistory, block.current_history_entry_id) - if not current_entry: - raise NoResultFound(f"BlockHistory row not found for id={block.current_history_entry_id}") - - current_seq = current_entry.sequence_number - - # Find the smallest sequence that is > current_seq - next_entry = ( - session.query(BlockHistory) - .filter(BlockHistory.block_id == block.id, BlockHistory.sequence_number > current_seq) - .order_by(BlockHistory.sequence_number.asc()) - .first() - ) - if not next_entry: - raise ValueError(f"Block {block_id} is at the highest checkpoint (seq={current_seq}). Cannot redo further.") - - block = self._move_block_to_sequence(session, block, next_entry.sequence_number, actor) - - session.commit() - return block.to_pydantic() - @enforce_types @trace_method async def bulk_update_block_values_async( diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 609a2a89..3f67e003 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -60,13 +60,6 @@ class GroupManager: groups = result.scalars().all() return [group.to_pydantic() for group in groups] - @enforce_types - @trace_method - def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup: - with db_registry.session() as session: - group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) - return group.to_pydantic() - @enforce_types @trace_method async def retrieve_group_async(self, group_id: str, actor: PydanticUser) -> PydanticGroup: @@ -74,57 +67,6 @@ class GroupManager: group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) return group.to_pydantic() - @enforce_types - @trace_method - def create_group(self, group: Union[GroupCreate, InternalTemplateGroupCreate], actor: PydanticUser) -> PydanticGroup: - with db_registry.session() as session: - new_group = GroupModel() - new_group.organization_id = actor.organization_id - new_group.description = group.description - - match group.manager_config.manager_type: - case ManagerType.round_robin: - new_group.manager_type = ManagerType.round_robin - new_group.max_turns = group.manager_config.max_turns - case ManagerType.dynamic: - new_group.manager_type = ManagerType.dynamic - new_group.manager_agent_id = group.manager_config.manager_agent_id - new_group.max_turns = group.manager_config.max_turns - new_group.termination_token = group.manager_config.termination_token - case ManagerType.supervisor: - new_group.manager_type = ManagerType.supervisor - new_group.manager_agent_id = group.manager_config.manager_agent_id - case ManagerType.sleeptime: - new_group.manager_type = ManagerType.sleeptime - new_group.manager_agent_id = group.manager_config.manager_agent_id - new_group.sleeptime_agent_frequency = group.manager_config.sleeptime_agent_frequency - if new_group.sleeptime_agent_frequency: - new_group.turns_counter = -1 - case ManagerType.voice_sleeptime: - new_group.manager_type = ManagerType.voice_sleeptime - new_group.manager_agent_id = group.manager_config.manager_agent_id - max_message_buffer_length = group.manager_config.max_message_buffer_length - min_message_buffer_length = group.manager_config.min_message_buffer_length - # Safety check for buffer length range - self.ensure_buffer_length_range_valid(max_value=max_message_buffer_length, min_value=min_message_buffer_length) - new_group.max_message_buffer_length = max_message_buffer_length - new_group.min_message_buffer_length = min_message_buffer_length - case _: - raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") - - if isinstance(group, InternalTemplateGroupCreate): - new_group.base_template_id = group.base_template_id - new_group.template_id = group.template_id - new_group.deployment_id = group.deployment_id - - self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False) - - if group.shared_block_ids: - self._process_shared_block_relationship(session=session, group=new_group, block_ids=group.shared_block_ids) - - new_group.create(session, actor=actor) - return new_group.to_pydantic() - @enforce_types async def create_group_async(self, group: Union[GroupCreate, InternalTemplateGroupCreate], actor: PydanticUser) -> PydanticGroup: async with db_registry.async_session() as session: @@ -238,14 +180,6 @@ class GroupManager: await group.update_async(session, actor=actor) return group.to_pydantic() - @enforce_types - @trace_method - def delete_group(self, group_id: str, actor: PydanticUser) -> None: - with db_registry.session() as session: - # Retrieve the agent - group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) - group.hard_delete(session) - @enforce_types @trace_method async def delete_group_async(self, group_id: str, actor: PydanticUser) -> None: @@ -253,43 +187,6 @@ class GroupManager: group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) await group.hard_delete_async(session) - @enforce_types - @trace_method - def list_group_messages( - self, - actor: PydanticUser, - group_id: Optional[str] = None, - before: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = 50, - use_assistant_message: bool = True, - assistant_message_tool_name: str = "send_message", - assistant_message_tool_kwarg: str = "message", - ) -> list[LettaMessage]: - with db_registry.session() as session: - filters = { - "organization_id": actor.organization_id, - "group_id": group_id, - } - messages = MessageModel.list( - db_session=session, - before=before, - after=after, - limit=limit, - **filters, - ) - - messages = PydanticMessage.to_letta_messages_from_list( - messages=[msg.to_pydantic() for msg in messages], - use_assistant_message=use_assistant_message, - assistant_message_tool_name=assistant_message_tool_name, - assistant_message_tool_kwarg=assistant_message_tool_kwarg, - ) - - # TODO: filter messages to return a clean conversation history - - return messages - @enforce_types @trace_method async def list_group_messages_async( @@ -327,20 +224,6 @@ class GroupManager: return messages - @enforce_types - @trace_method - def reset_messages(self, group_id: str, actor: PydanticUser) -> None: - with db_registry.session() as session: - # Ensure group is loadable by user - group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) - - # Delete all messages in the group - session.query(MessageModel).filter( - MessageModel.organization_id == actor.organization_id, MessageModel.group_id == group_id - ).delete(synchronize_session=False) - - session.commit() - @enforce_types @trace_method async def reset_messages_async(self, group_id: str, actor: PydanticUser) -> None: @@ -356,18 +239,6 @@ class GroupManager: await session.commit() - @enforce_types - @trace_method - def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int: - with db_registry.session() as session: - # Ensure group is loadable by user - group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) - - # Update turns counter - group.turns_counter = (group.turns_counter + 1) % group.sleeptime_agent_frequency - group.update(session, actor=actor) - return group.turns_counter - @enforce_types @trace_method async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int: @@ -380,19 +251,6 @@ class GroupManager: await group.update_async(session, actor=actor) return group.turns_counter - @enforce_types - def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str: - with db_registry.session() as session: - # Ensure group is loadable by user - group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) - - # Update last processed message id - prev_last_processed_message_id = group.last_processed_message_id - group.last_processed_message_id = last_processed_message_id - group.update(session, actor=actor) - - return prev_last_processed_message_id - @enforce_types @trace_method async def get_last_processed_message_id_and_update_async( diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 9b4e81a8..1d35de9f 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -34,40 +34,6 @@ logger = get_logger(__name__) class JobManager: """Manager class to handle business logic related to Jobs.""" - @enforce_types - @trace_method - def create_job( - self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser - ) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]: - """Create a new job based on the JobCreate schema.""" - from letta.orm.agents_runs import AgentsRuns - - with db_registry.session() as session: - # Associate the job with the user - pydantic_job.user_id = actor.id - - # Get agent_id if present - agent_id = getattr(pydantic_job, "agent_id", None) - - # Verify agent exists before creating the job - # await validate_agent_exists_async(session, agent_id, actor) - - job_data = pydantic_job.model_dump(to_orm=True) - # Remove agent_id from job_data as it's not a field in the Job ORM model - # The relationship is handled through the AgentsRuns association table - job_data.pop("agent_id", None) - job = JobModel(**job_data) - job.organization_id = actor.organization_id - job.create(session, actor=actor) # Save job in the database - - # If this is a Run with an agent_id, create the agents_runs association - if agent_id and isinstance(pydantic_job, PydanticRun): - agents_run = AgentsRuns(agent_id=agent_id, run_id=job.id) - session.add(agents_run) - session.commit() - - return job.to_pydantic() - @enforce_types @trace_method async def create_job_async( @@ -111,69 +77,6 @@ class JobManager: return result - @enforce_types - @trace_method - def update_job_by_id(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob: - """Update a job by its ID with the given JobUpdate object.""" - # First check if we need to dispatch a callback - needs_callback = False - callback_url = None - with db_registry.session() as session: - job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"]) - not_completed_before = not bool(job.completed_at) - - # Check if we'll need to dispatch callback - if job_update.status in {JobStatus.completed, JobStatus.failed} and not_completed_before and job.callback_url: - needs_callback = True - callback_url = job.callback_url - - # Update the job first to get the final metadata - with db_registry.session() as session: - job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"]) - not_completed_before = not bool(job.completed_at) - - # Update job attributes with only the fields that were explicitly set - update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) - - # Automatically update the completion timestamp if status is set to 'completed' - for key, value in update_data.items(): - # Ensure completed_at is timezone-naive for database compatibility - if key == "completed_at" and value is not None and hasattr(value, "replace"): - value = value.replace(tzinfo=None) - setattr(job, key, value) - - if job_update.status in {JobStatus.completed, JobStatus.failed} and not_completed_before: - job.completed_at = get_utc_time().replace(tzinfo=None) - - # Save the updated job to the database first - job = job.update(db_session=session, actor=actor) - - # Get the updated metadata for callback - final_metadata = job.metadata_ - result = job.to_pydantic() - - # Dispatch callback outside of database session if needed - if needs_callback: - callback_info = { - "job_id": job_id, - "callback_url": callback_url, - "status": job_update.status, - "completed_at": get_utc_time().replace(tzinfo=None), - "metadata": final_metadata, - } - callback_result = self._dispatch_callback_sync(callback_info) - - # Update callback status in a separate transaction - with db_registry.session() as session: - job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"]) - job.callback_sent_at = callback_result["callback_sent_at"] - job.callback_status_code = callback_result.get("callback_status_code") - job.callback_error = callback_result.get("callback_error") - job.update(db_session=session, actor=actor) - result = job.to_pydantic() - - return result - @enforce_types @trace_method async def update_job_by_id_async( @@ -291,15 +194,6 @@ class JobManager: logger.error(f"Failed to safely update job status for job {job_id}: {e}") return False - @enforce_types - @trace_method - def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob: - """Fetch a job by its ID.""" - with db_registry.session() as session: - # Retrieve job by ID using the Job model's read method - job = JobModel.read(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER) - return job.to_pydantic() - @enforce_types @trace_method async def get_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob: @@ -309,100 +203,6 @@ class JobManager: job = await JobModel.read_async(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER) return job.to_pydantic() - @enforce_types - @trace_method - def list_jobs( - self, - actor: PydanticUser, - before: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = 50, - statuses: Optional[List[JobStatus]] = None, - job_type: JobType = JobType.JOB, - ascending: bool = True, - stop_reason: Optional[StopReasonType] = None, - # agent_id: Optional[str] = None, - agent_ids: Optional[List[str]] = None, - background: Optional[bool] = None, - ) -> List[PydanticJob]: - """List all jobs with optional pagination and status filter.""" - from sqlalchemy import and_, select - - from letta.orm.agents_runs import AgentsRuns - - with db_registry.session() as session: - filter_kwargs = {"user_id": actor.id, "job_type": job_type} - - # Add status filter if provided - if statuses: - filter_kwargs["status"] = statuses - - # Add stop_reason filter if provided - if stop_reason is not None: - filter_kwargs["stop_reason"] = stop_reason - - # Add background filter if provided - if background is not None: - filter_kwargs["background"] = background - - # Build query - query = select(JobModel) - - # Apply basic filters - for key, value in filter_kwargs.items(): - if isinstance(value, list): - query = query.where(getattr(JobModel, key).in_(value)) - else: - query = query.where(getattr(JobModel, key) == value) - - # If agent_id filter is provided, join with agents_runs table - if agent_ids: - query = query.join(AgentsRuns, JobModel.id == AgentsRuns.run_id) - query = query.where(AgentsRuns.agent_id.in_(agent_ids)) - - # Apply pagination and ordering - if ascending: - query = query.order_by(JobModel.created_at.asc(), JobModel.id.asc()) - else: - query = query.order_by(JobModel.created_at.desc(), JobModel.id.desc()) - - # Apply cursor-based pagination - if before: - before_job = session.get(JobModel, before) - if before_job: - if ascending: - query = query.where( - (JobModel.created_at < before_job.created_at) - | ((JobModel.created_at == before_job.created_at) & (JobModel.id < before_job.id)) - ) - else: - query = query.where( - (JobModel.created_at > before_job.created_at) - | ((JobModel.created_at == before_job.created_at) & (JobModel.id > before_job.id)) - ) - - if after: - after_job = session.get(JobModel, after) - if after_job: - if ascending: - query = query.where( - (JobModel.created_at > after_job.created_at) - | ((JobModel.created_at == after_job.created_at) & (JobModel.id > after_job.id)) - ) - else: - query = query.where( - (JobModel.created_at < after_job.created_at) - | ((JobModel.created_at == after_job.created_at) & (JobModel.id < after_job.id)) - ) - - # Apply limit - if limit: - query = query.limit(limit) - - # Execute query - jobs = session.execute(query).scalars().all() - return [job.to_pydantic() for job in jobs] - @enforce_types @trace_method async def list_jobs_async( @@ -515,15 +315,6 @@ class JobManager: return [job.to_pydantic() for job in jobs] - @enforce_types - @trace_method - def delete_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob: - """Delete a job by its ID.""" - with db_registry.session() as session: - job = self._verify_job_access(session=session, job_id=job_id, actor=actor) - job.hard_delete(db_session=session, actor=actor) - return job.to_pydantic() - @enforce_types @trace_method async def delete_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob: @@ -533,127 +324,6 @@ class JobManager: await job.hard_delete_async(db_session=session, actor=actor) return job.to_pydantic() - @enforce_types - @trace_method - def get_job_messages( - self, - job_id: str, - actor: PydanticUser, - before: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = 100, - role: Optional[MessageRole] = None, - ascending: bool = True, - ) -> List[PydanticMessage]: - """ - Get all messages associated with a job. - - Args: - job_id: The ID of the job to get messages for - actor: The user making the request - before: Cursor for pagination - after: Cursor for pagination - limit: Maximum number of messages to return - role: Optional filter for message role - ascending: Optional flag to sort in ascending order - - Returns: - List of messages associated with the job - - Raises: - NoResultFound: If the job does not exist or user does not have access - """ - with db_registry.session() as session: - # Build filters - filters = {} - if role is not None: - filters["role"] = role - - # Get messages - messages = MessageModel.list( - db_session=session, - before=before, - after=after, - ascending=ascending, - limit=limit, - actor=actor, - join_model=JobMessage, - join_conditions=[MessageModel.id == JobMessage.message_id, JobMessage.job_id == job_id], - **filters, - ) - - return [message.to_pydantic() for message in messages] - - @enforce_types - @trace_method - def get_job_steps( - self, - job_id: str, - actor: PydanticUser, - before: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = 100, - ascending: bool = True, - ) -> List[PydanticStep]: - """ - Get all steps associated with a job. - - Args: - job_id: The ID of the job to get steps for - actor: The user making the request - before: Cursor for pagination - after: Cursor for pagination - limit: Maximum number of steps to return - ascending: Optional flag to sort in ascending order - - Returns: - List of steps associated with the job - - Raises: - NoResultFound: If the job does not exist or user does not have access - """ - with db_registry.session() as session: - # Build filters - filters = {} - filters["job_id"] = job_id - - # Get steps - steps = StepModel.list( - db_session=session, - before=before, - after=after, - ascending=ascending, - limit=limit, - actor=actor, - **filters, - ) - - return [step.to_pydantic() for step in steps] - - @enforce_types - @trace_method - def add_message_to_job(self, job_id: str, message_id: str, actor: PydanticUser) -> None: - """ - Associate a message with a job by creating a JobMessage record. - Each message can only be associated with one job. - - Args: - job_id: The ID of the job - message_id: The ID of the message to associate - actor: The user making the request - - Raises: - NoResultFound: If the job does not exist or user does not have access - """ - with db_registry.session() as session: - # First verify job exists and user has access - self._verify_job_access(session, job_id, actor, access=["write"]) - - # Create new JobMessage association - job_message = JobMessage(job_id=job_id, message_id=message_id) - session.add(job_message) - session.commit() - @enforce_types @trace_method async def add_messages_to_job_async(self, job_id: str, message_ids: List[str], actor: PydanticUser) -> None: @@ -683,86 +353,7 @@ class JobManager: @enforce_types @trace_method - def get_job_usage(self, job_id: str, actor: PydanticUser) -> LettaUsageStatistics: - """ - Get usage statistics for a job. - - Args: - job_id: The ID of the job - actor: The user making the request - - Returns: - Usage statistics for the job - - Raises: - NoResultFound: If the job does not exist or user does not have access - """ - with db_registry.session() as session: - # First verify job exists and user has access - self._verify_job_access(session, job_id, actor) - - # Get the latest usage statistics for the job - latest_stats = session.query(Step).filter(Step.job_id == job_id).order_by(Step.created_at.desc()).all() - - if not latest_stats: - return LettaUsageStatistics( - completion_tokens=0, - prompt_tokens=0, - total_tokens=0, - step_count=0, - ) - - return LettaUsageStatistics( - completion_tokens=reduce(add, (step.completion_tokens or 0 for step in latest_stats), 0), - prompt_tokens=reduce(add, (step.prompt_tokens or 0 for step in latest_stats), 0), - total_tokens=reduce(add, (step.total_tokens or 0 for step in latest_stats), 0), - step_count=len(latest_stats), - ) - - @enforce_types - @trace_method - def add_job_usage( - self, - job_id: str, - usage: LettaUsageStatistics, - step_id: Optional[str] = None, - actor: PydanticUser = None, - ) -> None: - """ - Add usage statistics for a job. - - Args: - job_id: The ID of the job - usage: Usage statistics for the job - step_id: Optional ID of the specific step within the job - actor: The user making the request - - Raises: - NoResultFound: If the job does not exist or user does not have access - """ - with db_registry.session() as session: - # First verify job exists and user has access - self._verify_job_access(session, job_id, actor, access=["write"]) - - # Manually log step with usage data - # TODO(@caren): log step under the hood and remove this - usage_stats = Step( - job_id=job_id, - completion_tokens=usage.completion_tokens, - prompt_tokens=usage.prompt_tokens, - total_tokens=usage.total_tokens, - step_count=usage.step_count, - step_id=step_id, - ) - if actor: - usage_stats._set_created_and_updated_by_fields(actor.id) - - session.add(usage_stats) - session.commit() - - @enforce_types - @trace_method - def get_run_messages( + async def get_run_messages( self, run_id: str, actor: PydanticUser, @@ -791,7 +382,7 @@ class JobManager: Raises: NoResultFound: If the job does not exist or user does not have access """ - messages = self.get_job_messages( + messages = await self.get_job_messages( job_id=run_id, actor=actor, before=before, @@ -801,7 +392,7 @@ class JobManager: ascending=ascending, ) - request_config = self._get_run_request_config(run_id) + request_config = await self._get_run_request_config(run_id) print("request_config", request_config) messages = PydanticMessage.to_letta_messages_from_list( @@ -819,7 +410,7 @@ class JobManager: @enforce_types @trace_method - def get_step_messages( + async def get_step_messages( self, run_id: str, actor: PydanticUser, @@ -848,7 +439,7 @@ class JobManager: Raises: NoResultFound: If the job does not exist or user does not have access """ - messages = self.get_job_messages( + messages = await self.get_job_messages( job_id=run_id, actor=actor, before=before, @@ -858,7 +449,7 @@ class JobManager: ascending=ascending, ) - request_config = self._get_run_request_config(run_id) + request_config = await self._get_run_request_config(run_id) messages = PydanticMessage.to_letta_messages_from_list( messages=messages, @@ -869,34 +460,6 @@ class JobManager: return messages - def _verify_job_access( - self, - session: Session, - job_id: str, - actor: PydanticUser, - access: List[Literal["read", "write", "admin"]] = ["read"], - ) -> JobModel: - """ - Verify that a job exists and the user has the required access. - - Args: - session: The database session - job_id: The ID of the job to verify - actor: The user making the request - - Returns: - The job if it exists and the user has access - - Raises: - NoResultFound: If the job does not exist or user does not have access - """ - job_query = select(JobModel).where(JobModel.id == job_id) - job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER) - job = session.execute(job_query).scalar_one_or_none() - if not job: - raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") - return job - async def _verify_job_access_async( self, session: Session, @@ -926,21 +489,6 @@ class JobManager: raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access") return job - def _get_run_request_config(self, run_id: str) -> LettaRequestConfig: - """ - Get the request config for a job. - - Args: - job_id: The ID of the job to get messages for - - Returns: - The request config for the job - """ - with db_registry.session() as session: - job = session.query(JobModel).filter(JobModel.id == run_id).first() - request_config = job.request_config or LettaRequestConfig() - return request_config - @enforce_types async def record_ttft(self, job_id: str, ttft_ns: int, actor: PydanticUser) -> None: """Record time to first token for a run""" @@ -1021,3 +569,115 @@ class JobManager: # Continue silently - callback failures should not affect job completion finally: return result + + @enforce_types + @trace_method + async def get_job_messages( + self, + job_id: str, + actor: PydanticUser, + before: Optional[str] = None, + after: Optional[str] = None, + limit: Optional[int] = 100, + role: Optional[MessageRole] = None, + ascending: bool = True, + ) -> List[PydanticMessage]: + """ + Get all messages associated with a job. + + Args: + job_id: The ID of the job to get messages for + actor: The user making the request + before: Cursor for pagination + after: Cursor for pagination + limit: Maximum number of messages to return + role: Optional filter for message role + ascending: Optional flag to sort in ascending order + + Returns: + List of messages associated with the job + + Raises: + NoResultFound: If the job does not exist or user does not have access + """ + async with db_registry.async_session() as session: + # Build filters + filters = {} + if role is not None: + filters["role"] = role + + # Get messages + messages = await MessageModel.list_async( + db_session=session, + before=before, + after=after, + ascending=ascending, + limit=limit, + actor=actor, + join_model=JobMessage, + join_conditions=[MessageModel.id == JobMessage.message_id, JobMessage.job_id == job_id], + **filters, + ) + + return [message.to_pydantic() for message in messages] + + @enforce_types + @trace_method + async def get_job_steps( + self, + job_id: str, + actor: PydanticUser, + before: Optional[str] = None, + after: Optional[str] = None, + limit: Optional[int] = 100, + ascending: bool = True, + ) -> List[PydanticStep]: + """ + Get all steps associated with a job. + + Args: + job_id: The ID of the job to get steps for + actor: The user making the request + before: Cursor for pagination + after: Cursor for pagination + limit: Maximum number of steps to return + ascending: Optional flag to sort in ascending order + + Returns: + List of steps associated with the job + + Raises: + NoResultFound: If the job does not exist or user does not have access + """ + async with db_registry.async_session() as session: + # Build filters + filters = {} + filters["job_id"] = job_id + + # Get steps + steps = StepModel.list_async( + db_session=session, + before=before, + after=after, + ascending=ascending, + limit=limit, + actor=actor, + **filters, + ) + + return [step.to_pydantic() for step in steps] + + async def _get_run_request_config(self, run_id: str) -> LettaRequestConfig: + """ + Get the request config for a job. + + Args: + job_id: The ID of the job to get messages for + + Returns: + The request config for the job + """ + async with db_registry.async_session() as session: + job = await JobModel.read_async(db_session=session, identifier=run_id) + request_config = job.request_config or LettaRequestConfig() + return request_config diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 9c6bfa1d..6afd8e68 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -214,17 +214,6 @@ class MessageManager: return combined_messages - @enforce_types - @trace_method - def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]: - """Fetch a message by ID.""" - with db_registry.session() as session: - try: - message = MessageModel.read(db_session=session, identifier=message_id, actor=actor) - return message.to_pydantic() - except NoResultFound: - return None - @enforce_types @trace_method async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]: @@ -236,14 +225,6 @@ class MessageManager: except NoResultFound: return None - @enforce_types - @trace_method - def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]: - """Fetch messages by ID and return them in the requested order.""" - with db_registry.session() as session: - results = MessageModel.read_multiple(db_session=session, identifiers=message_ids, actor=actor) - return self._get_messages_by_id_postprocess(results, message_ids) - @enforce_types @trace_method async def get_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]: @@ -265,18 +246,6 @@ class MessageManager: result_dict = {msg.id: msg.to_pydantic() for msg in results} return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids])) - @enforce_types - @trace_method - def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage: - """Create a new message.""" - with db_registry.session() as session: - # Set the organization id of the Pydantic message - msg_data = pydantic_msg.model_dump(to_orm=True) - msg_data["organization_id"] = actor.organization_id - msg = MessageModel(**msg_data) - msg.create(session, actor=actor) # Persist to database - return msg.to_pydantic() - def _create_many_preprocess(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[MessageModel]: # Create ORM model instances for all messages orm_messages = [] @@ -287,26 +256,6 @@ class MessageManager: orm_messages.append(MessageModel(**msg_data)) return orm_messages - @enforce_types - @trace_method - def create_many_messages(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]: - """ - Create multiple messages in a single database transaction. - Args: - pydantic_msgs: List of Pydantic message models to create - actor: User performing the action - - Returns: - List of created Pydantic message models - """ - if not pydantic_msgs: - return [] - - orm_messages = self._create_many_preprocess(pydantic_msgs, actor) - with db_registry.session() as session: - created_messages = MessageModel.batch_create(orm_messages, session, actor=actor) - return [msg.to_pydantic() for msg in created_messages] - @enforce_types @trace_method async def check_existing_message_ids(self, message_ids: List[str], actor: PydanticUser) -> Set[str]: @@ -518,13 +467,13 @@ class MessageManager: @enforce_types @trace_method - def update_message_by_letta_message( + async def update_message_by_letta_message_async( self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser ) -> PydanticMessage: """ Updated the underlying messages table giving an update specified to the user-facing LettaMessage """ - message = self.get_message_by_id(message_id=message_id, actor=actor) + message = await self.get_message_by_id_async(message_id=message_id, actor=actor) if letta_message_update.message_type == "assistant_message": # modify the tool call for send_message # TODO: fix this if we add parallel tool calls @@ -545,7 +494,7 @@ class MessageManager: else: raise ValueError(f"Unsupported message type for modification: {letta_message_update.message_type}") - message = self.update_message_by_id(message_id=message_id, message_update=update_message, actor=actor) + message = await self.update_message_by_id_async(message_id=message_id, message_update=update_message, actor=actor) # convert back to LettaMessage for letta_msg in message.to_letta_messages(use_assistant_message=True): @@ -594,24 +543,6 @@ class MessageManager: # raise error if message type got modified raise ValueError(f"Message type got modified: {letta_message_update.message_type}") - @enforce_types - @trace_method - def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage: - """ - Updates an existing record in the database with values from the provided record object. - """ - with db_registry.session() as session: - # Fetch existing message from database - message = MessageModel.read( - db_session=session, - identifier=message_id, - actor=actor, - ) - - message = self._update_message_by_id_impl(message_id, message_update, actor, message) - message.update(db_session=session, actor=actor) - return message.to_pydantic() - @enforce_types @trace_method async def update_message_by_id_async( @@ -731,22 +662,6 @@ class MessageManager: setattr(message, key, value) return message - @enforce_types - @trace_method - def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool: - """Delete a message.""" - with db_registry.session() as session: - try: - msg = MessageModel.read( - db_session=session, - identifier=message_id, - actor=actor, - ) - msg.hard_delete(session, actor=actor) - # Note: Turbopuffer deletion requires async, use delete_message_by_id_async for full deletion - except NoResultFound: - raise ValueError(f"Message with id {message_id} not found.") - @enforce_types @trace_method async def delete_message_by_id_async(self, message_id: str, actor: PydanticUser, strict_mode: bool = False) -> bool: @@ -781,23 +696,6 @@ class MessageManager: except NoResultFound: raise ValueError(f"Message with id {message_id} not found.") - @enforce_types - @trace_method - def size( - self, - actor: PydanticUser, - role: Optional[MessageRole] = None, - agent_id: Optional[str] = None, - ) -> int: - """Get the total count of messages with optional filters. - - Args: - actor: The user requesting the count - role: The role of the message - """ - with db_registry.session() as session: - return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id) - @enforce_types @trace_method async def size_async( @@ -814,29 +712,6 @@ class MessageManager: async with db_registry.async_session() as session: return await MessageModel.size_async(db_session=session, actor=actor, role=role, agent_id=agent_id) - @enforce_types - @trace_method - def list_user_messages_for_agent( - self, - agent_id: str, - actor: PydanticUser, - after: Optional[str] = None, - before: Optional[str] = None, - query_text: Optional[str] = None, - limit: Optional[int] = 50, - ascending: bool = True, - ) -> List[PydanticMessage]: - return self.list_messages_for_agent( - agent_id=agent_id, - actor=actor, - after=after, - before=before, - query_text=query_text, - roles=[MessageRole.user], - limit=limit, - ascending=ascending, - ) - @enforce_types @trace_method async def list_user_messages_for_agent_async( @@ -860,109 +735,6 @@ class MessageManager: ascending=ascending, ) - @enforce_types - @trace_method - def list_messages_for_agent( - self, - agent_id: str, - actor: PydanticUser, - after: Optional[str] = None, - before: Optional[str] = None, - query_text: Optional[str] = None, - roles: Optional[Sequence[MessageRole]] = None, - limit: Optional[int] = 50, - ascending: bool = True, - group_id: Optional[str] = None, - ) -> List[PydanticMessage]: - """ - Most performant query to list messages for an agent by directly querying the Message table. - - This function filters by the agent_id (leveraging the index on messages.agent_id) - and applies pagination using sequence_id as the cursor. - If query_text is provided, it will filter messages whose text content partially matches the query. - If role is provided, it will filter messages by the specified role. - - Args: - agent_id: The ID of the agent whose messages are queried. - actor: The user performing the action (used for permission checks). - after: A message ID; if provided, only messages *after* this message (by sequence_id) are returned. - before: A message ID; if provided, only messages *before* this message (by sequence_id) are returned. - query_text: Optional string to partially match the message text content. - roles: Optional MessageRole to filter messages by role. - limit: Maximum number of messages to return. - ascending: If True, sort by sequence_id ascending; if False, sort descending. - group_id: Optional group ID to filter messages by group_id. - - Returns: - List[PydanticMessage]: A list of messages (converted via .to_pydantic()). - - Raises: - NoResultFound: If the provided after/before message IDs do not exist. - """ - - with db_registry.session() as session: - # Permission check: raise if the agent doesn't exist or actor is not allowed. - AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - - # Build a query that directly filters the Message table by agent_id. - query = session.query(MessageModel).filter(MessageModel.agent_id == agent_id) - - # If group_id is provided, filter messages by group_id. - if group_id: - query = query.filter(MessageModel.group_id == group_id) - - # If query_text is provided, filter messages using database-specific JSON search. - if query_text: - if settings.database_engine is DatabaseChoice.POSTGRES: - # PostgreSQL: Use json_array_elements and ILIKE - content_element = func.json_array_elements(MessageModel.content).alias("content_element") - query = query.filter( - exists( - select(1) - .select_from(content_element) - .where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text")) - .params(query_text=f"%{query_text}%") - ) - ) - else: - # SQLite: Use JSON_EXTRACT with individual array indices for case-insensitive search - # Since SQLite doesn't support $[*] syntax, we'll use a different approach - query = query.filter(text("JSON_EXTRACT(content, '$') LIKE :query_text")).params(query_text=f"%{query_text}%") - - # If role(s) are provided, filter messages by those roles. - if roles: - role_values = [r.value for r in roles] - query = query.filter(MessageModel.role.in_(role_values)) - - # Apply 'after' pagination if specified. - if after: - after_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == after).one_or_none() - if not after_ref: - raise NoResultFound(f"No message found with id '{after}' for agent '{agent_id}'.") - # Filter out any messages with a sequence_id <= after_ref.sequence_id - query = query.filter(MessageModel.sequence_id > after_ref.sequence_id) - - # Apply 'before' pagination if specified. - if before: - before_ref = session.query(MessageModel.sequence_id).filter(MessageModel.id == before).one_or_none() - if not before_ref: - raise NoResultFound(f"No message found with id '{before}' for agent '{agent_id}'.") - # Filter out any messages with a sequence_id >= before_ref.sequence_id - query = query.filter(MessageModel.sequence_id < before_ref.sequence_id) - - # Apply ordering based on the ascending flag. - if ascending: - query = query.order_by(MessageModel.sequence_id.asc()) - else: - query = query.order_by(MessageModel.sequence_id.desc()) - - # Limit the number of results. - query = query.limit(limit) - - # Execute and convert each Message to its Pydantic representation. - results = query.all() - return [msg.to_pydantic() for msg in results] - @enforce_types @trace_method async def list_messages_for_agent_async( diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index e72defd3..993a3eb1 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -18,14 +18,6 @@ class OrganizationManager: """Fetch the default organization.""" return await self.get_organization_by_id_async(DEFAULT_ORG_ID) - @enforce_types - @trace_method - def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]: - """Fetch an organization by ID.""" - with db_registry.session() as session: - organization = OrganizationModel.read(db_session=session, identifier=org_id) - return organization.to_pydantic() - @enforce_types @trace_method async def get_organization_by_id_async(self, org_id: str) -> Optional[PydanticOrganization]: @@ -34,19 +26,6 @@ class OrganizationManager: organization = await OrganizationModel.read_async(db_session=session, identifier=org_id) return organization.to_pydantic() - @enforce_types - @trace_method - def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: - """Create the default organization.""" - with db_registry.session() as session: - try: - organization = OrganizationModel.read(db_session=session, identifier=pydantic_org.id) - return organization.to_pydantic() - except: - organization = OrganizationModel(**pydantic_org.model_dump(to_orm=True)) - organization = organization.create(session) - return organization.to_pydantic() - @enforce_types @trace_method async def create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: @@ -102,14 +81,6 @@ class OrganizationManager: await org.update_async(session) return org.to_pydantic() - @enforce_types - @trace_method - def delete_organization_by_id(self, org_id: str): - """Delete an organization by marking it as deleted.""" - with db_registry.session() as session: - organization = OrganizationModel.read(db_session=session, identifier=org_id) - organization.hard_delete(session) - @enforce_types @trace_method async def delete_organization_by_id_async(self, org_id: str): diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index a5201554..5ddcff8a 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -93,17 +93,6 @@ class PassageManager: return created_tags # AGENT PASSAGE METHODS - @enforce_types - @trace_method - def get_agent_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: - """Fetch an agent passage by ID.""" - with db_registry.session() as session: - try: - passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor) - return passage.to_pydantic() - except NoResultFound: - raise NoResultFound(f"Agent passage with id {passage_id} not found in database.") - @enforce_types @trace_method async def get_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: @@ -116,17 +105,6 @@ class PassageManager: raise NoResultFound(f"Agent passage with id {passage_id} not found in database.") # SOURCE PASSAGE METHODS - @enforce_types - @trace_method - def get_source_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: - """Fetch a source passage by ID.""" - with db_registry.session() as session: - try: - passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) - return passage.to_pydantic() - except NoResultFound: - raise NoResultFound(f"Source passage with id {passage_id} not found in database.") - @enforce_types @trace_method async def get_source_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: @@ -138,32 +116,6 @@ class PassageManager: except NoResultFound: raise NoResultFound(f"Source passage with id {passage_id} not found in database.") - # DEPRECATED - Use specific methods above - @enforce_types - @trace_method - def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: - """DEPRECATED: Use get_agent_passage_by_id() or get_source_passage_by_id() instead.""" - import warnings - - warnings.warn( - "get_passage_by_id is deprecated. Use get_agent_passage_by_id() or get_source_passage_by_id() instead.", - DeprecationWarning, - stacklevel=2, - ) - - with db_registry.session() as session: - # Try source passages first - try: - passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) - return passage.to_pydantic() - except NoResultFound: - # Try archival passages - try: - passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor) - return passage.to_pydantic() - except NoResultFound: - raise NoResultFound(f"Passage with id {passage_id} not found in database.") - @enforce_types @trace_method async def get_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: @@ -189,40 +141,6 @@ class PassageManager: except NoResultFound: raise NoResultFound(f"Passage with id {passage_id} not found in database.") - @enforce_types - @trace_method - def create_agent_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: - """Create a new agent passage.""" - if not pydantic_passage.archive_id: - raise ValueError("Agent passage must have archive_id") - if pydantic_passage.source_id: - raise ValueError("Agent passage cannot have source_id") - - data = pydantic_passage.model_dump(to_orm=True) - - # Deduplicate tags if provided (for dual storage consistency) - tags = data.get("tags") - if tags: - tags = list(set(tags)) - - common_fields = { - "id": data.get("id"), - "text": data["text"], - "embedding": data["embedding"], - "embedding_config": data["embedding_config"], - "organization_id": data["organization_id"], - "metadata_": data.get("metadata", {}), - "tags": tags, - "is_deleted": data.get("is_deleted", False), - "created_at": data.get("created_at", datetime.now(timezone.utc)), - } - agent_fields = {"archive_id": data["archive_id"]} - passage = ArchivalPassage(**common_fields, **agent_fields) - - with db_registry.session() as session: - passage.create(session, actor=actor) - return passage.to_pydantic() - @enforce_types @trace_method async def create_agent_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: @@ -269,46 +187,6 @@ class PassageManager: return passage.to_pydantic() - @enforce_types - @trace_method - def create_source_passage( - self, pydantic_passage: PydanticPassage, file_metadata: PydanticFileMetadata, actor: PydanticUser - ) -> PydanticPassage: - """Create a new source passage.""" - if not pydantic_passage.source_id: - raise ValueError("Source passage must have source_id") - if pydantic_passage.archive_id: - raise ValueError("Source passage cannot have archive_id") - - data = pydantic_passage.model_dump(to_orm=True) - - # Deduplicate tags if provided (for dual storage consistency) - tags = data.get("tags") - if tags: - tags = list(set(tags)) - - common_fields = { - "id": data.get("id"), - "text": data["text"], - "embedding": data["embedding"], - "embedding_config": data["embedding_config"], - "organization_id": data["organization_id"], - "metadata_": data.get("metadata", {}), - "tags": tags, - "is_deleted": data.get("is_deleted", False), - "created_at": data.get("created_at", datetime.now(timezone.utc)), - } - source_fields = { - "source_id": data["source_id"], - "file_id": data.get("file_id"), - "file_name": file_metadata.file_name, - } - passage = SourcePassage(**common_fields, **source_fields) - - with db_registry.session() as session: - passage.create(session, actor=actor) - return passage.to_pydantic() - @enforce_types @trace_method async def create_source_passage_async( @@ -349,23 +227,6 @@ class PassageManager: passage = await passage.create_async(session, actor=actor) return passage.to_pydantic() - # DEPRECATED - Use specific methods above - @enforce_types - @trace_method - def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: - """DEPRECATED: Use create_agent_passage() or create_source_passage() instead.""" - import warnings - - warnings.warn( - "create_passage is deprecated. Use create_agent_passage() or create_source_passage() instead.", DeprecationWarning, stacklevel=2 - ) - - passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage) - - with db_registry.session() as session: - passage.create(session, actor=actor) - return passage.to_pydantic() - @enforce_types @trace_method async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: @@ -654,34 +515,6 @@ class PassageManager: embeddings = await embedding_client.request_embeddings(text_chunks, embedding_config) return embeddings - @enforce_types - @trace_method - def update_agent_passage_by_id( - self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs - ) -> Optional[PydanticPassage]: - """Update an agent passage.""" - if not passage_id: - raise ValueError("Passage ID must be provided.") - - with db_registry.session() as session: - try: - curr_passage = ArchivalPassage.read( - db_session=session, - identifier=passage_id, - actor=actor, - ) - except NoResultFound: - raise ValueError(f"Agent passage with id {passage_id} does not exist.") - - # Update the database record with values from the provided record - update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) - for key, value in update_data.items(): - setattr(curr_passage, key, value) - - # Commit changes - curr_passage.update(session, actor=actor) - return curr_passage.to_pydantic() - @enforce_types @trace_method async def update_agent_passage_by_id_async( @@ -738,34 +571,6 @@ class PassageManager: await curr_passage.update_async(session, actor=actor) return curr_passage.to_pydantic() - @enforce_types - @trace_method - def update_source_passage_by_id( - self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs - ) -> Optional[PydanticPassage]: - """Update a source passage.""" - if not passage_id: - raise ValueError("Passage ID must be provided.") - - with db_registry.session() as session: - try: - curr_passage = SourcePassage.read( - db_session=session, - identifier=passage_id, - actor=actor, - ) - except NoResultFound: - raise ValueError(f"Source passage with id {passage_id} does not exist.") - - # Update the database record with values from the provided record - update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) - for key, value in update_data.items(): - setattr(curr_passage, key, value) - - # Commit changes - curr_passage.update(session, actor=actor) - return curr_passage.to_pydantic() - @enforce_types @trace_method async def update_source_passage_by_id_async( @@ -794,21 +599,6 @@ class PassageManager: await curr_passage.update_async(session, actor=actor) return curr_passage.to_pydantic() - @enforce_types - @trace_method - def delete_agent_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: - """Delete an agent passage.""" - if not passage_id: - raise ValueError("Passage ID must be provided.") - - with db_registry.session() as session: - try: - passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor) - passage.hard_delete(session, actor=actor) - return True - except NoResultFound: - raise NoResultFound(f"Agent passage with id {passage_id} not found.") - @enforce_types @trace_method async def delete_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser, strict_mode: bool = False) -> bool: @@ -842,21 +632,6 @@ class PassageManager: except NoResultFound: raise NoResultFound(f"Agent passage with id {passage_id} not found.") - @enforce_types - @trace_method - def delete_source_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: - """Delete a source passage.""" - if not passage_id: - raise ValueError("Passage ID must be provided.") - - with db_registry.session() as session: - try: - passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) - passage.hard_delete(session, actor=actor) - return True - except NoResultFound: - raise NoResultFound(f"Source passage with id {passage_id} not found.") - @enforce_types @trace_method async def delete_source_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool: @@ -872,80 +647,6 @@ class PassageManager: except NoResultFound: raise NoResultFound(f"Source passage with id {passage_id} not found.") - # DEPRECATED - Use specific methods above - @enforce_types - @trace_method - def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]: - """DEPRECATED: Use update_agent_passage_by_id() or update_source_passage_by_id() instead.""" - import warnings - - warnings.warn( - "update_passage_by_id is deprecated. Use update_agent_passage_by_id() or update_source_passage_by_id() instead.", - DeprecationWarning, - stacklevel=2, - ) - - if not passage_id: - raise ValueError("Passage ID must be provided.") - - with db_registry.session() as session: - # Try source passages first - try: - curr_passage = SourcePassage.read( - db_session=session, - identifier=passage_id, - actor=actor, - ) - except NoResultFound: - # Try agent passages - try: - curr_passage = ArchivalPassage.read( - db_session=session, - identifier=passage_id, - actor=actor, - ) - except NoResultFound: - raise ValueError(f"Passage with id {passage_id} does not exist.") - - # Update the database record with values from the provided record - update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) - for key, value in update_data.items(): - setattr(curr_passage, key, value) - - # Commit changes - curr_passage.update(session, actor=actor) - return curr_passage.to_pydantic() - - @enforce_types - @trace_method - def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: - """DEPRECATED: Use delete_agent_passage_by_id() or delete_source_passage_by_id() instead.""" - import warnings - - warnings.warn( - "delete_passage_by_id is deprecated. Use delete_agent_passage_by_id() or delete_source_passage_by_id() instead.", - DeprecationWarning, - stacklevel=2, - ) - - if not passage_id: - raise ValueError("Passage ID must be provided.") - - with db_registry.session() as session: - # Try source passages first - try: - passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) - passage.hard_delete(session, actor=actor) - return True - except NoResultFound: - # Try archival passages - try: - passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor) - passage.hard_delete(session, actor=actor) - return True - except NoResultFound: - raise NoResultFound(f"Passage with id {passage_id} not found.") - @enforce_types @trace_method async def delete_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool: @@ -1077,36 +778,6 @@ class PassageManager: self.delete_passage_by_id(passage_id=passage.id, actor=actor) return True - @enforce_types - @trace_method - def agent_passage_size( - self, - actor: PydanticUser, - agent_id: Optional[str] = None, - ) -> int: - """Get the total count of agent passages with optional filters. - - Args: - actor: The user requesting the count - agent_id: The agent ID of the messages - """ - with db_registry.session() as session: - if agent_id: - # Count passages through the archives relationship - return ( - session.query(ArchivalPassage) - .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id) - .filter( - ArchivesAgents.agent_id == agent_id, - ArchivalPassage.organization_id == actor.organization_id, - ArchivalPassage.is_deleted == False, - ) - .count() - ) - else: - # Count all archival passages in the organization - return ArchivalPassage.size(db_session=session, actor=actor) - # DEPRECATED - Use agent_passage_size() instead since this only counted agent passages anyway @enforce_types @trace_method @@ -1152,22 +823,6 @@ class PassageManager: # Count all archival passages in the organization return await ArchivalPassage.size_async(db_session=session, actor=actor) - @enforce_types - @trace_method - def source_passage_size( - self, - actor: PydanticUser, - source_id: Optional[str] = None, - ) -> int: - """Get the total count of source passages with optional filters. - - Args: - actor: The user requesting the count - source_id: The source ID of the passages - """ - with db_registry.session() as session: - return SourcePassage.size(db_session=session, actor=actor, source_id=source_id) - @enforce_types @trace_method async def source_passage_size_async( diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 6c908b14..ff4e0406 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -10,27 +10,6 @@ from letta.utils import enforce_types class ProviderManager: - @enforce_types - @trace_method - def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider: - """Create a new provider if it doesn't already exist.""" - with db_registry.session() as session: - provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok} - provider = PydanticProvider(**provider_create_args) - - if provider.name == provider.provider_type.value: - raise ValueError("Provider name must be unique and different from provider type") - - # Assign the organization id based on the actor - provider.organization_id = actor.organization_id - - # Lazily create the provider id prior to persistence - provider.resolve_identifier() - - new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True)) - new_provider.create(session, actor=actor) - return new_provider.to_pydantic() - @enforce_types @trace_method async def create_provider_async(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider: @@ -52,23 +31,6 @@ class ProviderManager: await new_provider.create_async(session, actor=actor) return new_provider.to_pydantic() - @enforce_types - @trace_method - def update_provider(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider: - """Update provider details.""" - with db_registry.session() as session: - # Retrieve the existing provider by ID - existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True) - - # Update only the fields that are provided in ProviderUpdate - update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) - for key, value in update_data.items(): - setattr(existing_provider, key, value) - - # Commit the updated provider - existing_provider.update(session, actor=actor) - return existing_provider.to_pydantic() - @enforce_types @trace_method async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider: @@ -88,21 +50,6 @@ class ProviderManager: await existing_provider.update_async(session, actor=actor) return existing_provider.to_pydantic() - @enforce_types - @trace_method - def delete_provider_by_id(self, provider_id: str, actor: PydanticUser): - """Delete a provider.""" - with db_registry.session() as session: - # Clear api key field - existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True) - existing_provider.api_key = None - existing_provider.update(session, actor=actor) - - # Soft delete in provider table - existing_provider.delete(session, actor=actor) - - session.commit() - @enforce_types @trace_method async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser): @@ -120,33 +67,6 @@ class ProviderManager: await session.commit() - @enforce_types - @trace_method - def list_providers( - self, - actor: PydanticUser, - name: Optional[str] = None, - provider_type: Optional[ProviderType] = None, - after: Optional[str] = None, - limit: Optional[int] = 50, - ) -> List[PydanticProvider]: - """List all providers with optional pagination.""" - filter_kwargs = {} - if name: - filter_kwargs["name"] = name - if provider_type: - filter_kwargs["provider_type"] = provider_type - with db_registry.session() as session: - providers = ProviderModel.list( - db_session=session, - after=after, - limit=limit, - actor=actor, - check_is_deleted=True, - **filter_kwargs, - ) - return [provider.to_pydantic() for provider in providers] - @enforce_types @trace_method async def list_providers_async( diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index bb069982..8bc67824 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -45,40 +45,6 @@ class SandboxConfigManager: sandbox_config = self.create_or_update_sandbox_config(SandboxConfigCreate(config=default_config), actor=actor) return sandbox_config - @enforce_types - @trace_method - def create_or_update_sandbox_config(self, sandbox_config_create: SandboxConfigCreate, actor: PydanticUser) -> PydanticSandboxConfig: - """Create or update a sandbox configuration based on the PydanticSandboxConfig schema.""" - config = sandbox_config_create.config - sandbox_type = config.type - sandbox_config = PydanticSandboxConfig( - type=sandbox_type, config=config.model_dump(exclude_none=True), organization_id=actor.organization_id - ) - - # Attempt to retrieve the existing sandbox configuration by type within the organization - db_sandbox = self.get_sandbox_config_by_type(sandbox_config.type, actor=actor) - if db_sandbox: - # Prepare the update data, excluding fields that should not be reset - update_data = sandbox_config.model_dump(exclude_unset=True, exclude_none=True) - update_data = {key: value for key, value in update_data.items() if getattr(db_sandbox, key) != value} - - # If there are changes, update the sandbox configuration - if update_data: - db_sandbox = self.update_sandbox_config(db_sandbox.id, SandboxConfigUpdate(**update_data), actor) - else: - printd( - f"`create_or_update_sandbox_config` was called with user_id={actor.id}, organization_id={actor.organization_id}, " - f"type={sandbox_config.type}, but found existing configuration with nothing to update." - ) - - return db_sandbox - else: - # If the sandbox configuration doesn't exist, create a new one - with db_registry.session() as session: - db_sandbox = SandboxConfigModel(**sandbox_config.model_dump(exclude_none=True)) - db_sandbox.create(session, actor=actor) - return db_sandbox.to_pydantic() - @enforce_types @trace_method async def get_or_create_default_sandbox_config_async(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig: @@ -133,34 +99,6 @@ class SandboxConfigManager: await db_sandbox.create_async(session, actor=actor) return db_sandbox.to_pydantic() - @enforce_types - @trace_method - def update_sandbox_config( - self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser - ) -> PydanticSandboxConfig: - """Update an existing sandbox configuration.""" - with db_registry.session() as session: - sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor) - # We need to check that the sandbox_update provided is the same type as the original sandbox - if sandbox.type != sandbox_update.config.type: - raise ValueError( - f"Mismatched type for sandbox config update: tried to update sandbox_config of type {sandbox.type} with config of type {sandbox_update.config.type}" - ) - - update_data = sandbox_update.model_dump(exclude_unset=True, exclude_none=True) - update_data = {key: value for key, value in update_data.items() if getattr(sandbox, key) != value} - - if update_data: - for key, value in update_data.items(): - setattr(sandbox, key, value) - sandbox.update(db_session=session, actor=actor) - else: - printd( - f"`update_sandbox_config` called with user_id={actor.id}, organization_id={actor.organization_id}, " - f"name={sandbox.type}, but nothing to update." - ) - return sandbox.to_pydantic() - @enforce_types @trace_method async def update_sandbox_config_async( @@ -189,15 +127,6 @@ class SandboxConfigManager: ) return sandbox.to_pydantic() - @enforce_types - @trace_method - def delete_sandbox_config(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig: - """Delete a sandbox configuration by its ID.""" - with db_registry.session() as session: - sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor) - sandbox.hard_delete(db_session=session, actor=actor) - return sandbox.to_pydantic() - @enforce_types @trace_method async def delete_sandbox_config_async(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig: @@ -207,24 +136,6 @@ class SandboxConfigManager: await sandbox.hard_delete_async(db_session=session, actor=actor) return sandbox.to_pydantic() - @enforce_types - @trace_method - def list_sandbox_configs( - self, - actor: PydanticUser, - after: Optional[str] = None, - limit: Optional[int] = 50, - sandbox_type: Optional[SandboxType] = None, - ) -> List[PydanticSandboxConfig]: - """List all sandbox configurations with optional pagination.""" - kwargs = {"organization_id": actor.organization_id} - if sandbox_type: - kwargs.update({"type": sandbox_type}) - - with db_registry.session() as session: - sandboxes = SandboxConfigModel.list(db_session=session, after=after, limit=limit, **kwargs) - return [sandbox.to_pydantic() for sandbox in sandboxes] - @enforce_types @trace_method async def list_sandbox_configs_async( @@ -243,35 +154,6 @@ class SandboxConfigManager: sandboxes = await SandboxConfigModel.list_async(db_session=session, after=after, limit=limit, **kwargs) return [sandbox.to_pydantic() for sandbox in sandboxes] - @enforce_types - @trace_method - def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]: - """Retrieve a sandbox configuration by its ID.""" - with db_registry.session() as session: - try: - sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor) - return sandbox.to_pydantic() - except NoResultFound: - return None - - @enforce_types - @trace_method - def get_sandbox_config_by_type(self, type: SandboxType, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]: - """Retrieve a sandbox config by its type.""" - with db_registry.session() as session: - try: - sandboxes = SandboxConfigModel.list( - db_session=session, - type=type, - organization_id=actor.organization_id, - limit=1, - ) - if sandboxes: - return sandboxes[0].to_pydantic() - return None - except NoResultFound: - return None - @enforce_types @trace_method async def get_sandbox_config_by_type_async( @@ -292,34 +174,6 @@ class SandboxConfigManager: except NoResultFound: return None - @enforce_types - @trace_method - def create_sandbox_env_var( - self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser - ) -> PydanticEnvVar: - """Create a new sandbox environment variable.""" - env_var = PydanticEnvVar(**env_var_create.model_dump(), sandbox_config_id=sandbox_config_id, organization_id=actor.organization_id) - - db_env_var = self.get_sandbox_env_var_by_key_and_sandbox_config_id(env_var.key, env_var.sandbox_config_id, actor=actor) - if db_env_var: - update_data = env_var.model_dump(exclude_unset=True, exclude_none=True) - update_data = {key: value for key, value in update_data.items() if getattr(db_env_var, key) != value} - # If there are changes, update the environment variable - if update_data: - db_env_var = self.update_sandbox_env_var(db_env_var.id, SandboxEnvironmentVariableUpdate(**update_data), actor) - else: - printd( - f"`create_or_update_sandbox_env_var` was called with user_id={actor.id}, organization_id={actor.organization_id}, " - f"key={env_var.key}, but found existing variable with nothing to update." - ) - - return db_env_var - else: - with db_registry.session() as session: - env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True)) - env_var.create(session, actor=actor) - return env_var.to_pydantic() - @enforce_types @trace_method async def create_sandbox_env_var_async( @@ -348,28 +202,6 @@ class SandboxConfigManager: await env_var.create_async(session, actor=actor) return env_var.to_pydantic() - @enforce_types - @trace_method - def update_sandbox_env_var( - self, env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, actor: PydanticUser - ) -> PydanticEnvVar: - """Update an existing sandbox environment variable.""" - with db_registry.session() as session: - env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor) - update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) - update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value} - - if update_data: - for key, value in update_data.items(): - setattr(env_var, key, value) - env_var.update(db_session=session, actor=actor) - else: - printd( - f"`update_sandbox_env_var` called with user_id={actor.id}, organization_id={actor.organization_id}, " - f"key={env_var.key}, but nothing to update." - ) - return env_var.to_pydantic() - @enforce_types @trace_method async def update_sandbox_env_var_async( @@ -392,15 +224,6 @@ class SandboxConfigManager: ) return env_var.to_pydantic() - @enforce_types - @trace_method - def delete_sandbox_env_var(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar: - """Delete a sandbox environment variable by its ID.""" - with db_registry.session() as session: - env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor) - env_var.hard_delete(db_session=session, actor=actor) - return env_var.to_pydantic() - @enforce_types @trace_method async def delete_sandbox_env_var_async(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar: @@ -410,26 +233,6 @@ class SandboxConfigManager: await env_var.hard_delete_async(db_session=session, actor=actor) return env_var.to_pydantic() - @enforce_types - @trace_method - def list_sandbox_env_vars( - self, - sandbox_config_id: str, - actor: PydanticUser, - after: Optional[str] = None, - limit: Optional[int] = 50, - ) -> List[PydanticEnvVar]: - """List all sandbox environment variables with optional pagination.""" - with db_registry.session() as session: - env_vars = SandboxEnvVarModel.list( - db_session=session, - after=after, - limit=limit, - organization_id=actor.organization_id, - sandbox_config_id=sandbox_config_id, - ) - return [env_var.to_pydantic() for env_var in env_vars] - @enforce_types @trace_method async def list_sandbox_env_vars_async( @@ -450,22 +253,6 @@ class SandboxConfigManager: ) return [env_var.to_pydantic() for env_var in env_vars] - @enforce_types - @trace_method - def list_sandbox_env_vars_by_key( - self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 - ) -> List[PydanticEnvVar]: - """List all sandbox environment variables with optional pagination.""" - with db_registry.session() as session: - env_vars = SandboxEnvVarModel.list( - db_session=session, - after=after, - limit=limit, - organization_id=actor.organization_id, - key=key, - ) - return [env_var.to_pydantic() for env_var in env_vars] - @enforce_types @trace_method async def list_sandbox_env_vars_by_key_async( @@ -501,27 +288,6 @@ class SandboxConfigManager: env_vars = await self.list_sandbox_env_vars_async(sandbox_config_id, actor, after, limit) return {env_var.key: env_var.value for env_var in env_vars} - @enforce_types - @trace_method - def get_sandbox_env_var_by_key_and_sandbox_config_id( - self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None - ) -> Optional[PydanticEnvVar]: - """Retrieve a sandbox environment variable by its key and sandbox_config_id.""" - with db_registry.session() as session: - try: - env_var = SandboxEnvVarModel.list( - db_session=session, - key=key, - sandbox_config_id=sandbox_config_id, - organization_id=actor.organization_id, - limit=1, - ) - if env_var: - return env_var[0].to_pydantic() - return None - except NoResultFound: - return None - @enforce_types @trace_method async def get_sandbox_env_var_by_key_and_sandbox_config_id_async( diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 33a60d5a..0c50285c 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -75,60 +75,6 @@ class StepManager: ) return [step.to_pydantic() for step in steps] - @enforce_types - @trace_method - def log_step( - self, - actor: PydanticUser, - agent_id: str, - provider_name: str, - provider_category: str, - model: str, - model_endpoint: Optional[str], - context_window_limit: int, - usage: UsageStatistics, - provider_id: Optional[str] = None, - job_id: Optional[str] = None, - step_id: Optional[str] = None, - project_id: Optional[str] = None, - stop_reason: Optional[LettaStopReason] = None, - status: Optional[StepStatus] = None, - error_type: Optional[str] = None, - error_data: Optional[Dict] = None, - ) -> PydanticStep: - step_data = { - "origin": None, - "organization_id": actor.organization_id, - "agent_id": agent_id, - "provider_id": provider_id, - "provider_name": provider_name, - "provider_category": provider_category, - "model": model, - "model_endpoint": model_endpoint, - "context_window_limit": context_window_limit, - "completion_tokens": usage.completion_tokens, - "prompt_tokens": usage.prompt_tokens, - "total_tokens": usage.total_tokens, - "job_id": job_id, - "tags": [], - "tid": None, - "trace_id": get_trace_id(), # Get the current trace ID - "project_id": project_id, - "status": status if status else StepStatus.PENDING, - "error_type": error_type, - "error_data": error_data, - } - if step_id: - step_data["id"] = step_id - if stop_reason: - step_data["stop_reason"] = stop_reason.stop_reason - with db_registry.session() as session: - if job_id: - self._verify_job_access(session, job_id, actor, access=["write"]) - new_step = StepModel(**step_data) - new_step.create(session) - return new_step.to_pydantic() - @enforce_types @trace_method async def log_step_async( diff --git a/letta/services/telemetry_manager.py b/letta/services/telemetry_manager.py index 804343f9..c74a9cd6 100644 --- a/letta/services/telemetry_manager.py +++ b/letta/services/telemetry_manager.py @@ -39,22 +39,6 @@ class TelemetryManager: await session.commit() return pydantic_provider_trace - @enforce_types - @trace_method - def create_provider_trace(self, actor: PydanticUser, provider_trace_create: ProviderTraceCreate) -> PydanticProviderTrace: - with db_registry.session() as session: - provider_trace = ProviderTraceModel(**provider_trace_create.model_dump()) - provider_trace.organization_id = actor.organization_id - if provider_trace_create.request_json: - request_json_str = json_dumps(provider_trace_create.request_json) - provider_trace.request_json = json_loads(request_json_str) - - if provider_trace_create.response_json: - response_json_str = json_dumps(provider_trace_create.response_json) - provider_trace.response_json = json_loads(response_json_str) - provider_trace.create(session, actor=actor) - return provider_trace.to_pydantic() - @singleton class NoopTelemetryManager(TelemetryManager): diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index ac8d8ac5..a573a9b1 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -148,22 +148,6 @@ class ToolManager: PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor ) - @enforce_types - @trace_method - def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: - """Create a new tool based on the ToolCreate schema.""" - with db_registry.session() as session: - # Auto-generate description if not provided - if pydantic_tool.description is None: - pydantic_tool.description = pydantic_tool.json_schema.get("description", None) - tool_data = pydantic_tool.model_dump(to_orm=True) - # Set the organization id at the ORM layer - tool_data["organization_id"] = actor.organization_id - - tool = ToolModel(**tool_data) - tool.create(session, actor=actor) # Re-raise other database-related errors - return tool.to_pydantic() - @enforce_types @trace_method async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: @@ -232,16 +216,6 @@ class ToolManager: # fallback to individual upserts for sqlite return await self._upsert_tools_individually(pydantic_tools, actor, override_existing_tools) - @enforce_types - @trace_method - def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool: - """Fetch a tool by its ID.""" - with db_registry.session() as session: - # Retrieve tool by id using the Tool model's read method - tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) - # Convert the SQLAlchemy Tool object to PydanticTool - return tool.to_pydantic() - @enforce_types @trace_method async def get_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> PydanticTool: @@ -252,17 +226,6 @@ class ToolManager: # Convert the SQLAlchemy Tool object to PydanticTool return tool.to_pydantic() - @enforce_types - @trace_method - def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]: - """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" - try: - with db_registry.session() as session: - tool = ToolModel.read(db_session=session, name=tool_name, actor=actor) - return tool.to_pydantic() - except NoResultFound: - return None - @enforce_types @trace_method async def get_tool_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]: @@ -274,17 +237,6 @@ class ToolManager: except NoResultFound: return None - @enforce_types - @trace_method - def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]: - """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" - try: - with db_registry.session() as session: - tool = ToolModel.read(db_session=session, name=tool_name, actor=actor) - return tool.id - except NoResultFound: - return None - @enforce_types @trace_method async def get_tool_id_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[str]: @@ -568,114 +520,6 @@ class ToolManager: return await ToolModel.size_async(db_session=session, actor=actor) return await ToolModel.size_async(db_session=session, actor=actor, name=LETTA_TOOL_SET) - @enforce_types - @trace_method - def update_tool_by_id( - self, - tool_id: str, - tool_update: ToolUpdate, - actor: PydanticUser, - updated_tool_type: Optional[ToolType] = None, - bypass_name_check: bool = False, - ) -> PydanticTool: - # TODO: remove this (legacy non-async) - """ - Update a tool with complex validation and schema derivation logic. - - This method handles updates differently based on tool type: - - MCP tools: JSON schema is trusted, no Python source derivation - - Python/TypeScript tools: Schema derived from source code if provided - - Name conflicts are checked unless bypassed - - Args: - tool_id: The UUID of the tool to update - tool_update: Partial update data (only changed fields) - actor: User performing the update (for permissions) - updated_tool_type: Optional new tool type (e.g., converting custom to builtin) - bypass_name_check: Skip name conflict validation (use with caution) - - Returns: - Updated tool as Pydantic model - - Raises: - LettaToolNameConflictError: If new name conflicts with existing tool - NoResultFound: If tool doesn't exist or user lacks access - - Side Effects: - - Updates tool in database - - May change tool name if source code is modified - - Recomputes JSON schema from source for non-MCP tools - - Important: - When source_code is provided for Python/TypeScript tools, the name - MUST match the function name in the code, overriding any name in json_schema - """ - # First, check if source code update would cause a name conflict - update_data = tool_update.model_dump(to_orm=True, exclude_none=True) - new_name = None - new_schema = None - - # Fetch current tool to allow conditional logic based on tool type - current_tool = self.get_tool_by_id(tool_id=tool_id, actor=actor) - - # For MCP tools, do NOT derive schema from Python source. Trust provided JSON schema. - if current_tool.tool_type == ToolType.EXTERNAL_MCP: - if "json_schema" in update_data: - new_schema = update_data["json_schema"].copy() - new_name = new_schema.get("name", current_tool.name) - else: - new_schema = current_tool.json_schema - new_name = current_tool.name - update_data.pop("source_code", None) - if new_name != current_tool.name: - existing_tool = self.get_tool_by_name(tool_name=new_name, actor=actor) - if existing_tool: - raise LettaToolNameConflictError(tool_name=new_name) - else: - # For non-MCP tools, preserve existing behavior - if "source_code" in update_data.keys() and not bypass_name_check: - # Check source type to use appropriate parser - source_type = update_data.get("source_type", current_tool.source_type) - if source_type == "typescript": - from letta.functions.typescript_parser import derive_typescript_json_schema - - derived_schema = derive_typescript_json_schema(source_code=update_data["source_code"]) - else: - # Default to Python for backwards compatibility - derived_schema = derive_openai_json_schema(source_code=update_data["source_code"]) - - new_name = derived_schema["name"] - if "json_schema" not in update_data.keys(): - new_schema = derived_schema - else: - new_schema = update_data["json_schema"].copy() - new_schema["name"] = new_name - update_data["json_schema"] = new_schema - if new_name != current_tool.name: - existing_tool = self.get_tool_by_name(tool_name=new_name, actor=actor) - if existing_tool: - raise LettaToolNameConflictError(tool_name=new_name) - - # Now perform the update within the session - with db_registry.session() as session: - # Fetch the tool by ID - tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) - - # Update tool attributes with only the fields that were explicitly set - for key, value in update_data.items(): - setattr(tool, key, value) - - # If we already computed the new schema, apply it - if new_schema is not None: - tool.json_schema = new_schema - tool.name = new_name - - if updated_tool_type: - tool.tool_type = updated_tool_type - - # Save the updated tool to the database - return tool.update(db_session=session, actor=actor).to_pydantic() - @enforce_types @trace_method async def update_tool_by_id_async( @@ -747,17 +591,6 @@ class ToolManager: tool = await tool.update_async(db_session=session, actor=actor) return tool.to_pydantic() - @enforce_types - @trace_method - def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None: - """Delete a tool by its ID.""" - with db_registry.session() as session: - try: - tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) - tool.hard_delete(db_session=session, actor=actor) - except NoResultFound: - raise ValueError(f"Tool with id {tool_id} not found.") - @enforce_types @trace_method async def delete_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> None: diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index bfa73ab0..31f6bbff 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -23,27 +23,6 @@ class UserManager: DEFAULT_USER_NAME = "default_user" DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000" - @enforce_types - @trace_method - def create_default_user(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser: - """Create the default user.""" - with db_registry.session() as session: - # Make sure the org id exists - try: - OrganizationModel.read(db_session=session, identifier=org_id) - except NoResultFound: - raise ValueError(f"No organization with {org_id} exists in the organization table.") - - # Try to retrieve the user - try: - user = UserModel.read(db_session=session, identifier=self.DEFAULT_USER_ID) - except NoResultFound: - # If it doesn't exist, make it - user = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id) - user.create(session) - - return user.to_pydantic() - @enforce_types @trace_method async def create_default_actor_async(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser: @@ -66,15 +45,6 @@ class UserManager: return actor.to_pydantic() - @enforce_types - @trace_method - def create_user(self, pydantic_user: PydanticUser) -> PydanticUser: - """Create a new user if it doesn't already exist.""" - with db_registry.session() as session: - new_user = UserModel(**pydantic_user.model_dump(to_orm=True)) - new_user.create(session) - return new_user.to_pydantic() - @enforce_types @trace_method async def create_actor_async(self, pydantic_user: PydanticUser) -> PydanticUser: @@ -85,23 +55,6 @@ class UserManager: await self._invalidate_actor_cache(new_user.id) return new_user.to_pydantic() - @enforce_types - @trace_method - def update_user(self, user_update: UserUpdate) -> PydanticUser: - """Update user details.""" - with db_registry.session() as session: - # Retrieve the existing user by ID - existing_user = UserModel.read(db_session=session, identifier=user_update.id) - - # Update only the fields that are provided in UserUpdate - update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) - for key, value in update_data.items(): - setattr(existing_user, key, value) - - # Commit the updated user - existing_user.update(session) - return existing_user.to_pydantic() - @enforce_types @trace_method async def update_actor_async(self, user_update: UserUpdate) -> PydanticUser: @@ -120,17 +73,6 @@ class UserManager: await self._invalidate_actor_cache(user_update.id) return existing_user.to_pydantic() - @enforce_types - @trace_method - def delete_user_by_id(self, user_id: str): - """Delete a user and their associated records (agents, sources, mappings).""" - with db_registry.session() as session: - # Delete from user table - user = UserModel.read(db_session=session, identifier=user_id) - user.hard_delete(session) - - session.commit() - @enforce_types @trace_method async def delete_actor_by_id_async(self, user_id: str): @@ -141,14 +83,6 @@ class UserManager: await user.hard_delete_async(session) await self._invalidate_actor_cache(user_id) - @enforce_types - @trace_method - def get_user_by_id(self, user_id: str) -> PydanticUser: - """Fetch a user by ID.""" - with db_registry.session() as session: - user = UserModel.read(db_session=session, identifier=user_id) - return user.to_pydantic() - @enforce_types @trace_method @async_redis_cache(key_func=lambda self, actor_id: f"actor_id:{actor_id}", model_class=PydanticUser) @@ -164,27 +98,6 @@ class UserManager: return user.to_pydantic() - @enforce_types - @trace_method - def get_default_user(self) -> PydanticUser: - """Fetch the default user. If it doesn't exist, create it.""" - try: - return self.get_user_by_id(self.DEFAULT_USER_ID) - except NoResultFound: - return self.create_default_user() - - @enforce_types - @trace_method - def get_user_or_default(self, user_id: Optional[str] = None): - """Fetch the user or default user.""" - if not user_id: - return self.get_default_user() - - try: - return self.get_user_by_id(user_id=user_id) - except NoResultFound: - return self.get_default_user() - @enforce_types @trace_method async def get_default_actor_async(self) -> PydanticUser: @@ -206,18 +119,6 @@ class UserManager: user = await self.create_default_actor_async(org_id=DEFAULT_ORG_ID) return user - @enforce_types - @trace_method - def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]: - """List all users with optional pagination.""" - with db_registry.session() as session: - users = UserModel.list( - db_session=session, - after=after, - limit=limit, - ) - return [user.to_pydantic() for user in users] - @enforce_types @trace_method async def list_actors_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]: diff --git a/letta/settings.py b/letta/settings.py index 44ef942d..c7acab0e 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -232,7 +232,7 @@ class Settings(BaseSettings): pg_echo: bool = False # Logging pool_pre_ping: bool = True # Pre ping to check for dead connections pool_use_lifo: bool = True - disable_sqlalchemy_pooling: bool = False + disable_sqlalchemy_pooling: bool = True db_max_concurrent_sessions: Optional[int] = 48 redis_host: Optional[str] = Field(default=None, description="Host for Redis instance") diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 276985f9..98f8fe9e 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -6,7 +6,6 @@ from typing import List import pytest -from letta.agent import Agent from letta.config import LettaConfig from letta.llm_api.helpers import calculate_summarizer_cutoff from letta.schemas.agent import CreateAgent diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 768e4545..1785e3e3 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -2,6 +2,7 @@ from typing import List, Optional import pytest +from letta.agents.agent_loop import AgentLoop from letta.config import LettaConfig from letta.errors import AgentFileExportError, AgentFileImportError from letta.orm import Base @@ -34,55 +35,57 @@ from tests.utils import create_tool_from_func # ------------------------------ -def _clear_tables(): - from letta.server.db import db_context +# +async def _clear_tables(): + from letta.server.db import db_registry - with db_context() as session: + async with db_registry.async_session() as session: for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues - session.execute(table.delete()) # Truncate table - session.commit() + await session.execute(table.delete()) # Truncate table + await session.commit() @pytest.fixture(autouse=True) -def clear_tables(): - _clear_tables() +async def clear_tables(): + await _clear_tables() @pytest.fixture -def server(): +async def server(): config = LettaConfig.load() config.save() server = SyncServer(init_with_default_org_and_user=True) - server.tool_manager.upsert_base_tools(actor=server.default_user) + await server.init_async() + await server.tool_manager.upsert_base_tools_async(actor=server.default_user) yield server @pytest.fixture -def default_organization(server: SyncServer): +async def default_organization(server: SyncServer): """Fixture to create and return the default organization.""" - org = server.organization_manager.create_default_organization() + org = await server.organization_manager.create_default_organization_async() yield org @pytest.fixture -def default_user(server: SyncServer, default_organization): +async def default_user(server: SyncServer, default_organization): """Fixture to create and return the default user within the default organization.""" - user = server.user_manager.create_default_user(org_id=default_organization.id) + user = await server.user_manager.create_default_actor_async(org_id=default_organization.id) yield user @pytest.fixture -def other_organization(server: SyncServer): +async def other_organization(server: SyncServer): """Fixture to create and return another organization.""" - org = server.organization_manager.create_organization(pydantic_org=Organization(name="test_org")) + org = await server.organization_manager.create_organization_async(pydantic_org=Organization(name="test_org")) yield org @pytest.fixture -def other_user(server: SyncServer, other_organization): +async def other_user(server: SyncServer, other_organization): """Fixture to create and return another user within the other organization.""" - user = server.user_manager.create_user(pydantic_user=User(organization_id=other_organization.id, name="test_user")) + user = await server.user_manager.create_actor_async(pydantic_user=User(organization_id=other_organization.id, name="test_user")) yield user @@ -120,19 +123,19 @@ def print_tool_func(): @pytest.fixture -def weather_tool(server, weather_tool_func, default_user): - weather_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=weather_tool_func), actor=default_user) +async def weather_tool(server, weather_tool_func, default_user): + weather_tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=weather_tool_func), actor=default_user) yield weather_tool @pytest.fixture -def print_tool(server, print_tool_func, default_user): - print_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=print_tool_func), actor=default_user) +async def print_tool(server, print_tool_func, default_user): + print_tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=print_tool_func), actor=default_user) yield print_tool @pytest.fixture -def test_block(server: SyncServer, default_user): +async def test_block(server: SyncServer, default_user): """Fixture to create and return a test block.""" block_data = Block( label="test_block", @@ -141,7 +144,7 @@ def test_block(server: SyncServer, default_user): limit=1000, metadata={"type": "test", "category": "demo"}, ) - block = server.block_manager.create_or_update_block(block_data, actor=default_user) + block = await server.block_manager.create_or_update_block_async(block_data, actor=default_user) yield block @@ -162,8 +165,16 @@ def agent_serialization_manager(server, default_user): yield manager +async def send_message_to_agent(server: SyncServer, agent_state, actor: User, messages: list[MessageCreate]): + agent_loop = AgentLoop.load(agent_state=agent_state, actor=actor) + result = await agent_loop.step( + input_messages=messages, + ) + return result + + @pytest.fixture -def test_agent(server: SyncServer, default_user, default_organization, test_block, weather_tool): +async def test_agent(server: SyncServer, default_user, default_organization, test_block, weather_tool): """Fixture to create and return a test agent with messages.""" memory_blocks = [ CreateBlock(label="human", value="User is a test user"), @@ -190,18 +201,16 @@ def test_agent(server: SyncServer, default_user, default_organization, test_bloc message_buffer_autoclear=False, ) - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( agent_create=create_agent_request, actor=default_user, ) - server.send_messages( - actor=default_user, - agent_id=agent_state.id, - input_messages=[MessageCreate(role=MessageRole.user, content="What's the weather like?")], + await send_message_to_agent( + server, agent_state, default_user, [MessageCreate(role=MessageRole.user, content="What's the weather like?")] ) - agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_state.id, actor=default_user) + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=default_user) yield agent_state @@ -938,7 +947,7 @@ class TestAgentFileExport: ], ) - second_agent = server.agent_manager.create_agent( + second_agent = await server.agent_manager.create_agent_async( agent_create=create_agent_request, actor=default_user, ) @@ -1140,7 +1149,7 @@ class TestAgentFileImport: # check embedding handle imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") - imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user) + imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user) assert imported_agent.embedding_config.handle == embedding_handle_override async def test_import_preserves_data(self, server, agent_serialization_manager, test_agent, default_user, other_user): @@ -1150,7 +1159,7 @@ class TestAgentFileImport: result = await agent_serialization_manager.import_file(agent_file, other_user) imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") - imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user) + imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user) assert imported_agent.name == test_agent.name assert imported_agent.system == test_agent.system @@ -1161,8 +1170,8 @@ class TestAgentFileImport: assert len(imported_agent.tools) == len(test_agent.tools) assert len(imported_agent.memory.blocks) == len(test_agent.memory.blocks) - original_messages = server.message_manager.list_messages_for_agent(test_agent.id, default_user) - imported_messages = server.message_manager.list_messages_for_agent(imported_agent_id, other_user) + original_messages = await server.message_manager.list_messages_for_agent_async(test_agent.id, default_user) + imported_messages = await server.message_manager.list_messages_for_agent_async(imported_agent_id, other_user) assert len(imported_messages) == len(original_messages) @@ -1178,11 +1187,11 @@ class TestAgentFileImport: result = await agent_serialization_manager.import_file(agent_file, other_user) imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") - imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user) + imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user) assert len(imported_agent.message_ids) == len(test_agent.message_ids) - imported_messages = server.message_manager.list_messages_for_agent(imported_agent_id, other_user) + imported_messages = await server.message_manager.list_messages_for_agent_async(imported_agent_id, other_user) imported_message_ids = {msg.id for msg in imported_messages} for in_context_id in imported_agent.message_ids: @@ -1425,7 +1434,7 @@ class TestAgentFileRoundTrip: current_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") current_user = target_user - imported_agent = server.agent_manager.get_agent_by_id(current_agent_id, current_user) + imported_agent = await server.agent_manager.get_agent_by_id_async(current_agent_id, current_user) assert imported_agent.name == test_agent.name @@ -1458,7 +1467,7 @@ class TestAgentFileEdgeCases: # Verify assert result.success imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") - imported_agent = server.agent_manager.get_agent_by_id(imported_agent_id, other_user) + imported_agent = await server.agent_manager.get_agent_by_id_async(imported_agent_id, other_user) assert len(imported_agent.message_ids) == 0 @@ -1473,18 +1482,14 @@ class TestAgentFileEdgeCases: tool_ids=[weather_tool.id], ) - agent_state = server.agent_manager.create_agent( + agent_state = await server.agent_manager.create_agent_async( agent_create=create_agent_request, actor=default_user, ) # Add many messages for i in range(10): - server.send_messages( - actor=default_user, - agent_id=agent_state.id, - input_messages=[MessageCreate(role=MessageRole.user, content=f"Message {i}")], - ) + await send_message_to_agent(server, agent_state, default_user, [MessageCreate(role=MessageRole.user, content=f"Message {i}")]) # Export agent_file = await agent_serialization_manager.export([agent_state.id], default_user) @@ -1499,7 +1504,7 @@ class TestAgentFileEdgeCases: # Verify all messages imported correctly assert result.success imported_agent_id = next(db_id for file_id, db_id in result.id_mappings.items() if file_id == "agent-0") - imported_messages = server.message_manager.list_messages_for_agent(imported_agent_id, other_user) + imported_messages = await server.message_manager.list_messages_for_agent_async(imported_agent_id, other_user) assert len(imported_messages) >= 10 @@ -1528,10 +1533,11 @@ class TestAgentFileValidation: assert valid_schema.agents[0].id == "agent-0" assert valid_schema.metadata.get("revision_id") == current_revision - def test_message_schema_conversion(self, test_agent, server, default_user): + @pytest.mark.asyncio + async def test_message_schema_conversion(self, test_agent, server, default_user): """Test MessageSchema.from_message conversion.""" # Get a message from the test agent - messages = server.message_manager.list_messages_for_agent(test_agent.id, default_user) + messages = await server.message_manager.list_messages_for_agent_async(test_agent.id, default_user) if messages: original_message = messages[0] diff --git a/tests/test_client.py b/tests/test_client.py index 9549c3f7..f7db7bf2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -100,14 +100,15 @@ def search_agent_two(client: Letta): @pytest.fixture(autouse=True) -def clear_tables(): +async def clear_tables(): """Clear the sandbox tables before each test.""" - from letta.server.db import db_context - with db_context() as session: - session.execute(delete(SandboxEnvironmentVariable)) - session.execute(delete(SandboxConfig)) - session.commit() + from letta.server.db import db_registry + + async with db_registry.async_session() as session: + await session.execute(delete(SandboxEnvironmentVariable)) + await session.execute(delete(SandboxConfig)) + await session.commit() # -------------------------------------------------------------------------------------------------------------------- @@ -157,7 +158,7 @@ def test_add_and_manage_tags_for_agent(client: Letta): client.agents.delete(agent.id) -def test_agent_tags(client: Letta): +def test_agent_tags(client: Letta, clear_tables): """Test creating agents with tags and retrieving tags via the API.""" # Create multiple agents with different tags @@ -185,6 +186,8 @@ def test_agent_tags(client: Letta): # Test getting all tags all_tags = client.tags.list() expected_tags = ["agent1", "agent2", "agent3", "development", "production", "test"] + print("ALL TAGS", all_tags) + print("EXPECTED TAGS", expected_tags) assert sorted(all_tags) == expected_tags # Test pagination