From 071642b74fe652f2856d84dea4d4a00ffd81c95c Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 1 Oct 2024 16:40:28 -0700 Subject: [PATCH] Finish adding groq and test --- letta/llm_api/groq.py | 439 --------------------------------- letta/llm_api/llm_api_tools.py | 15 +- letta/schemas/llm_config.py | 6 +- tests/test_endpoints.py | 20 +- 4 files changed, 35 insertions(+), 445 deletions(-) delete mode 100644 letta/llm_api/groq.py diff --git a/letta/llm_api/groq.py b/letta/llm_api/groq.py deleted file mode 100644 index 95dcddf2..00000000 --- a/letta/llm_api/groq.py +++ /dev/null @@ -1,439 +0,0 @@ -import json -from typing import Generator, Optional, Union - -import httpx -import requests -from httpx_sse import connect_sse -from httpx_sse._exceptions import SSEError - -from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING -from letta.errors import LLMError -from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages -from letta.schemas.message import Message as _Message -from letta.schemas.message import MessageRole as _MessageRole -from letta.schemas.openai.chat_completion_request import ChatCompletionRequest -from letta.schemas.openai.chat_completion_response import ( - ChatCompletionChunkResponse, - ChatCompletionResponse, - Choice, - FunctionCall, - Message, - ToolCall, - UsageStatistics, -) -from letta.streaming_interface import ( - AgentChunkStreamingInterface, - AgentRefreshStreamingInterface, -) -from letta.utils import smart_urljoin - -OPENAI_SSE_DONE = "[DONE]" - - -def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict: - """https://platform.openai.com/docs/api-reference/models/list""" - from letta.utils import printd - - # In some cases we may want to double-check the URL and do basic correction, eg: - # In Letta config the address for vLLM is w/o a /v1 suffix for simplicity - # However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit - if fix_url: - if not url.endswith("/v1"): - url = smart_urljoin(url, "v1") - - url = smart_urljoin(url, "models") - - headers = {"Content-Type": "application/json"} - if api_key is not None: - headers["Authorization"] = f"Bearer {api_key}" - - printd(f"Sending request to {url}") - try: - response = requests.get(url, headers=headers) - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response = response.json() # convert to dict from string - printd(f"response = {response}") - return response - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - try: - response = response.json() - except: - pass - printd(f"Got HTTPError, exception={http_err}, response={response}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - try: - response = response.json() - except: - pass - printd(f"Got RequestException, exception={req_err}, response={response}") - raise req_err - except Exception as e: - # Handle other potential errors - try: - response = response.json() - except: - pass - printd(f"Got unknown Exception, exception={e}, response={response}") - raise e - - -def openai_chat_completions_process_stream( - url: str, - api_key: str, - chat_completion_request: ChatCompletionRequest, - stream_inferface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None, - create_message_id: bool = True, - create_message_datetime: bool = True, -) -> ChatCompletionResponse: - """Process a streaming completion response, and return a ChatCompletionRequest at the end. - - To "stream" the response in Letta, we want to call a streaming-compatible interface function - on the chunks received from the OpenAI-compatible server POST SSE response. - """ - assert chat_completion_request.stream == True - assert stream_inferface is not None, "Required" - - # Count the prompt tokens - # TODO move to post-request? - chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages] - # print(chat_history) - - prompt_tokens = num_tokens_from_messages( - messages=chat_history, - model=chat_completion_request.model, - ) - # We also need to add the cost of including the functions list to the input prompt - if chat_completion_request.tools is not None: - assert chat_completion_request.functions is None - prompt_tokens += num_tokens_from_functions( - functions=[t.function.model_dump() for t in chat_completion_request.tools], - model=chat_completion_request.model, - ) - elif chat_completion_request.functions is not None: - assert chat_completion_request.tools is None - prompt_tokens += num_tokens_from_functions( - functions=[f.model_dump() for f in chat_completion_request.functions], - model=chat_completion_request.model, - ) - - # Create a dummy Message object to get an ID and date - # TODO(sarah): add message ID generation function - dummy_message = _Message( - role=_MessageRole.assistant, - text="", - user_id="", - agent_id="", - model="", - name=None, - tool_calls=None, - tool_call_id=None, - ) - - TEMP_STREAM_RESPONSE_ID = "temp_id" - TEMP_STREAM_FINISH_REASON = "temp_null" - TEMP_STREAM_TOOL_CALL_ID = "temp_id" - chat_completion_response = ChatCompletionResponse( - id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID, - choices=[], - created=dummy_message.created_at, # NOTE: doesn't matter since both will do get_utc_time() - model=chat_completion_request.model, - usage=UsageStatistics( - completion_tokens=0, - prompt_tokens=prompt_tokens, - total_tokens=prompt_tokens, - ), - ) - - if stream_inferface: - stream_inferface.stream_start() - - n_chunks = 0 # approx == n_tokens - try: - for chunk_idx, chat_completion_chunk in enumerate( - openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request) - ): - assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) - - if stream_inferface: - if isinstance(stream_inferface, AgentChunkStreamingInterface): - stream_inferface.process_chunk( - chat_completion_chunk, - message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id, - message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created, - ) - elif isinstance(stream_inferface, AgentRefreshStreamingInterface): - stream_inferface.process_refresh(chat_completion_response) - else: - raise TypeError(stream_inferface) - - if chunk_idx == 0: - # initialize the choice objects which we will increment with the deltas - num_choices = len(chat_completion_chunk.choices) - assert num_choices > 0 - chat_completion_response.choices = [ - Choice( - finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten - index=i, - message=Message( - role="assistant", - ), - ) - for i in range(len(chat_completion_chunk.choices)) - ] - - # add the choice delta - assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk - for chunk_choice in chat_completion_chunk.choices: - if chunk_choice.finish_reason is not None: - chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason - - if chunk_choice.logprobs is not None: - chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs - - accum_message = chat_completion_response.choices[chunk_choice.index].message - message_delta = chunk_choice.delta - - if message_delta.content is not None: - content_delta = message_delta.content - if accum_message.content is None: - accum_message.content = content_delta - else: - accum_message.content += content_delta - - if message_delta.tool_calls is not None: - tool_calls_delta = message_delta.tool_calls - - # If this is the first tool call showing up in a chunk, initialize the list with it - if accum_message.tool_calls is None: - accum_message.tool_calls = [ - ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments="")) - for _ in range(len(tool_calls_delta)) - ] - - for tool_call_delta in tool_calls_delta: - if tool_call_delta.id is not None: - # TODO assert that we're not overwriting? - # TODO += instead of =? - accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id - if tool_call_delta.function is not None: - if tool_call_delta.function.name is not None: - # TODO assert that we're not overwriting? - # TODO += instead of =? - accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name - if tool_call_delta.function.arguments is not None: - accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments - - if message_delta.function_call is not None: - raise NotImplementedError(f"Old function_call style not support with stream=True") - - # overwrite response fields based on latest chunk - if not create_message_id: - chat_completion_response.id = chat_completion_chunk.id - if not create_message_datetime: - chat_completion_response.created = chat_completion_chunk.created - chat_completion_response.model = chat_completion_chunk.model - chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint - - # increment chunk counter - n_chunks += 1 - - except Exception as e: - if stream_inferface: - stream_inferface.stream_end() - print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}") - raise e - finally: - if stream_inferface: - stream_inferface.stream_end() - - # make sure we didn't leave temp stuff in - assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices]) - assert all( - [ - all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True - for c in chat_completion_response.choices - ] - ) - if not create_message_id: - assert chat_completion_response.id != dummy_message.id - - # compute token usage before returning - # TODO try actually computing the #tokens instead of assuming the chunks is the same - chat_completion_response.usage.completion_tokens = n_chunks - chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks - - # printd(chat_completion_response) - return chat_completion_response - - -def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionChunkResponse, None, None]: - - with httpx.Client() as client: - with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: - - # Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12) - if not event_source.response.is_success: - # handle errors - from letta.utils import printd - - printd("Caught error before iterating SSE request:", vars(event_source.response)) - printd(event_source.response.read()) - - try: - response_bytes = event_source.response.read() - response_dict = json.loads(response_bytes.decode("utf-8")) - error_message = response_dict["error"]["message"] - # e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions. - if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message: - raise LLMError(error_message) - except LLMError: - raise - except: - print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack") - event_source.response.raise_for_status() - - try: - for sse in event_source.iter_sse(): - # printd(sse.event, sse.data, sse.id, sse.retry) - if sse.data == OPENAI_SSE_DONE: - # print("finished") - break - else: - chunk_data = json.loads(sse.data) - # print("chunk_data::", chunk_data) - chunk_object = ChatCompletionChunkResponse(**chunk_data) - # print("chunk_object::", chunk_object) - # id=chunk_data["id"], - # choices=[ChunkChoice], - # model=chunk_data["model"], - # system_fingerprint=chunk_data["system_fingerprint"] - # ) - yield chunk_object - - except SSEError as e: - print("Caught an error while iterating the SSE stream:", str(e)) - if "application/json" in str(e): # Check if the error is because of JSON response - # TODO figure out a better way to catch the error other than re-trying with a POST - response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response - if response.headers["Content-Type"].startswith("application/json"): - error_details = response.json() # Parse the JSON to get the error message - print("Request:", vars(response.request)) - print("POST Error:", error_details) - print("Original SSE Error:", str(e)) - else: - print("Failed to retrieve JSON error message via retry.") - else: - print("SSEError not related to 'application/json' content type.") - - # Optionally re-raise the exception if you need to propagate it - raise e - - except Exception as e: - if event_source.response.request is not None: - print("HTTP Request:", vars(event_source.response.request)) - if event_source.response is not None: - print("HTTP Status:", event_source.response.status_code) - print("HTTP Headers:", event_source.response.headers) - # print("HTTP Body:", event_source.response.text) - print("Exception message:", str(e)) - raise e - - -def openai_chat_completions_request_stream( - url: str, - api_key: str, - chat_completion_request: ChatCompletionRequest, -) -> Generator[ChatCompletionChunkResponse, None, None]: - from letta.utils import printd - - url = smart_urljoin(url, "chat/completions") - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - data = chat_completion_request.model_dump(exclude_none=True) - - printd("Request:\n", json.dumps(data, indent=2)) - - # If functions == None, strip from the payload - if "functions" in data and data["functions"] is None: - data.pop("functions") - data.pop("function_call", None) # extra safe, should exist always (default="auto") - - if "tools" in data and data["tools"] is None: - data.pop("tools") - data.pop("tool_choice", None) # extra safe, should exist always (default="auto") - - printd(f"Sending request to {url}") - try: - return _sse_post(url=url, data=data, headers=headers) - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - printd(f"Got HTTPError, exception={http_err}, payload={data}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") - raise req_err - except Exception as e: - # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") - raise e - - -def openai_chat_completions_request( - url: str, - api_key: str, - chat_completion_request: ChatCompletionRequest, -) -> ChatCompletionResponse: - """Send a ChatCompletion request to an OpenAI-compatible server - - If request.stream == True, will yield ChatCompletionChunkResponses - If request.stream == False, will return a ChatCompletionResponse - - https://platform.openai.com/docs/guides/text-generation?lang=curl - """ - from letta.utils import printd - - url = smart_urljoin(url, "chat/completions") - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - data = chat_completion_request.model_dump(exclude_none=True) - - # add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified." - if chat_completion_request.tools is not None: - data["parallel_tool_calls"] = False - - printd("Request:\n", json.dumps(data, indent=2)) - - # If functions == None, strip from the payload - if "functions" in data and data["functions"] is None: - data.pop("functions") - data.pop("function_call", None) # extra safe, should exist always (default="auto") - - if "tools" in data and data["tools"] is None: - data.pop("tools") - data.pop("tool_choice", None) # extra safe, should exist always (default="auto") - - printd(f"Sending request to {url}") - try: - response = requests.post(url, headers=headers, json=data) - printd(f"response = {response}, response.text = {response.text}") - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - - response = response.json() # convert to dict from string - printd(f"response.json = {response}") - - response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default - return response - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - printd(f"Got HTTPError, exception={http_err}, payload={data}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") - raise req_err - except Exception as e: - # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") - raise e diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 00e5f28c..295f05e6 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -464,7 +464,11 @@ def create( elif llm_config.model_endpoint_type == "groq": if stream: - raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + raise NotImplementedError(f"Streaming not yet implemented for Groq.") + + if credentials.groq_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions": + # only is a problem if we are *not* using an openai proxy + raise ValueError(f"Groq key is missing from letta config file") tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None data = ChatCompletionRequest( @@ -475,10 +479,19 @@ def create( user=str(user_id), ) + # https://console.groq.com/docs/openai + # "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:" + assert data.top_logprobs is None + assert data.logit_bias is None + assert data.logprobs == False + assert data.n == 1 + # They mention that none of the messages can have names, but it seems to not error out (for now) + data.stream = False if isinstance(stream_inferface, AgentChunkStreamingInterface): stream_inferface.stream_start() try: + # groq uses the openai chat completions API, so this component should be reusable response = openai_chat_completions_request( url=llm_config.model_endpoint, api_key=credentials.groq_key, diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index dfc68882..b5ec436f 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -16,8 +16,10 @@ class LLMConfig(BaseModel): """ # TODO: 🤮 don't default to a vendor! bug city! - model: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq"] = Field(..., description="LLM model name. ") - model_endpoint_type: str = Field(..., description="The endpoint type for the model.") + model: str = Field(..., description="LLM model name. ") + model_endpoint_type: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq"] = Field( + ..., description="The endpoint type for the model." + ) model_endpoint: str = Field(..., description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") context_window: int = Field(..., description="The context window size for the model.") diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index f9985fc2..5def1875 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -2,7 +2,7 @@ import json import os import uuid -from letta import create_client +from letta import LocalClient, RESTClient, create_client from letta.agent import Agent from letta.config import LettaConfig from letta.embeddings import embedding_model @@ -22,6 +22,18 @@ llm_config_path = "configs/llm_model_configs/letta-hosted.json" embedding_config_dir = "configs/embedding_model_configs" llm_config_dir = "configs/llm_model_configs" +# Generate uuid for agent name for this example +namespace = uuid.NAMESPACE_DNS +agent_uuid = str(uuid.uuid5(namespace, "letta-endpoint-tests")) + + +def clean_up_agent(client: LocalClient | RESTClient): + # Clear all agents + for agent_state in client.list_agents(): + if agent_state.name == agent_uuid: + client.delete_agent(agent_id=agent_state.id) + print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") + def run_llm_endpoint(filename): config_data = json.load(open(filename, "r")) @@ -36,13 +48,14 @@ def run_llm_endpoint(filename): config.save() client = create_client() - agent_state = client.create_agent(name="test_agent", llm_config=llm_config, embedding_config=embedding_config) + clean_up_agent(client) + agent_state = client.create_agent(name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config) tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools] agent = Agent( interface=None, tools=tools, agent_state=agent_state, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now + # gpt-3.5-turbo tends to omit inner monologue, relax th is requirement for now first_message_verify_mono=True, ) @@ -56,6 +69,7 @@ def run_llm_endpoint(filename): ) client.delete_agent(agent_state.id) assert response is not None + print(response) def run_embedding_endpoint(filename):