feat: Create batch request tracking tables (#1604)

This commit is contained in:
Matthew Zhou
2025-04-07 16:27:18 -07:00
committed by GitHub
parent 0af857d3ba
commit 0aeddec547
15 changed files with 804 additions and 33 deletions

View File

@@ -0,0 +1,86 @@
"""Add LLM batch jobs tables
Revision ID: 0ceb975e0063
Revises: 90bb156e71df
Create Date: 2025-04-07 15:57:18.475151
"""
from typing import Sequence, Union
import sqlalchemy as sa
import letta
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0ceb975e0063"
down_revision: Union[str, None] = "90bb156e71df"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"llm_batch_job",
sa.Column("id", sa.String(), nullable=False),
sa.Column("status", sa.String(), nullable=False),
sa.Column("llm_provider", sa.String(), nullable=False),
sa.Column("create_batch_response", letta.orm.custom_columns.CreateBatchResponseColumn(), nullable=False),
sa.Column("latest_polling_response", letta.orm.custom_columns.PollBatchResponseColumn(), nullable=True),
sa.Column("last_polled_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_llm_batch_job_created_at", "llm_batch_job", ["created_at"], unique=False)
op.create_index("ix_llm_batch_job_status", "llm_batch_job", ["status"], unique=False)
op.create_table(
"llm_batch_items",
sa.Column("id", sa.String(), nullable=False),
sa.Column("batch_id", sa.String(), nullable=False),
sa.Column("llm_config", letta.orm.custom_columns.LLMConfigColumn(), nullable=False),
sa.Column("request_status", sa.String(), nullable=False),
sa.Column("step_status", sa.String(), nullable=False),
sa.Column("step_state", letta.orm.custom_columns.AgentStepStateColumn(), nullable=False),
sa.Column("batch_request_result", letta.orm.custom_columns.BatchRequestResultColumn(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("agent_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["batch_id"], ["llm_batch_job.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_llm_batch_items_agent_id", "llm_batch_items", ["agent_id"], unique=False)
op.create_index("ix_llm_batch_items_batch_id", "llm_batch_items", ["batch_id"], unique=False)
op.create_index("ix_llm_batch_items_status", "llm_batch_items", ["request_status"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_llm_batch_items_status", table_name="llm_batch_items")
op.drop_index("ix_llm_batch_items_batch_id", table_name="llm_batch_items")
op.drop_index("ix_llm_batch_items_agent_id", table_name="llm_batch_items")
op.drop_table("llm_batch_items")
op.drop_index("ix_llm_batch_job_status", table_name="llm_batch_job")
op.drop_index("ix_llm_batch_job_created_at", table_name="llm_batch_job")
op.drop_table("llm_batch_job")
# ### end Alembic commands ###

View File

@@ -2,12 +2,14 @@ import base64
from typing import Any, Dict, List, Optional, Union
import numpy as np
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy import Dialect
from letta.schemas.agent import AgentStepState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.enums import ProviderType, ToolRuleType
from letta.schemas.letta_message_content import (
MessageContent,
MessageContentType,
@@ -38,7 +40,7 @@ from letta.schemas.tool_rule import (
def serialize_llm_config(config: Union[Optional[LLMConfig], Dict]) -> Optional[Dict]:
"""Convert an LLMConfig object into a JSON-serializable dictionary."""
if config and isinstance(config, LLMConfig):
return config.model_dump()
return config.model_dump(mode="json")
return config
@@ -55,7 +57,7 @@ def deserialize_llm_config(data: Optional[Dict]) -> Optional[LLMConfig]:
def serialize_embedding_config(config: Union[Optional[EmbeddingConfig], Dict]) -> Optional[Dict]:
"""Convert an EmbeddingConfig object into a JSON-serializable dictionary."""
if config and isinstance(config, EmbeddingConfig):
return config.model_dump()
return config.model_dump(mode="json")
return config
@@ -75,7 +77,9 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str,
if not tool_rules:
return []
data = [{**rule.model_dump(), "type": rule.type.value} for rule in tool_rules] # Convert Enum to string for JSON compatibility
data = [
{**rule.model_dump(mode="json"), "type": rule.type.value} for rule in tool_rules
] # Convert Enum to string for JSON compatibility
# Validate ToolRule structure
for rule_data in data:
@@ -130,7 +134,7 @@ def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]]
serialized_calls = []
for call in tool_calls:
if isinstance(call, OpenAIToolCall):
serialized_calls.append(call.model_dump())
serialized_calls.append(call.model_dump(mode="json"))
elif isinstance(call, dict):
serialized_calls.append(call) # Already a dictionary, leave it as-is
else:
@@ -166,7 +170,7 @@ def serialize_tool_returns(tool_returns: Optional[List[Union[ToolReturn, dict]]]
serialized_tool_returns = []
for tool_return in tool_returns:
if isinstance(tool_return, ToolReturn):
serialized_tool_returns.append(tool_return.model_dump())
serialized_tool_returns.append(tool_return.model_dump(mode="json"))
elif isinstance(tool_return, dict):
serialized_tool_returns.append(tool_return) # Already a dictionary, leave it as-is
else:
@@ -201,7 +205,7 @@ def serialize_message_content(message_content: Optional[List[Union[MessageConten
serialized_message_content = []
for content in message_content:
if isinstance(content, MessageContent):
serialized_message_content.append(content.model_dump())
serialized_message_content.append(content.model_dump(mode="json"))
elif isinstance(content, dict):
serialized_message_content.append(content) # Already a dictionary, leave it as-is
else:
@@ -266,3 +270,101 @@ def deserialize_vector(data: Optional[bytes], dialect: Dialect) -> Optional[np.n
data = base64.b64decode(data)
return np.frombuffer(data, dtype=np.float32)
# --------------------------
# Batch Request Serialization
# --------------------------
def serialize_create_batch_response(create_batch_response: Union[BetaMessageBatch]) -> Dict[str, Any]:
"""Convert a list of ToolRules into a JSON-serializable format."""
llm_provider_type = None
if isinstance(create_batch_response, BetaMessageBatch):
llm_provider_type = ProviderType.anthropic.value
if not llm_provider_type:
raise ValueError(f"Could not determine llm provider from create batch response object type: {create_batch_response}")
return {"data": create_batch_response.model_dump(mode="json"), "type": llm_provider_type}
def deserialize_create_batch_response(data: Dict) -> Union[BetaMessageBatch]:
provider_type = ProviderType(data.get("type"))
if provider_type == ProviderType.anthropic:
return BetaMessageBatch(**data.get("data"))
raise ValueError(f"Unknown ProviderType type: {provider_type}")
# TODO: Note that this is the same as above for Anthropic, but this is not the case for all providers
# TODO: Some have different types based on the create v.s. poll requests
def serialize_poll_batch_response(poll_batch_response: Optional[Union[BetaMessageBatch]]) -> Optional[Dict[str, Any]]:
"""Convert a list of ToolRules into a JSON-serializable format."""
if not poll_batch_response:
return None
llm_provider_type = None
if isinstance(poll_batch_response, BetaMessageBatch):
llm_provider_type = ProviderType.anthropic.value
if not llm_provider_type:
raise ValueError(f"Could not determine llm provider from poll batch response object type: {poll_batch_response}")
return {"data": poll_batch_response.model_dump(mode="json"), "type": llm_provider_type}
def deserialize_poll_batch_response(data: Optional[Dict]) -> Optional[Union[BetaMessageBatch]]:
if not data:
return None
provider_type = ProviderType(data.get("type"))
if provider_type == ProviderType.anthropic:
return BetaMessageBatch(**data.get("data"))
raise ValueError(f"Unknown ProviderType type: {provider_type}")
def serialize_batch_request_result(
batch_individual_response: Optional[Union[BetaMessageBatchIndividualResponse]],
) -> Optional[Dict[str, Any]]:
"""Convert a list of ToolRules into a JSON-serializable format."""
if not batch_individual_response:
return None
llm_provider_type = None
if isinstance(batch_individual_response, BetaMessageBatchIndividualResponse):
llm_provider_type = ProviderType.anthropic.value
if not llm_provider_type:
raise ValueError(f"Could not determine llm provider from batch result object type: {batch_individual_response}")
return {"data": batch_individual_response.model_dump(mode="json"), "type": llm_provider_type}
def deserialize_batch_request_result(data: Optional[Dict]) -> Optional[Union[BetaMessageBatchIndividualResponse]]:
if not data:
return None
provider_type = ProviderType(data.get("type"))
if provider_type == ProviderType.anthropic:
return BetaMessageBatchIndividualResponse(**data.get("data"))
raise ValueError(f"Unknown ProviderType type: {provider_type}")
def serialize_agent_step_state(agent_step_state: Optional[AgentStepState]) -> Optional[Dict[str, Any]]:
"""Convert a list of ToolRules into a JSON-serializable format."""
if not agent_step_state:
return None
return agent_step_state.model_dump(mode="json")
def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepState]:
if not data:
return None
return AgentStepState(**data)

View File

@@ -38,29 +38,46 @@ class ToolRulesSolver(BaseModel):
)
tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
super().__init__(**kwargs)
# Separate the provided tool rules into init, standard, and terminal categories
for rule in tool_rules:
if rule.type == ToolRuleType.run_first:
assert isinstance(rule, InitToolRule)
self.init_tool_rules.append(rule)
elif rule.type == ToolRuleType.constrain_child_tools:
assert isinstance(rule, ChildToolRule)
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.conditional:
assert isinstance(rule, ConditionalToolRule)
self.validate_conditional_tool(rule)
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.exit_loop:
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)
elif rule.type == ToolRuleType.continue_loop:
assert isinstance(rule, ContinueToolRule)
self.continue_tool_rules.append(rule)
elif rule.type == ToolRuleType.max_count_per_step:
assert isinstance(rule, MaxCountPerStepToolRule)
self.child_based_tool_rules.append(rule)
def __init__(
self,
tool_rules: Optional[List[BaseToolRule]] = None,
init_tool_rules: Optional[List[InitToolRule]] = None,
continue_tool_rules: Optional[List[ContinueToolRule]] = None,
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
tool_call_history: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
init_tool_rules=init_tool_rules or [],
continue_tool_rules=continue_tool_rules or [],
child_based_tool_rules=child_based_tool_rules or [],
terminal_tool_rules=terminal_tool_rules or [],
tool_call_history=tool_call_history or [],
**kwargs,
)
if tool_rules:
for rule in tool_rules:
if rule.type == ToolRuleType.run_first:
assert isinstance(rule, InitToolRule)
self.init_tool_rules.append(rule)
elif rule.type == ToolRuleType.constrain_child_tools:
assert isinstance(rule, ChildToolRule)
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.conditional:
assert isinstance(rule, ConditionalToolRule)
self.validate_conditional_tool(rule)
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.exit_loop:
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)
elif rule.type == ToolRuleType.continue_loop:
assert isinstance(rule, ContinueToolRule)
self.continue_tool_rules.append(rule)
elif rule.type == ToolRuleType.max_count_per_step:
assert isinstance(rule, MaxCountPerStepToolRule)
self.child_based_tool_rules.append(rule)
def register_tool_call(self, tool_name: str):
"""Update the internal state to track tool call history."""

View File

@@ -13,6 +13,8 @@ from letta.orm.identities_blocks import IdentitiesBlocks
from letta.orm.identity import Identity
from letta.orm.job import Job
from letta.orm.job_messages import JobMessage
from letta.orm.llm_batch_items import LLMBatchItem
from letta.orm.llm_batch_job import LLMBatchJob
from letta.orm.message import Message
from letta.orm.organization import Organization
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage

View File

@@ -144,6 +144,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
viewonly=True,
back_populates="manager_agent",
)
batch_items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="agent", lazy="selectin")
def to_pydantic(self, include_relationships: Optional[Set[str]] = None) -> PydanticAgentState:
"""

View File

@@ -2,16 +2,24 @@ from sqlalchemy import JSON
from sqlalchemy.types import BINARY, TypeDecorator
from letta.helpers.converters import (
deserialize_agent_step_state,
deserialize_batch_request_result,
deserialize_create_batch_response,
deserialize_embedding_config,
deserialize_llm_config,
deserialize_message_content,
deserialize_poll_batch_response,
deserialize_tool_calls,
deserialize_tool_returns,
deserialize_tool_rules,
deserialize_vector,
serialize_agent_step_state,
serialize_batch_request_result,
serialize_create_batch_response,
serialize_embedding_config,
serialize_llm_config,
serialize_message_content,
serialize_poll_batch_response,
serialize_tool_calls,
serialize_tool_returns,
serialize_tool_rules,
@@ -108,3 +116,55 @@ class CommonVector(TypeDecorator):
def process_result_value(self, value, dialect):
return deserialize_vector(value, dialect)
class CreateBatchResponseColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_create_batch_response(value)
def process_result_value(self, value, dialect):
return deserialize_create_batch_response(value)
class PollBatchResponseColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_poll_batch_response(value)
def process_result_value(self, value, dialect):
return deserialize_poll_batch_response(value)
class BatchRequestResultColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_batch_request_result(value)
def process_result_value(self, value, dialect):
return deserialize_batch_request_result(value)
class AgentStepStateColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_agent_step_state(value)
def process_result_value(self, value, dialect):
return deserialize_agent_step_state(value)

View File

@@ -0,0 +1,55 @@
import uuid
from typing import Optional, Union
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse
from sqlalchemy import ForeignKey, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.custom_columns import AgentStepStateColumn, BatchRequestResultColumn, LLMConfigColumn
from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.agent import AgentStepState
from letta.schemas.enums import AgentStepStatus, JobStatus
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
from letta.schemas.llm_config import LLMConfig
class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin):
"""Represents a single agent's LLM request within a batch"""
__tablename__ = "llm_batch_items"
__pydantic_model__ = PydanticLLMBatchItem
__table_args__ = (
Index("ix_llm_batch_items_batch_id", "batch_id"),
Index("ix_llm_batch_items_agent_id", "agent_id"),
Index("ix_llm_batch_items_status", "request_status"),
)
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
# TODO: Some still rely on the Pydantic object to do this
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"batch_item-{uuid.uuid4()}")
batch_id: Mapped[str] = mapped_column(
ForeignKey("llm_batch_job.id", ondelete="CASCADE"), doc="Foreign key to the LLM provider batch this item belongs to"
)
llm_config: Mapped[LLMConfig] = mapped_column(LLMConfigColumn, nullable=False, doc="LLM configuration specific to this request")
request_status: Mapped[JobStatus] = mapped_column(
String, default=JobStatus.created, doc="Status of the LLM request in the batch (PENDING, SUBMITTED, DONE, ERROR)"
)
step_status: Mapped[AgentStepStatus] = mapped_column(String, default=AgentStepStatus.paused, doc="Status of the agent's step execution")
step_state: Mapped[AgentStepState] = mapped_column(
AgentStepStateColumn, doc="Execution metadata for resuming the agent step (e.g., tool call ID, timestamps)"
)
batch_request_result: Mapped[Optional[Union[BetaMessageBatchIndividualResponse]]] = mapped_column(
BatchRequestResultColumn, nullable=True, doc="Raw JSON response from the LLM for this item"
)
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_items")
batch: Mapped["LLMBatchJob"] = relationship("LLMBatchJob", back_populates="items", lazy="selectin")
agent: Mapped["Agent"] = relationship("Agent", back_populates="batch_items", lazy="selectin")

View File

@@ -0,0 +1,48 @@
import uuid
from datetime import datetime
from typing import List, Optional, Union
from anthropic.types.beta.messages import BetaMessageBatch
from sqlalchemy import DateTime, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.custom_columns import CreateBatchResponseColumn, PollBatchResponseColumn
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.enums import JobStatus, ProviderType
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
class LLMBatchJob(SqlalchemyBase, OrganizationMixin):
"""Represents a single LLM batch request made to a provider like Anthropic"""
__tablename__ = "llm_batch_job"
__table_args__ = (
Index("ix_llm_batch_job_created_at", "created_at"),
Index("ix_llm_batch_job_status", "status"),
)
__pydantic_model__ = PydanticLLMBatchJob
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
# TODO: Some still rely on the Pydantic object to do this
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"batch_req-{uuid.uuid4()}")
status: Mapped[JobStatus] = mapped_column(String, default=JobStatus.created, doc="The current status of the batch.")
llm_provider: Mapped[ProviderType] = mapped_column(String, doc="LLM provider used (e.g., 'Anthropic')")
create_batch_response: Mapped[Union[BetaMessageBatch]] = mapped_column(
CreateBatchResponseColumn, doc="Full JSON response from initial batch creation"
)
latest_polling_response: Mapped[Union[BetaMessageBatch]] = mapped_column(
PollBatchResponseColumn, nullable=True, doc="Last known polling result from LLM provider"
)
last_polled_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True, doc="Last time we polled the provider for status"
)
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs")
items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin")

View File

@@ -51,6 +51,8 @@ class Organization(SqlalchemyBase):
providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan")
identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan")
groups: Mapped[List["Group"]] = relationship("Group", back_populates="organization", cascade="all, delete-orphan")
llm_batch_jobs: Mapped[List["Agent"]] = relationship("LLMBatchJob", back_populates="organization", cascade="all, delete-orphan")
llm_batch_items: Mapped[List["Agent"]] = relationship("LLMBatchItem", back_populates="organization", cascade="all, delete-orphan")
@property
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:

View File

@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.helpers import ToolRulesSolver
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.environment_variables import AgentEnvironmentVariable
@@ -271,3 +272,8 @@ class AgentStepResponse(BaseModel):
..., description="Whether the agent step ended because the in-context memory is near its limit."
)
usage: UsageStatistics = Field(..., description="Usage statistics of the LLM call during the agent's step.")
class AgentStepState(BaseModel):
step_number: int = Field(..., description="The current step number in the agent loop")
tool_rules_solver: ToolRulesSolver = Field(..., description="The current state of the ToolRulesSolver")

View File

@@ -1,6 +1,10 @@
from enum import Enum
class ProviderType(str, Enum):
anthropic = "anthropic"
class MessageRole(str, Enum):
assistant = "assistant"
user = "user"
@@ -22,6 +26,7 @@ class JobStatus(str, Enum):
Status of the job.
"""
not_started = "not_started"
created = "created"
running = "running"
completed = "completed"
@@ -29,6 +34,15 @@ class JobStatus(str, Enum):
pending = "pending"
class AgentStepStatus(str, Enum):
"""
Status of the job.
"""
paused = "paused"
running = "running"
class MessageStreamStatus(str, Enum):
done = "[DONE]"

View File

@@ -0,0 +1,53 @@
from datetime import datetime
from typing import Optional, Union
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
from pydantic import Field
from letta.schemas.agent import AgentStepState
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.llm_config import LLMConfig
class LLMBatchItem(OrmMetadataBase, validate_assignment=True):
"""
Represents a single agent's LLM request within a batch.
This object captures the configuration, execution status, and eventual result of one agent's request within a larger LLM batch job.
"""
__id_prefix__ = "batch_item"
id: str = Field(..., description="The id of the batch item. Assigned by the database.")
batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.")
agent_id: str = Field(..., description="The id of the agent associated with this LLM request.")
llm_config: LLMConfig = Field(..., description="The LLM configuration used for this request.")
request_status: JobStatus = Field(..., description="The current status of the batch item request (e.g., PENDING, DONE, ERROR).")
step_status: AgentStepStatus = Field(..., description="The current execution status of the agent step.")
step_state: AgentStepState = Field(..., description="The serialized state for resuming execution at a later point.")
batch_request_result: Optional[Union[BetaMessageBatchIndividualResponse]] = Field(
None, description="The raw response received from the LLM provider for this item."
)
class LLMBatchJob(OrmMetadataBase, validate_assignment=True):
"""
Represents a single LLM batch request made to a provider like Anthropic.
Each job corresponds to one API call that sends multiple messages to the LLM provider, and aggregates responses across all agent submissions.
"""
__id_prefix__ = "batch_req"
id: str = Field(..., description="The id of the batch job. Assigned by the database.")
status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).")
llm_provider: ProviderType = Field(..., description="The LLM provider used for the batch (e.g., anthropic, openai).")
create_batch_response: Union[BetaMessageBatch] = Field(..., description="The full JSON response from the initial batch creation.")
latest_polling_response: Optional[Union[BetaMessageBatch]] = Field(
None, description="The most recent polling response received from the LLM provider."
)
last_polled_at: Optional[datetime] = Field(None, description="The timestamp of the last polling check for the batch status.")

View File

@@ -81,6 +81,7 @@ from letta.services.block_manager import BlockManager
from letta.services.group_manager import GroupManager
from letta.services.identity_manager import IdentityManager
from letta.services.job_manager import JobManager
from letta.services.llm_batch_manager import LLMBatchManager
from letta.services.message_manager import MessageManager
from letta.services.organization_manager import OrganizationManager
from letta.services.passage_manager import PassageManager
@@ -207,6 +208,7 @@ class SyncServer(Server):
self.step_manager = StepManager()
self.identity_manager = IdentityManager()
self.group_manager = GroupManager()
self.batch_manager = LLMBatchManager()
# Make default user and org
if init_with_default_org_and_user:

View File

@@ -0,0 +1,139 @@
import datetime
from typing import Optional
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
from letta.log import get_logger
from letta.orm.llm_batch_items import LLMBatchItem
from letta.orm.llm_batch_job import LLMBatchJob
from letta.schemas.agent import AgentStepState
from letta.schemas.enums import AgentStepStatus, JobStatus
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
from letta.schemas.llm_config import LLMConfig
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
logger = get_logger(__name__)
class LLMBatchManager:
"""Manager for handling both LLMBatchJob and LLMBatchItem operations."""
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def create_batch_request(
self,
llm_provider: str,
create_batch_response: BetaMessageBatch,
actor: PydanticUser,
status: JobStatus = JobStatus.created,
) -> PydanticLLMBatchJob:
"""Create a new LLM batch job."""
with self.session_maker() as session:
batch = LLMBatchJob(
status=status,
llm_provider=llm_provider,
create_batch_response=create_batch_response,
organization_id=actor.organization_id,
)
batch.create(session, actor=actor)
return batch.to_pydantic()
@enforce_types
def get_batch_request_by_id(self, batch_id: str, actor: PydanticUser) -> PydanticLLMBatchJob:
"""Retrieve a single batch job by ID."""
with self.session_maker() as session:
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
return batch.to_pydantic()
@enforce_types
def update_batch_status(
self,
batch_id: str,
status: JobStatus,
actor: PydanticUser,
latest_polling_response: Optional[BetaMessageBatch] = None,
) -> PydanticLLMBatchJob:
"""Update a batch jobs status and optionally its polling response."""
with self.session_maker() as session:
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
batch.status = status
batch.latest_polling_response = latest_polling_response
batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc)
return batch.update(db_session=session, actor=actor).to_pydantic()
@enforce_types
def delete_batch_request(self, batch_id: str, actor: PydanticUser) -> None:
"""Hard delete a batch job by ID."""
with self.session_maker() as session:
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
batch.hard_delete(db_session=session, actor=actor)
@enforce_types
def create_batch_item(
self,
batch_id: str,
agent_id: str,
llm_config: LLMConfig,
actor: PydanticUser,
request_status: JobStatus = JobStatus.created,
step_status: AgentStepStatus = AgentStepStatus.paused,
step_state: Optional[AgentStepState] = None,
) -> PydanticLLMBatchItem:
"""Create a new batch item."""
with self.session_maker() as session:
item = LLMBatchItem(
batch_id=batch_id,
agent_id=agent_id,
llm_config=llm_config,
request_status=request_status,
step_status=step_status,
step_state=step_state,
organization_id=actor.organization_id,
)
item.create(session, actor=actor)
return item.to_pydantic()
@enforce_types
def get_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
"""Retrieve a single batch item by ID."""
with self.session_maker() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
return item.to_pydantic()
@enforce_types
def update_batch_item(
self,
item_id: str,
actor: PydanticUser,
request_status: Optional[JobStatus] = None,
step_status: Optional[AgentStepStatus] = None,
llm_request_response: Optional[BetaMessageBatchIndividualResponse] = None,
step_state: Optional[AgentStepState] = None,
) -> PydanticLLMBatchItem:
"""Update fields on a batch item."""
with self.session_maker() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
if request_status:
item.request_status = request_status
if step_status:
item.step_status = step_status
if llm_request_response:
item.batch_request_result = llm_request_response
if step_state:
item.step_state = step_state
return item.update(db_session=session, actor=actor).to_pydantic()
@enforce_types
def delete_batch_item(self, item_id: str, actor: PydanticUser) -> None:
"""Hard delete a batch item by ID."""
with self.session_maker() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
item.hard_delete(db_session=session, actor=actor)

View File

@@ -2,10 +2,17 @@ import os
import random
import string
import time
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import List
import pytest
from anthropic.types.beta import BetaMessage
from anthropic.types.beta.messages import (
BetaMessageBatch,
BetaMessageBatchIndividualResponse,
BetaMessageBatchRequestCounts,
BetaMessageBatchSucceededResult,
)
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy.exc import IntegrityError
@@ -23,15 +30,16 @@ from letta.constants import (
from letta.embeddings import embedding_model
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.functions.mcp_client.types import MCPTool
from letta.helpers import ToolRulesSolver
from letta.orm import Base, Block
from letta.orm.block_history import BlockHistory
from letta.orm.enums import ActorType, JobType, ToolType
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.agent import AgentStepState, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate
@@ -570,6 +578,62 @@ def agent_with_tags(server: SyncServer, default_user):
return [agent1, agent2, agent3]
@pytest.fixture
def dummy_beta_message_batch() -> BetaMessageBatch:
return BetaMessageBatch(
id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF",
archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc),
processing_status="in_progress",
request_counts=BetaMessageBatchRequestCounts(
canceled=10,
errored=30,
expired=10,
processing=100,
succeeded=50,
),
results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results",
type="message_batch",
)
@pytest.fixture
def dummy_llm_config() -> LLMConfig:
return LLMConfig.default_config("gpt-4")
@pytest.fixture
def dummy_tool_rules_solver() -> ToolRulesSolver:
return ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")])
@pytest.fixture
def dummy_step_state(dummy_tool_rules_solver: ToolRulesSolver) -> AgentStepState:
return AgentStepState(step_number=1, tool_rules_solver=dummy_tool_rules_solver)
@pytest.fixture
def dummy_successful_response() -> BetaMessageBatchIndividualResponse:
return BetaMessageBatchIndividualResponse(
custom_id="my-second-request",
result=BetaMessageBatchSucceededResult(
type="succeeded",
message=BetaMessage(
id="msg_abc123",
role="assistant",
type="message",
model="claude-3-5-sonnet-20240620",
content=[{"type": "text", "text": "hi!"}],
usage={"input_tokens": 5, "output_tokens": 7},
stop_reason="end_turn",
),
),
)
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@@ -4614,3 +4678,123 @@ def test_list_tags(server: SyncServer, default_user, default_organization):
# Cleanup
for agent in agents:
server.agent_manager.delete_agent(agent.id, actor=default_user)
# ======================================================================================================================
# LLMBatchManager Tests
# ======================================================================================================================
def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch):
batch = server.batch_manager.create_batch_request(
llm_provider="anthropic",
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
)
assert batch.id.startswith("batch_req-")
assert batch.create_batch_response == dummy_beta_message_batch
fetched = server.batch_manager.get_batch_request_by_id(batch.id, actor=default_user)
assert fetched.id == batch.id
def test_update_batch_status(server, default_user, dummy_beta_message_batch):
batch = server.batch_manager.create_batch_request(
llm_provider="anthropic",
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
)
before = datetime.now(timezone.utc)
server.batch_manager.update_batch_status(
batch_id=batch.id,
status=JobStatus.completed,
latest_polling_response=dummy_beta_message_batch,
actor=default_user,
)
updated = server.batch_manager.get_batch_request_by_id(batch.id, actor=default_user)
assert updated.status == JobStatus.completed
assert updated.latest_polling_response == dummy_beta_message_batch
assert updated.last_polled_at >= before
def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
batch = server.batch_manager.create_batch_request(
llm_provider="anthropic",
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
)
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
assert item.batch_id == batch.id
assert item.agent_id == sarah_agent.id
assert item.step_state == dummy_step_state
fetched = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
assert fetched.id == item.id
def test_update_batch_item(
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response
):
batch = server.batch_manager.create_batch_request(
llm_provider="anthropic",
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
)
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver)
server.batch_manager.update_batch_item(
item_id=item.id,
request_status=JobStatus.completed,
step_status=AgentStepStatus.running,
llm_request_response=dummy_successful_response,
step_state=updated_step_state,
actor=default_user,
)
updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)
assert updated.request_status == JobStatus.completed
assert updated.batch_request_result == dummy_successful_response
def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
batch = server.batch_manager.create_batch_request(
llm_provider="anthropic",
status=JobStatus.created,
create_batch_response=dummy_beta_message_batch,
actor=default_user,
)
item = server.batch_manager.create_batch_item(
batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
step_state=dummy_step_state,
actor=default_user,
)
server.batch_manager.delete_batch_item(item_id=item.id, actor=default_user)
with pytest.raises(NoResultFound):
server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)