chore: add back test_server.py (#5783)

This commit is contained in:
Sarah Wooders
2025-10-28 16:24:47 -07:00
committed by Caren Thomas
parent 3f78c93be5
commit a566900533
5 changed files with 212 additions and 800 deletions

View File

@@ -911,6 +911,10 @@ class LettaAgentV2(BaseAgentV2):
) )
messages_to_persist = (initial_messages or []) + tool_call_messages messages_to_persist = (initial_messages or []) + tool_call_messages
for message in messages_to_persist:
message.step_id = step_id
message.run_id = run_id
persisted_messages = await self.message_manager.create_many_messages_async( persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, messages_to_persist,
actor=self.actor, actor=self.actor,
@@ -1028,6 +1032,10 @@ class LettaAgentV2(BaseAgentV2):
) )
messages_to_persist = (initial_messages or []) + tool_call_messages messages_to_persist = (initial_messages or []) + tool_call_messages
for message in messages_to_persist:
message.step_id = step_id
message.run_id = run_id
persisted_messages = await self.message_manager.create_many_messages_async( persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id
) )

View File

@@ -673,6 +673,8 @@ class LettaAgentV3(LettaAgentV2):
for message in messages_to_persist: for message in messages_to_persist:
if message.run_id is None: if message.run_id is None:
message.run_id = run_id message.run_id = run_id
if message.step_id is None:
message.step_id = step_id
persisted_messages = await self.message_manager.create_many_messages_async( persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id
@@ -699,6 +701,8 @@ class LettaAgentV3(LettaAgentV2):
for message in messages_to_persist: for message in messages_to_persist:
if message.run_id is None: if message.run_id is None:
message.run_id = run_id message.run_id = run_id
if message.step_id is None:
message.step_id = step_id
persisted_messages = await self.message_manager.create_many_messages_async( persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, messages_to_persist,
@@ -909,10 +913,12 @@ class LettaAgentV3(LettaAgentV2):
messages_to_persist: list[Message] = (initial_messages or []) + parallel_messages messages_to_persist: list[Message] = (initial_messages or []) + parallel_messages
# Set run_id on all messages before persisting # Set run_id and step_id on all messages before persisting
for message in messages_to_persist: for message in messages_to_persist:
if message.run_id is None: if message.run_id is None:
message.run_id = run_id message.run_id = run_id
if message.step_id is None:
message.step_id = step_id
# Persist all messages # Persist all messages
persisted_messages = await self.message_manager.create_many_messages_async( persisted_messages = await self.message_manager.create_many_messages_async(

View File

@@ -19,6 +19,7 @@ if TYPE_CHECKING:
from letta.orm.passage import ArchivalPassage, SourcePassage from letta.orm.passage import ArchivalPassage, SourcePassage
from letta.orm.passage_tag import PassageTag from letta.orm.passage_tag import PassageTag
from letta.orm.provider import Provider from letta.orm.provider import Provider
from letta.orm.provider_trace import ProviderTrace
from letta.orm.run import Run from letta.orm.run import Run
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
from letta.orm.tool import Tool from letta.orm.tool import Tool
@@ -70,3 +71,6 @@ class Organization(SqlalchemyBase):
) )
jobs: Mapped[List["Job"]] = relationship("Job", back_populates="organization", cascade="all, delete-orphan") jobs: Mapped[List["Job"]] = relationship("Job", back_populates="organization", cascade="all, delete-orphan")
runs: Mapped[List["Run"]] = relationship("Run", back_populates="organization", cascade="all, delete-orphan") runs: Mapped[List["Run"]] = relationship("Run", back_populates="organization", cascade="all, delete-orphan")
provider_traces: Mapped[List["ProviderTrace"]] = relationship(
"ProviderTrace", back_populates="organization", cascade="all, delete-orphan"
)

View File

@@ -434,7 +434,6 @@ class SyncServer(object):
assert request.llm_config.handle == request.model, ( assert request.llm_config.handle == request.model, (
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}" f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
) )
print("GOT LLM CONFIG", request.llm_config)
if request.reasoning is None: if request.reasoning is None:
request.reasoning = request.llm_config.enable_reasoner or request.llm_config.put_inner_thoughts_in_kwargs request.reasoning = request.llm_config.enable_reasoner or request.llm_config.put_inner_thoughts_in_kwargs
@@ -1039,7 +1038,9 @@ class SyncServer(object):
"""String match the `handle` to the available configs""" """String match the `handle` to the available configs"""
matched_llm_config = None matched_llm_config = None
available_handles = [] available_handles = []
for provider in self._enabled_providers: # Get all enabled providers (including BYOK providers from database)
providers = await self.get_enabled_providers_async(actor=actor)
for provider in providers:
llm_configs = await provider.list_llm_models_async() llm_configs = await provider.list_llm_models_async()
for llm_config in llm_configs: for llm_config in llm_configs:
available_handles.append(llm_config.handle) available_handles.append(llm_config.handle)
@@ -1081,7 +1082,9 @@ class SyncServer(object):
) -> EmbeddingConfig: ) -> EmbeddingConfig:
matched_embedding_config = None matched_embedding_config = None
available_handles = [] available_handles = []
for provider in self._enabled_providers: # Get all enabled providers (including BYOK providers from database)
providers = await self.get_enabled_providers_async(actor=actor)
for provider in providers:
embedding_configs = await provider.list_embedding_models_async() embedding_configs = await provider.list_embedding_models_async()
for embedding_config in embedding_configs: for embedding_config in embedding_configs:
available_handles.append(embedding_config.handle) available_handles.append(embedding_config.handle)

File diff suppressed because it is too large Load Diff