chore: add ty + pre-commit hook and repeal even more ruff rules (#9504)
* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import * remove all ignores, add FastAPI rules and Ruff rules * add ty and precommit * ruff stuff * ty check fixes * ty check fixes pt 2 * error on invalid
This commit is contained in:
@@ -31,7 +31,7 @@ def get_support_status(passed_tests, feature_tests):
|
||||
|
||||
# Filter out error tests when checking for support
|
||||
non_error_tests = [test for test in feature_tests if not test.endswith("_error")]
|
||||
error_tests = [test for test in feature_tests if test.endswith("_error")]
|
||||
[test for test in feature_tests if test.endswith("_error")]
|
||||
|
||||
# Check which non-error tests passed
|
||||
passed_non_error_tests = [test for test in non_error_tests if test in passed_tests]
|
||||
@@ -137,7 +137,7 @@ def get_github_repo_info():
|
||||
else:
|
||||
return None
|
||||
return repo_path
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
@@ -335,7 +335,7 @@ def process_model_sweep_report(input_file, output_file, config_file=None, debug=
|
||||
# Format timestamp if it's a full ISO string
|
||||
if "T" in str(last_scanned):
|
||||
last_scanned = str(last_scanned).split("T")[0] # Just the date part
|
||||
except:
|
||||
except Exception:
|
||||
last_scanned = "Unknown"
|
||||
|
||||
# Calculate support score for ranking
|
||||
|
||||
2
.github/scripts/model-sweep/model_sweep.py
vendored
2
.github/scripts/model-sweep/model_sweep.py
vendored
@@ -690,7 +690,7 @@ def test_token_streaming_agent_loop_error(
|
||||
stream_tokens=True,
|
||||
)
|
||||
list(response)
|
||||
except:
|
||||
except Exception:
|
||||
pass # only some models throw an error TODO: make this consistent
|
||||
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
|
||||
@@ -23,3 +23,10 @@ repos:
|
||||
- id: ruff-check
|
||||
args: [ --fix ]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: ty
|
||||
name: ty check
|
||||
entry: uv run ty check .
|
||||
language: python
|
||||
|
||||
@@ -143,7 +143,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
# Extract tool call from the interface
|
||||
try:
|
||||
self.tool_call = self.interface.get_tool_call_object()
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
# No tool call, handle upstream
|
||||
self.tool_call = None
|
||||
|
||||
|
||||
@@ -292,7 +292,7 @@ class SGLangNativeAdapter(SimpleLLMRequestAdapter):
|
||||
if isinstance(tc_args, str):
|
||||
try:
|
||||
tc_args = json.loads(tc_args)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tc_parts.append(f'<tool_call>\n{{"name": "{tc_name}", "arguments": {json.dumps(tc_args)}}}\n</tool_call>')
|
||||
|
||||
@@ -168,7 +168,7 @@ class BaseAgent(ABC):
|
||||
actor=self.actor,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
return [new_system_message, *in_context_messages[1:]]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
|
||||
@@ -79,7 +79,7 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
content=[TextContent(text=get_system_text("summary_system_prompt"))],
|
||||
)
|
||||
messages = await convert_message_creates_to_messages(
|
||||
message_creates=[system_message_create] + input_messages,
|
||||
message_creates=[system_message_create, *input_messages],
|
||||
agent_id=self.agent_id,
|
||||
timezone=agent_state.timezone,
|
||||
run_id=None, # TODO: add this
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import json
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
from letta.errors import LettaError, PendingApprovalError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
@@ -462,7 +465,7 @@ def _schema_accepts_value(prop_schema: Dict[str, Any], value: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def merge_and_validate_prefilled_args(tool: "Tool", llm_args: Dict[str, Any], prefilled_args: Dict[str, Any]) -> Dict[str, Any]: # noqa: F821
|
||||
def merge_and_validate_prefilled_args(tool: "Tool", llm_args: Dict[str, Any], prefilled_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Merge LLM-provided args with prefilled args from tool rules.
|
||||
|
||||
- Overlapping keys are replaced by prefilled values (prefilled wins).
|
||||
|
||||
@@ -1574,7 +1574,7 @@ class LettaAgent(BaseAgent):
|
||||
self.logger.warning(
|
||||
f"Total tokens {total_tokens} exceeds configured max tokens {llm_config.context_window}, forcefully clearing message history."
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
force=True,
|
||||
@@ -1587,7 +1587,7 @@ class LettaAgent(BaseAgent):
|
||||
self.logger.info(
|
||||
f"Total tokens {total_tokens} does not exceed configured max tokens {llm_config.context_window}, passing summarizing w/o force."
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
run_id=run_id,
|
||||
@@ -1607,7 +1607,7 @@ class LettaAgent(BaseAgent):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
|
||||
message_ids = agent_state.message_ids
|
||||
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=self.actor)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages, new_letta_messages=[], force=True
|
||||
)
|
||||
return await self.agent_manager.update_message_ids_async(
|
||||
|
||||
@@ -217,7 +217,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
|
||||
if batch_items:
|
||||
log_event(name="bulk_create_batch_items")
|
||||
batch_items_persisted = await self.batch_manager.create_llm_batch_items_bulk_async(batch_items, actor=self.actor)
|
||||
await self.batch_manager.create_llm_batch_items_bulk_async(batch_items, actor=self.actor)
|
||||
|
||||
log_event(name="return_batch_response")
|
||||
return LettaBatchResponse(
|
||||
|
||||
@@ -456,7 +456,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
step_progression = StepProgression.START
|
||||
caught_exception = None
|
||||
# TODO(@caren): clean this up
|
||||
tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
|
||||
tool_call, reasoning_content, agent_step_span, first_chunk, step_id, logged_step, _step_start_ns, step_metrics = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@@ -752,7 +752,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
num_archival_memories=None,
|
||||
force=True,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# Always scrub inner thoughts regardless of system prompt refresh
|
||||
@@ -835,7 +835,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
new_system_message = await self.message_manager.update_message_by_id_async(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
return [new_system_message, *in_context_messages[1:]]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
@@ -1322,7 +1322,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
self.logger.warning(
|
||||
f"Total tokens {total_tokens} exceeds configured max tokens {self.agent_state.llm_config.context_window}, forcefully clearing message history."
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
force=True,
|
||||
@@ -1335,7 +1335,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
self.logger.info(
|
||||
f"Total tokens {total_tokens} does not exceed configured max tokens {self.agent_state.llm_config.context_window}, passing summarizing w/o force."
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
run_id=run_id,
|
||||
|
||||
@@ -644,7 +644,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
message.conversation_id = self.conversation_id
|
||||
|
||||
# persist the new message objects - ONLY place where messages are persisted
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
await self.message_manager.create_many_messages_async(
|
||||
new_messages,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
@@ -799,7 +799,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
step_progression = StepProgression.START
|
||||
caught_exception = None
|
||||
# TODO(@caren): clean this up
|
||||
tool_calls, content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
|
||||
tool_calls, content, agent_step_span, _first_chunk, step_id, logged_step, _step_start_ns, step_metrics = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
@@ -971,7 +971,6 @@ class LettaAgentV3(LettaAgentV2):
|
||||
async for chunk in invocation:
|
||||
if llm_adapter.supports_token_streaming():
|
||||
if include_return_message_types is None or chunk.message_type in include_return_message_types:
|
||||
first_chunk = True
|
||||
yield chunk
|
||||
# If you've reached this point without an error, break out of retry loop
|
||||
break
|
||||
@@ -1659,10 +1658,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# Decide continuation for this tool
|
||||
if has_prefill_error:
|
||||
cont = False
|
||||
hb_reason = None
|
||||
_hb_reason = None
|
||||
sr = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
|
||||
else:
|
||||
cont, hb_reason, sr = self._decide_continuation(
|
||||
cont, _hb_reason, sr = self._decide_continuation(
|
||||
agent_state=self.agent_state,
|
||||
tool_call_name=spec["name"],
|
||||
tool_rule_violated=spec["violated"],
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import openai
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.agents.voice_sleeptime_agent import VoiceSleeptimeAgent
|
||||
@@ -250,7 +253,6 @@ class VoiceAgent(BaseAgent):
|
||||
agent_state=agent_state,
|
||||
)
|
||||
tool_result = tool_execution_result.func_return
|
||||
success_flag = tool_execution_result.success_flag
|
||||
|
||||
# 3. Provide function_call response back into the conversation
|
||||
# TODO: fix this tool format
|
||||
@@ -292,7 +294,7 @@ class VoiceAgent(BaseAgent):
|
||||
new_letta_messages = await self.message_manager.create_many_messages_async(letta_message_db_queue, actor=self.actor)
|
||||
|
||||
# TODO: Make this more general and configurable, less brittle
|
||||
new_in_context_messages, updated = await summarizer.summarize(
|
||||
new_in_context_messages, _updated = await summarizer.summarize(
|
||||
in_context_messages=in_context_messages, new_letta_messages=new_letta_messages
|
||||
)
|
||||
|
||||
@@ -414,7 +416,7 @@ class VoiceAgent(BaseAgent):
|
||||
for t in tools
|
||||
]
|
||||
|
||||
async def _execute_tool(self, user_query: str, tool_name: str, tool_args: dict, agent_state: AgentState) -> "ToolExecutionResult": # noqa: F821
|
||||
async def _execute_tool(self, user_query: str, tool_name: str, tool_args: dict, agent_state: AgentState) -> "ToolExecutionResult":
|
||||
"""
|
||||
Executes a tool and returns the ToolExecutionResult.
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
from letta.agents.helpers import _create_letta_response, serialize_message_history
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
@@ -89,7 +94,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
current_in_context_messages, new_in_context_messages, stop_reason, usage = await super()._step(
|
||||
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
|
||||
)
|
||||
new_in_context_messages, updated = await self.summarizer.summarize(
|
||||
new_in_context_messages, _updated = await self.summarizer.summarize(
|
||||
in_context_messages=current_in_context_messages, new_letta_messages=new_in_context_messages
|
||||
)
|
||||
self.agent_manager.set_in_context_messages(
|
||||
@@ -110,9 +115,9 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
tool_name: str,
|
||||
tool_args: JsonDict,
|
||||
agent_state: AgentState,
|
||||
agent_step_span: Optional["Span"] = None, # noqa: F821
|
||||
agent_step_span: Optional["Span"] = None,
|
||||
step_id: str | None = None,
|
||||
) -> "ToolExecutionResult": # noqa: F821
|
||||
) -> "ToolExecutionResult":
|
||||
"""
|
||||
Executes a tool and returns the ToolExecutionResult
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterator, List, Tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.user import User
|
||||
|
||||
import typer
|
||||
|
||||
@@ -37,7 +40,7 @@ class DataConnector:
|
||||
"""
|
||||
|
||||
|
||||
async def load_data(connector: DataConnector, source: Source, passage_manager: PassageManager, file_manager: FileManager, actor: "User"): # noqa: F821
|
||||
async def load_data(connector: DataConnector, source: Source, passage_manager: PassageManager, file_manager: FileManager, actor: "User"):
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
|
||||
@@ -143,7 +146,13 @@ async def load_data(connector: DataConnector, source: Source, passage_manager: P
|
||||
|
||||
|
||||
class DirectoryConnector(DataConnector):
|
||||
def __init__(self, input_files: List[str] = None, input_directory: str = None, recursive: bool = False, extensions: List[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
input_files: List[str] | None = None,
|
||||
input_directory: str | None = None,
|
||||
recursive: bool = False,
|
||||
extensions: List[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Connector for reading text data from a directory of files.
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ class AsyncRedisClient:
|
||||
try:
|
||||
client = await self.get_client()
|
||||
return await client.get(key)
|
||||
except:
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@with_retry()
|
||||
@@ -320,7 +320,7 @@ class AsyncRedisClient:
|
||||
client = await self.get_client()
|
||||
result = await client.smismember(key, values)
|
||||
return result if isinstance(values, list) else result[0]
|
||||
except:
|
||||
except Exception:
|
||||
return [0] * len(values) if isinstance(values, list) else 0
|
||||
|
||||
async def srem(self, key: str, *members: Union[str, int, float]) -> int:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
# Avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
@@ -362,16 +363,16 @@ class RateLimitExceededError(LettaError):
|
||||
class LettaMessageError(LettaError):
|
||||
"""Base error class for handling message-related errors."""
|
||||
|
||||
messages: List[Union["Message", "LettaMessage"]] # noqa: F821
|
||||
messages: List[Union["Message", "LettaMessage"]]
|
||||
default_error_message: str = "An error occurred with the message."
|
||||
|
||||
def __init__(self, *, messages: List[Union["Message", "LettaMessage"]], explanation: Optional[str] = None) -> None: # noqa: F821
|
||||
def __init__(self, *, messages: List[Union["Message", "LettaMessage"]], explanation: Optional[str] = None) -> None:
|
||||
error_msg = self.construct_error_message(messages, self.default_error_message, explanation)
|
||||
super().__init__(error_msg)
|
||||
self.messages = messages
|
||||
|
||||
@staticmethod
|
||||
def construct_error_message(messages: List[Union["Message", "LettaMessage"]], error_msg: str, explanation: Optional[str] = None) -> str: # noqa: F821
|
||||
def construct_error_message(messages: List[Union["Message", "LettaMessage"]], error_msg: str, explanation: Optional[str] = None) -> str:
|
||||
"""Helper method to construct a clean and formatted error message."""
|
||||
if explanation:
|
||||
error_msg += f" (Explanation: {explanation})"
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agents.letta_agent import LettaAgent as Agent
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
|
||||
|
||||
|
||||
def memory(
|
||||
agent_state: "AgentState",
|
||||
@@ -67,7 +68,7 @@ def memory(
|
||||
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
|
||||
|
||||
|
||||
def send_message(self: "Agent", message: str) -> Optional[str]: # noqa: F821
|
||||
def send_message(self: "Agent", message: str) -> Optional[str]:
|
||||
"""
|
||||
Sends a message to the human user.
|
||||
|
||||
@@ -84,7 +85,7 @@ def send_message(self: "Agent", message: str) -> Optional[str]: # noqa: F821
|
||||
|
||||
|
||||
def conversation_search(
|
||||
self: "Agent", # noqa: F821
|
||||
self: "Agent",
|
||||
query: Optional[str] = None,
|
||||
roles: Optional[List[Literal["assistant", "user", "tool"]]] = None,
|
||||
limit: Optional[int] = None,
|
||||
@@ -160,7 +161,7 @@ def conversation_search(
|
||||
return results_str
|
||||
|
||||
|
||||
async def archival_memory_insert(self: "Agent", content: str, tags: Optional[list[str]] = None) -> Optional[str]: # noqa: F821
|
||||
async def archival_memory_insert(self: "Agent", content: str, tags: Optional[list[str]] = None) -> Optional[str]:
|
||||
"""
|
||||
Add information to long-term archival memory for later retrieval.
|
||||
|
||||
@@ -191,7 +192,7 @@ async def archival_memory_insert(self: "Agent", content: str, tags: Optional[lis
|
||||
|
||||
|
||||
async def archival_memory_search(
|
||||
self: "Agent", # noqa: F821
|
||||
self: "Agent",
|
||||
query: str,
|
||||
tags: Optional[list[str]] = None,
|
||||
tag_match_mode: Literal["any", "all"] = "any",
|
||||
@@ -431,7 +432,7 @@ def memory_insert(agent_state: "AgentState", label: str, new_str: str, insert_li
|
||||
# Insert the new string as a line
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_value_lines = current_value_lines[:insert_line] + new_str_lines + current_value_lines[insert_line:]
|
||||
snippet_lines = (
|
||||
(
|
||||
current_value_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ current_value_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agents.letta_agent import LettaAgent as Agent
|
||||
|
||||
from letta.functions.helpers import (
|
||||
_send_message_to_agents_matching_tags_async,
|
||||
@@ -13,7 +16,7 @@ from letta.server.rest_api.dependencies import get_letta_server
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_agent_id: str) -> str: # noqa: F821
|
||||
def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_agent_id: str) -> str:
|
||||
"""
|
||||
Sends a message to a specific Letta agent within the same organization and waits for a response. The sender's identity is automatically included, so no explicit introduction is needed in the message. This function is designed for two-way communication where a reply is expected.
|
||||
|
||||
@@ -37,7 +40,7 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_
|
||||
)
|
||||
|
||||
|
||||
def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: List[str], match_some: List[str]) -> List[str]: # noqa: F821
|
||||
def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all: List[str], match_some: List[str]) -> List[str]:
|
||||
"""
|
||||
Sends a message to all agents within the same organization that match the specified tag criteria. Agents must possess *all* of the tags in `match_all` and *at least one* of the tags in `match_some` to receive the message.
|
||||
|
||||
@@ -66,7 +69,7 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
|
||||
return asyncio.run(_send_message_to_agents_matching_tags_async(self, server, messages, matching_agents))
|
||||
|
||||
|
||||
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]: # noqa: F821
|
||||
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]:
|
||||
"""
|
||||
Sends a message to all agents within the same multi-agent group.
|
||||
|
||||
@@ -82,7 +85,7 @@ def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str
|
||||
return asyncio.run(_send_message_to_all_agents_in_group_async(self, message))
|
||||
|
||||
|
||||
def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str) -> str: # noqa: F821
|
||||
def send_message_to_agent_async(self: "Agent", message: str, other_agent_id: str) -> str:
|
||||
"""
|
||||
Sends a message to a specific Letta agent within the same organization. The sender's identity is automatically included, so no explicit introduction is required in the message. This function does not expect a response from the target agent, making it suitable for notifications or one-way communication.
|
||||
Args:
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
## Voice chat + sleeptime tools
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def rethink_user_memory(agent_state: "AgentState", new_memory: str) -> None: # noqa: F821
|
||||
def rethink_user_memory(agent_state: "AgentState", new_memory: str) -> None:
|
||||
"""
|
||||
Rewrite memory block for the main agent, new_memory should contain all current information from the block that is not outdated or inconsistent, integrating any new information, resulting in a new memory block that is organized, readable, and comprehensive.
|
||||
|
||||
@@ -18,7 +21,7 @@ def rethink_user_memory(agent_state: "AgentState", new_memory: str) -> None: #
|
||||
return None
|
||||
|
||||
|
||||
def finish_rethinking_memory(agent_state: "AgentState") -> None: # type: ignore # noqa: F821
|
||||
def finish_rethinking_memory(agent_state: "AgentState") -> None: # type: ignore
|
||||
"""
|
||||
This function is called when the agent is done rethinking the memory.
|
||||
|
||||
@@ -43,7 +46,7 @@ class MemoryChunk(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def store_memories(agent_state: "AgentState", chunks: List[MemoryChunk]) -> None: # noqa: F821
|
||||
def store_memories(agent_state: "AgentState", chunks: List[MemoryChunk]) -> None:
|
||||
"""
|
||||
Persist dialogue that is about to fall out of the agent’s context window.
|
||||
|
||||
@@ -59,7 +62,7 @@ def store_memories(agent_state: "AgentState", chunks: List[MemoryChunk]) -> None
|
||||
|
||||
|
||||
def search_memory(
|
||||
agent_state: "AgentState", # noqa: F821
|
||||
agent_state: "AgentState",
|
||||
convo_keyword_queries: Optional[List[str]],
|
||||
start_minutes_ago: Optional[int],
|
||||
end_minutes_ago: Optional[int],
|
||||
|
||||
@@ -179,7 +179,7 @@ def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dic
|
||||
pass # Field is required, no default
|
||||
else:
|
||||
field_kwargs["default"] = default_val
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
fields[field_name] = Field(**field_kwargs)
|
||||
@@ -188,7 +188,7 @@ def _extract_pydantic_classes(tree: ast.AST, imports_map: Dict[str, Any]) -> Dic
|
||||
try:
|
||||
default_val = ast.literal_eval(stmt.value)
|
||||
fields[field_name] = default_val
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create the dynamic Pydantic model
|
||||
|
||||
@@ -3,7 +3,17 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
from random import uniform
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agents.letta_agent import LettaAgent as Agent
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
try:
|
||||
from langchain.tools.base import BaseTool as LangChainBaseTool
|
||||
except ImportError:
|
||||
LangChainBaseTool = None
|
||||
|
||||
import humps
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
@@ -21,6 +31,8 @@ from letta.server.rest_api.dependencies import get_letta_server
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
# TODO needed?
|
||||
def generate_mcp_tool_wrapper(mcp_tool_name: str) -> tuple[str, str]:
|
||||
@@ -36,8 +48,8 @@ def {mcp_tool_name}(**kwargs):
|
||||
|
||||
|
||||
def generate_langchain_tool_wrapper(
|
||||
tool: "LangChainBaseTool", # noqa: F821
|
||||
additional_imports_module_attr_map: dict[str, str] = None,
|
||||
tool: "LangChainBaseTool",
|
||||
additional_imports_module_attr_map: dict[str, str] | None = None,
|
||||
) -> tuple[str, str]:
|
||||
tool_name = tool.__class__.__name__
|
||||
import_statement = f"from langchain_community.tools import {tool_name}"
|
||||
@@ -73,7 +85,7 @@ def _assert_code_gen_compilable(code_str):
|
||||
print(f"Syntax error in code: {e}")
|
||||
|
||||
|
||||
def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None: # noqa: F821
|
||||
def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
|
||||
# Safety check that user has passed in all required imports:
|
||||
tool_name = tool.__class__.__name__
|
||||
current_class_imports = {tool_name}
|
||||
@@ -87,7 +99,7 @@ def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additiona
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
|
||||
def _find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]: # noqa: F821
|
||||
def _find_required_class_names_for_import(obj: Union["LangChainBaseTool", BaseModel]) -> list[str]:
|
||||
"""
|
||||
Finds all the class names for required imports when instantiating the `obj`.
|
||||
NOTE: This does not return the full import path, only the class name.
|
||||
@@ -225,7 +237,7 @@ def _parse_letta_response_for_assistant_message(
|
||||
|
||||
|
||||
async def async_execute_send_message_to_agent(
|
||||
sender_agent: "Agent", # noqa: F821
|
||||
sender_agent: "Agent",
|
||||
messages: List[MessageCreate],
|
||||
other_agent_id: str,
|
||||
log_prefix: str,
|
||||
@@ -256,7 +268,7 @@ async def async_execute_send_message_to_agent(
|
||||
|
||||
|
||||
def execute_send_message_to_agent(
|
||||
sender_agent: "Agent", # noqa: F821
|
||||
sender_agent: "Agent",
|
||||
messages: List[MessageCreate],
|
||||
other_agent_id: str,
|
||||
log_prefix: str,
|
||||
@@ -269,7 +281,7 @@ def execute_send_message_to_agent(
|
||||
|
||||
|
||||
async def _send_message_to_agent_no_stream(
|
||||
server: "SyncServer", # noqa: F821
|
||||
server: "SyncServer",
|
||||
agent_id: str,
|
||||
actor: User,
|
||||
messages: List[MessageCreate],
|
||||
@@ -302,8 +314,8 @@ async def _send_message_to_agent_no_stream(
|
||||
|
||||
|
||||
async def _async_send_message_with_retries(
|
||||
server: "SyncServer", # noqa: F821
|
||||
sender_agent: "Agent", # noqa: F821
|
||||
server: "SyncServer",
|
||||
sender_agent: "Agent",
|
||||
target_agent_id: str,
|
||||
messages: List[MessageCreate],
|
||||
max_retries: int,
|
||||
@@ -353,7 +365,7 @@ async def _async_send_message_with_retries(
|
||||
|
||||
|
||||
def fire_and_forget_send_to_agent(
|
||||
sender_agent: "Agent", # noqa: F821
|
||||
sender_agent: "Agent",
|
||||
messages: List[MessageCreate],
|
||||
other_agent_id: str,
|
||||
log_prefix: str,
|
||||
@@ -429,18 +441,18 @@ def fire_and_forget_send_to_agent(
|
||||
# 4) Try to schedule the coroutine in an existing loop, else spawn a thread
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we get here, a loop is running; schedule the coroutine in background
|
||||
loop.create_task(background_task())
|
||||
task = loop.create_task(background_task())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except RuntimeError:
|
||||
# Means no event loop is running in this thread
|
||||
run_in_background_thread(background_task())
|
||||
|
||||
|
||||
async def _send_message_to_agents_matching_tags_async(
|
||||
sender_agent: "Agent", # noqa: F821
|
||||
server: "SyncServer", # noqa: F821
|
||||
sender_agent: "Agent",
|
||||
server: "SyncServer",
|
||||
messages: List[MessageCreate],
|
||||
matching_agents: List["AgentState"], # noqa: F821
|
||||
matching_agents: List["AgentState"],
|
||||
) -> List[str]:
|
||||
async def _send_single(agent_state):
|
||||
return await _async_send_message_with_retries(
|
||||
@@ -464,7 +476,7 @@ async def _send_message_to_agents_matching_tags_async(
|
||||
return final
|
||||
|
||||
|
||||
async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", message: str) -> List[str]: # noqa: F821
|
||||
async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", message: str) -> List[str]:
|
||||
server = get_letta_server()
|
||||
|
||||
augmented_message = (
|
||||
@@ -522,7 +534,9 @@ def generate_model_from_args_json_schema(schema: Dict[str, Any]) -> Type[BaseMod
|
||||
return _create_model_from_schema(schema.get("title", "DynamicModel"), schema, nested_models)
|
||||
|
||||
|
||||
def _create_model_from_schema(name: str, model_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] = None) -> Type[BaseModel]:
|
||||
def _create_model_from_schema(
|
||||
name: str, model_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] | None = None
|
||||
) -> Type[BaseModel]:
|
||||
fields = {}
|
||||
for field_name, field_schema in model_schema["properties"].items():
|
||||
field_type = _get_field_type(field_schema, nested_models)
|
||||
@@ -533,7 +547,7 @@ def _create_model_from_schema(name: str, model_schema: Dict[str, Any], nested_mo
|
||||
return create_model(name, **fields)
|
||||
|
||||
|
||||
def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] = None) -> Any:
|
||||
def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[BaseModel]] | None = None) -> Any:
|
||||
"""Helper to convert JSON schema types to Python types."""
|
||||
if field_schema.get("type") == "string":
|
||||
return str
|
||||
|
||||
@@ -96,7 +96,7 @@ def type_to_json_schema_type(py_type) -> dict:
|
||||
|
||||
# Handle array types
|
||||
origin = get_origin(py_type)
|
||||
if py_type == list or origin in (list, List):
|
||||
if py_type is list or origin in (list, List):
|
||||
args = get_args(py_type)
|
||||
if len(args) == 0:
|
||||
# is this correct
|
||||
@@ -142,7 +142,7 @@ def type_to_json_schema_type(py_type) -> dict:
|
||||
}
|
||||
|
||||
# Handle object types
|
||||
if py_type == dict or origin in (dict, Dict):
|
||||
if py_type is dict or origin in (dict, Dict):
|
||||
args = get_args(py_type)
|
||||
if not args:
|
||||
# Generic dict without type arguments
|
||||
|
||||
@@ -56,7 +56,7 @@ def validate_complete_json_schema(schema: Dict[str, Any]) -> Tuple[SchemaHealth,
|
||||
"""
|
||||
if obj_schema.get("type") != "object":
|
||||
return False
|
||||
props = obj_schema.get("properties", {})
|
||||
obj_schema.get("properties", {})
|
||||
required = obj_schema.get("required", [])
|
||||
additional = obj_schema.get("additionalProperties", True)
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agents.letta_agent import LettaAgent as Agent
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
@@ -92,7 +95,7 @@ class DynamicMultiAgent(BaseAgent):
|
||||
|
||||
# Parse manager response
|
||||
responses = Message.to_letta_messages_from_list(manager_agent.last_response_messages)
|
||||
assistant_message = [response for response in responses if response.message_type == "assistant_message"][0]
|
||||
assistant_message = next(response for response in responses if response.message_type == "assistant_message")
|
||||
for name, agent_id in [(agents[agent_id].agent_state.name, agent_id) for agent_id in agent_id_options]:
|
||||
if name.lower() in assistant_message.content.lower():
|
||||
speaker_id = agent_id
|
||||
@@ -177,7 +180,7 @@ class DynamicMultiAgent(BaseAgent):
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
|
||||
def load_manager_agent(self) -> Agent: # noqa: F821
|
||||
def load_manager_agent(self) -> Agent:
|
||||
for participant_agent_id in self.agent_ids:
|
||||
participant_agent_state = self.agent_manager.get_agent_by_id(agent_id=participant_agent_id, actor=self.user)
|
||||
participant_persona_block = participant_agent_state.memory.get_block(label="persona")
|
||||
|
||||
@@ -98,7 +98,7 @@ def stringify_message(message: Message, use_assistant_name: bool = False) -> str
|
||||
elif isinstance(content, ImageContent):
|
||||
messages.append(f"{message.name or 'user'}: [Image Here]")
|
||||
return "\n".join(messages)
|
||||
except:
|
||||
except Exception:
|
||||
if message.content and len(message.content) > 0:
|
||||
return f"{message.name or 'user'}: {message.content[0].text}"
|
||||
return None
|
||||
|
||||
@@ -212,7 +212,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
|
||||
)
|
||||
for sleeptime_agent_id in self.group.agent_ids:
|
||||
run_id = await self._issue_background_task(
|
||||
await self._issue_background_task(
|
||||
sleeptime_agent_id,
|
||||
last_response_messages,
|
||||
last_processed_message_id,
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.services.summarizer.summarizer_config import CompactionSettings
|
||||
|
||||
import numpy as np
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
@@ -113,7 +116,7 @@ def deserialize_embedding_config(data: Optional[Dict]) -> Optional[EmbeddingConf
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_compaction_settings(config: Union[Optional["CompactionSettings"], Dict]) -> Optional[Dict]: # noqa: F821
|
||||
def serialize_compaction_settings(config: Union[Optional["CompactionSettings"], Dict]) -> Optional[Dict]:
|
||||
"""Convert a CompactionSettings object into a JSON-serializable dictionary."""
|
||||
if config:
|
||||
# Import here to avoid circular dependency
|
||||
@@ -124,7 +127,7 @@ def serialize_compaction_settings(config: Union[Optional["CompactionSettings"],
|
||||
return config
|
||||
|
||||
|
||||
def deserialize_compaction_settings(data: Optional[Dict]) -> Optional["CompactionSettings"]: # noqa: F821
|
||||
def deserialize_compaction_settings(data: Optional[Dict]) -> Optional["CompactionSettings"]:
|
||||
"""Convert a dictionary back into a CompactionSettings object."""
|
||||
if data:
|
||||
# Import here to avoid circular dependency
|
||||
|
||||
@@ -306,7 +306,9 @@ async def search_pinecone_index(query: str, limit: int, filter: Dict[str, Any],
|
||||
|
||||
@pinecone_retry()
|
||||
@trace_method
|
||||
async def list_pinecone_index_for_files(file_id: str, actor: User, limit: int = None, pagination_token: str = None) -> List[str]:
|
||||
async def list_pinecone_index_for_files(
|
||||
file_id: str, actor: User, limit: int | None = None, pagination_token: str | None = None
|
||||
) -> List[str]:
|
||||
if not PINECONE_AVAILABLE:
|
||||
raise ImportError("Pinecone is not available. Please install pinecone to use this feature.")
|
||||
|
||||
|
||||
@@ -201,7 +201,7 @@ def add_pre_execution_message(tool_schema: Dict[str, Any], description: Optional
|
||||
|
||||
# Ensure pre-execution message is the first required field
|
||||
if PRE_EXECUTION_MESSAGE_ARG not in required:
|
||||
required = [PRE_EXECUTION_MESSAGE_ARG] + required
|
||||
required = [PRE_EXECUTION_MESSAGE_ARG, *required]
|
||||
|
||||
# Update the schema with ordered properties and required list
|
||||
schema["parameters"] = {
|
||||
|
||||
@@ -6,7 +6,11 @@ import logging
|
||||
import random
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Optional, Tuple, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -95,7 +99,6 @@ def async_retry_with_backoff(
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -106,7 +109,6 @@ def async_retry_with_backoff(
|
||||
# Not a transient error, re-raise immediately
|
||||
raise
|
||||
|
||||
last_error = e
|
||||
num_retries += 1
|
||||
|
||||
# Log the retry attempt
|
||||
@@ -161,11 +163,11 @@ def _run_turbopuffer_write_in_thread(
|
||||
api_key: str,
|
||||
region: str,
|
||||
namespace_name: str,
|
||||
upsert_columns: dict = None,
|
||||
deletes: list = None,
|
||||
delete_by_filter: tuple = None,
|
||||
upsert_columns: dict | None = None,
|
||||
deletes: list | None = None,
|
||||
delete_by_filter: tuple | None = None,
|
||||
distance_metric: str = "cosine_distance",
|
||||
schema: dict = None,
|
||||
schema: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Sync wrapper to run turbopuffer write in isolated event loop.
|
||||
@@ -229,7 +231,7 @@ class TurbopufferClient:
|
||||
embedding_chunk_size=DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str = None, region: str = None):
|
||||
def __init__(self, api_key: str | None = None, region: str | None = None):
|
||||
"""Initialize Turbopuffer client."""
|
||||
self.api_key = api_key or settings.tpuf_api_key
|
||||
self.region = region or settings.tpuf_region
|
||||
@@ -244,7 +246,7 @@ class TurbopufferClient:
|
||||
raise ValueError("Turbopuffer API key not provided")
|
||||
|
||||
@trace_method
|
||||
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]: # noqa: F821
|
||||
async def _generate_embeddings(self, texts: List[str], actor: "PydanticUser") -> List[List[float]]:
|
||||
"""Generate embeddings using the default embedding configuration.
|
||||
|
||||
Args:
|
||||
@@ -311,7 +313,7 @@ class TurbopufferClient:
|
||||
|
||||
return namespace_name
|
||||
|
||||
def _extract_tool_text(self, tool: "PydanticTool") -> str: # noqa: F821
|
||||
def _extract_tool_text(self, tool: "PydanticTool") -> str:
|
||||
"""Extract searchable text from a tool for embedding.
|
||||
|
||||
Combines name, description, and JSON schema into a structured format
|
||||
@@ -361,9 +363,9 @@ class TurbopufferClient:
|
||||
@async_retry_with_backoff()
|
||||
async def insert_tools(
|
||||
self,
|
||||
tools: List["PydanticTool"], # noqa: F821
|
||||
tools: List["PydanticTool"],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser", # noqa: F821
|
||||
actor: "PydanticUser",
|
||||
) -> bool:
|
||||
"""Insert tools into Turbopuffer.
|
||||
|
||||
@@ -456,7 +458,7 @@ class TurbopufferClient:
|
||||
text_chunks: List[str],
|
||||
passage_ids: List[str],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser", # noqa: F821
|
||||
actor: "PydanticUser",
|
||||
tags: Optional[List[str]] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
embeddings: Optional[List[List[float]]] = None,
|
||||
@@ -607,7 +609,7 @@ class TurbopufferClient:
|
||||
message_texts: List[str],
|
||||
message_ids: List[str],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser", # noqa: F821
|
||||
actor: "PydanticUser",
|
||||
roles: List[MessageRole],
|
||||
created_ats: List[datetime],
|
||||
project_id: Optional[str] = None,
|
||||
@@ -867,7 +869,7 @@ class TurbopufferClient:
|
||||
async def query_passages(
|
||||
self,
|
||||
archive_id: str,
|
||||
actor: "PydanticUser", # noqa: F821
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
||||
top_k: int = 10,
|
||||
@@ -1012,7 +1014,7 @@ class TurbopufferClient:
|
||||
self,
|
||||
agent_id: str,
|
||||
organization_id: str,
|
||||
actor: "PydanticUser", # noqa: F821
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
||||
top_k: int = 10,
|
||||
@@ -1188,7 +1190,7 @@ class TurbopufferClient:
|
||||
async def query_messages_by_org_id(
|
||||
self,
|
||||
organization_id: str,
|
||||
actor: "PydanticUser", # noqa: F821
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "hybrid", # "vector", "fts", "hybrid"
|
||||
top_k: int = 10,
|
||||
@@ -1654,7 +1656,7 @@ class TurbopufferClient:
|
||||
file_id: str,
|
||||
text_chunks: List[str],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser", # noqa: F821
|
||||
actor: "PydanticUser",
|
||||
created_at: Optional[datetime] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Insert file passages into Turbopuffer using org-scoped namespace.
|
||||
@@ -1767,7 +1769,7 @@ class TurbopufferClient:
|
||||
self,
|
||||
source_ids: List[str],
|
||||
organization_id: str,
|
||||
actor: "PydanticUser", # noqa: F821
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid"
|
||||
top_k: int = 10,
|
||||
@@ -1991,7 +1993,7 @@ class TurbopufferClient:
|
||||
async def query_tools(
|
||||
self,
|
||||
organization_id: str,
|
||||
actor: "PydanticUser", # noqa: F821
|
||||
actor: "PydanticUser",
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "hybrid", # "vector", "fts", "hybrid", "timestamp"
|
||||
top_k: int = 50,
|
||||
|
||||
@@ -136,7 +136,7 @@ class CLIInterface(AgentInterface):
|
||||
else:
|
||||
try:
|
||||
msg_json = json_loads(msg)
|
||||
except:
|
||||
except Exception:
|
||||
printd(f"{CLI_WARNING_PREFIX}failed to parse user message into json")
|
||||
printd_user_message("🧑", msg)
|
||||
return
|
||||
|
||||
@@ -3,7 +3,12 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import (
|
||||
@@ -146,7 +151,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
return tool_calls[0]
|
||||
return None
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
@@ -232,7 +237,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
@@ -287,7 +292,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
async def _process_event(
|
||||
self,
|
||||
event: BetaRawMessageStreamEvent,
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
|
||||
@@ -3,7 +3,12 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import (
|
||||
@@ -116,7 +121,7 @@ class AnthropicStreamingInterface:
|
||||
# Attempt to use OptimisticJSONParser to handle incomplete/malformed JSON
|
||||
try:
|
||||
tool_input = self.json_parser.parse(args_str)
|
||||
except:
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to decode tool call arguments for tool_call_id={self.tool_call_id}, "
|
||||
f"name={self.tool_call_name}. Raw input: {args_str!r}. Error: {e}"
|
||||
@@ -128,7 +133,7 @@ class AnthropicStreamingInterface:
|
||||
arguments = str(json.dumps(tool_input, indent=2))
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
@@ -222,7 +227,7 @@ class AnthropicStreamingInterface:
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
@@ -276,7 +281,7 @@ class AnthropicStreamingInterface:
|
||||
async def _process_event(
|
||||
self,
|
||||
event: BetaRawMessageStreamEvent,
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
@@ -650,7 +655,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
# Attempt to use OptimisticJSONParser to handle incomplete/malformed JSON
|
||||
try:
|
||||
tool_input = self.json_parser.parse(args_str)
|
||||
except:
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Failed to decode tool call arguments for tool_call_id={self.tool_call_id}, "
|
||||
f"name={self.tool_call_name}. Raw input: {args_str!r}. Error: {e}"
|
||||
@@ -662,7 +667,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
arguments = str(json.dumps(tool_input, indent=2))
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=arguments, name=self.tool_call_name))
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
@@ -754,7 +759,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[BetaRawMessageStreamEvent],
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
@@ -803,7 +808,7 @@ class SimpleAnthropicStreamingInterface:
|
||||
async def _process_event(
|
||||
self,
|
||||
event: BetaRawMessageStreamEvent,
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
|
||||
@@ -3,7 +3,12 @@ import base64
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncIterator, List, Optional
|
||||
from typing import TYPE_CHECKING, AsyncIterator, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
from google.genai.types import (
|
||||
GenerateContentResponse,
|
||||
@@ -124,7 +129,7 @@ class SimpleGeminiStreamingInterface:
|
||||
"""Return all finalized tool calls collected during this message (parallel supported)."""
|
||||
return list(self.collected_tool_calls)
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
@@ -148,7 +153,7 @@ class SimpleGeminiStreamingInterface:
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncIterator[GenerateContentResponse],
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
"""
|
||||
Iterates over the Gemini stream, yielding SSE events.
|
||||
@@ -202,7 +207,7 @@ class SimpleGeminiStreamingInterface:
|
||||
async def _process_event(
|
||||
self,
|
||||
event: GenerateContentResponse,
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
@@ -194,7 +199,7 @@ class OpenAIStreamingInterface:
|
||||
function=FunctionCall(arguments=self._get_current_function_arguments(), name=function_name),
|
||||
)
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
@@ -219,7 +224,7 @@ class OpenAIStreamingInterface:
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[ChatCompletionChunk],
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
"""
|
||||
Iterates over the OpenAI stream, yielding SSE events.
|
||||
@@ -307,7 +312,7 @@ class OpenAIStreamingInterface:
|
||||
async def _process_chunk(
|
||||
self,
|
||||
chunk: ChatCompletionChunk,
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
@@ -471,7 +476,7 @@ class OpenAIStreamingInterface:
|
||||
# Minimal, robust extraction: only emit the value of "message".
|
||||
# If we buffered a prefix while name was streaming, feed it first.
|
||||
if self._function_args_buffer_parts:
|
||||
payload = "".join(self._function_args_buffer_parts + [tool_call.function.arguments])
|
||||
payload = "".join([*self._function_args_buffer_parts, tool_call.function.arguments])
|
||||
self._function_args_buffer_parts = None
|
||||
else:
|
||||
payload = tool_call.function.arguments
|
||||
@@ -498,7 +503,7 @@ class OpenAIStreamingInterface:
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self._function_args_buffer_parts:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = "".join(self._function_args_buffer_parts + [updates_main_json])
|
||||
combined_chunk = "".join([*self._function_args_buffer_parts, updates_main_json])
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
if self._get_function_name_buffer() in self.requires_approval_tools:
|
||||
@@ -588,7 +593,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
messages: Optional[list] = None,
|
||||
tools: Optional[list] = None,
|
||||
requires_approval_tools: list = [],
|
||||
model: str = None,
|
||||
model: str | None = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
@@ -639,7 +644,6 @@ class SimpleOpenAIStreamingInterface:
|
||||
|
||||
def get_content(self) -> list[TextContent | OmittedReasoningContent | ReasoningContent]:
|
||||
shown_omitted = False
|
||||
concat_content = ""
|
||||
merged_messages = []
|
||||
reasoning_content = []
|
||||
concat_content_parts: list[str] = []
|
||||
@@ -694,7 +698,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
raise ValueError("No tool calls available")
|
||||
return calls[0]
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
@@ -719,7 +723,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[ChatCompletionChunk],
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
"""
|
||||
Iterates over the OpenAI stream, yielding SSE events.
|
||||
@@ -833,7 +837,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
async def _process_chunk(
|
||||
self,
|
||||
chunk: ChatCompletionChunk,
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
@@ -984,7 +988,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
messages: Optional[list] = None,
|
||||
tools: Optional[list] = None,
|
||||
requires_approval_tools: list = [],
|
||||
model: str = None,
|
||||
model: str | None = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
@@ -1120,7 +1124,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
raise ValueError("No tool calls available")
|
||||
return calls[0]
|
||||
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics": # noqa: F821
|
||||
def get_usage_statistics(self) -> "LettaUsageStatistics":
|
||||
"""Extract usage statistics from accumulated streaming data.
|
||||
|
||||
Returns:
|
||||
@@ -1141,7 +1145,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncStream[ResponseStreamEvent],
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
"""
|
||||
Iterates over the OpenAI stream, yielding SSE events.
|
||||
@@ -1227,7 +1231,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
async def _process_event(
|
||||
self,
|
||||
event: ResponseStreamEvent,
|
||||
ttft_span: Optional["Span"] = None, # noqa: F821
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
@@ -1250,8 +1254,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
if isinstance(new_event_item, ResponseReasoningItem):
|
||||
# Look for summary delta, or encrypted_content
|
||||
summary = new_event_item.summary
|
||||
content = new_event_item.content # NOTE: always none
|
||||
encrypted_content = new_event_item.encrypted_content
|
||||
# TODO change to summarize reasoning message, but we need to figure out the streaming indices of summary problem
|
||||
concat_summary = "".join([s.text for s in summary])
|
||||
if concat_summary != "":
|
||||
@@ -1390,7 +1392,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
# NOTE: is this inclusive of the deltas?
|
||||
# If not, we should add it to the rolling
|
||||
summary_index = event.summary_index
|
||||
text = event.text
|
||||
return
|
||||
|
||||
# Reasoning summary streaming
|
||||
@@ -1432,7 +1433,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
# Assistant message streaming
|
||||
elif isinstance(event, ResponseTextDoneEvent):
|
||||
# NOTE: inclusive, can skip
|
||||
text = event.text
|
||||
return
|
||||
|
||||
# Assistant message done
|
||||
@@ -1447,7 +1447,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
delta = event.delta
|
||||
|
||||
# Resolve tool_call_id/name using output_index or item_id
|
||||
resolved_call_id, resolved_name, out_idx, item_id = self._resolve_mapping_for_delta(event)
|
||||
resolved_call_id, resolved_name, _out_idx, _item_id = self._resolve_mapping_for_delta(event)
|
||||
|
||||
# Fallback to last seen tool name for approval routing if mapping name missing
|
||||
if not resolved_name:
|
||||
@@ -1493,7 +1493,6 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
# Function calls
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
|
||||
# NOTE: inclusive
|
||||
full_args = event.arguments
|
||||
return
|
||||
|
||||
# Generic
|
||||
|
||||
@@ -94,7 +94,7 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
|
||||
if scheduler.running:
|
||||
try:
|
||||
scheduler.shutdown(wait=False)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
finally:
|
||||
|
||||
@@ -406,7 +406,7 @@ class AnthropicClient(LLMClientBase):
|
||||
for agent_id in agent_messages_mapping
|
||||
}
|
||||
|
||||
client = await self._get_anthropic_client_async(list(agent_llm_config_mapping.values())[0], async_client=True)
|
||||
client = await self._get_anthropic_client_async(next(iter(agent_llm_config_mapping.values())), async_client=True)
|
||||
|
||||
anthropic_requests = [
|
||||
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
|
||||
@@ -599,7 +599,7 @@ class AnthropicClient(LLMClientBase):
|
||||
# Special case for summarization path
|
||||
tools_for_request = None
|
||||
tool_choice = None
|
||||
elif self.is_reasoning_model(llm_config) and llm_config.enable_reasoner or agent_type == AgentType.letta_v1_agent:
|
||||
elif (self.is_reasoning_model(llm_config) and llm_config.enable_reasoner) or agent_type == AgentType.letta_v1_agent:
|
||||
# NOTE: reasoning models currently do not allow for `any`
|
||||
# NOTE: react agents should always have at least auto on, since the precense/absense of tool calls controls chaining
|
||||
if agent_type == AgentType.split_thread_agent and force_tool_call is not None:
|
||||
@@ -785,7 +785,9 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
return data
|
||||
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
|
||||
async def count_tokens(
|
||||
self, messages: List[dict] | None = None, model: str | None = None, tools: List[OpenAITool] | None = None
|
||||
) -> int:
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
# Use the default client; token counting is lightweight and does not require BYOK overrides
|
||||
client = anthropic.AsyncAnthropic()
|
||||
@@ -1104,7 +1106,7 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
if isinstance(e, anthropic.APIStatusError):
|
||||
logger.warning(f"[Anthropic] API status error: {str(e)}")
|
||||
if hasattr(e, "status_code") and e.status_code == 402 or is_insufficient_credits_message(str(e)):
|
||||
if (hasattr(e, "status_code") and e.status_code == 402) or is_insufficient_credits_message(str(e)):
|
||||
msg = str(e)
|
||||
return LLMInsufficientCreditsError(
|
||||
message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}",
|
||||
@@ -1247,7 +1249,7 @@ class AnthropicClient(LLMClientBase):
|
||||
args_json = json.loads(arguments)
|
||||
if not isinstance(args_json, dict):
|
||||
raise LLMServerError("Expected parseable json object for arguments")
|
||||
except:
|
||||
except Exception:
|
||||
arguments = str(tool_input["function"]["arguments"])
|
||||
else:
|
||||
arguments = json.dumps(tool_input, indent=2)
|
||||
@@ -1539,7 +1541,7 @@ def is_heartbeat(message: dict, is_ping: bool = False) -> bool:
|
||||
|
||||
try:
|
||||
message_json = json.loads(message["content"])
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Check if message_json is a dict (not int, str, list, etc.)
|
||||
|
||||
@@ -302,7 +302,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
for item_schema in schema_part[key]:
|
||||
self._clean_google_ai_schema_properties(item_schema)
|
||||
|
||||
def _resolve_json_schema_refs(self, schema: dict, defs: dict = None) -> dict:
|
||||
def _resolve_json_schema_refs(self, schema: dict, defs: dict | None = None) -> dict:
|
||||
"""
|
||||
Recursively resolve $ref in JSON schema by inlining definitions.
|
||||
Google GenAI SDK does not support $ref.
|
||||
@@ -1057,7 +1057,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
# Fallback to base implementation for other errors
|
||||
return super().handle_llm_error(e, llm_config=llm_config)
|
||||
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
|
||||
async def count_tokens(
|
||||
self, messages: List[dict] | None = None, model: str | None = None, tools: List[OpenAITool] | None = None
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens for the given messages and tools using the Gemini token counting API.
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ async def openai_get_model_list_async(
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
try:
|
||||
error_response = http_err.response.json()
|
||||
except:
|
||||
except Exception:
|
||||
error_response = {"status_code": http_err.response.status_code, "text": http_err.response.text}
|
||||
logger.debug(f"Got HTTPError, exception={http_err}, response={error_response}")
|
||||
raise http_err
|
||||
|
||||
@@ -106,7 +106,7 @@ def accepts_developer_role(model: str) -> bool:
|
||||
|
||||
See: https://community.openai.com/t/developer-role-not-accepted-for-o1-o1-mini-o3-mini/1110750/7
|
||||
"""
|
||||
if is_openai_reasoning_model(model) and "o1-mini" not in model or "o1-preview" in model:
|
||||
if (is_openai_reasoning_model(model) and "o1-mini" not in model) or "o1-preview" in model:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -459,7 +459,7 @@ class OpenAIClient(LLMClientBase):
|
||||
if is_openrouter:
|
||||
try:
|
||||
model = llm_config.handle.split("/", 1)[-1]
|
||||
except:
|
||||
except Exception:
|
||||
# don't raise error since this isn't robust against edge cases
|
||||
pass
|
||||
|
||||
@@ -747,7 +747,6 @@ class OpenAIClient(LLMClientBase):
|
||||
finish_reason = None
|
||||
|
||||
# Optionally capture reasoning presence
|
||||
found_reasoning = False
|
||||
for out in outputs:
|
||||
out_type = (out or {}).get("type")
|
||||
if out_type == "message":
|
||||
@@ -758,7 +757,6 @@ class OpenAIClient(LLMClientBase):
|
||||
if text_val:
|
||||
assistant_text_parts.append(text_val)
|
||||
elif out_type == "reasoning":
|
||||
found_reasoning = True
|
||||
reasoning_summary_parts = [part.get("text") for part in out.get("summary")]
|
||||
reasoning_content_signature = out.get("encrypted_content")
|
||||
elif out_type == "function_call":
|
||||
|
||||
@@ -16,7 +16,7 @@ from letta.local_llm.llamacpp.api import get_llamacpp_completion
|
||||
from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
|
||||
from letta.local_llm.lmstudio.api import get_lmstudio_completion, get_lmstudio_completion_chatcompletions
|
||||
from letta.local_llm.ollama.api import get_ollama_completion
|
||||
from letta.local_llm.utils import count_tokens, get_available_wrappers
|
||||
from letta.local_llm.utils import get_available_wrappers
|
||||
from letta.local_llm.vllm.api import get_vllm_completion
|
||||
from letta.local_llm.webui.api import get_webui_completion
|
||||
from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
|
||||
@@ -177,7 +177,7 @@ def get_chat_completion(
|
||||
raise LocalLLMError(
|
||||
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
|
||||
)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}")
|
||||
|
||||
attributes = usage if isinstance(usage, dict) else {"usage": usage}
|
||||
@@ -207,10 +207,12 @@ def get_chat_completion(
|
||||
|
||||
if usage["prompt_tokens"] is None:
|
||||
printd("usage dict was missing prompt_tokens, computing on-the-fly...")
|
||||
usage["prompt_tokens"] = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
usage["prompt_tokens"] = len(prompt.encode("utf-8")) // 4
|
||||
|
||||
# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
|
||||
usage["completion_tokens"] = count_tokens(json_dumps(chat_completion_result))
|
||||
# Approximate token count: bytes / 4
|
||||
usage["completion_tokens"] = len(json_dumps(chat_completion_result).encode("utf-8")) // 4
|
||||
"""
|
||||
if usage["completion_tokens"] is None:
|
||||
printd(f"usage dict was missing completion_tokens, computing on-the-fly...")
|
||||
|
||||
@@ -5,7 +5,7 @@ from copy import copy
|
||||
from enum import Enum
|
||||
from inspect import getdoc, isclass
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union, _GenericAlias, get_args, get_origin
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union, _GenericAlias, get_args, get_origin # type: ignore[attr-defined]
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, create_model
|
||||
@@ -58,13 +58,13 @@ def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
|
||||
|
||||
elif isclass(pydantic_type) and issubclass(pydantic_type, BaseModel):
|
||||
return format_model_and_field_name(pydantic_type.__name__)
|
||||
elif get_origin(pydantic_type) == list:
|
||||
elif get_origin(pydantic_type) is list:
|
||||
element_type = get_args(pydantic_type)[0]
|
||||
return f"{map_pydantic_type_to_gbnf(element_type)}-list"
|
||||
elif get_origin(pydantic_type) == set:
|
||||
elif get_origin(pydantic_type) is set:
|
||||
element_type = get_args(pydantic_type)[0]
|
||||
return f"{map_pydantic_type_to_gbnf(element_type)}-set"
|
||||
elif get_origin(pydantic_type) == Union:
|
||||
elif get_origin(pydantic_type) is Union:
|
||||
union_types = get_args(pydantic_type)
|
||||
union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types]
|
||||
return f"union-{'-or-'.join(union_rules)}"
|
||||
@@ -73,7 +73,7 @@ def map_pydantic_type_to_gbnf(pydantic_type: Type[Any]) -> str:
|
||||
return f"optional-{map_pydantic_type_to_gbnf(element_type)}"
|
||||
elif isclass(pydantic_type):
|
||||
return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(pydantic_type.__name__)}"
|
||||
elif get_origin(pydantic_type) == dict:
|
||||
elif get_origin(pydantic_type) is dict:
|
||||
key_type, value_type = get_args(pydantic_type)
|
||||
return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}"
|
||||
else:
|
||||
@@ -299,7 +299,7 @@ def generate_gbnf_rule_for_type(
|
||||
enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}"
|
||||
rules.append(enum_rule)
|
||||
gbnf_type, rules = model_name + "-" + field_name, rules
|
||||
elif get_origin(field_type) == list: # Array
|
||||
elif get_origin(field_type) is list: # Array
|
||||
element_type = get_args(field_type)[0]
|
||||
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
|
||||
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
|
||||
@@ -309,7 +309,7 @@ def generate_gbnf_rule_for_type(
|
||||
rules.append(array_rule)
|
||||
gbnf_type, rules = model_name + "-" + field_name, rules
|
||||
|
||||
elif get_origin(field_type) == set or field_type == set: # Array
|
||||
elif get_origin(field_type) is set or field_type is set: # Array
|
||||
element_type = get_args(field_type)[0]
|
||||
element_rule_name, additional_rules = generate_gbnf_rule_for_type(
|
||||
model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules
|
||||
@@ -320,7 +320,7 @@ def generate_gbnf_rule_for_type(
|
||||
gbnf_type, rules = model_name + "-" + field_name, rules
|
||||
|
||||
elif gbnf_type.startswith("custom-class-"):
|
||||
nested_model_rules, field_types = get_members_structure(field_type, gbnf_type)
|
||||
nested_model_rules, _field_types = get_members_structure(field_type, gbnf_type)
|
||||
rules.append(nested_model_rules)
|
||||
elif gbnf_type.startswith("custom-dict-"):
|
||||
key_type, value_type = get_args(field_type)
|
||||
@@ -502,15 +502,15 @@ def generate_gbnf_grammar(model: Type[BaseModel], processed_models: set, created
|
||||
model_rule += '"\\n" ws "}"'
|
||||
model_rule += '"\\n" markdown-code-block'
|
||||
has_special_string = True
|
||||
all_rules = [model_rule] + nested_rules
|
||||
all_rules = [model_rule, *nested_rules]
|
||||
|
||||
return all_rules, has_special_string
|
||||
|
||||
|
||||
def generate_gbnf_grammar_from_pydantic_models(
|
||||
models: List[Type[BaseModel]],
|
||||
outer_object_name: str = None,
|
||||
outer_object_content: str = None,
|
||||
outer_object_name: str | None = None,
|
||||
outer_object_content: str | None = None,
|
||||
list_of_outputs: bool = False,
|
||||
add_inner_thoughts: bool = False,
|
||||
allow_only_inner_thoughts: bool = False,
|
||||
@@ -704,11 +704,11 @@ def generate_markdown_documentation(
|
||||
# continue
|
||||
if isclass(field_type) and issubclass(field_type, BaseModel):
|
||||
pyd_models.append((field_type, False))
|
||||
if get_origin(field_type) == list:
|
||||
if get_origin(field_type) is list:
|
||||
element_type = get_args(field_type)[0]
|
||||
if isclass(element_type) and issubclass(element_type, BaseModel):
|
||||
pyd_models.append((element_type, False))
|
||||
if get_origin(field_type) == Union:
|
||||
if get_origin(field_type) is Union:
|
||||
element_types = get_args(field_type)
|
||||
for element_type in element_types:
|
||||
if isclass(element_type) and issubclass(element_type, BaseModel):
|
||||
@@ -747,14 +747,14 @@ def generate_field_markdown(
|
||||
field_info = model.model_fields.get(field_name)
|
||||
field_description = field_info.description if field_info and field_info.description else ""
|
||||
|
||||
if get_origin(field_type) == list:
|
||||
if get_origin(field_type) is list:
|
||||
element_type = get_args(field_type)[0]
|
||||
field_text = f"{indent}{field_name} ({field_type.__name__} of {element_type.__name__})"
|
||||
if field_description != "":
|
||||
field_text += ": "
|
||||
else:
|
||||
field_text += "\n"
|
||||
elif get_origin(field_type) == Union:
|
||||
elif get_origin(field_type) is Union:
|
||||
element_types = get_args(field_type)
|
||||
types = []
|
||||
for element_type in element_types:
|
||||
@@ -857,11 +857,11 @@ def generate_text_documentation(
|
||||
for name, field_type in model.__annotations__.items():
|
||||
# if name == "markdown_code_block":
|
||||
# continue
|
||||
if get_origin(field_type) == list:
|
||||
if get_origin(field_type) is list:
|
||||
element_type = get_args(field_type)[0]
|
||||
if isclass(element_type) and issubclass(element_type, BaseModel):
|
||||
pyd_models.append((element_type, False))
|
||||
if get_origin(field_type) == Union:
|
||||
if get_origin(field_type) is Union:
|
||||
element_types = get_args(field_type)
|
||||
for element_type in element_types:
|
||||
if isclass(element_type) and issubclass(element_type, BaseModel):
|
||||
@@ -905,14 +905,14 @@ def generate_field_text(
|
||||
field_info = model.model_fields.get(field_name)
|
||||
field_description = field_info.description if field_info and field_info.description else ""
|
||||
|
||||
if get_origin(field_type) == list:
|
||||
if get_origin(field_type) is list:
|
||||
element_type = get_args(field_type)[0]
|
||||
field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})"
|
||||
if field_description != "":
|
||||
field_text += ":\n"
|
||||
else:
|
||||
field_text += "\n"
|
||||
elif get_origin(field_type) == Union:
|
||||
elif get_origin(field_type) is Union:
|
||||
element_types = get_args(field_type)
|
||||
types = []
|
||||
for element_type in element_types:
|
||||
@@ -1015,8 +1015,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
|
||||
pydantic_model_list,
|
||||
grammar_file_path="./generated_grammar.gbnf",
|
||||
documentation_file_path="./generated_grammar_documentation.md",
|
||||
outer_object_name: str = None,
|
||||
outer_object_content: str = None,
|
||||
outer_object_name: str | None = None,
|
||||
outer_object_content: str | None = None,
|
||||
model_prefix: str = "Output Model",
|
||||
fields_prefix: str = "Output Fields",
|
||||
list_of_outputs: bool = False,
|
||||
@@ -1049,8 +1049,8 @@ def generate_and_save_gbnf_grammar_and_documentation(
|
||||
|
||||
def generate_gbnf_grammar_and_documentation(
|
||||
pydantic_model_list,
|
||||
outer_object_name: str = None,
|
||||
outer_object_content: str = None,
|
||||
outer_object_name: str | None = None,
|
||||
outer_object_content: str | None = None,
|
||||
model_prefix: str = "Output Model",
|
||||
fields_prefix: str = "Output Fields",
|
||||
list_of_outputs: bool = False,
|
||||
@@ -1087,8 +1087,8 @@ def generate_gbnf_grammar_and_documentation(
|
||||
|
||||
def generate_gbnf_grammar_and_documentation_from_dictionaries(
|
||||
dictionaries: List[dict],
|
||||
outer_object_name: str = None,
|
||||
outer_object_content: str = None,
|
||||
outer_object_name: str | None = None,
|
||||
outer_object_content: str | None = None,
|
||||
model_prefix: str = "Output Model",
|
||||
fields_prefix: str = "Output Fields",
|
||||
list_of_outputs: bool = False,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from letta.local_llm.settings.settings import get_completions_settings
|
||||
from letta.local_llm.utils import count_tokens, post_json_auth_request
|
||||
from letta.local_llm.utils import post_json_auth_request
|
||||
|
||||
KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
|
||||
|
||||
@@ -10,7 +10,8 @@ def get_koboldcpp_completion(endpoint, auth_type, auth_key, prompt, context_wind
|
||||
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
|
||||
from letta.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
prompt_tokens = len(prompt.encode("utf-8")) // 4
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from letta.local_llm.settings.settings import get_completions_settings
|
||||
from letta.local_llm.utils import count_tokens, post_json_auth_request
|
||||
from letta.local_llm.utils import post_json_auth_request
|
||||
|
||||
LLAMACPP_API_SUFFIX = "/completion"
|
||||
|
||||
@@ -10,7 +10,8 @@ def get_llamacpp_completion(endpoint, auth_type, auth_key, prompt, context_windo
|
||||
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
|
||||
from letta.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
prompt_tokens = len(prompt.encode("utf-8")) // 4
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
|
||||
@@ -130,12 +130,12 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
content_json = json_loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\nUSER: {content_simple}"
|
||||
except:
|
||||
except Exception:
|
||||
prompt += f"\nUSER: {message['content']}"
|
||||
elif message["role"] == "assistant":
|
||||
prompt += f"\nASSISTANT: {message['content']}"
|
||||
# need to add the function call if there was one
|
||||
if "function_call" in message and message["function_call"]:
|
||||
if message.get("function_call"):
|
||||
prompt += f"\n{create_function_call(message['function_call'])}"
|
||||
elif message["role"] in ["function", "tool"]:
|
||||
# TODO find a good way to add this
|
||||
@@ -348,7 +348,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
content_json = json_loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\n{user_prefix}: {content_simple}"
|
||||
except:
|
||||
except Exception:
|
||||
prompt += f"\n{user_prefix}: {message['content']}"
|
||||
elif message["role"] == "assistant":
|
||||
# Support for AutoGen naming of agents
|
||||
@@ -360,7 +360,7 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
prompt += f"\n{assistant_prefix}:"
|
||||
# need to add the function call if there was one
|
||||
inner_thoughts = message["content"]
|
||||
if "function_call" in message and message["function_call"]:
|
||||
if message.get("function_call"):
|
||||
prompt += f"\n{create_function_call(message['function_call'], inner_thoughts=inner_thoughts)}"
|
||||
elif message["role"] in ["function", "tool"]:
|
||||
# TODO find a good way to add this
|
||||
|
||||
@@ -143,9 +143,9 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
|
||||
# need to add the function call if there was one
|
||||
inner_thoughts = message["content"]
|
||||
if "function_call" in message and message["function_call"]:
|
||||
if message.get("function_call"):
|
||||
prompt += f"\n{self._compile_function_call(message['function_call'], inner_thoughts=inner_thoughts)}"
|
||||
elif "tool_calls" in message and message["tool_calls"]:
|
||||
elif message.get("tool_calls"):
|
||||
for tool_call in message["tool_calls"]:
|
||||
prompt += f"\n{self._compile_function_call(tool_call['function'], inner_thoughts=inner_thoughts)}"
|
||||
else:
|
||||
@@ -163,14 +163,14 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
try:
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = user_msg_json["message"]
|
||||
except:
|
||||
except Exception:
|
||||
user_msg_str = message["content"]
|
||||
else:
|
||||
# Otherwise just dump the full json
|
||||
try:
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = json_dumps(user_msg_json, indent=self.json_indent)
|
||||
except:
|
||||
except Exception:
|
||||
user_msg_str = message["content"]
|
||||
|
||||
prompt += user_msg_str
|
||||
@@ -185,7 +185,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
# indent the function replies
|
||||
function_return_dict = json_loads(message["content"])
|
||||
function_return_str = json_dumps(function_return_dict, indent=0)
|
||||
except:
|
||||
except Exception:
|
||||
function_return_str = message["content"]
|
||||
|
||||
prompt += function_return_str
|
||||
@@ -218,7 +218,7 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
msg_json = json_loads(message["content"])
|
||||
if msg_json["type"] != "user_message":
|
||||
role_str = "system"
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
prompt += f"\n<|im_start|>{role_str}\n{msg_str.strip()}<|im_end|>"
|
||||
|
||||
|
||||
@@ -141,9 +141,9 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
|
||||
|
||||
# need to add the function call if there was one
|
||||
inner_thoughts = message["content"]
|
||||
if "function_call" in message and message["function_call"]:
|
||||
if message.get("function_call"):
|
||||
prompt += f"\n{self._compile_function_call(message['function_call'], inner_thoughts=inner_thoughts)}"
|
||||
elif "tool_calls" in message and message["tool_calls"]:
|
||||
elif message.get("tool_calls"):
|
||||
for tool_call in message["tool_calls"]:
|
||||
prompt += f"\n{self._compile_function_call(tool_call['function'], inner_thoughts=inner_thoughts)}"
|
||||
else:
|
||||
@@ -161,14 +161,14 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
|
||||
try:
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = user_msg_json["message"]
|
||||
except:
|
||||
except Exception:
|
||||
user_msg_str = message["content"]
|
||||
else:
|
||||
# Otherwise just dump the full json
|
||||
try:
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = json_dumps(user_msg_json, indent=self.json_indent)
|
||||
except:
|
||||
except Exception:
|
||||
user_msg_str = message["content"]
|
||||
|
||||
prompt += user_msg_str
|
||||
@@ -183,7 +183,7 @@ class ConfigurableJSONWrapper(LLMChatCompletionWrapper):
|
||||
# indent the function replies
|
||||
function_return_dict = json_loads(message["content"])
|
||||
function_return_str = json_dumps(function_return_dict, indent=0)
|
||||
except:
|
||||
except Exception:
|
||||
function_return_str = message["content"]
|
||||
|
||||
prompt += function_return_str
|
||||
|
||||
@@ -158,7 +158,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\n{IM_START_TOKEN}user\n{content_simple}{IM_END_TOKEN}"
|
||||
# prompt += f"\nUSER: {content_simple}"
|
||||
except:
|
||||
except Exception:
|
||||
prompt += f"\n{IM_START_TOKEN}user\n{message['content']}{IM_END_TOKEN}"
|
||||
# prompt += f"\nUSER: {message['content']}"
|
||||
elif message["role"] == "assistant":
|
||||
@@ -167,7 +167,7 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
prompt += f"\n{message['content']}"
|
||||
# prompt += f"\nASSISTANT: {message['content']}"
|
||||
# need to add the function call if there was one
|
||||
if "function_call" in message and message["function_call"]:
|
||||
if message.get("function_call"):
|
||||
prompt += f"\n{create_function_call(message['function_call'])}"
|
||||
prompt += f"{IM_END_TOKEN}"
|
||||
elif message["role"] in ["function", "tool"]:
|
||||
|
||||
@@ -142,9 +142,9 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
|
||||
# need to add the function call if there was one
|
||||
inner_thoughts = message["content"]
|
||||
if "function_call" in message and message["function_call"]:
|
||||
if message.get("function_call"):
|
||||
prompt += f"\n{self._compile_function_call(message['function_call'], inner_thoughts=inner_thoughts)}"
|
||||
elif "tool_calls" in message and message["tool_calls"]:
|
||||
elif message.get("tool_calls"):
|
||||
for tool_call in message["tool_calls"]:
|
||||
prompt += f"\n{self._compile_function_call(tool_call['function'], inner_thoughts=inner_thoughts)}"
|
||||
else:
|
||||
@@ -162,7 +162,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
try:
|
||||
user_msg_json = json_loads(message["content"])
|
||||
user_msg_str = user_msg_json["message"]
|
||||
except:
|
||||
except Exception:
|
||||
user_msg_str = message["content"]
|
||||
else:
|
||||
# Otherwise just dump the full json
|
||||
@@ -172,7 +172,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
user_msg_json,
|
||||
indent=self.json_indent,
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
user_msg_str = message["content"]
|
||||
|
||||
prompt += user_msg_str
|
||||
@@ -190,7 +190,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
function_return_dict,
|
||||
indent=self.json_indent,
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
function_return_str = message["content"]
|
||||
|
||||
prompt += function_return_str
|
||||
@@ -223,7 +223,7 @@ class LLaMA3InnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
msg_json = json_loads(message["content"])
|
||||
if msg_json["type"] != "user_message":
|
||||
role_str = "system"
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
prompt += f"\n<|start_header_id|>{role_str}<|end_header_id|>\n\n{msg_str.strip()}<|eot_id|>"
|
||||
|
||||
|
||||
@@ -101,14 +101,14 @@ class SimpleSummaryWrapper(LLMChatCompletionWrapper):
|
||||
content_json = json_loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\nUSER: {content_simple}"
|
||||
except:
|
||||
except Exception:
|
||||
prompt += f"\nUSER: {message['content']}"
|
||||
elif message["role"] == "assistant":
|
||||
prompt += f"\nASSISTANT: {message['content']}"
|
||||
# need to add the function call if there was one
|
||||
if "function_call" in message and message["function_call"]:
|
||||
if message.get("function_call"):
|
||||
prompt += f"\n{create_function_call(message['function_call'])}"
|
||||
elif "tool_calls" in message and message["tool_calls"]:
|
||||
elif message.get("tool_calls"):
|
||||
prompt += f"\n{create_function_call(message['tool_calls'][0]['function'])}"
|
||||
elif message["role"] in ["function", "tool"]:
|
||||
# TODO find a good way to add this
|
||||
|
||||
@@ -88,7 +88,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\n<|user|>\n{content_simple}{IM_END_TOKEN}"
|
||||
# prompt += f"\nUSER: {content_simple}"
|
||||
except:
|
||||
except Exception:
|
||||
prompt += f"\n<|user|>\n{message['content']}{IM_END_TOKEN}"
|
||||
# prompt += f"\nUSER: {message['content']}"
|
||||
elif message["role"] == "assistant":
|
||||
@@ -97,7 +97,7 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
|
||||
prompt += f"\n{message['content']}"
|
||||
# prompt += f"\nASSISTANT: {message['content']}"
|
||||
# need to add the function call if there was one
|
||||
if "function_call" in message and message["function_call"]:
|
||||
if message.get("function_call"):
|
||||
prompt += f"\n{create_function_call(message['function_call'])}"
|
||||
prompt += f"{IM_END_TOKEN}"
|
||||
elif message["role"] in ["function", "tool"]:
|
||||
@@ -256,7 +256,7 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
|
||||
content_json = json_loads(message["content"])
|
||||
content_simple = content_json["message"]
|
||||
prompt += f"\n<|user|>\n{content_simple}{IM_END_TOKEN}"
|
||||
except:
|
||||
except Exception:
|
||||
prompt += f"\n<|user|>\n{message['content']}{IM_END_TOKEN}"
|
||||
elif message["role"] == "assistant":
|
||||
prompt += "\n<|assistant|>"
|
||||
|
||||
@@ -3,7 +3,6 @@ from urllib.parse import urljoin
|
||||
|
||||
from letta.local_llm.settings.settings import get_completions_settings
|
||||
from letta.local_llm.utils import post_json_auth_request
|
||||
from letta.utils import count_tokens
|
||||
|
||||
LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
|
||||
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
|
||||
@@ -80,7 +79,8 @@ def get_lmstudio_completion(endpoint, auth_type, auth_key, prompt, context_windo
|
||||
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
|
||||
from letta.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
prompt_tokens = len(prompt.encode("utf-8")) // 4
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from urllib.parse import urljoin
|
||||
from letta.errors import LocalLLMError
|
||||
from letta.local_llm.settings.settings import get_completions_settings
|
||||
from letta.local_llm.utils import post_json_auth_request
|
||||
from letta.utils import count_tokens
|
||||
|
||||
OLLAMA_API_SUFFIX = "/api/generate"
|
||||
|
||||
@@ -12,7 +11,8 @@ def get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_
|
||||
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
|
||||
from letta.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
prompt_tokens = len(prompt.encode("utf-8")) // 4
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from letta.local_llm.settings.settings import get_completions_settings
|
||||
from letta.local_llm.utils import count_tokens, post_json_auth_request
|
||||
from letta.local_llm.utils import post_json_auth_request
|
||||
|
||||
WEBUI_API_SUFFIX = "/completions"
|
||||
|
||||
@@ -10,7 +10,8 @@ def get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_wi
|
||||
"""https://github.com/vllm-project/vllm/blob/main/examples/api_client.py"""
|
||||
from letta.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
prompt_tokens = len(prompt.encode("utf-8")) // 4
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from letta.local_llm.settings.settings import get_completions_settings
|
||||
from letta.local_llm.utils import count_tokens, post_json_auth_request
|
||||
from letta.local_llm.utils import post_json_auth_request
|
||||
|
||||
WEBUI_API_SUFFIX = "/v1/completions"
|
||||
|
||||
@@ -10,7 +10,8 @@ def get_webui_completion(endpoint, auth_type, auth_key, prompt, context_window,
|
||||
"""Compatibility for the new OpenAI API: https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples"""
|
||||
from letta.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
prompt_tokens = len(prompt.encode("utf-8")) // 4
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from letta.local_llm.settings.settings import get_completions_settings
|
||||
from letta.local_llm.utils import count_tokens, post_json_auth_request
|
||||
from letta.local_llm.utils import post_json_auth_request
|
||||
|
||||
WEBUI_API_SUFFIX = "/api/v1/generate"
|
||||
|
||||
@@ -10,7 +10,8 @@ def get_webui_completion(endpoint, auth_type, auth_key, prompt, context_window,
|
||||
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
|
||||
from letta.utils import printd
|
||||
|
||||
prompt_tokens = count_tokens(prompt)
|
||||
# Approximate token count: bytes / 4
|
||||
prompt_tokens = len(prompt.encode("utf-8")) // 4
|
||||
if prompt_tokens > context_window:
|
||||
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
|
||||
|
||||
|
||||
@@ -32,9 +32,12 @@ if TYPE_CHECKING:
|
||||
from letta.orm.archives_agents import ArchivesAgents
|
||||
from letta.orm.conversation import Conversation
|
||||
from letta.orm.files_agents import FileAgent
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.llm_batch_items import LLMBatchItem
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.run import Run
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.tool import Tool
|
||||
|
||||
@@ -122,7 +125,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents", lazy="raise")
|
||||
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship( # noqa: F821
|
||||
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
|
||||
"AgentEnvironmentVariable",
|
||||
back_populates="agent",
|
||||
cascade="all, delete-orphan",
|
||||
@@ -160,14 +163,14 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
back_populates="agents",
|
||||
passive_deletes=True,
|
||||
)
|
||||
groups: Mapped[List["Group"]] = relationship( # noqa: F821
|
||||
groups: Mapped[List["Group"]] = relationship(
|
||||
"Group",
|
||||
secondary="groups_agents",
|
||||
lazy="raise",
|
||||
back_populates="agents",
|
||||
passive_deletes=True,
|
||||
)
|
||||
multi_agent_group: Mapped["Group"] = relationship( # noqa: F821
|
||||
multi_agent_group: Mapped["Group"] = relationship(
|
||||
"Group",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
@@ -175,7 +178,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateEntityMixin
|
||||
foreign_keys="[Group.manager_agent_id]",
|
||||
uselist=False,
|
||||
)
|
||||
batch_items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="agent", lazy="raise") # noqa: F821
|
||||
batch_items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="agent", lazy="raise")
|
||||
file_agents: Mapped[List["FileAgent"]] = relationship(
|
||||
"FileAgent",
|
||||
back_populates="agent",
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
|
||||
from sqlalchemy import ForeignKey, Index, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@@ -21,4 +26,4 @@ class AgentsTags(Base):
|
||||
tag: Mapped[str] = mapped_column(String, doc="The name of the tag associated with the agent.", primary_key=True)
|
||||
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="tags") # noqa: F821
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="tags")
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.archive import Archive
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
@@ -23,5 +28,5 @@ class ArchivesAgents(Base):
|
||||
is_owner: Mapped[bool] = mapped_column(Boolean, default=False, doc="Whether this agent created/owns the archive")
|
||||
|
||||
# relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="archives_agents") # noqa: F821
|
||||
archive: Mapped["Archive"] = relationship("Archive", back_populates="archives_agents") # noqa: F821
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="archives_agents")
|
||||
archive: Mapped["Archive"] = relationship("Archive", back_populates="archives_agents")
|
||||
|
||||
@@ -78,7 +78,7 @@ class CommonSqlalchemyMetaMixins(Base):
|
||||
setattr(self, full_prop, None)
|
||||
return
|
||||
# Safety check
|
||||
prefix, id_ = value.split("-", 1)
|
||||
prefix, _id = value.split("-", 1)
|
||||
assert prefix == "user", f"{prefix} is not a valid id prefix for a user id"
|
||||
|
||||
# Set the full value
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Type
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, ForeignKey, Index, Integer, String, UniqueConstraint, event
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
|
||||
@@ -11,7 +11,9 @@ from letta.schemas.block import Block as PydanticBlock, Human, Persona
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import Organization
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.blocks_tags import BlocksTags
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.identity import Identity
|
||||
|
||||
|
||||
@@ -56,11 +58,11 @@ class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin, TemplateEntityMixin
|
||||
)
|
||||
# NOTE: This takes advantage of built-in optimistic locking functionality by SqlAlchemy
|
||||
# https://docs.sqlalchemy.org/en/20/orm/versioning.html
|
||||
__mapper_args__ = {"version_id_col": version}
|
||||
__mapper_args__: ClassVar[dict] = {"version_id_col": version}
|
||||
|
||||
# relationships
|
||||
organization: Mapped[Optional["Organization"]] = relationship("Organization", lazy="raise")
|
||||
agents: Mapped[List["Agent"]] = relationship( # noqa: F821
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent",
|
||||
secondary="blocks_agents",
|
||||
lazy="raise",
|
||||
@@ -75,7 +77,7 @@ class Block(OrganizationMixin, SqlalchemyBase, ProjectMixin, TemplateEntityMixin
|
||||
back_populates="blocks",
|
||||
passive_deletes=True,
|
||||
)
|
||||
groups: Mapped[List["Group"]] = relationship( # noqa: F821
|
||||
groups: Mapped[List["Group"]] = relationship(
|
||||
"Group",
|
||||
secondary="groups_blocks",
|
||||
lazy="raise",
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.block import Block
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, UniqueConstraint, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
@@ -34,4 +37,4 @@ class BlocksTags(Base):
|
||||
_last_updated_by_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
block: Mapped["Block"] = relationship("Block", back_populates="tags") # noqa: F821
|
||||
block: Mapped["Block"] = relationship("Block", back_populates="tags")
|
||||
|
||||
@@ -12,7 +12,7 @@ from letta.schemas.file import FileAgent as PydanticFileAgent
|
||||
from letta.utils import truncate_file_visible_content
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from letta.orm.agent import Agent
|
||||
|
||||
|
||||
class FileAgent(SqlalchemyBase, OrganizationMixin):
|
||||
@@ -85,7 +85,7 @@ class FileAgent(SqlalchemyBase, OrganizationMixin):
|
||||
)
|
||||
|
||||
# relationships
|
||||
agent: Mapped["Agent"] = relationship( # noqa: F821
|
||||
agent: Mapped["Agent"] = relationship(
|
||||
"Agent",
|
||||
back_populates="file_agents",
|
||||
lazy="selectin",
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
from sqlalchemy import JSON, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
@@ -27,12 +32,12 @@ class Group(SqlalchemyBase, OrganizationMixin, ProjectMixin, TemplateMixin):
|
||||
hidden: Mapped[Optional[bool]] = mapped_column(nullable=True, doc="If set to True, the group will be hidden.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups") # noqa: F821
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
|
||||
agent_ids: Mapped[List[str]] = mapped_column(JSON, nullable=False, doc="Ordered list of agent IDs in this group")
|
||||
agents: Mapped[List["Agent"]] = relationship( # noqa: F821
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
|
||||
)
|
||||
shared_blocks: Mapped[List["Block"]] = relationship( # noqa: F821
|
||||
shared_blocks: Mapped[List["Block"]] = relationship(
|
||||
"Block", secondary="groups_blocks", lazy="selectin", passive_deletes=True, back_populates="groups"
|
||||
)
|
||||
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group") # noqa: F821
|
||||
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group")
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
from sqlalchemy import String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
@@ -36,11 +41,11 @@ class Identity(SqlalchemyBase, OrganizationMixin, ProjectMixin):
|
||||
)
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="identities") # noqa: F821
|
||||
agents: Mapped[List["Agent"]] = relationship( # noqa: F821
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="identities")
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent", secondary="identities_agents", lazy="selectin", passive_deletes=True, back_populates="identities"
|
||||
)
|
||||
blocks: Mapped[List["Block"]] = relationship( # noqa: F821
|
||||
blocks: Mapped[List["Block"]] = relationship(
|
||||
"Block", secondary="identities_blocks", lazy="selectin", passive_deletes=True, back_populates="identities"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.llm_batch_job import LLMBatchJob
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse
|
||||
from sqlalchemy import ForeignKey, Index, String
|
||||
@@ -49,6 +54,6 @@ class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
)
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_items") # noqa: F821
|
||||
batch: Mapped["LLMBatchJob"] = relationship("LLMBatchJob", back_populates="items", lazy="selectin") # noqa: F821
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="batch_items", lazy="selectin") # noqa: F821
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_items")
|
||||
batch: Mapped["LLMBatchJob"] = relationship("LLMBatchJob", back_populates="items", lazy="selectin")
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="batch_items", lazy="selectin")
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.llm_batch_items import LLMBatchItem
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, String
|
||||
@@ -47,5 +51,5 @@ class LLMBatchJob(SqlalchemyBase, OrganizationMixin):
|
||||
String, ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the Letta batch job"
|
||||
)
|
||||
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs") # noqa: F821
|
||||
items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin") # noqa: F821
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs")
|
||||
items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin")
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.run import Run
|
||||
from letta.orm.step import Step
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from sqlalchemy import BigInteger, FetchedValue, ForeignKey, Index, event, text
|
||||
@@ -83,12 +89,12 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="raise") # noqa: F821
|
||||
step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin") # noqa: F821
|
||||
run: Mapped["Run"] = relationship("Run", back_populates="messages", lazy="selectin") # noqa: F821
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="raise")
|
||||
step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin")
|
||||
run: Mapped["Run"] = relationship("Run", back_populates="messages", lazy="selectin")
|
||||
|
||||
@property
|
||||
def job(self) -> Optional["Job"]: # noqa: F821
|
||||
def job(self) -> Optional["Job"]:
|
||||
"""Get the job associated with this message, if any."""
|
||||
return self.job_message.job if self.job_message else None
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ config = LettaConfig()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage_tag import PassageTag
|
||||
|
||||
|
||||
class BasePassage(SqlalchemyBase, OrganizationMixin):
|
||||
@@ -78,7 +79,7 @@ class ArchivalPassage(BasePassage, ArchiveMixin):
|
||||
__tablename__ = "archival_passages"
|
||||
|
||||
# junction table for efficient tag queries (complements json column above)
|
||||
passage_tags: Mapped[List["PassageTag"]] = relationship( # noqa: F821
|
||||
passage_tags: Mapped[List["PassageTag"]] = relationship(
|
||||
"PassageTag", back_populates="passage", cascade="all, delete-orphan", lazy="noload"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
from sqlalchemy import JSON, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
@@ -43,4 +46,4 @@ class ProviderTrace(SqlalchemyBase, OrganizationMixin):
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin") # noqa: F821
|
||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
from sqlalchemy import JSON, DateTime, Index, String, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
@@ -42,4 +45,4 @@ class ProviderTraceMetadata(SqlalchemyBase, OrganizationMixin):
|
||||
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin") # noqa: F821
|
||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
||||
|
||||
@@ -30,6 +30,9 @@ from letta.settings import DatabaseChoice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Select
|
||||
|
||||
from letta.schemas.user import User
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -122,7 +125,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
ascending: bool = True,
|
||||
actor: Optional["User"] = None, # noqa: F821
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
join_model: Optional[Base] = None,
|
||||
@@ -222,7 +225,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
ascending: bool = True,
|
||||
actor: Optional["User"] = None, # noqa: F821
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
join_model: Optional[Base] = None,
|
||||
@@ -415,7 +418,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
cls,
|
||||
db_session: "AsyncSession",
|
||||
identifier: Optional[str] = None,
|
||||
actor: Optional["User"] = None, # noqa: F821
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
check_is_deleted: bool = False,
|
||||
@@ -451,7 +454,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
cls,
|
||||
db_session: "AsyncSession",
|
||||
identifiers: List[str] = [],
|
||||
actor: Optional["User"] = None, # noqa: F821
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
check_is_deleted: bool = False,
|
||||
@@ -471,7 +474,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
def _read_multiple_preprocess(
|
||||
cls,
|
||||
identifiers: List[str],
|
||||
actor: Optional["User"], # noqa: F821
|
||||
actor: Optional["User"],
|
||||
access: Optional[List[Literal["read", "write", "admin"]]],
|
||||
access_type: AccessType,
|
||||
check_is_deleted: bool,
|
||||
@@ -543,7 +546,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
async def create_async(
|
||||
self,
|
||||
db_session: "AsyncSession",
|
||||
actor: Optional["User"] = None, # noqa: F821
|
||||
actor: Optional["User"] = None,
|
||||
no_commit: bool = False,
|
||||
no_refresh: bool = False,
|
||||
ignore_conflicts: bool = False,
|
||||
@@ -599,7 +602,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
cls,
|
||||
items: List["SqlalchemyBase"],
|
||||
db_session: "AsyncSession",
|
||||
actor: Optional["User"] = None, # noqa: F821
|
||||
actor: Optional["User"] = None,
|
||||
no_commit: bool = False,
|
||||
no_refresh: bool = False,
|
||||
) -> List["SqlalchemyBase"]:
|
||||
@@ -654,7 +657,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
cls._handle_dbapi_error(e)
|
||||
|
||||
@handle_db_timeout
|
||||
async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase": # noqa: F821
|
||||
async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
"""Soft delete a record asynchronously (mark as deleted)."""
|
||||
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
||||
|
||||
@@ -665,7 +668,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
return await self.update_async(db_session)
|
||||
|
||||
@handle_db_timeout
|
||||
async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None: # noqa: F821
|
||||
async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None:
|
||||
"""Permanently removes the record from the database asynchronously."""
|
||||
obj_id = self.id
|
||||
obj_class = self.__class__.__name__
|
||||
@@ -694,7 +697,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
cls,
|
||||
db_session: "AsyncSession",
|
||||
identifiers: List[str],
|
||||
actor: Optional["User"], # noqa: F821
|
||||
actor: Optional["User"],
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["write"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
) -> None:
|
||||
@@ -731,7 +734,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
async def update_async(
|
||||
self,
|
||||
db_session: "AsyncSession",
|
||||
actor: Optional["User"] = None, # noqa: F821
|
||||
actor: Optional["User"] = None,
|
||||
no_commit: bool = False,
|
||||
no_refresh: bool = False,
|
||||
) -> "SqlalchemyBase":
|
||||
@@ -778,7 +781,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
cls,
|
||||
*,
|
||||
db_session: "Session",
|
||||
actor: Optional["User"] = None, # noqa: F821
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
check_is_deleted: bool = False,
|
||||
@@ -818,7 +821,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
cls,
|
||||
*,
|
||||
db_session: "AsyncSession",
|
||||
actor: Optional["User"] = None, # noqa: F821
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
check_is_deleted: bool = False,
|
||||
@@ -854,11 +857,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
@classmethod
|
||||
def apply_access_predicate(
|
||||
cls,
|
||||
query: "Select", # noqa: F821
|
||||
actor: "User", # noqa: F821
|
||||
query: "Select",
|
||||
actor: "User",
|
||||
access: List[Literal["read", "write", "admin"]],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
) -> "Select": # noqa: F821
|
||||
) -> "Select":
|
||||
"""applies a WHERE clause restricting results to the given actor and access level
|
||||
Args:
|
||||
query: The initial sqlalchemy select statement
|
||||
|
||||
@@ -339,7 +339,7 @@ def trace_method(func):
|
||||
try:
|
||||
# Test if str() works (some objects have broken __str__)
|
||||
try:
|
||||
test_str = str(value)
|
||||
str(value)
|
||||
# If str() works and is reasonable, use repr
|
||||
str_value = repr(value)
|
||||
except Exception:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
from typing import Annotated, ClassVar, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
@@ -246,7 +246,7 @@ class ToolCallMessage(LettaMessage):
|
||||
return data
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
json_encoders: ClassVar[dict] = {
|
||||
ToolCallDelta: lambda v: v.model_dump(exclude_none=True),
|
||||
ToolCall: lambda v: v.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
@@ -1150,7 +1150,7 @@ class Message(BaseMessage):
|
||||
tool_returns = [ToolReturn(**tr) for tr in openai_message_dict["tool_returns"]]
|
||||
|
||||
# TODO(caren) bad assumption here that "reasoning_content" always comes before "redacted_reasoning_content"
|
||||
if "reasoning_content" in openai_message_dict and openai_message_dict["reasoning_content"]:
|
||||
if openai_message_dict.get("reasoning_content"):
|
||||
content.append(
|
||||
ReasoningContent(
|
||||
reasoning=openai_message_dict["reasoning_content"],
|
||||
@@ -1162,13 +1162,13 @@ class Message(BaseMessage):
|
||||
),
|
||||
),
|
||||
)
|
||||
if "redacted_reasoning_content" in openai_message_dict and openai_message_dict["redacted_reasoning_content"]:
|
||||
if openai_message_dict.get("redacted_reasoning_content"):
|
||||
content.append(
|
||||
RedactedReasoningContent(
|
||||
data=str(openai_message_dict["redacted_reasoning_content"]),
|
||||
),
|
||||
)
|
||||
if "omitted_reasoning_content" in openai_message_dict and openai_message_dict["omitted_reasoning_content"]:
|
||||
if openai_message_dict.get("omitted_reasoning_content"):
|
||||
content.append(
|
||||
OmittedReasoningContent(),
|
||||
)
|
||||
@@ -2237,7 +2237,7 @@ class Message(BaseMessage):
|
||||
try:
|
||||
# NOTE: Google AI wants actual JSON objects, not strings
|
||||
function_args = parse_json(function_args)
|
||||
except:
|
||||
except Exception:
|
||||
raise UserWarning(f"Failed to parse JSON function args: {function_args}")
|
||||
function_args = {"args": function_args}
|
||||
|
||||
@@ -2327,7 +2327,7 @@ class Message(BaseMessage):
|
||||
|
||||
try:
|
||||
function_response = parse_json(text_content)
|
||||
except:
|
||||
except Exception:
|
||||
function_response = {"function_response": text_content}
|
||||
|
||||
parts.append(
|
||||
@@ -2360,7 +2360,7 @@ class Message(BaseMessage):
|
||||
# NOTE: Google AI API wants the function response as JSON only, no string
|
||||
try:
|
||||
function_response = parse_json(legacy_content)
|
||||
except:
|
||||
except Exception:
|
||||
function_response = {"function_response": legacy_content}
|
||||
|
||||
google_ai_message = {
|
||||
|
||||
@@ -24,13 +24,6 @@ from .xai import XAIProvider
|
||||
from .zai import ZAIProvider
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
"Provider",
|
||||
"ProviderBase",
|
||||
"ProviderCreate",
|
||||
"ProviderUpdate",
|
||||
"ProviderCheck",
|
||||
# Provider implementations
|
||||
"AnthropicProvider",
|
||||
"AzureProvider",
|
||||
"BedrockProvider",
|
||||
@@ -40,16 +33,21 @@ __all__ = [
|
||||
"GoogleAIProvider",
|
||||
"GoogleVertexProvider",
|
||||
"GroqProvider",
|
||||
"LettaProvider",
|
||||
"LMStudioOpenAIProvider",
|
||||
"LettaProvider",
|
||||
"MiniMaxProvider",
|
||||
"MistralProvider",
|
||||
"OllamaProvider",
|
||||
"OpenAIProvider",
|
||||
"TogetherProvider",
|
||||
"VLLMProvider", # Replaces ChatCompletions and Completions
|
||||
"OpenRouterProvider",
|
||||
"Provider",
|
||||
"ProviderBase",
|
||||
"ProviderCheck",
|
||||
"ProviderCreate",
|
||||
"ProviderUpdate",
|
||||
"SGLangProvider",
|
||||
"TogetherProvider",
|
||||
"VLLMProvider",
|
||||
"XAIProvider",
|
||||
"ZAIProvider",
|
||||
"OpenRouterProvider",
|
||||
]
|
||||
|
||||
@@ -92,7 +92,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
check = self._do_model_checks_for_name_and_context_size(model, length_key="max_context_length")
|
||||
if check is None:
|
||||
continue
|
||||
model_name, context_window_size = check
|
||||
model_name, _context_window_size = check
|
||||
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
|
||||
@@ -93,7 +93,7 @@ class OpenRouterProvider(OpenAIProvider):
|
||||
model_name = model["id"]
|
||||
|
||||
# OpenRouter returns context_length in the model listing
|
||||
if "context_length" in model and model["context_length"]:
|
||||
if model.get("context_length"):
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
@@ -158,7 +158,7 @@ class ToolCreate(LettaBase):
|
||||
description = mcp_tool.description
|
||||
source_type = "python"
|
||||
tags = [f"{MCP_TOOL_TAG_NAME_PREFIX}:{mcp_server_name}"]
|
||||
wrapper_func_name, wrapper_function_str = generate_mcp_tool_wrapper(mcp_tool.name)
|
||||
_wrapper_func_name, wrapper_function_str = generate_mcp_tool_wrapper(mcp_tool.name)
|
||||
|
||||
return cls(
|
||||
description=description,
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Uni
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
UsageStatistics,
|
||||
UsageStatisticsCompletionTokenDetails,
|
||||
UsageStatisticsPromptTokenDetails,
|
||||
)
|
||||
@@ -131,7 +133,7 @@ class LettaUsageStatistics(BaseModel):
|
||||
description="Estimate of tokens currently in the context window.",
|
||||
)
|
||||
|
||||
def to_usage(self, provider_type: Optional["ProviderType"] = None) -> "UsageStatistics": # noqa: F821 # noqa: F821
|
||||
def to_usage(self, provider_type: Optional["ProviderType"] = None) -> "UsageStatistics":
|
||||
"""Convert to UsageStatistics (OpenAI-compatible format).
|
||||
|
||||
Args:
|
||||
|
||||
@@ -112,7 +112,7 @@ class MarshmallowAgentSchema(BaseSchema):
|
||||
.all()
|
||||
)
|
||||
# combine system message with step messages
|
||||
msgs = [system_msg] + step_msgs if system_msg else step_msgs
|
||||
msgs = [system_msg, *step_msgs] if system_msg else step_msgs
|
||||
else:
|
||||
# no user messages, just return system message
|
||||
msgs = [system_msg] if system_msg else []
|
||||
@@ -147,7 +147,7 @@ class MarshmallowAgentSchema(BaseSchema):
|
||||
.all()
|
||||
)
|
||||
# combine system message with step messages
|
||||
msgs = [system_msg] + step_msgs if system_msg else step_msgs
|
||||
msgs = [system_msg, *step_msgs] if system_msg else step_msgs
|
||||
else:
|
||||
# no user messages, just return system message
|
||||
msgs = [system_msg] if system_msg else []
|
||||
@@ -231,7 +231,8 @@ class MarshmallowAgentSchema(BaseSchema):
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Agent
|
||||
exclude = BaseSchema.Meta.exclude + (
|
||||
exclude = (
|
||||
*BaseSchema.Meta.exclude,
|
||||
"project_id",
|
||||
"template_id",
|
||||
"base_template_id",
|
||||
|
||||
@@ -18,4 +18,4 @@ class SerializedAgentEnvironmentVariableSchema(BaseSchema):
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = AgentEnvironmentVariable
|
||||
exclude = BaseSchema.Meta.exclude + ("agent",)
|
||||
exclude = (*BaseSchema.Meta.exclude, "agent")
|
||||
|
||||
@@ -34,4 +34,4 @@ class SerializedBlockSchema(BaseSchema):
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Block
|
||||
exclude = BaseSchema.Meta.exclude + ("agents", "identities", "is_deleted", "groups", "organization")
|
||||
exclude = (*BaseSchema.Meta.exclude, "agents", "identities", "is_deleted", "groups", "organization")
|
||||
|
||||
@@ -37,4 +37,4 @@ class SerializedMessageSchema(BaseSchema):
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Message
|
||||
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "otid", "is_deleted", "organization")
|
||||
exclude = (*BaseSchema.Meta.exclude, "step", "job_message", "otid", "is_deleted", "organization")
|
||||
|
||||
@@ -25,4 +25,4 @@ class SerializedAgentTagSchema(BaseSchema):
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = AgentsTags
|
||||
exclude = BaseSchema.Meta.exclude + ("agent",)
|
||||
exclude = (*BaseSchema.Meta.exclude, "agent")
|
||||
|
||||
@@ -34,4 +34,4 @@ class SerializedToolSchema(BaseSchema):
|
||||
|
||||
class Meta(BaseSchema.Meta):
|
||||
model = Tool
|
||||
exclude = BaseSchema.Meta.exclude + ("is_deleted", "organization")
|
||||
exclude = (*BaseSchema.Meta.exclude, "is_deleted", "organization")
|
||||
|
||||
@@ -891,7 +891,7 @@ def start_server(
|
||||
import uvloop
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (os.getenv("LOCAL_HTTPS") == "true") or "--localhttps" in sys.argv:
|
||||
|
||||
@@ -22,7 +22,7 @@ class AuthRequest(BaseModel):
|
||||
|
||||
|
||||
def setup_auth_router(server: SyncServer, interface: QueuingInterface, password: str) -> APIRouter:
|
||||
@router.post("/auth", tags=["auth"], response_model=AuthResponse)
|
||||
@router.post("/auth", tags=["auth"])
|
||||
def authenticate_user(request: AuthRequest) -> AuthResponse:
|
||||
"""
|
||||
Authenticates the user and sends response with User related data.
|
||||
|
||||
@@ -1227,7 +1227,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# }
|
||||
try:
|
||||
func_args = parse_json(function_call.function.arguments)
|
||||
except:
|
||||
except Exception:
|
||||
func_args = function_call.function.arguments
|
||||
# processed_chunk = {
|
||||
# "function_call": f"{function_call.function.name}({func_args})",
|
||||
@@ -1262,7 +1262,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
else:
|
||||
try:
|
||||
func_args = parse_json(function_call.function.arguments)
|
||||
except:
|
||||
except Exception:
|
||||
logger.warning(f"Failed to parse function arguments: {function_call.function.arguments}")
|
||||
func_args = {}
|
||||
|
||||
|
||||
@@ -301,7 +301,7 @@ async def inject_memory_context(
|
||||
# Handle both string and list system prompts
|
||||
if isinstance(existing_system, list):
|
||||
# If it's a list, prepend our context as a text block
|
||||
modified_data["system"] = existing_system + [{"type": "text", "text": memory_context.rstrip()}]
|
||||
modified_data["system"] = [*existing_system, {"type": "text", "text": memory_context.rstrip()}]
|
||||
elif existing_system:
|
||||
# If it's a non-empty string, prepend our context
|
||||
modified_data["system"] = memory_context + existing_system
|
||||
@@ -451,8 +451,8 @@ async def backfill_agent_project_id(server, agent, actor, project_id: str):
|
||||
async def get_or_create_claude_code_agent(
|
||||
server,
|
||||
actor,
|
||||
project_id: str = None,
|
||||
agent_id: str = None,
|
||||
project_id: str | None = None,
|
||||
agent_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Get or create a special agent for Claude Code sessions.
|
||||
|
||||
@@ -276,7 +276,7 @@ async def create_background_stream_processor(
|
||||
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
|
||||
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
|
||||
stop_reason = maybe_stop_reason.get("stop_reason")
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Stream ended naturally - check if we got a proper terminal
|
||||
@@ -313,7 +313,7 @@ async def create_background_stream_processor(
|
||||
# Set a default stop_reason so run status can be mapped in finally
|
||||
stop_reason = StopReasonType.error.value
|
||||
|
||||
except RunCancelledException as e:
|
||||
except RunCancelledException:
|
||||
# Handle cancellation gracefully - don't write error chunk, cancellation event was already sent
|
||||
logger.info(f"Stream processing stopped due to cancellation for run {run_id}")
|
||||
# The cancellation event was already yielded by cancellation_aware_stream_wrapper
|
||||
|
||||
@@ -3,9 +3,9 @@ import json
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, Body, Depends, File, Form, Header, HTTPException, Query, Request, UploadFile, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from orjson import orjson
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
@@ -879,7 +879,7 @@ async def detach_source(
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
|
||||
await server.block_manager.delete_block_async(block.id, actor)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return agent_state
|
||||
@@ -911,7 +911,7 @@ async def detach_folder_from_agent(
|
||||
source = await server.source_manager.get_source_by_id(source_id=folder_id, actor=actor)
|
||||
block = await server.agent_manager.get_block_with_label_async(agent_id=agent_state.id, block_label=source.name, actor=actor)
|
||||
await server.block_manager.delete_block_async(block.id, actor)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_1_0_sdk_version(headers):
|
||||
@@ -972,7 +972,7 @@ async def open_file_for_agent(
|
||||
visible_content = truncate_file_visible_content(visible_content, True, per_file_view_window_char_limit)
|
||||
|
||||
# Use enforce_max_open_files_and_open for efficient LRU handling
|
||||
closed_files, was_already_open, _ = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
closed_files, _was_already_open, _ = await server.file_agent_manager.enforce_max_open_files_and_open(
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
file_name=file_metadata.file_name,
|
||||
@@ -1840,7 +1840,7 @@ async def send_message_streaming(
|
||||
# use the streaming service for unified stream handling
|
||||
streaming_service = StreamingService(server)
|
||||
|
||||
run, result = await streaming_service.create_agent_stream(
|
||||
_run, result = await streaming_service.create_agent_stream(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
request=request,
|
||||
@@ -1921,7 +1921,6 @@ async def cancel_message(
|
||||
|
||||
@router.post(
|
||||
"/{agent_id}/generate",
|
||||
response_model=GenerateResponse,
|
||||
operation_id="generate_completion",
|
||||
responses={
|
||||
200: {"description": "Successful generation"},
|
||||
@@ -2177,7 +2176,7 @@ async def send_message_async(
|
||||
|
||||
try:
|
||||
is_message_input = request.messages[0].type == MessageCreateType.message
|
||||
except:
|
||||
except Exception:
|
||||
is_message_input = True
|
||||
use_lettuce = headers.experimental_params.message_async and is_message_input
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ from letta.server.server import SyncServer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
router = APIRouter(prefix="/anthropic", tags=["anthropic"])
|
||||
|
||||
ANTHROPIC_API_BASE = "https://api.anthropic.com"
|
||||
@@ -172,7 +174,7 @@ async def anthropic_messages_proxy(
|
||||
# This prevents race conditions where multiple requests persist the same message
|
||||
user_messages_to_persist = await check_for_duplicate_message(server, agent, actor, user_messages, PROXY_NAME)
|
||||
|
||||
asyncio.create_task(
|
||||
task = asyncio.create_task(
|
||||
persist_messages_background(
|
||||
server=server,
|
||||
agent=agent,
|
||||
@@ -183,6 +185,8 @@ async def anthropic_messages_proxy(
|
||||
proxy_name=PROXY_NAME,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
@@ -226,7 +230,7 @@ async def anthropic_messages_proxy(
|
||||
# Check for duplicate user messages before creating background task
|
||||
user_messages_to_persist = await check_for_duplicate_message(server, agent, actor, user_messages, PROXY_NAME)
|
||||
|
||||
asyncio.create_task(
|
||||
task = asyncio.create_task(
|
||||
persist_messages_background(
|
||||
server=server,
|
||||
agent=agent,
|
||||
@@ -237,6 +241,8 @@ async def anthropic_messages_proxy(
|
||||
proxy_name=PROXY_NAME,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except Exception as e:
|
||||
logger.warning(f"[{PROXY_NAME}] Failed to extract assistant response for logging: {e}")
|
||||
|
||||
|
||||
@@ -331,7 +331,7 @@ async def upload_file_to_folder(
|
||||
return response
|
||||
elif duplicate_handling == DuplicateFileHandling.REPLACE:
|
||||
# delete the file
|
||||
deleted_file = await server.file_manager.delete_file(file_id=existing_file.id, actor=actor)
|
||||
await server.file_manager.delete_file(file_id=existing_file.id, actor=actor)
|
||||
unique_filename = original_filename
|
||||
|
||||
if not unique_filename:
|
||||
|
||||
@@ -65,7 +65,8 @@ from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_le
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Routes are proxied to dulwich running on a separate port.
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
router = APIRouter(prefix="/git", tags=["git"], include_in_schema=False)
|
||||
|
||||
# Global storage for the server instance (set during app startup)
|
||||
@@ -718,8 +719,9 @@ async def proxy_git_http(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
# Authorization check: ensure the actor can access this agent.
|
||||
await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=[])
|
||||
# Fire-and-forget; do not block git client response.
|
||||
asyncio.create_task(_sync_after_push(actor.id, agent_id))
|
||||
task = asyncio.create_task(_sync_after_push(actor.id, agent_id))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except Exception:
|
||||
logger.exception("Failed to trigger post-push sync (agent_id=%s)", agent_id)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ async def list_identities(
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
|
||||
identities, next_cursor, has_more = await server.identity_manager.list_identities_async(
|
||||
identities, _next_cursor, _has_more = await server.identity_manager.list_identities_async(
|
||||
name=name,
|
||||
project_id=project_id,
|
||||
identifier_key=identifier_key,
|
||||
|
||||
@@ -309,7 +309,7 @@ async def upload_file_to_source(
|
||||
return response
|
||||
elif duplicate_handling == DuplicateFileHandling.REPLACE:
|
||||
# delete the file
|
||||
deleted_file = await server.file_manager.delete_file(file_id=existing_file.id, actor=actor)
|
||||
await server.file_manager.delete_file(file_id=existing_file.id, actor=actor)
|
||||
unique_filename = original_filename
|
||||
|
||||
if not unique_filename:
|
||||
|
||||
@@ -106,7 +106,7 @@ async def retrieve_trace_for_step(
|
||||
provider_trace = await server.telemetry_manager.get_provider_trace_by_step_id_async(
|
||||
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return provider_trace
|
||||
|
||||
@@ -27,7 +27,7 @@ async def retrieve_provider_trace(
|
||||
provider_trace = await server.telemetry_manager.get_provider_trace_by_step_id_async(
|
||||
step_id=step_id, actor=await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return provider_trace
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user