chore: enable F821, F401, W293 (#9503)
* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import
This commit is contained in:
1
.github/scripts/model-sweep/conftest.py
vendored
1
.github/scripts/model-sweep/conftest.py
vendored
@@ -16,7 +16,6 @@ from letta.schemas.agent import AgentState
|
|||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.services.organization_manager import OrganizationManager
|
from letta.services.organization_manager import OrganizationManager
|
||||||
from letta.services.user_manager import UserManager
|
from letta.services.user_manager import UserManager
|
||||||
from letta.settings import tool_settings
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
|
|||||||
4
.github/scripts/model-sweep/model_sweep.py
vendored
4
.github/scripts/model-sweep/model_sweep.py
vendored
@@ -1,16 +1,12 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import socket
|
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from letta_client import Letta, MessageCreate, Run
|
from letta_client import Letta, MessageCreate, Run
|
||||||
from letta_client.core.api_error import ApiError
|
from letta_client.core.api_error import ApiError
|
||||||
from letta_client.types import (
|
from letta_client.types import (
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-10-07 13:01:17.872405
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-09-10 19:16:39.118760
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-12-17 15:46:06.184858
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-10-03 12:10:51.065067
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ Create Date: 2024-12-14 17:23:08.772554
|
|||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-09-19 10:58:19.658106
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-10-06 13:17:09.918439
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-11-11 19:16:00.000000
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-12-07 15:30:43.407495
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-11-11 21:16:00.000000
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# determine backfill value based on current pinecone settings
|
# determine backfill value based on current pinecone settings
|
||||||
try:
|
try:
|
||||||
from pinecone import IndexEmbed, PineconeAsyncio
|
from pinecone import IndexEmbed, PineconeAsyncio # noqa: F401
|
||||||
|
|
||||||
pinecone_available = True
|
pinecone_available = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# Add the app directory to path to import our crypto utils
|
# Add the app directory to path to import our crypto utils
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-11-07 15:43:59.446292
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ Create Date: 2025-10-04 00:44:06.663817
|
|||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@@ -16,26 +16,32 @@ try:
|
|||||||
from letta.settings import DatabaseChoice, settings
|
from letta.settings import DatabaseChoice, settings
|
||||||
|
|
||||||
if settings.database_engine == DatabaseChoice.SQLITE:
|
if settings.database_engine == DatabaseChoice.SQLITE:
|
||||||
from letta.orm import sqlite_functions
|
from letta.orm import sqlite_functions # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# If sqlite_vec is not installed, it's fine for client usage
|
# If sqlite_vec is not installed, it's fine for client usage
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# # imports for easier access
|
# # imports for easier access
|
||||||
from letta.schemas.agent import AgentState
|
from letta.schemas.agent import AgentState as AgentState
|
||||||
from letta.schemas.block import Block
|
from letta.schemas.block import Block as Block
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig as EmbeddingConfig
|
||||||
from letta.schemas.enums import JobStatus
|
from letta.schemas.enums import JobStatus as JobStatus
|
||||||
from letta.schemas.file import FileMetadata
|
from letta.schemas.file import FileMetadata as FileMetadata
|
||||||
from letta.schemas.job import Job
|
from letta.schemas.job import Job as Job
|
||||||
from letta.schemas.letta_message import LettaMessage, LettaPing
|
from letta.schemas.letta_message import LettaErrorMessage as LettaErrorMessage, LettaMessage as LettaMessage, LettaPing as LettaPing
|
||||||
from letta.schemas.letta_stop_reason import LettaStopReason
|
from letta.schemas.letta_stop_reason import LettaStopReason as LettaStopReason
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig as LLMConfig
|
||||||
from letta.schemas.memory import ArchivalMemorySummary, BasicBlockMemory, ChatMemory, Memory, RecallMemorySummary
|
from letta.schemas.memory import (
|
||||||
from letta.schemas.message import Message
|
ArchivalMemorySummary as ArchivalMemorySummary,
|
||||||
from letta.schemas.organization import Organization
|
BasicBlockMemory as BasicBlockMemory,
|
||||||
from letta.schemas.passage import Passage
|
ChatMemory as ChatMemory,
|
||||||
from letta.schemas.source import Source
|
Memory as Memory,
|
||||||
from letta.schemas.tool import Tool
|
RecallMemorySummary as RecallMemorySummary,
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
)
|
||||||
from letta.schemas.user import User
|
from letta.schemas.message import Message as Message
|
||||||
|
from letta.schemas.organization import Organization as Organization
|
||||||
|
from letta.schemas.passage import Passage as Passage
|
||||||
|
from letta.schemas.source import Source as Source
|
||||||
|
from letta.schemas.tool import Tool as Tool
|
||||||
|
from letta.schemas.usage import LettaUsageStatistics as LettaUsageStatistics
|
||||||
|
from letta.schemas.user import User as User
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import AsyncGenerator, Optional
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from letta.llm_api.llm_client_base import LLMClientBase
|
from letta.llm_api.llm_client_base import LLMClientBase
|
||||||
from letta.schemas.enums import LLMCallType
|
from letta.schemas.enums import LLMCallType
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import AsyncGenerator
|
|||||||
|
|
||||||
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
|
from letta.adapters.letta_llm_adapter import LettaLLMAdapter
|
||||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||||
from letta.otel.tracing import log_attributes, log_event, safe_json_dumps, trace_method
|
from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
|
||||||
from letta.schemas.letta_message import LettaMessage
|
from letta.schemas.letta_message import LettaMessage
|
||||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
|
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
|
||||||
from letta.schemas.provider_trace import ProviderTrace
|
from letta.schemas.provider_trace import ProviderTrace
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from letta.schemas.enums import LLMCallType, ProviderType
|
|||||||
from letta.schemas.letta_message import LettaMessage
|
from letta.schemas.letta_message import LettaMessage
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.provider_trace import ProviderTrace
|
from letta.schemas.provider_trace import ProviderTrace
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
from letta.utils import safe_create_task
|
from letta.utils import safe_create_task
|
||||||
|
|||||||
@@ -19,18 +19,17 @@ from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
|||||||
from letta.llm_api.sglang_native_client import SGLangNativeClient
|
from letta.llm_api.sglang_native_client import SGLangNativeClient
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.schemas.letta_message import LettaMessage
|
from letta.schemas.letta_message import LettaMessage
|
||||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, TextContent
|
from letta.schemas.letta_message_content import TextContent
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionTokenLogprob,
|
||||||
Choice,
|
Choice,
|
||||||
ChoiceLogprobs,
|
ChoiceLogprobs,
|
||||||
ChatCompletionTokenLogprob,
|
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
Message as ChoiceMessage,
|
Message as ChoiceMessage,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
UsageStatistics,
|
UsageStatistics,
|
||||||
)
|
)
|
||||||
from letta.schemas.usage import normalize_cache_tokens, normalize_reasoning_tokens
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -41,37 +40,38 @@ _tokenizer_cache: dict[str, Any] = {}
|
|||||||
class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter that uses SGLang's native /generate endpoint for multi-turn RL training.
|
Adapter that uses SGLang's native /generate endpoint for multi-turn RL training.
|
||||||
|
|
||||||
Key differences from SimpleLLMRequestAdapter:
|
Key differences from SimpleLLMRequestAdapter:
|
||||||
- Uses /generate instead of /v1/chat/completions
|
- Uses /generate instead of /v1/chat/completions
|
||||||
- Returns output_ids (token IDs) in addition to text
|
- Returns output_ids (token IDs) in addition to text
|
||||||
- Returns output_token_logprobs with [logprob, token_id] pairs
|
- Returns output_token_logprobs with [logprob, token_id] pairs
|
||||||
- Formats tools into prompt and parses tool calls from response
|
- Formats tools into prompt and parses tool calls from response
|
||||||
|
|
||||||
These are essential for building accurate loss masks in multi-turn training.
|
These are essential for building accurate loss masks in multi-turn training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._sglang_client: Optional[SGLangNativeClient] = None
|
self._sglang_client: Optional[SGLangNativeClient] = None
|
||||||
self._tokenizer: Any = None
|
self._tokenizer: Any = None
|
||||||
|
|
||||||
def _get_tokenizer(self) -> Any:
|
def _get_tokenizer(self) -> Any:
|
||||||
"""Get or create tokenizer for the model."""
|
"""Get or create tokenizer for the model."""
|
||||||
global _tokenizer_cache
|
global _tokenizer_cache
|
||||||
|
|
||||||
# Get model name from llm_config
|
# Get model name from llm_config
|
||||||
model_name = self.llm_config.model
|
model_name = self.llm_config.model
|
||||||
if not model_name:
|
if not model_name:
|
||||||
logger.warning("No model name in llm_config, cannot load tokenizer")
|
logger.warning("No model name in llm_config, cannot load tokenizer")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check cache
|
# Check cache
|
||||||
if model_name in _tokenizer_cache:
|
if model_name in _tokenizer_cache:
|
||||||
return _tokenizer_cache[model_name]
|
return _tokenizer_cache[model_name]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
logger.info(f"Loading tokenizer for model: {model_name}")
|
logger.info(f"Loading tokenizer for model: {model_name}")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
_tokenizer_cache[model_name] = tokenizer
|
_tokenizer_cache[model_name] = tokenizer
|
||||||
@@ -82,7 +82,7 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load tokenizer: {e}, falling back to manual formatting")
|
logger.warning(f"Failed to load tokenizer: {e}, falling back to manual formatting")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_sglang_client(self) -> SGLangNativeClient:
|
def _get_sglang_client(self) -> SGLangNativeClient:
|
||||||
"""Get or create SGLang native client."""
|
"""Get or create SGLang native client."""
|
||||||
if self._sglang_client is None:
|
if self._sglang_client is None:
|
||||||
@@ -94,17 +94,17 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
api_key=None,
|
api_key=None,
|
||||||
)
|
)
|
||||||
return self._sglang_client
|
return self._sglang_client
|
||||||
|
|
||||||
def _format_tools_for_prompt(self, tools: list) -> str:
|
def _format_tools_for_prompt(self, tools: list) -> str:
|
||||||
"""
|
"""
|
||||||
Format tools in Qwen3 chat template format for the system prompt.
|
Format tools in Qwen3 chat template format for the system prompt.
|
||||||
|
|
||||||
This matches the exact format produced by Qwen3's tokenizer.apply_chat_template()
|
This matches the exact format produced by Qwen3's tokenizer.apply_chat_template()
|
||||||
with tools parameter.
|
with tools parameter.
|
||||||
"""
|
"""
|
||||||
if not tools:
|
if not tools:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Format each tool as JSON (matching Qwen3 template exactly)
|
# Format each tool as JSON (matching Qwen3 template exactly)
|
||||||
tool_jsons = []
|
tool_jsons = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@@ -120,84 +120,85 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
"name": getattr(getattr(tool, "function", tool), "name", ""),
|
"name": getattr(getattr(tool, "function", tool), "name", ""),
|
||||||
"description": getattr(getattr(tool, "function", tool), "description", ""),
|
"description": getattr(getattr(tool, "function", tool), "description", ""),
|
||||||
"parameters": getattr(getattr(tool, "function", tool), "parameters", {}),
|
"parameters": getattr(getattr(tool, "function", tool), "parameters", {}),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
tool_jsons.append(json.dumps(tool_dict))
|
tool_jsons.append(json.dumps(tool_dict))
|
||||||
|
|
||||||
# Use exact Qwen3 format
|
# Use exact Qwen3 format
|
||||||
tools_section = (
|
tools_section = (
|
||||||
"\n\n# Tools\n\n"
|
"\n\n# Tools\n\n"
|
||||||
"You may call one or more functions to assist with the user query.\n\n"
|
"You may call one or more functions to assist with the user query.\n\n"
|
||||||
"You are provided with function signatures within <tools></tools> XML tags:\n"
|
"You are provided with function signatures within <tools></tools> XML tags:\n"
|
||||||
"<tools>\n"
|
"<tools>\n" + "\n".join(tool_jsons) + "\n"
|
||||||
+ "\n".join(tool_jsons) + "\n"
|
|
||||||
"</tools>\n\n"
|
"</tools>\n\n"
|
||||||
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
|
"For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
|
||||||
"<tool_call>\n"
|
"<tool_call>\n"
|
||||||
'{"name": <function-name>, "arguments": <args-json-object>}\n'
|
'{"name": <function-name>, "arguments": <args-json-object>}\n'
|
||||||
"</tool_call>"
|
"</tool_call>"
|
||||||
)
|
)
|
||||||
|
|
||||||
return tools_section
|
return tools_section
|
||||||
|
|
||||||
def _convert_messages_to_openai_format(self, messages: list) -> list[dict]:
|
def _convert_messages_to_openai_format(self, messages: list) -> list[dict]:
|
||||||
"""Convert Letta Message objects to OpenAI-style message dicts."""
|
"""Convert Letta Message objects to OpenAI-style message dicts."""
|
||||||
openai_messages = []
|
openai_messages = []
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
# Handle both dict and Pydantic Message objects
|
# Handle both dict and Pydantic Message objects
|
||||||
if hasattr(msg, 'role'):
|
if hasattr(msg, "role"):
|
||||||
role = msg.role
|
role = msg.role
|
||||||
content = msg.content if hasattr(msg, 'content') else ""
|
content = msg.content if hasattr(msg, "content") else ""
|
||||||
# Handle content that might be a list of content parts
|
# Handle content that might be a list of content parts
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
content = " ".join([c.text if hasattr(c, 'text') else str(c) for c in content])
|
content = " ".join([c.text if hasattr(c, "text") else str(c) for c in content])
|
||||||
elif content is None:
|
elif content is None:
|
||||||
content = ""
|
content = ""
|
||||||
tool_calls = getattr(msg, 'tool_calls', None)
|
tool_calls = getattr(msg, "tool_calls", None)
|
||||||
tool_call_id = getattr(msg, 'tool_call_id', None)
|
tool_call_id = getattr(msg, "tool_call_id", None)
|
||||||
name = getattr(msg, 'name', None)
|
name = getattr(msg, "name", None)
|
||||||
else:
|
else:
|
||||||
role = msg.get("role", "user")
|
role = msg.get("role", "user")
|
||||||
content = msg.get("content", "")
|
content = msg.get("content", "")
|
||||||
tool_calls = msg.get("tool_calls", None)
|
tool_calls = msg.get("tool_calls", None)
|
||||||
tool_call_id = msg.get("tool_call_id", None)
|
tool_call_id = msg.get("tool_call_id", None)
|
||||||
name = msg.get("name", None)
|
name = msg.get("name", None)
|
||||||
|
|
||||||
openai_msg = {"role": role, "content": content}
|
openai_msg = {"role": role, "content": content}
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
# Convert tool calls to OpenAI format
|
# Convert tool calls to OpenAI format
|
||||||
openai_tool_calls = []
|
openai_tool_calls = []
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
if hasattr(tc, 'function'):
|
if hasattr(tc, "function"):
|
||||||
tc_dict = {
|
tc_dict = {
|
||||||
"id": getattr(tc, 'id', f"call_{uuid.uuid4().hex[:8]}"),
|
"id": getattr(tc, "id", f"call_{uuid.uuid4().hex[:8]}"),
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": tc.function.name,
|
"name": tc.function.name,
|
||||||
"arguments": tc.function.arguments if isinstance(tc.function.arguments, str) else json.dumps(tc.function.arguments)
|
"arguments": tc.function.arguments
|
||||||
}
|
if isinstance(tc.function.arguments, str)
|
||||||
|
else json.dumps(tc.function.arguments),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
tc_dict = {
|
tc_dict = {
|
||||||
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
|
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": tc.get("function", {})
|
"function": tc.get("function", {}),
|
||||||
}
|
}
|
||||||
openai_tool_calls.append(tc_dict)
|
openai_tool_calls.append(tc_dict)
|
||||||
openai_msg["tool_calls"] = openai_tool_calls
|
openai_msg["tool_calls"] = openai_tool_calls
|
||||||
|
|
||||||
if tool_call_id:
|
if tool_call_id:
|
||||||
openai_msg["tool_call_id"] = tool_call_id
|
openai_msg["tool_call_id"] = tool_call_id
|
||||||
|
|
||||||
if name and role == "tool":
|
if name and role == "tool":
|
||||||
openai_msg["name"] = name
|
openai_msg["name"] = name
|
||||||
|
|
||||||
openai_messages.append(openai_msg)
|
openai_messages.append(openai_msg)
|
||||||
|
|
||||||
return openai_messages
|
return openai_messages
|
||||||
|
|
||||||
def _convert_tools_to_openai_format(self, tools: list) -> list[dict]:
|
def _convert_tools_to_openai_format(self, tools: list) -> list[dict]:
|
||||||
"""Convert tools to OpenAI format for tokenizer."""
|
"""Convert tools to OpenAI format for tokenizer."""
|
||||||
openai_tools = []
|
openai_tools = []
|
||||||
@@ -218,24 +219,24 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
"name": getattr(func, "name", ""),
|
"name": getattr(func, "name", ""),
|
||||||
"description": getattr(func, "description", ""),
|
"description": getattr(func, "description", ""),
|
||||||
"parameters": getattr(func, "parameters", {}),
|
"parameters": getattr(func, "parameters", {}),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
openai_tools.append(tool_dict)
|
openai_tools.append(tool_dict)
|
||||||
return openai_tools
|
return openai_tools
|
||||||
|
|
||||||
def _format_messages_to_text(self, messages: list, tools: list) -> str:
|
def _format_messages_to_text(self, messages: list, tools: list) -> str:
|
||||||
"""
|
"""
|
||||||
Format messages to text using tokenizer's apply_chat_template if available.
|
Format messages to text using tokenizer's apply_chat_template if available.
|
||||||
|
|
||||||
Falls back to manual formatting if tokenizer is not available.
|
Falls back to manual formatting if tokenizer is not available.
|
||||||
"""
|
"""
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self._get_tokenizer()
|
||||||
|
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
# Use tokenizer's apply_chat_template for proper formatting
|
# Use tokenizer's apply_chat_template for proper formatting
|
||||||
openai_messages = self._convert_messages_to_openai_format(messages)
|
openai_messages = self._convert_messages_to_openai_format(messages)
|
||||||
openai_tools = self._convert_tools_to_openai_format(tools) if tools else None
|
openai_tools = self._convert_tools_to_openai_format(tools) if tools else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
formatted = tokenizer.apply_chat_template(
|
formatted = tokenizer.apply_chat_template(
|
||||||
openai_messages,
|
openai_messages,
|
||||||
@@ -247,30 +248,30 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
return formatted
|
return formatted
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"apply_chat_template failed: {e}, falling back to manual formatting")
|
logger.warning(f"apply_chat_template failed: {e}, falling back to manual formatting")
|
||||||
|
|
||||||
# Fallback to manual formatting
|
# Fallback to manual formatting
|
||||||
return self._format_messages_to_text_manual(messages, tools)
|
return self._format_messages_to_text_manual(messages, tools)
|
||||||
|
|
||||||
def _format_messages_to_text_manual(self, messages: list, tools: list) -> str:
|
def _format_messages_to_text_manual(self, messages: list, tools: list) -> str:
|
||||||
"""Manual fallback formatting for when tokenizer is not available."""
|
"""Manual fallback formatting for when tokenizer is not available."""
|
||||||
formatted_parts = []
|
formatted_parts = []
|
||||||
tools_section = self._format_tools_for_prompt(tools)
|
tools_section = self._format_tools_for_prompt(tools)
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
# Handle both dict and Pydantic Message objects
|
# Handle both dict and Pydantic Message objects
|
||||||
if hasattr(msg, 'role'):
|
if hasattr(msg, "role"):
|
||||||
role = msg.role
|
role = msg.role
|
||||||
content = msg.content if hasattr(msg, 'content') else ""
|
content = msg.content if hasattr(msg, "content") else ""
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
content = " ".join([c.text if hasattr(c, 'text') else str(c) for c in content])
|
content = " ".join([c.text if hasattr(c, "text") else str(c) for c in content])
|
||||||
elif content is None:
|
elif content is None:
|
||||||
content = ""
|
content = ""
|
||||||
tool_calls = getattr(msg, 'tool_calls', None)
|
tool_calls = getattr(msg, "tool_calls", None)
|
||||||
else:
|
else:
|
||||||
role = msg.get("role", "user")
|
role = msg.get("role", "user")
|
||||||
content = msg.get("content", "")
|
content = msg.get("content", "")
|
||||||
tool_calls = msg.get("tool_calls", None)
|
tool_calls = msg.get("tool_calls", None)
|
||||||
|
|
||||||
if role == "system":
|
if role == "system":
|
||||||
system_content = content + tools_section if tools_section else content
|
system_content = content + tools_section if tools_section else content
|
||||||
formatted_parts.append(f"<|im_start|>system\n{system_content}<|im_end|>")
|
formatted_parts.append(f"<|im_start|>system\n{system_content}<|im_end|>")
|
||||||
@@ -281,62 +282,55 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
if tool_calls:
|
if tool_calls:
|
||||||
tc_parts = []
|
tc_parts = []
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
if hasattr(tc, 'function'):
|
if hasattr(tc, "function"):
|
||||||
tc_name = tc.function.name
|
tc_name = tc.function.name
|
||||||
tc_args = tc.function.arguments
|
tc_args = tc.function.arguments
|
||||||
else:
|
else:
|
||||||
tc_name = tc.get("function", {}).get("name", "")
|
tc_name = tc.get("function", {}).get("name", "")
|
||||||
tc_args = tc.get("function", {}).get("arguments", "{}")
|
tc_args = tc.get("function", {}).get("arguments", "{}")
|
||||||
|
|
||||||
if isinstance(tc_args, str):
|
if isinstance(tc_args, str):
|
||||||
try:
|
try:
|
||||||
tc_args = json.loads(tc_args)
|
tc_args = json.loads(tc_args)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
tc_parts.append(
|
tc_parts.append(f'<tool_call>\n{{"name": "{tc_name}", "arguments": {json.dumps(tc_args)}}}\n</tool_call>')
|
||||||
f"<tool_call>\n"
|
|
||||||
f'{{"name": "{tc_name}", "arguments": {json.dumps(tc_args)}}}\n'
|
|
||||||
f"</tool_call>"
|
|
||||||
)
|
|
||||||
|
|
||||||
assistant_content = content + "\n" + "\n".join(tc_parts) if content else "\n".join(tc_parts)
|
assistant_content = content + "\n" + "\n".join(tc_parts) if content else "\n".join(tc_parts)
|
||||||
formatted_parts.append(f"<|im_start|>assistant\n{assistant_content}<|im_end|>")
|
formatted_parts.append(f"<|im_start|>assistant\n{assistant_content}<|im_end|>")
|
||||||
elif content:
|
elif content:
|
||||||
formatted_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
formatted_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
||||||
elif role == "tool":
|
elif role == "tool":
|
||||||
formatted_parts.append(
|
formatted_parts.append(f"<|im_start|>user\n<tool_response>\n{content}\n</tool_response><|im_end|>")
|
||||||
f"<|im_start|>user\n"
|
|
||||||
f"<tool_response>\n{content}\n</tool_response><|im_end|>"
|
|
||||||
)
|
|
||||||
|
|
||||||
formatted_parts.append("<|im_start|>assistant\n")
|
formatted_parts.append("<|im_start|>assistant\n")
|
||||||
return "\n".join(formatted_parts)
|
return "\n".join(formatted_parts)
|
||||||
|
|
||||||
def _parse_tool_calls(self, text: str) -> list[ToolCall]:
|
def _parse_tool_calls(self, text: str) -> list[ToolCall]:
|
||||||
"""
|
"""
|
||||||
Parse tool calls from response text.
|
Parse tool calls from response text.
|
||||||
|
|
||||||
Looks for patterns like:
|
Looks for patterns like:
|
||||||
<tool_call>
|
<tool_call>
|
||||||
{"name": "tool_name", "arguments": {...}}
|
{"name": "tool_name", "arguments": {...}}
|
||||||
</tool_call>
|
</tool_call>
|
||||||
"""
|
"""
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
|
|
||||||
# Find all tool_call blocks
|
# Find all tool_call blocks
|
||||||
pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
|
pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
for match in matches:
|
for match in matches:
|
||||||
try:
|
try:
|
||||||
tc_data = json.loads(match)
|
tc_data = json.loads(match)
|
||||||
name = tc_data.get("name", "")
|
name = tc_data.get("name", "")
|
||||||
arguments = tc_data.get("arguments", {})
|
arguments = tc_data.get("arguments", {})
|
||||||
|
|
||||||
if isinstance(arguments, dict):
|
if isinstance(arguments, dict):
|
||||||
arguments = json.dumps(arguments)
|
arguments = json.dumps(arguments)
|
||||||
|
|
||||||
tool_call = ToolCall(
|
tool_call = ToolCall(
|
||||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||||
type="function",
|
type="function",
|
||||||
@@ -349,17 +343,17 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning(f"Failed to parse tool call JSON: {e}")
|
logger.warning(f"Failed to parse tool call JSON: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
def _extract_content_without_tool_calls(self, text: str) -> str:
|
def _extract_content_without_tool_calls(self, text: str) -> str:
|
||||||
"""Extract content from response, removing tool_call blocks."""
|
"""Extract content from response, removing tool_call blocks."""
|
||||||
# Remove tool_call blocks
|
# Remove tool_call blocks
|
||||||
cleaned = re.sub(r'<tool_call>.*?</tool_call>', '', text, flags=re.DOTALL)
|
cleaned = re.sub(r"<tool_call>.*?</tool_call>", "", text, flags=re.DOTALL)
|
||||||
# Clean up whitespace
|
# Clean up whitespace
|
||||||
cleaned = cleaned.strip()
|
cleaned = cleaned.strip()
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
async def invoke_llm(
|
async def invoke_llm(
|
||||||
self,
|
self,
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
@@ -372,7 +366,7 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
) -> AsyncGenerator[LettaMessage | None, None]:
|
) -> AsyncGenerator[LettaMessage | None, None]:
|
||||||
"""
|
"""
|
||||||
Execute LLM request using SGLang native endpoint.
|
Execute LLM request using SGLang native endpoint.
|
||||||
|
|
||||||
This method:
|
This method:
|
||||||
1. Formats messages and tools to text using chat template
|
1. Formats messages and tools to text using chat template
|
||||||
2. Calls SGLang native /generate endpoint
|
2. Calls SGLang native /generate endpoint
|
||||||
@@ -381,20 +375,20 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
5. Converts response to standard format
|
5. Converts response to standard format
|
||||||
"""
|
"""
|
||||||
self.request_data = request_data
|
self.request_data = request_data
|
||||||
|
|
||||||
# Get sampling params from request_data
|
# Get sampling params from request_data
|
||||||
sampling_params = {
|
sampling_params = {
|
||||||
"temperature": request_data.get("temperature", 0.7),
|
"temperature": request_data.get("temperature", 0.7),
|
||||||
"max_new_tokens": request_data.get("max_tokens", 4096),
|
"max_new_tokens": request_data.get("max_tokens", 4096),
|
||||||
"top_p": request_data.get("top_p", 0.9),
|
"top_p": request_data.get("top_p", 0.9),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Format messages to text (includes tools in prompt)
|
# Format messages to text (includes tools in prompt)
|
||||||
text_input = self._format_messages_to_text(messages, tools)
|
text_input = self._format_messages_to_text(messages, tools)
|
||||||
|
|
||||||
# Call SGLang native endpoint
|
# Call SGLang native endpoint
|
||||||
client = self._get_sglang_client()
|
client = self._get_sglang_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.generate(
|
response = await client.generate(
|
||||||
text=text_input,
|
text=text_input,
|
||||||
@@ -404,31 +398,31 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"SGLang native endpoint error: {e}")
|
logger.error(f"SGLang native endpoint error: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||||
|
|
||||||
# Store native response data
|
# Store native response data
|
||||||
self.response_data = response
|
self.response_data = response
|
||||||
|
|
||||||
# Extract SGLang native data
|
# Extract SGLang native data
|
||||||
self.output_ids = response.get("output_ids")
|
self.output_ids = response.get("output_ids")
|
||||||
# output_token_logprobs is inside meta_info
|
# output_token_logprobs is inside meta_info
|
||||||
meta_info = response.get("meta_info", {})
|
meta_info = response.get("meta_info", {})
|
||||||
self.output_token_logprobs = meta_info.get("output_token_logprobs")
|
self.output_token_logprobs = meta_info.get("output_token_logprobs")
|
||||||
|
|
||||||
# Extract text response
|
# Extract text response
|
||||||
text_response = response.get("text", "")
|
text_response = response.get("text", "")
|
||||||
|
|
||||||
# Remove trailing end token if present
|
# Remove trailing end token if present
|
||||||
if text_response.endswith("<|im_end|>"):
|
if text_response.endswith("<|im_end|>"):
|
||||||
text_response = text_response[:-10]
|
text_response = text_response[:-10]
|
||||||
|
|
||||||
# Parse tool calls from response
|
# Parse tool calls from response
|
||||||
parsed_tool_calls = self._parse_tool_calls(text_response)
|
parsed_tool_calls = self._parse_tool_calls(text_response)
|
||||||
|
|
||||||
# Extract content (text without tool_call blocks)
|
# Extract content (text without tool_call blocks)
|
||||||
content_text = self._extract_content_without_tool_calls(text_response)
|
content_text = self._extract_content_without_tool_calls(text_response)
|
||||||
|
|
||||||
# Determine finish reason
|
# Determine finish reason
|
||||||
meta_info = response.get("meta_info", {})
|
meta_info = response.get("meta_info", {})
|
||||||
finish_reason_info = meta_info.get("finish_reason", {})
|
finish_reason_info = meta_info.get("finish_reason", {})
|
||||||
@@ -436,11 +430,11 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
finish_reason = finish_reason_info.get("type", "stop")
|
finish_reason = finish_reason_info.get("type", "stop")
|
||||||
else:
|
else:
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
|
||||||
# If we have tool calls, set finish_reason to tool_calls
|
# If we have tool calls, set finish_reason to tool_calls
|
||||||
if parsed_tool_calls:
|
if parsed_tool_calls:
|
||||||
finish_reason = "tool_calls"
|
finish_reason = "tool_calls"
|
||||||
|
|
||||||
# Convert to standard ChatCompletionResponse format for compatibility
|
# Convert to standard ChatCompletionResponse format for compatibility
|
||||||
# Build logprobs in OpenAI format from SGLang format
|
# Build logprobs in OpenAI format from SGLang format
|
||||||
logprobs_content = None
|
logprobs_content = None
|
||||||
@@ -458,13 +452,13 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
top_logprobs=[],
|
top_logprobs=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
choice_logprobs = ChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
choice_logprobs = ChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||||
|
|
||||||
# Build chat completion response
|
# Build chat completion response
|
||||||
prompt_tokens = meta_info.get("prompt_tokens", 0)
|
prompt_tokens = meta_info.get("prompt_tokens", 0)
|
||||||
completion_tokens = len(self.output_ids) if self.output_ids else 0
|
completion_tokens = len(self.output_ids) if self.output_ids else 0
|
||||||
|
|
||||||
self.chat_completions_response = ChatCompletionResponse(
|
self.chat_completions_response = ChatCompletionResponse(
|
||||||
id=meta_info.get("id", "sglang-native"),
|
id=meta_info.get("id", "sglang-native"),
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
@@ -486,36 +480,36 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
|||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract content
|
# Extract content
|
||||||
if content_text:
|
if content_text:
|
||||||
self.content = [TextContent(text=content_text)]
|
self.content = [TextContent(text=content_text)]
|
||||||
else:
|
else:
|
||||||
self.content = None
|
self.content = None
|
||||||
|
|
||||||
# No reasoning content from native endpoint
|
# No reasoning content from native endpoint
|
||||||
self.reasoning_content = None
|
self.reasoning_content = None
|
||||||
|
|
||||||
# Set tool calls
|
# Set tool calls
|
||||||
self.tool_calls = parsed_tool_calls
|
self.tool_calls = parsed_tool_calls
|
||||||
self.tool_call = parsed_tool_calls[0] if parsed_tool_calls else None
|
self.tool_call = parsed_tool_calls[0] if parsed_tool_calls else None
|
||||||
|
|
||||||
# Set logprobs
|
# Set logprobs
|
||||||
self.logprobs = choice_logprobs
|
self.logprobs = choice_logprobs
|
||||||
|
|
||||||
# Extract usage statistics
|
# Extract usage statistics
|
||||||
self.usage.step_count = 1
|
self.usage.step_count = 1
|
||||||
self.usage.completion_tokens = completion_tokens
|
self.usage.completion_tokens = completion_tokens
|
||||||
self.usage.prompt_tokens = prompt_tokens
|
self.usage.prompt_tokens = prompt_tokens
|
||||||
self.usage.total_tokens = prompt_tokens + completion_tokens
|
self.usage.total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
self.log_provider_trace(step_id=step_id, actor=actor)
|
self.log_provider_trace(step_id=step_id, actor=actor)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"SGLang native response: {len(self.output_ids or [])} tokens, "
|
f"SGLang native response: {len(self.output_ids or [])} tokens, "
|
||||||
f"{len(self.output_token_logprobs or [])} logprobs, "
|
f"{len(self.output_token_logprobs or [])} logprobs, "
|
||||||
f"{len(parsed_tool_calls)} tool calls"
|
f"{len(parsed_tool_calls)} tool calls"
|
||||||
)
|
)
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from typing import AsyncGenerator, List
|
from typing import AsyncGenerator, List
|
||||||
|
|
||||||
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
@@ -15,7 +14,7 @@ from letta.schemas.letta_message import MessageType
|
|||||||
from letta.schemas.letta_message_content import TextContent
|
from letta.schemas.letta_message_content import TextContent
|
||||||
from letta.schemas.letta_response import LettaResponse
|
from letta.schemas.letta_response import LettaResponse
|
||||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||||
from letta.schemas.message import ApprovalCreate, Message, MessageCreate, MessageCreateBase, ToolReturnCreate
|
from letta.schemas.message import ApprovalCreate, Message, MessageCreate, MessageCreateBase
|
||||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
@@ -463,7 +462,7 @@ def _schema_accepts_value(prop_schema: Dict[str, Any], value: Any) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def merge_and_validate_prefilled_args(tool: "Tool", llm_args: Dict[str, Any], prefilled_args: Dict[str, Any]) -> Dict[str, Any]:
|
def merge_and_validate_prefilled_args(tool: "Tool", llm_args: Dict[str, Any], prefilled_args: Dict[str, Any]) -> Dict[str, Any]: # noqa: F821
|
||||||
"""Merge LLM-provided args with prefilled args from tool rules.
|
"""Merge LLM-provided args with prefilled args from tool rules.
|
||||||
|
|
||||||
- Overlapping keys are replaced by prefilled values (prefilled wins).
|
- Overlapping keys are replaced by prefilled values (prefilled wins).
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
|||||||
from letta.agents.helpers import (
|
from letta.agents.helpers import (
|
||||||
_build_rule_violation_result,
|
_build_rule_violation_result,
|
||||||
_create_letta_response,
|
_create_letta_response,
|
||||||
_load_last_function_response,
|
|
||||||
_pop_heartbeat,
|
_pop_heartbeat,
|
||||||
_prepare_in_context_messages_no_persist_async,
|
_prepare_in_context_messages_no_persist_async,
|
||||||
_safe_load_tool_call_str,
|
_safe_load_tool_call_str,
|
||||||
@@ -293,6 +292,7 @@ class LettaAgent(BaseAgent):
|
|||||||
agent_step_span.set_attributes({"step_id": step_id})
|
agent_step_span.set_attributes({"step_id": step_id})
|
||||||
|
|
||||||
step_progression = StepProgression.START
|
step_progression = StepProgression.START
|
||||||
|
caught_exception = None
|
||||||
should_continue = False
|
should_continue = False
|
||||||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||||||
|
|
||||||
@@ -439,6 +439,7 @@ class LettaAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
caught_exception = e
|
||||||
# Handle any unexpected errors during step processing
|
# Handle any unexpected errors during step processing
|
||||||
self.logger.error(f"Error during step processing: {e}")
|
self.logger.error(f"Error during step processing: {e}")
|
||||||
job_update_metadata = {"error": str(e)}
|
job_update_metadata = {"error": str(e)}
|
||||||
@@ -485,8 +486,8 @@ class LettaAgent(BaseAgent):
|
|||||||
await self.step_manager.update_step_error_async(
|
await self.step_manager.update_step_error_async(
|
||||||
actor=self.actor,
|
actor=self.actor,
|
||||||
step_id=step_id, # Use original step_id for telemetry
|
step_id=step_id, # Use original step_id for telemetry
|
||||||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
|
||||||
error_message=str(e) if "e" in locals() else "Unknown error",
|
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
|
||||||
error_traceback=traceback.format_exc(),
|
error_traceback=traceback.format_exc(),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
@@ -632,6 +633,7 @@ class LettaAgent(BaseAgent):
|
|||||||
agent_step_span.set_attributes({"step_id": step_id})
|
agent_step_span.set_attributes({"step_id": step_id})
|
||||||
|
|
||||||
step_progression = StepProgression.START
|
step_progression = StepProgression.START
|
||||||
|
caught_exception = None
|
||||||
should_continue = False
|
should_continue = False
|
||||||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||||||
|
|
||||||
@@ -768,6 +770,7 @@ class LettaAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
caught_exception = e
|
||||||
# Handle any unexpected errors during step processing
|
# Handle any unexpected errors during step processing
|
||||||
self.logger.error(f"Error during step processing: {e}")
|
self.logger.error(f"Error during step processing: {e}")
|
||||||
job_update_metadata = {"error": str(e)}
|
job_update_metadata = {"error": str(e)}
|
||||||
@@ -810,8 +813,8 @@ class LettaAgent(BaseAgent):
|
|||||||
await self.step_manager.update_step_error_async(
|
await self.step_manager.update_step_error_async(
|
||||||
actor=self.actor,
|
actor=self.actor,
|
||||||
step_id=step_id, # Use original step_id for telemetry
|
step_id=step_id, # Use original step_id for telemetry
|
||||||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
|
||||||
error_message=str(e) if "e" in locals() else "Unknown error",
|
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
|
||||||
error_traceback=traceback.format_exc(),
|
error_traceback=traceback.format_exc(),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
@@ -973,6 +976,7 @@ class LettaAgent(BaseAgent):
|
|||||||
agent_step_span.set_attributes({"step_id": step_id})
|
agent_step_span.set_attributes({"step_id": step_id})
|
||||||
|
|
||||||
step_progression = StepProgression.START
|
step_progression = StepProgression.START
|
||||||
|
caught_exception = None
|
||||||
should_continue = False
|
should_continue = False
|
||||||
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
step_metrics = StepMetrics(id=step_id) # Initialize metrics tracking
|
||||||
|
|
||||||
@@ -1228,6 +1232,7 @@ class LettaAgent(BaseAgent):
|
|||||||
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
|
self.logger.warning(f"Failed to record step metrics: {metrics_error}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
caught_exception = e
|
||||||
# Handle any unexpected errors during step processing
|
# Handle any unexpected errors during step processing
|
||||||
self.logger.error(f"Error during step processing: {e}")
|
self.logger.error(f"Error during step processing: {e}")
|
||||||
job_update_metadata = {"error": str(e)}
|
job_update_metadata = {"error": str(e)}
|
||||||
@@ -1274,8 +1279,8 @@ class LettaAgent(BaseAgent):
|
|||||||
await self.step_manager.update_step_error_async(
|
await self.step_manager.update_step_error_async(
|
||||||
actor=self.actor,
|
actor=self.actor,
|
||||||
step_id=step_id, # Use original step_id for telemetry
|
step_id=step_id, # Use original step_id for telemetry
|
||||||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
|
||||||
error_message=str(e) if "e" in locals() else "Unknown error",
|
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
|
||||||
error_traceback=traceback.format_exc(),
|
error_traceback=traceback.format_exc(),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from letta.adapters.letta_llm_adapter import LettaLLMAdapter
|
|||||||
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
|
from letta.adapters.letta_llm_request_adapter import LettaLLMRequestAdapter
|
||||||
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
||||||
from letta.agents.base_agent_v2 import BaseAgentV2
|
from letta.agents.base_agent_v2 import BaseAgentV2
|
||||||
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
|
||||||
from letta.agents.helpers import (
|
from letta.agents.helpers import (
|
||||||
_build_rule_violation_result,
|
_build_rule_violation_result,
|
||||||
_load_last_function_response,
|
_load_last_function_response,
|
||||||
@@ -68,7 +67,7 @@ from letta.services.summarizer.enums import SummarizationMode
|
|||||||
from letta.services.summarizer.summarizer import Summarizer
|
from letta.services.summarizer.summarizer import Summarizer
|
||||||
from letta.services.telemetry_manager import TelemetryManager
|
from letta.services.telemetry_manager import TelemetryManager
|
||||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||||
from letta.settings import model_settings, settings, summarizer_settings
|
from letta.settings import settings, summarizer_settings
|
||||||
from letta.system import package_function_response
|
from letta.system import package_function_response
|
||||||
from letta.types import JsonDict
|
from letta.types import JsonDict
|
||||||
from letta.utils import log_telemetry, safe_create_task, safe_create_task_with_return, united_diff, validate_function_response
|
from letta.utils import log_telemetry, safe_create_task, safe_create_task_with_return, united_diff, validate_function_response
|
||||||
@@ -455,6 +454,7 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
raise AssertionError("run_id is required when enforce_run_id_set is True")
|
raise AssertionError("run_id is required when enforce_run_id_set is True")
|
||||||
|
|
||||||
step_progression = StepProgression.START
|
step_progression = StepProgression.START
|
||||||
|
caught_exception = None
|
||||||
# TODO(@caren): clean this up
|
# TODO(@caren): clean this up
|
||||||
tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
|
tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
|
||||||
None,
|
None,
|
||||||
@@ -615,6 +615,7 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
)
|
)
|
||||||
step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
|
step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
caught_exception = e
|
||||||
self.logger.warning(f"Error during step processing: {e}")
|
self.logger.warning(f"Error during step processing: {e}")
|
||||||
self.job_update_metadata = {"error": str(e)}
|
self.job_update_metadata = {"error": str(e)}
|
||||||
|
|
||||||
@@ -650,8 +651,8 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
await self.step_manager.update_step_error_async(
|
await self.step_manager.update_step_error_async(
|
||||||
actor=self.actor,
|
actor=self.actor,
|
||||||
step_id=step_id, # Use original step_id for telemetry
|
step_id=step_id, # Use original step_id for telemetry
|
||||||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
|
||||||
error_message=str(e) if "e" in locals() else "Unknown error",
|
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
|
||||||
error_traceback=traceback.format_exc(),
|
error_traceback=traceback.format_exc(),
|
||||||
stop_reason=self.stop_reason,
|
stop_reason=self.stop_reason,
|
||||||
)
|
)
|
||||||
@@ -705,14 +706,11 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
async def _check_credits(self) -> bool:
|
async def _check_credits(self) -> bool:
|
||||||
"""Check if the organization still has credits. Returns True if OK or not configured."""
|
"""Check if the organization still has credits. Returns True if OK or not configured."""
|
||||||
try:
|
try:
|
||||||
await self.credit_verification_service.verify_credits(
|
await self.credit_verification_service.verify_credits(self.actor.organization_id, self.agent_state.id)
|
||||||
self.actor.organization_id, self.agent_state.id
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
except InsufficientCreditsError:
|
except InsufficientCreditsError:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"Insufficient credits for organization {self.actor.organization_id}, "
|
f"Insufficient credits for organization {self.actor.organization_id}, agent {self.agent_state.id}, stopping agent loop"
|
||||||
f"agent {self.agent_state.id}, stopping agent loop"
|
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, Literal, Optional
|
from typing import Any, AsyncGenerator, Dict, Optional
|
||||||
|
|
||||||
from opentelemetry.trace import Span
|
from opentelemetry.trace import Span
|
||||||
|
|
||||||
@@ -20,16 +20,15 @@ from letta.agents.helpers import (
|
|||||||
merge_and_validate_prefilled_args,
|
merge_and_validate_prefilled_args,
|
||||||
)
|
)
|
||||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||||
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM, SUMMARIZATION_TRIGGER_MULTIPLIER
|
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
|
||||||
from letta.errors import ContextWindowExceededError, LLMError, SystemPromptTokenExceededError
|
from letta.errors import ContextWindowExceededError, LLMError, SystemPromptTokenExceededError
|
||||||
from letta.helpers import ToolRulesSolver
|
from letta.helpers import ToolRulesSolver
|
||||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
||||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
|
||||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.agent import AgentState
|
from letta.schemas.agent import AgentState
|
||||||
from letta.schemas.enums import LLMCallType, MessageRole
|
from letta.schemas.enums import LLMCallType
|
||||||
from letta.schemas.letta_message import (
|
from letta.schemas.letta_message import (
|
||||||
ApprovalReturn,
|
ApprovalReturn,
|
||||||
CompactionStats,
|
CompactionStats,
|
||||||
@@ -44,13 +43,11 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni
|
|||||||
from letta.schemas.letta_request import ClientToolSchema
|
from letta.schemas.letta_request import ClientToolSchema
|
||||||
from letta.schemas.letta_response import LettaResponse, TurnTokenData
|
from letta.schemas.letta_response import LettaResponse, TurnTokenData
|
||||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||||
from letta.schemas.llm_config import LLMConfig
|
|
||||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||||
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, FunctionCall, ToolCall, ToolCallDenial, UsageStatistics
|
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, ToolCall, ToolCallDenial, UsageStatistics
|
||||||
from letta.schemas.step import StepProgression
|
from letta.schemas.step import StepProgression
|
||||||
from letta.schemas.step_metrics import StepMetrics
|
from letta.schemas.step_metrics import StepMetrics
|
||||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.server.rest_api.utils import (
|
from letta.server.rest_api.utils import (
|
||||||
create_approval_request_message_from_llm_response,
|
create_approval_request_message_from_llm_response,
|
||||||
@@ -64,8 +61,8 @@ from letta.services.summarizer.compact import compact_messages
|
|||||||
from letta.services.summarizer.summarizer_config import CompactionSettings
|
from letta.services.summarizer.summarizer_config import CompactionSettings
|
||||||
from letta.services.summarizer.summarizer_sliding_window import count_tokens
|
from letta.services.summarizer.summarizer_sliding_window import count_tokens
|
||||||
from letta.settings import settings, summarizer_settings
|
from letta.settings import settings, summarizer_settings
|
||||||
from letta.system import package_function_response, package_summarize_message_no_counts
|
from letta.system import package_function_response
|
||||||
from letta.utils import log_telemetry, safe_create_task_with_return, validate_function_response
|
from letta.utils import safe_create_task_with_return, validate_function_response
|
||||||
|
|
||||||
|
|
||||||
def extract_compaction_stats_from_message(message: Message) -> CompactionStats | None:
|
def extract_compaction_stats_from_message(message: Message) -> CompactionStats | None:
|
||||||
@@ -800,6 +797,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
self.logger.warning("Context token estimate is not set")
|
self.logger.warning("Context token estimate is not set")
|
||||||
|
|
||||||
step_progression = StepProgression.START
|
step_progression = StepProgression.START
|
||||||
|
caught_exception = None
|
||||||
# TODO(@caren): clean this up
|
# TODO(@caren): clean this up
|
||||||
tool_calls, content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
|
tool_calls, content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
|
||||||
None,
|
None,
|
||||||
@@ -1272,6 +1270,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
caught_exception = e
|
||||||
# NOTE: message persistence does not happen in the case of an exception (rollback to previous state)
|
# NOTE: message persistence does not happen in the case of an exception (rollback to previous state)
|
||||||
# Use repr() if str() is empty (happens with Exception() with no args)
|
# Use repr() if str() is empty (happens with Exception() with no args)
|
||||||
error_detail = str(e) or repr(e)
|
error_detail = str(e) or repr(e)
|
||||||
@@ -1322,8 +1321,8 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
await self.step_manager.update_step_error_async(
|
await self.step_manager.update_step_error_async(
|
||||||
actor=self.actor,
|
actor=self.actor,
|
||||||
step_id=step_id, # Use original step_id for telemetry
|
step_id=step_id, # Use original step_id for telemetry
|
||||||
error_type=type(e).__name__ if "e" in locals() else "Unknown",
|
error_type=type(caught_exception).__name__ if caught_exception is not None else "Unknown",
|
||||||
error_message=str(e) if "e" in locals() else "Unknown error",
|
error_message=str(caught_exception) if caught_exception is not None else "Unknown error",
|
||||||
error_traceback=traceback.format_exc(),
|
error_traceback=traceback.format_exc(),
|
||||||
stop_reason=self.stop_reason,
|
stop_reason=self.stop_reason,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -414,7 +414,7 @@ class VoiceAgent(BaseAgent):
|
|||||||
for t in tools
|
for t in tools
|
||||||
]
|
]
|
||||||
|
|
||||||
async def _execute_tool(self, user_query: str, tool_name: str, tool_args: dict, agent_state: AgentState) -> "ToolExecutionResult":
|
async def _execute_tool(self, user_query: str, tool_name: str, tool_args: dict, agent_state: AgentState) -> "ToolExecutionResult": # noqa: F821
|
||||||
"""
|
"""
|
||||||
Executes a tool and returns the ToolExecutionResult.
|
Executes a tool and returns the ToolExecutionResult.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -110,9 +110,9 @@ class VoiceSleeptimeAgent(LettaAgent):
|
|||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_args: JsonDict,
|
tool_args: JsonDict,
|
||||||
agent_state: AgentState,
|
agent_state: AgentState,
|
||||||
agent_step_span: Optional["Span"] = None,
|
agent_step_span: Optional["Span"] = None, # noqa: F821
|
||||||
step_id: str | None = None,
|
step_id: str | None = None,
|
||||||
) -> "ToolExecutionResult":
|
) -> "ToolExecutionResult": # noqa: F821
|
||||||
"""
|
"""
|
||||||
Executes a tool and returns the ToolExecutionResult
|
Executes a tool and returns the ToolExecutionResult
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Annotated, Optional
|
|||||||
import typer
|
import typer
|
||||||
|
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class DataConnector:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def load_data(connector: DataConnector, source: Source, passage_manager: PassageManager, file_manager: FileManager, actor: "User"):
|
async def load_data(connector: DataConnector, source: Source, passage_manager: PassageManager, file_manager: FileManager, actor: "User"): # noqa: F821
|
||||||
from letta.llm_api.llm_client import LLMClient
|
from letta.llm_api.llm_client import LLMClient
|
||||||
|
|
||||||
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
|
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
|
||||||
|
|||||||
@@ -362,16 +362,16 @@ class RateLimitExceededError(LettaError):
|
|||||||
class LettaMessageError(LettaError):
|
class LettaMessageError(LettaError):
|
||||||
"""Base error class for handling message-related errors."""
|
"""Base error class for handling message-related errors."""
|
||||||
|
|
||||||
messages: List[Union["Message", "LettaMessage"]]
|
messages: List[Union["Message", "LettaMessage"]] # noqa: F821
|
||||||
default_error_message: str = "An error occurred with the message."
|
default_error_message: str = "An error occurred with the message."
|
||||||
|
|
||||||
def __init__(self, *, messages: List[Union["Message", "LettaMessage"]], explanation: Optional[str] = None) -> None:
|
def __init__(self, *, messages: List[Union["Message", "LettaMessage"]], explanation: Optional[str] = None) -> None: # noqa: F821
|
||||||
error_msg = self.construct_error_message(messages, self.default_error_message, explanation)
|
error_msg = self.construct_error_message(messages, self.default_error_message, explanation)
|
||||||
super().__init__(error_msg)
|
super().__init__(error_msg)
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def construct_error_message(messages: List[Union["Message", "LettaMessage"]], error_msg: str, explanation: Optional[str] = None) -> str:
|
def construct_error_message(messages: List[Union["Message", "LettaMessage"]], error_msg: str, explanation: Optional[str] = None) -> str: # noqa: F821
|
||||||
"""Helper method to construct a clean and formatted error message."""
|
"""Helper method to construct a clean and formatted error message."""
|
||||||
if explanation:
|
if explanation:
|
||||||
error_msg += f" (Explanation: {explanation})"
|
error_msg += f" (Explanation: {explanation})"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional
|
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||||
|
|
||||||
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
|
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ def memory(
|
|||||||
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
|
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
|
||||||
|
|
||||||
|
|
||||||
def send_message(self: "Agent", message: str) -> Optional[str]:
|
def send_message(self: "Agent", message: str) -> Optional[str]: # noqa: F821
|
||||||
"""
|
"""
|
||||||
Sends a message to the human user.
|
Sends a message to the human user.
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ def send_message(self: "Agent", message: str) -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
def conversation_search(
|
def conversation_search(
|
||||||
self: "Agent",
|
self: "Agent", # noqa: F821
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
roles: Optional[List[Literal["assistant", "user", "tool"]]] = None,
|
roles: Optional[List[Literal["assistant", "user", "tool"]]] = None,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
@@ -160,7 +160,7 @@ def conversation_search(
|
|||||||
return results_str
|
return results_str
|
||||||
|
|
||||||
|
|
||||||
async def archival_memory_insert(self: "Agent", content: str, tags: Optional[list[str]] = None) -> Optional[str]:
|
async def archival_memory_insert(self: "Agent", content: str, tags: Optional[list[str]] = None) -> Optional[str]: # noqa: F821
|
||||||
"""
|
"""
|
||||||
Add information to long-term archival memory for later retrieval.
|
Add information to long-term archival memory for later retrieval.
|
||||||
|
|
||||||
@@ -191,7 +191,7 @@ async def archival_memory_insert(self: "Agent", content: str, tags: Optional[lis
|
|||||||
|
|
||||||
|
|
||||||
async def archival_memory_search(
|
async def archival_memory_search(
|
||||||
self: "Agent",
|
self: "Agent", # noqa: F821
|
||||||
query: str,
|
query: str,
|
||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
tag_match_mode: Literal["any", "all"] = "any",
|
tag_match_mode: Literal["any", "all"] = "any",
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import List
|
||||||
|
|
||||||
from letta.functions.helpers import (
|
from letta.functions.helpers import (
|
||||||
_send_message_to_agents_matching_tags_async,
|
_send_message_to_agents_matching_tags_async,
|
||||||
@@ -10,9 +10,10 @@ from letta.functions.helpers import (
|
|||||||
from letta.schemas.enums import MessageRole
|
from letta.schemas.enums import MessageRole
|
||||||
from letta.schemas.message import MessageCreate
|
from letta.schemas.message import MessageCreate
|
||||||
from letta.server.rest_api.dependencies import get_letta_server
|
from letta.server.rest_api.dependencies import get_letta_server
|
||||||
|
from letta.settings import settings
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_agent_id: str) -> str:
|
def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_agent_id: str) -> str: # noqa: F821
|
||||||
"""
|
"""
|
||||||
Sends a message to a specific Letta agent within the same organization and waits for a response. The sender's identity is automatically included, so no explicit introduction is needed in the message. This function is designed for two-way communication where a reply is expected.
|
Sends a message to a specific Letta agent within the same organization and waits for a response. The sender's identity is automatically included, so no explicit introduction is needed in the message. This function is designed for two-way communication where a reply is expected.
|
||||||
|
|
||||||
@@ -36,7 +37,7 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: List[str], match_some: List[str]) -> List[str]:
|
def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: List[str], match_some: List[str]) -> List[str]: # noqa: F821
|
||||||
"""
|
"""
|
||||||
Sends a message to all agents within the same organization that match the specified tag criteria. Agents must possess *all* of the tags in `match_all` and *at least one* of the tags in `match_some` to receive the message.
|
Sends a message to all agents within the same organization that match the specified tag criteria. Agents must possess *all* of the tags in `match_all` and *at least one* of the tags in `match_some` to receive the message.
|
||||||
|
|
||||||
@@ -65,7 +66,7 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
|
|||||||
return asyncio.run(_send_message_to_agents_matching_tags_async(self, server, messages, matching_agents))
|
return asyncio.run(_send_message_to_agents_matching_tags_async(self, server, messages, matching_agents))
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]:
|
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]: # noqa: F821
|
||||||
"""
|
"""
|
||||||
Sends a message to all agents within the same multi-agent group.
|
Sends a message to all agents within the same multi-agent group.
|
||||||
|
|
||||||
@@ -81,7 +82,7 @@ def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str
|
|||||||
return asyncio.run(_send_message_to_all_agents_in_group_async(self, message))
|
return asyncio.run(_send_message_to_all_agents_in_group_async(self, message))
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str) -> str:
|
def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str) -> str: # noqa: F821
|
||||||
"""
|
"""
|
||||||
Sends a message to a specific Letta agent within the same organization. The sender's identity is automatically included, so no explicit introduction is required in the message. This function does not expect a response from the target agent, making it suitable for notifications or one-way communication.
|
Sends a message to a specific Letta agent within the same organization. The sender's identity is automatically included, so no explicit introduction is required in the message. This function does not expect a response from the target agent, making it suitable for notifications or one-way communication.
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
def rethink_user_memory(agent_state: "AgentState", new_memory: str) -> None:
|
def rethink_user_memory(agent_state: "AgentState", new_memory: str) -> None: # noqa: F821
|
||||||
"""
|
"""
|
||||||
Rewrite memory block for the main agent, new_memory should contain all current information from the block that is not outdated or inconsistent, integrating any new information, resulting in a new memory block that is organized, readable, and comprehensive.
|
Rewrite memory block for the main agent, new_memory should contain all current information from the block that is not outdated or inconsistent, integrating any new information, resulting in a new memory block that is organized, readable, and comprehensive.
|
||||||
|
|
||||||
@@ -18,7 +18,7 @@ def rethink_user_memory(agent_state: "AgentState", new_memory: str) -> None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def finish_rethinking_memory(agent_state: "AgentState") -> None: # type: ignore
|
def finish_rethinking_memory(agent_state: "AgentState") -> None: # type: ignore # noqa: F821
|
||||||
"""
|
"""
|
||||||
This function is called when the agent is done rethinking the memory.
|
This function is called when the agent is done rethinking the memory.
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ class MemoryChunk(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def store_memories(agent_state: "AgentState", chunks: List[MemoryChunk]) -> None:
|
def store_memories(agent_state: "AgentState", chunks: List[MemoryChunk]) -> None: # noqa: F821
|
||||||
"""
|
"""
|
||||||
Persist dialogue that is about to fall out of the agent’s context window.
|
Persist dialogue that is about to fall out of the agent’s context window.
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ def store_memories(agent_state: "AgentState", chunks: List[MemoryChunk]) -> None
|
|||||||
|
|
||||||
|
|
||||||
def search_memory(
|
def search_memory(
|
||||||
agent_state: "AgentState",
|
agent_state: "AgentState", # noqa: F821
|
||||||
convo_keyword_queries: Optional[List[str]],
|
convo_keyword_queries: Optional[List[str]],
|
||||||
start_minutes_ago: Optional[int],
|
start_minutes_ago: Optional[int],
|
||||||
end_minutes_ago: Optional[int],
|
end_minutes_ago: Optional[int],
|
||||||
|
|||||||
@@ -36,7 +36,8 @@ def {mcp_tool_name}(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def generate_langchain_tool_wrapper(
|
def generate_langchain_tool_wrapper(
|
||||||
tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None
|
tool: "LangChainBaseTool", # noqa: F821
|
||||||
|
additional_imports_module_attr_map: dict[str, str] = None,
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
tool_name = tool.__class__.__name__
|
tool_name = tool.__class__.__name__
|
||||||
import_statement = f"from langchain_community.tools import {tool_name}"
|
import_statement = f"from langchain_community.tools import {tool_name}"
|
||||||
@@ -72,7 +73,7 @@ def _assert_code_gen_compilable(code_str):
|
|||||||
print(f"Syntax error in code: {e}")
|
print(f"Syntax error in code: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
|
def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None: # noqa: F821
|
||||||
# Safety check that user has passed in all required imports:
|
# Safety check that user has passed in all required imports:
|
||||||
tool_name = tool.__class__.__name__
|
tool_name = tool.__class__.__name__
|
||||||
current_class_imports = {tool_name}
|
current_class_imports = {tool_name}
|
||||||
@@ -86,7 +87,7 @@ def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additiona
|
|||||||
raise RuntimeError(err_msg)
|
raise RuntimeError(err_msg)
|
||||||
|
|
||||||
|
|
||||||
def _find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]:
|
def _find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]: # noqa: F821
|
||||||
"""
|
"""
|
||||||
Finds all the class names for required imports when instantiating the `obj`.
|
Finds all the class names for required imports when instantiating the `obj`.
|
||||||
NOTE: This does not return the full import path, only the class name.
|
NOTE: This does not return the full import path, only the class name.
|
||||||
@@ -224,7 +225,7 @@ def _parse_letta_response_for_assistant_message(
|
|||||||
|
|
||||||
|
|
||||||
async def async_execute_send_message_to_agent(
|
async def async_execute_send_message_to_agent(
|
||||||
sender_agent: "Agent",
|
sender_agent: "Agent", # noqa: F821
|
||||||
messages: List[MessageCreate],
|
messages: List[MessageCreate],
|
||||||
other_agent_id: str,
|
other_agent_id: str,
|
||||||
log_prefix: str,
|
log_prefix: str,
|
||||||
@@ -255,7 +256,7 @@ async def async_execute_send_message_to_agent(
|
|||||||
|
|
||||||
|
|
||||||
def execute_send_message_to_agent(
|
def execute_send_message_to_agent(
|
||||||
sender_agent: "Agent",
|
sender_agent: "Agent", # noqa: F821
|
||||||
messages: List[MessageCreate],
|
messages: List[MessageCreate],
|
||||||
other_agent_id: str,
|
other_agent_id: str,
|
||||||
log_prefix: str,
|
log_prefix: str,
|
||||||
@@ -268,7 +269,7 @@ def execute_send_message_to_agent(
|
|||||||
|
|
||||||
|
|
||||||
async def _send_message_to_agent_no_stream(
|
async def _send_message_to_agent_no_stream(
|
||||||
server: "SyncServer",
|
server: "SyncServer", # noqa: F821
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
actor: User,
|
actor: User,
|
||||||
messages: List[MessageCreate],
|
messages: List[MessageCreate],
|
||||||
@@ -301,8 +302,8 @@ async def _send_message_to_agent_no_stream(
|
|||||||
|
|
||||||
|
|
||||||
async def _async_send_message_with_retries(
|
async def _async_send_message_with_retries(
|
||||||
server: "SyncServer",
|
server: "SyncServer", # noqa: F821
|
||||||
sender_agent: "Agent",
|
sender_agent: "Agent", # noqa: F821
|
||||||
target_agent_id: str,
|
target_agent_id: str,
|
||||||
messages: List[MessageCreate],
|
messages: List[MessageCreate],
|
||||||
max_retries: int,
|
max_retries: int,
|
||||||
@@ -352,7 +353,7 @@ async def _async_send_message_with_retries(
|
|||||||
|
|
||||||
|
|
||||||
def fire_and_forget_send_to_agent(
|
def fire_and_forget_send_to_agent(
|
||||||
sender_agent: "Agent",
|
sender_agent: "Agent", # noqa: F821
|
||||||
messages: List[MessageCreate],
|
messages: List[MessageCreate],
|
||||||
other_agent_id: str,
|
other_agent_id: str,
|
||||||
log_prefix: str,
|
log_prefix: str,
|
||||||
@@ -436,7 +437,10 @@ def fire_and_forget_send_to_agent(
|
|||||||
|
|
||||||
|
|
||||||
async def _send_message_to_agents_matching_tags_async(
|
async def _send_message_to_agents_matching_tags_async(
|
||||||
sender_agent: "Agent", server: "SyncServer", messages: List[MessageCreate], matching_agents: List["AgentState"]
|
sender_agent: "Agent", # noqa: F821
|
||||||
|
server: "SyncServer", # noqa: F821
|
||||||
|
messages: List[MessageCreate],
|
||||||
|
matching_agents: List["AgentState"], # noqa: F821
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
async def _send_single(agent_state):
|
async def _send_single(agent_state):
|
||||||
return await _async_send_message_with_retries(
|
return await _async_send_message_with_retries(
|
||||||
@@ -460,7 +464,7 @@ async def _send_message_to_agents_matching_tags_async(
|
|||||||
return final
|
return final
|
||||||
|
|
||||||
|
|
||||||
async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", message: str) -> List[str]:
|
async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", message: str) -> List[str]: # noqa: F821
|
||||||
server = get_letta_server()
|
server = get_letta_server()
|
||||||
|
|
||||||
augmented_message = (
|
augmented_message = (
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ class DynamicMultiAgent(BaseAgent):
|
|||||||
|
|
||||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||||
|
|
||||||
def load_manager_agent(self) -> Agent:
|
def load_manager_agent(self) -> Agent: # noqa: F821
|
||||||
for participant_agent_id in self.agent_ids:
|
for participant_agent_id in self.agent_ids:
|
||||||
participant_agent_state = self.agent_manager.get_agent_by_id(agent_id=participant_agent_id, actor=self.user)
|
participant_agent_state = self.agent_manager.get_agent_by_id(agent_id=participant_agent_id, actor=self.user)
|
||||||
participant_persona_block = participant_agent_state.memory.get_block(label="persona")
|
participant_persona_block = participant_agent_state.memory.get_block(label="persona")
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from letta.otel.tracing import trace_method
|
|||||||
from letta.schemas.agent import AgentState
|
from letta.schemas.agent import AgentState
|
||||||
from letta.schemas.enums import RunStatus
|
from letta.schemas.enums import RunStatus
|
||||||
from letta.schemas.group import Group, ManagerType
|
from letta.schemas.group import Group, ManagerType
|
||||||
from letta.schemas.job import JobUpdate
|
|
||||||
from letta.schemas.letta_message import MessageType
|
from letta.schemas.letta_message import MessageType
|
||||||
from letta.schemas.letta_message_content import TextContent
|
from letta.schemas.letta_message_content import TextContent
|
||||||
from letta.schemas.letta_request import ClientToolSchema
|
from letta.schemas.letta_request import ClientToolSchema
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
@@ -7,9 +6,8 @@ from letta.constants import DEFAULT_MAX_STEPS
|
|||||||
from letta.groups.helpers import stringify_message
|
from letta.groups.helpers import stringify_message
|
||||||
from letta.otel.tracing import trace_method
|
from letta.otel.tracing import trace_method
|
||||||
from letta.schemas.agent import AgentState
|
from letta.schemas.agent import AgentState
|
||||||
from letta.schemas.enums import JobStatus, RunStatus
|
from letta.schemas.enums import RunStatus
|
||||||
from letta.schemas.group import Group, ManagerType
|
from letta.schemas.group import Group, ManagerType
|
||||||
from letta.schemas.job import JobUpdate
|
|
||||||
from letta.schemas.letta_message import MessageType
|
from letta.schemas.letta_message import MessageType
|
||||||
from letta.schemas.letta_message_content import TextContent
|
from letta.schemas.letta_message_content import TextContent
|
||||||
from letta.schemas.letta_request import ClientToolSchema
|
from letta.schemas.letta_request import ClientToolSchema
|
||||||
|
|||||||
@@ -1,19 +1,9 @@
|
|||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
from letta.agents.base_agent import BaseAgent
|
from letta.agents.base_agent import BaseAgent
|
||||||
from letta.constants import DEFAULT_MESSAGE_TOOL
|
|
||||||
from letta.functions.function_sets.multi_agent import send_message_to_all_agents_in_group
|
|
||||||
from letta.functions.functions import parse_source_code
|
|
||||||
from letta.functions.schema_generator import generate_schema
|
|
||||||
from letta.interface import AgentInterface
|
from letta.interface import AgentInterface
|
||||||
from letta.orm import User
|
from letta.orm import User
|
||||||
from letta.schemas.agent import AgentState
|
from letta.schemas.agent import AgentState
|
||||||
from letta.schemas.enums import ToolType
|
|
||||||
from letta.schemas.letta_message_content import TextContent
|
|
||||||
from letta.schemas.message import MessageCreate
|
|
||||||
from letta.schemas.tool import Tool
|
|
||||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
|
||||||
from letta.services.agent_manager import AgentManager
|
from letta.services.agent_manager import AgentManager
|
||||||
from letta.services.tool_manager import ToolManager
|
from letta.services.tool_manager import ToolManager
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from letta.helpers.tool_rule_solver import ToolRulesSolver
|
from letta.helpers.tool_rule_solver import ToolRulesSolver as ToolRulesSolver
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ def deserialize_embedding_config(data: Optional[Dict]) -> Optional[EmbeddingConf
|
|||||||
# --------------------------
|
# --------------------------
|
||||||
|
|
||||||
|
|
||||||
def serialize_compaction_settings(config: Union[Optional["CompactionSettings"], Dict]) -> Optional[Dict]:
|
def serialize_compaction_settings(config: Union[Optional["CompactionSettings"], Dict]) -> Optional[Dict]: # noqa: F821
|
||||||
"""Convert a CompactionSettings object into a JSON-serializable dictionary."""
|
"""Convert a CompactionSettings object into a JSON-serializable dictionary."""
|
||||||
if config:
|
if config:
|
||||||
# Import here to avoid circular dependency
|
# Import here to avoid circular dependency
|
||||||
@@ -124,7 +124,7 @@ def serialize_compaction_settings(config: Union[Optional["CompactionSettings"],
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def deserialize_compaction_settings(data: Optional[Dict]) -> Optional["CompactionSettings"]:
|
def deserialize_compaction_settings(data: Optional[Dict]) -> Optional["CompactionSettings"]: # noqa: F821
|
||||||
"""Convert a dictionary back into a CompactionSettings object."""
|
"""Convert a dictionary back into a CompactionSettings object."""
|
||||||
if data:
|
if data:
|
||||||
# Import here to avoid circular dependency
|
# Import here to avoid circular dependency
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
|
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
|
||||||
from letta.schemas.tool import MCP_TOOL_METADATA_SCHEMA_STATUS, MCP_TOOL_METADATA_SCHEMA_WARNINGS
|
from letta.schemas.tool import MCP_TOOL_METADATA_SCHEMA_STATUS, MCP_TOOL_METADATA_SCHEMA_WARNINGS
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ class TurbopufferClient:
|
|||||||
raise ValueError("Turbopuffer API key not provided")
|
raise ValueError("Turbopuffer API key not provided")
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]:
|
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]: # noqa: F821
|
||||||
"""Generate embeddings using the default embedding configuration.
|
"""Generate embeddings using the default embedding configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -311,7 +311,7 @@ class TurbopufferClient:
|
|||||||
|
|
||||||
return namespace_name
|
return namespace_name
|
||||||
|
|
||||||
def _extract_tool_text(self, tool: "PydanticTool") -> str:
|
def _extract_tool_text(self, tool: "PydanticTool") -> str: # noqa: F821
|
||||||
"""Extract searchable text from a tool for embedding.
|
"""Extract searchable text from a tool for embedding.
|
||||||
|
|
||||||
Combines name, description, and JSON schema into a structured format
|
Combines name, description, and JSON schema into a structured format
|
||||||
@@ -361,9 +361,9 @@ class TurbopufferClient:
|
|||||||
@async_retry_with_backoff()
|
@async_retry_with_backoff()
|
||||||
async def insert_tools(
|
async def insert_tools(
|
||||||
self,
|
self,
|
||||||
tools: List["PydanticTool"],
|
tools: List["PydanticTool"], # noqa: F821
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
actor: "PydanticUser",
|
actor: "PydanticUser", # noqa: F821
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Insert tools into Turbopuffer.
|
"""Insert tools into Turbopuffer.
|
||||||
|
|
||||||
@@ -375,7 +375,6 @@ class TurbopufferClient:
|
|||||||
Returns:
|
Returns:
|
||||||
True if successful
|
True if successful
|
||||||
"""
|
"""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
if not tools:
|
if not tools:
|
||||||
return True
|
return True
|
||||||
@@ -457,7 +456,7 @@ class TurbopufferClient:
|
|||||||
text_chunks: List[str],
|
text_chunks: List[str],
|
||||||
passage_ids: List[str],
|
passage_ids: List[str],
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
actor: "PydanticUser",
|
actor: "PydanticUser", # noqa: F821
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
created_at: Optional[datetime] = None,
|
created_at: Optional[datetime] = None,
|
||||||
embeddings: Optional[List[List[float]]] = None,
|
embeddings: Optional[List[List[float]]] = None,
|
||||||
@@ -477,7 +476,6 @@ class TurbopufferClient:
|
|||||||
Returns:
|
Returns:
|
||||||
List of PydanticPassage objects that were inserted
|
List of PydanticPassage objects that were inserted
|
||||||
"""
|
"""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
# filter out empty text chunks
|
# filter out empty text chunks
|
||||||
filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()]
|
filtered_chunks = [(i, text) for i, text in enumerate(text_chunks) if text.strip()]
|
||||||
@@ -609,7 +607,7 @@ class TurbopufferClient:
|
|||||||
message_texts: List[str],
|
message_texts: List[str],
|
||||||
message_ids: List[str],
|
message_ids: List[str],
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
actor: "PydanticUser",
|
actor: "PydanticUser", # noqa: F821
|
||||||
roles: List[MessageRole],
|
roles: List[MessageRole],
|
||||||
created_ats: List[datetime],
|
created_ats: List[datetime],
|
||||||
project_id: Optional[str] = None,
|
project_id: Optional[str] = None,
|
||||||
@@ -633,7 +631,6 @@ class TurbopufferClient:
|
|||||||
Returns:
|
Returns:
|
||||||
True if successful
|
True if successful
|
||||||
"""
|
"""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
# filter out empty message texts
|
# filter out empty message texts
|
||||||
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
|
filtered_messages = [(i, text) for i, text in enumerate(message_texts) if text.strip()]
|
||||||
@@ -870,7 +867,7 @@ class TurbopufferClient:
|
|||||||
async def query_passages(
|
async def query_passages(
|
||||||
self,
|
self,
|
||||||
archive_id: str,
|
archive_id: str,
|
||||||
actor: "PydanticUser",
|
actor: "PydanticUser", # noqa: F821
|
||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
@@ -1015,7 +1012,7 @@ class TurbopufferClient:
|
|||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
actor: "PydanticUser",
|
actor: "PydanticUser", # noqa: F821
|
||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
@@ -1191,7 +1188,7 @@ class TurbopufferClient:
|
|||||||
async def query_messages_by_org_id(
|
async def query_messages_by_org_id(
|
||||||
self,
|
self,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
actor: "PydanticUser",
|
actor: "PydanticUser", # noqa: F821
|
||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
search_mode: str = "hybrid", # "vector", "fts", "hybrid"
|
search_mode: str = "hybrid", # "vector", "fts", "hybrid"
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
@@ -1520,7 +1517,6 @@ class TurbopufferClient:
|
|||||||
@async_retry_with_backoff()
|
@async_retry_with_backoff()
|
||||||
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
|
async def delete_passage(self, archive_id: str, passage_id: str) -> bool:
|
||||||
"""Delete a passage from Turbopuffer."""
|
"""Delete a passage from Turbopuffer."""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
namespace_name = await self._get_archive_namespace_name(archive_id)
|
namespace_name = await self._get_archive_namespace_name(archive_id)
|
||||||
|
|
||||||
@@ -1543,7 +1539,6 @@ class TurbopufferClient:
|
|||||||
@async_retry_with_backoff()
|
@async_retry_with_backoff()
|
||||||
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
|
async def delete_passages(self, archive_id: str, passage_ids: List[str]) -> bool:
|
||||||
"""Delete multiple passages from Turbopuffer."""
|
"""Delete multiple passages from Turbopuffer."""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
if not passage_ids:
|
if not passage_ids:
|
||||||
return True
|
return True
|
||||||
@@ -1588,7 +1583,6 @@ class TurbopufferClient:
|
|||||||
@async_retry_with_backoff()
|
@async_retry_with_backoff()
|
||||||
async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool:
|
async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool:
|
||||||
"""Delete multiple messages from Turbopuffer."""
|
"""Delete multiple messages from Turbopuffer."""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
if not message_ids:
|
if not message_ids:
|
||||||
return True
|
return True
|
||||||
@@ -1614,7 +1608,6 @@ class TurbopufferClient:
|
|||||||
@async_retry_with_backoff()
|
@async_retry_with_backoff()
|
||||||
async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool:
|
async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool:
|
||||||
"""Delete all messages for an agent from Turbopuffer."""
|
"""Delete all messages for an agent from Turbopuffer."""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
namespace_name = await self._get_message_namespace_name(organization_id)
|
namespace_name = await self._get_message_namespace_name(organization_id)
|
||||||
|
|
||||||
@@ -1661,7 +1654,7 @@ class TurbopufferClient:
|
|||||||
file_id: str,
|
file_id: str,
|
||||||
text_chunks: List[str],
|
text_chunks: List[str],
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
actor: "PydanticUser",
|
actor: "PydanticUser", # noqa: F821
|
||||||
created_at: Optional[datetime] = None,
|
created_at: Optional[datetime] = None,
|
||||||
) -> List[PydanticPassage]:
|
) -> List[PydanticPassage]:
|
||||||
"""Insert file passages into Turbopuffer using org-scoped namespace.
|
"""Insert file passages into Turbopuffer using org-scoped namespace.
|
||||||
@@ -1677,7 +1670,6 @@ class TurbopufferClient:
|
|||||||
Returns:
|
Returns:
|
||||||
List of PydanticPassage objects that were inserted
|
List of PydanticPassage objects that were inserted
|
||||||
"""
|
"""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
if not text_chunks:
|
if not text_chunks:
|
||||||
return []
|
return []
|
||||||
@@ -1775,7 +1767,7 @@ class TurbopufferClient:
|
|||||||
self,
|
self,
|
||||||
source_ids: List[str],
|
source_ids: List[str],
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
actor: "PydanticUser",
|
actor: "PydanticUser", # noqa: F821
|
||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
@@ -1914,7 +1906,6 @@ class TurbopufferClient:
|
|||||||
@async_retry_with_backoff()
|
@async_retry_with_backoff()
|
||||||
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
|
async def delete_file_passages(self, source_id: str, file_id: str, organization_id: str) -> bool:
|
||||||
"""Delete all passages for a specific file from Turbopuffer."""
|
"""Delete all passages for a specific file from Turbopuffer."""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
||||||
|
|
||||||
@@ -1943,7 +1934,6 @@ class TurbopufferClient:
|
|||||||
@async_retry_with_backoff()
|
@async_retry_with_backoff()
|
||||||
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
|
async def delete_source_passages(self, source_id: str, organization_id: str) -> bool:
|
||||||
"""Delete all passages for a source from Turbopuffer."""
|
"""Delete all passages for a source from Turbopuffer."""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
namespace_name = await self._get_file_passages_namespace_name(organization_id)
|
||||||
|
|
||||||
@@ -1976,7 +1966,6 @@ class TurbopufferClient:
|
|||||||
Returns:
|
Returns:
|
||||||
True if successful
|
True if successful
|
||||||
"""
|
"""
|
||||||
from turbopuffer import AsyncTurbopuffer
|
|
||||||
|
|
||||||
if not tool_ids:
|
if not tool_ids:
|
||||||
return True
|
return True
|
||||||
@@ -2002,7 +1991,7 @@ class TurbopufferClient:
|
|||||||
async def query_tools(
|
async def query_tools(
|
||||||
self,
|
self,
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
actor: "PydanticUser",
|
actor: "PydanticUser", # noqa: F821
|
||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
search_mode: str = "hybrid", # "vector", "fts", "hybrid", "timestamp"
|
search_mode: str = "hybrid", # "vector", "fts", "hybrid", "timestamp"
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ class SimpleAnthropicStreamingInterface:
|
|||||||
return tool_calls[0]
|
return tool_calls[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||||
"""Extract usage statistics from accumulated streaming data.
|
"""Extract usage statistics from accumulated streaming data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -232,7 +232,7 @@ class SimpleAnthropicStreamingInterface:
|
|||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
prev_message_type = None
|
prev_message_type = None
|
||||||
message_index = 0
|
message_index = 0
|
||||||
@@ -287,7 +287,7 @@ class SimpleAnthropicStreamingInterface:
|
|||||||
async def _process_event(
|
async def _process_event(
|
||||||
self,
|
self,
|
||||||
event: BetaRawMessageStreamEvent,
|
event: BetaRawMessageStreamEvent,
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
prev_message_type: Optional[str] = None,
|
prev_message_type: Optional[str] = None,
|
||||||
message_index: int = 0,
|
message_index: int = 0,
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class AnthropicStreamingInterface:
|
|||||||
arguments = str(json.dumps(tool_input, indent=2))
|
arguments = str(json.dumps(tool_input, indent=2))
|
||||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
||||||
|
|
||||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||||
"""Extract usage statistics from accumulated streaming data.
|
"""Extract usage statistics from accumulated streaming data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -222,7 +222,7 @@ class AnthropicStreamingInterface:
|
|||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
prev_message_type = None
|
prev_message_type = None
|
||||||
message_index = 0
|
message_index = 0
|
||||||
@@ -276,7 +276,7 @@ class AnthropicStreamingInterface:
|
|||||||
async def _process_event(
|
async def _process_event(
|
||||||
self,
|
self,
|
||||||
event: BetaRawMessageStreamEvent,
|
event: BetaRawMessageStreamEvent,
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
prev_message_type: Optional[str] = None,
|
prev_message_type: Optional[str] = None,
|
||||||
message_index: int = 0,
|
message_index: int = 0,
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
@@ -662,7 +662,7 @@ class SimpleAnthropicStreamingInterface:
|
|||||||
arguments = str(json.dumps(tool_input, indent=2))
|
arguments = str(json.dumps(tool_input, indent=2))
|
||||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
||||||
|
|
||||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||||
"""Extract usage statistics from accumulated streaming data.
|
"""Extract usage statistics from accumulated streaming data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -754,7 +754,7 @@ class SimpleAnthropicStreamingInterface:
|
|||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
prev_message_type = None
|
prev_message_type = None
|
||||||
message_index = 0
|
message_index = 0
|
||||||
@@ -803,7 +803,7 @@ class SimpleAnthropicStreamingInterface:
|
|||||||
async def _process_event(
|
async def _process_event(
|
||||||
self,
|
self,
|
||||||
event: BetaRawMessageStreamEvent,
|
event: BetaRawMessageStreamEvent,
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
prev_message_type: Optional[str] = None,
|
prev_message_type: Optional[str] = None,
|
||||||
message_index: int = 0,
|
message_index: int = 0,
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ class SimpleGeminiStreamingInterface:
|
|||||||
"""Return all finalized tool calls collected during this message (parallel supported)."""
|
"""Return all finalized tool calls collected during this message (parallel supported)."""
|
||||||
return list(self.collected_tool_calls)
|
return list(self.collected_tool_calls)
|
||||||
|
|
||||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||||
"""Extract usage statistics from accumulated streaming data.
|
"""Extract usage statistics from accumulated streaming data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -148,7 +148,7 @@ class SimpleGeminiStreamingInterface:
|
|||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
stream: AsyncIterator[GenerateContentResponse],
|
stream: AsyncIterator[GenerateContentResponse],
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
"""
|
"""
|
||||||
Iterates over the Gemini stream, yielding SSE events.
|
Iterates over the Gemini stream, yielding SSE events.
|
||||||
@@ -202,7 +202,7 @@ class SimpleGeminiStreamingInterface:
|
|||||||
async def _process_event(
|
async def _process_event(
|
||||||
self,
|
self,
|
||||||
event: GenerateContentResponse,
|
event: GenerateContentResponse,
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
prev_message_type: Optional[str] = None,
|
prev_message_type: Optional[str] = None,
|
||||||
message_index: int = 0,
|
message_index: int = 0,
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ class OpenAIStreamingInterface:
|
|||||||
function=FunctionCall(arguments=self._get_current_function_arguments(), name=function_name),
|
function=FunctionCall(arguments=self._get_current_function_arguments(), name=function_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||||
"""Extract usage statistics from accumulated streaming data.
|
"""Extract usage statistics from accumulated streaming data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -219,7 +219,7 @@ class OpenAIStreamingInterface:
|
|||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
stream: AsyncStream[ChatCompletionChunk],
|
stream: AsyncStream[ChatCompletionChunk],
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
"""
|
"""
|
||||||
Iterates over the OpenAI stream, yielding SSE events.
|
Iterates over the OpenAI stream, yielding SSE events.
|
||||||
@@ -307,7 +307,7 @@ class OpenAIStreamingInterface:
|
|||||||
async def _process_chunk(
|
async def _process_chunk(
|
||||||
self,
|
self,
|
||||||
chunk: ChatCompletionChunk,
|
chunk: ChatCompletionChunk,
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
prev_message_type: Optional[str] = None,
|
prev_message_type: Optional[str] = None,
|
||||||
message_index: int = 0,
|
message_index: int = 0,
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
@@ -694,7 +694,7 @@ class SimpleOpenAIStreamingInterface:
|
|||||||
raise ValueError("No tool calls available")
|
raise ValueError("No tool calls available")
|
||||||
return calls[0]
|
return calls[0]
|
||||||
|
|
||||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||||
"""Extract usage statistics from accumulated streaming data.
|
"""Extract usage statistics from accumulated streaming data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -719,7 +719,7 @@ class SimpleOpenAIStreamingInterface:
|
|||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
stream: AsyncStream[ChatCompletionChunk],
|
stream: AsyncStream[ChatCompletionChunk],
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
"""
|
"""
|
||||||
Iterates over the OpenAI stream, yielding SSE events.
|
Iterates over the OpenAI stream, yielding SSE events.
|
||||||
@@ -833,7 +833,7 @@ class SimpleOpenAIStreamingInterface:
|
|||||||
async def _process_chunk(
|
async def _process_chunk(
|
||||||
self,
|
self,
|
||||||
chunk: ChatCompletionChunk,
|
chunk: ChatCompletionChunk,
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
prev_message_type: Optional[str] = None,
|
prev_message_type: Optional[str] = None,
|
||||||
message_index: int = 0,
|
message_index: int = 0,
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
@@ -1120,7 +1120,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
|||||||
raise ValueError("No tool calls available")
|
raise ValueError("No tool calls available")
|
||||||
return calls[0]
|
return calls[0]
|
||||||
|
|
||||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||||
"""Extract usage statistics from accumulated streaming data.
|
"""Extract usage statistics from accumulated streaming data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1141,7 +1141,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
|||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
stream: AsyncStream[ResponseStreamEvent],
|
stream: AsyncStream[ResponseStreamEvent],
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
"""
|
"""
|
||||||
Iterates over the OpenAI stream, yielding SSE events.
|
Iterates over the OpenAI stream, yielding SSE events.
|
||||||
@@ -1227,7 +1227,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
|||||||
async def _process_event(
|
async def _process_event(
|
||||||
self,
|
self,
|
||||||
event: ResponseStreamEvent,
|
event: ResponseStreamEvent,
|
||||||
ttft_span: Optional["Span"] = None,
|
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||||
prev_message_type: Optional[str] = None,
|
prev_message_type: Optional[str] = None,
|
||||||
message_index: int = 0,
|
message_index: int = 0,
|
||||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||||
|
|||||||
@@ -49,9 +49,7 @@ from letta.schemas.openai.chat_completion_response import (
|
|||||||
FunctionCall,
|
FunctionCall,
|
||||||
Message as ChoiceMessage,
|
Message as ChoiceMessage,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
UsageStatistics,
|
|
||||||
)
|
)
|
||||||
from letta.schemas.response_format import JsonSchemaResponseFormat
|
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.settings import model_settings
|
from letta.settings import model_settings
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,9 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
Response,
|
Response,
|
||||||
ResponseCompletedEvent,
|
ResponseCompletedEvent,
|
||||||
@@ -50,11 +49,6 @@ from letta.schemas.llm_config import LLMConfig
|
|||||||
from letta.schemas.message import Message as PydanticMessage
|
from letta.schemas.message import Message as PydanticMessage
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
Choice,
|
|
||||||
FunctionCall,
|
|
||||||
Message as ChoiceMessage,
|
|
||||||
ToolCall,
|
|
||||||
UsageStatistics,
|
|
||||||
)
|
)
|
||||||
from letta.schemas.providers.chatgpt_oauth import ChatGPTOAuthCredentials, ChatGPTOAuthProvider
|
from letta.schemas.providers.chatgpt_oauth import ChatGPTOAuthCredentials, ChatGPTOAuthProvider
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
|
|||||||
@@ -1,23 +1,17 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, List, Optional, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
|
||||||
from letta.helpers.json_helpers import json_dumps
|
from letta.helpers.json_helpers import json_dumps
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.schemas.message import Message
|
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
|
||||||
from letta.schemas.response_format import (
|
from letta.schemas.response_format import (
|
||||||
JsonObjectResponseFormat,
|
JsonObjectResponseFormat,
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
ResponseFormatType,
|
|
||||||
ResponseFormatUnion,
|
ResponseFormatUnion,
|
||||||
TextResponseFormat,
|
TextResponseFormat,
|
||||||
)
|
)
|
||||||
from letta.settings import summarizer_settings
|
|
||||||
from letta.utils import printd
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ from letta.schemas.message import Message as PydanticMessage
|
|||||||
from letta.schemas.openai.chat_completion_request import (
|
from letta.schemas.openai.chat_completion_request import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
FunctionCall as ToolFunctionChoiceFunctionCall,
|
FunctionCall as ToolFunctionChoiceFunctionCall,
|
||||||
FunctionSchema,
|
|
||||||
Tool as OpenAITool,
|
Tool as OpenAITool,
|
||||||
ToolFunctionChoice,
|
ToolFunctionChoice,
|
||||||
cast_message_to_subtype,
|
cast_message_to_subtype,
|
||||||
@@ -59,7 +58,6 @@ from letta.schemas.openai.chat_completion_response import (
|
|||||||
FunctionCall,
|
FunctionCall,
|
||||||
Message as ChoiceMessage,
|
Message as ChoiceMessage,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
UsageStatistics,
|
|
||||||
)
|
)
|
||||||
from letta.schemas.openai.responses_request import ResponsesRequest
|
from letta.schemas.openai.responses_request import ResponsesRequest
|
||||||
from letta.schemas.response_format import JsonSchemaResponseFormat
|
from letta.schemas.response_format import JsonSchemaResponseFormat
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ The OpenAI-compatible endpoint only returns token strings, not IDs, making it
|
|||||||
impossible to accurately reconstruct the token sequence for training.
|
impossible to accurately reconstruct the token sequence for training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -20,18 +20,18 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
class SGLangNativeClient:
|
class SGLangNativeClient:
|
||||||
"""Client for SGLang's native /generate endpoint.
|
"""Client for SGLang's native /generate endpoint.
|
||||||
|
|
||||||
Unlike the OpenAI-compatible endpoint, this returns:
|
Unlike the OpenAI-compatible endpoint, this returns:
|
||||||
- output_ids: List of token IDs
|
- output_ids: List of token IDs
|
||||||
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
|
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
|
||||||
|
|
||||||
This is essential for RL training where we need exact token IDs, not re-tokenized text.
|
This is essential for RL training where we need exact token IDs, not re-tokenized text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
def __init__(self, base_url: str, api_key: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Initialize the SGLang native client.
|
Initialize the SGLang native client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
base_url: Base URL for SGLang server (e.g., http://localhost:30000)
|
base_url: Base URL for SGLang server (e.g., http://localhost:30000)
|
||||||
api_key: Optional API key for authentication
|
api_key: Optional API key for authentication
|
||||||
@@ -41,7 +41,7 @@ class SGLangNativeClient:
|
|||||||
if self.base_url.endswith("/v1"):
|
if self.base_url.endswith("/v1"):
|
||||||
self.base_url = self.base_url[:-3]
|
self.base_url = self.base_url[:-3]
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -50,19 +50,19 @@ class SGLangNativeClient:
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Call SGLang's native /generate endpoint.
|
Call SGLang's native /generate endpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The formatted prompt text (with chat template applied)
|
text: The formatted prompt text (with chat template applied)
|
||||||
sampling_params: Sampling parameters (temperature, max_new_tokens, etc.)
|
sampling_params: Sampling parameters (temperature, max_new_tokens, etc.)
|
||||||
return_logprob: Whether to return logprobs (default True for RL training)
|
return_logprob: Whether to return logprobs (default True for RL training)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response dict with:
|
Response dict with:
|
||||||
- text: Generated text
|
- text: Generated text
|
||||||
- output_ids: List of token IDs
|
- output_ids: List of token IDs
|
||||||
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
|
- output_token_logprobs: List of [logprob, token_id, top_logprob] tuples
|
||||||
- meta_info: Metadata including finish_reason, prompt_tokens, etc.
|
- meta_info: Metadata including finish_reason, prompt_tokens, etc.
|
||||||
|
|
||||||
Example response:
|
Example response:
|
||||||
{
|
{
|
||||||
"text": "Hello! How can I help?",
|
"text": "Hello! How can I help?",
|
||||||
@@ -82,13 +82,13 @@ class SGLangNativeClient:
|
|||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"text": text,
|
"text": text,
|
||||||
"sampling_params": sampling_params or {},
|
"sampling_params": sampling_params or {},
|
||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
}
|
}
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/generate",
|
f"{self.base_url}/generate",
|
||||||
@@ -97,7 +97,7 @@ class SGLangNativeClient:
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
async def health_check(self) -> bool:
|
||||||
"""Check if the SGLang server is healthy."""
|
"""Check if the SGLang server is healthy."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# (settings.py imports from this module indirectly through log.py)
|
# (settings.py imports from this module indirectly through log.py)
|
||||||
# Import this here to avoid circular dependency at module level
|
# Import this here to avoid circular dependency at module level
|
||||||
from letta.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper
|
from letta.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper
|
||||||
from letta.settings import DEFAULT_WRAPPER_NAME, INNER_THOUGHTS_KWARG
|
from letta.settings import INNER_THOUGHTS_KWARG
|
||||||
|
|
||||||
DEFAULT_WRAPPER = ChatMLInnerMonologueWrapper
|
DEFAULT_WRAPPER = ChatMLInnerMonologueWrapper
|
||||||
INNER_THOUGHTS_KWARG_VERTEX = "thinking"
|
INNER_THOUGHTS_KWARG_VERTEX = "thinking"
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def get_completions_settings(defaults="simple") -> dict:
|
|||||||
with open(settings_file, "r", encoding="utf-8") as file:
|
with open(settings_file, "r", encoding="utf-8") as file:
|
||||||
user_settings = json.load(file)
|
user_settings = json.load(file)
|
||||||
if len(user_settings) > 0:
|
if len(user_settings) > 0:
|
||||||
printd(f"Updating base settings with the following user settings:\n{json_dumps(user_settings, indent=2)}")
|
printd(f"Updating base settings with the following user settings:\n{json.dumps(user_settings, indent=2)}")
|
||||||
settings.update(user_settings)
|
settings.update(user_settings)
|
||||||
else:
|
else:
|
||||||
printd(f"'{settings_file}' was empty, ignoring...")
|
printd(f"'{settings_file}' was empty, ignoring...")
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from letta.cli.cli import server
|
from letta.cli.cli import server
|
||||||
|
|||||||
@@ -1,44 +1,48 @@
|
|||||||
from letta.orm.agent import Agent
|
from letta.orm.agent import Agent as Agent
|
||||||
from letta.orm.agents_tags import AgentsTags
|
from letta.orm.agents_tags import AgentsTags as AgentsTags
|
||||||
from letta.orm.archive import Archive
|
from letta.orm.archive import Archive as Archive
|
||||||
from letta.orm.archives_agents import ArchivesAgents
|
from letta.orm.archives_agents import ArchivesAgents as ArchivesAgents
|
||||||
from letta.orm.base import Base
|
from letta.orm.base import Base as Base
|
||||||
from letta.orm.block import Block
|
from letta.orm.block import Block as Block
|
||||||
from letta.orm.block_history import BlockHistory
|
from letta.orm.block_history import BlockHistory as BlockHistory
|
||||||
from letta.orm.blocks_agents import BlocksAgents
|
from letta.orm.blocks_agents import BlocksAgents as BlocksAgents
|
||||||
from letta.orm.blocks_conversations import BlocksConversations
|
from letta.orm.blocks_conversations import BlocksConversations as BlocksConversations
|
||||||
from letta.orm.blocks_tags import BlocksTags
|
from letta.orm.blocks_tags import BlocksTags as BlocksTags
|
||||||
from letta.orm.conversation import Conversation
|
from letta.orm.conversation import Conversation as Conversation
|
||||||
from letta.orm.conversation_messages import ConversationMessage
|
from letta.orm.conversation_messages import ConversationMessage as ConversationMessage
|
||||||
from letta.orm.file import FileMetadata
|
from letta.orm.file import FileMetadata as FileMetadata
|
||||||
from letta.orm.files_agents import FileAgent
|
from letta.orm.files_agents import FileAgent as FileAgent
|
||||||
from letta.orm.group import Group
|
from letta.orm.group import Group as Group
|
||||||
from letta.orm.groups_agents import GroupsAgents
|
from letta.orm.groups_agents import GroupsAgents as GroupsAgents
|
||||||
from letta.orm.groups_blocks import GroupsBlocks
|
from letta.orm.groups_blocks import GroupsBlocks as GroupsBlocks
|
||||||
from letta.orm.identities_agents import IdentitiesAgents
|
from letta.orm.identities_agents import IdentitiesAgents as IdentitiesAgents
|
||||||
from letta.orm.identities_blocks import IdentitiesBlocks
|
from letta.orm.identities_blocks import IdentitiesBlocks as IdentitiesBlocks
|
||||||
from letta.orm.identity import Identity
|
from letta.orm.identity import Identity as Identity
|
||||||
from letta.orm.job import Job
|
from letta.orm.job import Job as Job
|
||||||
from letta.orm.llm_batch_items import LLMBatchItem
|
from letta.orm.llm_batch_items import LLMBatchItem as LLMBatchItem
|
||||||
from letta.orm.llm_batch_job import LLMBatchJob
|
from letta.orm.llm_batch_job import LLMBatchJob as LLMBatchJob
|
||||||
from letta.orm.mcp_oauth import MCPOAuth
|
from letta.orm.mcp_oauth import MCPOAuth as MCPOAuth
|
||||||
from letta.orm.mcp_server import MCPServer
|
from letta.orm.mcp_server import MCPServer as MCPServer
|
||||||
from letta.orm.message import Message
|
from letta.orm.message import Message as Message
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization as Organization
|
||||||
from letta.orm.passage import ArchivalPassage, BasePassage, SourcePassage
|
from letta.orm.passage import ArchivalPassage as ArchivalPassage, BasePassage as BasePassage, SourcePassage as SourcePassage
|
||||||
from letta.orm.passage_tag import PassageTag
|
from letta.orm.passage_tag import PassageTag as PassageTag
|
||||||
from letta.orm.prompt import Prompt
|
from letta.orm.prompt import Prompt as Prompt
|
||||||
from letta.orm.provider import Provider
|
from letta.orm.provider import Provider as Provider
|
||||||
from letta.orm.provider_model import ProviderModel
|
from letta.orm.provider_model import ProviderModel as ProviderModel
|
||||||
from letta.orm.provider_trace import ProviderTrace
|
from letta.orm.provider_trace import ProviderTrace as ProviderTrace
|
||||||
from letta.orm.provider_trace_metadata import ProviderTraceMetadata
|
from letta.orm.provider_trace_metadata import ProviderTraceMetadata as ProviderTraceMetadata
|
||||||
from letta.orm.run import Run
|
from letta.orm.run import Run as Run
|
||||||
from letta.orm.run_metrics import RunMetrics
|
from letta.orm.run_metrics import RunMetrics as RunMetrics
|
||||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
from letta.orm.sandbox_config import (
|
||||||
from letta.orm.source import Source
|
AgentEnvironmentVariable as AgentEnvironmentVariable,
|
||||||
from letta.orm.sources_agents import SourcesAgents
|
SandboxConfig as SandboxConfig,
|
||||||
from letta.orm.step import Step
|
SandboxEnvironmentVariable as SandboxEnvironmentVariable,
|
||||||
from letta.orm.step_metrics import StepMetrics
|
)
|
||||||
from letta.orm.tool import Tool
|
from letta.orm.source import Source as Source
|
||||||
from letta.orm.tools_agents import ToolsAgents
|
from letta.orm.sources_agents import SourcesAgents as SourcesAgents
|
||||||
from letta.orm.user import User
|
from letta.orm.step import Step as Step
|
||||||
|
from letta.orm.step_metrics import StepMetrics as StepMetrics
|
||||||
|
from letta.orm.tool import Tool as Tool
|
||||||
|
from letta.orm.tools_agents import ToolsAgents as ToolsAgents
|
||||||
|
from letta.orm.user import User as User
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ if TYPE_CHECKING:
|
|||||||
from letta.orm.run import Run
|
from letta.orm.run import Run
|
||||||
from letta.orm.source import Source
|
from letta.orm.source import Source
|
||||||
from letta.orm.tool import Tool
|
from letta.orm.tool import Tool
|
||||||
from letta.services.summarizer.summarizer_config import CompactionSettings
|
|
||||||
|
|
||||||
|
|
||||||
class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin, AsyncAttrs):
|
class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin, AsyncAttrs):
|
||||||
@@ -123,7 +122,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
|||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents", lazy="raise")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents", lazy="raise")
|
||||||
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
|
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship( # noqa: F821
|
||||||
"AgentEnvironmentVariable",
|
"AgentEnvironmentVariable",
|
||||||
back_populates="agent",
|
back_populates="agent",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
@@ -161,14 +160,14 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
|||||||
back_populates="agents",
|
back_populates="agents",
|
||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
groups: Mapped[List["Group"]] = relationship(
|
groups: Mapped[List["Group"]] = relationship( # noqa: F821
|
||||||
"Group",
|
"Group",
|
||||||
secondary="groups_agents",
|
secondary="groups_agents",
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
back_populates="agents",
|
back_populates="agents",
|
||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
multi_agent_group: Mapped["Group"] = relationship(
|
multi_agent_group: Mapped["Group"] = relationship( # noqa: F821
|
||||||
"Group",
|
"Group",
|
||||||
lazy="selectin",
|
lazy="selectin",
|
||||||
viewonly=True,
|
viewonly=True,
|
||||||
@@ -176,7 +175,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
|||||||
foreign_keys="[Group.manager_agent_id]",
|
foreign_keys="[Group.manager_agent_id]",
|
||||||
uselist=False,
|
uselist=False,
|
||||||
)
|
)
|
||||||
batch_items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="agent", lazy="raise")
|
batch_items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="agent", lazy="raise") # noqa: F821
|
||||||
file_agents: Mapped[List["FileAgent"]] = relationship(
|
file_agents: Mapped[List["FileAgent"]] = relationship(
|
||||||
"FileAgent",
|
"FileAgent",
|
||||||
back_populates="agent",
|
back_populates="agent",
|
||||||
|
|||||||
@@ -21,4 +21,4 @@ class AgentsTags(Base):
|
|||||||
tag: Mapped[str] = mapped_column(String, doc="The name of the tag associated with the agent.", primary_key=True)
|
tag: Mapped[str] = mapped_column(String, doc="The name of the tag associated with the agent.", primary_key=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="tags")
|
agent: Mapped["Agent"] = relationship("Agent", back_populates="tags") # noqa: F821
|
||||||
|
|||||||
@@ -23,5 +23,5 @@ class ArchivesAgents(Base):
|
|||||||
is_owner: Mapped[bool] = mapped_column(Boolean, default=False, doc="Whether this agent created/owns the archive")
|
is_owner: Mapped[bool] = mapped_column(Boolean, default=False, doc="Whether this agent created/owns the archive")
|
||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="archives_agents")
|
agent: Mapped["Agent"] = relationship("Agent", back_populates="archives_agents") # noqa: F821
|
||||||
archive: Mapped["Archive"] = relationship("Archive", back_populates="archives_agents")
|
archive: Mapped["Archive"] = relationship("Archive", back_populates="archives_agents") # noqa: F821
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
|
|||||||
|
|
||||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
||||||
from letta.orm.block_history import BlockHistory
|
from letta.orm.block_history import BlockHistory
|
||||||
from letta.orm.blocks_agents import BlocksAgents
|
|
||||||
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
|
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateEntityMixin, TemplateMixin
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
from letta.schemas.block import Block as PydanticBlock, Human, Persona
|
from letta.schemas.block import Block as PydanticBlock, Human, Persona
|
||||||
@@ -61,7 +60,7 @@ class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin, TemplateEntityMixin
|
|||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
organization: Mapped[Optional["Organization"]] = relationship("Organization", lazy="raise")
|
organization: Mapped[Optional["Organization"]] = relationship("Organization", lazy="raise")
|
||||||
agents: Mapped[List["Agent"]] = relationship(
|
agents: Mapped[List["Agent"]] = relationship( # noqa: F821
|
||||||
"Agent",
|
"Agent",
|
||||||
secondary="blocks_agents",
|
secondary="blocks_agents",
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
@@ -76,7 +75,7 @@ class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin, TemplateEntityMixin
|
|||||||
back_populates="blocks",
|
back_populates="blocks",
|
||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
groups: Mapped[List["Group"]] = relationship(
|
groups: Mapped[List["Group"]] = relationship( # noqa: F821
|
||||||
"Group",
|
"Group",
|
||||||
secondary="groups_blocks",
|
secondary="groups_blocks",
|
||||||
lazy="raise",
|
lazy="raise",
|
||||||
|
|||||||
@@ -34,4 +34,4 @@ class BlocksTags(Base):
|
|||||||
_last_updated_by_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
_last_updated_by_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
block: Mapped["Block"] = relationship("Block", back_populates="tags")
|
block: Mapped["Block"] = relationship("Block", back_populates="tags") # noqa: F821
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
agent: Mapped["Agent"] = relationship(
|
agent: Mapped["Agent"] = relationship( # noqa: F821
|
||||||
"Agent",
|
"Agent",
|
||||||
back_populates="file_agents",
|
back_populates="file_agents",
|
||||||
lazy="selectin",
|
lazy="selectin",
|
||||||
|
|||||||
@@ -27,12 +27,12 @@ class Group(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
|||||||
hidden: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="If set to True, the group will be hidden.")
|
hidden: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="If set to True, the group will be hidden.")
|
||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups") # noqa: F821
|
||||||
agent_ids: Mapped[List[str]] = mapped_column(JSON, nullable=False, doc="Ordered list of agent IDs in this group")
|
agent_ids: Mapped[List[str]] = mapped_column(JSON, nullable=False, doc="Ordered list of agent IDs in this group")
|
||||||
agents: Mapped[List["Agent"]] = relationship(
|
agents: Mapped[List["Agent"]] = relationship( # noqa: F821
|
||||||
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
|
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
|
||||||
)
|
)
|
||||||
shared_blocks: Mapped[List["Block"]] = relationship(
|
shared_blocks: Mapped[List["Block"]] = relationship( # noqa: F821
|
||||||
"Block", secondary="groups_blocks", lazy="selectin", passive_deletes=True, back_populates="groups"
|
"Block", secondary="groups_blocks", lazy="selectin", passive_deletes=True, back_populates="groups"
|
||||||
)
|
)
|
||||||
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group")
|
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group") # noqa: F821
|
||||||
|
|||||||
@@ -36,11 +36,11 @@ class Identity(SqlalchemyBase, OrganizationMixin, ProjectMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="identities")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="identities") # noqa: F821
|
||||||
agents: Mapped[List["Agent"]] = relationship(
|
agents: Mapped[List["Agent"]] = relationship( # noqa: F821
|
||||||
"Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities"
|
"Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities"
|
||||||
)
|
)
|
||||||
blocks: Mapped[List["Block"]] = relationship(
|
blocks: Mapped[List["Block"]] = relationship( # noqa: F821
|
||||||
"Block", secondary="identities_blocks", lazy="selectin", passive_deletes=True, back_populates="identities"
|
"Block", secondary="identities_blocks", lazy="selectin", passive_deletes=True, back_populates="identities"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,6 @@ class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_items")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_items") # noqa: F821
|
||||||
batch: Mapped["LLMBatchJob"] = relationship("LLMBatchJob", back_populates="items", lazy="selectin")
|
batch: Mapped["LLMBatchJob"] = relationship("LLMBatchJob", back_populates="items", lazy="selectin") # noqa: F821
|
||||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="batch_items", lazy="selectin")
|
agent: Mapped["Agent"] = relationship("Agent", back_populates="batch_items", lazy="selectin") # noqa: F821
|
||||||
|
|||||||
@@ -47,5 +47,5 @@ class LLMBatchJob(SqlalchemyBase, OrganizationMixin):
|
|||||||
String, ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the Letta batch job"
|
String, ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the Letta batch job"
|
||||||
)
|
)
|
||||||
|
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs") # noqa: F821
|
||||||
items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin")
|
items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin") # noqa: F821
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from sqlalchemy import JSON, String, Text, UniqueConstraint
|
from sqlalchemy import JSON, String, Text, UniqueConstraint
|
||||||
|
|||||||
@@ -83,12 +83,12 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="raise")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="raise") # noqa: F821
|
||||||
step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin")
|
step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin") # noqa: F821
|
||||||
run: Mapped["Run"] = relationship("Run", back_populates="messages", lazy="selectin")
|
run: Mapped["Run"] = relationship("Run", back_populates="messages", lazy="selectin") # noqa: F821
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def job(self) -> Optional["Job"]:
|
def job(self) -> Optional["Job"]: # noqa: F821
|
||||||
"""Get the job associated with this message, if any."""
|
"""Get the job associated with this message, if any."""
|
||||||
return self.job_message.job if self.job_message else None
|
return self.job_message.job if self.job_message else None
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class ArchivalPassage(BasePassage, ArchiveMixin):
|
|||||||
__tablename__ = "archival_passages"
|
__tablename__ = "archival_passages"
|
||||||
|
|
||||||
# junction table for efficient tag queries (complements json column above)
|
# junction table for efficient tag queries (complements json column above)
|
||||||
passage_tags: Mapped[List["PassageTag"]] = relationship(
|
passage_tags: Mapped[List["PassageTag"]] = relationship( # noqa: F821
|
||||||
"PassageTag", back_populates="passage", cascade="all, delete-orphan", lazy="noload"
|
"PassageTag", back_populates="passage", cascade="all, delete-orphan", lazy="noload"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -43,4 +43,4 @@ class ProviderTrace(SqlalchemyBase, OrganizationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin") # noqa: F821
|
||||||
|
|||||||
@@ -42,4 +42,4 @@ class ProviderTraceMetadata(SqlalchemyBase, OrganizationMixin):
|
|||||||
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
|
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin") # noqa: F821
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import JSON, BigInteger, Boolean, DateTime, ForeignKey, Index, String
|
from sqlalchemy import JSON, BigInteger, Boolean, ForeignKey, Index, String
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateMixin
|
from letta.orm.mixins import OrganizationMixin, ProjectMixin, TemplateMixin
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from sqlalchemy import JSON, BigInteger, ForeignKey, Integer, String
|
from sqlalchemy import JSON, BigInteger, ForeignKey, Integer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
||||||
|
|
||||||
@@ -14,7 +14,6 @@ from letta.settings import DatabaseChoice, settings
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm.agent import Agent
|
from letta.orm.agent import Agent
|
||||||
from letta.orm.run import Run
|
from letta.orm.run import Run
|
||||||
from letta.orm.step import Step
|
|
||||||
|
|
||||||
|
|
||||||
class RunMetrics(SqlalchemyBase, ProjectMixin, AgentMixin, OrganizationMixin, TemplateMixin):
|
class RunMetrics(SqlalchemyBase, ProjectMixin, AgentMixin, OrganizationMixin, TemplateMixin):
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
query_embedding: Optional[List[float]] = None,
|
query_embedding: Optional[List[float]] = None,
|
||||||
ascending: bool = True,
|
ascending: bool = True,
|
||||||
actor: Optional["User"] = None,
|
actor: Optional["User"] = None, # noqa: F821
|
||||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||||
access_type: AccessType = AccessType.ORGANIZATION,
|
access_type: AccessType = AccessType.ORGANIZATION,
|
||||||
join_model: Optional[Base] = None,
|
join_model: Optional[Base] = None,
|
||||||
@@ -222,7 +222,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
query_text: Optional[str] = None,
|
query_text: Optional[str] = None,
|
||||||
query_embedding: Optional[List[float]] = None,
|
query_embedding: Optional[List[float]] = None,
|
||||||
ascending: bool = True,
|
ascending: bool = True,
|
||||||
actor: Optional["User"] = None,
|
actor: Optional["User"] = None, # noqa: F821
|
||||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||||
access_type: AccessType = AccessType.ORGANIZATION,
|
access_type: AccessType = AccessType.ORGANIZATION,
|
||||||
join_model: Optional[Base] = None,
|
join_model: Optional[Base] = None,
|
||||||
@@ -415,7 +415,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
cls,
|
cls,
|
||||||
db_session: "AsyncSession",
|
db_session: "AsyncSession",
|
||||||
identifier: Optional[str] = None,
|
identifier: Optional[str] = None,
|
||||||
actor: Optional["User"] = None,
|
actor: Optional["User"] = None, # noqa: F821
|
||||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||||
access_type: AccessType = AccessType.ORGANIZATION,
|
access_type: AccessType = AccessType.ORGANIZATION,
|
||||||
check_is_deleted: bool = False,
|
check_is_deleted: bool = False,
|
||||||
@@ -451,7 +451,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
cls,
|
cls,
|
||||||
db_session: "AsyncSession",
|
db_session: "AsyncSession",
|
||||||
identifiers: List[str] = [],
|
identifiers: List[str] = [],
|
||||||
actor: Optional["User"] = None,
|
actor: Optional["User"] = None, # noqa: F821
|
||||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||||
access_type: AccessType = AccessType.ORGANIZATION,
|
access_type: AccessType = AccessType.ORGANIZATION,
|
||||||
check_is_deleted: bool = False,
|
check_is_deleted: bool = False,
|
||||||
@@ -471,7 +471,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
def _read_multiple_preprocess(
|
def _read_multiple_preprocess(
|
||||||
cls,
|
cls,
|
||||||
identifiers: List[str],
|
identifiers: List[str],
|
||||||
actor: Optional["User"],
|
actor: Optional["User"], # noqa: F821
|
||||||
access: Optional[List[Literal["read", "write", "admin"]]],
|
access: Optional[List[Literal["read", "write", "admin"]]],
|
||||||
access_type: AccessType,
|
access_type: AccessType,
|
||||||
check_is_deleted: bool,
|
check_is_deleted: bool,
|
||||||
@@ -543,7 +543,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
async def create_async(
|
async def create_async(
|
||||||
self,
|
self,
|
||||||
db_session: "AsyncSession",
|
db_session: "AsyncSession",
|
||||||
actor: Optional["User"] = None,
|
actor: Optional["User"] = None, # noqa: F821
|
||||||
no_commit: bool = False,
|
no_commit: bool = False,
|
||||||
no_refresh: bool = False,
|
no_refresh: bool = False,
|
||||||
ignore_conflicts: bool = False,
|
ignore_conflicts: bool = False,
|
||||||
@@ -599,7 +599,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
cls,
|
cls,
|
||||||
items: List["SqlalchemyBase"],
|
items: List["SqlalchemyBase"],
|
||||||
db_session: "AsyncSession",
|
db_session: "AsyncSession",
|
||||||
actor: Optional["User"] = None,
|
actor: Optional["User"] = None, # noqa: F821
|
||||||
no_commit: bool = False,
|
no_commit: bool = False,
|
||||||
no_refresh: bool = False,
|
no_refresh: bool = False,
|
||||||
) -> List["SqlalchemyBase"]:
|
) -> List["SqlalchemyBase"]:
|
||||||
@@ -654,7 +654,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
cls._handle_dbapi_error(e)
|
cls._handle_dbapi_error(e)
|
||||||
|
|
||||||
@handle_db_timeout
|
@handle_db_timeout
|
||||||
async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase": # noqa: F821
|
||||||
"""Soft delete a record asynchronously (mark as deleted)."""
|
"""Soft delete a record asynchronously (mark as deleted)."""
|
||||||
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
||||||
|
|
||||||
@@ -665,7 +665,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
return await self.update_async(db_session)
|
return await self.update_async(db_session)
|
||||||
|
|
||||||
@handle_db_timeout
|
@handle_db_timeout
|
||||||
async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None:
|
async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None: # noqa: F821
|
||||||
"""Permanently removes the record from the database asynchronously."""
|
"""Permanently removes the record from the database asynchronously."""
|
||||||
obj_id = self.id
|
obj_id = self.id
|
||||||
obj_class = self.__class__.__name__
|
obj_class = self.__class__.__name__
|
||||||
@@ -694,7 +694,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
cls,
|
cls,
|
||||||
db_session: "AsyncSession",
|
db_session: "AsyncSession",
|
||||||
identifiers: List[str],
|
identifiers: List[str],
|
||||||
actor: Optional["User"],
|
actor: Optional["User"], # noqa: F821
|
||||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["write"],
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["write"],
|
||||||
access_type: AccessType = AccessType.ORGANIZATION,
|
access_type: AccessType = AccessType.ORGANIZATION,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -729,7 +729,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
|
|
||||||
@handle_db_timeout
|
@handle_db_timeout
|
||||||
async def update_async(
|
async def update_async(
|
||||||
self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False, no_refresh: bool = False
|
self,
|
||||||
|
db_session: "AsyncSession",
|
||||||
|
actor: Optional["User"] = None, # noqa: F821
|
||||||
|
no_commit: bool = False,
|
||||||
|
no_refresh: bool = False,
|
||||||
) -> "SqlalchemyBase":
|
) -> "SqlalchemyBase":
|
||||||
"""Async version of update function"""
|
"""Async version of update function"""
|
||||||
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||||
@@ -774,7 +778,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
db_session: "Session",
|
db_session: "Session",
|
||||||
actor: Optional["User"] = None,
|
actor: Optional["User"] = None, # noqa: F821
|
||||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||||
access_type: AccessType = AccessType.ORGANIZATION,
|
access_type: AccessType = AccessType.ORGANIZATION,
|
||||||
check_is_deleted: bool = False,
|
check_is_deleted: bool = False,
|
||||||
@@ -814,7 +818,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
db_session: "AsyncSession",
|
db_session: "AsyncSession",
|
||||||
actor: Optional["User"] = None,
|
actor: Optional["User"] = None, # noqa: F821
|
||||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||||
access_type: AccessType = AccessType.ORGANIZATION,
|
access_type: AccessType = AccessType.ORGANIZATION,
|
||||||
check_is_deleted: bool = False,
|
check_is_deleted: bool = False,
|
||||||
@@ -850,11 +854,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def apply_access_predicate(
|
def apply_access_predicate(
|
||||||
cls,
|
cls,
|
||||||
query: "Select",
|
query: "Select", # noqa: F821
|
||||||
actor: "User",
|
actor: "User", # noqa: F821
|
||||||
access: List[Literal["read", "write", "admin"]],
|
access: List[Literal["read", "write", "admin"]],
|
||||||
access_type: AccessType = AccessType.ORGANIZATION,
|
access_type: AccessType = AccessType.ORGANIZATION,
|
||||||
) -> "Select":
|
) -> "Select": # noqa: F821
|
||||||
"""applies a WHERE clause restricting results to the given actor and access level
|
"""applies a WHERE clause restricting results to the given actor and access level
|
||||||
Args:
|
Args:
|
||||||
query: The initial sqlalchemy select statement
|
query: The initial sqlalchemy select statement
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import itertools
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
from letta.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
def is_experimental_enabled(feature_name: str, **kwargs) -> bool:
|
def is_experimental_enabled(feature_name: str, **kwargs) -> bool:
|
||||||
# if feature_name in ("async_agent_loop", "summarize"):
|
# if feature_name in ("async_agent_loop", "summarize"):
|
||||||
# if not (kwargs.get("eligibility", False) and settings.use_experimental):
|
# if not (kwargs.get("eligibility", False) and settings.use_experimental):
|
||||||
|
|||||||
@@ -12,6 +12,13 @@ from letta.otel.tracing import trace_method
|
|||||||
from letta.schemas.memory import Memory
|
from letta.schemas.memory import Memory
|
||||||
|
|
||||||
|
|
||||||
|
class PreserveMapping(dict):
|
||||||
|
"""Used to preserve (do not modify) undefined variables in the system prompt"""
|
||||||
|
|
||||||
|
def __missing__(self, key):
|
||||||
|
return "{" + key + "}"
|
||||||
|
|
||||||
|
|
||||||
class PromptGenerator:
|
class PromptGenerator:
|
||||||
# TODO: This code is kind of wonky and deserves a rewrite
|
# TODO: This code is kind of wonky and deserves a rewrite
|
||||||
@trace_method
|
@trace_method
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Dict, List, Literal, Optional
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
from letta.constants import (
|
from letta.constants import (
|
||||||
CORE_MEMORY_LINE_NUMBER_WARNING,
|
|
||||||
DEFAULT_EMBEDDING_CHUNK_SIZE,
|
DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||||
MAX_FILES_OPEN_LIMIT,
|
MAX_FILES_OPEN_LIMIT,
|
||||||
MAX_PER_FILE_VIEW_WINDOW_CHAR_LIMIT,
|
MAX_PER_FILE_VIEW_WINDOW_CHAR_LIMIT,
|
||||||
@@ -15,7 +14,6 @@ from letta.schemas.block import Block, CreateBlock
|
|||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.enums import PrimitiveType
|
from letta.schemas.enums import PrimitiveType
|
||||||
from letta.schemas.environment_variables import AgentEnvironmentVariable
|
from letta.schemas.environment_variables import AgentEnvironmentVariable
|
||||||
from letta.schemas.file import FileStatus
|
|
||||||
from letta.schemas.group import Group
|
from letta.schemas.group import Group
|
||||||
from letta.schemas.identity import Identity
|
from letta.schemas.identity import Identity
|
||||||
from letta.schemas.letta_base import OrmMetadataBase
|
from letta.schemas.letta_base import OrmMetadataBase
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
from openai.types import Reasoning
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,15 +7,13 @@ from typing import Any, List, Literal, Optional, Union
|
|||||||
from pydantic import BaseModel, Field, RootModel
|
from pydantic import BaseModel, Field, RootModel
|
||||||
|
|
||||||
from letta.helpers.json_helpers import json_dumps
|
from letta.helpers.json_helpers import json_dumps
|
||||||
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
from letta.schemas.enums import JobStatus
|
||||||
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs
|
|
||||||
from letta.schemas.letta_message import (
|
from letta.schemas.letta_message import (
|
||||||
ApprovalRequestMessage,
|
ApprovalRequestMessage,
|
||||||
ApprovalResponseMessage,
|
ApprovalResponseMessage,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
HiddenReasoningMessage,
|
HiddenReasoningMessage,
|
||||||
LettaErrorMessage,
|
LettaErrorMessage,
|
||||||
LettaMessage,
|
|
||||||
LettaMessageUnion,
|
LettaMessageUnion,
|
||||||
LettaPing,
|
LettaPing,
|
||||||
ReasoningMessage,
|
ReasoningMessage,
|
||||||
@@ -26,6 +24,7 @@ from letta.schemas.letta_message import (
|
|||||||
)
|
)
|
||||||
from letta.schemas.letta_stop_reason import LettaStopReason
|
from letta.schemas.letta_stop_reason import LettaStopReason
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
|
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
|
|
||||||
# TODO: consider moving into own file
|
# TODO: consider moving into own file
|
||||||
@@ -33,31 +32,21 @@ from letta.schemas.usage import LettaUsageStatistics
|
|||||||
|
|
||||||
class TurnTokenData(BaseModel):
|
class TurnTokenData(BaseModel):
|
||||||
"""Token data for a single LLM generation turn in a multi-turn agent interaction.
|
"""Token data for a single LLM generation turn in a multi-turn agent interaction.
|
||||||
|
|
||||||
Used for RL training to track token IDs and logprobs across all LLM calls,
|
Used for RL training to track token IDs and logprobs across all LLM calls,
|
||||||
not just the final one. Tool results are included so the client can tokenize
|
not just the final one. Tool results are included so the client can tokenize
|
||||||
them with loss_mask=0 (non-trainable).
|
them with loss_mask=0 (non-trainable).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["assistant", "tool"] = Field(
|
role: Literal["assistant", "tool"] = Field(
|
||||||
...,
|
..., description="Role of this turn: 'assistant' for LLM generations (trainable), 'tool' for tool results (non-trainable)."
|
||||||
description="Role of this turn: 'assistant' for LLM generations (trainable), 'tool' for tool results (non-trainable)."
|
|
||||||
)
|
|
||||||
output_ids: Optional[List[int]] = Field(
|
|
||||||
None,
|
|
||||||
description="Token IDs from SGLang native endpoint. Only present for assistant turns."
|
|
||||||
)
|
)
|
||||||
|
output_ids: Optional[List[int]] = Field(None, description="Token IDs from SGLang native endpoint. Only present for assistant turns.")
|
||||||
output_token_logprobs: Optional[List[List[Any]]] = Field(
|
output_token_logprobs: Optional[List[List[Any]]] = Field(
|
||||||
None,
|
None, description="Logprobs from SGLang: [[logprob, token_id, top_logprob_or_null], ...]. Only present for assistant turns."
|
||||||
description="Logprobs from SGLang: [[logprob, token_id, top_logprob_or_null], ...]. Only present for assistant turns."
|
|
||||||
)
|
|
||||||
content: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Text content. For tool turns, client tokenizes this with loss_mask=0."
|
|
||||||
)
|
|
||||||
tool_name: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Name of the tool called. Only present for tool turns."
|
|
||||||
)
|
)
|
||||||
|
content: Optional[str] = Field(None, description="Text content. For tool turns, client tokenizes this with loss_mask=0.")
|
||||||
|
tool_name: Optional[str] = Field(None, description="Name of the tool called. Only present for tool turns.")
|
||||||
|
|
||||||
|
|
||||||
class LettaResponse(BaseModel):
|
class LettaResponse(BaseModel):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Annotated, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from letta.orm.mcp_oauth import OAuthSessionStatus
|
|||||||
from letta.schemas.enums import PrimitiveType
|
from letta.schemas.enums import PrimitiveType
|
||||||
from letta.schemas.letta_base import LettaBase
|
from letta.schemas.letta_base import LettaBase
|
||||||
from letta.schemas.secret import Secret
|
from letta.schemas.secret import Secret
|
||||||
from letta.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMCPServer(LettaBase):
|
class BaseMCPServer(LettaBase):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@@ -6,12 +5,8 @@ from urllib.parse import urlparse
|
|||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
from letta.functions.mcp_client.types import (
|
from letta.functions.mcp_client.types import (
|
||||||
MCP_AUTH_HEADER_AUTHORIZATION,
|
|
||||||
MCP_AUTH_TOKEN_BEARER_PREFIX,
|
MCP_AUTH_TOKEN_BEARER_PREFIX,
|
||||||
MCPServerType,
|
MCPServerType,
|
||||||
SSEServerConfig,
|
|
||||||
StdioServerConfig,
|
|
||||||
StreamableHTTPServerConfig,
|
|
||||||
)
|
)
|
||||||
from letta.orm.mcp_oauth import OAuthSessionStatus
|
from letta.orm.mcp_oauth import OAuthSessionStatus
|
||||||
from letta.schemas.enums import PrimitiveType
|
from letta.schemas.enums import PrimitiveType
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
|
|
||||||
@@ -535,7 +535,7 @@ class BasicBlockMemory(Memory):
|
|||||||
"""
|
"""
|
||||||
super().__init__(blocks=blocks)
|
super().__init__(blocks=blocks)
|
||||||
|
|
||||||
def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore
|
def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore # noqa: F821
|
||||||
"""
|
"""
|
||||||
Append to the contents of core memory.
|
Append to the contents of core memory.
|
||||||
|
|
||||||
@@ -551,7 +551,7 @@ class BasicBlockMemory(Memory):
|
|||||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
|
def core_memory_replace(agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore # noqa: F821
|
||||||
"""
|
"""
|
||||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||||
|
|
||||||
|
|||||||
@@ -11,10 +11,9 @@ import uuid
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
||||||
from openai.types.responses import ResponseReasoningItem
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REQUEST_HEARTBEAT_PARAM, TOOL_CALL_ID_MAX_LEN
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REQUEST_HEARTBEAT_PARAM, TOOL_CALL_ID_MAX_LEN
|
||||||
@@ -30,7 +29,6 @@ from letta.schemas.letta_message import (
|
|||||||
ApprovalReturn,
|
ApprovalReturn,
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
AssistantMessageListResult,
|
AssistantMessageListResult,
|
||||||
CompactionStats,
|
|
||||||
HiddenReasoningMessage,
|
HiddenReasoningMessage,
|
||||||
LettaMessage,
|
LettaMessage,
|
||||||
LettaMessageReturnUnion,
|
LettaMessageReturnUnion,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseModel):
|
class SystemMessage(BaseModel):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|||||||
@@ -90,7 +90,6 @@ class Provider(ProviderBase):
|
|||||||
def list_llm_models(self) -> list[LLMConfig]:
|
def list_llm_models(self) -> list[LLMConfig]:
|
||||||
"""List available LLM models (deprecated: use list_llm_models_async)"""
|
"""List available LLM models (deprecated: use list_llm_models_async)"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import warnings
|
|
||||||
|
|
||||||
logger.warning("list_llm_models is deprecated, use list_llm_models_async instead", stacklevel=2)
|
logger.warning("list_llm_models is deprecated, use list_llm_models_async instead", stacklevel=2)
|
||||||
|
|
||||||
@@ -115,7 +114,6 @@ class Provider(ProviderBase):
|
|||||||
def list_embedding_models(self) -> list[EmbeddingConfig]:
|
def list_embedding_models(self) -> list[EmbeddingConfig]:
|
||||||
"""List available embedding models (deprecated: use list_embedding_models_async)"""
|
"""List available embedding models (deprecated: use list_embedding_models_async)"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import warnings
|
|
||||||
|
|
||||||
logger.warning("list_embedding_models is deprecated, use list_embedding_models_async instead", stacklevel=2)
|
logger.warning("list_embedding_models is deprecated, use list_embedding_models_async instead", stacklevel=2)
|
||||||
|
|
||||||
|
|||||||
@@ -35,8 +35,6 @@ class TogetherProvider(OpenAIProvider):
|
|||||||
return self._list_llm_models(models)
|
return self._list_llm_models(models)
|
||||||
|
|
||||||
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
|
async def list_embedding_models_async(self) -> list[EmbeddingConfig]:
|
||||||
import warnings
|
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Letta does not currently support listing embedding models for Together. Please "
|
"Letta does not currently support listing embedding models for Together. Please "
|
||||||
"contact support or reach out via GitHub or Discord to get support."
|
"contact support or reach out via GitHub or Discord to get support."
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from letta.helpers.tpuf_client import should_use_tpuf
|
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.enums import PrimitiveType, VectorDBProvider
|
from letta.schemas.enums import PrimitiveType, VectorDBProvider
|
||||||
from letta.schemas.letta_base import LettaBase
|
from letta.schemas.letta_base import LettaBase
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from letta.functions.functions import get_json_schema_from_module
|
|||||||
from letta.functions.mcp_client.types import MCPTool
|
from letta.functions.mcp_client.types import MCPTool
|
||||||
from letta.functions.schema_generator import generate_tool_schema_for_mcp
|
from letta.functions.schema_generator import generate_tool_schema_for_mcp
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.schemas.enums import ToolSourceType, ToolType
|
from letta.schemas.enums import ToolType
|
||||||
from letta.schemas.letta_base import LettaBase
|
from letta.schemas.letta_base import LettaBase
|
||||||
from letta.schemas.npm_requirement import NpmRequirement
|
from letta.schemas.npm_requirement import NpmRequirement
|
||||||
from letta.schemas.pip_requirement import PipRequirement
|
from letta.schemas.pip_requirement import PipRequirement
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Uni
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from letta.schemas.message import Message
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.schemas.openai.chat_completion_response import (
|
from letta.schemas.openai.chat_completion_response import (
|
||||||
UsageStatisticsCompletionTokenDetails,
|
UsageStatisticsCompletionTokenDetails,
|
||||||
@@ -133,7 +131,7 @@ class LettaUsageStatistics(BaseModel):
|
|||||||
description="Estimate of tokens currently in the context window.",
|
description="Estimate of tokens currently in the context window.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_usage(self, provider_type: Optional["ProviderType"] = None) -> "UsageStatistics":
|
def to_usage(self, provider_type: Optional["ProviderType"] = None) -> "UsageStatistics": # noqa: F821 # noqa: F821
|
||||||
"""Convert to UsageStatistics (OpenAI-compatible format).
|
"""Convert to UsageStatistics (OpenAI-compatible format).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from letta.serialize_schemas.marshmallow_agent import MarshmallowAgentSchema
|
from letta.serialize_schemas.marshmallow_agent import MarshmallowAgentSchema as MarshmallowAgentSchema
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -27,7 +26,7 @@ from starlette.middleware.cors import CORSMiddleware
|
|||||||
|
|
||||||
from letta.__init__ import __version__ as letta_version
|
from letta.__init__ import __version__ as letta_version
|
||||||
from letta.agents.exceptions import IncompatibleAgentType
|
from letta.agents.exceptions import IncompatibleAgentType
|
||||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
from letta.constants import ADMIN_PREFIX, API_PREFIX
|
||||||
from letta.errors import (
|
from letta.errors import (
|
||||||
AgentExportIdMappingError,
|
AgentExportIdMappingError,
|
||||||
AgentExportProcessingError,
|
AgentExportProcessingError,
|
||||||
@@ -108,7 +107,6 @@ class SafeORJSONResponse(ORJSONResponse):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from letta.server.db import db_registry
|
|
||||||
from letta.server.global_exception_handler import setup_global_exception_handlers
|
from letta.server.global_exception_handler import setup_global_exception_handlers
|
||||||
|
|
||||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Shared helper functions for Anthropic-compatible proxy endpoints.
|
|||||||
These helpers are used by both the Anthropic and Z.ai proxy routers to reduce code duplication.
|
These helpers are used by both the Anthropic and Z.ai proxy routers to reduce code duplication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|||||||
@@ -1,29 +1,21 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import traceback
|
from datetime import datetime
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, Query, Request, UploadFile, status
|
from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, Query, Request, UploadFile, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from marshmallow import ValidationError
|
|
||||||
from orjson import orjson
|
from orjson import orjson
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
|
||||||
from starlette.responses import Response, StreamingResponse
|
from starlette.responses import Response, StreamingResponse
|
||||||
|
|
||||||
from letta.agents.agent_loop import AgentLoop
|
from letta.agents.agent_loop import AgentLoop
|
||||||
from letta.agents.base_agent_v2 import BaseAgentV2
|
from letta.agents.base_agent_v2 import BaseAgentV2
|
||||||
from letta.agents.letta_agent import LettaAgent
|
from letta.agents.letta_agent import LettaAgent
|
||||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
|
||||||
from letta.agents.letta_agent_v3 import LettaAgentV3
|
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||||
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX
|
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX
|
||||||
from letta.data_sources.redis_client import get_redis_client
|
from letta.data_sources.redis_client import get_redis_client
|
||||||
from letta.errors import (
|
from letta.errors import (
|
||||||
AgentExportIdMappingError,
|
|
||||||
AgentExportProcessingError,
|
|
||||||
AgentFileImportError,
|
|
||||||
AgentNotFoundForExportError,
|
|
||||||
HandleNotFoundError,
|
HandleNotFoundError,
|
||||||
LLMError,
|
LLMError,
|
||||||
NoActiveRunsToCancelError,
|
NoActiveRunsToCancelError,
|
||||||
@@ -31,16 +23,15 @@ from letta.errors import (
|
|||||||
)
|
)
|
||||||
from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4
|
from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4
|
||||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
||||||
from letta.llm_api.llm_client import LLMClient
|
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.otel.context import get_ctx_attributes
|
from letta.otel.context import get_ctx_attributes
|
||||||
from letta.otel.metric_registry import MetricRegistry
|
from letta.otel.metric_registry import MetricRegistry
|
||||||
from letta.schemas.agent import AgentRelationships, AgentState, CreateAgent, UpdateAgent
|
from letta.schemas.agent import AgentRelationships, AgentState, CreateAgent, UpdateAgent
|
||||||
from letta.schemas.agent_file import AgentFileSchema, SkillSchema
|
from letta.schemas.agent_file import AgentFileSchema, SkillSchema
|
||||||
from letta.schemas.block import BaseBlock, Block, BlockResponse, BlockUpdate
|
from letta.schemas.block import BlockResponse, BlockUpdate
|
||||||
from letta.schemas.enums import AgentType, MessageRole, RunStatus
|
from letta.schemas.enums import AgentType, MessageRole, RunStatus
|
||||||
from letta.schemas.file import AgentFileAttachment, FileMetadataBase, PaginatedAgentFiles
|
from letta.schemas.file import AgentFileAttachment, PaginatedAgentFiles
|
||||||
from letta.schemas.group import Group
|
from letta.schemas.group import Group
|
||||||
from letta.schemas.job import LettaRequestConfig
|
from letta.schemas.job import LettaRequestConfig
|
||||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
|
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
|
||||||
@@ -59,8 +50,8 @@ from letta.schemas.memory import (
|
|||||||
from letta.schemas.message import Message, MessageCreate, MessageCreateType, MessageSearchRequest, MessageSearchResult
|
from letta.schemas.message import Message, MessageCreate, MessageCreateType, MessageSearchRequest, MessageSearchResult
|
||||||
from letta.schemas.passage import Passage
|
from letta.schemas.passage import Passage
|
||||||
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
||||||
from letta.schemas.source import BaseSource, Source
|
from letta.schemas.source import Source
|
||||||
from letta.schemas.tool import BaseTool, Tool
|
from letta.schemas.tool import Tool
|
||||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
|||||||
@@ -1,19 +1,17 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Dict, List, Literal, Optional
|
from typing import Dict, List, Literal, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, Query
|
from fastapi import APIRouter, Body, Depends, Query
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from letta import AgentState
|
from letta import AgentState
|
||||||
from letta.errors import LettaInvalidArgumentError
|
|
||||||
from letta.schemas.agent import AgentRelationships
|
from letta.schemas.agent import AgentRelationships
|
||||||
from letta.schemas.archive import Archive as PydanticArchive, ArchiveBase
|
from letta.schemas.archive import Archive as PydanticArchive
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.passage import Passage
|
from letta.schemas.passage import Passage
|
||||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||||
from letta.server.server import SyncServer
|
from letta.server.server import SyncServer
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
from letta.validators import AgentId, ArchiveId, PassageId
|
from letta.validators import ArchiveId, PassageId
|
||||||
|
|
||||||
router = APIRouter(prefix="/archives", tags=["archives"])
|
router = APIRouter(prefix="/archives", tags=["archives"])
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user