chore: Clean up upserting base tools (#2274)

This commit is contained in:
Matthew Zhou
2024-12-18 14:33:29 -08:00
committed by GitHub
parent 8644f2016a
commit b1ce8b4e8a
10 changed files with 113 additions and 228 deletions

View File

@@ -1,18 +1,18 @@
from typing import Dict, List, Optional
from datetime import datetime
import numpy as np
from typing import Dict, List, Optional
from sqlalchemy import select, union_all, literal, func, Select
import numpy as np
from sqlalchemy import Select, func, literal, select, union_all
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model
from letta.log import get_logger
from letta.orm import Agent as AgentModel
from letta.orm import AgentPassage
from letta.orm import Block as BlockModel
from letta.orm import Source as SourceModel
from letta.orm import SourcePassage, SourcesAgents
from letta.orm import Tool as ToolModel
from letta.orm import AgentPassage, SourcePassage
from letta.orm import SourcesAgents
from letta.orm.errors import NoResultFound
from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
@@ -77,6 +77,8 @@ class AgentManager:
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
if agent_create.tools:
tool_names.extend(agent_create.tools)
# Remove duplicates
tool_names = list(set(tool_names))
tool_ids = agent_create.tool_ids or []
for tool_name in tool_names:
@@ -431,7 +433,7 @@ class AgentManager:
agent_only: bool = False,
) -> Select:
"""Helper function to build the base passage query with all filters applied.
Returns the query before any limit or count operations are applied.
"""
embedded_text = None
@@ -448,21 +450,14 @@ class AgentManager:
if not agent_only: # Include source passages
if agent_id is not None:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
select(SourcePassage, literal(None).label("agent_id"))
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
.where(SourcesAgents.agent_id == agent_id)
.where(SourcePassage.organization_id == actor.organization_id)
)
else:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
.where(SourcePassage.organization_id == actor.organization_id)
source_passages = select(SourcePassage, literal(None).label("agent_id")).where(
SourcePassage.organization_id == actor.organization_id
)
if source_id:
@@ -486,9 +481,9 @@ class AgentManager:
AgentPassage._created_by_id,
AgentPassage._last_updated_by_id,
AgentPassage.organization_id,
literal(None).label('file_id'),
literal(None).label('source_id'),
AgentPassage.agent_id
literal(None).label("file_id"),
literal(None).label("source_id"),
AgentPassage.agent_id,
)
.where(AgentPassage.agent_id == agent_id)
.where(AgentPassage.organization_id == actor.organization_id)
@@ -496,11 +491,11 @@ class AgentManager:
# Combine queries
if source_passages is not None and agent_passages is not None:
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
combined_query = union_all(source_passages, agent_passages).cte("combined_passages")
elif agent_passages is not None:
combined_query = agent_passages.cte('combined_passages')
combined_query = agent_passages.cte("combined_passages")
elif source_passages is not None:
combined_query = source_passages.cte('combined_passages')
combined_query = source_passages.cte("combined_passages")
else:
raise ValueError("No passages found")
@@ -521,9 +516,7 @@ class AgentManager:
if embedded_text:
if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
main_query = main_query.order_by(
combined_query.c.embedding.cosine_distance(embedded_text).asc()
)
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
else:
# SQLite with custom vector type
query_embedding_binary = adapt_array(embedded_text)
@@ -531,13 +524,13 @@ class AgentManager:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.asc(),
combined_query.c.id.asc()
combined_query.c.id.asc(),
)
else:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.desc(),
combined_query.c.id.asc()
combined_query.c.id.asc(),
)
else:
if query_text:
@@ -545,18 +538,12 @@ class AgentManager:
# Handle cursor-based pagination
if cursor:
cursor_query = select(combined_query.c.created_at).where(
combined_query.c.id == cursor
).scalar_subquery()
cursor_query = select(combined_query.c.created_at).where(combined_query.c.id == cursor).scalar_subquery()
if ascending:
main_query = main_query.where(
combined_query.c.created_at > cursor_query
)
main_query = main_query.where(combined_query.c.created_at > cursor_query)
else:
main_query = main_query.where(
combined_query.c.created_at < cursor_query
)
main_query = main_query.where(combined_query.c.created_at < cursor_query)
# Add ordering if not already ordered by similarity
if not embed_query:
@@ -588,7 +575,7 @@ class AgentManager:
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
agent_only: bool = False,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
with self.session_maker() as session:
@@ -617,19 +604,18 @@ class AgentManager:
passages = []
for row in results:
data = dict(row._mapping)
if data['agent_id'] is not None:
if data["agent_id"] is not None:
# This is an AgentPassage - remove source fields
data.pop('source_id', None)
data.pop('file_id', None)
data.pop("source_id", None)
data.pop("file_id", None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
data.pop('agent_id', None)
data.pop("agent_id", None)
passage = SourcePassage(**data)
passages.append(passage)
return [p.to_pydantic() for p in passages]
return [p.to_pydantic() for p in passages]
@enforce_types
def passage_size(
@@ -645,7 +631,7 @@ class AgentManager:
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
agent_only: bool = False,
) -> int:
"""Returns the count of passages matching the given criteria."""
with self.session_maker() as session:
@@ -663,7 +649,7 @@ class AgentManager:
embedding_config=embedding_config,
agent_only=agent_only,
)
# Convert to count query
count_query = select(func.count()).select_from(main_query.subquery())
return session.scalar(count_query) or 0