feat: async list/prepare messages (#2181)
Co-authored-by: Caren Thomas <carenthomas@gmail.com>
This commit is contained in:
@@ -60,6 +60,46 @@ def _prepare_in_context_messages(
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
|
||||
|
||||
async def _prepare_in_context_messages_async(
|
||||
input_messages: List[MessageCreate],
|
||||
agent_state: AgentState,
|
||||
message_manager: MessageManager,
|
||||
actor: User,
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
"""
|
||||
Prepares in-context messages for an agent, based on the current state and a new user input.
|
||||
Async version of _prepare_in_context_messages.
|
||||
|
||||
Args:
|
||||
input_messages (List[MessageCreate]): The new user input messages to process.
|
||||
agent_state (AgentState): The current state of the agent, including message buffer config.
|
||||
message_manager (MessageManager): The manager used to retrieve and create messages.
|
||||
actor (User): The user performing the action, used for access control and attribution.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Message], List[Message]]: A tuple containing:
|
||||
- The current in-context messages (existing context for the agent).
|
||||
- The new in-context messages (messages created from the new input).
|
||||
"""
|
||||
|
||||
if agent_state.message_buffer_autoclear:
|
||||
# If autoclear is enabled, only include the most recent system message (usually at index 0)
|
||||
current_in_context_messages = [
|
||||
(await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor))[0]
|
||||
]
|
||||
else:
|
||||
# Otherwise, include the full list of messages by ID for context
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
|
||||
# Create a new user message from the input and store it
|
||||
# TODO: make this async
|
||||
new_in_context_messages = message_manager.create_many_messages(
|
||||
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=actor), actor=actor
|
||||
)
|
||||
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
|
||||
|
||||
def serialize_message_history(messages: List[str], context: str) -> str:
|
||||
"""
|
||||
Produce an XML document like:
|
||||
|
||||
@@ -7,7 +7,7 @@ from aiomultiprocess import Pool
|
||||
from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMessageBatchErroredResult, BetaMessageBatchSucceededResult
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.helpers import _prepare_in_context_messages
|
||||
from letta.agents.helpers import _prepare_in_context_messages_async
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
@@ -126,6 +126,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
letta_batch_job_id: str,
|
||||
agent_step_state_mapping: Optional[Dict[str, AgentStepState]] = None,
|
||||
) -> LettaBatchResponse:
|
||||
"""Carry out agent steps until the LLM request is sent."""
|
||||
log_event(name="validate_inputs")
|
||||
if not batch_requests:
|
||||
raise ValueError("Empty list of batch_requests passed in!")
|
||||
@@ -133,15 +134,26 @@ class LettaAgentBatch(BaseAgent):
|
||||
agent_step_state_mapping = {}
|
||||
|
||||
log_event(name="load_and_prepare_agents")
|
||||
agent_messages_mapping: Dict[str, List[Message]] = {}
|
||||
agent_tools_mapping: Dict[str, List[dict]] = {}
|
||||
# prepares (1) agent states, (2) step states, (3) LLMBatchItems (4) message batch_item_ids (5) messages per agent (6) tools per agent
|
||||
|
||||
agent_messages_mapping: dict[str, list[Message]] = {}
|
||||
agent_tools_mapping: dict[str, list[dict]] = {}
|
||||
# TODO: This isn't optimal, moving fast - prone to bugs because we pass around this half formed pydantic object
|
||||
agent_batch_item_mapping: Dict[str, LLMBatchItem] = {}
|
||||
agent_batch_item_mapping: dict[str, LLMBatchItem] = {}
|
||||
|
||||
# fetch agent states in batch
|
||||
agent_mapping = {
|
||||
agent_state.id: agent_state
|
||||
for agent_state in await self.agent_manager.get_agents_by_ids_async(
|
||||
agent_ids=[request.agent_id for request in batch_requests], actor=self.actor
|
||||
)
|
||||
}
|
||||
|
||||
agent_states = []
|
||||
for batch_request in batch_requests:
|
||||
agent_id = batch_request.agent_id
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor)
|
||||
agent_states.append(agent_state)
|
||||
agent_state = agent_mapping[agent_id]
|
||||
agent_states.append(agent_state) # keeping this to maintain ordering, but may not be necessary
|
||||
|
||||
if agent_id not in agent_step_state_mapping:
|
||||
agent_step_state_mapping[agent_id] = AgentStepState(
|
||||
@@ -162,7 +174,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
for msg in batch_request.messages:
|
||||
msg.batch_item_id = llm_batch_item.id
|
||||
|
||||
agent_messages_mapping[agent_id] = self._prepare_in_context_messages_per_agent(
|
||||
agent_messages_mapping[agent_id] = await self._prepare_in_context_messages_per_agent_async(
|
||||
agent_state=agent_state, input_messages=batch_request.messages
|
||||
)
|
||||
|
||||
@@ -528,12 +540,14 @@ class LettaAgentBatch(BaseAgent):
|
||||
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
|
||||
return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
|
||||
def _prepare_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]:
|
||||
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
|
||||
async def _prepare_in_context_messages_per_agent_async(
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate]
|
||||
) -> List[Message]:
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
|
||||
in_context_messages = self._rebuild_memory(current_in_context_messages + new_in_context_messages, agent_state)
|
||||
in_context_messages = await self._rebuild_memory_async(current_in_context_messages + new_in_context_messages, agent_state)
|
||||
return in_context_messages
|
||||
|
||||
# TODO: Make this a bullk function
|
||||
|
||||
@@ -12,7 +12,7 @@ from composio.exceptions import (
|
||||
)
|
||||
|
||||
|
||||
class AsyncComposioToolSet(BaseComposioToolSet, runtime="letta"):
|
||||
class AsyncComposioToolSet(BaseComposioToolSet, runtime="letta", description_char_limit=1024):
|
||||
"""
|
||||
Async version of ComposioToolSet client for interacting with Composio API
|
||||
Used to asynchronously hit the execute action endpoint
|
||||
|
||||
@@ -114,155 +114,324 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
if before_obj and after_obj and before_obj.created_at < after_obj.created_at:
|
||||
raise ValueError("'before' reference must be later than 'after' reference")
|
||||
|
||||
query = select(cls)
|
||||
query = cls._list_preprocess(
|
||||
before_obj=before_obj,
|
||||
after_obj=after_obj,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
query_text=query_text,
|
||||
query_embedding=query_embedding,
|
||||
ascending=ascending,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
actor=actor,
|
||||
access=access,
|
||||
access_type=access_type,
|
||||
join_model=join_model,
|
||||
join_conditions=join_conditions,
|
||||
identifier_keys=identifier_keys,
|
||||
identity_id=identity_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if join_model and join_conditions:
|
||||
query = query.join(join_model, and_(*join_conditions))
|
||||
# Execute the query
|
||||
results = session.execute(query)
|
||||
|
||||
# Apply access predicate if actor is provided
|
||||
if actor:
|
||||
query = cls.apply_access_predicate(query, actor, access, access_type)
|
||||
|
||||
# Handle tag filtering if the model has tags
|
||||
if tags and hasattr(cls, "tags"):
|
||||
query = select(cls)
|
||||
|
||||
if match_all_tags:
|
||||
# Match ALL tags - use subqueries
|
||||
subquery = (
|
||||
select(cls.tags.property.mapper.class_.agent_id)
|
||||
.where(cls.tags.property.mapper.class_.tag.in_(tags))
|
||||
.group_by(cls.tags.property.mapper.class_.agent_id)
|
||||
.having(func.count() == len(tags))
|
||||
)
|
||||
query = query.filter(cls.id.in_(subquery))
|
||||
else:
|
||||
# Match ANY tag - use join and filter
|
||||
query = (
|
||||
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
|
||||
) # Deduplicate results
|
||||
|
||||
# select distinct primary key
|
||||
query = query.distinct(cls.id).order_by(cls.id)
|
||||
|
||||
if identifier_keys and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
||||
|
||||
# given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids
|
||||
if identity_id and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
|
||||
|
||||
# Apply filtering logic from kwargs
|
||||
for key, value in kwargs.items():
|
||||
if "." in key:
|
||||
# Handle joined table columns
|
||||
table_name, column_name = key.split(".")
|
||||
joined_table = locals().get(table_name) or globals().get(table_name)
|
||||
column = getattr(joined_table, column_name)
|
||||
else:
|
||||
# Handle columns from main table
|
||||
column = getattr(cls, key)
|
||||
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
query = query.where(column.in_(value))
|
||||
else:
|
||||
query = query.where(column == value)
|
||||
|
||||
# Date range filtering
|
||||
if start_date:
|
||||
query = query.filter(cls.created_at > start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at < end_date)
|
||||
|
||||
# Handle pagination based on before/after
|
||||
if before or after:
|
||||
conditions = []
|
||||
|
||||
if before and after:
|
||||
# Window-based query - get records between before and after
|
||||
conditions = [
|
||||
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id)),
|
||||
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id)),
|
||||
]
|
||||
else:
|
||||
# Pure pagination query
|
||||
if before:
|
||||
conditions.append(
|
||||
or_(
|
||||
cls.created_at < before_obj.created_at,
|
||||
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
|
||||
)
|
||||
)
|
||||
if after:
|
||||
conditions.append(
|
||||
or_(
|
||||
cls.created_at > after_obj.created_at,
|
||||
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
|
||||
)
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
# Text search
|
||||
if query_text:
|
||||
if hasattr(cls, "text"):
|
||||
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
||||
elif hasattr(cls, "name"):
|
||||
# Special case for Agent model - search across name
|
||||
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
|
||||
|
||||
# Embedding search (for Passages)
|
||||
is_ordered = False
|
||||
if query_embedding:
|
||||
if not hasattr(cls, "embedding"):
|
||||
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL with pgvector
|
||||
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
query_embedding_binary = adapt_array(query_embedding)
|
||||
query = query.order_by(
|
||||
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
||||
cls.created_at.asc() if ascending else cls.created_at.desc(),
|
||||
cls.id.asc(),
|
||||
)
|
||||
is_ordered = True
|
||||
|
||||
# Handle soft deletes
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
|
||||
# Apply ordering
|
||||
if not is_ordered:
|
||||
if ascending:
|
||||
query = query.order_by(cls.created_at.asc(), cls.id.asc())
|
||||
else:
|
||||
query = query.order_by(cls.created_at.desc(), cls.id.desc())
|
||||
|
||||
# Apply limit, adjusting for both bounds if necessary
|
||||
if before and after:
|
||||
# When both bounds are provided, we need to fetch enough records to satisfy
|
||||
# the limit while respecting both bounds. We'll fetch more and then trim.
|
||||
query = query.limit(limit * 2)
|
||||
else:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = list(session.execute(query).scalars())
|
||||
|
||||
# If we have both bounds, take the middle portion
|
||||
if before and after and len(results) > limit:
|
||||
middle = len(results) // 2
|
||||
start = max(0, middle - limit // 2)
|
||||
end = min(len(results), start + limit)
|
||||
results = results[start:end]
|
||||
results = list(results.scalars())
|
||||
results = cls._list_postprocess(
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
results=results,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
async def list_async(
|
||||
cls,
|
||||
*,
|
||||
db_session: "AsyncSession",
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
ascending: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
join_model: Optional[Base] = None,
|
||||
join_conditions: Optional[Union[Tuple, List]] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
identity_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List["SqlalchemyBase"]:
|
||||
"""
|
||||
Async version of list method above.
|
||||
NOTE: Keep in sync.
|
||||
List records with before/after pagination, ordering by created_at.
|
||||
Can use both before and after to fetch a window of records.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
before: ID of item to paginate before (upper bound)
|
||||
after: ID of item to paginate after (lower bound)
|
||||
start_date: Filter items after this date
|
||||
end_date: Filter items before this date
|
||||
limit: Maximum number of items to return
|
||||
query_text: Text to search for
|
||||
query_embedding: Vector to search for similar embeddings
|
||||
ascending: Sort direction
|
||||
tags: List of tags to filter by
|
||||
match_all_tags: If True, return items matching all tags. If False, match any tag.
|
||||
**kwargs: Additional filters to apply
|
||||
"""
|
||||
if start_date and end_date and start_date > end_date:
|
||||
raise ValueError("start_date must be earlier than or equal to end_date")
|
||||
|
||||
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
|
||||
|
||||
async with db_session as session:
|
||||
# Get the reference objects for pagination
|
||||
before_obj = None
|
||||
after_obj = None
|
||||
|
||||
if before:
|
||||
before_obj = await session.get(cls, before)
|
||||
if not before_obj:
|
||||
raise NoResultFound(f"No {cls.__name__} found with id {before}")
|
||||
|
||||
if after:
|
||||
after_obj = await session.get(cls, after)
|
||||
if not after_obj:
|
||||
raise NoResultFound(f"No {cls.__name__} found with id {after}")
|
||||
|
||||
# Validate that before comes after the after object if both are provided
|
||||
if before_obj and after_obj and before_obj.created_at < after_obj.created_at:
|
||||
raise ValueError("'before' reference must be later than 'after' reference")
|
||||
|
||||
query = cls._list_preprocess(
|
||||
before_obj=before_obj,
|
||||
after_obj=after_obj,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
query_text=query_text,
|
||||
query_embedding=query_embedding,
|
||||
ascending=ascending,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
actor=actor,
|
||||
access=access,
|
||||
access_type=access_type,
|
||||
join_model=join_model,
|
||||
join_conditions=join_conditions,
|
||||
identifier_keys=identifier_keys,
|
||||
identity_id=identity_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
results = await session.execute(query)
|
||||
|
||||
results = list(results.scalars())
|
||||
results = cls._list_postprocess(
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
results=results,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def _list_preprocess(
|
||||
cls,
|
||||
*,
|
||||
before_obj,
|
||||
after_obj,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
ascending: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
join_model: Optional[Base] = None,
|
||||
join_conditions: Optional[Union[Tuple, List]] = None,
|
||||
identifier_keys: Optional[List[str]] = None,
|
||||
identity_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Constructs the query for listing records.
|
||||
"""
|
||||
query = select(cls)
|
||||
|
||||
if join_model and join_conditions:
|
||||
query = query.join(join_model, and_(*join_conditions))
|
||||
|
||||
# Apply access predicate if actor is provided
|
||||
if actor:
|
||||
query = cls.apply_access_predicate(query, actor, access, access_type)
|
||||
|
||||
# Handle tag filtering if the model has tags
|
||||
if tags and hasattr(cls, "tags"):
|
||||
query = select(cls)
|
||||
|
||||
if match_all_tags:
|
||||
# Match ALL tags - use subqueries
|
||||
subquery = (
|
||||
select(cls.tags.property.mapper.class_.agent_id)
|
||||
.where(cls.tags.property.mapper.class_.tag.in_(tags))
|
||||
.group_by(cls.tags.property.mapper.class_.agent_id)
|
||||
.having(func.count() == len(tags))
|
||||
)
|
||||
query = query.filter(cls.id.in_(subquery))
|
||||
else:
|
||||
# Match ANY tag - use join and filter
|
||||
query = (
|
||||
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
|
||||
) # Deduplicate results
|
||||
|
||||
# select distinct primary key
|
||||
query = query.distinct(cls.id).order_by(cls.id)
|
||||
|
||||
if identifier_keys and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
||||
|
||||
# given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids
|
||||
if identity_id and hasattr(cls, "identities"):
|
||||
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
|
||||
|
||||
# Apply filtering logic from kwargs
|
||||
for key, value in kwargs.items():
|
||||
if "." in key:
|
||||
# Handle joined table columns
|
||||
table_name, column_name = key.split(".")
|
||||
joined_table = locals().get(table_name) or globals().get(table_name)
|
||||
column = getattr(joined_table, column_name)
|
||||
else:
|
||||
# Handle columns from main table
|
||||
column = getattr(cls, key)
|
||||
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
query = query.where(column.in_(value))
|
||||
else:
|
||||
query = query.where(column == value)
|
||||
|
||||
# Date range filtering
|
||||
if start_date:
|
||||
query = query.filter(cls.created_at > start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at < end_date)
|
||||
|
||||
# Handle pagination based on before/after
|
||||
if before_obj or after_obj:
|
||||
conditions = []
|
||||
|
||||
if before_obj and after_obj:
|
||||
# Window-based query - get records between before and after
|
||||
conditions = [
|
||||
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id)),
|
||||
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id)),
|
||||
]
|
||||
else:
|
||||
# Pure pagination query
|
||||
if before_obj:
|
||||
conditions.append(
|
||||
or_(
|
||||
cls.created_at < before_obj.created_at,
|
||||
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
|
||||
)
|
||||
)
|
||||
if after_obj:
|
||||
conditions.append(
|
||||
or_(
|
||||
cls.created_at > after_obj.created_at,
|
||||
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
|
||||
)
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
# Text search
|
||||
if query_text:
|
||||
if hasattr(cls, "text"):
|
||||
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
||||
elif hasattr(cls, "name"):
|
||||
# Special case for Agent model - search across name
|
||||
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
|
||||
|
||||
# Embedding search (for Passages)
|
||||
is_ordered = False
|
||||
if query_embedding:
|
||||
if not hasattr(cls, "embedding"):
|
||||
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL with pgvector
|
||||
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
query_embedding_binary = adapt_array(query_embedding)
|
||||
query = query.order_by(
|
||||
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
||||
cls.created_at.asc() if ascending else cls.created_at.desc(),
|
||||
cls.id.asc(),
|
||||
)
|
||||
is_ordered = True
|
||||
|
||||
# Handle soft deletes
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
|
||||
# Apply ordering
|
||||
if not is_ordered:
|
||||
if ascending:
|
||||
query = query.order_by(cls.created_at.asc(), cls.id.asc())
|
||||
else:
|
||||
query = query.order_by(cls.created_at.desc(), cls.id.desc())
|
||||
|
||||
# Apply limit, adjusting for both bounds if necessary
|
||||
if before_obj and after_obj:
|
||||
# When both bounds are provided, we need to fetch enough records to satisfy
|
||||
# the limit while respecting both bounds. We'll fetch more and then trim.
|
||||
query = query.limit(limit * 2)
|
||||
else:
|
||||
query = query.limit(limit)
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def _list_postprocess(
|
||||
cls,
|
||||
before: str | None,
|
||||
after: str | None,
|
||||
limit: int | None,
|
||||
results: list,
|
||||
):
|
||||
# If we have both bounds, take the middle portion
|
||||
if before and after and len(results) > limit:
|
||||
middle = len(results) // 2
|
||||
start = max(0, middle - limit // 2)
|
||||
end = min(len(results), start + limit)
|
||||
results = results[start:end]
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def read(
|
||||
@@ -305,7 +474,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
@handle_db_timeout
|
||||
async def read_async(
|
||||
cls,
|
||||
db_session: "Session",
|
||||
db_session: "AsyncSession",
|
||||
identifier: Optional[str] = None,
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
|
||||
@@ -577,6 +577,13 @@ class AgentManager:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_agents_by_ids_async(self, agent_ids: list[str], actor: PydanticUser) -> list[PydanticAgentState]:
|
||||
"""Fetch a list of agents by their IDs."""
|
||||
async with db_registry.async_session() as session:
|
||||
agents = await AgentModel.read_multiple_async(db_session=session, identifiers=agent_ids, actor=actor)
|
||||
return [agent.to_pydantic() for agent in agents]
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
|
||||
@@ -36,15 +36,29 @@ class MessageManager:
|
||||
"""Fetch messages by ID and return them in the requested order."""
|
||||
with db_registry.session() as session:
|
||||
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids))
|
||||
return self._get_messages_by_id_postprocess(results, message_ids)
|
||||
|
||||
if len(results) != len(message_ids):
|
||||
logger.warning(
|
||||
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
|
||||
)
|
||||
@enforce_types
|
||||
async def get_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
|
||||
"""Fetch messages by ID and return them in the requested order. Async version of above function."""
|
||||
async with db_registry.async_session() as session:
|
||||
results = await MessageModel.list_async(
|
||||
db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids)
|
||||
)
|
||||
return self._get_messages_by_id_postprocess(results, message_ids)
|
||||
|
||||
# Sort results directly based on message_ids
|
||||
result_dict = {msg.id: msg.to_pydantic() for msg in results}
|
||||
return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids]))
|
||||
def _get_messages_by_id_postprocess(
|
||||
self,
|
||||
results: List[MessageModel],
|
||||
message_ids: List[str],
|
||||
) -> List[PydanticMessage]:
|
||||
if len(results) != len(message_ids):
|
||||
logger.warning(
|
||||
f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}"
|
||||
)
|
||||
# Sort results directly based on message_ids
|
||||
result_dict = {msg.id: msg.to_pydantic() for msg in results}
|
||||
return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids]))
|
||||
|
||||
@enforce_types
|
||||
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
|
||||
|
||||
0
letta/types/__init__.py
Normal file
0
letta/types/__init__.py
Normal file
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
@@ -164,6 +165,14 @@ def step_state_map(agents):
|
||||
return {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request):
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
def create_batch_response(batch_id: str, processing_status: str = "in_progress") -> BetaMessageBatch:
|
||||
"""Create a dummy BetaMessageBatch with the specified ID and status."""
|
||||
now = datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc)
|
||||
@@ -452,7 +461,7 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_error_from_anthropic_batch(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop
|
||||
):
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
dummy_batch_response = create_batch_response(
|
||||
@@ -594,7 +603,7 @@ async def test_partial_error_from_anthropic_batch(
|
||||
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
|
||||
)
|
||||
assert len(messages) == (len(agents) - 1) * 4 + 1
|
||||
assert_descending_order(messages)
|
||||
_assert_descending_order(messages)
|
||||
# Check that each agent is represented
|
||||
for agent in agents_continue:
|
||||
agent_messages = [m for m in messages if m.agent_id == agent.id]
|
||||
@@ -612,7 +621,7 @@ async def test_partial_error_from_anthropic_batch(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_some_stop(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop
|
||||
):
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
dummy_batch_response = create_batch_response(
|
||||
@@ -743,7 +752,7 @@ async def test_resume_step_some_stop(
|
||||
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
|
||||
)
|
||||
assert len(messages) == len(agents) * 3 + 1
|
||||
assert_descending_order(messages)
|
||||
_assert_descending_order(messages)
|
||||
# Check that each agent is represented
|
||||
for agent in agents_continue:
|
||||
agent_messages = [m for m in messages if m.agent_id == agent.id]
|
||||
@@ -761,23 +770,21 @@ async def test_resume_step_some_stop(
|
||||
assert agent_messages[-3].role == MessageRole.tool, "Expected tool response after assistant tool call"
|
||||
|
||||
|
||||
def assert_descending_order(messages):
|
||||
"""Assert messages are in descending order by created_at timestamps."""
|
||||
def _assert_descending_order(messages):
|
||||
"""Assert messages are in monotonically decreasing by created_at timestamps."""
|
||||
if len(messages) <= 1:
|
||||
return True
|
||||
|
||||
for i in range(1, len(messages)):
|
||||
assert messages[i].created_at <= messages[i - 1].created_at, (
|
||||
f"Order violation: {messages[i - 1].id} ({messages[i - 1].created_at}) "
|
||||
f"followed by {messages[i].id} ({messages[i].created_at})"
|
||||
)
|
||||
|
||||
for prev, next in zip(messages[:-1], messages[1:]):
|
||||
assert (
|
||||
prev.created_at >= next.created_at
|
||||
), f"Order violation: {prev.id} ({prev.created_at}) followed by {next.id} ({next.created_at})"
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_step_after_request_all_continue(
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job
|
||||
disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop
|
||||
):
|
||||
anthropic_batch_id = "msgbatch_test_12345"
|
||||
dummy_batch_response = create_batch_response(
|
||||
@@ -902,7 +909,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
|
||||
)
|
||||
assert len(messages) == len(agents) * 4
|
||||
assert_descending_order(messages)
|
||||
_assert_descending_order(messages)
|
||||
# Check that each agent is represented
|
||||
for agent in agents:
|
||||
agent_messages = [m for m in messages if m.agent_id == agent.id]
|
||||
@@ -915,7 +922,7 @@ async def test_resume_step_after_request_all_continue(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_until_request_prepares_and_submits_batch_correctly(
|
||||
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job
|
||||
disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job, event_loop
|
||||
):
|
||||
"""
|
||||
Test that step_until_request correctly:
|
||||
|
||||
Reference in New Issue
Block a user