chore: migrate to ruff (#4305)
* base requirements * autofix * Configure ruff for Python linting and formatting - Set up minimal ruff configuration with basic checks (E, W, F, I) - Add temporary ignores for common issues during migration - Configure pre-commit hooks to use ruff with pass_filenames - This enables gradual migration from black to ruff * Delete sdj * autofixed only * migrate lint action * more autofixed * more fixes * change precommit * try changing the hook * try this stuff
This commit is contained in:
@@ -13,23 +13,13 @@ repos:
|
||||
hooks:
|
||||
- id: trufflehog
|
||||
name: TruffleHog
|
||||
entry: bash -c 'trufflehog git file://. --since-commit HEAD --results=verified,unknown --fail'
|
||||
entry: bash -c 'trufflehog git file://. --since-commit HEAD --results=verified,unknown --fail --no-update'
|
||||
language: system
|
||||
stages: ["pre-commit", "pre-push"]
|
||||
- id: autoflake
|
||||
name: autoflake
|
||||
entry: bash -c '[ -d "apps/core" ] && cd apps/core; uv run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .'
|
||||
language: system
|
||||
types: [python]
|
||||
- id: isort
|
||||
name: isort
|
||||
entry: bash -c '[ -d "apps/core" ] && cd apps/core; uv run isort --profile black .'
|
||||
language: system
|
||||
types: [python]
|
||||
exclude: ^docs/
|
||||
- id: black
|
||||
name: black
|
||||
entry: bash -c '[ -d "apps/core" ] && cd apps/core; uv run black --line-length 140 --target-version py310 --target-version py311 .'
|
||||
language: system
|
||||
types: [python]
|
||||
exclude: ^docs/
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.11
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
args: [ --fix ]
|
||||
- id: ruff-format
|
||||
|
||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
||||
# add default value of `False`
|
||||
op.add_column("block", sa.Column("read_only", sa.Boolean(), nullable=True))
|
||||
op.execute(
|
||||
f"""
|
||||
"""
|
||||
UPDATE block
|
||||
SET read_only = False
|
||||
"""
|
||||
|
||||
@@ -29,7 +29,7 @@ def upgrade() -> None:
|
||||
op.add_column("jobs", sa.Column("job_type", sa.String(), nullable=True))
|
||||
|
||||
# Set existing rows to have the default value of JobType.JOB
|
||||
op.execute(f"UPDATE jobs SET job_type = 'job' WHERE job_type IS NULL")
|
||||
op.execute("UPDATE jobs SET job_type = 'job' WHERE job_type IS NULL")
|
||||
|
||||
# Make the column non-nullable after setting default values
|
||||
op.alter_column("jobs", "job_type", existing_type=sa.String(), nullable=False)
|
||||
|
||||
@@ -30,7 +30,7 @@ def upgrade() -> None:
|
||||
|
||||
# fill in column with `False`
|
||||
op.execute(
|
||||
f"""
|
||||
"""
|
||||
UPDATE organizations
|
||||
SET privileged_tools = False
|
||||
"""
|
||||
|
||||
@@ -42,7 +42,7 @@ def upgrade() -> None:
|
||||
f"""
|
||||
UPDATE tools
|
||||
SET tool_type = '{letta_core_value}'
|
||||
WHERE name IN ({','.join(f"'{name}'" for name in BASE_TOOLS)});
|
||||
WHERE name IN ({",".join(f"'{name}'" for name in BASE_TOOLS)});
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -50,7 +50,7 @@ def upgrade() -> None:
|
||||
f"""
|
||||
UPDATE tools
|
||||
SET tool_type = '{letta_memory_core_value}'
|
||||
WHERE name IN ({','.join(f"'{name}'" for name in BASE_MEMORY_TOOLS)});
|
||||
WHERE name IN ({",".join(f"'{name}'" for name in BASE_MEMORY_TOOLS)});
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ def upgrade() -> None:
|
||||
WITH numbered_rows AS (
|
||||
SELECT
|
||||
id,
|
||||
ROW_NUMBER() OVER (ORDER BY {', '.join(ORDERING_COLUMNS)} ASC) as rn
|
||||
ROW_NUMBER() OVER (ORDER BY {", ".join(ORDERING_COLUMNS)} ASC) as rn
|
||||
FROM {TABLE_NAME}
|
||||
)
|
||||
UPDATE {TABLE_NAME}
|
||||
|
||||
@@ -49,9 +49,7 @@ from letta.schemas.enums import MessageRole, ProviderType, StepStatus, ToolType
|
||||
from letta.schemas.letta_message_content import ImageContent, TextContent
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Message as ChatCompletionMessage, UsageStatistics
|
||||
from letta.schemas.response_format import ResponseFormatType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
@@ -871,7 +869,6 @@ class Agent(BaseAgent):
|
||||
) -> AgentStepResponse:
|
||||
"""Runs a single step in the agent loop (generates at most one LLM call)"""
|
||||
try:
|
||||
|
||||
# Extract job_id from metadata if present
|
||||
job_id = metadata.get("job_id") if metadata else None
|
||||
|
||||
@@ -1084,9 +1081,9 @@ class Agent(BaseAgent):
|
||||
-> agent.step(messages=[Message(role='user', text=...)])
|
||||
"""
|
||||
# Wrap with metadata, dumps to JSON
|
||||
assert user_message_str and isinstance(
|
||||
user_message_str, str
|
||||
), f"user_message_str should be a non-empty string, got {type(user_message_str)}"
|
||||
assert user_message_str and isinstance(user_message_str, str), (
|
||||
f"user_message_str should be a non-empty string, got {type(user_message_str)}"
|
||||
)
|
||||
user_message_json_str = package_user_message(user_message_str, self.agent_state.timezone)
|
||||
|
||||
# Validate JSON via save/load
|
||||
|
||||
@@ -269,16 +269,20 @@ class LettaAgent(BaseAgent):
|
||||
effective_step_id = step_id if logged_step else None
|
||||
|
||||
try:
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm(
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
agent_step_span,
|
||||
step_metrics,
|
||||
)
|
||||
(
|
||||
request_data,
|
||||
response_data,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
valid_tool_names,
|
||||
) = await self._build_and_request_from_llm(
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
agent_step_span,
|
||||
step_metrics,
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
@@ -574,16 +578,20 @@ class LettaAgent(BaseAgent):
|
||||
effective_step_id = step_id if logged_step else None
|
||||
|
||||
try:
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm(
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
agent_step_span,
|
||||
step_metrics,
|
||||
)
|
||||
(
|
||||
request_data,
|
||||
response_data,
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
valid_tool_names,
|
||||
) = await self._build_and_request_from_llm(
|
||||
current_in_context_messages,
|
||||
new_in_context_messages,
|
||||
agent_state,
|
||||
llm_client,
|
||||
tool_rules_solver,
|
||||
agent_step_span,
|
||||
step_metrics,
|
||||
)
|
||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||
|
||||
@@ -1626,7 +1634,6 @@ class LettaAgent(BaseAgent):
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
is_final_step: bool | None,
|
||||
) -> tuple[bool, str | None, LettaStopReason | None]:
|
||||
|
||||
continue_stepping = request_heartbeat
|
||||
heartbeat_reason: str | None = None
|
||||
stop_reason: LettaStopReason | None = None
|
||||
@@ -1658,9 +1665,7 @@ class LettaAgent(BaseAgent):
|
||||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
||||
if not continue_stepping and uncalled:
|
||||
continue_stepping = True
|
||||
heartbeat_reason = (
|
||||
f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [" f"{', '.join(uncalled)}] to be called still."
|
||||
)
|
||||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [{', '.join(uncalled)}] to be called still."
|
||||
|
||||
stop_reason = None # reset – we’re still going
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStrea
|
||||
# TODO: Please note his is a very generous timeout for e2b reasons
|
||||
with httpx.Client(timeout=httpx.Timeout(5 * 60.0, read=5 * 60.0)) as client:
|
||||
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
|
||||
|
||||
# Check for immediate HTTP errors before processing the SSE stream
|
||||
if not event_source.response.is_success:
|
||||
response_bytes = event_source.response.read()
|
||||
|
||||
@@ -593,7 +593,6 @@ def generate_tool_schema_for_mcp(
|
||||
append_heartbeat: bool = True,
|
||||
strict: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# MCP tool.inputSchema is a JSON schema
|
||||
# https://github.com/modelcontextprotocol/python-sdk/blob/775f87981300660ee957b63c2a14b448ab9c3675/src/mcp/types.py#L678
|
||||
parameters_schema = mcp_tool.inputSchema
|
||||
|
||||
@@ -2,8 +2,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
||||
from sqlalchemy import Dialect
|
||||
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import timezone as dt_timezone
|
||||
from datetime import datetime, timedelta, timezone as dt_timezone
|
||||
from typing import Callable
|
||||
|
||||
import pytz
|
||||
|
||||
@@ -317,7 +317,6 @@ async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int =
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
|
||||
kwargs = {"namespace": namespace, "prefix": file_id}
|
||||
if limit is not None:
|
||||
kwargs["limit"] = limit
|
||||
|
||||
@@ -198,23 +198,23 @@ class CLIInterface(AgentInterface):
|
||||
try:
|
||||
msg_dict = eval(function_args)
|
||||
if function_name == "archival_memory_search":
|
||||
output = f'\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
|
||||
output = f"\tquery: {msg_dict['query']}, page: {msg_dict['page']}"
|
||||
if STRIP_UI:
|
||||
print(output)
|
||||
else:
|
||||
print(f"{Fore.RED}{output}{Style.RESET_ALL}")
|
||||
elif function_name == "archival_memory_insert":
|
||||
output = f'\t→ {msg_dict["content"]}'
|
||||
output = f"\t→ {msg_dict['content']}"
|
||||
if STRIP_UI:
|
||||
print(output)
|
||||
else:
|
||||
print(f"{Style.BRIGHT}{Fore.RED}{output}{Style.RESET_ALL}")
|
||||
else:
|
||||
if STRIP_UI:
|
||||
print(f'\t {msg_dict["old_content"]}\n\t→ {msg_dict["new_content"]}')
|
||||
print(f"\t {msg_dict['old_content']}\n\t→ {msg_dict['new_content']}")
|
||||
else:
|
||||
print(
|
||||
f'{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}{Style.RESET_ALL}'
|
||||
f"{Style.BRIGHT}\t{Fore.RED} {msg_dict['old_content']}\n\t{Fore.GREEN}→ {msg_dict['new_content']}{Style.RESET_ALL}"
|
||||
)
|
||||
except Exception as e:
|
||||
printd(str(e))
|
||||
@@ -223,7 +223,7 @@ class CLIInterface(AgentInterface):
|
||||
print_function_message("🧠", f"searching memory with {function_name}")
|
||||
try:
|
||||
msg_dict = eval(function_args)
|
||||
output = f'\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
|
||||
output = f"\tquery: {msg_dict['query']}, page: {msg_dict['page']}"
|
||||
if STRIP_UI:
|
||||
print(output)
|
||||
else:
|
||||
|
||||
@@ -232,16 +232,13 @@ class OpenAIStreamingInterface:
|
||||
|
||||
# If we have main_json, we should output a ToolCallMessage
|
||||
elif updates_main_json:
|
||||
|
||||
# If there's something in the function_name buffer, we should release it first
|
||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||
# however the frontend may expect name first, then args, so to be
|
||||
# safe we'll output name first in a separate chunk
|
||||
if self.function_name_buffer:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
||||
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
@@ -373,7 +370,6 @@ class OpenAIStreamingInterface:
|
||||
# clear buffers
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
|
||||
@@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
import anthropic
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import BetaMessage as AnthropicMessage
|
||||
from anthropic.types.beta import BetaRawMessageStreamEvent
|
||||
from anthropic.types.beta import BetaMessage as AnthropicMessage, BetaRawMessageStreamEvent
|
||||
from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming
|
||||
from anthropic.types.beta.messages import BetaMessageBatch
|
||||
from anthropic.types.beta.messages.batch_create_params import Request
|
||||
@@ -34,9 +33,14 @@ from letta.otel.tracing import trace_method
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
FunctionCall,
|
||||
Message as ChoiceMessage,
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
from letta.settings import model_settings
|
||||
|
||||
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
||||
@@ -45,7 +49,6 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient(LLMClientBase):
|
||||
|
||||
@trace_method
|
||||
@deprecated("Synchronous version of this is no longer valid. Will result in model_dump of coroutine")
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
|
||||
@@ -13,7 +13,6 @@ from letta.settings import model_settings
|
||||
|
||||
|
||||
class AzureClient(OpenAIClient):
|
||||
|
||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
@@ -16,7 +16,6 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BedrockClient(AnthropicClient):
|
||||
|
||||
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]:
|
||||
override_access_key_id, override_secret_access_key, override_default_region = None, None, None
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
|
||||
@@ -11,11 +11,18 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import Message as _Message
|
||||
from letta.schemas.openai.chat_completion_request import AssistantMessage, ChatCompletionRequest, ChatMessage
|
||||
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
||||
from letta.schemas.openai.chat_completion_request import Tool, ToolFunctionChoice, ToolMessage, UserMessage, cast_message_to_subtype
|
||||
from letta.schemas.message import Message as PydanticMessage, Message as _Message
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
AssistantMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
FunctionCall as ToolFunctionChoiceFunctionCall,
|
||||
Tool,
|
||||
ToolFunctionChoice,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
cast_message_to_subtype,
|
||||
)
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.openai import Function, ToolCall
|
||||
from letta.settings import model_settings
|
||||
@@ -313,7 +320,6 @@ def convert_deepseek_response_to_chatcompletion(
|
||||
|
||||
|
||||
class DeepseekClient(OpenAIClient):
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GoogleVertexClient(LLMClientBase):
|
||||
|
||||
def _get_client(self):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
return genai.Client(
|
||||
@@ -344,9 +343,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
||||
|
||||
assert (
|
||||
INNER_THOUGHTS_KWARG_VERTEX in function_args
|
||||
), f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
|
||||
f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||
)
|
||||
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
||||
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
||||
else:
|
||||
@@ -380,9 +379,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
||||
|
||||
assert (
|
||||
INNER_THOUGHTS_KWARG_VERTEX in function_args
|
||||
), f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
|
||||
f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||
)
|
||||
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
||||
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
||||
else:
|
||||
@@ -406,7 +405,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
if candidate.finish_reason == "MAX_TOKENS":
|
||||
raise ValueError(f"Could not parse response data from LLM: exceeded max token limit")
|
||||
raise ValueError("Could not parse response data from LLM: exceeded max token limit")
|
||||
# Inner thoughts are the content by default
|
||||
inner_thoughts = response_message.text
|
||||
|
||||
@@ -463,7 +462,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
)
|
||||
else:
|
||||
# Count it ourselves
|
||||
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
|
||||
assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required"
|
||||
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
|
||||
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
@@ -14,7 +14,6 @@ from letta.settings import model_settings
|
||||
|
||||
|
||||
class GroqClient(OpenAIClient):
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ async def mistral_get_model_list_async(url: str, api_key: str) -> dict:
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
logger.debug(f"Sending request to %s", url)
|
||||
logger.debug("Sending request to %s", url)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# TODO add query param "tool" to be true
|
||||
|
||||
@@ -21,11 +21,15 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import log_event
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as _Message
|
||||
from letta.schemas.message import MessageRole as _MessageRole
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
||||
from letta.schemas.openai.chat_completion_request import FunctionSchema, Tool, ToolFunctionChoice, cast_message_to_subtype
|
||||
from letta.schemas.message import Message as _Message, MessageRole as _MessageRole
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
ChatCompletionRequest,
|
||||
FunctionCall as ToolFunctionChoiceFunctionCall,
|
||||
FunctionSchema,
|
||||
Tool,
|
||||
ToolFunctionChoice,
|
||||
cast_message_to_subtype,
|
||||
)
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
ChatCompletionChunkResponse,
|
||||
ChatCompletionResponse,
|
||||
|
||||
@@ -29,11 +29,14 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message_content import MessageContentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
||||
from letta.schemas.openai.chat_completion_request import FunctionSchema
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
ChatCompletionRequest,
|
||||
FunctionCall as ToolFunctionChoiceFunctionCall,
|
||||
FunctionSchema,
|
||||
Tool as OpenAITool,
|
||||
ToolFunctionChoice,
|
||||
cast_message_to_subtype,
|
||||
)
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from letta.settings import model_settings
|
||||
|
||||
|
||||
class TogetherClient(OpenAIClient):
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from letta.settings import model_settings
|
||||
|
||||
|
||||
class XAIClient(OpenAIClient):
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -205,7 +205,7 @@ def get_chat_completion(
|
||||
raise LocalLLMError(f"usage dict in response was missing fields ({usage})")
|
||||
|
||||
if usage["prompt_tokens"] is None:
|
||||
printd(f"usage dict was missing prompt_tokens, computing on-the-fly...")
|
||||
printd("usage dict was missing prompt_tokens, computing on-the-fly...")
|
||||
usage["prompt_tokens"] = count_tokens(prompt)
|
||||
|
||||
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
|
||||
@@ -220,7 +220,7 @@ def get_chat_completion(
|
||||
|
||||
# NOTE: this is the token count that matters most
|
||||
if usage["total_tokens"] is None:
|
||||
printd(f"usage dict was missing total_tokens, computing on-the-fly...")
|
||||
printd("usage dict was missing total_tokens, computing on-the-fly...")
|
||||
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
@@ -261,9 +261,9 @@ def generate_grammar_and_documentation(
|
||||
):
|
||||
from letta.utils import printd
|
||||
|
||||
assert not (
|
||||
add_inner_thoughts_top_level and add_inner_thoughts_param_level
|
||||
), "Can only place inner thoughts in one location in the grammar generator"
|
||||
assert not (add_inner_thoughts_top_level and add_inner_thoughts_param_level), (
|
||||
"Can only place inner thoughts in one location in the grammar generator"
|
||||
)
|
||||
|
||||
grammar_function_models = []
|
||||
# create_dynamic_model_from_function will add inner thoughts to the function parameters if add_inner_thoughts is True.
|
||||
|
||||
@@ -46,7 +46,7 @@ def get_completions_settings(defaults="simple") -> dict:
|
||||
with open(settings_file, "r", encoding="utf-8") as file:
|
||||
user_settings = json.load(file)
|
||||
if len(user_settings) > 0:
|
||||
printd(f"Updating base settings with the following user settings:\n{json_dumps(user_settings,indent=2)}")
|
||||
printd(f"Updating base settings with the following user settings:\n{json_dumps(user_settings, indent=2)}")
|
||||
settings.update(user_settings)
|
||||
else:
|
||||
printd(f"'{settings_file}' was empty, ignoring...")
|
||||
|
||||
@@ -13,8 +13,7 @@ from letta.orm.identity import Identity
|
||||
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, get_prompt_template_for_agent_type
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState, AgentType, get_prompt_template_for_agent_type
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
|
||||
@@ -8,8 +8,7 @@ from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import Human, Persona
|
||||
from letta.schemas.block import Block as PydanticBlock, Human, Persona
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import Organization
|
||||
|
||||
@@ -38,7 +38,9 @@ class BlockHistory(OrganizationMixin, SqlalchemyBase):
|
||||
|
||||
# Relationships
|
||||
block_id: Mapped[str] = mapped_column(
|
||||
String, ForeignKey("block.id", ondelete="CASCADE"), nullable=False # History deleted if Block is deleted
|
||||
String,
|
||||
ForeignKey("block.id", ondelete="CASCADE"),
|
||||
nullable=False, # History deleted if Block is deleted
|
||||
)
|
||||
|
||||
sequence_number: Mapped[int] = mapped_column(
|
||||
|
||||
@@ -10,7 +10,6 @@ from letta.schemas.group import Group as PydanticGroup
|
||||
|
||||
|
||||
class Group(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
||||
|
||||
__tablename__ = "groups"
|
||||
__pydantic_model__ = PydanticGroup
|
||||
|
||||
|
||||
@@ -7,8 +7,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin, ProjectMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.identity import Identity as PydanticIdentity
|
||||
from letta.schemas.identity import IdentityProperty
|
||||
from letta.schemas.identity import Identity as PydanticIdentity, IdentityProperty
|
||||
|
||||
|
||||
class Identity(SqlalchemyBase, OrganizationMixin, ProjectMixin):
|
||||
|
||||
@@ -7,8 +7,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from letta.orm.mixins import UserMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.enums import JobStatus, JobType
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import LettaRequestConfig
|
||||
from letta.schemas.job import Job as PydanticJob, LettaRequestConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.job_messages import JobMessage
|
||||
|
||||
@@ -9,8 +9,7 @@ from letta.orm.custom_columns import AgentStepStateColumn, BatchRequestResultCol
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus
|
||||
from letta.schemas.llm_batch_job import AgentStepState
|
||||
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
|
||||
from letta.schemas.llm_batch_job import AgentStepState, LLMBatchItem as PydanticLLMBatchItem
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
|
||||
@@ -7,10 +7,8 @@ from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
||||
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.letta_message_content import MessageContent
|
||||
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import ToolReturn
|
||||
from letta.schemas.letta_message_content import MessageContent, TextContent as PydanticTextContent
|
||||
from letta.schemas.message import Message as PydanticMessage, ToolReturn
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy import Enum as SqlEnum
|
||||
from sqlalchemy import Index, String, UniqueConstraint
|
||||
from sqlalchemy import JSON, Enum as SqlEnum, Index, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin, SandboxConfigMixin
|
||||
|
||||
@@ -37,7 +37,7 @@ def get_plugin(plugin_type: str):
|
||||
return plugin
|
||||
elif type(plugin).__name__ == "class":
|
||||
if plugin_register["protocol"] and not isinstance(plugin, type(plugin_register["protocol"])):
|
||||
raise TypeError(f'{plugin} does not implement {type(plugin_register["protocol"]).__name__}')
|
||||
raise TypeError(f"{plugin} does not implement {type(plugin_register['protocol']).__name__}")
|
||||
return plugin()
|
||||
raise TypeError("Unknown plugin type")
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from letta.schemas.memory import Memory
|
||||
|
||||
|
||||
class PromptGenerator:
|
||||
|
||||
# TODO: This code is kind of wonky and deserves a rewrite
|
||||
@trace_method
|
||||
@staticmethod
|
||||
|
||||
@@ -43,7 +43,6 @@ class EmbeddingConfig(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def default_config(cls, model_name: Optional[str] = None, provider: Optional[str] = None):
|
||||
|
||||
if model_name == "text-embedding-ada-002" and provider == "openai":
|
||||
return cls(
|
||||
embedding_model="text-embedding-ada-002",
|
||||
|
||||
@@ -9,8 +9,7 @@ from collections import OrderedDict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
|
||||
@@ -880,7 +879,6 @@ class Message(BaseMessage):
|
||||
# Tool calling
|
||||
if self.tool_calls is not None:
|
||||
for tool_call in self.tool_calls:
|
||||
|
||||
if put_inner_thoughts_in_kwargs:
|
||||
tool_call_input = add_inner_thoughts_to_tool_call(
|
||||
tool_call,
|
||||
@@ -1021,7 +1019,7 @@ class Message(BaseMessage):
|
||||
assert self.tool_call_id is not None, vars(self)
|
||||
|
||||
if self.name is None:
|
||||
warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.")
|
||||
warnings.warn("Couldn't find function name on tool call, defaulting to tool ID instead.")
|
||||
function_name = self.tool_call_id
|
||||
else:
|
||||
function_name = self.name
|
||||
|
||||
@@ -35,7 +35,7 @@ class BedrockProvider(Provider):
|
||||
response = await bedrock.list_inference_profiles()
|
||||
return response["inferenceProfileSummaries"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model list for bedrock: %s", e)
|
||||
logger.error("Error getting model list for bedrock: %s", e)
|
||||
raise e
|
||||
|
||||
async def check_api_key(self):
|
||||
|
||||
@@ -203,7 +203,7 @@ class OpenAIProvider(Provider):
|
||||
continue
|
||||
else:
|
||||
logger.debug(
|
||||
f"Skipping embedding models for %s by default, as we don't assume embeddings are supported."
|
||||
"Skipping embedding models for %s by default, as we don't assume embeddings are supported."
|
||||
"Please open an issue on GitHub if support is required.",
|
||||
self.base_url,
|
||||
)
|
||||
@@ -227,7 +227,7 @@ class OpenAIProvider(Provider):
|
||||
return LLM_MAX_TOKENS[model_name]
|
||||
else:
|
||||
logger.debug(
|
||||
f"Model %s on %s for provider %s not found in LLM_MAX_TOKENS. Using default of {{LLM_MAX_TOKENS['DEFAULT']}}",
|
||||
"Model %s on %s for provider %s not found in LLM_MAX_TOKENS. Using default of {LLM_MAX_TOKENS['DEFAULT']}",
|
||||
model_name,
|
||||
self.base_url,
|
||||
self.__class__.__name__,
|
||||
|
||||
@@ -218,9 +218,9 @@ class ToolCreate(LettaBase):
|
||||
composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False)
|
||||
|
||||
assert len(composio_action_schemas) > 0, "User supplied parameters do not match any Composio tools"
|
||||
assert (
|
||||
len(composio_action_schemas) == 1
|
||||
), f"User supplied parameters match too many Composio tools; {len(composio_action_schemas)} > 1"
|
||||
assert len(composio_action_schemas) == 1, (
|
||||
f"User supplied parameters match too many Composio tools; {len(composio_action_schemas)} > 1"
|
||||
)
|
||||
|
||||
composio_action_schema = composio_action_schemas[0]
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from letta.schemas.agent import AgentState
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
|
||||
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
||||
func_return: Optional[Any] = Field(None, description="The function return object")
|
||||
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
||||
|
||||
@@ -5,8 +5,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import letta
|
||||
from letta.orm import Agent
|
||||
from letta.orm import Message as MessageModel
|
||||
from letta.orm import Agent, Message as MessageModel
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.marshmallow_agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
||||
|
||||
@@ -261,7 +261,7 @@ def create_application() -> "FastAPI":
|
||||
|
||||
@app.exception_handler(BedrockPermissionError)
|
||||
async def bedrock_permission_error_handler(request, exc: BedrockPermissionError):
|
||||
logger.error(f"Bedrock permission denied.")
|
||||
logger.error("Bedrock permission denied.")
|
||||
if SENTRY_ENABLED:
|
||||
sentry_sdk.capture_exception(exc)
|
||||
|
||||
@@ -433,10 +433,10 @@ def start_server(
|
||||
if IS_WINDOWS:
|
||||
# Windows doesn't those the fancy unicode characters
|
||||
print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
||||
print(f"View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
||||
print("View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
||||
else:
|
||||
print(f"▶ Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
||||
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
||||
print("▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
||||
|
||||
if importlib.util.find_spec("granian") is not None and settings.use_granian:
|
||||
# Experimental Granian engine
|
||||
|
||||
@@ -22,7 +22,6 @@ class AuthRequest(BaseModel):
|
||||
|
||||
|
||||
def setup_auth_router(server: SyncServer, interface: QueuingInterface, password: str) -> APIRouter:
|
||||
|
||||
@router.post("/auth", tags=["auth"], response_model=AuthResponse)
|
||||
def authenticate_user(request: AuthRequest) -> AuthResponse:
|
||||
"""
|
||||
|
||||
@@ -377,9 +377,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
):
|
||||
"""Add an item to the deque"""
|
||||
assert self._active, "Generator is inactive"
|
||||
assert (
|
||||
isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or isinstance(item, MessageStreamStatus)
|
||||
), f"Wrong type: {type(item)}"
|
||||
assert isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or isinstance(item, MessageStreamStatus), (
|
||||
f"Wrong type: {type(item)}"
|
||||
)
|
||||
|
||||
self._chunks.append(item)
|
||||
self._event.set() # Signal that new data is available
|
||||
@@ -731,13 +731,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
# If we have main_json, we should output a ToolCallMessage
|
||||
elif updates_main_json:
|
||||
|
||||
# If there's something in the function_name buffer, we should release it first
|
||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||
# however the frontend may expect name first, then args, so to be
|
||||
# safe we'll output name first in a separate chunk
|
||||
if self.function_name_buffer:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
||||
processed_chunk = None
|
||||
@@ -778,7 +776,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a ToolCallMessage
|
||||
else:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and (
|
||||
self.last_flushed_function_name is not None
|
||||
@@ -860,7 +857,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# clear buffers
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
@@ -997,7 +993,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# Otherwise, do simple chunks of ToolCallMessage
|
||||
|
||||
else:
|
||||
|
||||
tool_call_delta = {}
|
||||
if tool_call.id:
|
||||
tool_call_delta["id"] = tool_call.id
|
||||
@@ -1073,7 +1068,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call = message_delta.tool_calls[0]
|
||||
|
||||
if tool_call.function:
|
||||
|
||||
# Track the function name while streaming
|
||||
# If we were previously on a 'send_message', we need to 'toggle' into 'content' mode
|
||||
if tool_call.function.name:
|
||||
@@ -1154,7 +1148,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Letta generates some internal monologue"""
|
||||
if not self.streaming_mode:
|
||||
|
||||
# create a fake "chunk" of a stream
|
||||
# processed_chunk = {
|
||||
# "internal_monologue": msg,
|
||||
@@ -1268,7 +1261,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
print(f"Failed to parse function message: {e}")
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
func_args = parse_json(function_call.function.arguments)
|
||||
except:
|
||||
|
||||
@@ -140,9 +140,7 @@ class RedisSSEStreamWriter:
|
||||
|
||||
self.last_flush[run_id] = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, " f"seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}"
|
||||
)
|
||||
logger.debug(f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}")
|
||||
|
||||
if chunks[-1].get("complete") == "true":
|
||||
self._cleanup_run(run_id)
|
||||
|
||||
@@ -34,7 +34,7 @@ async def list_blocks(
|
||||
),
|
||||
label_search: Optional[str] = Query(
|
||||
None,
|
||||
description=("Search blocks by label. If provided, returns blocks that match this label. " "This is a full-text search on labels."),
|
||||
description=("Search blocks by label. If provided, returns blocks that match this label. This is a full-text search on labels."),
|
||||
),
|
||||
description_search: Optional[str] = Query(
|
||||
None,
|
||||
|
||||
@@ -6,11 +6,17 @@ from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import SandboxType
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.environment_variables import (
|
||||
SandboxEnvironmentVariable as PydanticEnvVar,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxEnvironmentVariableUpdate,
|
||||
)
|
||||
from letta.schemas.sandbox_config import (
|
||||
LocalSandboxConfig,
|
||||
SandboxConfig as PydanticSandboxConfig,
|
||||
SandboxConfigCreate,
|
||||
SandboxConfigUpdate,
|
||||
)
|
||||
from letta.server.rest_api.utils import get_letta_server, get_user_id
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.helpers.tool_execution_helper import create_venv_for_local_sandbox, install_pip_requirements_for_sandbox
|
||||
|
||||
@@ -749,8 +749,8 @@ async def connect_mcp_server(
|
||||
except ConnectionError:
|
||||
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
|
||||
if isinstance(client, AsyncStdioMCPClient):
|
||||
logger.warning(f"OAuth not supported for stdio")
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"OAuth not supported for stdio")
|
||||
logger.warning("OAuth not supported for stdio")
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
|
||||
return
|
||||
# Continue to OAuth flow
|
||||
logger.info(f"Attempting OAuth flow for {request}...")
|
||||
|
||||
@@ -185,7 +185,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
try:
|
||||
await asyncio.shield(self._protected_stream_response(send))
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Stream response was cancelled, but shielded task should continue")
|
||||
logger.info("Stream response was cancelled, but shielded task should continue")
|
||||
except anyio.ClosedResourceError:
|
||||
logger.info("Client disconnected, but shielded task should continue")
|
||||
self._client_connected = False
|
||||
|
||||
@@ -7,8 +7,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional
|
||||
|
||||
from fastapi import Header, HTTPException
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -30,8 +30,10 @@ from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
|
||||
# TODO use custom interface
|
||||
from letta.interface import AgentInterface # abstract
|
||||
from letta.interface import CLIInterface # for printing to terminal
|
||||
from letta.interface import (
|
||||
AgentInterface, # abstract
|
||||
CLIInterface, # for printing to terminal
|
||||
)
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
|
||||
@@ -26,37 +26,42 @@ from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import AgentsTags, ArchivalPassage
|
||||
from letta.orm import Block as BlockModel
|
||||
from letta.orm import BlocksAgents
|
||||
from letta.orm import Group as GroupModel
|
||||
from letta.orm import GroupsAgents, IdentitiesAgents
|
||||
from letta.orm import Source as SourceModel
|
||||
from letta.orm import SourcePassage, SourcesAgents
|
||||
from letta.orm import Tool as ToolModel
|
||||
from letta.orm import ToolsAgents
|
||||
from letta.orm import (
|
||||
Agent as AgentModel,
|
||||
AgentsTags,
|
||||
ArchivalPassage,
|
||||
Block as BlockModel,
|
||||
BlocksAgents,
|
||||
Group as GroupModel,
|
||||
GroupsAgents,
|
||||
IdentitiesAgents,
|
||||
Source as SourceModel,
|
||||
SourcePassage,
|
||||
SourcesAgents,
|
||||
Tool as ToolModel,
|
||||
ToolsAgents,
|
||||
)
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, CreateAgent, InternalTemplateAgentCreate, UpdateAgent, get_prompt_template_for_agent_type
|
||||
from letta.schemas.block import DEFAULT_BLOCKS
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.agent import (
|
||||
AgentState as PydanticAgentState,
|
||||
AgentType,
|
||||
CreateAgent,
|
||||
InternalTemplateAgentCreate,
|
||||
UpdateAgent,
|
||||
get_prompt_template_for_agent_type,
|
||||
)
|
||||
from letta.schemas.block import DEFAULT_BLOCKS, Block as PydanticBlock, BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderType, TagMatchMode, ToolType, VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.group import Group as PydanticGroup
|
||||
from letta.schemas.group import ManagerType
|
||||
from letta.schemas.group import Group as PydanticGroup, ManagerType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
from letta.schemas.message import Message, Message as PydanticMessage, MessageCreate, MessageUpdate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
@@ -493,7 +498,6 @@ class AgentManager:
|
||||
# blocks
|
||||
block_ids = list(agent_create.block_ids or [])
|
||||
if agent_create.memory_blocks:
|
||||
|
||||
pydantic_blocks = [PydanticBlock(**b.model_dump(to_orm=True)) for b in agent_create.memory_blocks]
|
||||
|
||||
# Inject a description for the default blocks if the user didn't specify them
|
||||
@@ -798,7 +802,6 @@ class AgentManager:
|
||||
agent_update: UpdateAgent,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
|
||||
new_tools = set(agent_update.tool_ids or [])
|
||||
new_sources = set(agent_update.source_ids or [])
|
||||
new_blocks = set(agent_update.block_ids or [])
|
||||
@@ -806,7 +809,6 @@ class AgentManager:
|
||||
new_tags = set(agent_update.tags or [])
|
||||
|
||||
with db_registry.session() as session, session.begin():
|
||||
|
||||
agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
agent.updated_at = datetime.now(timezone.utc)
|
||||
agent.last_updated_by_id = actor.id
|
||||
@@ -923,7 +925,6 @@ class AgentManager:
|
||||
agent_update: UpdateAgent,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
|
||||
new_tools = set(agent_update.tool_ids or [])
|
||||
new_sources = set(agent_update.source_ids or [])
|
||||
new_blocks = set(agent_update.block_ids or [])
|
||||
@@ -931,7 +932,6 @@ class AgentManager:
|
||||
new_tags = set(agent_update.tags or [])
|
||||
|
||||
async with db_registry.async_session() as session, session.begin():
|
||||
|
||||
agent: AgentModel = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
agent.updated_at = datetime.now(timezone.utc)
|
||||
agent.last_updated_by_id = actor.id
|
||||
|
||||
@@ -4,9 +4,7 @@ from sqlalchemy import select
|
||||
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.log import get_logger
|
||||
from letta.orm import ArchivalPassage
|
||||
from letta.orm import Archive as ArchiveModel
|
||||
from letta.orm import ArchivesAgents
|
||||
from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
|
||||
from letta.schemas.archive import Archive as PydanticArchive
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
|
||||
@@ -13,8 +13,7 @@ from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.block import Block as PydanticBlock, BlockUpdate
|
||||
from letta.schemas.enums import ActorType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
|
||||
@@ -50,7 +50,8 @@ class AnthropicTokenCounter(TokenCounter):
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
key_func=lambda self,
|
||||
messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
@@ -61,7 +62,8 @@ class AnthropicTokenCounter(TokenCounter):
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
key_func=lambda self,
|
||||
tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
@@ -93,7 +95,8 @@ class TiktokenCounter(TokenCounter):
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
key_func=lambda self,
|
||||
messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
@@ -106,7 +109,8 @@ class TiktokenCounter(TokenCounter):
|
||||
|
||||
@trace_method
|
||||
@async_redis_cache(
|
||||
key_func=lambda self, tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
key_func=lambda self,
|
||||
tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||
prefix="token_counter",
|
||||
ttl_s=3600, # cache for 1 hour
|
||||
)
|
||||
|
||||
@@ -12,8 +12,7 @@ from letta.constants import MAX_FILENAME_LENGTH
|
||||
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.file import FileContent as FileContentModel
|
||||
from letta.orm.file import FileMetadata as FileMetadataModel
|
||||
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
@@ -60,7 +59,6 @@ class FileManager:
|
||||
*,
|
||||
text: Optional[str] = None,
|
||||
) -> PydanticFileMetadata:
|
||||
|
||||
# short-circuit if it already exists
|
||||
existing = await self.get_file_by_id(file_metadata.id, actor=actor)
|
||||
if existing:
|
||||
|
||||
@@ -7,10 +7,8 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.files_agents import FileAgent as FileAgentModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import FileBlock as PydanticFileBlock
|
||||
from letta.schemas.file import FileAgent as PydanticFileAgent
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.block import Block as PydanticBlock, FileBlock as PydanticFileBlock
|
||||
from letta.schemas.file import FileAgent as PydanticFileAgent, FileMetadata
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@@ -8,8 +8,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.group import Group as GroupModel
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.group import Group as PydanticGroup
|
||||
from letta.schemas.group import GroupCreate, GroupUpdate, InternalTemplateGroupCreate, ManagerType
|
||||
from letta.schemas.group import Group as PydanticGroup, GroupCreate, GroupUpdate, InternalTemplateGroupCreate, ManagerType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -18,7 +17,6 @@ from letta.utils import enforce_types
|
||||
|
||||
|
||||
class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_groups_async(
|
||||
|
||||
@@ -464,7 +464,6 @@ def package_initial_message_sequence(
|
||||
# create the agent object
|
||||
init_messages = []
|
||||
for message_create in initial_message_sequence:
|
||||
|
||||
if message_create.role == MessageRole.user:
|
||||
packed_message = system.package_user_message(
|
||||
user_message=message_create.content,
|
||||
@@ -498,8 +497,10 @@ def package_initial_message_sequence(
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall as OpenAIToolCall,
|
||||
Function as OpenAIFunction,
|
||||
)
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL
|
||||
|
||||
|
||||
@@ -9,8 +9,14 @@ from letta.orm.block import Block as BlockModel
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.orm.identity import Identity as IdentityModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.identity import Identity as PydanticIdentity
|
||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityType, IdentityUpdate, IdentityUpsert
|
||||
from letta.schemas.identity import (
|
||||
Identity as PydanticIdentity,
|
||||
IdentityCreate,
|
||||
IdentityProperty,
|
||||
IdentityType,
|
||||
IdentityUpdate,
|
||||
IdentityUpsert,
|
||||
)
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
@@ -18,7 +24,6 @@ from letta.utils import enforce_types
|
||||
|
||||
|
||||
class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_identities_async(
|
||||
|
||||
@@ -13,13 +13,10 @@ from letta.orm.job import Job as JobModel
|
||||
from letta.orm.job_messages import JobMessage
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step
|
||||
from letta.orm.step import Step as StepModel
|
||||
from letta.orm.step import Step, Step as StepModel
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.enums import JobStatus, JobType, MessageRole
|
||||
from letta.schemas.job import BatchJob as PydanticBatchJob
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.job import BatchJob as PydanticBatchJob, Job as PydanticJob, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
|
||||
@@ -11,9 +11,7 @@ from letta.orm.llm_batch_items import LLMBatchItem
|
||||
from letta.orm.llm_batch_job import LLMBatchJob
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.llm_batch_job import AgentStepState
|
||||
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
|
||||
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
|
||||
from letta.schemas.llm_batch_job import AgentStepState, LLMBatchItem as PydanticLLMBatchItem, LLMBatchJob as PydanticLLMBatchJob
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp import Tool as MCPTool
|
||||
from mcp import ClientSession, Tool as MCPTool
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.types import TextContent
|
||||
|
||||
|
||||
@@ -33,8 +33,7 @@ from letta.schemas.mcp import (
|
||||
UpdateStdioMCPServer,
|
||||
UpdateStreamableHTTPMCPServer,
|
||||
)
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.tool import Tool as PydanticTool, ToolCreate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
||||
@@ -137,8 +136,7 @@ class MCPManager:
|
||||
if mcp_tool.health:
|
||||
if mcp_tool.health.status == "INVALID":
|
||||
raise ValueError(
|
||||
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid."
|
||||
f"Reasons: {', '.join(mcp_tool.health.reasons)}"
|
||||
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid.Reasons: {', '.join(mcp_tool.health.reasons)}"
|
||||
)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
@@ -305,7 +303,9 @@ class MCPManager:
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
mcp_servers = await MCPServerModel.list_async(
|
||||
db_session=session, organization_id=actor.organization_id, id=mcp_server_ids # This will use the IN operator
|
||||
db_session=session,
|
||||
organization_id=actor.organization_id,
|
||||
id=mcp_server_ids, # This will use the IN operator
|
||||
)
|
||||
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
|
||||
|
||||
@@ -407,7 +407,6 @@ class MCPManager:
|
||||
# with the value being the schema from StdioServerParameters
|
||||
if MCP_CONFIG_TOPLEVEL_KEY in mcp_config:
|
||||
for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items():
|
||||
|
||||
# No support for duplicate server names
|
||||
if server_name in mcp_server_list:
|
||||
# Duplicate server names are configuration issues, not system errors
|
||||
|
||||
@@ -12,8 +12,7 @@ from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageUpdate
|
||||
from letta.schemas.message import Message as PydanticMessage, MessageUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.file_manager import FileManager
|
||||
@@ -185,9 +184,9 @@ class MessageManager:
|
||||
# modify the tool call for send_message
|
||||
# TODO: fix this if we add parallel tool calls
|
||||
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
||||
assert (
|
||||
message.tool_calls[0].function.name == "send_message"
|
||||
), f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||||
assert message.tool_calls[0].function.name == "send_message", (
|
||||
f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||||
)
|
||||
original_args = json.loads(message.tool_calls[0].function.arguments)
|
||||
original_args["message"] = letta_message_update.content # override the assistant message
|
||||
update_tool_call = message.tool_calls[0].__deepcopy__()
|
||||
@@ -224,9 +223,9 @@ class MessageManager:
|
||||
# modify the tool call for send_message
|
||||
# TODO: fix this if we add parallel tool calls
|
||||
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
||||
assert (
|
||||
message.tool_calls[0].function.name == "send_message"
|
||||
), f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||||
assert message.tool_calls[0].function.name == "send_message", (
|
||||
f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||||
)
|
||||
original_args = json.loads(message.tool_calls[0].function.arguments)
|
||||
original_args["message"] = letta_message_update.content # override the assistant message
|
||||
update_tool_call = message.tool_calls[0].__deepcopy__()
|
||||
|
||||
@@ -4,8 +4,7 @@ from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.organization import OrganizationUpdate
|
||||
from letta.schemas.organization import Organization as PydanticOrganization, OrganizationUpdate
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
@@ -3,15 +3,13 @@ from typing import List, Optional, Tuple, Union
|
||||
from letta.orm.provider import Provider as ProviderModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate
|
||||
from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
|
||||
|
||||
@@ -3,15 +3,20 @@ from typing import Dict, List, Optional
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel
|
||||
from letta.orm.sandbox_config import SandboxEnvironmentVariable as SandboxEnvVarModel
|
||||
from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel, SandboxEnvironmentVariable as SandboxEnvVarModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import SandboxType
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.environment_variables import (
|
||||
SandboxEnvironmentVariable as PydanticEnvVar,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxEnvironmentVariableUpdate,
|
||||
)
|
||||
from letta.schemas.sandbox_config import (
|
||||
LocalSandboxConfig,
|
||||
SandboxConfig as PydanticSandboxConfig,
|
||||
SandboxConfigCreate,
|
||||
SandboxConfigUpdate,
|
||||
)
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
@@ -9,8 +9,7 @@ from letta.orm.source import Source as SourceModel
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source import SourceUpdate
|
||||
from letta.schemas.source import Source as PydanticSource, SourceUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
@@ -29,7 +29,6 @@ class FeedbackType(str, Enum):
|
||||
|
||||
|
||||
class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_steps_async(
|
||||
|
||||
@@ -137,7 +137,7 @@ class Summarizer:
|
||||
total_message_count = len(all_in_context_messages)
|
||||
assert self.partial_evict_summarizer_percentage >= 0.0 and self.partial_evict_summarizer_percentage <= 1.0
|
||||
target_message_start = round((1.0 - self.partial_evict_summarizer_percentage) * total_message_count)
|
||||
logger.info(f"Target message count: {total_message_count}->{(total_message_count-target_message_start)}")
|
||||
logger.info(f"Target message count: {total_message_count}->{(total_message_count - target_message_start)}")
|
||||
|
||||
# The summary message we'll insert is role 'user' (vs 'assistant', 'tool', or 'system')
|
||||
# We are going to put it at index 1 (index 0 is the system message)
|
||||
|
||||
@@ -2,8 +2,7 @@ from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.singleton import singleton
|
||||
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace, ProviderTraceCreate
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
@@ -11,7 +10,6 @@ from letta.utils import enforce_types
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_provider_trace_by_step_id_async(
|
||||
|
||||
@@ -309,7 +309,7 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
|
||||
# Create numbered markdown for the LLM to reference
|
||||
numbered_lines = markdown_content.split("\n")
|
||||
numbered_markdown = "\n".join([f"{i+1:4d}: {line}" for i, line in enumerate(numbered_lines)])
|
||||
numbered_markdown = "\n".join([f"{i + 1:4d}: {line}" for i, line in enumerate(numbered_lines)])
|
||||
|
||||
# Truncate if too long
|
||||
max_content_length = 200000
|
||||
|
||||
@@ -273,14 +273,13 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
occurences = current_value.count(old_str)
|
||||
if occurences == 0:
|
||||
raise ValueError(
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear " f"verbatim in memory block with label `{label}`."
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in memory block with label `{label}`."
|
||||
)
|
||||
elif occurences > 1:
|
||||
content_value_lines = current_value.split("\n")
|
||||
lines = [idx + 1 for idx, line in enumerate(content_value_lines) if old_str in line]
|
||||
raise ValueError(
|
||||
f"No replacement was performed. Multiple occurrences of "
|
||||
f"old_str `{old_str}` in lines {lines}. Please ensure it is unique."
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique."
|
||||
)
|
||||
|
||||
# Replace old_str with new_str
|
||||
|
||||
@@ -568,7 +568,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
|
||||
source_ids = [source.id for source in attached_sources]
|
||||
if not source_ids:
|
||||
return f"No valid source IDs found for attached files"
|
||||
return "No valid source IDs found for attached files"
|
||||
|
||||
# Get all attached files for this agent
|
||||
file_agents = await self.files_agents_manager.list_files_for_agent(
|
||||
|
||||
@@ -25,7 +25,6 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
|
||||
pass
|
||||
|
||||
mcp_server_tag = [tag for tag in tool.tags if tag.startswith(f"{MCP_TOOL_TAG_NAME_PREFIX}:")]
|
||||
|
||||
@@ -34,7 +34,6 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
|
||||
# Store original memory state
|
||||
if agent_state:
|
||||
orig_memory_str = await agent_state.memory.compile_in_thread_async()
|
||||
|
||||
@@ -100,7 +100,7 @@ class ToolExecutionSandbox:
|
||||
logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n")
|
||||
for log_line in (result.stdout or []) + (result.stderr or []):
|
||||
logger.debug(f"{log_line}")
|
||||
logger.debug(f"Ending output log from tool run.")
|
||||
logger.debug("Ending output log from tool run.")
|
||||
|
||||
# Return result
|
||||
return result
|
||||
@@ -267,7 +267,6 @@ class ToolExecutionSandbox:
|
||||
|
||||
try:
|
||||
with self.temporary_env_vars(env):
|
||||
|
||||
# Read and compile the Python script
|
||||
with open(temp_file_path, "r", encoding="utf-8") as f:
|
||||
source = f.read()
|
||||
@@ -475,7 +474,7 @@ class ToolExecutionSandbox:
|
||||
return None, None
|
||||
result = pickle.loads(base64.b64decode(text))
|
||||
agent_state = None
|
||||
if not result["agent_state"] is None:
|
||||
if result["agent_state"] is not None:
|
||||
agent_state = result["agent_state"]
|
||||
return result["results"], agent_state
|
||||
|
||||
|
||||
@@ -27,8 +27,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.tool import Tool as ToolModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import ToolType
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.agent_manager_helper import calculate_multi_agent_tools
|
||||
|
||||
@@ -183,9 +183,9 @@ class ModalDeploymentManager:
|
||||
existing_app = await self._try_get_existing_app(sbx_config, version_hash, user)
|
||||
if existing_app:
|
||||
return existing_app, version_hash
|
||||
raise RuntimeError(f"Deployment completed but app not found")
|
||||
raise RuntimeError("Deployment completed but app not found")
|
||||
else:
|
||||
raise RuntimeError(f"Timeout waiting for deployment")
|
||||
raise RuntimeError("Timeout waiting for deployment")
|
||||
|
||||
# We're deploying - mark as in progress
|
||||
deployment_key = None
|
||||
|
||||
@@ -10,8 +10,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.orm.user import User as UserModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.schemas.user import User as PydanticUser, UserUpdate
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
@@ -89,7 +89,6 @@ class SummarizerSettings(BaseSettings):
|
||||
|
||||
|
||||
class ModelSettings(BaseSettings):
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||
|
||||
global_max_context_window_limit: int = 32000
|
||||
|
||||
@@ -117,9 +117,9 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
|
||||
|
||||
# Starting a new buffer line
|
||||
if not self.streaming_buffer_type:
|
||||
assert not (
|
||||
message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls)
|
||||
), f"Error: got both content and tool_calls in message stream\n{message_delta}"
|
||||
assert not (message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls)), (
|
||||
f"Error: got both content and tool_calls in message stream\n{message_delta}"
|
||||
)
|
||||
|
||||
if message_delta.content is not None:
|
||||
# Write out the prefix for inner thoughts
|
||||
|
||||
@@ -187,7 +187,7 @@ def package_summarize_message(summary, summary_message_count, hidden_message_cou
|
||||
|
||||
def package_summarize_message_no_counts(summary, timezone):
|
||||
context_message = (
|
||||
f"Note: prior messages have been hidden from view due to conversation memory constraints.\n"
|
||||
"Note: prior messages have been hidden from view due to conversation memory constraints.\n"
|
||||
+ f"The following is a summary of the previous messages:\n {summary}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1149,7 +1149,6 @@ class CancellationSignal:
|
||||
"""
|
||||
|
||||
def __init__(self, job_manager=None, job_id=None, actor=None):
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.user import User
|
||||
from letta.services.job_manager import JobManager
|
||||
|
||||
@@ -83,7 +83,7 @@ class CallbackHandler(BaseHTTPRequestHandler):
|
||||
<html>
|
||||
<body>
|
||||
<h1>Authorization Failed</h1>
|
||||
<p>Error: {query_params['error'][0]}</p>
|
||||
<p>Error: {query_params["error"][0]}</p>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -100,7 +100,7 @@ def generate_docqa_baseline_response(
|
||||
# print(f"Top {num_documents} documents: {documents_search_results_sorted_by_relevance}")
|
||||
|
||||
# compute truncation length
|
||||
extra_text = BASELINE_PROMPT + f"Question: {question}" + f"Answer:"
|
||||
extra_text = BASELINE_PROMPT + f"Question: {question}" + "Answer:"
|
||||
padding = count_tokens(extra_text) + 1000
|
||||
truncation_length = int((config.default_llm_config.context_window - padding) / num_documents)
|
||||
print("Token size", config.default_llm_config.context_window)
|
||||
@@ -114,7 +114,7 @@ def generate_docqa_baseline_response(
|
||||
if i >= num_documents:
|
||||
break
|
||||
|
||||
doc_prompt = f"Document [{i+1}]: {doc} \n"
|
||||
doc_prompt = f"Document [{i + 1}]: {doc} \n"
|
||||
|
||||
# truncate (that's why the performance goes down as x-axis increases)
|
||||
if truncation_length is not None:
|
||||
|
||||
@@ -124,7 +124,7 @@ if __name__ == "__main__":
|
||||
if a in response:
|
||||
found = True
|
||||
|
||||
if not found and not "INSUFFICIENT INFORMATION" in response:
|
||||
if not found and "INSUFFICIENT INFORMATION" not in response:
|
||||
# inconclusive: pass to llm judge
|
||||
print(question)
|
||||
print(answer)
|
||||
|
||||
@@ -61,13 +61,13 @@ def archival_memory_text_search(self, query: str, page: Optional[int] = 0) -> Op
|
||||
try:
|
||||
page = int(page)
|
||||
except:
|
||||
raise ValueError(f"'page' argument must be an integer")
|
||||
raise ValueError("'page' argument must be an integer")
|
||||
count = 10
|
||||
results = self.persistence_manager.archival_memory.storage.query_text(query, limit=count, offset=page * count)
|
||||
total = len(results)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
results_str = "No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"memory: {d.text}" for d in results]
|
||||
@@ -253,8 +253,8 @@ if __name__ == "__main__":
|
||||
# overwrite
|
||||
kv_d[current_key] = next_key
|
||||
|
||||
print(f"Nested {i+1}")
|
||||
print(f"Done")
|
||||
print(f"Nested {i + 1}")
|
||||
print("Done")
|
||||
|
||||
def get_nested_key(original_key, kv_d):
|
||||
key = original_key
|
||||
|
||||
@@ -67,6 +67,7 @@ dependencies = [
|
||||
"certifi>=2025.6.15",
|
||||
"markitdown[docx,pdf,pptx]>=0.1.2",
|
||||
"orjson>=3.11.1",
|
||||
"ruff[dev]>=0.12.10",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -108,11 +109,8 @@ dev = [
|
||||
"pytest-mock>=3.14.0",
|
||||
"pytest-json-report>=1.5.0",
|
||||
"pexpect>=4.9.0",
|
||||
"black[jupyter]>=24.4.2",
|
||||
"pre-commit>=3.5.0",
|
||||
"pyright>=1.1.347",
|
||||
"autoflake>=2.3.0",
|
||||
"isort>=5.13.2",
|
||||
"ipykernel>=6.29.5",
|
||||
"ipdb>=0.13.13",
|
||||
]
|
||||
@@ -149,18 +147,46 @@ build-backend = "hatchling.build"
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["letta"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 140
|
||||
target-version = ['py310', 'py311', 'py312', 'py313']
|
||||
extend-exclude = "examples/*"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 140
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
[tool.ruff]
|
||||
line-length = 140
|
||||
target-version = "py312"
|
||||
extend-exclude = [
|
||||
"examples/*",
|
||||
"tests/data/*",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"E402", # module import not at top of file
|
||||
"E711", # none-comparison
|
||||
"E712", # true-false-comparison
|
||||
"E722", # bare except
|
||||
"E721", # type comparison
|
||||
"F401", # unused import
|
||||
"F821", # undefined name
|
||||
"F811", # redefined while unused
|
||||
"F841", # local variable assigned but never used
|
||||
"W293", # blank line contains whitespace
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
force-single-line = false
|
||||
combine-as-imports = true
|
||||
split-on-trailing-comma = true
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
@@ -143,7 +143,7 @@ def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], k
|
||||
# Message field not in send_message
|
||||
if "message" not in arguments:
|
||||
raise InvalidToolCallError(
|
||||
messages=[target_message], explanation=f"send_message function call does not have required field `message`"
|
||||
messages=[target_message], explanation="send_message function call does not have required field `message`"
|
||||
)
|
||||
|
||||
# Check that the keyword is in the message arguments
|
||||
@@ -151,7 +151,7 @@ def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], k
|
||||
keyword = keyword.lower()
|
||||
arguments["message"] = arguments["message"].lower()
|
||||
|
||||
if not keyword in arguments["message"]:
|
||||
if keyword not in arguments["message"]:
|
||||
raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")
|
||||
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.file import FileAgent
|
||||
from letta.schemas.memory import ContextWindowOverview
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import User, User as PydanticUser
|
||||
from letta.server.rest_api.routers.v1.agents import ImportedAgentsResponse
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@@ -66,7 +65,6 @@ def retry_until_success(max_attempts=10, sleep_time_seconds=4):
|
||||
def decorator_retry(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -124,32 +122,32 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
|
||||
assert agent.llm_config == request.llm_config, f"LLM config mismatch: {agent.llm_config} != {request.llm_config}"
|
||||
|
||||
# Assert embedding configuration
|
||||
assert (
|
||||
agent.embedding_config == request.embedding_config
|
||||
), f"Embedding config mismatch: {agent.embedding_config} != {request.embedding_config}"
|
||||
assert agent.embedding_config == request.embedding_config, (
|
||||
f"Embedding config mismatch: {agent.embedding_config} != {request.embedding_config}"
|
||||
)
|
||||
|
||||
# Assert memory blocks
|
||||
if hasattr(request, "memory_blocks"):
|
||||
assert len(agent.memory.blocks) == len(request.memory_blocks) + len(
|
||||
request.block_ids
|
||||
), f"Memory blocks count mismatch: {len(agent.memory.blocks)} != {len(request.memory_blocks) + len(request.block_ids)}"
|
||||
assert len(agent.memory.blocks) == len(request.memory_blocks) + len(request.block_ids), (
|
||||
f"Memory blocks count mismatch: {len(agent.memory.blocks)} != {len(request.memory_blocks) + len(request.block_ids)}"
|
||||
)
|
||||
memory_block_values = {block.value for block in agent.memory.blocks}
|
||||
expected_block_values = {block.value for block in request.memory_blocks}
|
||||
assert expected_block_values.issubset(
|
||||
memory_block_values
|
||||
), f"Memory blocks mismatch: {expected_block_values} not in {memory_block_values}"
|
||||
assert expected_block_values.issubset(memory_block_values), (
|
||||
f"Memory blocks mismatch: {expected_block_values} not in {memory_block_values}"
|
||||
)
|
||||
|
||||
# Assert tools
|
||||
assert len(agent.tools) == len(request.tool_ids), f"Tools count mismatch: {len(agent.tools)} != {len(request.tool_ids)}"
|
||||
assert {tool.id for tool in agent.tools} == set(
|
||||
request.tool_ids
|
||||
), f"Tools mismatch: {set(tool.id for tool in agent.tools)} != {set(request.tool_ids)}"
|
||||
assert {tool.id for tool in agent.tools} == set(request.tool_ids), (
|
||||
f"Tools mismatch: {set(tool.id for tool in agent.tools)} != {set(request.tool_ids)}"
|
||||
)
|
||||
|
||||
# Assert sources
|
||||
assert len(agent.sources) == len(request.source_ids), f"Sources count mismatch: {len(agent.sources)} != {len(request.source_ids)}"
|
||||
assert {source.id for source in agent.sources} == set(
|
||||
request.source_ids
|
||||
), f"Sources mismatch: {set(source.id for source in agent.sources)} != {set(request.source_ids)}"
|
||||
assert {source.id for source in agent.sources} == set(request.source_ids), (
|
||||
f"Sources mismatch: {set(source.id for source in agent.sources)} != {set(request.source_ids)}"
|
||||
)
|
||||
|
||||
# Assert tags
|
||||
assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}"
|
||||
@@ -158,15 +156,15 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
|
||||
print("TOOLRULES", request.tool_rules)
|
||||
print("AGENTTOOLRULES", agent.tool_rules)
|
||||
if request.tool_rules:
|
||||
assert len(agent.tool_rules) == len(
|
||||
request.tool_rules
|
||||
), f"Tool rules count mismatch: {len(agent.tool_rules)} != {len(request.tool_rules)}"
|
||||
assert all(
|
||||
any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules
|
||||
), f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}"
|
||||
assert len(agent.tool_rules) == len(request.tool_rules), (
|
||||
f"Tool rules count mismatch: {len(agent.tool_rules)} != {len(request.tool_rules)}"
|
||||
)
|
||||
assert all(any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules), (
|
||||
f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}"
|
||||
)
|
||||
|
||||
# Assert message_buffer_autoclear
|
||||
if not request.message_buffer_autoclear is None:
|
||||
if request.message_buffer_autoclear is not None:
|
||||
assert agent.message_buffer_autoclear == request.message_buffer_autoclear
|
||||
|
||||
|
||||
@@ -176,9 +174,9 @@ def validate_context_window_overview(
|
||||
"""Validate common sense assertions for ContextWindowOverview"""
|
||||
|
||||
# 1. Current context size should not exceed maximum
|
||||
assert (
|
||||
overview.context_window_size_current <= overview.context_window_size_max
|
||||
), f"Current context size ({overview.context_window_size_current}) exceeds maximum ({overview.context_window_size_max})"
|
||||
assert overview.context_window_size_current <= overview.context_window_size_max, (
|
||||
f"Current context size ({overview.context_window_size_current}) exceeds maximum ({overview.context_window_size_max})"
|
||||
)
|
||||
|
||||
# 2. All token counts should be non-negative
|
||||
assert overview.num_tokens_system >= 0, "System token count cannot be negative"
|
||||
@@ -197,14 +195,14 @@ def validate_context_window_overview(
|
||||
+ overview.num_tokens_messages
|
||||
+ overview.num_tokens_functions_definitions
|
||||
)
|
||||
assert (
|
||||
overview.context_window_size_current == expected_total
|
||||
), f"Token sum ({expected_total}) doesn't match current size ({overview.context_window_size_current})"
|
||||
assert overview.context_window_size_current == expected_total, (
|
||||
f"Token sum ({expected_total}) doesn't match current size ({overview.context_window_size_current})"
|
||||
)
|
||||
|
||||
# 4. Message count should match messages list length
|
||||
assert (
|
||||
len(overview.messages) == overview.num_messages
|
||||
), f"Messages list length ({len(overview.messages)}) doesn't match num_messages ({overview.num_messages})"
|
||||
assert len(overview.messages) == overview.num_messages, (
|
||||
f"Messages list length ({len(overview.messages)}) doesn't match num_messages ({overview.num_messages})"
|
||||
)
|
||||
|
||||
# 5. If summary_memory is None, its token count should be 0
|
||||
if overview.summary_memory is None:
|
||||
|
||||
@@ -141,7 +141,7 @@ def create_test_agent(name, actor, test_id: Optional[str] = None, model="anthrop
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint="https://api.anthropic.com/v1",
|
||||
context_window=32000,
|
||||
handle=f"anthropic/claude-3-7-sonnet-latest",
|
||||
handle="anthropic/claude-3-7-sonnet-latest",
|
||||
put_inner_thoughts_in_kwargs=True,
|
||||
max_tokens=4096,
|
||||
)
|
||||
@@ -193,7 +193,7 @@ async def create_test_batch_item(server, batch_id, agent_id, default_user):
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint="https://api.anthropic.com/v1",
|
||||
context_window=32000,
|
||||
handle=f"anthropic/claude-3-7-sonnet-latest",
|
||||
handle="anthropic/claude-3-7-sonnet-latest",
|
||||
put_inner_thoughts_in_kwargs=True,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
@@ -219,7 +219,7 @@ def test_run_code(
|
||||
|
||||
returns = [m.tool_return for m in tool_returns]
|
||||
assert any(expected in ret for ret in returns), (
|
||||
f"For language={language!r}, expected to find '{expected}' in tool_return, " f"but got {returns!r}"
|
||||
f"For language={language!r}, expected to find '{expected}' in tool_return, but got {returns!r}"
|
||||
)
|
||||
|
||||
|
||||
@@ -357,7 +357,6 @@ async def test_web_search_uses_agent_env_var_model():
|
||||
patch.dict(os.environ, {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}),
|
||||
patch("firecrawl.AsyncFirecrawlApp") as mock_firecrawl_class,
|
||||
):
|
||||
|
||||
# setup mocks
|
||||
mock_model_settings.openai_api_key = "test-key"
|
||||
|
||||
|
||||
@@ -11,8 +11,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage as OpenAIUserMessage
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.tool_manager import ToolManager
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user