feat: New openai client (#1460)

This commit is contained in:
Matthew Zhou
2025-03-31 13:08:59 -07:00
committed by GitHub
parent eb6269697e
commit 4fe496f3f3
6 changed files with 286 additions and 56 deletions

View File

@@ -1,6 +1,6 @@
import json
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
# Avoid circular imports
if TYPE_CHECKING:
@@ -10,6 +10,10 @@ if TYPE_CHECKING:
class ErrorCode(Enum):
"""Enum for error codes used by client."""
NOT_FOUND = "NOT_FOUND"
UNAUTHENTICATED = "UNAUTHENTICATED"
PERMISSION_DENIED = "PERMISSION_DENIED"
INVALID_ARGUMENT = "INVALID_ARGUMENT"
INTERNAL_SERVER_ERROR = "INTERNAL_SERVER_ERROR"
CONTEXT_WINDOW_EXCEEDED = "CONTEXT_WINDOW_EXCEEDED"
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
@@ -18,7 +22,9 @@ class ErrorCode(Enum):
class LettaError(Exception):
"""Base class for all Letta related errors."""
def __init__(self, message: str, code: Optional[ErrorCode] = None, details: dict = {}):
def __init__(self, message: str, code: Optional[ErrorCode] = None, details: Optional[Union[Dict, str, object]] = None):
if details is None:
details = {}
self.message = message
self.code = code
self.details = details
@@ -91,7 +97,8 @@ class LLMUnprocessableEntityError(LLMError):
class LLMServerError(LLMError):
"""Error when LLM service encounters an internal error"""
"""Error indicating an internal server error occurred within the LLM service itself
while processing the request."""
class BedrockPermissionError(LettaError):

View File

@@ -49,5 +49,12 @@ class LLMClient:
llm_config=llm_config,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case "openai":
from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient(
llm_config=llm_config,
put_inner_thoughts_first=put_inner_thoughts_first,
)
case _:
return None

View File

@@ -0,0 +1,254 @@
import os
from typing import List, Optional
import openai
from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.errors import (
ErrorCode,
LLMAuthenticationError,
LLMBadRequestError,
LLMConnectionError,
LLMNotFoundError,
LLMPermissionDeniedError,
LLMRateLimitError,
LLMServerError,
LLMUnprocessableEntityError,
)
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, unpack_all_inner_thoughts_from_kwargs
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.log import get_logger
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
from letta.schemas.openai.chat_completion_request import FunctionSchema
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import model_settings
logger = get_logger(__name__)
class OpenAIClient(LLMClientBase):
def _prepare_client_kwargs(self) -> dict:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
# supposedly the openai python client requires a dummy API key
api_key = api_key or "DUMMY_API_KEY"
kwargs = {"api_key": api_key, "base_url": self.llm_config.model_endpoint}
return kwargs
def build_request_data(
self,
messages: List[PydanticMessage],
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
tool_call: Optional[str] = None, # Note: OpenAI uses tool_choice
force_tool_call: Optional[str] = None,
) -> dict:
"""
Constructs a request object in the expected data format for the OpenAI API.
"""
if tools and self.llm_config.put_inner_thoughts_in_kwargs:
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
# TODO(fix)
inner_thoughts_desc = (
INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in self.llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION
)
tools = add_inner_thoughts_to_functions(
functions=tools,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
inner_thoughts_description=inner_thoughts_desc,
put_inner_thoughts_first=True,
)
openai_message_list = [
cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=self.llm_config.put_inner_thoughts_in_kwargs))
for m in messages
]
if self.llm_config.model:
model = self.llm_config.model
else:
logger.warning(f"Model type not set in llm_config: {self.llm_config.model_dump_json(indent=4)}")
model = None
if tool_call is None and tools is not None and len(tools) > 0:
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# TODO(matt) move into LLMConfig
# TODO: This vllm checking is very brittle and is a patch at most
if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (
self.llm_config.handle and "vllm" in self.llm_config.handle
):
tool_call = "auto" # TODO change to "required" once proxy supports it
else:
tool_call = "required"
if tool_call not in ["none", "auto", "required"]:
tool_call = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=tool_call))
data = ChatCompletionRequest(
model=model,
messages=openai_message_list,
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
tool_choice=tool_call,
user=str(),
max_completion_tokens=self.llm_config.max_tokens,
temperature=self.llm_config.temperature,
)
if "inference.memgpt.ai" in self.llm_config.model_endpoint:
# override user id for inference.memgpt.ai
import uuid
data.user = str(uuid.UUID(int=0))
data.model = "memgpt-openai"
if data.tools is not None and len(data.tools) > 0:
# Convert to structured output style (which has 'strict' and no optionals)
for tool in data.tools:
try:
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
return data.model_dump(exclude_unset=True)
def request(self, request_data: dict) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
client = OpenAI(**self._prepare_client_kwargs())
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
async def request_async(self, request_data: dict) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
client = AsyncOpenAI(**self._prepare_client_kwargs())
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
) -> ChatCompletionResponse:
"""
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
Handles potential extraction of inner thoughts if they were added via kwargs.
"""
# OpenAI's response structure directly maps to ChatCompletionResponse
# We just need to instantiate the Pydantic model for validation and type safety.
chat_completion_response = ChatCompletionResponse(**response_data)
# Unpack inner thoughts if they were embedded in function arguments
if self.llm_config.put_inner_thoughts_in_kwargs:
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
)
return chat_completion_response
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
"""
Performs underlying streaming request to OpenAI and returns the stream iterator.
"""
client = OpenAI(**self._prepare_client_kwargs())
response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True)
return response_stream
async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
client = AsyncOpenAI(**self._prepare_client_kwargs())
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True)
return response_stream
def handle_llm_error(self, e: Exception) -> Exception:
"""
Maps OpenAI-specific errors to common LLMError types.
"""
if isinstance(e, openai.APIConnectionError):
logger.warning(f"[OpenAI] API connection error: {e}")
return LLMConnectionError(
message=f"Failed to connect to OpenAI: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None},
)
if isinstance(e, openai.RateLimitError):
logger.warning(f"[OpenAI] Rate limited (429). Consider backoff. Error: {e}")
return LLMRateLimitError(
message=f"Rate limited by OpenAI: {str(e)}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
details=e.body, # Include body which often has rate limit details
)
if isinstance(e, openai.BadRequestError):
logger.warning(f"[OpenAI] Bad request (400): {str(e)}")
# BadRequestError can signify different issues (e.g., invalid args, context length)
# Check message content if finer-grained errors are needed
# Example: if "context_length_exceeded" in str(e): return LLMContextLengthExceededError(...)
return LLMBadRequestError(
message=f"Bad request to OpenAI: {str(e)}",
code=ErrorCode.INVALID_ARGUMENT, # Or more specific if detectable
details=e.body,
)
if isinstance(e, openai.AuthenticationError):
logger.error(f"[OpenAI] Authentication error (401): {str(e)}") # More severe log level
return LLMAuthenticationError(
message=f"Authentication failed with OpenAI: {str(e)}", code=ErrorCode.UNAUTHENTICATED, details=e.body
)
if isinstance(e, openai.PermissionDeniedError):
logger.error(f"[OpenAI] Permission denied (403): {str(e)}") # More severe log level
return LLMPermissionDeniedError(
message=f"Permission denied by OpenAI: {str(e)}", code=ErrorCode.PERMISSION_DENIED, details=e.body
)
if isinstance(e, openai.NotFoundError):
logger.warning(f"[OpenAI] Resource not found (404): {str(e)}")
# Could be invalid model name, etc.
return LLMNotFoundError(message=f"Resource not found in OpenAI: {str(e)}", code=ErrorCode.NOT_FOUND, details=e.body)
if isinstance(e, openai.UnprocessableEntityError):
logger.warning(f"[OpenAI] Unprocessable entity (422): {str(e)}")
return LLMUnprocessableEntityError(
message=f"Invalid request content for OpenAI: {str(e)}",
code=ErrorCode.INVALID_ARGUMENT, # Usually validation errors
details=e.body,
)
# General API error catch-all
if isinstance(e, openai.APIStatusError):
logger.warning(f"[OpenAI] API status error ({e.status_code}): {str(e)}")
# Map based on status code potentially
if e.status_code >= 500:
error_cls = LLMServerError
error_code = ErrorCode.INTERNAL_SERVER_ERROR
else:
# Treat other 4xx as bad requests if not caught above
error_cls = LLMBadRequestError
error_code = ErrorCode.INVALID_ARGUMENT
return error_cls(
message=f"OpenAI API error: {str(e)}",
code=error_code,
details={
"status_code": e.status_code,
"response": str(e.response),
"body": e.body,
},
)
# Fallback for unexpected errors
return super().handle_llm_error(e)

