chore: Clean up upserting base tools (#2274)
This commit is contained in:
@@ -1,18 +1,18 @@
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import select, union_all, literal, func, Select
|
||||
import numpy as np
|
||||
from sqlalchemy import Select, func, literal, select, union_all
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import AgentPassage
|
||||
from letta.orm import Block as BlockModel
|
||||
from letta.orm import Source as SourceModel
|
||||
from letta.orm import SourcePassage, SourcesAgents
|
||||
from letta.orm import Tool as ToolModel
|
||||
from letta.orm import AgentPassage, SourcePassage
|
||||
from letta.orm import SourcesAgents
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
@@ -77,6 +77,8 @@ class AgentManager:
|
||||
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||||
if agent_create.tools:
|
||||
tool_names.extend(agent_create.tools)
|
||||
# Remove duplicates
|
||||
tool_names = list(set(tool_names))
|
||||
|
||||
tool_ids = agent_create.tool_ids or []
|
||||
for tool_name in tool_names:
|
||||
@@ -431,7 +433,7 @@ class AgentManager:
|
||||
agent_only: bool = False,
|
||||
) -> Select:
|
||||
"""Helper function to build the base passage query with all filters applied.
|
||||
|
||||
|
||||
Returns the query before any limit or count operations are applied.
|
||||
"""
|
||||
embedded_text = None
|
||||
@@ -448,21 +450,14 @@ class AgentManager:
|
||||
if not agent_only: # Include source passages
|
||||
if agent_id is not None:
|
||||
source_passages = (
|
||||
select(
|
||||
SourcePassage,
|
||||
literal(None).label('agent_id')
|
||||
)
|
||||
select(SourcePassage, literal(None).label("agent_id"))
|
||||
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
||||
.where(SourcesAgents.agent_id == agent_id)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
)
|
||||
else:
|
||||
source_passages = (
|
||||
select(
|
||||
SourcePassage,
|
||||
literal(None).label('agent_id')
|
||||
)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
source_passages = select(SourcePassage, literal(None).label("agent_id")).where(
|
||||
SourcePassage.organization_id == actor.organization_id
|
||||
)
|
||||
|
||||
if source_id:
|
||||
@@ -486,9 +481,9 @@ class AgentManager:
|
||||
AgentPassage._created_by_id,
|
||||
AgentPassage._last_updated_by_id,
|
||||
AgentPassage.organization_id,
|
||||
literal(None).label('file_id'),
|
||||
literal(None).label('source_id'),
|
||||
AgentPassage.agent_id
|
||||
literal(None).label("file_id"),
|
||||
literal(None).label("source_id"),
|
||||
AgentPassage.agent_id,
|
||||
)
|
||||
.where(AgentPassage.agent_id == agent_id)
|
||||
.where(AgentPassage.organization_id == actor.organization_id)
|
||||
@@ -496,11 +491,11 @@ class AgentManager:
|
||||
|
||||
# Combine queries
|
||||
if source_passages is not None and agent_passages is not None:
|
||||
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
|
||||
combined_query = union_all(source_passages, agent_passages).cte("combined_passages")
|
||||
elif agent_passages is not None:
|
||||
combined_query = agent_passages.cte('combined_passages')
|
||||
combined_query = agent_passages.cte("combined_passages")
|
||||
elif source_passages is not None:
|
||||
combined_query = source_passages.cte('combined_passages')
|
||||
combined_query = source_passages.cte("combined_passages")
|
||||
else:
|
||||
raise ValueError("No passages found")
|
||||
|
||||
@@ -521,9 +516,7 @@ class AgentManager:
|
||||
if embedded_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL with pgvector
|
||||
main_query = main_query.order_by(
|
||||
combined_query.c.embedding.cosine_distance(embedded_text).asc()
|
||||
)
|
||||
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
query_embedding_binary = adapt_array(embedded_text)
|
||||
@@ -531,13 +524,13 @@ class AgentManager:
|
||||
main_query = main_query.order_by(
|
||||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||
combined_query.c.created_at.asc(),
|
||||
combined_query.c.id.asc()
|
||||
combined_query.c.id.asc(),
|
||||
)
|
||||
else:
|
||||
main_query = main_query.order_by(
|
||||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||
combined_query.c.created_at.desc(),
|
||||
combined_query.c.id.asc()
|
||||
combined_query.c.id.asc(),
|
||||
)
|
||||
else:
|
||||
if query_text:
|
||||
@@ -545,18 +538,12 @@ class AgentManager:
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if cursor:
|
||||
cursor_query = select(combined_query.c.created_at).where(
|
||||
combined_query.c.id == cursor
|
||||
).scalar_subquery()
|
||||
|
||||
cursor_query = select(combined_query.c.created_at).where(combined_query.c.id == cursor).scalar_subquery()
|
||||
|
||||
if ascending:
|
||||
main_query = main_query.where(
|
||||
combined_query.c.created_at > cursor_query
|
||||
)
|
||||
main_query = main_query.where(combined_query.c.created_at > cursor_query)
|
||||
else:
|
||||
main_query = main_query.where(
|
||||
combined_query.c.created_at < cursor_query
|
||||
)
|
||||
main_query = main_query.where(combined_query.c.created_at < cursor_query)
|
||||
|
||||
# Add ordering if not already ordered by similarity
|
||||
if not embed_query:
|
||||
@@ -588,7 +575,7 @@ class AgentManager:
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False
|
||||
agent_only: bool = False,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
with self.session_maker() as session:
|
||||
@@ -617,19 +604,18 @@ class AgentManager:
|
||||
passages = []
|
||||
for row in results:
|
||||
data = dict(row._mapping)
|
||||
if data['agent_id'] is not None:
|
||||
if data["agent_id"] is not None:
|
||||
# This is an AgentPassage - remove source fields
|
||||
data.pop('source_id', None)
|
||||
data.pop('file_id', None)
|
||||
data.pop("source_id", None)
|
||||
data.pop("file_id", None)
|
||||
passage = AgentPassage(**data)
|
||||
else:
|
||||
# This is a SourcePassage - remove agent field
|
||||
data.pop('agent_id', None)
|
||||
data.pop("agent_id", None)
|
||||
passage = SourcePassage(**data)
|
||||
passages.append(passage)
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@enforce_types
|
||||
def passage_size(
|
||||
@@ -645,7 +631,7 @@ class AgentManager:
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False
|
||||
agent_only: bool = False,
|
||||
) -> int:
|
||||
"""Returns the count of passages matching the given criteria."""
|
||||
with self.session_maker() as session:
|
||||
@@ -663,7 +649,7 @@ class AgentManager:
|
||||
embedding_config=embedding_config,
|
||||
agent_only=agent_only,
|
||||
)
|
||||
|
||||
|
||||
# Convert to count query
|
||||
count_query = select(func.count()).select_from(main_query.subquery())
|
||||
return session.scalar(count_query) or 0
|
||||
|
||||
Reference in New Issue
Block a user