feat: Use official OpenAI client (#752)

This commit is contained in:
Matthew Zhou
2025-01-23 13:45:06 -10:00
committed by GitHub
parent 587ff08a52
commit 3ed216673e
4 changed files with 36 additions and 150 deletions

View File

@@ -290,7 +290,6 @@ def create(
# # max_tokens=1024, # TODO make dynamic
# ),
# )
elif llm_config.model_endpoint_type == "groq":
if stream:
raise NotImplementedError(f"Streaming not yet implemented for Groq.")
@@ -329,7 +328,6 @@ def create(
try:
# groq uses the openai chat completions API, so this component should be reusable
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.groq_api_key,
chat_completion_request=data,
)

View File

@@ -1,14 +1,9 @@
import json
import warnings
from typing import Generator, List, Optional, Union
import httpx
import requests
from httpx_sse import connect_sse
from httpx_sse._exceptions import SSEError
from openai import OpenAI
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
@@ -378,126 +373,21 @@ def openai_chat_completions_process_stream(
return chat_completion_response
def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionChunkResponse, None, None]:
with httpx.Client() as client:
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
# Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12)
if not event_source.response.is_success:
# handle errors
from letta.utils import printd
printd("Caught error before iterating SSE request:", vars(event_source.response))
printd(event_source.response.read())
try:
response_bytes = event_source.response.read()
response_dict = json.loads(response_bytes.decode("utf-8"))
error_message = response_dict["error"]["message"]
# e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions.
if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message:
raise LLMError(error_message)
except LLMError:
raise
except:
print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
event_source.response.raise_for_status()
try:
for sse in event_source.iter_sse():
# printd(sse.event, sse.data, sse.id, sse.retry)
if sse.data == OPENAI_SSE_DONE:
# print("finished")
break
else:
chunk_data = json.loads(sse.data)
# print("chunk_data::", chunk_data)
chunk_object = ChatCompletionChunkResponse(**chunk_data)
# print("chunk_object::", chunk_object)
# id=chunk_data["id"],
# choices=[ChunkChoice],
# model=chunk_data["model"],
# system_fingerprint=chunk_data["system_fingerprint"]
# )
yield chunk_object
except SSEError as e:
print("Caught an error while iterating the SSE stream:", str(e))
if "application/json" in str(e): # Check if the error is because of JSON response
# TODO figure out a better way to catch the error other than re-trying with a POST
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response
if response.headers["Content-Type"].startswith("application/json"):
error_details = response.json() # Parse the JSON to get the error message
print("Request:", vars(response.request))
print("POST Error:", error_details)
print("Original SSE Error:", str(e))
else:
print("Failed to retrieve JSON error message via retry.")
else:
print("SSEError not related to 'application/json' content type.")
# Optionally re-raise the exception if you need to propagate it
raise e
except Exception as e:
if event_source.response.request is not None:
print("HTTP Request:", vars(event_source.response.request))
if event_source.response is not None:
print("HTTP Status:", event_source.response.status_code)
print("HTTP Headers:", event_source.response.headers)
# print("HTTP Body:", event_source.response.text)
print("Exception message:", str(e))
raise e
def openai_chat_completions_request_stream(
url: str,
api_key: str,
chat_completion_request: ChatCompletionRequest,
) -> Generator[ChatCompletionChunkResponse, None, None]:
from letta.utils import printd
url = smart_urljoin(url, "chat/completions")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = chat_completion_request.model_dump(exclude_none=True)
printd("Request:\n", json.dumps(data, indent=2))
# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
data.pop("functions")
data.pop("function_call", None) # extra safe, should exist always (default="auto")
if "tools" in data and data["tools"] is None:
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
if "tools" in data:
for tool in data["tools"]:
# tool["strict"] = True
try:
tool["function"] = convert_to_structured_output(tool["function"])
except ValueError as e:
warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
# print(f"\n\n\n\nData[tools]: {json.dumps(data['tools'], indent=2)}")
printd(f"Sending request to {url}")
try:
return _sse_post(url=url, data=data, headers=headers)
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
data = prepare_openai_payload(chat_completion_request)
data["stream"] = True
client = OpenAI(
api_key=api_key,
base_url=url,
)
stream = client.chat.completions.create(**data)
for chunk in stream:
# TODO: Use the native OpenAI objects here?
yield ChatCompletionChunkResponse(**chunk.model_dump(exclude_none=True))
def openai_chat_completions_request(
@@ -512,18 +402,28 @@ def openai_chat_completions_request(
https://platform.openai.com/docs/guides/text-generation?lang=curl
"""
from letta.utils import printd
data = prepare_openai_payload(chat_completion_request)
client = OpenAI(api_key=api_key, base_url=url)
chat_completion = client.chat.completions.create(**data)
return ChatCompletionResponse(**chat_completion.model_dump())
url = smart_urljoin(url, "chat/completions")
def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse:
"""https://platform.openai.com/docs/api-reference/embeddings/create"""
url = smart_urljoin(url, "embeddings")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
response_json = make_post_request(url, headers, data)
return EmbeddingResponse(**response_json)
def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):
data = chat_completion_request.model_dump(exclude_none=True)
# add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified."
if chat_completion_request.tools is not None:
data["parallel_tool_calls"] = False
printd("Request:\n", json.dumps(data, indent=2))
# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
data.pop("functions")
@@ -540,14 +440,4 @@ def openai_chat_completions_request(
except ValueError as e:
warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
response_json = make_post_request(url, headers, data)
return ChatCompletionResponse(**response_json)
def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse:
"""https://platform.openai.com/docs/api-reference/embeddings/create"""
url = smart_urljoin(url, "embeddings")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
response_json = make_post_request(url, headers, data)
return EmbeddingResponse(**response_json)
return data

View File

@@ -1018,8 +1018,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# new_message = {"function_return": msg, "status": "success"}
assert msg_obj.tool_call_id is not None
print(f"YYY printing the function call - {msg_obj.tool_call_id} == {self.prev_assistant_message_id} ???")
# Skip this is use_assistant_message is on
if self.use_assistant_message and msg_obj.tool_call_id == self.prev_assistant_message_id:
# Wipe the cache

View File

@@ -332,7 +332,7 @@ def agent_id(server, user_id, base_tools):
name="test_agent",
tool_ids=[t.id for t in base_tools],
memory_blocks=[],
model="openai/gpt-4",
model="openai/gpt-4o",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
@@ -353,7 +353,7 @@ def other_agent_id(server, user_id, base_tools):
name="test_agent_other",
tool_ids=[t.id for t in base_tools],
memory_blocks=[],
model="openai/gpt-4",
model="openai/gpt-4o",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
@@ -428,11 +428,11 @@ def test_save_archival_memory(server, user_id, agent_id):
@pytest.mark.order(4)
def test_user_message(server, user, agent_id):
# add data into recall memory
server.user_message(user_id=user.id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
response = server.user_message(user_id=user.id, agent_id=agent_id, message="What's up?")
assert response.step_count == 1
assert response.completion_tokens > 0
assert response.prompt_tokens > 0
assert response.total_tokens > 0
@pytest.mark.order(5)
@@ -552,7 +552,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
request=CreateAgent(
name="nonexistent_tools_agent",
memory_blocks=[],
model="openai/gpt-4",
model="openai/gpt-4o",
embedding="openai/text-embedding-ada-002",
),
actor=user,
@@ -920,7 +920,7 @@ def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, b
CreateBlock(label="human", value="The human's name is Bob."),
CreateBlock(label="persona", value="My name is Alice."),
],
model="openai/gpt-4",
model="openai/gpt-4o",
embedding="openai/text-embedding-ada-002",
),
actor=actor,
@@ -1108,7 +1108,7 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
CreateBlock(label="human", value="The human's name is Bob."),
CreateBlock(label="persona", value="My name is Alice."),
],
model="openai/gpt-4",
model="openai/gpt-4o",
embedding="openai/text-embedding-ada-002",
include_base_tools=False,
),