Merge branch 'main' into bump-0.6.33
This commit is contained in:
@@ -245,10 +245,13 @@ class Agent(BaseAgent):
|
||||
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
|
||||
)
|
||||
else:
|
||||
# Parse the source code to extract function annotations
|
||||
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
|
||||
# Coerce the function arguments to the correct types based on the annotations
|
||||
function_args = coerce_dict_args_by_annotations(function_args, annotations)
|
||||
try:
|
||||
# Parse the source code to extract function annotations
|
||||
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
|
||||
# Coerce the function arguments to the correct types based on the annotations
|
||||
function_args = coerce_dict_args_by_annotations(function_args, annotations)
|
||||
except ValueError as e:
|
||||
self.logger.debug(f"Error coercing function arguments: {e}")
|
||||
|
||||
# execute tool in a sandbox
|
||||
# TODO: allow agent_state to specify which sandbox to execute tools in
|
||||
@@ -257,7 +260,9 @@ class Agent(BaseAgent):
|
||||
agent_state_copy.tools = []
|
||||
agent_state_copy.tool_rules = []
|
||||
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(agent_state=agent_state_copy)
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
|
||||
agent_state=agent_state_copy
|
||||
)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
if updated_agent_state is not None:
|
||||
|
||||
@@ -52,6 +52,8 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "
|
||||
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
# Multi agent tools
|
||||
MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"]
|
||||
# Set of all built-in Letta tools
|
||||
LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS)
|
||||
|
||||
# The name of the tool used to send message to the user
|
||||
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
|
||||
@@ -59,6 +61,11 @@ MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_t
|
||||
DEFAULT_MESSAGE_TOOL = "send_message"
|
||||
DEFAULT_MESSAGE_TOOL_KWARG = "message"
|
||||
|
||||
PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"
|
||||
|
||||
REQUEST_HEARTBEAT_PARAM = "request_heartbeat"
|
||||
|
||||
|
||||
# Structured output models
|
||||
STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"}
|
||||
|
||||
|
||||
@@ -6,10 +6,11 @@ from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
def get_composio_api_key(actor: User, logger: Logger) -> Optional[str]:
|
||||
def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Optional[str]:
|
||||
api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
|
||||
if not api_keys:
|
||||
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
|
||||
if logger:
|
||||
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
|
||||
if tool_settings.composio_api_key:
|
||||
return tool_settings.composio_api_key
|
||||
else:
|
||||
|
||||
171
letta/helpers/tool_execution_helper.py
Normal file
171
letta/helpers/tool_execution_helper.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
|
||||
def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Enables strict mode for a tool schema by setting 'strict' to True and
|
||||
disallowing additional properties in the parameters.
|
||||
|
||||
Args:
|
||||
tool_schema (Dict[str, Any]): The original tool schema.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A new tool schema with strict mode enabled.
|
||||
"""
|
||||
schema = tool_schema.copy()
|
||||
|
||||
# Enable strict mode
|
||||
schema["strict"] = True
|
||||
|
||||
# Ensure parameters is a valid dictionary
|
||||
parameters = schema.get("parameters", {})
|
||||
|
||||
if isinstance(parameters, dict) and parameters.get("type") == "object":
|
||||
# Set additionalProperties to False
|
||||
parameters["additionalProperties"] = False
|
||||
schema["parameters"] = parameters
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def add_pre_execution_message(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Adds a `pre_execution_message` parameter to a tool schema to prompt a natural, human-like message before executing the tool.
|
||||
|
||||
Args:
|
||||
tool_schema (Dict[str, Any]): The original tool schema.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A new tool schema with the `pre_execution_message` field added at the beginning.
|
||||
"""
|
||||
schema = tool_schema.copy()
|
||||
parameters = schema.get("parameters", {})
|
||||
|
||||
if not isinstance(parameters, dict) or parameters.get("type") != "object":
|
||||
return schema # Do not modify if schema is not valid
|
||||
|
||||
properties = parameters.get("properties", {})
|
||||
required = parameters.get("required", [])
|
||||
|
||||
# Define the new `pre_execution_message` field with a refined description
|
||||
pre_execution_message_field = {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"A concise message to be uttered before executing this tool. "
|
||||
"This should sound natural, as if a person is casually announcing their next action."
|
||||
"You MUST also include punctuation at the end of this message."
|
||||
),
|
||||
}
|
||||
|
||||
# Ensure the pre-execution message is the first field in properties
|
||||
updated_properties = OrderedDict()
|
||||
updated_properties[PRE_EXECUTION_MESSAGE_ARG] = pre_execution_message_field
|
||||
updated_properties.update(properties) # Retain all existing properties
|
||||
|
||||
# Ensure pre-execution message is the first required field
|
||||
if PRE_EXECUTION_MESSAGE_ARG not in required:
|
||||
required = [PRE_EXECUTION_MESSAGE_ARG] + required
|
||||
|
||||
# Update the schema with ordered properties and required list
|
||||
schema["parameters"] = {
|
||||
**parameters,
|
||||
"properties": dict(updated_properties), # Convert OrderedDict back to dict
|
||||
"required": required,
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def remove_request_heartbeat(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Removes the `request_heartbeat` parameter from a tool schema if it exists.
|
||||
|
||||
Args:
|
||||
tool_schema (Dict[str, Any]): The original tool schema.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A new tool schema without `request_heartbeat`.
|
||||
"""
|
||||
schema = tool_schema.copy()
|
||||
parameters = schema.get("parameters", {})
|
||||
|
||||
if isinstance(parameters, dict):
|
||||
properties = parameters.get("properties", {})
|
||||
required = parameters.get("required", [])
|
||||
|
||||
# Remove the `request_heartbeat` property if it exists
|
||||
if "request_heartbeat" in properties:
|
||||
properties.pop("request_heartbeat")
|
||||
|
||||
# Remove `request_heartbeat` from required fields if present
|
||||
if "request_heartbeat" in required:
|
||||
required = [r for r in required if r != "request_heartbeat"]
|
||||
|
||||
# Update parameters with modified properties and required list
|
||||
schema["parameters"] = {**parameters, "properties": properties, "required": required}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
# TODO: Deprecate the `execute_external_tool` function on the agent body
|
||||
def execute_external_tool(
|
||||
agent_state: AgentState,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
target_letta_tool: Tool,
|
||||
actor: User,
|
||||
allow_agent_state_modifications: bool = False,
|
||||
) -> tuple[Any, Optional[SandboxRunResult]]:
|
||||
# TODO: need to have an AgentState object that actually has full access to the block data
|
||||
# this is because the sandbox tools need to be able to access block.value to edit this data
|
||||
try:
|
||||
if target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO:
|
||||
action_name = generate_composio_action_from_func_name(target_letta_tool.name)
|
||||
# Get entity ID from the agent_state
|
||||
entity_id = None
|
||||
for env_var in agent_state.tool_exec_environment_variables:
|
||||
if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY:
|
||||
entity_id = env_var.value
|
||||
# Get composio_api_key
|
||||
composio_api_key = get_composio_api_key(actor=actor)
|
||||
function_response = execute_composio_action(
|
||||
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
|
||||
)
|
||||
return function_response, None
|
||||
elif target_letta_tool.tool_type == ToolType.CUSTOM:
|
||||
# Parse the source code to extract function annotations
|
||||
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
|
||||
# Coerce the function arguments to the correct types based on the annotations
|
||||
function_args = coerce_dict_args_by_annotations(function_args, annotations)
|
||||
|
||||
# execute tool in a sandbox
|
||||
# TODO: allow agent_state to specify which sandbox to execute tools in
|
||||
# TODO: This is only temporary, can remove after we publish a pip package with this object
|
||||
if allow_agent_state_modifications:
|
||||
agent_state_copy = agent_state.__deepcopy__()
|
||||
agent_state_copy.tools = []
|
||||
agent_state_copy.tool_rules = []
|
||||
else:
|
||||
agent_state_copy = None
|
||||
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
# TODO: Bring this back
|
||||
# if allow_agent_state_modifications and updated_agent_state is not None:
|
||||
# self.update_memory_if_changed(updated_agent_state.memory)
|
||||
return function_response, sandbox_run_result
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
# TODO: modify to function execution error
|
||||
function_response = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
|
||||
return function_response, None
|
||||
@@ -47,14 +47,39 @@ BASE_URL = "https://api.anthropic.com/v1"
|
||||
# https://docs.anthropic.com/claude/docs/models-overview
|
||||
# Sadly hardcoded
|
||||
MODEL_LIST = [
|
||||
## Opus
|
||||
{
|
||||
"name": "claude-3-opus-20240229",
|
||||
"context_window": 200000,
|
||||
},
|
||||
## Sonnet
|
||||
# 3.0
|
||||
{
|
||||
"name": "claude-3-sonnet-20240229",
|
||||
"context_window": 200000,
|
||||
},
|
||||
# 3.5
|
||||
{
|
||||
"name": "claude-3-5-sonnet-20240620",
|
||||
"context_window": 200000,
|
||||
},
|
||||
# 3.5 new
|
||||
{
|
||||
"name": "claude-3-5-sonnet-20241022",
|
||||
"context_window": 200000,
|
||||
},
|
||||
# 3.7
|
||||
{
|
||||
"name": "claude-3-7-sonnet-20250219",
|
||||
"context_window": 200000,
|
||||
},
|
||||
## Haiku
|
||||
# 3.0
|
||||
{
|
||||
"name": "claude-3-haiku-20240307",
|
||||
"context_window": 200000,
|
||||
},
|
||||
# 3.5
|
||||
{
|
||||
"name": "claude-3-5-haiku-20241022",
|
||||
"context_window": 200000,
|
||||
@@ -75,7 +100,18 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict:
|
||||
"""https://docs.anthropic.com/claude/docs/models-overview"""
|
||||
|
||||
# NOTE: currently there is no GET /models, so we need to hardcode
|
||||
return MODEL_LIST
|
||||
# return MODEL_LIST
|
||||
|
||||
anthropic_override_key = ProviderManager().get_anthropic_override_key()
|
||||
if anthropic_override_key:
|
||||
anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key)
|
||||
elif model_settings.anthropic_api_key:
|
||||
anthropic_client = anthropic.Anthropic()
|
||||
|
||||
models = anthropic_client.models.list()
|
||||
models_json = models.model_dump()
|
||||
assert "data" in models_json, f"Anthropic model query response missing 'data' field: {models_json}"
|
||||
return models_json["data"]
|
||||
|
||||
|
||||
def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
|
||||
|
||||
@@ -42,6 +42,6 @@ class Source(SqlalchemyBase, OrganizationMixin):
|
||||
secondary="sources_agents",
|
||||
back_populates="sources",
|
||||
lazy="selectin",
|
||||
cascade="all, delete", # Ensures rows in sources_agents are deleted when the source is deleted
|
||||
passive_deletes=True, # Allows the database to handle deletion of orphaned rows
|
||||
cascade="save-update", # Only propagate save and update operations
|
||||
passive_deletes=True, # Let the database handle deletions
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
"""https://platform.openai.com/docs/api-reference/chat/create"""
|
||||
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
messages: List[Union[ChatMessage, Dict]]
|
||||
frequency_penalty: Optional[float] = 0
|
||||
logit_bias: Optional[Dict[str, int]] = None
|
||||
logprobs: Optional[bool] = False
|
||||
|
||||
@@ -410,28 +410,67 @@ class AnthropicProvider(Provider):
|
||||
base_url: str = "https://api.anthropic.com/v1"
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.anthropic import anthropic_get_model_list
|
||||
from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list
|
||||
|
||||
models = anthropic_get_model_list(self.base_url, api_key=self.api_key)
|
||||
|
||||
"""
|
||||
Example response:
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"type": "model",
|
||||
"id": "claude-3-5-sonnet-20241022",
|
||||
"display_name": "Claude 3.5 Sonnet (New)",
|
||||
"created_at": "2024-10-22T00:00:00Z"
|
||||
}
|
||||
],
|
||||
"has_more": true,
|
||||
"first_id": "<string>",
|
||||
"last_id": "<string>"
|
||||
}
|
||||
"""
|
||||
|
||||
configs = []
|
||||
for model in models:
|
||||
|
||||
if model["type"] != "model":
|
||||
continue
|
||||
|
||||
if "id" not in model:
|
||||
continue
|
||||
|
||||
# Don't support 2.0 and 2.1
|
||||
if model["id"].startswith("claude-2"):
|
||||
continue
|
||||
|
||||
# Anthropic doesn't return the context window in their API
|
||||
if "context_window" not in model:
|
||||
# Remap list to name: context_window
|
||||
model_library = {m["name"]: m["context_window"] for m in MODEL_LIST}
|
||||
# Attempt to look it up in a hardcoded list
|
||||
if model["id"] in model_library:
|
||||
model["context_window"] = model_library[model["id"]]
|
||||
else:
|
||||
# On fallback, we can set 200k (generally safe), but we should warn the user
|
||||
warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000")
|
||||
model["context_window"] = 200000
|
||||
|
||||
# We set this to false by default, because Anthropic can
|
||||
# natively support <thinking> tags inside of content fields
|
||||
# However, putting COT inside of tool calls can make it more
|
||||
# reliable for tool calling (no chance of a non-tool call step)
|
||||
# Since tool_choice_type 'any' doesn't work with in-content COT
|
||||
# NOTE For Haiku, it can be flaky if we don't enable this by default
|
||||
inner_thoughts_in_kwargs = True if "haiku" in model["name"] else False
|
||||
inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model["name"],
|
||||
model=model["id"],
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["context_window"],
|
||||
handle=self.get_handle(model["name"]),
|
||||
handle=self.get_handle(model["id"]),
|
||||
put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from letta.constants import (
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
||||
)
|
||||
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
|
||||
from letta.functions.helpers import generate_composio_action_from_func_name, generate_composio_tool_wrapper, generate_langchain_tool_wrapper
|
||||
from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper
|
||||
from letta.functions.schema_generator import generate_schema_from_args_schema_v2, generate_tool_schema_for_composio
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
@@ -77,18 +77,6 @@ class Tool(BaseTool):
|
||||
elif self.tool_type in {ToolType.LETTA_MULTI_AGENT_CORE}:
|
||||
# If it's letta multi-agent tool, we also generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name=self.name)
|
||||
elif self.tool_type == ToolType.EXTERNAL_COMPOSIO:
|
||||
# If it is a composio tool, we generate both the source code and json schema on the fly here
|
||||
# TODO: Deriving the composio action name is brittle, need to think long term about how to improve this
|
||||
try:
|
||||
composio_action = generate_composio_action_from_func_name(self.name)
|
||||
tool_create = ToolCreate.from_composio(composio_action)
|
||||
self.source_code = tool_create.source_code
|
||||
self.json_schema = tool_create.json_schema
|
||||
self.description = tool_create.description
|
||||
self.tags = tool_create.tags
|
||||
except Exception as e:
|
||||
logger.error(f"Encountered exception while attempting to refresh source_code and json_schema for composio_tool: {e}")
|
||||
|
||||
# At this point, we need to validate that at least json_schema is populated
|
||||
if not self.json_schema:
|
||||
|
||||
@@ -1,19 +1,39 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from letta.agent import Agent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_TOOL_SET, NON_USER_MSG_PREFIX, PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.helpers.tool_execution_helper import (
|
||||
add_pre_execution_message,
|
||||
enable_strict_mode,
|
||||
execute_external_tool,
|
||||
remove_request_heartbeat,
|
||||
)
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_request import (
|
||||
AssistantMessage,
|
||||
ChatCompletionRequest,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
|
||||
# TODO this belongs in a controller!
|
||||
from letta.server.rest_api.utils import (
|
||||
@@ -52,20 +72,53 @@ async def create_fast_chat_completions(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
# TODO: This is necessary, we need to factor out CompletionCreateParams due to weird behavior
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent_id = str(completion_request.get("user", None))
|
||||
if agent_id is None:
|
||||
error_msg = "Must pass agent_id in the 'user' field"
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
model = completion_request.get("model")
|
||||
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
if agent_state.llm_config.model_endpoint_type != "openai":
|
||||
raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.")
|
||||
|
||||
# Convert Letta messages to OpenAI messages
|
||||
in_context_messages = server.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor)
|
||||
openai_messages = convert_letta_messages_to_openai(in_context_messages)
|
||||
|
||||
# Also parse user input from completion_request and append
|
||||
input_message = get_messages_from_completion_request(completion_request)[-1]
|
||||
openai_messages.append(input_message)
|
||||
|
||||
# Tools we allow this agent to call
|
||||
tools = [t for t in agent_state.tools if t.name not in LETTA_TOOL_SET and t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}]
|
||||
|
||||
# Initial request
|
||||
openai_request = ChatCompletionRequest(
|
||||
model=agent_state.llm_config.model,
|
||||
messages=openai_messages,
|
||||
# TODO: This nested thing here is so ugly, need to refactor
|
||||
tools=(
|
||||
[
|
||||
Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema))))
|
||||
for t in tools
|
||||
]
|
||||
if tools
|
||||
else None
|
||||
),
|
||||
tool_choice="auto",
|
||||
user=user_id,
|
||||
max_completion_tokens=agent_state.llm_config.max_tokens,
|
||||
temperature=agent_state.llm_config.temperature,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Create the OpenAI async client
|
||||
client = openai.AsyncClient(
|
||||
api_key=model_settings.openai_api_key,
|
||||
max_retries=0,
|
||||
http_client=httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
|
||||
timeout=httpx.Timeout(connect=15.0, read=30.0, write=15.0, pool=15.0),
|
||||
follow_redirects=True,
|
||||
limits=httpx.Limits(
|
||||
max_connections=50,
|
||||
@@ -75,38 +128,175 @@ async def create_fast_chat_completions(
|
||||
),
|
||||
)
|
||||
|
||||
# Magic message manipulating
|
||||
input_message = get_messages_from_completion_request(completion_request)[-1]
|
||||
completion_request.pop("messages")
|
||||
|
||||
# Get in context messages
|
||||
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor)
|
||||
openai_dict_in_context_messages = convert_letta_messages_to_openai(in_context_messages)
|
||||
openai_dict_in_context_messages.append(input_message)
|
||||
# The messages we want to persist to the Letta agent
|
||||
user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor)
|
||||
message_db_queue = [user_message]
|
||||
|
||||
async def event_stream():
|
||||
# TODO: Factor this out into separate interface
|
||||
response_accumulator = []
|
||||
"""
|
||||
A function-calling loop:
|
||||
- We stream partial tokens.
|
||||
- If we detect a tool call (finish_reason="tool_calls"), we parse it,
|
||||
add two messages to the conversation:
|
||||
(a) assistant message with tool_calls referencing the same ID
|
||||
(b) a tool message referencing that ID, containing the tool result.
|
||||
- Re-invoke the OpenAI request with updated conversation, streaming again.
|
||||
- End when finish_reason="stop" or no more tool calls.
|
||||
"""
|
||||
|
||||
stream = await client.chat.completions.create(**completion_request, messages=openai_dict_in_context_messages)
|
||||
# We'll keep updating this conversation in a loop
|
||||
conversation = openai_messages[:]
|
||||
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
# TODO: This does not support tool calling right now
|
||||
response_accumulator.append(chunk.choices[0].delta.content)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
while True:
|
||||
# Make the streaming request to OpenAI
|
||||
stream = await client.chat.completions.create(**openai_request.model_dump(exclude_unset=True))
|
||||
|
||||
# Construct messages
|
||||
user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor)
|
||||
assistant_message = create_assistant_message_from_openai_response(
|
||||
response_text="".join(response_accumulator), agent_id=agent_id, model=str(model), actor=actor
|
||||
)
|
||||
content_buffer = []
|
||||
tool_call_name = None
|
||||
tool_call_args_str = ""
|
||||
tool_call_id = None
|
||||
tool_call_happened = False
|
||||
finish_reason_stop = False
|
||||
optimistic_json_parser = OptimisticJSONParser(strict=True)
|
||||
current_parsed_json_result = {}
|
||||
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
finish_reason = choice.finish_reason # "tool_calls", "stop", or None
|
||||
|
||||
if delta.content:
|
||||
content_buffer.append(delta.content)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# CASE B: Partial tool call info
|
||||
if delta.tool_calls:
|
||||
# Typically there's only one in delta.tool_calls
|
||||
tc = delta.tool_calls[0]
|
||||
if tc.function.name:
|
||||
tool_call_name = tc.function.name
|
||||
if tc.function.arguments:
|
||||
tool_call_args_str += tc.function.arguments
|
||||
|
||||
# See if we can stream out the pre-execution message
|
||||
parsed_args = optimistic_json_parser.parse(tool_call_args_str)
|
||||
if parsed_args.get(
|
||||
PRE_EXECUTION_MESSAGE_ARG
|
||||
) and current_parsed_json_result.get( # Ensure key exists and is not None/empty
|
||||
PRE_EXECUTION_MESSAGE_ARG
|
||||
) != parsed_args.get(
|
||||
PRE_EXECUTION_MESSAGE_ARG
|
||||
):
|
||||
# Only stream if there's something new to stream
|
||||
# We do this way to avoid hanging JSON at the end of the stream, e.g. '}'
|
||||
if parsed_args != current_parsed_json_result:
|
||||
current_parsed_json_result = parsed_args
|
||||
synthetic_chunk = ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created,
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(content=tc.function.arguments, role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
yield f"data: {synthetic_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# We might generate a unique ID for the tool call
|
||||
if tc.id:
|
||||
tool_call_id = tc.id
|
||||
|
||||
# Check finish_reason
|
||||
if finish_reason == "tool_calls":
|
||||
tool_call_happened = True
|
||||
break
|
||||
elif finish_reason == "stop":
|
||||
finish_reason_stop = True
|
||||
break
|
||||
|
||||
if content_buffer:
|
||||
# We treat that partial text as an assistant message
|
||||
content = "".join(content_buffer)
|
||||
conversation.append({"role": "assistant", "content": content})
|
||||
|
||||
# Create an assistant message here to persist later
|
||||
assistant_message = create_assistant_message_from_openai_response(
|
||||
response_text=content, agent_id=agent_id, model=agent_state.llm_config.model, actor=actor
|
||||
)
|
||||
message_db_queue.append(assistant_message)
|
||||
|
||||
if tool_call_happened:
|
||||
# Parse the tool call arguments
|
||||
try:
|
||||
tool_args = json.loads(tool_call_args_str)
|
||||
except json.JSONDecodeError:
|
||||
tool_args = {}
|
||||
|
||||
if not tool_call_id:
|
||||
# If no tool_call_id given by the model, generate one
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 1) Insert the "assistant" message with the tool_calls field
|
||||
# referencing the same tool_call_id
|
||||
assistant_tool_call_msg = AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[ToolCall(id=tool_call_id, function=ToolCallFunction(name=tool_call_name, arguments=tool_call_args_str))],
|
||||
)
|
||||
|
||||
conversation.append(assistant_tool_call_msg.model_dump())
|
||||
|
||||
# 2) Execute the tool
|
||||
target_tool = next((x for x in tools if x.name == tool_call_name), None)
|
||||
if not target_tool:
|
||||
# Tool not found, handle error
|
||||
yield f"data: {json.dumps({'error': 'Tool not found', 'tool': tool_call_name})}\n\n"
|
||||
break
|
||||
|
||||
try:
|
||||
tool_result, _ = execute_external_tool(
|
||||
agent_state=agent_state,
|
||||
function_name=tool_call_name,
|
||||
function_args=tool_args,
|
||||
target_letta_tool=target_tool,
|
||||
actor=actor,
|
||||
allow_agent_state_modifications=False,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = f"Failed to call tool. Error: {e}"
|
||||
|
||||
# 3) Insert the "tool" message referencing the same tool_call_id
|
||||
tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id)
|
||||
|
||||
conversation.append(tool_message.model_dump())
|
||||
|
||||
# 4) Add a user message prompting the tool call result summarization
|
||||
heartbeat_user_message = UserMessage(
|
||||
content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user.",
|
||||
)
|
||||
conversation.append(heartbeat_user_message.model_dump())
|
||||
|
||||
# Now, re-invoke OpenAI with the updated conversation
|
||||
openai_request.messages = conversation
|
||||
|
||||
continue # Start the while loop again
|
||||
|
||||
if finish_reason_stop:
|
||||
# Model is done, no more calls
|
||||
break
|
||||
|
||||
# If we reach here, no tool call, no "stop", but we've ended streaming
|
||||
# Possibly a model error or some other finish reason. We'll just end.
|
||||
break
|
||||
|
||||
# Persist both in one synchronous DB call, done in a threadpool
|
||||
await run_in_threadpool(
|
||||
server.agent_manager.append_to_in_context_messages,
|
||||
[user_message, assistant_message],
|
||||
message_db_queue,
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ def list_agents(
|
||||
description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed in tags.",
|
||||
),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(None, description="Limit for pagination"),
|
||||
@@ -52,13 +52,14 @@ def list_agents(
|
||||
project_id: Optional[str] = Query(None, description="Search agents by project id"),
|
||||
template_id: Optional[str] = Query(None, description="Search agents by template id"),
|
||||
base_template_id: Optional[str] = Query(None, description="Search agents by base template id"),
|
||||
identifier_id: Optional[str] = Query(None, description="Search agents by identifier id"),
|
||||
identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"),
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Use dictionary comprehension to build kwargs dynamically
|
||||
kwargs = {
|
||||
@@ -68,6 +69,7 @@ def list_agents(
|
||||
"project_id": project_id,
|
||||
"template_id": template_id,
|
||||
"base_template_id": base_template_id,
|
||||
"identifier_id": identifier_id,
|
||||
}.items()
|
||||
if value is not None
|
||||
}
|
||||
@@ -91,12 +93,12 @@ def list_agents(
|
||||
def retrieve_agent_context_window(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the context window of a specific agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_context_window(agent_id=agent_id, actor=actor)
|
||||
|
||||
@@ -107,21 +109,21 @@ class CreateAgentRequest(CreateAgent):
|
||||
"""
|
||||
|
||||
# Override the user_id field to exclude it from the request body validation
|
||||
user_id: Optional[str] = Field(None, exclude=True)
|
||||
actor_id: Optional[str] = Field(None, exclude=True)
|
||||
|
||||
|
||||
@router.post("/", response_model=AgentState, operation_id="create_agent")
|
||||
def create_agent(
|
||||
agent: CreateAgentRequest = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
"""
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.create_agent(agent, actor=actor)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
@@ -133,10 +135,10 @@ def modify_agent(
|
||||
agent_id: str,
|
||||
update_agent: UpdateAgent = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Update an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.update_agent(agent_id=agent_id, agent_update=update_agent, actor=actor)
|
||||
|
||||
|
||||
@@ -144,10 +146,10 @@ def modify_agent(
|
||||
def list_agent_tools(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Get tools from an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.list_attached_tools(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@@ -156,12 +158,12 @@ def attach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a tool to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@@ -170,12 +172,12 @@ def detach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a tool from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@@ -184,12 +186,12 @@ def attach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a source to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@@ -198,12 +200,12 @@ def detach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a source from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@@ -211,12 +213,12 @@ def detach_source(
|
||||
def retrieve_agent(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the state of the agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
@@ -228,12 +230,12 @@ def retrieve_agent(
|
||||
def delete_agent(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
server.agent_manager.delete_agent(agent_id=agent_id, actor=actor)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"})
|
||||
@@ -245,12 +247,12 @@ def delete_agent(
|
||||
def list_agent_sources(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the sources associated with an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@@ -259,13 +261,13 @@ def list_agent_sources(
|
||||
def retrieve_agent_memory(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memory state of a specific agent.
|
||||
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_memory(agent_id=agent_id, actor=actor)
|
||||
|
||||
@@ -275,12 +277,12 @@ def retrieve_core_memory_block(
|
||||
agent_id: str,
|
||||
block_label: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve a memory block from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||
@@ -292,12 +294,12 @@ def retrieve_core_memory_block(
|
||||
def list_core_memory_blocks(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memory blocks of a specific agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor)
|
||||
return agent.memory.blocks
|
||||
@@ -311,12 +313,12 @@ def modify_core_memory_block(
|
||||
block_label: str,
|
||||
block_update: BlockUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Updates a memory block of an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||
block = server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
|
||||
@@ -332,12 +334,12 @@ def attach_core_memory_block(
|
||||
agent_id: str,
|
||||
block_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a block to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@@ -346,12 +348,12 @@ def detach_core_memory_block(
|
||||
agent_id: str,
|
||||
block_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a block from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@@ -362,12 +364,12 @@ def list_archival_memory(
|
||||
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
|
||||
limit: Optional[int] = Query(None, description="How many results to include in the response."),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_archival(
|
||||
user_id=actor.id,
|
||||
@@ -383,12 +385,12 @@ def create_archival_memory(
|
||||
agent_id: str,
|
||||
request: CreateArchivalMemory = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.insert_archival_memory(agent_id=agent_id, memory_contents=request.text, actor=actor)
|
||||
|
||||
@@ -401,12 +403,12 @@ def delete_archival_memory(
|
||||
memory_id: str,
|
||||
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a memory from an agent's archival memory store.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
server.delete_archival_memory(memory_id=memory_id, actor=actor)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
||||
@@ -427,12 +429,12 @@ def list_messages(
|
||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
|
||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
|
||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_recall(
|
||||
user_id=actor.id,
|
||||
@@ -454,13 +456,13 @@ def modify_message(
|
||||
message_id: str,
|
||||
request: MessageUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
# TODO: Get rid of agent_id here, it's not really relevant
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
|
||||
|
||||
|
||||
@@ -474,13 +476,13 @@ async def send_message(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
@@ -513,14 +515,14 @@ async def send_message_streaming(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
@@ -590,13 +592,13 @@ async def send_message_async(
|
||||
background_tasks: BackgroundTasks,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Asynchronously process a user message and return a run object.
|
||||
The actual processing happens in the background, and the status can be checked using the run ID.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Create a new job
|
||||
run = Run(
|
||||
@@ -635,8 +637,8 @@ def reset_messages(
|
||||
agent_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Resets the messages for an agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)
|
||||
|
||||
@@ -21,9 +21,9 @@ def list_blocks(
|
||||
templates_only: bool = Query(True, description="Whether to include only templates"),
|
||||
name: Optional[str] = Query(None, description="Name of the block"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name)
|
||||
|
||||
|
||||
@@ -31,9 +31,9 @@ def list_blocks(
|
||||
def create_block(
|
||||
create_block: CreateBlock = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
block = Block(**create_block.model_dump())
|
||||
return server.block_manager.create_or_update_block(actor=actor, block=block)
|
||||
|
||||
@@ -43,9 +43,9 @@ def modify_block(
|
||||
block_id: str,
|
||||
block_update: BlockUpdate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.update_block(block_id=block_id, block_update=block_update, actor=actor)
|
||||
|
||||
|
||||
@@ -53,9 +53,9 @@ def modify_block(
|
||||
def delete_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.delete_block(block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@@ -63,10 +63,10 @@ def delete_block(
|
||||
def retrieve_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
print("call get block", block_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
|
||||
if block is None:
|
||||
@@ -80,13 +80,13 @@ def retrieve_block(
|
||||
def list_agents_for_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Retrieves all agents associated with the specified block.
|
||||
Raises a 404 if the block does not exist.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor)
|
||||
return agents
|
||||
|
||||
@@ -22,13 +22,13 @@ def list_identities(
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a list of all identities in the database
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
identities = server.identity_manager.list_identities(
|
||||
name=name,
|
||||
@@ -51,10 +51,10 @@ def list_identities(
|
||||
def retrieve_identity(
|
||||
identity_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.get_identity(identity_id=identity_id, actor=actor)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -64,11 +64,11 @@ def retrieve_identity(
|
||||
def create_identity(
|
||||
identity: IdentityCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.create_identity(identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -80,11 +80,11 @@ def create_identity(
|
||||
def upsert_identity(
|
||||
identity: IdentityCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.upsert_identity(identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -97,10 +97,10 @@ def modify_identity(
|
||||
identity_id: str,
|
||||
identity: IdentityUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -112,10 +112,10 @@ def modify_identity(
|
||||
def delete_identity(
|
||||
identity_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete an identity by its identifier key
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.identity_manager.delete_identity(identity_id=identity_id, actor=actor)
|
||||
|
||||
@@ -15,12 +15,12 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
def list_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all jobs.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# TODO: add filtering by status
|
||||
jobs = server.job_manager.list_jobs(actor=actor)
|
||||
@@ -35,12 +35,12 @@ def list_jobs(
|
||||
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
|
||||
def list_active_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all active jobs.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running])
|
||||
|
||||
@@ -48,13 +48,13 @@ def list_active_jobs(
|
||||
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")
|
||||
def retrieve_job(
|
||||
job_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get the status of a job.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.job_manager.get_job_by_id(job_id=job_id, actor=actor)
|
||||
@@ -65,13 +65,13 @@ def retrieve_job(
|
||||
@router.delete("/{job_id}", response_model=Job, operation_id="delete_job")
|
||||
def delete_job(
|
||||
job_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Delete a job by its job_id.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.delete_job_by_id(job_id=job_id, actor=actor)
|
||||
|
||||
@@ -15,13 +15,15 @@ router = APIRouter(prefix="/providers", tags=["providers"])
|
||||
def list_providers(
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all custom providers in the database
|
||||
"""
|
||||
try:
|
||||
providers = server.provider_manager.list_providers(after=after, limit=limit)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -32,13 +34,13 @@ def list_providers(
|
||||
@router.post("/", tags=["providers"], response_model=Provider, operation_id="create_provider")
|
||||
def create_provider(
|
||||
request: ProviderCreate = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new custom provider
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
provider = Provider(**request.model_dump())
|
||||
provider = server.provider_manager.create_provider(provider, actor=actor)
|
||||
@@ -48,25 +50,29 @@ def create_provider(
|
||||
@router.patch("/", tags=["providers"], response_model=Provider, operation_id="modify_provider")
|
||||
def modify_provider(
|
||||
request: ProviderUpdate = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Update an existing custom provider
|
||||
"""
|
||||
provider = server.provider_manager.update_provider(request)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
provider = server.provider_manager.update_provider(request, actor=actor)
|
||||
return provider
|
||||
|
||||
|
||||
@router.delete("/", tags=["providers"], response_model=None, operation_id="delete_provider")
|
||||
def delete_provider(
|
||||
provider_id: str = Query(..., description="The provider_id key to be deleted."),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Delete an existing custom provider
|
||||
"""
|
||||
try:
|
||||
server.provider_manager.delete_provider_by_id(provider_id=provider_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.provider_manager.delete_provider_by_id(provider_id=provider_id, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -18,12 +18,12 @@ router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
@router.get("/", response_model=List[Run], operation_id="list_runs")
|
||||
def list_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all runs.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)]
|
||||
|
||||
@@ -31,12 +31,12 @@ def list_runs(
|
||||
@router.get("/active", response_model=List[Run], operation_id="list_active_runs")
|
||||
def list_active_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all active runs.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
active_runs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN)
|
||||
|
||||
@@ -46,13 +46,13 @@ def list_active_runs(
|
||||
@router.get("/{run_id}", response_model=Run, operation_id="retrieve_run")
|
||||
def retrieve_run(
|
||||
run_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get the status of a run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
||||
@@ -74,7 +74,7 @@ RunMessagesResponse = Annotated[
|
||||
async def list_run_messages(
|
||||
run_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
@@ -102,7 +102,7 @@ async def list_run_messages(
|
||||
if order not in ["asc", "desc"]:
|
||||
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
messages = server.job_manager.get_run_messages(
|
||||
@@ -122,13 +122,13 @@ async def list_run_messages(
|
||||
@router.get("/{run_id}/usage", response_model=UsageStatistics, operation_id="retrieve_run_usage")
|
||||
def retrieve_run_usage(
|
||||
run_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get usage statistics for a run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
usage = server.job_manager.get_job_usage(job_id=run_id, actor=actor)
|
||||
@@ -140,13 +140,13 @@ def retrieve_run_usage(
|
||||
@router.delete("/{run_id}", response_model=Run, operation_id="delete_run")
|
||||
def delete_run(
|
||||
run_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Delete a run by its run_id.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.delete_job_by_id(job_id=run_id, actor=actor)
|
||||
|
||||
@@ -25,9 +25,9 @@ logger = get_logger(__name__)
|
||||
def create_sandbox_config(
|
||||
config_create: SandboxConfigCreate,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor)
|
||||
|
||||
@@ -35,18 +35,18 @@ def create_sandbox_config(
|
||||
@router.post("/e2b/default", response_model=PydanticSandboxConfig)
|
||||
def create_default_e2b_sandbox_config(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=actor)
|
||||
|
||||
|
||||
@router.post("/local/default", response_model=PydanticSandboxConfig)
|
||||
def create_default_local_sandbox_config(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ def create_default_local_sandbox_config(
|
||||
def create_custom_local_sandbox_config(
|
||||
local_sandbox_config: LocalSandboxConfig,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""
|
||||
Create or update a custom LocalSandboxConfig, including pip_requirements.
|
||||
@@ -67,7 +67,7 @@ def create_custom_local_sandbox_config(
|
||||
)
|
||||
|
||||
# Retrieve the user (actor)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Wrap the LocalSandboxConfig into a SandboxConfigCreate
|
||||
sandbox_config_create = SandboxConfigCreate(config=local_sandbox_config)
|
||||
@@ -83,9 +83,9 @@ def update_sandbox_config(
|
||||
sandbox_config_id: str,
|
||||
config_update: SandboxConfigUpdate,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor)
|
||||
|
||||
|
||||
@@ -93,9 +93,9 @@ def update_sandbox_config(
|
||||
def delete_sandbox_config(
|
||||
sandbox_config_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor)
|
||||
|
||||
|
||||
@@ -105,22 +105,22 @@ def list_sandbox_configs(
|
||||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
sandbox_type: Optional[SandboxType] = Query(None, description="Filter for this specific sandbox type"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, after=after, sandbox_type=sandbox_type)
|
||||
|
||||
|
||||
@router.post("/local/recreate-venv", response_model=PydanticSandboxConfig)
|
||||
def force_recreate_local_sandbox_venv(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""
|
||||
Forcefully recreate the virtual environment for the local sandbox.
|
||||
Deletes and recreates the venv, then reinstalls required dependencies.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Retrieve the local sandbox config
|
||||
sbx_config = server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
|
||||
@@ -162,9 +162,9 @@ def create_sandbox_env_var(
|
||||
sandbox_config_id: str,
|
||||
env_var_create: SandboxEnvironmentVariableCreate,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor)
|
||||
|
||||
|
||||
@@ -173,9 +173,9 @@ def update_sandbox_env_var(
|
||||
env_var_id: str,
|
||||
env_var_update: SandboxEnvironmentVariableUpdate,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor)
|
||||
|
||||
|
||||
@@ -183,9 +183,9 @@ def update_sandbox_env_var(
|
||||
def delete_sandbox_env_var(
|
||||
env_var_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor)
|
||||
|
||||
|
||||
@@ -195,7 +195,7 @@ def list_sandbox_env_vars(
|
||||
limit: int = Query(1000, description="Number of results to return"),
|
||||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, after=after)
|
||||
|
||||
@@ -23,12 +23,12 @@ router = APIRouter(prefix="/sources", tags=["sources"])
|
||||
def retrieve_source(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get all sources
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
if not source:
|
||||
@@ -40,12 +40,12 @@ def retrieve_source(
|
||||
def get_source_id_by_name(
|
||||
source_name: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a source by name
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
||||
if not source:
|
||||
@@ -56,12 +56,12 @@ def get_source_id_by_name(
|
||||
@router.get("/", response_model=List[Source], operation_id="list_sources")
|
||||
def list_sources(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all data sources created by a user.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.list_all_sources(actor=actor)
|
||||
|
||||
@@ -70,12 +70,12 @@ def list_sources(
|
||||
def create_source(
|
||||
source_create: SourceCreate,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
source = Source(**source_create.model_dump())
|
||||
|
||||
return server.source_manager.create_source(source=source, actor=actor)
|
||||
@@ -86,12 +86,12 @@ def modify_source(
|
||||
source_id: str,
|
||||
source: SourceUpdate,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the name or documentation of an existing data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
|
||||
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")
|
||||
return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
|
||||
@@ -101,12 +101,12 @@ def modify_source(
|
||||
def delete_source(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
server.delete_source(source_id=source_id, actor=actor)
|
||||
|
||||
@@ -117,12 +117,12 @@ def upload_file_to_source(
|
||||
source_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Upload a file to a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
@@ -151,12 +151,12 @@ def upload_file_to_source(
|
||||
def list_source_passages(
|
||||
source_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all passages associated with a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
|
||||
return passages
|
||||
|
||||
@@ -167,12 +167,12 @@ def list_source_files(
|
||||
limit: int = Query(1000, description="Number of files to return"),
|
||||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List paginated files associated with a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor)
|
||||
|
||||
|
||||
@@ -183,12 +183,12 @@ def delete_file_from_source(
|
||||
source_id: str,
|
||||
file_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
|
||||
if deleted_file is None:
|
||||
|
||||
@@ -21,13 +21,13 @@ def list_steps(
|
||||
end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
|
||||
model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
List steps with optional pagination and date filters.
|
||||
Dates should be provided in ISO 8601 format (e.g. 2025-01-29T15:01:19-08:00)
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Convert ISO strings to datetime objects if provided
|
||||
start_dt = datetime.fromisoformat(start_date) if start_date else None
|
||||
@@ -48,14 +48,15 @@ def list_steps(
|
||||
@router.get("/{step_id}", response_model=Step, operation_id="retrieve_step")
|
||||
def retrieve_step(
|
||||
step_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get a step by ID.
|
||||
"""
|
||||
try:
|
||||
return server.step_manager.get_step(step_id=step_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.step_manager.get_step(step_id=step_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
@@ -64,15 +65,15 @@ def retrieve_step(
|
||||
def update_step_transaction_id(
|
||||
step_id: str,
|
||||
transaction_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Update the transaction ID for a step.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.step_manager.update_step_transaction_id(actor, step_id=step_id, transaction_id=transaction_id)
|
||||
return server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
@@ -17,11 +17,11 @@ def list_tags(
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
query_text: Optional[str] = Query(None),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get a list of all tags in the database
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tags = server.agent_manager.list_tags(actor=actor, after=after, limit=limit, query_text=query_text)
|
||||
return tags
|
||||
|
||||
@@ -29,12 +29,12 @@ logger = get_logger(__name__)
|
||||
def delete_tool(
|
||||
tool_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.tool_manager.delete_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@@ -42,12 +42,12 @@ def delete_tool(
|
||||
def retrieve_tool(
|
||||
tool_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a tool by ID
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
@@ -61,13 +61,13 @@ def list_tools(
|
||||
limit: Optional[int] = 50,
|
||||
name: Optional[str] = None,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents belonging to the org of the user
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
if name is not None:
|
||||
tool = server.tool_manager.get_tool_by_name(tool_name=name, actor=actor)
|
||||
return [tool] if tool else []
|
||||
@@ -82,13 +82,13 @@ def list_tools(
|
||||
def create_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tool = Tool(**request.model_dump())
|
||||
return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor)
|
||||
except UniqueConstraintViolationError as e:
|
||||
@@ -114,13 +114,13 @@ def create_tool(
|
||||
def upsert_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Create or update a tool
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor)
|
||||
return tool
|
||||
except UniqueConstraintViolationError as e:
|
||||
@@ -142,13 +142,13 @@ def modify_tool(
|
||||
tool_id: str,
|
||||
request: ToolUpdate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update an existing tool
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.tool_manager.update_tool_by_id(tool_id=tool_id, tool_update=request, actor=actor)
|
||||
except LettaToolCreateError as e:
|
||||
# HTTP 400 == Bad Request
|
||||
@@ -163,12 +163,12 @@ def modify_tool(
|
||||
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
|
||||
def upsert_base_tools(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Upsert base tools
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
|
||||
@@ -176,12 +176,12 @@ def upsert_base_tools(
|
||||
def run_tool_from_source(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: ToolRunFromSource = Body(...),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Attempt to build a tool from source, then run it on the provided arguments
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.run_tool_from_source(
|
||||
@@ -227,12 +227,12 @@ def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id:
|
||||
def list_composio_actions_by_app(
|
||||
composio_app_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get a list of all Composio actions for a specific app
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
composio_api_key = get_composio_api_key(actor=actor, logger=logger)
|
||||
if not composio_api_key:
|
||||
raise HTTPException(
|
||||
@@ -246,12 +246,12 @@ def list_composio_actions_by_app(
|
||||
def add_composio_tool(
|
||||
composio_action_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
tool_create = ToolCreate.from_composio(action_name=composio_action_name)
|
||||
|
||||
@@ -7,7 +7,6 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
|
||||
|
||||
import pytz
|
||||
from fastapi import Header, HTTPException
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
@@ -145,7 +144,7 @@ def create_user_message(input_message: dict, agent_id: str, actor: User) -> Mess
|
||||
Converts a user input message into the internal structured format.
|
||||
"""
|
||||
# Generate timestamp in the correct format
|
||||
now = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Format message as structured JSON
|
||||
structured_message = {"type": "user_message", "message": input_message["content"], "time": now}
|
||||
@@ -197,7 +196,7 @@ def create_assistant_message_from_openai_response(
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
tool_calls=[tool_call],
|
||||
tool_call_id=None,
|
||||
tool_call_id=tool_call_id,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@@ -21,8 +21,10 @@ from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
@@ -613,6 +615,40 @@ class AgentManager:
|
||||
)
|
||||
return self.append_to_in_context_messages([system_message], agent_id=agent_state.id, actor=actor)
|
||||
|
||||
# TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version
|
||||
@enforce_types
|
||||
def update_memory_if_changed(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
|
||||
Args:
|
||||
new_memory (Memory): the new memory object to compare to the current memory object
|
||||
|
||||
Returns:
|
||||
modified (bool): whether the memory was updated
|
||||
"""
|
||||
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
if agent_state.memory.compile() != new_memory.compile():
|
||||
# update the blocks (LRW) in the DB
|
||||
for label in agent_state.memory.list_block_labels():
|
||||
updated_value = new_memory.get_block(label).value
|
||||
if updated_value != agent_state.memory.get_block(label).value:
|
||||
# update the block if it's changed
|
||||
block_id = agent_state.memory.get_block(label).id
|
||||
block = self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor)
|
||||
|
||||
# refresh memory from DB (using block ids)
|
||||
agent_state.memory = Memory(
|
||||
blocks=[self.block_manager.get_block_by_id(block.id, actor=actor) for block in agent_state.memory.get_blocks()]
|
||||
)
|
||||
|
||||
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
|
||||
# rebuild memory - this records the last edited timestamp of the memory
|
||||
# TODO: pass in update timestamp from block edit time
|
||||
agent_state = self.rebuild_system_prompt(agent_id=agent_id, actor=actor)
|
||||
|
||||
return agent_state
|
||||
|
||||
# ======================================================================================================================
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -107,12 +107,14 @@ class BlockManager:
|
||||
@enforce_types
|
||||
def add_default_blocks(self, actor: PydanticUser):
|
||||
for persona_file in list_persona_files():
|
||||
text = open(persona_file, "r", encoding="utf-8").read()
|
||||
with open(persona_file, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
name = os.path.basename(persona_file).replace(".txt", "")
|
||||
self.create_or_update_block(Persona(template_name=name, value=text, is_template=True), actor=actor)
|
||||
|
||||
for human_file in list_human_files():
|
||||
text = open(human_file, "r", encoding="utf-8").read()
|
||||
with open(human_file, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
name = os.path.basename(human_file).replace(".txt", "")
|
||||
self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor)
|
||||
|
||||
|
||||
@@ -111,6 +111,12 @@ class IdentityManager:
|
||||
existing_identity.name = identity.name
|
||||
if identity.identity_type is not None:
|
||||
existing_identity.identity_type = identity.identity_type
|
||||
if identity.properties is not None:
|
||||
if replace:
|
||||
existing_identity.properties = [prop.model_dump() for prop in identity.properties]
|
||||
else:
|
||||
new_properties = existing_identity.properties + identity.properties
|
||||
existing_identity.properties = [prop.model_dump() for prop in new_properties]
|
||||
|
||||
self._process_agent_relationship(
|
||||
session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import List, Optional
|
||||
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
@@ -11,6 +12,8 @@ from letta.schemas.message import MessageUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MessageManager:
|
||||
"""Manager class to handle business logic related to Messages."""
|
||||
@@ -37,7 +40,7 @@ class MessageManager:
|
||||
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids))
|
||||
|
||||
if len(results) != len(message_ids):
|
||||
raise NoResultFound(
|
||||
logger.warning(
|
||||
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
|
||||
)
|
||||
|
||||
|
||||
@@ -25,15 +25,15 @@ class ProviderManager:
|
||||
provider.resolve_identifier()
|
||||
|
||||
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
|
||||
new_provider.create(session)
|
||||
new_provider.create(session, actor=actor)
|
||||
return new_provider.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_provider(self, provider_update: ProviderUpdate) -> PydanticProvider:
|
||||
def update_provider(self, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
||||
"""Update provider details."""
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the existing provider by ID
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id)
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id, actor=actor)
|
||||
|
||||
# Update only the fields that are provided in ProviderUpdate
|
||||
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
@@ -41,31 +41,32 @@ class ProviderManager:
|
||||
setattr(existing_provider, key, value)
|
||||
|
||||
# Commit the updated provider
|
||||
existing_provider.update(session)
|
||||
existing_provider.update(session, actor=actor)
|
||||
return existing_provider.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_provider_by_id(self, provider_id: str):
|
||||
def delete_provider_by_id(self, provider_id: str, actor: PydanticUser):
|
||||
"""Delete a provider."""
|
||||
with self.session_maker() as session:
|
||||
# Clear api key field
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id)
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
|
||||
existing_provider.api_key = None
|
||||
existing_provider.update(session)
|
||||
existing_provider.update(session, actor=actor)
|
||||
|
||||
# Soft delete in provider table
|
||||
existing_provider.delete(session)
|
||||
existing_provider.delete(session, actor=actor)
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticProvider]:
|
||||
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]:
|
||||
"""List all providers with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
providers = ProviderModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
)
|
||||
return [provider.to_pydantic() for provider in providers]
|
||||
|
||||
|
||||
@@ -84,9 +84,9 @@ class StepManager:
|
||||
return new_step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_step(self, step_id: str) -> PydanticStep:
|
||||
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
with self.session_maker() as session:
|
||||
step = StepModel.read(db_session=session, identifier=step_id)
|
||||
step = StepModel.read(db_session=session, identifier=step_id, actor=actor)
|
||||
return step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -14,7 +14,9 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.tool_manager import ToolManager
|
||||
|
||||
# --- Server Management --- #
|
||||
|
||||
@@ -69,9 +71,49 @@ def roll_dice_tool(client):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent(client, roll_dice_tool):
|
||||
def weather_tool(client):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
|
||||
Parameters:
|
||||
location (str): The location to get the weather for.
|
||||
|
||||
Returns:
|
||||
str: A formatted string describing the weather in the given location.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the request to fetch weather data fails.
|
||||
"""
|
||||
import requests
|
||||
|
||||
url = f"https://wttr.in/{location}?format=%C+%t"
|
||||
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
weather_data = response.text
|
||||
return f"The weather in {location} is {weather_data}."
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
tool = client.create_or_update_tool(func=get_weather)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def composio_gmail_get_profile_tool(default_user):
|
||||
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
|
||||
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent(client, roll_dice_tool, weather_tool, composio_gmail_get_profile_tool):
|
||||
"""Creates an agent and ensures cleanup after tests."""
|
||||
agent_state = client.create_agent(name=f"test_client_{uuid.uuid4()}", tool_ids=[roll_dice_tool.id])
|
||||
agent_state = client.create_agent(
|
||||
name=f"test_compl_{str(uuid.uuid4())[5:]}", tool_ids=[roll_dice_tool.id, weather_tool.id, composio_gmail_get_profile_tool.id]
|
||||
)
|
||||
yield agent_state
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
@@ -111,6 +153,19 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
# --- Test Cases --- #
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", ["What's the weather in SF?"])
|
||||
@pytest.mark.parametrize("endpoint", ["fast/chat/completions"])
|
||||
def test_tool_usage_fast_chat_completions(mock_e2b_api_key_none, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming via SSE."""
|
||||
request = _get_chat_request(agent.id, message)
|
||||
|
||||
response = _sse_post(f"{client.base_url}/openai/{client.api_prefix}/{endpoint}", request.model_dump(exclude_none=True), client.headers)
|
||||
|
||||
for chunk in response:
|
||||
if isinstance(chunk, ChatCompletionChunk) and chunk.choices:
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."])
|
||||
@pytest.mark.parametrize("endpoint", ["chat/completions", "fast/chat/completions"])
|
||||
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message, endpoint):
|
||||
|
||||
@@ -2150,6 +2150,30 @@ def test_delete_source(server: SyncServer, default_user):
|
||||
assert len(sources) == 0
|
||||
|
||||
|
||||
def test_delete_attached_source(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test deleting a source."""
|
||||
source_pydantic = PydanticSource(
|
||||
name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG
|
||||
)
|
||||
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
|
||||
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=source.id, actor=default_user)
|
||||
|
||||
# Delete the source
|
||||
deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user)
|
||||
|
||||
# Assertions to verify deletion
|
||||
assert deleted_source.id == source.id
|
||||
|
||||
# Verify that the source no longer appears in list_sources
|
||||
sources = server.source_manager.list_sources(actor=default_user)
|
||||
assert len(sources) == 0
|
||||
|
||||
# Verify that agent is not deleted
|
||||
agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user)
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_list_sources(server: SyncServer, default_user):
|
||||
"""Test listing sources with pagination."""
|
||||
# Create multiple sources
|
||||
|
||||
@@ -1194,7 +1194,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
step_ids = set([msg.step_id for msg in get_messages_response])
|
||||
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
||||
for step_id in step_ids:
|
||||
step = server.step_manager.get_step(step_id=step_id)
|
||||
step = server.step_manager.get_step(step_id=step_id, actor=actor)
|
||||
assert step, "Step was not logged correctly"
|
||||
assert step.provider_id == provider.id
|
||||
assert step.provider_name == agent.llm_config.model_endpoint_type
|
||||
@@ -1208,7 +1208,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
assert prompt_tokens == usage.prompt_tokens
|
||||
assert total_tokens == usage.total_tokens
|
||||
|
||||
server.provider_manager.delete_provider_by_id(provider.id)
|
||||
server.provider_manager.delete_provider_by_id(provider.id, actor=actor)
|
||||
|
||||
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
|
||||
|
||||
@@ -1221,7 +1221,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
step_ids = set([msg.step_id for msg in get_messages_response])
|
||||
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
||||
for step_id in step_ids:
|
||||
step = server.step_manager.get_step(step_id=step_id)
|
||||
step = server.step_manager.get_step(step_id=step_id, actor=actor)
|
||||
assert step, "Step was not logged correctly"
|
||||
assert step.provider_id == None
|
||||
assert step.provider_name == agent.llm_config.model_endpoint_type
|
||||
|
||||
Reference in New Issue
Block a user