chore: bump v0.7.17 (#2638)

Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: jnjpng <jin@letta.com>
This commit is contained in:
cthomas
2025-05-16 02:02:40 -07:00
committed by GitHub
parent 62c8cbff27
commit e72dc3e93c
19 changed files with 585 additions and 106 deletions

View File

@@ -1,4 +1,4 @@
__version__ = "0.7.16"
__version__ = "0.7.17"
# import clients
from letta.client.client import LocalClient, RESTClient, create_client

View File

@@ -8,10 +8,11 @@ from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async
from letta.helpers import ToolRulesSolver
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
from letta.llm_api.llm_client import LLMClient
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
@@ -61,12 +62,8 @@ class LettaAgent(BaseAgent):
self.last_function_response = None
# Cached archival memory/message size
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id)
# Cached archival memory/message size
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id)
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id)
self.num_messages = 0
self.num_archival_memories = 0
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
@@ -81,7 +78,7 @@ class LettaAgent(BaseAgent):
async def _step(
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
) -> Tuple[List[Message], List[Message], CompletionUsage]:
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
@@ -129,14 +126,14 @@ class LettaAgent(BaseAgent):
@trace_method
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False
) -> AsyncGenerator[str, None]:
"""
Main streaming loop that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well.
"""
agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
input_messages, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
@@ -157,9 +154,16 @@ class LettaAgent(BaseAgent):
)
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs
)
if agent_state.llm_config.model_endpoint_type == "anthropic":
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
elif agent_state.llm_config.model_endpoint_type == "openai":
interface = OpenAIStreamingInterface(
use_assistant_message=use_assistant_message,
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
)
async for chunk in interface.process(stream):
yield f"data: {chunk.model_dump_json()}\n\n"
@@ -197,8 +201,8 @@ class LettaAgent(BaseAgent):
# TODO: This may be out of sync, if in between steps users add files
# NOTE (cliandy): temporary for now for particlar use cases.
self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
self.num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
self.num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
# TODO: Also yield out a letta usage stats SSE
yield f"data: {usage.model_dump_json()}\n\n"
@@ -215,6 +219,10 @@ class LettaAgent(BaseAgent):
stream: bool,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
if settings.experimental_enable_async_db_engine:
self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id))
self.num_archival_memories = self.num_archival_memories or (
await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
)
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)

View File

