diff --git a/letta/constants.py b/letta/constants.py index 64e49d07..63079049 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -4,6 +4,8 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta") LETTA_TOOL_EXECUTION_DIR = os.path.join(LETTA_DIR, "tool_execution_dir") +LETTA_MODEL_ENDPOINT = "https://inference.memgpt.ai" + ADMIN_PREFIX = "/v1/admin" API_PREFIX = "/v1" OPENAI_API_PREFIX = "/openai" diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 243dfecf..ffc88ef6 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union import requests -from letta.constants import CLI_WARNING_PREFIX +from letta.constants import CLI_WARNING_PREFIX, LETTA_MODEL_ENDPOINT from letta.errors import LettaConfigurationError, RateLimitExceededError from letta.llm_api.anthropic import ( anthropic_bedrock_chat_completions_request, @@ -181,7 +181,7 @@ def create( # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # TODO(matt) move into LLMConfig # TODO: This vllm checking is very brittle and is a patch at most - if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in llm_config.handle): + if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle): function_call = "auto" # TODO change to "required" once proxy supports it else: function_call = "required" diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index eda4c9a8..578f2d02 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -4,6 +4,7 @@ from typing import Generator, List, Optional, Union import requests from openai import OpenAI +from letta.constants import LETTA_MODEL_ENDPOINT from letta.helpers.datetime_helpers import timestamp_to_datetime from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request from letta.llm_api.openai_client import supports_parallel_tool_calling, supports_temperature_param @@ -156,7 +157,7 @@ def build_openai_chat_completions_request( # if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model: # data.response_format = {"type": "json_object"} - if "inference.memgpt.ai" in llm_config.model_endpoint: + if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: # override user id for inference.memgpt.ai import uuid diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 5639f884..96e473c7 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -6,6 +6,7 @@ from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from letta.constants import LETTA_MODEL_ENDPOINT from letta.errors import ( ErrorCode, LLMAuthenticationError, @@ -115,7 +116,7 @@ class OpenAIClient(LLMClientBase): # TODO(matt) move into LLMConfig # TODO: This vllm checking is very brittle and is a patch at most tool_choice = None - if llm_config.model_endpoint == "https://inference.memgpt.ai" or (llm_config.handle and "vllm" in llm_config.handle): + if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT or (llm_config.handle and "vllm" in llm_config.handle): tool_choice = "auto" # TODO change to "required" once proxy supports it elif tools: # only set if tools is non-Null @@ -134,7 +135,7 @@ class OpenAIClient(LLMClientBase): temperature=llm_config.temperature if supports_temperature_param(model) else None, ) - if "inference.memgpt.ai" in llm_config.model_endpoint: + if llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: # override user id for inference.memgpt.ai import uuid diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index d6027a85..f94e5f35 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -2,6 +2,7 @@ from typing import Literal, Optional from pydantic import BaseModel, ConfigDict, Field, model_validator +from letta.constants import LETTA_MODEL_ENDPOINT from letta.log import get_logger logger = get_logger(__name__) @@ -163,7 +164,7 @@ class LLMConfig(BaseModel): return cls( model="memgpt-openai", model_endpoint_type="openai", - model_endpoint="https://inference.memgpt.ai", + model_endpoint=LETTA_MODEL_ENDPOINT, context_window=8192, ) else: diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 90a025a9..3a20aea7 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -4,7 +4,7 @@ from typing import List, Optional from pydantic import Field, model_validator -from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW +from letta.constants import LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint from letta.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH from letta.schemas.embedding_config import EmbeddingConfig @@ -78,7 +78,7 @@ class LettaProvider(Provider): LLMConfig( model="letta-free", # NOTE: renamed model_endpoint_type="openai", - model_endpoint="https://inference.memgpt.ai", + model_endpoint=LETTA_MODEL_ENDPOINT, context_window=8192, handle=self.get_handle("letta-free"), ) diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 3ebdb3af..089e8048 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -6,7 +6,7 @@ from fastapi.responses import StreamingResponse from openai.types.chat.completion_create_params import CompletionCreateParams from letta.agent import Agent -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT from letta.log import get_logger from letta.schemas.message import Message, MessageCreate from letta.schemas.user import User @@ -54,7 +54,7 @@ async def create_chat_completions( letta_agent = server.load_agent(agent_id=agent_id, actor=actor) llm_config = letta_agent.agent_state.llm_config - if llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint: + if llm_config.model_endpoint_type != "openai" or llm_config.model_endpoint == LETTA_MODEL_ENDPOINT: error_msg = f"You can only use models with type 'openai' for chat completions. This agent {agent_id} has llm_config: \n{llm_config.model_dump_json(indent=4)}" logger.error(error_msg) raise HTTPException(status_code=400, detail=error_msg) diff --git a/letta/server/server.py b/letta/server/server.py index 6190773b..f8972255 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1562,7 +1562,8 @@ class SyncServer(Server): # supports_token_streaming = ["openai", "anthropic", "xai", "deepseek"] supports_token_streaming = ["openai", "anthropic", "deepseek"] # TODO re-enable xAI once streaming is patched if stream_tokens and ( - llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint + llm_config.model_endpoint_type not in supports_token_streaming + or llm_config.model_endpoint == constants.LETTA_MODEL_ENDPOINT ): warnings.warn( f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." @@ -1685,7 +1686,7 @@ class SyncServer(Server): llm_config = letta_multi_agent.agent_state.llm_config supports_token_streaming = ["openai", "anthropic", "deepseek"] if stream_tokens and ( - llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint + llm_config.model_endpoint_type not in supports_token_streaming or llm_config.model_endpoint == constants.LETTA_MODEL_ENDPOINT ): warnings.warn( f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False."