feat: Asyncify insert archival memories (#2430)

Co-authored-by: Caren Thomas <carenthomas@gmail.com>
This commit is contained in:
Matthew Zhou
2025-05-25 22:28:35 -07:00
committed by GitHub
parent 790d00ac75
commit 8e9307c289
8 changed files with 205 additions and 475 deletions

View File

@@ -26,7 +26,7 @@ class OpenAIStreamingInterface:
self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=True) # TODO: pass in kward
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=True) # TODO: pass in kwarg
self.function_name_buffer = None
self.function_args_buffer = None
self.function_id_buffer = None

View File

@@ -517,18 +517,18 @@ async def list_passages(
@router.post("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="create_passage")
def create_passage(
async def create_passage(
agent_id: str,
request: CreateArchivalMemory = Body(...),
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Insert a memory into an agent's archival memory store.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return server.insert_archival_memory(agent_id=agent_id, memory_contents=request.text, actor=actor)
return await server.insert_archival_memory_async(agent_id=agent_id, memory_contents=request.text, actor=actor)
@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=List[Passage], operation_id="modify_passage")

View File

@@ -1128,6 +1128,20 @@ class SyncServer(Server):
return passages
async def insert_archival_memory_async(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
# Get the agent object (loaded in memory)
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
# Insert into archival memory
# TODO: @mindy look at moving this to agent_manager to avoid above extra call
passages = await self.passage_manager.insert_passage_async(
agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor
)
# rebuild agent system prompt - force since no archival change
await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
return passages
def modify_archival_memory(self, agent_id: str, memory_id: str, passage: PassageUpdate, actor: User) -> List[Passage]:
passage = Passage(**passage.model_dump(exclude_unset=True, exclude_none=True))
passages = self.passage_manager.update_passage_by_id(passage_id=memory_id, passage=passage, actor=actor)

View File

@@ -2,7 +2,8 @@ from datetime import datetime, timezone
from functools import lru_cache
from typing import List, Optional
from openai import OpenAI
from async_lru import alru_cache
from openai import AsyncOpenAI, OpenAI
from letta.constants import MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model, parse_and_chunk_text
@@ -26,6 +27,16 @@ def get_openai_embedding(text: str, model: str, endpoint: str) -> List[float]:
return response.data[0].embedding
# TODO: Add redis-backed caching for backend
@alru_cache(maxsize=8192)
async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> List[float]:
from letta.settings import model_settings
client = AsyncOpenAI(api_key=model_settings.openai_api_key, base_url=endpoint, max_retries=0)
response = await client.embeddings.create(input=text, model=model)
return response.data[0].embedding
class PassageManager:
"""Manager class to handle business logic related to Passages."""
@@ -83,6 +94,43 @@ class PassageManager:
passage.create(session, actor=actor)
return passage.to_pydantic()
@enforce_types
@trace_method
async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
# Common fields for both passage types
data = pydantic_passage.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
if "agent_id" in data and data["agent_id"]:
assert not data.get("source_id"), "Passage cannot have both agent_id and source_id"
agent_fields = {
"agent_id": data["agent_id"],
}
passage = AgentPassage(**common_fields, **agent_fields)
elif "source_id" in data and data["source_id"]:
assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id"
source_fields = {
"source_id": data["source_id"],
"file_id": data.get("file_id"),
}
passage = SourcePassage(**common_fields, **source_fields)
else:
raise ValueError("Passage must have either agent_id or source_id")
async with db_registry.async_session() as session:
passage = await passage.create_async(session, actor=actor)
return passage.to_pydantic()
@enforce_types
@trace_method
def create_many_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
@@ -148,6 +196,65 @@ class PassageManager:
except Exception as e:
raise e
@enforce_types
@trace_method
async def insert_passage_async(
self,
agent_state: AgentState,
agent_id: str,
text: str,
actor: PydanticUser,
) -> List[PydanticPassage]:
"""Insert passage(s) into archival memory"""
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
# TODO eventually migrate off of llama-index for embeddings?
# Already causing pain for OpenAI proxy endpoints like LM Studio...
if agent_state.embedding_config.embedding_endpoint_type != "openai":
embed_model = embedding_model(agent_state.embedding_config)
passages = []
try:
# breakup string into passages
for text in parse_and_chunk_text(text, embedding_chunk_size):
if agent_state.embedding_config.embedding_endpoint_type != "openai":
embedding = embed_model.get_text_embedding(text)
else:
# TODO should have the settings passed in via the server call
embedding = await get_openai_embedding_async(
text,
agent_state.embedding_config.embedding_model,
agent_state.embedding_config.embedding_endpoint,
)
if isinstance(embedding, dict):
try:
embedding = embedding["data"][0]["embedding"]
except (KeyError, IndexError):
# TODO as a fallback, see if we can find any lists in the payload
raise TypeError(
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
)
passage = await self.create_passage_async(
PydanticPassage(
organization_id=actor.organization_id,
agent_id=agent_id,
text=text,
embedding=embedding,
embedding_config=agent_state.embedding_config,
),
actor=actor,
)
passages.append(passage)
return passages
except Exception as e:
raise e
@enforce_types
@trace_method
def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:

View File

@@ -230,8 +230,7 @@ class LettaCoreToolExecutor(ToolExecutor):
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
# TODO: convert this to async
PassageManager().insert_passage(
await PassageManager().insert_passage_async(
agent_state=agent_state,
agent_id=agent_state.id,
text=content,

530
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -92,6 +92,7 @@ aiomultiprocess = "^0.9.1"
matplotlib = "^3.10.1"
asyncpg = {version = "^0.30.0", optional = true}
tavily-python = "^0.7.2"
async-lru = "^2.0.5"
[tool.poetry.extras]

View File

@@ -508,7 +508,7 @@ def server():
@pytest.fixture
@pytest.mark.asyncio
async def agent_passages_setup(server, default_source, default_user, sarah_agent):
async def agent_passages_setup(server, default_source, default_user, sarah_agent, event_loop):
"""Setup fixture for agent passages tests"""
agent_id = sarah_agent.id
actor = default_user
@@ -518,7 +518,7 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
# Create some source passages
source_passages = []
for i in range(3):
passage = server.passage_manager.create_passage(
passage = await server.passage_manager.create_passage_async(
PydanticPassage(
organization_id=actor.organization_id,
source_id=default_source.id,
@@ -533,7 +533,7 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
# Create some agent passages
agent_passages = []
for i in range(2):
passage = server.passage_manager.create_passage(
passage = await server.passage_manager.create_passage_async(
PydanticPassage(
organization_id=actor.organization_id,
agent_id=agent_id,
@@ -1948,7 +1948,7 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding,
)
created_passage = server.passage_manager.create_passage(passage, default_user)
created_passage = await server.passage_manager.create_passage_async(passage, default_user)
passages.append(created_passage)
# Query vector similar to "red" embedding
@@ -2097,14 +2097,15 @@ def test_passage_create_source(server: SyncServer, source_passage_fixture, defau
assert retrieved.text == source_passage_fixture.text
def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user):
@pytest.mark.asyncio
async def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user, event_loop):
"""Test creating an agent passage."""
assert agent_passage_fixture is not None
assert agent_passage_fixture.text == "Hello, I am an agent passage"
# Try to create an invalid passage (with both agent_id and source_id)
with pytest.raises(AssertionError):
server.passage_manager.create_passage(
await server.passage_manager.create_passage_async(
PydanticPassage(
text="Invalid passage",
agent_id="123",