@@ -0,0 +1,303 @@
from datetime import datetime, timezone
from typing import AsyncGenerator, List, Optional
from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.streaming_utils import JSONInnerThoughtsExtractor
class OpenAIStreamingInterface:
"""
Encapsulates the logic for streaming responses from OpenAI.
This class handles parsing of partial tokens, pre-execution messages,
and detection of tool call events.
"""
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
self.use_assistant_message = use_assistant_message
self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL
self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=True) # TODO: pass in kward
self.function_name_buffer = None
self.function_args_buffer = None
self.function_id_buffer = None
self.last_flushed_function_name = None
# Buffer to hold function arguments until inner thoughts are complete
self.current_function_arguments = ""
self.current_json_parse_result = {}
# Premake IDs for database writes
self.letta_assistant_message_id = Message.generate_id()
self.letta_tool_message_id = Message.generate_id()
# token counters
self.input_tokens = 0
self.output_tokens = 0
self.content_buffer: List[str] = []
self.tool_call_name: Optional[str] = None
self.tool_call_id: Optional[str] = None
self.reasoning_messages = []
def get_reasoning_content(self) -> List[TextContent]:
content = "".join(self.reasoning_messages)
return [TextContent(text=content)]
def get_tool_call_object(self) -> ToolCall:
"""Useful for agent loop"""
return ToolCall(
id=self.letta_tool_message_id,
function=FunctionCall(arguments=self.current_function_arguments, name=self.last_flushed_function_name),
)
async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[LettaMessage, None]:
"""
Iterates over the OpenAI stream, yielding SSE events.
It also collects tokens and detects if a tool call is triggered.
"""
async with stream:
prev_message_type = None
message_index = 0
async for chunk in stream:
# track usage
if chunk.usage:
self.input_tokens += len(chunk.usage.prompt_tokens)
self.output_tokens += len(chunk.usage.completion_tokens)
if chunk.choices:
choice = chunk.choices[0]
message_delta = choice.delta
if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
tool_call = message_delta.tool_calls[0]
if tool_call.function.name:
# If we're waiting for the first key, then we should hold back the name
# ie add it to a buffer instead of returning it as a chunk
if self.function_name_buffer is None:
self.function_name_buffer = tool_call.function.name
else:
self.function_name_buffer += tool_call.function.name
if tool_call.id:
# Buffer until next time
if self.function_id_buffer is None:
self.function_id_buffer = tool_call.id
else:
self.function_id_buffer += tool_call.id
if tool_call.function.arguments:
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
self.current_function_arguments += tool_call.function.arguments
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(
tool_call.function.arguments
)
# If we have inner thoughts, we should output them as a chunk
if updates_inner_thoughts:
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
self.reasoning_messages.append(updates_inner_thoughts)
reasoning_message = ReasoningMessage(
id=self.letta_tool_message_id,
date=datetime.now(timezone.utc),
reasoning=updates_inner_thoughts,
# name=name,
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
)
prev_message_type = reasoning_message.message_type
yield reasoning_message
# Additionally inner thoughts may stream back with a chunk of main JSON
# In that case, since we can only return a chunk at a time, we should buffer it
if updates_main_json:
if self.function_args_buffer is None:
self.function_args_buffer = updates_main_json
else:
self.function_args_buffer += updates_main_json
# If we have main_json, we should output a ToolCallMessage
elif updates_main_json:
# If there's something in the function_name buffer, we should release it first
# NOTE: we could output it as part of a chunk that has both name and args,
# however the frontend may expect name first, then args, so to be
# safe we'll output name first in a separate chunk
if self.function_name_buffer:
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
self.prev_assistant_message_id = self.function_id_buffer
else:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
self.tool_call_name = str(self.function_name_buffer)
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
tool_call_id=self.function_id_buffer,
),
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
# Record what the last function name we flushed was
self.last_flushed_function_name = self.function_name_buffer
# Clear the buffer
self.function_name_buffer = None
self.function_id_buffer = None
# Since we're clearing the name buffer, we should store
# any updates to the arguments inside a separate buffer
# Add any main_json updates to the arguments buffer
if self.function_args_buffer is None:
self.function_args_buffer = updates_main_json
else:
self.function_args_buffer += updates_main_json
# If there was nothing in the name buffer, we can proceed to
# output the arguments chunk as a ToolCallMessage
else:
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
if self.use_assistant_message and (
self.last_flushed_function_name is not None
and self.last_flushed_function_name == self.assistant_message_tool_name
):
# do an additional parse on the updates_main_json
if self.function_args_buffer:
updates_main_json = self.function_args_buffer + updates_main_json
self.function_args_buffer = None
# Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix
match_str = '{"' + self.assistant_message_tool_kwarg + '":"'
if updates_main_json == match_str:
updates_main_json = None
else:
# Some hardcoding to strip off the trailing "}"
if updates_main_json in ["}", '"}']:
updates_main_json = None
if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"':
updates_main_json = updates_main_json[:-1]
if not updates_main_json:
# early exit to turn into content mode
continue
# There may be a buffer from a previous chunk, for example
# if the previous chunk had arguments but we needed to flush name
if self.function_args_buffer:
# In this case, we should release the buffer + new data at once
combined_chunk = self.function_args_buffer + updates_main_json
if prev_message_type and prev_message_type != "assistant_message":
message_index += 1
assistant_message = AssistantMessage(
id=self.letta_assistant_message_id,
date=datetime.now(timezone.utc),
content=combined_chunk,
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
)
prev_message_type = assistant_message.message_type
yield assistant_message
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
self.prev_assistant_message_id = self.function_id_buffer
# clear buffer
self.function_args_buffer = None
self.function_id_buffer = None
else:
# If there's no buffer to clear, just output a new chunk with new data
# TODO: THIS IS HORRIBLE
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg
) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
new_content = parsed_args.get(self.assistant_message_tool_kwarg)
prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "")
# TODO: Assumes consistent state and that prev_content is subset of new_content
diff = new_content.replace(prev_content, "", 1)
self.current_json_parse_result = parsed_args
if prev_message_type and prev_message_type != "assistant_message":
message_index += 1
assistant_message = AssistantMessage(
id=self.letta_assistant_message_id,
date=datetime.now(timezone.utc),
content=diff,
# name=name,
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
)
prev_message_type = assistant_message.message_type
yield assistant_message
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
self.prev_assistant_message_id = self.function_id_buffer
# clear buffers
self.function_id_buffer = None
else:
# There may be a buffer from a previous chunk, for example
# if the previous chunk had arguments but we needed to flush name
if self.function_args_buffer:
# In this case, we should release the buffer + new data at once
combined_chunk = self.function_args_buffer + updates_main_json
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=combined_chunk,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
# clear buffer
self.function_args_buffer = None
self.function_id_buffer = None
else:
# If there's no buffer to clear, just output a new chunk with new data
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
tool_call_id=self.function_id_buffer,
),
# name=name,
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
self.function_id_buffer = None

