feat: add letta-free endpoint constant (#1907)
This commit is contained in:
@@ -4,6 +4,8 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
|
||||
LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
|
||||
LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir")
|
||||
|
||||
LETTA_MODEL_ENDPOINT = "https://inference.memgpt.ai"
|
||||
|
||||
ADMIN_PREFIX = "/v1/admin"
|
||||
API_PREFIX = "/v1"
|
||||
OPENAI_API_PREFIX = "/openai"
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.constants import CLI_WARNING_PREFIX, LETTA_MODEL_ENDPOINT
|
||||
from letta.errors import LettaConfigurationError, RateLimitExceededError
|
||||
from letta.llm_api.anthropic import (
|
||||
anthropic_bedrock_chat_completions_request,
|
||||
@@ -181,7 +181,7 @@ def create(
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in llm_config.handle):
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle):
|
||||
function_call = "auto" # TODO change to "required" once proxy supports it
|
||||
else:
|
||||
function_call = "required"
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Generator, List, Optional, Union
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT
|
||||
from letta.helpers.datetime_helpers import timestamp_to_datetime
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
|
||||
from letta.llm_api.openai_client import supports_parallel_tool_calling, supports_temperature_param
|
||||
@@ -156,7 +157,7 @@ def build_openai_chat_completions_request(
|
||||
# if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model:
|
||||
# data.response_format = {"type": "json_object"}
|
||||
|
||||
if "inference.memgpt.ai" in llm_config.model_endpoint:
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
||||
# override user id for inference.memgpt.ai
|
||||
import uuid
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT
|
||||
from letta.errors import (
|
||||
ErrorCode,
|
||||
LLMAuthenticationError,
|
||||
@@ -115,7 +116,7 @@ class OpenAIClient(LLMClientBase):
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in llm_config.handle):
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle):
|
||||
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
||||
elif tools:
|
||||
# only set if tools is non-Null
|
||||
@@ -134,7 +135,7 @@ class OpenAIClient(LLMClientBase):
|
||||
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
||||
)
|
||||
|
||||
if "inference.memgpt.ai" in llm_config.model_endpoint:
|
||||
if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
||||
# override user id for inference.memgpt.ai
|
||||
import uuid
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -163,7 +164,7 @@ class LLMConfig(BaseModel):
|
||||
return cls(
|
||||
model="memgpt-openai",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://inference.memgpt.ai",
|
||||
model_endpoint=LETTA_MODEL_ENDPOINT,
|
||||
context_window=8192,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||
from letta.constants import LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||
from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint
|
||||
from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -78,7 +78,7 @@ class LettaProvider(Provider):
|
||||
LLMConfig(
|
||||
model="letta-free", # NOTE: renamed
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://inference.memgpt.ai",
|
||||
model_endpoint=LETTA_MODEL_ENDPOINT,
|
||||
context_window=8192,
|
||||
handle=self.get_handle("letta-free"),
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from fastapi.responses import StreamingResponse
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
|
||||
from letta.agent import Agent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
@@ -54,7 +54,7 @@ async def create_chat_completions(
|
||||
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
llm_config = letta_agent.agent_state.llm_config
|
||||
if llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint:
|
||||
if llm_config.model_endpoint_type != "openai" or llm_config.model_endpoint == LETTA_MODEL_ENDPOINT:
|
||||
error_msg = f"You can only use models with type 'openai' for chat completions. This agent {agent_id} has llm_config: \n{llm_config.model_dump_json(indent=4)}"
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
@@ -1562,7 +1562,8 @@ class SyncServer(Server):
|
||||
# supports_token_streaming = ["openai", "anthropic", "xai", "deepseek"]
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"] # TODO re-enable xAI once streaming is patched
|
||||
if stream_tokens and (
|
||||
llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint
|
||||
llm_config.model_endpoint_type not in supports_token_streaming
|
||||
or llm_config.model_endpoint == constants.LETTA_MODEL_ENDPOINT
|
||||
):
|
||||
warnings.warn(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
@@ -1685,7 +1686,7 @@ class SyncServer(Server):
|
||||
llm_config = letta_multi_agent.agent_state.llm_config
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"]
|
||||
if stream_tokens and (
|
||||
llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint
|
||||
llm_config.model_endpoint_type not in supports_token_streaming or llm_config.model_endpoint == constants.LETTA_MODEL_ENDPOINT
|
||||
):
|
||||
warnings.warn(
|
||||
f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."
|
||||
|
||||
Reference in New Issue
Block a user