Merge branch 'main' into bump-0.6.33

This commit is contained in:
Sarah Wooders
2025-02-25 17:45:04 -08:00
31 changed files with 831 additions and 259 deletions

View File

@@ -245,10 +245,13 @@ class Agent(BaseAgent):
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
)
else:
# Parse the source code to extract function annotations
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
# Coerce the function arguments to the correct types based on the annotations
function_args = coerce_dict_args_by_annotations(function_args, annotations)
try:
# Parse the source code to extract function annotations
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
# Coerce the function arguments to the correct types based on the annotations
function_args = coerce_dict_args_by_annotations(function_args, annotations)
except ValueError as e:
self.logger.debug(f"Error coercing function arguments: {e}")
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
@@ -257,7 +260,9 @@ class Agent(BaseAgent):
agent_state_copy.tools = []
agent_state_copy.tool_rules = []
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(agent_state=agent_state_copy)
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
agent_state=agent_state_copy
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
if updated_agent_state is not None:

View File

@@ -52,6 +52,8 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
# Multi agent tools
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
# Set of all built-in Letta tools
LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS)
# The name of the tool used to send message to the user
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
@@ -59,6 +61,11 @@ MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_t
DEFAULT_MESSAGE_TOOL = "send_message"
DEFAULT_MESSAGE_TOOL_KWARG = "message"
PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"
REQUEST_HEARTBEAT_PARAM = "request_heartbeat"
# Structured output models
STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"}

View File

@@ -6,10 +6,11 @@ from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.settings import tool_settings
def get_composio_api_key(actor: User, logger: Logger) -> Optional[str]:
def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Optional[str]:
api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
if not api_keys:
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
if logger:
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
if tool_settings.composio_api_key:
return tool_settings.composio_api_key
else:

View File

