feat: Asyncify insert archival memories (#2430)
Co-authored-by: Caren Thomas <carenthomas@gmail.com>
This commit is contained in:
@@ -26,7 +26,7 @@ class OpenAIStreamingInterface:
|
||||
self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG
|
||||
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=True) # TODO: pass in kward
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=True) # TODO: pass in kwarg
|
||||
self.function_name_buffer = None
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
|
||||
@@ -517,18 +517,18 @@ async def list_passages(
|
||||
|
||||
|
||||
@router.post("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="create_passage")
|
||||
def create_passage(
|
||||
async def create_passage(
|
||||
agent_id: str,
|
||||
request: CreateArchivalMemory = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
return server.insert_archival_memory(agent_id=agent_id, memory_contents=request.text, actor=actor)
|
||||
return await server.insert_archival_memory_async(agent_id=agent_id, memory_contents=request.text, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/archival-memory/{memory_id}", response_model=List[Passage], operation_id="modify_passage")
|
||||
|
||||
@@ -1128,6 +1128,20 @@ class SyncServer(Server):
|
||||
|
||||
return passages
|
||||
|
||||
async def insert_archival_memory_async(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
|
||||
# Get the agent object (loaded in memory)
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
# Insert into archival memory
|
||||
# TODO: @mindy look at moving this to agent_manager to avoid above extra call
|
||||
passages = await self.passage_manager.insert_passage_async(
|
||||
agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor
|
||||
)
|
||||
|
||||
# rebuild agent system prompt - force since no archival change
|
||||
await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True)
|
||||
|
||||
return passages
|
||||
|
||||
def modify_archival_memory(self, agent_id: str, memory_id: str, passage: PassageUpdate, actor: User) -> List[Passage]:
|
||||
passage = Passage(**passage.model_dump(exclude_unset=True, exclude_none=True))
|
||||
passages = self.passage_manager.update_passage_by_id(passage_id=memory_id, passage=passage, actor=actor)
|
||||
|
||||
@@ -2,7 +2,8 @@ from datetime import datetime, timezone
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from async_lru import alru_cache
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import embedding_model, parse_and_chunk_text
|
||||
@@ -26,6 +27,16 @@ def get_openai_embedding(text: str, model: str, endpoint: str) -> List[float]:
|
||||
return response.data[0].embedding
|
||||
|
||||
|
||||
# TODO: Add redis-backed caching for backend
|
||||
@alru_cache(maxsize=8192)
|
||||
async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> List[float]:
|
||||
from letta.settings import model_settings
|
||||
|
||||
client = AsyncOpenAI(api_key=model_settings.openai_api_key, base_url=endpoint, max_retries=0)
|
||||
response = await client.embeddings.create(input=text, model=model)
|
||||
return response.data[0].embedding
|
||||
|
||||
|
||||
class PassageManager:
|
||||
"""Manager class to handle business logic related to Passages."""
|
||||
|
||||
@@ -83,6 +94,43 @@ class PassageManager:
|
||||
passage.create(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
||||
# Common fields for both passage types
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
|
||||
if "agent_id" in data and data["agent_id"]:
|
||||
assert not data.get("source_id"), "Passage cannot have both agent_id and source_id"
|
||||
agent_fields = {
|
||||
"agent_id": data["agent_id"],
|
||||
}
|
||||
passage = AgentPassage(**common_fields, **agent_fields)
|
||||
elif "source_id" in data and data["source_id"]:
|
||||
assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id"
|
||||
source_fields = {
|
||||
"source_id": data["source_id"],
|
||||
"file_id": data.get("file_id"),
|
||||
}
|
||||
passage = SourcePassage(**common_fields, **source_fields)
|
||||
else:
|
||||
raise ValueError("Passage must have either agent_id or source_id")
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
passage = await passage.create_async(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_many_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
@@ -148,6 +196,65 @@ class PassageManager:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def insert_passage_async(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
agent_id: str,
|
||||
text: str,
|
||||
actor: PydanticUser,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Insert passage(s) into archival memory"""
|
||||
|
||||
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
||||
|
||||
# TODO eventually migrate off of llama-index for embeddings?
|
||||
# Already causing pain for OpenAI proxy endpoints like LM Studio...
|
||||
if agent_state.embedding_config.embedding_endpoint_type != "openai":
|
||||
embed_model = embedding_model(agent_state.embedding_config)
|
||||
|
||||
passages = []
|
||||
|
||||
try:
|
||||
# breakup string into passages
|
||||
for text in parse_and_chunk_text(text, embedding_chunk_size):
|
||||
|
||||
if agent_state.embedding_config.embedding_endpoint_type != "openai":
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
else:
|
||||
# TODO should have the settings passed in via the server call
|
||||
embedding = await get_openai_embedding_async(
|
||||
text,
|
||||
agent_state.embedding_config.embedding_model,
|
||||
agent_state.embedding_config.embedding_endpoint,
|
||||
)
|
||||
|
||||
if isinstance(embedding, dict):
|
||||
try:
|
||||
embedding = embedding["data"][0]["embedding"]
|
||||
except (KeyError, IndexError):
|
||||
# TODO as a fallback, see if we can find any lists in the payload
|
||||
raise TypeError(
|
||||
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
||||
)
|
||||
passage = await self.create_passage_async(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:
|
||||
|
||||
@@ -230,8 +230,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
# TODO: convert this to async
|
||||
PassageManager().insert_passage(
|
||||
await PassageManager().insert_passage_async(
|
||||
agent_state=agent_state,
|
||||
agent_id=agent_state.id,
|
||||
text=content,
|
||||
|
||||
530
poetry.lock
generated
530
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -92,6 +92,7 @@ aiomultiprocess = "^0.9.1"
|
||||
matplotlib = "^3.10.1"
|
||||
asyncpg = {version = "^0.30.0", optional = true}
|
||||
tavily-python = "^0.7.2"
|
||||
async-lru = "^2.0.5"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
@@ -508,7 +508,7 @@ def server():
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.asyncio
|
||||
async def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
async def agent_passages_setup(server, default_source, default_user, sarah_agent, event_loop):
|
||||
"""Setup fixture for agent passages tests"""
|
||||
agent_id = sarah_agent.id
|
||||
actor = default_user
|
||||
@@ -518,7 +518,7 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
|
||||
# Create some source passages
|
||||
source_passages = []
|
||||
for i in range(3):
|
||||
passage = server.passage_manager.create_passage(
|
||||
passage = await server.passage_manager.create_passage_async(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
source_id=default_source.id,
|
||||
@@ -533,7 +533,7 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
|
||||
# Create some agent passages
|
||||
agent_passages = []
|
||||
for i in range(2):
|
||||
passage = server.passage_manager.create_passage(
|
||||
passage = await server.passage_manager.create_passage_async(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
@@ -1948,7 +1948,7 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding,
|
||||
)
|
||||
created_passage = server.passage_manager.create_passage(passage, default_user)
|
||||
created_passage = await server.passage_manager.create_passage_async(passage, default_user)
|
||||
passages.append(created_passage)
|
||||
|
||||
# Query vector similar to "red" embedding
|
||||
@@ -2097,14 +2097,15 @@ def test_passage_create_source(server: SyncServer, source_passage_fixture, defau
|
||||
assert retrieved.text == source_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user, event_loop):
|
||||
"""Test creating an agent passage."""
|
||||
assert agent_passage_fixture is not None
|
||||
assert agent_passage_fixture.text == "Hello, I am an agent passage"
|
||||
|
||||
# Try to create an invalid passage (with both agent_id and source_id)
|
||||
with pytest.raises(AssertionError):
|
||||
server.passage_manager.create_passage(
|
||||
await server.passage_manager.create_passage_async(
|
||||
PydanticPassage(
|
||||
text="Invalid passage",
|
||||
agent_id="123",
|
||||
|
||||
Reference in New Issue
Block a user