chore: bump v0.7.17 (#2638)
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: Kevin Lin <klin5061@gmail.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: jnjpng <jin@letta.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.7.16"
|
||||
__version__ = "0.7.17"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
||||
@@ -8,10 +8,11 @@ from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
|
||||
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
@@ -61,12 +62,8 @@ class LettaAgent(BaseAgent):
|
||||
self.last_function_response = None
|
||||
|
||||
# Cached archival memory/message size
|
||||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
|
||||
# Cached archival memory/message size
|
||||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
self.num_messages = 0
|
||||
self.num_archival_memories = 0
|
||||
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
|
||||
@@ -81,7 +78,7 @@ class LettaAgent(BaseAgent):
|
||||
async def _step(
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
) -> Tuple[List[Message], List[Message], CompletionUsage]:
|
||||
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
@@ -129,14 +126,14 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main streaming loop that yields partial tokens.
|
||||
Whenever we detect a tool call, we yield from _handle_ai_response as well.
|
||||
"""
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
|
||||
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
@@ -157,9 +154,16 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs
|
||||
)
|
||||
if agent_state.llm_config.model_endpoint_type == "anthropic":
|
||||
interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
elif agent_state.llm_config.model_endpoint_type == "openai":
|
||||
interface = OpenAIStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
async for chunk in interface.process(stream):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
@@ -197,8 +201,8 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
# TODO: This may be out of sync, if in between steps users add files
|
||||
# NOTE (cliandy): temporary for now for particlar use cases.
|
||||
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
self.num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
self.num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
# TODO: Also yield out a letta usage stats SSE
|
||||
yield f"data: {usage.model_dump_json()}\n\n"
|
||||
@@ -215,6 +219,10 @@ class LettaAgent(BaseAgent):
|
||||
stream: bool,
|
||||
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
|
||||
self.num_archival_memories = self.num_archival_memories or (
|
||||
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
|
||||
303
letta/interfaces/openai_streaming_interface.py
Normal file
303
letta/interfaces/openai_streaming_interface.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.streaming_utils import JSONInnerThoughtsExtractor
|
||||
|
||||
|
||||
class OpenAIStreamingInterface:
|
||||
"""
|
||||
Encapsulates the logic for streaming responses from OpenAI.
|
||||
This class handles parsing of partial tokens, pre-execution messages,
|
||||
and detection of tool call events.
|
||||
"""
|
||||
|
||||
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL
|
||||
self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG
|
||||
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=True) # TODO: pass in kward
|
||||
self.function_name_buffer = None
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
self.last_flushed_function_name = None
|
||||
|
||||
# Buffer to hold function arguments until inner thoughts are complete
|
||||
self.current_function_arguments = ""
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_assistant_message_id = Message.generate_id()
|
||||
self.letta_tool_message_id = Message.generate_id()
|
||||
|
||||
# token counters
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
|
||||
self.content_buffer: List[str] = []
|
||||
self.tool_call_name: Optional[str] = None
|
||||
self.tool_call_id: Optional[str] = None
|
||||
self.reasoning_messages = []
|
||||
|
||||
def get_reasoning_content(self) -> List[TextContent]:
|
||||
content = "".join(self.reasoning_messages)
|
||||
return [TextContent(text=content)]
|
||||
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Useful for agent loop"""
|
||||
return ToolCall(
|
||||
id=self.letta_tool_message_id,
|
||||
function=FunctionCall(arguments=self.current_function_arguments, name=self.last_flushed_function_name),
|
||||
)
|
||||
|
||||
async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[LettaMessage, None]:
|
||||
"""
|
||||
Iterates over the OpenAI stream, yielding SSE events.
|
||||
It also collects tokens and detects if a tool call is triggered.
|
||||
"""
|
||||
async with stream:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
async for chunk in stream:
|
||||
# track usage
|
||||
if chunk.usage:
|
||||
self.input_tokens += len(chunk.usage.prompt_tokens)
|
||||
self.output_tokens += len(chunk.usage.completion_tokens)
|
||||
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
message_delta = choice.delta
|
||||
|
||||
if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
|
||||
tool_call = message_delta.tool_calls[0]
|
||||
|
||||
if tool_call.function.name:
|
||||
# If we're waiting for the first key, then we should hold back the name
|
||||
# ie add it to a buffer instead of returning it as a chunk
|
||||
if self.function_name_buffer is None:
|
||||
self.function_name_buffer = tool_call.function.name
|
||||
else:
|
||||
self.function_name_buffer += tool_call.function.name
|
||||
|
||||
if tool_call.id:
|
||||
# Buffer until next time
|
||||
if self.function_id_buffer is None:
|
||||
self.function_id_buffer = tool_call.id
|
||||
else:
|
||||
self.function_id_buffer += tool_call.id
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
self.current_function_arguments += tool_call.function.arguments
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# If we have inner thoughts, we should output them as a chunk
|
||||
if updates_inner_thoughts:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
self.reasoning_messages.append(updates_inner_thoughts)
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
reasoning=updates_inner_thoughts,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
|
||||
)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
|
||||
# Additionally inner thoughts may stream back with a chunk of main JSON
|
||||
# In that case, since we can only return a chunk at a time, we should buffer it
|
||||
if updates_main_json:
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If we have main_json, we should output a ToolCallMessage
|
||||
elif updates_main_json:
|
||||
|
||||
# If there's something in the function_name buffer, we should release it first
|
||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||
# however the frontend may expect name first, then args, so to be
|
||||
# safe we'll output name first in a separate chunk
|
||||
if self.function_name_buffer:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
||||
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
self.tool_call_name = str(self.function_name_buffer)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
arguments=None,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
|
||||
# Record what the last function name we flushed was
|
||||
self.last_flushed_function_name = self.function_name_buffer
|
||||
# Clear the buffer
|
||||
self.function_name_buffer = None
|
||||
self.function_id_buffer = None
|
||||
# Since we're clearing the name buffer, we should store
|
||||
# any updates to the arguments inside a separate buffer
|
||||
|
||||
# Add any main_json updates to the arguments buffer
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a ToolCallMessage
|
||||
else:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and (
|
||||
self.last_flushed_function_name is not None
|
||||
and self.last_flushed_function_name == self.assistant_message_tool_name
|
||||
):
|
||||
# do an additional parse on the updates_main_json
|
||||
if self.function_args_buffer:
|
||||
updates_main_json = self.function_args_buffer + updates_main_json
|
||||
self.function_args_buffer = None
|
||||
|
||||
# Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix
|
||||
match_str = '{"' + self.assistant_message_tool_kwarg + '":"'
|
||||
if updates_main_json == match_str:
|
||||
updates_main_json = None
|
||||
|
||||
else:
|
||||
# Some hardcoding to strip off the trailing "}"
|
||||
if updates_main_json in ["}", '"}']:
|
||||
updates_main_json = None
|
||||
if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"':
|
||||
updates_main_json = updates_main_json[:-1]
|
||||
|
||||
if not updates_main_json:
|
||||
# early exit to turn into content mode
|
||||
continue
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_message = AssistantMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
content=combined_chunk,
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
# TODO: THIS IS HORRIBLE
|
||||
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
|
||||
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
|
||||
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
|
||||
|
||||
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
|
||||
self.assistant_message_tool_kwarg
|
||||
) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
|
||||
new_content = parsed_args.get(self.assistant_message_tool_kwarg)
|
||||
prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "")
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_message = AssistantMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
content=diff,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
# clear buffers
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
self.function_id_buffer = None
|
||||
@@ -745,6 +745,17 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
self.is_deleted = True
|
||||
return self.update(db_session)
|
||||
|
||||
@handle_db_timeout
|
||||
async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
"""Soft delete a record asynchronously (mark as deleted)."""
|
||||
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
||||
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
|
||||
self.is_deleted = True
|
||||
return await self.update_async(db_session)
|
||||
|
||||
@handle_db_timeout
|
||||
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
|
||||
"""Permanently removes the record from the database."""
|
||||
@@ -761,6 +772,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
else:
|
||||
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
||||
|
||||
@handle_db_timeout
|
||||
async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None:
|
||||
"""Permanently removes the record from the database asynchronously."""
|
||||
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
||||
|
||||
async with db_session as session:
|
||||
try:
|
||||
await session.delete(self)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}")
|
||||
raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}")
|
||||
|
||||
@handle_db_timeout
|
||||
def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
|
||||
logger.debug(...)
|
||||
@@ -793,6 +818,39 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
await db_session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def _size_preprocess(
|
||||
cls,
|
||||
*,
|
||||
db_session: "Session",
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
**kwargs,
|
||||
):
|
||||
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
||||
query = select(func.count()).select_from(cls)
|
||||
|
||||
if actor:
|
||||
query = cls.apply_access_predicate(query, actor, access, access_type)
|
||||
|
||||
# Apply filtering logic based on kwargs
|
||||
for key, value in kwargs.items():
|
||||
if value:
|
||||
column = getattr(cls, key, None)
|
||||
if not column:
|
||||
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
|
||||
if isinstance(value, (list, tuple, set)): # Check for iterables
|
||||
query = query.where(column.in_(value))
|
||||
else: # Single value for equality filtering
|
||||
query = query.where(column == value)
|
||||
|
||||
# Handle soft deletes if the class has the 'is_deleted' attribute
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def size(
|
||||
@@ -817,28 +875,8 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
Raises:
|
||||
DBAPIError: If a database error occurs
|
||||
"""
|
||||
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
||||
|
||||
with db_session as session:
|
||||
query = select(func.count()).select_from(cls)
|
||||
|
||||
if actor:
|
||||
query = cls.apply_access_predicate(query, actor, access, access_type)
|
||||
|
||||
# Apply filtering logic based on kwargs
|
||||
for key, value in kwargs.items():
|
||||
if value:
|
||||
column = getattr(cls, key, None)
|
||||
if not column:
|
||||
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
|
||||
if isinstance(value, (list, tuple, set)): # Check for iterables
|
||||
query = query.where(column.in_(value))
|
||||
else: # Single value for equality filtering
|
||||
query = query.where(column == value)
|
||||
|
||||
# Handle soft deletes if the class has the 'is_deleted' attribute
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
query = cls._size_preprocess(db_session=session, actor=actor, access=access, access_type=access_type, **kwargs)
|
||||
|
||||
try:
|
||||
count = session.execute(query).scalar()
|
||||
@@ -847,6 +885,37 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
logger.exception(f"Failed to calculate size for {cls.__name__}")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
async def size_async(
|
||||
cls,
|
||||
*,
|
||||
db_session: "AsyncSession",
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
**kwargs,
|
||||
) -> int:
|
||||
"""
|
||||
Get the count of rows that match the provided filters.
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
**kwargs: Filters to apply to the query (e.g., column_name=value)
|
||||
Returns:
|
||||
int: The count of rows that match the filters
|
||||
Raises:
|
||||
DBAPIError: If a database error occurs
|
||||
"""
|
||||
async with db_session as session:
|
||||
query = cls._size_preprocess(db_session=session, actor=actor, access=access, access_type=access_type, **kwargs)
|
||||
|
||||
try:
|
||||
count = await session.execute(query).scalar()
|
||||
return count if count else 0
|
||||
except DBAPIError as e:
|
||||
logger.exception(f"Failed to calculate size for {cls.__name__}")
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def apply_access_predicate(
|
||||
cls,
|
||||
|
||||
@@ -83,7 +83,7 @@ async def list_agents(
|
||||
"""
|
||||
|
||||
# Retrieve the actor (user) details
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# Call list_agents directly without unnecessary dict handling
|
||||
return await server.agent_manager.list_agents_async(
|
||||
@@ -163,7 +163,7 @@ async def import_agent_serialized(
|
||||
"""
|
||||
Import a serialized agent file and recreate the agent in the system.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
serialized_data = await file.read()
|
||||
@@ -233,7 +233,7 @@ async def create_agent(
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.create_agent_async(agent, actor=actor)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
@@ -248,7 +248,7 @@ async def modify_agent(
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Update an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.update_agent_async(agent_id=agent_id, request=update_agent, actor=actor)
|
||||
|
||||
|
||||
@@ -333,7 +333,7 @@ def detach_source(
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent")
|
||||
def retrieve_agent(
|
||||
async def retrieve_agent(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@@ -344,7 +344,7 @@ def retrieve_agent(
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -414,7 +414,7 @@ def retrieve_block(
|
||||
|
||||
|
||||
@router.get("/{agent_id}/core-memory/blocks", response_model=List[Block], operation_id="list_core_memory_blocks")
|
||||
def list_blocks(
|
||||
async def list_blocks(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@@ -424,7 +424,7 @@ def list_blocks(
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
|
||||
return agent.memory.blocks
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -628,9 +628,9 @@ async def send_message(
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
|
||||
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
|
||||
@@ -686,13 +686,13 @@ async def send_message_streaming(
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
request_start_timestamp_ns = get_utc_timestamp_ns()
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
|
||||
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
|
||||
model_compatible = agent.llm_config.model_endpoint_type == "anthropic"
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
|
||||
|
||||
if agent_eligible and feature_enabled and model_compatible and request.stream_tokens:
|
||||
experimental_agent = LettaAgent(
|
||||
@@ -705,7 +705,9 @@ async def send_message_streaming(
|
||||
)
|
||||
|
||||
result = StreamingResponse(
|
||||
experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message),
|
||||
experimental_agent.step_stream(
|
||||
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
@@ -784,7 +786,7 @@ async def send_message_async(
|
||||
Asynchronously process a user message and return a run object.
|
||||
The actual processing happens in the background, and the status can be checked using the run ID.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# Create a new job
|
||||
run = Run(
|
||||
@@ -838,6 +840,6 @@ async def list_agent_groups(
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Lists the groups for an agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
print("in list agents with manager_type", manager_type)
|
||||
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
|
||||
|
||||
@@ -26,7 +26,7 @@ async def list_blocks(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.block_manager.get_blocks_async(
|
||||
actor=actor,
|
||||
label=label,
|
||||
|
||||
@@ -135,7 +135,7 @@ async def send_group_message(
|
||||
Process a user message and return the group's response.
|
||||
This endpoint accepts a message from a user and processes it through through agents in the group based on the specified pattern
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
@@ -174,7 +174,7 @@ async def send_group_message_streaming(
|
||||
This endpoint accepts a message from a user and processes it through agents in the group based on the specified pattern.
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
|
||||
@@ -52,7 +52,7 @@ async def create_messages_batch(
|
||||
detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.",
|
||||
)
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
batch_job = BatchJob(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.running,
|
||||
@@ -100,7 +100,7 @@ async def retrieve_batch_run(
|
||||
"""
|
||||
Get the status of a batch run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor)
|
||||
@@ -118,7 +118,7 @@ async def list_batch_runs(
|
||||
List all batch runs.
|
||||
"""
|
||||
# TODO: filter
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
jobs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.BATCH)
|
||||
return [BatchJob.from_job(job) for job in jobs]
|
||||
@@ -150,7 +150,7 @@ async def list_batch_messages(
|
||||
- For subsequent pages, use the ID of the last message from the previous response as the cursor
|
||||
- Results will include messages before/after the cursor based on sort_descending
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# First, verify the batch job exists and the user has access to it
|
||||
try:
|
||||
@@ -177,7 +177,7 @@ async def cancel_batch_run(
|
||||
"""
|
||||
Cancel a batch run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor)
|
||||
|
||||
@@ -115,7 +115,7 @@ async def list_run_messages(
|
||||
if order not in ["asc", "desc"]:
|
||||
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
messages = server.job_manager.get_run_messages(
|
||||
@@ -182,7 +182,7 @@ async def list_run_steps(
|
||||
if order not in ["asc", "desc"]:
|
||||
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
steps = server.job_manager.get_job_steps(
|
||||
|
||||
@@ -87,7 +87,7 @@ async def list_tools(
|
||||
Get a list of all tools available to agents belonging to the org of the user
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
if name is not None:
|
||||
tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor)
|
||||
return [tool] if tool else []
|
||||
|
||||
@@ -14,7 +14,7 @@ router = APIRouter(prefix="/users", tags=["users", "admin"])
|
||||
|
||||
|
||||
@router.get("/", tags=["admin"], response_model=List[User], operation_id="list_users")
|
||||
def list_users(
|
||||
async def list_users(
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -23,7 +23,7 @@ def list_users(
|
||||
Get a list of all users in the database
|
||||
"""
|
||||
try:
|
||||
users = server.user_manager.list_users(after=after, limit=limit)
|
||||
users = await server.user_manager.list_actors_async(after=after, limit=limit)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -32,7 +32,7 @@ def list_users(
|
||||
|
||||
|
||||
@router.post("/", tags=["admin"], response_model=User, operation_id="create_user")
|
||||
def create_user(
|
||||
async def create_user(
|
||||
request: UserCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
@@ -40,33 +40,33 @@ def create_user(
|
||||
Create a new user in the database
|
||||
"""
|
||||
user = User(**request.model_dump())
|
||||
user = server.user_manager.create_user(user)
|
||||
user = await server.user_manager.create_actor_async(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/", tags=["admin"], response_model=User, operation_id="update_user")
|
||||
def update_user(
|
||||
async def update_user(
|
||||
user: UserUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Update a user in the database
|
||||
"""
|
||||
user = server.user_manager.update_user(user)
|
||||
user = await server.user_manager.update_actor_async(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/", tags=["admin"], response_model=User, operation_id="delete_user")
|
||||
def delete_user(
|
||||
async def delete_user(
|
||||
user_id: str = Query(..., description="The user_id key to be deleted."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
user = server.user_manager.get_user_by_id(user_id=user_id)
|
||||
user = await server.user_manager.get_actor_by_id_async(actor_id=user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail=f"User does not exist")
|
||||
server.user_manager.delete_user_by_id(user_id=user_id)
|
||||
await server.user_manager.delete_actor_by_id_async(user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -36,7 +36,7 @@ async def create_voice_chat_completions(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
|
||||
|
||||
# Create OpenAI async client
|
||||
client = openai.AsyncClient(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
@@ -905,12 +906,7 @@ class AgentManager:
|
||||
|
||||
result = await session.execute(query)
|
||||
agents = result.scalars().all()
|
||||
pydantic_agents = []
|
||||
for agent in agents:
|
||||
pydantic_agent = await agent.to_pydantic_async(include_relationships=include_relationships)
|
||||
pydantic_agents.append(pydantic_agent)
|
||||
|
||||
return pydantic_agents
|
||||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
||||
|
||||
@enforce_types
|
||||
def list_agents_matching_tags(
|
||||
@@ -1195,8 +1191,8 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
return await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=actor)
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||||
|
||||
@@ -286,6 +286,21 @@ class MessageManager:
|
||||
with db_registry.session() as session:
|
||||
return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id)
|
||||
|
||||
@enforce_types
|
||||
async def size_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
role: Optional[MessageRole] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get the total count of messages with optional filters.
|
||||
Args:
|
||||
actor: The user requesting the count
|
||||
role: The role of the message
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
return await MessageModel.size_async(db_session=session, actor=actor, role=role, agent_id=agent_id)
|
||||
|
||||
@enforce_types
|
||||
def list_user_messages_for_agent(
|
||||
self,
|
||||
|
||||
@@ -216,6 +216,20 @@ class PassageManager:
|
||||
with db_registry.session() as session:
|
||||
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
|
||||
|
||||
@enforce_types
|
||||
async def size_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get the total count of messages with optional filters.
|
||||
Args:
|
||||
actor: The user requesting the count
|
||||
agent_id: The agent ID of the messages
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id)
|
||||
|
||||
def estimate_embeddings_size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
|
||||
@@ -44,6 +44,14 @@ class UserManager:
|
||||
new_user.create(session)
|
||||
return new_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def create_actor_async(self, pydantic_user: PydanticUser) -> PydanticUser:
|
||||
"""Create a new user if it doesn't already exist (async version)."""
|
||||
async with db_registry.async_session() as session:
|
||||
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
|
||||
await new_user.create_async(session)
|
||||
return new_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_user(self, user_update: UserUpdate) -> PydanticUser:
|
||||
"""Update user details."""
|
||||
@@ -60,6 +68,22 @@ class UserManager:
|
||||
existing_user.update(session)
|
||||
return existing_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def update_actor_async(self, user_update: UserUpdate) -> PydanticUser:
|
||||
"""Update user details (async version)."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Retrieve the existing user by ID
|
||||
existing_user = await UserModel.read_async(db_session=session, identifier=user_update.id)
|
||||
|
||||
# Update only the fields that are provided in UserUpdate
|
||||
update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(existing_user, key, value)
|
||||
|
||||
# Commit the updated user
|
||||
await existing_user.update_async(session)
|
||||
return existing_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_user_by_id(self, user_id: str):
|
||||
"""Delete a user and their associated records (agents, sources, mappings)."""
|
||||
@@ -70,6 +94,14 @@ class UserManager:
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
async def delete_actor_by_id_async(self, user_id: str):
|
||||
"""Delete a user and their associated records (agents, sources, mappings) asynchronously."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Delete from user table
|
||||
user = await UserModel.read_async(db_session=session, identifier=user_id)
|
||||
await user.hard_delete_async(session)
|
||||
|
||||
@enforce_types
|
||||
def get_user_by_id(self, user_id: str) -> PydanticUser:
|
||||
"""Fetch a user by ID."""
|
||||
@@ -77,6 +109,13 @@ class UserManager:
|
||||
user = UserModel.read(db_session=session, identifier=user_id)
|
||||
return user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser:
|
||||
"""Fetch a user by ID asynchronously."""
|
||||
async with db_registry.async_session() as session:
|
||||
user = await UserModel.read_async(db_session=session, identifier=actor_id)
|
||||
return user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_default_user(self) -> PydanticUser:
|
||||
"""Fetch the default user. If it doesn't exist, create it."""
|
||||
@@ -96,6 +135,26 @@ class UserManager:
|
||||
except NoResultFound:
|
||||
return self.get_default_user()
|
||||
|
||||
@enforce_types
|
||||
async def get_default_actor_async(self) -> PydanticUser:
|
||||
"""Fetch the default user asynchronously. If it doesn't exist, create it."""
|
||||
try:
|
||||
return await self.get_actor_by_id_async(self.DEFAULT_USER_ID)
|
||||
except NoResultFound:
|
||||
# Fall back to synchronous version since create_default_user isn't async yet
|
||||
return self.create_default_user(org_id=self.DEFAULT_ORG_ID)
|
||||
|
||||
@enforce_types
|
||||
async def get_actor_or_default_async(self, actor_id: Optional[str] = None):
|
||||
"""Fetch the user or default user asynchronously."""
|
||||
if not actor_id:
|
||||
return await self.get_default_actor_async()
|
||||
|
||||
try:
|
||||
return await self.get_actor_by_id_async(actor_id=actor_id)
|
||||
except NoResultFound:
|
||||
return await self.get_default_actor_async()
|
||||
|
||||
@enforce_types
|
||||
def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
|
||||
"""List all users with optional pagination."""
|
||||
@@ -106,3 +165,14 @@ class UserManager:
|
||||
limit=limit,
|
||||
)
|
||||
return [user.to_pydantic() for user in users]
|
||||
|
||||
@enforce_types
|
||||
async def list_actors_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
|
||||
"""List all users with optional pagination (async version)."""
|
||||
async with db_registry.async_session() as session:
|
||||
users = await UserModel.list_async(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
)
|
||||
return [user.to_pydantic() for user in users]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.7.16"
|
||||
version = "0.7.17"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
from letta_client import AsyncLetta
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
@@ -130,12 +130,12 @@ def server_url():
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = Letta(base_url=server_url)
|
||||
client = AsyncLetta(base_url=server_url)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def roll_dice_tool(client):
|
||||
async def roll_dice_tool(client):
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
@@ -145,13 +145,13 @@ def roll_dice_tool(client):
|
||||
"""
|
||||
return "Rolled a 10!"
|
||||
|
||||
tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
tool = await client.tools.upsert_from_function(func=roll_dice)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def weather_tool(client):
|
||||
async def weather_tool(client):
|
||||
def get_weather(location: str) -> str:
|
||||
"""
|
||||
Fetches the current weather for a given location.
|
||||
@@ -176,7 +176,7 @@ def weather_tool(client):
|
||||
else:
|
||||
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
||||
|
||||
tool = client.tools.upsert_from_function(func=get_weather)
|
||||
tool = await client.tools.upsert_from_function(func=get_weather)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
@@ -270,7 +270,7 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"])
|
||||
async def test_model_compatibility(disable_e2b_api_key, client, model, server, group_id, actor):
|
||||
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, group_id, actor):
|
||||
request = _get_chat_request("How are you?")
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
@@ -306,7 +306,7 @@ async def test_model_compatibility(disable_e2b_api_key, client, model, server, g
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_voice_recall_memory(disable_e2b_api_key, client, voice_agent, message, endpoint):
|
||||
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
@@ -320,7 +320,7 @@ async def test_voice_recall_memory(disable_e2b_api_key, client, voice_agent, mes
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_trigger_summarization(disable_e2b_api_key, client, server, voice_agent, group_id, endpoint, actor):
|
||||
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor):
|
||||
server.group_manager.modify_group(
|
||||
group_id=group_id,
|
||||
group_update=GroupUpdate(
|
||||
@@ -423,7 +423,7 @@ async def test_summarization(disable_e2b_api_key, voice_agent):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_voice_sleeptime_agent(disable_e2b_api_key, client, voice_agent):
|
||||
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
agent_manager = AgentManager()
|
||||
tool_manager = ToolManager()
|
||||
|
||||
@@ -124,16 +124,16 @@ def default_user(server: SyncServer, default_organization):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_user(server: SyncServer, default_organization):
|
||||
async def other_user(server: SyncServer, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_user(PydanticUser(name="other", organization_id=default_organization.id))
|
||||
user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=default_organization.id))
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_user_different_org(server: SyncServer, other_organization):
|
||||
async def other_user_different_org(server: SyncServer, other_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_user(PydanticUser(name="other", organization_id=other_organization.id))
|
||||
user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=other_organization.id))
|
||||
yield user
|
||||
|
||||
|
||||
@@ -2160,20 +2160,21 @@ def test_passage_cascade_deletion(
|
||||
# ======================================================================================================================
|
||||
# User Manager Tests
|
||||
# ======================================================================================================================
|
||||
def test_list_users(server: SyncServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users(server: SyncServer, event_loop):
|
||||
# Create default organization
|
||||
org = server.organization_manager.create_default_organization()
|
||||
|
||||
user_name = "user"
|
||||
user = server.user_manager.create_user(PydanticUser(name=user_name, organization_id=org.id))
|
||||
user = await server.user_manager.create_actor_async(PydanticUser(name=user_name, organization_id=org.id))
|
||||
|
||||
users = server.user_manager.list_users()
|
||||
users = await server.user_manager.list_actors_async()
|
||||
assert len(users) == 1
|
||||
assert users[0].name == user_name
|
||||
|
||||
# Delete it after
|
||||
server.user_manager.delete_user_by_id(user.id)
|
||||
assert len(server.user_manager.list_users()) == 0
|
||||
await server.user_manager.delete_actor_by_id_async(user.id)
|
||||
assert len(await server.user_manager.list_actors_async()) == 0
|
||||
|
||||
|
||||
def test_create_default_user(server: SyncServer):
|
||||
@@ -2183,7 +2184,8 @@ def test_create_default_user(server: SyncServer):
|
||||
assert retrieved.name == server.user_manager.DEFAULT_USER_NAME
|
||||
|
||||
|
||||
def test_update_user(server: SyncServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user(server: SyncServer, event_loop):
|
||||
# Create default organization
|
||||
default_org = server.organization_manager.create_default_organization()
|
||||
test_org = server.organization_manager.create_organization(PydanticOrganization(name="test_org"))
|
||||
@@ -2192,16 +2194,16 @@ def test_update_user(server: SyncServer):
|
||||
user_name_b = "b"
|
||||
|
||||
# Assert it's been created
|
||||
user = server.user_manager.create_user(PydanticUser(name=user_name_a, organization_id=default_org.id))
|
||||
user = await server.user_manager.create_actor_async(PydanticUser(name=user_name_a, organization_id=default_org.id))
|
||||
assert user.name == user_name_a
|
||||
|
||||
# Adjust name
|
||||
user = server.user_manager.update_user(UserUpdate(id=user.id, name=user_name_b))
|
||||
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, name=user_name_b))
|
||||
assert user.name == user_name_b
|
||||
assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID
|
||||
|
||||
# Adjust org id
|
||||
user = server.user_manager.update_user(UserUpdate(id=user.id, organization_id=test_org.id))
|
||||
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, organization_id=test_org.id))
|
||||
assert user.name == user_name_b
|
||||
assert user.organization_id == test_org.id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user