From 3ed216673ec949ccbe352025d0da6a3988e9149f Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 23 Jan 2025 13:45:06 -1000 Subject: [PATCH] feat: Use official OpenAI client (#752) --- letta/llm_api/llm_api_tools.py | 2 - letta/llm_api/openai.py | 162 +++++------------------------ letta/server/rest_api/interface.py | 2 - tests/test_server.py | 20 ++-- 4 files changed, 36 insertions(+), 150 deletions(-) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 431e0d97..d535ac24 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -290,7 +290,6 @@ def create( # # max_tokens=1024, # TODO make dynamic # ), # ) - elif llm_config.model_endpoint_type == "groq": if stream: raise NotImplementedError(f"Streaming not yet implemented for Groq.") @@ -329,7 +328,6 @@ def create( try: # groq uses the openai chat completions API, so this component should be reusable response = openai_chat_completions_request( - url=llm_config.model_endpoint, api_key=model_settings.groq_api_key, chat_completion_request=data, ) diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index c335c6cb..80180d72 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -1,14 +1,9 @@ -import json import warnings from typing import Generator, List, Optional, Union -import httpx import requests -from httpx_sse import connect_sse -from httpx_sse._exceptions import SSEError +from openai import OpenAI -from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING -from letta.errors import LLMError from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages @@ -378,126 +373,21 @@ def openai_chat_completions_process_stream( return chat_completion_response -def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionChunkResponse, None, None]: - - with httpx.Client() as client: - with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: - - # Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12) - if not event_source.response.is_success: - # handle errors - from letta.utils import printd - - printd("Caught error before iterating SSE request:", vars(event_source.response)) - printd(event_source.response.read()) - - try: - response_bytes = event_source.response.read() - response_dict = json.loads(response_bytes.decode("utf-8")) - error_message = response_dict["error"]["message"] - # e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions. - if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message: - raise LLMError(error_message) - except LLMError: - raise - except: - print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack") - event_source.response.raise_for_status() - - try: - for sse in event_source.iter_sse(): - # printd(sse.event, sse.data, sse.id, sse.retry) - if sse.data == OPENAI_SSE_DONE: - # print("finished") - break - else: - chunk_data = json.loads(sse.data) - # print("chunk_data::", chunk_data) - chunk_object = ChatCompletionChunkResponse(**chunk_data) - # print("chunk_object::", chunk_object) - # id=chunk_data["id"], - # choices=[ChunkChoice], - # model=chunk_data["model"], - # system_fingerprint=chunk_data["system_fingerprint"] - # ) - yield chunk_object - - except SSEError as e: - print("Caught an error while iterating the SSE stream:", str(e)) - if "application/json" in str(e): # Check if the error is because of JSON response - # TODO figure out a better way to catch the error other than re-trying with a POST - response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response - if response.headers["Content-Type"].startswith("application/json"): - error_details = response.json() # Parse the JSON to get the error message - print("Request:", vars(response.request)) - print("POST Error:", error_details) - print("Original SSE Error:", str(e)) - else: - print("Failed to retrieve JSON error message via retry.") - else: - print("SSEError not related to 'application/json' content type.") - - # Optionally re-raise the exception if you need to propagate it - raise e - - except Exception as e: - if event_source.response.request is not None: - print("HTTP Request:", vars(event_source.response.request)) - if event_source.response is not None: - print("HTTP Status:", event_source.response.status_code) - print("HTTP Headers:", event_source.response.headers) - # print("HTTP Body:", event_source.response.text) - print("Exception message:", str(e)) - raise e - - def openai_chat_completions_request_stream( url: str, api_key: str, chat_completion_request: ChatCompletionRequest, ) -> Generator[ChatCompletionChunkResponse, None, None]: - from letta.utils import printd - - url = smart_urljoin(url, "chat/completions") - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - data = chat_completion_request.model_dump(exclude_none=True) - - printd("Request:\n", json.dumps(data, indent=2)) - - # If functions == None, strip from the payload - if "functions" in data and data["functions"] is None: - data.pop("functions") - data.pop("function_call", None) # extra safe, should exist always (default="auto") - - if "tools" in data and data["tools"] is None: - data.pop("tools") - data.pop("tool_choice", None) # extra safe, should exist always (default="auto") - - if "tools" in data: - for tool in data["tools"]: - # tool["strict"] = True - try: - tool["function"] = convert_to_structured_output(tool["function"]) - except ValueError as e: - warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}") - - # print(f"\n\n\n\nData[tools]: {json.dumps(data['tools'], indent=2)}") - - printd(f"Sending request to {url}") - try: - return _sse_post(url=url, data=data, headers=headers) - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - printd(f"Got HTTPError, exception={http_err}, payload={data}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") - raise req_err - except Exception as e: - # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") - raise e + data = prepare_openai_payload(chat_completion_request) + data["stream"] = True + client = OpenAI( + api_key=api_key, + base_url=url, + ) + stream = client.chat.completions.create(**data) + for chunk in stream: + # TODO: Use the native OpenAI objects here? + yield ChatCompletionChunkResponse(**chunk.model_dump(exclude_none=True)) def openai_chat_completions_request( @@ -512,18 +402,28 @@ def openai_chat_completions_request( https://platform.openai.com/docs/guides/text-generation?lang=curl """ - from letta.utils import printd + data = prepare_openai_payload(chat_completion_request) + client = OpenAI(api_key=api_key, base_url=url) + chat_completion = client.chat.completions.create(**data) + return ChatCompletionResponse(**chat_completion.model_dump()) - url = smart_urljoin(url, "chat/completions") + +def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse: + """https://platform.openai.com/docs/api-reference/embeddings/create""" + + url = smart_urljoin(url, "embeddings") headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + response_json = make_post_request(url, headers, data) + return EmbeddingResponse(**response_json) + + +def prepare_openai_payload(chat_completion_request: ChatCompletionRequest): data = chat_completion_request.model_dump(exclude_none=True) # add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified." if chat_completion_request.tools is not None: data["parallel_tool_calls"] = False - printd("Request:\n", json.dumps(data, indent=2)) - # If functions == None, strip from the payload if "functions" in data and data["functions"] is None: data.pop("functions") @@ -540,14 +440,4 @@ def openai_chat_completions_request( except ValueError as e: warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}") - response_json = make_post_request(url, headers, data) - return ChatCompletionResponse(**response_json) - - -def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse: - """https://platform.openai.com/docs/api-reference/embeddings/create""" - - url = smart_urljoin(url, "embeddings") - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - response_json = make_post_request(url, headers, data) - return EmbeddingResponse(**response_json) + return data diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 93370330..1bed9a25 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -1018,8 +1018,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # new_message = {"function_return": msg, "status": "success"} assert msg_obj.tool_call_id is not None - print(f"YYY printing the function call - {msg_obj.tool_call_id} == {self.prev_assistant_message_id} ???") - # Skip this is use_assistant_message is on if self.use_assistant_message and msg_obj.tool_call_id == self.prev_assistant_message_id: # Wipe the cache diff --git a/tests/test_server.py b/tests/test_server.py index 8bbe449b..a0d2c663 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -332,7 +332,7 @@ def agent_id(server, user_id, base_tools): name="test_agent", tool_ids=[t.id for t in base_tools], memory_blocks=[], - model="openai/gpt-4", + model="openai/gpt-4o", embedding="openai/text-embedding-ada-002", ), actor=actor, @@ -353,7 +353,7 @@ def other_agent_id(server, user_id, base_tools): name="test_agent_other", tool_ids=[t.id for t in base_tools], memory_blocks=[], - model="openai/gpt-4", + model="openai/gpt-4o", embedding="openai/text-embedding-ada-002", ), actor=actor, @@ -428,11 +428,11 @@ def test_save_archival_memory(server, user_id, agent_id): @pytest.mark.order(4) def test_user_message(server, user, agent_id): # add data into recall memory - server.user_message(user_id=user.id, agent_id=agent_id, message="Hello?") - # server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") - # server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") - # server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") - # server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") + response = server.user_message(user_id=user.id, agent_id=agent_id, message="What's up?") + assert response.step_count == 1 + assert response.completion_tokens > 0 + assert response.prompt_tokens > 0 + assert response.total_tokens > 0 @pytest.mark.order(5) @@ -552,7 +552,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User): request=CreateAgent( name="nonexistent_tools_agent", memory_blocks=[], - model="openai/gpt-4", + model="openai/gpt-4o", embedding="openai/text-embedding-ada-002", ), actor=user, @@ -920,7 +920,7 @@ def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, b CreateBlock(label="human", value="The human's name is Bob."), CreateBlock(label="persona", value="My name is Alice."), ], - model="openai/gpt-4", + model="openai/gpt-4o", embedding="openai/text-embedding-ada-002", ), actor=actor, @@ -1108,7 +1108,7 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to CreateBlock(label="human", value="The human's name is Bob."), CreateBlock(label="persona", value="My name is Alice."), ], - model="openai/gpt-4", + model="openai/gpt-4o", embedding="openai/text-embedding-ada-002", include_base_tools=False, ),