@@ -0,0 +1,171 @@
from collections import OrderedDict
from typing import Any, Dict, Optional
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
from letta.helpers.composio_helpers import get_composio_api_key
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.utils import get_friendly_error_msg
def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
"""Enables strict mode for a tool schema by setting 'strict' to True and
disallowing additional properties in the parameters.
Args:
tool_schema (Dict[str, Any]): The original tool schema.
Returns:
Dict[str, Any]: A new tool schema with strict mode enabled.
"""
schema = tool_schema.copy()
# Enable strict mode
schema["strict"] = True
# Ensure parameters is a valid dictionary
parameters = schema.get("parameters", {})
if isinstance(parameters, dict) and parameters.get("type") == "object":
# Set additionalProperties to False
parameters["additionalProperties"] = False
schema["parameters"] = parameters
return schema
def add_pre_execution_message(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
"""Adds a `pre_execution_message` parameter to a tool schema to prompt a natural, human-like message before executing the tool.
Args:
tool_schema (Dict[str, Any]): The original tool schema.
Returns:
Dict[str, Any]: A new tool schema with the `pre_execution_message` field added at the beginning.
"""
schema = tool_schema.copy()
parameters = schema.get("parameters", {})
if not isinstance(parameters, dict) or parameters.get("type") != "object":
return schema # Do not modify if schema is not valid
properties = parameters.get("properties", {})
required = parameters.get("required", [])
# Define the new `pre_execution_message` field with a refined description
pre_execution_message_field = {
"type": "string",
"description": (
"A concise message to be uttered before executing this tool. "
"This should sound natural, as if a person is casually announcing their next action."
"You MUST also include punctuation at the end of this message."
),
}
# Ensure the pre-execution message is the first field in properties
updated_properties = OrderedDict()
updated_properties[PRE_EXECUTION_MESSAGE_ARG] = pre_execution_message_field
updated_properties.update(properties) # Retain all existing properties
# Ensure pre-execution message is the first required field
if PRE_EXECUTION_MESSAGE_ARG not in required:
required = [PRE_EXECUTION_MESSAGE_ARG] + required
# Update the schema with ordered properties and required list
schema["parameters"] = {
**parameters,
"properties": dict(updated_properties), # Convert OrderedDict back to dict
"required": required,
}
return schema
def remove_request_heartbeat(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
"""Removes the `request_heartbeat` parameter from a tool schema if it exists.
Args:
tool_schema (Dict[str, Any]): The original tool schema.
Returns:
Dict[str, Any]: A new tool schema without `request_heartbeat`.
"""
schema = tool_schema.copy()
parameters = schema.get("parameters", {})
if isinstance(parameters, dict):
properties = parameters.get("properties", {})
required = parameters.get("required", [])
# Remove the `request_heartbeat` property if it exists
if "request_heartbeat" in properties:
properties.pop("request_heartbeat")
# Remove `request_heartbeat` from required fields if present
if "request_heartbeat" in required:
required = [r for r in required if r != "request_heartbeat"]
# Update parameters with modified properties and required list
schema["parameters"] = {**parameters, "properties": properties, "required": required}
return schema
# TODO: Deprecate the `execute_external_tool` function on the agent body
def execute_external_tool(
agent_state: AgentState,
function_name: str,
function_args: dict,
target_letta_tool: Tool,
actor: User,
allow_agent_state_modifications: bool = False,
) -> tuple[Any, Optional[SandboxRunResult]]:
# TODO: need to have an AgentState object that actually has full access to the block data
# this is because the sandbox tools need to be able to access block.value to edit this data
try:
if target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO:
action_name = generate_composio_action_from_func_name(target_letta_tool.name)
# Get entity ID from the agent_state
entity_id = None
for env_var in agent_state.tool_exec_environment_variables:
if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY:
entity_id = env_var.value
# Get composio_api_key
composio_api_key = get_composio_api_key(actor=actor)
function_response = execute_composio_action(
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
)
return function_response, None
elif target_letta_tool.tool_type == ToolType.CUSTOM:
# Parse the source code to extract function annotations
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
# Coerce the function arguments to the correct types based on the annotations
function_args = coerce_dict_args_by_annotations(function_args, annotations)
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
# TODO: This is only temporary, can remove after we publish a pip package with this object
if allow_agent_state_modifications:
agent_state_copy = agent_state.__deepcopy__()
agent_state_copy.tools = []
agent_state_copy.tool_rules = []
else:
agent_state_copy = None
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
# TODO: Bring this back
# if allow_agent_state_modifications and updated_agent_state is not None:
# self.update_memory_if_changed(updated_agent_state.memory)
return function_response, sandbox_run_result
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
function_response = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
return function_response, None

View File

@@ -47,14 +47,39 @@ BASE_URL = "https://api.anthropic.com/v1"
# https://docs.anthropic.com/claude/docs/models-overview
# Sadly hardcoded
MODEL_LIST = [
## Opus
{
"name": "claude-3-opus-20240229",
"context_window": 200000,
},
## Sonnet
# 3.0
{
"name": "claude-3-sonnet-20240229",
"context_window": 200000,
},
# 3.5
{
"name": "claude-3-5-sonnet-20240620",
"context_window": 200000,
},
# 3.5 new
{
"name": "claude-3-5-sonnet-20241022",
"context_window": 200000,
},
# 3.7
{
"name": "claude-3-7-sonnet-20250219",
"context_window": 200000,
},
## Haiku
# 3.0
{
"name": "claude-3-haiku-20240307",
"context_window": 200000,
},
# 3.5
{
"name": "claude-3-5-haiku-20241022",
"context_window": 200000,
@@ -75,7 +100,18 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
"""https://docs.anthropic.com/claude/docs/models-overview"""
# NOTE: currently there is no GET /models, so we need to hardcode
return MODEL_LIST
# return MODEL_LIST
anthropic_override_key = ProviderManager().get_anthropic_override_key()
if anthropic_override_key:
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
elif model_settings.anthropic_api_key:
anthropic_client = anthropic.Anthropic()
models = anthropic_client.models.list()
models_json = models.model_dump()
assert "data" in models_json, f"Anthropic model query response missing 'data' field: {models_json}"
return models_json["data"]
def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:

View File

@@ -42,6 +42,6 @@ class Source(SqlalchemyBase, OrganizationMixin):
secondary="sources_agents",
back_populates="sources",
lazy="selectin",
cascade="all, delete", # Ensures rows in sources_agents are deleted when the source is deleted
passive_deletes=True, # Allows the database to handle deletion of orphaned rows
cascade="save-update", # Only propagate save and update operations
passive_deletes=True, # Let the database handle deletions
)

View File

@@ -99,7 +99,7 @@ class ChatCompletionRequest(BaseModel):
"""https://platform.openai.com/docs/api-reference/chat/create"""
model: str
messages: List[ChatMessage]
messages: List[Union[ChatMessage, Dict]]
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = False

View File

@@ -410,28 +410,67 @@ class AnthropicProvider(Provider):
base_url: str = "https://api.anthropic.com/v1"
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.anthropic import anthropic_get_model_list
from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list
models = anthropic_get_model_list(self.base_url, api_key=self.api_key)
"""
Example response:
{
"data": [
{
"type": "model",
"id": "claude-3-5-sonnet-20241022",
"display_name": "Claude 3.5 Sonnet (New)",
"created_at": "2024-10-22T00:00:00Z"
}
],
"has_more": true,
"first_id": "<string>",
"last_id": "<string>"
}
"""
configs = []
for model in models:
if model["type"] != "model":
continue
if "id" not in model:
continue
# Don't support 2.0 and 2.1
if model["id"].startswith("claude-2"):
continue
# Anthropic doesn't return the context window in their API
if "context_window" not in model:
# Remap list to name: context_window
model_library = {m["name"]: m["context_window"] for m in MODEL_LIST}
# Attempt to look it up in a hardcoded list
if model["id"] in model_library:
model["context_window"] = model_library[model["id"]]
else:
# On fallback, we can set 200k (generally safe), but we should warn the user
warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000")
model["context_window"] = 200000
# We set this to false by default, because Anthropic can
# natively support <thinking> tags inside of content fields
# However, putting COT inside of tool calls can make it more
# reliable for tool calling (no chance of a non-tool call step)
# Since tool_choice_type 'any' doesn't work with in-content COT
# NOTE For Haiku, it can be flaky if we don't enable this by default
inner_thoughts_in_kwargs = True if "haiku" in model["name"] else False
inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False
configs.append(
LLMConfig(
model=model["name"],
model=model["id"],
model_endpoint_type="anthropic",
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["name"]),
handle=self.get_handle(model["id"]),
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
)
)

View File

@@ -9,7 +9,7 @@ from letta.constants import (
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
)
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
from letta.functions.helpers import generate_composio_action_from_func_name, generate_composio_tool_wrapper, generate_langchain_tool_wrapper
from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper
from letta.functions.schema_generator import generate_schema_from_args_schema_v2, generate_tool_schema_for_composio
from letta.log import get_logger
from letta.orm.enums import ToolType
@@ -77,18 +77,6 @@ class Tool(BaseTool):
elif self.tool_type in {ToolType.LETTA_MULTI_AGENT_CORE}:
# If it's letta multi-agent tool, we also generate the json_schema on the fly here
self.json_schema = get_json_schema_from_module(module_name=LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name=self.name)
elif self.tool_type == ToolType.EXTERNAL_COMPOSIO:
# If it is a composio tool, we generate both the source code and json schema on the fly here
# TODO: Deriving the composio action name is brittle, need to think long term about how to improve this
try:
composio_action = generate_composio_action_from_func_name(self.name)
tool_create = ToolCreate.from_composio(composio_action)
self.source_code = tool_create.source_code
self.json_schema = tool_create.json_schema
self.description = tool_create.description
self.tags = tool_create.tags
except Exception as e:
logger.error(f"Encountered exception while attempting to refresh source_code and json_schema for composio_tool: {e}")
# At this point, we need to validate that at least json_schema is populated
if not self.json_schema:

View File

@@ -1,19 +1,39 @@
import asyncio
import json
import uuid
from typing import TYPE_CHECKING, List, Optional, Union
import httpx
import openai
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
from openai.types.chat.completion_create_params import CompletionCreateParams
from starlette.concurrency import run_in_threadpool
from letta.agent import Agent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_TOOL_SET, NON_USER_MSG_PREFIX, PRE_EXECUTION_MESSAGE_ARG
from letta.helpers.tool_execution_helper import (
add_pre_execution_message,
enable_strict_mode,
execute_external_tool,
remove_request_heartbeat,
)
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_request import (
AssistantMessage,
ChatCompletionRequest,
Tool,
ToolCall,
ToolCallFunction,
ToolMessage,
UserMessage,
)
from letta.schemas.user import User
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
# TODO this belongs in a controller!
from letta.server.rest_api.utils import (
@@ -52,20 +72,53 @@ async def create_fast_chat_completions(
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
# TODO: This is necessary, we need to factor out CompletionCreateParams due to weird behavior
actor = server.user_manager.get_user_or_default(user_id=user_id)
agent_id = str(completion_request.get("user", None))
if agent_id is None:
error_msg = "Must pass agent_id in the 'user' field"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
model = completion_request.get("model")
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
actor = server.user_manager.get_user_or_default(user_id=user_id)
agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state.llm_config.model_endpoint_type != "openai":
raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.")
# Convert Letta messages to OpenAI messages
in_context_messages = server.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor)
openai_messages = convert_letta_messages_to_openai(in_context_messages)
# Also parse user input from completion_request and append
input_message = get_messages_from_completion_request(completion_request)[-1]
openai_messages.append(input_message)
# Tools we allow this agent to call
tools = [t for t in agent_state.tools if t.name not in LETTA_TOOL_SET and t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}]
# Initial request
openai_request = ChatCompletionRequest(
model=agent_state.llm_config.model,
messages=openai_messages,
# TODO: This nested thing here is so ugly, need to refactor
tools=(
[
Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema))))
for t in tools
]
if tools
else None
),
tool_choice="auto",
user=user_id,
max_completion_tokens=agent_state.llm_config.max_tokens,
temperature=agent_state.llm_config.temperature,
stream=True,
)
# Create the OpenAI async client
client = openai.AsyncClient(
api_key=model_settings.openai_api_key,
max_retries=0,
http_client=httpx.AsyncClient(
timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
timeout=httpx.Timeout(connect=15.0, read=30.0, write=15.0, pool=15.0),
follow_redirects=True,
limits=httpx.Limits(
max_connections=50,
@@ -75,38 +128,175 @@ async def create_fast_chat_completions(
),
)
# Magic message manipulating
input_message = get_messages_from_completion_request(completion_request)[-1]
completion_request.pop("messages")
# Get in context messages
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor)
openai_dict_in_context_messages = convert_letta_messages_to_openai(in_context_messages)
openai_dict_in_context_messages.append(input_message)
# The messages we want to persist to the Letta agent
user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor)
message_db_queue = [user_message]
async def event_stream():
# TODO: Factor this out into separate interface
response_accumulator = []
"""
A function-calling loop:
- We stream partial tokens.
- If we detect a tool call (finish_reason="tool_calls"), we parse it,
add two messages to the conversation:
(a) assistant message with tool_calls referencing the same ID
(b) a tool message referencing that ID, containing the tool result.
- Re-invoke the OpenAI request with updated conversation, streaming again.
- End when finish_reason="stop" or no more tool calls.
"""
stream = await client.chat.completions.create(**completion_request, messages=openai_dict_in_context_messages)
# We'll keep updating this conversation in a loop
conversation = openai_messages[:]
async with stream:
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
# TODO: This does not support tool calling right now
response_accumulator.append(chunk.choices[0].delta.content)
yield f"data: {chunk.model_dump_json()}\n\n"
while True:
# Make the streaming request to OpenAI
stream = await client.chat.completions.create(**openai_request.model_dump(exclude_unset=True))
# Construct messages
user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor)
assistant_message = create_assistant_message_from_openai_response(
response_text="".join(response_accumulator), agent_id=agent_id, model=str(model), actor=actor
)
content_buffer = []
tool_call_name = None
tool_call_args_str = ""
tool_call_id = None
tool_call_happened = False
finish_reason_stop = False
optimistic_json_parser = OptimisticJSONParser(strict=True)
current_parsed_json_result = {}
async with stream:
async for chunk in stream:
choice = chunk.choices[0]
delta = choice.delta
finish_reason = choice.finish_reason # "tool_calls", "stop", or None
if delta.content:
content_buffer.append(delta.content)
yield f"data: {chunk.model_dump_json()}\n\n"
# CASE B: Partial tool call info
if delta.tool_calls:
# Typically there's only one in delta.tool_calls
tc = delta.tool_calls[0]
if tc.function.name:
tool_call_name = tc.function.name
if tc.function.arguments:
tool_call_args_str += tc.function.arguments
# See if we can stream out the pre-execution message
parsed_args = optimistic_json_parser.parse(tool_call_args_str)
if parsed_args.get(
PRE_EXECUTION_MESSAGE_ARG
) and current_parsed_json_result.get( # Ensure key exists and is not None/empty
PRE_EXECUTION_MESSAGE_ARG
) != parsed_args.get(
PRE_EXECUTION_MESSAGE_ARG
):
# Only stream if there's something new to stream
# We do this way to avoid hanging JSON at the end of the stream, e.g. '}'
if parsed_args != current_parsed_json_result:
current_parsed_json_result = parsed_args
synthetic_chunk = ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created,
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(content=tc.function.arguments, role="assistant"),
finish_reason=None,
)
],
)
yield f"data: {synthetic_chunk.model_dump_json()}\n\n"
# We might generate a unique ID for the tool call
if tc.id:
tool_call_id = tc.id
# Check finish_reason
if finish_reason == "tool_calls":
tool_call_happened = True
break
elif finish_reason == "stop":
finish_reason_stop = True
break
if content_buffer:
# We treat that partial text as an assistant message
content = "".join(content_buffer)
conversation.append({"role": "assistant", "content": content})
# Create an assistant message here to persist later
assistant_message = create_assistant_message_from_openai_response(
response_text=content, agent_id=agent_id, model=agent_state.llm_config.model, actor=actor
)
message_db_queue.append(assistant_message)
if tool_call_happened:
# Parse the tool call arguments
try:
tool_args = json.loads(tool_call_args_str)
except json.JSONDecodeError:
tool_args = {}
if not tool_call_id:
# If no tool_call_id given by the model, generate one
tool_call_id = f"call_{uuid.uuid4().hex[:8]}"
# 1) Insert the "assistant" message with the tool_calls field
# referencing the same tool_call_id
assistant_tool_call_msg = AssistantMessage(
content=None,
tool_calls=[ToolCall(id=tool_call_id, function=ToolCallFunction(name=tool_call_name, arguments=tool_call_args_str))],
)
conversation.append(assistant_tool_call_msg.model_dump())
# 2) Execute the tool
target_tool = next((x for x in tools if x.name == tool_call_name), None)
if not target_tool:
# Tool not found, handle error
yield f"data: {json.dumps({'error': 'Tool not found', 'tool': tool_call_name})}\n\n"
break
try:
tool_result, _ = execute_external_tool(
agent_state=agent_state,
function_name=tool_call_name,
function_args=tool_args,
target_letta_tool=target_tool,
actor=actor,
allow_agent_state_modifications=False,
)
except Exception as e:
tool_result = f"Failed to call tool. Error: {e}"
# 3) Insert the "tool" message referencing the same tool_call_id
tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id)
conversation.append(tool_message.model_dump())
# 4) Add a user message prompting the tool call result summarization
heartbeat_user_message = UserMessage(
content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user.",
)
conversation.append(heartbeat_user_message.model_dump())
# Now, re-invoke OpenAI with the updated conversation
openai_request.messages = conversation
continue # Start the while loop again
if finish_reason_stop:
# Model is done, no more calls
break
# If we reach here, no tool call, no "stop", but we've ended streaming
# Possibly a model error or some other finish reason. We'll just end.
break
# Persist both in one synchronous DB call, done in a threadpool
await run_in_threadpool(
server.agent_manager.append_to_in_context_messages,
[user_message, assistant_message],
message_db_queue,
agent_id=agent_id,
actor=actor,
)

