chore: add back test_server.py (#5783)
This commit is contained in:
committed by
Caren Thomas
parent
3f78c93be5
commit
a566900533
@@ -911,6 +911,10 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
)
|
)
|
||||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||||
|
|
||||||
|
for message in messages_to_persist:
|
||||||
|
message.step_id = step_id
|
||||||
|
message.run_id = run_id
|
||||||
|
|
||||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||||
messages_to_persist,
|
messages_to_persist,
|
||||||
actor=self.actor,
|
actor=self.actor,
|
||||||
@@ -1028,6 +1032,10 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
)
|
)
|
||||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||||
|
|
||||||
|
for message in messages_to_persist:
|
||||||
|
message.step_id = step_id
|
||||||
|
message.run_id = run_id
|
||||||
|
|
||||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||||
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id
|
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -673,6 +673,8 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
for message in messages_to_persist:
|
for message in messages_to_persist:
|
||||||
if message.run_id is None:
|
if message.run_id is None:
|
||||||
message.run_id = run_id
|
message.run_id = run_id
|
||||||
|
if message.step_id is None:
|
||||||
|
message.step_id = step_id
|
||||||
|
|
||||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||||
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id
|
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id
|
||||||
@@ -699,6 +701,8 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
for message in messages_to_persist:
|
for message in messages_to_persist:
|
||||||
if message.run_id is None:
|
if message.run_id is None:
|
||||||
message.run_id = run_id
|
message.run_id = run_id
|
||||||
|
if message.step_id is None:
|
||||||
|
message.step_id = step_id
|
||||||
|
|
||||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||||
messages_to_persist,
|
messages_to_persist,
|
||||||
@@ -909,10 +913,12 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
|
|
||||||
messages_to_persist: list[Message] = (initial_messages or []) + parallel_messages
|
messages_to_persist: list[Message] = (initial_messages or []) + parallel_messages
|
||||||
|
|
||||||
# Set run_id on all messages before persisting
|
# Set run_id and step_id on all messages before persisting
|
||||||
for message in messages_to_persist:
|
for message in messages_to_persist:
|
||||||
if message.run_id is None:
|
if message.run_id is None:
|
||||||
message.run_id = run_id
|
message.run_id = run_id
|
||||||
|
if message.step_id is None:
|
||||||
|
message.step_id = step_id
|
||||||
|
|
||||||
# Persist all messages
|
# Persist all messages
|
||||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
|||||||
from letta.orm.passage import ArchivalPassage, SourcePassage
|
from letta.orm.passage import ArchivalPassage, SourcePassage
|
||||||
from letta.orm.passage_tag import PassageTag
|
from letta.orm.passage_tag import PassageTag
|
||||||
from letta.orm.provider import Provider
|
from letta.orm.provider import Provider
|
||||||
|
from letta.orm.provider_trace import ProviderTrace
|
||||||
from letta.orm.run import Run
|
from letta.orm.run import Run
|
||||||
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
||||||
from letta.orm.tool import Tool
|
from letta.orm.tool import Tool
|
||||||
@@ -70,3 +71,6 @@ class Organization(SqlalchemyBase):
|
|||||||
)
|
)
|
||||||
jobs: Mapped[List["Job"]] = relationship("Job", back_populates="organization", cascade="all, delete-orphan")
|
jobs: Mapped[List["Job"]] = relationship("Job", back_populates="organization", cascade="all, delete-orphan")
|
||||||
runs: Mapped[List["Run"]] = relationship("Run", back_populates="organization", cascade="all, delete-orphan")
|
runs: Mapped[List["Run"]] = relationship("Run", back_populates="organization", cascade="all, delete-orphan")
|
||||||
|
provider_traces: Mapped[List["ProviderTrace"]] = relationship(
|
||||||
|
"ProviderTrace", back_populates="organization", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|||||||
@@ -434,7 +434,6 @@ class SyncServer(object):
|
|||||||
assert request.llm_config.handle == request.model, (
|
assert request.llm_config.handle == request.model, (
|
||||||
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
|
f"LLM config handle {request.llm_config.handle} does not match request handle {request.model}"
|
||||||
)
|
)
|
||||||
print("GOT LLM CONFIG", request.llm_config)
|
|
||||||
|
|
||||||
if request.reasoning is None:
|
if request.reasoning is None:
|
||||||
request.reasoning = request.llm_config.enable_reasoner or request.llm_config.put_inner_thoughts_in_kwargs
|
request.reasoning = request.llm_config.enable_reasoner or request.llm_config.put_inner_thoughts_in_kwargs
|
||||||
@@ -1039,7 +1038,9 @@ class SyncServer(object):
|
|||||||
"""String match the `handle` to the available configs"""
|
"""String match the `handle` to the available configs"""
|
||||||
matched_llm_config = None
|
matched_llm_config = None
|
||||||
available_handles = []
|
available_handles = []
|
||||||
for provider in self._enabled_providers:
|
# Get all enabled providers (including BYOK providers from database)
|
||||||
|
providers = await self.get_enabled_providers_async(actor=actor)
|
||||||
|
for provider in providers:
|
||||||
llm_configs = await provider.list_llm_models_async()
|
llm_configs = await provider.list_llm_models_async()
|
||||||
for llm_config in llm_configs:
|
for llm_config in llm_configs:
|
||||||
available_handles.append(llm_config.handle)
|
available_handles.append(llm_config.handle)
|
||||||
@@ -1081,7 +1082,9 @@ class SyncServer(object):
|
|||||||
) -> EmbeddingConfig:
|
) -> EmbeddingConfig:
|
||||||
matched_embedding_config = None
|
matched_embedding_config = None
|
||||||
available_handles = []
|
available_handles = []
|
||||||
for provider in self._enabled_providers:
|
# Get all enabled providers (including BYOK providers from database)
|
||||||
|
providers = await self.get_enabled_providers_async(actor=actor)
|
||||||
|
for provider in providers:
|
||||||
embedding_configs = await provider.list_embedding_models_async()
|
embedding_configs = await provider.list_embedding_models_async()
|
||||||
for embedding_config in embedding_configs:
|
for embedding_config in embedding_configs:
|
||||||
available_handles.append(embedding_config.handle)
|
available_handles.append(embedding_config.handle)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user