chore: bump 0.7.5 (#2587)

Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
This commit is contained in:
cthomas
2025-04-24 17:59:39 -07:00
committed by GitHub
parent 7299142f8e
commit 8cf5784258
25 changed files with 422 additions and 610 deletions

View File

@@ -64,8 +64,7 @@ ENV LETTA_ENVIRONMENT=${LETTA_ENVIRONMENT} \
POSTGRES_USER=letta \
POSTGRES_PASSWORD=letta \
POSTGRES_DB=letta \
COMPOSIO_DISABLE_VERSION_CHECK=true \
LETTA_OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4317"
COMPOSIO_DISABLE_VERSION_CHECK=true
WORKDIR /app

View File

@@ -49,6 +49,7 @@ services:
- VLLM_API_BASE=${VLLM_API_BASE}
- OPENLLM_AUTH_TYPE=${OPENLLM_AUTH_TYPE}
- OPENLLM_API_KEY=${OPENLLM_API_KEY}
- LETTA_OTEL_EXPORTER_OTLP_ENDPOINT=${LETTA_OTEL_EXPORTER_OTLP_ENDPOINT}
- CLICKHOUSE_ENDPOINT=${CLICKHOUSE_ENDPOINT}
- CLICKHOUSE_DATABASE=${CLICKHOUSE_DATABASE}
- CLICKHOUSE_USERNAME=${CLICKHOUSE_USERNAME}

View File

@@ -1,4 +1,4 @@
__version__ = "0.7.4"
__version__ = "0.7.5"
# import clients
from letta.client.client import LocalClient, RESTClient, create_client

View File

@@ -190,16 +190,15 @@ class Agent(BaseAgent):
Returns:
modified (bool): whether the memory was updated
"""
if self.agent_state.memory.compile() != new_memory.compile():
system_message = self.message_manager.get_message_by_id(message_id=self.agent_state.message_ids[0], actor=self.user)
if new_memory.compile() not in system_message.content[0].text:
# update the blocks (LRW) in the DB
for label in self.agent_state.memory.list_block_labels():
updated_value = new_memory.get_block(label).value
if updated_value != self.agent_state.memory.get_block(label).value:
# update the block if it's changed
block_id = self.agent_state.memory.get_block(label).id
block = self.block_manager.update_block(
block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user
)
self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=self.user)
# refresh memory from DB (using block ids)
self.agent_state.memory = Memory(

View File

@@ -1233,7 +1233,10 @@ class AzureProvider(Provider):
"""
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
"""
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)
context_window = AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, None)
if context_window is None:
context_window = LLM_MAX_TOKENS.get(model_name, 4096)
return context_window
class VLLMChatCompletionsProvider(Provider):

View File

@@ -104,6 +104,17 @@ def list_agents(
)
@router.get("/count", response_model=int, operation_id="count_agents")
def count_agents(
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Get the count of all agents associated with a given user.
"""
return server.agent_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
class IndentedORJSONResponse(Response):
media_type = "application/json"

View File

@@ -49,6 +49,24 @@ def list_identities(
return identities
@router.get("/count", tags=["identities"], response_model=int, operation_id="count_identities")
def count_identities(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Get count of all identities for a user
"""
try:
return server.identity_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
except NoResultFound:
return 0
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.get("/{identity_id}", tags=["identities"], response_model=Identity, operation_id="retrieve_identity")
def retrieve_identity(
identity_id: str,

View File

@@ -67,6 +67,17 @@ def list_sources(
return server.list_all_sources(actor=actor)
@router.get("/count", response_model=int, operation_id="count_sources")
def count_sources(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Count all data sources created by a user.
"""
return server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
@router.post("/", response_model=Source, operation_id="create_source")
def create_source(
source_create: SourceCreate,

View File

@@ -80,6 +80,21 @@ def list_tools(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/count", response_model=int, operation_id="count_tools")
def count_tools(
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Get a count of all tools available to agents belonging to the org of the user
"""
try:
return server.tool_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
except Exception as e:
print(f"Error occurred: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/", response_model=Tool, operation_id="create_tool")
def create_tool(
request: ToolCreate = Body(...),

View File

@@ -1308,14 +1308,12 @@ class SyncServer(Server):
tool_execution_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
agent_state=agent_state, additional_env_vars=tool_env_vars
)
status = "error" if tool_execution_result.stderr else "success"
tool_return = str(tool_execution_result.stderr) if tool_execution_result.stderr else str(tool_execution_result.func_return)
return ToolReturnMessage(
id="null",
tool_call_id="null",
date=get_utc_time(),
status=status,
tool_return=tool_return,
status=tool_execution_result.status,
tool_return=str(tool_execution_result.func_return),
stdout=tool_execution_result.stdout,
stderr=tool_execution_result.stderr,
)

View File

@@ -556,6 +556,16 @@ class AgentManager:
return list(session.execute(query).scalars())
def size(
self,
actor: PydanticUser,
) -> int:
"""
Get the total count of agents for the given user.
"""
with self.session_maker() as session:
return AgentModel.size(db_session=session, actor=actor)
@enforce_types
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
@@ -590,15 +600,18 @@ class AgentManager:
agents_to_delete = [agent]
sleeptime_group_to_delete = None
# Delete sleeptime agent and group
# Delete sleeptime agent and group (TODO this is flimsy pls fix)
if agent.multi_agent_group:
participant_agent_ids = agent.multi_agent_group.agent_ids
if agent.multi_agent_group.manager_type == ManagerType.sleeptime and len(participant_agent_ids) == 1:
sleeptime_agent = AgentModel.read(db_session=session, identifier=participant_agent_ids[0], actor=actor)
if sleeptime_agent.agent_type == AgentType.sleeptime_agent:
sleeptime_agent_group = GroupModel.read(db_session=session, identifier=agent.multi_agent_group.id, actor=actor)
sleeptime_group_to_delete = sleeptime_agent_group
try:
sleeptime_agent = AgentModel.read(db_session=session, identifier=participant_agent_ids[0], actor=actor)
agents_to_delete.append(sleeptime_agent)
except NoResultFound:
pass # agent already deleted
sleeptime_agent_group = GroupModel.read(db_session=session, identifier=agent.multi_agent_group.id, actor=actor)
sleeptime_group_to_delete = sleeptime_agent_group
try:
if sleeptime_group_to_delete is not None:
session.delete(sleeptime_group_to_delete)
@@ -931,7 +944,8 @@ class AgentManager:
modified (bool): whether the memory was updated
"""
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state.memory.compile() != new_memory.compile():
system_message = self.message_manager.get_message_by_id(message_id=agent_state.message_ids[0], actor=actor)
if new_memory.compile() not in system_message.content[0].text:
# update the blocks (LRW) in the DB
for label in agent_state.memory.list_block_labels():
updated_value = new_memory.get_block(label).value

View File

@@ -190,6 +190,17 @@ class IdentityManager:
session.delete(identity)
session.commit()
@enforce_types
def size(
self,
actor: PydanticUser,
) -> int:
"""
Get the total count of identities for the given user.
"""
with self.session_maker() as session:
return IdentityModel.size(db_session=session, actor=actor)
def _process_relationship(
self,
session: Session,

View File

View File

@@ -0,0 +1,67 @@
from contextlib import AsyncExitStack
from typing import Optional, Tuple
from mcp import ClientSession
from mcp import Tool as MCPTool
from mcp.types import TextContent
from letta.functions.mcp_client.types import BaseServerConfig
from letta.log import get_logger
logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncBaseMCPClient:
def __init__(self, server_config: BaseServerConfig):
self.server_config = server_config
self.exit_stack = AsyncExitStack()
self.session: Optional[ClientSession] = None
self.initialized = False
async def connect_to_server(self):
try:
await self._initialize_connection(self.server_config)
await self.session.initialize()
self.initialized = True
except Exception as e:
logger.error(
f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
)
raise e
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: BaseServerConfig) -> None:
raise NotImplementedError("Subclasses must implement _initialize_connection")
async def list_tools(self) -> list[MCPTool]:
self._check_initialized()
response = await self.session.list_tools()
return response.tools
async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
self._check_initialized()
result = await self.session.call_tool(tool_name, tool_args)
parsed_content = []
for content_piece in result.content:
if isinstance(content_piece, TextContent):
parsed_content.append(content_piece.text)
print("parsed_content (text)", parsed_content)
else:
parsed_content.append(str(content_piece))
print("parsed_content (other)", parsed_content)
if len(parsed_content) > 0:
final_content = " ".join(parsed_content)
else:
# TODO move hardcoding to constants
final_content = "Empty response from tool"
return final_content, result.isError
def _check_initialized(self):
if not self.initialized:
logger.error("MCPClient has not been initialized")
raise RuntimeError("MCPClient has not been initialized")
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()

View File

@@ -0,0 +1,25 @@
from contextlib import AsyncExitStack
from mcp import ClientSession
from mcp.client.sse import sse_client
from letta.functions.mcp_client.types import SSEServerConfig
from letta.log import get_logger
from letta.services.mcp.base_client import AsyncBaseMCPClient
# see: https://modelcontextprotocol.io/quickstart/user
MCP_CONFIG_TOPLEVEL_KEY = "mcpServers"
logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncSSEMCPClient(AsyncBaseMCPClient):
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: SSEServerConfig) -> None:
sse_cm = sse_client(url=server_config.server_url)
sse_transport = await exit_stack.enter_async_context(sse_cm)
self.stdio, self.write = sse_transport
# Create and enter the ClientSession context manager
session_cm = ClientSession(self.stdio, self.write)
self.session = await exit_stack.enter_async_context(session_cm)

View File

@@ -0,0 +1,19 @@
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from letta.functions.mcp_client.types import StdioServerConfig
from letta.log import get_logger
from letta.services.mcp.base_client import AsyncBaseMCPClient
logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncStdioMCPClient(AsyncBaseMCPClient):
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: StdioServerConfig) -> None:
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await exit_stack.enter_async_context(ClientSession(self.stdio, self.write))

View File

@@ -0,0 +1,48 @@
from enum import Enum
from typing import List, Optional
from mcp import Tool
from pydantic import BaseModel, Field
class MCPTool(Tool):
"""A simple wrapper around MCP's tool definition (to avoid conflict with our own)"""
class MCPServerType(str, Enum):
SSE = "sse"
STDIO = "stdio"
class BaseServerConfig(BaseModel):
server_name: str = Field(..., description="The name of the server")
type: MCPServerType
class SSEServerConfig(BaseServerConfig):
type: MCPServerType = MCPServerType.SSE
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
def to_dict(self) -> dict:
values = {
"transport": "sse",
"url": self.server_url,
}
return values
class StdioServerConfig(BaseServerConfig):
type: MCPServerType = MCPServerType.STDIO
command: str = Field(..., description="The command to run (MCP 'local' client will run this command)")
args: List[str] = Field(..., description="The arguments to pass to the command")
env: Optional[dict[str, str]] = Field(None, description="Environment variables to set")
def to_dict(self) -> dict:
values = {
"transport": "stdio",
"command": self.command,
"args": self.args,
}
if self.env is not None:
values["env"] = self.env
return values

View File

@@ -77,6 +77,17 @@ class SourceManager:
)
return [source.to_pydantic() for source in sources]
@enforce_types
def size(
self,
actor: PydanticUser,
) -> int:
"""
Get the total count of sources for the given user.
"""
with self.session_maker() as session:
return SourceModel.size(db_session=session, actor=actor)
@enforce_types
def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
"""

View File

@@ -1,10 +1,12 @@
import ast
import base64
import io
import os
import pickle
import subprocess
import sys
import tempfile
import traceback
import uuid
from typing import Any, Dict, Optional
@@ -117,98 +119,108 @@ class ToolExecutionSandbox:
@trace_method
def run_local_dir_sandbox(
self,
agent_state: Optional[AgentState] = None,
additional_env_vars: Optional[Dict] = None,
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
) -> ToolExecutionResult:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(
sandbox_type=SandboxType.LOCAL,
actor=self.user,
)
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
local_configs = sbx_config.get_local_config()
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
venv_path = os.path.join(sandbox_dir, local_configs.venv_name)
# Aggregate environment variables
# Get environment variables for the sandbox
env = os.environ.copy()
env.update(self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100))
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env.update(env_vars)
# Get environment variables for this agent specifically
if agent_state:
env.update(agent_state.get_agent_env_vars_as_dict())
# Finally, get any that are passed explicitly into the `run` function call
if additional_env_vars:
env.update(additional_env_vars)
# Ensure sandbox dir exists
if not os.path.exists(sandbox_dir):
logger.warning(f"Sandbox directory does not exist, creating: {sandbox_dir}")
os.makedirs(sandbox_dir)
# Safety checks
if not os.path.exists(local_configs.sandbox_dir) or not os.path.isdir(local_configs.sandbox_dir):
logger.warning(f"Sandbox directory does not exist, creating: {local_configs.sandbox_dir}")
os.makedirs(local_configs.sandbox_dir)
# Write the code to a temp file in the sandbox_dir
with tempfile.NamedTemporaryFile(mode="w", dir=local_configs.sandbox_dir, suffix=".py", delete=False) as temp_file:
if local_configs.force_create_venv:
# If using venv, we need to wrap with special string markers to separate out the output and the stdout (since it is all in stdout)
code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
else:
code = self.generate_execution_script(agent_state=agent_state)
# Write the code to a temp file
with tempfile.NamedTemporaryFile(mode="w", dir=sandbox_dir, suffix=".py", delete=False) as temp_file:
code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
temp_file.write(code)
temp_file.flush()
temp_file_path = temp_file.name
try:
# Decide whether to use venv
use_venv = os.path.isdir(venv_path)
if self.force_recreate_venv or (not use_venv and local_configs.force_create_venv):
logger.warning(f"Virtual environment not found at {venv_path}. Creating one...")
log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path})
create_venv_for_local_sandbox(
sandbox_dir_path=sandbox_dir,
venv_path=venv_path,
env=env,
force_recreate=self.force_recreate_venv,
)
log_event(name="finish create_venv_for_local_sandbox")
use_venv = True
if use_venv:
log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
install_pip_requirements_for_sandbox(local_configs, env=env)
log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
python_executable = find_python_executable(local_configs)
if not os.path.isfile(python_executable):
logger.warning(
f"Python executable not found at expected venv path: {python_executable}. Falling back to system Python."
)
python_executable = sys.executable
else:
env = dict(env)
env["VIRTUAL_ENV"] = venv_path
env["PATH"] = os.path.join(venv_path, "bin") + ":" + env.get("PATH", "")
if local_configs.force_create_venv:
return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path)
else:
python_executable = sys.executable
return self.run_local_dir_sandbox_directly(sbx_config, env, temp_file_path)
except Exception as e:
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}")
raise e
finally:
# Clean up the temp file
os.remove(temp_file_path)
env["PYTHONWARNINGS"] = "ignore"
@trace_method
def run_local_dir_sandbox_venv(
self,
sbx_config: SandboxConfig,
env: Dict[str, str],
temp_file_path: str,
) -> ToolExecutionResult:
local_configs = sbx_config.get_local_config()
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde
venv_path = os.path.join(sandbox_dir, local_configs.venv_name)
# Recreate venv if required
if self.force_recreate_venv or not os.path.isdir(venv_path):
logger.warning(f"Virtual environment directory does not exist at: {venv_path}, creating one now...")
log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path})
create_venv_for_local_sandbox(
sandbox_dir_path=sandbox_dir, venv_path=venv_path, env=env, force_recreate=self.force_recreate_venv
)
log_event(name="finish create_venv_for_local_sandbox")
log_event(name="start install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
install_pip_requirements_for_sandbox(local_configs, env=env)
log_event(name="finish install_pip_requirements_for_sandbox", attributes={"local_configs": local_configs.model_dump_json()})
# Ensure Python executable exists
python_executable = find_python_executable(local_configs)
if not os.path.isfile(python_executable):
raise FileNotFoundError(f"Python executable not found in virtual environment: {python_executable}")
# Set up environment variables
env["VIRTUAL_ENV"] = venv_path
env["PATH"] = os.path.join(venv_path, "bin") + ":" + env["PATH"]
env["PYTHONWARNINGS"] = "ignore"
# Execute the code
try:
log_event(name="start subprocess")
result = subprocess.run(
[python_executable, temp_file_path],
env=env,
cwd=sandbox_dir,
timeout=60,
capture_output=True,
text=True,
[python_executable, temp_file_path], env=env, cwd=sandbox_dir, timeout=60, capture_output=True, text=True, check=True
)
log_event(name="finish subprocess")
func_result, stdout = self.parse_out_function_results_markers(result.stdout)
func_return, parsed_agent_state = self.parse_best_effort(func_result)
func_return, agent_state = self.parse_best_effort(func_result)
return ToolExecutionResult(
status="success",
func_return=func_return,
agent_state=parsed_agent_state,
agent_state=agent_state,
stdout=[stdout] if stdout else [],
stderr=[result.stderr] if result.stderr else [],
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
except subprocess.CalledProcessError as e:
logger.error(f"Tool execution failed: {e}")
logger.error(f"Executing tool {self.tool_name} has process error: {e}")
func_return = get_friendly_error_msg(
function_name=self.tool_name,
exception_name=type(e).__name__,
@@ -228,11 +240,72 @@ class ToolExecutionSandbox:
except Exception as e:
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
logger.error(f"Generated script:\n{code}")
raise e
finally:
os.remove(temp_file_path)
@trace_method
def run_local_dir_sandbox_directly(
self,
sbx_config: SandboxConfig,
env: Dict[str, str],
temp_file_path: str,
) -> ToolExecutionResult:
status = "success"
func_return, agent_state, stderr = None, None, None
old_stdout = sys.stdout
old_stderr = sys.stderr
captured_stdout, captured_stderr = io.StringIO(), io.StringIO()
sys.stdout = captured_stdout
sys.stderr = captured_stderr
try:
with self.temporary_env_vars(env):
# Read and compile the Python script
with open(temp_file_path, "r", encoding="utf-8") as f:
source = f.read()
code_obj = compile(source, temp_file_path, "exec")
# Provide a dict for globals.
globals_dict = dict(env) # or {}
# If you need to mimic `__main__` behavior:
globals_dict["__name__"] = "__main__"
globals_dict["__file__"] = temp_file_path
# Execute the compiled code
log_event(name="start exec", attributes={"temp_file_path": temp_file_path})
exec(code_obj, globals_dict)
log_event(name="finish exec", attributes={"temp_file_path": temp_file_path})
# Get result from the global dict
func_result = globals_dict.get(self.LOCAL_SANDBOX_RESULT_VAR_NAME)
func_return, agent_state = self.parse_best_effort(func_result)
except Exception as e:
func_return = get_friendly_error_msg(
function_name=self.tool_name,
exception_name=type(e).__name__,
exception_message=str(e),
)
traceback.print_exc(file=sys.stderr)
status = "error"
# Restore stdout/stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
stdout_output = [captured_stdout.getvalue()] if captured_stdout.getvalue() else []
stderr_output = [captured_stderr.getvalue()] if captured_stderr.getvalue() else []
return ToolExecutionResult(
status=status,
func_return=func_return,
agent_state=agent_state,
stdout=stdout_output,
stderr=stderr_output,
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
def parse_out_function_results_markers(self, text: str):
if self.LOCAL_SANDBOX_RESULT_START_MARKER not in text:

View File

@@ -145,6 +145,17 @@ class ToolManager:
return results
@enforce_types
def size(
self,
actor: PydanticUser,
) -> int:
"""
Get the total count of tools for the given user.
"""
with self.session_maker() as session:
return ToolModel.size(db_session=session, actor=actor)
@enforce_types
def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser) -> PydanticTool:
"""Update a tool by its ID with the given ToolUpdate object."""

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.7.4"
version = "0.7.5"
packages = [
{include = "letta"},
]

View File

@@ -152,7 +152,7 @@ async def test_sleeptime_group_chat(server, actor):
assert len(agent_runs) == len(run_ids)
# 6. Verify run status after sleep
time.sleep(8)
time.sleep(10)
for run_id in run_ids:
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
assert job.status == JobStatus.completed

View File

@@ -505,9 +505,8 @@ def test_function_always_error(client: Letta):
assert response_message, "ToolReturnMessage message not found in response"
assert response_message.status == "error"
assert (
response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero"
), response_message.tool_return
assert "Error executing function testing_method" in response_message.tool_return, response_message.tool_return
assert "ZeroDivisionError: division by zero" in response_message.stderr[0]
client.agents.delete(agent_id=agent.id)
@@ -642,9 +641,9 @@ def test_agent_listing(client: Letta, agent, search_agent_one, search_agent_two)
assert len(all_ids) == 2
assert all_ids == {search_agent_one.id, search_agent_two.id}
# Test listing without any filters
# Test listing without any filters; make less flakey by checking we have at least 3 agents in case created elsewhere
all_agents = client.agents.list()
assert len(all_agents) == 3
assert len(all_agents) >= 3
assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent])

View File

@@ -304,7 +304,7 @@ async def test_round_robin(server, actor, participant_agents):
input_messages=[
MessageCreate(
role="user",
content="what is everyone up to for the holidays?",
content="when should we plan our next adventure?",
),
],
stream_steps=False,

View File

@@ -1,521 +0,0 @@
from datetime import datetime, timezone
from unittest.mock import MagicMock, Mock
import pytest
from composio.client.collections import AppModel
from fastapi.testclient import TestClient
from letta.orm.errors import NoResultFound
from letta.schemas.block import Block, BlockUpdate, CreateBlock
from letta.schemas.message import UserMessage
from letta.schemas.sandbox_config import LocalSandboxConfig, PipRequirement, SandboxConfig
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.server.rest_api.app import app
from letta.server.rest_api.utils import get_letta_server
from tests.helpers.utils import create_tool_from_func
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def mock_sync_server():
mock_server = Mock()
app.dependency_overrides[get_letta_server] = lambda: mock_server
return mock_server
@pytest.fixture
def add_integers_tool():
def add(x: int, y: int) -> int:
"""
Simple function that adds two integers.
Parameters:
x (int): The first integer to add.
y (int): The second integer to add.
Returns:
int: The result of adding x and y.
"""
return x + y
tool = create_tool_from_func(add)
yield tool
@pytest.fixture
def create_integers_tool(add_integers_tool):
tool_create = ToolCreate(
description=add_integers_tool.description,
tags=add_integers_tool.tags,
source_code=add_integers_tool.source_code,
source_type=add_integers_tool.source_type,
json_schema=add_integers_tool.json_schema,
)
yield tool_create
@pytest.fixture
def update_integers_tool(add_integers_tool):
tool_update = ToolUpdate(
description=add_integers_tool.description,
tags=add_integers_tool.tags,
source_code=add_integers_tool.source_code,
source_type=add_integers_tool.source_type,
json_schema=add_integers_tool.json_schema,
)
yield tool_update
@pytest.fixture
def composio_apps():
affinity_app = AppModel(
name="affinity",
key="affinity",
appId="3a7d2dc7-c58c-4491-be84-f64b1ff498a8",
description="Affinity helps private capital investors to find, manage, and close more deals",
categories=["CRM"],
meta={
"is_custom_app": False,
"triggersCount": 0,
"actionsCount": 20,
"documentation_doc_text": None,
"configuration_docs_text": None,
},
logo="https://cdn.jsdelivr.net/gh/ComposioHQ/open-logos@master/affinity.jpeg",
docs=None,
group=None,
status=None,
enabled=False,
no_auth=False,
auth_schemes=None,
testConnectors=None,
documentation_doc_text=None,
configuration_docs_text=None,
)
yield [affinity_app]
def configure_mock_sync_server(mock_sync_server):
# Mock sandbox config manager to return a valid API key
mock_api_key = Mock()
mock_api_key.value = "mock_composio_api_key"
mock_sync_server.sandbox_config_manager.list_sandbox_env_vars_by_key.return_value = [mock_api_key]
# Mock user retrieval
mock_sync_server.user_manager.get_user_or_default.return_value = Mock() # Provide additional attributes if needed
# ======================================================================================================================
# Tools Routes Tests
# ======================================================================================================================
def test_delete_tool(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.delete_tool_by_id = MagicMock()
response = client.delete(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
assert response.status_code == 200
mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with(
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
def test_get_tool(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.get_tool_by_id.return_value = add_integers_tool
response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
assert response.json()["source_code"] == add_integers_tool.source_code
mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with(
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
def test_get_tool_404(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.get_tool_by_id.return_value = None
response = client.get(f"/v1/tools/{add_integers_tool.id}", headers={"user_id": "test_user"})
assert response.status_code == 404
assert response.json()["detail"] == f"Tool with id {add_integers_tool.id} not found."
def test_list_tools(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.list_tools.return_value = [add_integers_tool]
response = client.get("/v1/tools", headers={"user_id": "test_user"})
assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["id"] == add_integers_tool.id
mock_sync_server.tool_manager.list_tools.assert_called_once()
def test_create_tool(client, mock_sync_server, create_integers_tool, add_integers_tool):
mock_sync_server.tool_manager.create_tool.return_value = add_integers_tool
response = client.post("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
mock_sync_server.tool_manager.create_tool.assert_called_once()
def test_upsert_tool(client, mock_sync_server, create_integers_tool, add_integers_tool):
mock_sync_server.tool_manager.create_or_update_tool.return_value = add_integers_tool
response = client.put("/v1/tools", json=create_integers_tool.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
mock_sync_server.tool_manager.create_or_update_tool.assert_called_once()
def test_update_tool(client, mock_sync_server, update_integers_tool, add_integers_tool):
mock_sync_server.tool_manager.update_tool_by_id.return_value = add_integers_tool
response = client.patch(f"/v1/tools/{add_integers_tool.id}", json=update_integers_tool.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with(
tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
def test_upsert_base_tools(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool]
response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"})
assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["id"] == add_integers_tool.id
mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
# ======================================================================================================================
# Runs Routes Tests
# ======================================================================================================================
def test_get_run_messages(client, mock_sync_server):
"""Test getting messages for a run."""
# Create properly formatted mock messages
current_time = datetime.now(timezone.utc)
mock_messages = [
UserMessage(
id=f"message-{i:08x}",
date=current_time,
content=f"Test message {i}",
)
for i in range(2)
]
# Configure mock server responses
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
mock_sync_server.job_manager.get_run_messages.return_value = mock_messages
# Test successful retrieval
response = client.get(
"/v1/runs/run-12345678/messages",
headers={"user_id": "user-123"},
params={
"limit": 10,
"before": "message-1234",
"after": "message-6789",
"role": "user",
"order": "desc",
},
)
assert response.status_code == 200
assert len(response.json()) == 2
assert response.json()[0]["id"] == mock_messages[0].id
assert response.json()[1]["id"] == mock_messages[1].id
# Verify mock calls
mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123")
mock_sync_server.job_manager.get_run_messages.assert_called_once_with(
run_id="run-12345678",
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
limit=10,
before="message-1234",
after="message-6789",
ascending=False,
role="user",
)
def test_get_run_messages_not_found(client, mock_sync_server):
"""Test getting messages for a non-existent run."""
# Configure mock responses
error_message = "Run 'run-nonexistent' not found"
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
mock_sync_server.job_manager.get_run_messages.side_effect = NoResultFound(error_message)
response = client.get("/v1/runs/run-nonexistent/messages", headers={"user_id": "user-123"})
assert response.status_code == 404
assert error_message in response.json()["detail"]
def test_get_run_usage(client, mock_sync_server):
"""Test getting usage statistics for a run."""
# Configure mock responses
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
mock_usage = Mock(
completion_tokens=100,
prompt_tokens=200,
total_tokens=300,
)
mock_sync_server.job_manager.get_job_usage.return_value = mock_usage
# Make request
response = client.get("/v1/runs/run-12345678/usage", headers={"user_id": "user-123"})
# Check response
assert response.status_code == 200
assert response.json() == {
"completion_tokens": 100,
"prompt_tokens": 200,
"total_tokens": 300,
}
# Verify mock calls
mock_sync_server.user_manager.get_user_or_default.assert_called_once_with(user_id="user-123")
mock_sync_server.job_manager.get_job_usage.assert_called_once_with(
job_id="run-12345678",
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
)
def test_get_run_usage_not_found(client, mock_sync_server):
"""Test getting usage statistics for a non-existent run."""
# Configure mock responses
error_message = "Run 'run-nonexistent' not found"
mock_sync_server.user_manager.get_user_or_default.return_value = Mock(id="user-123")
mock_sync_server.job_manager.get_job_usage.side_effect = NoResultFound(error_message)
# Make request
response = client.get("/v1/runs/run-nonexistent/usage", headers={"user_id": "user-123"})
assert response.status_code == 404
assert error_message in response.json()["detail"]
# ======================================================================================================================
# Tags Routes Tests
# ======================================================================================================================
def test_get_tags(client, mock_sync_server):
"""Test basic tag listing"""
mock_sync_server.agent_manager.list_tags.return_value = ["tag1", "tag2"]
response = client.get("/v1/tags", headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json() == ["tag1", "tag2"]
mock_sync_server.agent_manager.list_tags.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value, after=None, limit=50, query_text=None
)
def test_get_tags_with_pagination(client, mock_sync_server):
"""Test tag listing with pagination parameters"""
mock_sync_server.agent_manager.list_tags.return_value = ["tag3", "tag4"]
response = client.get("/v1/tags", params={"after": "tag2", "limit": 2}, headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json() == ["tag3", "tag4"]
mock_sync_server.agent_manager.list_tags.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value, after="tag2", limit=2, query_text=None
)
def test_get_tags_with_search(client, mock_sync_server):
"""Test tag listing with text search"""
mock_sync_server.agent_manager.list_tags.return_value = ["user_tag1", "user_tag2"]
response = client.get("/v1/tags", params={"query_text": "user"}, headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json() == ["user_tag1", "user_tag2"]
mock_sync_server.agent_manager.list_tags.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value, after=None, limit=50, query_text="user"
)
# ======================================================================================================================
# Blocks Routes Tests
# ======================================================================================================================
def test_list_blocks(client, mock_sync_server):
"""
Test the GET /v1/blocks endpoint to list blocks.
"""
# Arrange: mock return from block_manager
mock_block = Block(label="human", value="Hi")
mock_sync_server.block_manager.get_blocks.return_value = [mock_block]
# Act
response = client.get("/v1/blocks", headers={"user_id": "test_user"})
# Assert
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["id"] == mock_block.id
mock_sync_server.block_manager.get_blocks.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
label=None,
is_template=False,
template_name=None,
identity_id=None,
identifier_keys=None,
)
def test_create_block(client, mock_sync_server):
"""
Test the POST /v1/blocks endpoint to create a block.
"""
new_block = CreateBlock(label="system", value="Some system text")
returned_block = Block(**new_block.model_dump())
mock_sync_server.block_manager.create_or_update_block.return_value = returned_block
response = client.post("/v1/blocks", json=new_block.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert data["id"] == returned_block.id
mock_sync_server.block_manager.create_or_update_block.assert_called_once()
def test_modify_block(client, mock_sync_server):
"""
Test the PATCH /v1/blocks/{block_id} endpoint to update a block.
"""
block_update = BlockUpdate(value="Updated text", description="New description")
updated_block = Block(label="human", value="Updated text", description="New description")
mock_sync_server.block_manager.update_block.return_value = updated_block
response = client.patch(f"/v1/blocks/{updated_block.id}", json=block_update.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert data["value"] == "Updated text"
assert data["description"] == "New description"
mock_sync_server.block_manager.update_block.assert_called_once_with(
block_id=updated_block.id,
block_update=block_update,
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
)
def test_delete_block(client, mock_sync_server):
"""
Test the DELETE /v1/blocks/{block_id} endpoint.
"""
deleted_block = Block(label="persona", value="Deleted text")
mock_sync_server.block_manager.delete_block.return_value = deleted_block
response = client.delete(f"/v1/blocks/{deleted_block.id}", headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert data["id"] == deleted_block.id
mock_sync_server.block_manager.delete_block.assert_called_once_with(
block_id=deleted_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
def test_retrieve_block(client, mock_sync_server):
"""
Test the GET /v1/blocks/{block_id} endpoint.
"""
existing_block = Block(label="human", value="Hello")
mock_sync_server.block_manager.get_block_by_id.return_value = existing_block
response = client.get(f"/v1/blocks/{existing_block.id}", headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert data["id"] == existing_block.id
mock_sync_server.block_manager.get_block_by_id.assert_called_once_with(
block_id=existing_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
def test_retrieve_block_404(client, mock_sync_server):
"""
Test that retrieving a non-existent block returns 404.
"""
mock_sync_server.block_manager.get_block_by_id.return_value = None
response = client.get("/v1/blocks/block-999", headers={"user_id": "test_user"})
assert response.status_code == 404
assert "Block not found" in response.json()["detail"]
def test_list_agents_for_block(client, mock_sync_server):
"""
Test the GET /v1/blocks/{block_id}/agents endpoint.
"""
mock_sync_server.block_manager.get_agents_for_block.return_value = []
response = client.get("/v1/blocks/block-abc/agents", headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert len(data) == 0
mock_sync_server.block_manager.get_agents_for_block.assert_called_once_with(
block_id="block-abc",
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
)
# ======================================================================================================================
# Sandbox Config Routes Tests
# ======================================================================================================================
@pytest.fixture
def sample_local_sandbox_config():
"""Fixture for a sample LocalSandboxConfig object."""
return LocalSandboxConfig(
sandbox_dir="/custom/path",
force_create_venv=True,
venv_name="custom_venv_name",
pip_requirements=[
PipRequirement(name="numpy", version="1.23.0"),
PipRequirement(name="pandas"),
],
)
def test_create_custom_local_sandbox_config(client, mock_sync_server, sample_local_sandbox_config):
"""Test creating or updating a LocalSandboxConfig."""
mock_sync_server.sandbox_config_manager.create_or_update_sandbox_config.return_value = SandboxConfig(
type="local", organization_id="org-123", config=sample_local_sandbox_config.model_dump()
)
response = client.post("/v1/sandbox-config/local", json=sample_local_sandbox_config.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
assert response.json()["type"] == "local"
assert response.json()["config"]["sandbox_dir"] == "/custom/path"
assert response.json()["config"]["pip_requirements"] == [
{"name": "numpy", "version": "1.23.0"},
{"name": "pandas", "version": None},
]
mock_sync_server.sandbox_config_manager.create_or_update_sandbox_config.assert_called_once()