diff --git a/alembic/versions/0ceb975e0063_add_llm_batch_jobs_tables.py b/alembic/versions/0ceb975e0063_add_llm_batch_jobs_tables.py new file mode 100644 index 00000000..fee45f75 --- /dev/null +++ b/alembic/versions/0ceb975e0063_add_llm_batch_jobs_tables.py @@ -0,0 +1,86 @@ +"""Add LLM batch jobs tables + +Revision ID: 0ceb975e0063 +Revises: 90bb156e71df +Create Date: 2025-04-07 15:57:18.475151 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +import letta +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0ceb975e0063" +down_revision: Union[str, None] = "90bb156e71df" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "llm_batch_job", + sa.Column("id", sa.String(), nullable=False), + sa.Column("status", sa.String(), nullable=False), + sa.Column("llm_provider", sa.String(), nullable=False), + sa.Column("create_batch_response", letta.orm.custom_columns.CreateBatchResponseColumn(), nullable=False), + sa.Column("latest_polling_response", letta.orm.custom_columns.PollBatchResponseColumn(), nullable=True), + sa.Column("last_polled_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_llm_batch_job_created_at", "llm_batch_job", ["created_at"], unique=False) + op.create_index("ix_llm_batch_job_status", "llm_batch_job", ["status"], unique=False) + op.create_table( + "llm_batch_items", + sa.Column("id", sa.String(), nullable=False), + sa.Column("batch_id", sa.String(), nullable=False), + sa.Column("llm_config", letta.orm.custom_columns.LLMConfigColumn(), nullable=False), + sa.Column("request_status", sa.String(), nullable=False), + sa.Column("step_status", sa.String(), nullable=False), + sa.Column("step_state", letta.orm.custom_columns.AgentStepStateColumn(), nullable=False), + sa.Column("batch_request_result", letta.orm.custom_columns.BatchRequestResultColumn(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["batch_id"], ["llm_batch_job.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("ix_llm_batch_items_agent_id", "llm_batch_items", ["agent_id"], unique=False) + op.create_index("ix_llm_batch_items_batch_id", "llm_batch_items", ["batch_id"], unique=False) + op.create_index("ix_llm_batch_items_status", "llm_batch_items", ["request_status"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_llm_batch_items_status", table_name="llm_batch_items") + op.drop_index("ix_llm_batch_items_batch_id", table_name="llm_batch_items") + op.drop_index("ix_llm_batch_items_agent_id", table_name="llm_batch_items") + op.drop_table("llm_batch_items") + op.drop_index("ix_llm_batch_job_status", table_name="llm_batch_job") + op.drop_index("ix_llm_batch_job_created_at", table_name="llm_batch_job") + op.drop_table("llm_batch_job") + # ### end Alembic commands ### diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 4f0510de..6ffe25fb 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -2,12 +2,14 @@ import base64 from typing import Any, Dict, List, Optional, Union import numpy as np +from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction from sqlalchemy import Dialect +from letta.schemas.agent import AgentStepState from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import ToolRuleType +from letta.schemas.enums import ProviderType, ToolRuleType from letta.schemas.letta_message_content import ( MessageContent, MessageContentType, @@ -38,7 +40,7 @@ from letta.schemas.tool_rule import ( def serialize_llm_config(config: Union[Optional[LLMConfig], Dict]) -> Optional[Dict]: """Convert an LLMConfig object into a JSON-serializable dictionary.""" if config and isinstance(config, LLMConfig): - return config.model_dump() + return config.model_dump(mode="json") return config @@ -55,7 +57,7 @@ def deserialize_llm_config(data: Optional[Dict]) -> Optional[LLMConfig]: def serialize_embedding_config(config: Union[Optional[EmbeddingConfig], Dict]) -> Optional[Dict]: """Convert an EmbeddingConfig object into a JSON-serializable dictionary.""" if config and isinstance(config, EmbeddingConfig): - return config.model_dump() + return config.model_dump(mode="json") return config @@ -75,7 +77,9 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str, if not tool_rules: return [] - data = [{**rule.model_dump(), "type": rule.type.value} for rule in tool_rules] # Convert Enum to string for JSON compatibility + data = [ + {**rule.model_dump(mode="json"), "type": rule.type.value} for rule in tool_rules + ] # Convert Enum to string for JSON compatibility # Validate ToolRule structure for rule_data in data: @@ -130,7 +134,7 @@ def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]] serialized_calls = [] for call in tool_calls: if isinstance(call, OpenAIToolCall): - serialized_calls.append(call.model_dump()) + serialized_calls.append(call.model_dump(mode="json")) elif isinstance(call, dict): serialized_calls.append(call) # Already a dictionary, leave it as-is else: @@ -166,7 +170,7 @@ def serialize_tool_returns(tool_returns: Optional[List[Union[ToolReturn, dict]]] serialized_tool_returns = [] for tool_return in tool_returns: if isinstance(tool_return, ToolReturn): - serialized_tool_returns.append(tool_return.model_dump()) + serialized_tool_returns.append(tool_return.model_dump(mode="json")) elif isinstance(tool_return, dict): serialized_tool_returns.append(tool_return) # Already a dictionary, leave it as-is else: @@ -201,7 +205,7 @@ def serialize_message_content(message_content: Optional[List[Union[MessageConten serialized_message_content = [] for content in message_content: if isinstance(content, MessageContent): - serialized_message_content.append(content.model_dump()) + serialized_message_content.append(content.model_dump(mode="json")) elif isinstance(content, dict): serialized_message_content.append(content) # Already a dictionary, leave it as-is else: @@ -266,3 +270,101 @@ def deserialize_vector(data: Optional[bytes], dialect: Dialect) -> Optional[np.n data = base64.b64decode(data) return np.frombuffer(data, dtype=np.float32) + + +# -------------------------- +# Batch Request Serialization +# -------------------------- + + +def serialize_create_batch_response(create_batch_response: Union[BetaMessageBatch]) -> Dict[str, Any]: + """Convert a list of ToolRules into a JSON-serializable format.""" + llm_provider_type = None + if isinstance(create_batch_response, BetaMessageBatch): + llm_provider_type = ProviderType.anthropic.value + + if not llm_provider_type: + raise ValueError(f"Could not determine llm provider from create batch response object type: {create_batch_response}") + + return {"data": create_batch_response.model_dump(mode="json"), "type": llm_provider_type} + + +def deserialize_create_batch_response(data: Dict) -> Union[BetaMessageBatch]: + provider_type = ProviderType(data.get("type")) + + if provider_type == ProviderType.anthropic: + return BetaMessageBatch(**data.get("data")) + + raise ValueError(f"Unknown ProviderType type: {provider_type}") + + +# TODO: Note that this is the same as above for Anthropic, but this is not the case for all providers +# TODO: Some have different types based on the create v.s. poll requests +def serialize_poll_batch_response(poll_batch_response: Optional[Union[BetaMessageBatch]]) -> Optional[Dict[str, Any]]: + """Convert a list of ToolRules into a JSON-serializable format.""" + if not poll_batch_response: + return None + + llm_provider_type = None + if isinstance(poll_batch_response, BetaMessageBatch): + llm_provider_type = ProviderType.anthropic.value + + if not llm_provider_type: + raise ValueError(f"Could not determine llm provider from poll batch response object type: {poll_batch_response}") + + return {"data": poll_batch_response.model_dump(mode="json"), "type": llm_provider_type} + + +def deserialize_poll_batch_response(data: Optional[Dict]) -> Optional[Union[BetaMessageBatch]]: + if not data: + return None + + provider_type = ProviderType(data.get("type")) + + if provider_type == ProviderType.anthropic: + return BetaMessageBatch(**data.get("data")) + + raise ValueError(f"Unknown ProviderType type: {provider_type}") + + +def serialize_batch_request_result( + batch_individual_response: Optional[Union[BetaMessageBatchIndividualResponse]], +) -> Optional[Dict[str, Any]]: + """Convert a list of ToolRules into a JSON-serializable format.""" + if not batch_individual_response: + return None + + llm_provider_type = None + if isinstance(batch_individual_response, BetaMessageBatchIndividualResponse): + llm_provider_type = ProviderType.anthropic.value + + if not llm_provider_type: + raise ValueError(f"Could not determine llm provider from batch result object type: {batch_individual_response}") + + return {"data": batch_individual_response.model_dump(mode="json"), "type": llm_provider_type} + + +def deserialize_batch_request_result(data: Optional[Dict]) -> Optional[Union[BetaMessageBatchIndividualResponse]]: + if not data: + return None + provider_type = ProviderType(data.get("type")) + + if provider_type == ProviderType.anthropic: + return BetaMessageBatchIndividualResponse(**data.get("data")) + + raise ValueError(f"Unknown ProviderType type: {provider_type}") + + +def serialize_agent_step_state(agent_step_state: Optional[AgentStepState]) -> Optional[Dict[str, Any]]: + """Convert a list of ToolRules into a JSON-serializable format.""" + if not agent_step_state: + return None + + return agent_step_state.model_dump(mode="json") + + +def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepState]: + if not data: + return None + + return AgentStepState(**data) diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index a4b6a655..b0ff1d79 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -38,29 +38,46 @@ class ToolRulesSolver(BaseModel): ) tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.") - def __init__(self, tool_rules: List[BaseToolRule], **kwargs): - super().__init__(**kwargs) - # Separate the provided tool rules into init, standard, and terminal categories - for rule in tool_rules: - if rule.type == ToolRuleType.run_first: - assert isinstance(rule, InitToolRule) - self.init_tool_rules.append(rule) - elif rule.type == ToolRuleType.constrain_child_tools: - assert isinstance(rule, ChildToolRule) - self.child_based_tool_rules.append(rule) - elif rule.type == ToolRuleType.conditional: - assert isinstance(rule, ConditionalToolRule) - self.validate_conditional_tool(rule) - self.child_based_tool_rules.append(rule) - elif rule.type == ToolRuleType.exit_loop: - assert isinstance(rule, TerminalToolRule) - self.terminal_tool_rules.append(rule) - elif rule.type == ToolRuleType.continue_loop: - assert isinstance(rule, ContinueToolRule) - self.continue_tool_rules.append(rule) - elif rule.type == ToolRuleType.max_count_per_step: - assert isinstance(rule, MaxCountPerStepToolRule) - self.child_based_tool_rules.append(rule) + def __init__( + self, + tool_rules: Optional[List[BaseToolRule]] = None, + init_tool_rules: Optional[List[InitToolRule]] = None, + continue_tool_rules: Optional[List[ContinueToolRule]] = None, + child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None, + terminal_tool_rules: Optional[List[TerminalToolRule]] = None, + tool_call_history: Optional[List[str]] = None, + **kwargs, + ): + super().__init__( + init_tool_rules=init_tool_rules or [], + continue_tool_rules=continue_tool_rules or [], + child_based_tool_rules=child_based_tool_rules or [], + terminal_tool_rules=terminal_tool_rules or [], + tool_call_history=tool_call_history or [], + **kwargs, + ) + + if tool_rules: + for rule in tool_rules: + if rule.type == ToolRuleType.run_first: + assert isinstance(rule, InitToolRule) + self.init_tool_rules.append(rule) + elif rule.type == ToolRuleType.constrain_child_tools: + assert isinstance(rule, ChildToolRule) + self.child_based_tool_rules.append(rule) + elif rule.type == ToolRuleType.conditional: + assert isinstance(rule, ConditionalToolRule) + self.validate_conditional_tool(rule) + self.child_based_tool_rules.append(rule) + elif rule.type == ToolRuleType.exit_loop: + assert isinstance(rule, TerminalToolRule) + self.terminal_tool_rules.append(rule) + elif rule.type == ToolRuleType.continue_loop: + assert isinstance(rule, ContinueToolRule) + self.continue_tool_rules.append(rule) + elif rule.type == ToolRuleType.max_count_per_step: + assert isinstance(rule, MaxCountPerStepToolRule) + self.child_based_tool_rules.append(rule) def register_tool_call(self, tool_name: str): """Update the internal state to track tool call history.""" diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 02af8304..348cd19e 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -13,6 +13,8 @@ from letta.orm.identities_blocks import IdentitiesBlocks from letta.orm.identity import Identity from letta.orm.job import Job from letta.orm.job_messages import JobMessage +from letta.orm.llm_batch_items import LLMBatchItem +from letta.orm.llm_batch_job import LLMBatchJob from letta.orm.message import Message from letta.orm.organization import Organization from letta.orm.passage import AgentPassage, BasePassage, SourcePassage diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 39a549ac..689ab8af 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -144,6 +144,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): viewonly=True, back_populates="manager_agent", ) + batch_items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="agent", lazy="selectin") def to_pydantic(self, include_relationships: Optional[Set[str]] = None) -> PydanticAgentState: """ diff --git a/letta/orm/custom_columns.py b/letta/orm/custom_columns.py index 2f9150d5..77346406 100644 --- a/letta/orm/custom_columns.py +++ b/letta/orm/custom_columns.py @@ -2,16 +2,24 @@ from sqlalchemy import JSON from sqlalchemy.types import BINARY, TypeDecorator from letta.helpers.converters import ( + deserialize_agent_step_state, + deserialize_batch_request_result, + deserialize_create_batch_response, deserialize_embedding_config, deserialize_llm_config, deserialize_message_content, + deserialize_poll_batch_response, deserialize_tool_calls, deserialize_tool_returns, deserialize_tool_rules, deserialize_vector, + serialize_agent_step_state, + serialize_batch_request_result, + serialize_create_batch_response, serialize_embedding_config, serialize_llm_config, serialize_message_content, + serialize_poll_batch_response, serialize_tool_calls, serialize_tool_returns, serialize_tool_rules, @@ -108,3 +116,55 @@ class CommonVector(TypeDecorator): def process_result_value(self, value, dialect): return deserialize_vector(value, dialect) + + +class CreateBatchResponseColumn(TypeDecorator): + """Custom SQLAlchemy column type for storing a list of ToolRules as JSON.""" + + impl = JSON + cache_ok = True + + def process_bind_param(self, value, dialect): + return serialize_create_batch_response(value) + + def process_result_value(self, value, dialect): + return deserialize_create_batch_response(value) + + +class PollBatchResponseColumn(TypeDecorator): + """Custom SQLAlchemy column type for storing a list of ToolRules as JSON.""" + + impl = JSON + cache_ok = True + + def process_bind_param(self, value, dialect): + return serialize_poll_batch_response(value) + + def process_result_value(self, value, dialect): + return deserialize_poll_batch_response(value) + + +class BatchRequestResultColumn(TypeDecorator): + """Custom SQLAlchemy column type for storing a list of ToolRules as JSON.""" + + impl = JSON + cache_ok = True + + def process_bind_param(self, value, dialect): + return serialize_batch_request_result(value) + + def process_result_value(self, value, dialect): + return deserialize_batch_request_result(value) + + +class AgentStepStateColumn(TypeDecorator): + """Custom SQLAlchemy column type for storing a list of ToolRules as JSON.""" + + impl = JSON + cache_ok = True + + def process_bind_param(self, value, dialect): + return serialize_agent_step_state(value) + + def process_result_value(self, value, dialect): + return deserialize_agent_step_state(value) diff --git a/letta/orm/llm_batch_items.py b/letta/orm/llm_batch_items.py new file mode 100644 index 00000000..e11de396 --- /dev/null +++ b/letta/orm/llm_batch_items.py @@ -0,0 +1,55 @@ +import uuid +from typing import Optional, Union + +from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse +from sqlalchemy import ForeignKey, Index, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.custom_columns import AgentStepStateColumn, BatchRequestResultColumn, LLMConfigColumn +from letta.orm.mixins import AgentMixin, OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.agent import AgentStepState +from letta.schemas.enums import AgentStepStatus, JobStatus +from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem +from letta.schemas.llm_config import LLMConfig + + +class LLMBatchItem(SqlalchemyBase, OrganizationMixin, AgentMixin): + """Represents a single agent's LLM request within a batch""" + + __tablename__ = "llm_batch_items" + __pydantic_model__ = PydanticLLMBatchItem + __table_args__ = ( + Index("ix_llm_batch_items_batch_id", "batch_id"), + Index("ix_llm_batch_items_agent_id", "agent_id"), + Index("ix_llm_batch_items_status", "request_status"), + ) + + # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase + # TODO: Some still rely on the Pydantic object to do this + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"batch_item-{uuid.uuid4()}") + + batch_id: Mapped[str] = mapped_column( + ForeignKey("llm_batch_job.id", ondelete="CASCADE"), doc="Foreign key to the LLM provider batch this item belongs to" + ) + + llm_config: Mapped[LLMConfig] = mapped_column(LLMConfigColumn, nullable=False, doc="LLM configuration specific to this request") + + request_status: Mapped[JobStatus] = mapped_column( + String, default=JobStatus.created, doc="Status of the LLM request in the batch (PENDING, SUBMITTED, DONE, ERROR)" + ) + + step_status: Mapped[AgentStepStatus] = mapped_column(String, default=AgentStepStatus.paused, doc="Status of the agent's step execution") + + step_state: Mapped[AgentStepState] = mapped_column( + AgentStepStateColumn, doc="Execution metadata for resuming the agent step (e.g., tool call ID, timestamps)" + ) + + batch_request_result: Mapped[Optional[Union[BetaMessageBatchIndividualResponse]]] = mapped_column( + BatchRequestResultColumn, nullable=True, doc="Raw JSON response from the LLM for this item" + ) + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_items") + batch: Mapped["LLMBatchJob"] = relationship("LLMBatchJob", back_populates="items", lazy="selectin") + agent: Mapped["Agent"] = relationship("Agent", back_populates="batch_items", lazy="selectin") diff --git a/letta/orm/llm_batch_job.py b/letta/orm/llm_batch_job.py new file mode 100644 index 00000000..d86a7564 --- /dev/null +++ b/letta/orm/llm_batch_job.py @@ -0,0 +1,48 @@ +import uuid +from datetime import datetime +from typing import List, Optional, Union + +from anthropic.types.beta.messages import BetaMessageBatch +from sqlalchemy import DateTime, Index, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.custom_columns import CreateBatchResponseColumn, PollBatchResponseColumn +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.enums import JobStatus, ProviderType +from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob + + +class LLMBatchJob(SqlalchemyBase, OrganizationMixin): + """Represents a single LLM batch request made to a provider like Anthropic""" + + __tablename__ = "llm_batch_job" + __table_args__ = ( + Index("ix_llm_batch_job_created_at", "created_at"), + Index("ix_llm_batch_job_status", "status"), + ) + + __pydantic_model__ = PydanticLLMBatchJob + + # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase + # TODO: Some still rely on the Pydantic object to do this + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"batch_req-{uuid.uuid4()}") + + status: Mapped[JobStatus] = mapped_column(String, default=JobStatus.created, doc="The current status of the batch.") + + llm_provider: Mapped[ProviderType] = mapped_column(String, doc="LLM provider used (e.g., 'Anthropic')") + + create_batch_response: Mapped[Union[BetaMessageBatch]] = mapped_column( + CreateBatchResponseColumn, doc="Full JSON response from initial batch creation" + ) + latest_polling_response: Mapped[Union[BetaMessageBatch]] = mapped_column( + PollBatchResponseColumn, nullable=True, doc="Last known polling result from LLM provider" + ) + + last_polled_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), nullable=True, doc="Last time we polled the provider for status" + ) + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_batch_jobs") + items: Mapped[List["LLMBatchItem"]] = relationship("LLMBatchItem", back_populates="batch", lazy="selectin") diff --git a/letta/orm/organization.py b/letta/orm/organization.py index ebd66f80..df0ea75e 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -51,6 +51,8 @@ class Organization(SqlalchemyBase): providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan") identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan") groups: Mapped[List["Group"]] = relationship("Group", back_populates="organization", cascade="all, delete-orphan") + llm_batch_jobs: Mapped[List["Agent"]] = relationship("LLMBatchJob", back_populates="organization", cascade="all, delete-orphan") + llm_batch_items: Mapped[List["Agent"]] = relationship("LLMBatchItem", back_populates="organization", cascade="all, delete-orphan") @property def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]: diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index a5476cf6..f0e3aa7d 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional from pydantic import BaseModel, Field, field_validator from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE +from letta.helpers import ToolRulesSolver from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.environment_variables import AgentEnvironmentVariable @@ -271,3 +272,8 @@ class AgentStepResponse(BaseModel): ..., description="Whether the agent step ended because the in-context memory is near its limit." ) usage: UsageStatistics = Field(..., description="Usage statistics of the LLM call during the agent's step.") + + +class AgentStepState(BaseModel): + step_number: int = Field(..., description="The current step number in the agent loop") + tool_rules_solver: ToolRulesSolver = Field(..., description="The current state of the ToolRulesSolver") diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 022846d5..c02e1438 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -1,6 +1,10 @@ from enum import Enum +class ProviderType(str, Enum): + anthropic = "anthropic" + + class MessageRole(str, Enum): assistant = "assistant" user = "user" @@ -22,6 +26,7 @@ class JobStatus(str, Enum): Status of the job. """ + not_started = "not_started" created = "created" running = "running" completed = "completed" @@ -29,6 +34,15 @@ class JobStatus(str, Enum): pending = "pending" +class AgentStepStatus(str, Enum): + """ + Status of the job. + """ + + paused = "paused" + running = "running" + + class MessageStreamStatus(str, Enum): done = "[DONE]" diff --git a/letta/schemas/llm_batch_job.py b/letta/schemas/llm_batch_job.py new file mode 100644 index 00000000..7f21fc27 --- /dev/null +++ b/letta/schemas/llm_batch_job.py @@ -0,0 +1,53 @@ +from datetime import datetime +from typing import Optional, Union + +from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse +from pydantic import Field + +from letta.schemas.agent import AgentStepState +from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType +from letta.schemas.letta_base import OrmMetadataBase +from letta.schemas.llm_config import LLMConfig + + +class LLMBatchItem(OrmMetadataBase, validate_assignment=True): + """ + Represents a single agent's LLM request within a batch. + + This object captures the configuration, execution status, and eventual result of one agent's request within a larger LLM batch job. + """ + + __id_prefix__ = "batch_item" + + id: str = Field(..., description="The id of the batch item. Assigned by the database.") + batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.") + agent_id: str = Field(..., description="The id of the agent associated with this LLM request.") + + llm_config: LLMConfig = Field(..., description="The LLM configuration used for this request.") + request_status: JobStatus = Field(..., description="The current status of the batch item request (e.g., PENDING, DONE, ERROR).") + step_status: AgentStepStatus = Field(..., description="The current execution status of the agent step.") + step_state: AgentStepState = Field(..., description="The serialized state for resuming execution at a later point.") + + batch_request_result: Optional[Union[BetaMessageBatchIndividualResponse]] = Field( + None, description="The raw response received from the LLM provider for this item." + ) + + +class LLMBatchJob(OrmMetadataBase, validate_assignment=True): + """ + Represents a single LLM batch request made to a provider like Anthropic. + + Each job corresponds to one API call that sends multiple messages to the LLM provider, and aggregates responses across all agent submissions. + """ + + __id_prefix__ = "batch_req" + + id: str = Field(..., description="The id of the batch job. Assigned by the database.") + status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).") + llm_provider: ProviderType = Field(..., description="The LLM provider used for the batch (e.g., anthropic, openai).") + + create_batch_response: Union[BetaMessageBatch] = Field(..., description="The full JSON response from the initial batch creation.") + latest_polling_response: Optional[Union[BetaMessageBatch]] = Field( + None, description="The most recent polling response received from the LLM provider." + ) + last_polled_at: Optional[datetime] = Field(None, description="The timestamp of the last polling check for the batch status.") diff --git a/letta/server/server.py b/letta/server/server.py index 39e3ad25..9793fec7 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -81,6 +81,7 @@ from letta.services.block_manager import BlockManager from letta.services.group_manager import GroupManager from letta.services.identity_manager import IdentityManager from letta.services.job_manager import JobManager +from letta.services.llm_batch_manager import LLMBatchManager from letta.services.message_manager import MessageManager from letta.services.organization_manager import OrganizationManager from letta.services.passage_manager import PassageManager @@ -207,6 +208,7 @@ class SyncServer(Server): self.step_manager = StepManager() self.identity_manager = IdentityManager() self.group_manager = GroupManager() + self.batch_manager = LLMBatchManager() # Make default user and org if init_with_default_org_and_user: diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py new file mode 100644 index 00000000..b538e549 --- /dev/null +++ b/letta/services/llm_batch_manager.py @@ -0,0 +1,139 @@ +import datetime +from typing import Optional + +from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse + +from letta.log import get_logger +from letta.orm.llm_batch_items import LLMBatchItem +from letta.orm.llm_batch_job import LLMBatchJob +from letta.schemas.agent import AgentStepState +from letta.schemas.enums import AgentStepStatus, JobStatus +from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem +from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob +from letta.schemas.llm_config import LLMConfig +from letta.schemas.user import User as PydanticUser +from letta.utils import enforce_types + +logger = get_logger(__name__) + + +class LLMBatchManager: + """Manager for handling both LLMBatchJob and LLMBatchItem operations.""" + + def __init__(self): + from letta.server.db import db_context + + self.session_maker = db_context + + @enforce_types + def create_batch_request( + self, + llm_provider: str, + create_batch_response: BetaMessageBatch, + actor: PydanticUser, + status: JobStatus = JobStatus.created, + ) -> PydanticLLMBatchJob: + """Create a new LLM batch job.""" + with self.session_maker() as session: + batch = LLMBatchJob( + status=status, + llm_provider=llm_provider, + create_batch_response=create_batch_response, + organization_id=actor.organization_id, + ) + batch.create(session, actor=actor) + return batch.to_pydantic() + + @enforce_types + def get_batch_request_by_id(self, batch_id: str, actor: PydanticUser) -> PydanticLLMBatchJob: + """Retrieve a single batch job by ID.""" + with self.session_maker() as session: + batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor) + return batch.to_pydantic() + + @enforce_types + def update_batch_status( + self, + batch_id: str, + status: JobStatus, + actor: PydanticUser, + latest_polling_response: Optional[BetaMessageBatch] = None, + ) -> PydanticLLMBatchJob: + """Update a batch job’s status and optionally its polling response.""" + with self.session_maker() as session: + batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor) + batch.status = status + batch.latest_polling_response = latest_polling_response + batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc) + return batch.update(db_session=session, actor=actor).to_pydantic() + + @enforce_types + def delete_batch_request(self, batch_id: str, actor: PydanticUser) -> None: + """Hard delete a batch job by ID.""" + with self.session_maker() as session: + batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor) + batch.hard_delete(db_session=session, actor=actor) + + @enforce_types + def create_batch_item( + self, + batch_id: str, + agent_id: str, + llm_config: LLMConfig, + actor: PydanticUser, + request_status: JobStatus = JobStatus.created, + step_status: AgentStepStatus = AgentStepStatus.paused, + step_state: Optional[AgentStepState] = None, + ) -> PydanticLLMBatchItem: + """Create a new batch item.""" + with self.session_maker() as session: + item = LLMBatchItem( + batch_id=batch_id, + agent_id=agent_id, + llm_config=llm_config, + request_status=request_status, + step_status=step_status, + step_state=step_state, + organization_id=actor.organization_id, + ) + item.create(session, actor=actor) + return item.to_pydantic() + + @enforce_types + def get_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem: + """Retrieve a single batch item by ID.""" + with self.session_maker() as session: + item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) + return item.to_pydantic() + + @enforce_types + def update_batch_item( + self, + item_id: str, + actor: PydanticUser, + request_status: Optional[JobStatus] = None, + step_status: Optional[AgentStepStatus] = None, + llm_request_response: Optional[BetaMessageBatchIndividualResponse] = None, + step_state: Optional[AgentStepState] = None, + ) -> PydanticLLMBatchItem: + """Update fields on a batch item.""" + with self.session_maker() as session: + item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) + + if request_status: + item.request_status = request_status + if step_status: + item.step_status = step_status + if llm_request_response: + item.batch_request_result = llm_request_response + if step_state: + item.step_state = step_state + + return item.update(db_session=session, actor=actor).to_pydantic() + + @enforce_types + def delete_batch_item(self, item_id: str, actor: PydanticUser) -> None: + """Hard delete a batch item by ID.""" + with self.session_maker() as session: + item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) + item.hard_delete(db_session=session, actor=actor) diff --git a/tests/test_managers.py b/tests/test_managers.py index 040d4edf..fb2f1c54 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2,10 +2,17 @@ import os import random import string import time -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import List import pytest +from anthropic.types.beta import BetaMessage +from anthropic.types.beta.messages import ( + BetaMessageBatch, + BetaMessageBatchIndividualResponse, + BetaMessageBatchRequestCounts, + BetaMessageBatchSucceededResult, +) from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction from sqlalchemy.exc import IntegrityError @@ -23,15 +30,16 @@ from letta.constants import ( from letta.embeddings import embedding_model from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.functions.mcp_client.types import MCPTool +from letta.helpers import ToolRulesSolver from letta.orm import Base, Block from letta.orm.block_history import BlockHistory from letta.orm.enums import ActorType, JobType, ToolType from letta.orm.errors import NoResultFound, UniqueConstraintViolationError -from letta.schemas.agent import CreateAgent, UpdateAgent +from letta.schemas.agent import AgentStepState, CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import JobStatus, MessageRole +from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate @@ -570,6 +578,62 @@ def agent_with_tags(server: SyncServer, default_user): return [agent1, agent2, agent3] +@pytest.fixture +def dummy_beta_message_batch() -> BetaMessageBatch: + return BetaMessageBatch( + id="msgbatch_013Zva2CMHLNnXjNJJKqJ2EF", + archived_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + cancel_initiated_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + created_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + ended_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + expires_at=datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc), + processing_status="in_progress", + request_counts=BetaMessageBatchRequestCounts( + canceled=10, + errored=30, + expired=10, + processing=100, + succeeded=50, + ), + results_url="https://api.anthropic.com/v1/messages/batches/msgbatch_013Zva2CMHLNnXjNJJKqJ2EF/results", + type="message_batch", + ) + + +@pytest.fixture +def dummy_llm_config() -> LLMConfig: + return LLMConfig.default_config("gpt-4") + + +@pytest.fixture +def dummy_tool_rules_solver() -> ToolRulesSolver: + return ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")]) + + +@pytest.fixture +def dummy_step_state(dummy_tool_rules_solver: ToolRulesSolver) -> AgentStepState: + return AgentStepState(step_number=1, tool_rules_solver=dummy_tool_rules_solver) + + +@pytest.fixture +def dummy_successful_response() -> BetaMessageBatchIndividualResponse: + return BetaMessageBatchIndividualResponse( + custom_id="my-second-request", + result=BetaMessageBatchSucceededResult( + type="succeeded", + message=BetaMessage( + id="msg_abc123", + role="assistant", + type="message", + model="claude-3-5-sonnet-20240620", + content=[{"type": "text", "text": "hi!"}], + usage={"input_tokens": 5, "output_tokens": 7}, + stop_reason="end_turn", + ), + ), + ) + + # ====================================================================================================================== # AgentManager Tests - Basic # ====================================================================================================================== @@ -4614,3 +4678,123 @@ def test_list_tags(server: SyncServer, default_user, default_organization): # Cleanup for agent in agents: server.agent_manager.delete_agent(agent.id, actor=default_user) + + +# ====================================================================================================================== +# LLMBatchManager Tests +# ====================================================================================================================== + + +def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch): + batch = server.batch_manager.create_batch_request( + llm_provider="anthropic", + status=JobStatus.created, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + assert batch.id.startswith("batch_req-") + assert batch.create_batch_response == dummy_beta_message_batch + fetched = server.batch_manager.get_batch_request_by_id(batch.id, actor=default_user) + assert fetched.id == batch.id + + +def test_update_batch_status(server, default_user, dummy_beta_message_batch): + batch = server.batch_manager.create_batch_request( + llm_provider="anthropic", + status=JobStatus.created, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + before = datetime.now(timezone.utc) + + server.batch_manager.update_batch_status( + batch_id=batch.id, + status=JobStatus.completed, + latest_polling_response=dummy_beta_message_batch, + actor=default_user, + ) + + updated = server.batch_manager.get_batch_request_by_id(batch.id, actor=default_user) + assert updated.status == JobStatus.completed + assert updated.latest_polling_response == dummy_beta_message_batch + assert updated.last_polled_at >= before + + +def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): + batch = server.batch_manager.create_batch_request( + llm_provider="anthropic", + status=JobStatus.created, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + item = server.batch_manager.create_batch_item( + batch_id=batch.id, + agent_id=sarah_agent.id, + llm_config=dummy_llm_config, + step_state=dummy_step_state, + actor=default_user, + ) + + assert item.batch_id == batch.id + assert item.agent_id == sarah_agent.id + assert item.step_state == dummy_step_state + + fetched = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + assert fetched.id == item.id + + +def test_update_batch_item( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response +): + batch = server.batch_manager.create_batch_request( + llm_provider="anthropic", + status=JobStatus.created, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + item = server.batch_manager.create_batch_item( + batch_id=batch.id, + agent_id=sarah_agent.id, + llm_config=dummy_llm_config, + step_state=dummy_step_state, + actor=default_user, + ) + + updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver) + + server.batch_manager.update_batch_item( + item_id=item.id, + request_status=JobStatus.completed, + step_status=AgentStepStatus.running, + llm_request_response=dummy_successful_response, + step_state=updated_step_state, + actor=default_user, + ) + + updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + assert updated.request_status == JobStatus.completed + assert updated.batch_request_result == dummy_successful_response + + +def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): + batch = server.batch_manager.create_batch_request( + llm_provider="anthropic", + status=JobStatus.created, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + item = server.batch_manager.create_batch_item( + batch_id=batch.id, + agent_id=sarah_agent.id, + llm_config=dummy_llm_config, + step_state=dummy_step_state, + actor=default_user, + ) + + server.batch_manager.delete_batch_item(item_id=item.id, actor=default_user) + + with pytest.raises(NoResultFound): + server.batch_manager.get_batch_item_by_id(item.id, actor=default_user)