feat: Support streaming and move endpoint for letta-free (#3780)
This commit is contained in:
@@ -5,7 +5,7 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
|
||||
LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
|
||||
LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir")
|
||||
|
||||
LETTA_MODEL_ENDPOINT = "https://inference.letta.com"
|
||||
LETTA_MODEL_ENDPOINT = "https://inference.letta.com/v1/"
|
||||
DEFAULT_TIMEZONE = "UTC"
|
||||
|
||||
ADMIN_PREFIX = "/v1/admin"
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX, LETTA_MODEL_ENDPOINT
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.errors import LettaConfigurationError, RateLimitExceededError
|
||||
from letta.llm_api.anthropic import (
|
||||
anthropic_bedrock_chat_completions_request,
|
||||
@@ -193,8 +193,8 @@ def create(
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle):
|
||||
function_call = "auto" # TODO change to "required" once proxy supports it
|
||||
if llm_config.handle and "vllm" in llm_config.handle:
|
||||
function_call = "auto"
|
||||
else:
|
||||
function_call = "required"
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ class LLMConfig(BaseModel):
|
||||
model="memgpt-openai",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=LETTA_MODEL_ENDPOINT,
|
||||
context_window=8192,
|
||||
context_window=30000,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model {model_name} not supported.")
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT, REDIS_RUN_ID_PREFIX
|
||||
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX
|
||||
from letta.data_sources.redis_client import get_redis_client
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
@@ -1019,7 +1019,6 @@ async def send_message_streaming(
|
||||
"ollama",
|
||||
]
|
||||
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]
|
||||
not_letta_endpoint = agent.llm_config.model_endpoint != LETTA_MODEL_ENDPOINT
|
||||
|
||||
# Create a new job for execution tracking
|
||||
if settings.track_agent_run:
|
||||
@@ -1087,7 +1086,7 @@ async def send_message_streaming(
|
||||
)
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
|
||||
|
||||
if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint:
|
||||
if request.stream_tokens and model_compatible_token_streaming:
|
||||
raw_stream = agent_loop.step_stream(
|
||||
input_messages=request.messages,
|
||||
max_steps=request.max_steps,
|
||||
|
||||
@@ -2229,10 +2229,7 @@ class SyncServer(Server):
|
||||
llm_config = letta_agent.agent_state.llm_config
|
||||
# supports_token_streaming = ["openai", "anthropic", "xai", "deepseek"]
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"] # TODO re-enable xAI once streaming is patched
|
||||
if stream_tokens and (
|
||||
llm_config.model_endpoint_type not in supports_token_streaming
|
||||
or llm_config.model_endpoint == constants.LETTA_MODEL_ENDPOINT
|
||||
):
|
||||
if stream_tokens and (llm_config.model_endpoint_type not in supports_token_streaming):
|
||||
warnings.warn(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
)
|
||||
@@ -2364,9 +2361,7 @@ class SyncServer(Server):
|
||||
|
||||
llm_config = letta_multi_agent.agent_state.llm_config
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"]
|
||||
if stream_tokens and (
|
||||
llm_config.model_endpoint_type not in supports_token_streaming or llm_config.model_endpoint == constants.LETTA_MODEL_ENDPOINT
|
||||
):
|
||||
if stream_tokens and (llm_config.model_endpoint_type not in supports_token_streaming):
|
||||
warnings.warn(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
)
|
||||
|
||||
@@ -1378,63 +1378,6 @@ def test_async_greeting_with_callback_url(
|
||||
assert headers.get("Content-Type") == "application/json", "Callback should have JSON content type"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
def test_async_callback_failure_scenarios(
|
||||
disable_e2b_api_key: Any,
|
||||
client: Letta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Tests that job completion works even when callback URLs fail.
|
||||
This ensures callback failures don't affect job processing.
|
||||
"""
|
||||
config_filename = None
|
||||
for filename in filenames:
|
||||
config = get_llm_config(filename)
|
||||
if config.model_dump() == llm_config.model_dump():
|
||||
config_filename = filename
|
||||
break
|
||||
|
||||
# skip if this is a limited model
|
||||
if not config_filename or config_filename in limited_configs:
|
||||
pytest.skip(f"Skipping test for limited model {llm_config.model}")
|
||||
|
||||
client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
# Test with invalid callback URL - job should still complete
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
callback_url="http://invalid-domain-that-does-not-exist.com/callback",
|
||||
)
|
||||
|
||||
# Wait for job completion - should work despite callback failure
|
||||
run = wait_for_run_completion(client, run.id)
|
||||
|
||||
# Validate job completed successfully
|
||||
result = run.metadata.get("result")
|
||||
assert result is not None, "Run metadata missing 'result' key"
|
||||
|
||||
messages = cast_message_dict_to_messages(result["messages"])
|
||||
assert_greeting_with_assistant_message_response(messages, llm_config=llm_config)
|
||||
|
||||
# Job should be marked as completed even if callback failed
|
||||
assert run.status == "completed", f"Expected status 'completed', got {run.status}"
|
||||
|
||||
# Validate callback failure was properly recorded
|
||||
assert run.callback_sent_at is not None, "callback_sent_at should be set even for failed callbacks"
|
||||
assert run.callback_error is not None, "callback_error should be set to error message for failed callbacks"
|
||||
assert isinstance(run.callback_error, str), "callback_error should be error message string for failed callbacks"
|
||||
assert "Failed to dispatch callback" in run.callback_error, "callback_error should contain error details"
|
||||
assert run.id in run.callback_error, "callback_error should contain job ID"
|
||||
assert "invalid-domain-that-does-not-exist.com" in run.callback_error, "callback_error should contain failed URL"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
|
||||
Reference in New Issue
Block a user