View File

@@ -44,7 +44,7 @@ def list_agents(
description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed in tags.",
),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
before: Optional[str] = Query(None, description="Cursor for pagination"),
after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(None, description="Limit for pagination"),
@@ -52,13 +52,14 @@ def list_agents(
project_id: Optional[str] = Query(None, description="Search agents by project id"),
template_id: Optional[str] = Query(None, description="Search agents by template id"),
base_template_id: Optional[str] = Query(None, description="Search agents by base template id"),
identifier_id: Optional[str] = Query(None, description="Search agents by identifier id"),
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
):
"""
List all agents associated with a given user.
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
# Use dictionary comprehension to build kwargs dynamically
kwargs = {
@@ -68,6 +69,7 @@ def list_agents(
"project_id": project_id,
"template_id": template_id,
"base_template_id": base_template_id,
"identifier_id": identifier_id,
}.items()
if value is not None
}
@@ -91,12 +93,12 @@ def list_agents(
def retrieve_agent_context_window(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the context window of a specific agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.get_agent_context_window(agent_id=agent_id, actor=actor)
@@ -107,21 +109,21 @@ class CreateAgentRequest(CreateAgent):
"""
# Override the user_id field to exclude it from the request body validation
user_id: Optional[str] = Field(None, exclude=True)
actor_id: Optional[str] = Field(None, exclude=True)
@router.post("/", response_model=AgentState, operation_id="create_agent")
def create_agent(
agent: CreateAgentRequest = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
):
"""
Create a new agent with the specified configuration.
"""
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.create_agent(agent, actor=actor)
except Exception as e:
traceback.print_exc()
@@ -133,10 +135,10 @@ def modify_agent(
agent_id: str,
update_agent: UpdateAgent = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Update an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.update_agent(agent_id=agent_id, agent_update=update_agent, actor=actor)
@@ -144,10 +146,10 @@ def modify_agent(
def list_agent_tools(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Get tools from an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.list_attached_tools(agent_id=agent_id, actor=actor)
@@ -156,12 +158,12 @@ def attach_tool(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Attach a tool to an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
@@ -170,12 +172,12 @@ def detach_tool(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Detach a tool from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
@@ -184,12 +186,12 @@ def attach_source(
agent_id: str,
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Attach a source to an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor)
@@ -198,12 +200,12 @@ def detach_source(
agent_id: str,
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Detach a source from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
@@ -211,12 +213,12 @@ def detach_source(
def retrieve_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get the state of the agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
@@ -228,12 +230,12 @@ def retrieve_agent(
def delete_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
server.agent_manager.delete_agent(agent_id=agent_id, actor=actor)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"})
@@ -245,12 +247,12 @@ def delete_agent(
def list_agent_sources(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get the sources associated with an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor)
@@ -259,13 +261,13 @@ def list_agent_sources(
def retrieve_agent_memory(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memory state of a specific agent.
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.get_agent_memory(agent_id=agent_id, actor=actor)
@@ -275,12 +277,12 @@ def retrieve_core_memory_block(
agent_id: str,
block_label: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve a memory block from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
return server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
@@ -292,12 +294,12 @@ def retrieve_core_memory_block(
def list_core_memory_blocks(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memory blocks of a specific agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor)
return agent.memory.blocks
@@ -311,12 +313,12 @@ def modify_core_memory_block(
block_label: str,
block_update: BlockUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Updates a memory block of an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
block = server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
@@ -332,12 +334,12 @@ def attach_core_memory_block(
agent_id: str,
block_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Attach a block to an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
@@ -346,12 +348,12 @@ def detach_core_memory_block(
agent_id: str,
block_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Detach a block from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
@@ -362,12 +364,12 @@ def list_archival_memory(
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.get_agent_archival(
user_id=actor.id,
@@ -383,12 +385,12 @@ def create_archival_memory(
agent_id: str,
request: CreateArchivalMemory = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Insert a memory into an agent's archival memory store.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.insert_archival_memory(agent_id=agent_id, memory_contents=request.text, actor=actor)
@@ -401,12 +403,12 @@ def delete_archival_memory(
memory_id: str,
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a memory from an agent's archival memory store.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.delete_archival_memory(memory_id=memory_id, actor=actor)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
@@ -427,12 +429,12 @@ def list_messages(
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve message history for an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.get_agent_recall(
user_id=actor.id,
@@ -454,13 +456,13 @@ def modify_message(
message_id: str,
request: MessageUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the details of a message associated with an agent.
"""
# TODO: Get rid of agent_id here, it's not really relevant
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
@@ -474,13 +476,13 @@ async def send_message(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
@@ -513,14 +515,14 @@ async def send_message_streaming(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaStreamingRequest = Body(...),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
@@ -590,13 +592,13 @@ async def send_message_async(
background_tasks: BackgroundTasks,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Asynchronously process a user message and return a run object.
The actual processing happens in the background, and the status can be checked using the run ID.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
# Create a new job
run = Run(
@@ -635,8 +637,8 @@ def reset_messages(
agent_id: str,
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Resets the messages for an agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)

View File

@@ -21,9 +21,9 @@ def list_blocks(
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name)
@@ -31,9 +31,9 @@ def list_blocks(
def create_block(
create_block: CreateBlock = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
block = Block(**create_block.model_dump())
return server.block_manager.create_or_update_block(actor=actor, block=block)
@@ -43,9 +43,9 @@ def modify_block(
block_id: str,
block_update: BlockUpdate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.block_manager.update_block(block_id=block_id, block_update=block_update, actor=actor)
@@ -53,9 +53,9 @@ def modify_block(
def delete_block(
block_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.block_manager.delete_block(block_id=block_id, actor=actor)
@@ -63,10 +63,10 @@ def delete_block(
def retrieve_block(
block_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
print("call get block", block_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
if block is None:
@@ -80,13 +80,13 @@ def retrieve_block(
def list_agents_for_block(
block_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Retrieves all agents associated with the specified block.
Raises a 404 if the block does not exist.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor)
return agents

View File

@@ -22,13 +22,13 @@ def list_identities(
after: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a list of all identities in the database
"""
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
identities = server.identity_manager.list_identities(
name=name,
@@ -51,10 +51,10 @@ def list_identities(
def retrieve_identity(
identity_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.get_identity(identity_id=identity_id, actor=actor)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -64,11 +64,11 @@ def retrieve_identity(
def create_identity(
identity: IdentityCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
):
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.create_identity(identity=identity, actor=actor)
except HTTPException:
raise
@@ -80,11 +80,11 @@ def create_identity(
def upsert_identity(
identity: IdentityCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
):
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.upsert_identity(identity=identity, actor=actor)
except HTTPException:
raise
@@ -97,10 +97,10 @@ def modify_identity(
identity_id: str,
identity: IdentityUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor)
except HTTPException:
raise
@@ -112,10 +112,10 @@ def modify_identity(
def delete_identity(
identity_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete an identity by its identifier key
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.identity_manager.delete_identity(identity_id=identity_id, actor=actor)

View File

@@ -15,12 +15,12 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
def list_jobs(
server: "SyncServer" = Depends(get_letta_server),
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all jobs.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
# TODO: add filtering by status
jobs = server.job_manager.list_jobs(actor=actor)
@@ -35,12 +35,12 @@ def list_jobs(
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
def list_active_jobs(
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all active jobs.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running])
@@ -48,13 +48,13 @@ def list_active_jobs(
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")
def retrieve_job(
job_id: str,
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Get the status of a job.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
return server.job_manager.get_job_by_id(job_id=job_id, actor=actor)
@@ -65,13 +65,13 @@ def retrieve_job(
@router.delete("/{job_id}", response_model=Job, operation_id="delete_job")
def delete_job(
job_id: str,
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Delete a job by its job_id.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
job = server.job_manager.delete_job_by_id(job_id=job_id, actor=actor)

View File

@@ -15,13 +15,15 @@ router = APIRouter(prefix="/providers", tags=["providers"])
def list_providers(
after: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Get a list of all custom providers in the database
"""
try:
providers = server.provider_manager.list_providers(after=after, limit=limit)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor)
except HTTPException:
raise
except Exception as e:
@@ -32,13 +34,13 @@ def list_providers(
@router.post("/", tags=["providers"], response_model=Provider, operation_id="create_provider")
def create_provider(
request: ProviderCreate = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new custom provider
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
provider = Provider(**request.model_dump())
provider = server.provider_manager.create_provider(provider, actor=actor)
@@ -48,25 +50,29 @@ def create_provider(
@router.patch("/", tags=["providers"], response_model=Provider, operation_id="modify_provider")
def modify_provider(
request: ProviderUpdate = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
server: "SyncServer" = Depends(get_letta_server),
):
"""
Update an existing custom provider
"""
provider = server.provider_manager.update_provider(request)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
provider = server.provider_manager.update_provider(request, actor=actor)
return provider
@router.delete("/", tags=["providers"], response_model=None, operation_id="delete_provider")
def delete_provider(
provider_id: str = Query(..., description="The provider_id key to be deleted."),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Delete an existing custom provider
"""
try:
server.provider_manager.delete_provider_by_id(provider_id=provider_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.provider_manager.delete_provider_by_id(provider_id=provider_id, actor=actor)
except HTTPException:
raise
except Exception as e:

View File

@@ -18,12 +18,12 @@ router = APIRouter(prefix="/runs", tags=["runs"])
@router.get("/", response_model=List[Run], operation_id="list_runs")
def list_runs(
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all runs.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)]
@@ -31,12 +31,12 @@ def list_runs(
@router.get("/active", response_model=List[Run], operation_id="list_active_runs")
def list_active_runs(
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all active runs.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
active_runs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN)
@@ -46,13 +46,13 @@ def list_active_runs(
@router.get("/{run_id}", response_model=Run, operation_id="retrieve_run")
def retrieve_run(
run_id: str,
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Get the status of a run.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
@@ -74,7 +74,7 @@ RunMessagesResponse = Annotated[
async def list_run_messages(
run_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
before: Optional[str] = Query(None, description="Cursor for pagination"),
after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
@@ -102,7 +102,7 @@ async def list_run_messages(
if order not in ["asc", "desc"]:
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
messages = server.job_manager.get_run_messages(
@@ -122,13 +122,13 @@ async def list_run_messages(
@router.get("/{run_id}/usage", response_model=UsageStatistics, operation_id="retrieve_run_usage")
def retrieve_run_usage(
run_id: str,
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Get usage statistics for a run.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
usage = server.job_manager.get_job_usage(job_id=run_id, actor=actor)
@@ -140,13 +140,13 @@ def retrieve_run_usage(
@router.delete("/{run_id}", response_model=Run, operation_id="delete_run")
def delete_run(
run_id: str,
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Delete a run by its run_id.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
job = server.job_manager.delete_job_by_id(job_id=run_id, actor=actor)

View File

@@ -25,9 +25,9 @@ logger = get_logger(__name__)
def create_sandbox_config(
config_create: SandboxConfigCreate,
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor)
@@ -35,18 +35,18 @@ def create_sandbox_config(
@router.post("/e2b/default", response_model=PydanticSandboxConfig)
def create_default_e2b_sandbox_config(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=actor)
@router.post("/local/default", response_model=PydanticSandboxConfig)
def create_default_local_sandbox_config(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
@@ -54,7 +54,7 @@ def create_default_local_sandbox_config(
def create_custom_local_sandbox_config(
local_sandbox_config: LocalSandboxConfig,
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
"""
Create or update a custom LocalSandboxConfig, including pip_requirements.
@@ -67,7 +67,7 @@ def create_custom_local_sandbox_config(
)
# Retrieve the user (actor)
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
# Wrap the LocalSandboxConfig into a SandboxConfigCreate
sandbox_config_create = SandboxConfigCreate(config=local_sandbox_config)
@@ -83,9 +83,9 @@ def update_sandbox_config(
sandbox_config_id: str,
config_update: SandboxConfigUpdate,
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor)
@@ -93,9 +93,9 @@ def update_sandbox_config(
def delete_sandbox_config(
sandbox_config_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor)
@@ -105,22 +105,22 @@ def list_sandbox_configs(
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
sandbox_type: Optional[SandboxType] = Query(None, description="Filter for this specific sandbox type"),
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, after=after, sandbox_type=sandbox_type)
@router.post("/local/recreate-venv", response_model=PydanticSandboxConfig)
def force_recreate_local_sandbox_venv(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
"""
Forcefully recreate the virtual environment for the local sandbox.
Deletes and recreates the venv, then reinstalls required dependencies.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
# Retrieve the local sandbox config
sbx_config = server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
@@ -162,9 +162,9 @@ def create_sandbox_env_var(
sandbox_config_id: str,
env_var_create: SandboxEnvironmentVariableCreate,
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor)
@@ -173,9 +173,9 @@ def update_sandbox_env_var(
env_var_id: str,
env_var_update: SandboxEnvironmentVariableUpdate,
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor)
@@ -183,9 +183,9 @@ def update_sandbox_env_var(
def delete_sandbox_env_var(
env_var_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor)
@@ -195,7 +195,7 @@ def list_sandbox_env_vars(
limit: int = Query(1000, description="Number of results to return"),
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, after=after)

View File

@@ -23,12 +23,12 @@ router = APIRouter(prefix="/sources", tags=["sources"])
def retrieve_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get all sources
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
if not source:
@@ -40,12 +40,12 @@ def retrieve_source(
def get_source_id_by_name(
source_name: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a source by name
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
if not source:
@@ -56,12 +56,12 @@ def get_source_id_by_name(
@router.get("/", response_model=List[Source], operation_id="list_sources")
def list_sources(
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all data sources created by a user.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.list_all_sources(actor=actor)
@@ -70,12 +70,12 @@ def list_sources(
def create_source(
source_create: SourceCreate,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new data source.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = Source(**source_create.model_dump())
return server.source_manager.create_source(source=source, actor=actor)
@@ -86,12 +86,12 @@ def modify_source(
source_id: str,
source: SourceUpdate,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the name or documentation of an existing data source.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")
return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
@@ -101,12 +101,12 @@ def modify_source(
def delete_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.delete_source(source_id=source_id, actor=actor)
@@ -117,12 +117,12 @@ def upload_file_to_source(
source_id: str,
background_tasks: BackgroundTasks,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Upload a file to a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
assert source is not None, f"Source with id={source_id} not found."
@@ -151,12 +151,12 @@ def upload_file_to_source(
def list_source_passages(
source_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all passages associated with a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
return passages
@@ -167,12 +167,12 @@ def list_source_files(
limit: int = Query(1000, description="Number of files to return"),
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List paginated files associated with a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor)
@@ -183,12 +183,12 @@ def delete_file_from_source(
source_id: str,
file_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
if deleted_file is None:

View File

@@ -21,13 +21,13 @@ def list_steps(
end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
List steps with optional pagination and date filters.
Dates should be provided in ISO 8601 format (e.g. 2025-01-29T15:01:19-08:00)
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
# Convert ISO strings to datetime objects if provided
start_dt = datetime.fromisoformat(start_date) if start_date else None
@@ -48,14 +48,15 @@ def list_steps(
@router.get("/{step_id}", response_model=Step, operation_id="retrieve_step")
def retrieve_step(
step_id: str,
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: SyncServer = Depends(get_letta_server),
):
"""
Get a step by ID.
"""
try:
return server.step_manager.get_step(step_id=step_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.step_manager.get_step(step_id=step_id, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail="Step not found")
@@ -64,15 +65,15 @@ def retrieve_step(
def update_step_transaction_id(
step_id: str,
transaction_id: str,
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: SyncServer = Depends(get_letta_server),
):
"""
Update the transaction ID for a step.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
return server.step_manager.update_step_transaction_id(actor, step_id=step_id, transaction_id=transaction_id)
return server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id)
except NoResultFound:
raise HTTPException(status_code=404, detail="Step not found")

View File

@@ -17,11 +17,11 @@ def list_tags(
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_letta_server),
query_text: Optional[str] = Query(None),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Get a list of all tags in the database
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
tags = server.agent_manager.list_tags(actor=actor, after=after, limit=limit, query_text=query_text)
return tags

View File

@@ -29,12 +29,12 @@ logger = get_logger(__name__)
def delete_tool(
tool_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a tool by name
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.tool_manager.delete_tool_by_id(tool_id=tool_id, actor=actor)
@@ -42,12 +42,12 @@ def delete_tool(
def retrieve_tool(
tool_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a tool by ID
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
if tool is None:
# return 404 error
@@ -61,13 +61,13 @@ def list_tools(
limit: Optional[int] = 50,
name: Optional[str] = None,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a list of all tools available to agents belonging to the org of the user
"""
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
if name is not None:
tool = server.tool_manager.get_tool_by_name(tool_name=name, actor=actor)
return [tool] if tool else []
@@ -82,13 +82,13 @@ def list_tools(
def create_tool(
request: ToolCreate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new tool
"""
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
tool = Tool(**request.model_dump())
return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor)
except UniqueConstraintViolationError as e:
@@ -114,13 +114,13 @@ def create_tool(
def upsert_tool(
request: ToolCreate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Create or update a tool
"""
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor)
return tool
except UniqueConstraintViolationError as e:
@@ -142,13 +142,13 @@ def modify_tool(
tool_id: str,
request: ToolUpdate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update an existing tool
"""
try:
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.tool_manager.update_tool_by_id(tool_id=tool_id, tool_update=request, actor=actor)
except LettaToolCreateError as e:
# HTTP 400 == Bad Request
@@ -163,12 +163,12 @@ def modify_tool(
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
def upsert_base_tools(
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Upsert base tools
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.tool_manager.upsert_base_tools(actor=actor)
@@ -176,12 +176,12 @@ def upsert_base_tools(
def run_tool_from_source(
server: SyncServer = Depends(get_letta_server),
request: ToolRunFromSource = Body(...),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Attempt to build a tool from source, then run it on the provided arguments
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
return server.run_tool_from_source(
@@ -227,12 +227,12 @@ def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id:
def list_composio_actions_by_app(
composio_app_name: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Get a list of all Composio actions for a specific app
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
composio_api_key = get_composio_api_key(actor=actor, logger=logger)
if not composio_api_key:
raise HTTPException(
@@ -246,12 +246,12 @@ def list_composio_actions_by_app(
def add_composio_tool(
composio_action_name: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
tool_create = ToolCreate.from_composio(action_name=composio_action_name)

View File

@@ -7,7 +7,6 @@ from datetime import datetime, timezone
from enum import Enum
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
import pytz
from fastapi import Header, HTTPException
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
@@ -145,7 +144,7 @@ def create_user_message(input_message: dict, agent_id: str, actor: User) -> Mess
Converts a user input message into the internal structured format.
"""
# Generate timestamp in the correct format
now = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
now = datetime.now(timezone.utc).isoformat()
# Format message as structured JSON
structured_message = {"type": "user_message", "message": input_message["content"], "time": now}
@@ -197,7 +196,7 @@ def create_assistant_message_from_openai_response(
agent_id=agent_id,
model=model,
tool_calls=[tool_call],
tool_call_id=None,
tool_call_id=tool_call_id,
created_at=datetime.now(timezone.utc),
)

View File

@@ -21,8 +21,10 @@ from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate
from letta.schemas.passage import Passage as PydanticPassage
@@ -613,6 +615,40 @@ class AgentManager:
)
return self.append_to_in_context_messages([system_message], agent_id=agent_state.id, actor=actor)
# TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version
@enforce_types
def update_memory_if_changed(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
"""
Update internal memory object and system prompt if there have been modifications.
Args:
new_memory (Memory): the new memory object to compare to the current memory object
Returns:
modified (bool): whether the memory was updated
"""
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state.memory.compile() != new_memory.compile():
# update the blocks (LRW) in the DB
for label in agent_state.memory.list_block_labels():
updated_value = new_memory.get_block(label).value
if updated_value != agent_state.memory.get_block(label).value:
# update the block if it's changed
block_id = agent_state.memory.get_block(label).id
block = self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor)
# refresh memory from DB (using block ids)
agent_state.memory = Memory(
blocks=[self.block_manager.get_block_by_id(block.id, actor=actor) for block in agent_state.memory.get_blocks()]
)
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
# rebuild memory - this records the last edited timestamp of the memory
# TODO: pass in update timestamp from block edit time
agent_state = self.rebuild_system_prompt(agent_id=agent_id, actor=actor)
return agent_state
# ======================================================================================================================
# Source Management
# ======================================================================================================================

View File

@@ -107,12 +107,14 @@ class BlockManager:
@enforce_types
def add_default_blocks(self, actor: PydanticUser):
for persona_file in list_persona_files():
text = open(persona_file, "r", encoding="utf-8").read()
with open(persona_file, "r", encoding="utf-8") as f:
text = f.read()
name = os.path.basename(persona_file).replace(".txt", "")
self.create_or_update_block(Persona(template_name=name, value=text, is_template=True), actor=actor)
for human_file in list_human_files():
text = open(human_file, "r", encoding="utf-8").read()
with open(human_file, "r", encoding="utf-8") as f:
text = f.read()
name = os.path.basename(human_file).replace(".txt", "")
self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor)

View File

@@ -111,6 +111,12 @@ class IdentityManager:
existing_identity.name = identity.name
if identity.identity_type is not None:
existing_identity.identity_type = identity.identity_type
if identity.properties is not None:
if replace:
existing_identity.properties = [prop.model_dump() for prop in identity.properties]
else:
new_properties = existing_identity.properties + identity.properties
existing_identity.properties = [prop.model_dump() for prop in new_properties]
self._process_agent_relationship(
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace

View File

@@ -2,6 +2,7 @@ from typing import List, Optional
from sqlalchemy import and_, or_
from letta.log import get_logger
from letta.orm.agent import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
@@ -11,6 +12,8 @@ from letta.schemas.message import MessageUpdate
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
logger = get_logger(__name__)
class MessageManager:
"""Manager class to handle business logic related to Messages."""
@@ -37,7 +40,7 @@ class MessageManager:
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids))
if len(results) != len(message_ids):
raise NoResultFound(
logger.warning(
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
)

View File

@@ -25,15 +25,15 @@ class ProviderManager:
provider.resolve_identifier()
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
new_provider.create(session)
new_provider.create(session, actor=actor)
return new_provider.to_pydantic()
@enforce_types
def update_provider(self, provider_update: ProviderUpdate) -> PydanticProvider:
def update_provider(self, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
"""Update provider details."""
with self.session_maker() as session:
# Retrieve the existing provider by ID
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id)
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id, actor=actor)
# Update only the fields that are provided in ProviderUpdate
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
@@ -41,31 +41,32 @@ class ProviderManager:
setattr(existing_provider, key, value)
# Commit the updated provider
existing_provider.update(session)
existing_provider.update(session, actor=actor)
return existing_provider.to_pydantic()
@enforce_types
def delete_provider_by_id(self, provider_id: str):
def delete_provider_by_id(self, provider_id: str, actor: PydanticUser):
"""Delete a provider."""
with self.session_maker() as session:
# Clear api key field
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id)
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
existing_provider.api_key = None
existing_provider.update(session)
existing_provider.update(session, actor=actor)
# Soft delete in provider table
existing_provider.delete(session)
existing_provider.delete(session, actor=actor)
session.commit()
@enforce_types
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticProvider]:
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]:
"""List all providers with optional pagination."""
with self.session_maker() as session:
providers = ProviderModel.list(
db_session=session,
after=after,
limit=limit,
actor=actor,
)
return [provider.to_pydantic() for provider in providers]

View File

@@ -84,9 +84,9 @@ class StepManager:
return new_step.to_pydantic()
@enforce_types
def get_step(self, step_id: str) -> PydanticStep:
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
with self.session_maker() as session:
step = StepModel.read(db_session=session, identifier=step_id)
step = StepModel.read(db_session=session, identifier=step_id, actor=actor)
return step.to_pydantic()
@enforce_types

View File

@@ -14,7 +14,9 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage
from letta.schemas.tool import ToolCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.services.tool_manager import ToolManager
# --- Server Management --- #
@@ -69,9 +71,49 @@ def roll_dice_tool(client):
@pytest.fixture(scope="function")
def agent(client, roll_dice_tool):
def weather_tool(client):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
Parameters:
location (str): The location to get the weather for.
Returns:
str: A formatted string describing the weather in the given location.
Raises:
RuntimeError: If the request to fetch weather data fails.
"""
import requests
url = f"https://wttr.in/{location}?format=%C+%t"
response = requests.get(url)
if response.status_code == 200:
weather_data = response.text
return f"The weather in {location} is {weather_data}."
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
tool = client.create_or_update_tool(func=get_weather)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def composio_gmail_get_profile_tool(default_user):
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
yield tool
@pytest.fixture(scope="function")
def agent(client, roll_dice_tool, weather_tool, composio_gmail_get_profile_tool):
"""Creates an agent and ensures cleanup after tests."""
agent_state = client.create_agent(name=f"test_client_{uuid.uuid4()}", tool_ids=[roll_dice_tool.id])
agent_state = client.create_agent(
name=f"test_compl_{str(uuid.uuid4())[5:]}", tool_ids=[roll_dice_tool.id, weather_tool.id, composio_gmail_get_profile_tool.id]
)
yield agent_state
client.delete_agent(agent_state.id)
@@ -111,6 +153,19 @@ def _assert_valid_chunk(chunk, idx, chunks):
# --- Test Cases --- #
@pytest.mark.parametrize("message", ["What's the weather in SF?"])
@pytest.mark.parametrize("endpoint", ["fast/chat/completions"])
def test_tool_usage_fast_chat_completions(mock_e2b_api_key_none, client, agent, message, endpoint):
"""Tests chat completion streaming via SSE."""
request = _get_chat_request(agent.id, message)
response = _sse_post(f"{client.base_url}/openai/{client.api_prefix}/{endpoint}", request.model_dump(exclude_none=True), client.headers)
for chunk in response:
if isinstance(chunk, ChatCompletionChunk) and chunk.choices:
print(chunk.choices[0].delta.content)
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."])
@pytest.mark.parametrize("endpoint", ["chat/completions", "fast/chat/completions"])
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message, endpoint):

View File

@@ -2150,6 +2150,30 @@ def test_delete_source(server: SyncServer, default_user):
assert len(sources) == 0
def test_delete_attached_source(server: SyncServer, sarah_agent, default_user):
"""Test deleting a source."""
source_pydantic = PydanticSource(
name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=source.id, actor=default_user)
# Delete the source
deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user)
# Assertions to verify deletion
assert deleted_source.id == source.id
# Verify that the source no longer appears in list_sources
sources = server.source_manager.list_sources(actor=default_user)
assert len(sources) == 0
# Verify that agent is not deleted
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
assert agent is not None
def test_list_sources(server: SyncServer, default_user):
"""Test listing sources with pagination."""
# Create multiple sources

View File

@@ -1194,7 +1194,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
step_ids = set([msg.step_id for msg in get_messages_response])
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
for step_id in step_ids:
step = server.step_manager.get_step(step_id=step_id)
step = server.step_manager.get_step(step_id=step_id, actor=actor)
assert step, "Step was not logged correctly"
assert step.provider_id == provider.id
assert step.provider_name == agent.llm_config.model_endpoint_type
@@ -1208,7 +1208,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
assert prompt_tokens == usage.prompt_tokens
assert total_tokens == usage.total_tokens
server.provider_manager.delete_provider_by_id(provider.id)
server.provider_manager.delete_provider_by_id(provider.id, actor=actor)
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
@@ -1221,7 +1221,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
step_ids = set([msg.step_id for msg in get_messages_response])
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
for step_id in step_ids:
step = server.step_manager.get_step(step_id=step_id)
step = server.step_manager.get_step(step_id=step_id, actor=actor)
assert step, "Step was not logged correctly"
assert step.provider_id == None
assert step.provider_name == agent.llm_config.model_endpoint_type