chore: migrate to ruff (#4305)
* base requirements * autofix * Configure ruff for Python linting and formatting - Set up minimal ruff configuration with basic checks (E, W, F, I) - Add temporary ignores for common issues during migration - Configure pre-commit hooks to use ruff with pass_filenames - This enables gradual migration from black to ruff * Delete sdj * autofixed only * migrate lint action * more autofixed * more fixes * change precommit * try changing the hook * try this stuff
This commit is contained in:
@@ -13,23 +13,13 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: trufflehog
|
- id: trufflehog
|
||||||
name: TruffleHog
|
name: TruffleHog
|
||||||
entry: bash -c 'trufflehog git file://. --since-commit HEAD --results=verified,unknown --fail'
|
entry: bash -c 'trufflehog git file://. --since-commit HEAD --results=verified,unknown --fail --no-update'
|
||||||
language: system
|
language: system
|
||||||
stages: ["pre-commit", "pre-push"]
|
stages: ["pre-commit", "pre-push"]
|
||||||
- id: autoflake
|
|
||||||
name: autoflake
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
entry: bash -c '[ -d "apps/core" ] && cd apps/core; uv run autoflake --remove-all-unused-imports --remove-unused-variables --in-place --recursive --ignore-init-module-imports .'
|
rev: v0.12.11
|
||||||
language: system
|
hooks:
|
||||||
types: [python]
|
- id: ruff-check
|
||||||
- id: isort
|
args: [ --fix ]
|
||||||
name: isort
|
- id: ruff-format
|
||||||
entry: bash -c '[ -d "apps/core" ] && cd apps/core; uv run isort --profile black .'
|
|
||||||
language: system
|
|
||||||
types: [python]
|
|
||||||
exclude: ^docs/
|
|
||||||
- id: black
|
|
||||||
name: black
|
|
||||||
entry: bash -c '[ -d "apps/core" ] && cd apps/core; uv run black --line-length 140 --target-version py310 --target-version py311 .'
|
|
||||||
language: system
|
|
||||||
types: [python]
|
|
||||||
exclude: ^docs/
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
|||||||
# add default value of `False`
|
# add default value of `False`
|
||||||
op.add_column("block", sa.Column("read_only", sa.Boolean(), nullable=True))
|
op.add_column("block", sa.Column("read_only", sa.Boolean(), nullable=True))
|
||||||
op.execute(
|
op.execute(
|
||||||
f"""
|
"""
|
||||||
UPDATE block
|
UPDATE block
|
||||||
SET read_only = False
|
SET read_only = False
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def upgrade() -> None:
|
|||||||
op.add_column("jobs", sa.Column("job_type", sa.String(), nullable=True))
|
op.add_column("jobs", sa.Column("job_type", sa.String(), nullable=True))
|
||||||
|
|
||||||
# Set existing rows to have the default value of JobType.JOB
|
# Set existing rows to have the default value of JobType.JOB
|
||||||
op.execute(f"UPDATE jobs SET job_type = 'job' WHERE job_type IS NULL")
|
op.execute("UPDATE jobs SET job_type = 'job' WHERE job_type IS NULL")
|
||||||
|
|
||||||
# Make the column non-nullable after setting default values
|
# Make the column non-nullable after setting default values
|
||||||
op.alter_column("jobs", "job_type", existing_type=sa.String(), nullable=False)
|
op.alter_column("jobs", "job_type", existing_type=sa.String(), nullable=False)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
# fill in column with `False`
|
# fill in column with `False`
|
||||||
op.execute(
|
op.execute(
|
||||||
f"""
|
"""
|
||||||
UPDATE organizations
|
UPDATE organizations
|
||||||
SET privileged_tools = False
|
SET privileged_tools = False
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ def upgrade() -> None:
|
|||||||
f"""
|
f"""
|
||||||
UPDATE tools
|
UPDATE tools
|
||||||
SET tool_type = '{letta_core_value}'
|
SET tool_type = '{letta_core_value}'
|
||||||
WHERE name IN ({','.join(f"'{name}'" for name in BASE_TOOLS)});
|
WHERE name IN ({",".join(f"'{name}'" for name in BASE_TOOLS)});
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ def upgrade() -> None:
|
|||||||
f"""
|
f"""
|
||||||
UPDATE tools
|
UPDATE tools
|
||||||
SET tool_type = '{letta_memory_core_value}'
|
SET tool_type = '{letta_memory_core_value}'
|
||||||
WHERE name IN ({','.join(f"'{name}'" for name in BASE_MEMORY_TOOLS)});
|
WHERE name IN ({",".join(f"'{name}'" for name in BASE_MEMORY_TOOLS)});
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ def upgrade() -> None:
|
|||||||
WITH numbered_rows AS (
|
WITH numbered_rows AS (
|
||||||
SELECT
|
SELECT
|
||||||
id,
|
id,
|
||||||
ROW_NUMBER() OVER (ORDER BY {', '.join(ORDERING_COLUMNS)} ASC) as rn
|
ROW_NUMBER() OVER (ORDER BY {", ".join(ORDERING_COLUMNS)} ASC) as rn
|
||||||
FROM {TABLE_NAME}
|
FROM {TABLE_NAME}
|
||||||
)
|
)
|
||||||
UPDATE {TABLE_NAME}
|
UPDATE {TABLE_NAME}
|
||||||
|
|||||||
@@ -49,9 +49,7 @@ from letta.schemas.enums import MessageRole, ProviderType, StepStatus, ToolType
|
|||||||
from letta.schemas.letta_message_content import ImageContent, TextContent
|
from letta.schemas.letta_message_content import ImageContent, TextContent
|
||||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Message as ChatCompletionMessage, UsageStatistics
|
||||||
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
|
|
||||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
|
||||||
from letta.schemas.response_format import ResponseFormatType
|
from letta.schemas.response_format import ResponseFormatType
|
||||||
from letta.schemas.tool import Tool
|
from letta.schemas.tool import Tool
|
||||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||||
@@ -871,7 +869,6 @@ class Agent(BaseAgent):
|
|||||||
) -> AgentStepResponse:
|
) -> AgentStepResponse:
|
||||||
"""Runs a single step in the agent loop (generates at most one LLM call)"""
|
"""Runs a single step in the agent loop (generates at most one LLM call)"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Extract job_id from metadata if present
|
# Extract job_id from metadata if present
|
||||||
job_id = metadata.get("job_id") if metadata else None
|
job_id = metadata.get("job_id") if metadata else None
|
||||||
|
|
||||||
@@ -1084,9 +1081,9 @@ class Agent(BaseAgent):
|
|||||||
-> agent.step(messages=[Message(role='user', text=...)])
|
-> agent.step(messages=[Message(role='user', text=...)])
|
||||||
"""
|
"""
|
||||||
# Wrap with metadata, dumps to JSON
|
# Wrap with metadata, dumps to JSON
|
||||||
assert user_message_str and isinstance(
|
assert user_message_str and isinstance(user_message_str, str), (
|
||||||
user_message_str, str
|
f"user_message_str should be a non-empty string, got {type(user_message_str)}"
|
||||||
), f"user_message_str should be a non-empty string, got {type(user_message_str)}"
|
)
|
||||||
user_message_json_str = package_user_message(user_message_str, self.agent_state.timezone)
|
user_message_json_str = package_user_message(user_message_str, self.agent_state.timezone)
|
||||||
|
|
||||||
# Validate JSON via save/load
|
# Validate JSON via save/load
|
||||||
|
|||||||
@@ -269,16 +269,20 @@ class LettaAgent(BaseAgent):
|
|||||||
effective_step_id = step_id if logged_step else None
|
effective_step_id = step_id if logged_step else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
(
|
||||||
await self._build_and_request_from_llm(
|
request_data,
|
||||||
current_in_context_messages,
|
response_data,
|
||||||
new_in_context_messages,
|
current_in_context_messages,
|
||||||
agent_state,
|
new_in_context_messages,
|
||||||
llm_client,
|
valid_tool_names,
|
||||||
tool_rules_solver,
|
) = await self._build_and_request_from_llm(
|
||||||
agent_step_span,
|
current_in_context_messages,
|
||||||
step_metrics,
|
new_in_context_messages,
|
||||||
)
|
agent_state,
|
||||||
|
llm_client,
|
||||||
|
tool_rules_solver,
|
||||||
|
agent_step_span,
|
||||||
|
step_metrics,
|
||||||
)
|
)
|
||||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||||
|
|
||||||
@@ -574,16 +578,20 @@ class LettaAgent(BaseAgent):
|
|||||||
effective_step_id = step_id if logged_step else None
|
effective_step_id = step_id if logged_step else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
|
(
|
||||||
await self._build_and_request_from_llm(
|
request_data,
|
||||||
current_in_context_messages,
|
response_data,
|
||||||
new_in_context_messages,
|
current_in_context_messages,
|
||||||
agent_state,
|
new_in_context_messages,
|
||||||
llm_client,
|
valid_tool_names,
|
||||||
tool_rules_solver,
|
) = await self._build_and_request_from_llm(
|
||||||
agent_step_span,
|
current_in_context_messages,
|
||||||
step_metrics,
|
new_in_context_messages,
|
||||||
)
|
agent_state,
|
||||||
|
llm_client,
|
||||||
|
tool_rules_solver,
|
||||||
|
agent_step_span,
|
||||||
|
step_metrics,
|
||||||
)
|
)
|
||||||
in_context_messages = current_in_context_messages + new_in_context_messages
|
in_context_messages = current_in_context_messages + new_in_context_messages
|
||||||
|
|
||||||
@@ -1626,7 +1634,6 @@ class LettaAgent(BaseAgent):
|
|||||||
tool_rules_solver: ToolRulesSolver,
|
tool_rules_solver: ToolRulesSolver,
|
||||||
is_final_step: bool | None,
|
is_final_step: bool | None,
|
||||||
) -> tuple[bool, str | None, LettaStopReason | None]:
|
) -> tuple[bool, str | None, LettaStopReason | None]:
|
||||||
|
|
||||||
continue_stepping = request_heartbeat
|
continue_stepping = request_heartbeat
|
||||||
heartbeat_reason: str | None = None
|
heartbeat_reason: str | None = None
|
||||||
stop_reason: LettaStopReason | None = None
|
stop_reason: LettaStopReason | None = None
|
||||||
@@ -1658,9 +1665,7 @@ class LettaAgent(BaseAgent):
|
|||||||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
||||||
if not continue_stepping and uncalled:
|
if not continue_stepping and uncalled:
|
||||||
continue_stepping = True
|
continue_stepping = True
|
||||||
heartbeat_reason = (
|
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [{', '.join(uncalled)}] to be called still."
|
||||||
f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [" f"{', '.join(uncalled)}] to be called still."
|
|
||||||
)
|
|
||||||
|
|
||||||
stop_reason = None # reset – we’re still going
|
stop_reason = None # reset – we’re still going
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStrea
|
|||||||
# TODO: Please note his is a very generous timeout for e2b reasons
|
# TODO: Please note his is a very generous timeout for e2b reasons
|
||||||
with httpx.Client(timeout=httpx.Timeout(5 * 60.0, read=5 * 60.0)) as client:
|
with httpx.Client(timeout=httpx.Timeout(5 * 60.0, read=5 * 60.0)) as client:
|
||||||
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
|
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
|
||||||
|
|
||||||
# Check for immediate HTTP errors before processing the SSE stream
|
# Check for immediate HTTP errors before processing the SSE stream
|
||||||
if not event_source.response.is_success:
|
if not event_source.response.is_success:
|
||||||
response_bytes = event_source.response.read()
|
response_bytes = event_source.response.read()
|
||||||
|
|||||||
@@ -593,7 +593,6 @@ def generate_tool_schema_for_mcp(
|
|||||||
append_heartbeat: bool = True,
|
append_heartbeat: bool = True,
|
||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
|
||||||
# MCP tool.inputSchema is a JSON schema
|
# MCP tool.inputSchema is a JSON schema
|
||||||
# https://github.com/modelcontextprotocol/python-sdk/blob/775f87981300660ee957b63c2a14b448ab9c3675/src/mcp/types.py#L678
|
# https://github.com/modelcontextprotocol/python-sdk/blob/775f87981300660ee957b63c2a14b448ab9c3675/src/mcp/types.py#L678
|
||||||
parameters_schema = mcp_tool.inputSchema
|
parameters_schema = mcp_tool.inputSchema
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
||||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
|
||||||
from sqlalchemy import Dialect
|
from sqlalchemy import Dialect
|
||||||
|
|
||||||
from letta.functions.mcp_client.types import StdioServerConfig
|
from letta.functions.mcp_client.types import StdioServerConfig
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone as dt_timezone
|
||||||
from datetime import timezone as dt_timezone
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
|
|||||||
@@ -317,7 +317,6 @@ async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int =
|
|||||||
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
async with PineconeAsyncio(api_key=settings.pinecone_api_key) as pc:
|
||||||
description = await pc.describe_index(name=settings.pinecone_source_index)
|
description = await pc.describe_index(name=settings.pinecone_source_index)
|
||||||
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
async with pc.IndexAsyncio(host=description.index.host) as dense_index:
|
||||||
|
|
||||||
kwargs = {"namespace": namespace, "prefix": file_id}
|
kwargs = {"namespace": namespace, "prefix": file_id}
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
kwargs["limit"] = limit
|
kwargs["limit"] = limit
|
||||||
|
|||||||
@@ -198,23 +198,23 @@ class CLIInterface(AgentInterface):
|
|||||||
try:
|
try:
|
||||||
msg_dict = eval(function_args)
|
msg_dict = eval(function_args)
|
||||||
if function_name == "archival_memory_search":
|
if function_name == "archival_memory_search":
|
||||||
output = f'\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
|
output = f"\tquery: {msg_dict['query']}, page: {msg_dict['page']}"
|
||||||
if STRIP_UI:
|
if STRIP_UI:
|
||||||
print(output)
|
print(output)
|
||||||
else:
|
else:
|
||||||
print(f"{Fore.RED}{output}{Style.RESET_ALL}")
|
print(f"{Fore.RED}{output}{Style.RESET_ALL}")
|
||||||
elif function_name == "archival_memory_insert":
|
elif function_name == "archival_memory_insert":
|
||||||
output = f'\t→ {msg_dict["content"]}'
|
output = f"\t→ {msg_dict['content']}"
|
||||||
if STRIP_UI:
|
if STRIP_UI:
|
||||||
print(output)
|
print(output)
|
||||||
else:
|
else:
|
||||||
print(f"{Style.BRIGHT}{Fore.RED}{output}{Style.RESET_ALL}")
|
print(f"{Style.BRIGHT}{Fore.RED}{output}{Style.RESET_ALL}")
|
||||||
else:
|
else:
|
||||||
if STRIP_UI:
|
if STRIP_UI:
|
||||||
print(f'\t {msg_dict["old_content"]}\n\t→ {msg_dict["new_content"]}')
|
print(f"\t {msg_dict['old_content']}\n\t→ {msg_dict['new_content']}")
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f'{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}{Style.RESET_ALL}'
|
f"{Style.BRIGHT}\t{Fore.RED} {msg_dict['old_content']}\n\t{Fore.GREEN}→ {msg_dict['new_content']}{Style.RESET_ALL}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
printd(str(e))
|
printd(str(e))
|
||||||
@@ -223,7 +223,7 @@ class CLIInterface(AgentInterface):
|
|||||||
print_function_message("🧠", f"searching memory with {function_name}")
|
print_function_message("🧠", f"searching memory with {function_name}")
|
||||||
try:
|
try:
|
||||||
msg_dict = eval(function_args)
|
msg_dict = eval(function_args)
|
||||||
output = f'\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
|
output = f"\tquery: {msg_dict['query']}, page: {msg_dict['page']}"
|
||||||
if STRIP_UI:
|
if STRIP_UI:
|
||||||
print(output)
|
print(output)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -232,16 +232,13 @@ class OpenAIStreamingInterface:
|
|||||||
|
|
||||||
# If we have main_json, we should output a ToolCallMessage
|
# If we have main_json, we should output a ToolCallMessage
|
||||||
elif updates_main_json:
|
elif updates_main_json:
|
||||||
|
|
||||||
# If there's something in the function_name buffer, we should release it first
|
# If there's something in the function_name buffer, we should release it first
|
||||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||||
# however the frontend may expect name first, then args, so to be
|
# however the frontend may expect name first, then args, so to be
|
||||||
# safe we'll output name first in a separate chunk
|
# safe we'll output name first in a separate chunk
|
||||||
if self.function_name_buffer:
|
if self.function_name_buffer:
|
||||||
|
|
||||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||||
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
||||||
|
|
||||||
# Store the ID of the tool call so allow skipping the corresponding response
|
# Store the ID of the tool call so allow skipping the corresponding response
|
||||||
if self.function_id_buffer:
|
if self.function_id_buffer:
|
||||||
self.prev_assistant_message_id = self.function_id_buffer
|
self.prev_assistant_message_id = self.function_id_buffer
|
||||||
@@ -373,7 +370,6 @@ class OpenAIStreamingInterface:
|
|||||||
# clear buffers
|
# clear buffers
|
||||||
self.function_id_buffer = None
|
self.function_id_buffer = None
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# There may be a buffer from a previous chunk, for example
|
# There may be a buffer from a previous chunk, for example
|
||||||
# if the previous chunk had arguments but we needed to flush name
|
# if the previous chunk had arguments but we needed to flush name
|
||||||
if self.function_args_buffer:
|
if self.function_args_buffer:
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from anthropic import AsyncStream
|
from anthropic import AsyncStream
|
||||||
from anthropic.types.beta import BetaMessage as AnthropicMessage
|
from anthropic.types.beta import BetaMessage as AnthropicMessage, BetaRawMessageStreamEvent
|
||||||
from anthropic.types.beta import BetaRawMessageStreamEvent
|
|
||||||
from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming
|
from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming
|
||||||
from anthropic.types.beta.messages import BetaMessageBatch
|
from anthropic.types.beta.messages import BetaMessageBatch
|
||||||
from anthropic.types.beta.messages.batch_create_params import Request
|
from anthropic.types.beta.messages.batch_create_params import Request
|
||||||
@@ -34,9 +33,14 @@ from letta.otel.tracing import trace_method
|
|||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall
|
from letta.schemas.openai.chat_completion_response import (
|
||||||
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
|
ChatCompletionResponse,
|
||||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
Choice,
|
||||||
|
FunctionCall,
|
||||||
|
Message as ChoiceMessage,
|
||||||
|
ToolCall,
|
||||||
|
UsageStatistics,
|
||||||
|
)
|
||||||
from letta.settings import model_settings
|
from letta.settings import model_settings
|
||||||
|
|
||||||
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
||||||
@@ -45,7 +49,6 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AnthropicClient(LLMClientBase):
|
class AnthropicClient(LLMClientBase):
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
@deprecated("Synchronous version of this is no longer valid. Will result in model_dump of coroutine")
|
@deprecated("Synchronous version of this is no longer valid. Will result in model_dump of coroutine")
|
||||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from letta.settings import model_settings
|
|||||||
|
|
||||||
|
|
||||||
class AzureClient(OpenAIClient):
|
class AzureClient(OpenAIClient):
|
||||||
|
|
||||||
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
def get_byok_overrides(self, llm_config: LLMConfig) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||||
if llm_config.provider_category == ProviderCategory.byok:
|
if llm_config.provider_category == ProviderCategory.byok:
|
||||||
from letta.services.provider_manager import ProviderManager
|
from letta.services.provider_manager import ProviderManager
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class BedrockClient(AnthropicClient):
|
class BedrockClient(AnthropicClient):
|
||||||
|
|
||||||
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]:
|
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]:
|
||||||
override_access_key_id, override_secret_access_key, override_default_region = None, None, None
|
override_access_key_id, override_secret_access_key, override_default_region = None, None, None
|
||||||
if llm_config.provider_category == ProviderCategory.byok:
|
if llm_config.provider_category == ProviderCategory.byok:
|
||||||
|
|||||||
@@ -11,11 +11,18 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|||||||
from letta.llm_api.openai_client import OpenAIClient
|
from letta.llm_api.openai_client import OpenAIClient
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage, Message as _Message
|
||||||
from letta.schemas.message import Message as _Message
|
from letta.schemas.openai.chat_completion_request import (
|
||||||
from letta.schemas.openai.chat_completion_request import AssistantMessage, ChatCompletionRequest, ChatMessage
|
AssistantMessage,
|
||||||
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
ChatCompletionRequest,
|
||||||
from letta.schemas.openai.chat_completion_request import Tool, ToolFunctionChoice, ToolMessage, UserMessage, cast_message_to_subtype
|
ChatMessage,
|
||||||
|
FunctionCall as ToolFunctionChoiceFunctionCall,
|
||||||
|
Tool,
|
||||||
|
ToolFunctionChoice,
|
||||||
|
ToolMessage,
|
||||||
|
UserMessage,
|
||||||
|
cast_message_to_subtype,
|
||||||
|
)
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||||
from letta.schemas.openai.openai import Function, ToolCall
|
from letta.schemas.openai.openai import Function, ToolCall
|
||||||
from letta.settings import model_settings
|
from letta.settings import model_settings
|
||||||
@@ -313,7 +320,6 @@ def convert_deepseek_response_to_chatcompletion(
|
|||||||
|
|
||||||
|
|
||||||
class DeepseekClient(OpenAIClient):
|
class DeepseekClient(OpenAIClient):
|
||||||
|
|
||||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class GoogleVertexClient(LLMClientBase):
|
class GoogleVertexClient(LLMClientBase):
|
||||||
|
|
||||||
def _get_client(self):
|
def _get_client(self):
|
||||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||||
return genai.Client(
|
return genai.Client(
|
||||||
@@ -344,9 +343,9 @@ class GoogleVertexClient(LLMClientBase):
|
|||||||
if llm_config.put_inner_thoughts_in_kwargs:
|
if llm_config.put_inner_thoughts_in_kwargs:
|
||||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
||||||
|
|
||||||
assert (
|
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
|
||||||
INNER_THOUGHTS_KWARG_VERTEX in function_args
|
f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||||
), f"Couldn't find inner thoughts in function args:\n{function_call}"
|
)
|
||||||
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
||||||
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
||||||
else:
|
else:
|
||||||
@@ -380,9 +379,9 @@ class GoogleVertexClient(LLMClientBase):
|
|||||||
if llm_config.put_inner_thoughts_in_kwargs:
|
if llm_config.put_inner_thoughts_in_kwargs:
|
||||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
||||||
|
|
||||||
assert (
|
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
|
||||||
INNER_THOUGHTS_KWARG_VERTEX in function_args
|
f"Couldn't find inner thoughts in function args:\n{function_call}"
|
||||||
), f"Couldn't find inner thoughts in function args:\n{function_call}"
|
)
|
||||||
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
||||||
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
||||||
else:
|
else:
|
||||||
@@ -406,7 +405,7 @@ class GoogleVertexClient(LLMClientBase):
|
|||||||
|
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
if candidate.finish_reason == "MAX_TOKENS":
|
if candidate.finish_reason == "MAX_TOKENS":
|
||||||
raise ValueError(f"Could not parse response data from LLM: exceeded max token limit")
|
raise ValueError("Could not parse response data from LLM: exceeded max token limit")
|
||||||
# Inner thoughts are the content by default
|
# Inner thoughts are the content by default
|
||||||
inner_thoughts = response_message.text
|
inner_thoughts = response_message.text
|
||||||
|
|
||||||
@@ -463,7 +462,7 @@ class GoogleVertexClient(LLMClientBase):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Count it ourselves
|
# Count it ourselves
|
||||||
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
|
assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required"
|
||||||
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
|
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
|
||||||
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
|
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from letta.settings import model_settings
|
|||||||
|
|
||||||
|
|
||||||
class GroqClient(OpenAIClient):
|
class GroqClient(OpenAIClient):
|
||||||
|
|
||||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ async def mistral_get_model_list_async(url: str, api_key: str) -> dict:
|
|||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
logger.debug(f"Sending request to %s", url)
|
logger.debug("Sending request to %s", url)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
# TODO add query param "tool" to be true
|
# TODO add query param "tool" to be true
|
||||||
|
|||||||
@@ -21,11 +21,15 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes
|
|||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.otel.tracing import log_event
|
from letta.otel.tracing import log_event
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message as _Message
|
from letta.schemas.message import Message as _Message, MessageRole as _MessageRole
|
||||||
from letta.schemas.message import MessageRole as _MessageRole
|
from letta.schemas.openai.chat_completion_request import (
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
ChatCompletionRequest,
|
||||||
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
FunctionCall as ToolFunctionChoiceFunctionCall,
|
||||||
from letta.schemas.openai.chat_completion_request import FunctionSchema, Tool, ToolFunctionChoice, cast_message_to_subtype
|
FunctionSchema,
|
||||||
|
Tool,
|
||||||
|
ToolFunctionChoice,
|
||||||
|
cast_message_to_subtype,
|
||||||
|
)
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import (
|
||||||
ChatCompletionChunkResponse,
|
ChatCompletionChunkResponse,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
|||||||
@@ -29,11 +29,14 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
|||||||
from letta.schemas.letta_message_content import MessageContentType
|
from letta.schemas.letta_message_content import MessageContentType
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
from letta.schemas.openai.chat_completion_request import (
|
||||||
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
ChatCompletionRequest,
|
||||||
from letta.schemas.openai.chat_completion_request import FunctionSchema
|
FunctionCall as ToolFunctionChoiceFunctionCall,
|
||||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
FunctionSchema,
|
||||||
from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype
|
Tool as OpenAITool,
|
||||||
|
ToolFunctionChoice,
|
||||||
|
cast_message_to_subtype,
|
||||||
|
)
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||||
from letta.settings import model_settings
|
from letta.settings import model_settings
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from letta.settings import model_settings
|
|||||||
|
|
||||||
|
|
||||||
class TogetherClient(OpenAIClient):
|
class TogetherClient(OpenAIClient):
|
||||||
|
|
||||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from letta.settings import model_settings
|
|||||||
|
|
||||||
|
|
||||||
class XAIClient(OpenAIClient):
|
class XAIClient(OpenAIClient):
|
||||||
|
|
||||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ def get_chat_completion(
|
|||||||
raise LocalLLMError(f"usage dict in response was missing fields ({usage})")
|
raise LocalLLMError(f"usage dict in response was missing fields ({usage})")
|
||||||
|
|
||||||
if usage["prompt_tokens"] is None:
|
if usage["prompt_tokens"] is None:
|
||||||
printd(f"usage dict was missing prompt_tokens, computing on-the-fly...")
|
printd("usage dict was missing prompt_tokens, computing on-the-fly...")
|
||||||
usage["prompt_tokens"] = count_tokens(prompt)
|
usage["prompt_tokens"] = count_tokens(prompt)
|
||||||
|
|
||||||
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
|
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
|
||||||
@@ -220,7 +220,7 @@ def get_chat_completion(
|
|||||||
|
|
||||||
# NOTE: this is the token count that matters most
|
# NOTE: this is the token count that matters most
|
||||||
if usage["total_tokens"] is None:
|
if usage["total_tokens"] is None:
|
||||||
printd(f"usage dict was missing total_tokens, computing on-the-fly...")
|
printd("usage dict was missing total_tokens, computing on-the-fly...")
|
||||||
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
|
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]
|
||||||
|
|
||||||
# unpack with response.choices[0].message.content
|
# unpack with response.choices[0].message.content
|
||||||
@@ -261,9 +261,9 @@ def generate_grammar_and_documentation(
|
|||||||
):
|
):
|
||||||
from letta.utils import printd
|
from letta.utils import printd
|
||||||
|
|
||||||
assert not (
|
assert not (add_inner_thoughts_top_level and add_inner_thoughts_param_level), (
|
||||||
add_inner_thoughts_top_level and add_inner_thoughts_param_level
|
"Can only place inner thoughts in one location in the grammar generator"
|
||||||
), "Can only place inner thoughts in one location in the grammar generator"
|
)
|
||||||
|
|
||||||
grammar_function_models = []
|
grammar_function_models = []
|
||||||
# create_dynamic_model_from_function will add inner thoughts to the function parameters if add_inner_thoughts is True.
|
# create_dynamic_model_from_function will add inner thoughts to the function parameters if add_inner_thoughts is True.
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def get_completions_settings(defaults="simple") -> dict:
|
|||||||
with open(settings_file, "r", encoding="utf-8") as file:
|
with open(settings_file, "r", encoding="utf-8") as file:
|
||||||
user_settings = json.load(file)
|
user_settings = json.load(file)
|
||||||
if len(user_settings) > 0:
|
if len(user_settings) > 0:
|
||||||
printd(f"Updating base settings with the following user settings:\n{json_dumps(user_settings,indent=2)}")
|
printd(f"Updating base settings with the following user settings:\n{json_dumps(user_settings, indent=2)}")
|
||||||
settings.update(user_settings)
|
settings.update(user_settings)
|
||||||
else:
|
else:
|
||||||
printd(f"'{settings_file}' was empty, ignoring...")
|
printd(f"'{settings_file}' was empty, ignoring...")
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ from letta.orm.identity import Identity
|
|||||||
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
|
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
from letta.schemas.agent import AgentState as PydanticAgentState, AgentType, get_prompt_template_for_agent_type
|
||||||
from letta.schemas.agent import AgentType, get_prompt_template_for_agent_type
|
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.memory import Memory
|
from letta.schemas.memory import Memory
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ from letta.orm.block_history import BlockHistory
|
|||||||
from letta.orm.blocks_agents import BlocksAgents
|
from letta.orm.blocks_agents import BlocksAgents
|
||||||
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
|
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.block import Block as PydanticBlock
|
from letta.schemas.block import Block as PydanticBlock, Human, Persona
|
||||||
from letta.schemas.block import Human, Persona
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm import Organization
|
from letta.orm import Organization
|
||||||
|
|||||||
@@ -38,7 +38,9 @@ class BlockHistory(OrganizationMixin, SqlalchemyBase):
|
|||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
block_id: Mapped[str] = mapped_column(
|
block_id: Mapped[str] = mapped_column(
|
||||||
String, ForeignKey("block.id", ondelete="CASCADE"), nullable=False # History deleted if Block is deleted
|
String,
|
||||||
|
ForeignKey("block.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, # History deleted if Block is deleted
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_number: Mapped[int] = mapped_column(
|
sequence_number: Mapped[int] = mapped_column(
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from letta.schemas.group import Group as PydanticGroup
|
|||||||
|
|
||||||
|
|
||||||
class Group(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
class Group(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
||||||
|
|
||||||
__tablename__ = "groups"
|
__tablename__ = "groups"
|
||||||
__pydantic_model__ = PydanticGroup
|
__pydantic_model__ = PydanticGroup
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|||||||
|
|
||||||
from letta.orm.mixins import OrganizationMixin, ProjectMixin
|
from letta.orm.mixins import OrganizationMixin, ProjectMixin
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.identity import Identity as PydanticIdentity
|
from letta.schemas.identity import Identity as PydanticIdentity, IdentityProperty
|
||||||
from letta.schemas.identity import IdentityProperty
|
|
||||||
|
|
||||||
|
|
||||||
class Identity(SqlalchemyBase, OrganizationMixin, ProjectMixin):
|
class Identity(SqlalchemyBase, OrganizationMixin, ProjectMixin):
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|||||||
from letta.orm.mixins import UserMixin
|
from letta.orm.mixins import UserMixin
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.enums import JobStatus, JobType
|
from letta.schemas.enums import JobStatus, JobType
|
||||||
from letta.schemas.job import Job as PydanticJob
|
from letta.schemas.job import Job as PydanticJob, LettaRequestConfig
|
||||||
from letta.schemas.job import LettaRequestConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm.job_messages import JobMessage
|
from letta.orm.job_messages import JobMessage
|
||||||
|
|||||||
@@ -9,8 +9,7 @@ from letta.orm.custom_columns import AgentStepStateColumn, BatchRequestResultCol
|
|||||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.enums import AgentStepStatus, JobStatus
|
from letta.schemas.enums import AgentStepStatus, JobStatus
|
||||||
from letta.schemas.llm_batch_job import AgentStepState
|
from letta.schemas.llm_batch_job import AgentStepState, LLMBatchItem as PydanticLLMBatchItem
|
||||||
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
|
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,10 +7,8 @@ from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
|||||||
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
||||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.letta_message_content import MessageContent
|
from letta.schemas.letta_message_content import MessageContent, TextContent as PydanticTextContent
|
||||||
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
from letta.schemas.message import Message as PydanticMessage, ToolReturn
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
|
||||||
from letta.schemas.message import ToolReturn
|
|
||||||
from letta.settings import DatabaseChoice, settings
|
from letta.settings import DatabaseChoice, settings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import JSON
|
from sqlalchemy import JSON, Enum as SqlEnum, Index, String, UniqueConstraint
|
||||||
from sqlalchemy import Enum as SqlEnum
|
|
||||||
from sqlalchemy import Index, String, UniqueConstraint
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from letta.orm.mixins import AgentMixin, OrganizationMixin, SandboxConfigMixin
|
from letta.orm.mixins import AgentMixin, OrganizationMixin, SandboxConfigMixin
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def get_plugin(plugin_type: str):
|
|||||||
return plugin
|
return plugin
|
||||||
elif type(plugin).__name__ == "class":
|
elif type(plugin).__name__ == "class":
|
||||||
if plugin_register["protocol"] and not isinstance(plugin, type(plugin_register["protocol"])):
|
if plugin_register["protocol"] and not isinstance(plugin, type(plugin_register["protocol"])):
|
||||||
raise TypeError(f'{plugin} does not implement {type(plugin_register["protocol"]).__name__}')
|
raise TypeError(f"{plugin} does not implement {type(plugin_register['protocol']).__name__}")
|
||||||
return plugin()
|
return plugin()
|
||||||
raise TypeError("Unknown plugin type")
|
raise TypeError("Unknown plugin type")
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from letta.schemas.memory import Memory
|
|||||||
|
|
||||||
|
|
||||||
class PromptGenerator:
|
class PromptGenerator:
|
||||||
|
|
||||||
# TODO: This code is kind of wonky and deserves a rewrite
|
# TODO: This code is kind of wonky and deserves a rewrite
|
||||||
@trace_method
|
@trace_method
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ class EmbeddingConfig(BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_config(cls, model_name: Optional[str] = None, provider: Optional[str] = None):
|
def default_config(cls, model_name: Optional[str] = None, provider: Optional[str] = None):
|
||||||
|
|
||||||
if model_name == "text-embedding-ada-002" and provider == "openai":
|
if model_name == "text-embedding-ada-002" and provider == "openai":
|
||||||
return cls(
|
return cls(
|
||||||
embedding_model="text-embedding-ada-002",
|
embedding_model="text-embedding-ada-002",
|
||||||
|
|||||||
@@ -9,8 +9,7 @@ from collections import OrderedDict
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
||||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
|
||||||
@@ -880,7 +879,6 @@ class Message(BaseMessage):
|
|||||||
# Tool calling
|
# Tool calling
|
||||||
if self.tool_calls is not None:
|
if self.tool_calls is not None:
|
||||||
for tool_call in self.tool_calls:
|
for tool_call in self.tool_calls:
|
||||||
|
|
||||||
if put_inner_thoughts_in_kwargs:
|
if put_inner_thoughts_in_kwargs:
|
||||||
tool_call_input = add_inner_thoughts_to_tool_call(
|
tool_call_input = add_inner_thoughts_to_tool_call(
|
||||||
tool_call,
|
tool_call,
|
||||||
@@ -1021,7 +1019,7 @@ class Message(BaseMessage):
|
|||||||
assert self.tool_call_id is not None, vars(self)
|
assert self.tool_call_id is not None, vars(self)
|
||||||
|
|
||||||
if self.name is None:
|
if self.name is None:
|
||||||
warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.")
|
warnings.warn("Couldn't find function name on tool call, defaulting to tool ID instead.")
|
||||||
function_name = self.tool_call_id
|
function_name = self.tool_call_id
|
||||||
else:
|
else:
|
||||||
function_name = self.name
|
function_name = self.name
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class BedrockProvider(Provider):
|
|||||||
response = await bedrock.list_inference_profiles()
|
response = await bedrock.list_inference_profiles()
|
||||||
return response["inferenceProfileSummaries"]
|
return response["inferenceProfileSummaries"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting model list for bedrock: %s", e)
|
logger.error("Error getting model list for bedrock: %s", e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def check_api_key(self):
|
async def check_api_key(self):
|
||||||
|
|||||||
@@ -203,7 +203,7 @@ class OpenAIProvider(Provider):
|
|||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Skipping embedding models for %s by default, as we don't assume embeddings are supported."
|
"Skipping embedding models for %s by default, as we don't assume embeddings are supported."
|
||||||
"Please open an issue on GitHub if support is required.",
|
"Please open an issue on GitHub if support is required.",
|
||||||
self.base_url,
|
self.base_url,
|
||||||
)
|
)
|
||||||
@@ -227,7 +227,7 @@ class OpenAIProvider(Provider):
|
|||||||
return LLM_MAX_TOKENS[model_name]
|
return LLM_MAX_TOKENS[model_name]
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Model %s on %s for provider %s not found in LLM_MAX_TOKENS. Using default of {{LLM_MAX_TOKENS['DEFAULT']}}",
|
"Model %s on %s for provider %s not found in LLM_MAX_TOKENS. Using default of {LLM_MAX_TOKENS['DEFAULT']}",
|
||||||
model_name,
|
model_name,
|
||||||
self.base_url,
|
self.base_url,
|
||||||
self.__class__.__name__,
|
self.__class__.__name__,
|
||||||
|
|||||||
@@ -218,9 +218,9 @@ class ToolCreate(LettaBase):
|
|||||||
composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False)
|
composio_action_schemas = composio_toolset.get_action_schemas(actions=[action_name], check_connected_accounts=False)
|
||||||
|
|
||||||
assert len(composio_action_schemas) > 0, "User supplied parameters do not match any Composio tools"
|
assert len(composio_action_schemas) > 0, "User supplied parameters do not match any Composio tools"
|
||||||
assert (
|
assert len(composio_action_schemas) == 1, (
|
||||||
len(composio_action_schemas) == 1
|
f"User supplied parameters match too many Composio tools; {len(composio_action_schemas)} > 1"
|
||||||
), f"User supplied parameters match too many Composio tools; {len(composio_action_schemas)} > 1"
|
)
|
||||||
|
|
||||||
composio_action_schema = composio_action_schemas[0]
|
composio_action_schema = composio_action_schemas[0]
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from letta.schemas.agent import AgentState
|
|||||||
|
|
||||||
|
|
||||||
class ToolExecutionResult(BaseModel):
|
class ToolExecutionResult(BaseModel):
|
||||||
|
|
||||||
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
||||||
func_return: Optional[Any] = Field(None, description="The function return object")
|
func_return: Optional[Any] = Field(None, description="The function return object")
|
||||||
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ from sqlalchemy import func
|
|||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
import letta
|
import letta
|
||||||
from letta.orm import Agent
|
from letta.orm import Agent, Message as MessageModel
|
||||||
from letta.orm import Message as MessageModel
|
|
||||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.serialize_schemas.marshmallow_agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
from letta.serialize_schemas.marshmallow_agent_environment_variable import SerializedAgentEnvironmentVariableSchema
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ def create_application() -> "FastAPI":
|
|||||||
|
|
||||||
@app.exception_handler(BedrockPermissionError)
|
@app.exception_handler(BedrockPermissionError)
|
||||||
async def bedrock_permission_error_handler(request, exc: BedrockPermissionError):
|
async def bedrock_permission_error_handler(request, exc: BedrockPermissionError):
|
||||||
logger.error(f"Bedrock permission denied.")
|
logger.error("Bedrock permission denied.")
|
||||||
if SENTRY_ENABLED:
|
if SENTRY_ENABLED:
|
||||||
sentry_sdk.capture_exception(exc)
|
sentry_sdk.capture_exception(exc)
|
||||||
|
|
||||||
@@ -433,10 +433,10 @@ def start_server(
|
|||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
# Windows doesn't those the fancy unicode characters
|
# Windows doesn't those the fancy unicode characters
|
||||||
print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
print(f"Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
||||||
print(f"View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
print("View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
||||||
else:
|
else:
|
||||||
print(f"▶ Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
print(f"▶ Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}")
|
||||||
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
print("▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard\n")
|
||||||
|
|
||||||
if importlib.util.find_spec("granian") is not None and settings.use_granian:
|
if importlib.util.find_spec("granian") is not None and settings.use_granian:
|
||||||
# Experimental Granian engine
|
# Experimental Granian engine
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ class AuthRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def setup_auth_router(server: SyncServer, interface: QueuingInterface, password: str) -> APIRouter:
|
def setup_auth_router(server: SyncServer, interface: QueuingInterface, password: str) -> APIRouter:
|
||||||
|
|
||||||
@router.post("/auth", tags=["auth"], response_model=AuthResponse)
|
@router.post("/auth", tags=["auth"], response_model=AuthResponse)
|
||||||
def authenticate_user(request: AuthRequest) -> AuthResponse:
|
def authenticate_user(request: AuthRequest) -> AuthResponse:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -377,9 +377,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
):
|
):
|
||||||
"""Add an item to the deque"""
|
"""Add an item to the deque"""
|
||||||
assert self._active, "Generator is inactive"
|
assert self._active, "Generator is inactive"
|
||||||
assert (
|
assert isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or isinstance(item, MessageStreamStatus), (
|
||||||
isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or isinstance(item, MessageStreamStatus)
|
f"Wrong type: {type(item)}"
|
||||||
), f"Wrong type: {type(item)}"
|
)
|
||||||
|
|
||||||
self._chunks.append(item)
|
self._chunks.append(item)
|
||||||
self._event.set() # Signal that new data is available
|
self._event.set() # Signal that new data is available
|
||||||
@@ -731,13 +731,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
|
|
||||||
# If we have main_json, we should output a ToolCallMessage
|
# If we have main_json, we should output a ToolCallMessage
|
||||||
elif updates_main_json:
|
elif updates_main_json:
|
||||||
|
|
||||||
# If there's something in the function_name buffer, we should release it first
|
# If there's something in the function_name buffer, we should release it first
|
||||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||||
# however the frontend may expect name first, then args, so to be
|
# however the frontend may expect name first, then args, so to be
|
||||||
# safe we'll output name first in a separate chunk
|
# safe we'll output name first in a separate chunk
|
||||||
if self.function_name_buffer:
|
if self.function_name_buffer:
|
||||||
|
|
||||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||||
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
||||||
processed_chunk = None
|
processed_chunk = None
|
||||||
@@ -778,7 +776,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
# If there was nothing in the name buffer, we can proceed to
|
# If there was nothing in the name buffer, we can proceed to
|
||||||
# output the arguments chunk as a ToolCallMessage
|
# output the arguments chunk as a ToolCallMessage
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||||
if self.use_assistant_message and (
|
if self.use_assistant_message and (
|
||||||
self.last_flushed_function_name is not None
|
self.last_flushed_function_name is not None
|
||||||
@@ -860,7 +857,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
# clear buffers
|
# clear buffers
|
||||||
self.function_id_buffer = None
|
self.function_id_buffer = None
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# There may be a buffer from a previous chunk, for example
|
# There may be a buffer from a previous chunk, for example
|
||||||
# if the previous chunk had arguments but we needed to flush name
|
# if the previous chunk had arguments but we needed to flush name
|
||||||
if self.function_args_buffer:
|
if self.function_args_buffer:
|
||||||
@@ -997,7 +993,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
# Otherwise, do simple chunks of ToolCallMessage
|
# Otherwise, do simple chunks of ToolCallMessage
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
tool_call_delta = {}
|
tool_call_delta = {}
|
||||||
if tool_call.id:
|
if tool_call.id:
|
||||||
tool_call_delta["id"] = tool_call.id
|
tool_call_delta["id"] = tool_call.id
|
||||||
@@ -1073,7 +1068,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
tool_call = message_delta.tool_calls[0]
|
tool_call = message_delta.tool_calls[0]
|
||||||
|
|
||||||
if tool_call.function:
|
if tool_call.function:
|
||||||
|
|
||||||
# Track the function name while streaming
|
# Track the function name while streaming
|
||||||
# If we were previously on a 'send_message', we need to 'toggle' into 'content' mode
|
# If we were previously on a 'send_message', we need to 'toggle' into 'content' mode
|
||||||
if tool_call.function.name:
|
if tool_call.function.name:
|
||||||
@@ -1154,7 +1148,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||||
"""Letta generates some internal monologue"""
|
"""Letta generates some internal monologue"""
|
||||||
if not self.streaming_mode:
|
if not self.streaming_mode:
|
||||||
|
|
||||||
# create a fake "chunk" of a stream
|
# create a fake "chunk" of a stream
|
||||||
# processed_chunk = {
|
# processed_chunk = {
|
||||||
# "internal_monologue": msg,
|
# "internal_monologue": msg,
|
||||||
@@ -1268,7 +1261,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|||||||
print(f"Failed to parse function message: {e}")
|
print(f"Failed to parse function message: {e}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
func_args = parse_json(function_call.function.arguments)
|
func_args = parse_json(function_call.function.arguments)
|
||||||
except:
|
except:
|
||||||
|
|||||||
@@ -140,9 +140,7 @@ class RedisSSEStreamWriter:
|
|||||||
|
|
||||||
self.last_flush[run_id] = time.time()
|
self.last_flush[run_id] = time.time()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}")
|
||||||
f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, " f"seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if chunks[-1].get("complete") == "true":
|
if chunks[-1].get("complete") == "true":
|
||||||
self._cleanup_run(run_id)
|
self._cleanup_run(run_id)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ async def list_blocks(
|
|||||||
),
|
),
|
||||||
label_search: Optional[str] = Query(
|
label_search: Optional[str] = Query(
|
||||||
None,
|
None,
|
||||||
description=("Search blocks by label. If provided, returns blocks that match this label. " "This is a full-text search on labels."),
|
description=("Search blocks by label. If provided, returns blocks that match this label. This is a full-text search on labels."),
|
||||||
),
|
),
|
||||||
description_search: Optional[str] = Query(
|
description_search: Optional[str] = Query(
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -6,11 +6,17 @@ from fastapi import APIRouter, Depends, HTTPException, Query
|
|||||||
|
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.schemas.enums import SandboxType
|
from letta.schemas.enums import SandboxType
|
||||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
|
from letta.schemas.environment_variables import (
|
||||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
SandboxEnvironmentVariable as PydanticEnvVar,
|
||||||
from letta.schemas.sandbox_config import LocalSandboxConfig
|
SandboxEnvironmentVariableCreate,
|
||||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
SandboxEnvironmentVariableUpdate,
|
||||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
)
|
||||||
|
from letta.schemas.sandbox_config import (
|
||||||
|
LocalSandboxConfig,
|
||||||
|
SandboxConfig as PydanticSandboxConfig,
|
||||||
|
SandboxConfigCreate,
|
||||||
|
SandboxConfigUpdate,
|
||||||
|
)
|
||||||
from letta.server.rest_api.utils import get_letta_server, get_user_id
|
from letta.server.rest_api.utils import get_letta_server, get_user_id
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
from letta.services.helpers.tool_execution_helper import create_venv_for_local_sandbox, install_pip_requirements_for_sandbox
|
from letta.services.helpers.tool_execution_helper import create_venv_for_local_sandbox, install_pip_requirements_for_sandbox
|
||||||
|
|||||||
@@ -749,8 +749,8 @@ async def connect_mcp_server(
|
|||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
|
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
|
||||||
if isinstance(client, AsyncStdioMCPClient):
|
if isinstance(client, AsyncStdioMCPClient):
|
||||||
logger.warning(f"OAuth not supported for stdio")
|
logger.warning("OAuth not supported for stdio")
|
||||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"OAuth not supported for stdio")
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
|
||||||
return
|
return
|
||||||
# Continue to OAuth flow
|
# Continue to OAuth flow
|
||||||
logger.info(f"Attempting OAuth flow for {request}...")
|
logger.info(f"Attempting OAuth flow for {request}...")
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ class StreamingResponseWithStatusCode(StreamingResponse):
|
|||||||
try:
|
try:
|
||||||
await asyncio.shield(self._protected_stream_response(send))
|
await asyncio.shield(self._protected_stream_response(send))
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"Stream response was cancelled, but shielded task should continue")
|
logger.info("Stream response was cancelled, but shielded task should continue")
|
||||||
except anyio.ClosedResourceError:
|
except anyio.ClosedResourceError:
|
||||||
logger.info("Client disconnected, but shielded task should continue")
|
logger.info("Client disconnected, but shielded task should continue")
|
||||||
self._client_connected = False
|
self._client_connected = False
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional
|
|||||||
|
|
||||||
from fastapi import Header, HTTPException
|
from fastapi import Header, HTTPException
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
||||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
|
||||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|||||||
@@ -30,8 +30,10 @@ from letta.helpers.datetime_helpers import get_utc_time
|
|||||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||||
|
|
||||||
# TODO use custom interface
|
# TODO use custom interface
|
||||||
from letta.interface import AgentInterface # abstract
|
from letta.interface import (
|
||||||
from letta.interface import CLIInterface # for printing to terminal
|
AgentInterface, # abstract
|
||||||
|
CLIInterface, # for printing to terminal
|
||||||
|
)
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.otel.tracing import log_event, trace_method
|
from letta.otel.tracing import log_event, trace_method
|
||||||
|
|||||||
@@ -26,37 +26,42 @@ from letta.helpers import ToolRulesSolver
|
|||||||
from letta.helpers.datetime_helpers import get_utc_time
|
from letta.helpers.datetime_helpers import get_utc_time
|
||||||
from letta.llm_api.llm_client import LLMClient
|
from letta.llm_api.llm_client import LLMClient
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm import Agent as AgentModel
|
from letta.orm import (
|
||||||
from letta.orm import AgentsTags, ArchivalPassage
|
Agent as AgentModel,
|
||||||
from letta.orm import Block as BlockModel
|
AgentsTags,
|
||||||
from letta.orm import BlocksAgents
|
ArchivalPassage,
|
||||||
from letta.orm import Group as GroupModel
|
Block as BlockModel,
|
||||||
from letta.orm import GroupsAgents, IdentitiesAgents
|
BlocksAgents,
|
||||||
from letta.orm import Source as SourceModel
|
Group as GroupModel,
|
||||||
from letta.orm import SourcePassage, SourcesAgents
|
GroupsAgents,
|
||||||
from letta.orm import Tool as ToolModel
|
IdentitiesAgents,
|
||||||
from letta.orm import ToolsAgents
|
Source as SourceModel,
|
||||||
|
SourcePassage,
|
||||||
|
SourcesAgents,
|
||||||
|
Tool as ToolModel,
|
||||||
|
ToolsAgents,
|
||||||
|
)
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
from letta.orm.sandbox_config import AgentEnvironmentVariable, AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||||||
from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
|
||||||
from letta.orm.sqlalchemy_base import AccessType
|
from letta.orm.sqlalchemy_base import AccessType
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.prompts.prompt_generator import PromptGenerator
|
from letta.prompts.prompt_generator import PromptGenerator
|
||||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
from letta.schemas.agent import (
|
||||||
from letta.schemas.agent import AgentType, CreateAgent, InternalTemplateAgentCreate, UpdateAgent, get_prompt_template_for_agent_type
|
AgentState as PydanticAgentState,
|
||||||
from letta.schemas.block import DEFAULT_BLOCKS
|
AgentType,
|
||||||
from letta.schemas.block import Block as PydanticBlock
|
CreateAgent,
|
||||||
from letta.schemas.block import BlockUpdate
|
InternalTemplateAgentCreate,
|
||||||
|
UpdateAgent,
|
||||||
|
get_prompt_template_for_agent_type,
|
||||||
|
)
|
||||||
|
from letta.schemas.block import DEFAULT_BLOCKS, Block as PydanticBlock, BlockUpdate
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.enums import ProviderType, TagMatchMode, ToolType, VectorDBProvider
|
from letta.schemas.enums import ProviderType, TagMatchMode, ToolType, VectorDBProvider
|
||||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||||
from letta.schemas.group import Group as PydanticGroup
|
from letta.schemas.group import Group as PydanticGroup, ManagerType
|
||||||
from letta.schemas.group import ManagerType
|
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message, Message as PydanticMessage, MessageCreate, MessageUpdate
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
|
||||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
|
||||||
from letta.schemas.passage import Passage as PydanticPassage
|
from letta.schemas.passage import Passage as PydanticPassage
|
||||||
from letta.schemas.source import Source as PydanticSource
|
from letta.schemas.source import Source as PydanticSource
|
||||||
from letta.schemas.tool import Tool as PydanticTool
|
from letta.schemas.tool import Tool as PydanticTool
|
||||||
@@ -493,7 +498,6 @@ class AgentManager:
|
|||||||
# blocks
|
# blocks
|
||||||
block_ids = list(agent_create.block_ids or [])
|
block_ids = list(agent_create.block_ids or [])
|
||||||
if agent_create.memory_blocks:
|
if agent_create.memory_blocks:
|
||||||
|
|
||||||
pydantic_blocks = [PydanticBlock(**b.model_dump(to_orm=True)) for b in agent_create.memory_blocks]
|
pydantic_blocks = [PydanticBlock(**b.model_dump(to_orm=True)) for b in agent_create.memory_blocks]
|
||||||
|
|
||||||
# Inject a description for the default blocks if the user didn't specify them
|
# Inject a description for the default blocks if the user didn't specify them
|
||||||
@@ -798,7 +802,6 @@ class AgentManager:
|
|||||||
agent_update: UpdateAgent,
|
agent_update: UpdateAgent,
|
||||||
actor: PydanticUser,
|
actor: PydanticUser,
|
||||||
) -> PydanticAgentState:
|
) -> PydanticAgentState:
|
||||||
|
|
||||||
new_tools = set(agent_update.tool_ids or [])
|
new_tools = set(agent_update.tool_ids or [])
|
||||||
new_sources = set(agent_update.source_ids or [])
|
new_sources = set(agent_update.source_ids or [])
|
||||||
new_blocks = set(agent_update.block_ids or [])
|
new_blocks = set(agent_update.block_ids or [])
|
||||||
@@ -806,7 +809,6 @@ class AgentManager:
|
|||||||
new_tags = set(agent_update.tags or [])
|
new_tags = set(agent_update.tags or [])
|
||||||
|
|
||||||
with db_registry.session() as session, session.begin():
|
with db_registry.session() as session, session.begin():
|
||||||
|
|
||||||
agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||||
agent.updated_at = datetime.now(timezone.utc)
|
agent.updated_at = datetime.now(timezone.utc)
|
||||||
agent.last_updated_by_id = actor.id
|
agent.last_updated_by_id = actor.id
|
||||||
@@ -923,7 +925,6 @@ class AgentManager:
|
|||||||
agent_update: UpdateAgent,
|
agent_update: UpdateAgent,
|
||||||
actor: PydanticUser,
|
actor: PydanticUser,
|
||||||
) -> PydanticAgentState:
|
) -> PydanticAgentState:
|
||||||
|
|
||||||
new_tools = set(agent_update.tool_ids or [])
|
new_tools = set(agent_update.tool_ids or [])
|
||||||
new_sources = set(agent_update.source_ids or [])
|
new_sources = set(agent_update.source_ids or [])
|
||||||
new_blocks = set(agent_update.block_ids or [])
|
new_blocks = set(agent_update.block_ids or [])
|
||||||
@@ -931,7 +932,6 @@ class AgentManager:
|
|||||||
new_tags = set(agent_update.tags or [])
|
new_tags = set(agent_update.tags or [])
|
||||||
|
|
||||||
async with db_registry.async_session() as session, session.begin():
|
async with db_registry.async_session() as session, session.begin():
|
||||||
|
|
||||||
agent: AgentModel = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
agent: AgentModel = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||||
agent.updated_at = datetime.now(timezone.utc)
|
agent.updated_at = datetime.now(timezone.utc)
|
||||||
agent.last_updated_by_id = actor.id
|
agent.last_updated_by_id = actor.id
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ from sqlalchemy import select
|
|||||||
|
|
||||||
from letta.helpers.tpuf_client import should_use_tpuf
|
from letta.helpers.tpuf_client import should_use_tpuf
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm import ArchivalPassage
|
from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
|
||||||
from letta.orm import Archive as ArchiveModel
|
|
||||||
from letta.orm import ArchivesAgents
|
|
||||||
from letta.schemas.archive import Archive as PydanticArchive
|
from letta.schemas.archive import Archive as PydanticArchive
|
||||||
from letta.schemas.enums import VectorDBProvider
|
from letta.schemas.enums import VectorDBProvider
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ from letta.orm.blocks_agents import BlocksAgents
|
|||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||||
from letta.schemas.block import Block as PydanticBlock
|
from letta.schemas.block import Block as PydanticBlock, BlockUpdate
|
||||||
from letta.schemas.block import BlockUpdate
|
|
||||||
from letta.schemas.enums import ActorType
|
from letta.schemas.enums import ActorType
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
|
|||||||
@@ -50,7 +50,8 @@ class AnthropicTokenCounter(TokenCounter):
|
|||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
@async_redis_cache(
|
@async_redis_cache(
|
||||||
key_func=lambda self, messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
key_func=lambda self,
|
||||||
|
messages: f"anthropic_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||||
prefix="token_counter",
|
prefix="token_counter",
|
||||||
ttl_s=3600, # cache for 1 hour
|
ttl_s=3600, # cache for 1 hour
|
||||||
)
|
)
|
||||||
@@ -61,7 +62,8 @@ class AnthropicTokenCounter(TokenCounter):
|
|||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
@async_redis_cache(
|
@async_redis_cache(
|
||||||
key_func=lambda self, tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
key_func=lambda self,
|
||||||
|
tools: f"anthropic_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||||
prefix="token_counter",
|
prefix="token_counter",
|
||||||
ttl_s=3600, # cache for 1 hour
|
ttl_s=3600, # cache for 1 hour
|
||||||
)
|
)
|
||||||
@@ -93,7 +95,8 @@ class TiktokenCounter(TokenCounter):
|
|||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
@async_redis_cache(
|
@async_redis_cache(
|
||||||
key_func=lambda self, messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
key_func=lambda self,
|
||||||
|
messages: f"tiktoken_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}",
|
||||||
prefix="token_counter",
|
prefix="token_counter",
|
||||||
ttl_s=3600, # cache for 1 hour
|
ttl_s=3600, # cache for 1 hour
|
||||||
)
|
)
|
||||||
@@ -106,7 +109,8 @@ class TiktokenCounter(TokenCounter):
|
|||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
@async_redis_cache(
|
@async_redis_cache(
|
||||||
key_func=lambda self, tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
key_func=lambda self,
|
||||||
|
tools: f"tiktoken_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}",
|
||||||
prefix="token_counter",
|
prefix="token_counter",
|
||||||
ttl_s=3600, # cache for 1 hour
|
ttl_s=3600, # cache for 1 hour
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,8 +12,7 @@ from letta.constants import MAX_FILENAME_LENGTH
|
|||||||
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
|
from letta.helpers.pinecone_utils import list_pinecone_index_for_files, should_use_pinecone
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.orm.file import FileContent as FileContentModel
|
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
|
||||||
from letta.orm.file import FileMetadata as FileMetadataModel
|
|
||||||
from letta.orm.sqlalchemy_base import AccessType
|
from letta.orm.sqlalchemy_base import AccessType
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.enums import FileProcessingStatus
|
from letta.schemas.enums import FileProcessingStatus
|
||||||
@@ -60,7 +59,6 @@ class FileManager:
|
|||||||
*,
|
*,
|
||||||
text: Optional[str] = None,
|
text: Optional[str] = None,
|
||||||
) -> PydanticFileMetadata:
|
) -> PydanticFileMetadata:
|
||||||
|
|
||||||
# short-circuit if it already exists
|
# short-circuit if it already exists
|
||||||
existing = await self.get_file_by_id(file_metadata.id, actor=actor)
|
existing = await self.get_file_by_id(file_metadata.id, actor=actor)
|
||||||
if existing:
|
if existing:
|
||||||
|
|||||||
@@ -7,10 +7,8 @@ from letta.log import get_logger
|
|||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.orm.files_agents import FileAgent as FileAgentModel
|
from letta.orm.files_agents import FileAgent as FileAgentModel
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.block import Block as PydanticBlock
|
from letta.schemas.block import Block as PydanticBlock, FileBlock as PydanticFileBlock
|
||||||
from letta.schemas.block import FileBlock as PydanticFileBlock
|
from letta.schemas.file import FileAgent as PydanticFileAgent, FileMetadata
|
||||||
from letta.schemas.file import FileAgent as PydanticFileAgent
|
|
||||||
from letta.schemas.file import FileMetadata
|
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.utils import enforce_types
|
from letta.utils import enforce_types
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ from letta.orm.errors import NoResultFound
|
|||||||
from letta.orm.group import Group as GroupModel
|
from letta.orm.group import Group as GroupModel
|
||||||
from letta.orm.message import Message as MessageModel
|
from letta.orm.message import Message as MessageModel
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.group import Group as PydanticGroup
|
from letta.schemas.group import Group as PydanticGroup, GroupCreate, GroupUpdate, InternalTemplateGroupCreate, ManagerType
|
||||||
from letta.schemas.group import GroupCreate, GroupUpdate, InternalTemplateGroupCreate, ManagerType
|
|
||||||
from letta.schemas.letta_message import LettaMessage
|
from letta.schemas.letta_message import LettaMessage
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
@@ -18,7 +17,6 @@ from letta.utils import enforce_types
|
|||||||
|
|
||||||
|
|
||||||
class GroupManager:
|
class GroupManager:
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
@trace_method
|
@trace_method
|
||||||
async def list_groups_async(
|
async def list_groups_async(
|
||||||
|
|||||||
@@ -464,7 +464,6 @@ def package_initial_message_sequence(
|
|||||||
# create the agent object
|
# create the agent object
|
||||||
init_messages = []
|
init_messages = []
|
||||||
for message_create in initial_message_sequence:
|
for message_create in initial_message_sequence:
|
||||||
|
|
||||||
if message_create.role == MessageRole.user:
|
if message_create.role == MessageRole.user:
|
||||||
packed_message = system.package_user_message(
|
packed_message = system.package_user_message(
|
||||||
user_message=message_create.content,
|
user_message=message_create.content,
|
||||||
@@ -498,8 +497,10 @@ def package_initial_message_sequence(
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
from openai.types.chat.chat_completion_message_tool_call import (
|
||||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
ChatCompletionMessageToolCall as OpenAIToolCall,
|
||||||
|
Function as OpenAIFunction,
|
||||||
|
)
|
||||||
|
|
||||||
from letta.constants import DEFAULT_MESSAGE_TOOL
|
from letta.constants import DEFAULT_MESSAGE_TOOL
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,14 @@ from letta.orm.block import Block as BlockModel
|
|||||||
from letta.orm.errors import UniqueConstraintViolationError
|
from letta.orm.errors import UniqueConstraintViolationError
|
||||||
from letta.orm.identity import Identity as IdentityModel
|
from letta.orm.identity import Identity as IdentityModel
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.identity import Identity as PydanticIdentity
|
from letta.schemas.identity import (
|
||||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityType, IdentityUpdate, IdentityUpsert
|
Identity as PydanticIdentity,
|
||||||
|
IdentityCreate,
|
||||||
|
IdentityProperty,
|
||||||
|
IdentityType,
|
||||||
|
IdentityUpdate,
|
||||||
|
IdentityUpsert,
|
||||||
|
)
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.settings import DatabaseChoice, settings
|
from letta.settings import DatabaseChoice, settings
|
||||||
@@ -18,7 +24,6 @@ from letta.utils import enforce_types
|
|||||||
|
|
||||||
|
|
||||||
class IdentityManager:
|
class IdentityManager:
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
@trace_method
|
@trace_method
|
||||||
async def list_identities_async(
|
async def list_identities_async(
|
||||||
|
|||||||
@@ -13,13 +13,10 @@ from letta.orm.job import Job as JobModel
|
|||||||
from letta.orm.job_messages import JobMessage
|
from letta.orm.job_messages import JobMessage
|
||||||
from letta.orm.message import Message as MessageModel
|
from letta.orm.message import Message as MessageModel
|
||||||
from letta.orm.sqlalchemy_base import AccessType
|
from letta.orm.sqlalchemy_base import AccessType
|
||||||
from letta.orm.step import Step
|
from letta.orm.step import Step, Step as StepModel
|
||||||
from letta.orm.step import Step as StepModel
|
|
||||||
from letta.otel.tracing import log_event, trace_method
|
from letta.otel.tracing import log_event, trace_method
|
||||||
from letta.schemas.enums import JobStatus, JobType, MessageRole
|
from letta.schemas.enums import JobStatus, JobType, MessageRole
|
||||||
from letta.schemas.job import BatchJob as PydanticBatchJob
|
from letta.schemas.job import BatchJob as PydanticBatchJob, Job as PydanticJob, JobUpdate, LettaRequestConfig
|
||||||
from letta.schemas.job import Job as PydanticJob
|
|
||||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
|
||||||
from letta.schemas.letta_message import LettaMessage
|
from letta.schemas.letta_message import LettaMessage
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
from letta.schemas.run import Run as PydanticRun
|
from letta.schemas.run import Run as PydanticRun
|
||||||
|
|||||||
@@ -11,9 +11,7 @@ from letta.orm.llm_batch_items import LLMBatchItem
|
|||||||
from letta.orm.llm_batch_job import LLMBatchJob
|
from letta.orm.llm_batch_job import LLMBatchJob
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||||
from letta.schemas.llm_batch_job import AgentStepState
|
from letta.schemas.llm_batch_job import AgentStepState, LLMBatchItem as PydanticLLMBatchItem, LLMBatchJob as PydanticLLMBatchJob
|
||||||
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
|
|
||||||
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
|
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession, Tool as MCPTool
|
||||||
from mcp import Tool as MCPTool
|
|
||||||
from mcp.client.auth import OAuthClientProvider
|
from mcp.client.auth import OAuthClientProvider
|
||||||
from mcp.types import TextContent
|
from mcp.types import TextContent
|
||||||
|
|
||||||
|
|||||||
@@ -33,8 +33,7 @@ from letta.schemas.mcp import (
|
|||||||
UpdateStdioMCPServer,
|
UpdateStdioMCPServer,
|
||||||
UpdateStreamableHTTPMCPServer,
|
UpdateStreamableHTTPMCPServer,
|
||||||
)
|
)
|
||||||
from letta.schemas.tool import Tool as PydanticTool
|
from letta.schemas.tool import Tool as PydanticTool, ToolCreate
|
||||||
from letta.schemas.tool import ToolCreate
|
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
||||||
@@ -137,8 +136,7 @@ class MCPManager:
|
|||||||
if mcp_tool.health:
|
if mcp_tool.health:
|
||||||
if mcp_tool.health.status == "INVALID":
|
if mcp_tool.health.status == "INVALID":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid."
|
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid.Reasons: {', '.join(mcp_tool.health.reasons)}"
|
||||||
f"Reasons: {', '.join(mcp_tool.health.reasons)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||||
@@ -305,7 +303,9 @@ class MCPManager:
|
|||||||
|
|
||||||
async with db_registry.async_session() as session:
|
async with db_registry.async_session() as session:
|
||||||
mcp_servers = await MCPServerModel.list_async(
|
mcp_servers = await MCPServerModel.list_async(
|
||||||
db_session=session, organization_id=actor.organization_id, id=mcp_server_ids # This will use the IN operator
|
db_session=session,
|
||||||
|
organization_id=actor.organization_id,
|
||||||
|
id=mcp_server_ids, # This will use the IN operator
|
||||||
)
|
)
|
||||||
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
|
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
|
||||||
|
|
||||||
@@ -407,7 +407,6 @@ class MCPManager:
|
|||||||
# with the value being the schema from StdioServerParameters
|
# with the value being the schema from StdioServerParameters
|
||||||
if MCP_CONFIG_TOPLEVEL_KEY in mcp_config:
|
if MCP_CONFIG_TOPLEVEL_KEY in mcp_config:
|
||||||
for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items():
|
for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items():
|
||||||
|
|
||||||
# No support for duplicate server names
|
# No support for duplicate server names
|
||||||
if server_name in mcp_server_list:
|
if server_name in mcp_server_list:
|
||||||
# Duplicate server names are configuration issues, not system errors
|
# Duplicate server names are configuration issues, not system errors
|
||||||
|
|||||||
@@ -12,8 +12,7 @@ from letta.otel.tracing import trace_method
|
|||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole
|
||||||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||||
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType
|
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType
|
||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage, MessageUpdate
|
||||||
from letta.schemas.message import MessageUpdate
|
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.services.file_manager import FileManager
|
from letta.services.file_manager import FileManager
|
||||||
@@ -185,9 +184,9 @@ class MessageManager:
|
|||||||
# modify the tool call for send_message
|
# modify the tool call for send_message
|
||||||
# TODO: fix this if we add parallel tool calls
|
# TODO: fix this if we add parallel tool calls
|
||||||
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
||||||
assert (
|
assert message.tool_calls[0].function.name == "send_message", (
|
||||||
message.tool_calls[0].function.name == "send_message"
|
f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||||||
), f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
)
|
||||||
original_args = json.loads(message.tool_calls[0].function.arguments)
|
original_args = json.loads(message.tool_calls[0].function.arguments)
|
||||||
original_args["message"] = letta_message_update.content # override the assistant message
|
original_args["message"] = letta_message_update.content # override the assistant message
|
||||||
update_tool_call = message.tool_calls[0].__deepcopy__()
|
update_tool_call = message.tool_calls[0].__deepcopy__()
|
||||||
@@ -224,9 +223,9 @@ class MessageManager:
|
|||||||
# modify the tool call for send_message
|
# modify the tool call for send_message
|
||||||
# TODO: fix this if we add parallel tool calls
|
# TODO: fix this if we add parallel tool calls
|
||||||
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
# TODO: note this only works if the AssistantMessage is generated by the standard send_message
|
||||||
assert (
|
assert message.tool_calls[0].function.name == "send_message", (
|
||||||
message.tool_calls[0].function.name == "send_message"
|
f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
||||||
), f"Expected the first tool call to be send_message, but got {message.tool_calls[0].function.name}"
|
)
|
||||||
original_args = json.loads(message.tool_calls[0].function.arguments)
|
original_args = json.loads(message.tool_calls[0].function.arguments)
|
||||||
original_args["message"] = letta_message_update.content # override the assistant message
|
original_args["message"] = letta_message_update.content # override the assistant message
|
||||||
update_tool_call = message.tool_calls[0].__deepcopy__()
|
update_tool_call = message.tool_calls[0].__deepcopy__()
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ from letta.constants import DEFAULT_ORG_ID, DEFAULT_ORG_NAME
|
|||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.orm.organization import Organization as OrganizationModel
|
from letta.orm.organization import Organization as OrganizationModel
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.organization import Organization as PydanticOrganization
|
from letta.schemas.organization import Organization as PydanticOrganization, OrganizationUpdate
|
||||||
from letta.schemas.organization import OrganizationUpdate
|
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.utils import enforce_types
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
|
|||||||
@@ -3,15 +3,13 @@ from typing import List, Optional, Tuple, Union
|
|||||||
from letta.orm.provider import Provider as ProviderModel
|
from letta.orm.provider import Provider as ProviderModel
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||||
from letta.schemas.providers import Provider as PydanticProvider
|
from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate
|
||||||
from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate
|
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.utils import enforce_types
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
|
|
||||||
class ProviderManager:
|
class ProviderManager:
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
@trace_method
|
@trace_method
|
||||||
def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
|
def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
|
||||||
|
|||||||
@@ -3,15 +3,20 @@ from typing import Dict, List, Optional
|
|||||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel
|
from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel, SandboxEnvironmentVariable as SandboxEnvVarModel
|
||||||
from letta.orm.sandbox_config import SandboxEnvironmentVariable as SandboxEnvVarModel
|
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.enums import SandboxType
|
from letta.schemas.enums import SandboxType
|
||||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
|
from letta.schemas.environment_variables import (
|
||||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
SandboxEnvironmentVariable as PydanticEnvVar,
|
||||||
from letta.schemas.sandbox_config import LocalSandboxConfig
|
SandboxEnvironmentVariableCreate,
|
||||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
SandboxEnvironmentVariableUpdate,
|
||||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
)
|
||||||
|
from letta.schemas.sandbox_config import (
|
||||||
|
LocalSandboxConfig,
|
||||||
|
SandboxConfig as PydanticSandboxConfig,
|
||||||
|
SandboxConfigCreate,
|
||||||
|
SandboxConfigUpdate,
|
||||||
|
)
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.utils import enforce_types, printd
|
from letta.utils import enforce_types, printd
|
||||||
|
|||||||
@@ -9,8 +9,7 @@ from letta.orm.source import Source as SourceModel
|
|||||||
from letta.orm.sources_agents import SourcesAgents
|
from letta.orm.sources_agents import SourcesAgents
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||||
from letta.schemas.source import Source as PydanticSource
|
from letta.schemas.source import Source as PydanticSource, SourceUpdate
|
||||||
from letta.schemas.source import SourceUpdate
|
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.utils import enforce_types, printd
|
from letta.utils import enforce_types, printd
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ class FeedbackType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class StepManager:
|
class StepManager:
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
@trace_method
|
@trace_method
|
||||||
async def list_steps_async(
|
async def list_steps_async(
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class Summarizer:
|
|||||||
total_message_count = len(all_in_context_messages)
|
total_message_count = len(all_in_context_messages)
|
||||||
assert self.partial_evict_summarizer_percentage >= 0.0 and self.partial_evict_summarizer_percentage <= 1.0
|
assert self.partial_evict_summarizer_percentage >= 0.0 and self.partial_evict_summarizer_percentage <= 1.0
|
||||||
target_message_start = round((1.0 - self.partial_evict_summarizer_percentage) * total_message_count)
|
target_message_start = round((1.0 - self.partial_evict_summarizer_percentage) * total_message_count)
|
||||||
logger.info(f"Target message count: {total_message_count}->{(total_message_count-target_message_start)}")
|
logger.info(f"Target message count: {total_message_count}->{(total_message_count - target_message_start)}")
|
||||||
|
|
||||||
# The summary message we'll insert is role 'user' (vs 'assistant', 'tool', or 'system')
|
# The summary message we'll insert is role 'user' (vs 'assistant', 'tool', or 'system')
|
||||||
# We are going to put it at index 1 (index 0 is the system message)
|
# We are going to put it at index 1 (index 0 is the system message)
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ from letta.helpers.json_helpers import json_dumps, json_loads
|
|||||||
from letta.helpers.singleton import singleton
|
from letta.helpers.singleton import singleton
|
||||||
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
|
from letta.orm.provider_trace import ProviderTrace as ProviderTraceModel
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace
|
from letta.schemas.provider_trace import ProviderTrace as PydanticProviderTrace, ProviderTraceCreate
|
||||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
|
||||||
from letta.schemas.step import Step as PydanticStep
|
from letta.schemas.step import Step as PydanticStep
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
@@ -11,7 +10,6 @@ from letta.utils import enforce_types
|
|||||||
|
|
||||||
|
|
||||||
class TelemetryManager:
|
class TelemetryManager:
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
@trace_method
|
@trace_method
|
||||||
async def get_provider_trace_by_step_id_async(
|
async def get_provider_trace_by_step_id_async(
|
||||||
|
|||||||
@@ -309,7 +309,7 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
|||||||
|
|
||||||
# Create numbered markdown for the LLM to reference
|
# Create numbered markdown for the LLM to reference
|
||||||
numbered_lines = markdown_content.split("\n")
|
numbered_lines = markdown_content.split("\n")
|
||||||
numbered_markdown = "\n".join([f"{i+1:4d}: {line}" for i, line in enumerate(numbered_lines)])
|
numbered_markdown = "\n".join([f"{i + 1:4d}: {line}" for i, line in enumerate(numbered_lines)])
|
||||||
|
|
||||||
# Truncate if too long
|
# Truncate if too long
|
||||||
max_content_length = 200000
|
max_content_length = 200000
|
||||||
|
|||||||
@@ -273,14 +273,13 @@ class LettaCoreToolExecutor(ToolExecutor):
|
|||||||
occurences = current_value.count(old_str)
|
occurences = current_value.count(old_str)
|
||||||
if occurences == 0:
|
if occurences == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No replacement was performed, old_str `{old_str}` did not appear " f"verbatim in memory block with label `{label}`."
|
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in memory block with label `{label}`."
|
||||||
)
|
)
|
||||||
elif occurences > 1:
|
elif occurences > 1:
|
||||||
content_value_lines = current_value.split("\n")
|
content_value_lines = current_value.split("\n")
|
||||||
lines = [idx + 1 for idx, line in enumerate(content_value_lines) if old_str in line]
|
lines = [idx + 1 for idx, line in enumerate(content_value_lines) if old_str in line]
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No replacement was performed. Multiple occurrences of "
|
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique."
|
||||||
f"old_str `{old_str}` in lines {lines}. Please ensure it is unique."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace old_str with new_str
|
# Replace old_str with new_str
|
||||||
|
|||||||
@@ -568,7 +568,7 @@ class LettaFileToolExecutor(ToolExecutor):
|
|||||||
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
|
attached_sources = await self.agent_manager.list_attached_sources_async(agent_id=agent_state.id, actor=self.actor)
|
||||||
source_ids = [source.id for source in attached_sources]
|
source_ids = [source.id for source in attached_sources]
|
||||||
if not source_ids:
|
if not source_ids:
|
||||||
return f"No valid source IDs found for attached files"
|
return "No valid source IDs found for attached files"
|
||||||
|
|
||||||
# Get all attached files for this agent
|
# Get all attached files for this agent
|
||||||
file_agents = await self.files_agents_manager.list_files_for_agent(
|
file_agents = await self.files_agents_manager.list_files_for_agent(
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
|||||||
sandbox_config: Optional[SandboxConfig] = None,
|
sandbox_config: Optional[SandboxConfig] = None,
|
||||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||||
) -> ToolExecutionResult:
|
) -> ToolExecutionResult:
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
mcp_server_tag = [tag for tag in tool.tags if tag.startswith(f"{MCP_TOOL_TAG_NAME_PREFIX}:")]
|
mcp_server_tag = [tag for tag in tool.tags if tag.startswith(f"{MCP_TOOL_TAG_NAME_PREFIX}:")]
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ class SandboxToolExecutor(ToolExecutor):
|
|||||||
sandbox_config: Optional[SandboxConfig] = None,
|
sandbox_config: Optional[SandboxConfig] = None,
|
||||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||||
) -> ToolExecutionResult:
|
) -> ToolExecutionResult:
|
||||||
|
|
||||||
# Store original memory state
|
# Store original memory state
|
||||||
if agent_state:
|
if agent_state:
|
||||||
orig_memory_str = await agent_state.memory.compile_in_thread_async()
|
orig_memory_str = await agent_state.memory.compile_in_thread_async()
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class ToolExecutionSandbox:
|
|||||||
logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n")
|
logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n")
|
||||||
for log_line in (result.stdout or []) + (result.stderr or []):
|
for log_line in (result.stdout or []) + (result.stderr or []):
|
||||||
logger.debug(f"{log_line}")
|
logger.debug(f"{log_line}")
|
||||||
logger.debug(f"Ending output log from tool run.")
|
logger.debug("Ending output log from tool run.")
|
||||||
|
|
||||||
# Return result
|
# Return result
|
||||||
return result
|
return result
|
||||||
@@ -267,7 +267,6 @@ class ToolExecutionSandbox:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with self.temporary_env_vars(env):
|
with self.temporary_env_vars(env):
|
||||||
|
|
||||||
# Read and compile the Python script
|
# Read and compile the Python script
|
||||||
with open(temp_file_path, "r", encoding="utf-8") as f:
|
with open(temp_file_path, "r", encoding="utf-8") as f:
|
||||||
source = f.read()
|
source = f.read()
|
||||||
@@ -475,7 +474,7 @@ class ToolExecutionSandbox:
|
|||||||
return None, None
|
return None, None
|
||||||
result = pickle.loads(base64.b64decode(text))
|
result = pickle.loads(base64.b64decode(text))
|
||||||
agent_state = None
|
agent_state = None
|
||||||
if not result["agent_state"] is None:
|
if result["agent_state"] is not None:
|
||||||
agent_state = result["agent_state"]
|
agent_state = result["agent_state"]
|
||||||
return result["results"], agent_state
|
return result["results"], agent_state
|
||||||
|
|
||||||
|
|||||||
@@ -27,8 +27,7 @@ from letta.orm.errors import NoResultFound
|
|||||||
from letta.orm.tool import Tool as ToolModel
|
from letta.orm.tool import Tool as ToolModel
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.enums import ToolType
|
from letta.schemas.enums import ToolType
|
||||||
from letta.schemas.tool import Tool as PydanticTool
|
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
|
||||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.services.helpers.agent_manager_helper import calculate_multi_agent_tools
|
from letta.services.helpers.agent_manager_helper import calculate_multi_agent_tools
|
||||||
|
|||||||
@@ -183,9 +183,9 @@ class ModalDeploymentManager:
|
|||||||
existing_app = await self._try_get_existing_app(sbx_config, version_hash, user)
|
existing_app = await self._try_get_existing_app(sbx_config, version_hash, user)
|
||||||
if existing_app:
|
if existing_app:
|
||||||
return existing_app, version_hash
|
return existing_app, version_hash
|
||||||
raise RuntimeError(f"Deployment completed but app not found")
|
raise RuntimeError("Deployment completed but app not found")
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Timeout waiting for deployment")
|
raise RuntimeError("Timeout waiting for deployment")
|
||||||
|
|
||||||
# We're deploying - mark as in progress
|
# We're deploying - mark as in progress
|
||||||
deployment_key = None
|
deployment_key = None
|
||||||
|
|||||||
@@ -10,8 +10,7 @@ from letta.orm.errors import NoResultFound
|
|||||||
from letta.orm.organization import Organization as OrganizationModel
|
from letta.orm.organization import Organization as OrganizationModel
|
||||||
from letta.orm.user import User as UserModel
|
from letta.orm.user import User as UserModel
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser, UserUpdate
|
||||||
from letta.schemas.user import UserUpdate
|
|
||||||
from letta.server.db import db_registry
|
from letta.server.db import db_registry
|
||||||
from letta.utils import enforce_types
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
|
|||||||
@@ -89,7 +89,6 @@ class SummarizerSettings(BaseSettings):
|
|||||||
|
|
||||||
|
|
||||||
class ModelSettings(BaseSettings):
|
class ModelSettings(BaseSettings):
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||||
|
|
||||||
global_max_context_window_limit: int = 32000
|
global_max_context_window_limit: int = 32000
|
||||||
|
|||||||
@@ -117,9 +117,9 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
|
|||||||
|
|
||||||
# Starting a new buffer line
|
# Starting a new buffer line
|
||||||
if not self.streaming_buffer_type:
|
if not self.streaming_buffer_type:
|
||||||
assert not (
|
assert not (message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls)), (
|
||||||
message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls)
|
f"Error: got both content and tool_calls in message stream\n{message_delta}"
|
||||||
), f"Error: got both content and tool_calls in message stream\n{message_delta}"
|
)
|
||||||
|
|
||||||
if message_delta.content is not None:
|
if message_delta.content is not None:
|
||||||
# Write out the prefix for inner thoughts
|
# Write out the prefix for inner thoughts
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ def package_summarize_message(summary, summary_message_count, hidden_message_cou
|
|||||||
|
|
||||||
def package_summarize_message_no_counts(summary, timezone):
|
def package_summarize_message_no_counts(summary, timezone):
|
||||||
context_message = (
|
context_message = (
|
||||||
f"Note: prior messages have been hidden from view due to conversation memory constraints.\n"
|
"Note: prior messages have been hidden from view due to conversation memory constraints.\n"
|
||||||
+ f"The following is a summary of the previous messages:\n {summary}"
|
+ f"The following is a summary of the previous messages:\n {summary}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1149,7 +1149,6 @@ class CancellationSignal:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, job_manager=None, job_id=None, actor=None):
|
def __init__(self, job_manager=None, job_id=None, actor=None):
|
||||||
|
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.services.job_manager import JobManager
|
from letta.services.job_manager import JobManager
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class CallbackHandler(BaseHTTPRequestHandler):
|
|||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<h1>Authorization Failed</h1>
|
<h1>Authorization Failed</h1>
|
||||||
<p>Error: {query_params['error'][0]}</p>
|
<p>Error: {query_params["error"][0]}</p>
|
||||||
<p>You can close this window and return to the terminal.</p>
|
<p>You can close this window and return to the terminal.</p>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ def generate_docqa_baseline_response(
|
|||||||
# print(f"Top {num_documents} documents: {documents_search_results_sorted_by_relevance}")
|
# print(f"Top {num_documents} documents: {documents_search_results_sorted_by_relevance}")
|
||||||
|
|
||||||
# compute truncation length
|
# compute truncation length
|
||||||
extra_text = BASELINE_PROMPT + f"Question: {question}" + f"Answer:"
|
extra_text = BASELINE_PROMPT + f"Question: {question}" + "Answer:"
|
||||||
padding = count_tokens(extra_text) + 1000
|
padding = count_tokens(extra_text) + 1000
|
||||||
truncation_length = int((config.default_llm_config.context_window - padding) / num_documents)
|
truncation_length = int((config.default_llm_config.context_window - padding) / num_documents)
|
||||||
print("Token size", config.default_llm_config.context_window)
|
print("Token size", config.default_llm_config.context_window)
|
||||||
@@ -114,7 +114,7 @@ def generate_docqa_baseline_response(
|
|||||||
if i >= num_documents:
|
if i >= num_documents:
|
||||||
break
|
break
|
||||||
|
|
||||||
doc_prompt = f"Document [{i+1}]: {doc} \n"
|
doc_prompt = f"Document [{i + 1}]: {doc} \n"
|
||||||
|
|
||||||
# truncate (that's why the performance goes down as x-axis increases)
|
# truncate (that's why the performance goes down as x-axis increases)
|
||||||
if truncation_length is not None:
|
if truncation_length is not None:
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ if __name__ == "__main__":
|
|||||||
if a in response:
|
if a in response:
|
||||||
found = True
|
found = True
|
||||||
|
|
||||||
if not found and not "INSUFFICIENT INFORMATION" in response:
|
if not found and "INSUFFICIENT INFORMATION" not in response:
|
||||||
# inconclusive: pass to llm judge
|
# inconclusive: pass to llm judge
|
||||||
print(question)
|
print(question)
|
||||||
print(answer)
|
print(answer)
|
||||||
|
|||||||
@@ -61,13 +61,13 @@ def archival_memory_text_search(self, query: str, page: Optional[int] = 0) -> Op
|
|||||||
try:
|
try:
|
||||||
page = int(page)
|
page = int(page)
|
||||||
except:
|
except:
|
||||||
raise ValueError(f"'page' argument must be an integer")
|
raise ValueError("'page' argument must be an integer")
|
||||||
count = 10
|
count = 10
|
||||||
results = self.persistence_manager.archival_memory.storage.query_text(query, limit=count, offset=page * count)
|
results = self.persistence_manager.archival_memory.storage.query_text(query, limit=count, offset=page * count)
|
||||||
total = len(results)
|
total = len(results)
|
||||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
results_str = f"No results found."
|
results_str = "No results found."
|
||||||
else:
|
else:
|
||||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||||
results_formatted = [f"memory: {d.text}" for d in results]
|
results_formatted = [f"memory: {d.text}" for d in results]
|
||||||
@@ -253,8 +253,8 @@ if __name__ == "__main__":
|
|||||||
# overwrite
|
# overwrite
|
||||||
kv_d[current_key] = next_key
|
kv_d[current_key] = next_key
|
||||||
|
|
||||||
print(f"Nested {i+1}")
|
print(f"Nested {i + 1}")
|
||||||
print(f"Done")
|
print("Done")
|
||||||
|
|
||||||
def get_nested_key(original_key, kv_d):
|
def get_nested_key(original_key, kv_d):
|
||||||
key = original_key
|
key = original_key
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ dependencies = [
|
|||||||
"certifi>=2025.6.15",
|
"certifi>=2025.6.15",
|
||||||
"markitdown[docx,pdf,pptx]>=0.1.2",
|
"markitdown[docx,pdf,pptx]>=0.1.2",
|
||||||
"orjson>=3.11.1",
|
"orjson>=3.11.1",
|
||||||
|
"ruff[dev]>=0.12.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
@@ -108,11 +109,8 @@ dev = [
|
|||||||
"pytest-mock>=3.14.0",
|
"pytest-mock>=3.14.0",
|
||||||
"pytest-json-report>=1.5.0",
|
"pytest-json-report>=1.5.0",
|
||||||
"pexpect>=4.9.0",
|
"pexpect>=4.9.0",
|
||||||
"black[jupyter]>=24.4.2",
|
|
||||||
"pre-commit>=3.5.0",
|
"pre-commit>=3.5.0",
|
||||||
"pyright>=1.1.347",
|
"pyright>=1.1.347",
|
||||||
"autoflake>=2.3.0",
|
|
||||||
"isort>=5.13.2",
|
|
||||||
"ipykernel>=6.29.5",
|
"ipykernel>=6.29.5",
|
||||||
"ipdb>=0.13.13",
|
"ipdb>=0.13.13",
|
||||||
]
|
]
|
||||||
@@ -149,18 +147,46 @@ build-backend = "hatchling.build"
|
|||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["letta"]
|
packages = ["letta"]
|
||||||
|
|
||||||
[tool.black]
|
|
||||||
line-length = 140
|
|
||||||
target-version = ['py310', 'py311', 'py312', 'py313']
|
|
||||||
extend-exclude = "examples/*"
|
|
||||||
|
|
||||||
[tool.isort]
|
[tool.ruff]
|
||||||
profile = "black"
|
line-length = 140
|
||||||
line_length = 140
|
target-version = "py312"
|
||||||
multi_line_output = 3
|
extend-exclude = [
|
||||||
include_trailing_comma = true
|
"examples/*",
|
||||||
force_grid_wrap = 0
|
"tests/data/*",
|
||||||
use_parentheses = true
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by formatter)
|
||||||
|
"E402", # module import not at top of file
|
||||||
|
"E711", # none-comparison
|
||||||
|
"E712", # true-false-comparison
|
||||||
|
"E722", # bare except
|
||||||
|
"E721", # type comparison
|
||||||
|
"F401", # unused import
|
||||||
|
"F821", # undefined name
|
||||||
|
"F811", # redefined while unused
|
||||||
|
"F841", # local variable assigned but never used
|
||||||
|
"W293", # blank line contains whitespace
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
force-single-line = false
|
||||||
|
combine-as-imports = true
|
||||||
|
split-on-trailing-comma = true
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
|
indent-style = "space"
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
|
line-ending = "auto"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], k
|
|||||||
# Message field not in send_message
|
# Message field not in send_message
|
||||||
if "message" not in arguments:
|
if "message" not in arguments:
|
||||||
raise InvalidToolCallError(
|
raise InvalidToolCallError(
|
||||||
messages=[target_message], explanation=f"send_message function call does not have required field `message`"
|
messages=[target_message], explanation="send_message function call does not have required field `message`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the keyword is in the message arguments
|
# Check that the keyword is in the message arguments
|
||||||
@@ -151,7 +151,7 @@ def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], k
|
|||||||
keyword = keyword.lower()
|
keyword = keyword.lower()
|
||||||
arguments["message"] = arguments["message"].lower()
|
arguments["message"] = arguments["message"].lower()
|
||||||
|
|
||||||
if not keyword in arguments["message"]:
|
if keyword not in arguments["message"]:
|
||||||
raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")
|
raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,7 @@ from letta.schemas.enums import MessageRole
|
|||||||
from letta.schemas.file import FileAgent
|
from letta.schemas.file import FileAgent
|
||||||
from letta.schemas.memory import ContextWindowOverview
|
from letta.schemas.memory import ContextWindowOverview
|
||||||
from letta.schemas.tool import Tool
|
from letta.schemas.tool import Tool
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User, User as PydanticUser
|
||||||
from letta.schemas.user import User as PydanticUser
|
|
||||||
from letta.server.rest_api.routers.v1.agents import ImportedAgentsResponse
|
from letta.server.rest_api.routers.v1.agents import ImportedAgentsResponse
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
|
|
||||||
@@ -66,7 +65,6 @@ def retry_until_success(max_attempts=10, sleep_time_seconds=4):
|
|||||||
def decorator_retry(func):
|
def decorator_retry(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
|
||||||
for attempt in range(1, max_attempts + 1):
|
for attempt in range(1, max_attempts + 1):
|
||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
@@ -124,32 +122,32 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
|
|||||||
assert agent.llm_config == request.llm_config, f"LLM config mismatch: {agent.llm_config} != {request.llm_config}"
|
assert agent.llm_config == request.llm_config, f"LLM config mismatch: {agent.llm_config} != {request.llm_config}"
|
||||||
|
|
||||||
# Assert embedding configuration
|
# Assert embedding configuration
|
||||||
assert (
|
assert agent.embedding_config == request.embedding_config, (
|
||||||
agent.embedding_config == request.embedding_config
|
f"Embedding config mismatch: {agent.embedding_config} != {request.embedding_config}"
|
||||||
), f"Embedding config mismatch: {agent.embedding_config} != {request.embedding_config}"
|
)
|
||||||
|
|
||||||
# Assert memory blocks
|
# Assert memory blocks
|
||||||
if hasattr(request, "memory_blocks"):
|
if hasattr(request, "memory_blocks"):
|
||||||
assert len(agent.memory.blocks) == len(request.memory_blocks) + len(
|
assert len(agent.memory.blocks) == len(request.memory_blocks) + len(request.block_ids), (
|
||||||
request.block_ids
|
f"Memory blocks count mismatch: {len(agent.memory.blocks)} != {len(request.memory_blocks) + len(request.block_ids)}"
|
||||||
), f"Memory blocks count mismatch: {len(agent.memory.blocks)} != {len(request.memory_blocks) + len(request.block_ids)}"
|
)
|
||||||
memory_block_values = {block.value for block in agent.memory.blocks}
|
memory_block_values = {block.value for block in agent.memory.blocks}
|
||||||
expected_block_values = {block.value for block in request.memory_blocks}
|
expected_block_values = {block.value for block in request.memory_blocks}
|
||||||
assert expected_block_values.issubset(
|
assert expected_block_values.issubset(memory_block_values), (
|
||||||
memory_block_values
|
f"Memory blocks mismatch: {expected_block_values} not in {memory_block_values}"
|
||||||
), f"Memory blocks mismatch: {expected_block_values} not in {memory_block_values}"
|
)
|
||||||
|
|
||||||
# Assert tools
|
# Assert tools
|
||||||
assert len(agent.tools) == len(request.tool_ids), f"Tools count mismatch: {len(agent.tools)} != {len(request.tool_ids)}"
|
assert len(agent.tools) == len(request.tool_ids), f"Tools count mismatch: {len(agent.tools)} != {len(request.tool_ids)}"
|
||||||
assert {tool.id for tool in agent.tools} == set(
|
assert {tool.id for tool in agent.tools} == set(request.tool_ids), (
|
||||||
request.tool_ids
|
f"Tools mismatch: {set(tool.id for tool in agent.tools)} != {set(request.tool_ids)}"
|
||||||
), f"Tools mismatch: {set(tool.id for tool in agent.tools)} != {set(request.tool_ids)}"
|
)
|
||||||
|
|
||||||
# Assert sources
|
# Assert sources
|
||||||
assert len(agent.sources) == len(request.source_ids), f"Sources count mismatch: {len(agent.sources)} != {len(request.source_ids)}"
|
assert len(agent.sources) == len(request.source_ids), f"Sources count mismatch: {len(agent.sources)} != {len(request.source_ids)}"
|
||||||
assert {source.id for source in agent.sources} == set(
|
assert {source.id for source in agent.sources} == set(request.source_ids), (
|
||||||
request.source_ids
|
f"Sources mismatch: {set(source.id for source in agent.sources)} != {set(request.source_ids)}"
|
||||||
), f"Sources mismatch: {set(source.id for source in agent.sources)} != {set(request.source_ids)}"
|
)
|
||||||
|
|
||||||
# Assert tags
|
# Assert tags
|
||||||
assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}"
|
assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}"
|
||||||
@@ -158,15 +156,15 @@ def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, Up
|
|||||||
print("TOOLRULES", request.tool_rules)
|
print("TOOLRULES", request.tool_rules)
|
||||||
print("AGENTTOOLRULES", agent.tool_rules)
|
print("AGENTTOOLRULES", agent.tool_rules)
|
||||||
if request.tool_rules:
|
if request.tool_rules:
|
||||||
assert len(agent.tool_rules) == len(
|
assert len(agent.tool_rules) == len(request.tool_rules), (
|
||||||
request.tool_rules
|
f"Tool rules count mismatch: {len(agent.tool_rules)} != {len(request.tool_rules)}"
|
||||||
), f"Tool rules count mismatch: {len(agent.tool_rules)} != {len(request.tool_rules)}"
|
)
|
||||||
assert all(
|
assert all(any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules), (
|
||||||
any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules
|
f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}"
|
||||||
), f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}"
|
)
|
||||||
|
|
||||||
# Assert message_buffer_autoclear
|
# Assert message_buffer_autoclear
|
||||||
if not request.message_buffer_autoclear is None:
|
if request.message_buffer_autoclear is not None:
|
||||||
assert agent.message_buffer_autoclear == request.message_buffer_autoclear
|
assert agent.message_buffer_autoclear == request.message_buffer_autoclear
|
||||||
|
|
||||||
|
|
||||||
@@ -176,9 +174,9 @@ def validate_context_window_overview(
|
|||||||
"""Validate common sense assertions for ContextWindowOverview"""
|
"""Validate common sense assertions for ContextWindowOverview"""
|
||||||
|
|
||||||
# 1. Current context size should not exceed maximum
|
# 1. Current context size should not exceed maximum
|
||||||
assert (
|
assert overview.context_window_size_current <= overview.context_window_size_max, (
|
||||||
overview.context_window_size_current <= overview.context_window_size_max
|
f"Current context size ({overview.context_window_size_current}) exceeds maximum ({overview.context_window_size_max})"
|
||||||
), f"Current context size ({overview.context_window_size_current}) exceeds maximum ({overview.context_window_size_max})"
|
)
|
||||||
|
|
||||||
# 2. All token counts should be non-negative
|
# 2. All token counts should be non-negative
|
||||||
assert overview.num_tokens_system >= 0, "System token count cannot be negative"
|
assert overview.num_tokens_system >= 0, "System token count cannot be negative"
|
||||||
@@ -197,14 +195,14 @@ def validate_context_window_overview(
|
|||||||
+ overview.num_tokens_messages
|
+ overview.num_tokens_messages
|
||||||
+ overview.num_tokens_functions_definitions
|
+ overview.num_tokens_functions_definitions
|
||||||
)
|
)
|
||||||
assert (
|
assert overview.context_window_size_current == expected_total, (
|
||||||
overview.context_window_size_current == expected_total
|
f"Token sum ({expected_total}) doesn't match current size ({overview.context_window_size_current})"
|
||||||
), f"Token sum ({expected_total}) doesn't match current size ({overview.context_window_size_current})"
|
)
|
||||||
|
|
||||||
# 4. Message count should match messages list length
|
# 4. Message count should match messages list length
|
||||||
assert (
|
assert len(overview.messages) == overview.num_messages, (
|
||||||
len(overview.messages) == overview.num_messages
|
f"Messages list length ({len(overview.messages)}) doesn't match num_messages ({overview.num_messages})"
|
||||||
), f"Messages list length ({len(overview.messages)}) doesn't match num_messages ({overview.num_messages})"
|
)
|
||||||
|
|
||||||
# 5. If summary_memory is None, its token count should be 0
|
# 5. If summary_memory is None, its token count should be 0
|
||||||
if overview.summary_memory is None:
|
if overview.summary_memory is None:
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ def create_test_agent(name, actor, test_id: Optional[str] = None, model="anthrop
|
|||||||
model_endpoint_type="anthropic",
|
model_endpoint_type="anthropic",
|
||||||
model_endpoint="https://api.anthropic.com/v1",
|
model_endpoint="https://api.anthropic.com/v1",
|
||||||
context_window=32000,
|
context_window=32000,
|
||||||
handle=f"anthropic/claude-3-7-sonnet-latest",
|
handle="anthropic/claude-3-7-sonnet-latest",
|
||||||
put_inner_thoughts_in_kwargs=True,
|
put_inner_thoughts_in_kwargs=True,
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
)
|
)
|
||||||
@@ -193,7 +193,7 @@ async def create_test_batch_item(server, batch_id, agent_id, default_user):
|
|||||||
model_endpoint_type="anthropic",
|
model_endpoint_type="anthropic",
|
||||||
model_endpoint="https://api.anthropic.com/v1",
|
model_endpoint="https://api.anthropic.com/v1",
|
||||||
context_window=32000,
|
context_window=32000,
|
||||||
handle=f"anthropic/claude-3-7-sonnet-latest",
|
handle="anthropic/claude-3-7-sonnet-latest",
|
||||||
put_inner_thoughts_in_kwargs=True,
|
put_inner_thoughts_in_kwargs=True,
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ def test_run_code(
|
|||||||
|
|
||||||
returns = [m.tool_return for m in tool_returns]
|
returns = [m.tool_return for m in tool_returns]
|
||||||
assert any(expected in ret for ret in returns), (
|
assert any(expected in ret for ret in returns), (
|
||||||
f"For language={language!r}, expected to find '{expected}' in tool_return, " f"but got {returns!r}"
|
f"For language={language!r}, expected to find '{expected}' in tool_return, but got {returns!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -357,7 +357,6 @@ async def test_web_search_uses_agent_env_var_model():
|
|||||||
patch.dict(os.environ, {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}),
|
patch.dict(os.environ, {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}),
|
||||||
patch("firecrawl.AsyncFirecrawlApp") as mock_firecrawl_class,
|
patch("firecrawl.AsyncFirecrawlApp") as mock_firecrawl_class,
|
||||||
):
|
):
|
||||||
|
|
||||||
# setup mocks
|
# setup mocks
|
||||||
mock_model_settings.openai_api_key = "test-key"
|
mock_model_settings.openai_api_key = "test-key"
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.enums import MessageStreamStatus
|
from letta.schemas.enums import MessageStreamStatus
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage as OpenAIUserMessage
|
||||||
from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage
|
|
||||||
from letta.schemas.tool import ToolCreate
|
from letta.schemas.tool import ToolCreate
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.services.tool_manager import ToolManager
|
from letta.services.tool_manager import ToolManager
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user