chore: bump 0.7.5 (#2587)
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com> Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
This commit is contained in:
@@ -64,8 +64,7 @@ ENV LETTA_ENVIRONMENT=${LETTA_ENVIRONMENT} \
|
||||
POSTGRES_USER=letta \
|
||||
POSTGRES_PASSWORD=letta \
|
||||
POSTGRES_DB=letta \
|
||||
COMPOSIO_DISABLE_VERSION_CHECK=true \
|
||||
LETTA_OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4317"
|
||||
COMPOSIO_DISABLE_VERSION_CHECK=true
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ services:
|
||||
- VLLM_API_BASE=${VLLM_API_BASE}
|
||||
- OPENLLM_AUTH_TYPE=${OPENLLM_AUTH_TYPE}
|
||||
- OPENLLM_API_KEY=${OPENLLM_API_KEY}
|
||||
- LETTA_OTEL_EXPORTER_OTLP_ENDPOINT=${LETTA_OTEL_EXPORTER_OTLP_ENDPOINT}
|
||||
- CLICKHOUSE_ENDPOINT=${CLICKHOUSE_ENDPOINT}
|
||||
- CLICKHOUSE_DATABASE=${CLICKHOUSE_DATABASE}
|
||||
- CLICKHOUSE_USERNAME=${CLICKHOUSE_USERNAME}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.7.4"
|
||||
__version__ = "0.7.5"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
||||
@@ -190,16 +190,15 @@ class Agent(BaseAgent):
|
||||
Returns:
|
||||
modified (bool): whether the memory was updated
|
||||
"""
|
||||
if self.agent_state.memory.compile() != new_memory.compile():
|
||||
system_message = self.message_manager.get_message_by_id(message_id=self.agent_state.message_ids[0], actor=self.user)
|
||||
if new_memory.compile() not in system_message.content[0].text:
|
||||
# update the blocks (LRW) in the DB
|
||||
for label in self.agent_state.memory.list_block_labels():
|
||||
updated_value = new_memory.get_block(label).value
|
||||
if updated_value != self.agent_state.memory.get_block(label).value:
|
||||
# update the block if it's changed
|
||||
block_id = self.agent_state.memory.get_block(label).id
|
||||
block = self.block_manager.update_block(
|
||||
block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user
|
||||
)
|
||||
self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user)
|
||||
|
||||
# refresh memory from DB (using block ids)
|
||||
self.agent_state.memory = Memory(
|
||||
|
||||
@@ -1233,7 +1233,10 @@ class AzureProvider(Provider):
|
||||
"""
|
||||
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
|
||||
"""
|
||||
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)
|
||||
context_window = AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, None)
|
||||
if context_window is None:
|
||||
context_window = LLM_MAX_TOKENS.get(model_name, 4096)
|
||||
return context_window
|
||||
|
||||
|
||||
class VLLMChatCompletionsProvider(Provider):
|
||||
|
||||
@@ -104,6 +104,17 @@ def list_agents(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/count", response_model=int, operation_id="count_agents")
|
||||
def count_agents(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get the count of all agents associated with a given user.
|
||||
"""
|
||||
return server.agent_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
|
||||
|
||||
|
||||
class IndentedORJSONResponse(Response):
|
||||
media_type = "application/json"
|
||||
|
||||
|
||||
@@ -49,6 +49,24 @@ def list_identities(
|
||||
return identities
|
||||
|
||||
|
||||
@router.get("/count", tags=["identities"], response_model=int, operation_id="count_identities")
|
||||
def count_identities(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get count of all identities for a user
|
||||
"""
|
||||
try:
|
||||
return server.identity_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
|
||||
except NoResultFound:
|
||||
return 0
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
|
||||
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
|
||||
def retrieve_identity(
|
||||
identity_id: str,
|
||||
|
||||
@@ -67,6 +67,17 @@ def list_sources(
|
||||
return server.list_all_sources(actor=actor)
|
||||
|
||||
|
||||
@router.get("/count", response_model=int, operation_id="count_sources")
|
||||
def count_sources(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Count all data sources created by a user.
|
||||
"""
|
||||
return server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
|
||||
|
||||
|
||||
@router.post("/", response_model=Source, operation_id="create_source")
|
||||
def create_source(
|
||||
source_create: SourceCreate,
|
||||
|
||||
@@ -80,6 +80,21 @@ def list_tools(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/count", response_model=int, operation_id="count_tools")
|
||||
def count_tools(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get a count of all tools available to agents belonging to the org of the user
|
||||
"""
|
||||
try:
|
||||
return server.tool_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/", response_model=Tool, operation_id="create_tool")
|
||||
def create_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
|
||||
@@ -1308,14 +1308,12 @@ class SyncServer(Server):
|
||||
tool_execution_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
|
||||
agent_state=agent_state, additional_env_vars=tool_env_vars
|
||||
)
|
||||
status = "error" if tool_execution_result.stderr else "success"
|
||||
tool_return = str(tool_execution_result.stderr) if tool_execution_result.stderr else str(tool_execution_result.func_return)
|
||||
return ToolReturnMessage(
|
||||
id="null",
|
||||
tool_call_id="null",
|
||||
date=get_utc_time(),
|
||||
status=status,
|
||||
tool_return=tool_return,
|
||||
status=tool_execution_result.status,
|
||||
tool_return=str(tool_execution_result.func_return),
|
||||
stdout=tool_execution_result.stdout,
|
||||
stderr=tool_execution_result.stderr,
|
||||
)
|
||||
|
||||
@@ -556,6 +556,16 @@ class AgentManager:
|
||||
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
def size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of agents for the given user.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
return AgentModel.size(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
@@ -590,15 +600,18 @@ class AgentManager:
|
||||
agents_to_delete = [agent]
|
||||
sleeptime_group_to_delete = None
|
||||
|
||||
# Delete sleeptime agent and group
|
||||
# Delete sleeptime agent and group (TODO this is flimsy pls fix)
|
||||
if agent.multi_agent_group:
|
||||
participant_agent_ids = agent.multi_agent_group.agent_ids
|
||||
if agent.multi_agent_group.manager_type == ManagerType.sleeptime and len(participant_agent_ids) == 1:
|
||||
sleeptime_agent = AgentModel.read(db_session=session, identifier=participant_agent_ids[0], actor=actor)
|
||||
if sleeptime_agent.agent_type == AgentType.sleeptime_agent:
|
||||
sleeptime_agent_group = GroupModel.read(db_session=session, identifier=agent.multi_agent_group.id, actor=actor)
|
||||
sleeptime_group_to_delete = sleeptime_agent_group
|
||||
try:
|
||||
sleeptime_agent = AgentModel.read(db_session=session, identifier=participant_agent_ids[0], actor=actor)
|
||||
agents_to_delete.append(sleeptime_agent)
|
||||
except NoResultFound:
|
||||
pass # agent already deleted
|
||||
sleeptime_agent_group = GroupModel.read(db_session=session, identifier=agent.multi_agent_group.id, actor=actor)
|
||||
sleeptime_group_to_delete = sleeptime_agent_group
|
||||
|
||||
try:
|
||||
if sleeptime_group_to_delete is not None:
|
||||
session.delete(sleeptime_group_to_delete)
|
||||
@@ -931,7 +944,8 @@ class AgentManager:
|
||||
modified (bool): whether the memory was updated
|
||||
"""
|
||||
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
if agent_state.memory.compile() != new_memory.compile():
|
||||
system_message = self.message_manager.get_message_by_id(message_id=agent_state.message_ids[0], actor=actor)
|
||||
if new_memory.compile() not in system_message.content[0].text:
|
||||
# update the blocks (LRW) in the DB
|
||||
for label in agent_state.memory.list_block_labels():
|
||||
updated_value = new_memory.get_block(label).value
|
||||
|
||||
@@ -190,6 +190,17 @@ class IdentityManager:
|
||||
session.delete(identity)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of identities for the given user.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
return IdentityModel.size(db_session=session, actor=actor)
|
||||
|
||||
def _process_relationship(
|
||||
self,
|
||||
session: Session,
|
||||
|
||||
0
letta/services/mcp/__init__.py
Normal file
0
letta/services/mcp/__init__.py
Normal file
67
letta/services/mcp/base_client.py
Normal file
67
letta/services/mcp/base_client.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp import Tool as MCPTool
|
||||
from mcp.types import TextContent
|
||||
|
||||
from letta.functions.mcp_client.types import BaseServerConfig
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncBaseMCPClient:
|
||||
def __init__(self, server_config: BaseServerConfig):
|
||||
self.server_config = server_config
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.initialized = False
|
||||
|
||||
async def connect_to_server(self):
|
||||
try:
|
||||
await self._initialize_connection(self.server_config)
|
||||
await self.session.initialize()
|
||||
self.initialized = True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: BaseServerConfig) -> None:
|
||||
raise NotImplementedError("Subclasses must implement _initialize_connection")
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
self._check_initialized()
|
||||
response = await self.session.list_tools()
|
||||
return response.tools
|
||||
|
||||
async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
self._check_initialized()
|
||||
result = await self.session.call_tool(tool_name, tool_args)
|
||||
parsed_content = []
|
||||
for content_piece in result.content:
|
||||
if isinstance(content_piece, TextContent):
|
||||
parsed_content.append(content_piece.text)
|
||||
print("parsed_content (text)", parsed_content)
|
||||
else:
|
||||
parsed_content.append(str(content_piece))
|
||||
print("parsed_content (other)", parsed_content)
|
||||
if len(parsed_content) > 0:
|
||||
final_content = " ".join(parsed_content)
|
||||
else:
|
||||
# TODO move hardcoding to constants
|
||||
final_content = "Empty response from tool"
|
||||
|
||||
return final_content, result.isError
|
||||
|
||||
def _check_initialized(self):
|
||||
if not self.initialized:
|
||||
logger.error("MCPClient has not been initialized")
|
||||
raise RuntimeError("MCPClient has not been initialized")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
25
letta/services/mcp/sse_client.py
Normal file
25
letta/services/mcp/sse_client.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from letta.functions.mcp_client.types import SSEServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
|
||||
# see: https://modelcontextprotocol.io/quickstart/user
|
||||
MCP_CONFIG_TOPLEVEL_KEY = "mcpServers"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncSSEMCPClient(AsyncBaseMCPClient):
|
||||
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: SSEServerConfig) -> None:
|
||||
sse_cm = sse_client(url=server_config.server_url)
|
||||
sse_transport = await exit_stack.enter_async_context(sse_cm)
|
||||
self.stdio, self.write = sse_transport
|
||||
|
||||
# Create and enter the ClientSession context manager
|
||||
session_cm = ClientSession(self.stdio, self.write)
|
||||
self.session = await exit_stack.enter_async_context(session_cm)
|
||||
19
letta/services/mcp/stdio_client.py
Normal file
19
letta/services/mcp/stdio_client.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncStdioMCPClient(AsyncBaseMCPClient):
|
||||
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: StdioServerConfig) -> None:
|
||||
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
|
||||
stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
48
letta/services/mcp/types.py
Normal file
48
letta/services/mcp/types.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from mcp import Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""A simple wrapper around MCP's tool definition (to avoid conflict with our own)"""
|
||||
|
||||
|
||||
class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
STDIO = "stdio"
|
||||
|
||||
|
||||
class BaseServerConfig(BaseModel):
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
type: MCPServerType
|
||||
|
||||
|
||||
class SSEServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.SSE
|
||||
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "sse",
|
||||
"url": self.server_url,
|
||||
}
|
||||
return values
|
||||
|
||||
|
||||
class StdioServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.STDIO
|
||||
command: str = Field(..., description="The command to run (MCP 'local' client will run this command)")
|
||||
args: List[str] = Field(..., description="The arguments to pass to the command")
|
||||
env: Optional[dict[str, str]] = Field(None, description="Environment variables to set")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "stdio",
|
||||
"command": self.command,
|
||||
"args": self.args,
|
||||
}
|
||||
if self.env is not None:
|
||||
values["env"] = self.env
|
||||
return values
|
||||
@@ -77,6 +77,17 @@ class SourceManager:
|
||||
)
|
||||
return [source.to_pydantic() for source in sources]
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of sources for the given user.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
return SourceModel.size(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import ast
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -117,98 +119,108 @@ class ToolExecutionSandbox:
|
||||
|
||||
@trace_method
|
||||
def run_local_dir_sandbox(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
|
||||
) -> ToolExecutionResult:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(
|
||||
sandbox_type=SandboxType.LOCAL,
|
||||
actor=self.user,
|
||||
)
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
|
||||
local_configs = sbx_config.get_local_config()
|
||||
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
|
||||
venv_path = os.path.join(sandbox_dir, local_configs.venv_name)
|
||||
|
||||
# Aggregate environment variables
|
||||
# Get environment variables for the sandbox
|
||||
env = os.environ.copy()
|
||||
env.update(self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100))
|
||||
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
|
||||
env.update(env_vars)
|
||||
|
||||
# Get environment variables for this agent specifically
|
||||
if agent_state:
|
||||
env.update(agent_state.get_agent_env_vars_as_dict())
|
||||
|
||||
# Finally, get any that are passed explicitly into the `run` function call
|
||||
if additional_env_vars:
|
||||
env.update(additional_env_vars)
|
||||
|
||||
# Ensure sandbox dir exists
|
||||
if not os.path.exists(sandbox_dir):
|
||||
logger.warning(f"Sandbox directory does not exist, creating: {sandbox_dir}")
|
||||
os.makedirs(sandbox_dir)
|
||||
# Safety checks
|
||||
if not os.path.exists(local_configs.sandbox_dir) or not os.path.isdir(local_configs.sandbox_dir):
|
||||
logger.warning(f"Sandbox directory does not exist, creating: {local_configs.sandbox_dir}")
|
||||
os.makedirs(local_configs.sandbox_dir)
|
||||
|
||||
# Write the code to a temp file in the sandbox_dir
|
||||
with tempfile.NamedTemporaryFile(mode="w", dir=local_configs.sandbox_dir, suffix=".py", delete=False) as temp_file:
|
||||
if local_configs.force_create_venv:
|
||||
# If using venv, we need to wrap with special string markers to separate out the output and the stdout (since it is all in stdout)
|
||||
code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
|
||||
else:
|
||||
code = self.generate_execution_script(agent_state=agent_state)
|
||||
|
||||
# Write the code to a temp file
|
||||
with tempfile.NamedTemporaryFile(mode="w", dir=sandbox_dir, suffix=".py", delete=False) as temp_file:
|
||||
code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
|
||||
temp_file.write(code)
|
||||
temp_file.flush()
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Decide whether to use venv
|
||||
use_venv = os.path.isdir(venv_path)
|
||||
|
||||
if self.force_recreate_venv or (not use_venv and local_configs.force_create_venv):
|
||||
logger.warning(f"Virtual environment not found at {venv_path}. Creating one...")
|
||||
log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path})
|
||||
create_venv_for_local_sandbox(
|
||||
sandbox_dir_path=sandbox_dir,
|
||||
venv_path=venv_path,
|
||||
env=env,
|
||||
force_recreate=self.force_recreate_venv,
|
||||
)
|
||||
log_event(name="finish create_venv_for_local_sandbox")
|
||||
use_venv = True
|
||||
|
||||
if use_venv:
|
||||
log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
|
||||
install_pip_requirements_for_sandbox(local_configs, env=env)
|
||||
log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
|
||||
|
||||
python_executable = find_python_executable(local_configs)
|
||||
if not os.path.isfile(python_executable):
|
||||
logger.warning(
|
||||
f"Python executable not found at expected venv path: {python_executable}. Falling back to system Python."
|
||||
)
|
||||
python_executable = sys.executable
|
||||
else:
|
||||
env = dict(env)
|
||||
env["VIRTUAL_ENV"] = venv_path
|
||||
env["PATH"] = os.path.join(venv_path, "bin") + ":" + env.get("PATH", "")
|
||||
if local_configs.force_create_venv:
|
||||
return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path)
|
||||
else:
|
||||
python_executable = sys.executable
|
||||
return self.run_local_dir_sandbox_directly(sbx_config, env, temp_file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
|
||||
logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}")
|
||||
raise e
|
||||
finally:
|
||||
# Clean up the temp file
|
||||
os.remove(temp_file_path)
|
||||
|
||||
env["PYTHONWARNINGS"] = "ignore"
|
||||
@trace_method
|
||||
def run_local_dir_sandbox_venv(
|
||||
self,
|
||||
sbx_config: SandboxConfig,
|
||||
env: Dict[str, str],
|
||||
temp_file_path: str,
|
||||
) -> ToolExecutionResult:
|
||||
local_configs = sbx_config.get_local_config()
|
||||
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde
|
||||
venv_path = os.path.join(sandbox_dir, local_configs.venv_name)
|
||||
|
||||
# Recreate venv if required
|
||||
if self.force_recreate_venv or not os.path.isdir(venv_path):
|
||||
logger.warning(f"Virtual environment directory does not exist at: {venv_path}, creating one now...")
|
||||
log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path})
|
||||
create_venv_for_local_sandbox(
|
||||
sandbox_dir_path=sandbox_dir, venv_path=venv_path, env=env, force_recreate=self.force_recreate_venv
|
||||
)
|
||||
log_event(name="finish create_venv_for_local_sandbox")
|
||||
|
||||
log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
|
||||
install_pip_requirements_for_sandbox(local_configs, env=env)
|
||||
log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
|
||||
|
||||
# Ensure Python executable exists
|
||||
python_executable = find_python_executable(local_configs)
|
||||
if not os.path.isfile(python_executable):
|
||||
raise FileNotFoundError(f"Python executable not found in virtual environment: {python_executable}")
|
||||
|
||||
# Set up environment variables
|
||||
env["VIRTUAL_ENV"] = venv_path
|
||||
env["PATH"] = os.path.join(venv_path, "bin") + ":" + env["PATH"]
|
||||
env["PYTHONWARNINGS"] = "ignore"
|
||||
|
||||
# Execute the code
|
||||
try:
|
||||
log_event(name="start subprocess")
|
||||
result = subprocess.run(
|
||||
[python_executable, temp_file_path],
|
||||
env=env,
|
||||
cwd=sandbox_dir,
|
||||
timeout=60,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
[python_executable, temp_file_path], env=env, cwd=sandbox_dir, timeout=60, capture_output=True, text=True, check=True
|
||||
)
|
||||
log_event(name="finish subprocess")
|
||||
func_result, stdout = self.parse_out_function_results_markers(result.stdout)
|
||||
func_return, parsed_agent_state = self.parse_best_effort(func_result)
|
||||
func_return, agent_state = self.parse_best_effort(func_result)
|
||||
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=func_return,
|
||||
agent_state=parsed_agent_state,
|
||||
agent_state=agent_state,
|
||||
stdout=[stdout] if stdout else [],
|
||||
stderr=[result.stderr] if result.stderr else [],
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Tool execution failed: {e}")
|
||||
logger.error(f"Executing tool {self.tool_name} has process error: {e}")
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name,
|
||||
exception_name=type(e).__name__,
|
||||
@@ -228,11 +240,72 @@ class ToolExecutionSandbox:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
|
||||
logger.error(f"Generated script:\n{code}")
|
||||
raise e
|
||||
|
||||
finally:
|
||||
os.remove(temp_file_path)
|
||||
@trace_method
|
||||
def run_local_dir_sandbox_directly(
|
||||
self,
|
||||
sbx_config: SandboxConfig,
|
||||
env: Dict[str, str],
|
||||
temp_file_path: str,
|
||||
) -> ToolExecutionResult:
|
||||
status = "success"
|
||||
func_return, agent_state, stderr = None, None, None
|
||||
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
captured_stdout, captured_stderr = io.StringIO(), io.StringIO()
|
||||
|
||||
sys.stdout = captured_stdout
|
||||
sys.stderr = captured_stderr
|
||||
|
||||
try:
|
||||
with self.temporary_env_vars(env):
|
||||
|
||||
# Read and compile the Python script
|
||||
with open(temp_file_path, "r", encoding="utf-8") as f:
|
||||
source = f.read()
|
||||
code_obj = compile(source, temp_file_path, "exec")
|
||||
|
||||
# Provide a dict for globals.
|
||||
globals_dict = dict(env) # or {}
|
||||
# If you need to mimic `__main__` behavior:
|
||||
globals_dict["__name__"] = "__main__"
|
||||
globals_dict["__file__"] = temp_file_path
|
||||
|
||||
# Execute the compiled code
|
||||
log_event(name="start exec", attributes={"temp_file_path": temp_file_path})
|
||||
exec(code_obj, globals_dict)
|
||||
log_event(name="finish exec", attributes={"temp_file_path": temp_file_path})
|
||||
|
||||
# Get result from the global dict
|
||||
func_result = globals_dict.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME)
|
||||
func_return, agent_state = self.parse_best_effort(func_result)
|
||||
|
||||
except Exception as e:
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name,
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
status = "error"
|
||||
|
||||
# Restore stdout/stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
stdout_output = [captured_stdout.getvalue()] if captured_stdout.getvalue() else []
|
||||
stderr_output = [captured_stderr.getvalue()] if captured_stderr.getvalue() else []
|
||||
|
||||
return ToolExecutionResult(
|
||||
status=status,
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=stdout_output,
|
||||
stderr=stderr_output,
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
def parse_out_function_results_markers(self, text: str):
|
||||
if self.LOCAL_SANDBOX_RESULT_START_MARKER not in text:
|
||||
|
||||
@@ -145,6 +145,17 @@ class ToolManager:
|
||||
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
) -> int:
|
||||
"""
|
||||
Get the total count of tools for the given user.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
return ToolModel.size(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.7.4"
|
||||
version = "0.7.5"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
|
||||
@@ -152,7 +152,7 @@ async def test_sleeptime_group_chat(server, actor):
|
||||
assert len(agent_runs) == len(run_ids)
|
||||
|
||||
# 6. Verify run status after sleep
|
||||
time.sleep(8)
|
||||
time.sleep(10)
|
||||
for run_id in run_ids:
|
||||
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
||||
assert job.status == JobStatus.completed
|
||||
|
||||
@@ -505,9 +505,8 @@ def test_function_always_error(client: Letta):
|
||||
|
||||
assert response_message, "ToolReturnMessage message not found in response"
|
||||
assert response_message.status == "error"
|
||||
assert (
|
||||
response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero"
|
||||
), response_message.tool_return
|
||||
assert "Error executing function testing_method" in response_message.tool_return, response_message.tool_return
|
||||
assert "ZeroDivisionError: division by zero" in response_message.stderr[0]
|
||||
|
||||
client.agents.delete(agent_id=agent.id)
|
||||
|
||||
@@ -642,9 +641,9 @@ def test_agent_listing(client: Letta, agent, search_agent_one, search_agent_two)
|
||||
assert len(all_ids) == 2
|
||||
assert all_ids == {search_agent_one.id, search_agent_two.id}
|
||||
|
||||
# Test listing without any filters
|
||||
# Test listing without any filters; make less flakey by checking we have at least 3 agents in case created elsewhere
|
||||
all_agents = client.agents.list()
|
||||
assert len(all_agents) == 3
|
||||
assert len(all_agents) >= 3
|
||||
assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent])
|
||||
|
||||
|
||||
|
||||
@@ -304,7 +304,7 @@ async def test_round_robin(server, actor, participant_agents):
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="what is everyone up to for the holidays?",
|
||||
content="when should we plan our next adventure?",
|
||||
),
|
||||
],
|
||||
stream_steps=False,
|
||||
|
||||
@@ -1,521 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from composio.client.collections import AppModel
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
||||
from letta.schemas.message import UserMessage
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig, PipRequirement, SandboxConfig
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.server.rest_api.app import app
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from tests.helpers.utils import create_tool_from_func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sync_server():
|
||||
mock_server = Mock()
|
||||
app.dependency_overrides[get_letta_server] = lambda: mock_server
|
||||
return mock_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def add_integers_tool():
|
||||
def add(x: int, y: int) -> int:
|
||||
"""
|
||||
Simple function that adds two integers.
|
||||
|
||||
Parameters:
|
||||
x (int): The first integer to add.
|
||||
y (int): The second integer to add.
|
||||
|
||||
Returns:
|
||||
int: The result of adding x and y.
|
||||
"""
|
||||
return x + y
|
||||
|
||||
tool = create_tool_from_func(add)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_integers_tool(add_integers_tool):
|
||||
tool_create = ToolCreate(
|
||||
description=add_integers_tool.description,
|
||||
tags=add_integers_tool.tags,
|
||||
source_code=add_integers_tool.source_code,
|
||||
source_type=add_integers_tool.source_type,
|
||||
json_schema=add_integers_tool.json_schema,
|
||||
)
|
||||
yield tool_create
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def update_integers_tool(add_integers_tool):
|
||||
tool_update = ToolUpdate(
|
||||
description=add_integers_tool.description,
|
||||
tags=add_integers_tool.tags,
|
||||
source_code=add_integers_tool.source_code,
|
||||
source_type=add_integers_tool.source_type,
|
||||
json_schema=add_integers_tool.json_schema,
|
||||
)
|
||||
yield tool_update
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def composio_apps():
|
||||
affinity_app = AppModel(
|
||||
name="affinity",
|
||||
key="affinity",
|
||||
appId="3a7d2dc7-c58c-4491-be84-f64b1ff498a8",
|
||||
description="Affinity helps private capital investors to find, manage, and close more deals",
|
||||
categories=["CRM"],
|
||||
meta={
|
||||
"is_custom_app": False,
|
||||
"triggersCount": 0,
|
||||
"actionsCount": 20,
|
||||
"documentation_doc_text": None,
|
||||
"configuration_docs_text": None,
|
||||
},
|
||||
logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg",
|
||||
docs=None,
|
||||
group=None,
|
||||
status=None,
|
||||
enabled=False,
|
||||
no_auth=False,
|
||||
auth_schemes=None,
|
||||
testConnectors=None,
|
||||
documentation_doc_text=None,
|
||||
configuration_docs_text=None,
|
||||
)
|
||||
yield [affinity_app]
|
||||
|
||||
|
||||
def configure_mock_sync_server(mock_sync_server):
|
||||
# Mock sandbox config manager to return a valid API key
|
||||
mock_api_key = Mock()
|
||||
mock_api_key.value = "mock_composio_api_key"
|
||||
mock_sync_server.sandbox_config_manager.list_sandbox_env_vars_by_key.return_value = [mock_api_key]
|
||||
|
||||
# Mock user retrieval
|
||||
mock_sync_server.user_manager.get_user_or_default.return_value = Mock() # Provide additional attributes if needed
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Tools Routes Tests
|
||||
# ======================================================================================================================
|
||||
def test_delete_tool(client, mock_sync_server, add_integers_tool):
|
||||
mock_sync_server.tool_manager.delete_tool_by_id = MagicMock()
|
||||
|
||||
response = client.delete(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with(
|
||||
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool(client, mock_sync_server, add_integers_tool):
|
||||
mock_sync_server.tool_manager.get_tool_by_id.return_value = add_integers_tool
|
||||
|
||||
response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == add_integers_tool.id
|
||||
assert response.json()["source_code"] == add_integers_tool.source_code
|
||||
mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with(
|
||||
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_404(client, mock_sync_server, add_integers_tool):
|
||||
mock_sync_server.tool_manager.get_tool_by_id.return_value = None
|
||||
|
||||
response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == f"Tool with id {add_integers_tool.id} not found."
|
||||
|
||||
|
||||
def test_list_tools(client, mock_sync_server, add_integers_tool):
|
||||
mock_sync_server.tool_manager.list_tools.return_value = [add_integers_tool]
|
||||
|
||||
response = client.get("/v1/tools", headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert response.json()[0]["id"] == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.list_tools.assert_called_once()
|
||||
|
||||
|
||||
def test_create_tool(client, mock_sync_server, create_integers_tool, add_integers_tool):
|
||||
mock_sync_server.tool_manager.create_tool.return_value = add_integers_tool
|
||||
|
||||
response = client.post("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.create_tool.assert_called_once()
|
||||
|
||||
|
||||
def test_upsert_tool(client, mock_sync_server, create_integers_tool, add_integers_tool):
|
||||
mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool
|
||||
|
||||
response = client.put("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.create_or_update_tool.assert_called_once()
|
||||
|
||||
|
||||
def test_update_tool(client, mock_sync_server, update_integers_tool, add_integers_tool):
|
||||
mock_sync_server.tool_manager.update_tool_by_id.return_value = add_integers_tool
|
||||
|
||||
response = client.patch(f"/v1/tools/{add_integers_tool.id}", json=update_integers_tool.model_dump(), headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with(
|
||||
tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_upsert_base_tools(client, mock_sync_server, add_integers_tool):
|
||||
mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool]
|
||||
|
||||
response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert response.json()[0]["id"] == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Runs Routes Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_get_run_messages(client, mock_sync_server):
|
||||
"""Test getting messages for a run."""
|
||||
# Create properly formatted mock messages
|
||||
current_time = datetime.now(timezone.utc)
|
||||
mock_messages = [
|
||||
UserMessage(
|
||||
id=f"message-{i:08x}",
|
||||
date=current_time,
|
||||
content=f"Test message {i}",
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
# Configure mock server responses
|
||||
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
||||
mock_sync_server.job_manager.get_run_messages.return_value = mock_messages
|
||||
|
||||
# Test successful retrieval
|
||||
response = client.get(
|
||||
"/v1/runs/run-12345678/messages",
|
||||
headers={"user_id": "user-123"},
|
||||
params={
|
||||
"limit": 10,
|
||||
"before": "message-1234",
|
||||
"after": "message-6789",
|
||||
"role": "user",
|
||||
"order": "desc",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
assert response.json()[0]["id"] == mock_messages[0].id
|
||||
assert response.json()[1]["id"] == mock_messages[1].id
|
||||
|
||||
# Verify mock calls
|
||||
mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123")
|
||||
mock_sync_server.job_manager.get_run_messages.assert_called_once_with(
|
||||
run_id="run-12345678",
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
||||
limit=10,
|
||||
before="message-1234",
|
||||
after="message-6789",
|
||||
ascending=False,
|
||||
role="user",
|
||||
)
|
||||
|
||||
|
||||
def test_get_run_messages_not_found(client, mock_sync_server):
|
||||
"""Test getting messages for a non-existent run."""
|
||||
# Configure mock responses
|
||||
error_message = "Run 'run-nonexistent' not found"
|
||||
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
||||
mock_sync_server.job_manager.get_run_messages.side_effect = NoResultFound(error_message)
|
||||
|
||||
response = client.get("/v1/runs/run-nonexistent/messages", headers={"user_id": "user-123"})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert error_message in response.json()["detail"]
|
||||
|
||||
|
||||
def test_get_run_usage(client, mock_sync_server):
|
||||
"""Test getting usage statistics for a run."""
|
||||
# Configure mock responses
|
||||
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
||||
mock_usage = Mock(
|
||||
completion_tokens=100,
|
||||
prompt_tokens=200,
|
||||
total_tokens=300,
|
||||
)
|
||||
mock_sync_server.job_manager.get_job_usage.return_value = mock_usage
|
||||
|
||||
# Make request
|
||||
response = client.get("/v1/runs/run-12345678/usage", headers={"user_id": "user-123"})
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"completion_tokens": 100,
|
||||
"prompt_tokens": 200,
|
||||
"total_tokens": 300,
|
||||
}
|
||||
|
||||
# Verify mock calls
|
||||
mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123")
|
||||
mock_sync_server.job_manager.get_job_usage.assert_called_once_with(
|
||||
job_id="run-12345678",
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_get_run_usage_not_found(client, mock_sync_server):
|
||||
"""Test getting usage statistics for a non-existent run."""
|
||||
# Configure mock responses
|
||||
error_message = "Run 'run-nonexistent' not found"
|
||||
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
|
||||
mock_sync_server.job_manager.get_job_usage.side_effect = NoResultFound(error_message)
|
||||
|
||||
# Make request
|
||||
response = client.get("/v1/runs/run-nonexistent/usage", headers={"user_id": "user-123"})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert error_message in response.json()["detail"]
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Tags Routes Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_get_tags(client, mock_sync_server):
|
||||
"""Test basic tag listing"""
|
||||
mock_sync_server.agent_manager.list_tags.return_value = ["tag1", "tag2"]
|
||||
|
||||
response = client.get("/v1/tags", headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["tag1", "tag2"]
|
||||
mock_sync_server.agent_manager.list_tags.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value, after=None, limit=50, query_text=None
|
||||
)
|
||||
|
||||
|
||||
def test_get_tags_with_pagination(client, mock_sync_server):
|
||||
"""Test tag listing with pagination parameters"""
|
||||
mock_sync_server.agent_manager.list_tags.return_value = ["tag3", "tag4"]
|
||||
|
||||
response = client.get("/v1/tags", params={"after": "tag2", "limit": 2}, headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["tag3", "tag4"]
|
||||
mock_sync_server.agent_manager.list_tags.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value, after="tag2", limit=2, query_text=None
|
||||
)
|
||||
|
||||
|
||||
def test_get_tags_with_search(client, mock_sync_server):
|
||||
"""Test tag listing with text search"""
|
||||
mock_sync_server.agent_manager.list_tags.return_value = ["user_tag1", "user_tag2"]
|
||||
|
||||
response = client.get("/v1/tags", params={"query_text": "user"}, headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["user_tag1", "user_tag2"]
|
||||
mock_sync_server.agent_manager.list_tags.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value, after=None, limit=50, query_text="user"
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Blocks Routes Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_list_blocks(client, mock_sync_server):
|
||||
"""
|
||||
Test the GET /v1/blocks endpoint to list blocks.
|
||||
"""
|
||||
# Arrange: mock return from block_manager
|
||||
mock_block = Block(label="human", value="Hi")
|
||||
mock_sync_server.block_manager.get_blocks.return_value = [mock_block]
|
||||
|
||||
# Act
|
||||
response = client.get("/v1/blocks", headers={"user_id": "test_user"})
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == mock_block.id
|
||||
mock_sync_server.block_manager.get_blocks.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
||||
label=None,
|
||||
is_template=False,
|
||||
template_name=None,
|
||||
identity_id=None,
|
||||
identifier_keys=None,
|
||||
)
|
||||
|
||||
|
||||
def test_create_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the POST /v1/blocks endpoint to create a block.
|
||||
"""
|
||||
new_block = CreateBlock(label="system", value="Some system text")
|
||||
returned_block = Block(**new_block.model_dump())
|
||||
|
||||
mock_sync_server.block_manager.create_or_update_block.return_value = returned_block
|
||||
|
||||
response = client.post("/v1/blocks", json=new_block.model_dump(), headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == returned_block.id
|
||||
|
||||
mock_sync_server.block_manager.create_or_update_block.assert_called_once()
|
||||
|
||||
|
||||
def test_modify_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the PATCH /v1/blocks/{block_id} endpoint to update a block.
|
||||
"""
|
||||
block_update = BlockUpdate(value="Updated text", description="New description")
|
||||
updated_block = Block(label="human", value="Updated text", description="New description")
|
||||
mock_sync_server.block_manager.update_block.return_value = updated_block
|
||||
|
||||
response = client.patch(f"/v1/blocks/{updated_block.id}", json=block_update.model_dump(), headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["value"] == "Updated text"
|
||||
assert data["description"] == "New description"
|
||||
|
||||
mock_sync_server.block_manager.update_block.assert_called_once_with(
|
||||
block_id=updated_block.id,
|
||||
block_update=block_update,
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_delete_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the DELETE /v1/blocks/{block_id} endpoint.
|
||||
"""
|
||||
deleted_block = Block(label="persona", value="Deleted text")
|
||||
mock_sync_server.block_manager.delete_block.return_value = deleted_block
|
||||
|
||||
response = client.delete(f"/v1/blocks/{deleted_block.id}", headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == deleted_block.id
|
||||
|
||||
mock_sync_server.block_manager.delete_block.assert_called_once_with(
|
||||
block_id=deleted_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_retrieve_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the GET /v1/blocks/{block_id} endpoint.
|
||||
"""
|
||||
existing_block = Block(label="human", value="Hello")
|
||||
mock_sync_server.block_manager.get_block_by_id.return_value = existing_block
|
||||
|
||||
response = client.get(f"/v1/blocks/{existing_block.id}", headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == existing_block.id
|
||||
|
||||
mock_sync_server.block_manager.get_block_by_id.assert_called_once_with(
|
||||
block_id=existing_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_retrieve_block_404(client, mock_sync_server):
|
||||
"""
|
||||
Test that retrieving a non-existent block returns 404.
|
||||
"""
|
||||
mock_sync_server.block_manager.get_block_by_id.return_value = None
|
||||
|
||||
response = client.get("/v1/blocks/block-999", headers={"user_id": "test_user"})
|
||||
assert response.status_code == 404
|
||||
assert "Block not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_list_agents_for_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the GET /v1/blocks/{block_id}/agents endpoint.
|
||||
"""
|
||||
mock_sync_server.block_manager.get_agents_for_block.return_value = []
|
||||
|
||||
response = client.get("/v1/blocks/block-abc/agents", headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 0
|
||||
|
||||
mock_sync_server.block_manager.get_agents_for_block.assert_called_once_with(
|
||||
block_id="block-abc",
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Sandbox Config Routes Tests
|
||||
# ======================================================================================================================
|
||||
@pytest.fixture
|
||||
def sample_local_sandbox_config():
|
||||
"""Fixture for a sample LocalSandboxConfig object."""
|
||||
return LocalSandboxConfig(
|
||||
sandbox_dir="/custom/path",
|
||||
force_create_venv=True,
|
||||
venv_name="custom_venv_name",
|
||||
pip_requirements=[
|
||||
PipRequirement(name="numpy", version="1.23.0"),
|
||||
PipRequirement(name="pandas"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_create_custom_local_sandbox_config(client, mock_sync_server, sample_local_sandbox_config):
|
||||
"""Test creating or updating a LocalSandboxConfig."""
|
||||
mock_sync_server.sandbox_config_manager.create_or_update_sandbox_config.return_value = SandboxConfig(
|
||||
type="local", organization_id="org-123", config=sample_local_sandbox_config.model_dump()
|
||||
)
|
||||
|
||||
response = client.post("/v1/sandbox-config/local", json=sample_local_sandbox_config.model_dump(), headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["type"] == "local"
|
||||
assert response.json()["config"]["sandbox_dir"] == "/custom/path"
|
||||
assert response.json()["config"]["pip_requirements"] == [
|
||||
{"name": "numpy", "version": "1.23.0"},
|
||||
{"name": "pandas", "version": None},
|
||||
]
|
||||
|
||||
mock_sync_server.sandbox_config_manager.create_or_update_sandbox_config.assert_called_once()
|
||||
Reference in New Issue
Block a user