View File

@@ -1,5 +1,6 @@
import concurrent
import os
import threading
import time
import uuid
@@ -41,10 +42,10 @@ def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
# if not os.getenv("LETTA_SERVER_URL"):
# thread = threading.Thread(target=_run_server, daemon=True)
# thread.start()
# time.sleep(5) # Allow server startup time
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
time.sleep(5) # Allow server startup time
return url
@@ -160,15 +161,15 @@ def composio_gmail_get_profile_tool(default_user):
@pytest.fixture(scope="function")
def agent_state(client, roll_dice_tool, weather_tool, rethink_tool):
"""Creates an agent and ensures cleanup after tests."""
llm_config = LLMConfig(
model="claude-3-7-sonnet-latest",
model_endpoint_type="anthropic",
model_endpoint="https://api.anthropic.com/v1",
context_window=32000,
handle=f"anthropic/claude-3-7-sonnet-latest",
put_inner_thoughts_in_kwargs=True,
max_tokens=4096,
)
# llm_config = LLMConfig(
# model="claude-3-7-sonnet-latest",
# model_endpoint_type="anthropic",
# model_endpoint="https://api.anthropic.com/v1",
# context_window=32000,
# handle=f"anthropic/claude-3-7-sonnet-latest",
# put_inner_thoughts_in_kwargs=True,
# max_tokens=4096,
# )
agent_state = client.agents.create(
name=f"test_compl_{str(uuid.uuid4())[5:]}",
tool_ids=[roll_dice_tool.id, weather_tool.id, rethink_tool.id],
@@ -183,7 +184,7 @@ def agent_state(client, roll_dice_tool, weather_tool, rethink_tool):
"value": "Friendly agent",
},
],
llm_config=llm_config,
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
yield agent_state

