diff --git a/Dockerfile b/Dockerfile index 92b5e340..437e1730 100644 --- a/Dockerfile +++ b/Dockerfile @@ -64,8 +64,7 @@ ENV LETTA_ENVIRONMENT=${LETTA_ENVIRONMENT} \ POSTGRES_USER=letta \ POSTGRES_PASSWORD=letta \ POSTGRES_DB=letta \ - COMPOSIO_DISABLE_VERSION_CHECK=true \ - LETTA_OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4317" + COMPOSIO_DISABLE_VERSION_CHECK=true WORKDIR /app diff --git a/compose.yaml b/compose.yaml index d7ce6e6d..322bdb29 100644 --- a/compose.yaml +++ b/compose.yaml @@ -49,6 +49,7 @@ services: - VLLM_API_BASE=${VLLM_API_BASE} - OPENLLM_AUTH_TYPE=${OPENLLM_AUTH_TYPE} - OPENLLM_API_KEY=${OPENLLM_API_KEY} + - LETTA_OTEL_EXPORTER_OTLP_ENDPOINT=${LETTA_OTEL_EXPORTER_OTLP_ENDPOINT} - CLICKHOUSE_ENDPOINT=${CLICKHOUSE_ENDPOINT} - CLICKHOUSE_DATABASE=${CLICKHOUSE_DATABASE} - CLICKHOUSE_USERNAME=${CLICKHOUSE_USERNAME} diff --git a/letta/__init__.py b/letta/__init__.py index 95abbdfe..9259a2eb 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.7.4" +__version__ = "0.7.5" # import clients from letta.client.client import LocalClient, RESTClient, create_client diff --git a/letta/agent.py b/letta/agent.py index cc035edf..7de5b69c 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -190,16 +190,15 @@ class Agent(BaseAgent): Returns: modified (bool): whether the memory was updated """ - if self.agent_state.memory.compile() != new_memory.compile(): + system_message = self.message_manager.get_message_by_id(message_id=self.agent_state.message_ids[0], actor=self.user) + if new_memory.compile() not in system_message.content[0].text: # update the blocks (LRW) in the DB for label in self.agent_state.memory.list_block_labels(): updated_value = new_memory.get_block(label).value if updated_value != self.agent_state.memory.get_block(label).value: # update the block if it's changed block_id = self.agent_state.memory.get_block(label).id - block = self.block_manager.update_block( - block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user - ) + self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user) # refresh memory from DB (using block ids) self.agent_state.memory = Memory( diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index cbe042bc..90a025a9 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -1233,7 +1233,10 @@ class AzureProvider(Provider): """ This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model. """ - return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096) + context_window = AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, None) + if context_window is None: + context_window = LLM_MAX_TOKENS.get(model_name, 4096) + return context_window class VLLMChatCompletionsProvider(Provider): diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index e8571fa5..bd03348e 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -104,6 +104,17 @@ def list_agents( ) +@router.get("/count", response_model=int, operation_id="count_agents") +def count_agents( + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Get the count of all agents associated with a given user. + """ + return server.agent_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + + class IndentedORJSONResponse(Response): media_type = "application/json" diff --git a/letta/server/rest_api/routers/v1/identities.py b/letta/server/rest_api/routers/v1/identities.py index 5f081600..dd48fd4e 100644 --- a/letta/server/rest_api/routers/v1/identities.py +++ b/letta/server/rest_api/routers/v1/identities.py @@ -49,6 +49,24 @@ def list_identities( return identities +@router.get("/count", tags=["identities"], response_model=int, operation_id="count_identities") +def count_identities( + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Get count of all identities for a user + """ + try: + return server.identity_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + except NoResultFound: + return 0 + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + + @router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity") def retrieve_identity( identity_id: str, diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 5f08b3ea..ac91d69b 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -67,6 +67,17 @@ def list_sources( return server.list_all_sources(actor=actor) +@router.get("/count", response_model=int, operation_id="count_sources") +def count_sources( + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present +): + """ + Count all data sources created by a user. + """ + return server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + + @router.post("/", response_model=Source, operation_id="create_source") def create_source( source_create: SourceCreate, diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index b1d95386..06482175 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -80,6 +80,21 @@ def list_tools( raise HTTPException(status_code=500, detail=str(e)) +@router.get("/count", response_model=int, operation_id="count_tools") +def count_tools( + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Get a count of all tools available to agents belonging to the org of the user + """ + try: + return server.tool_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + except Exception as e: + print(f"Error occurred: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/", response_model=Tool, operation_id="create_tool") def create_tool( request: ToolCreate = Body(...), diff --git a/letta/server/server.py b/letta/server/server.py index 7dba65db..6190773b 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1308,14 +1308,12 @@ class SyncServer(Server): tool_execution_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run( agent_state=agent_state, additional_env_vars=tool_env_vars ) - status = "error" if tool_execution_result.stderr else "success" - tool_return = str(tool_execution_result.stderr) if tool_execution_result.stderr else str(tool_execution_result.func_return) return ToolReturnMessage( id="null", tool_call_id="null", date=get_utc_time(), - status=status, - tool_return=tool_return, + status=tool_execution_result.status, + tool_return=str(tool_execution_result.func_return), stdout=tool_execution_result.stdout, stderr=tool_execution_result.stderr, ) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index aa94dae6..1eb139fa 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -556,6 +556,16 @@ class AgentManager: return list(session.execute(query).scalars()) + def size( + self, + actor: PydanticUser, + ) -> int: + """ + Get the total count of agents for the given user. + """ + with self.session_maker() as session: + return AgentModel.size(db_session=session, actor=actor) + @enforce_types def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: """Fetch an agent by its ID.""" @@ -590,15 +600,18 @@ class AgentManager: agents_to_delete = [agent] sleeptime_group_to_delete = None - # Delete sleeptime agent and group + # Delete sleeptime agent and group (TODO this is flimsy pls fix) if agent.multi_agent_group: participant_agent_ids = agent.multi_agent_group.agent_ids if agent.multi_agent_group.manager_type == ManagerType.sleeptime and len(participant_agent_ids) == 1: - sleeptime_agent = AgentModel.read(db_session=session, identifier=participant_agent_ids[0], actor=actor) - if sleeptime_agent.agent_type == AgentType.sleeptime_agent: - sleeptime_agent_group = GroupModel.read(db_session=session, identifier=agent.multi_agent_group.id, actor=actor) - sleeptime_group_to_delete = sleeptime_agent_group + try: + sleeptime_agent = AgentModel.read(db_session=session, identifier=participant_agent_ids[0], actor=actor) agents_to_delete.append(sleeptime_agent) + except NoResultFound: + pass # agent already deleted + sleeptime_agent_group = GroupModel.read(db_session=session, identifier=agent.multi_agent_group.id, actor=actor) + sleeptime_group_to_delete = sleeptime_agent_group + try: if sleeptime_group_to_delete is not None: session.delete(sleeptime_group_to_delete) @@ -931,7 +944,8 @@ class AgentManager: modified (bool): whether the memory was updated """ agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor) - if agent_state.memory.compile() != new_memory.compile(): + system_message = self.message_manager.get_message_by_id(message_id=agent_state.message_ids[0], actor=actor) + if new_memory.compile() not in system_message.content[0].text: # update the blocks (LRW) in the DB for label in agent_state.memory.list_block_labels(): updated_value = new_memory.get_block(label).value diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index e6bf881f..798b01a0 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -190,6 +190,17 @@ class IdentityManager: session.delete(identity) session.commit() + @enforce_types + def size( + self, + actor: PydanticUser, + ) -> int: + """ + Get the total count of identities for the given user. + """ + with self.session_maker() as session: + return IdentityModel.size(db_session=session, actor=actor) + def _process_relationship( self, session: Session, diff --git a/letta/services/mcp/__init__.py b/letta/services/mcp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/services/mcp/base_client.py b/letta/services/mcp/base_client.py new file mode 100644 index 00000000..270792a9 --- /dev/null +++ b/letta/services/mcp/base_client.py @@ -0,0 +1,67 @@ +from contextlib import AsyncExitStack +from typing import Optional, Tuple + +from mcp import ClientSession +from mcp import Tool as MCPTool +from mcp.types import TextContent + +from letta.functions.mcp_client.types import BaseServerConfig +from letta.log import get_logger + +logger = get_logger(__name__) + + +# TODO: Get rid of Async prefix on this class name once we deprecate old sync code +class AsyncBaseMCPClient: + def __init__(self, server_config: BaseServerConfig): + self.server_config = server_config + self.exit_stack = AsyncExitStack() + self.session: Optional[ClientSession] = None + self.initialized = False + + async def connect_to_server(self): + try: + await self._initialize_connection(self.server_config) + await self.session.initialize() + self.initialized = True + except Exception as e: + logger.error( + f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}" + ) + raise e + + async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: BaseServerConfig) -> None: + raise NotImplementedError("Subclasses must implement _initialize_connection") + + async def list_tools(self) -> list[MCPTool]: + self._check_initialized() + response = await self.session.list_tools() + return response.tools + + async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]: + self._check_initialized() + result = await self.session.call_tool(tool_name, tool_args) + parsed_content = [] + for content_piece in result.content: + if isinstance(content_piece, TextContent): + parsed_content.append(content_piece.text) + print("parsed_content (text)", parsed_content) + else: + parsed_content.append(str(content_piece)) + print("parsed_content (other)", parsed_content) + if len(parsed_content) > 0: + final_content = " ".join(parsed_content) + else: + # TODO move hardcoding to constants + final_content = "Empty response from tool" + + return final_content, result.isError + + def _check_initialized(self): + if not self.initialized: + logger.error("MCPClient has not been initialized") + raise RuntimeError("MCPClient has not been initialized") + + async def cleanup(self): + """Clean up resources""" + await self.exit_stack.aclose() diff --git a/letta/services/mcp/sse_client.py b/letta/services/mcp/sse_client.py new file mode 100644 index 00000000..5bfebf75 --- /dev/null +++ b/letta/services/mcp/sse_client.py @@ -0,0 +1,25 @@ +from contextlib import AsyncExitStack + +from mcp import ClientSession +from mcp.client.sse import sse_client + +from letta.functions.mcp_client.types import SSEServerConfig +from letta.log import get_logger +from letta.services.mcp.base_client import AsyncBaseMCPClient + +# see: https://modelcontextprotocol.io/quickstart/user +MCP_CONFIG_TOPLEVEL_KEY = "mcpServers" + +logger = get_logger(__name__) + + +# TODO: Get rid of Async prefix on this class name once we deprecate old sync code +class AsyncSSEMCPClient(AsyncBaseMCPClient): + async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: SSEServerConfig) -> None: + sse_cm = sse_client(url=server_config.server_url) + sse_transport = await exit_stack.enter_async_context(sse_cm) + self.stdio, self.write = sse_transport + + # Create and enter the ClientSession context manager + session_cm = ClientSession(self.stdio, self.write) + self.session = await exit_stack.enter_async_context(session_cm) diff --git a/letta/services/mcp/stdio_client.py b/letta/services/mcp/stdio_client.py new file mode 100644 index 00000000..ca9c4c44 --- /dev/null +++ b/letta/services/mcp/stdio_client.py @@ -0,0 +1,19 @@ +from contextlib import AsyncExitStack + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + +from letta.functions.mcp_client.types import StdioServerConfig +from letta.log import get_logger +from letta.services.mcp.base_client import AsyncBaseMCPClient + +logger = get_logger(__name__) + + +# TODO: Get rid of Async prefix on this class name once we deprecate old sync code +class AsyncStdioMCPClient(AsyncBaseMCPClient): + async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: StdioServerConfig) -> None: + server_params = StdioServerParameters(command=server_config.command, args=server_config.args) + stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params)) + self.stdio, self.write = stdio_transport + self.session = await exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) diff --git a/letta/services/mcp/types.py b/letta/services/mcp/types.py new file mode 100644 index 00000000..2d8b7af6 --- /dev/null +++ b/letta/services/mcp/types.py @@ -0,0 +1,48 @@ +from enum import Enum +from typing import List, Optional + +from mcp import Tool +from pydantic import BaseModel, Field + + +class MCPTool(Tool): + """A simple wrapper around MCP's tool definition (to avoid conflict with our own)""" + + +class MCPServerType(str, Enum): + SSE = "sse" + STDIO = "stdio" + + +class BaseServerConfig(BaseModel): + server_name: str = Field(..., description="The name of the server") + type: MCPServerType + + +class SSEServerConfig(BaseServerConfig): + type: MCPServerType = MCPServerType.SSE + server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)") + + def to_dict(self) -> dict: + values = { + "transport": "sse", + "url": self.server_url, + } + return values + + +class StdioServerConfig(BaseServerConfig): + type: MCPServerType = MCPServerType.STDIO + command: str = Field(..., description="The command to run (MCP 'local' client will run this command)") + args: List[str] = Field(..., description="The arguments to pass to the command") + env: Optional[dict[str, str]] = Field(None, description="Environment variables to set") + + def to_dict(self) -> dict: + values = { + "transport": "stdio", + "command": self.command, + "args": self.args, + } + if self.env is not None: + values["env"] = self.env + return values diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 21a36ded..c872f490 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -77,6 +77,17 @@ class SourceManager: ) return [source.to_pydantic() for source in sources] + @enforce_types + def size( + self, + actor: PydanticUser, + ) -> int: + """ + Get the total count of sources for the given user. + """ + with self.session_maker() as session: + return SourceModel.size(db_session=session, actor=actor) + @enforce_types def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]: """ diff --git a/letta/services/tool_executor/tool_execution_sandbox.py b/letta/services/tool_executor/tool_execution_sandbox.py index fa5f36cc..2588caf7 100644 --- a/letta/services/tool_executor/tool_execution_sandbox.py +++ b/letta/services/tool_executor/tool_execution_sandbox.py @@ -1,10 +1,12 @@ import ast import base64 +import io import os import pickle import subprocess import sys import tempfile +import traceback import uuid from typing import Any, Dict, Optional @@ -117,98 +119,108 @@ class ToolExecutionSandbox: @trace_method def run_local_dir_sandbox( - self, - agent_state: Optional[AgentState] = None, - additional_env_vars: Optional[Dict] = None, + self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None ) -> ToolExecutionResult: - sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config( - sandbox_type=SandboxType.LOCAL, - actor=self.user, - ) + sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user) local_configs = sbx_config.get_local_config() - sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) - venv_path = os.path.join(sandbox_dir, local_configs.venv_name) - # Aggregate environment variables + # Get environment variables for the sandbox env = os.environ.copy() - env.update(self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)) + env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) + env.update(env_vars) + + # Get environment variables for this agent specifically if agent_state: env.update(agent_state.get_agent_env_vars_as_dict()) + + # Finally, get any that are passed explicitly into the `run` function call if additional_env_vars: env.update(additional_env_vars) - # Ensure sandbox dir exists - if not os.path.exists(sandbox_dir): - logger.warning(f"Sandbox directory does not exist, creating: {sandbox_dir}") - os.makedirs(sandbox_dir) + # Safety checks + if not os.path.exists(local_configs.sandbox_dir) or not os.path.isdir(local_configs.sandbox_dir): + logger.warning(f"Sandbox directory does not exist, creating: {local_configs.sandbox_dir}") + os.makedirs(local_configs.sandbox_dir) + + # Write the code to a temp file in the sandbox_dir + with tempfile.NamedTemporaryFile(mode="w", dir=local_configs.sandbox_dir, suffix=".py", delete=False) as temp_file: + if local_configs.force_create_venv: + # If using venv, we need to wrap with special string markers to separate out the output and the stdout (since it is all in stdout) + code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True) + else: + code = self.generate_execution_script(agent_state=agent_state) - # Write the code to a temp file - with tempfile.NamedTemporaryFile(mode="w", dir=sandbox_dir, suffix=".py", delete=False) as temp_file: - code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True) temp_file.write(code) temp_file.flush() temp_file_path = temp_file.name - try: - # Decide whether to use venv - use_venv = os.path.isdir(venv_path) - - if self.force_recreate_venv or (not use_venv and local_configs.force_create_venv): - logger.warning(f"Virtual environment not found at {venv_path}. Creating one...") - log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path}) - create_venv_for_local_sandbox( - sandbox_dir_path=sandbox_dir, - venv_path=venv_path, - env=env, - force_recreate=self.force_recreate_venv, - ) - log_event(name="finish create_venv_for_local_sandbox") - use_venv = True - - if use_venv: - log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()}) - install_pip_requirements_for_sandbox(local_configs, env=env) - log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()}) - - python_executable = find_python_executable(local_configs) - if not os.path.isfile(python_executable): - logger.warning( - f"Python executable not found at expected venv path: {python_executable}. Falling back to system Python." - ) - python_executable = sys.executable - else: - env = dict(env) - env["VIRTUAL_ENV"] = venv_path - env["PATH"] = os.path.join(venv_path, "bin") + ":" + env.get("PATH", "") + if local_configs.force_create_venv: + return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path) else: - python_executable = sys.executable + return self.run_local_dir_sandbox_directly(sbx_config, env, temp_file_path) + except Exception as e: + logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}") + logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}") + raise e + finally: + # Clean up the temp file + os.remove(temp_file_path) - env["PYTHONWARNINGS"] = "ignore" + @trace_method + def run_local_dir_sandbox_venv( + self, + sbx_config: SandboxConfig, + env: Dict[str, str], + temp_file_path: str, + ) -> ToolExecutionResult: + local_configs = sbx_config.get_local_config() + sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde + venv_path = os.path.join(sandbox_dir, local_configs.venv_name) + # Recreate venv if required + if self.force_recreate_venv or not os.path.isdir(venv_path): + logger.warning(f"Virtual environment directory does not exist at: {venv_path}, creating one now...") + log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path}) + create_venv_for_local_sandbox( + sandbox_dir_path=sandbox_dir, venv_path=venv_path, env=env, force_recreate=self.force_recreate_venv + ) + log_event(name="finish create_venv_for_local_sandbox") + + log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()}) + install_pip_requirements_for_sandbox(local_configs, env=env) + log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()}) + + # Ensure Python executable exists + python_executable = find_python_executable(local_configs) + if not os.path.isfile(python_executable): + raise FileNotFoundError(f"Python executable not found in virtual environment: {python_executable}") + + # Set up environment variables + env["VIRTUAL_ENV"] = venv_path + env["PATH"] = os.path.join(venv_path, "bin") + ":" + env["PATH"] + env["PYTHONWARNINGS"] = "ignore" + + # Execute the code + try: log_event(name="start subprocess") result = subprocess.run( - [python_executable, temp_file_path], - env=env, - cwd=sandbox_dir, - timeout=60, - capture_output=True, - text=True, + [python_executable, temp_file_path], env=env, cwd=sandbox_dir, timeout=60, capture_output=True, text=True, check=True ) log_event(name="finish subprocess") func_result, stdout = self.parse_out_function_results_markers(result.stdout) - func_return, parsed_agent_state = self.parse_best_effort(func_result) + func_return, agent_state = self.parse_best_effort(func_result) return ToolExecutionResult( status="success", func_return=func_return, - agent_state=parsed_agent_state, + agent_state=agent_state, stdout=[stdout] if stdout else [], stderr=[result.stderr] if result.stderr else [], sandbox_config_fingerprint=sbx_config.fingerprint(), ) except subprocess.CalledProcessError as e: - logger.error(f"Tool execution failed: {e}") + logger.error(f"Executing tool {self.tool_name} has process error: {e}") func_return = get_friendly_error_msg( function_name=self.tool_name, exception_name=type(e).__name__, @@ -228,11 +240,72 @@ class ToolExecutionSandbox: except Exception as e: logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}") - logger.error(f"Generated script:\n{code}") raise e - finally: - os.remove(temp_file_path) + @trace_method + def run_local_dir_sandbox_directly( + self, + sbx_config: SandboxConfig, + env: Dict[str, str], + temp_file_path: str, + ) -> ToolExecutionResult: + status = "success" + func_return, agent_state, stderr = None, None, None + + old_stdout = sys.stdout + old_stderr = sys.stderr + captured_stdout, captured_stderr = io.StringIO(), io.StringIO() + + sys.stdout = captured_stdout + sys.stderr = captured_stderr + + try: + with self.temporary_env_vars(env): + + # Read and compile the Python script + with open(temp_file_path, "r", encoding="utf-8") as f: + source = f.read() + code_obj = compile(source, temp_file_path, "exec") + + # Provide a dict for globals. + globals_dict = dict(env) # or {} + # If you need to mimic `__main__` behavior: + globals_dict["__name__"] = "__main__" + globals_dict["__file__"] = temp_file_path + + # Execute the compiled code + log_event(name="start exec", attributes={"temp_file_path": temp_file_path}) + exec(code_obj, globals_dict) + log_event(name="finish exec", attributes={"temp_file_path": temp_file_path}) + + # Get result from the global dict + func_result = globals_dict.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME) + func_return, agent_state = self.parse_best_effort(func_result) + + except Exception as e: + func_return = get_friendly_error_msg( + function_name=self.tool_name, + exception_name=type(e).__name__, + exception_message=str(e), + ) + traceback.print_exc(file=sys.stderr) + status = "error" + + # Restore stdout/stderr + sys.stdout = old_stdout + sys.stderr = old_stderr + + stdout_output = [captured_stdout.getvalue()] if captured_stdout.getvalue() else [] + stderr_output = [captured_stderr.getvalue()] if captured_stderr.getvalue() else [] + + return ToolExecutionResult( + status=status, + func_return=func_return, + agent_state=agent_state, + stdout=stdout_output, + stderr=stderr_output, + sandbox_config_fingerprint=sbx_config.fingerprint(), + ) def parse_out_function_results_markers(self, text: str): if self.LOCAL_SANDBOX_RESULT_START_MARKER not in text: diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 90dbdcfa..571e67d4 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -145,6 +145,17 @@ class ToolManager: return results + @enforce_types + def size( + self, + actor: PydanticUser, + ) -> int: + """ + Get the total count of tools for the given user. + """ + with self.session_maker() as session: + return ToolModel.size(db_session=session, actor=actor) + @enforce_types def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool: """Update a tool by its ID with the given ToolUpdate object.""" diff --git a/pyproject.toml b/pyproject.toml index 2d66bc09..ed920e0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.7.4" +version = "0.7.5" packages = [ {include = "letta"}, ] diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 6b373d04..30bc3517 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -152,7 +152,7 @@ async def test_sleeptime_group_chat(server, actor): assert len(agent_runs) == len(run_ids) # 6. Verify run status after sleep - time.sleep(8) + time.sleep(10) for run_id in run_ids: job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor) assert job.status == JobStatus.completed diff --git a/tests/test_client.py b/tests/test_client.py index 280e50c7..4ae4828d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -505,9 +505,8 @@ def test_function_always_error(client: Letta): assert response_message, "ToolReturnMessage message not found in response" assert response_message.status == "error" - assert ( - response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero" - ), response_message.tool_return + assert "Error executing function testing_method" in response_message.tool_return, response_message.tool_return + assert "ZeroDivisionError: division by zero" in response_message.stderr[0] client.agents.delete(agent_id=agent.id) @@ -642,9 +641,9 @@ def test_agent_listing(client: Letta, agent, search_agent_one, search_agent_two) assert len(all_ids) == 2 assert all_ids == {search_agent_one.id, search_agent_two.id} - # Test listing without any filters + # Test listing without any filters; make less flakey by checking we have at least 3 agents in case created elsewhere all_agents = client.agents.list() - assert len(all_agents) == 3 + assert len(all_agents) >= 3 assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent]) diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index f0dc5e68..cbaa54dd 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -304,7 +304,7 @@ async def test_round_robin(server, actor, participant_agents): input_messages=[ MessageCreate( role="user", - content="what is everyone up to for the holidays?", + content="when should we plan our next adventure?", ), ], stream_steps=False, diff --git a/tests/test_v1_routes.py b/tests/test_v1_routes.py deleted file mode 100644 index d08ac86d..00000000 --- a/tests/test_v1_routes.py +++ /dev/null @@ -1,521 +0,0 @@ -from datetime import datetime, timezone -from unittest.mock import MagicMock, Mock - -import pytest -from composio.client.collections import AppModel -from fastapi.testclient import TestClient - -from letta.orm.errors import NoResultFound -from letta.schemas.block import Block, BlockUpdate, CreateBlock -from letta.schemas.message import UserMessage -from letta.schemas.sandbox_config import LocalSandboxConfig, PipRequirement, SandboxConfig -from letta.schemas.tool import ToolCreate, ToolUpdate -from letta.server.rest_api.app import app -from letta.server.rest_api.utils import get_letta_server -from tests.helpers.utils import create_tool_from_func - - -@pytest.fixture -def client(): - return TestClient(app) - - -@pytest.fixture -def mock_sync_server(): - mock_server = Mock() - app.dependency_overrides[get_letta_server] = lambda: mock_server - return mock_server - - -@pytest.fixture -def add_integers_tool(): - def add(x: int, y: int) -> int: - """ - Simple function that adds two integers. - - Parameters: - x (int): The first integer to add. - y (int): The second integer to add. - - Returns: - int: The result of adding x and y. - """ - return x + y - - tool = create_tool_from_func(add) - yield tool - - -@pytest.fixture -def create_integers_tool(add_integers_tool): - tool_create = ToolCreate( - description=add_integers_tool.description, - tags=add_integers_tool.tags, - source_code=add_integers_tool.source_code, - source_type=add_integers_tool.source_type, - json_schema=add_integers_tool.json_schema, - ) - yield tool_create - - -@pytest.fixture -def update_integers_tool(add_integers_tool): - tool_update = ToolUpdate( - description=add_integers_tool.description, - tags=add_integers_tool.tags, - source_code=add_integers_tool.source_code, - source_type=add_integers_tool.source_type, - json_schema=add_integers_tool.json_schema, - ) - yield tool_update - - -@pytest.fixture -def composio_apps(): - affinity_app = AppModel( - name="affinity", - key="affinity", - appId="3a7d2dc7-c58c-4491-be84-f64b1ff498a8", - description="Affinity helps private capital investors to find, manage, and close more deals", - categories=["CRM"], - meta={ - "is_custom_app": False, - "triggersCount": 0, - "actionsCount": 20, - "documentation_doc_text": None, - "configuration_docs_text": None, - }, - logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg", - docs=None, - group=None, - status=None, - enabled=False, - no_auth=False, - auth_schemes=None, - testConnectors=None, - documentation_doc_text=None, - configuration_docs_text=None, - ) - yield [affinity_app] - - -def configure_mock_sync_server(mock_sync_server): - # Mock sandbox config manager to return a valid API key - mock_api_key = Mock() - mock_api_key.value = "mock_composio_api_key" - mock_sync_server.sandbox_config_manager.list_sandbox_env_vars_by_key.return_value = [mock_api_key] - - # Mock user retrieval - mock_sync_server.user_manager.get_user_or_default.return_value = Mock() # Provide additional attributes if needed - - -# ====================================================================================================================== -# Tools Routes Tests -# ====================================================================================================================== -def test_delete_tool(client, mock_sync_server, add_integers_tool): - mock_sync_server.tool_manager.delete_tool_by_id = MagicMock() - - response = client.delete(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"}) - - assert response.status_code == 200 - mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with( - tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value - ) - - -def test_get_tool(client, mock_sync_server, add_integers_tool): - mock_sync_server.tool_manager.get_tool_by_id.return_value = add_integers_tool - - response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert response.json()["id"] == add_integers_tool.id - assert response.json()["source_code"] == add_integers_tool.source_code - mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with( - tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value - ) - - -def test_get_tool_404(client, mock_sync_server, add_integers_tool): - mock_sync_server.tool_manager.get_tool_by_id.return_value = None - - response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"}) - - assert response.status_code == 404 - assert response.json()["detail"] == f"Tool with id {add_integers_tool.id} not found." - - -def test_list_tools(client, mock_sync_server, add_integers_tool): - mock_sync_server.tool_manager.list_tools.return_value = [add_integers_tool] - - response = client.get("/v1/tools", headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert len(response.json()) == 1 - assert response.json()[0]["id"] == add_integers_tool.id - mock_sync_server.tool_manager.list_tools.assert_called_once() - - -def test_create_tool(client, mock_sync_server, create_integers_tool, add_integers_tool): - mock_sync_server.tool_manager.create_tool.return_value = add_integers_tool - - response = client.post("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert response.json()["id"] == add_integers_tool.id - mock_sync_server.tool_manager.create_tool.assert_called_once() - - -def test_upsert_tool(client, mock_sync_server, create_integers_tool, add_integers_tool): - mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool - - response = client.put("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert response.json()["id"] == add_integers_tool.id - mock_sync_server.tool_manager.create_or_update_tool.assert_called_once() - - -def test_update_tool(client, mock_sync_server, update_integers_tool, add_integers_tool): - mock_sync_server.tool_manager.update_tool_by_id.return_value = add_integers_tool - - response = client.patch(f"/v1/tools/{add_integers_tool.id}", json=update_integers_tool.model_dump(), headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert response.json()["id"] == add_integers_tool.id - mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with( - tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.user_manager.get_user_or_default.return_value - ) - - -def test_upsert_base_tools(client, mock_sync_server, add_integers_tool): - mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool] - - response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert len(response.json()) == 1 - assert response.json()[0]["id"] == add_integers_tool.id - mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with( - actor=mock_sync_server.user_manager.get_user_or_default.return_value - ) - - -# ====================================================================================================================== -# Runs Routes Tests -# ====================================================================================================================== - - -def test_get_run_messages(client, mock_sync_server): - """Test getting messages for a run.""" - # Create properly formatted mock messages - current_time = datetime.now(timezone.utc) - mock_messages = [ - UserMessage( - id=f"message-{i:08x}", - date=current_time, - content=f"Test message {i}", - ) - for i in range(2) - ] - - # Configure mock server responses - mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") - mock_sync_server.job_manager.get_run_messages.return_value = mock_messages - - # Test successful retrieval - response = client.get( - "/v1/runs/run-12345678/messages", - headers={"user_id": "user-123"}, - params={ - "limit": 10, - "before": "message-1234", - "after": "message-6789", - "role": "user", - "order": "desc", - }, - ) - assert response.status_code == 200 - assert len(response.json()) == 2 - assert response.json()[0]["id"] == mock_messages[0].id - assert response.json()[1]["id"] == mock_messages[1].id - - # Verify mock calls - mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123") - mock_sync_server.job_manager.get_run_messages.assert_called_once_with( - run_id="run-12345678", - actor=mock_sync_server.user_manager.get_user_or_default.return_value, - limit=10, - before="message-1234", - after="message-6789", - ascending=False, - role="user", - ) - - -def test_get_run_messages_not_found(client, mock_sync_server): - """Test getting messages for a non-existent run.""" - # Configure mock responses - error_message = "Run 'run-nonexistent' not found" - mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") - mock_sync_server.job_manager.get_run_messages.side_effect = NoResultFound(error_message) - - response = client.get("/v1/runs/run-nonexistent/messages", headers={"user_id": "user-123"}) - - assert response.status_code == 404 - assert error_message in response.json()["detail"] - - -def test_get_run_usage(client, mock_sync_server): - """Test getting usage statistics for a run.""" - # Configure mock responses - mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") - mock_usage = Mock( - completion_tokens=100, - prompt_tokens=200, - total_tokens=300, - ) - mock_sync_server.job_manager.get_job_usage.return_value = mock_usage - - # Make request - response = client.get("/v1/runs/run-12345678/usage", headers={"user_id": "user-123"}) - - # Check response - assert response.status_code == 200 - assert response.json() == { - "completion_tokens": 100, - "prompt_tokens": 200, - "total_tokens": 300, - } - - # Verify mock calls - mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123") - mock_sync_server.job_manager.get_job_usage.assert_called_once_with( - job_id="run-12345678", - actor=mock_sync_server.user_manager.get_user_or_default.return_value, - ) - - -def test_get_run_usage_not_found(client, mock_sync_server): - """Test getting usage statistics for a non-existent run.""" - # Configure mock responses - error_message = "Run 'run-nonexistent' not found" - mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123") - mock_sync_server.job_manager.get_job_usage.side_effect = NoResultFound(error_message) - - # Make request - response = client.get("/v1/runs/run-nonexistent/usage", headers={"user_id": "user-123"}) - - assert response.status_code == 404 - assert error_message in response.json()["detail"] - - -# ====================================================================================================================== -# Tags Routes Tests -# ====================================================================================================================== - - -def test_get_tags(client, mock_sync_server): - """Test basic tag listing""" - mock_sync_server.agent_manager.list_tags.return_value = ["tag1", "tag2"] - - response = client.get("/v1/tags", headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert response.json() == ["tag1", "tag2"] - mock_sync_server.agent_manager.list_tags.assert_called_once_with( - actor=mock_sync_server.user_manager.get_user_or_default.return_value, after=None, limit=50, query_text=None - ) - - -def test_get_tags_with_pagination(client, mock_sync_server): - """Test tag listing with pagination parameters""" - mock_sync_server.agent_manager.list_tags.return_value = ["tag3", "tag4"] - - response = client.get("/v1/tags", params={"after": "tag2", "limit": 2}, headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert response.json() == ["tag3", "tag4"] - mock_sync_server.agent_manager.list_tags.assert_called_once_with( - actor=mock_sync_server.user_manager.get_user_or_default.return_value, after="tag2", limit=2, query_text=None - ) - - -def test_get_tags_with_search(client, mock_sync_server): - """Test tag listing with text search""" - mock_sync_server.agent_manager.list_tags.return_value = ["user_tag1", "user_tag2"] - - response = client.get("/v1/tags", params={"query_text": "user"}, headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert response.json() == ["user_tag1", "user_tag2"] - mock_sync_server.agent_manager.list_tags.assert_called_once_with( - actor=mock_sync_server.user_manager.get_user_or_default.return_value, after=None, limit=50, query_text="user" - ) - - -# ====================================================================================================================== -# Blocks Routes Tests -# ====================================================================================================================== - - -def test_list_blocks(client, mock_sync_server): - """ - Test the GET /v1/blocks endpoint to list blocks. - """ - # Arrange: mock return from block_manager - mock_block = Block(label="human", value="Hi") - mock_sync_server.block_manager.get_blocks.return_value = [mock_block] - - # Act - response = client.get("/v1/blocks", headers={"user_id": "test_user"}) - - # Assert - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0]["id"] == mock_block.id - mock_sync_server.block_manager.get_blocks.assert_called_once_with( - actor=mock_sync_server.user_manager.get_user_or_default.return_value, - label=None, - is_template=False, - template_name=None, - identity_id=None, - identifier_keys=None, - ) - - -def test_create_block(client, mock_sync_server): - """ - Test the POST /v1/blocks endpoint to create a block. - """ - new_block = CreateBlock(label="system", value="Some system text") - returned_block = Block(**new_block.model_dump()) - - mock_sync_server.block_manager.create_or_update_block.return_value = returned_block - - response = client.post("/v1/blocks", json=new_block.model_dump(), headers={"user_id": "test_user"}) - assert response.status_code == 200 - data = response.json() - assert data["id"] == returned_block.id - - mock_sync_server.block_manager.create_or_update_block.assert_called_once() - - -def test_modify_block(client, mock_sync_server): - """ - Test the PATCH /v1/blocks/{block_id} endpoint to update a block. - """ - block_update = BlockUpdate(value="Updated text", description="New description") - updated_block = Block(label="human", value="Updated text", description="New description") - mock_sync_server.block_manager.update_block.return_value = updated_block - - response = client.patch(f"/v1/blocks/{updated_block.id}", json=block_update.model_dump(), headers={"user_id": "test_user"}) - assert response.status_code == 200 - data = response.json() - assert data["value"] == "Updated text" - assert data["description"] == "New description" - - mock_sync_server.block_manager.update_block.assert_called_once_with( - block_id=updated_block.id, - block_update=block_update, - actor=mock_sync_server.user_manager.get_user_or_default.return_value, - ) - - -def test_delete_block(client, mock_sync_server): - """ - Test the DELETE /v1/blocks/{block_id} endpoint. - """ - deleted_block = Block(label="persona", value="Deleted text") - mock_sync_server.block_manager.delete_block.return_value = deleted_block - - response = client.delete(f"/v1/blocks/{deleted_block.id}", headers={"user_id": "test_user"}) - assert response.status_code == 200 - data = response.json() - assert data["id"] == deleted_block.id - - mock_sync_server.block_manager.delete_block.assert_called_once_with( - block_id=deleted_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value - ) - - -def test_retrieve_block(client, mock_sync_server): - """ - Test the GET /v1/blocks/{block_id} endpoint. - """ - existing_block = Block(label="human", value="Hello") - mock_sync_server.block_manager.get_block_by_id.return_value = existing_block - - response = client.get(f"/v1/blocks/{existing_block.id}", headers={"user_id": "test_user"}) - assert response.status_code == 200 - data = response.json() - assert data["id"] == existing_block.id - - mock_sync_server.block_manager.get_block_by_id.assert_called_once_with( - block_id=existing_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value - ) - - -def test_retrieve_block_404(client, mock_sync_server): - """ - Test that retrieving a non-existent block returns 404. - """ - mock_sync_server.block_manager.get_block_by_id.return_value = None - - response = client.get("/v1/blocks/block-999", headers={"user_id": "test_user"}) - assert response.status_code == 404 - assert "Block not found" in response.json()["detail"] - - -def test_list_agents_for_block(client, mock_sync_server): - """ - Test the GET /v1/blocks/{block_id}/agents endpoint. - """ - mock_sync_server.block_manager.get_agents_for_block.return_value = [] - - response = client.get("/v1/blocks/block-abc/agents", headers={"user_id": "test_user"}) - assert response.status_code == 200 - data = response.json() - assert len(data) == 0 - - mock_sync_server.block_manager.get_agents_for_block.assert_called_once_with( - block_id="block-abc", - actor=mock_sync_server.user_manager.get_user_or_default.return_value, - ) - - -# ====================================================================================================================== -# Sandbox Config Routes Tests -# ====================================================================================================================== -@pytest.fixture -def sample_local_sandbox_config(): - """Fixture for a sample LocalSandboxConfig object.""" - return LocalSandboxConfig( - sandbox_dir="/custom/path", - force_create_venv=True, - venv_name="custom_venv_name", - pip_requirements=[ - PipRequirement(name="numpy", version="1.23.0"), - PipRequirement(name="pandas"), - ], - ) - - -def test_create_custom_local_sandbox_config(client, mock_sync_server, sample_local_sandbox_config): - """Test creating or updating a LocalSandboxConfig.""" - mock_sync_server.sandbox_config_manager.create_or_update_sandbox_config.return_value = SandboxConfig( - type="local", organization_id="org-123", config=sample_local_sandbox_config.model_dump() - ) - - response = client.post("/v1/sandbox-config/local", json=sample_local_sandbox_config.model_dump(), headers={"user_id": "test_user"}) - - assert response.status_code == 200 - assert response.json()["type"] == "local" - assert response.json()["config"]["sandbox_dir"] == "/custom/path" - assert response.json()["config"]["pip_requirements"] == [ - {"name": "numpy", "version": "1.23.0"}, - {"name": "pandas", "version": None}, - ] - - mock_sync_server.sandbox_config_manager.create_or_update_sandbox_config.assert_called_once()