feat: New openai client (#1460)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
# Avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
@@ -10,6 +10,10 @@ if TYPE_CHECKING:
|
||||
class ErrorCode(Enum):
|
||||
"""Enum for error codes used by client."""
|
||||
|
||||
NOT_FOUND = "NOT_FOUND"
|
||||
UNAUTHENTICATED = "UNAUTHENTICATED"
|
||||
PERMISSION_DENIED = "PERMISSION_DENIED"
|
||||
INVALID_ARGUMENT = "INVALID_ARGUMENT"
|
||||
INTERNAL_SERVER_ERROR = "INTERNAL_SERVER_ERROR"
|
||||
CONTEXT_WINDOW_EXCEEDED = "CONTEXT_WINDOW_EXCEEDED"
|
||||
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
|
||||
@@ -18,7 +22,9 @@ class ErrorCode(Enum):
|
||||
class LettaError(Exception):
|
||||
"""Base class for all Letta related errors."""
|
||||
|
||||
def __init__(self, message: str, code: Optional[ErrorCode] = None, details: dict = {}):
|
||||
def __init__(self, message: str, code: Optional[ErrorCode] = None, details: Optional[Union[Dict, str, object]] = None):
|
||||
if details is None:
|
||||
details = {}
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details
|
||||
@@ -91,7 +97,8 @@ class LLMUnprocessableEntityError(LLMError):
|
||||
|
||||
|
||||
class LLMServerError(LLMError):
|
||||
"""Error when LLM service encounters an internal error"""
|
||||
"""Error indicating an internal server error occurred within the LLM service itself
|
||||
while processing the request."""
|
||||
|
||||
|
||||
class BedrockPermissionError(LettaError):
|
||||
|
||||
@@ -49,5 +49,12 @@ class LLMClient:
|
||||
llm_config=llm_config,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case "openai":
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
return OpenAIClient(
|
||||
llm_config=llm_config,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
case _:
|
||||
return None
|
||||
|
||||
254
letta/llm_api/openai_client.py
Normal file
254
letta/llm_api/openai_client.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.errors import (
|
||||
ErrorCode,
|
||||
LLMAuthenticationError,
|
||||
LLMBadRequestError,
|
||||
LLMConnectionError,
|
||||
LLMNotFoundError,
|
||||
LLMPermissionDeniedError,
|
||||
LLMRateLimitError,
|
||||
LLMServerError,
|
||||
LLMUnprocessableEntityError,
|
||||
)
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, unpack_all_inner_thoughts_from_kwargs
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall
|
||||
from letta.schemas.openai.chat_completion_request import FunctionSchema
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient(LLMClientBase):
|
||||
def _prepare_client_kwargs(self) -> dict:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
# supposedly the openai python client requires a dummy API key
|
||||
api_key = api_key or "DUMMY_API_KEY"
|
||||
kwargs = {"api_key": api_key, "base_url": self.llm_config.model_endpoint}
|
||||
|
||||
return kwargs
|
||||
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||
tool_call: Optional[str] = None, # Note: OpenAI uses tool_choice
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for the OpenAI API.
|
||||
"""
|
||||
if tools and self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
|
||||
# TODO(fix)
|
||||
inner_thoughts_desc = (
|
||||
INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in self.llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
)
|
||||
tools = add_inner_thoughts_to_functions(
|
||||
functions=tools,
|
||||
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
||||
inner_thoughts_description=inner_thoughts_desc,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
|
||||
openai_message_list = [
|
||||
cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=self.llm_config.put_inner_thoughts_in_kwargs))
|
||||
for m in messages
|
||||
]
|
||||
|
||||
if self.llm_config.model:
|
||||
model = self.llm_config.model
|
||||
else:
|
||||
logger.warning(f"Model type not set in llm_config: {self.llm_config.model_dump_json(indent=4)}")
|
||||
model = None
|
||||
|
||||
if tool_call is None and tools is not None and len(tools) > 0:
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (
|
||||
self.llm_config.handle and "vllm" in self.llm_config.handle
|
||||
):
|
||||
tool_call = "auto" # TODO change to "required" once proxy supports it
|
||||
else:
|
||||
tool_call = "required"
|
||||
|
||||
if tool_call not in ["none", "auto", "required"]:
|
||||
tool_call = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=tool_call))
|
||||
|
||||
data = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=openai_message_list,
|
||||
tools=[OpenAITool(type="function", function=f) for f in tools] if tools else None,
|
||||
tool_choice=tool_call,
|
||||
user=str(),
|
||||
max_completion_tokens=self.llm_config.max_tokens,
|
||||
temperature=self.llm_config.temperature,
|
||||
)
|
||||
|
||||
if "inference.memgpt.ai" in self.llm_config.model_endpoint:
|
||||
# override user id for inference.memgpt.ai
|
||||
import uuid
|
||||
|
||||
data.user = str(uuid.UUID(int=0))
|
||||
data.model = "memgpt-openai"
|
||||
|
||||
if data.tools is not None and len(data.tools) > 0:
|
||||
# Convert to structured output style (which has 'strict' and no optionals)
|
||||
for tool in data.tools:
|
||||
try:
|
||||
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
||||
tool.function = FunctionSchema(**structured_output_version)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
|
||||
return data.model_dump(exclude_unset=True)
|
||||
|
||||
def request(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
client = OpenAI(**self._prepare_client_kwargs())
|
||||
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
async def request_async(self, request_data: dict) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
client = AsyncOpenAI(**self._prepare_client_kwargs())
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Converts raw OpenAI response dict into the ChatCompletionResponse Pydantic model.
|
||||
Handles potential extraction of inner thoughts if they were added via kwargs.
|
||||
"""
|
||||
# OpenAI's response structure directly maps to ChatCompletionResponse
|
||||
# We just need to instantiate the Pydantic model for validation and type safety.
|
||||
chat_completion_response = ChatCompletionResponse(**response_data)
|
||||
|
||||
# Unpack inner thoughts if they were embedded in function arguments
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
chat_completion_response = unpack_all_inner_thoughts_from_kwargs(
|
||||
response=chat_completion_response, inner_thoughts_key=INNER_THOUGHTS_KWARG
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
def stream(self, request_data: dict) -> Stream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying streaming request to OpenAI and returns the stream iterator.
|
||||
"""
|
||||
client = OpenAI(**self._prepare_client_kwargs())
|
||||
response_stream: Stream[ChatCompletionChunk] = client.chat.completions.create(**request_data, stream=True)
|
||||
return response_stream
|
||||
|
||||
async def stream_async(self, request_data: dict) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
client = AsyncOpenAI(**self._prepare_client_kwargs())
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(**request_data, stream=True)
|
||||
return response_stream
|
||||
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
"""
|
||||
Maps OpenAI-specific errors to common LLMError types.
|
||||
"""
|
||||
if isinstance(e, openai.APIConnectionError):
|
||||
logger.warning(f"[OpenAI] API connection error: {e}")
|
||||
return LLMConnectionError(
|
||||
message=f"Failed to connect to OpenAI: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
||||
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
||||
)
|
||||
|
||||
if isinstance(e, openai.RateLimitError):
|
||||
logger.warning(f"[OpenAI] Rate limited (429). Consider backoff. Error: {e}")
|
||||
return LLMRateLimitError(
|
||||
message=f"Rate limited by OpenAI: {str(e)}",
|
||||
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
details=e.body, # Include body which often has rate limit details
|
||||
)
|
||||
|
||||
if isinstance(e, openai.BadRequestError):
|
||||
logger.warning(f"[OpenAI] Bad request (400): {str(e)}")
|
||||
# BadRequestError can signify different issues (e.g., invalid args, context length)
|
||||
# Check message content if finer-grained errors are needed
|
||||
# Example: if "context_length_exceeded" in str(e): return LLMContextLengthExceededError(...)
|
||||
return LLMBadRequestError(
|
||||
message=f"Bad request to OpenAI: {str(e)}",
|
||||
code=ErrorCode.INVALID_ARGUMENT, # Or more specific if detectable
|
||||
details=e.body,
|
||||
)
|
||||
|
||||
if isinstance(e, openai.AuthenticationError):
|
||||
logger.error(f"[OpenAI] Authentication error (401): {str(e)}") # More severe log level
|
||||
return LLMAuthenticationError(
|
||||
message=f"Authentication failed with OpenAI: {str(e)}", code=ErrorCode.UNAUTHENTICATED, details=e.body
|
||||
)
|
||||
|
||||
if isinstance(e, openai.PermissionDeniedError):
|
||||
logger.error(f"[OpenAI] Permission denied (403): {str(e)}") # More severe log level
|
||||
return LLMPermissionDeniedError(
|
||||
message=f"Permission denied by OpenAI: {str(e)}", code=ErrorCode.PERMISSION_DENIED, details=e.body
|
||||
)
|
||||
|
||||
if isinstance(e, openai.NotFoundError):
|
||||
logger.warning(f"[OpenAI] Resource not found (404): {str(e)}")
|
||||
# Could be invalid model name, etc.
|
||||
return LLMNotFoundError(message=f"Resource not found in OpenAI: {str(e)}", code=ErrorCode.NOT_FOUND, details=e.body)
|
||||
|
||||
if isinstance(e, openai.UnprocessableEntityError):
|
||||
logger.warning(f"[OpenAI] Unprocessable entity (422): {str(e)}")
|
||||
return LLMUnprocessableEntityError(
|
||||
message=f"Invalid request content for OpenAI: {str(e)}",
|
||||
code=ErrorCode.INVALID_ARGUMENT, # Usually validation errors
|
||||
details=e.body,
|
||||
)
|
||||
|
||||
# General API error catch-all
|
||||
if isinstance(e, openai.APIStatusError):
|
||||
logger.warning(f"[OpenAI] API status error ({e.status_code}): {str(e)}")
|
||||
# Map based on status code potentially
|
||||
if e.status_code >= 500:
|
||||
error_cls = LLMServerError
|
||||
error_code = ErrorCode.INTERNAL_SERVER_ERROR
|
||||
else:
|
||||
# Treat other 4xx as bad requests if not caught above
|
||||
error_cls = LLMBadRequestError
|
||||
error_code = ErrorCode.INVALID_ARGUMENT
|
||||
|
||||
return error_cls(
|
||||
message=f"OpenAI API error: {str(e)}",
|
||||
code=error_code,
|
||||
details={
|
||||
"status_code": e.status_code,
|
||||
"response": str(e.response),
|
||||
"body": e.body,
|
||||
},
|
||||
)
|
||||
|
||||
# Fallback for unexpected errors
|
||||
return super().handle_llm_error(e)
|
||||
@@ -1,5 +1,6 @@
|
||||
import concurrent
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
@@ -41,10 +42,10 @@ def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
# if not os.getenv("LETTA_SERVER_URL"):
|
||||
# thread = threading.Thread(target=_run_server, daemon=True)
|
||||
# thread.start()
|
||||
# time.sleep(5) # Allow server startup time
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5) # Allow server startup time
|
||||
|
||||
return url
|
||||
|
||||
@@ -160,15 +161,15 @@ def composio_gmail_get_profile_tool(default_user):
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_state(client, roll_dice_tool, weather_tool, rethink_tool):
|
||||
"""Creates an agent and ensures cleanup after tests."""
|
||||
llm_config = LLMConfig(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint="https://api.anthropic.com/v1",
|
||||
context_window=32000,
|
||||
handle=f"anthropic/claude-3-7-sonnet-latest",
|
||||
put_inner_thoughts_in_kwargs=True,
|
||||
max_tokens=4096,
|
||||
)
|
||||
# llm_config = LLMConfig(
|
||||
# model="claude-3-7-sonnet-latest",
|
||||
# model_endpoint_type="anthropic",
|
||||
# model_endpoint="https://api.anthropic.com/v1",
|
||||
# context_window=32000,
|
||||
# handle=f"anthropic/claude-3-7-sonnet-latest",
|
||||
# put_inner_thoughts_in_kwargs=True,
|
||||
# max_tokens=4096,
|
||||
# )
|
||||
agent_state = client.agents.create(
|
||||
name=f"test_compl_{str(uuid.uuid4())[5:]}",
|
||||
tool_ids=[roll_dice_tool.id, weather_tool.id, rethink_tool.id],
|
||||
@@ -183,7 +184,7 @@ def agent_state(client, roll_dice_tool, weather_tool, rethink_tool):
|
||||
"value": "Friendly agent",
|
||||
},
|
||||
],
|
||||
llm_config=llm_config,
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
yield agent_state
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -604,41 +604,3 @@ def test_agent_download_upload_flow(fastapi_client, server, serialize_test_agent
|
||||
agent_id=copied_agent_id,
|
||||
messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")],
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add this back
|
||||
# @pytest.mark.parametrize("test_af_filename", ["deep_research_agent.af"])
|
||||
# def test_agent_file_upload_flow(fastapi_client, server, default_user, other_user, test_af_filename):
|
||||
# """
|
||||
# Test the full E2E serialization and deserialization flow using FastAPI endpoints.
|
||||
# """
|
||||
# file_path = Path(__file__).parent / "test_agent_files" / test_af_filename
|
||||
# with open(file_path, "r") as f:
|
||||
# data = json.load(f)
|
||||
#
|
||||
# # Ensure response matches expected schema
|
||||
# agent_schema = AgentSchema.model_validate(data) # Validate as Pydantic model
|
||||
# agent_json = agent_schema.model_dump(mode="json") # Convert back to serializable JSON
|
||||
#
|
||||
# import ipdb;ipdb.set_trace()
|
||||
#
|
||||
# # Step 2: Upload the serialized agent as a copy
|
||||
# agent_bytes = BytesIO(json.dumps(agent_json).encode("utf-8"))
|
||||
# files = {"file": ("agent.json", agent_bytes, "application/json")}
|
||||
# upload_response = fastapi_client.post(
|
||||
# "/v1/agents/import",
|
||||
# headers={"user_id": other_user.id},
|
||||
# params={"append_copy_suffix": True, "override_existing_tools": False, "project_id": None},
|
||||
# files=files,
|
||||
# )
|
||||
# assert upload_response.status_code == 200, f"Upload failed: {upload_response.text}"
|
||||
#
|
||||
# copied_agent = upload_response.json()
|
||||
# copied_agent_id = copied_agent["id"]
|
||||
#
|
||||
# # Step 3: Ensure copied agent receives messages correctly
|
||||
# server.send_messages(
|
||||
# actor=other_user,
|
||||
# agent_id=copied_agent_id,
|
||||
# messages=[MessageCreate(role=MessageRole.user, content="Hello copied agent!")],
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user