feat: Create batch request tracking tables (#1604)
This commit is contained in:
86
alembic/versions/0ceb975e0063_add_llm_batch_jobs_tables.py
Normal file
86
alembic/versions/0ceb975e0063_add_llm_batch_jobs_tables.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Add LLM batch jobs tables
|
||||
|
||||
Revision ID: 0ceb975e0063
|
||||
Revises: 90bb156e71df
|
||||
Create Date: 2025-04-07 15:57:18.475151
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
import letta
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0ceb975e0063"
|
||||
down_revision: Union[str, None] = "90bb156e71df"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"llm_batch_job",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("status", sa.String(), nullable=False),
|
||||
sa.Column("llm_provider", sa.String(), nullable=False),
|
||||
sa.Column("create_batch_response", letta.orm.custom_columns.CreateBatchResponseColumn(), nullable=False),
|
||||
sa.Column("latest_polling_response", letta.orm.custom_columns.PollBatchResponseColumn(), nullable=True),
|
||||
sa.Column("last_polled_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_llm_batch_job_created_at", "llm_batch_job", ["created_at"], unique=False)
|
||||
op.create_index("ix_llm_batch_job_status", "llm_batch_job", ["status"], unique=False)
|
||||
op.create_table(
|
||||
"llm_batch_items",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("batch_id", sa.String(), nullable=False),
|
||||
sa.Column("llm_config", letta.orm.custom_columns.LLMConfigColumn(), nullable=False),
|
||||
sa.Column("request_status", sa.String(), nullable=False),
|
||||
sa.Column("step_status", sa.String(), nullable=False),
|
||||
sa.Column("step_state", letta.orm.custom_columns.AgentStepStateColumn(), nullable=False),
|
||||
sa.Column("batch_request_result", letta.orm.custom_columns.BatchRequestResultColumn(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["batch_id"], ["llm_batch_job.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_llm_batch_items_agent_id", "llm_batch_items", ["agent_id"], unique=False)
|
||||
op.create_index("ix_llm_batch_items_batch_id", "llm_batch_items", ["batch_id"], unique=False)
|
||||
op.create_index("ix_llm_batch_items_status", "llm_batch_items", ["request_status"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_llm_batch_items_status", table_name="llm_batch_items")
|
||||
op.drop_index("ix_llm_batch_items_batch_id", table_name="llm_batch_items")
|
||||
op.drop_index("ix_llm_batch_items_agent_id", table_name="llm_batch_items")
|
||||
op.drop_table("llm_batch_items")
|
||||
op.drop_index("ix_llm_batch_job_status", table_name="llm_batch_job")
|
||||
op.drop_index("ix_llm_batch_job_created_at", table_name="llm_batch_job")
|
||||
op.drop_table("llm_batch_job")
|
||||
# ### end Alembic commands ###
|
||||
@@ -2,12 +2,14 @@ import base64
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from sqlalchemy import Dialect
|
||||
|
||||
from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.enums import ProviderType, ToolRuleType
|
||||
from letta.schemas.letta_message_content import (
|
||||
MessageContent,
|
||||
MessageContentType,
|
||||
@@ -38,7 +40,7 @@ from letta.schemas.tool_rule import (
|
||||
def serialize_llm_config(config: Union[Optional[LLMConfig], Dict]) -> Optional[Dict]:
|
||||
"""Convert an LLMConfig object into a JSON-serializable dictionary."""
|
||||
if config and isinstance(config, LLMConfig):
|
||||
return config.model_dump()
|
||||
return config.model_dump(mode="json")
|
||||
return config
|
||||
|
||||
|
||||
@@ -55,7 +57,7 @@ def deserialize_llm_config(data: Optional[Dict]) -> Optional[LLMConfig]:
|
||||
def serialize_embedding_config(config: Union[Optional[EmbeddingConfig], Dict]) -> Optional[Dict]:
|
||||
"""Convert an EmbeddingConfig object into a JSON-serializable dictionary."""
|
||||
if config and isinstance(config, EmbeddingConfig):
|
||||
return config.model_dump()
|
||||
return config.model_dump(mode="json")
|
||||
return config
|
||||
|
||||
|
||||
@@ -75,7 +77,9 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str,
|
||||
if not tool_rules:
|
||||
return []
|
||||
|
||||
data = [{**rule.model_dump(), "type": rule.type.value} for rule in tool_rules] # Convert Enum to string for JSON compatibility
|
||||
data = [
|
||||
{**rule.model_dump(mode="json"), "type": rule.type.value} for rule in tool_rules
|
||||
] # Convert Enum to string for JSON compatibility
|
||||
|
||||
# Validate ToolRule structure
|
||||
for rule_data in data:
|
||||
@@ -130,7 +134,7 @@ def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]]
|
||||
serialized_calls = []
|
||||
for call in tool_calls:
|
||||
if isinstance(call, OpenAIToolCall):
|
||||
serialized_calls.append(call.model_dump())
|
||||
serialized_calls.append(call.model_dump(mode="json"))
|
||||
elif isinstance(call, dict):
|
||||
serialized_calls.append(call) # Already a dictionary, leave it as-is
|
||||
else:
|
||||
@@ -166,7 +170,7 @@ def serialize_tool_returns(tool_returns: Optional[List[Union[ToolReturn, dict]]]
|
||||
serialized_tool_returns = []
|
||||
for tool_return in tool_returns:
|
||||
if isinstance(tool_return, ToolReturn):
|
||||
serialized_tool_returns.append(tool_return.model_dump())
|
||||
serialized_tool_returns.append(tool_return.model_dump(mode="json"))
|
||||
elif isinstance(tool_return, dict):
|
||||
serialized_tool_returns.append(tool_return) # Already a dictionary, leave it as-is
|
||||
else:
|
||||
@@ -201,7 +205,7 @@ def serialize_message_content(message_content: Optional[List[Union[MessageConten
|
||||
serialized_message_content = []
|
||||
for content in message_content:
|
||||
if isinstance(content, MessageContent):
|
||||
serialized_message_content.append(content.model_dump())
|
||||
serialized_message_content.append(content.model_dump(mode="json"))
|
||||
elif isinstance(content, dict):
|
||||
serialized_message_content.append(content) # Already a dictionary, leave it as-is
|
||||
else:
|
||||
@@ -266,3 +270,101 @@ def deserialize_vector(data: Optional[bytes], dialect: Dialect) -> Optional[np.n
|
||||
data = base64.b64decode(data)
|
||||
|
||||
return np.frombuffer(data, dtype=np.float32)
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Batch Request Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_create_batch_response(create_batch_response: Union[BetaMessageBatch]) -> Dict[str, Any]:
|
||||
"""Convert a list of ToolRules into a JSON-serializable format."""
|
||||
llm_provider_type = None
|
||||
if isinstance(create_batch_response, BetaMessageBatch):
|
||||
llm_provider_type = ProviderType.anthropic.value
|
||||
|
||||
if not llm_provider_type:
|
||||
raise ValueError(f"Could not determine llm provider from create batch response object type: {create_batch_response}")
|
||||
|
||||
return {"data": create_batch_response.model_dump(mode="json"), "type": llm_provider_type}
|
||||
|
||||
|
||||
def deserialize_create_batch_response(data: Dict) -> Union[BetaMessageBatch]:
|
||||
provider_type = ProviderType(data.get("type"))
|
||||
|
||||
if provider_type == ProviderType.anthropic:
|
||||
return BetaMessageBatch(**data.get("data"))
|
||||
|
||||
raise ValueError(f"Unknown ProviderType type: {provider_type}")
|
||||
|
||||
|
||||
# TODO: Note that this is the same as above for Anthropic, but this is not the case for all providers
|
||||
# TODO: Some have different types based on the create v.s. poll requests
|
||||
def serialize_poll_batch_response(poll_batch_response: Optional[Union[BetaMessageBatch]]) -> Optional[Dict[str, Any]]:
|
||||
"""Convert a list of ToolRules into a JSON-serializable format."""
|
||||
if not poll_batch_response:
|
||||
return None
|
||||
|
||||
llm_provider_type = None
|
||||
if isinstance(poll_batch_response, BetaMessageBatch):
|
||||
llm_provider_type = ProviderType.anthropic.value
|
||||
|
||||
if not llm_provider_type:
|
||||
raise ValueError(f"Could not determine llm provider from poll batch response object type: {poll_batch_response}")
|
||||
|
||||
return {"data": poll_batch_response.model_dump(mode="json"), "type": llm_provider_type}
|
||||
|
||||
|
||||
def deserialize_poll_batch_response(data: Optional[Dict]) -> Optional[Union[BetaMessageBatch]]:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
provider_type = ProviderType(data.get("type"))
|
||||
|
||||
if provider_type == ProviderType.anthropic:
|
||||
return BetaMessageBatch(**data.get("data"))
|
||||
|
||||
raise ValueError(f"Unknown ProviderType type: {provider_type}")
|
||||
|
||||
|
||||
def serialize_batch_request_result(
|
||||
batch_individual_response: Optional[Union[BetaMessageBatchIndividualResponse]],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Convert a list of ToolRules into a JSON-serializable format."""
|
||||
if not batch_individual_response:
|
||||
return None
|
||||
|
||||
llm_provider_type = None
|
||||
if isinstance(batch_individual_response, BetaMessageBatchIndividualResponse):
|
||||
llm_provider_type = ProviderType.anthropic.value
|
||||
|
||||
if not llm_provider_type:
|
||||
raise ValueError(f"Could not determine llm provider from batch result object type: {batch_individual_response}")
|
||||
|
||||
return {"data": batch_individual_response.model_dump(mode="json"), "type": llm_provider_type}
|
||||
|
||||
|
||||
def deserialize_batch_request_result(data: Optional[Dict]) -> Optional[Union[BetaMessageBatchIndividualResponse]]:
|
||||
if not data:
|
||||
return None
|
||||
provider_type = ProviderType(data.get("type"))
|
||||
|
||||
if provider_type == ProviderType.anthropic:
|
||||
return BetaMessageBatchIndividualResponse(**data.get("data"))
|
||||
|
||||
raise ValueError(f"Unknown ProviderType type: {provider_type}")
|
||||
|
||||
|
||||
def serialize_agent_step_state(agent_step_state: Optional[AgentStepState]) -> Optional[Dict[str, Any]]:
|
||||
"""Convert a list of ToolRules into a JSON-serializable format."""
|
||||
if not agent_step_state:
|
||||
return None
|
||||
|
||||
return agent_step_state.model_dump(mode="json")
|
||||
|
||||
|
||||
def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepState]:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
return AgentStepState(**data)
|
||||
|
||||
@@ -38,29 +38,46 @@ class ToolRulesSolver(BaseModel):
|
||||
)
|
||||
tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
|
||||
|
||||
def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Separate the provided tool rules into init, standard, and terminal categories
|
||||
for rule in tool_rules:
|
||||
if rule.type == ToolRuleType.run_first:
|
||||
assert isinstance(rule, InitToolRule)
|
||||
self.init_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.constrain_child_tools:
|
||||
assert isinstance(rule, ChildToolRule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.conditional:
|
||||
assert isinstance(rule, ConditionalToolRule)
|
||||
self.validate_conditional_tool(rule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.exit_loop:
|
||||
assert isinstance(rule, TerminalToolRule)
|
||||
self.terminal_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.continue_loop:
|
||||
assert isinstance(rule, ContinueToolRule)
|
||||
self.continue_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.max_count_per_step:
|
||||
assert isinstance(rule, MaxCountPerStepToolRule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
def __init__(
|
||||
self,
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
init_tool_rules: Optional[List[InitToolRule]] = None,
|
||||
continue_tool_rules: Optional[List[ContinueToolRule]] = None,
|
||||
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
|
||||
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
|
||||
tool_call_history: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
init_tool_rules=init_tool_rules or [],
|
||||
continue_tool_rules=continue_tool_rules or [],
|
||||
child_based_tool_rules=child_based_tool_rules or [],
|
||||
terminal_tool_rules=terminal_tool_rules or [],
|
||||
tool_call_history=tool_call_history or [],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if tool_rules:
|
||||
for rule in tool_rules:
|
||||
if rule.type == ToolRuleType.run_first:
|
||||
assert isinstance(rule, InitToolRule)
|
||||
self.init_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.constrain_child_tools:
|
||||
assert isinstance(rule, ChildToolRule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.conditional:
|
||||
assert isinstance(rule, ConditionalToolRule)
|
||||
self.validate_conditional_tool(rule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.exit_loop:
|
||||
assert isinstance(rule, TerminalToolRule)
|
||||
self.terminal_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.continue_loop:
|
||||
assert isinstance(rule, ContinueToolRule)
|
||||
self.continue_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.max_count_per_step:
|
||||
assert isinstance(rule, MaxCountPerStepToolRule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
|
||||
def register_tool_call(self, tool_name: str):
|
||||
"""Update the internal state to track tool call history."""
|
||||
|
||||
@@ -13,6 +13,8 @@ from letta.orm.identities_blocks import IdentitiesBlocks
|
||||
from letta.orm.identity import Identity
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.job_messages import JobMessage
|
||||
from letta.orm.llm_batch_items import LLMBatchItem
|
||||
from letta.orm.llm_batch_job import LLMBatchJob
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||
|
||||
@@ -144,6 +144,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
viewonly=True,
|
||||
back_populates="manager_agent",
|
||||
)
|
||||
batch_items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="agent", lazy="selectin")
|
||||
|
||||
def to_pydantic(self, include_relationships: Optional[Set[str]] = None) -> PydanticAgentState:
|
||||
"""
|
||||
|
||||
@@ -2,16 +2,24 @@ from sqlalchemy import JSON
|
||||
from sqlalchemy.types import BINARY, TypeDecorator
|
||||
|
||||
from letta.helpers.converters import (
|
||||
deserialize_agent_step_state,
|
||||
deserialize_batch_request_result,
|
||||
deserialize_create_batch_response,
|
||||
deserialize_embedding_config,
|
||||
deserialize_llm_config,
|
||||
deserialize_message_content,
|
||||
deserialize_poll_batch_response,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_returns,
|
||||
deserialize_tool_rules,
|
||||
deserialize_vector,
|
||||
serialize_agent_step_state,
|
||||
serialize_batch_request_result,
|
||||
serialize_create_batch_response,
|
||||
serialize_embedding_config,
|
||||
serialize_llm_config,
|
||||
serialize_message_content,
|
||||
serialize_poll_batch_response,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_returns,
|
||||
serialize_tool_rules,
|
||||
@@ -108,3 +116,55 @@ class CommonVector(TypeDecorator):
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_vector(value, dialect)
|
||||
|
||||
|
||||
class CreateBatchResponseColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_create_batch_response(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_create_batch_response(value)
|
||||
|
||||
|
||||
class PollBatchResponseColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_poll_batch_response(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_poll_batch_response(value)
|
||||
|
||||
|
||||
class BatchRequestResultColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_batch_request_result(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_batch_request_result(value)
|
||||
|
||||
|
||||
class AgentStepStateColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_agent_step_state(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_agent_step_state(value)
|
||||
|
||||
55
letta/orm/llm_batch_items.py
Normal file
55
letta/orm/llm_batch_items.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse
|
||||
from sqlalchemy import ForeignKey, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import AgentStepStateColumn, BatchRequestResultColumn, LLMConfigColumn
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus
|
||||
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
"""Represents a single agent's LLM request within a batch"""
|
||||
|
||||
__tablename__ = "llm_batch_items"
|
||||
__pydantic_model__ = PydanticLLMBatchItem
|
||||
__table_args__ = (
|
||||
Index("ix_llm_batch_items_batch_id", "batch_id"),
|
||||
Index("ix_llm_batch_items_agent_id", "agent_id"),
|
||||
Index("ix_llm_batch_items_status", "request_status"),
|
||||
)
|
||||
|
||||
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
# TODO: Some still rely on the Pydantic object to do this
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"batch_item-{uuid.uuid4()}")
|
||||
|
||||
batch_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("llm_batch_job.id", ondelete="CASCADE"), doc="Foreign key to the LLM provider batch this item belongs to"
|
||||
)
|
||||
|
||||
llm_config: Mapped[LLMConfig] = mapped_column(LLMConfigColumn, nullable=False, doc="LLM configuration specific to this request")
|
||||
|
||||
request_status: Mapped[JobStatus] = mapped_column(
|
||||
String, default=JobStatus.created, doc="Status of the LLM request in the batch (PENDING, SUBMITTED, DONE, ERROR)"
|
||||
)
|
||||
|
||||
step_status: Mapped[AgentStepStatus] = mapped_column(String, default=AgentStepStatus.paused, doc="Status of the agent's step execution")
|
||||
|
||||
step_state: Mapped[AgentStepState] = mapped_column(
|
||||
AgentStepStateColumn, doc="Execution metadata for resuming the agent step (e.g., tool call ID, timestamps)"
|
||||
)
|
||||
|
||||
batch_request_result: Mapped[Optional[Union[BetaMessageBatchIndividualResponse]]] = mapped_column(
|
||||
BatchRequestResultColumn, nullable=True, doc="Raw JSON response from the LLM for this item"
|
||||
)
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_items")
|
||||
batch: Mapped["LLMBatchJob"] = relationship("LLMBatchJob", back_populates="items", lazy="selectin")
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="batch_items", lazy="selectin")
|
||||
48
letta/orm/llm_batch_job.py
Normal file
48
letta/orm/llm_batch_job.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch
|
||||
from sqlalchemy import DateTime, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.custom_columns import CreateBatchResponseColumn, PollBatchResponseColumn
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.enums import JobStatus, ProviderType
|
||||
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
|
||||
|
||||
|
||||
class LLMBatchJob(SqlalchemyBase, OrganizationMixin):
|
||||
"""Represents a single LLM batch request made to a provider like Anthropic"""
|
||||
|
||||
__tablename__ = "llm_batch_job"
|
||||
__table_args__ = (
|
||||
Index("ix_llm_batch_job_created_at", "created_at"),
|
||||
Index("ix_llm_batch_job_status", "status"),
|
||||
)
|
||||
|
||||
__pydantic_model__ = PydanticLLMBatchJob
|
||||
|
||||
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
# TODO: Some still rely on the Pydantic object to do this
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"batch_req-{uuid.uuid4()}")
|
||||
|
||||
status: Mapped[JobStatus] = mapped_column(String, default=JobStatus.created, doc="The current status of the batch.")
|
||||
|
||||
llm_provider: Mapped[ProviderType] = mapped_column(String, doc="LLM provider used (e.g., 'Anthropic')")
|
||||
|
||||
create_batch_response: Mapped[Union[BetaMessageBatch]] = mapped_column(
|
||||
CreateBatchResponseColumn, doc="Full JSON response from initial batch creation"
|
||||
)
|
||||
latest_polling_response: Mapped[Union[BetaMessageBatch]] = mapped_column(
|
||||
PollBatchResponseColumn, nullable=True, doc="Last known polling result from LLM provider"
|
||||
)
|
||||
|
||||
last_polled_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, doc="Last time we polled the provider for status"
|
||||
)
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs")
|
||||
items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin")
|
||||
@@ -51,6 +51,8 @@ class Organization(SqlalchemyBase):
|
||||
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
|
||||
identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan")
|
||||
groups: Mapped[List["Group"]] = relationship("Group", back_populates="organization", cascade="all, delete-orphan")
|
||||
llm_batch_jobs: Mapped[List["Agent"]] = relationship("LLMBatchJob", back_populates="organization", cascade="all, delete-orphan")
|
||||
llm_batch_items: Mapped[List["Agent"]] = relationship("LLMBatchItem", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@property
|
||||
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable
|
||||
@@ -271,3 +272,8 @@ class AgentStepResponse(BaseModel):
|
||||
..., description="Whether the agent step ended because the in-context memory is near its limit."
|
||||
)
|
||||
usage: UsageStatistics = Field(..., description="Usage statistics of the LLM call during the agent's step.")
|
||||
|
||||
|
||||
class AgentStepState(BaseModel):
|
||||
step_number: int = Field(..., description="The current step number in the agent loop")
|
||||
tool_rules_solver: ToolRulesSolver = Field(..., description="The current state of the ToolRulesSolver")
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
anthropic = "anthropic"
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
assistant = "assistant"
|
||||
user = "user"
|
||||
@@ -22,6 +26,7 @@ class JobStatus(str, Enum):
|
||||
Status of the job.
|
||||
"""
|
||||
|
||||
not_started = "not_started"
|
||||
created = "created"
|
||||
running = "running"
|
||||
completed = "completed"
|
||||
@@ -29,6 +34,15 @@ class JobStatus(str, Enum):
|
||||
pending = "pending"
|
||||
|
||||
|
||||
class AgentStepStatus(str, Enum):
|
||||
"""
|
||||
Status of the job.
|
||||
"""
|
||||
|
||||
paused = "paused"
|
||||
running = "running"
|
||||
|
||||
|
||||
class MessageStreamStatus(str, Enum):
|
||||
done = "[DONE]"
|
||||
|
||||
|
||||
53
letta/schemas/llm_batch_job.py
Normal file
53
letta/schemas/llm_batch_job.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
class LLMBatchItem(OrmMetadataBase, validate_assignment=True):
|
||||
"""
|
||||
Represents a single agent's LLM request within a batch.
|
||||
|
||||
This object captures the configuration, execution status, and eventual result of one agent's request within a larger LLM batch job.
|
||||
"""
|
||||
|
||||
__id_prefix__ = "batch_item"
|
||||
|
||||
id: str = Field(..., description="The id of the batch item. Assigned by the database.")
|
||||
batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.")
|
||||
agent_id: str = Field(..., description="The id of the agent associated with this LLM request.")
|
||||
|
||||
llm_config: LLMConfig = Field(..., description="The LLM configuration used for this request.")
|
||||
request_status: JobStatus = Field(..., description="The current status of the batch item request (e.g., PENDING, DONE, ERROR).")
|
||||
step_status: AgentStepStatus = Field(..., description="The current execution status of the agent step.")
|
||||
step_state: AgentStepState = Field(..., description="The serialized state for resuming execution at a later point.")
|
||||
|
||||
batch_request_result: Optional[Union[BetaMessageBatchIndividualResponse]] = Field(
|
||||
None, description="The raw response received from the LLM provider for this item."
|
||||
)
|
||||
|
||||
|
||||
class LLMBatchJob(OrmMetadataBase, validate_assignment=True):
|
||||
"""
|
||||
Represents a single LLM batch request made to a provider like Anthropic.
|
||||
|
||||
Each job corresponds to one API call that sends multiple messages to the LLM provider, and aggregates responses across all agent submissions.
|
||||
"""
|
||||
|
||||
__id_prefix__ = "batch_req"
|
||||
|
||||
id: str = Field(..., description="The id of the batch job. Assigned by the database.")
|
||||
status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).")
|
||||
llm_provider: ProviderType = Field(..., description="The LLM provider used for the batch (e.g., anthropic, openai).")
|
||||
|
||||
create_batch_response: Union[BetaMessageBatch] = Field(..., description="The full JSON response from the initial batch creation.")
|
||||
latest_polling_response: Optional[Union[BetaMessageBatch]] = Field(
|
||||
None, description="The most recent polling response received from the LLM provider."
|
||||
)
|
||||
last_polled_at: Optional[datetime] = Field(None, description="The timestamp of the last polling check for the batch status.")
|
||||
@@ -81,6 +81,7 @@ from letta.services.block_manager import BlockManager
|
||||
from letta.services.group_manager import GroupManager
|
||||
from letta.services.identity_manager import IdentityManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.llm_batch_manager import LLMBatchManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
@@ -207,6 +208,7 @@ class SyncServer(Server):
|
||||
self.step_manager = StepManager()
|
||||
self.identity_manager = IdentityManager()
|
||||
self.group_manager = GroupManager()
|
||||
self.batch_manager = LLMBatchManager()
|
||||
|
||||
# Make default user and org
|
||||
if init_with_default_org_and_user:
|
||||
|
||||
139
letta/services/llm_batch_manager.py
Normal file
139
letta/services/llm_batch_manager.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.llm_batch_items import LLMBatchItem
|
||||
from letta.orm.llm_batch_job import LLMBatchJob
|
||||
from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus
|
||||
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
|
||||
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LLMBatchManager:
|
||||
"""Manager for handling both LLMBatchJob and LLMBatchItem operations."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_batch_request(
|
||||
self,
|
||||
llm_provider: str,
|
||||
create_batch_response: BetaMessageBatch,
|
||||
actor: PydanticUser,
|
||||
status: JobStatus = JobStatus.created,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Create a new LLM batch job."""
|
||||
with self.session_maker() as session:
|
||||
batch = LLMBatchJob(
|
||||
status=status,
|
||||
llm_provider=llm_provider,
|
||||
create_batch_response=create_batch_response,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
batch.create(session, actor=actor)
|
||||
return batch.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_batch_request_by_id(self, batch_id: str, actor: PydanticUser) -> PydanticLLMBatchJob:
|
||||
"""Retrieve a single batch job by ID."""
|
||||
with self.session_maker() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||||
return batch.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_batch_status(
|
||||
self,
|
||||
batch_id: str,
|
||||
status: JobStatus,
|
||||
actor: PydanticUser,
|
||||
latest_polling_response: Optional[BetaMessageBatch] = None,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Update a batch job’s status and optionally its polling response."""
|
||||
with self.session_maker() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||||
batch.status = status
|
||||
batch.latest_polling_response = latest_polling_response
|
||||
batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
return batch.update(db_session=session, actor=actor).to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_batch_request(self, batch_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard delete a batch job by ID."""
|
||||
with self.session_maker() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||||
batch.hard_delete(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def create_batch_item(
|
||||
self,
|
||||
batch_id: str,
|
||||
agent_id: str,
|
||||
llm_config: LLMConfig,
|
||||
actor: PydanticUser,
|
||||
request_status: JobStatus = JobStatus.created,
|
||||
step_status: AgentStepStatus = AgentStepStatus.paused,
|
||||
step_state: Optional[AgentStepState] = None,
|
||||
) -> PydanticLLMBatchItem:
|
||||
"""Create a new batch item."""
|
||||
with self.session_maker() as session:
|
||||
item = LLMBatchItem(
|
||||
batch_id=batch_id,
|
||||
agent_id=agent_id,
|
||||
llm_config=llm_config,
|
||||
request_status=request_status,
|
||||
step_status=step_status,
|
||||
step_state=step_state,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
item.create(session, actor=actor)
|
||||
return item.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
|
||||
"""Retrieve a single batch item by ID."""
|
||||
with self.session_maker() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
return item.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_batch_item(
|
||||
self,
|
||||
item_id: str,
|
||||
actor: PydanticUser,
|
||||
request_status: Optional[JobStatus] = None,
|
||||
step_status: Optional[AgentStepStatus] = None,
|
||||
llm_request_response: Optional[BetaMessageBatchIndividualResponse] = None,
|
||||
step_state: Optional[AgentStepState] = None,
|
||||
) -> PydanticLLMBatchItem:
|
||||
"""Update fields on a batch item."""
|
||||
with self.session_maker() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
|
||||
if request_status:
|
||||
item.request_status = request_status
|
||||
if step_status:
|
||||
item.step_status = step_status
|
||||
if llm_request_response:
|
||||
item.batch_request_result = llm_request_response
|
||||
if step_state:
|
||||
item.step_state = step_state
|
||||
|
||||
return item.update(db_session=session, actor=actor).to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_batch_item(self, item_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard delete a batch item by ID."""
|
||||
with self.session_maker() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
item.hard_delete(db_session=session, actor=actor)
|
||||
@@ -2,10 +2,17 @@ import os
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from anthropic.types.beta.messages import (
|
||||
BetaMessageBatch,
|
||||
BetaMessageBatchIndividualResponse,
|
||||
BetaMessageBatchRequestCounts,
|
||||
BetaMessageBatchSucceededResult,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -23,15 +30,16 @@ from letta.constants import (
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.orm import Base, Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.enums import ActorType, JobType, ToolType
|
||||
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent import AgentStepState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate
|
||||
@@ -570,6 +578,62 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
return [agent1, agent2, agent3]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_beta_message_batch() -> BetaMessageBatch:
|
||||
return BetaMessageBatch(
|
||||
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
|
||||
processing_status="in_progress",
|
||||
request_counts=BetaMessageBatchRequestCounts(
|
||||
canceled=10,
|
||||
errored=30,
|
||||
expired=10,
|
||||
processing=100,
|
||||
succeeded=50,
|
||||
),
|
||||
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
|
||||
type="message_batch",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_llm_config() -> LLMConfig:
|
||||
return LLMConfig.default_config("gpt-4")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tool_rules_solver() -> ToolRulesSolver:
|
||||
return ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_step_state(dummy_tool_rules_solver: ToolRulesSolver) -> AgentStepState:
|
||||
return AgentStepState(step_number=1, tool_rules_solver=dummy_tool_rules_solver)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_successful_response() -> BetaMessageBatchIndividualResponse:
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
custom_id="my-second-request",
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
type="succeeded",
|
||||
message=BetaMessage(
|
||||
id="msg_abc123",
|
||||
role="assistant",
|
||||
type="message",
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
content=[{"type": "text", "text": "hi!"}],
|
||||
usage={"input_tokens": 5, "output_tokens": 7},
|
||||
stop_reason="end_turn",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -4614,3 +4678,123 @@ def test_list_tags(server: SyncServer, default_user, default_organization):
|
||||
# Cleanup
|
||||
for agent in agents:
|
||||
server.agent_manager.delete_agent(agent.id, actor=default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# LLMBatchManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
assert batch.id.startswith("batch_req-")
|
||||
assert batch.create_batch_response == dummy_beta_message_batch
|
||||
fetched = server.batch_manager.get_batch_request_by_id(batch.id, actor=default_user)
|
||||
assert fetched.id == batch.id
|
||||
|
||||
|
||||
def test_update_batch_status(server, default_user, dummy_beta_message_batch):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
before = datetime.now(timezone.utc)
|
||||
|
||||
server.batch_manager.update_batch_status(
|
||||
batch_id=batch.id,
|
||||
status=JobStatus.completed,
|
||||
latest_polling_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_batch_request_by_id(batch.id, actor=default_user)
|
||||
assert updated.status == JobStatus.completed
|
||||
assert updated.latest_polling_response == dummy_beta_message_batch
|
||||
assert updated.last_polled_at >= before
|
||||
|
||||
|
||||
def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert item.batch_id == batch.id
|
||||
assert item.agent_id == sarah_agent.id
|
||||
assert item.step_state == dummy_step_state
|
||||
|
||||
fetched = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
|
||||
assert fetched.id == item.id
|
||||
|
||||
|
||||
def test_update_batch_item(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response
|
||||
):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver)
|
||||
|
||||
server.batch_manager.update_batch_item(
|
||||
item_id=item.id,
|
||||
request_status=JobStatus.completed,
|
||||
step_status=AgentStepStatus.running,
|
||||
llm_request_response=dummy_successful_response,
|
||||
step_state=updated_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
|
||||
assert updated.request_status == JobStatus.completed
|
||||
assert updated.batch_request_result == dummy_successful_response
|
||||
|
||||
|
||||
def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
item = server.batch_manager.create_batch_item(
|
||||
batch_id=batch.id,
|
||||
agent_id=sarah_agent.id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=dummy_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
server.batch_manager.delete_batch_item(item_id=item.id, actor=default_user)
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
|
||||
|
||||
Reference in New Issue
Block a user