chore: bump version 0.8.13 (#2718)
Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: Matthew Zhou <mattzh1314@gmail.com> Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: jnjpng <jin@letta.com> Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local> Co-authored-by: cpacker <packercharles@gmail.com> Co-authored-by: Shubham Naik <shub@letta.com> Co-authored-by: Shubham Naik <shub@memgpt.ai>
This commit is contained in:
@@ -5,7 +5,7 @@ try:
|
||||
__version__ = version("letta")
|
||||
except PackageNotFoundError:
|
||||
# Fallback for development installations
|
||||
__version__ = "0.8.12"
|
||||
__version__ = "0.8.13"
|
||||
|
||||
if os.environ.get("LETTA_VERSION"):
|
||||
__version__ = os.environ["LETTA_VERSION"]
|
||||
|
||||
@@ -96,7 +96,7 @@ class BaseAgent(ABC):
|
||||
"""
|
||||
try:
|
||||
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
|
||||
await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor)
|
||||
agent_state = await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
tool_constraint_block = None
|
||||
if tool_rules_solver is not None:
|
||||
@@ -104,18 +104,37 @@ class BaseAgent(ABC):
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources)
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
|
||||
# extract the dynamic section that includes memory blocks, tool rules, and directories
|
||||
# this avoids timestamp comparison issues
|
||||
def extract_dynamic_section(text):
|
||||
start_marker = "</base_instructions>"
|
||||
end_marker = "<memory_metadata>"
|
||||
|
||||
start_idx = text.find(start_marker)
|
||||
end_idx = text.find(end_marker)
|
||||
|
||||
if start_idx != -1 and end_idx != -1:
|
||||
return text[start_idx:end_idx]
|
||||
return text # fallback to full text if markers not found
|
||||
|
||||
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
|
||||
|
||||
# generate just the memory string with current state for comparison
|
||||
curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources)
|
||||
new_dynamic_section = extract_dynamic_section(curr_memory_str)
|
||||
|
||||
# compare just the dynamic sections (memory blocks, tool rules, directories)
|
||||
if curr_dynamic_section == new_dynamic_section:
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
# [DB Call] size of messages and archival memories
|
||||
# todo: blocking for now
|
||||
# size of messages and archival memories
|
||||
if num_messages is None:
|
||||
num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if num_archival_memories is None:
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
@@ -165,18 +165,28 @@ class LettaAgent(BaseAgent):
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
) -> LettaResponse:
|
||||
dry_run: bool = False,
|
||||
) -> Union[LettaResponse, dict]:
|
||||
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
agent_id=self.agent_id,
|
||||
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
|
||||
actor=self.actor,
|
||||
)
|
||||
_, new_in_context_messages, stop_reason, usage = await self._step(
|
||||
result = await self._step(
|
||||
agent_state=agent_state,
|
||||
input_messages=input_messages,
|
||||
max_steps=max_steps,
|
||||
run_id=run_id,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
# If dry run, return the request payload directly
|
||||
if dry_run:
|
||||
return result
|
||||
|
||||
_, new_in_context_messages, stop_reason, usage = result
|
||||
return _create_letta_response(
|
||||
new_in_context_messages=new_in_context_messages,
|
||||
use_assistant_message=use_assistant_message,
|
||||
@@ -195,7 +205,9 @@ class LettaAgent(BaseAgent):
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
agent_id=self.agent_id,
|
||||
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
|
||||
actor=self.actor,
|
||||
)
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
@@ -279,6 +291,7 @@ class LettaAgent(BaseAgent):
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
reasoning_content=reasoning,
|
||||
step_id=step_id,
|
||||
initial_messages=initial_messages,
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(i == max_steps - 1),
|
||||
@@ -357,7 +370,8 @@ class LettaAgent(BaseAgent):
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: str | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
) -> tuple[list[Message], list[Message], LettaStopReason | None, LettaUsageStatistics]:
|
||||
dry_run: bool = False,
|
||||
) -> Union[tuple[list[Message], list[Message], LettaStopReason | None, LettaUsageStatistics], dict]:
|
||||
"""
|
||||
Carries out an invocation of the agent loop. In each step, the agent
|
||||
1. Rebuilds its memory
|
||||
@@ -394,6 +408,16 @@ class LettaAgent(BaseAgent):
|
||||
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
|
||||
agent_step_span.set_attributes({"step_id": step_id})
|
||||
|
||||
# If dry run, build request data and return it without making LLM call
|
||||
if dry_run:
|
||||
request_data, valid_tool_names = await self._create_llm_request_data_async(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
)
|
||||
return request_data
|
||||
|
||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
||||
await self._build_and_request_from_llm(
|
||||
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span
|
||||
@@ -530,7 +554,9 @@ class LettaAgent(BaseAgent):
|
||||
4. Processes the response
|
||||
"""
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor
|
||||
agent_id=self.agent_id,
|
||||
include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"],
|
||||
actor=self.actor,
|
||||
)
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
@@ -628,7 +654,7 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
|
||||
# log LLM request time
|
||||
llm_request_ms = ns_to_ms(stream_end_time_ns - request_start_timestamp_ns)
|
||||
llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns)
|
||||
agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms})
|
||||
MetricRegistry().llm_execution_time_ms_histogram.record(
|
||||
llm_request_ms,
|
||||
|
||||
@@ -129,7 +129,8 @@ def get_function_name_and_docstring(source_code: str, name: Optional[str] = None
|
||||
raise LettaToolCreateError("Could not determine function name")
|
||||
|
||||
if not docstring:
|
||||
raise LettaToolCreateError("Docstring is missing")
|
||||
# For tools with args_json_schema, the docstring is optional
|
||||
docstring = f"The {function_name} tool"
|
||||
|
||||
return function_name, docstring
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ def json_dumps(data, indent=2):
|
||||
def safe_serializer(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode("utf-8")
|
||||
raise TypeError(f"Type {type(obj)} not serializable")
|
||||
|
||||
return json.dumps(data, indent=indent, default=safe_serializer, ensure_ascii=False)
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pinecone import PineconeAsyncio
|
||||
try:
|
||||
from pinecone import IndexEmbed, PineconeAsyncio
|
||||
from pinecone.exceptions.exceptions import NotFoundException
|
||||
|
||||
PINECONE_AVAILABLE = True
|
||||
except ImportError:
|
||||
PINECONE_AVAILABLE = False
|
||||
|
||||
from letta.constants import (
|
||||
PINECONE_CLOUD,
|
||||
@@ -27,11 +33,20 @@ def should_use_pinecone(verbose: bool = False):
|
||||
bool(settings.pinecone_source_index),
|
||||
)
|
||||
|
||||
return settings.enable_pinecone and settings.pinecone_api_key and settings.pinecone_agent_index and settings.pinecone_source_index
|
||||
return all(
|
||||
(
|
||||
PINECONE_AVAILABLE,
|
||||
settings.enable_pinecone,
|
||||
settings.pinecone_api_key,
|
||||
settings.pinecone_agent_index,
|
||||
settings.pinecone_source_index,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def upsert_pinecone_indices():
|
||||
from pinecone import IndexEmbed, PineconeAsyncio
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
for index_name in get_pinecone_indices():
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
@@ -49,6 +64,9 @@ def get_pinecone_indices() -> List[str]:
|
||||
|
||||
|
||||
async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, chunks: List[str], actor: User):
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
records = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
record = {
|
||||
@@ -63,7 +81,8 @@ async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, ch
|
||||
|
||||
|
||||
async def delete_file_records_from_pinecone_index(file_id: str, actor: User):
|
||||
from pinecone.exceptions.exceptions import NotFoundException
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
namespace = actor.organization_id
|
||||
try:
|
||||
@@ -81,7 +100,8 @@ async def delete_file_records_from_pinecone_index(file_id: str, actor: User):
|
||||
|
||||
|
||||
async def delete_source_records_from_pinecone_index(source_id: str, actor: User):
|
||||
from pinecone.exceptions.exceptions import NotFoundException
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
namespace = actor.organization_id
|
||||
try:
|
||||
@@ -94,6 +114,9 @@ async def delete_source_records_from_pinecone_index(source_id: str, actor: User)
|
||||
|
||||
|
||||
async def upsert_records_to_pinecone_index(records: List[dict], actor: User):
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
@@ -104,6 +127,9 @@ async def upsert_records_to_pinecone_index(records: List[dict], actor: User):
|
||||
|
||||
|
||||
async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], actor: User) -> Dict[str, Any]:
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||
@@ -127,7 +153,8 @@ async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any],
|
||||
|
||||
|
||||
async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]:
|
||||
from pinecone.exceptions.exceptions import NotFoundException
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
namespace = actor.organization_id
|
||||
try:
|
||||
|
||||
@@ -29,6 +29,7 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
|
||||
if _is_scheduler_leader:
|
||||
return True # Already leading
|
||||
|
||||
engine_name = None
|
||||
lock_session = None
|
||||
acquired_lock = False
|
||||
try:
|
||||
@@ -36,32 +37,25 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
|
||||
engine = session.get_bind()
|
||||
engine_name = engine.name
|
||||
logger.info(f"Database engine type: {engine_name}")
|
||||
if engine_name != "postgresql":
|
||||
logger.warning(f"Advisory locks not supported for {engine_name} database. Starting scheduler without leader election.")
|
||||
acquired_lock = True
|
||||
else:
|
||||
lock_session = db_registry.get_async_session_factory()()
|
||||
result = await lock_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(CAST(:lock_key AS bigint))"), {"lock_key": ADVISORY_LOCK_KEY}
|
||||
)
|
||||
acquired_lock = result.scalar()
|
||||
await lock_session.commit()
|
||||
|
||||
if not acquired_lock:
|
||||
if lock_session:
|
||||
await lock_session.close()
|
||||
logger.info("Scheduler lock held by another instance.")
|
||||
return False
|
||||
|
||||
if engine_name == "postgresql":
|
||||
logger.info("Acquired PostgreSQL advisory lock.")
|
||||
_advisory_lock_session = lock_session
|
||||
lock_session = None
|
||||
if engine_name != "postgresql":
|
||||
logger.warning(f"Advisory locks not supported for {engine_name} database. Starting scheduler without leader election.")
|
||||
acquired_lock = True
|
||||
else:
|
||||
logger.info("Starting scheduler for non-PostgreSQL database.")
|
||||
if lock_session:
|
||||
lock_session = db_registry.get_async_session_factory()()
|
||||
result = await lock_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(CAST(:lock_key AS bigint))"), {"lock_key": ADVISORY_LOCK_KEY}
|
||||
)
|
||||
acquired_lock = result.scalar()
|
||||
await lock_session.commit()
|
||||
|
||||
if not acquired_lock:
|
||||
await lock_session.close()
|
||||
lock_session = None
|
||||
logger.info("Scheduler lock held by another instance.")
|
||||
return False
|
||||
else:
|
||||
_advisory_lock_session = lock_session
|
||||
lock_session = None
|
||||
|
||||
trigger = IntervalTrigger(
|
||||
seconds=settings.poll_running_llm_batches_interval_seconds,
|
||||
@@ -90,7 +84,6 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
|
||||
if acquired_lock:
|
||||
logger.warning("Attempting to release lock due to error during startup.")
|
||||
try:
|
||||
_advisory_lock_session = lock_session
|
||||
await _release_advisory_lock(lock_session)
|
||||
except Exception as unlock_err:
|
||||
logger.error(f"Failed to release lock during error handling: {unlock_err}", exc_info=True)
|
||||
@@ -108,8 +101,8 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
|
||||
if lock_session:
|
||||
try:
|
||||
await lock_session.close()
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close session during error handling: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def _background_lock_retry_loop(server: SyncServer):
|
||||
@@ -138,15 +131,13 @@ async def _background_lock_retry_loop(server: SyncServer):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background lock retry loop: {e}", exc_info=True)
|
||||
await asyncio.sleep(settings.poll_lock_retry_interval_seconds)
|
||||
|
||||
|
||||
async def _release_advisory_lock(lock_session=None):
|
||||
async def _release_advisory_lock(target_lock_session=None):
|
||||
"""Releases the advisory lock using the stored session."""
|
||||
global _advisory_lock_session
|
||||
|
||||
lock_session = _advisory_lock_session or lock_session
|
||||
_advisory_lock_session = None
|
||||
lock_session = target_lock_session or _advisory_lock_session
|
||||
|
||||
if lock_session is not None:
|
||||
logger.info(f"Attempting to release PostgreSQL advisory lock {ADVISORY_LOCK_KEY}")
|
||||
@@ -161,6 +152,8 @@ async def _release_advisory_lock(lock_session=None):
|
||||
if lock_session:
|
||||
await lock_session.close()
|
||||
logger.info("Closed database session that held advisory lock.")
|
||||
if lock_session == _advisory_lock_session:
|
||||
_advisory_lock_session = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing advisory lock session: {e}", exc_info=True)
|
||||
else:
|
||||
|
||||
@@ -58,7 +58,12 @@ class MetricRegistry:
|
||||
def tool_execution_counter(self) -> Counter:
|
||||
return self._get_or_create_metric(
|
||||
"count_tool_execution",
|
||||
partial(self._meter.create_counter, name="count_tool_execution", description="Counts the number of tools executed.", unit="1"),
|
||||
partial(
|
||||
self._meter.create_counter,
|
||||
name="count_tool_execution",
|
||||
description="Counts the number of tools executed.",
|
||||
unit="1",
|
||||
),
|
||||
)
|
||||
|
||||
# project_id + model
|
||||
@@ -66,7 +71,12 @@ class MetricRegistry:
|
||||
def ttft_ms_histogram(self) -> Histogram:
|
||||
return self._get_or_create_metric(
|
||||
"hist_ttft_ms",
|
||||
partial(self._meter.create_histogram, name="hist_ttft_ms", description="Histogram for the Time to First Token (ms)", unit="ms"),
|
||||
partial(
|
||||
self._meter.create_histogram,
|
||||
name="hist_ttft_ms",
|
||||
description="Histogram for the Time to First Token (ms)",
|
||||
unit="ms",
|
||||
),
|
||||
)
|
||||
|
||||
# (includes model name)
|
||||
@@ -158,3 +168,15 @@ class MetricRegistry:
|
||||
unit="1",
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def file_process_bytes_histogram(self) -> Histogram:
|
||||
return self._get_or_create_metric(
|
||||
"hist_file_process_bytes",
|
||||
partial(
|
||||
self._meter.create_histogram,
|
||||
name="hist_file_process_bytes",
|
||||
description="Histogram for file process in bytes",
|
||||
unit="By",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -139,3 +139,11 @@ class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
STDIO = "stdio"
|
||||
STREAMABLE_HTTP = "streamable_http"
|
||||
|
||||
|
||||
class DuplicateFileHandling(str, Enum):
|
||||
"""How to handle duplicate filenames when uploading files"""
|
||||
|
||||
SKIP = "skip" # skip files with duplicate names
|
||||
ERROR = "error" # error when duplicate names are encountered
|
||||
SUFFIX = "suffix" # add numeric suffix to make names unique (default behavior)
|
||||
|
||||
@@ -77,9 +77,8 @@ class Tool(BaseTool):
|
||||
|
||||
if self.tool_type is ToolType.CUSTOM:
|
||||
if not self.source_code:
|
||||
error_msg = f"Custom tool with id={self.id} is missing source_code field."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
logger.error("Custom tool with id=%s is missing source_code field", self.id)
|
||||
raise ValueError(f"Custom tool with id={self.id} is missing source_code field.")
|
||||
|
||||
# Always derive json_schema for freshest possible json_schema
|
||||
if self.args_json_schema is not None:
|
||||
@@ -96,8 +95,7 @@ class Tool(BaseTool):
|
||||
try:
|
||||
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to derive json schema for tool with id={self.id} name={self.name}. Error: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.error("Failed to derive json schema for tool with id=%s name=%s: %s", self.id, self.name, e)
|
||||
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE, ToolType.LETTA_SLEEPTIME_CORE}:
|
||||
# If it's letta core tool, we generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
|
||||
@@ -119,9 +117,8 @@ class Tool(BaseTool):
|
||||
|
||||
# At this point, we need to validate that at least json_schema is populated
|
||||
if not self.json_schema:
|
||||
error_msg = f"Tool with id={self.id} name={self.name} tool_type={self.tool_type} is missing a json_schema."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
logger.error("Tool with id=%s name=%s tool_type=%s is missing a json_schema", self.id, self.name, self.tool_type)
|
||||
raise ValueError(f"Tool with id={self.id} name={self.name} tool_type={self.tool_type} is missing a json_schema.")
|
||||
|
||||
# Derive name from the JSON schema if not provided
|
||||
if not self.name:
|
||||
|
||||
@@ -337,8 +337,11 @@ def create_application() -> "FastAPI":
|
||||
# / static files
|
||||
mount_static_files(app)
|
||||
|
||||
no_generation = "--no-generation" in sys.argv
|
||||
|
||||
# Generate OpenAPI schema after all routes are mounted
|
||||
generate_openapi_schema(app)
|
||||
if not no_generation:
|
||||
generate_openapi_schema(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, List, Optional
|
||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, File, Header, HTTPException, Query, Request, UploadFile, status
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -522,7 +522,7 @@ async def attach_block(
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a core memoryblock to an agent.
|
||||
Attach a core memory block to an agent.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.attach_block_async(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
@@ -1160,6 +1160,69 @@ async def list_agent_groups(
|
||||
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{agent_id}/messages/preview-raw-payload",
|
||||
response_model=Dict[str, Any],
|
||||
operation_id="preview_raw_payload",
|
||||
)
|
||||
async def preview_raw_payload(
|
||||
agent_id: str,
|
||||
request: Union[LettaRequest, LettaStreamingRequest] = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Inspect the raw LLM request payload without sending it.
|
||||
|
||||
This endpoint processes the message through the agent loop up until
|
||||
the LLM request, then returns the raw request payload that would
|
||||
be sent to the LLM provider. Useful for debugging and inspection.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
|
||||
agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
if agent.enable_sleeptime:
|
||||
# TODO: @caren need to support this for sleeptime
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Payload inspection is not supported for agents with sleeptime enabled.",
|
||||
)
|
||||
else:
|
||||
agent_loop = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
summarizer_mode=(
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER
|
||||
if agent.agent_type == AgentType.voice_convo_agent
|
||||
else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER
|
||||
),
|
||||
)
|
||||
|
||||
# TODO: Support step_streaming
|
||||
return await agent_loop.step(
|
||||
input_messages=request.messages,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Payload inspection is not currently supported for this agent configuration.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/summarize", response_model=AgentState, operation_id="summarize_agent_conversation")
|
||||
async def summarize_agent_conversation(
|
||||
agent_id: str,
|
||||
|
||||
@@ -19,7 +19,7 @@ from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
from letta.schemas.enums import DuplicateFileHandling, FileProcessingStatus
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
@@ -208,6 +208,7 @@ async def delete_source(
|
||||
async def upload_file_to_source(
|
||||
file: UploadFile,
|
||||
source_id: str,
|
||||
duplicate_handling: DuplicateFileHandling = Query(DuplicateFileHandling.SUFFIX, description="How to handle duplicate filenames"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
@@ -264,8 +265,31 @@ async def upload_file_to_source(
|
||||
|
||||
content = await file.read()
|
||||
|
||||
# Store original filename and generate unique filename
|
||||
# Store original filename and handle duplicate logic
|
||||
original_filename = sanitize_filename(file.filename) # Basic sanitization only
|
||||
|
||||
# Check if duplicate exists
|
||||
existing_file = await server.file_manager.get_file_by_original_name_and_source(
|
||||
original_filename=original_filename, source_id=source_id, actor=actor
|
||||
)
|
||||
|
||||
if existing_file:
|
||||
# Duplicate found, handle based on strategy
|
||||
if duplicate_handling == DuplicateFileHandling.ERROR:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail=f"File '{original_filename}' already exists in source '{source.name}'"
|
||||
)
|
||||
elif duplicate_handling == DuplicateFileHandling.SKIP:
|
||||
# Return existing file metadata with custom header to indicate it was skipped
|
||||
from fastapi import Response
|
||||
|
||||
response = Response(
|
||||
content=existing_file.model_dump_json(), media_type="application/json", headers={"X-Upload-Result": "skipped"}
|
||||
)
|
||||
return response
|
||||
# For SUFFIX, continue to generate unique filename
|
||||
|
||||
# Generate unique filename (adds suffix if needed)
|
||||
unique_filename = await server.file_manager.generate_unique_filename(
|
||||
original_filename=original_filename, source=source, organization_id=actor.organization_id
|
||||
)
|
||||
@@ -360,6 +384,13 @@ async def get_file_metadata(
|
||||
file_id=file_id, actor=actor, include_content=include_content, strip_directory_prefix=True
|
||||
)
|
||||
|
||||
if not file_metadata:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
||||
|
||||
# Verify the file belongs to the specified source
|
||||
if file_metadata.source_id != source_id:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found in source {source_id}.")
|
||||
|
||||
if should_use_pinecone() and not file_metadata.is_processing_terminal():
|
||||
ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor, limit=file_metadata.total_chunks)
|
||||
logger.info(
|
||||
@@ -375,13 +406,6 @@ async def get_file_metadata(
|
||||
file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status
|
||||
)
|
||||
|
||||
if not file_metadata:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
||||
|
||||
# Verify the file belongs to the specified source
|
||||
if file_metadata.source_id != source_id:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found in source {source_id}.")
|
||||
|
||||
return file_metadata
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, insert, literal, or_, select
|
||||
from sqlalchemy import delete, func, insert, literal, or_, select, tuple_
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from letta.constants import (
|
||||
@@ -224,13 +224,44 @@ class AgentManager:
|
||||
@staticmethod
|
||||
async def _replace_pivot_rows_async(session, table, agent_id: str, rows: list[dict]):
|
||||
"""
|
||||
Replace all pivot rows for an agent with *exactly* the provided list.
|
||||
Uses two bulk statements (DELETE + INSERT ... ON CONFLICT DO NOTHING).
|
||||
Replace all pivot rows for an agent atomically using MERGE pattern.
|
||||
"""
|
||||
# delete all existing rows for this agent
|
||||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||||
if rows:
|
||||
await AgentManager._bulk_insert_pivot_async(session, table, rows)
|
||||
dialect = session.bind.dialect.name
|
||||
|
||||
if dialect == "postgresql":
|
||||
if rows:
|
||||
# separate upsert and delete operations
|
||||
stmt = pg_insert(table).values(rows)
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
await session.execute(stmt)
|
||||
|
||||
# delete rows not in new set
|
||||
pk_names = [c.name for c in table.primary_key.columns]
|
||||
new_keys = [tuple(r[c] for c in pk_names) for r in rows]
|
||||
await session.execute(
|
||||
delete(table).where(table.c.agent_id == agent_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys))
|
||||
)
|
||||
else:
|
||||
# if no rows to insert, just delete all
|
||||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||||
|
||||
elif dialect == "sqlite":
|
||||
if rows:
|
||||
stmt = sa.insert(table).values(rows).prefix_with("OR REPLACE")
|
||||
await session.execute(stmt)
|
||||
|
||||
if rows:
|
||||
primary_key_cols = [table.c[c.name] for c in table.primary_key.columns]
|
||||
new_keys = [tuple(r[c.name] for c in table.primary_key.columns) for r in rows]
|
||||
await session.execute(delete(table).where(table.c.agent_id == agent_id, ~tuple_(*primary_key_cols).in_(new_keys)))
|
||||
else:
|
||||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||||
|
||||
else:
|
||||
# fallback: use original DELETE + INSERT pattern
|
||||
await session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||||
if rows:
|
||||
await AgentManager._bulk_insert_pivot_async(session, table, rows)
|
||||
|
||||
# ======================================================================================================================
|
||||
# Basic CRUD operations
|
||||
|
||||
@@ -22,6 +22,15 @@ from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class DuplicateFileError(Exception):
|
||||
"""Raised when a duplicate file is encountered and error handling is specified"""
|
||||
|
||||
def __init__(self, filename: str, source_name: str):
|
||||
self.filename = filename
|
||||
self.source_name = source_name
|
||||
super().__init__(f"File '{filename}' already exists in source '{source_name}'")
|
||||
|
||||
|
||||
class FileManager:
|
||||
"""Manager class to handle business logic related to files."""
|
||||
|
||||
@@ -237,16 +246,16 @@ class FileManager:
|
||||
@trace_method
|
||||
async def generate_unique_filename(self, original_filename: str, source: PydanticSource, organization_id: str) -> str:
|
||||
"""
|
||||
Generate a unique filename by checking for duplicates and adding a numeric suffix if needed.
|
||||
Similar to how filesystems handle duplicates (e.g., file.txt, file (1).txt, file (2).txt).
|
||||
Generate a unique filename by adding a numeric suffix if duplicates exist.
|
||||
Always returns a unique filename - does not handle duplicate policies.
|
||||
|
||||
Parameters:
|
||||
original_filename (str): The original filename as uploaded.
|
||||
source_id (str): Source ID to check for duplicates within.
|
||||
source (PydanticSource): Source to check for duplicates within.
|
||||
organization_id (str): Organization ID to check for duplicates within.
|
||||
|
||||
Returns:
|
||||
str: A unique filename with numeric suffix if needed.
|
||||
str: A unique filename with source.name prefix and numeric suffix if needed.
|
||||
"""
|
||||
base, ext = os.path.splitext(original_filename)
|
||||
|
||||
@@ -271,9 +280,44 @@ class FileManager:
|
||||
# No duplicates, return original filename with source.name
|
||||
return f"{source.name}/{original_filename}"
|
||||
else:
|
||||
# Add numeric suffix
|
||||
# Add numeric suffix to make unique
|
||||
return f"{source.name}/{base}_({count}){ext}"
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_file_by_original_name_and_source(
|
||||
self, original_filename: str, source_id: str, actor: PydanticUser
|
||||
) -> Optional[PydanticFileMetadata]:
|
||||
"""
|
||||
Get a file by its original filename and source ID.
|
||||
|
||||
Parameters:
|
||||
original_filename (str): The original filename to search for.
|
||||
source_id (str): The source ID to search within.
|
||||
actor (PydanticUser): The actor performing the request.
|
||||
|
||||
Returns:
|
||||
Optional[PydanticFileMetadata]: The file metadata if found, None otherwise.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
query = (
|
||||
select(FileMetadataModel)
|
||||
.where(
|
||||
FileMetadataModel.original_file_name == original_filename,
|
||||
FileMetadataModel.source_id == source_id,
|
||||
FileMetadataModel.organization_id == actor.organization_id,
|
||||
FileMetadataModel.is_deleted == False,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
file_orm = result.scalar_one_or_none()
|
||||
|
||||
if file_orm:
|
||||
return await file_orm.to_pydantic_async()
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_organization_sources_metadata(self, actor: PydanticUser) -> OrganizationSourcesStats:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.otel.context import get_ctx_attributes
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
@@ -122,6 +123,10 @@ class FileProcessor:
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
from letta.otel.metric_registry import MetricRegistry
|
||||
|
||||
MetricRegistry().file_process_bytes_histogram.record(len(content), attributes=get_ctx_attributes())
|
||||
|
||||
if len(content) > self.max_file_size:
|
||||
log_event(
|
||||
"file_processor.size_limit_exceeded",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ast
|
||||
import base64
|
||||
import pickle
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, SEND_MESSAGE_TOOL_NAME
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -9,7 +9,7 @@ from letta.schemas.response_format import ResponseFormatType, ResponseFormatUnio
|
||||
from letta.types import JsonDict, JsonValue
|
||||
|
||||
|
||||
def parse_stdout_best_effort(text: str | bytes) -> tuple[Any, AgentState | None]:
|
||||
def parse_stdout_best_effort(text: Union[str, bytes]) -> tuple[Any, AgentState | None]:
|
||||
"""
|
||||
Decode and unpickle the result from the function execution if possible.
|
||||
Returns (function_return_value, agent_state).
|
||||
|
||||
@@ -2,6 +2,7 @@ from functools import partial, reduce
|
||||
from operator import add
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from httpx import AsyncClient, post
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -95,6 +96,8 @@ class JobManager:
|
||||
@trace_method
|
||||
async def update_job_by_id_async(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob:
|
||||
"""Update a job by its ID with the given JobUpdate object asynchronously."""
|
||||
callback_func = None
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch the job by ID
|
||||
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
|
||||
@@ -114,11 +117,23 @@ class JobManager:
|
||||
logger.info(f"Current job completed at: {job.completed_at}")
|
||||
job.completed_at = get_utc_time().replace(tzinfo=None)
|
||||
if job.callback_url:
|
||||
await self._dispatch_callback_async(job)
|
||||
callback_func = self._dispatch_callback_async(
|
||||
callback_url=job.callback_url,
|
||||
payload={
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
||||
"metadata": job.metadata_,
|
||||
},
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Save the updated job to the database
|
||||
await job.update_async(db_session=session, actor=actor)
|
||||
|
||||
if callback_func:
|
||||
return await callback_func
|
||||
|
||||
return job.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -683,10 +698,8 @@ class JobManager:
|
||||
"metadata": job.metadata_,
|
||||
}
|
||||
try:
|
||||
import httpx
|
||||
|
||||
log_event("POST callback dispatched", payload)
|
||||
resp = httpx.post(job.callback_url, json=payload, timeout=5.0)
|
||||
resp = post(job.callback_url, json=payload, timeout=5.0)
|
||||
log_event("POST callback finished")
|
||||
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
||||
job.callback_status_code = resp.status_code
|
||||
@@ -700,31 +713,33 @@ class JobManager:
|
||||
# Continue silently - callback failures should not affect job completion
|
||||
|
||||
@trace_method
|
||||
async def _dispatch_callback_async(self, job: JobModel) -> None:
|
||||
async def _dispatch_callback_async(self, callback_url: str, payload: dict, actor: PydanticUser) -> PydanticJob:
|
||||
"""
|
||||
POST a standard JSON payload to job.callback_url and record timestamp + HTTP status asynchronously.
|
||||
"""
|
||||
payload = {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
||||
"metadata": job.metadata_,
|
||||
}
|
||||
job_id = payload["job_id"]
|
||||
callback_sent_at, callback_status_code, callback_error = None, None, None
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with AsyncClient() as client:
|
||||
log_event("POST callback dispatched", payload)
|
||||
resp = await client.post(job.callback_url, json=payload, timeout=5.0)
|
||||
resp = await client.post(callback_url, json=payload, timeout=5.0)
|
||||
log_event("POST callback finished")
|
||||
# Ensure timestamp is timezone-naive for DB compatibility
|
||||
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
||||
job.callback_status_code = resp.status_code
|
||||
callback_sent_at = get_utc_time().replace(tzinfo=None)
|
||||
callback_status_code = resp.status_code
|
||||
except Exception as e:
|
||||
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {e!s}"
|
||||
error_message = f"Failed to dispatch callback for job {job_id} to {callback_url}: {e!s}"
|
||||
logger.error(error_message)
|
||||
# Record the failed attempt
|
||||
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
||||
job.callback_error = error_message
|
||||
callback_sent_at = get_utc_time().replace(tzinfo=None)
|
||||
callback_error = error_message
|
||||
# Continue silently - callback failures should not affect job completion
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
job = await JobModel.read_async(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER)
|
||||
job.callback_sent_at = callback_sent_at
|
||||
job.callback_status_code = callback_status_code
|
||||
job.callback_error = callback_error
|
||||
await job.update_async(db_session=session, actor=actor)
|
||||
return job.to_pydantic()
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic.config import JsonDict
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
@@ -23,6 +24,8 @@ from letta.services.tool_sandbox.base import AsyncToolSandboxBase
|
||||
from letta.settings import tool_settings
|
||||
from letta.utils import get_friendly_error_msg, parse_stderr_error_msg
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
METADATA_CONFIG_STATE_KEY = "config_state"
|
||||
@@ -240,9 +243,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
if isinstance(e, TimeoutError):
|
||||
raise e
|
||||
|
||||
print(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}")
|
||||
print(e.__class__.__name__)
|
||||
print(e.__traceback__)
|
||||
logger.error(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}")
|
||||
logger.error(e.__class__.__name__)
|
||||
logger.error(e.__traceback__)
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name,
|
||||
exception_name=type(e).__name__,
|
||||
|
||||
@@ -24,8 +24,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl
|
||||
{{ tool_source_code }}
|
||||
|
||||
{# Invoke the function and store the result in a global variable #}
|
||||
_function_result = {{ invoke_function_call }}
|
||||
|
||||
{# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #}
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
class _TempResultWrapper(BaseModel):
|
||||
result: Any
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
_wrapped = _TempResultWrapper(result=_function_result)
|
||||
_serialized_result = _wrapped.model_dump()['result']
|
||||
except ImportError:
|
||||
# Pydantic not available in sandbox, fall back to string conversion
|
||||
print("Pydantic not available in sandbox environment, falling back to string conversion")
|
||||
_serialized_result = str(_function_result)
|
||||
except Exception as e:
|
||||
# If wrapping fails, print the error and stringify the result
|
||||
print(f"Failed to serialize result with Pydantic wrapper: {e}")
|
||||
_serialized_result = str(_function_result)
|
||||
|
||||
{{ local_sandbox_result_var_name }} = {
|
||||
"results": {{ invoke_function_call }},
|
||||
"results": _serialized_result,
|
||||
"agent_state": agent_state
|
||||
}
|
||||
|
||||
|
||||
@@ -26,9 +26,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl
|
||||
|
||||
{# Async wrapper to handle the function call and store the result #}
|
||||
async def _async_wrapper():
|
||||
result = await {{ invoke_function_call }}
|
||||
_function_result = await {{ invoke_function_call }}
|
||||
|
||||
{# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #}
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
class _TempResultWrapper(BaseModel):
|
||||
result: Any
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
_wrapped = _TempResultWrapper(result=_function_result)
|
||||
_serialized_result = _wrapped.model_dump()['result']
|
||||
except ImportError:
|
||||
# Pydantic not available in sandbox, fall back to string conversion
|
||||
print("Pydantic not available in sandbox environment, falling back to string conversion")
|
||||
_serialized_result = str(_function_result)
|
||||
except Exception as e:
|
||||
# If wrapping fails, print the error and stringify the result
|
||||
print(f"Failed to serialize result with Pydantic wrapper: {e}")
|
||||
_serialized_result = str(_function_result)
|
||||
|
||||
return {
|
||||
"results": result,
|
||||
"results": _serialized_result,
|
||||
"agent_state": agent_state
|
||||
}
|
||||
|
||||
|
||||
41
poetry.lock
generated
41
poetry.lock
generated
@@ -186,9 +186,10 @@ speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>
|
||||
name = "aiohttp-retry"
|
||||
version = "2.9.1"
|
||||
description = "Simple retry client for aiohttp"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"pinecone\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "aiohttp_retry-2.9.1-py3-none-any.whl", hash = "sha256:66d2759d1921838256a05a3f80ad7e724936f083e35be5abb5e16eed6be6dc54"},
|
||||
{file = "aiohttp_retry-2.9.1.tar.gz", hash = "sha256:8eb75e904ed4ee5c2ec242fefe85bf04240f685391c4879d8f541d6028ff01f1"},
|
||||
@@ -944,7 +945,7 @@ version = "2025.6.15"
|
||||
description = "Python package for providing Mozilla's CA Bundle."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "certifi-2025.6.15-py3-none-any.whl", hash = "sha256:2e0c7ce7cb5d8f8634ca55d2ba7e6ec2689a2fd6537d8dec1296a477a4910057"},
|
||||
{file = "certifi-2025.6.15.tar.gz", hash = "sha256:d747aa5a8b9bbbb1bb8c22bb13e22bd1f18e9796defa16bab421f7f7a317323b"},
|
||||
@@ -1050,7 +1051,7 @@ version = "3.4.2"
|
||||
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "charset_normalizer-3.4.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7c48ed483eb946e6c04ccbe02c6b4d1d48e51944b6db70f697e089c193404941"},
|
||||
{file = "charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2d318c11350e10662026ad0eb71bb51c7812fc8590825304ae0bdd4ac283acd"},
|
||||
@@ -2792,7 +2793,7 @@ version = "3.10"
|
||||
description = "Internationalized Domain Names in Applications (IDNA)"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"},
|
||||
{file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"},
|
||||
@@ -3590,14 +3591,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "letta-client"
|
||||
version = "0.1.191"
|
||||
version = "0.1.198"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "letta_client-0.1.191-py3-none-any.whl", hash = "sha256:2cc234668784b022a25aeab4db48b944a0a188e42112870efd8b028ad223347b"},
|
||||
{file = "letta_client-0.1.191.tar.gz", hash = "sha256:95957695e679183ec0d87673c8dd169a83f9807359c4740d42ed84fbb6b05efc"},
|
||||
{file = "letta_client-0.1.198-py3-none-any.whl", hash = "sha256:08bbc238b128da2552b2a6e54feb3294794b5586e0962ce0bb95bb525109f58f"},
|
||||
{file = "letta_client-0.1.198.tar.gz", hash = "sha256:990c9132423e2955d9c7f7549e5064b2366616232d270e5927788cddba4ef9da"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -5207,9 +5208,10 @@ xmp = ["defusedxml"]
|
||||
name = "pinecone"
|
||||
version = "7.3.0"
|
||||
description = "Pinecone client and SDK"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["main"]
|
||||
markers = "extra == \"pinecone\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "pinecone-7.3.0-py3-none-any.whl", hash = "sha256:315b8fef20320bef723ecbb695dec0aafa75d8434d86e01e5a0e85933e1009a8"},
|
||||
{file = "pinecone-7.3.0.tar.gz", hash = "sha256:307edc155621d487c20dc71b76c3ad5d6f799569ba42064190d03917954f9a7b"},
|
||||
@@ -5236,9 +5238,10 @@ grpc = ["googleapis-common-protos (>=1.66.0)", "grpcio (>=1.44.0) ; python_versi
|
||||
name = "pinecone-plugin-assistant"
|
||||
version = "1.7.0"
|
||||
description = "Assistant plugin for Pinecone SDK"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["main"]
|
||||
markers = "extra == \"pinecone\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "pinecone_plugin_assistant-1.7.0-py3-none-any.whl", hash = "sha256:864cb8e7930588e6c2da97c6d44f0240969195f43fa303c5db76cbc12bf903a5"},
|
||||
{file = "pinecone_plugin_assistant-1.7.0.tar.gz", hash = "sha256:e26e3ba10a8b71c3da0d777cff407668022e82963c4913d0ffeb6c552721e482"},
|
||||
@@ -5252,9 +5255,10 @@ requests = ">=2.32.3,<3.0.0"
|
||||
name = "pinecone-plugin-interface"
|
||||
version = "0.0.7"
|
||||
description = "Plugin interface for the Pinecone python client"
|
||||
optional = false
|
||||
optional = true
|
||||
python-versions = "<4.0,>=3.8"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["main"]
|
||||
markers = "extra == \"pinecone\" or extra == \"all\""
|
||||
files = [
|
||||
{file = "pinecone_plugin_interface-0.0.7-py3-none-any.whl", hash = "sha256:875857ad9c9fc8bbc074dbe780d187a2afd21f5bfe0f3b08601924a61ef1bba8"},
|
||||
{file = "pinecone_plugin_interface-0.0.7.tar.gz", hash = "sha256:b8e6675e41847333aa13923cc44daa3f85676d7157324682dc1640588a982846"},
|
||||
@@ -6578,7 +6582,7 @@ version = "2.32.4"
|
||||
description = "Python HTTP for Humans."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c"},
|
||||
{file = "requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422"},
|
||||
@@ -7473,7 +7477,7 @@ files = [
|
||||
{file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"},
|
||||
{file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"},
|
||||
]
|
||||
markers = {"dev,tests" = "python_version == \"3.10\""}
|
||||
markers = {dev = "python_version < \"3.12\"", "dev,tests" = "python_version == \"3.10\""}
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspect"
|
||||
@@ -7542,7 +7546,7 @@ version = "2.5.0"
|
||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main", "dev"]
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc"},
|
||||
{file = "urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760"},
|
||||
@@ -8348,7 +8352,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
|
||||
cffi = ["cffi (>=1.11)"]
|
||||
|
||||
[extras]
|
||||
all = ["autoflake", "black", "docker", "fastapi", "granian", "isort", "langchain", "langchain-community", "locust", "pexpect", "pg8000", "pgvector", "pre-commit", "psycopg2", "psycopg2-binary", "pyright", "pytest-asyncio", "pytest-order", "redis", "uvicorn", "uvloop", "wikipedia"]
|
||||
all = ["autoflake", "black", "docker", "fastapi", "granian", "isort", "langchain", "langchain-community", "locust", "pexpect", "pg8000", "pgvector", "pinecone", "pre-commit", "psycopg2", "psycopg2-binary", "pyright", "pytest-asyncio", "pytest-order", "redis", "uvicorn", "uvloop", "wikipedia"]
|
||||
bedrock = ["aioboto3", "boto3"]
|
||||
cloud-tool-sandbox = ["e2b-code-interpreter"]
|
||||
desktop = ["docker", "fastapi", "langchain", "langchain-community", "locust", "pg8000", "pgvector", "psycopg2", "psycopg2-binary", "pyright", "uvicorn", "wikipedia"]
|
||||
@@ -8356,6 +8360,7 @@ dev = ["autoflake", "black", "isort", "locust", "pexpect", "pre-commit", "pyrigh
|
||||
experimental = ["granian", "uvloop"]
|
||||
external-tools = ["docker", "firecrawl-py", "langchain", "langchain-community", "wikipedia"]
|
||||
google = ["google-genai"]
|
||||
pinecone = ["pinecone"]
|
||||
postgres = ["asyncpg", "pg8000", "pgvector", "psycopg2", "psycopg2-binary"]
|
||||
redis = ["redis"]
|
||||
server = ["fastapi", "uvicorn"]
|
||||
@@ -8364,4 +8369,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "b2f23b566c52a2ecb9d656383d00f8ac64293690e0f6fcc0d354ed2d2ad8807b"
|
||||
content-hash = "c4fa225d582dac743e5eb8a1338a71c32bd8c8eb8283dec2331c28599bdf7698"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.8.12"
|
||||
version = "0.8.13"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
@@ -72,7 +72,7 @@ llama-index = "^0.12.2"
|
||||
llama-index-embeddings-openai = "^0.3.1"
|
||||
e2b-code-interpreter = {version = "^1.0.3", optional = true}
|
||||
anthropic = "^0.49.0"
|
||||
letta_client = "^0.1.183"
|
||||
letta_client = "^0.1.197"
|
||||
openai = "^1.60.0"
|
||||
opentelemetry-api = "1.30.0"
|
||||
opentelemetry-sdk = "1.30.0"
|
||||
@@ -98,13 +98,13 @@ redis = {version = "^6.2.0", optional = true}
|
||||
structlog = "^25.4.0"
|
||||
certifi = "^2025.6.15"
|
||||
aioboto3 = {version = "^14.3.0", optional = true}
|
||||
pinecone = {extras = ["asyncio"], version = "^7.3.0"}
|
||||
aiosqlite = "^0.21.0"
|
||||
pinecone = {extras = ["asyncio"], version = "^7.3.0", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
postgres = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "asyncpg"]
|
||||
redis = ["redis"]
|
||||
pinecone = ["pinecone"]
|
||||
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "locust"]
|
||||
experimental = ["uvloop", "granian"]
|
||||
server = ["websockets", "fastapi", "uvicorn"]
|
||||
@@ -114,7 +114,7 @@ tests = ["wikipedia"]
|
||||
bedrock = ["boto3", "aioboto3"]
|
||||
google = ["google-genai"]
|
||||
desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis", "pinecone"]
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
@@ -122,8 +122,6 @@ black = "^24.4.2"
|
||||
ipykernel = "^6.29.5"
|
||||
ipdb = "^0.13.13"
|
||||
pytest-mock = "^3.14.0"
|
||||
pinecone = "^7.3.0"
|
||||
|
||||
|
||||
[tool.poetry.group."dev,tests".dependencies]
|
||||
pytest-json-report = "^1.5.0"
|
||||
|
||||
@@ -28,6 +28,32 @@ def disable_e2b_api_key() -> Generator[None, None, None]:
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2b_sandbox_mode(request) -> Generator[None, None, None]:
|
||||
"""
|
||||
Parametrizable fixture to enable/disable E2B sandbox mode.
|
||||
|
||||
Usage:
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
def test_function(e2b_sandbox_mode, ...):
|
||||
# Test runs twice - once with E2B enabled, once disabled
|
||||
"""
|
||||
from letta.settings import tool_settings
|
||||
|
||||
enable_e2b = request.param
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
if not enable_e2b:
|
||||
# Disable E2B by setting API key to None
|
||||
tool_settings.e2b_api_key = None
|
||||
# If enable_e2b is True, leave the original API key unchanged
|
||||
|
||||
yield
|
||||
|
||||
# Restore original API key
|
||||
tool_settings.e2b_api_key = original_api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_pinecone() -> Generator[None, None, None]:
|
||||
"""
|
||||
|
||||
@@ -752,13 +752,14 @@ def test_step_stream_agent_loop_error(
|
||||
"""
|
||||
last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1)
|
||||
agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config)
|
||||
with pytest.raises(ApiError):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state_no_tools.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
list(response)
|
||||
|
||||
assert type(exc_info.value) in (ApiError, ValueError)
|
||||
print(exc_info.value)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id)
|
||||
assert len(messages_from_db) == 0
|
||||
|
||||
|
||||
@@ -8,8 +8,6 @@ import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
|
||||
# tests/test_file_content_flow.py
|
||||
import pytest
|
||||
from _pytest.python_api import approx
|
||||
@@ -3902,7 +3900,7 @@ async def test_bulk_update_return_hydrated_true(server: SyncServer, default_user
|
||||
mgr = BlockManager()
|
||||
|
||||
# create a block
|
||||
b = mgr.create_or_update_block(
|
||||
b = await mgr.create_or_update_block_async(
|
||||
PydanticBlock(label="persona", value="foo", limit=20),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -5117,6 +5115,137 @@ async def test_delete_cascades_to_content(server, default_user, default_source,
|
||||
assert await _count_file_content_rows(async_session, created.id) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_by_original_name_and_source_found(server: SyncServer, default_user, default_source):
|
||||
"""Test retrieving a file by original filename and source when it exists."""
|
||||
original_filename = "test_original_file.txt"
|
||||
file_metadata = PydanticFileMetadata(
|
||||
file_name="some_generated_name.txt",
|
||||
original_file_name=original_filename,
|
||||
file_path="/path/to/test_file.txt",
|
||||
file_type="text/plain",
|
||||
file_size=1024,
|
||||
source_id=default_source.id,
|
||||
)
|
||||
created_file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user)
|
||||
|
||||
# Retrieve the file by original name and source
|
||||
retrieved_file = await server.file_manager.get_file_by_original_name_and_source(
|
||||
original_filename=original_filename, source_id=default_source.id, actor=default_user
|
||||
)
|
||||
|
||||
# Assertions to verify the retrieved file matches the created one
|
||||
assert retrieved_file is not None
|
||||
assert retrieved_file.id == created_file.id
|
||||
assert retrieved_file.original_file_name == original_filename
|
||||
assert retrieved_file.source_id == default_source.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_by_original_name_and_source_not_found(server: SyncServer, default_user, default_source):
|
||||
"""Test retrieving a file by original filename and source when it doesn't exist."""
|
||||
non_existent_filename = "does_not_exist.txt"
|
||||
|
||||
# Try to retrieve a non-existent file
|
||||
retrieved_file = await server.file_manager.get_file_by_original_name_and_source(
|
||||
original_filename=non_existent_filename, source_id=default_source.id, actor=default_user
|
||||
)
|
||||
|
||||
# Should return None for non-existent file
|
||||
assert retrieved_file is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_by_original_name_and_source_different_sources(server: SyncServer, default_user, default_source):
|
||||
"""Test that files with same original name in different sources are handled correctly."""
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
|
||||
# Create a second source
|
||||
second_source_pydantic = PydanticSource(
|
||||
name="second_test_source",
|
||||
description="This is a test source.",
|
||||
metadata={"type": "test"},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
)
|
||||
second_source = await server.source_manager.create_source(source=second_source_pydantic, actor=default_user)
|
||||
|
||||
original_filename = "shared_filename.txt"
|
||||
|
||||
# Create file in first source
|
||||
file_metadata_1 = PydanticFileMetadata(
|
||||
file_name="file_in_source_1.txt",
|
||||
original_file_name=original_filename,
|
||||
file_path="/path/to/file1.txt",
|
||||
file_type="text/plain",
|
||||
file_size=1024,
|
||||
source_id=default_source.id,
|
||||
)
|
||||
created_file_1 = await server.file_manager.create_file(file_metadata=file_metadata_1, actor=default_user)
|
||||
|
||||
# Create file with same original name in second source
|
||||
file_metadata_2 = PydanticFileMetadata(
|
||||
file_name="file_in_source_2.txt",
|
||||
original_file_name=original_filename,
|
||||
file_path="/path/to/file2.txt",
|
||||
file_type="text/plain",
|
||||
file_size=2048,
|
||||
source_id=second_source.id,
|
||||
)
|
||||
created_file_2 = await server.file_manager.create_file(file_metadata=file_metadata_2, actor=default_user)
|
||||
|
||||
# Retrieve file from first source
|
||||
retrieved_file_1 = await server.file_manager.get_file_by_original_name_and_source(
|
||||
original_filename=original_filename, source_id=default_source.id, actor=default_user
|
||||
)
|
||||
|
||||
# Retrieve file from second source
|
||||
retrieved_file_2 = await server.file_manager.get_file_by_original_name_and_source(
|
||||
original_filename=original_filename, source_id=second_source.id, actor=default_user
|
||||
)
|
||||
|
||||
# Should retrieve different files
|
||||
assert retrieved_file_1 is not None
|
||||
assert retrieved_file_2 is not None
|
||||
assert retrieved_file_1.id == created_file_1.id
|
||||
assert retrieved_file_2.id == created_file_2.id
|
||||
assert retrieved_file_1.id != retrieved_file_2.id
|
||||
assert retrieved_file_1.source_id == default_source.id
|
||||
assert retrieved_file_2.source_id == second_source.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_file_by_original_name_and_source_ignores_deleted(server: SyncServer, default_user, default_source):
|
||||
"""Test that deleted files are ignored when searching by original name and source."""
|
||||
original_filename = "to_be_deleted.txt"
|
||||
file_metadata = PydanticFileMetadata(
|
||||
file_name="deletable_file.txt",
|
||||
original_file_name=original_filename,
|
||||
file_path="/path/to/deletable.txt",
|
||||
file_type="text/plain",
|
||||
file_size=512,
|
||||
source_id=default_source.id,
|
||||
)
|
||||
created_file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user)
|
||||
|
||||
# Verify file can be found before deletion
|
||||
retrieved_file = await server.file_manager.get_file_by_original_name_and_source(
|
||||
original_filename=original_filename, source_id=default_source.id, actor=default_user
|
||||
)
|
||||
assert retrieved_file is not None
|
||||
assert retrieved_file.id == created_file.id
|
||||
|
||||
# Delete the file
|
||||
await server.file_manager.delete_file(created_file.id, actor=default_user)
|
||||
|
||||
# Try to retrieve the deleted file
|
||||
retrieved_file_after_delete = await server.file_manager.get_file_by_original_name_and_source(
|
||||
original_filename=original_filename, source_id=default_source.id, actor=default_user
|
||||
)
|
||||
|
||||
# Should return None for deleted file
|
||||
assert retrieved_file_after_delete is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files(server: SyncServer, default_user, default_source):
|
||||
"""Test listing files with pagination."""
|
||||
@@ -5870,7 +5999,9 @@ async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user):
|
||||
return await mock_post(url, json, timeout)
|
||||
|
||||
# Patch the AsyncClient
|
||||
monkeypatch.setattr(httpx, "AsyncClient", MockAsyncClient)
|
||||
import letta.services.job_manager as job_manager_module
|
||||
|
||||
monkeypatch.setattr(job_manager_module, "AsyncClient", MockAsyncClient)
|
||||
|
||||
job_in = PydanticJob(status=JobStatus.created, metadata={"foo": "bar"}, callback_url="http://example.test/webhook/jobs")
|
||||
created = await server.job_manager.create_job_async(pydantic_job=job_in, actor=default_user)
|
||||
@@ -7813,3 +7944,231 @@ async def test_attach_files_bulk_oversized_bulk(server, default_user, sarah_agen
|
||||
# All files should be attached (some open, some closed)
|
||||
all_files_after = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user)
|
||||
assert len(all_files_after) == MAX_FILES_OPEN + 3
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Race Condition Tests - Blocks
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_block_updates_race_condition(
|
||||
server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop
|
||||
):
|
||||
"""Test that concurrent block updates don't cause race conditions."""
|
||||
agent, _ = comprehensive_test_agent_fixture
|
||||
|
||||
# Create multiple blocks to use in concurrent updates
|
||||
blocks = []
|
||||
for i in range(5):
|
||||
block = await server.block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label=f"test_block_{i}", value=f"Test block content {i}", limit=1000), actor=default_user
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
# Test concurrent updates with different block combinations
|
||||
async def update_agent_blocks(block_subset):
|
||||
"""Update agent with a specific subset of blocks."""
|
||||
update_request = UpdateAgent(block_ids=[b.id for b in block_subset])
|
||||
try:
|
||||
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
|
||||
except Exception as e:
|
||||
# Capture any errors that occur during concurrent updates
|
||||
return {"error": str(e)}
|
||||
|
||||
# Run concurrent updates with different block combinations
|
||||
tasks = [
|
||||
update_agent_blocks(blocks[:2]), # blocks 0, 1
|
||||
update_agent_blocks(blocks[1:3]), # blocks 1, 2
|
||||
update_agent_blocks(blocks[2:4]), # blocks 2, 3
|
||||
update_agent_blocks(blocks[3:5]), # blocks 3, 4
|
||||
update_agent_blocks(blocks[:1]), # block 0 only
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Verify no exceptions occurred
|
||||
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
|
||||
assert len(errors) == 0, f"Concurrent updates failed with errors: {errors}"
|
||||
|
||||
# Verify all results are valid agent states
|
||||
valid_results = [r for r in results if not isinstance(r, Exception) and not (isinstance(r, dict) and "error" in r)]
|
||||
assert len(valid_results) == 5, "All concurrent updates should succeed"
|
||||
|
||||
# Verify final state is consistent
|
||||
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
|
||||
assert final_agent is not None
|
||||
assert len(final_agent.memory.blocks) > 0
|
||||
|
||||
# Clean up
|
||||
for block in blocks:
|
||||
await server.block_manager.delete_block_async(block.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_same_block_updates_race_condition(
|
||||
server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop
|
||||
):
|
||||
"""Test that multiple concurrent updates to the same block configuration don't cause issues."""
|
||||
agent, _ = comprehensive_test_agent_fixture
|
||||
|
||||
# Create a single block configuration to use in all updates
|
||||
block = await server.block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label="shared_block", value="Shared block content", limit=1000), actor=default_user
|
||||
)
|
||||
|
||||
# Test multiple concurrent updates with the same block configuration
|
||||
async def update_agent_with_same_blocks():
|
||||
"""Update agent with the same block configuration."""
|
||||
update_request = UpdateAgent(block_ids=[block.id])
|
||||
try:
|
||||
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
# Run 10 concurrent identical updates
|
||||
tasks = [update_agent_with_same_blocks() for _ in range(10)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Verify no exceptions occurred
|
||||
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
|
||||
assert len(errors) == 0, f"Concurrent identical updates failed with errors: {errors}"
|
||||
|
||||
# Verify final state is consistent
|
||||
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
|
||||
assert len(final_agent.memory.blocks) == 1
|
||||
assert final_agent.memory.blocks[0].id == block.id
|
||||
|
||||
# Clean up
|
||||
await server.block_manager.delete_block_async(block.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_empty_block_updates_race_condition(
|
||||
server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop
|
||||
):
|
||||
"""Test concurrent updates that remove all blocks."""
|
||||
agent, _ = comprehensive_test_agent_fixture
|
||||
|
||||
# Test concurrent updates that clear all blocks
|
||||
async def clear_agent_blocks():
|
||||
"""Update agent to have no blocks."""
|
||||
update_request = UpdateAgent(block_ids=[])
|
||||
try:
|
||||
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
# Run concurrent clear operations
|
||||
tasks = [clear_agent_blocks() for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Verify no exceptions occurred
|
||||
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
|
||||
assert len(errors) == 0, f"Concurrent clear operations failed with errors: {errors}"
|
||||
|
||||
# Verify final state is consistent (no blocks)
|
||||
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
|
||||
assert len(final_agent.memory.blocks) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_mixed_block_operations_race_condition(
|
||||
server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop
|
||||
):
|
||||
"""Test mixed concurrent operations: some adding blocks, some removing."""
|
||||
agent, _ = comprehensive_test_agent_fixture
|
||||
|
||||
# Create test blocks
|
||||
blocks = []
|
||||
for i in range(3):
|
||||
block = await server.block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label=f"mixed_block_{i}", value=f"Mixed block content {i}", limit=1000), actor=default_user
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
# Mix of operations: add blocks, remove blocks, clear all
|
||||
async def mixed_operation(operation_type):
|
||||
"""Perform different types of block operations."""
|
||||
if operation_type == "add_all":
|
||||
update_request = UpdateAgent(block_ids=[b.id for b in blocks])
|
||||
elif operation_type == "add_subset":
|
||||
update_request = UpdateAgent(block_ids=[blocks[0].id])
|
||||
elif operation_type == "clear":
|
||||
update_request = UpdateAgent(block_ids=[])
|
||||
else:
|
||||
update_request = UpdateAgent(block_ids=[blocks[1].id, blocks[2].id])
|
||||
|
||||
try:
|
||||
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
# Run mixed concurrent operations
|
||||
tasks = [
|
||||
mixed_operation("add_all"),
|
||||
mixed_operation("add_subset"),
|
||||
mixed_operation("clear"),
|
||||
mixed_operation("add_two"),
|
||||
mixed_operation("add_all"),
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Verify no exceptions occurred
|
||||
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
|
||||
assert len(errors) == 0, f"Mixed concurrent operations failed with errors: {errors}"
|
||||
|
||||
# Verify final state is consistent (any valid state is acceptable)
|
||||
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
|
||||
assert final_agent is not None
|
||||
|
||||
# Clean up
|
||||
for block in blocks:
|
||||
await server.block_manager.delete_block_async(block.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_concurrency_stress_test(server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop):
|
||||
"""Stress test with high concurrency to catch race conditions."""
|
||||
agent, _ = comprehensive_test_agent_fixture
|
||||
|
||||
# Create many blocks for stress testing
|
||||
blocks = []
|
||||
for i in range(10):
|
||||
block = await server.block_manager.create_or_update_block_async(
|
||||
PydanticBlock(label=f"stress_block_{i}", value=f"Stress test content {i}", limit=1000), actor=default_user
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
# Create many concurrent update tasks
|
||||
async def stress_update(task_id):
|
||||
"""Perform a random block update operation."""
|
||||
import random
|
||||
|
||||
# Random subset of blocks
|
||||
num_blocks = random.randint(0, len(blocks))
|
||||
selected_blocks = random.sample(blocks, num_blocks)
|
||||
|
||||
update_request = UpdateAgent(block_ids=[b.id for b in selected_blocks])
|
||||
|
||||
try:
|
||||
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
|
||||
except Exception as e:
|
||||
return {"error": str(e), "task_id": task_id}
|
||||
|
||||
# Run 20 concurrent stress updates
|
||||
tasks = [stress_update(i) for i in range(20)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Verify no exceptions occurred
|
||||
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
|
||||
assert len(errors) == 0, f"High concurrency stress test failed with errors: {errors}"
|
||||
|
||||
# Verify final state is consistent
|
||||
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
|
||||
assert final_agent is not None
|
||||
|
||||
# Clean up
|
||||
for block in blocks:
|
||||
await server.block_manager.delete_block_async(block.id, actor=default_user)
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Type
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import MessageCreate
|
||||
from letta_client import LettaRequest, MessageCreate, TextContent
|
||||
from letta_client.client import BaseTool
|
||||
from letta_client.core import ApiError
|
||||
from letta_client.types import AgentState, ToolReturnMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
@@ -762,3 +766,270 @@ def test_base_tools_upsert_on_list(client: LettaSDKClient):
|
||||
final_tool_names = {tool.name for tool in final_tools}
|
||||
for deleted_tool in tools_to_delete:
|
||||
assert deleted_tool.name in final_tool_names, f"Deleted tool {deleted_tool.name} was not properly restored"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
class InventoryItem(BaseModel):
|
||||
sku: str
|
||||
name: str
|
||||
price: float
|
||||
category: str
|
||||
|
||||
class InventoryEntry(BaseModel):
|
||||
timestamp: int
|
||||
item: InventoryItem
|
||||
transaction_id: str
|
||||
|
||||
class InventoryEntryData(BaseModel):
|
||||
data: InventoryEntry
|
||||
quantity_change: int
|
||||
|
||||
class ManageInventoryTool(BaseTool):
|
||||
name: str = "manage_inventory"
|
||||
args_schema: Type[BaseModel] = InventoryEntryData
|
||||
description: str = "Update inventory catalogue with a new data entry"
|
||||
tags: List[str] = ["inventory", "shop"]
|
||||
|
||||
def run(self, data: InventoryEntry, quantity_change: int) -> bool:
|
||||
print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}")
|
||||
return True
|
||||
|
||||
tool = client.tools.add(
|
||||
tool=ManageInventoryTool(),
|
||||
)
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "manage_inventory"
|
||||
assert "inventory" in tool.tags
|
||||
assert "shop" in tool.tags
|
||||
|
||||
temp_agent = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a helpful inventory management assistant.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[tool.id],
|
||||
include_base_tools=False,
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=temp_agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Update the inventory for product 'iPhone 15' with SKU 'IPH15-001', price $999.99, category 'Electronics', transaction ID 'TXN-12345', timestamp 1640995200, with a quantity change of +10",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
|
||||
tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_call_messages) > 0, "Expected at least one tool call message"
|
||||
|
||||
first_tool_call = tool_call_messages[0]
|
||||
assert first_tool_call.tool_call.name == "manage_inventory"
|
||||
|
||||
args = json.loads(first_tool_call.tool_call.arguments)
|
||||
assert "data" in args
|
||||
assert "quantity_change" in args
|
||||
assert "item" in args["data"]
|
||||
assert "name" in args["data"]["item"]
|
||||
assert "sku" in args["data"]["item"]
|
||||
assert "price" in args["data"]["item"]
|
||||
assert "category" in args["data"]["item"]
|
||||
assert "transaction_id" in args["data"]
|
||||
assert "timestamp" in args["data"]
|
||||
|
||||
tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"]
|
||||
assert len(tool_return_messages) > 0, "Expected at least one tool return message"
|
||||
|
||||
first_tool_return = tool_return_messages[0]
|
||||
assert first_tool_return.status == "success"
|
||||
assert first_tool_return.tool_return == "True"
|
||||
assert "Updated inventory for iPhone 15 with a quantity change of 10" in "\n".join(first_tool_return.stdout)
|
||||
|
||||
client.agents.delete(temp_agent.id)
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
|
||||
class Step(BaseModel):
|
||||
name: str = Field(..., description="Name of the step.")
|
||||
description: str = Field(..., description="An exhaustive description of what this step is trying to achieve.")
|
||||
|
||||
class StepsList(BaseModel):
|
||||
steps: List[Step] = Field(..., description="List of steps to add to the task plan.")
|
||||
explanation: str = Field(..., description="Explanation for the list of steps.")
|
||||
|
||||
def create_task_plan(steps, explanation):
|
||||
"""Creates a task plan for the current task."""
|
||||
print(f"Created task plan with {len(steps)} steps: {explanation}")
|
||||
return steps
|
||||
|
||||
tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"])
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "create_task_plan"
|
||||
assert "planning" in tool.tags
|
||||
assert "task" in tool.tags
|
||||
|
||||
temp_agent = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a helpful task planning assistant.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
tool_ids=[tool.id],
|
||||
include_base_tools=False,
|
||||
)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=temp_agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Create a task plan for organizing a team meeting with 3 steps: 1) Schedule meeting (find available time slots), 2) Send invitations (notify all team members), 3) Prepare agenda (outline discussion topics). Explanation: This plan ensures a well-organized team meeting.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "messages")
|
||||
assert len(response.messages) > 0
|
||||
|
||||
tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"]
|
||||
assert len(tool_call_messages) > 0, "Expected at least one tool call message"
|
||||
|
||||
first_tool_call = tool_call_messages[0]
|
||||
assert first_tool_call.tool_call.name == "create_task_plan"
|
||||
|
||||
args = json.loads(first_tool_call.tool_call.arguments)
|
||||
assert "steps" in args
|
||||
assert "explanation" in args
|
||||
assert isinstance(args["steps"], list)
|
||||
assert len(args["steps"]) > 0
|
||||
|
||||
for step in args["steps"]:
|
||||
assert "name" in step
|
||||
assert "description" in step
|
||||
|
||||
tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"]
|
||||
assert len(tool_return_messages) > 0, "Expected at least one tool return message"
|
||||
|
||||
first_tool_return = tool_return_messages[0]
|
||||
assert first_tool_return.status == "success"
|
||||
|
||||
client.agents.delete(temp_agent.id)
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True)
|
||||
def test_create_tool_from_function_with_docstring(e2b_sandbox_mode, client: LettaSDKClient):
|
||||
"""Test creating a tool from a function with a docstring using create_from_function"""
|
||||
|
||||
def roll_dice() -> str:
|
||||
"""
|
||||
Simulate the roll of a 20-sided die (d20).
|
||||
|
||||
This function generates a random integer between 1 and 20, inclusive,
|
||||
which represents the outcome of a single roll of a d20.
|
||||
|
||||
Returns:
|
||||
str: The result of the die roll.
|
||||
"""
|
||||
import random
|
||||
|
||||
dice_role_outcome = random.randint(1, 20)
|
||||
output_string = f"You rolled a {dice_role_outcome}"
|
||||
return output_string
|
||||
|
||||
tool = client.tools.create_from_function(func=roll_dice)
|
||||
|
||||
assert tool is not None
|
||||
assert tool.name == "roll_dice"
|
||||
assert "Simulate the roll of a 20-sided die" in tool.description
|
||||
assert tool.source_code is not None
|
||||
assert "random.randint(1, 20)" in tool.source_code
|
||||
|
||||
all_tools = client.tools.list()
|
||||
tool_names = [t.name for t in all_tools]
|
||||
assert "roll_dice" in tool_names
|
||||
|
||||
client.tools.delete(tool.id)
|
||||
|
||||
|
||||
def test_preview_payload(client: LettaSDKClient, agent):
|
||||
payload = client.agents.messages.preview_raw_payload(
|
||||
agent_id=agent.id,
|
||||
request=LettaRequest(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
TextContent(
|
||||
text="text",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(payload, dict)
|
||||
assert "model" in payload
|
||||
assert "messages" in payload
|
||||
assert "tools" in payload
|
||||
assert "frequency_penalty" in payload
|
||||
assert "max_completion_tokens" in payload
|
||||
assert "temperature" in payload
|
||||
assert "user" in payload
|
||||
assert "parallel_tool_calls" in payload
|
||||
assert "tool_choice" in payload
|
||||
|
||||
assert payload["model"] == "gpt-4o-mini"
|
||||
|
||||
assert isinstance(payload["messages"], list)
|
||||
assert len(payload["messages"]) >= 3
|
||||
|
||||
system_message = payload["messages"][0]
|
||||
assert system_message["role"] == "system"
|
||||
assert "base_instructions" in system_message["content"]
|
||||
assert "memory_blocks" in system_message["content"]
|
||||
assert "tool_usage_rules" in system_message["content"]
|
||||
assert "Letta" in system_message["content"]
|
||||
|
||||
assert isinstance(payload["tools"], list)
|
||||
assert len(payload["tools"]) > 0
|
||||
|
||||
tool_names = [tool["function"]["name"] for tool in payload["tools"]]
|
||||
expected_tools = ["send_message", "conversation_search", "core_memory_replace", "core_memory_append"]
|
||||
for tool_name in expected_tools:
|
||||
assert tool_name in tool_names, f"Expected tool {tool_name} not found in tools"
|
||||
|
||||
for tool in payload["tools"]:
|
||||
assert tool["type"] == "function"
|
||||
assert "function" in tool
|
||||
assert "name" in tool["function"]
|
||||
assert "description" in tool["function"]
|
||||
assert "parameters" in tool["function"]
|
||||
assert tool["function"]["strict"] is True
|
||||
|
||||
assert payload["frequency_penalty"] == 1.0
|
||||
assert payload["max_completion_tokens"] == 4096
|
||||
assert payload["temperature"] == 0.7
|
||||
assert payload["parallel_tool_calls"] is False
|
||||
assert payload["tool_choice"] == "required"
|
||||
assert payload["user"].startswith("user-")
|
||||
|
||||
print(payload)
|
||||
|
||||
@@ -7,18 +7,37 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import CreateBlock
|
||||
from letta_client import Letta as LettaSDKClient
|
||||
from letta_client import LettaRequest
|
||||
from letta_client import MessageCreate as ClientMessageCreate
|
||||
from letta_client.types import AgentState
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
|
||||
|
||||
def get_raw_system_message(client: LettaSDKClient, agent_id: str) -> str:
|
||||
"""Helper function to get the raw system message from an agent's preview payload."""
|
||||
raw_payload = client.agents.messages.preview_raw_payload(
|
||||
agent_id=agent_id,
|
||||
request=LettaRequest(
|
||||
messages=[
|
||||
ClientMessageCreate(
|
||||
role="user",
|
||||
content="Testing",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
return raw_payload["messages"][0]["content"]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_sources(client: LettaSDKClient):
|
||||
# Clear existing sources
|
||||
@@ -172,6 +191,10 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
expected_value: str,
|
||||
expected_label_regex: str,
|
||||
):
|
||||
# skip pdf tests if mistral api key is missing
|
||||
if file_path.endswith(".pdf") and not settings.mistral_api_key:
|
||||
pytest.skip("mistral api key required for pdf processing")
|
||||
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
@@ -195,6 +218,15 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
assert any(b.value.startswith("[Viewing file start") for b in blocks)
|
||||
assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
|
||||
# verify raw system message contains source information
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
assert "test_source" in raw_system_message
|
||||
assert "<directories>" in raw_system_message
|
||||
# verify file-specific details in raw system message
|
||||
file_name = files[0].file_name
|
||||
assert f'name="test_source/{file_name}"' in raw_system_message
|
||||
assert 'status="open"' in raw_system_message
|
||||
|
||||
# Remove file from source
|
||||
client.sources.files.delete(source_id=source.id, file_id=files[0].id)
|
||||
|
||||
@@ -205,6 +237,14 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
assert not any(expected_value in b.value for b in blocks)
|
||||
assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks)
|
||||
|
||||
# verify raw system message no longer contains source information
|
||||
raw_system_message_after_removal = get_raw_system_message(client, agent_state.id)
|
||||
# this should be in, because we didn't delete the source
|
||||
assert "test_source" in raw_system_message_after_removal
|
||||
assert "<directories>" in raw_system_message_after_removal
|
||||
# verify file-specific details are also removed
|
||||
assert f'name="test_source/{file_name}"' not in raw_system_message_after_removal
|
||||
|
||||
|
||||
def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
@@ -224,6 +264,25 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone,
|
||||
|
||||
# Attach after uploading the file
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
|
||||
# Assert that the expected chunk is in the raw system message
|
||||
expected_chunk = """<directories>
|
||||
<directory name="test_source">
|
||||
<file status="open" name="test_source/test.txt">
|
||||
<metadata>
|
||||
- read_only=true
|
||||
- chars_current=46
|
||||
- chars_limit=50000
|
||||
</metadata>
|
||||
<value>
|
||||
[Viewing file start (out of 1 chunks)]
|
||||
1: test
|
||||
</value>
|
||||
</file>
|
||||
</directory>
|
||||
</directories>"""
|
||||
assert expected_chunk in raw_system_message
|
||||
|
||||
# Get the agent state, check blocks exist
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
@@ -241,20 +300,46 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone,
|
||||
assert len(blocks) == 0
|
||||
assert not any("test" in b.value for b in blocks)
|
||||
|
||||
# Verify no traces of the prompt exist in the raw system message after detaching
|
||||
raw_system_message_after_detach = get_raw_system_message(client, agent_state.id)
|
||||
assert expected_chunk not in raw_system_message_after_detach
|
||||
assert "test_source" not in raw_system_message_after_detach
|
||||
assert "<directories>" not in raw_system_message_after_detach
|
||||
|
||||
|
||||
def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState):
|
||||
# Create a new source
|
||||
source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small")
|
||||
assert len(client.sources.list()) == 1
|
||||
|
||||
# Attach
|
||||
client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id)
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
assert "test_source" in raw_system_message
|
||||
assert "<directories>" in raw_system_message
|
||||
|
||||
# Load files into the source
|
||||
file_path = "tests/data/test.txt"
|
||||
|
||||
# Upload the files
|
||||
upload_file_and_wait(client, source.id, file_path)
|
||||
raw_system_message = get_raw_system_message(client, agent_state.id)
|
||||
# Assert that the expected chunk is in the raw system message
|
||||
expected_chunk = """<directories>
|
||||
<directory name="test_source">
|
||||
<file status="open" name="test_source/test.txt">
|
||||
<metadata>
|
||||
- read_only=true
|
||||
- chars_current=46
|
||||
- chars_limit=50000
|
||||
</metadata>
|
||||
<value>
|
||||
[Viewing file start (out of 1 chunks)]
|
||||
1: test
|
||||
</value>
|
||||
</file>
|
||||
</directory>
|
||||
</directories>"""
|
||||
assert expected_chunk in raw_system_message
|
||||
|
||||
# Get the agent state, check blocks exist
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
@@ -264,6 +349,10 @@ def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client:
|
||||
|
||||
# Remove file from source
|
||||
client.sources.delete(source_id=source.id)
|
||||
raw_system_message_after_detach = get_raw_system_message(client, agent_state.id)
|
||||
assert expected_chunk not in raw_system_message_after_detach
|
||||
assert "test_source" not in raw_system_message_after_detach
|
||||
assert "<directories>" not in raw_system_message_after_detach
|
||||
|
||||
# Get the agent state, check blocks do NOT exist
|
||||
agent_state = client.agents.retrieve(agent_id=agent_state.id)
|
||||
|
||||
Reference in New Issue
Block a user