diff --git a/letta/__init__.py b/letta/__init__.py index f7de8edc..3aae91bf 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -5,7 +5,7 @@ try: __version__ = version("letta") except PackageNotFoundError: # Fallback for development installations - __version__ = "0.8.12" + __version__ = "0.8.13" if os.environ.get("LETTA_VERSION"): __version__ = os.environ["LETTA_VERSION"] diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index b8613201..a556ab73 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -96,7 +96,7 @@ class BaseAgent(ABC): """ try: # [DB Call] loading blocks (modifies: agent_state.memory.blocks) - await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor) + agent_state = await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor) tool_constraint_block = None if tool_rules_solver is not None: @@ -104,18 +104,37 @@ class BaseAgent(ABC): # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources) curr_system_message_text = curr_system_message.content[0].text - if curr_memory_str in curr_system_message_text: + + # extract the dynamic section that includes memory blocks, tool rules, and directories + # this avoids timestamp comparison issues + def extract_dynamic_section(text): + start_marker = "" + end_marker = "" + + start_idx = text.find(start_marker) + end_idx = text.find(end_marker) + + if start_idx != -1 and end_idx != -1: + return text[start_idx:end_idx] + return text # fallback to full text if markers not found + + curr_dynamic_section = extract_dynamic_section(curr_system_message_text) + + # generate just the memory string with current state for comparison + curr_memory_str = agent_state.memory.compile(tool_usage_rules=tool_constraint_block, sources=agent_state.sources) + new_dynamic_section = extract_dynamic_section(curr_memory_str) + + # compare just the dynamic sections (memory blocks, tool rules, directories) + if curr_dynamic_section == new_dynamic_section: logger.debug( - f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" + f"Memory and sources haven't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" ) return in_context_messages memory_edit_timestamp = get_utc_time() - # [DB Call] size of messages and archival memories - # todo: blocking for now + # size of messages and archival memories if num_messages is None: num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) if num_archival_memories is None: diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 9a6aaadb..7ab25738 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -3,7 +3,7 @@ import json import uuid from collections.abc import AsyncGenerator from datetime import datetime -from typing import Optional +from typing import Optional, Union from openai import AsyncStream from openai.types.chat import ChatCompletionChunk @@ -165,18 +165,28 @@ class LettaAgent(BaseAgent): use_assistant_message: bool = True, request_start_timestamp_ns: int | None = None, include_return_message_types: list[MessageType] | None = None, - ) -> LettaResponse: + dry_run: bool = False, + ) -> Union[LettaResponse, dict]: # TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions agent_state = await self.agent_manager.get_agent_by_id_async( - agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor + agent_id=self.agent_id, + include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"], + actor=self.actor, ) - _, new_in_context_messages, stop_reason, usage = await self._step( + result = await self._step( agent_state=agent_state, input_messages=input_messages, max_steps=max_steps, run_id=run_id, request_start_timestamp_ns=request_start_timestamp_ns, + dry_run=dry_run, ) + + # If dry run, return the request payload directly + if dry_run: + return result + + _, new_in_context_messages, stop_reason, usage = result return _create_letta_response( new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, @@ -195,7 +205,9 @@ class LettaAgent(BaseAgent): include_return_message_types: list[MessageType] | None = None, ): agent_state = await self.agent_manager.get_agent_by_id_async( - agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor + agent_id=self.agent_id, + include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"], + actor=self.actor, ) current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async( input_messages, agent_state, self.message_manager, self.actor @@ -279,6 +291,7 @@ class LettaAgent(BaseAgent): tool_rules_solver, response.usage, reasoning_content=reasoning, + step_id=step_id, initial_messages=initial_messages, agent_step_span=agent_step_span, is_final_step=(i == max_steps - 1), @@ -357,7 +370,8 @@ class LettaAgent(BaseAgent): max_steps: int = DEFAULT_MAX_STEPS, run_id: str | None = None, request_start_timestamp_ns: int | None = None, - ) -> tuple[list[Message], list[Message], LettaStopReason | None, LettaUsageStatistics]: + dry_run: bool = False, + ) -> Union[tuple[list[Message], list[Message], LettaStopReason | None, LettaUsageStatistics], dict]: """ Carries out an invocation of the agent loop. In each step, the agent 1. Rebuilds its memory @@ -394,6 +408,16 @@ class LettaAgent(BaseAgent): agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span.set_attributes({"step_id": step_id}) + # If dry run, build request data and return it without making LLM call + if dry_run: + request_data, valid_tool_names = await self._create_llm_request_data_async( + llm_client=llm_client, + in_context_messages=current_in_context_messages + new_in_context_messages, + agent_state=agent_state, + tool_rules_solver=tool_rules_solver, + ) + return request_data + request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( await self._build_and_request_from_llm( current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver, agent_step_span @@ -530,7 +554,9 @@ class LettaAgent(BaseAgent): 4. Processes the response """ agent_state = await self.agent_manager.get_agent_by_id_async( - agent_id=self.agent_id, include_relationships=["tools", "memory", "tool_exec_environment_variables"], actor=self.actor + agent_id=self.agent_id, + include_relationships=["tools", "memory", "tool_exec_environment_variables", "sources"], + actor=self.actor, ) current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async( input_messages, agent_state, self.message_manager, self.actor @@ -628,7 +654,7 @@ class LettaAgent(BaseAgent): ) # log LLM request time - llm_request_ms = ns_to_ms(stream_end_time_ns - request_start_timestamp_ns) + llm_request_ms = ns_to_ms(stream_end_time_ns - provider_request_start_timestamp_ns) agent_step_span.add_event(name="llm_request_ms", attributes={"duration_ms": llm_request_ms}) MetricRegistry().llm_execution_time_ms_histogram.record( llm_request_ms, diff --git a/letta/functions/ast_parsers.py b/letta/functions/ast_parsers.py index 57785b46..627b7fdb 100644 --- a/letta/functions/ast_parsers.py +++ b/letta/functions/ast_parsers.py @@ -129,7 +129,8 @@ def get_function_name_and_docstring(source_code: str, name: Optional[str] = None raise LettaToolCreateError("Could not determine function name") if not docstring: - raise LettaToolCreateError("Docstring is missing") + # For tools with args_json_schema, the docstring is optional + docstring = f"The {function_name} tool" return function_name, docstring diff --git a/letta/helpers/json_helpers.py b/letta/helpers/json_helpers.py index 3a1af412..45fa3414 100644 --- a/letta/helpers/json_helpers.py +++ b/letta/helpers/json_helpers.py @@ -10,6 +10,8 @@ def json_dumps(data, indent=2): def safe_serializer(obj): if isinstance(obj, datetime): return obj.isoformat() + if isinstance(obj, bytes): + return obj.decode("utf-8") raise TypeError(f"Type {type(obj)} not serializable") return json.dumps(data, indent=indent, default=safe_serializer, ensure_ascii=False) diff --git a/letta/helpers/pinecone_utils.py b/letta/helpers/pinecone_utils.py index d51b071d..f583b933 100644 --- a/letta/helpers/pinecone_utils.py +++ b/letta/helpers/pinecone_utils.py @@ -1,6 +1,12 @@ from typing import Any, Dict, List -from pinecone import PineconeAsyncio +try: + from pinecone import IndexEmbed, PineconeAsyncio + from pinecone.exceptions.exceptions import NotFoundException + + PINECONE_AVAILABLE = True +except ImportError: + PINECONE_AVAILABLE = False from letta.constants import ( PINECONE_CLOUD, @@ -27,11 +33,20 @@ def should_use_pinecone(verbose: bool = False): bool(settings.pinecone_source_index), ) - return settings.enable_pinecone and settings.pinecone_api_key and settings.pinecone_agent_index and settings.pinecone_source_index + return all( + ( + PINECONE_AVAILABLE, + settings.enable_pinecone, + settings.pinecone_api_key, + settings.pinecone_agent_index, + settings.pinecone_source_index, + ) + ) async def upsert_pinecone_indices(): - from pinecone import IndexEmbed, PineconeAsyncio + if not PINECONE_AVAILABLE: + raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") for index_name in get_pinecone_indices(): async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc: @@ -49,6 +64,9 @@ def get_pinecone_indices() -> List[str]: async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, chunks: List[str], actor: User): + if not PINECONE_AVAILABLE: + raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") + records = [] for i, chunk in enumerate(chunks): record = { @@ -63,7 +81,8 @@ async def upsert_file_records_to_pinecone_index(file_id: str, source_id: str, ch async def delete_file_records_from_pinecone_index(file_id: str, actor: User): - from pinecone.exceptions.exceptions import NotFoundException + if not PINECONE_AVAILABLE: + raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") namespace = actor.organization_id try: @@ -81,7 +100,8 @@ async def delete_file_records_from_pinecone_index(file_id: str, actor: User): async def delete_source_records_from_pinecone_index(source_id: str, actor: User): - from pinecone.exceptions.exceptions import NotFoundException + if not PINECONE_AVAILABLE: + raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") namespace = actor.organization_id try: @@ -94,6 +114,9 @@ async def delete_source_records_from_pinecone_index(source_id: str, actor: User) async def upsert_records_to_pinecone_index(records: List[dict], actor: User): + if not PINECONE_AVAILABLE: + raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") + async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc: description = await pc.describe_index(name=settings.pinecone_source_index) async with pc.IndexAsyncio(host=description.index.host) as dense_index: @@ -104,6 +127,9 @@ async def upsert_records_to_pinecone_index(records: List[dict], actor: User): async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], actor: User) -> Dict[str, Any]: + if not PINECONE_AVAILABLE: + raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") + async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc: description = await pc.describe_index(name=settings.pinecone_source_index) async with pc.IndexAsyncio(host=description.index.host) as dense_index: @@ -127,7 +153,8 @@ async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any], async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]: - from pinecone.exceptions.exceptions import NotFoundException + if not PINECONE_AVAILABLE: + raise ImportError("Pinecone is not available. Please install pinecone to use this feature.") namespace = actor.organization_id try: diff --git a/letta/jobs/scheduler.py b/letta/jobs/scheduler.py index df220e4b..1e6867b2 100644 --- a/letta/jobs/scheduler.py +++ b/letta/jobs/scheduler.py @@ -29,6 +29,7 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool: if _is_scheduler_leader: return True # Already leading + engine_name = None lock_session = None acquired_lock = False try: @@ -36,32 +37,25 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool: engine = session.get_bind() engine_name = engine.name logger.info(f"Database engine type: {engine_name}") - if engine_name != "postgresql": - logger.warning(f"Advisory locks not supported for {engine_name} database. Starting scheduler without leader election.") - acquired_lock = True - else: - lock_session = db_registry.get_async_session_factory()() - result = await lock_session.execute( - text("SELECT pg_try_advisory_lock(CAST(:lock_key AS bigint))"), {"lock_key": ADVISORY_LOCK_KEY} - ) - acquired_lock = result.scalar() - await lock_session.commit() - if not acquired_lock: - if lock_session: - await lock_session.close() - logger.info("Scheduler lock held by another instance.") - return False - - if engine_name == "postgresql": - logger.info("Acquired PostgreSQL advisory lock.") - _advisory_lock_session = lock_session - lock_session = None + if engine_name != "postgresql": + logger.warning(f"Advisory locks not supported for {engine_name} database. Starting scheduler without leader election.") + acquired_lock = True else: - logger.info("Starting scheduler for non-PostgreSQL database.") - if lock_session: + lock_session = db_registry.get_async_session_factory()() + result = await lock_session.execute( + text("SELECT pg_try_advisory_lock(CAST(:lock_key AS bigint))"), {"lock_key": ADVISORY_LOCK_KEY} + ) + acquired_lock = result.scalar() + await lock_session.commit() + + if not acquired_lock: await lock_session.close() - lock_session = None + logger.info("Scheduler lock held by another instance.") + return False + else: + _advisory_lock_session = lock_session + lock_session = None trigger = IntervalTrigger( seconds=settings.poll_running_llm_batches_interval_seconds, @@ -90,7 +84,6 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool: if acquired_lock: logger.warning("Attempting to release lock due to error during startup.") try: - _advisory_lock_session = lock_session await _release_advisory_lock(lock_session) except Exception as unlock_err: logger.error(f"Failed to release lock during error handling: {unlock_err}", exc_info=True) @@ -108,8 +101,8 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool: if lock_session: try: await lock_session.close() - except: - pass + except Exception as e: + logger.error(f"Failed to close session during error handling: {e}", exc_info=True) async def _background_lock_retry_loop(server: SyncServer): @@ -138,15 +131,13 @@ async def _background_lock_retry_loop(server: SyncServer): break except Exception as e: logger.error(f"Error in background lock retry loop: {e}", exc_info=True) - await asyncio.sleep(settings.poll_lock_retry_interval_seconds) -async def _release_advisory_lock(lock_session=None): +async def _release_advisory_lock(target_lock_session=None): """Releases the advisory lock using the stored session.""" global _advisory_lock_session - lock_session = _advisory_lock_session or lock_session - _advisory_lock_session = None + lock_session = target_lock_session or _advisory_lock_session if lock_session is not None: logger.info(f"Attempting to release PostgreSQL advisory lock {ADVISORY_LOCK_KEY}") @@ -161,6 +152,8 @@ async def _release_advisory_lock(lock_session=None): if lock_session: await lock_session.close() logger.info("Closed database session that held advisory lock.") + if lock_session == _advisory_lock_session: + _advisory_lock_session = None except Exception as e: logger.error(f"Error closing advisory lock session: {e}", exc_info=True) else: diff --git a/letta/otel/metric_registry.py b/letta/otel/metric_registry.py index c5b242f3..f3069a9e 100644 --- a/letta/otel/metric_registry.py +++ b/letta/otel/metric_registry.py @@ -58,7 +58,12 @@ class MetricRegistry: def tool_execution_counter(self) -> Counter: return self._get_or_create_metric( "count_tool_execution", - partial(self._meter.create_counter, name="count_tool_execution", description="Counts the number of tools executed.", unit="1"), + partial( + self._meter.create_counter, + name="count_tool_execution", + description="Counts the number of tools executed.", + unit="1", + ), ) # project_id + model @@ -66,7 +71,12 @@ class MetricRegistry: def ttft_ms_histogram(self) -> Histogram: return self._get_or_create_metric( "hist_ttft_ms", - partial(self._meter.create_histogram, name="hist_ttft_ms", description="Histogram for the Time to First Token (ms)", unit="ms"), + partial( + self._meter.create_histogram, + name="hist_ttft_ms", + description="Histogram for the Time to First Token (ms)", + unit="ms", + ), ) # (includes model name) @@ -158,3 +168,15 @@ class MetricRegistry: unit="1", ), ) + + @property + def file_process_bytes_histogram(self) -> Histogram: + return self._get_or_create_metric( + "hist_file_process_bytes", + partial( + self._meter.create_histogram, + name="hist_file_process_bytes", + description="Histogram for file process in bytes", + unit="By", + ), + ) diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index f4cd35f8..cbe922ca 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -139,3 +139,11 @@ class MCPServerType(str, Enum): SSE = "sse" STDIO = "stdio" STREAMABLE_HTTP = "streamable_http" + + +class DuplicateFileHandling(str, Enum): + """How to handle duplicate filenames when uploading files""" + + SKIP = "skip" # skip files with duplicate names + ERROR = "error" # error when duplicate names are encountered + SUFFIX = "suffix" # add numeric suffix to make names unique (default behavior) diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index a9a41b2b..dde4975f 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -77,9 +77,8 @@ class Tool(BaseTool): if self.tool_type is ToolType.CUSTOM: if not self.source_code: - error_msg = f"Custom tool with id={self.id} is missing source_code field." - logger.error(error_msg) - raise ValueError(error_msg) + logger.error("Custom tool with id=%s is missing source_code field", self.id) + raise ValueError(f"Custom tool with id={self.id} is missing source_code field.") # Always derive json_schema for freshest possible json_schema if self.args_json_schema is not None: @@ -96,8 +95,7 @@ class Tool(BaseTool): try: self.json_schema = derive_openai_json_schema(source_code=self.source_code) except Exception as e: - error_msg = f"Failed to derive json schema for tool with id={self.id} name={self.name}. Error: {str(e)}" - logger.error(error_msg) + logger.error("Failed to derive json schema for tool with id=%s name=%s: %s", self.id, self.name, e) elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE, ToolType.LETTA_SLEEPTIME_CORE}: # If it's letta core tool, we generate the json_schema on the fly here self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name) @@ -119,9 +117,8 @@ class Tool(BaseTool): # At this point, we need to validate that at least json_schema is populated if not self.json_schema: - error_msg = f"Tool with id={self.id} name={self.name} tool_type={self.tool_type} is missing a json_schema." - logger.error(error_msg) - raise ValueError(error_msg) + logger.error("Tool with id=%s name=%s tool_type=%s is missing a json_schema", self.id, self.name, self.tool_type) + raise ValueError(f"Tool with id={self.id} name={self.name} tool_type={self.tool_type} is missing a json_schema.") # Derive name from the JSON schema if not provided if not self.name: diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 8ddb175f..c630e2f5 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -337,8 +337,11 @@ def create_application() -> "FastAPI": # / static files mount_static_files(app) + no_generation = "--no-generation" in sys.argv + # Generate OpenAPI schema after all routes are mounted - generate_openapi_schema(app) + if not no_generation: + generate_openapi_schema(app) return app diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 6f0df27d..fee2de03 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -2,7 +2,7 @@ import asyncio import json import traceback from datetime import datetime, timezone -from typing import Annotated, Any, List, Optional +from typing import Annotated, Any, Dict, List, Optional, Union from fastapi import APIRouter, Body, Depends, File, Header, HTTPException, Query, Request, UploadFile, status from fastapi.responses import JSONResponse @@ -522,7 +522,7 @@ async def attach_block( actor_id: str | None = Header(None, alias="user_id"), ): """ - Attach a core memoryblock to an agent. + Attach a core memory block to an agent. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) return await server.agent_manager.attach_block_async(agent_id=agent_id, block_id=block_id, actor=actor) @@ -1160,6 +1160,69 @@ async def list_agent_groups( return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor) +@router.post( + "/{agent_id}/messages/preview-raw-payload", + response_model=Dict[str, Any], + operation_id="preview_raw_payload", +) +async def preview_raw_payload( + agent_id: str, + request: Union[LettaRequest, LettaStreamingRequest] = Body(...), + server: SyncServer = Depends(get_letta_server), + actor_id: str | None = Header(None, alias="user_id"), +): + """ + Inspect the raw LLM request payload without sending it. + + This endpoint processes the message through the agent loop up until + the LLM request, then returns the raw request payload that would + be sent to the LLM provider. Useful for debugging and inspection. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"]) + agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] + model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] + + if agent_eligible and model_compatible: + if agent.enable_sleeptime: + # TODO: @caren need to support this for sleeptime + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Payload inspection is not supported for agents with sleeptime enabled.", + ) + else: + agent_loop = LettaAgent( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + job_manager=server.job_manager, + passage_manager=server.passage_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + summarizer_mode=( + SummarizationMode.STATIC_MESSAGE_BUFFER + if agent.agent_type == AgentType.voice_convo_agent + else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER + ), + ) + + # TODO: Support step_streaming + return await agent_loop.step( + input_messages=request.messages, + use_assistant_message=request.use_assistant_message, + include_return_message_types=request.include_return_message_types, + dry_run=True, + ) + + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Payload inspection is not currently supported for this agent configuration.", + ) + + @router.post("/{agent_id}/summarize", response_model=AgentState, operation_id="summarize_agent_conversation") async def summarize_agent_conversation( agent_id: str, diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 48049d73..e7ab5370 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -19,7 +19,7 @@ from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import FileProcessingStatus +from letta.schemas.enums import DuplicateFileHandling, FileProcessingStatus from letta.schemas.file import FileMetadata from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate @@ -208,6 +208,7 @@ async def delete_source( async def upload_file_to_source( file: UploadFile, source_id: str, + duplicate_handling: DuplicateFileHandling = Query(DuplicateFileHandling.SUFFIX, description="How to handle duplicate filenames"), server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): @@ -264,8 +265,31 @@ async def upload_file_to_source( content = await file.read() - # Store original filename and generate unique filename + # Store original filename and handle duplicate logic original_filename = sanitize_filename(file.filename) # Basic sanitization only + + # Check if duplicate exists + existing_file = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=source_id, actor=actor + ) + + if existing_file: + # Duplicate found, handle based on strategy + if duplicate_handling == DuplicateFileHandling.ERROR: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail=f"File '{original_filename}' already exists in source '{source.name}'" + ) + elif duplicate_handling == DuplicateFileHandling.SKIP: + # Return existing file metadata with custom header to indicate it was skipped + from fastapi import Response + + response = Response( + content=existing_file.model_dump_json(), media_type="application/json", headers={"X-Upload-Result": "skipped"} + ) + return response + # For SUFFIX, continue to generate unique filename + + # Generate unique filename (adds suffix if needed) unique_filename = await server.file_manager.generate_unique_filename( original_filename=original_filename, source=source, organization_id=actor.organization_id ) @@ -360,6 +384,13 @@ async def get_file_metadata( file_id=file_id, actor=actor, include_content=include_content, strip_directory_prefix=True ) + if not file_metadata: + raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.") + + # Verify the file belongs to the specified source + if file_metadata.source_id != source_id: + raise HTTPException(status_code=404, detail=f"File with id={file_id} not found in source {source_id}.") + if should_use_pinecone() and not file_metadata.is_processing_terminal(): ids = await list_pinecone_index_for_files(file_id=file_id, actor=actor, limit=file_metadata.total_chunks) logger.info( @@ -375,13 +406,6 @@ async def get_file_metadata( file_id=file_metadata.id, actor=actor, chunks_embedded=len(ids), processing_status=file_status ) - if not file_metadata: - raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.") - - # Verify the file belongs to the specified source - if file_metadata.source_id != source_id: - raise HTTPException(status_code=404, detail=f"File with id={file_id} not found in source {source_id}.") - return file_metadata diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 3e17740f..2bc79280 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from typing import Dict, List, Optional, Set, Tuple import sqlalchemy as sa -from sqlalchemy import delete, func, insert, literal, or_, select +from sqlalchemy import delete, func, insert, literal, or_, select, tuple_ from sqlalchemy.dialects.postgresql import insert as pg_insert from letta.constants import ( @@ -224,13 +224,44 @@ class AgentManager: @staticmethod async def _replace_pivot_rows_async(session, table, agent_id: str, rows: list[dict]): """ - Replace all pivot rows for an agent with *exactly* the provided list. - Uses two bulk statements (DELETE + INSERT ... ON CONFLICT DO NOTHING). + Replace all pivot rows for an agent atomically using MERGE pattern. """ - # delete all existing rows for this agent - await session.execute(delete(table).where(table.c.agent_id == agent_id)) - if rows: - await AgentManager._bulk_insert_pivot_async(session, table, rows) + dialect = session.bind.dialect.name + + if dialect == "postgresql": + if rows: + # separate upsert and delete operations + stmt = pg_insert(table).values(rows) + stmt = stmt.on_conflict_do_nothing() + await session.execute(stmt) + + # delete rows not in new set + pk_names = [c.name for c in table.primary_key.columns] + new_keys = [tuple(r[c] for c in pk_names) for r in rows] + await session.execute( + delete(table).where(table.c.agent_id == agent_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys)) + ) + else: + # if no rows to insert, just delete all + await session.execute(delete(table).where(table.c.agent_id == agent_id)) + + elif dialect == "sqlite": + if rows: + stmt = sa.insert(table).values(rows).prefix_with("OR REPLACE") + await session.execute(stmt) + + if rows: + primary_key_cols = [table.c[c.name] for c in table.primary_key.columns] + new_keys = [tuple(r[c.name] for c in table.primary_key.columns) for r in rows] + await session.execute(delete(table).where(table.c.agent_id == agent_id, ~tuple_(*primary_key_cols).in_(new_keys))) + else: + await session.execute(delete(table).where(table.c.agent_id == agent_id)) + + else: + # fallback: use original DELETE + INSERT pattern + await session.execute(delete(table).where(table.c.agent_id == agent_id)) + if rows: + await AgentManager._bulk_insert_pivot_async(session, table, rows) # ====================================================================================================================== # Basic CRUD operations diff --git a/letta/services/file_manager.py b/letta/services/file_manager.py index f4a84fb3..530fa3e1 100644 --- a/letta/services/file_manager.py +++ b/letta/services/file_manager.py @@ -22,6 +22,15 @@ from letta.server.db import db_registry from letta.utils import enforce_types +class DuplicateFileError(Exception): + """Raised when a duplicate file is encountered and error handling is specified""" + + def __init__(self, filename: str, source_name: str): + self.filename = filename + self.source_name = source_name + super().__init__(f"File '{filename}' already exists in source '{source_name}'") + + class FileManager: """Manager class to handle business logic related to files.""" @@ -237,16 +246,16 @@ class FileManager: @trace_method async def generate_unique_filename(self, original_filename: str, source: PydanticSource, organization_id: str) -> str: """ - Generate a unique filename by checking for duplicates and adding a numeric suffix if needed. - Similar to how filesystems handle duplicates (e.g., file.txt, file (1).txt, file (2).txt). + Generate a unique filename by adding a numeric suffix if duplicates exist. + Always returns a unique filename - does not handle duplicate policies. Parameters: original_filename (str): The original filename as uploaded. - source_id (str): Source ID to check for duplicates within. + source (PydanticSource): Source to check for duplicates within. organization_id (str): Organization ID to check for duplicates within. Returns: - str: A unique filename with numeric suffix if needed. + str: A unique filename with source.name prefix and numeric suffix if needed. """ base, ext = os.path.splitext(original_filename) @@ -271,9 +280,44 @@ class FileManager: # No duplicates, return original filename with source.name return f"{source.name}/{original_filename}" else: - # Add numeric suffix + # Add numeric suffix to make unique return f"{source.name}/{base}_({count}){ext}" + @enforce_types + @trace_method + async def get_file_by_original_name_and_source( + self, original_filename: str, source_id: str, actor: PydanticUser + ) -> Optional[PydanticFileMetadata]: + """ + Get a file by its original filename and source ID. + + Parameters: + original_filename (str): The original filename to search for. + source_id (str): The source ID to search within. + actor (PydanticUser): The actor performing the request. + + Returns: + Optional[PydanticFileMetadata]: The file metadata if found, None otherwise. + """ + async with db_registry.async_session() as session: + query = ( + select(FileMetadataModel) + .where( + FileMetadataModel.original_file_name == original_filename, + FileMetadataModel.source_id == source_id, + FileMetadataModel.organization_id == actor.organization_id, + FileMetadataModel.is_deleted == False, + ) + .limit(1) + ) + + result = await session.execute(query) + file_orm = result.scalar_one_or_none() + + if file_orm: + return await file_orm.to_pydantic_async() + return None + @enforce_types @trace_method async def get_organization_sources_metadata(self, actor: PydanticUser) -> OrganizationSourcesStats: diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py index 07b324c9..4dec0ac2 100644 --- a/letta/services/file_processor/file_processor.py +++ b/letta/services/file_processor/file_processor.py @@ -1,6 +1,7 @@ from typing import List from letta.log import get_logger +from letta.otel.context import get_ctx_attributes from letta.otel.tracing import log_event, trace_method from letta.schemas.agent import AgentState from letta.schemas.enums import FileProcessingStatus @@ -122,6 +123,10 @@ class FileProcessor: if isinstance(content, str): content = content.encode("utf-8") + from letta.otel.metric_registry import MetricRegistry + + MetricRegistry().file_process_bytes_histogram.record(len(content), attributes=get_ctx_attributes()) + if len(content) > self.max_file_size: log_event( "file_processor.size_limit_exceeded", diff --git a/letta/services/helpers/tool_parser_helper.py b/letta/services/helpers/tool_parser_helper.py index f38de929..8bc5333b 100644 --- a/letta/services/helpers/tool_parser_helper.py +++ b/letta/services/helpers/tool_parser_helper.py @@ -1,7 +1,7 @@ import ast import base64 import pickle -from typing import Any +from typing import Any, Union from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, SEND_MESSAGE_TOOL_NAME from letta.schemas.agent import AgentState @@ -9,7 +9,7 @@ from letta.schemas.response_format import ResponseFormatType, ResponseFormatUnio from letta.types import JsonDict, JsonValue -def parse_stdout_best_effort(text: str | bytes) -> tuple[Any, AgentState | None]: +def parse_stdout_best_effort(text: Union[str, bytes]) -> tuple[Any, AgentState | None]: """ Decode and unpickle the result from the function execution if possible. Returns (function_return_value, agent_state). diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 5b8e2693..dbba4a91 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -2,6 +2,7 @@ from functools import partial, reduce from operator import add from typing import List, Literal, Optional, Union +from httpx import AsyncClient, post from sqlalchemy import select from sqlalchemy.orm import Session @@ -95,6 +96,8 @@ class JobManager: @trace_method async def update_job_by_id_async(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob: """Update a job by its ID with the given JobUpdate object asynchronously.""" + callback_func = None + async with db_registry.async_session() as session: # Fetch the job by ID job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"]) @@ -114,11 +117,23 @@ class JobManager: logger.info(f"Current job completed at: {job.completed_at}") job.completed_at = get_utc_time().replace(tzinfo=None) if job.callback_url: - await self._dispatch_callback_async(job) + callback_func = self._dispatch_callback_async( + callback_url=job.callback_url, + payload={ + "job_id": job.id, + "status": job.status, + "completed_at": job.completed_at.isoformat() if job.completed_at else None, + "metadata": job.metadata_, + }, + actor=actor, + ) # Save the updated job to the database await job.update_async(db_session=session, actor=actor) + if callback_func: + return await callback_func + return job.to_pydantic() @enforce_types @@ -683,10 +698,8 @@ class JobManager: "metadata": job.metadata_, } try: - import httpx - log_event("POST callback dispatched", payload) - resp = httpx.post(job.callback_url, json=payload, timeout=5.0) + resp = post(job.callback_url, json=payload, timeout=5.0) log_event("POST callback finished") job.callback_sent_at = get_utc_time().replace(tzinfo=None) job.callback_status_code = resp.status_code @@ -700,31 +713,33 @@ class JobManager: # Continue silently - callback failures should not affect job completion @trace_method - async def _dispatch_callback_async(self, job: JobModel) -> None: + async def _dispatch_callback_async(self, callback_url: str, payload: dict, actor: PydanticUser) -> PydanticJob: """ POST a standard JSON payload to job.callback_url and record timestamp + HTTP status asynchronously. """ - payload = { - "job_id": job.id, - "status": job.status, - "completed_at": job.completed_at.isoformat() if job.completed_at else None, - "metadata": job.metadata_, - } + job_id = payload["job_id"] + callback_sent_at, callback_status_code, callback_error = None, None, None try: - import httpx - - async with httpx.AsyncClient() as client: + async with AsyncClient() as client: log_event("POST callback dispatched", payload) - resp = await client.post(job.callback_url, json=payload, timeout=5.0) + resp = await client.post(callback_url, json=payload, timeout=5.0) log_event("POST callback finished") # Ensure timestamp is timezone-naive for DB compatibility - job.callback_sent_at = get_utc_time().replace(tzinfo=None) - job.callback_status_code = resp.status_code + callback_sent_at = get_utc_time().replace(tzinfo=None) + callback_status_code = resp.status_code except Exception as e: - error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {e!s}" + error_message = f"Failed to dispatch callback for job {job_id} to {callback_url}: {e!s}" logger.error(error_message) # Record the failed attempt - job.callback_sent_at = get_utc_time().replace(tzinfo=None) - job.callback_error = error_message + callback_sent_at = get_utc_time().replace(tzinfo=None) + callback_error = error_message # Continue silently - callback failures should not affect job completion + + async with db_registry.async_session() as session: + job = await JobModel.read_async(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER) + job.callback_sent_at = callback_sent_at + job.callback_status_code = callback_status_code + job.callback_error = callback_error + await job.update_async(db_session=session, actor=actor) + return job.to_pydantic() diff --git a/letta/services/tool_sandbox/local_sandbox.py b/letta/services/tool_sandbox/local_sandbox.py index 3f24fca1..5056adde 100644 --- a/letta/services/tool_sandbox/local_sandbox.py +++ b/letta/services/tool_sandbox/local_sandbox.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Optional from pydantic.config import JsonDict +from letta.log import get_logger from letta.otel.tracing import log_event, trace_method from letta.schemas.agent import AgentState from letta.schemas.sandbox_config import SandboxConfig, SandboxType @@ -23,6 +24,8 @@ from letta.services.tool_sandbox.base import AsyncToolSandboxBase from letta.settings import tool_settings from letta.utils import get_friendly_error_msg, parse_stderr_error_msg +logger = get_logger(__name__) + class AsyncToolSandboxLocal(AsyncToolSandboxBase): METADATA_CONFIG_STATE_KEY = "config_state" @@ -240,9 +243,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): if isinstance(e, TimeoutError): raise e - print(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}") - print(e.__class__.__name__) - print(e.__traceback__) + logger.error(f"Subprocess execution for tool {self.tool_name} encountered an error: {e}") + logger.error(e.__class__.__name__) + logger.error(e.__traceback__) func_return = get_friendly_error_msg( function_name=self.tool_name, exception_name=type(e).__name__, diff --git a/letta/templates/sandbox_code_file.py.j2 b/letta/templates/sandbox_code_file.py.j2 index 953b8ae8..3f4c4517 100644 --- a/letta/templates/sandbox_code_file.py.j2 +++ b/letta/templates/sandbox_code_file.py.j2 @@ -24,8 +24,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl {{ tool_source_code }} {# Invoke the function and store the result in a global variable #} +_function_result = {{ invoke_function_call }} + +{# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #} +try: + from pydantic import BaseModel + from typing import Any + + class _TempResultWrapper(BaseModel): + result: Any + + class Config: + arbitrary_types_allowed = True + + _wrapped = _TempResultWrapper(result=_function_result) + _serialized_result = _wrapped.model_dump()['result'] +except ImportError: + # Pydantic not available in sandbox, fall back to string conversion + print("Pydantic not available in sandbox environment, falling back to string conversion") + _serialized_result = str(_function_result) +except Exception as e: + # If wrapping fails, print the error and stringify the result + print(f"Failed to serialize result with Pydantic wrapper: {e}") + _serialized_result = str(_function_result) + {{ local_sandbox_result_var_name }} = { - "results": {{ invoke_function_call }}, + "results": _serialized_result, "agent_state": agent_state } diff --git a/letta/templates/sandbox_code_file_async.py.j2 b/letta/templates/sandbox_code_file_async.py.j2 index 6ed9cdbe..33c8971d 100644 --- a/letta/templates/sandbox_code_file_async.py.j2 +++ b/letta/templates/sandbox_code_file_async.py.j2 @@ -26,9 +26,32 @@ agent_state = {{ 'pickle.loads(' ~ agent_state_pickle ~ ')' if agent_state_pickl {# Async wrapper to handle the function call and store the result #} async def _async_wrapper(): - result = await {{ invoke_function_call }} + _function_result = await {{ invoke_function_call }} + + {# Use a temporary Pydantic wrapper to recursively serialize any nested Pydantic objects #} + try: + from pydantic import BaseModel + from typing import Any + + class _TempResultWrapper(BaseModel): + result: Any + + class Config: + arbitrary_types_allowed = True + + _wrapped = _TempResultWrapper(result=_function_result) + _serialized_result = _wrapped.model_dump()['result'] + except ImportError: + # Pydantic not available in sandbox, fall back to string conversion + print("Pydantic not available in sandbox environment, falling back to string conversion") + _serialized_result = str(_function_result) + except Exception as e: + # If wrapping fails, print the error and stringify the result + print(f"Failed to serialize result with Pydantic wrapper: {e}") + _serialized_result = str(_function_result) + return { - "results": result, + "results": _serialized_result, "agent_state": agent_state } diff --git a/poetry.lock b/poetry.lock index 763cf6a8..f771114b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -186,9 +186,10 @@ speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (> name = "aiohttp-retry" version = "2.9.1" description = "Simple retry client for aiohttp" -optional = false +optional = true python-versions = ">=3.7" groups = ["main"] +markers = "extra == \"pinecone\" or extra == \"all\"" files = [ {file = "aiohttp_retry-2.9.1-py3-none-any.whl", hash = "sha256:66d2759d1921838256a05a3f80ad7e724936f083e35be5abb5e16eed6be6dc54"}, {file = "aiohttp_retry-2.9.1.tar.gz", hash = "sha256:8eb75e904ed4ee5c2ec242fefe85bf04240f685391c4879d8f541d6028ff01f1"}, @@ -944,7 +945,7 @@ version = "2025.6.15" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.7" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "certifi-2025.6.15-py3-none-any.whl", hash = "sha256:2e0c7ce7cb5d8f8634ca55d2ba7e6ec2689a2fd6537d8dec1296a477a4910057"}, {file = "certifi-2025.6.15.tar.gz", hash = "sha256:d747aa5a8b9bbbb1bb8c22bb13e22bd1f18e9796defa16bab421f7f7a317323b"}, @@ -1050,7 +1051,7 @@ version = "3.4.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "charset_normalizer-3.4.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7c48ed483eb946e6c04ccbe02c6b4d1d48e51944b6db70f697e089c193404941"}, {file = "charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2d318c11350e10662026ad0eb71bb51c7812fc8590825304ae0bdd4ac283acd"}, @@ -2792,7 +2793,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -3590,14 +3591,14 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"] [[package]] name = "letta-client" -version = "0.1.191" +version = "0.1.198" description = "" optional = false python-versions = "<4.0,>=3.8" groups = ["main"] files = [ - {file = "letta_client-0.1.191-py3-none-any.whl", hash = "sha256:2cc234668784b022a25aeab4db48b944a0a188e42112870efd8b028ad223347b"}, - {file = "letta_client-0.1.191.tar.gz", hash = "sha256:95957695e679183ec0d87673c8dd169a83f9807359c4740d42ed84fbb6b05efc"}, + {file = "letta_client-0.1.198-py3-none-any.whl", hash = "sha256:08bbc238b128da2552b2a6e54feb3294794b5586e0962ce0bb95bb525109f58f"}, + {file = "letta_client-0.1.198.tar.gz", hash = "sha256:990c9132423e2955d9c7f7549e5064b2366616232d270e5927788cddba4ef9da"}, ] [package.dependencies] @@ -5207,9 +5208,10 @@ xmp = ["defusedxml"] name = "pinecone" version = "7.3.0" description = "Pinecone client and SDK" -optional = false +optional = true python-versions = "<4.0,>=3.9" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"pinecone\" or extra == \"all\"" files = [ {file = "pinecone-7.3.0-py3-none-any.whl", hash = "sha256:315b8fef20320bef723ecbb695dec0aafa75d8434d86e01e5a0e85933e1009a8"}, {file = "pinecone-7.3.0.tar.gz", hash = "sha256:307edc155621d487c20dc71b76c3ad5d6f799569ba42064190d03917954f9a7b"}, @@ -5236,9 +5238,10 @@ grpc = ["googleapis-common-protos (>=1.66.0)", "grpcio (>=1.44.0) ; python_versi name = "pinecone-plugin-assistant" version = "1.7.0" description = "Assistant plugin for Pinecone SDK" -optional = false +optional = true python-versions = "<4.0,>=3.9" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"pinecone\" or extra == \"all\"" files = [ {file = "pinecone_plugin_assistant-1.7.0-py3-none-any.whl", hash = "sha256:864cb8e7930588e6c2da97c6d44f0240969195f43fa303c5db76cbc12bf903a5"}, {file = "pinecone_plugin_assistant-1.7.0.tar.gz", hash = "sha256:e26e3ba10a8b71c3da0d777cff407668022e82963c4913d0ffeb6c552721e482"}, @@ -5252,9 +5255,10 @@ requests = ">=2.32.3,<3.0.0" name = "pinecone-plugin-interface" version = "0.0.7" description = "Plugin interface for the Pinecone python client" -optional = false +optional = true python-versions = "<4.0,>=3.8" -groups = ["main", "dev"] +groups = ["main"] +markers = "extra == \"pinecone\" or extra == \"all\"" files = [ {file = "pinecone_plugin_interface-0.0.7-py3-none-any.whl", hash = "sha256:875857ad9c9fc8bbc074dbe780d187a2afd21f5bfe0f3b08601924a61ef1bba8"}, {file = "pinecone_plugin_interface-0.0.7.tar.gz", hash = "sha256:b8e6675e41847333aa13923cc44daa3f85676d7157324682dc1640588a982846"}, @@ -6578,7 +6582,7 @@ version = "2.32.4" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c"}, {file = "requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422"}, @@ -7473,7 +7477,7 @@ files = [ {file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"}, {file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"}, ] -markers = {"dev,tests" = "python_version == \"3.10\""} +markers = {dev = "python_version < \"3.12\"", "dev,tests" = "python_version == \"3.10\""} [[package]] name = "typing-inspect" @@ -7542,7 +7546,7 @@ version = "2.5.0" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.9" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc"}, {file = "urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760"}, @@ -8348,7 +8352,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["autoflake", "black", "docker", "fastapi", "granian", "isort", "langchain", "langchain-community", "locust", "pexpect", "pg8000", "pgvector", "pre-commit", "psycopg2", "psycopg2-binary", "pyright", "pytest-asyncio", "pytest-order", "redis", "uvicorn", "uvloop", "wikipedia"] +all = ["autoflake", "black", "docker", "fastapi", "granian", "isort", "langchain", "langchain-community", "locust", "pexpect", "pg8000", "pgvector", "pinecone", "pre-commit", "psycopg2", "psycopg2-binary", "pyright", "pytest-asyncio", "pytest-order", "redis", "uvicorn", "uvloop", "wikipedia"] bedrock = ["aioboto3", "boto3"] cloud-tool-sandbox = ["e2b-code-interpreter"] desktop = ["docker", "fastapi", "langchain", "langchain-community", "locust", "pg8000", "pgvector", "psycopg2", "psycopg2-binary", "pyright", "uvicorn", "wikipedia"] @@ -8356,6 +8360,7 @@ dev = ["autoflake", "black", "isort", "locust", "pexpect", "pre-commit", "pyrigh experimental = ["granian", "uvloop"] external-tools = ["docker", "firecrawl-py", "langchain", "langchain-community", "wikipedia"] google = ["google-genai"] +pinecone = ["pinecone"] postgres = ["asyncpg", "pg8000", "pgvector", "psycopg2", "psycopg2-binary"] redis = ["redis"] server = ["fastapi", "uvicorn"] @@ -8364,4 +8369,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.1" python-versions = "<3.14,>=3.10" -content-hash = "b2f23b566c52a2ecb9d656383d00f8ac64293690e0f6fcc0d354ed2d2ad8807b" +content-hash = "c4fa225d582dac743e5eb8a1338a71c32bd8c8eb8283dec2331c28599bdf7698" diff --git a/pyproject.toml b/pyproject.toml index 6610bbe0..ac4991af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.8.12" +version = "0.8.13" packages = [ {include = "letta"}, ] @@ -72,7 +72,7 @@ llama-index = "^0.12.2" llama-index-embeddings-openai = "^0.3.1" e2b-code-interpreter = {version = "^1.0.3", optional = true} anthropic = "^0.49.0" -letta_client = "^0.1.183" +letta_client = "^0.1.197" openai = "^1.60.0" opentelemetry-api = "1.30.0" opentelemetry-sdk = "1.30.0" @@ -98,13 +98,13 @@ redis = {version = "^6.2.0", optional = true} structlog = "^25.4.0" certifi = "^2025.6.15" aioboto3 = {version = "^14.3.0", optional = true} -pinecone = {extras = ["asyncio"], version = "^7.3.0"} -aiosqlite = "^0.21.0" +pinecone = {extras = ["asyncio"], version = "^7.3.0", optional = true} [tool.poetry.extras] postgres = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "asyncpg"] redis = ["redis"] +pinecone = ["pinecone"] dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "locust"] experimental = ["uvloop", "granian"] server = ["websockets", "fastapi", "uvicorn"] @@ -114,7 +114,7 @@ tests = ["wikipedia"] bedrock = ["boto3", "aioboto3"] google = ["google-genai"] desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"] -all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis"] +all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis", "pinecone"] [tool.poetry.group.dev.dependencies] @@ -122,8 +122,6 @@ black = "^24.4.2" ipykernel = "^6.29.5" ipdb = "^0.13.13" pytest-mock = "^3.14.0" -pinecone = "^7.3.0" - [tool.poetry.group."dev,tests".dependencies] pytest-json-report = "^1.5.0" diff --git a/tests/conftest.py b/tests/conftest.py index 0abe389d..2d5eb88e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,32 @@ def disable_e2b_api_key() -> Generator[None, None, None]: tool_settings.e2b_api_key = original_api_key +@pytest.fixture +def e2b_sandbox_mode(request) -> Generator[None, None, None]: + """ + Parametrizable fixture to enable/disable E2B sandbox mode. + + Usage: + @pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True) + def test_function(e2b_sandbox_mode, ...): + # Test runs twice - once with E2B enabled, once disabled + """ + from letta.settings import tool_settings + + enable_e2b = request.param + original_api_key = tool_settings.e2b_api_key + + if not enable_e2b: + # Disable E2B by setting API key to None + tool_settings.e2b_api_key = None + # If enable_e2b is True, leave the original API key unchanged + + yield + + # Restore original API key + tool_settings.e2b_api_key = original_api_key + + @pytest.fixture def disable_pinecone() -> Generator[None, None, None]: """ diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 3894981b..9f6a250e 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -752,13 +752,14 @@ def test_step_stream_agent_loop_error( """ last_message = client.agents.messages.list(agent_id=agent_state_no_tools.id, limit=1) agent_state_no_tools = client.agents.modify(agent_id=agent_state_no_tools.id, llm_config=llm_config) - with pytest.raises(ApiError): + with pytest.raises(Exception) as exc_info: response = client.agents.messages.create_stream( agent_id=agent_state_no_tools.id, messages=USER_MESSAGE_FORCE_REPLY, ) list(response) - + assert type(exc_info.value) in (ApiError, ValueError) + print(exc_info.value) messages_from_db = client.agents.messages.list(agent_id=agent_state_no_tools.id, after=last_message[0].id) assert len(messages_from_db) == 0 diff --git a/tests/test_managers.py b/tests/test_managers.py index 8b88000a..2048533e 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -8,8 +8,6 @@ import time from datetime import datetime, timedelta, timezone from typing import List -import httpx - # tests/test_file_content_flow.py import pytest from _pytest.python_api import approx @@ -3902,7 +3900,7 @@ async def test_bulk_update_return_hydrated_true(server: SyncServer, default_user mgr = BlockManager() # create a block - b = mgr.create_or_update_block( + b = await mgr.create_or_update_block_async( PydanticBlock(label="persona", value="foo", limit=20), actor=default_user, ) @@ -5117,6 +5115,137 @@ async def test_delete_cascades_to_content(server, default_user, default_source, assert await _count_file_content_rows(async_session, created.id) == 0 +@pytest.mark.asyncio +async def test_get_file_by_original_name_and_source_found(server: SyncServer, default_user, default_source): + """Test retrieving a file by original filename and source when it exists.""" + original_filename = "test_original_file.txt" + file_metadata = PydanticFileMetadata( + file_name="some_generated_name.txt", + original_file_name=original_filename, + file_path="/path/to/test_file.txt", + file_type="text/plain", + file_size=1024, + source_id=default_source.id, + ) + created_file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user) + + # Retrieve the file by original name and source + retrieved_file = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=default_source.id, actor=default_user + ) + + # Assertions to verify the retrieved file matches the created one + assert retrieved_file is not None + assert retrieved_file.id == created_file.id + assert retrieved_file.original_file_name == original_filename + assert retrieved_file.source_id == default_source.id + + +@pytest.mark.asyncio +async def test_get_file_by_original_name_and_source_not_found(server: SyncServer, default_user, default_source): + """Test retrieving a file by original filename and source when it doesn't exist.""" + non_existent_filename = "does_not_exist.txt" + + # Try to retrieve a non-existent file + retrieved_file = await server.file_manager.get_file_by_original_name_and_source( + original_filename=non_existent_filename, source_id=default_source.id, actor=default_user + ) + + # Should return None for non-existent file + assert retrieved_file is None + + +@pytest.mark.asyncio +async def test_get_file_by_original_name_and_source_different_sources(server: SyncServer, default_user, default_source): + """Test that files with same original name in different sources are handled correctly.""" + from letta.schemas.source import Source as PydanticSource + + # Create a second source + second_source_pydantic = PydanticSource( + name="second_test_source", + description="This is a test source.", + metadata={"type": "test"}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + second_source = await server.source_manager.create_source(source=second_source_pydantic, actor=default_user) + + original_filename = "shared_filename.txt" + + # Create file in first source + file_metadata_1 = PydanticFileMetadata( + file_name="file_in_source_1.txt", + original_file_name=original_filename, + file_path="/path/to/file1.txt", + file_type="text/plain", + file_size=1024, + source_id=default_source.id, + ) + created_file_1 = await server.file_manager.create_file(file_metadata=file_metadata_1, actor=default_user) + + # Create file with same original name in second source + file_metadata_2 = PydanticFileMetadata( + file_name="file_in_source_2.txt", + original_file_name=original_filename, + file_path="/path/to/file2.txt", + file_type="text/plain", + file_size=2048, + source_id=second_source.id, + ) + created_file_2 = await server.file_manager.create_file(file_metadata=file_metadata_2, actor=default_user) + + # Retrieve file from first source + retrieved_file_1 = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=default_source.id, actor=default_user + ) + + # Retrieve file from second source + retrieved_file_2 = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=second_source.id, actor=default_user + ) + + # Should retrieve different files + assert retrieved_file_1 is not None + assert retrieved_file_2 is not None + assert retrieved_file_1.id == created_file_1.id + assert retrieved_file_2.id == created_file_2.id + assert retrieved_file_1.id != retrieved_file_2.id + assert retrieved_file_1.source_id == default_source.id + assert retrieved_file_2.source_id == second_source.id + + +@pytest.mark.asyncio +async def test_get_file_by_original_name_and_source_ignores_deleted(server: SyncServer, default_user, default_source): + """Test that deleted files are ignored when searching by original name and source.""" + original_filename = "to_be_deleted.txt" + file_metadata = PydanticFileMetadata( + file_name="deletable_file.txt", + original_file_name=original_filename, + file_path="/path/to/deletable.txt", + file_type="text/plain", + file_size=512, + source_id=default_source.id, + ) + created_file = await server.file_manager.create_file(file_metadata=file_metadata, actor=default_user) + + # Verify file can be found before deletion + retrieved_file = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=default_source.id, actor=default_user + ) + assert retrieved_file is not None + assert retrieved_file.id == created_file.id + + # Delete the file + await server.file_manager.delete_file(created_file.id, actor=default_user) + + # Try to retrieve the deleted file + retrieved_file_after_delete = await server.file_manager.get_file_by_original_name_and_source( + original_filename=original_filename, source_id=default_source.id, actor=default_user + ) + + # Should return None for deleted file + assert retrieved_file_after_delete is None + + @pytest.mark.asyncio async def test_list_files(server: SyncServer, default_user, default_source): """Test listing files with pagination.""" @@ -5870,7 +5999,9 @@ async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user): return await mock_post(url, json, timeout) # Patch the AsyncClient - monkeypatch.setattr(httpx, "AsyncClient", MockAsyncClient) + import letta.services.job_manager as job_manager_module + + monkeypatch.setattr(job_manager_module, "AsyncClient", MockAsyncClient) job_in = PydanticJob(status=JobStatus.created, metadata={"foo": "bar"}, callback_url="http://example.test/webhook/jobs") created = await server.job_manager.create_job_async(pydantic_job=job_in, actor=default_user) @@ -7813,3 +7944,231 @@ async def test_attach_files_bulk_oversized_bulk(server, default_user, sarah_agen # All files should be attached (some open, some closed) all_files_after = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user) assert len(all_files_after) == MAX_FILES_OPEN + 3 + + +# ====================================================================================================================== +# Race Condition Tests - Blocks +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_concurrent_block_updates_race_condition( + server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop +): + """Test that concurrent block updates don't cause race conditions.""" + agent, _ = comprehensive_test_agent_fixture + + # Create multiple blocks to use in concurrent updates + blocks = [] + for i in range(5): + block = await server.block_manager.create_or_update_block_async( + PydanticBlock(label=f"test_block_{i}", value=f"Test block content {i}", limit=1000), actor=default_user + ) + blocks.append(block) + + # Test concurrent updates with different block combinations + async def update_agent_blocks(block_subset): + """Update agent with a specific subset of blocks.""" + update_request = UpdateAgent(block_ids=[b.id for b in block_subset]) + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + # Capture any errors that occur during concurrent updates + return {"error": str(e)} + + # Run concurrent updates with different block combinations + tasks = [ + update_agent_blocks(blocks[:2]), # blocks 0, 1 + update_agent_blocks(blocks[1:3]), # blocks 1, 2 + update_agent_blocks(blocks[2:4]), # blocks 2, 3 + update_agent_blocks(blocks[3:5]), # blocks 3, 4 + update_agent_blocks(blocks[:1]), # block 0 only + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"Concurrent updates failed with errors: {errors}" + + # Verify all results are valid agent states + valid_results = [r for r in results if not isinstance(r, Exception) and not (isinstance(r, dict) and "error" in r)] + assert len(valid_results) == 5, "All concurrent updates should succeed" + + # Verify final state is consistent + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert final_agent is not None + assert len(final_agent.memory.blocks) > 0 + + # Clean up + for block in blocks: + await server.block_manager.delete_block_async(block.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_concurrent_same_block_updates_race_condition( + server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop +): + """Test that multiple concurrent updates to the same block configuration don't cause issues.""" + agent, _ = comprehensive_test_agent_fixture + + # Create a single block configuration to use in all updates + block = await server.block_manager.create_or_update_block_async( + PydanticBlock(label="shared_block", value="Shared block content", limit=1000), actor=default_user + ) + + # Test multiple concurrent updates with the same block configuration + async def update_agent_with_same_blocks(): + """Update agent with the same block configuration.""" + update_request = UpdateAgent(block_ids=[block.id]) + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + return {"error": str(e)} + + # Run 10 concurrent identical updates + tasks = [update_agent_with_same_blocks() for _ in range(10)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"Concurrent identical updates failed with errors: {errors}" + + # Verify final state is consistent + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert len(final_agent.memory.blocks) == 1 + assert final_agent.memory.blocks[0].id == block.id + + # Clean up + await server.block_manager.delete_block_async(block.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_concurrent_empty_block_updates_race_condition( + server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop +): + """Test concurrent updates that remove all blocks.""" + agent, _ = comprehensive_test_agent_fixture + + # Test concurrent updates that clear all blocks + async def clear_agent_blocks(): + """Update agent to have no blocks.""" + update_request = UpdateAgent(block_ids=[]) + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + return {"error": str(e)} + + # Run concurrent clear operations + tasks = [clear_agent_blocks() for _ in range(5)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"Concurrent clear operations failed with errors: {errors}" + + # Verify final state is consistent (no blocks) + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert len(final_agent.memory.blocks) == 0 + + +@pytest.mark.asyncio +async def test_concurrent_mixed_block_operations_race_condition( + server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop +): + """Test mixed concurrent operations: some adding blocks, some removing.""" + agent, _ = comprehensive_test_agent_fixture + + # Create test blocks + blocks = [] + for i in range(3): + block = await server.block_manager.create_or_update_block_async( + PydanticBlock(label=f"mixed_block_{i}", value=f"Mixed block content {i}", limit=1000), actor=default_user + ) + blocks.append(block) + + # Mix of operations: add blocks, remove blocks, clear all + async def mixed_operation(operation_type): + """Perform different types of block operations.""" + if operation_type == "add_all": + update_request = UpdateAgent(block_ids=[b.id for b in blocks]) + elif operation_type == "add_subset": + update_request = UpdateAgent(block_ids=[blocks[0].id]) + elif operation_type == "clear": + update_request = UpdateAgent(block_ids=[]) + else: + update_request = UpdateAgent(block_ids=[blocks[1].id, blocks[2].id]) + + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + return {"error": str(e)} + + # Run mixed concurrent operations + tasks = [ + mixed_operation("add_all"), + mixed_operation("add_subset"), + mixed_operation("clear"), + mixed_operation("add_two"), + mixed_operation("add_all"), + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"Mixed concurrent operations failed with errors: {errors}" + + # Verify final state is consistent (any valid state is acceptable) + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert final_agent is not None + + # Clean up + for block in blocks: + await server.block_manager.delete_block_async(block.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_high_concurrency_stress_test(server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop): + """Stress test with high concurrency to catch race conditions.""" + agent, _ = comprehensive_test_agent_fixture + + # Create many blocks for stress testing + blocks = [] + for i in range(10): + block = await server.block_manager.create_or_update_block_async( + PydanticBlock(label=f"stress_block_{i}", value=f"Stress test content {i}", limit=1000), actor=default_user + ) + blocks.append(block) + + # Create many concurrent update tasks + async def stress_update(task_id): + """Perform a random block update operation.""" + import random + + # Random subset of blocks + num_blocks = random.randint(0, len(blocks)) + selected_blocks = random.sample(blocks, num_blocks) + + update_request = UpdateAgent(block_ids=[b.id for b in selected_blocks]) + + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + return {"error": str(e), "task_id": task_id} + + # Run 20 concurrent stress updates + tasks = [stress_update(i) for i in range(20)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"High concurrency stress test failed with errors: {errors}" + + # Verify final state is consistent + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert final_agent is not None + + # Clean up + for block in blocks: + await server.block_manager.delete_block_async(block.id, actor=default_user) diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 46ed45e0..7ff37851 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -1,15 +1,19 @@ +import json import os import threading import time import uuid +from typing import List, Type import pytest from dotenv import load_dotenv from letta_client import CreateBlock from letta_client import Letta as LettaSDKClient -from letta_client import MessageCreate +from letta_client import LettaRequest, MessageCreate, TextContent +from letta_client.client import BaseTool from letta_client.core import ApiError from letta_client.types import AgentState, ToolReturnMessage +from pydantic import BaseModel, Field # Constants SERVER_PORT = 8283 @@ -762,3 +766,270 @@ def test_base_tools_upsert_on_list(client: LettaSDKClient): final_tool_names = {tool.name for tool in final_tools} for deleted_tool in tools_to_delete: assert deleted_tool.name in final_tool_names, f"Deleted tool {deleted_tool.name} was not properly restored" + + +@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True) +def test_pydantic_inventory_management_tool(e2b_sandbox_mode, client: LettaSDKClient): + class InventoryItem(BaseModel): + sku: str + name: str + price: float + category: str + + class InventoryEntry(BaseModel): + timestamp: int + item: InventoryItem + transaction_id: str + + class InventoryEntryData(BaseModel): + data: InventoryEntry + quantity_change: int + + class ManageInventoryTool(BaseTool): + name: str = "manage_inventory" + args_schema: Type[BaseModel] = InventoryEntryData + description: str = "Update inventory catalogue with a new data entry" + tags: List[str] = ["inventory", "shop"] + + def run(self, data: InventoryEntry, quantity_change: int) -> bool: + print(f"Updated inventory for {data.item.name} with a quantity change of {quantity_change}") + return True + + tool = client.tools.add( + tool=ManageInventoryTool(), + ) + + assert tool is not None + assert tool.name == "manage_inventory" + assert "inventory" in tool.tags + assert "shop" in tool.tags + + temp_agent = client.agents.create( + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a helpful inventory management assistant.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + tool_ids=[tool.id], + include_base_tools=False, + ) + + response = client.agents.messages.create( + agent_id=temp_agent.id, + messages=[ + MessageCreate( + role="user", + content="Update the inventory for product 'iPhone 15' with SKU 'IPH15-001', price $999.99, category 'Electronics', transaction ID 'TXN-12345', timestamp 1640995200, with a quantity change of +10", + ), + ], + ) + + assert response is not None + + tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"] + assert len(tool_call_messages) > 0, "Expected at least one tool call message" + + first_tool_call = tool_call_messages[0] + assert first_tool_call.tool_call.name == "manage_inventory" + + args = json.loads(first_tool_call.tool_call.arguments) + assert "data" in args + assert "quantity_change" in args + assert "item" in args["data"] + assert "name" in args["data"]["item"] + assert "sku" in args["data"]["item"] + assert "price" in args["data"]["item"] + assert "category" in args["data"]["item"] + assert "transaction_id" in args["data"] + assert "timestamp" in args["data"] + + tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"] + assert len(tool_return_messages) > 0, "Expected at least one tool return message" + + first_tool_return = tool_return_messages[0] + assert first_tool_return.status == "success" + assert first_tool_return.tool_return == "True" + assert "Updated inventory for iPhone 15 with a quantity change of 10" in "\n".join(first_tool_return.stdout) + + client.agents.delete(temp_agent.id) + client.tools.delete(tool.id) + + +@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True) +def test_pydantic_task_planning_tool(e2b_sandbox_mode, client: LettaSDKClient): + + class Step(BaseModel): + name: str = Field(..., description="Name of the step.") + description: str = Field(..., description="An exhaustive description of what this step is trying to achieve.") + + class StepsList(BaseModel): + steps: List[Step] = Field(..., description="List of steps to add to the task plan.") + explanation: str = Field(..., description="Explanation for the list of steps.") + + def create_task_plan(steps, explanation): + """Creates a task plan for the current task.""" + print(f"Created task plan with {len(steps)} steps: {explanation}") + return steps + + tool = client.tools.upsert_from_function(func=create_task_plan, args_schema=StepsList, tags=["planning", "task", "pydantic_test"]) + + assert tool is not None + assert tool.name == "create_task_plan" + assert "planning" in tool.tags + assert "task" in tool.tags + + temp_agent = client.agents.create( + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a helpful task planning assistant.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + tool_ids=[tool.id], + include_base_tools=False, + ) + + response = client.agents.messages.create( + agent_id=temp_agent.id, + messages=[ + MessageCreate( + role="user", + content="Create a task plan for organizing a team meeting with 3 steps: 1) Schedule meeting (find available time slots), 2) Send invitations (notify all team members), 3) Prepare agenda (outline discussion topics). Explanation: This plan ensures a well-organized team meeting.", + ), + ], + ) + + assert response is not None + assert hasattr(response, "messages") + assert len(response.messages) > 0 + + tool_call_messages = [msg for msg in response.messages if msg.message_type == "tool_call_message"] + assert len(tool_call_messages) > 0, "Expected at least one tool call message" + + first_tool_call = tool_call_messages[0] + assert first_tool_call.tool_call.name == "create_task_plan" + + args = json.loads(first_tool_call.tool_call.arguments) + assert "steps" in args + assert "explanation" in args + assert isinstance(args["steps"], list) + assert len(args["steps"]) > 0 + + for step in args["steps"]: + assert "name" in step + assert "description" in step + + tool_return_messages = [msg for msg in response.messages if msg.message_type == "tool_return_message"] + assert len(tool_return_messages) > 0, "Expected at least one tool return message" + + first_tool_return = tool_return_messages[0] + assert first_tool_return.status == "success" + + client.agents.delete(temp_agent.id) + client.tools.delete(tool.id) + + +@pytest.mark.parametrize("e2b_sandbox_mode", [True, False], indirect=True) +def test_create_tool_from_function_with_docstring(e2b_sandbox_mode, client: LettaSDKClient): + """Test creating a tool from a function with a docstring using create_from_function""" + + def roll_dice() -> str: + """ + Simulate the roll of a 20-sided die (d20). + + This function generates a random integer between 1 and 20, inclusive, + which represents the outcome of a single roll of a d20. + + Returns: + str: The result of the die roll. + """ + import random + + dice_role_outcome = random.randint(1, 20) + output_string = f"You rolled a {dice_role_outcome}" + return output_string + + tool = client.tools.create_from_function(func=roll_dice) + + assert tool is not None + assert tool.name == "roll_dice" + assert "Simulate the roll of a 20-sided die" in tool.description + assert tool.source_code is not None + assert "random.randint(1, 20)" in tool.source_code + + all_tools = client.tools.list() + tool_names = [t.name for t in all_tools] + assert "roll_dice" in tool_names + + client.tools.delete(tool.id) + + +def test_preview_payload(client: LettaSDKClient, agent): + payload = client.agents.messages.preview_raw_payload( + agent_id=agent.id, + request=LettaRequest( + messages=[ + MessageCreate( + role="user", + content=[ + TextContent( + text="text", + ) + ], + ) + ], + ), + ) + + assert isinstance(payload, dict) + assert "model" in payload + assert "messages" in payload + assert "tools" in payload + assert "frequency_penalty" in payload + assert "max_completion_tokens" in payload + assert "temperature" in payload + assert "user" in payload + assert "parallel_tool_calls" in payload + assert "tool_choice" in payload + + assert payload["model"] == "gpt-4o-mini" + + assert isinstance(payload["messages"], list) + assert len(payload["messages"]) >= 3 + + system_message = payload["messages"][0] + assert system_message["role"] == "system" + assert "base_instructions" in system_message["content"] + assert "memory_blocks" in system_message["content"] + assert "tool_usage_rules" in system_message["content"] + assert "Letta" in system_message["content"] + + assert isinstance(payload["tools"], list) + assert len(payload["tools"]) > 0 + + tool_names = [tool["function"]["name"] for tool in payload["tools"]] + expected_tools = ["send_message", "conversation_search", "core_memory_replace", "core_memory_append"] + for tool_name in expected_tools: + assert tool_name in tool_names, f"Expected tool {tool_name} not found in tools" + + for tool in payload["tools"]: + assert tool["type"] == "function" + assert "function" in tool + assert "name" in tool["function"] + assert "description" in tool["function"] + assert "parameters" in tool["function"] + assert tool["function"]["strict"] is True + + assert payload["frequency_penalty"] == 1.0 + assert payload["max_completion_tokens"] == 4096 + assert payload["temperature"] == 0.7 + assert payload["parallel_tool_calls"] is False + assert payload["tool_choice"] == "required" + assert payload["user"].startswith("user-") + + print(payload) diff --git a/tests/test_sources.py b/tests/test_sources.py index 6a1a4af2..758fc368 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -7,18 +7,37 @@ import pytest from dotenv import load_dotenv from letta_client import CreateBlock from letta_client import Letta as LettaSDKClient +from letta_client import LettaRequest +from letta_client import MessageCreate as ClientMessageCreate from letta_client.types import AgentState from letta.constants import DEFAULT_ORG_ID, FILES_TOOLS from letta.orm.enums import ToolType from letta.schemas.message import MessageCreate from letta.schemas.user import User +from letta.settings import settings from tests.utils import wait_for_server # Constants SERVER_PORT = 8283 +def get_raw_system_message(client: LettaSDKClient, agent_id: str) -> str: + """Helper function to get the raw system message from an agent's preview payload.""" + raw_payload = client.agents.messages.preview_raw_payload( + agent_id=agent_id, + request=LettaRequest( + messages=[ + ClientMessageCreate( + role="user", + content="Testing", + ) + ], + ), + ) + return raw_payload["messages"][0]["content"] + + @pytest.fixture(autouse=True) def clear_sources(client: LettaSDKClient): # Clear existing sources @@ -172,6 +191,10 @@ def test_file_upload_creates_source_blocks_correctly( expected_value: str, expected_label_regex: str, ): + # skip pdf tests if mistral api key is missing + if file_path.endswith(".pdf") and not settings.mistral_api_key: + pytest.skip("mistral api key required for pdf processing") + # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 @@ -195,6 +218,15 @@ def test_file_upload_creates_source_blocks_correctly( assert any(b.value.startswith("[Viewing file start") for b in blocks) assert any(re.fullmatch(expected_label_regex, b.label) for b in blocks) + # verify raw system message contains source information + raw_system_message = get_raw_system_message(client, agent_state.id) + assert "test_source" in raw_system_message + assert "" in raw_system_message + # verify file-specific details in raw system message + file_name = files[0].file_name + assert f'name="test_source/{file_name}"' in raw_system_message + assert 'status="open"' in raw_system_message + # Remove file from source client.sources.files.delete(source_id=source.id, file_id=files[0].id) @@ -205,6 +237,14 @@ def test_file_upload_creates_source_blocks_correctly( assert not any(expected_value in b.value for b in blocks) assert not any(re.fullmatch(expected_label_regex, b.label) for b in blocks) + # verify raw system message no longer contains source information + raw_system_message_after_removal = get_raw_system_message(client, agent_state.id) + # this should be in, because we didn't delete the source + assert "test_source" in raw_system_message_after_removal + assert "" in raw_system_message_after_removal + # verify file-specific details are also removed + assert f'name="test_source/{file_name}"' not in raw_system_message_after_removal + def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): # Create a new source @@ -224,6 +264,25 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, # Attach after uploading the file client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + raw_system_message = get_raw_system_message(client, agent_state.id) + + # Assert that the expected chunk is in the raw system message + expected_chunk = """ + + + +- read_only=true +- chars_current=46 +- chars_limit=50000 + + +[Viewing file start (out of 1 chunks)] +1: test + + + +""" + assert expected_chunk in raw_system_message # Get the agent state, check blocks exist agent_state = client.agents.retrieve(agent_id=agent_state.id) @@ -241,20 +300,46 @@ def test_attach_existing_files_creates_source_blocks_correctly(disable_pinecone, assert len(blocks) == 0 assert not any("test" in b.value for b in blocks) + # Verify no traces of the prompt exist in the raw system message after detaching + raw_system_message_after_detach = get_raw_system_message(client, agent_state.id) + assert expected_chunk not in raw_system_message_after_detach + assert "test_source" not in raw_system_message_after_detach + assert "" not in raw_system_message_after_detach + def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: LettaSDKClient, agent_state: AgentState): # Create a new source source = client.sources.create(name="test_source", embedding="openai/text-embedding-3-small") assert len(client.sources.list()) == 1 - # Attach client.agents.sources.attach(source_id=source.id, agent_id=agent_state.id) + raw_system_message = get_raw_system_message(client, agent_state.id) + assert "test_source" in raw_system_message + assert "" in raw_system_message # Load files into the source file_path = "tests/data/test.txt" # Upload the files upload_file_and_wait(client, source.id, file_path) + raw_system_message = get_raw_system_message(client, agent_state.id) + # Assert that the expected chunk is in the raw system message + expected_chunk = """ + + + +- read_only=true +- chars_current=46 +- chars_limit=50000 + + +[Viewing file start (out of 1 chunks)] +1: test + + + +""" + assert expected_chunk in raw_system_message # Get the agent state, check blocks exist agent_state = client.agents.retrieve(agent_id=agent_state.id) @@ -264,6 +349,10 @@ def test_delete_source_removes_source_blocks_correctly(disable_pinecone, client: # Remove file from source client.sources.delete(source_id=source.id) + raw_system_message_after_detach = get_raw_system_message(client, agent_state.id) + assert expected_chunk not in raw_system_message_after_detach + assert "test_source" not in raw_system_message_after_detach + assert "" not in raw_system_message_after_detach # Get the agent state, check blocks do NOT exist agent_state = client.agents.retrieve(agent_id=agent_state.id)