File diff suppressed because one or more lines are too long

View File

@@ -604,41 +604,3 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
agent_id=copied_agent_id,
messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")],
)
# TODO: Add this back
# @pytest.mark.parametrize("test_af_filename", ["deep_research_agent.af"])
# def test_agent_file_upload_flow(fastapi_client, server, default_user, other_user, test_af_filename):
# """
# Test the full E2E serialization and deserialization flow using FastAPI endpoints.
# """
# file_path = Path(__file__).parent / "test_agent_files" / test_af_filename
# with open(file_path, "r") as f:
# data = json.load(f)
#
# # Ensure response matches expected schema
# agent_schema = AgentSchema.model_validate(data) # Validate as Pydantic model
# agent_json = agent_schema.model_dump(mode="json") # Convert back to serializable JSON
#
# import ipdb;ipdb.set_trace()
#
# # Step 2: Upload the serialized agent as a copy
# agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
# files = {"file": ("agent.json", agent_bytes, "application/json")}
# upload_response = fastapi_client.post(
# "/v1/agents/import",
# headers={"user_id": other_user.id},
# params={"append_copy_suffix": True, "override_existing_tools": False, "project_id": None},
# files=files,
# )
# assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
#
# copied_agent = upload_response.json()
# copied_agent_id = copied_agent["id"]
#
# # Step 3: Ensure copied agent receives messages correctly
# server.send_messages(
# actor=other_user,
# agent_id=copied_agent_id,
# messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")],
# )