View File

@@ -745,6 +745,17 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
self.is_deleted = True
return self.update(db_session)
@handle_db_timeout
async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase":
"""Soft delete a record asynchronously (mark as deleted)."""
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
if actor:
self._set_created_and_updated_by_fields(actor.id)
self.is_deleted = True
return await self.update_async(db_session)
@handle_db_timeout
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
"""Permanently removes the record from the database."""
@@ -761,6 +772,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
else:
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
@handle_db_timeout
async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None:
"""Permanently removes the record from the database asynchronously."""
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
async with db_session as session:
try:
await session.delete(self)
await session.commit()
except Exception as e:
await session.rollback()
logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}")
raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}")
@handle_db_timeout
def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
logger.debug(...)
@@ -793,6 +818,39 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
await db_session.refresh(self)
return self
@classmethod
def _size_preprocess(
cls,
*,
db_session: "Session",
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
**kwargs,
):
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
query = select(func.count()).select_from(cls)
if actor:
query = cls.apply_access_predicate(query, actor, access, access_type)
# Apply filtering logic based on kwargs
for key, value in kwargs.items():
if value:
column = getattr(cls, key, None)
if not column:
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
if isinstance(value, (list, tuple, set)): # Check for iterables
query = query.where(column.in_(value))
else: # Single value for equality filtering
query = query.where(column == value)
# Handle soft deletes if the class has the 'is_deleted' attribute
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
return query
@classmethod
@handle_db_timeout
def size(
@@ -817,28 +875,8 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
Raises:
DBAPIError: If a database error occurs
"""
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
with db_session as session:
query = select(func.count()).select_from(cls)
if actor:
query = cls.apply_access_predicate(query, actor, access, access_type)
# Apply filtering logic based on kwargs
for key, value in kwargs.items():
if value:
column = getattr(cls, key, None)
if not column:
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
if isinstance(value, (list, tuple, set)): # Check for iterables
query = query.where(column.in_(value))
else: # Single value for equality filtering
query = query.where(column == value)
# Handle soft deletes if the class has the 'is_deleted' attribute
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
query = cls._size_preprocess(db_session=session, actor=actor, access=access, access_type=access_type, **kwargs)
try:
count = session.execute(query).scalar()
@@ -847,6 +885,37 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
logger.exception(f"Failed to calculate size for {cls.__name__}")
raise e
@classmethod
@handle_db_timeout
async def size_async(
cls,
*,
db_session: "AsyncSession",
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
**kwargs,
) -> int:
"""
Get the count of rows that match the provided filters.
Args:
db_session: SQLAlchemy session
**kwargs: Filters to apply to the query (e.g., column_name=value)
Returns:
int: The count of rows that match the filters
Raises:
DBAPIError: If a database error occurs
"""
async with db_session as session:
query = cls._size_preprocess(db_session=session, actor=actor, access=access, access_type=access_type, **kwargs)
try:
count = await session.execute(query).scalar()
return count if count else 0
except DBAPIError as e:
logger.exception(f"Failed to calculate size for {cls.__name__}")
raise e
@classmethod
def apply_access_predicate(
cls,

View File

@@ -83,7 +83,7 @@ async def list_agents(
"""
# Retrieve the actor (user) details
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# Call list_agents directly without unnecessary dict handling
return await server.agent_manager.list_agents_async(
@@ -163,7 +163,7 @@ async def import_agent_serialized(
"""
Import a serialized agent file and recreate the agent in the system.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
serialized_data = await file.read()
@@ -233,7 +233,7 @@ async def create_agent(
Create a new agent with the specified configuration.
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.create_agent_async(agent, actor=actor)
except Exception as e:
traceback.print_exc()
@@ -248,7 +248,7 @@ async def modify_agent(
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Update an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.update_agent_async(agent_id=agent_id, request=update_agent, actor=actor)
@@ -333,7 +333,7 @@ def detach_source(
@router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent")
def retrieve_agent(
async def retrieve_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -344,7 +344,7 @@ def retrieve_agent(
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -414,7 +414,7 @@ def retrieve_block(
@router.get("/{agent_id}/core-memory/blocks", response_model=List[Block], operation_id="list_core_memory_blocks")
def list_blocks(
async def list_blocks(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -424,7 +424,7 @@ def list_blocks(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
return agent.memory.blocks
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -628,9 +628,9 @@ async def send_message(
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# TODO: This is redundant, remove soon
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
@@ -686,13 +686,13 @@ async def send_message_streaming(
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
"""
request_start_timestamp_ns = get_utc_timestamp_ns()
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# TODO: This is redundant, remove soon
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
model_compatible = agent.llm_config.model_endpoint_type == "anthropic"
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
if agent_eligible and feature_enabled and model_compatible and request.stream_tokens:
experimental_agent = LettaAgent(
@@ -705,7 +705,9 @@ async def send_message_streaming(
)
result = StreamingResponse(
experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message),
experimental_agent.step_stream(
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens
),
media_type="text/event-stream",
)
else:
@@ -784,7 +786,7 @@ async def send_message_async(
Asynchronously process a user message and return a run object.
The actual processing happens in the background, and the status can be checked using the run ID.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# Create a new job
run = Run(
@@ -838,6 +840,6 @@ async def list_agent_groups(
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Lists the groups for an agent"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
print("in list agents with manager_type", manager_type)
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)

View File

@@ -26,7 +26,7 @@ async def list_blocks(
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.block_manager.get_blocks_async(
actor=actor,
label=label,

View File

@@ -135,7 +135,7 @@ async def send_group_message(
Process a user message and return the group's response.
This endpoint accepts a message from a user and processes it through through agents in the group based on the specified pattern
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
result = await server.send_group_message_to_agent(
group_id=group_id,
actor=actor,
@@ -174,7 +174,7 @@ async def send_group_message_streaming(
This endpoint accepts a message from a user and processes it through agents in the group based on the specified pattern.
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
result = await server.send_group_message_to_agent(
group_id=group_id,
actor=actor,

View File

@@ -52,7 +52,7 @@ async def create_messages_batch(
detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.",
)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
batch_job = BatchJob(
user_id=actor.id,
status=JobStatus.running,
@@ -100,7 +100,7 @@ async def retrieve_batch_run(
"""
Get the status of a batch run.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor)
@@ -118,7 +118,7 @@ async def list_batch_runs(
List all batch runs.
"""
# TODO: filter
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
jobs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.BATCH)
return [BatchJob.from_job(job) for job in jobs]
@@ -150,7 +150,7 @@ async def list_batch_messages(
- For subsequent pages, use the ID of the last message from the previous response as the cursor
- Results will include messages before/after the cursor based on sort_descending
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# First, verify the batch job exists and the user has access to it
try:
@@ -177,7 +177,7 @@ async def cancel_batch_run(
"""
Cancel a batch run.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor)

View File

@@ -115,7 +115,7 @@ async def list_run_messages(
if order not in ["asc", "desc"]:
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
messages = server.job_manager.get_run_messages(
@@ -182,7 +182,7 @@ async def list_run_steps(
if order not in ["asc", "desc"]:
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
steps = server.job_manager.get_job_steps(

View File

@@ -87,7 +87,7 @@ async def list_tools(
Get a list of all tools available to agents belonging to the org of the user
"""
try:
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
if name is not None:
tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor)
return [tool] if tool else []

View File

@@ -14,7 +14,7 @@ router = APIRouter(prefix="/users", tags=["users", "admin"])
@router.get("/", tags=["admin"], response_model=List[User], operation_id="list_users")
def list_users(
async def list_users(
after: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_letta_server),
@@ -23,7 +23,7 @@ def list_users(
Get a list of all users in the database
"""
try:
users = server.user_manager.list_users(after=after, limit=limit)
users = await server.user_manager.list_actors_async(after=after, limit=limit)
except HTTPException:
raise
except Exception as e:
@@ -32,7 +32,7 @@ def list_users(
@router.post("/", tags=["admin"], response_model=User, operation_id="create_user")
def create_user(
async def create_user(
request: UserCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
@@ -40,33 +40,33 @@ def create_user(
Create a new user in the database
"""
user = User(**request.model_dump())
user = server.user_manager.create_user(user)
user = await server.user_manager.create_actor_async(user)
return user
@router.put("/", tags=["admin"], response_model=User, operation_id="update_user")
def update_user(
async def update_user(
user: UserUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Update a user in the database
"""
user = server.user_manager.update_user(user)
user = await server.user_manager.update_actor_async(user)
return user
@router.delete("/", tags=["admin"], response_model=User, operation_id="delete_user")
def delete_user(
async def delete_user(
user_id: str = Query(..., description="The user_id key to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
):
# TODO make a soft deletion, instead of a hard deletion
try:
user = server.user_manager.get_user_by_id(user_id=user_id)
user = await server.user_manager.get_actor_by_id_async(actor_id=user_id)
if user is None:
raise HTTPException(status_code=404, detail=f"User does not exist")
server.user_manager.delete_user_by_id(user_id=user_id)
await server.user_manager.delete_actor_by_id_async(user_id=user_id)
except HTTPException:
raise
except Exception as e:

View File

@@ -36,7 +36,7 @@ async def create_voice_chat_completions(
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
# Create OpenAI async client
client = openai.AsyncClient(

View File

@@ -1,3 +1,4 @@
import asyncio
from datetime import datetime, timezone
from typing import Dict, List, Optional, Set, Tuple
@@ -905,12 +906,7 @@ class AgentManager:
result = await session.execute(query)
agents = result.scalars().all()
pydantic_agents = []
for agent in agents:
pydantic_agent = await agent.to_pydantic_async(include_relationships=include_relationships)
pydantic_agents.append(pydantic_agent)
return pydantic_agents
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
@enforce_types
def list_agents_matching_tags(
@@ -1195,8 +1191,8 @@ class AgentManager:
@enforce_types
async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
return await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=actor)
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
@enforce_types
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:

View File

@@ -286,6 +286,21 @@ class MessageManager:
with db_registry.session() as session:
return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id)
@enforce_types
async def size_async(
self,
actor: PydanticUser,
role: Optional[MessageRole] = None,
agent_id: Optional[str] = None,
) -> int:
"""Get the total count of messages with optional filters.
Args:
actor: The user requesting the count
role: The role of the message
"""
async with db_registry.async_session() as session:
return await MessageModel.size_async(db_session=session, actor=actor, role=role, agent_id=agent_id)
@enforce_types
def list_user_messages_for_agent(
self,

View File

@@ -216,6 +216,20 @@ class PassageManager:
with db_registry.session() as session:
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
@enforce_types
async def size_async(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
) -> int:
"""Get the total count of messages with optional filters.
Args:
actor: The user requesting the count
agent_id: The agent ID of the messages
"""
async with db_registry.async_session() as session:
return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id)
def estimate_embeddings_size(
self,
actor: PydanticUser,

View File

@@ -44,6 +44,14 @@ class UserManager:
new_user.create(session)
return new_user.to_pydantic()
@enforce_types
async def create_actor_async(self, pydantic_user: PydanticUser) -> PydanticUser:
"""Create a new user if it doesn't already exist (async version)."""
async with db_registry.async_session() as session:
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
await new_user.create_async(session)
return new_user.to_pydantic()
@enforce_types
def update_user(self, user_update: UserUpdate) -> PydanticUser:
"""Update user details."""
@@ -60,6 +68,22 @@ class UserManager:
existing_user.update(session)
return existing_user.to_pydantic()
@enforce_types
async def update_actor_async(self, user_update: UserUpdate) -> PydanticUser:
"""Update user details (async version)."""
async with db_registry.async_session() as session:
# Retrieve the existing user by ID
existing_user = await UserModel.read_async(db_session=session, identifier=user_update.id)
# Update only the fields that are provided in UserUpdate
update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(existing_user, key, value)
# Commit the updated user
await existing_user.update_async(session)
return existing_user.to_pydantic()
@enforce_types
def delete_user_by_id(self, user_id: str):
"""Delete a user and their associated records (agents, sources, mappings)."""
@@ -70,6 +94,14 @@ class UserManager:
session.commit()
@enforce_types
async def delete_actor_by_id_async(self, user_id: str):
"""Delete a user and their associated records (agents, sources, mappings) asynchronously."""
async with db_registry.async_session() as session:
# Delete from user table
user = await UserModel.read_async(db_session=session, identifier=user_id)
await user.hard_delete_async(session)
@enforce_types
def get_user_by_id(self, user_id: str) -> PydanticUser:
"""Fetch a user by ID."""
@@ -77,6 +109,13 @@ class UserManager:
user = UserModel.read(db_session=session, identifier=user_id)
return user.to_pydantic()
@enforce_types
async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser:
"""Fetch a user by ID asynchronously."""
async with db_registry.async_session() as session:
user = await UserModel.read_async(db_session=session, identifier=actor_id)
return user.to_pydantic()
@enforce_types
def get_default_user(self) -> PydanticUser:
"""Fetch the default user. If it doesn't exist, create it."""
@@ -96,6 +135,26 @@ class UserManager:
except NoResultFound:
return self.get_default_user()
@enforce_types
async def get_default_actor_async(self) -> PydanticUser:
"""Fetch the default user asynchronously. If it doesn't exist, create it."""
try:
return await self.get_actor_by_id_async(self.DEFAULT_USER_ID)
except NoResultFound:
# Fall back to synchronous version since create_default_user isn't async yet
return self.create_default_user(org_id=self.DEFAULT_ORG_ID)
@enforce_types
async def get_actor_or_default_async(self, actor_id: Optional[str] = None):
"""Fetch the user or default user asynchronously."""
if not actor_id:
return await self.get_default_actor_async()
try:
return await self.get_actor_by_id_async(actor_id=actor_id)
except NoResultFound:
return await self.get_default_actor_async()
@enforce_types
def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
"""List all users with optional pagination."""
@@ -106,3 +165,14 @@ class UserManager:
limit=limit,
)
return [user.to_pydantic() for user in users]
@enforce_types
async def list_actors_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
"""List all users with optional pagination (async version)."""
async with db_registry.async_session() as session:
users = await UserModel.list_async(
db_session=session,
after=after,
limit=limit,
)
return [user.to_pydantic() for user in users]

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.7.16"
version = "0.7.17"
packages = [
{include = "letta"},
]

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from dotenv import load_dotenv
from letta_client import Letta
from letta_client import AsyncLetta
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionChunk
@@ -130,12 +130,12 @@ def server_url():
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = Letta(base_url=server_url)
client = AsyncLetta(base_url=server_url)
yield client
@pytest.fixture(scope="function")
def roll_dice_tool(client):
async def roll_dice_tool(client):
def roll_dice():
"""
Rolls a 6 sided die.
@@ -145,13 +145,13 @@ def roll_dice_tool(client):
"""
return "Rolled a 10!"
tool = client.tools.upsert_from_function(func=roll_dice)
tool = await client.tools.upsert_from_function(func=roll_dice)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def weather_tool(client):
async def weather_tool(client):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
@@ -176,7 +176,7 @@ def weather_tool(client):
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
tool = client.tools.upsert_from_function(func=get_weather)
tool = await client.tools.upsert_from_function(func=get_weather)
# Yield the created tool
yield tool
@@ -270,7 +270,7 @@ def _assert_valid_chunk(chunk, idx, chunks):
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"])
async def test_model_compatibility(disable_e2b_api_key, client, model, server, group_id, actor):
async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, group_id, actor):
request = _get_chat_request("How are you?")
server.tool_manager.upsert_base_tools(actor=actor)
@@ -306,7 +306,7 @@ async def test_model_compatibility(disable_e2b_api_key, client, model, server, g
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."])
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
async def test_voice_recall_memory(disable_e2b_api_key, client, voice_agent, message, endpoint):
async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint):
"""Tests chat completion streaming using the Async OpenAI client."""
request = _get_chat_request(message)
@@ -320,7 +320,7 @@ async def test_voice_recall_memory(disable_e2b_api_key, client, voice_agent, mes
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
async def test_trigger_summarization(disable_e2b_api_key, client, server, voice_agent, group_id, endpoint, actor):
async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor):
server.group_manager.modify_group(
group_id=group_id,
group_update=GroupUpdate(
@@ -423,7 +423,7 @@ async def test_summarization(disable_e2b_api_key, voice_agent):
@pytest.mark.asyncio(loop_scope="session")
async def test_voice_sleeptime_agent(disable_e2b_api_key, client, voice_agent):
async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent):
"""Tests chat completion streaming using the Async OpenAI client."""
agent_manager = AgentManager()
tool_manager = ToolManager()

View File

@@ -124,16 +124,16 @@ def default_user(server: SyncServer, default_organization):
@pytest.fixture
def other_user(server: SyncServer, default_organization):
async def other_user(server: SyncServer, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_user(PydanticUser(name="other", organization_id=default_organization.id))
user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=default_organization.id))
yield user
@pytest.fixture
def other_user_different_org(server: SyncServer, other_organization):
async def other_user_different_org(server: SyncServer, other_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_user(PydanticUser(name="other", organization_id=other_organization.id))
user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=other_organization.id))
yield user
@@ -2160,20 +2160,21 @@ def test_passage_cascade_deletion(
# ======================================================================================================================
# User Manager Tests
# ======================================================================================================================
def test_list_users(server: SyncServer):
@pytest.mark.asyncio
async def test_list_users(server: SyncServer, event_loop):
# Create default organization
org = server.organization_manager.create_default_organization()
user_name = "user"
user = server.user_manager.create_user(PydanticUser(name=user_name, organization_id=org.id))
user = await server.user_manager.create_actor_async(PydanticUser(name=user_name, organization_id=org.id))
users = server.user_manager.list_users()
users = await server.user_manager.list_actors_async()
assert len(users) == 1
assert users[0].name == user_name
# Delete it after
server.user_manager.delete_user_by_id(user.id)
assert len(server.user_manager.list_users()) == 0
await server.user_manager.delete_actor_by_id_async(user.id)
assert len(await server.user_manager.list_actors_async()) == 0
def test_create_default_user(server: SyncServer):
@@ -2183,7 +2184,8 @@ def test_create_default_user(server: SyncServer):
assert retrieved.name == server.user_manager.DEFAULT_USER_NAME
def test_update_user(server: SyncServer):
@pytest.mark.asyncio
async def test_update_user(server: SyncServer, event_loop):
# Create default organization
default_org = server.organization_manager.create_default_organization()
test_org = server.organization_manager.create_organization(PydanticOrganization(name="test_org"))
@@ -2192,16 +2194,16 @@ def test_update_user(server: SyncServer):
user_name_b = "b"
# Assert it's been created
user = server.user_manager.create_user(PydanticUser(name=user_name_a, organization_id=default_org.id))
user = await server.user_manager.create_actor_async(PydanticUser(name=user_name_a, organization_id=default_org.id))
assert user.name == user_name_a
# Adjust name
user = server.user_manager.update_user(UserUpdate(id=user.id, name=user_name_b))
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, name=user_name_b))
assert user.name == user_name_b
assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID
# Adjust org id
user = server.user_manager.update_user(UserUpdate(id=user.id, organization_id=test_org.id))
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, organization_id=test_org.id))
assert user.name == user_name_b
assert user.organization_id == test_org.id