chore: bump v0.11.5 (#2777)
This commit is contained in:
@@ -41,6 +41,15 @@ jobs:
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
redis:
|
||||
image: redis:7
|
||||
ports:
|
||||
- 6379:6379
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 5s
|
||||
--health-timeout 5s
|
||||
--health-retries 10
|
||||
|
||||
steps:
|
||||
# Ensure secrets don't leak
|
||||
@@ -138,6 +147,8 @@ jobs:
|
||||
LETTA_PG_PASSWORD: postgres
|
||||
LETTA_PG_DB: postgres
|
||||
LETTA_PG_HOST: localhost
|
||||
LETTA_REDIS_HOST: localhost
|
||||
LETTA_REDIS_PORT: 6379
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
@@ -5,7 +5,7 @@ try:
|
||||
__version__ = version("letta")
|
||||
except PackageNotFoundError:
|
||||
# Fallback for development installations
|
||||
__version__ = "0.11.4"
|
||||
__version__ = "0.11.5"
|
||||
|
||||
if os.environ.get("LETTA_VERSION"):
|
||||
__version__ = os.environ["LETTA_VERSION"]
|
||||
|
||||
@@ -42,6 +42,7 @@ from letta.log import get_logger
|
||||
from letta.memory import summarize_messages
|
||||
from letta.orm import User
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent, get_prompt_template_for_agent_type
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -59,7 +60,7 @@ from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import check_supports_structured_output, compile_memory_metadata_block
|
||||
from letta.services.helpers.agent_manager_helper import check_supports_structured_output
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
@@ -330,8 +331,13 @@ class Agent(BaseAgent):
|
||||
return None
|
||||
|
||||
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
|
||||
# Extract terminal tool names from tool rules
|
||||
terminal_tool_names = {rule.tool_name for rule in self.tool_rules_solver.terminal_tool_rules}
|
||||
allowed_functions = runtime_override_tool_json_schema(
|
||||
tool_list=allowed_functions, response_format=self.agent_state.response_format, request_heartbeat=True
|
||||
tool_list=allowed_functions,
|
||||
response_format=self.agent_state.response_format,
|
||||
request_heartbeat=True,
|
||||
terminal_tools=terminal_tool_names,
|
||||
)
|
||||
|
||||
# For the first message, force the initial tool if one is specified
|
||||
@@ -1246,7 +1252,7 @@ class Agent(BaseAgent):
|
||||
|
||||
agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id)
|
||||
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||
external_memory_summary = compile_memory_metadata_block(
|
||||
external_memory_summary = PromptGenerator.compile_memory_metadata_block(
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
timezone=self.agent_state.timezone,
|
||||
previous_message_count=self.message_manager.size(actor=self.user, agent_id=self.agent_state.id),
|
||||
|
||||
@@ -7,6 +7,7 @@ from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
|
||||
@@ -17,7 +18,6 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.helpers.agent_manager_helper import get_system_message_from_compiled_memory
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.utils import united_diff
|
||||
@@ -142,7 +142,7 @@ class BaseAgent(ABC):
|
||||
if num_archival_memories is None:
|
||||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = get_system_message_from_compiled_memory(
|
||||
new_system_message_str = PromptGenerator.get_system_message_from_compiled_memory(
|
||||
system_prompt=agent_state.system,
|
||||
memory_with_sources=curr_memory_str,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
|
||||
@@ -137,6 +137,10 @@ class LettaAgent(BaseAgent):
|
||||
message_buffer_limit=message_buffer_limit,
|
||||
message_buffer_min=message_buffer_min,
|
||||
partial_evict_summarizer_percentage=partial_evict_summarizer_percentage,
|
||||
agent_manager=self.agent_manager,
|
||||
message_manager=self.message_manager,
|
||||
actor=self.actor,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
|
||||
async def _check_run_cancellation(self) -> bool:
|
||||
@@ -345,16 +349,17 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span.end()
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
if settings.track_provider_trace:
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
# stream step
|
||||
# TODO: improve TTFT
|
||||
@@ -642,17 +647,18 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span.end()
|
||||
|
||||
# Log LLM Trace
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
if settings.track_provider_trace:
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes())
|
||||
step_progression = StepProgression.FINISHED
|
||||
|
||||
@@ -1003,31 +1009,32 @@ class LettaAgent(BaseAgent):
|
||||
# Log LLM Trace
|
||||
# We are piecing together the streamed response here.
|
||||
# Content here does not match the actual response schema as streams come in chunks.
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json={
|
||||
"content": {
|
||||
"tool_call": tool_call.model_dump_json(),
|
||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||
if settings.track_provider_trace:
|
||||
await self.telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json={
|
||||
"content": {
|
||||
"tool_call": tool_call.model_dump_json(),
|
||||
"reasoning": [content.model_dump_json() for content in reasoning_content],
|
||||
},
|
||||
"id": interface.message_id,
|
||||
"model": interface.model,
|
||||
"role": "assistant",
|
||||
# "stop_reason": "",
|
||||
# "stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
},
|
||||
},
|
||||
"id": interface.message_id,
|
||||
"model": interface.model,
|
||||
"role": "assistant",
|
||||
# "stop_reason": "",
|
||||
# "stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
},
|
||||
},
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
# yields tool response as this is handled from Letta and not the response from the LLM provider
|
||||
tool_return = [msg for msg in persisted_messages if msg.role == "tool"][-1].to_letta_messages()[0]
|
||||
@@ -1352,6 +1359,7 @@ class LettaAgent(BaseAgent):
|
||||
) -> list[Message]:
|
||||
# If total tokens is reached, we truncate down
|
||||
# TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
|
||||
# TODO: `force` and `clear` seem to no longer be used, we should remove
|
||||
if force or (total_tokens and total_tokens > llm_config.context_window):
|
||||
self.logger.warning(
|
||||
f"Total tokens {total_tokens} exceeds configured max tokens {llm_config.context_window}, forcefully clearing message history."
|
||||
@@ -1363,6 +1371,7 @@ class LettaAgent(BaseAgent):
|
||||
clear=True,
|
||||
)
|
||||
else:
|
||||
# NOTE (Sarah): Seems like this is doing nothing?
|
||||
self.logger.info(
|
||||
f"Total tokens {total_tokens} does not exceed configured max tokens {llm_config.context_window}, passing summarizing w/o force."
|
||||
)
|
||||
@@ -1453,8 +1462,10 @@ class LettaAgent(BaseAgent):
|
||||
force_tool_call = valid_tool_names[0]
|
||||
|
||||
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
# Extract terminal tool names from tool rules
|
||||
terminal_tool_names = {rule.tool_name for rule in tool_rules_solver.terminal_tool_rules}
|
||||
allowed_tools = runtime_override_tool_json_schema(
|
||||
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
|
||||
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True, terminal_tools=terminal_tool_names
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.tool_execution_helper import add_pre_execution_message, enable_strict_mode, remove_request_heartbeat
|
||||
from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface
|
||||
from letta.log import get_logger
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole, ToolType
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
@@ -35,7 +36,6 @@ from letta.server.rest_api.utils import (
|
||||
)
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message_async
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
@@ -144,7 +144,7 @@ class VoiceAgent(BaseAgent):
|
||||
|
||||
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=self.actor)
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
in_context_messages[0].content[0].text = await compile_system_message_async(
|
||||
in_context_messages[0].content[0].text = await PromptGenerator.compile_system_message_async(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from letta.constants import REDIS_EXCLUDE, REDIS_INCLUDE, REDIS_SET_DEFAULT_VAL
|
||||
from letta.log import get_logger
|
||||
@@ -218,6 +218,126 @@ class AsyncRedisClient:
|
||||
client = await self.get_client()
|
||||
return await client.decr(key)
|
||||
|
||||
# Stream operations
|
||||
@with_retry()
|
||||
async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str:
|
||||
"""Add entry to a stream.
|
||||
|
||||
Args:
|
||||
stream: Stream name
|
||||
fields: Dict of field-value pairs to add
|
||||
id: Entry ID ('*' for auto-generation)
|
||||
maxlen: Maximum length of the stream
|
||||
approximate: Whether maxlen is approximate
|
||||
|
||||
Returns:
|
||||
The ID of the added entry
|
||||
"""
|
||||
client = await self.get_client()
|
||||
return await client.xadd(stream, fields, id=id, maxlen=maxlen, approximate=approximate)
|
||||
|
||||
@with_retry()
|
||||
async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]:
|
||||
"""Read from streams.
|
||||
|
||||
Args:
|
||||
streams: Dict mapping stream names to IDs
|
||||
count: Maximum number of entries to return
|
||||
block: Milliseconds to block waiting for data (None = no blocking)
|
||||
|
||||
Returns:
|
||||
List of entries from the streams
|
||||
"""
|
||||
client = await self.get_client()
|
||||
return await client.xread(streams, count=count, block=block)
|
||||
|
||||
@with_retry()
|
||||
async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]:
|
||||
"""Read range of entries from a stream.
|
||||
|
||||
Args:
|
||||
stream: Stream name
|
||||
start: Start ID (inclusive)
|
||||
end: End ID (inclusive)
|
||||
count: Maximum number of entries to return
|
||||
|
||||
Returns:
|
||||
List of entries in the specified range
|
||||
"""
|
||||
client = await self.get_client()
|
||||
return await client.xrange(stream, start, end, count=count)
|
||||
|
||||
@with_retry()
|
||||
async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]:
|
||||
"""Read range of entries from a stream in reverse order.
|
||||
|
||||
Args:
|
||||
stream: Stream name
|
||||
start: Start ID (inclusive)
|
||||
end: End ID (inclusive)
|
||||
count: Maximum number of entries to return
|
||||
|
||||
Returns:
|
||||
List of entries in the specified range in reverse order
|
||||
"""
|
||||
client = await self.get_client()
|
||||
return await client.xrevrange(stream, start, end, count=count)
|
||||
|
||||
@with_retry()
|
||||
async def xlen(self, stream: str) -> int:
|
||||
"""Get the length of a stream.
|
||||
|
||||
Args:
|
||||
stream: Stream name
|
||||
|
||||
Returns:
|
||||
Number of entries in the stream
|
||||
"""
|
||||
client = await self.get_client()
|
||||
return await client.xlen(stream)
|
||||
|
||||
@with_retry()
|
||||
async def xdel(self, stream: str, *ids: str) -> int:
|
||||
"""Delete entries from a stream.
|
||||
|
||||
Args:
|
||||
stream: Stream name
|
||||
ids: IDs of entries to delete
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
client = await self.get_client()
|
||||
return await client.xdel(stream, *ids)
|
||||
|
||||
@with_retry()
|
||||
async def xinfo_stream(self, stream: str) -> Dict:
|
||||
"""Get information about a stream.
|
||||
|
||||
Args:
|
||||
stream: Stream name
|
||||
|
||||
Returns:
|
||||
Dict with stream information
|
||||
"""
|
||||
client = await self.get_client()
|
||||
return await client.xinfo_stream(stream)
|
||||
|
||||
@with_retry()
|
||||
async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int:
|
||||
"""Trim a stream to a maximum length.
|
||||
|
||||
Args:
|
||||
stream: Stream name
|
||||
maxlen: Maximum length
|
||||
approximate: Whether maxlen is approximate
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
client = await self.get_client()
|
||||
return await client.xtrim(stream, maxlen=maxlen, approximate=approximate)
|
||||
|
||||
async def check_inclusion_and_exclusion(self, member: str, group: str) -> bool:
|
||||
exclude_key = self._get_group_exclusion_key(group)
|
||||
include_key = self._get_group_inclusion_key(group)
|
||||
@@ -290,6 +410,31 @@ class NoopAsyncRedisClient(AsyncRedisClient):
|
||||
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
|
||||
return 0
|
||||
|
||||
# Stream operations
|
||||
async def xadd(self, stream: str, fields: Dict[str, Any], id: str = "*", maxlen: Optional[int] = None, approximate: bool = True) -> str:
|
||||
return ""
|
||||
|
||||
async def xread(self, streams: Dict[str, str], count: Optional[int] = None, block: Optional[int] = None) -> List[Dict]:
|
||||
return []
|
||||
|
||||
async def xrange(self, stream: str, start: str = "-", end: str = "+", count: Optional[int] = None) -> List[Dict]:
|
||||
return []
|
||||
|
||||
async def xrevrange(self, stream: str, start: str = "+", end: str = "-", count: Optional[int] = None) -> List[Dict]:
|
||||
return []
|
||||
|
||||
async def xlen(self, stream: str) -> int:
|
||||
return 0
|
||||
|
||||
async def xdel(self, stream: str, *ids: str) -> int:
|
||||
return 0
|
||||
|
||||
async def xinfo_stream(self, stream: str) -> Dict:
|
||||
return {}
|
||||
|
||||
async def xtrim(self, stream: str, maxlen: int, approximate: bool = True) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
async def get_redis_client() -> AsyncRedisClient:
|
||||
global _client_instance
|
||||
|
||||
@@ -76,6 +76,10 @@ class LettaUserNotFoundError(LettaError):
|
||||
"""Error raised when a user is not found."""
|
||||
|
||||
|
||||
class LettaUnexpectedStreamCancellationError(LettaError):
|
||||
"""Error raised when a streaming request is terminated unexpectedly."""
|
||||
|
||||
|
||||
class LLMError(LettaError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ async def open_files(agent_state: "AgentState", file_requests: List[FileOpenRequ
|
||||
|
||||
Open multiple files with different view ranges:
|
||||
file_requests = [
|
||||
FileOpenRequest(file_name="project_utils/config.py", offset=1, length=50), # Lines 1-50
|
||||
FileOpenRequest(file_name="project_utils/main.py", offset=100, length=100), # Lines 100-199
|
||||
FileOpenRequest(file_name="project_utils/config.py", offset=0, length=50), # Lines 1-50
|
||||
FileOpenRequest(file_name="project_utils/main.py", offset=100, length=100), # Lines 101-200
|
||||
FileOpenRequest(file_name="project_utils/utils.py") # Entire file
|
||||
]
|
||||
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# class BaseMCPClient:
|
||||
# def __init__(self, server_config: BaseServerConfig):
|
||||
# self.server_config = server_config
|
||||
# self.session: Optional[ClientSession] = None
|
||||
# self.stdio = None
|
||||
# self.write = None
|
||||
# self.initialized = False
|
||||
# self.loop = asyncio.new_event_loop()
|
||||
# self.cleanup_funcs = []
|
||||
#
|
||||
# def connect_to_server(self):
|
||||
# asyncio.set_event_loop(self.loop)
|
||||
# success = self._initialize_connection(self.server_config, timeout=tool_settings.mcp_connect_to_server_timeout)
|
||||
#
|
||||
# if success:
|
||||
# try:
|
||||
# self.loop.run_until_complete(
|
||||
# asyncio.wait_for(self.session.initialize(), timeout=tool_settings.mcp_connect_to_server_timeout)
|
||||
# )
|
||||
# self.initialized = True
|
||||
# except asyncio.TimeoutError:
|
||||
# raise MCPTimeoutError("initializing session", self.server_config.server_name, tool_settings.mcp_connect_to_server_timeout)
|
||||
# else:
|
||||
# raise RuntimeError(
|
||||
# f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
|
||||
# )
|
||||
#
|
||||
# def _initialize_connection(self, server_config: BaseServerConfig, timeout: float) -> bool:
|
||||
# raise NotImplementedError("Subclasses must implement _initialize_connection")
|
||||
#
|
||||
# def list_tools(self) -> List[MCPTool]:
|
||||
# self._check_initialized()
|
||||
# try:
|
||||
# response = self.loop.run_until_complete(
|
||||
# asyncio.wait_for(self.session.list_tools(), timeout=tool_settings.mcp_list_tools_timeout)
|
||||
# )
|
||||
# return response.tools
|
||||
# except asyncio.TimeoutError:
|
||||
# logger.error(
|
||||
# f"Timed out while listing tools for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_list_tools_timeout}s)."
|
||||
# )
|
||||
# raise MCPTimeoutError("listing tools", self.server_config.server_name, tool_settings.mcp_list_tools_timeout)
|
||||
#
|
||||
# def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
# self._check_initialized()
|
||||
# try:
|
||||
# result = self.loop.run_until_complete(
|
||||
# asyncio.wait_for(self.session.call_tool(tool_name, tool_args), timeout=tool_settings.mcp_execute_tool_timeout)
|
||||
# )
|
||||
#
|
||||
# parsed_content = []
|
||||
# for content_piece in result.content:
|
||||
# if isinstance(content_piece, TextContent):
|
||||
# parsed_content.append(content_piece.text)
|
||||
# print("parsed_content (text)", parsed_content)
|
||||
# else:
|
||||
# parsed_content.append(str(content_piece))
|
||||
# print("parsed_content (other)", parsed_content)
|
||||
#
|
||||
# if len(parsed_content) > 0:
|
||||
# final_content = " ".join(parsed_content)
|
||||
# else:
|
||||
# # TODO move hardcoding to constants
|
||||
# final_content = "Empty response from tool"
|
||||
#
|
||||
# return final_content, result.isError
|
||||
# except asyncio.TimeoutError:
|
||||
# logger.error(
|
||||
# f"Timed out while executing tool '{tool_name}' for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_execute_tool_timeout}s)."
|
||||
# )
|
||||
# raise MCPTimeoutError(f"executing tool '{tool_name}'", self.server_config.server_name, tool_settings.mcp_execute_tool_timeout)
|
||||
#
|
||||
# def _check_initialized(self):
|
||||
# if not self.initialized:
|
||||
# logger.error("MCPClient has not been initialized")
|
||||
# raise RuntimeError("MCPClient has not been initialized")
|
||||
#
|
||||
# def cleanup(self):
|
||||
# try:
|
||||
# for cleanup_func in self.cleanup_funcs:
|
||||
# cleanup_func()
|
||||
# self.initialized = False
|
||||
# if not self.loop.is_closed():
|
||||
# self.loop.close()
|
||||
# except Exception as e:
|
||||
# logger.warning(e)
|
||||
# finally:
|
||||
# logger.info("Cleaned up MCP clients on shutdown.")
|
||||
#
|
||||
#
|
||||
# class BaseAsyncMCPClient:
|
||||
# def __init__(self, server_config: BaseServerConfig):
|
||||
# self.server_config = server_config
|
||||
# self.session: Optional[ClientSession] = None
|
||||
# self.stdio = None
|
||||
# self.write = None
|
||||
# self.initialized = False
|
||||
# self.cleanup_funcs = []
|
||||
#
|
||||
# async def connect_to_server(self):
|
||||
#
|
||||
# success = await self._initialize_connection(self.server_config, timeout=tool_settings.mcp_connect_to_server_timeout)
|
||||
#
|
||||
# if success:
|
||||
# self.initialized = True
|
||||
# else:
|
||||
# raise RuntimeError(
|
||||
# f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
|
||||
# )
|
||||
#
|
||||
# async def list_tools(self) -> List[MCPTool]:
|
||||
# self._check_initialized()
|
||||
# response = await self.session.list_tools()
|
||||
# return response.tools
|
||||
#
|
||||
# async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
# self._check_initialized()
|
||||
# result = await self.session.call_tool(tool_name, tool_args)
|
||||
#
|
||||
# parsed_content = []
|
||||
# for content_piece in result.content:
|
||||
# if isinstance(content_piece, TextContent):
|
||||
# parsed_content.append(content_piece.text)
|
||||
# else:
|
||||
# parsed_content.append(str(content_piece))
|
||||
#
|
||||
# if len(parsed_content) > 0:
|
||||
# final_content = " ".join(parsed_content)
|
||||
# else:
|
||||
# # TODO move hardcoding to constants
|
||||
# final_content = "Empty response from tool"
|
||||
#
|
||||
# return final_content, result.isError
|
||||
#
|
||||
# def _check_initialized(self):
|
||||
# if not self.initialized:
|
||||
# logger.error("MCPClient has not been initialized")
|
||||
# raise RuntimeError("MCPClient has not been initialized")
|
||||
#
|
||||
# async def cleanup(self):
|
||||
# try:
|
||||
# for cleanup_func in self.cleanup_funcs:
|
||||
# cleanup_func()
|
||||
# self.initialized = False
|
||||
# if not self.loop.is_closed():
|
||||
# self.loop.close()
|
||||
# except Exception as e:
|
||||
# logger.warning(e)
|
||||
# finally:
|
||||
# logger.info("Cleaned up MCP clients on shutdown.")
|
||||
#
|
||||
@@ -1,51 +0,0 @@
|
||||
# import asyncio
|
||||
#
|
||||
# from mcp import ClientSession
|
||||
# from mcp.client.sse import sse_client
|
||||
#
|
||||
# from letta.functions.mcp_client.base_client import BaseAsyncMCPClient, BaseMCPClient
|
||||
# from letta.functions.mcp_client.types import SSEServerConfig
|
||||
# from letta.log import get_logger
|
||||
#
|
||||
## see: https://modelcontextprotocol.io/quickstart/user
|
||||
#
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
|
||||
# class SSEMCPClient(BaseMCPClient):
|
||||
# def _initialize_connection(self, server_config: SSEServerConfig, timeout: float) -> bool:
|
||||
# try:
|
||||
# sse_cm = sse_client(url=server_config.server_url)
|
||||
# sse_transport = self.loop.run_until_complete(asyncio.wait_for(sse_cm.__aenter__(), timeout=timeout))
|
||||
# self.stdio, self.write = sse_transport
|
||||
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None)))
|
||||
#
|
||||
# session_cm = ClientSession(self.stdio, self.write)
|
||||
# self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
|
||||
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
# return True
|
||||
# except asyncio.TimeoutError:
|
||||
# logger.error(f"Timed out while establishing SSE connection (timeout={timeout}s).")
|
||||
# return False
|
||||
# except Exception:
|
||||
# logger.exception("Exception occurred while initializing SSE client session.")
|
||||
# return False
|
||||
#
|
||||
#
|
||||
# class AsyncSSEMCPClient(BaseAsyncMCPClient):
|
||||
#
|
||||
# async def _initialize_connection(self, server_config: SSEServerConfig, timeout: float) -> bool:
|
||||
# try:
|
||||
# sse_cm = sse_client(url=server_config.server_url)
|
||||
# sse_transport = await sse_cm.__aenter__()
|
||||
# self.stdio, self.write = sse_transport
|
||||
# self.cleanup_funcs.append(lambda: sse_cm.__aexit__(None, None, None))
|
||||
#
|
||||
# session_cm = ClientSession(self.stdio, self.write)
|
||||
# self.session = await session_cm.__aenter__()
|
||||
# self.cleanup_funcs.append(lambda: session_cm.__aexit__(None, None, None))
|
||||
# return True
|
||||
# except Exception:
|
||||
# logger.exception("Exception occurred while initializing SSE client session.")
|
||||
# return False
|
||||
#
|
||||
@@ -1,109 +0,0 @@
|
||||
# import asyncio
|
||||
# import sys
|
||||
# from contextlib import asynccontextmanager
|
||||
#
|
||||
# import anyio
|
||||
# import anyio.lowlevel
|
||||
# import mcp.types as types
|
||||
# from anyio.streams.text import TextReceiveStream
|
||||
# from mcp import ClientSession, StdioServerParameters
|
||||
# from mcp.client.stdio import get_default_environment
|
||||
#
|
||||
# from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
# from letta.functions.mcp_client.types import StdioServerConfig
|
||||
# from letta.log import get_logger
|
||||
#
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
|
||||
# class StdioMCPClient(BaseMCPClient):
|
||||
# def _initialize_connection(self, server_config: StdioServerConfig, timeout: float) -> bool:
|
||||
# try:
|
||||
# server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env)
|
||||
# stdio_cm = forked_stdio_client(server_params)
|
||||
# stdio_transport = self.loop.run_until_complete(asyncio.wait_for(stdio_cm.__aenter__(), timeout=timeout))
|
||||
# self.stdio, self.write = stdio_transport
|
||||
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None)))
|
||||
#
|
||||
# session_cm = ClientSession(self.stdio, self.write)
|
||||
# self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
|
||||
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
# return True
|
||||
# except asyncio.TimeoutError:
|
||||
# logger.error(f"Timed out while establishing stdio connection (timeout={timeout}s).")
|
||||
# return False
|
||||
# except Exception:
|
||||
# logger.exception("Exception occurred while initializing stdio client session.")
|
||||
# return False
|
||||
#
|
||||
#
|
||||
# @asynccontextmanager
|
||||
# async def forked_stdio_client(server: StdioServerParameters):
|
||||
# """
|
||||
# Client transport for stdio: this will connect to a server by spawning a
|
||||
# process and communicating with it over stdin/stdout.
|
||||
# """
|
||||
# read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
# write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
#
|
||||
# try:
|
||||
# process = await anyio.open_process(
|
||||
# [server.command, *server.args],
|
||||
# env=server.env or get_default_environment(),
|
||||
# stderr=sys.stderr, # Consider logging stderr somewhere instead of silencing it
|
||||
# )
|
||||
# except OSError as exc:
|
||||
# raise RuntimeError(f"Failed to spawn process: {server.command} {server.args}") from exc
|
||||
#
|
||||
# async def stdout_reader():
|
||||
# assert process.stdout, "Opened process is missing stdout"
|
||||
# buffer = ""
|
||||
# try:
|
||||
# async with read_stream_writer:
|
||||
# async for chunk in TextReceiveStream(
|
||||
# process.stdout,
|
||||
# encoding=server.encoding,
|
||||
# errors=server.encoding_error_handler,
|
||||
# ):
|
||||
# lines = (buffer + chunk).split("\n")
|
||||
# buffer = lines.pop()
|
||||
# for line in lines:
|
||||
# try:
|
||||
# message = types.JSONRPCMessage.model_validate_json(line)
|
||||
# except Exception as exc:
|
||||
# await read_stream_writer.send(exc)
|
||||
# continue
|
||||
# await read_stream_writer.send(message)
|
||||
# except anyio.ClosedResourceError:
|
||||
# await anyio.lowlevel.checkpoint()
|
||||
#
|
||||
# async def stdin_writer():
|
||||
# assert process.stdin, "Opened process is missing stdin"
|
||||
# try:
|
||||
# async with write_stream_reader:
|
||||
# async for message in write_stream_reader:
|
||||
# json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
# await process.stdin.send(
|
||||
# (json + "\n").encode(
|
||||
# encoding=server.encoding,
|
||||
# errors=server.encoding_error_handler,
|
||||
# )
|
||||
# )
|
||||
# except anyio.ClosedResourceError:
|
||||
# await anyio.lowlevel.checkpoint()
|
||||
#
|
||||
# async def watch_process_exit():
|
||||
# returncode = await process.wait()
|
||||
# if returncode != 0:
|
||||
# raise RuntimeError(f"Subprocess exited with code {returncode}. Command: {server.command} {server.args}")
|
||||
#
|
||||
# async with anyio.create_task_group() as tg, process:
|
||||
# tg.start_soon(stdout_reader)
|
||||
# tg.start_soon(stdin_writer)
|
||||
# tg.start_soon(watch_process_exit)
|
||||
#
|
||||
# with anyio.move_on_after(0.2):
|
||||
# await anyio.sleep_forever()
|
||||
#
|
||||
# yield read_stream, write_stream
|
||||
#
|
||||
@@ -148,9 +148,21 @@ class SSEServerConfig(BaseServerConfig):
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with SSE requests")
|
||||
|
||||
def resolve_token(self) -> Optional[str]:
|
||||
if self.auth_token and self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
|
||||
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
|
||||
return self.auth_token
|
||||
"""
|
||||
Extract token for storage if auth_header/auth_token are provided
|
||||
and not already in custom_headers.
|
||||
|
||||
Returns:
|
||||
The resolved token (without Bearer prefix) if it should be stored separately, None otherwise
|
||||
"""
|
||||
if self.auth_token and self.auth_header:
|
||||
# Check if custom_headers already has the auth header
|
||||
if not self.custom_headers or self.auth_header not in self.custom_headers:
|
||||
# Strip Bearer prefix if present
|
||||
if self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
|
||||
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
|
||||
return self.auth_token
|
||||
return None
|
||||
|
||||
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
|
||||
if self.auth_token and super().is_templated_tool_variable(self.auth_token):
|
||||
@@ -217,9 +229,21 @@ class StreamableHTTPServerConfig(BaseServerConfig):
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with streamable HTTP requests")
|
||||
|
||||
def resolve_token(self) -> Optional[str]:
|
||||
if self.auth_token and self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
|
||||
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
|
||||
return self.auth_token
|
||||
"""
|
||||
Extract token for storage if auth_header/auth_token are provided
|
||||
and not already in custom_headers.
|
||||
|
||||
Returns:
|
||||
The resolved token (without Bearer prefix) if it should be stored separately, None otherwise
|
||||
"""
|
||||
if self.auth_token and self.auth_header:
|
||||
# Check if custom_headers already has the auth header
|
||||
if not self.custom_headers or self.auth_header not in self.custom_headers:
|
||||
# Strip Bearer prefix if present
|
||||
if self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
|
||||
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
|
||||
return self.auth_token
|
||||
return None
|
||||
|
||||
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
|
||||
if self.auth_token and super().is_templated_tool_variable(self.auth_token):
|
||||
|
||||
@@ -608,13 +608,58 @@ def generate_tool_schema_for_mcp(
|
||||
# Normalise so downstream code can treat it consistently.
|
||||
parameters_schema.setdefault("required", [])
|
||||
|
||||
# Process properties to handle anyOf types and make optional fields strict-compatible
|
||||
if "properties" in parameters_schema:
|
||||
for field_name, field_props in parameters_schema["properties"].items():
|
||||
# Handle anyOf types by flattening to type array
|
||||
if "anyOf" in field_props and "type" not in field_props:
|
||||
types = []
|
||||
format_value = None
|
||||
for option in field_props["anyOf"]:
|
||||
if "type" in option:
|
||||
types.append(option["type"])
|
||||
# Capture format if present (e.g., uuid format for strings)
|
||||
if "format" in option and not format_value:
|
||||
format_value = option["format"]
|
||||
if types:
|
||||
# Deduplicate types using set
|
||||
field_props["type"] = list(set(types))
|
||||
# Only add format if the field is not optional (doesn't have null type)
|
||||
if format_value and len(field_props["type"]) == 1 and "null" not in field_props["type"]:
|
||||
field_props["format"] = format_value
|
||||
# Remove the anyOf since we've flattened it
|
||||
del field_props["anyOf"]
|
||||
|
||||
# For strict mode: heal optional fields by making them required with null type
|
||||
if strict and field_name not in parameters_schema["required"]:
|
||||
# Field is optional - add it to required array
|
||||
parameters_schema["required"].append(field_name)
|
||||
|
||||
# Ensure the field can accept null to maintain optionality
|
||||
if "type" in field_props:
|
||||
if isinstance(field_props["type"], list):
|
||||
# Already an array of types - add null if not present
|
||||
if "null" not in field_props["type"]:
|
||||
field_props["type"].append("null")
|
||||
# Deduplicate
|
||||
field_props["type"] = list(set(field_props["type"]))
|
||||
elif field_props["type"] != "null":
|
||||
# Single type - convert to array with null
|
||||
field_props["type"] = list(set([field_props["type"], "null"]))
|
||||
elif "anyOf" in field_props:
|
||||
# If there's still an anyOf, ensure null is one of the options
|
||||
has_null = any(opt.get("type") == "null" for opt in field_props["anyOf"])
|
||||
if not has_null:
|
||||
field_props["anyOf"].append({"type": "null"})
|
||||
|
||||
# Add the optional heartbeat parameter
|
||||
if append_heartbeat:
|
||||
parameters_schema["properties"][REQUEST_HEARTBEAT_PARAM] = {
|
||||
"type": "boolean",
|
||||
"description": REQUEST_HEARTBEAT_DESCRIPTION,
|
||||
}
|
||||
parameters_schema["required"].append(REQUEST_HEARTBEAT_PARAM)
|
||||
if REQUEST_HEARTBEAT_PARAM not in parameters_schema["required"]:
|
||||
parameters_schema["required"].append(REQUEST_HEARTBEAT_PARAM)
|
||||
|
||||
# Return the final schema
|
||||
if strict:
|
||||
|
||||
@@ -116,15 +116,21 @@ def validate_complete_json_schema(schema: Dict[str, Any]) -> Tuple[SchemaHealth,
|
||||
|
||||
required = node.get("required")
|
||||
if required is None:
|
||||
# TODO: @jnjpng skip this check for now, seems like OpenAI strict mode doesn't enforce this
|
||||
# Only mark as non-strict for nested objects, not root
|
||||
if not is_root:
|
||||
mark_non_strict(f"{path}: 'required' not specified for object")
|
||||
# if not is_root:
|
||||
# mark_non_strict(f"{path}: 'required' not specified for object")
|
||||
required = []
|
||||
elif not isinstance(required, list):
|
||||
mark_invalid(f"{path}: 'required' must be a list if present")
|
||||
required = []
|
||||
|
||||
# OpenAI strict-mode extra checks:
|
||||
# NOTE: We no longer flag properties not in required array as non-strict
|
||||
# because we can heal these schemas by adding null to the type union
|
||||
# This allows MCP tools with optional fields to be used with strict mode
|
||||
# The healing happens in generate_tool_schema_for_mcp() when strict=True
|
||||
|
||||
for req_key in required:
|
||||
if props and req_key not in props:
|
||||
mark_invalid(f"{path}: required contains '{req_key}' not found in properties")
|
||||
@@ -161,6 +167,15 @@ def validate_complete_json_schema(schema: Dict[str, Any]) -> Tuple[SchemaHealth,
|
||||
# These are generally fine, but check for specific constraints
|
||||
pass
|
||||
|
||||
# TYPE ARRAYS (e.g., ["string", "null"] for optional fields)
|
||||
elif isinstance(node_type, list):
|
||||
# Type arrays are allowed in OpenAI strict mode
|
||||
# They represent union types (e.g., string | null)
|
||||
for t in node_type:
|
||||
# TODO: @jnjpng handle enum types?
|
||||
if t not in ["string", "number", "integer", "boolean", "null", "array", "object"]:
|
||||
mark_invalid(f"{path}: Invalid type '{t}' in type array")
|
||||
|
||||
# UNION TYPES
|
||||
for kw in ("anyOf", "oneOf", "allOf"):
|
||||
if kw in node:
|
||||
|
||||
@@ -11,7 +11,7 @@ class SearchTask(BaseModel):
|
||||
class FileOpenRequest(BaseModel):
|
||||
file_name: str = Field(description="Name of the file to open")
|
||||
offset: Optional[int] = Field(
|
||||
default=None, description="Optional starting line number (1-indexed). If not specified, starts from beginning of file."
|
||||
default=None, description="Optional offset for starting line number (0-indexed). If not specified, starts from beginning of file."
|
||||
)
|
||||
length: Optional[int] = Field(
|
||||
default=None, description="Optional number of lines to view from offset (inclusive). If not specified, views to end of file."
|
||||
|
||||
@@ -39,12 +39,10 @@ def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
# Ensure parameters is a valid dictionary
|
||||
parameters = schema.get("parameters", {})
|
||||
|
||||
if isinstance(parameters, dict) and parameters.get("type") == "object":
|
||||
# Set additionalProperties to False
|
||||
parameters["additionalProperties"] = False
|
||||
schema["parameters"] = parameters
|
||||
|
||||
# Remove the metadata fields from the schema
|
||||
schema.pop(MCP_TOOL_METADATA_SCHEMA_STATUS, None)
|
||||
schema.pop(MCP_TOOL_METADATA_SCHEMA_WARNINGS, None)
|
||||
|
||||
@@ -287,12 +287,34 @@ class AnthropicClient(LLMClientBase):
|
||||
else:
|
||||
anthropic_tools = None
|
||||
|
||||
thinking_enabled = False
|
||||
if messages and len(messages) > 0:
|
||||
# Check if the last assistant message starts with a thinking block
|
||||
# Find the last assistant message
|
||||
last_assistant_message = None
|
||||
for message in reversed(messages):
|
||||
if message.get("role") == "assistant":
|
||||
last_assistant_message = message
|
||||
break
|
||||
|
||||
if (
|
||||
last_assistant_message
|
||||
and isinstance(last_assistant_message.get("content"), list)
|
||||
and len(last_assistant_message["content"]) > 0
|
||||
and last_assistant_message["content"][0].get("type") == "thinking"
|
||||
):
|
||||
thinking_enabled = True
|
||||
|
||||
try:
|
||||
result = await client.beta.messages.count_tokens(
|
||||
model=model or "claude-3-7-sonnet-20250219",
|
||||
messages=messages or [{"role": "user", "content": "hi"}],
|
||||
tools=anthropic_tools or [],
|
||||
)
|
||||
count_params = {
|
||||
"model": model or "claude-3-7-sonnet-20250219",
|
||||
"messages": messages or [{"role": "user", "content": "hi"}],
|
||||
"tools": anthropic_tools or [],
|
||||
}
|
||||
|
||||
if thinking_enabled:
|
||||
count_params["thinking"] = {"type": "enabled", "budget_tokens": 16000}
|
||||
result = await client.beta.messages.count_tokens(**count_params)
|
||||
except:
|
||||
raise
|
||||
|
||||
|
||||
97
letta/llm_api/deepseek_client.py
Normal file
97
letta/llm_api/deepseek_client.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.llm_api.deepseek import convert_deepseek_response_to_chatcompletion, map_messages_to_deepseek_format
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class DeepseekClient(OpenAIClient):
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
# Override put_inner_thoughts_in_kwargs to False for DeepSeek
|
||||
llm_config.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
|
||||
def add_functions_to_system_message(system_message: ChatMessage):
|
||||
system_message.content += f"<available functions> {''.join(json.dumps(f) for f in functions)} </available functions>"
|
||||
system_message.content += 'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.'
|
||||
|
||||
if llm_config.model == "deepseek-reasoner": # R1 currently doesn't support function calling natively
|
||||
add_functions_to_system_message(
|
||||
data["messages"][0]
|
||||
) # Inject additional instructions to the system prompt with the available functions
|
||||
|
||||
data["messages"] = map_messages_to_deepseek_format(data["messages"])
|
||||
|
||||
return data
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
|
||||
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
api_key = model_settings.deepseek_api_key or os.environ.get("DEEPSEEK_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data, stream=True, stream_options={"include_usage": True}
|
||||
)
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
|
||||
llm_config: LLMConfig,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
|
||||
Handles potential extraction of inner thoughts if they were added via kwargs.
|
||||
"""
|
||||
response = ChatCompletionResponse(**response_data)
|
||||
return convert_deepseek_response_to_chatcompletion(response)
|
||||
79
letta/llm_api/groq_client.py
Normal file
79
letta/llm_api/groq_client.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class GroqClient(OpenAIClient):
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return True
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
|
||||
# Groq validation - these fields are not supported and will cause 400 errors
|
||||
# https://console.groq.com/docs/openai
|
||||
if "top_logprobs" in data:
|
||||
del data["top_logprobs"]
|
||||
if "logit_bias" in data:
|
||||
del data["logit_bias"]
|
||||
data["logprobs"] = False
|
||||
data["n"] = 1
|
||||
|
||||
return data
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to Groq API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
|
||||
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to Groq API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
"""Request embeddings given texts and embedding config"""
|
||||
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
# TODO: add total usage
|
||||
return [r.embedding for r in response.data]
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
raise NotImplementedError("Streaming not supported for Groq.")
|
||||
@@ -133,7 +133,6 @@ def convert_to_structured_output(openai_function: dict, allow_optional: bool = F
|
||||
structured_output["parameters"]["required"] = list(structured_output["parameters"]["properties"].keys())
|
||||
else:
|
||||
raise NotImplementedError("Optional parameter handling is not implemented.")
|
||||
|
||||
return structured_output
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import requests
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.errors import LettaConfigurationError, RateLimitExceededError
|
||||
from letta.llm_api.deepseek import build_deepseek_chat_completions_request, convert_deepseek_response_to_chatcompletion
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs
|
||||
from letta.llm_api.helpers import unpack_all_inner_thoughts_from_kwargs
|
||||
from letta.llm_api.openai import (
|
||||
build_openai_chat_completions_request,
|
||||
openai_chat_completions_process_stream,
|
||||
@@ -16,14 +16,13 @@ from letta.llm_api.openai import (
|
||||
prepare_openai_payload,
|
||||
)
|
||||
from letta.local_llm.chat_completion_proxy import get_chat_completion
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.orm.user import User
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
@@ -246,116 +245,6 @@ def create(
|
||||
|
||||
return response
|
||||
|
||||
elif llm_config.model_endpoint_type == "xai":
|
||||
api_key = model_settings.xai_api_key
|
||||
|
||||
if function_call is None and functions is not None and len(functions) > 0:
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
function_call = "required"
|
||||
|
||||
data = build_openai_chat_completions_request(
|
||||
llm_config,
|
||||
messages,
|
||||
user_id,
|
||||
functions,
|
||||
function_call,
|
||||
use_tool_naming,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
use_structured_output=False, # NOTE: not supported atm for xAI
|
||||
)
|
||||
|
||||
# Specific bug for the mini models (as of Apr 14, 2025)
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
|
||||
if "grok-3-mini-" in llm_config.model:
|
||||
data.presence_penalty = None
|
||||
data.frequency_penalty = None
|
||||
|
||||
if stream: # Client requested token streaming
|
||||
data.stream = True
|
||||
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
|
||||
stream_interface, AgentRefreshStreamingInterface
|
||||
), type(stream_interface)
|
||||
response = openai_chat_completions_process_stream(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=api_key,
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
name=name,
|
||||
# TODO turn on to support reasoning content from xAI reasoners:
|
||||
# https://docs.x.ai/docs/guides/reasoning#reasoning
|
||||
expect_reasoning_content=False,
|
||||
)
|
||||
else: # Client did not request token streaming (expect a blocking backend response)
|
||||
data.stream = False
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.stream_start()
|
||||
try:
|
||||
response = openai_chat_completions_request(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=api_key,
|
||||
chat_completion_request=data,
|
||||
)
|
||||
finally:
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.stream_end()
|
||||
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
|
||||
|
||||
return response
|
||||
|
||||
elif llm_config.model_endpoint_type == "groq":
|
||||
if stream:
|
||||
raise NotImplementedError("Streaming not yet implemented for Groq.")
|
||||
|
||||
if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions":
|
||||
raise LettaConfigurationError(message="Groq key is missing from letta config file", missing_fields=["groq_api_key"])
|
||||
|
||||
# force to true for groq, since they don't support 'content' is non-null
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
functions = add_inner_thoughts_to_functions(
|
||||
functions=functions,
|
||||
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
||||
inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION,
|
||||
)
|
||||
|
||||
tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None
|
||||
data = ChatCompletionRequest(
|
||||
model=llm_config.model,
|
||||
messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) for m in messages],
|
||||
tools=tools,
|
||||
tool_choice=function_call,
|
||||
user=str(user_id),
|
||||
)
|
||||
|
||||
# https://console.groq.com/docs/openai
|
||||
# "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:"
|
||||
assert data.top_logprobs is None
|
||||
assert data.logit_bias is None
|
||||
assert data.logprobs == False
|
||||
assert data.n == 1
|
||||
# They mention that none of the messages can have names, but it seems to not error out (for now)
|
||||
|
||||
data.stream = False
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.stream_start()
|
||||
try:
|
||||
# groq uses the openai chat completions API, so this component should be reusable
|
||||
response = openai_chat_completions_request(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=model_settings.groq_api_key,
|
||||
chat_completion_request=data,
|
||||
)
|
||||
finally:
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.stream_end()
|
||||
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
|
||||
|
||||
return response
|
||||
|
||||
elif llm_config.model_endpoint_type == "deepseek":
|
||||
if model_settings.deepseek_api_key is None and llm_config.model_endpoint == "":
|
||||
# only is a problem if we are *not* using an openai proxy
|
||||
|
||||
@@ -79,5 +79,26 @@ class LLMClient:
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.xai:
|
||||
from letta.llm_api.xai_client import XAIClient
|
||||
|
||||
return XAIClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.groq:
|
||||
from letta.llm_api.groq_client import GroqClient
|
||||
|
||||
return GroqClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.deepseek:
|
||||
from letta.llm_api.deepseek_client import DeepseekClient
|
||||
|
||||
return DeepseekClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.settings import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import User
|
||||
@@ -90,15 +91,16 @@ class LLMClientBase:
|
||||
try:
|
||||
log_event(name="llm_request_sent", attributes=request_data)
|
||||
response_data = await self.request_async(request_data, llm_config)
|
||||
await telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
if settings.track_provider_trace and telemetry_manager:
|
||||
await telemetry_manager.create_provider_trace_async(
|
||||
actor=self.actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
request_json=request_data,
|
||||
response_json=response_data,
|
||||
step_id=step_id,
|
||||
organization_id=self.actor.organization_id,
|
||||
),
|
||||
)
|
||||
|
||||
log_event(name="llm_response_received", attributes=response_data)
|
||||
except Exception as e:
|
||||
|
||||
@@ -146,6 +146,9 @@ class OpenAIClient(LLMClientBase):
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return requires_auto_tool_choice(llm_config)
|
||||
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return supports_structured_output(llm_config)
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
|
||||
85
letta/llm_api/xai_client.py
Normal file
85
letta/llm_api/xai_client.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class XAIClient(OpenAIClient):
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
|
||||
# Specific bug for the mini models (as of Apr 14, 2025)
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
|
||||
if "grok-3-mini-" in llm_config.model:
|
||||
data.pop("presence_penalty", None)
|
||||
data.pop("frequency_penalty", None)
|
||||
|
||||
return data
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data, stream=True, stream_options={"include_usage": True}
|
||||
)
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
"""Request embeddings given texts and embedding config"""
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
# TODO: add total usage
|
||||
return [r.embedding for r in response.data]
|
||||
190
letta/prompts/prompt_generator.py
Normal file
190
letta/prompts/prompt_generator.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import format_datetime, get_local_time_fast
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.memory import Memory
|
||||
|
||||
|
||||
class PromptGenerator:
|
||||
|
||||
# TODO: This code is kind of wonky and deserves a rewrite
|
||||
@trace_method
|
||||
@staticmethod
|
||||
def compile_memory_metadata_block(
|
||||
memory_edit_timestamp: datetime,
|
||||
timezone: str,
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: Optional[int] = 0,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a memory metadata block for the agent's system prompt.
|
||||
|
||||
This creates a structured metadata section that informs the agent about
|
||||
the current state of its memory systems, including timing information
|
||||
and memory counts. This helps the agent understand what information
|
||||
is available through its tools.
|
||||
|
||||
Args:
|
||||
memory_edit_timestamp: When memory blocks were last modified
|
||||
timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles')
|
||||
previous_message_count: Number of messages in recall memory (conversation history)
|
||||
archival_memory_size: Number of items in archival memory (long-term storage)
|
||||
|
||||
Returns:
|
||||
A formatted string containing the memory metadata block with XML-style tags
|
||||
|
||||
Example Output:
|
||||
<memory_metadata>
|
||||
- The current time is: 2024-01-15 10:30 AM PST
|
||||
- Memory blocks were last modified: 2024-01-15 09:00 AM PST
|
||||
- 42 previous messages between you and the user are stored in recall memory (use tools to access them)
|
||||
- 156 total memories you created are stored in archival memory (use tools to access them)
|
||||
</memory_metadata>
|
||||
"""
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
timestamp_str = format_datetime(memory_edit_timestamp, timezone)
|
||||
|
||||
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
|
||||
metadata_lines = [
|
||||
"<memory_metadata>",
|
||||
f"- The current time is: {get_local_time_fast(timezone)}",
|
||||
f"- Memory blocks were last modified: {timestamp_str}",
|
||||
f"- {previous_message_count} previous messages between you and the user are stored in recall memory (use tools to access them)",
|
||||
]
|
||||
|
||||
# Only include archival memory line if there are archival memories
|
||||
if archival_memory_size is not None and archival_memory_size > 0:
|
||||
metadata_lines.append(
|
||||
f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)"
|
||||
)
|
||||
|
||||
metadata_lines.append("</memory_metadata>")
|
||||
memory_metadata_block = "\n".join(metadata_lines)
|
||||
return memory_metadata_block
|
||||
|
||||
@staticmethod
|
||||
def safe_format(template: str, variables: dict) -> str:
|
||||
"""
|
||||
Safely formats a template string, preserving empty {} and {unknown_vars}
|
||||
while substituting known variables.
|
||||
|
||||
If we simply use {} in format_map, it'll be treated as a positional field
|
||||
"""
|
||||
# First escape any empty {} by doubling them
|
||||
escaped = template.replace("{}", "{{}}")
|
||||
|
||||
# Now use format_map with our custom mapping
|
||||
return escaped.format_map(PreserveMapping(variables))
|
||||
|
||||
@trace_method
|
||||
@staticmethod
|
||||
def get_system_message_from_compiled_memory(
|
||||
system_prompt: str,
|
||||
memory_with_sources: str,
|
||||
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
|
||||
timezone: str,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
The base system message may be templated, in which case we need to render the variables.
|
||||
|
||||
The following are reserved variables:
|
||||
- CORE_MEMORY: the in-context memory of the LLM
|
||||
"""
|
||||
if user_defined_variables is not None:
|
||||
# TODO eventually support the user defining their own variables to inject
|
||||
raise NotImplementedError
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
# Add the protected memory variable
|
||||
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
||||
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
||||
else:
|
||||
# TODO should this all put into the memory.__repr__ function?
|
||||
memory_metadata_string = PromptGenerator.compile_memory_metadata_block(
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
|
||||
|
||||
# Add to the variables list to inject
|
||||
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
||||
|
||||
if template_format == "f-string":
|
||||
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
||||
|
||||
# Catch the special case where the system prompt is unformatted
|
||||
if append_icm_if_missing:
|
||||
if memory_variable_string not in system_prompt:
|
||||
# In this case, append it to the end to make sure memory is still injected
|
||||
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
||||
system_prompt += "\n\n" + memory_variable_string
|
||||
|
||||
# render the variables using the built-in templater
|
||||
try:
|
||||
if user_defined_variables:
|
||||
formatted_prompt = PromptGenerator.safe_format(system_prompt, variables)
|
||||
else:
|
||||
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
||||
|
||||
else:
|
||||
# TODO support for mustache and jinja2
|
||||
raise NotImplementedError(template_format)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
@trace_method
|
||||
@staticmethod
|
||||
async def compile_system_message_async(
|
||||
system_prompt: str,
|
||||
in_context_memory: Memory,
|
||||
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
|
||||
timezone: str,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
tool_rules_solver: Optional[ToolRulesSolver] = None,
|
||||
sources: Optional[List] = None,
|
||||
max_files_open: Optional[int] = None,
|
||||
) -> str:
|
||||
tool_constraint_block = None
|
||||
if tool_rules_solver is not None:
|
||||
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
|
||||
|
||||
if user_defined_variables is not None:
|
||||
# TODO eventually support the user defining their own variables to inject
|
||||
raise NotImplementedError
|
||||
else:
|
||||
pass
|
||||
|
||||
memory_with_sources = await in_context_memory.compile_in_thread_async(
|
||||
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
|
||||
)
|
||||
|
||||
return PromptGenerator.get_system_message_from_compiled_memory(
|
||||
system_prompt=system_prompt,
|
||||
memory_with_sources=memory_with_sources,
|
||||
in_context_memory_last_edit=in_context_memory_last_edit,
|
||||
timezone=timezone,
|
||||
user_defined_variables=user_defined_variables,
|
||||
append_icm_if_missing=append_icm_if_missing,
|
||||
template_format=template_format,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
)
|
||||
@@ -1,15 +1,17 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.block import Block, CreateBlock
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
|
||||
from letta.schemas.group import Group, GroupCreate
|
||||
from letta.schemas.mcp import MCPServer
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.source import Source, SourceCreate
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
@@ -46,6 +48,15 @@ class MessageSchema(MessageCreate):
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
model: Optional[str] = Field(None, description="The model used to make the function call")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent")
|
||||
tool_calls: Optional[List[OpenAIToolCall]] = Field(
|
||||
default=None, description="The list of tool calls requested. Only applicable for role assistant."
|
||||
)
|
||||
tool_call_id: Optional[str] = Field(default=None, description="The ID of the tool call. Only applicable for role tool.")
|
||||
tool_returns: Optional[List[ToolReturn]] = Field(default=None, description="Tool execution return information for prior tool calls")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
# TODO: Should we also duplicate the steps here?
|
||||
# TODO: What about tool_return?
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: Message) -> "MessageSchema":
|
||||
@@ -64,6 +75,10 @@ class MessageSchema(MessageCreate):
|
||||
group_id=message.group_id,
|
||||
model=message.model,
|
||||
agent_id=message.agent_id,
|
||||
tool_calls=message.tool_calls,
|
||||
tool_call_id=message.tool_call_id,
|
||||
tool_returns=message.tool_returns,
|
||||
created_at=message.created_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -114,7 +129,7 @@ class AgentSchema(CreateAgent):
|
||||
memory_blocks=[], # TODO: Convert from agent_state.memory if needed
|
||||
tools=[],
|
||||
tool_ids=[tool.id for tool in agent_state.tools] if agent_state.tools else [],
|
||||
source_ids=[], # [source.id for source in agent_state.sources] if agent_state.sources else [],
|
||||
source_ids=[source.id for source in agent_state.sources] if agent_state.sources else [],
|
||||
block_ids=[block.id for block in agent_state.memory.blocks],
|
||||
tool_rules=agent_state.tool_rules,
|
||||
tags=agent_state.tags,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -108,3 +108,26 @@ class FileAgent(FileAgentBase):
|
||||
default_factory=datetime.utcnow,
|
||||
description="Row last-update timestamp (UTC).",
|
||||
)
|
||||
|
||||
|
||||
class AgentFileAttachment(LettaBase):
|
||||
"""Response model for agent file attachments showing file status in agent context"""
|
||||
|
||||
id: str = Field(..., description="Unique identifier of the file-agent relationship")
|
||||
file_id: str = Field(..., description="Unique identifier of the file")
|
||||
file_name: str = Field(..., description="Name of the file")
|
||||
folder_id: str = Field(..., description="Unique identifier of the folder/source")
|
||||
folder_name: str = Field(..., description="Name of the folder/source")
|
||||
is_open: bool = Field(..., description="Whether the file is currently open in the agent's context")
|
||||
last_accessed_at: Optional[datetime] = Field(None, description="Timestamp of last access by the agent")
|
||||
visible_content: Optional[str] = Field(None, description="Portion of the file visible to the agent if open")
|
||||
start_line: Optional[int] = Field(None, description="Starting line number if file was opened with line range")
|
||||
end_line: Optional[int] = Field(None, description="Ending line number if file was opened with line range")
|
||||
|
||||
|
||||
class PaginatedAgentFiles(LettaBase):
|
||||
"""Paginated response for agent files"""
|
||||
|
||||
files: List[AgentFileAttachment] = Field(..., description="List of file attachments for the agent")
|
||||
next_cursor: Optional[str] = Field(None, description="Cursor for fetching the next page (file-agent relationship ID)")
|
||||
has_more: bool = Field(..., description="Whether more results exist after this page")
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.enums import JobStatus, JobType
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import MessageType
|
||||
@@ -12,6 +13,7 @@ from letta.schemas.letta_message import MessageType
|
||||
class JobBase(OrmMetadataBase):
|
||||
__id_prefix__ = "job"
|
||||
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
|
||||
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
|
||||
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="The metadata of the job.")
|
||||
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")
|
||||
|
||||
@@ -52,6 +52,8 @@ class LettaMessage(BaseModel):
|
||||
sender_id: str | None = None
|
||||
step_id: str | None = None
|
||||
is_err: bool | None = None
|
||||
seq_id: int | None = None
|
||||
run_id: str | None = None
|
||||
|
||||
@field_serializer("date")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
|
||||
@@ -46,6 +46,10 @@ class LettaStreamingRequest(LettaRequest):
|
||||
default=False,
|
||||
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
|
||||
)
|
||||
background: bool = Field(
|
||||
default=False,
|
||||
description="Whether to process the request in the background.",
|
||||
)
|
||||
|
||||
|
||||
class LettaAsyncRequest(LettaRequest):
|
||||
@@ -66,3 +70,21 @@ class CreateBatch(BaseModel):
|
||||
"'status' is the final batch status (e.g., 'completed', 'failed'), and "
|
||||
"'completed_at' is an ISO 8601 timestamp indicating when the batch job completed.",
|
||||
)
|
||||
|
||||
|
||||
class RetrieveStreamRequest(BaseModel):
|
||||
starting_after: int = Field(
|
||||
0, description="Sequence id to use as a cursor for pagination. Response will start streaming after this chunk sequence id"
|
||||
)
|
||||
include_pings: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
|
||||
)
|
||||
poll_interval: Optional[float] = Field(
|
||||
default=0.1,
|
||||
description="Seconds to wait between polls when no new data.",
|
||||
)
|
||||
batch_size: Optional[int] = Field(
|
||||
default=100,
|
||||
description="Number of entries to read per batch.",
|
||||
)
|
||||
|
||||
@@ -414,6 +414,8 @@ class Message(BaseMessage):
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode function return: {text_content}")
|
||||
|
||||
# if self.tool_call_id is None:
|
||||
# import pdb;pdb.set_trace()
|
||||
assert self.tool_call_id is not None
|
||||
|
||||
return ToolReturnMessage(
|
||||
@@ -844,7 +846,7 @@ class Message(BaseMessage):
|
||||
}
|
||||
content = []
|
||||
# COT / reasoning / thinking
|
||||
if self.content is not None and len(self.content) > 1:
|
||||
if self.content is not None and len(self.content) >= 1:
|
||||
for content_part in self.content:
|
||||
if isinstance(content_part, ReasoningContent):
|
||||
content.append(
|
||||
@@ -861,6 +863,13 @@ class Message(BaseMessage):
|
||||
"data": content_part.data,
|
||||
}
|
||||
)
|
||||
if isinstance(content_part, TextContent):
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": content_part.text,
|
||||
}
|
||||
)
|
||||
elif text_content is not None:
|
||||
content.append(
|
||||
{
|
||||
|
||||
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
|
||||
class BedrockProvider(Provider):
|
||||
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
access_key: str = Field(..., description="AWS secret access key for Bedrock.")
|
||||
region: str = Field(..., description="AWS region for Bedrock")
|
||||
|
||||
async def bedrock_get_model_list_async(self) -> list[dict]:
|
||||
|
||||
300
letta/server/rest_api/redis_stream_manager.py
Normal file
300
letta/server/rest_api/redis_stream_manager.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Redis stream manager for reading and writing SSE chunks with batching and TTL."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import AsyncIterator, Dict, List, Optional
|
||||
|
||||
from letta.data_sources.redis_client import AsyncRedisClient
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisSSEStreamWriter:
|
||||
"""
|
||||
Efficiently writes SSE chunks to Redis streams with batching and TTL management.
|
||||
|
||||
Features:
|
||||
- Batches writes using Redis pipelines for performance
|
||||
- Automatically sets/refreshes TTL on streams
|
||||
- Tracks sequential IDs for cursor-based recovery
|
||||
- Handles flush on size or time thresholds
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: AsyncRedisClient,
|
||||
flush_interval: float = 0.5,
|
||||
flush_size: int = 50,
|
||||
stream_ttl_seconds: int = 10800, # 3 hours default
|
||||
max_stream_length: int = 10000, # Max entries per stream
|
||||
):
|
||||
"""
|
||||
Initialize the Redis SSE stream writer.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
flush_interval: Seconds between automatic flushes
|
||||
flush_size: Number of chunks to buffer before flushing
|
||||
stream_ttl_seconds: TTL for streams in seconds (default: 6 hours)
|
||||
max_stream_length: Maximum entries per stream before trimming
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.flush_interval = flush_interval
|
||||
self.flush_size = flush_size
|
||||
self.stream_ttl = stream_ttl_seconds
|
||||
self.max_stream_length = max_stream_length
|
||||
|
||||
# Buffer for batching: run_id -> list of chunks
|
||||
self.buffer: Dict[str, List[Dict]] = defaultdict(list)
|
||||
# Track sequence IDs per run
|
||||
self.seq_counters: Dict[str, int] = defaultdict(lambda: 1)
|
||||
# Track last flush time per run
|
||||
self.last_flush: Dict[str, float] = defaultdict(float)
|
||||
|
||||
# Background flush task
|
||||
self._flush_task = None
|
||||
self._running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the background flush task."""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the background flush task and flush remaining data."""
|
||||
self._running = False
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
for run_id in list(self.buffer.keys()):
|
||||
if self.buffer[run_id]:
|
||||
await self._flush_run(run_id)
|
||||
|
||||
async def write_chunk(
|
||||
self,
|
||||
run_id: str,
|
||||
data: str,
|
||||
is_complete: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Write an SSE chunk to the buffer for a specific run.
|
||||
|
||||
Args:
|
||||
run_id: The run ID to write to
|
||||
data: SSE-formatted chunk data
|
||||
is_complete: Whether this is the final chunk
|
||||
|
||||
Returns:
|
||||
The sequence ID assigned to this chunk
|
||||
"""
|
||||
seq_id = self.seq_counters[run_id]
|
||||
self.seq_counters[run_id] += 1
|
||||
|
||||
chunk = {
|
||||
"seq_id": seq_id,
|
||||
"data": data,
|
||||
"timestamp": int(time.time() * 1000),
|
||||
}
|
||||
|
||||
if is_complete:
|
||||
chunk["complete"] = "true"
|
||||
|
||||
self.buffer[run_id].append(chunk)
|
||||
|
||||
should_flush = (
|
||||
len(self.buffer[run_id]) >= self.flush_size or is_complete or (time.time() - self.last_flush[run_id]) > self.flush_interval
|
||||
)
|
||||
|
||||
if should_flush:
|
||||
await self._flush_run(run_id)
|
||||
|
||||
return seq_id
|
||||
|
||||
async def _flush_run(self, run_id: str):
|
||||
"""Flush buffered chunks for a specific run to Redis."""
|
||||
if not self.buffer[run_id]:
|
||||
return
|
||||
|
||||
chunks = self.buffer[run_id]
|
||||
self.buffer[run_id] = []
|
||||
stream_key = f"sse:run:{run_id}"
|
||||
|
||||
try:
|
||||
client = await self.redis.get_client()
|
||||
|
||||
async with client.pipeline(transaction=False) as pipe:
|
||||
for chunk in chunks:
|
||||
pipe.xadd(stream_key, chunk, maxlen=self.max_stream_length, approximate=True)
|
||||
|
||||
pipe.expire(stream_key, self.stream_ttl)
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
self.last_flush[run_id] = time.time()
|
||||
|
||||
logger.debug(
|
||||
f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, " f"seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}"
|
||||
)
|
||||
|
||||
if chunks[-1].get("complete") == "true":
|
||||
self._cleanup_run(run_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to flush chunks for run {run_id}: {e}")
|
||||
# Put chunks back in buffer to retry
|
||||
self.buffer[run_id] = chunks + self.buffer[run_id]
|
||||
raise
|
||||
|
||||
async def _periodic_flush(self):
|
||||
"""Background task to periodically flush buffers."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.flush_interval)
|
||||
|
||||
# Check each run for time-based flush
|
||||
current_time = time.time()
|
||||
runs_to_flush = [
|
||||
run_id
|
||||
for run_id, last_flush in self.last_flush.items()
|
||||
if (current_time - last_flush) > self.flush_interval and self.buffer[run_id]
|
||||
]
|
||||
|
||||
for run_id in runs_to_flush:
|
||||
await self._flush_run(run_id)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in periodic flush: {e}")
|
||||
|
||||
def _cleanup_run(self, run_id: str):
|
||||
"""Clean up tracking data for a completed run."""
|
||||
self.buffer.pop(run_id, None)
|
||||
self.seq_counters.pop(run_id, None)
|
||||
self.last_flush.pop(run_id, None)
|
||||
|
||||
async def mark_complete(self, run_id: str):
|
||||
"""Mark a stream as complete and flush."""
|
||||
# Add a [DONE] marker
|
||||
await self.write_chunk(run_id, "data: [DONE]\n\n", is_complete=True)
|
||||
|
||||
|
||||
async def create_background_stream_processor(
|
||||
stream_generator,
|
||||
redis_client: AsyncRedisClient,
|
||||
run_id: str,
|
||||
writer: Optional[RedisSSEStreamWriter] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Process a stream in the background and store chunks to Redis.
|
||||
|
||||
This function consumes the stream generator and writes all chunks
|
||||
to Redis for later retrieval.
|
||||
|
||||
Args:
|
||||
stream_generator: The async generator yielding SSE chunks
|
||||
redis_client: Redis client instance
|
||||
run_id: The run ID to store chunks under
|
||||
writer: Optional pre-configured writer (creates new if not provided)
|
||||
"""
|
||||
if writer is None:
|
||||
writer = RedisSSEStreamWriter(redis_client)
|
||||
await writer.start()
|
||||
should_stop_writer = True
|
||||
else:
|
||||
should_stop_writer = False
|
||||
|
||||
try:
|
||||
async for chunk in stream_generator:
|
||||
if isinstance(chunk, tuple):
|
||||
chunk = chunk[0]
|
||||
|
||||
is_done = isinstance(chunk, str) and ("data: [DONE]" in chunk or "event: error" in chunk)
|
||||
|
||||
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
|
||||
|
||||
if is_done:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing stream for run {run_id}: {e}")
|
||||
# Write error chunk
|
||||
error_chunk = {"error": {"message": str(e)}}
|
||||
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=True)
|
||||
finally:
|
||||
if should_stop_writer:
|
||||
await writer.stop()
|
||||
|
||||
|
||||
async def redis_sse_stream_generator(
|
||||
redis_client: AsyncRedisClient,
|
||||
run_id: str,
|
||||
starting_after: Optional[int] = None,
|
||||
poll_interval: float = 0.1,
|
||||
batch_size: int = 100,
|
||||
) -> AsyncIterator[str]:
|
||||
"""
|
||||
Generate SSE events from Redis stream chunks.
|
||||
|
||||
This generator reads chunks stored in Redis streams and yields them as SSE events.
|
||||
It supports cursor-based recovery by allowing you to start from a specific seq_id.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client instance
|
||||
run_id: The run ID to read chunks for
|
||||
starting_after: Sequential ID (integer) to start reading from (default: None for beginning)
|
||||
poll_interval: Seconds to wait between polls when no new data (default: 0.1)
|
||||
batch_size: Number of entries to read per batch (default: 100)
|
||||
|
||||
Yields:
|
||||
SSE-formatted chunks from the Redis stream
|
||||
"""
|
||||
stream_key = f"sse:run:{run_id}"
|
||||
last_redis_id = "-"
|
||||
cursor_seq_id = starting_after or 0
|
||||
|
||||
logger.debug(f"Starting redis_sse_stream_generator for run_id={run_id}, stream_key={stream_key}")
|
||||
|
||||
while True:
|
||||
entries = await redis_client.xrange(stream_key, start=last_redis_id, count=batch_size)
|
||||
|
||||
if entries:
|
||||
yielded_any = False
|
||||
for entry_id, fields in entries:
|
||||
if entry_id == last_redis_id:
|
||||
continue
|
||||
|
||||
chunk_seq_id = int(fields.get("seq_id", 0))
|
||||
if chunk_seq_id > cursor_seq_id:
|
||||
data = fields.get("data", "")
|
||||
if not data:
|
||||
logger.debug(f"No data found for chunk {chunk_seq_id} in run {run_id}")
|
||||
continue
|
||||
|
||||
if '"run_id":null' in data:
|
||||
data = data.replace('"run_id":null', f'"run_id":"{run_id}"')
|
||||
|
||||
if '"seq_id":null' in data:
|
||||
data = data.replace('"seq_id":null', f'"seq_id":{chunk_seq_id}')
|
||||
|
||||
yield data
|
||||
yielded_any = True
|
||||
|
||||
if fields.get("complete") == "true":
|
||||
return
|
||||
|
||||
last_redis_id = entry_id
|
||||
|
||||
if not yielded_any and len(entries) > 1:
|
||||
continue
|
||||
|
||||
if not entries or (len(entries) == 1 and entries[0][0] == last_redis_id):
|
||||
await asyncio.sleep(poll_interval)
|
||||
@@ -14,7 +14,7 @@ from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import AGENT_ID_PATTERN, DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX
|
||||
from letta.data_sources.redis_client import get_redis_client
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
@@ -26,6 +26,7 @@ from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent_file import AgentFileSchema
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.enums import JobType
|
||||
from letta.schemas.file import AgentFileAttachment, PaginatedAgentFiles
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
|
||||
@@ -39,6 +40,7 @@ from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
@@ -249,6 +251,7 @@ async def import_agent(
|
||||
override_existing_tools: bool = True,
|
||||
project_id: str | None = None,
|
||||
strip_messages: bool = False,
|
||||
env_vars: Optional[dict[str, Any]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Import an agent using the new AgentFileSchema format.
|
||||
@@ -259,7 +262,13 @@ async def import_agent(
|
||||
raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}")
|
||||
|
||||
try:
|
||||
import_result = await server.agent_serialization_manager.import_file(schema=agent_schema, actor=actor)
|
||||
import_result = await server.agent_serialization_manager.import_file(
|
||||
schema=agent_schema,
|
||||
actor=actor,
|
||||
append_copy_suffix=append_copy_suffix,
|
||||
override_existing_tools=override_existing_tools,
|
||||
env_vars=env_vars,
|
||||
)
|
||||
|
||||
if not import_result.success:
|
||||
raise HTTPException(
|
||||
@@ -297,7 +306,9 @@ async def import_agent_serialized(
|
||||
False,
|
||||
description="If set to True, strips all messages from the agent before importing.",
|
||||
),
|
||||
env_vars: Optional[Dict[str, Any]] = Form(None, description="Environment variables to pass to the agent for tool execution."),
|
||||
env_vars_json: Optional[str] = Form(
|
||||
None, description="Environment variables as a JSON string to pass to the agent for tool execution."
|
||||
),
|
||||
):
|
||||
"""
|
||||
Import a serialized agent file and recreate the agent(s) in the system.
|
||||
@@ -311,6 +322,17 @@ async def import_agent_serialized(
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Corrupted agent file format.")
|
||||
|
||||
# Parse env_vars_json if provided
|
||||
env_vars = None
|
||||
if env_vars_json:
|
||||
try:
|
||||
env_vars = json.loads(env_vars_json)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="env_vars_json must be a valid JSON string")
|
||||
|
||||
if not isinstance(env_vars, dict):
|
||||
raise HTTPException(status_code=400, detail="env_vars_json must be a valid JSON string")
|
||||
|
||||
# Check if the JSON is AgentFileSchema or AgentSchema
|
||||
# TODO: This is kind of hacky, but should work as long as dont' change the schema
|
||||
if "agents" in agent_json and isinstance(agent_json.get("agents"), list):
|
||||
@@ -323,6 +345,7 @@ async def import_agent_serialized(
|
||||
override_existing_tools=override_existing_tools,
|
||||
project_id=project_id,
|
||||
strip_messages=strip_messages,
|
||||
env_vars=env_vars,
|
||||
)
|
||||
else:
|
||||
# This is a legacy AgentSchema
|
||||
@@ -728,6 +751,49 @@ async def list_agent_folders(
|
||||
return await server.agent_manager.list_attached_sources_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/files", response_model=PaginatedAgentFiles, operation_id="list_agent_files")
|
||||
async def list_agent_files(
|
||||
agent_id: str,
|
||||
cursor: Optional[str] = Query(None, description="Pagination cursor from previous response"),
|
||||
limit: int = Query(20, ge=1, le=100, description="Number of items to return (1-100)"),
|
||||
is_open: Optional[bool] = Query(None, description="Filter by open status (true for open files, false for closed files)"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the files attached to an agent with their open/closed status (paginated).
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# get paginated file-agent relationships for this agent
|
||||
file_agents, next_cursor, has_more = await server.file_agent_manager.list_files_for_agent_paginated(
|
||||
agent_id=agent_id, actor=actor, cursor=cursor, limit=limit, is_open=is_open
|
||||
)
|
||||
|
||||
# enrich with file and source metadata
|
||||
enriched_files = []
|
||||
for fa in file_agents:
|
||||
# get source/folder metadata
|
||||
source = await server.source_manager.get_source_by_id(source_id=fa.source_id, actor=actor)
|
||||
|
||||
# build response object
|
||||
attachment = AgentFileAttachment(
|
||||
id=fa.id,
|
||||
file_id=fa.file_id,
|
||||
file_name=fa.file_name,
|
||||
folder_id=fa.source_id,
|
||||
folder_name=source.name if source else "Unknown",
|
||||
is_open=fa.is_open,
|
||||
last_accessed_at=fa.last_accessed_at,
|
||||
visible_content=fa.visible_content,
|
||||
start_line=fa.start_line,
|
||||
end_line=fa.end_line,
|
||||
)
|
||||
enriched_files.append(attachment)
|
||||
|
||||
return PaginatedAgentFiles(files=enriched_files, next_cursor=next_cursor, has_more=has_more)
|
||||
|
||||
|
||||
# TODO: remove? can also get with agent blocks
|
||||
@router.get("/{agent_id}/core-memory", response_model=Memory, operation_id="retrieve_agent_memory")
|
||||
async def retrieve_agent_memory(
|
||||
@@ -999,7 +1065,8 @@ async def send_message(
|
||||
"bedrock",
|
||||
"ollama",
|
||||
"azure",
|
||||
"together",
|
||||
"xai",
|
||||
"groq",
|
||||
]
|
||||
|
||||
# Create a new run for execution tracking
|
||||
@@ -1143,7 +1210,8 @@ async def send_message_streaming(
|
||||
"bedrock",
|
||||
"ollama",
|
||||
"azure",
|
||||
"together",
|
||||
"xai",
|
||||
"groq",
|
||||
]
|
||||
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]
|
||||
|
||||
@@ -1157,6 +1225,7 @@ async def send_message_streaming(
|
||||
metadata={
|
||||
"job_type": "send_message_streaming",
|
||||
"agent_id": agent_id,
|
||||
"background": request.background or False,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
@@ -1211,8 +1280,58 @@ async def send_message_streaming(
|
||||
else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER
|
||||
),
|
||||
)
|
||||
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
|
||||
|
||||
if request.background and settings.track_agent_run:
|
||||
if isinstance(redis_client, NoopAsyncRedisClient):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Background streaming requires Redis to be running. "
|
||||
"Please ensure Redis is properly configured. "
|
||||
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}"
|
||||
),
|
||||
)
|
||||
|
||||
if request.stream_tokens and model_compatible_token_streaming:
|
||||
raw_stream = agent_loop.step_stream(
|
||||
input_messages=request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
else:
|
||||
raw_stream = agent_loop.step_stream_no_tokens(
|
||||
request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
create_background_stream_processor(
|
||||
stream_generator=raw_stream,
|
||||
redis_client=redis_client,
|
||||
run_id=run.id,
|
||||
)
|
||||
)
|
||||
|
||||
stream = redis_sse_stream_generator(
|
||||
redis_client=redis_client,
|
||||
run_id=run.id,
|
||||
)
|
||||
|
||||
if request.include_pings and settings.enable_keepalive:
|
||||
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval)
|
||||
|
||||
return StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
if request.stream_tokens and model_compatible_token_streaming:
|
||||
raw_stream = agent_loop.step_stream(
|
||||
input_messages=request.messages,
|
||||
@@ -1350,6 +1469,7 @@ async def _process_message_background(
|
||||
"google_vertex",
|
||||
"bedrock",
|
||||
"ollama",
|
||||
"groq",
|
||||
]
|
||||
if agent_eligible and model_compatible:
|
||||
if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
|
||||
@@ -1538,7 +1658,8 @@ async def preview_raw_payload(
|
||||
"bedrock",
|
||||
"ollama",
|
||||
"azure",
|
||||
"together",
|
||||
"xai",
|
||||
"groq",
|
||||
]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
@@ -1608,7 +1729,8 @@ async def summarize_agent_conversation(
|
||||
"bedrock",
|
||||
"ollama",
|
||||
"azure",
|
||||
"together",
|
||||
"xai",
|
||||
"groq",
|
||||
]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, UploadFile
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.helpers.pinecone_utils import (
|
||||
@@ -34,7 +35,7 @@ from letta.services.file_processor.file_types import get_allowed_media_types, ge
|
||||
from letta.services.file_processor.parser.markitdown_parser import MarkitdownFileParser
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task, sanitize_filename
|
||||
from letta.utils import safe_create_file_processing_task, safe_create_task, sanitize_filename
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -138,8 +139,11 @@ async def create_folder(
|
||||
# TODO: need to asyncify this
|
||||
if not folder_create.embedding_config:
|
||||
if not folder_create.embedding:
|
||||
# TODO: modify error type
|
||||
raise ValueError("Must specify either embedding or embedding_config in request")
|
||||
if settings.default_embedding_handle is None:
|
||||
# TODO: modify error type
|
||||
raise ValueError("Must specify either embedding or embedding_config in request")
|
||||
else:
|
||||
folder_create.embedding = settings.default_embedding_handle
|
||||
folder_create.embedding_config = await server.get_embedding_config_from_handle_async(
|
||||
handle=folder_create.embedding,
|
||||
embedding_chunk_size=folder_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
@@ -257,13 +261,16 @@ async def upload_file_to_folder(
|
||||
|
||||
# Store original filename and handle duplicate logic
|
||||
# Use custom name if provided, otherwise use the uploaded file's name
|
||||
original_filename = sanitize_filename(name if name else file.filename) # Basic sanitization only
|
||||
# If custom name is provided, use it directly (it's just metadata, not a filesystem path)
|
||||
# Otherwise, sanitize the uploaded filename for security
|
||||
original_filename = name if name else sanitize_filename(file.filename) # Basic sanitization only
|
||||
|
||||
# Check if duplicate exists
|
||||
existing_file = await server.file_manager.get_file_by_original_name_and_source(
|
||||
original_filename=original_filename, source_id=folder_id, actor=actor
|
||||
)
|
||||
|
||||
unique_filename = None
|
||||
if existing_file:
|
||||
# Duplicate found, handle based on strategy
|
||||
if duplicate_handling == DuplicateFileHandling.ERROR:
|
||||
@@ -305,8 +312,11 @@ async def upload_file_to_folder(
|
||||
|
||||
# Use cloud processing for all files (simple files always, complex files with Mistral key)
|
||||
logger.info("Running experimental cloud based file processing...")
|
||||
safe_create_task(
|
||||
safe_create_file_processing_task(
|
||||
load_file_to_source_cloud(server, agent_states, content, folder_id, actor, folder.embedding_config, file_metadata),
|
||||
file_metadata=file_metadata,
|
||||
server=server,
|
||||
actor=actor,
|
||||
logger=logger,
|
||||
label="file_processor.process",
|
||||
)
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
from datetime import timedelta
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
from pydantic import Field
|
||||
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.enums import JobStatus, JobType, MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import RetrieveStreamRequest
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.step import Step
|
||||
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
|
||||
@@ -19,6 +26,14 @@ router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
def list_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."),
|
||||
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(50, description="Maximum number of runs to return"),
|
||||
ascending: bool = Query(
|
||||
False,
|
||||
description="Whether to sort agents oldest to newest (True) or newest to oldest (False, default)",
|
||||
),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
@@ -26,18 +41,29 @@ def list_runs(
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
runs = [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)]
|
||||
|
||||
if not agent_ids:
|
||||
return runs
|
||||
|
||||
return [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
|
||||
runs = [
|
||||
Run.from_job(job)
|
||||
for job in server.job_manager.list_jobs(
|
||||
actor=actor,
|
||||
job_type=JobType.RUN,
|
||||
limit=limit,
|
||||
before=before,
|
||||
after=after,
|
||||
ascending=False,
|
||||
)
|
||||
]
|
||||
if agent_ids:
|
||||
runs = [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
|
||||
if background is not None:
|
||||
runs = [run for run in runs if "background" in run.metadata and run.metadata["background"] == background]
|
||||
return runs
|
||||
|
||||
|
||||
@router.get("/active", response_model=List[Run], operation_id="list_active_runs")
|
||||
def list_active_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."),
|
||||
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
@@ -46,13 +72,15 @@ def list_active_runs(
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
active_runs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN)
|
||||
|
||||
active_runs = [Run.from_job(job) for job in active_runs]
|
||||
|
||||
if not agent_ids:
|
||||
return active_runs
|
||||
if agent_ids:
|
||||
active_runs = [run for run in active_runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
|
||||
|
||||
return [run for run in active_runs if "agent_id" in run.metadata and run.metadata["agent_id"] in agent_ids]
|
||||
if background is not None:
|
||||
active_runs = [run for run in active_runs if "background" in run.metadata and run.metadata["background"] == background]
|
||||
|
||||
return active_runs
|
||||
|
||||
|
||||
@router.get("/{run_id}", response_model=Run, operation_id="retrieve_run")
|
||||
@@ -213,3 +241,65 @@ async def delete_run(
|
||||
return Run.from_job(job)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{run_id}/stream",
|
||||
response_model=None,
|
||||
operation_id="retrieve_stream",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def retrieve_stream(
|
||||
run_id: str,
|
||||
request: RetrieveStreamRequest = Body(None),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
run = Run.from_job(job)
|
||||
|
||||
if "background" not in run.metadata or not run.metadata["background"]:
|
||||
raise HTTPException(status_code=400, detail="Run was not created in background mode, so it cannot be retrieved.")
|
||||
|
||||
if run.created_at < get_utc_time() - timedelta(hours=3):
|
||||
raise HTTPException(status_code=410, detail="Run was created more than 3 hours ago, and is now expired.")
|
||||
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
if isinstance(redis_client, NoopAsyncRedisClient):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Background streaming requires Redis to be running. "
|
||||
"Please ensure Redis is properly configured. "
|
||||
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}"
|
||||
),
|
||||
)
|
||||
|
||||
stream = redis_sse_stream_generator(
|
||||
redis_client=redis_client,
|
||||
run_id=run_id,
|
||||
starting_after=request.starting_after,
|
||||
poll_interval=request.poll_interval,
|
||||
batch_size=request.batch_size,
|
||||
)
|
||||
|
||||
if request.include_pings and settings.enable_keepalive:
|
||||
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval)
|
||||
|
||||
return StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -2,18 +2,17 @@ import asyncio
|
||||
import mimetypes
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, UploadFile
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.helpers.pinecone_utils import (
|
||||
delete_file_records_from_pinecone_index,
|
||||
delete_source_records_from_pinecone_index,
|
||||
list_pinecone_index_for_files,
|
||||
should_use_pinecone,
|
||||
)
|
||||
from letta.log import get_logger
|
||||
@@ -35,14 +34,13 @@ from letta.services.file_processor.file_types import get_allowed_media_types, ge
|
||||
from letta.services.file_processor.parser.markitdown_parser import MarkitdownFileParser
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task, sanitize_filename
|
||||
from letta.utils import safe_create_file_processing_task, safe_create_task, sanitize_filename
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Register all supported file types with Python's mimetypes module
|
||||
register_mime_types()
|
||||
|
||||
|
||||
router = APIRouter(prefix="/sources", tags=["sources"])
|
||||
|
||||
|
||||
@@ -139,8 +137,11 @@ async def create_source(
|
||||
# TODO: need to asyncify this
|
||||
if not source_create.embedding_config:
|
||||
if not source_create.embedding:
|
||||
# TODO: modify error type
|
||||
raise ValueError("Must specify either embedding or embedding_config in request")
|
||||
if settings.default_embedding_handle is None:
|
||||
# TODO: modify error type
|
||||
raise ValueError("Must specify either embedding or embedding_config in request")
|
||||
else:
|
||||
source_create.embedding = settings.default_embedding_handle
|
||||
source_create.embedding_config = await server.get_embedding_config_from_handle_async(
|
||||
handle=source_create.embedding,
|
||||
embedding_chunk_size=source_create.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
@@ -258,7 +259,9 @@ async def upload_file_to_source(
|
||||
|
||||
# Store original filename and handle duplicate logic
|
||||
# Use custom name if provided, otherwise use the uploaded file's name
|
||||
original_filename = sanitize_filename(name if name else file.filename) # Basic sanitization only
|
||||
# If custom name is provided, use it directly (it's just metadata, not a filesystem path)
|
||||
# Otherwise, sanitize the uploaded filename for security
|
||||
original_filename = name if name else sanitize_filename(file.filename) # Basic sanitization only
|
||||
|
||||
# Check if duplicate exists
|
||||
existing_file = await server.file_manager.get_file_by_original_name_and_source(
|
||||
@@ -307,8 +310,11 @@ async def upload_file_to_source(
|
||||
|
||||
# Use cloud processing for all files (simple files always, complex files with Mistral key)
|
||||
logger.info("Running experimental cloud based file processing...")
|
||||
safe_create_task(
|
||||
safe_create_file_processing_task(
|
||||
load_file_to_source_cloud(server, agent_states, content, source_id, actor, source.embedding_config, file_metadata),
|
||||
file_metadata=file_metadata,
|
||||
server=server,
|
||||
actor=actor,
|
||||
logger=logger,
|
||||
label="file_processor.process",
|
||||
)
|
||||
@@ -358,6 +364,10 @@ async def list_source_files(
|
||||
limit: int = Query(1000, description="Number of files to return"),
|
||||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
include_content: bool = Query(False, description="Whether to include full file content"),
|
||||
check_status_updates: bool = Query(
|
||||
True,
|
||||
description="Whether to check and update file processing status (from the vector db service). If False, will not fetch and update the status, which may lead to performance gains.",
|
||||
),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
@@ -372,6 +382,7 @@ async def list_source_files(
|
||||
actor=actor,
|
||||
include_content=include_content,
|
||||
strip_directory_prefix=True, # TODO: Reconsider this. This is purely for aesthetics.
|
||||
check_status_updates=check_status_updates,
|
||||
)
|
||||
|
||||
|
||||
@@ -400,51 +411,8 @@ async def get_file_metadata(
|
||||
if file_metadata.source_id != source_id:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found in source {source_id}.")
|
||||
|
||||
# Check for timeout if status is not terminal
|
||||
if not file_metadata.processing_status.is_terminal_state():
|
||||
if file_metadata.created_at:
|
||||
# Handle timezone differences between PostgreSQL (timezone-aware) and SQLite (timezone-naive)
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL: both datetimes are timezone-aware
|
||||
timeout_threshold = datetime.now(timezone.utc) - timedelta(minutes=settings.file_processing_timeout_minutes)
|
||||
file_created_at = file_metadata.created_at
|
||||
else:
|
||||
# SQLite: both datetimes should be timezone-naive
|
||||
timeout_threshold = datetime.utcnow() - timedelta(minutes=settings.file_processing_timeout_minutes)
|
||||
file_created_at = file_metadata.created_at
|
||||
|
||||
if file_created_at < timeout_threshold:
|
||||
# Move file to error status with timeout message
|
||||
timeout_message = settings.file_processing_timeout_error_message.format(settings.file_processing_timeout_minutes)
|
||||
try:
|
||||
file_metadata = await server.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=actor, processing_status=FileProcessingStatus.ERROR, error_message=timeout_message
|
||||
)
|
||||
except ValueError as e:
|
||||
# state transition was blocked - log it but don't fail the request
|
||||
logger.warning(f"Could not update file to timeout error state: {str(e)}")
|
||||
# continue with existing file_metadata
|
||||
|
||||
if should_use_pinecone() and file_metadata.processing_status == FileProcessingStatus.EMBEDDING:
|
||||
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor)
|
||||
logger.info(
|
||||
f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_id} ({file_metadata.file_name}) in organization {actor.organization_id}"
|
||||
)
|
||||
|
||||
if len(ids) != file_metadata.chunks_embedded or len(ids) == file_metadata.total_chunks:
|
||||
if len(ids) != file_metadata.total_chunks:
|
||||
file_status = file_metadata.processing_status
|
||||
else:
|
||||
file_status = FileProcessingStatus.COMPLETED
|
||||
try:
|
||||
file_metadata = await server.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
|
||||
)
|
||||
except ValueError as e:
|
||||
# state transition was blocked - this is a race condition
|
||||
# log it but don't fail the request since we're just reading metadata
|
||||
logger.warning(f"Race condition detected in get_file_metadata: {str(e)}")
|
||||
# return the current file state without updating
|
||||
# Check and update file status (timeout check and pinecone embedding sync)
|
||||
file_metadata = await server.file_manager.check_and_update_file_status(file_metadata, actor)
|
||||
|
||||
return file_metadata
|
||||
|
||||
|
||||
@@ -1,18 +1,28 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
router = APIRouter(prefix="/telemetry", tags=["telemetry"])
|
||||
|
||||
|
||||
@router.get("/{step_id}", response_model=ProviderTrace, operation_id="retrieve_provider_trace")
|
||||
@router.get("/{step_id}", response_model=Optional[ProviderTrace], operation_id="retrieve_provider_trace")
|
||||
async def retrieve_provider_trace_by_step_id(
|
||||
step_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
return await server.telemetry_manager.get_provider_trace_by_step_id_async(
|
||||
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
)
|
||||
provider_trace = None
|
||||
if settings.track_provider_trace:
|
||||
try:
|
||||
provider_trace = await server.telemetry_manager.get_provider_trace_by_step_id_async(
|
||||
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
return provider_trace
|
||||
|
||||
@@ -547,7 +547,7 @@ async def add_mcp_server_to_config(
|
||||
server_name=request.server_name,
|
||||
server_type=request.type,
|
||||
server_url=request.server_url,
|
||||
token=request.resolve_token() if not request.custom_headers else None,
|
||||
token=request.resolve_token(),
|
||||
custom_headers=request.custom_headers,
|
||||
)
|
||||
elif isinstance(request, StreamableHTTPServerConfig):
|
||||
@@ -555,7 +555,7 @@ async def add_mcp_server_to_config(
|
||||
server_name=request.server_name,
|
||||
server_type=request.type,
|
||||
server_url=request.server_url,
|
||||
token=request.resolve_token() if not request.custom_headers else None,
|
||||
token=request.resolve_token(),
|
||||
custom_headers=request.custom_headers,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import anyio
|
||||
from fastapi.responses import StreamingResponse
|
||||
from starlette.types import Send
|
||||
|
||||
from letta.errors import LettaUnexpectedStreamCancellationError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.letta_ping import LettaPing
|
||||
@@ -288,33 +289,11 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
||||
|
||||
# Handle client timeouts (should throw error to inform user)
|
||||
except asyncio.CancelledError as exc:
|
||||
logger.warning("Stream was cancelled due to client timeout or unexpected disconnection")
|
||||
logger.warning("Stream was terminated due to unexpected cancellation from server")
|
||||
# Handle unexpected cancellation with error
|
||||
more_body = False
|
||||
error_resp = {"error": {"message": "Request was unexpectedly cancelled (likely due to client timeout or disconnection)"}}
|
||||
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
|
||||
if not self.response_started:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 408, # Request Timeout
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
raise
|
||||
if self._client_connected:
|
||||
try:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": error_event,
|
||||
"more_body": more_body,
|
||||
}
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
self._client_connected = False
|
||||
capture_sentry_exception(exc)
|
||||
return
|
||||
raise LettaUnexpectedStreamCancellationError("Stream was terminated due to unexpected cancellation from server")
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Unhandled Streaming Error")
|
||||
|
||||
@@ -2068,7 +2068,6 @@ class SyncServer(Server):
|
||||
raise ValueError(f"No client was created for MCP server: {mcp_server_name}")
|
||||
|
||||
tools = await self.mcp_clients[mcp_server_name].list_tools()
|
||||
|
||||
# Add health information to each tool
|
||||
for tool in tools:
|
||||
if tool.inputSchema:
|
||||
|
||||
@@ -42,6 +42,7 @@ from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent, get_prompt_template_for_agent_type
|
||||
from letta.schemas.block import DEFAULT_BLOCKS
|
||||
@@ -89,7 +90,6 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
check_supports_structured_output,
|
||||
compile_system_message,
|
||||
derive_system_message,
|
||||
get_system_message_from_compiled_memory,
|
||||
initialize_message_sequence,
|
||||
initialize_message_sequence_async,
|
||||
package_initial_message_sequence,
|
||||
@@ -1783,7 +1783,7 @@ class AgentManager:
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
|
||||
new_system_message_str = get_system_message_from_compiled_memory(
|
||||
new_system_message_str = PromptGenerator.get_system_message_from_compiled_memory(
|
||||
system_prompt=agent_state.system,
|
||||
memory_with_sources=curr_memory_str,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
|
||||
from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError
|
||||
from letta.errors import (
|
||||
AgentExportIdMappingError,
|
||||
AgentExportProcessingError,
|
||||
AgentFileExportError,
|
||||
AgentFileImportError,
|
||||
AgentNotFoundForExportError,
|
||||
)
|
||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
@@ -420,6 +428,8 @@ class AgentSerializationManager:
|
||||
self,
|
||||
schema: AgentFileSchema,
|
||||
actor: User,
|
||||
append_copy_suffix: bool = False,
|
||||
override_existing_tools: bool = True,
|
||||
dry_run: bool = False,
|
||||
env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ImportResult:
|
||||
@@ -481,7 +491,9 @@ class AgentSerializationManager:
|
||||
pydantic_tools.append(Tool(**tool_schema.model_dump(exclude={"id"})))
|
||||
|
||||
# bulk upsert all tools at once
|
||||
created_tools = await self.tool_manager.bulk_upsert_tools_async(pydantic_tools, actor)
|
||||
created_tools = await self.tool_manager.bulk_upsert_tools_async(
|
||||
pydantic_tools, actor, override_existing_tools=override_existing_tools
|
||||
)
|
||||
|
||||
# map file ids to database ids
|
||||
# note: tools are matched by name during upsert, so we need to match by name here too
|
||||
@@ -513,8 +525,20 @@ class AgentSerializationManager:
|
||||
if schema.sources:
|
||||
# convert source schemas to pydantic sources
|
||||
pydantic_sources = []
|
||||
|
||||
# First, do a fast batch check for existing source names to avoid conflicts
|
||||
source_names_to_check = [s.name for s in schema.sources]
|
||||
existing_source_names = await self.source_manager.get_existing_source_names(source_names_to_check, actor)
|
||||
|
||||
for source_schema in schema.sources:
|
||||
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
|
||||
|
||||
# Check if source name already exists, if so add unique suffix
|
||||
original_name = source_data["name"]
|
||||
if original_name in existing_source_names:
|
||||
unique_suffix = uuid.uuid4().hex[:8]
|
||||
source_data["name"] = f"{original_name}_{unique_suffix}"
|
||||
|
||||
pydantic_sources.append(Source(**source_data))
|
||||
|
||||
# bulk upsert all sources at once
|
||||
@@ -523,13 +547,15 @@ class AgentSerializationManager:
|
||||
# map file ids to database ids
|
||||
# note: sources are matched by name during upsert, so we need to match by name here too
|
||||
created_sources_by_name = {source.name: source for source in created_sources}
|
||||
for source_schema in schema.sources:
|
||||
created_source = created_sources_by_name.get(source_schema.name)
|
||||
for i, source_schema in enumerate(schema.sources):
|
||||
# Use the pydantic source name (which may have been modified for uniqueness)
|
||||
source_name = pydantic_sources[i].name
|
||||
created_source = created_sources_by_name.get(source_name)
|
||||
if created_source:
|
||||
file_to_db_ids[source_schema.id] = created_source.id
|
||||
imported_count += 1
|
||||
else:
|
||||
logger.warning(f"Source {source_schema.name} was not created during bulk upsert")
|
||||
logger.warning(f"Source {source_name} was not created during bulk upsert")
|
||||
|
||||
# 4. Create files (depends on sources)
|
||||
for file_schema in schema.files:
|
||||
@@ -548,38 +574,49 @@ class AgentSerializationManager:
|
||||
imported_count += 1
|
||||
|
||||
# 5. Process files for chunking/embedding (depends on files and sources)
|
||||
if should_use_pinecone():
|
||||
embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config)
|
||||
else:
|
||||
embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config)
|
||||
file_processor = FileProcessor(
|
||||
file_parser=self.file_parser,
|
||||
embedder=embedder,
|
||||
actor=actor,
|
||||
using_pinecone=self.using_pinecone,
|
||||
)
|
||||
# Start background tasks for file processing
|
||||
background_tasks = []
|
||||
if schema.files and any(f.content for f in schema.files):
|
||||
if should_use_pinecone():
|
||||
embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config)
|
||||
else:
|
||||
embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config)
|
||||
file_processor = FileProcessor(
|
||||
file_parser=self.file_parser,
|
||||
embedder=embedder,
|
||||
actor=actor,
|
||||
using_pinecone=self.using_pinecone,
|
||||
)
|
||||
|
||||
for file_schema in schema.files:
|
||||
if file_schema.content: # Only process files with content
|
||||
file_db_id = file_to_db_ids[file_schema.id]
|
||||
source_db_id = file_to_db_ids[file_schema.source_id]
|
||||
for file_schema in schema.files:
|
||||
if file_schema.content: # Only process files with content
|
||||
file_db_id = file_to_db_ids[file_schema.id]
|
||||
source_db_id = file_to_db_ids[file_schema.source_id]
|
||||
|
||||
# Get the created file metadata (with caching)
|
||||
if file_db_id not in file_metadata_cache:
|
||||
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
|
||||
file_metadata = file_metadata_cache[file_db_id]
|
||||
# Get the created file metadata (with caching)
|
||||
if file_db_id not in file_metadata_cache:
|
||||
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
|
||||
file_metadata = file_metadata_cache[file_db_id]
|
||||
|
||||
# Save the db call of fetching content again
|
||||
file_metadata.content = file_schema.content
|
||||
# Save the db call of fetching content again
|
||||
file_metadata.content = file_schema.content
|
||||
|
||||
# Process the file for chunking/embedding
|
||||
passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_db_id)
|
||||
imported_count += len(passages)
|
||||
# Create background task for file processing
|
||||
# TODO: This can be moved to celery or RQ or something
|
||||
task = asyncio.create_task(
|
||||
self._process_file_async(
|
||||
file_metadata=file_metadata, source_id=source_db_id, file_processor=file_processor, actor=actor
|
||||
)
|
||||
)
|
||||
background_tasks.append(task)
|
||||
logger.info(f"Started background processing for file {file_metadata.file_name} (ID: {file_db_id})")
|
||||
|
||||
# 6. Create agents with empty message history
|
||||
for agent_schema in schema.agents:
|
||||
# Convert AgentSchema back to CreateAgent, remapping tool/block IDs
|
||||
agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"})
|
||||
if append_copy_suffix:
|
||||
agent_data["name"] = agent_data.get("name") + "_copy"
|
||||
|
||||
# Remap tool_ids from file IDs to database IDs
|
||||
if agent_data.get("tool_ids"):
|
||||
@@ -589,6 +626,10 @@ class AgentSerializationManager:
|
||||
if agent_data.get("block_ids"):
|
||||
agent_data["block_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["block_ids"]]
|
||||
|
||||
# Remap source_ids from file IDs to database IDs
|
||||
if agent_data.get("source_ids"):
|
||||
agent_data["source_ids"] = [file_to_db_ids[file_id] for file_id in agent_data["source_ids"]]
|
||||
|
||||
if env_vars:
|
||||
for var in agent_data["tool_exec_environment_variables"]:
|
||||
var["value"] = env_vars.get(var["key"], "")
|
||||
@@ -635,14 +676,16 @@ class AgentSerializationManager:
|
||||
for file_agent_schema in agent_schema.files_agents:
|
||||
file_db_id = file_to_db_ids[file_agent_schema.file_id]
|
||||
|
||||
# Use cached file metadata if available
|
||||
# Use cached file metadata if available (with content)
|
||||
if file_db_id not in file_metadata_cache:
|
||||
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
|
||||
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(
|
||||
file_db_id, actor, include_content=True
|
||||
)
|
||||
file_metadata = file_metadata_cache[file_db_id]
|
||||
files_for_agent.append(file_metadata)
|
||||
|
||||
if file_agent_schema.visible_content:
|
||||
visible_content_map[file_db_id] = file_agent_schema.visible_content
|
||||
visible_content_map[file_metadata.file_name] = file_agent_schema.visible_content
|
||||
|
||||
# Bulk attach files to agent
|
||||
await self.file_agent_manager.attach_files_bulk(
|
||||
@@ -669,9 +712,19 @@ class AgentSerializationManager:
|
||||
file_to_db_ids[group.id] = created_group.id
|
||||
imported_count += 1
|
||||
|
||||
# prepare result message
|
||||
num_background_tasks = len(background_tasks)
|
||||
if num_background_tasks > 0:
|
||||
message = (
|
||||
f"Import completed successfully. Imported {imported_count} entities. "
|
||||
f"{num_background_tasks} file(s) are being processed in the background for embeddings."
|
||||
)
|
||||
else:
|
||||
message = f"Import completed successfully. Imported {imported_count} entities."
|
||||
|
||||
return ImportResult(
|
||||
success=True,
|
||||
message=f"Import completed successfully. Imported {imported_count} entities.",
|
||||
message=message,
|
||||
imported_count=imported_count,
|
||||
imported_agent_ids=imported_agent_ids,
|
||||
id_mappings=file_to_db_ids,
|
||||
@@ -849,3 +902,47 @@ class AgentSerializationManager:
|
||||
except AttributeError:
|
||||
allowed = model_cls.__fields__.keys() # Pydantic v1
|
||||
return {k: v for k, v in data.items() if k in allowed}
|
||||
|
||||
async def _process_file_async(self, file_metadata: FileMetadata, source_id: str, file_processor: FileProcessor, actor: User):
|
||||
"""
|
||||
Process a file asynchronously in the background.
|
||||
|
||||
This method handles chunking and embedding of file content without blocking
|
||||
the main import process.
|
||||
|
||||
Args:
|
||||
file_metadata: The file metadata with content
|
||||
source_id: The database ID of the source
|
||||
file_processor: The file processor instance to use
|
||||
actor: The user performing the action
|
||||
"""
|
||||
file_id = file_metadata.id
|
||||
file_name = file_metadata.file_name
|
||||
|
||||
try:
|
||||
logger.info(f"Starting background processing for file {file_name} (ID: {file_id})")
|
||||
|
||||
# process the file for chunking/embedding
|
||||
passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_id)
|
||||
|
||||
logger.info(f"Successfully processed file {file_name} with {len(passages)} passages")
|
||||
|
||||
# file status is automatically updated to COMPLETED by process_imported_file
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process file {file_name} (ID: {file_id}) in background: {e}")
|
||||
|
||||
# update file status to ERROR
|
||||
try:
|
||||
await self.file_manager.update_file_status(
|
||||
file_id=file_id,
|
||||
actor=actor,
|
||||
processing_status=FileProcessingStatus.ERROR,
|
||||
error_message=str(e) if str(e) else f"Agent serialization failed: {type(e).__name__}",
|
||||
)
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update file status to ERROR for {file_id}: {update_error}")
|
||||
|
||||
# we don't re-raise here since this is a background task
|
||||
# the file will be marked as ERROR and the import can continue
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
@@ -9,6 +9,8 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from letta.constants import MAX_FILENAME_LENGTH
|
||||
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.file import FileContent as FileContentModel
|
||||
from letta.orm.file import FileMetadata as FileMetadataModel
|
||||
@@ -20,8 +22,11 @@ from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source_metadata import FileStats, OrganizationSourcesStats, SourceStats
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DuplicateFileError(Exception):
|
||||
"""Raised when a duplicate file is encountered and error handling is specified"""
|
||||
@@ -174,6 +179,10 @@ class FileManager:
|
||||
if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None:
|
||||
raise ValueError("Nothing to update")
|
||||
|
||||
# validate that ERROR status must have an error message
|
||||
if processing_status == FileProcessingStatus.ERROR and not error_message:
|
||||
raise ValueError("Error message is required when setting processing status to ERROR")
|
||||
|
||||
values: dict[str, object] = {"updated_at": datetime.utcnow()}
|
||||
if processing_status is not None:
|
||||
values["processing_status"] = processing_status
|
||||
@@ -273,6 +282,79 @@ class FileManager:
|
||||
)
|
||||
return await file_orm.to_pydantic_async()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def check_and_update_file_status(
|
||||
self,
|
||||
file_metadata: PydanticFileMetadata,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticFileMetadata:
|
||||
"""
|
||||
Check and update file status for timeout and embedding completion.
|
||||
|
||||
This method consolidates logic for:
|
||||
1. Checking if a file has timed out during processing
|
||||
2. Checking Pinecone embedding status and updating counts
|
||||
|
||||
Args:
|
||||
file_metadata: The file metadata to check
|
||||
actor: User performing the check
|
||||
|
||||
Returns:
|
||||
Updated file metadata with current status
|
||||
"""
|
||||
# check for timeout if status is not terminal
|
||||
if not file_metadata.processing_status.is_terminal_state():
|
||||
if file_metadata.created_at:
|
||||
# handle timezone differences between PostgreSQL (timezone-aware) and SQLite (timezone-naive)
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# postgresql: both datetimes are timezone-aware
|
||||
timeout_threshold = datetime.now(timezone.utc) - timedelta(minutes=settings.file_processing_timeout_minutes)
|
||||
file_created_at = file_metadata.created_at
|
||||
else:
|
||||
# sqlite: both datetimes should be timezone-naive
|
||||
timeout_threshold = datetime.utcnow() - timedelta(minutes=settings.file_processing_timeout_minutes)
|
||||
file_created_at = file_metadata.created_at
|
||||
|
||||
if file_created_at < timeout_threshold:
|
||||
# move file to error status with timeout message
|
||||
timeout_message = settings.file_processing_timeout_error_message.format(settings.file_processing_timeout_minutes)
|
||||
try:
|
||||
file_metadata = await self.update_file_status(
|
||||
file_id=file_metadata.id,
|
||||
actor=actor,
|
||||
processing_status=FileProcessingStatus.ERROR,
|
||||
error_message=timeout_message,
|
||||
)
|
||||
except ValueError as e:
|
||||
# state transition was blocked - log it but don't fail
|
||||
logger.warning(f"Could not update file to timeout error state: {str(e)}")
|
||||
# continue with existing file_metadata
|
||||
|
||||
# check pinecone embedding status
|
||||
if should_use_pinecone() and file_metadata.processing_status == FileProcessingStatus.EMBEDDING:
|
||||
ids = await list_pinecone_index_for_files(file_id=file_metadata.id, actor=actor)
|
||||
logger.info(
|
||||
f"Embedded chunks {len(ids)}/{file_metadata.total_chunks} for {file_metadata.id} ({file_metadata.file_name}) in organization {actor.organization_id}"
|
||||
)
|
||||
|
||||
if len(ids) != file_metadata.chunks_embedded or len(ids) == file_metadata.total_chunks:
|
||||
if len(ids) != file_metadata.total_chunks:
|
||||
file_status = file_metadata.processing_status
|
||||
else:
|
||||
file_status = FileProcessingStatus.COMPLETED
|
||||
try:
|
||||
file_metadata = await self.update_file_status(
|
||||
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
|
||||
)
|
||||
except ValueError as e:
|
||||
# state transition was blocked - this is a race condition
|
||||
# log it but don't fail since we're just checking status
|
||||
logger.warning(f"Race condition detected in check_and_update_file_status: {str(e)}")
|
||||
# return the current file state without updating
|
||||
|
||||
return file_metadata
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def upsert_file_content(
|
||||
@@ -328,8 +410,22 @@ class FileManager:
|
||||
limit: Optional[int] = 50,
|
||||
include_content: bool = False,
|
||||
strip_directory_prefix: bool = False,
|
||||
check_status_updates: bool = False,
|
||||
) -> List[PydanticFileMetadata]:
|
||||
"""List all files with optional pagination."""
|
||||
"""List all files with optional pagination and status checking.
|
||||
|
||||
Args:
|
||||
source_id: Source to list files from
|
||||
actor: User performing the request
|
||||
after: Pagination cursor
|
||||
limit: Maximum number of files to return
|
||||
include_content: Whether to include file content
|
||||
strip_directory_prefix: Whether to strip directory prefix from filenames
|
||||
check_status_updates: Whether to check and update status for timeout and embedding completion
|
||||
|
||||
Returns:
|
||||
List of file metadata
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
options = [selectinload(FileMetadataModel.content)] if include_content else None
|
||||
|
||||
@@ -341,10 +437,19 @@ class FileManager:
|
||||
source_id=source_id,
|
||||
query_options=options,
|
||||
)
|
||||
return [
|
||||
await file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix)
|
||||
for file in files
|
||||
]
|
||||
|
||||
# convert all files to pydantic models
|
||||
file_metadatas = await asyncio.gather(
|
||||
*[file.to_pydantic_async(include_content=include_content, strip_directory_prefix=strip_directory_prefix) for file in files]
|
||||
)
|
||||
|
||||
# if status checking is enabled, check all files concurrently
|
||||
if check_status_updates:
|
||||
file_metadatas = await asyncio.gather(
|
||||
*[self.check_and_update_file_status(file_metadata, actor) for file_metadata in file_metadatas]
|
||||
)
|
||||
|
||||
return file_metadatas
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -264,7 +264,10 @@ class FileProcessor:
|
||||
},
|
||||
)
|
||||
await self.file_manager.update_file_status(
|
||||
file_id=file_metadata.id, actor=self.actor, processing_status=FileProcessingStatus.ERROR, error_message=str(e)
|
||||
file_id=file_metadata.id,
|
||||
actor=self.actor,
|
||||
processing_status=FileProcessingStatus.ERROR,
|
||||
error_message=str(e) if str(e) else f"File processing failed: {type(e).__name__}",
|
||||
)
|
||||
|
||||
return []
|
||||
@@ -361,7 +364,7 @@ class FileProcessor:
|
||||
file_id=file_metadata.id,
|
||||
actor=self.actor,
|
||||
processing_status=FileProcessingStatus.ERROR,
|
||||
error_message=str(e),
|
||||
error_message=str(e) if str(e) else f"Import file processing failed: {type(e).__name__}",
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
@@ -293,6 +293,66 @@ class FileAgentManager:
|
||||
else:
|
||||
return [r.to_pydantic() for r in rows]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_files_for_agent_paginated(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
cursor: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
is_open: Optional[bool] = None,
|
||||
) -> tuple[List[PydanticFileAgent], Optional[str], bool]:
|
||||
"""
|
||||
Return paginated file associations for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent ID to get files for
|
||||
actor: User performing the action
|
||||
cursor: Pagination cursor (file-agent ID to start after)
|
||||
limit: Maximum number of results to return
|
||||
is_open: Optional filter for open/closed status (None = all, True = open only, False = closed only)
|
||||
|
||||
Returns:
|
||||
Tuple of (file_agents, next_cursor, has_more)
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
conditions = [
|
||||
FileAgentModel.agent_id == agent_id,
|
||||
FileAgentModel.organization_id == actor.organization_id,
|
||||
FileAgentModel.is_deleted == False,
|
||||
]
|
||||
|
||||
# apply is_open filter if specified
|
||||
if is_open is not None:
|
||||
conditions.append(FileAgentModel.is_open == is_open)
|
||||
|
||||
# apply cursor if provided (get records after this ID)
|
||||
if cursor:
|
||||
conditions.append(FileAgentModel.id > cursor)
|
||||
|
||||
query = select(FileAgentModel).where(and_(*conditions))
|
||||
|
||||
# order by ID for stable pagination
|
||||
query = query.order_by(FileAgentModel.id)
|
||||
|
||||
# fetch limit + 1 to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await session.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
# check if we got more records than requested (meaning there are more pages)
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
# trim back to the requested limit
|
||||
rows = rows[:limit]
|
||||
|
||||
# get cursor for next page (ID of last item in current page)
|
||||
next_cursor = rows[-1].id if rows else None
|
||||
|
||||
return [r.to_pydantic() for r in rows], next_cursor, has_more
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_agents_for_file(
|
||||
|
||||
@@ -21,7 +21,7 @@ from letta.constants import (
|
||||
STRUCTURED_OUTPUT_MODELS,
|
||||
)
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import format_datetime, get_local_time, get_local_time_fast
|
||||
from letta.helpers.datetime_helpers import get_local_time
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
@@ -33,6 +33,7 @@ from letta.orm.sources_agents import SourcesAgents
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.prompts import gpt_system
|
||||
from letta.prompts.prompt_generator import PromptGenerator
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
@@ -217,60 +218,6 @@ def derive_system_message(agent_type: AgentType, enable_sleeptime: Optional[bool
|
||||
return system
|
||||
|
||||
|
||||
# TODO: This code is kind of wonky and deserves a rewrite
|
||||
def compile_memory_metadata_block(
|
||||
memory_edit_timestamp: datetime,
|
||||
timezone: str,
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: Optional[int] = 0,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a memory metadata block for the agent's system prompt.
|
||||
|
||||
This creates a structured metadata section that informs the agent about
|
||||
the current state of its memory systems, including timing information
|
||||
and memory counts. This helps the agent understand what information
|
||||
is available through its tools.
|
||||
|
||||
Args:
|
||||
memory_edit_timestamp: When memory blocks were last modified
|
||||
timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles')
|
||||
previous_message_count: Number of messages in recall memory (conversation history)
|
||||
archival_memory_size: Number of items in archival memory (long-term storage)
|
||||
|
||||
Returns:
|
||||
A formatted string containing the memory metadata block with XML-style tags
|
||||
|
||||
Example Output:
|
||||
<memory_metadata>
|
||||
- The current time is: 2024-01-15 10:30 AM PST
|
||||
- Memory blocks were last modified: 2024-01-15 09:00 AM PST
|
||||
- 42 previous messages between you and the user are stored in recall memory (use tools to access them)
|
||||
- 156 total memories you created are stored in archival memory (use tools to access them)
|
||||
</memory_metadata>
|
||||
"""
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
timestamp_str = format_datetime(memory_edit_timestamp, timezone)
|
||||
|
||||
# Create a metadata block of info so the agent knows about the metadata of out-of-context memories
|
||||
metadata_lines = [
|
||||
"<memory_metadata>",
|
||||
f"- The current time is: {get_local_time_fast(timezone)}",
|
||||
f"- Memory blocks were last modified: {timestamp_str}",
|
||||
f"- {previous_message_count} previous messages between you and the user are stored in recall memory (use tools to access them)",
|
||||
]
|
||||
|
||||
# Only include archival memory line if there are archival memories
|
||||
if archival_memory_size is not None and archival_memory_size > 0:
|
||||
metadata_lines.append(
|
||||
f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)"
|
||||
)
|
||||
|
||||
metadata_lines.append("</memory_metadata>")
|
||||
memory_metadata_block = "\n".join(metadata_lines)
|
||||
return memory_metadata_block
|
||||
|
||||
|
||||
class PreserveMapping(dict):
|
||||
"""Used to preserve (do not modify) undefined variables in the system prompt"""
|
||||
|
||||
@@ -331,7 +278,7 @@ def compile_system_message(
|
||||
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
||||
else:
|
||||
# TODO should this all put into the memory.__repr__ function?
|
||||
memory_metadata_string = compile_memory_metadata_block(
|
||||
memory_metadata_string = PromptGenerator.compile_memory_metadata_block(
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
@@ -372,154 +319,6 @@ def compile_system_message(
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
@trace_method
|
||||
def get_system_message_from_compiled_memory(
|
||||
system_prompt: str,
|
||||
memory_with_sources: str,
|
||||
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
|
||||
timezone: str,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
The base system message may be templated, in which case we need to render the variables.
|
||||
|
||||
The following are reserved variables:
|
||||
- CORE_MEMORY: the in-context memory of the LLM
|
||||
"""
|
||||
if user_defined_variables is not None:
|
||||
# TODO eventually support the user defining their own variables to inject
|
||||
raise NotImplementedError
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
# Add the protected memory variable
|
||||
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
||||
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
||||
else:
|
||||
# TODO should this all put into the memory.__repr__ function?
|
||||
memory_metadata_string = compile_memory_metadata_block(
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
|
||||
|
||||
# Add to the variables list to inject
|
||||
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
||||
|
||||
if template_format == "f-string":
|
||||
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
||||
|
||||
# Catch the special case where the system prompt is unformatted
|
||||
if append_icm_if_missing:
|
||||
if memory_variable_string not in system_prompt:
|
||||
# In this case, append it to the end to make sure memory is still injected
|
||||
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
||||
system_prompt += "\n\n" + memory_variable_string
|
||||
|
||||
# render the variables using the built-in templater
|
||||
try:
|
||||
if user_defined_variables:
|
||||
formatted_prompt = safe_format(system_prompt, variables)
|
||||
else:
|
||||
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
||||
|
||||
else:
|
||||
# TODO support for mustache and jinja2
|
||||
raise NotImplementedError(template_format)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
@trace_method
|
||||
async def compile_system_message_async(
|
||||
system_prompt: str,
|
||||
in_context_memory: Memory,
|
||||
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
|
||||
timezone: str,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
tool_rules_solver: Optional[ToolRulesSolver] = None,
|
||||
sources: Optional[List] = None,
|
||||
max_files_open: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
The base system message may be templated, in which case we need to render the variables.
|
||||
|
||||
The following are reserved variables:
|
||||
- CORE_MEMORY: the in-context memory of the LLM
|
||||
"""
|
||||
|
||||
# Add tool rule constraints if available
|
||||
tool_constraint_block = None
|
||||
if tool_rules_solver is not None:
|
||||
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
|
||||
|
||||
if user_defined_variables is not None:
|
||||
# TODO eventually support the user defining their own variables to inject
|
||||
raise NotImplementedError
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
# Add the protected memory variable
|
||||
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
||||
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
||||
else:
|
||||
# TODO should this all put into the memory.__repr__ function?
|
||||
memory_metadata_string = compile_memory_metadata_block(
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
memory_with_sources = await in_context_memory.compile_in_thread_async(
|
||||
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
|
||||
)
|
||||
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
|
||||
|
||||
# Add to the variables list to inject
|
||||
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
||||
|
||||
if template_format == "f-string":
|
||||
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
||||
|
||||
# Catch the special case where the system prompt is unformatted
|
||||
if append_icm_if_missing:
|
||||
if memory_variable_string not in system_prompt:
|
||||
# In this case, append it to the end to make sure memory is still injected
|
||||
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
||||
system_prompt += "\n\n" + memory_variable_string
|
||||
|
||||
# render the variables using the built-in templater
|
||||
try:
|
||||
if user_defined_variables:
|
||||
formatted_prompt = safe_format(system_prompt, variables)
|
||||
else:
|
||||
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
||||
|
||||
else:
|
||||
# TODO support for mustache and jinja2
|
||||
raise NotImplementedError(template_format)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
@trace_method
|
||||
def initialize_message_sequence(
|
||||
agent_state: AgentState,
|
||||
@@ -601,7 +400,7 @@ async def initialize_message_sequence_async(
|
||||
if memory_edit_timestamp is None:
|
||||
memory_edit_timestamp = get_local_time()
|
||||
|
||||
full_system_message = await compile_system_message_async(
|
||||
full_system_message = await PromptGenerator.compile_system_message_async(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
|
||||
@@ -70,13 +70,16 @@ def runtime_override_tool_json_schema(
|
||||
tool_list: list[JsonDict],
|
||||
response_format: ResponseFormatUnion | None,
|
||||
request_heartbeat: bool = True,
|
||||
terminal_tools: set[str] | None = None,
|
||||
) -> list[JsonDict]:
|
||||
"""Override the tool JSON schemas at runtime if certain conditions are met.
|
||||
|
||||
Cases:
|
||||
1. We will inject `send_message` tool calls with `response_format` if provided
|
||||
2. Tools will have an additional `request_heartbeat` parameter added.
|
||||
2. Tools will have an additional `request_heartbeat` parameter added (except for terminal tools).
|
||||
"""
|
||||
if terminal_tools is None:
|
||||
terminal_tools = set()
|
||||
for tool_json in tool_list:
|
||||
if tool_json["name"] == SEND_MESSAGE_TOOL_NAME and response_format and response_format.type != ResponseFormatType.text:
|
||||
if response_format.type == ResponseFormatType.json_schema:
|
||||
@@ -89,8 +92,8 @@ def runtime_override_tool_json_schema(
|
||||
"properties": {},
|
||||
}
|
||||
if request_heartbeat:
|
||||
# TODO (cliandy): see support for tool control loop parameters
|
||||
if tool_json["name"] != SEND_MESSAGE_TOOL_NAME:
|
||||
# Only add request_heartbeat to non-terminal tools
|
||||
if tool_json["name"] not in terminal_tools:
|
||||
tool_json["parameters"]["properties"][REQUEST_HEARTBEAT_PARAM] = {
|
||||
"type": "boolean",
|
||||
"description": REQUEST_HEARTBEAT_DESCRIPTION,
|
||||
|
||||
@@ -14,9 +14,15 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncBaseMCPClient:
|
||||
def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
|
||||
# HTTP headers
|
||||
AGENT_ID_HEADER = "X-Agent-Id"
|
||||
|
||||
def __init__(
|
||||
self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None
|
||||
):
|
||||
self.server_config = server_config
|
||||
self.oauth_provider = oauth_provider
|
||||
self.agent_id = agent_id
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.initialized = False
|
||||
|
||||
@@ -16,8 +16,10 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncSSEMCPClient(AsyncBaseMCPClient):
|
||||
def __init__(self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
|
||||
super().__init__(server_config, oauth_provider)
|
||||
def __init__(
|
||||
self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None
|
||||
):
|
||||
super().__init__(server_config, oauth_provider, agent_id)
|
||||
|
||||
async def _initialize_connection(self, server_config: SSEServerConfig) -> None:
|
||||
headers = {}
|
||||
@@ -27,6 +29,9 @@ class AsyncSSEMCPClient(AsyncBaseMCPClient):
|
||||
if server_config.auth_header and server_config.auth_token:
|
||||
headers[server_config.auth_header] = server_config.auth_token
|
||||
|
||||
if self.agent_id:
|
||||
headers[self.AGENT_ID_HEADER] = self.agent_id
|
||||
|
||||
# Use OAuth provider if available, otherwise use regular headers
|
||||
if self.oauth_provider:
|
||||
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
@@ -10,6 +12,9 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncStdioMCPClient(AsyncBaseMCPClient):
|
||||
def __init__(self, server_config: StdioServerConfig, oauth_provider=None, agent_id: Optional[str] = None):
|
||||
super().__init__(server_config, oauth_provider, agent_id)
|
||||
|
||||
async def _initialize_connection(self, server_config: StdioServerConfig) -> None:
|
||||
args = [arg.split() for arg in server_config.args]
|
||||
# flatten
|
||||
|
||||
@@ -12,8 +12,13 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
|
||||
def __init__(self, server_config: StreamableHTTPServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
|
||||
super().__init__(server_config, oauth_provider)
|
||||
def __init__(
|
||||
self,
|
||||
server_config: StreamableHTTPServerConfig,
|
||||
oauth_provider: Optional[OAuthClientProvider] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
):
|
||||
super().__init__(server_config, oauth_provider, agent_id)
|
||||
|
||||
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
|
||||
if not isinstance(server_config, StreamableHTTPServerConfig):
|
||||
@@ -28,6 +33,10 @@ class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
|
||||
if server_config.auth_header and server_config.auth_token:
|
||||
headers[server_config.auth_header] = server_config.auth_token
|
||||
|
||||
# Add agent ID header if provided
|
||||
if self.agent_id:
|
||||
headers[self.AGENT_ID_HEADER] = self.agent_id
|
||||
|
||||
# Use OAuth provider if available, otherwise use regular headers
|
||||
if self.oauth_provider:
|
||||
streamable_http_cm = streamablehttp_client(
|
||||
|
||||
@@ -41,6 +41,7 @@ from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPCl
|
||||
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
||||
from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import tool_settings
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -55,19 +56,18 @@ class MCPManager:
|
||||
self.cached_mcp_servers = {} # maps id -> async connection
|
||||
|
||||
@enforce_types
|
||||
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser) -> List[MCPTool]:
|
||||
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]:
|
||||
"""Get a list of all tools for a specific MCP server."""
|
||||
mcp_client = None
|
||||
try:
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
mcp_client = await self.get_mcp_client(server_config, actor)
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# list tools
|
||||
tools = await mcp_client.list_tools()
|
||||
|
||||
# Add health information to each tool
|
||||
for tool in tools:
|
||||
if tool.inputSchema:
|
||||
@@ -92,33 +92,34 @@ class MCPManager:
|
||||
tool_args: Optional[Dict[str, Any]],
|
||||
environment_variables: Dict[str, str],
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> Tuple[str, bool]:
|
||||
"""Call a specific tool from a specific MCP server."""
|
||||
from letta.settings import tool_settings
|
||||
mcp_client = None
|
||||
try:
|
||||
if not tool_settings.mcp_read_from_config:
|
||||
# read from DB
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config(environment_variables)
|
||||
else:
|
||||
# read from config file
|
||||
mcp_config = self.read_mcp_config()
|
||||
if mcp_server_name not in mcp_config:
|
||||
raise ValueError(f"MCP server {mcp_server_name} not found in config.")
|
||||
server_config = mcp_config[mcp_server_name]
|
||||
|
||||
if not tool_settings.mcp_read_from_config:
|
||||
# read from DB
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config(environment_variables)
|
||||
else:
|
||||
# read from config file
|
||||
mcp_config = self.read_mcp_config()
|
||||
if mcp_server_name not in mcp_config:
|
||||
raise ValueError(f"MCP server {mcp_server_name} not found in config.")
|
||||
server_config = mcp_config[mcp_server_name]
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
mcp_client = await self.get_mcp_client(server_config, actor)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# call tool
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
# TODO: change to pydantic tool
|
||||
|
||||
await mcp_client.cleanup()
|
||||
|
||||
return result, success
|
||||
# call tool
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
# TODO: change to pydantic tool
|
||||
return result, success
|
||||
finally:
|
||||
if mcp_client:
|
||||
await mcp_client.cleanup()
|
||||
|
||||
@enforce_types
|
||||
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
@@ -129,7 +130,6 @@ class MCPManager:
|
||||
raise ValueError(f"MCP server '{mcp_server_name}' not found")
|
||||
|
||||
mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor)
|
||||
|
||||
for mcp_tool in mcp_tools:
|
||||
# TODO: @jnjpng move health check to tool class
|
||||
if mcp_tool.name == mcp_tool_name:
|
||||
@@ -450,6 +450,7 @@ class MCPManager:
|
||||
server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig],
|
||||
actor: PydanticUser,
|
||||
oauth_provider: Optional[Any] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]:
|
||||
"""
|
||||
Helper function to create the appropriate MCP client based on server configuration.
|
||||
@@ -482,13 +483,13 @@ class MCPManager:
|
||||
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
server_config = SSEServerConfig(**server_config.model_dump())
|
||||
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider)
|
||||
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
server_config = StdioServerConfig(**server_config.model_dump())
|
||||
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider)
|
||||
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
|
||||
elif server_config.type == MCPServerType.STREAMABLE_HTTP:
|
||||
server_config = StreamableHTTPServerConfig(**server_config.model_dump())
|
||||
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider)
|
||||
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
|
||||
@@ -143,7 +143,6 @@ class SourceManager:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
|
||||
|
||||
await session.execute(upsert_stmt)
|
||||
await session.commit()
|
||||
|
||||
@@ -397,3 +396,29 @@ class SourceManager:
|
||||
sources_orm = result.scalars().all()
|
||||
|
||||
return [source.to_pydantic() for source in sources_orm]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_existing_source_names(self, source_names: List[str], actor: PydanticUser) -> set[str]:
|
||||
"""
|
||||
Fast batch check to see which source names already exist for the organization.
|
||||
|
||||
Args:
|
||||
source_names: List of source names to check
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
Set of source names that already exist
|
||||
"""
|
||||
if not source_names:
|
||||
return set()
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(SourceModel.name).where(
|
||||
SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
existing_names = result.scalars().all()
|
||||
|
||||
return set(existing_names)
|
||||
|
||||
@@ -15,6 +15,8 @@ from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.system import package_summarize_message_no_counts
|
||||
from letta.templates.template_helper import render_template
|
||||
@@ -36,6 +38,10 @@ class Summarizer:
|
||||
message_buffer_limit: int = 10,
|
||||
message_buffer_min: int = 3,
|
||||
partial_evict_summarizer_percentage: float = 0.30,
|
||||
agent_manager: Optional[AgentManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
actor: Optional[User] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
):
|
||||
self.mode = mode
|
||||
|
||||
@@ -46,6 +52,12 @@ class Summarizer:
|
||||
self.summarizer_agent = summarizer_agent
|
||||
self.partial_evict_summarizer_percentage = partial_evict_summarizer_percentage
|
||||
|
||||
# for partial buffer only
|
||||
self.agent_manager = agent_manager
|
||||
self.message_manager = message_manager
|
||||
self.actor = actor
|
||||
self.agent_id = agent_id
|
||||
|
||||
@trace_method
|
||||
async def summarize(
|
||||
self,
|
||||
@@ -121,9 +133,6 @@ class Summarizer:
|
||||
logger.debug("Not forcing summarization, returning in-context messages as is.")
|
||||
return all_in_context_messages, False
|
||||
|
||||
# Very ugly code to pull LLMConfig etc from the SummarizerAgent if we're not using it for anything else
|
||||
assert self.summarizer_agent is not None
|
||||
|
||||
# First step: determine how many messages to retain
|
||||
total_message_count = len(all_in_context_messages)
|
||||
assert self.partial_evict_summarizer_percentage >= 0.0 and self.partial_evict_summarizer_percentage <= 1.0
|
||||
@@ -147,15 +156,13 @@ class Summarizer:
|
||||
|
||||
# Dynamically get the LLMConfig from the summarizer agent
|
||||
# Pretty cringe code here that we need the agent for this but we don't use it
|
||||
agent_state = await self.summarizer_agent.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.summarizer_agent.agent_id, actor=self.summarizer_agent.actor
|
||||
)
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
|
||||
|
||||
# TODO if we do this via the "agent", then we can more easily allow toggling on the memory block version
|
||||
summary_message_str = await simple_summary(
|
||||
messages=messages_to_summarize,
|
||||
llm_config=agent_state.llm_config,
|
||||
actor=self.summarizer_agent.actor,
|
||||
actor=self.actor,
|
||||
include_ack=True,
|
||||
)
|
||||
|
||||
@@ -185,9 +192,9 @@ class Summarizer:
|
||||
)[0]
|
||||
|
||||
# Create the message in the DB
|
||||
await self.summarizer_agent.message_manager.create_many_messages_async(
|
||||
await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[summary_message_obj],
|
||||
actor=self.summarizer_agent.actor,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
|
||||
@@ -354,7 +361,11 @@ async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor:
|
||||
# NOTE: we should disable the inner_thoughts_in_kwargs here, because we don't use it
|
||||
# I'm leaving it commented it out for now for safety but is fine assuming the var here is a copy not a reference
|
||||
# llm_config.put_inner_thoughts_in_kwargs = False
|
||||
response_data = await llm_client.request_async(request_data, llm_config)
|
||||
try:
|
||||
response_data = await llm_client.request_async(request_data, llm_config)
|
||||
except Exception as e:
|
||||
# handle LLM error (likely a context window exceeded error)
|
||||
raise llm_client.handle_llm_error(e)
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, llm_config)
|
||||
if response.choices[0].message.content is None:
|
||||
logger.warning("No content returned from summarizer")
|
||||
|
||||
@@ -151,16 +151,16 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
offset = file_request.offset
|
||||
length = file_request.length
|
||||
|
||||
# Convert 1-indexed offset/length to 0-indexed start/end for LineChunker
|
||||
# Use 0-indexed offset/length directly for LineChunker
|
||||
start, end = None, None
|
||||
if offset is not None or length is not None:
|
||||
if offset is not None and offset < 1:
|
||||
raise ValueError(f"Offset for file {file_name} must be >= 1 (1-indexed), got {offset}")
|
||||
if offset is not None and offset < 0:
|
||||
raise ValueError(f"Offset for file {file_name} must be >= 0 (0-indexed), got {offset}")
|
||||
if length is not None and length < 1:
|
||||
raise ValueError(f"Length for file {file_name} must be >= 1, got {length}")
|
||||
|
||||
# Convert to 0-indexed for LineChunker
|
||||
start = (offset - 1) if offset is not None else None
|
||||
# Use offset directly as it's already 0-indexed
|
||||
start = offset if offset is not None else None
|
||||
if start is not None and length is not None:
|
||||
end = start + length
|
||||
else:
|
||||
@@ -193,7 +193,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
visible_content=visible_content,
|
||||
max_files_open=agent_state.max_files_open,
|
||||
start_line=start + 1 if start is not None else None, # convert to 1-indexed for user display
|
||||
end_line=end if end is not None else None, # end is already exclusive in slicing, so this is correct
|
||||
end_line=end if end is not None else None, # end is already exclusive, shows as 1-indexed inclusive
|
||||
)
|
||||
|
||||
opened_files.append(file_name)
|
||||
@@ -220,10 +220,14 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
for req in file_requests:
|
||||
previous_info = format_previous_range(req.file_name)
|
||||
if req.offset is not None and req.length is not None:
|
||||
end_line = req.offset + req.length - 1
|
||||
file_summaries.append(f"{req.file_name} (lines {req.offset}-{end_line}){previous_info}")
|
||||
# Display as 1-indexed for user readability: (offset+1) to (offset+length)
|
||||
start_line = req.offset + 1
|
||||
end_line = req.offset + req.length
|
||||
file_summaries.append(f"{req.file_name} (lines {start_line}-{end_line}){previous_info}")
|
||||
elif req.offset is not None:
|
||||
file_summaries.append(f"{req.file_name} (lines {req.offset}-end){previous_info}")
|
||||
# Display as 1-indexed
|
||||
start_line = req.offset + 1
|
||||
file_summaries.append(f"{req.file_name} (lines {start_line}-end){previous_info}")
|
||||
else:
|
||||
file_summaries.append(f"{req.file_name}{previous_info}")
|
||||
|
||||
|
||||
@@ -37,8 +37,10 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
# TODO: may need to have better client connection management
|
||||
|
||||
environment_variables = {}
|
||||
agent_id = None
|
||||
if agent_state:
|
||||
environment_variables = agent_state.get_agent_env_vars_as_dict()
|
||||
agent_id = agent_state.id
|
||||
|
||||
function_response, success = await mcp_manager.execute_mcp_server_tool(
|
||||
mcp_server_name=mcp_server_name,
|
||||
@@ -46,6 +48,7 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
tool_args=function_args,
|
||||
environment_variables=environment_variables,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
return ToolExecutionResult(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
@@ -129,6 +130,18 @@ class ToolExecutionManager:
|
||||
result.func_return = FUNCTION_RETURN_VALUE_TRUNCATED(return_str, len(return_str), tool.return_char_limit)
|
||||
return result
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
self.logger.error(f"Aysnc cancellation error executing tool {function_name}: {str(e)}")
|
||||
error_message = get_friendly_error_msg(
|
||||
function_name=function_name,
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=error_message,
|
||||
stderr=[traceback.format_exc()],
|
||||
)
|
||||
except Exception as e:
|
||||
status = "error"
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
|
||||
@@ -184,7 +184,9 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def bulk_upsert_tools_async(self, pydantic_tools: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
|
||||
async def bulk_upsert_tools_async(
|
||||
self, pydantic_tools: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
|
||||
) -> List[PydanticTool]:
|
||||
"""
|
||||
Bulk create or update multiple tools in a single database transaction.
|
||||
|
||||
@@ -227,10 +229,10 @@ class ToolManager:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# use optimized postgresql bulk upsert
|
||||
async with db_registry.async_session() as session:
|
||||
return await self._bulk_upsert_postgresql(session, pydantic_tools, actor)
|
||||
return await self._bulk_upsert_postgresql(session, pydantic_tools, actor, override_existing_tools)
|
||||
else:
|
||||
# fallback to individual upserts for sqlite
|
||||
return await self._upsert_tools_individually(pydantic_tools, actor)
|
||||
return await self._upsert_tools_individually(pydantic_tools, actor, override_existing_tools)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -784,8 +786,10 @@ class ToolManager:
|
||||
return await self._upsert_tools_individually(tool_data_list, actor)
|
||||
|
||||
@trace_method
|
||||
async def _bulk_upsert_postgresql(self, session, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""hyper-optimized postgresql bulk upsert using on_conflict_do_update."""
|
||||
async def _bulk_upsert_postgresql(
|
||||
self, session, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
|
||||
) -> List[PydanticTool]:
|
||||
"""hyper-optimized postgresql bulk upsert using on_conflict_do_update or on_conflict_do_nothing."""
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
@@ -809,32 +813,51 @@ class ToolManager:
|
||||
# use postgresql's native bulk upsert
|
||||
stmt = insert(table).values(insert_data)
|
||||
|
||||
# on conflict, update all columns except id, created_at, and _created_by_id
|
||||
excluded = stmt.excluded
|
||||
update_dict = {}
|
||||
for col in table.columns:
|
||||
if col.name not in ("id", "created_at", "_created_by_id"):
|
||||
if col.name == "updated_at":
|
||||
update_dict[col.name] = func.now()
|
||||
else:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
if override_existing_tools:
|
||||
# on conflict, update all columns except id, created_at, and _created_by_id
|
||||
excluded = stmt.excluded
|
||||
update_dict = {}
|
||||
for col in table.columns:
|
||||
if col.name not in ("id", "created_at", "_created_by_id"):
|
||||
if col.name == "updated_at":
|
||||
update_dict[col.name] = func.now()
|
||||
else:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
|
||||
else:
|
||||
# on conflict, do nothing (skip existing tools)
|
||||
upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"])
|
||||
|
||||
await session.execute(upsert_stmt)
|
||||
await session.commit()
|
||||
|
||||
# fetch results
|
||||
# fetch results (includes both inserted and skipped tools)
|
||||
tool_names = [tool.name for tool in tool_data_list]
|
||||
result_query = select(ToolModel).where(ToolModel.name.in_(tool_names), ToolModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(result_query)
|
||||
return [tool.to_pydantic() for tool in result.scalars()]
|
||||
|
||||
@trace_method
|
||||
async def _upsert_tools_individually(self, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]:
|
||||
async def _upsert_tools_individually(
|
||||
self, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True
|
||||
) -> List[PydanticTool]:
|
||||
"""fallback to individual upserts for sqlite (original approach)."""
|
||||
tools = []
|
||||
for tool in tool_data_list:
|
||||
upserted_tool = await self.create_or_update_tool_async(tool, actor)
|
||||
tools.append(upserted_tool)
|
||||
if override_existing_tools:
|
||||
# update existing tools if they exist
|
||||
upserted_tool = await self.create_or_update_tool_async(tool, actor)
|
||||
tools.append(upserted_tool)
|
||||
else:
|
||||
# skip existing tools, only create new ones
|
||||
existing_tool_id = await self.get_tool_id_by_name_async(tool_name=tool.name, actor=actor)
|
||||
if existing_tool_id:
|
||||
# tool exists, fetch and return it without updating
|
||||
existing_tool = await self.get_tool_by_id_async(existing_tool_id, actor=actor)
|
||||
tools.append(existing_tool)
|
||||
else:
|
||||
# tool doesn't exist, create it
|
||||
created_tool = await self.create_tool_async(tool, actor=actor)
|
||||
tools.append(created_tool)
|
||||
return tools
|
||||
|
||||
@@ -252,6 +252,7 @@ class Settings(BaseSettings):
|
||||
track_errored_messages: bool = Field(default=True, description="Enable tracking for errored messages")
|
||||
track_stop_reason: bool = Field(default=True, description="Enable tracking stop reason on steps.")
|
||||
track_agent_run: bool = Field(default=True, description="Enable tracking agent run with cancellation support")
|
||||
track_provider_trace: bool = Field(default=True, description="Enable tracking raw llm request and response at each step")
|
||||
|
||||
# FastAPI Application Settings
|
||||
uvicorn_workers: int = 1
|
||||
|
||||
@@ -1103,6 +1103,43 @@ def safe_create_task(coro, logger: Logger, label: str = "background task"):
|
||||
return asyncio.create_task(wrapper())
|
||||
|
||||
|
||||
def safe_create_file_processing_task(coro, file_metadata, server, actor, logger: Logger, label: str = "file processing task"):
|
||||
"""
|
||||
Create a task for file processing that updates file status on failure.
|
||||
|
||||
This is a specialized version of safe_create_task that ensures file
|
||||
status is properly updated to ERROR with a meaningful message if the
|
||||
task fails.
|
||||
|
||||
Args:
|
||||
coro: The coroutine to execute
|
||||
file_metadata: FileMetadata object being processed
|
||||
server: Server instance with file_manager
|
||||
actor: User performing the operation
|
||||
logger: Logger instance for error logging
|
||||
label: Description of the task for logging
|
||||
"""
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
|
||||
async def wrapper():
|
||||
try:
|
||||
await coro
|
||||
except Exception as e:
|
||||
logger.exception(f"{label} failed for file {file_metadata.file_name} with {type(e).__name__}: {e}")
|
||||
# update file status to ERROR with a meaningful message
|
||||
try:
|
||||
await server.file_manager.update_file_status(
|
||||
file_id=file_metadata.id,
|
||||
actor=actor,
|
||||
processing_status=FileProcessingStatus.ERROR,
|
||||
error_message=f"Processing failed: {str(e)}" if str(e) else f"Processing failed: {type(e).__name__}",
|
||||
)
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update file status to ERROR for {file_metadata.id}: {update_error}")
|
||||
|
||||
return asyncio.create_task(wrapper())
|
||||
|
||||
|
||||
class CancellationSignal:
|
||||
"""
|
||||
A signal that can be checked for cancellation during streaming operations.
|
||||
|
||||
2763
poetry.lock
generated
2763
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
17
project.json
17
project.json
@@ -7,7 +7,7 @@
|
||||
"lock": {
|
||||
"executor": "@nxlv/python:run-commands",
|
||||
"options": {
|
||||
"command": "poetry lock --no-update",
|
||||
"command": "uv lock --no-update",
|
||||
"cwd": "apps/core"
|
||||
}
|
||||
},
|
||||
@@ -26,10 +26,7 @@
|
||||
"dev": {
|
||||
"executor": "@nxlv/python:run-commands",
|
||||
"options": {
|
||||
"commands": [
|
||||
"./otel/start-otel-collector.sh",
|
||||
"poetry run letta server"
|
||||
],
|
||||
"commands": ["./otel/start-otel-collector.sh", "uv run letta server"],
|
||||
"parallel": true,
|
||||
"cwd": "apps/core"
|
||||
}
|
||||
@@ -39,7 +36,7 @@
|
||||
"options": {
|
||||
"commands": [
|
||||
"./otel/start-otel-collector.sh",
|
||||
"poetry run letta server --debug --reload"
|
||||
"uv run letta server --debug --reload"
|
||||
],
|
||||
"parallel": true,
|
||||
"cwd": "apps/core"
|
||||
@@ -58,21 +55,21 @@
|
||||
"install": {
|
||||
"executor": "@nxlv/python:run-commands",
|
||||
"options": {
|
||||
"command": "poetry install --all-extras",
|
||||
"command": "uv sync --all-extras",
|
||||
"cwd": "apps/core"
|
||||
}
|
||||
},
|
||||
"lint": {
|
||||
"executor": "@nxlv/python:run-commands",
|
||||
"options": {
|
||||
"command": "poetry run isort --profile black . && poetry run black . && poetry run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .",
|
||||
"command": "uv run isort --profile black . && uv run black . && uv run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .",
|
||||
"cwd": "apps/core"
|
||||
}
|
||||
},
|
||||
"database:migrate": {
|
||||
"executor": "@nxlv/python:run-commands",
|
||||
"options": {
|
||||
"command": "poetry run alembic upgrade head",
|
||||
"command": "uv run alembic upgrade head",
|
||||
"cwd": "apps/core"
|
||||
}
|
||||
},
|
||||
@@ -83,7 +80,7 @@
|
||||
"{workspaceRoot}/coverage/apps/core"
|
||||
],
|
||||
"options": {
|
||||
"command": "poetry run pytest tests/",
|
||||
"command": "uv run pytest tests/",
|
||||
"cwd": "apps/core"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,96 @@
|
||||
[project]
|
||||
name = "letta"
|
||||
version = "0.10.0"
|
||||
description = "Create LLM agents with long-term memory and custom tools"
|
||||
authors = [
|
||||
{name = "Letta Team", email = "contact@letta.com"},
|
||||
]
|
||||
license = {text = "Apache License"}
|
||||
readme = "README.md"
|
||||
requires-python = "<3.14,>=3.11"
|
||||
dependencies = [
|
||||
"typer>=0.15.2",
|
||||
"questionary>=2.0.1",
|
||||
"pytz>=2023.3.post1",
|
||||
"tqdm>=4.66.1",
|
||||
"black[jupyter]>=24.2.0",
|
||||
"setuptools>=70",
|
||||
"prettytable>=3.9.0",
|
||||
"docstring-parser>=0.16,<0.17",
|
||||
"httpx>=0.28.0",
|
||||
"numpy>=2.1.0",
|
||||
"demjson3>=3.0.6",
|
||||
"pyyaml>=6.0.1",
|
||||
"sqlalchemy-json>=0.7.0",
|
||||
"pydantic>=2.10.6",
|
||||
"html2text>=2020.1.16",
|
||||
"sqlalchemy[asyncio]>=2.0.41",
|
||||
"python-box>=7.1.1",
|
||||
"sqlmodel>=0.0.16",
|
||||
"python-multipart>=0.0.19",
|
||||
"sqlalchemy-utils>=0.41.2",
|
||||
"pydantic-settings>=2.2.1",
|
||||
"httpx-sse>=0.4.0",
|
||||
"nltk>=3.8.1",
|
||||
"jinja2>=3.1.5",
|
||||
"composio-core>=0.7.7",
|
||||
"alembic>=1.13.3",
|
||||
"pyhumps>=3.8.0",
|
||||
"pathvalidate>=3.2.1",
|
||||
"sentry-sdk[fastapi]==2.19.1",
|
||||
"rich>=13.9.4",
|
||||
"brotli>=1.1.0",
|
||||
"grpcio>=1.68.1",
|
||||
"grpcio-tools>=1.68.1",
|
||||
"llama-index>=0.12.2",
|
||||
"llama-index-embeddings-openai>=0.3.1",
|
||||
"anthropic>=0.49.0",
|
||||
"letta_client>=0.1.276",
|
||||
"openai>=1.99.9",
|
||||
"opentelemetry-api==1.30.0",
|
||||
"opentelemetry-sdk==1.30.0",
|
||||
"opentelemetry-instrumentation-requests==0.51b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.51b0",
|
||||
"opentelemetry-exporter-otlp==1.30.0",
|
||||
"faker>=36.1.0",
|
||||
"colorama>=0.4.6",
|
||||
"marshmallow-sqlalchemy>=1.4.1",
|
||||
"datamodel-code-generator[http]>=0.25.0",
|
||||
"mcp[cli]>=1.9.4",
|
||||
"firecrawl-py==2.16.5",
|
||||
"apscheduler>=3.11.0",
|
||||
"aiomultiprocess>=0.9.1",
|
||||
"matplotlib>=3.10.1",
|
||||
"tavily-python>=0.7.2",
|
||||
"mistralai>=1.8.1",
|
||||
"structlog>=25.4.0",
|
||||
"certifi>=2025.6.15",
|
||||
"markitdown[docx,pdf,pptx]>=0.1.2",
|
||||
"orjson>=3.11.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
postgres = ["pgvector>=0.2.3", "pg8000>=1.30.3", "psycopg2-binary>=2.9.10", "psycopg2>=2.9.10", "asyncpg>=0.30.0"]
|
||||
redis = ["redis>=6.2.0"]
|
||||
pinecone = ["pinecone[asyncio]>=7.3.0"]
|
||||
dev = ["pytest>=8.0.0", "pytest-asyncio>=0.24.0", "pexpect>=4.9.0", "black>=24.2.0", "pre-commit>=3.5.0", "pyright>=1.1.347", "pytest-order>=1.2.0", "autoflake>=2.3.0", "isort>=5.13.2", "locust>=2.31.5"]
|
||||
experimental = ["uvloop>=0.21.0; sys_platform != 'win32'", "granian[reload]>=2.3.2", "google-cloud-profiler>=4.1.0"]
|
||||
server = ["websockets>=12.0", "fastapi>=0.115.6", "uvicorn>=0.24.0.post1"]
|
||||
cloud-tool-sandbox = ["e2b-code-interpreter==1.5.2", "modal>=1.1.0"]
|
||||
external-tools = ["docker>=7.1.0", "langchain>=0.3.7", "wikipedia>=1.4.0", "langchain-community>=0.3.7", "firecrawl-py==2.16.5"]
|
||||
tests = ["wikipedia>=1.4.0", "pytest-asyncio>=0.24.0"]
|
||||
sqlite = ["aiosqlite>=0.21.0", "sqlite-vec>=0.1.7a2"]
|
||||
bedrock = ["boto3>=1.36.24", "aioboto3>=14.3.0"]
|
||||
google = ["google-genai>=1.15.0"]
|
||||
desktop = ["pyright>=1.1.347", "fastapi>=0.115.6", "uvicorn>=0.24.0.post1", "docker>=7.1.0", "langchain>=0.3.7", "wikipedia>=1.4.0", "langchain-community>=0.3.7", "locust>=2.31.5", "sqlite-vec>=0.1.7a2", "pgvector>=0.2.3"]
|
||||
all = ["pgvector>=0.2.3", "turbopuffer>=0.5.17", "pg8000>=1.30.3", "psycopg2-binary>=2.9.10", "psycopg2>=2.9.10", "pytest", "pytest-asyncio>=0.24.0", "pexpect>=4.9.0", "black>=24.2.0", "pre-commit>=3.5.0", "pyright>=1.1.347", "pytest-order>=1.2.0", "autoflake>=2.3.0", "isort>=5.13.2", "fastapi>=0.115.6", "uvicorn>=0.24.0.post1", "docker>=7.1.0", "langchain>=0.3.7", "wikipedia>=1.4.0", "langchain-community>=0.3.7", "locust>=2.31.5", "uvloop>=0.21.0; sys_platform != 'win32'", "granian[reload]>=2.3.2", "redis>=6.2.0", "pinecone[asyncio]>=7.3.0", "google-cloud-profiler>=4.1.0"]
|
||||
|
||||
[project.scripts]
|
||||
letta = "letta.main:app"
|
||||
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.11.4"
|
||||
version = "0.11.5"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
@@ -72,7 +162,7 @@ llama-index = "^0.12.2"
|
||||
llama-index-embeddings-openai = "^0.3.1"
|
||||
e2b-code-interpreter = {version = "^1.0.3", optional = true}
|
||||
anthropic = "^0.49.0"
|
||||
letta_client = "^0.1.220"
|
||||
letta_client = "^0.1.277"
|
||||
openai = "^1.99.9"
|
||||
opentelemetry-api = "1.30.0"
|
||||
opentelemetry-sdk = "1.30.0"
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
{
|
||||
"context_window": 8192,
|
||||
"model": "llama-3.1-70b-versatile",
|
||||
"model_endpoint_type": "groq",
|
||||
"model_endpoint": "https://api.groq.com/openai/v1",
|
||||
"model_wrapper": null,
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
"context_window": 8192,
|
||||
"model": "qwen/qwen3-32b",
|
||||
"model_endpoint_type": "groq",
|
||||
"model_endpoint": "https://api.groq.com/openai/v1",
|
||||
"model_wrapper": null,
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts
|
||||
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import tool_settings
|
||||
@@ -14,6 +16,36 @@ def pytest_configure(config):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def disable_db_pooling_for_tests():
|
||||
"""Disable database connection pooling for the entire test session."""
|
||||
os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"] = "true"
|
||||
yield
|
||||
if "LETTA_DISABLE_SQLALCHEMY_POOLING" in os.environ:
|
||||
del os.environ["LETTA_DISABLE_SQLALCHEMY_POOLING"]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def cleanup_db_connections():
|
||||
"""Cleanup database connections after each test."""
|
||||
yield
|
||||
try:
|
||||
if hasattr(db_registry, "_async_engines"):
|
||||
for engine in db_registry._async_engines.values():
|
||||
if engine:
|
||||
try:
|
||||
await engine.dispose()
|
||||
except Exception:
|
||||
# Suppress common teardown errors that don't affect test validity
|
||||
pass
|
||||
db_registry._initialized["async"] = False
|
||||
db_registry._async_engines.clear()
|
||||
db_registry._async_session_factories.clear()
|
||||
except Exception:
|
||||
# Suppress all cleanup errors to avoid confusing test failures
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_e2b_api_key() -> Generator[None, None, None]:
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from letta_client import AsyncLetta, Letta
|
||||
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
@@ -254,7 +254,8 @@ def validate_context_window_overview(
|
||||
assert len(overview.functions_definitions) > 0
|
||||
|
||||
|
||||
def upload_test_agentfile_from_disk(server_url: str, filename: str) -> ImportedAgentsResponse:
|
||||
# Changed this from server_url to client since client may be authenticated or not
|
||||
def upload_test_agentfile_from_disk(client: Letta, filename: str) -> ImportedAgentsResponse:
|
||||
"""
|
||||
Upload a given .af file to live FastAPI server.
|
||||
"""
|
||||
@@ -263,18 +264,87 @@ def upload_test_agentfile_from_disk(server_url: str, filename: str) -> ImportedA
|
||||
file_path = os.path.join(path_to_test_agent_files, filename)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": (filename, f, "application/json")}
|
||||
return client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
|
||||
|
||||
# Send parameters as form data instead of query parameters
|
||||
form_data = {
|
||||
"append_copy_suffix": "true",
|
||||
"override_existing_tools": "false",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{server_url}/v1/agents/import",
|
||||
headers={"user_id": ""},
|
||||
files=files,
|
||||
data=form_data, # Send as form data
|
||||
)
|
||||
return ImportedAgentsResponse(**response.json())
|
||||
async def upload_test_agentfile_from_disk_async(client: AsyncLetta, filename: str) -> ImportedAgentsResponse:
|
||||
"""
|
||||
Upload a given .af file to live FastAPI server.
|
||||
"""
|
||||
path_to_current_file = os.path.dirname(__file__)
|
||||
path_to_test_agent_files = path_to_current_file.removesuffix("/helpers") + "/test_agent_files"
|
||||
file_path = os.path.join(path_to_test_agent_files, filename)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
uploaded = await client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
|
||||
return uploaded
|
||||
|
||||
|
||||
def upload_file_and_wait(
|
||||
client: Letta,
|
||||
source_id: str,
|
||||
file_path: str,
|
||||
name: Optional[str] = None,
|
||||
max_wait: int = 60,
|
||||
duplicate_handling: Optional[str] = None,
|
||||
):
|
||||
"""Helper function to upload a file and wait for processing to complete"""
|
||||
with open(file_path, "rb") as f:
|
||||
if duplicate_handling:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
|
||||
else:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
|
||||
|
||||
# wait for the file to be processed
|
||||
start_time = time.time()
|
||||
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
|
||||
if time.time() - start_time > max_wait:
|
||||
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
|
||||
time.sleep(1)
|
||||
file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id)
|
||||
print("Waiting for file processing to complete...", file_metadata.processing_status)
|
||||
|
||||
if file_metadata.processing_status == "error":
|
||||
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
|
||||
|
||||
return file_metadata
|
||||
|
||||
|
||||
def upload_file_and_wait_list_files(
|
||||
client: Letta,
|
||||
source_id: str,
|
||||
file_path: str,
|
||||
name: Optional[str] = None,
|
||||
max_wait: int = 60,
|
||||
duplicate_handling: Optional[str] = None,
|
||||
):
|
||||
"""Helper function to upload a file and wait for processing using list_files instead of get_file_metadata"""
|
||||
with open(file_path, "rb") as f:
|
||||
if duplicate_handling:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
|
||||
else:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
|
||||
|
||||
# wait for the file to be processed using list_files
|
||||
start_time = time.time()
|
||||
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
|
||||
if time.time() - start_time > max_wait:
|
||||
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
|
||||
time.sleep(1)
|
||||
|
||||
# use list_files to get all files and find our specific file
|
||||
files = client.sources.files.list(source_id=source_id, limit=100)
|
||||
# find the file with matching id
|
||||
for file in files:
|
||||
if file.id == file_metadata.id:
|
||||
file_metadata = file
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f"File {file_metadata.id} not found in source files list")
|
||||
|
||||
print("Waiting for file processing to complete (via list_files)...", file_metadata.processing_status)
|
||||
|
||||
if file_metadata.processing_status == "error":
|
||||
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
|
||||
|
||||
return file_metadata
|
||||
|
||||
@@ -26,13 +26,6 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
|
||||
@@ -321,7 +321,7 @@ def tool_with_pip_requirements(test_user):
|
||||
import requests
|
||||
|
||||
# Simple usage to verify packages work
|
||||
response = requests.get("https://httpbin.org/json", timeout=5)
|
||||
response = requests.get("https://httpbin.org/json", timeout=30)
|
||||
arr = np.array([1, 2, 3])
|
||||
return f"Success! Status: {response.status_code}, Array sum: {np.sum(arr)}"
|
||||
except ImportError as e:
|
||||
|
||||
@@ -70,7 +70,7 @@ def client(server_url: str) -> Letta:
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_state(client: Letta) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
@@ -333,7 +333,7 @@ def test_web_search(
|
||||
], f"Invalid api_key_source: {response_json['api_key_source']}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio(scope="function")
|
||||
async def test_web_search_uses_agent_env_var_model():
|
||||
"""Test that web search uses the model specified in agent tool exec env vars."""
|
||||
|
||||
|
||||
@@ -32,20 +32,6 @@ from letta.services.tool_sandbox.modal_sandbox_v2 import AsyncToolSandboxModalV2
|
||||
from letta.services.tool_sandbox.modal_version_manager import ModalVersionManager, get_version_manager
|
||||
from letta.services.user_manager import UserManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
# Cleanup tasks before closing loop
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SHARED FIXTURES
|
||||
# ============================================================================
|
||||
@@ -90,12 +76,12 @@ def basic_tool(test_user):
|
||||
source_code="""
|
||||
def calculate(operation: str, a: float, b: float) -> float:
|
||||
'''Perform a calculation on two numbers.
|
||||
|
||||
|
||||
Args:
|
||||
operation: The operation to perform (add, subtract, multiply, divide)
|
||||
a: The first number
|
||||
b: The second number
|
||||
|
||||
|
||||
Returns:
|
||||
float: The result of the calculation
|
||||
'''
|
||||
@@ -145,11 +131,11 @@ import asyncio
|
||||
|
||||
async def fetch_data(url: str, delay: float = 0.1) -> Dict:
|
||||
'''Simulate fetching data from a URL.
|
||||
|
||||
|
||||
Args:
|
||||
url: The URL to fetch data from
|
||||
delay: The delay in seconds before returning
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary containing the fetched data
|
||||
'''
|
||||
@@ -194,17 +180,17 @@ import hashlib
|
||||
|
||||
def process_json(data: str) -> Dict:
|
||||
'''Process JSON data and return metadata.
|
||||
|
||||
|
||||
Args:
|
||||
data: The JSON string to process
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: Metadata about the JSON data
|
||||
'''
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
data_hash = hashlib.md5(data.encode()).hexdigest()
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"keys": list(parsed.keys()) if isinstance(parsed, dict) else None,
|
||||
|
||||
@@ -10,7 +10,7 @@ from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, MessageCreate, ReasoningMessage, ToolCallMessage
|
||||
from letta_client.core import RequestOptions
|
||||
|
||||
from tests.helpers.utils import upload_test_agentfile_from_disk
|
||||
from tests.helpers.utils import upload_test_agentfile_from_disk_async
|
||||
|
||||
REASONING_THROTTLE_MS = 100
|
||||
TEST_USER_MESSAGE = "What products or services does 11x AI sell?"
|
||||
@@ -66,7 +66,7 @@ async def test_pinecone_tool(client: AsyncLetta, server_url: str) -> None:
|
||||
"""
|
||||
Test the Pinecone tool integration with the Letta client.
|
||||
"""
|
||||
response = upload_test_agentfile_from_disk(server_url, "knowledge-base.af")
|
||||
response = await upload_test_agentfile_from_disk_async(client, "knowledge-base.af")
|
||||
|
||||
agent_id = response.agent_ids[0]
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [
|
||||
]
|
||||
|
||||
# configs for models that are to dumb to do much other than messaging
|
||||
limited_configs = ["ollama.json", "together-qwen-2.5-72b-instruct.json", "vllm.json", "lmstudio.json"]
|
||||
limited_configs = ["ollama.json", "together-qwen-2.5-72b-instruct.json", "vllm.json", "lmstudio.json", "groq.json"]
|
||||
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
@@ -161,6 +161,7 @@ all_configs = [
|
||||
"gemini-2.5-pro-vertex.json",
|
||||
"ollama.json",
|
||||
"together-qwen-2.5-72b-instruct.json",
|
||||
"groq.json",
|
||||
]
|
||||
|
||||
reasoning_configs = [
|
||||
@@ -398,7 +399,7 @@ def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
|
||||
assert content is None
|
||||
|
||||
|
||||
def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
|
||||
def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]], reasoning_enabled: bool) -> None:
|
||||
"""
|
||||
Validate that Anthropic/Claude format assistant messages with tool_use have no <thinking> tags.
|
||||
Args:
|
||||
@@ -432,10 +433,12 @@ def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]]) -> None:
|
||||
# Verify that the message only contains tool_use items
|
||||
tool_use_items = [item for item in content_list if item.get("type") == "tool_use"]
|
||||
assert len(tool_use_items) > 0, "Assistant message should have at least one tool_use item"
|
||||
assert len(content_list) == len(tool_use_items), (
|
||||
f"Assistant message should ONLY contain tool_use items when reasoning is disabled. "
|
||||
f"Found {len(content_list)} total items but only {len(tool_use_items)} are tool_use items."
|
||||
)
|
||||
|
||||
if not reasoning_enabled:
|
||||
assert len(content_list) == len(tool_use_items), (
|
||||
f"Assistant message should ONLY contain tool_use items when reasoning is disabled. "
|
||||
f"Found {len(content_list)} total items but only {len(tool_use_items)} are tool_use items."
|
||||
)
|
||||
|
||||
|
||||
def validate_google_format_scrubbing(contents: List[Dict[str, Any]]) -> None:
|
||||
@@ -1131,6 +1134,146 @@ def test_token_streaming_agent_loop_error(
|
||||
assert len(messages_from_db) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_background_token_streaming_greeting_with_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
# Use longer message for Anthropic models to test if they stream in chunks
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY
|
||||
else:
|
||||
messages_to_send = USER_MESSAGE_FORCE_REPLY
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=messages_to_send,
|
||||
stream_tokens=True,
|
||||
background=True,
|
||||
)
|
||||
messages = accumulate_chunks(
|
||||
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
|
||||
)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
run_id = messages[0].run_id
|
||||
assert run_id is not None
|
||||
|
||||
response = client.runs.stream(run_id=run_id, starting_after=0)
|
||||
messages = accumulate_chunks(
|
||||
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
|
||||
)
|
||||
assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
|
||||
|
||||
last_message_cursor = messages[-3].seq_id - 1
|
||||
response = client.runs.stream(run_id=run_id, starting_after=last_message_cursor)
|
||||
messages = accumulate_chunks(
|
||||
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
|
||||
)
|
||||
assert len(messages) == 3
|
||||
assert messages[0].message_type == "assistant_message" and messages[0].seq_id == last_message_cursor + 1
|
||||
assert messages[1].message_type == "stop_reason"
|
||||
assert messages[2].message_type == "usage_statistics"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_background_token_streaming_greeting_without_assistant_message(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
# Use longer message for Anthropic models to force chunking
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY
|
||||
else:
|
||||
messages_to_send = USER_MESSAGE_FORCE_REPLY
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=messages_to_send,
|
||||
use_assistant_message=False,
|
||||
stream_tokens=True,
|
||||
background=True,
|
||||
)
|
||||
messages = accumulate_chunks(
|
||||
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
|
||||
)
|
||||
assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id, use_assistant_message=False)
|
||||
assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_background_token_streaming_tool_call(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests sending a streaming message with a synchronous client.
|
||||
Checks that each chunk in the stream has the correct message types.
|
||||
"""
|
||||
# get the config filename
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
last_message = client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
# Use longer message for Anthropic models to force chunking
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
messages_to_send = USER_MESSAGE_ROLL_DICE_LONG
|
||||
else:
|
||||
messages_to_send = USER_MESSAGE_ROLL_DICE
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=messages_to_send,
|
||||
stream_tokens=True,
|
||||
background=True,
|
||||
)
|
||||
messages = accumulate_chunks(
|
||||
list(response), verify_token_streaming=(llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"])
|
||||
)
|
||||
assert_tool_call_response(messages, streaming=True, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
|
||||
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
|
||||
start = time.time()
|
||||
while True:
|
||||
@@ -1842,7 +1985,7 @@ def test_inner_thoughts_toggle_interleaved(
|
||||
validate_openai_format_scrubbing(messages)
|
||||
elif llm_config.model_endpoint_type == "anthropic":
|
||||
messages = response["messages"]
|
||||
validate_anthropic_format_scrubbing(messages)
|
||||
validate_anthropic_format_scrubbing(messages, llm_config.enable_reasoner)
|
||||
elif llm_config.model_endpoint_type in ["google_ai", "google_vertex"]:
|
||||
# Google uses 'contents' instead of 'messages'
|
||||
contents = response.get("contents", response.get("messages", []))
|
||||
|
||||
@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
|
||||
from letta.functions.schema_generator import generate_tool_schema_for_mcp
|
||||
from letta.functions.schema_validator import SchemaHealth, validate_complete_json_schema
|
||||
|
||||
|
||||
@@ -113,10 +114,8 @@ def test_empty_object_in_required_marked_invalid():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_mcp_tool_rejects_non_strict_schemas():
|
||||
"""Test that adding MCP tools with non-strict schemas is rejected."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
async def test_add_mcp_tool_accepts_non_strict_schemas():
|
||||
"""Test that adding MCP tools with non-strict schemas is allowed."""
|
||||
from letta.server.rest_api.routers.v1.tools import add_mcp_tool
|
||||
from letta.settings import tool_settings
|
||||
|
||||
@@ -138,15 +137,19 @@ async def test_add_mcp_tool_rejects_non_strict_schemas():
|
||||
mock_server = AsyncMock()
|
||||
mock_server.get_tools_from_mcp_server = AsyncMock(return_value=[non_strict_tool])
|
||||
mock_server.user_manager.get_user_or_default = MagicMock()
|
||||
mock_server.tool_manager.create_mcp_tool_async = AsyncMock(return_value=non_strict_tool)
|
||||
mock_get_server.return_value = mock_server
|
||||
|
||||
# Should raise HTTPException for non-strict schema
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await add_mcp_tool(mcp_server_name="test_server", mcp_tool_name="test_tool", server=mock_server, actor_id=None)
|
||||
# Should accept non-strict schema without raising an exception
|
||||
result = await add_mcp_tool(mcp_server_name="test_server", mcp_tool_name="test_tool", server=mock_server, actor_id=None)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "non-strict schema" in exc_info.value.detail["message"].lower()
|
||||
assert exc_info.value.detail["health_status"] == SchemaHealth.NON_STRICT_ONLY.value
|
||||
# Verify the tool was added successfully
|
||||
assert result is not None
|
||||
|
||||
# Verify create_mcp_tool_async was called with the right parameters
|
||||
mock_server.tool_manager.create_mcp_tool_async.assert_called_once()
|
||||
call_args = mock_server.tool_manager.create_mcp_tool_async.call_args
|
||||
assert call_args.kwargs["mcp_server_name"] == "test_server"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -183,3 +186,271 @@ async def test_add_mcp_tool_rejects_invalid_schemas():
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "invalid schema" in exc_info.value.detail["message"].lower()
|
||||
assert exc_info.value.detail["health_status"] == SchemaHealth.INVALID.value
|
||||
|
||||
|
||||
def test_mcp_schema_healing_for_optional_fields():
|
||||
"""Test that optional fields in MCP schemas are healed only in strict mode."""
|
||||
# Create an MCP tool with optional field 'b'
|
||||
mcp_tool = MCPTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "integer", "description": "Required field"},
|
||||
"b": {"type": "integer", "description": "Optional field"},
|
||||
},
|
||||
"required": ["a"], # Only 'a' is required
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Generate schema without strict mode - should NOT heal optional fields
|
||||
non_strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=False)
|
||||
assert "a" in non_strict_schema["parameters"]["required"]
|
||||
assert "b" not in non_strict_schema["parameters"]["required"] # Should remain optional
|
||||
assert non_strict_schema["parameters"]["properties"]["b"]["type"] == "integer" # No null added
|
||||
|
||||
# Validate non-strict schema - should still be STRICT_COMPLIANT because validator is relaxed
|
||||
status, _ = validate_complete_json_schema(non_strict_schema["parameters"])
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
|
||||
# Generate schema with strict mode - should heal optional fields
|
||||
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
|
||||
assert strict_schema["strict"] is True
|
||||
assert "a" in strict_schema["parameters"]["required"]
|
||||
assert "b" in strict_schema["parameters"]["required"] # Now required
|
||||
assert set(strict_schema["parameters"]["properties"]["b"]["type"]) == {"integer", "null"} # Now accepts null
|
||||
|
||||
# Validate strict schema
|
||||
status, _ = validate_complete_json_schema(strict_schema["parameters"])
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT # Should pass strict mode
|
||||
|
||||
|
||||
def test_mcp_schema_healing_with_anyof():
|
||||
"""Test schema healing for fields with anyOf that include optional types."""
|
||||
mcp_tool = MCPTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string", "description": "Required field"},
|
||||
"b": {
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
"description": "Optional field with anyOf",
|
||||
},
|
||||
},
|
||||
"required": ["a"], # Only 'a' is required
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Generate strict schema
|
||||
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
|
||||
assert strict_schema["strict"] is True
|
||||
assert "a" in strict_schema["parameters"]["required"]
|
||||
assert "b" in strict_schema["parameters"]["required"] # Now required
|
||||
# Type should be flattened array with deduplication
|
||||
assert set(strict_schema["parameters"]["properties"]["b"]["type"]) == {"integer", "null"}
|
||||
|
||||
# Validate strict schema
|
||||
status, _ = validate_complete_json_schema(strict_schema["parameters"])
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
|
||||
|
||||
def test_mcp_schema_type_deduplication():
|
||||
"""Test that duplicate types are deduplicated in schema generation."""
|
||||
mcp_tool = MCPTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "string"}, # Duplicate
|
||||
{"type": "null"},
|
||||
],
|
||||
"description": "Field with duplicate types",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Generate strict schema
|
||||
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
|
||||
|
||||
# Check that duplicates were removed
|
||||
field_types = strict_schema["parameters"]["properties"]["field"]["type"]
|
||||
assert len(field_types) == len(set(field_types)) # No duplicates
|
||||
assert set(field_types) == {"string", "null"}
|
||||
|
||||
|
||||
def test_mcp_schema_healing_preserves_existing_null():
|
||||
"""Test that schema healing doesn't add duplicate null when it already exists."""
|
||||
mcp_tool = MCPTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": ["string", "null"], # Already has null
|
||||
"description": "Field that already accepts null",
|
||||
},
|
||||
},
|
||||
"required": [], # Optional
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Generate strict schema
|
||||
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
|
||||
|
||||
# Check that null wasn't duplicated
|
||||
field_types = strict_schema["parameters"]["properties"]["field"]["type"]
|
||||
null_count = field_types.count("null")
|
||||
assert null_count == 1 # Should only have one null
|
||||
|
||||
|
||||
def test_mcp_schema_healing_all_fields_already_required():
|
||||
"""Test that schema healing works correctly when all fields are already required."""
|
||||
mcp_tool = MCPTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string", "description": "Field A"},
|
||||
"b": {"type": "integer", "description": "Field B"},
|
||||
},
|
||||
"required": ["a", "b"], # All fields already required
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Generate strict schema
|
||||
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
|
||||
|
||||
# Check that fields remain as-is
|
||||
assert set(strict_schema["parameters"]["required"]) == {"a", "b"}
|
||||
assert strict_schema["parameters"]["properties"]["a"]["type"] == "string"
|
||||
assert strict_schema["parameters"]["properties"]["b"]["type"] == "integer"
|
||||
|
||||
# Should be strict compliant
|
||||
status, _ = validate_complete_json_schema(strict_schema["parameters"])
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
|
||||
|
||||
def test_mcp_schema_with_uuid_format():
|
||||
"""Test handling of UUID format in anyOf schemas (root cause of duplicate string types)."""
|
||||
mcp_tool = MCPTool(
|
||||
name="test_tool",
|
||||
description="A test tool with UUID formatted field",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"anyOf": [{"type": "string"}, {"format": "uuid", "type": "string"}, {"type": "null"}],
|
||||
"description": "Session ID that can be a string, UUID, or null",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Generate strict schema
|
||||
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
|
||||
|
||||
# Check that string type is not duplicated
|
||||
session_props = strict_schema["parameters"]["properties"]["session_id"]
|
||||
assert set(session_props["type"]) == {"string", "null"} # No duplicate strings
|
||||
# Format should NOT be preserved because field is optional (has null type)
|
||||
assert "format" not in session_props
|
||||
|
||||
# Should be in required array (healed)
|
||||
assert "session_id" in strict_schema["parameters"]["required"]
|
||||
|
||||
# Should be strict compliant
|
||||
status, _ = validate_complete_json_schema(strict_schema["parameters"])
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
|
||||
|
||||
def test_mcp_schema_healing_only_in_strict_mode():
|
||||
"""Test that schema healing only happens in strict mode."""
|
||||
mcp_tool = MCPTool(
|
||||
name="test_tool",
|
||||
description="Test that healing only happens in strict mode",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_field": {"type": "string", "description": "Already required"},
|
||||
"optional_field1": {"type": "integer", "description": "Optional 1"},
|
||||
"optional_field2": {"type": "boolean", "description": "Optional 2"},
|
||||
},
|
||||
"required": ["required_field"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Test with strict=False - no healing
|
||||
non_strict = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=False)
|
||||
assert "strict" not in non_strict # strict flag not set
|
||||
assert non_strict["parameters"]["required"] == ["required_field"] # Only originally required field
|
||||
assert non_strict["parameters"]["properties"]["required_field"]["type"] == "string"
|
||||
assert non_strict["parameters"]["properties"]["optional_field1"]["type"] == "integer" # No null
|
||||
assert non_strict["parameters"]["properties"]["optional_field2"]["type"] == "boolean" # No null
|
||||
|
||||
# Test with strict=True - healing happens
|
||||
strict = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
|
||||
assert strict["strict"] is True # strict flag is set
|
||||
assert set(strict["parameters"]["required"]) == {"required_field", "optional_field1", "optional_field2"}
|
||||
assert strict["parameters"]["properties"]["required_field"]["type"] == "string"
|
||||
assert set(strict["parameters"]["properties"]["optional_field1"]["type"]) == {"integer", "null"}
|
||||
assert set(strict["parameters"]["properties"]["optional_field2"]["type"]) == {"boolean", "null"}
|
||||
|
||||
# Both should be strict compliant (validator is relaxed)
|
||||
status1, _ = validate_complete_json_schema(non_strict["parameters"])
|
||||
status2, _ = validate_complete_json_schema(strict["parameters"])
|
||||
assert status1 == SchemaHealth.STRICT_COMPLIANT
|
||||
assert status2 == SchemaHealth.STRICT_COMPLIANT
|
||||
|
||||
|
||||
def test_mcp_schema_with_uuid_format_required_field():
|
||||
"""Test that UUID format is preserved for required fields that don't have null type."""
|
||||
mcp_tool = MCPTool(
|
||||
name="test_tool",
|
||||
description="A test tool with required UUID formatted field",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"anyOf": [{"type": "string"}, {"format": "uuid", "type": "string"}],
|
||||
"description": "Session ID that must be a string with UUID format",
|
||||
},
|
||||
},
|
||||
"required": ["session_id"], # Required field
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Generate strict schema
|
||||
strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True)
|
||||
|
||||
# Check that string type is not duplicated and format IS preserved
|
||||
session_props = strict_schema["parameters"]["properties"]["session_id"]
|
||||
assert session_props["type"] == ["string"] # No null, no duplicates
|
||||
assert "format" in session_props
|
||||
assert session_props["format"] == "uuid" # Format should be preserved for non-optional field
|
||||
|
||||
# Should be in required array
|
||||
assert "session_id" in strict_schema["parameters"]["required"]
|
||||
|
||||
# Should be strict compliant
|
||||
status, _ = validate_complete_json_schema(strict_schema["parameters"])
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestSchemaValidator:
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
"required": ["name", "age", "address"], # All properties must be required for strict mode
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@@ -235,22 +235,22 @@ class TestSchemaValidator:
|
||||
assert status == SchemaHealth.INVALID
|
||||
assert any("dict" in reason for reason in reasons)
|
||||
|
||||
def test_schema_with_defaults_strict_compliant(self):
|
||||
"""Test that root-level schemas without required field are STRICT_COMPLIANT."""
|
||||
def test_schema_with_defaults_non_strict(self):
|
||||
"""Test that root-level schemas without required field are STRICT_COMPLIANT (validator is relaxed)."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "optional": {"type": "string"}},
|
||||
# Missing "required" field at root level is OK
|
||||
# Missing "required" field at root level - validator now accepts this
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
# After fix, root level without required should be STRICT_COMPLIANT
|
||||
# Validator is relaxed - schemas with optional fields are now STRICT_COMPLIANT
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_composio_schema_with_optional_root_properties_strict_compliant(self):
|
||||
"""Test that Composio-like schemas with optional root properties are STRICT_COMPLIANT."""
|
||||
def test_composio_schema_with_optional_root_properties_non_strict(self):
|
||||
"""Test that Composio-like schemas with optional root properties are STRICT_COMPLIANT (validator is relaxed)."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -267,25 +267,25 @@ class TestSchemaValidator:
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_root_level_without_required_strict_compliant(self):
|
||||
"""Test that root-level objects without 'required' field are STRICT_COMPLIANT."""
|
||||
def test_root_level_without_required_non_strict(self):
|
||||
"""Test that root-level objects without 'required' field are STRICT_COMPLIANT (validator is relaxed)."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
# No "required" field at root level
|
||||
# No "required" field at root level - validator now accepts this
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
# Root level without required should be STRICT_COMPLIANT
|
||||
# Validator is relaxed - accepts schemas without required field
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_nested_object_without_required_non_strict(self):
|
||||
"""Test that nested objects without 'required' remain NON_STRICT_ONLY."""
|
||||
"""Test that nested objects without 'required' are STRICT_COMPLIANT (validator is relaxed)."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -309,6 +309,37 @@ class TestSchemaValidator:
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.NON_STRICT_ONLY
|
||||
# Should have warning about nested preferences object missing 'required'
|
||||
assert any("required" in reason and "preferences" in reason for reason in reasons)
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_user_example_schema_non_strict(self):
|
||||
"""Test the user's example schema with optional properties - now STRICT_COMPLIANT (validator is relaxed)."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None, "title": "B"},
|
||||
},
|
||||
"required": ["a"], # Only 'a' is required, 'b' is not
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_all_properties_required_strict_compliant(self):
|
||||
"""Test that schemas with all properties required are STRICT_COMPLIANT."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None, "title": "B"},
|
||||
},
|
||||
"required": ["a", "b"], # All properties are required
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
@@ -36,14 +35,6 @@ from tests.utils import create_tool_from_func
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
"""Use a single asyncio loop for the entire test session."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
def _clear_tables():
|
||||
from letta.server.db import db_context
|
||||
|
||||
|
||||
52
tests/test_embeddings.py
Normal file
52
tests/test_embeddings.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
included_files = [
|
||||
# "ollama.json",
|
||||
"letta-hosted.json",
|
||||
"openai_embed.json",
|
||||
]
|
||||
config_dir = "tests/configs/embedding_model_configs"
|
||||
config_files = glob.glob(os.path.join(config_dir, "*.json"))
|
||||
embedding_configs = [
|
||||
EmbeddingConfig(**json.load(open(config_file))) for config_file in config_files if config_file.split("/")[-1] in included_files
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def default_organization(server: SyncServer):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = server.organization_manager.create_default_organization()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(server: SyncServer, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_default_user(org_id=default_organization.id)
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"embedding_config",
|
||||
embedding_configs,
|
||||
ids=[c.embedding_model for c in embedding_configs],
|
||||
)
|
||||
async def test_embeddings(embedding_config: EmbeddingConfig, default_user):
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
test_input = "This is a test input."
|
||||
embeddings = await embedding_client.request_embeddings([test_input], embedding_config)
|
||||
assert len(embeddings) == 1
|
||||
assert len(embeddings[0]) == embedding_config.embedding_dim
|
||||
@@ -49,14 +49,6 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"]
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
"""Use a single asyncio loop for the entire test session."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tool(server):
|
||||
def get_weather(location: str) -> str:
|
||||
|
||||
@@ -28,6 +28,10 @@ def server_url() -> str:
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
api_url = os.getenv("LETTA_API_URL")
|
||||
if api_url:
|
||||
return api_url
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
@@ -56,12 +60,18 @@ def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = Letta(base_url=server_url)
|
||||
api_url = os.getenv("LETTA_API_URL")
|
||||
api_key = os.getenv("LETTA_API_KEY")
|
||||
|
||||
if api_url and not api_key:
|
||||
raise ValueError("LETTA_API_KEY is required when passing LETTA_API_URL")
|
||||
|
||||
client_instance = Letta(token=api_key, base_url=api_url if api_url else server_url)
|
||||
return client_instance
|
||||
|
||||
|
||||
async def test_deep_research_agent(client, server_url, disable_e2b_api_key):
|
||||
imported_af = upload_test_agentfile_from_disk(server_url, "deep-thought.af")
|
||||
async def test_deep_research_agent(client: Letta, server_url, disable_e2b_api_key):
|
||||
imported_af = upload_test_agentfile_from_disk(client, "deep-thought.af")
|
||||
|
||||
agent_id = imported_af.agent_ids[0]
|
||||
|
||||
@@ -69,6 +79,7 @@ async def test_deep_research_agent(client, server_url, disable_e2b_api_key):
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_id,
|
||||
stream_tokens=True,
|
||||
include_pings=True,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
@@ -90,8 +101,8 @@ async def test_deep_research_agent(client, server_url, disable_e2b_api_key):
|
||||
client.agents.delete(agent_id=agent_id)
|
||||
|
||||
|
||||
async def test_11x_agent(client, server_url, disable_e2b_api_key):
|
||||
imported_af = upload_test_agentfile_from_disk(server_url, "mock_alice.af")
|
||||
async def test_11x_agent(client: Letta, server_url, disable_e2b_api_key):
|
||||
imported_af = upload_test_agentfile_from_disk(client, "mock_alice.af")
|
||||
|
||||
agent_id = imported_af.agent_ids[0]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,7 @@ from letta.settings import settings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_experimental_decorator(event_loop):
|
||||
async def test_default_experimental_decorator():
|
||||
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
|
||||
|
||||
@experimental("test_just_pass", fallback_function=lambda: False, kwarg1=3)
|
||||
@@ -18,7 +18,7 @@ async def test_default_experimental_decorator(event_loop):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_arg_success(event_loop):
|
||||
async def test_overwrite_arg_success():
|
||||
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
|
||||
|
||||
@experimental("test_override_kwarg", fallback_function=lambda *args, **kwargs: False, bool_val=True)
|
||||
@@ -31,7 +31,7 @@ async def test_overwrite_arg_success(event_loop):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_overwrite_arg_fail(event_loop):
|
||||
async def test_overwrite_arg_fail():
|
||||
# Should fallback to lambda
|
||||
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
|
||||
|
||||
@@ -61,7 +61,7 @@ async def test_overwrite_arg_fail(event_loop):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_flag(event_loop):
|
||||
async def test_redis_flag():
|
||||
settings.plugin_register = "experimental_check=tests.helpers.plugins_helper:is_experimental_okay"
|
||||
|
||||
@experimental("test_redis_flag", fallback_function=lambda *args, **kwargs: _raise())
|
||||
|
||||
@@ -130,7 +130,7 @@ async def test_provider_trace_experimental_step(message, agent_state, default_us
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_experimental_step_stream(message, agent_state, default_user, event_loop):
|
||||
async def test_provider_trace_experimental_step_stream(message, agent_state, default_user):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
@@ -169,7 +169,7 @@ async def test_provider_trace_experimental_step_stream(message, agent_state, def
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_provider_trace_step(client, agent_state, default_user, message, event_loop):
|
||||
async def test_provider_trace_step(client, agent_state, default_user, message):
|
||||
client.agents.messages.create(agent_id=agent_state.id, messages=[])
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
@@ -186,7 +186,7 @@ async def test_provider_trace_step(client, agent_state, default_user, message, e
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Get the weather in San Francisco."])
|
||||
async def test_noop_provider_trace(message, agent_state, default_user, event_loop):
|
||||
async def test_noop_provider_trace(message, agent_state, default_user):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_state.id,
|
||||
message_manager=MessageManager(),
|
||||
|
||||
@@ -4,7 +4,7 @@ from letta.data_sources.redis_client import get_redis_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_client(event_loop):
|
||||
async def test_redis_client():
|
||||
test_values = {"LETTA_TEST_0": [1, 2, 3], "LETTA_TEST_1": ["apple", "pear", "banana"], "LETTA_TEST_2": ["{}", 3.2, "cat"]}
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
|
||||
278
tests/test_schema_validator.py
Normal file
278
tests/test_schema_validator.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Test schema validation for OpenAI strict mode compliance.
|
||||
"""
|
||||
|
||||
from letta.functions.schema_validator import SchemaHealth, validate_complete_json_schema
|
||||
|
||||
|
||||
def test_user_example_schema_now_strict():
|
||||
"""Test that schemas with optional fields are now considered STRICT_COMPLIANT (will be healed)."""
|
||||
schema = {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
"default": None,
|
||||
"title": "B",
|
||||
},
|
||||
},
|
||||
"required": ["a"], # Only 'a' is required, 'b' is not
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should now be STRICT_COMPLIANT because we can heal optional fields
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
|
||||
def test_all_properties_required_is_strict():
|
||||
"""Test that schemas with all properties required are STRICT_COMPLIANT."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "integer"},
|
||||
"b": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, # Optional via null type
|
||||
},
|
||||
"required": ["a", "b"], # All properties are required
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should be STRICT_COMPLIANT since all properties are required
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
|
||||
def test_nested_object_missing_required_now_strict():
|
||||
"""Test that nested objects with optional fields are now STRICT_COMPLIANT (will be healed)."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"host": {"type": "string"},
|
||||
"port": {"type": "integer"},
|
||||
"optional_field": {"anyOf": [{"type": "string"}, {"type": "null"}]},
|
||||
},
|
||||
"required": ["host", "port"], # optional_field not required
|
||||
"additionalProperties": False,
|
||||
}
|
||||
},
|
||||
"required": ["config"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should now be STRICT_COMPLIANT because we can heal optional fields
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
|
||||
def test_nested_object_all_required_is_strict():
|
||||
"""Test that nested objects with all properties required are STRICT_COMPLIANT."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"host": {"type": "string"},
|
||||
"port": {"type": "integer"},
|
||||
"timeout": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
|
||||
},
|
||||
"required": ["host", "port", "timeout"], # All properties required
|
||||
"additionalProperties": False,
|
||||
}
|
||||
},
|
||||
"required": ["config"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should be STRICT_COMPLIANT since all properties at all levels are required
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
|
||||
def test_empty_object_no_properties_is_strict():
|
||||
"""Test that objects with no properties are STRICT_COMPLIANT."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Empty objects with no properties should be STRICT_COMPLIANT
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
|
||||
def test_missing_additionalproperties_not_strict():
|
||||
"""Test that missing additionalProperties makes schema NON_STRICT_ONLY."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {"type": "string"},
|
||||
},
|
||||
"required": ["field"],
|
||||
# Missing additionalProperties
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should be NON_STRICT_ONLY due to missing additionalProperties
|
||||
assert status == SchemaHealth.NON_STRICT_ONLY
|
||||
assert any("additionalProperties" in reason and "not explicitly set" in reason for reason in reasons)
|
||||
|
||||
|
||||
def test_additionalproperties_true_not_strict():
|
||||
"""Test that additionalProperties: true makes schema NON_STRICT_ONLY."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {"type": "string"},
|
||||
},
|
||||
"required": ["field"],
|
||||
"additionalProperties": True, # Allows additional properties
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should be NON_STRICT_ONLY due to additionalProperties not being false
|
||||
assert status == SchemaHealth.NON_STRICT_ONLY
|
||||
assert any("additionalProperties" in reason and "not false" in reason for reason in reasons)
|
||||
|
||||
|
||||
def test_complex_schema_with_arrays():
|
||||
"""Test a complex schema with arrays and nested objects."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer"},
|
||||
"name": {"type": "string"},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["id", "name", "tags"], # All properties required
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"total": {"type": "integer"},
|
||||
},
|
||||
"required": ["items", "total"], # All properties required
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should be STRICT_COMPLIANT since all properties at all levels are required
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
|
||||
def test_fastmcp_tool_schema_now_strict():
|
||||
"""Test that a schema from FastMCP with optional field 'b' is now STRICT_COMPLIANT."""
|
||||
# This is the exact schema format provided by the user
|
||||
schema = {
|
||||
"properties": {
|
||||
"a": {"title": "A", "type": "integer"},
|
||||
"b": {"anyOf": [{"type": "integer"}, {"type": "null"}], "default": None, "title": "B"},
|
||||
},
|
||||
"required": ["a"], # Only 'a' is required, but we can heal this
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should now be STRICT_COMPLIANT because we can heal optional fields
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
|
||||
def test_union_types_with_anyof():
|
||||
"""Test that anyOf unions are handled correctly."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "number"},
|
||||
{"type": "null"},
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["value"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should be STRICT_COMPLIANT - anyOf is allowed and all properties are required
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
|
||||
def test_healed_schema_with_type_array():
|
||||
"""Test that healed schemas with type arrays including null are STRICT_COMPLIANT."""
|
||||
# This represents a schema that has been healed by adding null to optional fields
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_field": {"type": "string"},
|
||||
"optional_field": {"type": ["integer", "null"]}, # Healed: was optional, now required with null
|
||||
},
|
||||
"required": ["required_field", "optional_field"], # All fields now required
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should be STRICT_COMPLIANT since all properties are required
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
|
||||
def test_healed_nested_schema():
|
||||
"""Test that healed nested schemas are STRICT_COMPLIANT."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"host": {"type": "string"},
|
||||
"port": {"type": ["integer", "null"]}, # Healed optional field
|
||||
"timeout": {"type": ["number", "null"]}, # Healed optional field
|
||||
},
|
||||
"required": ["host", "port", "timeout"], # All fields required after healing
|
||||
"additionalProperties": False,
|
||||
}
|
||||
},
|
||||
"required": ["config"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
# Should be STRICT_COMPLIANT after healing
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
@@ -1,5 +1,7 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import textwrap
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@@ -15,6 +17,8 @@ from letta_client.core import ApiError
|
||||
from letta_client.types import AgentState, ToolReturnMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from tests.helpers.utils import upload_file_and_wait
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
|
||||
@@ -60,6 +64,117 @@ def agent(client: LettaSDKClient):
|
||||
client.agents.delete(agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def fibonacci_tool(client: LettaSDKClient):
|
||||
"""Fixture providing Fibonacci calculation tool."""
|
||||
|
||||
def calculate_fibonacci(n: int) -> int:
|
||||
"""Calculate the nth Fibonacci number.
|
||||
|
||||
Args:
|
||||
n: The position in the Fibonacci sequence to calculate.
|
||||
|
||||
Returns:
|
||||
The nth Fibonacci number.
|
||||
"""
|
||||
if n <= 0:
|
||||
return 0
|
||||
elif n == 1:
|
||||
return 1
|
||||
else:
|
||||
a, b = 0, 1
|
||||
for _ in range(2, n + 1):
|
||||
a, b = b, a + b
|
||||
return b
|
||||
|
||||
tool = client.tools.upsert_from_function(func=calculate_fibonacci, tags=["math", "utility"])
|
||||
yield tool
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def preferences_tool(client: LettaSDKClient):
|
||||
"""Fixture providing user preferences tool."""
|
||||
|
||||
def get_user_preferences(category: str) -> str:
|
||||
"""Get user preferences for a specific category.
|
||||
|
||||
Args:
|
||||
category: The preference category to retrieve (notification, theme, language).
|
||||
|
||||
Returns:
|
||||
The user's preference for the specified category, or "not specified" if unknown.
|
||||
"""
|
||||
preferences = {"notification": "email only", "theme": "dark mode", "language": "english"}
|
||||
return preferences.get(category, "not specified")
|
||||
|
||||
tool = client.tools.upsert_from_function(func=get_user_preferences, tags=["user", "preferences"])
|
||||
yield tool
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def data_analysis_tool(client: LettaSDKClient):
|
||||
"""Fixture providing data analysis tool."""
|
||||
|
||||
def analyze_data(data_type: str, values: List[float]) -> str:
|
||||
"""Analyze data and provide insights.
|
||||
|
||||
Args:
|
||||
data_type: Type of data to analyze.
|
||||
values: Numerical values to analyze.
|
||||
|
||||
Returns:
|
||||
Analysis results including average, max, and min values.
|
||||
"""
|
||||
if not values:
|
||||
return "No data provided"
|
||||
avg = sum(values) / len(values)
|
||||
max_val = max(values)
|
||||
min_val = min(values)
|
||||
return f"Analysis of {data_type}: avg={avg:.2f}, max={max_val}, min={min_val}"
|
||||
|
||||
tool = client.tools.upsert_from_function(func=analyze_data, tags=["analysis", "data"])
|
||||
yield tool
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def persona_block(client: LettaSDKClient):
|
||||
"""Fixture providing persona memory block."""
|
||||
block = client.blocks.create(
|
||||
label="persona",
|
||||
value="You are Alex, a data analyst and mathematician who helps users with calculations and insights. You have extensive experience in statistical analysis and prefer to provide clear, accurate results.",
|
||||
limit=8000,
|
||||
)
|
||||
yield block
|
||||
client.blocks.delete(block.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def human_block(client: LettaSDKClient):
|
||||
"""Fixture providing human memory block."""
|
||||
block = client.blocks.create(
|
||||
label="human",
|
||||
value="username: sarah_researcher\noccupation: data scientist\ninterests: machine learning, statistics, fibonacci sequences\npreferred_communication: detailed explanations with examples",
|
||||
limit=4000,
|
||||
)
|
||||
yield block
|
||||
client.blocks.delete(block.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def context_block(client: LettaSDKClient):
|
||||
"""Fixture providing project context memory block."""
|
||||
block = client.blocks.create(
|
||||
label="project_context",
|
||||
value="Current project: Building predictive models for financial markets. Sarah is working on sequence analysis and pattern recognition. Recently interested in mathematical sequences like Fibonacci for trend analysis.",
|
||||
limit=6000,
|
||||
)
|
||||
yield block
|
||||
client.blocks.delete(block.id)
|
||||
|
||||
|
||||
def test_shared_blocks(client: LettaSDKClient):
|
||||
# create a block
|
||||
block = client.blocks.create(
|
||||
@@ -1465,7 +1580,6 @@ def test_tool_name_auto_update_with_multiple_functions(client: LettaSDKClient):
|
||||
|
||||
def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
|
||||
"""Test that passing both new JSON schema AND source code still renames the tool based on source code"""
|
||||
import textwrap
|
||||
|
||||
# Create initial tool
|
||||
def initial_tool(x: int) -> int:
|
||||
@@ -1543,3 +1657,346 @@ def test_tool_rename_with_json_schema_and_source_code(client: LettaSDKClient):
|
||||
finally:
|
||||
# Clean up
|
||||
client.tools.delete(tool_id=tool.id)
|
||||
|
||||
|
||||
def test_import_agent_file_from_disk(
|
||||
client: LettaSDKClient, fibonacci_tool, preferences_tool, data_analysis_tool, persona_block, human_block, context_block
|
||||
):
|
||||
"""Test exporting an agent to file and importing it back from disk."""
|
||||
# Create a comprehensive agent (similar to test_agent_serialization_v2)
|
||||
name = f"test_export_import_{str(uuid.uuid4())}"
|
||||
temp_agent = client.agents.create(
|
||||
name=name,
|
||||
memory_blocks=[persona_block, human_block, context_block],
|
||||
model="openai/gpt-4.1-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[fibonacci_tool.id, preferences_tool.id, data_analysis_tool.id],
|
||||
include_base_tools=True,
|
||||
tags=["test", "export", "import"],
|
||||
system="You are a helpful assistant specializing in data analysis and mathematical computations.",
|
||||
)
|
||||
|
||||
# Add archival memory
|
||||
archival_passages = ["Test archival passage for export/import testing.", "Another passage with data about testing procedures."]
|
||||
|
||||
for passage_text in archival_passages:
|
||||
client.agents.passages.create(agent_id=temp_agent.id, text=passage_text)
|
||||
|
||||
# Send a test message
|
||||
client.agents.messages.create(
|
||||
agent_id=temp_agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Test message for export",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Export the agent
|
||||
serialized_v2 = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
|
||||
|
||||
# Save to file
|
||||
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_basic_agent_with_blocks_tools_messages_v2.af")
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(serialized_v2, f, indent=2)
|
||||
|
||||
# Now import from the file
|
||||
with open(file_path, "rb") as f:
|
||||
import_result = client.agents.import_file(
|
||||
file=f, append_copy_suffix=True, override_existing_tools=True # Use suffix to avoid name conflict
|
||||
)
|
||||
|
||||
# Basic verification
|
||||
assert import_result is not None, "Import result should not be None"
|
||||
assert len(import_result.agent_ids) > 0, "Should have imported at least one agent"
|
||||
|
||||
# Get the imported agent
|
||||
imported_agent_id = import_result.agent_ids[0]
|
||||
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
|
||||
|
||||
# Basic checks
|
||||
assert imported_agent is not None, "Should be able to retrieve imported agent"
|
||||
assert imported_agent.name is not None, "Imported agent should have a name"
|
||||
assert imported_agent.memory is not None, "Agent should have memory"
|
||||
assert len(imported_agent.tools) > 0, "Agent should have tools"
|
||||
assert imported_agent.system is not None, "Agent should have a system prompt"
|
||||
|
||||
|
||||
def test_agent_serialization_v2(
|
||||
client: LettaSDKClient, fibonacci_tool, preferences_tool, data_analysis_tool, persona_block, human_block, context_block
|
||||
):
|
||||
"""Test agent serialization with comprehensive setup including custom tools, blocks, messages, and archival memory."""
|
||||
name = f"comprehensive_test_agent_{str(uuid.uuid4())}"
|
||||
temp_agent = client.agents.create(
|
||||
name=name,
|
||||
memory_blocks=[persona_block, human_block, context_block],
|
||||
model="openai/gpt-4.1-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[fibonacci_tool.id, preferences_tool.id, data_analysis_tool.id],
|
||||
include_base_tools=True,
|
||||
tags=["test", "comprehensive", "serialization"],
|
||||
system="You are a helpful assistant specializing in data analysis and mathematical computations.",
|
||||
)
|
||||
|
||||
# Add archival memory
|
||||
archival_passages = [
|
||||
"Project background: Sarah is working on a financial prediction model that uses Fibonacci retracements for technical analysis.",
|
||||
"Research notes: Golden ratio (1.618) derived from Fibonacci sequence is often used in financial markets for support/resistance levels.",
|
||||
]
|
||||
|
||||
for passage_text in archival_passages:
|
||||
client.agents.passages.create(agent_id=temp_agent.id, text=passage_text)
|
||||
|
||||
# Send some messages
|
||||
client.agents.messages.create(
|
||||
agent_id=temp_agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Test message",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Serialize using v2
|
||||
serialized_v2 = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
|
||||
# Convert dict to JSON bytes for import
|
||||
json_str = json.dumps(serialized_v2)
|
||||
file_obj = io.BytesIO(json_str.encode("utf-8"))
|
||||
|
||||
# Import again
|
||||
import_result = client.agents.import_file(file=file_obj, append_copy_suffix=False, override_existing_tools=True)
|
||||
|
||||
# Verify import was successful
|
||||
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
|
||||
imported_agent_id = import_result.agent_ids[0]
|
||||
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
|
||||
|
||||
# ========== BASIC AGENT PROPERTIES ==========
|
||||
# Name should be the same (if append_copy_suffix=False) or have suffix
|
||||
assert imported_agent.name == name, f"Agent name mismatch: {imported_agent.name} != {name}"
|
||||
|
||||
# LLM and embedding configs should be preserved
|
||||
assert (
|
||||
imported_agent.llm_config.model == temp_agent.llm_config.model
|
||||
), f"LLM model mismatch: {imported_agent.llm_config.model} != {temp_agent.llm_config.model}"
|
||||
assert imported_agent.embedding_config.embedding_model == temp_agent.embedding_config.embedding_model, "Embedding model mismatch"
|
||||
|
||||
# System prompt should be preserved
|
||||
assert imported_agent.system == temp_agent.system, "System prompt was not preserved"
|
||||
|
||||
# Tags should be preserved
|
||||
assert set(imported_agent.tags) == set(temp_agent.tags), f"Tags mismatch: {imported_agent.tags} != {temp_agent.tags}"
|
||||
|
||||
# Agent type should be preserved
|
||||
assert (
|
||||
imported_agent.agent_type == temp_agent.agent_type
|
||||
), f"Agent type mismatch: {imported_agent.agent_type} != {temp_agent.agent_type}"
|
||||
|
||||
# ========== MEMORY BLOCKS ==========
|
||||
# Compare memory blocks directly from AgentState objects
|
||||
original_blocks = temp_agent.memory.blocks
|
||||
imported_blocks = imported_agent.memory.blocks
|
||||
|
||||
# Should have same number of blocks
|
||||
assert len(imported_blocks) == len(original_blocks), f"Block count mismatch: {len(imported_blocks)} != {len(original_blocks)}"
|
||||
|
||||
# Verify each block by label
|
||||
original_blocks_by_label = {block.label: block for block in original_blocks}
|
||||
imported_blocks_by_label = {block.label: block for block in imported_blocks}
|
||||
|
||||
# Check persona block
|
||||
assert "persona" in imported_blocks_by_label, "Persona block missing in imported agent"
|
||||
assert "Alex" in imported_blocks_by_label["persona"].value, "Persona block content not preserved"
|
||||
assert imported_blocks_by_label["persona"].limit == original_blocks_by_label["persona"].limit, "Persona block limit mismatch"
|
||||
|
||||
# Check human block
|
||||
assert "human" in imported_blocks_by_label, "Human block missing in imported agent"
|
||||
assert "sarah_researcher" in imported_blocks_by_label["human"].value, "Human block content not preserved"
|
||||
assert imported_blocks_by_label["human"].limit == original_blocks_by_label["human"].limit, "Human block limit mismatch"
|
||||
|
||||
# Check context block
|
||||
assert "project_context" in imported_blocks_by_label, "Context block missing in imported agent"
|
||||
assert "financial markets" in imported_blocks_by_label["project_context"].value, "Context block content not preserved"
|
||||
assert (
|
||||
imported_blocks_by_label["project_context"].limit == original_blocks_by_label["project_context"].limit
|
||||
), "Context block limit mismatch"
|
||||
|
||||
# ========== TOOLS ==========
|
||||
# Compare tools directly from AgentState objects
|
||||
original_tools = temp_agent.tools
|
||||
imported_tools = imported_agent.tools
|
||||
|
||||
# Should have same number of tools
|
||||
assert len(imported_tools) == len(original_tools), f"Tool count mismatch: {len(imported_tools)} != {len(original_tools)}"
|
||||
|
||||
original_tool_names = {tool.name for tool in original_tools}
|
||||
imported_tool_names = {tool.name for tool in imported_tools}
|
||||
|
||||
# Check custom tools are present
|
||||
assert "calculate_fibonacci" in imported_tool_names, "Fibonacci tool missing in imported agent"
|
||||
assert "get_user_preferences" in imported_tool_names, "Preferences tool missing in imported agent"
|
||||
assert "analyze_data" in imported_tool_names, "Data analysis tool missing in imported agent"
|
||||
|
||||
# Check for base tools (since we set include_base_tools=True when creating the agent)
|
||||
# Base tools should also be present (at least some core ones)
|
||||
base_tool_names = {"send_message", "conversation_search"}
|
||||
missing_base_tools = base_tool_names - imported_tool_names
|
||||
assert len(missing_base_tools) == 0, f"Missing base tools: {missing_base_tools}"
|
||||
|
||||
# Verify tool names match exactly
|
||||
assert original_tool_names == imported_tool_names, f"Tool names don't match: {original_tool_names} != {imported_tool_names}"
|
||||
|
||||
# ========== MESSAGE HISTORY ==========
|
||||
# Get messages for both agents
|
||||
original_messages = client.agents.messages.list(agent_id=temp_agent.id, limit=100)
|
||||
imported_messages = client.agents.messages.list(agent_id=imported_agent_id, limit=100)
|
||||
|
||||
# Should have same number of messages
|
||||
assert len(imported_messages) >= 1, "Imported agent should have messages"
|
||||
|
||||
# Filter for user messages (excluding system-generated login messages)
|
||||
original_user_msgs = [msg for msg in original_messages if msg.message_type == "user_message" and "Test message" in msg.content]
|
||||
imported_user_msgs = [msg for msg in imported_messages if msg.message_type == "user_message" and "Test message" in msg.content]
|
||||
|
||||
# Should have the same number of test messages
|
||||
assert len(imported_user_msgs) == len(
|
||||
original_user_msgs
|
||||
), f"User message count mismatch: {len(imported_user_msgs)} != {len(original_user_msgs)}"
|
||||
|
||||
# Verify test message content is preserved
|
||||
if len(original_user_msgs) > 0 and len(imported_user_msgs) > 0:
|
||||
assert imported_user_msgs[0].content == original_user_msgs[0].content, "User message content not preserved"
|
||||
assert "Test message" in imported_user_msgs[0].content, "Test message content not found"
|
||||
|
||||
|
||||
def test_export_import_agent_with_files(client: LettaSDKClient):
|
||||
"""Test exporting and importing an agent with files attached."""
|
||||
|
||||
# Clean up any existing source with the same name from previous runs
|
||||
existing_sources = client.sources.list()
|
||||
for existing_source in existing_sources:
|
||||
client.sources.delete(source_id=existing_source.id)
|
||||
|
||||
# Create a source and upload test files
|
||||
source = client.sources.create(name="test_export_source", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Upload test files to the source
|
||||
test_files = ["tests/data/test.txt", "tests/data/test.md"]
|
||||
|
||||
for file_path in test_files:
|
||||
upload_file_and_wait(client, source.id, file_path)
|
||||
|
||||
# Verify files were uploaded successfully
|
||||
files_in_source = client.sources.files.list(source_id=source.id, limit=10)
|
||||
assert len(files_in_source) == len(test_files), f"Expected {len(test_files)} files, got {len(files_in_source)}"
|
||||
|
||||
# Create a simple agent with the source attached
|
||||
temp_agent = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="username: sarah"),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
source_ids=[source.id], # Attach the source with files
|
||||
)
|
||||
|
||||
# Verify the agent has the source and file blocks
|
||||
agent_state = client.agents.retrieve(agent_id=temp_agent.id)
|
||||
assert len(agent_state.sources) == 1, "Agent should have one source attached"
|
||||
assert agent_state.sources[0].id == source.id, "Agent should have the correct source attached"
|
||||
|
||||
# Verify file blocks are present
|
||||
file_blocks = agent_state.memory.file_blocks
|
||||
assert len(file_blocks) == len(test_files), f"Expected {len(test_files)} file blocks, got {len(file_blocks)}"
|
||||
|
||||
# Export the agent
|
||||
serialized_agent = client.agents.export_file(agent_id=temp_agent.id, use_legacy_format=False)
|
||||
|
||||
# Convert to JSON bytes for import
|
||||
json_str = json.dumps(serialized_agent)
|
||||
file_obj = io.BytesIO(json_str.encode("utf-8"))
|
||||
|
||||
# Import the agent
|
||||
import_result = client.agents.import_file(file=file_obj, append_copy_suffix=True, override_existing_tools=True)
|
||||
|
||||
# Verify import was successful
|
||||
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
|
||||
imported_agent_id = import_result.agent_ids[0]
|
||||
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
|
||||
|
||||
# Verify the source is attached to the imported agent
|
||||
assert len(imported_agent.sources) == 1, "Imported agent should have one source attached"
|
||||
imported_source = imported_agent.sources[0]
|
||||
|
||||
# Check that imported source has the same files
|
||||
imported_files = client.sources.files.list(source_id=imported_source.id, limit=10)
|
||||
assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files"
|
||||
|
||||
# Verify file blocks are preserved in imported agent
|
||||
imported_file_blocks = imported_agent.memory.file_blocks
|
||||
assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks"
|
||||
|
||||
# Verify file block content
|
||||
for file_block in imported_file_blocks:
|
||||
assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content"
|
||||
assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header"
|
||||
|
||||
# Test that files can be opened on the imported agent
|
||||
if len(imported_files) > 0:
|
||||
test_file = imported_files[0]
|
||||
client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id)
|
||||
|
||||
# Clean up
|
||||
client.agents.delete(agent_id=temp_agent.id)
|
||||
client.agents.delete(agent_id=imported_agent_id)
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
|
||||
def test_import_agent_with_files_from_disk(client: LettaSDKClient):
|
||||
"""Test exporting an agent with files to disk and importing it back."""
|
||||
# Upload test files to the source
|
||||
test_files = ["tests/data/test.txt", "tests/data/test.md"]
|
||||
|
||||
# Save to file
|
||||
file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent_with_files_and_sources.af")
|
||||
|
||||
# Now import from the file
|
||||
with open(file_path, "rb") as f:
|
||||
import_result = client.agents.import_file(
|
||||
file=f, append_copy_suffix=True, override_existing_tools=True # Use suffix to avoid name conflict
|
||||
)
|
||||
|
||||
# Verify import was successful
|
||||
assert len(import_result.agent_ids) == 1, "Should have imported exactly one agent"
|
||||
imported_agent_id = import_result.agent_ids[0]
|
||||
imported_agent = client.agents.retrieve(agent_id=imported_agent_id)
|
||||
|
||||
# Verify the source is attached to the imported agent
|
||||
assert len(imported_agent.sources) == 1, "Imported agent should have one source attached"
|
||||
imported_source = imported_agent.sources[0]
|
||||
|
||||
# Check that imported source has the same files
|
||||
imported_files = client.sources.files.list(source_id=imported_source.id, limit=10)
|
||||
assert len(imported_files) == len(test_files), f"Imported source should have {len(test_files)} files"
|
||||
|
||||
# Verify file blocks are preserved in imported agent
|
||||
imported_file_blocks = imported_agent.memory.file_blocks
|
||||
assert len(imported_file_blocks) == len(test_files), f"Imported agent should have {len(test_files)} file blocks"
|
||||
|
||||
# Verify file block content
|
||||
for file_block in imported_file_blocks:
|
||||
assert file_block.value is not None and len(file_block.value) > 0, "Imported file block should have content"
|
||||
assert "[Viewing file start" in file_block.value, "Imported file block should show file viewing header"
|
||||
|
||||
# Test that files can be opened on the imported agent
|
||||
if len(imported_files) > 0:
|
||||
test_file = imported_files[0]
|
||||
client.agents.files.open(agent_id=imported_agent_id, file_id=test_file.id)
|
||||
|
||||
# Clean up agents and sources
|
||||
client.agents.delete(agent_id=imported_agent_id)
|
||||
client.sources.delete(source_id=imported_source.id)
|
||||
|
||||
@@ -485,7 +485,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_local_llm_configs(server: SyncServer, user: User, event_loop):
|
||||
async def test_read_local_llm_configs(server: SyncServer, user: User):
|
||||
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
|
||||
clean_up_dir = False
|
||||
if not os.path.exists(configs_base_dir):
|
||||
@@ -1016,7 +1016,7 @@ async def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, b
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_with_provider_override(server: SyncServer, user_id: str, event_loop):
|
||||
async def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
|
||||
provider = server.provider_manager.create_provider(
|
||||
request=ProviderCreate(
|
||||
@@ -1096,7 +1096,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User, event_loop):
|
||||
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
|
||||
models = await server.list_llm_models_async(actor=user)
|
||||
model_handles = [model.handle for model in models]
|
||||
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
@@ -7,17 +8,19 @@ from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock, DuplicateFileHandling
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import LettaRequest
|
||||
from letta_client import MessageCreate as ClientMessageCreate
|
||||
from letta_client.types import AgentState
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
|
||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||
from letta.schemas.enums import FileProcessingStatus, ToolType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from tests.helpers.utils import upload_file_and_wait, upload_file_and_wait_list_files
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Constants
|
||||
@@ -71,31 +74,6 @@ def client() -> LettaSDKClient:
|
||||
yield client
|
||||
|
||||
|
||||
def upload_file_and_wait(
|
||||
client: LettaSDKClient, source_id: str, file_path: str, max_wait: int = 60, duplicate_handling: DuplicateFileHandling = None
|
||||
):
|
||||
"""Helper function to upload a file and wait for processing to complete"""
|
||||
with open(file_path, "rb") as f:
|
||||
if duplicate_handling:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling)
|
||||
else:
|
||||
file_metadata = client.sources.files.upload(source_id=source_id, file=f)
|
||||
|
||||
# Wait for the file to be processed
|
||||
start_time = time.time()
|
||||
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
|
||||
if time.time() - start_time > max_wait:
|
||||
pytest.fail(f"File processing timed out after {max_wait} seconds")
|
||||
time.sleep(1)
|
||||
file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id)
|
||||
print("Waiting for file processing to complete...", file_metadata.processing_status)
|
||||
|
||||
if file_metadata.processing_status == "error":
|
||||
pytest.fail(f"File processing failed: {file_metadata.error_message}")
|
||||
|
||||
return file_metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_state(disable_pinecone, client: LettaSDKClient):
|
||||
open_file_tool = client.tools.list(name="open_files")[0]
|
||||
@@ -418,7 +396,7 @@ def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDK
|
||||
assert initial_content_length > 10, f"Expected file content > 10 chars, got {initial_content_length}"
|
||||
|
||||
# Ask agent to open the file for a specific range using offset/length
|
||||
offset, length = 1, 5 # 1-indexed offset, 5 lines
|
||||
offset, length = 0, 5 # 0-indexed offset, 5 lines
|
||||
print(f"Requesting agent to open file with offset={offset}, length={length}")
|
||||
open_response1 = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
@@ -447,7 +425,7 @@ def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDK
|
||||
assert "5: " in old_value, f"Expected line 5 to be present, got: {old_value}"
|
||||
|
||||
# Ask agent to open the file for a different range
|
||||
offset, length = 6, 5 # Different offset, same length
|
||||
offset, length = 5, 5 # Different offset, same length
|
||||
open_response2 = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[
|
||||
@@ -476,8 +454,8 @@ def test_agent_uses_open_close_file_correctly(disable_pinecone, client: LettaSDK
|
||||
assert "10: " in new_value, f"Expected line 10 to be present, got: {new_value}"
|
||||
|
||||
print(f"Comparing content ranges:")
|
||||
print(f" First range (offset=1, length=5): '{old_value}'")
|
||||
print(f" Second range (offset=6, length=5): '{new_value}'")
|
||||
print(f" First range (offset=0, length=5): '{old_value}'")
|
||||
print(f" Second range (offset=5, length=5): '{new_value}'")
|
||||
|
||||
assert new_value != old_value, f"Different view ranges should have different content. New: '{new_value}', Old: '{old_value}'"
|
||||
|
||||
@@ -697,7 +675,7 @@ def test_view_ranges_have_metadata(disable_pinecone, client: LettaSDKClient, age
|
||||
assert block.value.startswith("[Viewing file start (out of 100 lines)]")
|
||||
|
||||
# Open a specific range using offset/length
|
||||
offset = 50 # 1-indexed line 50
|
||||
offset = 49 # 0-indexed for line 50
|
||||
length = 5 # 5 lines (50-54)
|
||||
open_response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
@@ -851,6 +829,87 @@ def test_duplicate_file_handling_replace(disable_pinecone, client: LettaSDKClien
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_upload_file_with_custom_name(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that uploading a file with a custom name overrides the original filename"""
|
||||
# Create agent
|
||||
agent_state = client.agents.create(
|
||||
name="test_agent_custom_name",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="I am a helpful assistant",
|
||||
),
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="The user is a developer",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Create source
|
||||
source = client.sources.create(name="test_source_custom_name", embedding="openai/text-embedding-3-small")
|
||||
|
||||
# Attach source to agent
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
|
||||
# Create a temporary file with specific content
|
||||
import tempfile
|
||||
|
||||
temp_file_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
||||
f.write("This is a test file for custom naming")
|
||||
temp_file_path = f.name
|
||||
|
||||
# Upload file with custom name
|
||||
custom_name = "my_custom_file_name.txt"
|
||||
file_metadata = upload_file_and_wait(client, source.id, temp_file_path, name=custom_name)
|
||||
|
||||
# Verify the file uses the custom name
|
||||
assert file_metadata.file_name == custom_name
|
||||
assert file_metadata.original_file_name == custom_name
|
||||
|
||||
# Verify file appears in source files list with custom name
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
assert files[0].file_name == custom_name
|
||||
assert files[0].original_file_name == custom_name
|
||||
|
||||
# Verify the custom name is used in file blocks
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
file_blocks = agent_state.memory.file_blocks
|
||||
assert len(file_blocks) == 1
|
||||
# Check that the custom name appears in the block label
|
||||
assert custom_name.replace(".txt", "") in file_blocks[0].label
|
||||
|
||||
# Test duplicate handling with custom name - upload same file with same custom name
|
||||
from letta.schemas.enums import DuplicateFileHandling
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
upload_file_and_wait(client, source.id, temp_file_path, name=custom_name, duplicate_handling=DuplicateFileHandling.ERROR)
|
||||
assert "already exists" in str(exc_info.value).lower()
|
||||
|
||||
# Upload same file with different custom name should succeed
|
||||
different_custom_name = "folder_a/folder_b/another_custom_name.txt"
|
||||
file_metadata2 = upload_file_and_wait(client, source.id, temp_file_path, name=different_custom_name)
|
||||
assert file_metadata2.file_name == different_custom_name
|
||||
assert file_metadata2.original_file_name == different_custom_name
|
||||
|
||||
# Verify both files exist
|
||||
files = client.sources.files.list(source_id=source.id, limit=10)
|
||||
assert len(files) == 2
|
||||
file_names = {f.file_name for f in files}
|
||||
assert custom_name in file_names
|
||||
assert different_custom_name in file_names
|
||||
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
os.unlink(temp_file_path)
|
||||
|
||||
|
||||
def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient):
|
||||
"""Test that open_files tool schema contains correct descriptions from docstring"""
|
||||
|
||||
@@ -873,9 +932,9 @@ def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient
|
||||
# Check that examples are included
|
||||
assert "Examples:" in description
|
||||
assert 'FileOpenRequest(file_name="project_utils/config.py")' in description
|
||||
assert 'FileOpenRequest(file_name="project_utils/config.py", offset=1, length=50)' in description
|
||||
assert 'FileOpenRequest(file_name="project_utils/config.py", offset=0, length=50)' in description
|
||||
assert "# Lines 1-50" in description
|
||||
assert "# Lines 100-199" in description
|
||||
assert "# Lines 101-200" in description
|
||||
assert "# Entire file" in description
|
||||
assert "close_all_others=True" in description
|
||||
assert "View specific portions of large files (e.g. functions or definitions)" in description
|
||||
@@ -922,7 +981,7 @@ def test_open_files_schema_descriptions(disable_pinecone, client: LettaSDKClient
|
||||
# Check offset field
|
||||
assert "offset" in file_request_properties
|
||||
offset_prop = file_request_properties["offset"]
|
||||
expected_offset_desc = "Optional starting line number (1-indexed). If not specified, starts from beginning of file."
|
||||
expected_offset_desc = "Optional offset for starting line number (0-indexed). If not specified, starts from beginning of file."
|
||||
assert offset_prop["description"] == expected_offset_desc
|
||||
assert offset_prop["type"] == "integer"
|
||||
|
||||
@@ -1074,10 +1133,43 @@ def test_pinecone_search_files_tool(client: LettaSDKClient):
|
||||
), f"Search results should contain relevant content: {search_results}"
|
||||
|
||||
|
||||
def test_pinecone_list_files_status(client: LettaSDKClient):
|
||||
"""Test that list_source_files properly syncs embedding status with Pinecone"""
|
||||
if not should_use_pinecone():
|
||||
pytest.skip("Pinecone not configured (missing API key or disabled), skipping Pinecone-specific tests")
|
||||
|
||||
# create source
|
||||
source = client.sources.create(name="test_list_files_status", embedding="openai/text-embedding-3-small")
|
||||
|
||||
file_paths = ["tests/data/long_test.txt"]
|
||||
uploaded_files = []
|
||||
for file_path in file_paths:
|
||||
# use the new helper that polls via list_files
|
||||
file_metadata = upload_file_and_wait_list_files(client, source.id, file_path)
|
||||
uploaded_files.append(file_metadata)
|
||||
assert file_metadata.processing_status == "completed", f"File {file_path} should be completed"
|
||||
|
||||
# now get files using list_source_files to verify status checking works
|
||||
files_list = client.sources.files.list(source_id=source.id, limit=100)
|
||||
|
||||
# verify all files show completed status and have proper embedding counts
|
||||
assert len(files_list) == len(uploaded_files), f"Expected {len(uploaded_files)} files, got {len(files_list)}"
|
||||
|
||||
for file_metadata in files_list:
|
||||
assert file_metadata.processing_status == "completed", f"File {file_metadata.file_name} should show completed status"
|
||||
|
||||
# verify embedding counts for files that have chunks
|
||||
if file_metadata.total_chunks and file_metadata.total_chunks > 0:
|
||||
assert (
|
||||
file_metadata.chunks_embedded == file_metadata.total_chunks
|
||||
), f"File {file_metadata.file_name} should have all chunks embedded: {file_metadata.chunks_embedded}/{file_metadata.total_chunks}"
|
||||
|
||||
# cleanup
|
||||
client.sources.delete(source_id=source.id)
|
||||
|
||||
|
||||
def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
|
||||
"""Test that file and source deletion removes records from Pinecone"""
|
||||
import asyncio
|
||||
|
||||
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
|
||||
|
||||
if not should_use_pinecone():
|
||||
@@ -1146,8 +1238,6 @@ def test_pinecone_lifecycle_file_and_source_deletion(client: LettaSDKClient):
|
||||
len(records_after) == 0
|
||||
), f"All source records should be removed from Pinecone after source deletion, but found {len(records_after)}"
|
||||
|
||||
print("✓ Pinecone lifecycle verified - namespace is clean after source deletion")
|
||||
|
||||
|
||||
def test_agent_open_file(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
"""Test client.agents.open_file() function"""
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.constants import MAX_FILENAME_LENGTH
|
||||
@@ -522,19 +520,8 @@ def test_line_chunker_only_start_parameter():
|
||||
# ---------------------- Alembic Revision TESTS ---------------------- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
"""
|
||||
Create an event loop for the entire test session.
|
||||
Ensures all async tasks use the same loop, avoiding cross-loop errors.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_alembic_revision(event_loop):
|
||||
async def test_get_latest_alembic_revision():
|
||||
"""Test that get_latest_alembic_revision returns a valid revision ID from the database."""
|
||||
from letta.utils import get_latest_alembic_revision
|
||||
|
||||
@@ -553,7 +540,7 @@ async def test_get_latest_alembic_revision(event_loop):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_alembic_revision_consistency(event_loop):
|
||||
async def test_get_latest_alembic_revision_consistency():
|
||||
"""Test that get_latest_alembic_revision returns the same value on multiple calls."""
|
||||
from letta.utils import get_latest_alembic_revision
|
||||
|
||||
|
||||
Reference in New Issue
Block a user