chore: release 0.6.50 (#2547)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import { LettaClient } from '@letta-ai/letta-client';
|
||||
import {
|
||||
import type { LettaClient } from '@letta-ai/letta-client';
|
||||
import type {
|
||||
AssistantMessage,
|
||||
ReasoningMessage,
|
||||
ToolCallMessage,
|
||||
@@ -52,17 +52,17 @@ console.log(
|
||||
(response.messages[1] as AssistantMessage).content,
|
||||
);
|
||||
|
||||
const custom_tool_source_code = `
|
||||
const CUSTOM_TOOL_SOURCE_CODE = `
|
||||
def secret_message():
|
||||
"""Return a secret message."""
|
||||
return "Hello world!"
|
||||
`.trim();
|
||||
|
||||
const tool = await client.tools.upsert({
|
||||
sourceCode: custom_tool_source_code,
|
||||
sourceCode: CUSTOM_TOOL_SOURCE_CODE,
|
||||
});
|
||||
|
||||
await client.agents.tools.attach(agent.id, tool.id!);
|
||||
await client.agents.tools.attach(agent.id, tool.id);
|
||||
|
||||
console.log(`Created tool ${tool.name} and attached to agent ${agent.name}`);
|
||||
|
||||
@@ -103,7 +103,7 @@ let agentCopy = await client.agents.create({
|
||||
embedding: 'openai/text-embedding-ada-002',
|
||||
});
|
||||
let block = await client.agents.blocks.retrieve(agent.id, 'human');
|
||||
agentCopy = await client.agents.blocks.attach(agentCopy.id, block.id!);
|
||||
agentCopy = await client.agents.blocks.attach(agentCopy.id, block.id);
|
||||
|
||||
console.log('Created agent copy with shared memory named', agentCopy.name);
|
||||
|
||||
|
||||
4
examples/docs/node/project.json
Normal file
4
examples/docs/node/project.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"name": "node-example",
|
||||
"$schema": "../../node_modules/nx/schemas/project-schema.json"
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.6.49"
|
||||
__version__ = "0.6.50"
|
||||
|
||||
# import clients
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
||||
@@ -376,7 +376,6 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
raise ValueError(f"Bad finish reason from API: {response.choices[0].finish_reason}")
|
||||
log_telemetry(self.logger, "_handle_ai_response finish")
|
||||
return response
|
||||
|
||||
except ValueError as ve:
|
||||
if attempt >= empty_response_retry_limit:
|
||||
@@ -393,6 +392,14 @@ class Agent(BaseAgent):
|
||||
log_telemetry(self.logger, "_handle_ai_response finish generic Exception")
|
||||
raise e
|
||||
|
||||
# check if we are going over the context window: this allows for articifial constraints
|
||||
if response.usage.total_tokens > self.agent_state.llm_config.context_window:
|
||||
# trigger summarization
|
||||
log_telemetry(self.logger, "_get_ai_reply summarize_messages_inplace")
|
||||
self.summarize_messages_inplace()
|
||||
# return the response
|
||||
return response
|
||||
|
||||
log_telemetry(self.logger, "_handle_ai_response finish catch-all exception")
|
||||
raise Exception("Retries exhausted and no valid response received.")
|
||||
|
||||
|
||||
@@ -225,7 +225,10 @@ def core_memory_insert(agent_state: "AgentState", target_block_label: str, new_m
|
||||
current_value_list = current_value.split("\n")
|
||||
if line_number is None:
|
||||
line_number = len(current_value_list)
|
||||
current_value_list.insert(line_number, new_memory)
|
||||
if replace:
|
||||
current_value_list[line_number] = new_memory
|
||||
else:
|
||||
current_value_list.insert(line_number, new_memory)
|
||||
new_value = "\n".join(current_value_list)
|
||||
agent_state.memory.update_block_value(label=target_block_label, value=new_value)
|
||||
return None
|
||||
|
||||
@@ -629,8 +629,22 @@ def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[
|
||||
if nested_models and ref_type in nested_models:
|
||||
return nested_models[ref_type]
|
||||
elif "additionalProperties" in field_schema:
|
||||
value_type = _get_field_type(field_schema["additionalProperties"], nested_models)
|
||||
return Dict[str, value_type]
|
||||
# TODO: This is totally GPT generated and I'm not sure it works
|
||||
# TODO: This is done to quickly patch some tests, we should nuke this whole pathway asap
|
||||
ap = field_schema["additionalProperties"]
|
||||
|
||||
if ap is True:
|
||||
return dict
|
||||
elif ap is False:
|
||||
raise ValueError("additionalProperties=false is not supported.")
|
||||
else:
|
||||
# Try resolving nested type
|
||||
nested_type = _get_field_type(ap, nested_models)
|
||||
# If nested_type is Any, fall back to `dict`, or raise, depending on how strict you want to be
|
||||
if nested_type == Any:
|
||||
return dict
|
||||
return Dict[str, nested_type]
|
||||
|
||||
return dict
|
||||
elif field_schema.get("$ref") is not None:
|
||||
ref_type = field_schema["$ref"].split("/")[-1]
|
||||
|
||||
0
letta/jobs/__init__.py
Normal file
0
letta/jobs/__init__.py
Normal file
25
letta/jobs/helpers.py
Normal file
25
letta/jobs/helpers.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from anthropic.types.beta.messages import (
|
||||
BetaMessageBatchCanceledResult,
|
||||
BetaMessageBatchIndividualResponse,
|
||||
BetaMessageBatchSucceededResult,
|
||||
)
|
||||
|
||||
from letta.schemas.enums import JobStatus
|
||||
|
||||
|
||||
def map_anthropic_batch_job_status_to_job_status(anthropic_status: str) -> JobStatus:
|
||||
mapping = {
|
||||
"in_progress": JobStatus.running,
|
||||
"canceling": JobStatus.cancelled,
|
||||
"ended": JobStatus.completed,
|
||||
}
|
||||
return mapping.get(anthropic_status, JobStatus.pending) # fallback just in case
|
||||
|
||||
|
||||
def map_anthropic_individual_batch_item_status_to_job_status(individual_item: BetaMessageBatchIndividualResponse) -> JobStatus:
|
||||
if isinstance(individual_item.result, BetaMessageBatchSucceededResult):
|
||||
return JobStatus.completed
|
||||
elif isinstance(individual_item.result, BetaMessageBatchCanceledResult):
|
||||
return JobStatus.cancelled
|
||||
else:
|
||||
return JobStatus.failed
|
||||
204
letta/jobs/llm_batch_job_polling.py
Normal file
204
letta/jobs/llm_batch_job_polling.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import List
|
||||
|
||||
from letta.jobs.helpers import map_anthropic_batch_job_status_to_job_status, map_anthropic_individual_batch_item_status_to_job_status
|
||||
from letta.jobs.types import BatchId, BatchPollingResult, ItemUpdateInfo
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import JobStatus, ProviderType
|
||||
from letta.schemas.llm_batch_job import LLMBatchJob
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BatchPollingMetrics:
|
||||
"""Class to track metrics for batch polling operations."""
|
||||
|
||||
def __init__(self):
|
||||
self.start_time = datetime.datetime.now()
|
||||
self.total_batches = 0
|
||||
self.anthropic_batches = 0
|
||||
self.running_count = 0
|
||||
self.completed_count = 0
|
||||
self.updated_items_count = 0
|
||||
|
||||
def log_summary(self):
|
||||
"""Log a summary of the metrics collected during polling."""
|
||||
elapsed = (datetime.datetime.now() - self.start_time).total_seconds()
|
||||
logger.info(f"[Poll BatchJob] Finished poll_running_llm_batches job in {elapsed:.2f}s")
|
||||
logger.info(f"[Poll BatchJob] Found {self.total_batches} running batches total.")
|
||||
logger.info(f"[Poll BatchJob] Found {self.anthropic_batches} Anthropic batch(es) to poll.")
|
||||
logger.info(f"[Poll BatchJob] Final results: {self.completed_count} completed, {self.running_count} still running.")
|
||||
logger.info(f"[Poll BatchJob] Updated {self.updated_items_count} items for newly completed batch(es).")
|
||||
|
||||
|
||||
async def fetch_batch_status(server: SyncServer, batch_job: LLMBatchJob) -> BatchPollingResult:
|
||||
"""
|
||||
Fetch the current status of a single batch job from the provider.
|
||||
|
||||
Args:
|
||||
server: The SyncServer instance
|
||||
batch_job: The batch job to check status for
|
||||
|
||||
Returns:
|
||||
A tuple containing (batch_id, new_status, polling_response)
|
||||
"""
|
||||
batch_id_str = batch_job.create_batch_response.id
|
||||
try:
|
||||
response = await server.anthropic_async_client.beta.messages.batches.retrieve(batch_id_str)
|
||||
new_status = map_anthropic_batch_job_status_to_job_status(response.processing_status)
|
||||
logger.debug(f"[Poll BatchJob] Batch {batch_job.id}: provider={response.processing_status} → internal={new_status}")
|
||||
return (batch_job.id, new_status, response)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Poll BatchJob] Batch {batch_job.id}: failed to retrieve {batch_id_str}: {e}")
|
||||
# We treat a retrieval error as still running to try again next cycle
|
||||
return (batch_job.id, JobStatus.running, None)
|
||||
|
||||
|
||||
async def fetch_batch_items(server: SyncServer, batch_id: BatchId, batch_resp_id: str) -> List[ItemUpdateInfo]:
|
||||
"""
|
||||
Fetch individual item results for a completed batch.
|
||||
|
||||
Args:
|
||||
server: The SyncServer instance
|
||||
batch_id: The internal batch ID
|
||||
batch_resp_id: The provider's batch response ID
|
||||
|
||||
Returns:
|
||||
A list of item update information tuples
|
||||
"""
|
||||
updates = []
|
||||
try:
|
||||
async for item_result in server.anthropic_async_client.beta.messages.batches.results(batch_resp_id):
|
||||
# Here, custom_id should be the agent_id
|
||||
item_status = map_anthropic_individual_batch_item_status_to_job_status(item_result)
|
||||
updates.append((batch_id, item_result.custom_id, item_status, item_result))
|
||||
logger.info(f"[Poll BatchJob] Fetched {len(updates)} item updates for batch {batch_id}.")
|
||||
except Exception as e:
|
||||
logger.error(f"[Poll BatchJob] Error fetching item updates for batch {batch_id}: {e}")
|
||||
|
||||
return updates
|
||||
|
||||
|
||||
async def poll_batch_updates(server: SyncServer, batch_jobs: List[LLMBatchJob], metrics: BatchPollingMetrics) -> List[BatchPollingResult]:
|
||||
"""
|
||||
Poll for updates to multiple batch jobs concurrently.
|
||||
|
||||
Args:
|
||||
server: The SyncServer instance
|
||||
batch_jobs: List of batch jobs to poll
|
||||
metrics: Metrics collection object
|
||||
|
||||
Returns:
|
||||
List of batch polling results
|
||||
"""
|
||||
if not batch_jobs:
|
||||
logger.info("[Poll BatchJob] No Anthropic batches to update; job complete.")
|
||||
return []
|
||||
|
||||
# Create polling tasks for all batch jobs
|
||||
coros = [fetch_batch_status(server, b) for b in batch_jobs]
|
||||
results: List[BatchPollingResult] = await asyncio.gather(*coros)
|
||||
|
||||
# Update the server with batch status changes
|
||||
server.batch_manager.bulk_update_batch_statuses(updates=results)
|
||||
logger.info(f"[Poll BatchJob] Bulk-updated {len(results)} LLM batch(es) in the DB at job level.")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def process_completed_batches(
|
||||
server: SyncServer, batch_results: List[BatchPollingResult], metrics: BatchPollingMetrics
|
||||
) -> List[ItemUpdateInfo]:
|
||||
"""
|
||||
Process batches that have completed and fetch their item results.
|
||||
|
||||
Args:
|
||||
server: The SyncServer instance
|
||||
batch_results: Results from polling batch statuses
|
||||
metrics: Metrics collection object
|
||||
|
||||
Returns:
|
||||
List of item updates to apply
|
||||
"""
|
||||
item_update_tasks = []
|
||||
|
||||
# Process each top-level polling result
|
||||
for batch_id, new_status, maybe_batch_resp in batch_results:
|
||||
if not maybe_batch_resp:
|
||||
if new_status == JobStatus.running:
|
||||
metrics.running_count += 1
|
||||
logger.warning(f"[Poll BatchJob] Batch {batch_id}: JobStatus was {new_status} and no batch response was found.")
|
||||
continue
|
||||
|
||||
if new_status == JobStatus.completed:
|
||||
metrics.completed_count += 1
|
||||
batch_resp_id = maybe_batch_resp.id # The Anthropic-assigned batch ID
|
||||
# Queue an async call to fetch item results for this batch
|
||||
item_update_tasks.append(fetch_batch_items(server, batch_id, batch_resp_id))
|
||||
elif new_status == JobStatus.running:
|
||||
metrics.running_count += 1
|
||||
|
||||
# Launch all item update tasks concurrently
|
||||
concurrent_results = await asyncio.gather(*item_update_tasks, return_exceptions=True)
|
||||
|
||||
# Flatten and filter the results
|
||||
item_updates = []
|
||||
for result in concurrent_results:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"[Poll BatchJob] A fetch_batch_items task failed with: {result}")
|
||||
elif isinstance(result, list):
|
||||
item_updates.extend(result)
|
||||
|
||||
logger.info(f"[Poll BatchJob] Collected a total of {len(item_updates)} item update(s) from completed batches.")
|
||||
|
||||
return item_updates
|
||||
|
||||
|
||||
async def poll_running_llm_batches(server: "SyncServer") -> None:
|
||||
"""
|
||||
Cron job to poll all running LLM batch jobs and update their polling responses in bulk.
|
||||
|
||||
Steps:
|
||||
1. Fetch currently running batch jobs
|
||||
2. Filter Anthropic only
|
||||
3. Retrieve updated top-level polling info concurrently
|
||||
4. Bulk update LLMBatchJob statuses
|
||||
5. For each completed batch, call .results(...) to get item-level results
|
||||
6. Bulk update all matching LLMBatchItem records by (batch_id, agent_id)
|
||||
7. Log telemetry about success/fail
|
||||
"""
|
||||
# Initialize metrics tracking
|
||||
metrics = BatchPollingMetrics()
|
||||
|
||||
logger.info("[Poll BatchJob] Starting poll_running_llm_batches job")
|
||||
|
||||
try:
|
||||
# 1. Retrieve running batch jobs
|
||||
batches = server.batch_manager.list_running_batches()
|
||||
metrics.total_batches = len(batches)
|
||||
|
||||
# TODO: Expand to more providers
|
||||
# 2. Filter for Anthropic jobs only
|
||||
anthropic_batch_jobs = [b for b in batches if b.llm_provider == ProviderType.anthropic]
|
||||
metrics.anthropic_batches = len(anthropic_batch_jobs)
|
||||
|
||||
# 3-4. Poll for batch updates and bulk update statuses
|
||||
batch_results = await poll_batch_updates(server, anthropic_batch_jobs, metrics)
|
||||
|
||||
# 5. Process completed batches and fetch item results
|
||||
item_updates = await process_completed_batches(server, batch_results, metrics)
|
||||
|
||||
# 6. Bulk update all items for newly completed batch(es)
|
||||
if item_updates:
|
||||
metrics.updated_items_count = len(item_updates)
|
||||
server.batch_manager.bulk_update_batch_items_by_agent(item_updates)
|
||||
else:
|
||||
logger.info("[Poll BatchJob] No item-level updates needed.")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("[Poll BatchJob] Unhandled error in poll_running_llm_batches", exc_info=e)
|
||||
finally:
|
||||
# 7. Log metrics summary
|
||||
metrics.log_summary()
|
||||
28
letta/jobs/scheduler.py
Normal file
28
letta/jobs/scheduler.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import datetime
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
|
||||
def start_cron_jobs(server: SyncServer):
|
||||
"""Initialize cron jobs"""
|
||||
scheduler.add_job(
|
||||
poll_running_llm_batches,
|
||||
args=[server],
|
||||
trigger=IntervalTrigger(seconds=settings.poll_running_llm_batches_interval_seconds),
|
||||
next_run_time=datetime.datetime.now(datetime.UTC),
|
||||
id="poll_llm_batches",
|
||||
name="Poll LLM API batch jobs and update status",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def shutdown_cron_scheduler():
|
||||
scheduler.shutdown()
|
||||
10
letta/jobs/types.py
Normal file
10
letta/jobs/types.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
|
||||
from letta.schemas.enums import JobStatus
|
||||
|
||||
BatchId = str
|
||||
AgentId = str
|
||||
BatchPollingResult = Tuple[BatchId, JobStatus, Optional[BetaMessageBatch]]
|
||||
ItemUpdateInfo = Tuple[BatchId, AgentId, JobStatus, BetaMessageBatchIndividualResponse]
|
||||
@@ -25,6 +25,7 @@ from letta.llm_api.aws_bedrock import get_bedrock_client
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.message import Message as _Message
|
||||
from letta.schemas.message import MessageRole as _MessageRole
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
@@ -44,6 +45,8 @@ from letta.settings import model_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.tracing import log_event
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
BASE_URL = "https://api.anthropic.com/v1"
|
||||
|
||||
|
||||
@@ -620,9 +623,9 @@ def _prepare_anthropic_request(
|
||||
data: ChatCompletionRequest,
|
||||
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
||||
# if true, prefix fill the generation with the thinking tag
|
||||
prefix_fill: bool = True,
|
||||
prefix_fill: bool = False,
|
||||
# if true, put COT inside the tool calls instead of inside the content
|
||||
put_inner_thoughts_in_kwargs: bool = False,
|
||||
put_inner_thoughts_in_kwargs: bool = True,
|
||||
bedrock: bool = False,
|
||||
# extended thinking related fields
|
||||
# https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking
|
||||
@@ -634,7 +637,9 @@ def _prepare_anthropic_request(
|
||||
assert (
|
||||
max_reasoning_tokens is not None and max_reasoning_tokens < data.max_tokens
|
||||
), "max tokens must be greater than thinking budget"
|
||||
assert not put_inner_thoughts_in_kwargs, "extended thinking not compatible with put_inner_thoughts_in_kwargs"
|
||||
if put_inner_thoughts_in_kwargs:
|
||||
logger.warning("Extended thinking not compatible with put_inner_thoughts_in_kwargs")
|
||||
put_inner_thoughts_in_kwargs = False
|
||||
# assert not prefix_fill, "extended thinking not compatible with prefix_fill"
|
||||
# Silently disable prefix_fill for now
|
||||
prefix_fill = False
|
||||
|
||||
@@ -90,7 +90,7 @@ class AnthropicClient(LLMClientBase):
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: List[dict],
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
|
||||
@@ -146,11 +146,12 @@ class AnthropicClient(LLMClientBase):
|
||||
tools_for_request = [Tool(function=f) for f in tools] if tools is not None else None
|
||||
|
||||
# Add tool choice
|
||||
data["tool_choice"] = tool_choice
|
||||
if tool_choice:
|
||||
data["tool_choice"] = tool_choice
|
||||
|
||||
# Add inner thoughts kwarg
|
||||
# TODO: Can probably make this more efficient
|
||||
if len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
if tools_for_request and len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
tools_with_inner_thoughts = add_inner_thoughts_to_functions(
|
||||
functions=[t.function.model_dump() for t in tools_for_request],
|
||||
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
||||
@@ -158,7 +159,7 @@ class AnthropicClient(LLMClientBase):
|
||||
)
|
||||
tools_for_request = [Tool(function=f) for f in tools_with_inner_thoughts]
|
||||
|
||||
if len(tools_for_request) > 0:
|
||||
if tools_for_request and len(tools_for_request) > 0:
|
||||
# TODO eventually enable parallel tool use
|
||||
data["tools"] = convert_tools_to_anthropic_format(tools_for_request)
|
||||
|
||||
|
||||
@@ -322,6 +322,7 @@ def create(
|
||||
|
||||
# Force tool calling
|
||||
tool_call = None
|
||||
llm_config.put_inner_thoughts_in_kwargs = True
|
||||
if functions is None:
|
||||
# Special case for summarization path
|
||||
tools = None
|
||||
@@ -356,6 +357,7 @@ def create(
|
||||
if stream: # Client requested token streaming
|
||||
assert isinstance(stream_interface, (AgentChunkStreamingInterface, AgentRefreshStreamingInterface)), type(stream_interface)
|
||||
|
||||
stream_interface.inner_thoughts_in_kwargs = True
|
||||
response = anthropic_chat_completions_process_stream(
|
||||
chat_completion_request=chat_completion_request,
|
||||
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
||||
|
||||
@@ -78,9 +78,11 @@ class OpenAIClient(LLMClientBase):
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (self.llm_config.handle and "vllm" in self.llm_config.handle):
|
||||
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
||||
else:
|
||||
elif tools:
|
||||
# only set if tools is non-Null
|
||||
tool_choice = "required"
|
||||
|
||||
if force_tool_call is not None:
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Callable, Dict, List
|
||||
|
||||
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
@@ -9,6 +10,7 @@ from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import count_tokens, printd
|
||||
|
||||
|
||||
@@ -45,6 +47,7 @@ def _format_summary_history(message_history: List[Message]):
|
||||
return "\n".join([f"{m.role}: {get_message_text(m.content)}" for m in message_history])
|
||||
|
||||
|
||||
@trace_method
|
||||
def summarize_messages(
|
||||
agent_state: AgentState,
|
||||
message_sequence_to_summarize: List[Message],
|
||||
@@ -74,12 +77,25 @@ def summarize_messages(
|
||||
# TODO: We need to eventually have a separate LLM config for the summarizer LLM
|
||||
llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
|
||||
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||
response = create(
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
user_id=agent_state.created_by_id,
|
||||
messages=message_sequence,
|
||||
stream=False,
|
||||
put_inner_thoughts_first=False,
|
||||
)
|
||||
# try to use new client, otherwise fallback to old flow
|
||||
# TODO: we can just directly call the LLM here?
|
||||
if llm_client:
|
||||
response = llm_client.send_llm_request(
|
||||
messages=message_sequence,
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
response = create(
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
user_id=agent_state.created_by_id,
|
||||
messages=message_sequence,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
printd(f"summarize_messages gpt reply: {response.choices[0]}")
|
||||
reply = response.choices[0].message.content
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.schemas.letta_message_content import MessageContent
|
||||
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import ToolReturn
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
@@ -42,9 +43,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in")
|
||||
|
||||
# Monotonically increasing sequence for efficient/correct listing
|
||||
sequence_id: Mapped[int] = mapped_column(
|
||||
BigInteger, Sequence("message_seq_id"), unique=True, nullable=False, doc="Global monotonically increasing ID"
|
||||
)
|
||||
sequence_id = mapped_column(BigInteger, Sequence("message_seq_id"), unique=True, nullable=False)
|
||||
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||
@@ -68,15 +67,19 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
model.content = [PydanticTextContent(text=self.text)]
|
||||
return model
|
||||
|
||||
# listener
|
||||
|
||||
|
||||
@event.listens_for(Message, "before_insert")
|
||||
def set_sequence_id_for_sqlite(mapper, connection, target):
|
||||
session = Session.object_session(target)
|
||||
# TODO: Kind of hacky, used to detect if we are using sqlite or not
|
||||
if not settings.pg_uri:
|
||||
session = Session.object_session(target)
|
||||
|
||||
if not hasattr(session, "_sequence_id_counter"):
|
||||
# Initialize counter for this flush
|
||||
max_seq = connection.scalar(text("SELECT MAX(sequence_id) FROM messages"))
|
||||
session._sequence_id_counter = max_seq or 0
|
||||
if not hasattr(session, "_sequence_id_counter"):
|
||||
# Initialize counter for this flush
|
||||
max_seq = connection.scalar(text("SELECT MAX(sequence_id) FROM messages"))
|
||||
session._sequence_id_counter = max_seq or 0
|
||||
|
||||
session._sequence_id_counter += 1
|
||||
target.sequence_id = session._sequence_id_counter
|
||||
session._sequence_id_counter += 1
|
||||
target.sequence_id = session._sequence_id_counter
|
||||
|
||||
@@ -32,6 +32,7 @@ class JobStatus(str, Enum):
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
pending = "pending"
|
||||
cancelled = "cancelled"
|
||||
|
||||
|
||||
class AgentStepStatus(str, Enum):
|
||||
|
||||
@@ -2,6 +2,10 @@ from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""
|
||||
@@ -88,14 +92,14 @@ class LLMConfig(BaseModel):
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_reasoning_constraints(self) -> "LLMConfig":
|
||||
def issue_warning_for_reasoning_constraints(self) -> "LLMConfig":
|
||||
if self.enable_reasoner:
|
||||
if self.max_reasoning_tokens is None:
|
||||
raise ValueError("max_reasoning_tokens must be set when enable_reasoner is True")
|
||||
logger.warning("max_reasoning_tokens must be set when enable_reasoner is True")
|
||||
if self.max_tokens is not None and self.max_reasoning_tokens >= self.max_tokens:
|
||||
raise ValueError("max_tokens must be greater than max_reasoning_tokens (thinking budget)")
|
||||
logger.warning("max_tokens must be greater than max_reasoning_tokens (thinking budget)")
|
||||
if self.put_inner_thoughts_in_kwargs:
|
||||
raise ValueError("Extended thinking is not compatible with put_inner_thoughts_in_kwargs")
|
||||
logger.warning("Extended thinking is not compatible with put_inner_thoughts_in_kwargs")
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -37,6 +37,7 @@ from letta.schemas.letta_message_content import (
|
||||
get_letta_message_content_union_str_json_schema,
|
||||
)
|
||||
from letta.system import unpack_message
|
||||
from letta.utils import parse_json
|
||||
|
||||
|
||||
def add_inner_thoughts_to_tool_call(
|
||||
@@ -47,7 +48,7 @@ def add_inner_thoughts_to_tool_call(
|
||||
"""Add inner thoughts (arg + value) to a tool call"""
|
||||
try:
|
||||
# load the args list
|
||||
func_args = json.loads(tool_call.function.arguments)
|
||||
func_args = parse_json(tool_call.function.arguments)
|
||||
# create new ordered dict with inner thoughts first
|
||||
ordered_args = OrderedDict({inner_thoughts_key: inner_thoughts})
|
||||
# update with remaining args
|
||||
@@ -293,7 +294,7 @@ class Message(BaseMessage):
|
||||
if use_assistant_message and tool_call.function.name == assistant_message_tool_name:
|
||||
# We need to unpack the actual message contents from the function call
|
||||
try:
|
||||
func_args = json.loads(tool_call.function.arguments)
|
||||
func_args = parse_json(tool_call.function.arguments)
|
||||
message_string = func_args[assistant_message_tool_kwarg]
|
||||
except KeyError:
|
||||
raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument")
|
||||
@@ -336,7 +337,7 @@ class Message(BaseMessage):
|
||||
raise ValueError(f"Invalid tool return (no text object on message): {self.content}")
|
||||
|
||||
try:
|
||||
function_return = json.loads(text_content)
|
||||
function_return = parse_json(text_content)
|
||||
status = function_return["status"]
|
||||
if status == "OK":
|
||||
status_enum = "success"
|
||||
@@ -760,7 +761,7 @@ class Message(BaseMessage):
|
||||
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
||||
).model_dump()
|
||||
else:
|
||||
tool_call_input = json.loads(tool_call.function.arguments)
|
||||
tool_call_input = parse_json(tool_call.function.arguments)
|
||||
|
||||
content.append(
|
||||
{
|
||||
@@ -846,7 +847,7 @@ class Message(BaseMessage):
|
||||
function_args = tool_call.function.arguments
|
||||
try:
|
||||
# NOTE: Google AI wants actual JSON objects, not strings
|
||||
function_args = json.loads(function_args)
|
||||
function_args = parse_json(function_args)
|
||||
except:
|
||||
raise UserWarning(f"Failed to parse JSON function args: {function_args}")
|
||||
function_args = {"args": function_args}
|
||||
@@ -881,7 +882,7 @@ class Message(BaseMessage):
|
||||
|
||||
# NOTE: Google AI API wants the function response as JSON only, no string
|
||||
try:
|
||||
function_response = json.loads(text_content)
|
||||
function_response = parse_json(text_content)
|
||||
except:
|
||||
function_response = {"function_response": text_content}
|
||||
|
||||
@@ -970,7 +971,7 @@ class Message(BaseMessage):
|
||||
]
|
||||
for tc in self.tool_calls:
|
||||
function_name = tc.function["name"]
|
||||
function_args = json.loads(tc.function["arguments"])
|
||||
function_args = parse_json(tc.function["arguments"])
|
||||
function_args_str = ",".join([f"{k}={v}" for k, v in function_args.items()])
|
||||
function_call_text = f"{function_name}({function_args_str})"
|
||||
cohere_message.append(
|
||||
|
||||
@@ -16,6 +16,7 @@ from starlette.middleware.cors import CORSMiddleware
|
||||
from letta.__init__ import __version__
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
from letta.jobs.scheduler import shutdown_cron_scheduler, start_cron_jobs
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import create_letta_message_union_schema
|
||||
@@ -144,6 +145,12 @@ def create_application() -> "FastAPI":
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=settings.event_loop_threadpool_max_workers)
|
||||
loop.set_default_executor(executor)
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
global server
|
||||
|
||||
start_cron_jobs(server)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_mcp_clients():
|
||||
global server
|
||||
@@ -159,6 +166,10 @@ def create_application() -> "FastAPI":
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_scheduler():
|
||||
shutdown_cron_scheduler()
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_error_handler(request: Request, exc: Exception):
|
||||
# Log the actual error for debugging
|
||||
|
||||
@@ -160,6 +160,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
message_id: str,
|
||||
message_date: datetime,
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,7 @@ from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion
|
||||
@@ -31,6 +31,7 @@ from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
|
||||
@@ -593,7 +594,13 @@ async def send_message(
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
||||
|
||||
if agent.llm_config.model_endpoint_type == "anthropic" and not agent.enable_sleeptime and not agent.multi_agent_group:
|
||||
if (
|
||||
agent.llm_config.model_endpoint_type == "anthropic"
|
||||
and not agent.enable_sleeptime
|
||||
and not agent.multi_agent_group
|
||||
and not agent.agent_type == AgentType.sleeptime_agent
|
||||
and settings.use_experimental
|
||||
):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
@@ -649,7 +656,13 @@ async def send_message_streaming(
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
||||
|
||||
if agent.llm_config.model_endpoint_type == "anthropic" and not agent.enable_sleeptime and not agent.multi_agent_group:
|
||||
if (
|
||||
agent.llm_config.model_endpoint_type == "anthropic"
|
||||
and not agent.enable_sleeptime
|
||||
and not agent.multi_agent_group
|
||||
and not agent.agent_type == AgentType.sleeptime_agent
|
||||
and settings.use_experimental
|
||||
):
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
|
||||
@@ -8,6 +8,7 @@ from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
from composio.client import Composio
|
||||
from composio.client.collections import ActionModel, AppModel
|
||||
from fastapi import HTTPException
|
||||
@@ -352,6 +353,9 @@ class SyncServer(Server):
|
||||
self._llm_config_cache = {}
|
||||
self._embedding_config_cache = {}
|
||||
|
||||
# TODO: Replace this with the Anthropic client we have in house
|
||||
self.anthropic_async_client = AsyncAnthropic()
|
||||
|
||||
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
|
||||
"""Updated method to load agents from persisted storage"""
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
@@ -775,7 +779,7 @@ class SyncServer(Server):
|
||||
|
||||
def create_sleeptime_agent(self, main_agent: AgentState, actor: User) -> AgentState:
|
||||
request = CreateAgent(
|
||||
name=main_agent.name,
|
||||
name=main_agent.name + "-sleeptime",
|
||||
agent_type=AgentType.sleeptime_agent,
|
||||
block_ids=[block.id for block in main_agent.memory.blocks],
|
||||
memory_blocks=[
|
||||
|
||||
@@ -38,7 +38,7 @@ from letta.schemas.group import ManagerType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
@@ -115,9 +115,10 @@ class AgentManager:
|
||||
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump(to_orm=True)), actor=actor)
|
||||
block_ids.append(block.id)
|
||||
|
||||
# TODO: Remove this block once we deprecate the legacy `tools` field
|
||||
# create passed in `tools`
|
||||
tool_names = []
|
||||
# add passed in `tools`
|
||||
tool_names = agent_create.tools or []
|
||||
|
||||
# add base tools
|
||||
if agent_create.include_base_tools:
|
||||
if agent_create.agent_type == AgentType.sleeptime_agent:
|
||||
tool_names.extend(BASE_SLEEPTIME_TOOLS)
|
||||
@@ -128,42 +129,45 @@ class AgentManager:
|
||||
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||||
if agent_create.include_multi_agent_tools:
|
||||
tool_names.extend(MULTI_AGENT_TOOLS)
|
||||
if agent_create.tools:
|
||||
tool_names.extend(agent_create.tools)
|
||||
# Remove duplicates
|
||||
|
||||
# remove duplicates
|
||||
tool_names = list(set(tool_names))
|
||||
|
||||
# add default tool rules
|
||||
if agent_create.include_base_tool_rules:
|
||||
if not agent_create.tool_rules:
|
||||
tool_rules = []
|
||||
else:
|
||||
tool_rules = agent_create.tool_rules
|
||||
# convert tool names to ids
|
||||
tool_ids = []
|
||||
for tool_name in tool_names:
|
||||
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
tool_ids.append(tool.id)
|
||||
|
||||
# add passed in `tool_ids`
|
||||
for tool_id in agent_create.tool_ids or []:
|
||||
if tool_id not in tool_ids:
|
||||
tool = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
if tool:
|
||||
tool_ids.append(tool.id)
|
||||
tool_names.append(tool.name)
|
||||
else:
|
||||
raise ValueError(f"Tool {tool_id} not found")
|
||||
|
||||
# add default tool rules
|
||||
tool_rules = agent_create.tool_rules or []
|
||||
if agent_create.include_base_tool_rules:
|
||||
# apply default tool rules
|
||||
for tool_name in tool_names:
|
||||
if tool_name == "send_message" or tool_name == "send_message_to_agent_async" or tool_name == "finish_rethinking_memory":
|
||||
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
|
||||
elif tool_name in BASE_TOOLS:
|
||||
elif tool_name in BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS:
|
||||
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
|
||||
|
||||
if agent_create.agent_type == AgentType.sleeptime_agent:
|
||||
tool_rules.append(PydanticChildToolRule(tool_name="view_core_memory_with_line_numbers", children=["core_memory_insert"]))
|
||||
|
||||
else:
|
||||
tool_rules = agent_create.tool_rules
|
||||
# Check tool rules are valid
|
||||
# if custom rules, check tool rules are valid
|
||||
if agent_create.tool_rules:
|
||||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
|
||||
|
||||
tool_ids = agent_create.tool_ids or []
|
||||
for tool_name in tool_names:
|
||||
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
||||
if tool:
|
||||
tool_ids.append(tool.id)
|
||||
# Remove duplicates
|
||||
tool_ids = list(set(tool_ids))
|
||||
|
||||
# Create the agent
|
||||
agent_state = self._create_agent(
|
||||
name=agent_create.name,
|
||||
@@ -714,10 +718,12 @@ class AgentManager:
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||||
)
|
||||
# TODO: This seems kind of silly, why not just update the message?
|
||||
message = self.message_manager.create_message(message, actor=actor)
|
||||
message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system)
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
message = self.message_manager.update_message_by_id(
|
||||
message_id=curr_system_message.id,
|
||||
message_update=MessageUpdate(**message.model_dump()),
|
||||
actor=actor,
|
||||
)
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor)
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
|
||||
@@ -238,7 +238,9 @@ def initialize_message_sequence(
|
||||
first_user_message = get_login_event() # event letting Letta know the user just logged in
|
||||
|
||||
if include_initial_boot_message:
|
||||
if agent_state.llm_config.model is not None and "gpt-3.5" in agent_state.llm_config.model:
|
||||
if agent_state.agent_type == AgentType.sleeptime_agent:
|
||||
initial_boot_messages = []
|
||||
elif agent_state.llm_config.model is not None and "gpt-3.5" in agent_state.llm_config.model:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35")
|
||||
else:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message")
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
||||
from sqlalchemy import tuple_
|
||||
|
||||
from letta.jobs.types import BatchPollingResult, ItemUpdateInfo
|
||||
from letta.log import get_logger
|
||||
from letta.orm.llm_batch_items import LLMBatchItem
|
||||
from letta.orm.llm_batch_job import LLMBatchJob
|
||||
from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.llm_batch_job import LLMBatchItem as PydanticLLMBatchItem
|
||||
from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -28,7 +30,7 @@ class LLMBatchManager:
|
||||
@enforce_types
|
||||
def create_batch_request(
|
||||
self,
|
||||
llm_provider: str,
|
||||
llm_provider: ProviderType,
|
||||
create_batch_response: BetaMessageBatch,
|
||||
actor: PydanticUser,
|
||||
status: JobStatus = JobStatus.created,
|
||||
@@ -45,7 +47,7 @@ class LLMBatchManager:
|
||||
return batch.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_batch_request_by_id(self, batch_id: str, actor: PydanticUser) -> PydanticLLMBatchJob:
|
||||
def get_batch_job_by_id(self, batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob:
|
||||
"""Retrieve a single batch job by ID."""
|
||||
with self.session_maker() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||||
@@ -56,7 +58,7 @@ class LLMBatchManager:
|
||||
self,
|
||||
batch_id: str,
|
||||
status: JobStatus,
|
||||
actor: PydanticUser,
|
||||
actor: Optional[PydanticUser] = None,
|
||||
latest_polling_response: Optional[BetaMessageBatch] = None,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Update a batch job’s status and optionally its polling response."""
|
||||
@@ -65,7 +67,34 @@ class LLMBatchManager:
|
||||
batch.status = status
|
||||
batch.latest_polling_response = latest_polling_response
|
||||
batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
return batch.update(db_session=session, actor=actor).to_pydantic()
|
||||
batch = batch.update(db_session=session, actor=actor)
|
||||
return batch.to_pydantic()
|
||||
|
||||
def bulk_update_batch_statuses(
|
||||
self,
|
||||
updates: List[BatchPollingResult],
|
||||
) -> None:
|
||||
"""
|
||||
Efficiently update many LLMBatchJob rows. This is used by the cron jobs.
|
||||
|
||||
`updates` = [(batch_id, new_status, polling_response_or_None), …]
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
with self.session_maker() as session:
|
||||
mappings = []
|
||||
for batch_id, status, response in updates:
|
||||
mappings.append(
|
||||
{
|
||||
"id": batch_id,
|
||||
"status": status,
|
||||
"latest_polling_response": response,
|
||||
"last_polled_at": now,
|
||||
}
|
||||
)
|
||||
|
||||
session.bulk_update_mappings(LLMBatchJob, mappings)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_batch_request(self, batch_id: str, actor: PydanticUser) -> None:
|
||||
@@ -74,6 +103,18 @@ class LLMBatchManager:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=batch_id, actor=actor)
|
||||
batch.hard_delete(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def list_running_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
|
||||
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
|
||||
with self.session_maker() as session:
|
||||
query = session.query(LLMBatchJob).filter(LLMBatchJob.status == JobStatus.running)
|
||||
|
||||
if actor is not None:
|
||||
query = query.filter(LLMBatchJob.organization_id == actor.organization_id)
|
||||
|
||||
results = query.all()
|
||||
return [batch.to_pydantic() for batch in results]
|
||||
|
||||
@enforce_types
|
||||
def create_batch_item(
|
||||
self,
|
||||
@@ -131,6 +172,56 @@ class LLMBatchManager:
|
||||
|
||||
return item.update(db_session=session, actor=actor).to_pydantic()
|
||||
|
||||
def bulk_update_batch_items_by_agent(
|
||||
self,
|
||||
updates: List[ItemUpdateInfo],
|
||||
) -> None:
|
||||
"""
|
||||
Efficiently update LLMBatchItem rows by (batch_id, agent_id).
|
||||
|
||||
Args:
|
||||
updates: List of tuples:
|
||||
(batch_id, agent_id, new_request_status, batch_request_result)
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# For bulk_update_mappings, we need the primary key of each row
|
||||
# So we must map (batch_id, agent_id) → actual PK (id)
|
||||
# We'll do it in one DB query using the (batch_id, agent_id) sets
|
||||
|
||||
# 1. Gather the pairs
|
||||
key_pairs = [(b_id, a_id) for (b_id, a_id, *_rest) in updates]
|
||||
|
||||
# 2. Query items in a single step
|
||||
items = (
|
||||
session.query(LLMBatchItem.id, LLMBatchItem.batch_id, LLMBatchItem.agent_id)
|
||||
.filter(tuple_(LLMBatchItem.batch_id, LLMBatchItem.agent_id).in_(key_pairs))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Build a map from (batch_id, agent_id) → PK id
|
||||
pair_to_pk = {}
|
||||
for row_id, row_batch_id, row_agent_id in items:
|
||||
pair_to_pk[(row_batch_id, row_agent_id)] = row_id
|
||||
|
||||
# 3. Construct mappings for the PK-based bulk update
|
||||
mappings = []
|
||||
for batch_id, agent_id, new_status, new_result in updates:
|
||||
pk_id = pair_to_pk.get((batch_id, agent_id))
|
||||
if not pk_id:
|
||||
# Nonexistent or mismatch → skip
|
||||
continue
|
||||
mappings.append(
|
||||
{
|
||||
"id": pk_id,
|
||||
"request_status": new_status,
|
||||
"batch_request_result": new_result,
|
||||
}
|
||||
)
|
||||
|
||||
if mappings:
|
||||
session.bulk_update_mappings(LLMBatchItem, mappings)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_batch_item(self, item_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard delete a batch item by ID."""
|
||||
|
||||
@@ -12,6 +12,7 @@ from letta.services.helpers.tool_execution_helper import (
|
||||
install_pip_requirements_for_sandbox,
|
||||
)
|
||||
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
|
||||
from letta.settings import tool_settings
|
||||
from letta.tracing import log_event, trace_method
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
@@ -152,7 +153,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
)
|
||||
|
||||
try:
|
||||
stdout_bytes, stderr_bytes = await asyncio.wait_for(process.communicate(), timeout=60)
|
||||
stdout_bytes, stderr_bytes = await asyncio.wait_for(process.communicate(), timeout=tool_settings.local_sandbox_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# Terminate the process on timeout
|
||||
if process.returncode is None:
|
||||
|
||||
@@ -17,6 +17,7 @@ class ToolSettings(BaseSettings):
|
||||
|
||||
# Local Sandbox configurations
|
||||
local_sandbox_dir: Optional[str] = None
|
||||
local_sandbox_timeout: float = 180
|
||||
|
||||
# MCP settings
|
||||
mcp_connect_to_server_timeout: float = 30.0
|
||||
@@ -203,6 +204,9 @@ class Settings(BaseSettings):
|
||||
httpx_max_keepalive_connections: int = 500
|
||||
httpx_keepalive_expiry: float = 120.0
|
||||
|
||||
# cron job parameters
|
||||
poll_running_llm_batches_interval_seconds: int = 5 * 60
|
||||
|
||||
@property
|
||||
def letta_pg_uri(self) -> str:
|
||||
if self.pg_uri:
|
||||
|
||||
@@ -54,6 +54,7 @@ class AgentChunkStreamingInterface(ABC):
|
||||
message_id: str,
|
||||
message_date: datetime,
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server"""
|
||||
@@ -105,6 +106,7 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
|
||||
message_id: str,
|
||||
message_date: datetime,
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
):
|
||||
assert len(chunk.choices) == 1, chunk
|
||||
|
||||
118
poetry.lock
generated
118
poetry.lock
generated
@@ -215,6 +215,33 @@ files = [
|
||||
{file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "apscheduler"
|
||||
version = "3.11.0"
|
||||
description = "In-process task scheduler with Cron-like capabilities"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "APScheduler-3.11.0-py3-none-any.whl", hash = "sha256:fc134ca32e50f5eadcc4938e3a4545ab19131435e851abb40b34d63d5141c6da"},
|
||||
{file = "apscheduler-3.11.0.tar.gz", hash = "sha256:4c622d250b0955a65d5d0eb91c33e6d43fd879834bf541e0a18661ae60460133"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tzlocal = ">=3.0"
|
||||
|
||||
[package.extras]
|
||||
doc = ["packaging", "sphinx", "sphinx-rtd-theme (>=1.3.0)"]
|
||||
etcd = ["etcd3", "protobuf (<=3.21.0)"]
|
||||
gevent = ["gevent"]
|
||||
mongodb = ["pymongo (>=3.0)"]
|
||||
redis = ["redis (>=3.0)"]
|
||||
rethinkdb = ["rethinkdb (>=2.4.0)"]
|
||||
sqlalchemy = ["sqlalchemy (>=1.4)"]
|
||||
test = ["APScheduler[etcd,mongodb,redis,rethinkdb,sqlalchemy,tornado,zookeeper]", "PySide6", "anyio (>=4.5.2)", "gevent", "pytest", "pytz", "twisted"]
|
||||
tornado = ["tornado (>=4.3)"]
|
||||
twisted = ["twisted"]
|
||||
zookeeper = ["kazoo"]
|
||||
|
||||
[[package]]
|
||||
name = "argcomplete"
|
||||
version = "3.6.2"
|
||||
@@ -521,10 +548,6 @@ files = [
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"},
|
||||
{file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"},
|
||||
@@ -537,14 +560,8 @@ files = [
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"},
|
||||
{file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"},
|
||||
@@ -555,24 +572,8 @@ files = [
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"},
|
||||
{file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"},
|
||||
{file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"},
|
||||
@@ -582,10 +583,6 @@ files = [
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"},
|
||||
{file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"},
|
||||
@@ -597,10 +594,6 @@ files = [
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"},
|
||||
{file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"},
|
||||
@@ -613,10 +606,6 @@ files = [
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"},
|
||||
{file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"},
|
||||
@@ -629,10 +618,6 @@ files = [
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"},
|
||||
{file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"},
|
||||
{file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"},
|
||||
@@ -1056,9 +1041,9 @@ isort = ">=4.3.21,<6.0"
|
||||
jinja2 = ">=2.10.1,<4.0"
|
||||
packaging = "*"
|
||||
pydantic = [
|
||||
{version = ">=1.10.0,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.12\" and python_version < \"4.0\""},
|
||||
{version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.9.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.10.0,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.12\" and python_version < \"4.0\""},
|
||||
]
|
||||
pyyaml = ">=6.0.1"
|
||||
toml = {version = ">=0.10.0,<1.0.0", markers = "python_version < \"3.11\""}
|
||||
@@ -2744,13 +2729,13 @@ types-requests = ">=2.31.0.2,<3.0.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.3.26"
|
||||
version = "0.3.27"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "langsmith-0.3.26-py3-none-any.whl", hash = "sha256:3ae49e49d6f3c980a524d15ac2fd895896e709ecedc83ac150c38e1ead776e1b"},
|
||||
{file = "langsmith-0.3.26.tar.gz", hash = "sha256:3bd5b952a5fc82d69b0e2c030e502ee081a8ccf20468e96fd3d53e1572aef6fc"},
|
||||
{file = "langsmith-0.3.27-py3-none-any.whl", hash = "sha256:060956aaed5f391a85829daa0c220b5e07b2e7dd5d33be4b92f280672be984f7"},
|
||||
{file = "langsmith-0.3.27.tar.gz", hash = "sha256:0bdeda73cf723cbcde1cab0f3459f7e5d5748db28a33bf9f6bdc0e2f4fe0ee1e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2773,13 +2758,13 @@ pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4,<14.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "letta-client"
|
||||
version = "0.1.99"
|
||||
version = "0.1.100"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "letta_client-0.1.99-py3-none-any.whl", hash = "sha256:e3cddd9eb14c447fd4808888df5b40316a8dedf48360fb635bdf6282ca16a3fa"},
|
||||
{file = "letta_client-0.1.99.tar.gz", hash = "sha256:9aacf7d7daa5e0829831cdac34efa89d0556de32ad5aff0af998b276915e4938"},
|
||||
{file = "letta_client-0.1.100-py3-none-any.whl", hash = "sha256:711683cfbfa8b134e0a71b31a8692574f490efa96a3cbc04cf103d4921be3383"},
|
||||
{file = "letta_client-0.1.100.tar.gz", hash = "sha256:8d4e04da71528f43e0e9037dfa53287b65bbad7ebd0fac8dc2ab874099d0adcc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3082,8 +3067,8 @@ psutil = ">=5.9.1"
|
||||
pywin32 = {version = "*", markers = "sys_platform == \"win32\""}
|
||||
pyzmq = ">=25.0.0"
|
||||
requests = [
|
||||
{version = ">=2.26.0", markers = "python_version <= \"3.11\""},
|
||||
{version = ">=2.32.2", markers = "python_version > \"3.11\""},
|
||||
{version = ">=2.26.0", markers = "python_version <= \"3.11\""},
|
||||
]
|
||||
setuptools = ">=70.0.0"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
@@ -3936,9 +3921,9 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
@@ -4570,13 +4555,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.11.2"
|
||||
version = "2.11.3"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pydantic-2.11.2-py3-none-any.whl", hash = "sha256:7f17d25846bcdf89b670a86cdfe7b29a9f1c9ca23dee154221c9aa81845cfca7"},
|
||||
{file = "pydantic-2.11.2.tar.gz", hash = "sha256:2138628e050bd7a1e70b91d4bf4a91167f4ad76fdb83209b107c8d84b854917e"},
|
||||
{file = "pydantic-2.11.3-py3-none-any.whl", hash = "sha256:a082753436a07f9ba1289c6ffa01cd93db3548776088aa917cc43b63f68fa60f"},
|
||||
{file = "pydantic-2.11.3.tar.gz", hash = "sha256:7471657138c16adad9322fe3070c0116dd6c3ad8d649300e3cbdfe91f4db4ec3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -5237,8 +5222,8 @@ grpcio = ">=1.41.0"
|
||||
grpcio-tools = ">=1.41.0"
|
||||
httpx = {version = ">=0.20.0", extras = ["http2"]}
|
||||
numpy = [
|
||||
{version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""},
|
||||
{version = ">=1.26", markers = "python_version == \"3.12\""},
|
||||
{version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""},
|
||||
]
|
||||
portalocker = ">=2.7.0,<3.0.0"
|
||||
pydantic = ">=1.10.8"
|
||||
@@ -6219,6 +6204,23 @@ files = [
|
||||
{file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tzlocal"
|
||||
version = "5.3.1"
|
||||
description = "tzinfo object for the local timezone"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "tzlocal-5.3.1-py3-none-any.whl", hash = "sha256:eb1a66c3ef5847adf7a834f1be0800581b683b5608e74f86ecbcef8ab91bb85d"},
|
||||
{file = "tzlocal-5.3.1.tar.gz", hash = "sha256:cceffc7edecefea1f595541dbd6e990cb1ea3d19bf01b2809f362a03dd7921fd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tzdata = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
|
||||
[package.extras]
|
||||
devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.3.0"
|
||||
@@ -6804,10 +6806,10 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
|
||||
cffi = ["cffi (>=1.11)"]
|
||||
|
||||
[extras]
|
||||
all = ["autoflake", "black", "datamodel-code-generator", "docker", "fastapi", "isort", "langchain", "langchain-community", "locust", "pexpect", "pg8000", "pgvector", "pre-commit", "psycopg2", "psycopg2-binary", "pyright", "pytest-asyncio", "pytest-order", "uvicorn", "wikipedia"]
|
||||
all = ["autoflake", "black", "docker", "fastapi", "isort", "langchain", "langchain-community", "locust", "pexpect", "pg8000", "pgvector", "pre-commit", "psycopg2", "psycopg2-binary", "pyright", "pytest-asyncio", "pytest-order", "uvicorn", "wikipedia"]
|
||||
bedrock = ["boto3"]
|
||||
cloud-tool-sandbox = ["e2b-code-interpreter"]
|
||||
desktop = ["datamodel-code-generator", "docker", "fastapi", "langchain", "langchain-community", "letta_client", "locust", "mcp", "pg8000", "pgvector", "psycopg2", "psycopg2-binary", "pyright", "uvicorn", "wikipedia"]
|
||||
desktop = ["docker", "fastapi", "langchain", "langchain-community", "locust", "pg8000", "pgvector", "psycopg2", "psycopg2-binary", "pyright", "uvicorn", "wikipedia"]
|
||||
dev = ["autoflake", "black", "isort", "locust", "pexpect", "pre-commit", "pyright", "pytest-asyncio", "pytest-order"]
|
||||
external-tools = ["docker", "langchain", "langchain-community", "wikipedia"]
|
||||
google = ["google-genai"]
|
||||
@@ -6819,4 +6821,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "697e7c10385b25fbf5842ed430eb3565e50deb591bc2c05a616180460ebf0b28"
|
||||
content-hash = "c7532fe22e86ca8602c0b27be85020e0b139ec521cd5b0dc94b180113ede41c4"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "letta"
|
||||
version = "0.6.49"
|
||||
version = "0.6.50"
|
||||
packages = [
|
||||
{include = "letta"},
|
||||
]
|
||||
@@ -88,6 +88,7 @@ boto3 = {version = "^1.36.24", optional = true}
|
||||
datamodel-code-generator = {extras = ["http"], version = "^0.25.0"}
|
||||
mcp = "^1.3.0"
|
||||
firecrawl-py = "^1.15.0"
|
||||
apscheduler = "^3.11.0"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
@@ -100,8 +101,8 @@ external-tools = ["docker", "langchain", "wikipedia", "langchain-community"]
|
||||
tests = ["wikipedia"]
|
||||
bedrock = ["boto3"]
|
||||
google = ["google-genai"]
|
||||
desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "datamodel-code-generator", "mcp", "letta-client"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "datamodel-code-generator"]
|
||||
desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
|
||||
all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"]
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
331
tests/integration_test_batch_api.py
Normal file
331
tests/integration_test_batch_api.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from anthropic.types import BetaErrorResponse, BetaRateLimitError
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from anthropic.types.beta.messages import (
|
||||
BetaMessageBatch,
|
||||
BetaMessageBatchErroredResult,
|
||||
BetaMessageBatchIndividualResponse,
|
||||
BetaMessageBatchRequestCounts,
|
||||
BetaMessageBatchSucceededResult,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
||||
from letta.orm import Base
|
||||
from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.enums import JobStatus, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool_rule import InitToolRule
|
||||
from letta.server.db import db_context
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
# --- Server and Database Management --- #
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_tables():
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
if table.name in {"llm_batch_job", "llm_batch_items"}:
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5) # Allow server startup time
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
config.save()
|
||||
return SyncServer()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
return Letta(base_url=server_url)
|
||||
|
||||
|
||||
# --- Dummy Response Factories --- #
|
||||
|
||||
|
||||
def create_batch_response(batch_id: str, processing_status: str = "in_progress") -> BetaMessageBatch:
|
||||
"""Create a dummy BetaMessageBatch with the specified ID and status."""
|
||||
now = datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc)
|
||||
return BetaMessageBatch(
|
||||
id=batch_id,
|
||||
archived_at=now,
|
||||
cancel_initiated_at=now,
|
||||
created_at=now,
|
||||
ended_at=now,
|
||||
expires_at=now,
|
||||
processing_status=processing_status,
|
||||
request_counts=BetaMessageBatchRequestCounts(
|
||||
canceled=10,
|
||||
errored=30,
|
||||
expired=10,
|
||||
processing=100,
|
||||
succeeded=50,
|
||||
),
|
||||
results_url=None,
|
||||
type="message_batch",
|
||||
)
|
||||
|
||||
|
||||
def create_successful_response(custom_id: str) -> BetaMessageBatchIndividualResponse:
|
||||
"""Create a dummy successful batch response."""
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
custom_id=custom_id,
|
||||
result=BetaMessageBatchSucceededResult(
|
||||
type="succeeded",
|
||||
message=BetaMessage(
|
||||
id="msg_abc123",
|
||||
role="assistant",
|
||||
type="message",
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
content=[{"type": "text", "text": "hi!"}],
|
||||
usage={"input_tokens": 5, "output_tokens": 7},
|
||||
stop_reason="end_turn",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse:
|
||||
"""Create a dummy failed batch response with a rate limit error."""
|
||||
return BetaMessageBatchIndividualResponse(
|
||||
custom_id=custom_id,
|
||||
result=BetaMessageBatchErroredResult(
|
||||
type="errored",
|
||||
error=BetaErrorResponse(type="error", error=BetaRateLimitError(type="rate_limit_error", message="Rate limit hit.")),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# --- Test Setup Helpers --- #
|
||||
|
||||
|
||||
def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022"):
|
||||
"""Create a test agent with standardized configuration."""
|
||||
return client.agents.create(
|
||||
name=name,
|
||||
include_base_tools=True,
|
||||
model=model,
|
||||
tags=["test_agents"],
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
|
||||
|
||||
def create_test_batch_job(server, batch_response, default_user):
|
||||
"""Create a test batch job with the given batch response."""
|
||||
return server.batch_manager.create_batch_request(
|
||||
llm_provider=ProviderType.anthropic,
|
||||
create_batch_response=batch_response,
|
||||
actor=default_user,
|
||||
status=JobStatus.running,
|
||||
)
|
||||
|
||||
|
||||
def create_test_batch_item(server, batch_id, agent_id, default_user):
|
||||
"""Create a test batch item for the given batch and agent."""
|
||||
dummy_llm_config = LLMConfig(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint="https://api.anthropic.com/v1",
|
||||
context_window=32000,
|
||||
handle=f"anthropic/claude-3-7-sonnet-latest",
|
||||
put_inner_thoughts_in_kwargs=True,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
common_step_state = AgentStepState(
|
||||
step_number=1, tool_rules_solver=ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")])
|
||||
)
|
||||
|
||||
return server.batch_manager.create_batch_item(
|
||||
batch_id=batch_id,
|
||||
agent_id=agent_id,
|
||||
llm_config=dummy_llm_config,
|
||||
step_state=common_step_state,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_c_id):
|
||||
"""Set up mocks for the Anthropic client's retrieve and results methods."""
|
||||
|
||||
# Mock the retrieve method
|
||||
async def dummy_retrieve(batch_resp_id: str) -> BetaMessageBatch:
|
||||
if batch_resp_id == batch_a_resp.id:
|
||||
return batch_a_resp
|
||||
elif batch_resp_id == batch_b_resp.id:
|
||||
return batch_b_resp
|
||||
else:
|
||||
raise ValueError(f"Unknown batch response id: {batch_resp_id}")
|
||||
|
||||
server.anthropic_async_client.beta.messages.batches.retrieve = AsyncMock(side_effect=dummy_retrieve)
|
||||
|
||||
# Mock the results method
|
||||
def dummy_results(batch_resp_id: str):
|
||||
if batch_resp_id == batch_b_resp.id:
|
||||
|
||||
async def generator():
|
||||
yield create_successful_response(agent_b_id)
|
||||
yield create_failed_response(agent_c_id)
|
||||
|
||||
return generator()
|
||||
else:
|
||||
raise RuntimeError("This test should never request the results for batch_a.")
|
||||
|
||||
server.anthropic_async_client.beta.messages.batches.results = dummy_results
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# End-to-End Test
|
||||
# -----------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_polling_mixed_batch_jobs(client, default_user, server):
|
||||
"""
|
||||
End-to-end test for polling batch jobs with mixed statuses and idempotency.
|
||||
|
||||
Test scenario:
|
||||
- Create two batch jobs:
|
||||
- Job A: Single agent that remains "in_progress"
|
||||
- Job B: Two agents that complete (one succeeds, one fails)
|
||||
- Poll jobs and verify:
|
||||
- Job A remains in "running" state
|
||||
- Job B moves to "completed" state
|
||||
- Job B's items reflect appropriate individual success/failure statuses
|
||||
- Test idempotency:
|
||||
- Run additional polls and verify:
|
||||
- Completed job B remains unchanged (no status changes or re-polling)
|
||||
- In-progress job A continues to be polled
|
||||
- All batch items maintain their final states
|
||||
"""
|
||||
# --- Step 1: Prepare test data ---
|
||||
# Create batch responses with different statuses
|
||||
batch_a_resp = create_batch_response("msgbatch_A", processing_status="in_progress")
|
||||
batch_b_resp = create_batch_response("msgbatch_B", processing_status="ended")
|
||||
|
||||
# Create test agents
|
||||
agent_a = create_test_agent(client, "agent_a")
|
||||
agent_b = create_test_agent(client, "agent_b")
|
||||
agent_c = create_test_agent(client, "agent_c")
|
||||
|
||||
# --- Step 2: Create batch jobs ---
|
||||
job_a = create_test_batch_job(server, batch_a_resp, default_user)
|
||||
job_b = create_test_batch_job(server, batch_b_resp, default_user)
|
||||
|
||||
# --- Step 3: Create batch items ---
|
||||
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
|
||||
item_b = create_test_batch_item(server, job_b.id, agent_b.id, default_user)
|
||||
item_c = create_test_batch_item(server, job_b.id, agent_c.id, default_user)
|
||||
|
||||
# --- Step 4: Mock the Anthropic client ---
|
||||
mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b.id, agent_c.id)
|
||||
|
||||
# --- Step 5: Run the polling job twice (simulating periodic polling) ---
|
||||
await poll_running_llm_batches(server)
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# --- Step 6: Verify batch job status updates ---
|
||||
updated_job_a = server.batch_manager.get_batch_job_by_id(batch_id=job_a.id, actor=default_user)
|
||||
updated_job_b = server.batch_manager.get_batch_job_by_id(batch_id=job_b.id, actor=default_user)
|
||||
|
||||
# Job A should remain running since its processing_status is "in_progress"
|
||||
assert updated_job_a.status == JobStatus.running
|
||||
# Job B should be updated to completed
|
||||
assert updated_job_b.status == JobStatus.completed
|
||||
|
||||
# Both jobs should have been polled
|
||||
assert updated_job_a.last_polled_at is not None
|
||||
assert updated_job_b.last_polled_at is not None
|
||||
assert updated_job_b.latest_polling_response is not None
|
||||
|
||||
# --- Step 7: Verify batch item status updates ---
|
||||
# Item A should remain unchanged
|
||||
updated_item_a = server.batch_manager.get_batch_item_by_id(item_a.id, actor=default_user)
|
||||
assert updated_item_a.request_status == JobStatus.created
|
||||
assert updated_item_a.batch_request_result is None
|
||||
|
||||
# Item B should be marked as completed with a successful result
|
||||
updated_item_b = server.batch_manager.get_batch_item_by_id(item_b.id, actor=default_user)
|
||||
assert updated_item_b.request_status == JobStatus.completed
|
||||
assert updated_item_b.batch_request_result is not None
|
||||
|
||||
# Item C should be marked as failed with an error result
|
||||
updated_item_c = server.batch_manager.get_batch_item_by_id(item_c.id, actor=default_user)
|
||||
assert updated_item_c.request_status == JobStatus.failed
|
||||
assert updated_item_c.batch_request_result is not None
|
||||
|
||||
# --- Step 8: Test idempotency by running polls again ---
|
||||
# Save timestamps and response objects to compare later
|
||||
job_a_polled_at = updated_job_a.last_polled_at
|
||||
job_b_polled_at = updated_job_b.last_polled_at
|
||||
job_b_response = updated_job_b.latest_polling_response
|
||||
|
||||
# Save detailed item states
|
||||
item_a_status = updated_item_a.request_status
|
||||
item_b_status = updated_item_b.request_status
|
||||
item_c_status = updated_item_c.request_status
|
||||
item_b_result = updated_item_b.batch_request_result
|
||||
item_c_result = updated_item_c.batch_request_result
|
||||
|
||||
# Run the polling job again multiple times
|
||||
await poll_running_llm_batches(server)
|
||||
await poll_running_llm_batches(server)
|
||||
await poll_running_llm_batches(server)
|
||||
|
||||
# --- Step 9: Verify that nothing changed for completed jobs ---
|
||||
# Refresh all objects
|
||||
final_job_a = server.batch_manager.get_batch_job_by_id(batch_id=job_a.id, actor=default_user)
|
||||
final_job_b = server.batch_manager.get_batch_job_by_id(batch_id=job_b.id, actor=default_user)
|
||||
final_item_a = server.batch_manager.get_batch_item_by_id(item_a.id, actor=default_user)
|
||||
final_item_b = server.batch_manager.get_batch_item_by_id(item_b.id, actor=default_user)
|
||||
final_item_c = server.batch_manager.get_batch_item_by_id(item_c.id, actor=default_user)
|
||||
|
||||
# Job A should still be polling (last_polled_at should update)
|
||||
assert final_job_a.status == JobStatus.running
|
||||
assert final_job_a.last_polled_at > job_a_polled_at
|
||||
|
||||
# Job B should remain completed with no status changes
|
||||
assert final_job_b.status == JobStatus.completed
|
||||
# The completed job should not be polled again
|
||||
assert final_job_b.last_polled_at == job_b_polled_at
|
||||
assert final_job_b.latest_polling_response == job_b_response
|
||||
|
||||
# All items should maintain their final states
|
||||
assert final_item_a.request_status == item_a_status
|
||||
assert final_item_b.request_status == item_b_status
|
||||
assert final_item_c.request_status == item_c_status
|
||||
assert final_item_b.batch_request_result == item_b_result
|
||||
assert final_item_c.batch_request_result == item_c_result
|
||||
@@ -15,9 +15,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
|
||||
# --- Server Management --- #
|
||||
|
||||
@@ -155,44 +153,44 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
# --- Test Cases --- #
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Hi how are you today?"])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_latency(disable_e2b_api_key, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}/{agent.id}", max_retries=0)
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
print(chunk)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Use recall memory tool to recall what my name is."])
|
||||
@pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
async def test_voice_recall_memory(disable_e2b_api_key, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
# Insert some messages about my name
|
||||
client.user_message(agent.id, "My name is Matt")
|
||||
|
||||
# Wipe the in context messages
|
||||
actor = UserManager().get_default_user()
|
||||
AgentManager().set_in_context_messages(agent_id=agent.id, message_ids=[agent.message_ids[0]], actor=actor)
|
||||
|
||||
async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}/{agent.id}", max_retries=0)
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
print(chunk)
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize("message", ["Hi how are you today?"])
|
||||
# @pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
# async def test_latency(disable_e2b_api_key, client, agent, message, endpoint):
|
||||
# """Tests chat completion streaming using the Async OpenAI client."""
|
||||
# request = _get_chat_request(message)
|
||||
#
|
||||
# async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}/{agent.id}", max_retries=0)
|
||||
# stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
# async with stream:
|
||||
# async for chunk in stream:
|
||||
# print(chunk)
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize("message", ["Use recall memory tool to recall what my name is."])
|
||||
# @pytest.mark.parametrize("endpoint", ["v1/voice-beta"])
|
||||
# async def test_voice_recall_memory(disable_e2b_api_key, client, agent, message, endpoint):
|
||||
# """Tests chat completion streaming using the Async OpenAI client."""
|
||||
# request = _get_chat_request(message)
|
||||
#
|
||||
# # Insert some messages about my name
|
||||
# client.user_message(agent.id, "My name is Matt")
|
||||
#
|
||||
# # Wipe the in context messages
|
||||
# actor = UserManager().get_default_user()
|
||||
# AgentManager().set_in_context_messages(agent_id=agent.id, message_ids=[agent.message_ids[0]], actor=actor)
|
||||
#
|
||||
# async_client = AsyncOpenAI(base_url=f"{client.base_url}/{endpoint}/{agent.id}", max_retries=0)
|
||||
# stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
# async with stream:
|
||||
# async for chunk in stream:
|
||||
# print(chunk)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas.", "What's the weather in SF?"])
|
||||
@pytest.mark.parametrize("endpoint", ["openai/v1", "v1/voice-beta"])
|
||||
@pytest.mark.parametrize("endpoint", ["openai/v1"]) # , "v1/voice-beta"])
|
||||
async def test_chat_completions_streaming_openai_client(disable_e2b_api_key, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(message)
|
||||
|
||||
@@ -482,7 +482,7 @@ def test_create_agents_telemetry(client: Letta):
|
||||
print(f"[telemetry] Deleted {len(workers)} existing worker agents in {end_delete - start_delete:.2f}s")
|
||||
|
||||
# create worker agents
|
||||
num_workers = 100
|
||||
num_workers = 1
|
||||
agent_times = []
|
||||
for idx in range(num_workers):
|
||||
start = time.perf_counter()
|
||||
|
||||
@@ -130,21 +130,6 @@ def test_summarize_many_messages_basic(client, disable_e2b_api_key):
|
||||
client.delete_agent(small_agent_state.id)
|
||||
|
||||
|
||||
def test_summarize_large_message_does_not_loop_infinitely(client, disable_e2b_api_key):
|
||||
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
|
||||
small_context_llm_config.context_window = 2000
|
||||
small_agent_state = client.create_agent(
|
||||
name="super_small_context_agent",
|
||||
llm_config=small_context_llm_config,
|
||||
)
|
||||
with pytest.raises(ContextWindowExceededError, match=f"Ran summarizer {summarizer_settings.max_summarizer_retries}"):
|
||||
client.user_message(
|
||||
agent_id=small_agent_state.id,
|
||||
message="hi " * 1000,
|
||||
)
|
||||
client.delete_agent(small_agent_state.id)
|
||||
|
||||
|
||||
def test_summarize_messages_inplace(client, agent_state, disable_e2b_api_key):
|
||||
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
|
||||
# First send a few messages (5)
|
||||
|
||||
@@ -39,7 +39,7 @@ from letta.schemas.agent import AgentStepState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageRole, ProviderType
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate
|
||||
@@ -437,6 +437,7 @@ def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -452,6 +453,7 @@ def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -476,6 +478,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
|
||||
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
|
||||
message_buffer_autoclear=True,
|
||||
include_base_tools=False,
|
||||
)
|
||||
created_agent = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
@@ -549,6 +552,7 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -560,6 +564,7 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -571,6 +576,7 @@ def agent_with_tags(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -672,6 +678,7 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use
|
||||
tags=["a", "b"],
|
||||
description="test_description",
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, content="hello world")],
|
||||
include_base_tools=False,
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
@@ -697,6 +704,7 @@ def test_create_agent_default_initial_message(server: SyncServer, default_user,
|
||||
block_ids=[default_block.id],
|
||||
tags=["a", "b"],
|
||||
description="test_description",
|
||||
include_base_tools=False,
|
||||
)
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
@@ -841,6 +849,7 @@ def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -854,6 +863,7 @@ def test_list_agents_ascending(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -871,6 +881,7 @@ def test_list_agents_descending(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -884,6 +895,7 @@ def test_list_agents_descending(server: SyncServer, default_user):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -905,6 +917,7 @@ def test_list_agents_ordering_and_pagination(server: SyncServer, default_user):
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1266,6 +1279,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1281,6 +1295,7 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
memory_blocks=[],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1321,6 +1336,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
|
||||
description="This is a search agent for testing",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1332,6 +1348,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
|
||||
description="Another search agent for testing",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -1343,6 +1360,7 @@ def test_list_agents_query_text_pagination(server: SyncServer, default_user, def
|
||||
description="This is a different agent",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -3351,6 +3369,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
identity_ids=[identity.id],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -3359,6 +3378,7 @@ def test_get_set_agents_for_identities(server: SyncServer, sarah_agent, charles_
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
@@ -4643,6 +4663,7 @@ def test_list_tags(server: SyncServer, default_user, default_organization):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tags=tags[i : i + 3], # Each agent gets 3 consecutive tags
|
||||
include_base_tools=False,
|
||||
),
|
||||
)
|
||||
agents.append(agent)
|
||||
@@ -4687,20 +4708,20 @@ def test_list_tags(server: SyncServer, default_user, default_organization):
|
||||
|
||||
def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
)
|
||||
assert batch.id.startswith("batch_req-")
|
||||
assert batch.create_batch_response == dummy_beta_message_batch
|
||||
fetched = server.batch_manager.get_batch_request_by_id(batch.id, actor=default_user)
|
||||
fetched = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user)
|
||||
assert fetched.id == batch.id
|
||||
|
||||
|
||||
def test_update_batch_status(server, default_user, dummy_beta_message_batch):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
@@ -4714,7 +4735,7 @@ def test_update_batch_status(server, default_user, dummy_beta_message_batch):
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
updated = server.batch_manager.get_batch_request_by_id(batch.id, actor=default_user)
|
||||
updated = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user)
|
||||
assert updated.status == JobStatus.completed
|
||||
assert updated.latest_polling_response == dummy_beta_message_batch
|
||||
assert updated.last_polled_at >= before
|
||||
@@ -4722,7 +4743,7 @@ def test_update_batch_status(server, default_user, dummy_beta_message_batch):
|
||||
|
||||
def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
@@ -4748,7 +4769,7 @@ def test_update_batch_item(
|
||||
server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response
|
||||
):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
@@ -4780,7 +4801,7 @@ def test_update_batch_item(
|
||||
|
||||
def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state):
|
||||
batch = server.batch_manager.create_batch_request(
|
||||
llm_provider="anthropic",
|
||||
llm_provider=ProviderType.anthropic,
|
||||
status=JobStatus.created,
|
||||
create_batch_response=dummy_beta_message_batch,
|
||||
actor=default_user,
|
||||
|
||||
@@ -279,8 +279,6 @@ def server():
|
||||
def org_id(server):
|
||||
# create org
|
||||
org = server.organization_manager.create_default_organization()
|
||||
print(f"Created org\n{org.id}")
|
||||
|
||||
yield org.id
|
||||
|
||||
# cleanup
|
||||
@@ -338,7 +336,6 @@ def agent_id(server, user_id, base_tools):
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
yield agent_state.id
|
||||
|
||||
# cleanup
|
||||
@@ -359,7 +356,6 @@ def other_agent_id(server, user_id, base_tools):
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
yield agent_state.id
|
||||
|
||||
# cleanup
|
||||
@@ -953,7 +949,6 @@ def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, bas
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
|
||||
def count_system_messages_in_recall() -> Tuple[int, List[LettaMessage]]:
|
||||
|
||||
@@ -967,10 +962,6 @@ def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, bas
|
||||
)
|
||||
assert all(isinstance(m, LettaMessage) for m in letta_messages)
|
||||
|
||||
print("LETTA_MESSAGES:")
|
||||
for i, m in enumerate(letta_messages):
|
||||
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
|
||||
|
||||
# Collect system messages and their texts
|
||||
system_messages = [m for m in letta_messages if m.message_type == "system_message"]
|
||||
return len(system_messages), letta_messages
|
||||
@@ -1116,7 +1107,47 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
|
||||
|
||||
|
||||
def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools):
|
||||
def test_add_nonexisting_tool(server: SyncServer, user_id: str, base_tools):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
# create agent
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="memory_rebuild_test_agent",
|
||||
tools=["fake_nonexisting_tool"],
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="The human's name is Bob."),
|
||||
CreateBlock(label="persona", value="My name is Alice."),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
include_base_tools=True,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
def test_default_tool_rules(server: SyncServer, user_id: str, base_tools, base_memory_tools):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="tool_rules_test_agent",
|
||||
tool_ids=[t.id for t in base_tools + base_memory_tools],
|
||||
memory_blocks=[],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
assert len(agent_state.tool_rules) == len(base_tools + base_memory_tools)
|
||||
|
||||
|
||||
def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools, base_memory_tools):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user