From 4035a211fb47b4a2ef914f07d6640e6f508df005 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 1 Oct 2024 15:42:59 -0700 Subject: [PATCH 01/16] wip --- letta/llm_api/llm_api_tools.py | 6 +++++- letta/schemas/llm_config.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 93753a55..35f7c8e2 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -455,7 +455,7 @@ def create( chat_completion_request=ChatCompletionRequest( model="command-r-plus", # TODO messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], - tools=[{"type": "function", "function": f} for f in functions] if functions else None, + tools=tools, tool_choice=function_call, # user=str(user_id), # NOTE: max_tokens is required for Anthropic API @@ -463,6 +463,10 @@ def create( ), ) + elif llm_config.model_endpoint_type == "groq": + if stream: + raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + # local model else: if stream: diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 134dff02..dfc68882 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel, ConfigDict, Field @@ -16,7 +16,7 @@ class LLMConfig(BaseModel): """ # TODO: 🤮 don't default to a vendor! bug city! - model: str = Field(..., description="LLM model name. ") + model: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq"] = Field(..., description="LLM model name. ") model_endpoint_type: str = Field(..., description="The endpoint type for the model.") model_endpoint: str = Field(..., description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") From d61b806fd5cb01aaf50b826f3959bdc1f785588d Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 1 Oct 2024 16:22:34 -0700 Subject: [PATCH 02/16] wip --- letta/credentials.py | 4 + letta/llm_api/groq.py | 439 +++++++++++++++++++++++++++++++++ letta/llm_api/llm_api_tools.py | 25 +- tests/test_endpoints.py | 5 + 4 files changed, 472 insertions(+), 1 deletion(-) create mode 100644 letta/llm_api/groq.py diff --git a/letta/credentials.py b/letta/credentials.py index b5b58f94..ea92cc29 100644 --- a/letta/credentials.py +++ b/letta/credentials.py @@ -31,6 +31,10 @@ class LettaCredentials: # azure config azure_auth_type: str = "api_key" azure_key: Optional[str] = None + + # groq config + groq_key: Optional[str] = os.getenv("GROQ_API_KEY") + # base llm / model azure_version: Optional[str] = None azure_endpoint: Optional[str] = None diff --git a/letta/llm_api/groq.py b/letta/llm_api/groq.py new file mode 100644 index 00000000..95dcddf2 --- /dev/null +++ b/letta/llm_api/groq.py @@ -0,0 +1,439 @@ +import json +from typing import Generator, Optional, Union + +import httpx +import requests +from httpx_sse import connect_sse +from httpx_sse._exceptions import SSEError + +from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING +from letta.errors import LLMError +from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages +from letta.schemas.message import Message as _Message +from letta.schemas.message import MessageRole as _MessageRole +from letta.schemas.openai.chat_completion_request import ChatCompletionRequest +from letta.schemas.openai.chat_completion_response import ( + ChatCompletionChunkResponse, + ChatCompletionResponse, + Choice, + FunctionCall, + Message, + ToolCall, + UsageStatistics, +) +from letta.streaming_interface import ( + AgentChunkStreamingInterface, + AgentRefreshStreamingInterface, +) +from letta.utils import smart_urljoin + +OPENAI_SSE_DONE = "[DONE]" + + +def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict: + """https://platform.openai.com/docs/api-reference/models/list""" + from letta.utils import printd + + # In some cases we may want to double-check the URL and do basic correction, eg: + # In Letta config the address for vLLM is w/o a /v1 suffix for simplicity + # However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit + if fix_url: + if not url.endswith("/v1"): + url = smart_urljoin(url, "v1") + + url = smart_urljoin(url, "models") + + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + printd(f"Sending request to {url}") + try: + response = requests.get(url, headers=headers) + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + response = response.json() # convert to dict from string + printd(f"response = {response}") + return response + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + try: + response = response.json() + except: + pass + printd(f"Got HTTPError, exception={http_err}, response={response}") + raise http_err + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + try: + response = response.json() + except: + pass + printd(f"Got RequestException, exception={req_err}, response={response}") + raise req_err + except Exception as e: + # Handle other potential errors + try: + response = response.json() + except: + pass + printd(f"Got unknown Exception, exception={e}, response={response}") + raise e + + +def openai_chat_completions_process_stream( + url: str, + api_key: str, + chat_completion_request: ChatCompletionRequest, + stream_inferface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None, + create_message_id: bool = True, + create_message_datetime: bool = True, +) -> ChatCompletionResponse: + """Process a streaming completion response, and return a ChatCompletionRequest at the end. + + To "stream" the response in Letta, we want to call a streaming-compatible interface function + on the chunks received from the OpenAI-compatible server POST SSE response. + """ + assert chat_completion_request.stream == True + assert stream_inferface is not None, "Required" + + # Count the prompt tokens + # TODO move to post-request? + chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages] + # print(chat_history) + + prompt_tokens = num_tokens_from_messages( + messages=chat_history, + model=chat_completion_request.model, + ) + # We also need to add the cost of including the functions list to the input prompt + if chat_completion_request.tools is not None: + assert chat_completion_request.functions is None + prompt_tokens += num_tokens_from_functions( + functions=[t.function.model_dump() for t in chat_completion_request.tools], + model=chat_completion_request.model, + ) + elif chat_completion_request.functions is not None: + assert chat_completion_request.tools is None + prompt_tokens += num_tokens_from_functions( + functions=[f.model_dump() for f in chat_completion_request.functions], + model=chat_completion_request.model, + ) + + # Create a dummy Message object to get an ID and date + # TODO(sarah): add message ID generation function + dummy_message = _Message( + role=_MessageRole.assistant, + text="", + user_id="", + agent_id="", + model="", + name=None, + tool_calls=None, + tool_call_id=None, + ) + + TEMP_STREAM_RESPONSE_ID = "temp_id" + TEMP_STREAM_FINISH_REASON = "temp_null" + TEMP_STREAM_TOOL_CALL_ID = "temp_id" + chat_completion_response = ChatCompletionResponse( + id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID, + choices=[], + created=dummy_message.created_at, # NOTE: doesn't matter since both will do get_utc_time() + model=chat_completion_request.model, + usage=UsageStatistics( + completion_tokens=0, + prompt_tokens=prompt_tokens, + total_tokens=prompt_tokens, + ), + ) + + if stream_inferface: + stream_inferface.stream_start() + + n_chunks = 0 # approx == n_tokens + try: + for chunk_idx, chat_completion_chunk in enumerate( + openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request) + ): + assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) + + if stream_inferface: + if isinstance(stream_inferface, AgentChunkStreamingInterface): + stream_inferface.process_chunk( + chat_completion_chunk, + message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id, + message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created, + ) + elif isinstance(stream_inferface, AgentRefreshStreamingInterface): + stream_inferface.process_refresh(chat_completion_response) + else: + raise TypeError(stream_inferface) + + if chunk_idx == 0: + # initialize the choice objects which we will increment with the deltas + num_choices = len(chat_completion_chunk.choices) + assert num_choices > 0 + chat_completion_response.choices = [ + Choice( + finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten + index=i, + message=Message( + role="assistant", + ), + ) + for i in range(len(chat_completion_chunk.choices)) + ] + + # add the choice delta + assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk + for chunk_choice in chat_completion_chunk.choices: + if chunk_choice.finish_reason is not None: + chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason + + if chunk_choice.logprobs is not None: + chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs + + accum_message = chat_completion_response.choices[chunk_choice.index].message + message_delta = chunk_choice.delta + + if message_delta.content is not None: + content_delta = message_delta.content + if accum_message.content is None: + accum_message.content = content_delta + else: + accum_message.content += content_delta + + if message_delta.tool_calls is not None: + tool_calls_delta = message_delta.tool_calls + + # If this is the first tool call showing up in a chunk, initialize the list with it + if accum_message.tool_calls is None: + accum_message.tool_calls = [ + ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments="")) + for _ in range(len(tool_calls_delta)) + ] + + for tool_call_delta in tool_calls_delta: + if tool_call_delta.id is not None: + # TODO assert that we're not overwriting? + # TODO += instead of =? + accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id + if tool_call_delta.function is not None: + if tool_call_delta.function.name is not None: + # TODO assert that we're not overwriting? + # TODO += instead of =? + accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name + if tool_call_delta.function.arguments is not None: + accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments + + if message_delta.function_call is not None: + raise NotImplementedError(f"Old function_call style not support with stream=True") + + # overwrite response fields based on latest chunk + if not create_message_id: + chat_completion_response.id = chat_completion_chunk.id + if not create_message_datetime: + chat_completion_response.created = chat_completion_chunk.created + chat_completion_response.model = chat_completion_chunk.model + chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint + + # increment chunk counter + n_chunks += 1 + + except Exception as e: + if stream_inferface: + stream_inferface.stream_end() + print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}") + raise e + finally: + if stream_inferface: + stream_inferface.stream_end() + + # make sure we didn't leave temp stuff in + assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices]) + assert all( + [ + all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True + for c in chat_completion_response.choices + ] + ) + if not create_message_id: + assert chat_completion_response.id != dummy_message.id + + # compute token usage before returning + # TODO try actually computing the #tokens instead of assuming the chunks is the same + chat_completion_response.usage.completion_tokens = n_chunks + chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks + + # printd(chat_completion_response) + return chat_completion_response + + +def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionChunkResponse, None, None]: + + with httpx.Client() as client: + with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: + + # Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12) + if not event_source.response.is_success: + # handle errors + from letta.utils import printd + + printd("Caught error before iterating SSE request:", vars(event_source.response)) + printd(event_source.response.read()) + + try: + response_bytes = event_source.response.read() + response_dict = json.loads(response_bytes.decode("utf-8")) + error_message = response_dict["error"]["message"] + # e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions. + if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message: + raise LLMError(error_message) + except LLMError: + raise + except: + print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack") + event_source.response.raise_for_status() + + try: + for sse in event_source.iter_sse(): + # printd(sse.event, sse.data, sse.id, sse.retry) + if sse.data == OPENAI_SSE_DONE: + # print("finished") + break + else: + chunk_data = json.loads(sse.data) + # print("chunk_data::", chunk_data) + chunk_object = ChatCompletionChunkResponse(**chunk_data) + # print("chunk_object::", chunk_object) + # id=chunk_data["id"], + # choices=[ChunkChoice], + # model=chunk_data["model"], + # system_fingerprint=chunk_data["system_fingerprint"] + # ) + yield chunk_object + + except SSEError as e: + print("Caught an error while iterating the SSE stream:", str(e)) + if "application/json" in str(e): # Check if the error is because of JSON response + # TODO figure out a better way to catch the error other than re-trying with a POST + response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response + if response.headers["Content-Type"].startswith("application/json"): + error_details = response.json() # Parse the JSON to get the error message + print("Request:", vars(response.request)) + print("POST Error:", error_details) + print("Original SSE Error:", str(e)) + else: + print("Failed to retrieve JSON error message via retry.") + else: + print("SSEError not related to 'application/json' content type.") + + # Optionally re-raise the exception if you need to propagate it + raise e + + except Exception as e: + if event_source.response.request is not None: + print("HTTP Request:", vars(event_source.response.request)) + if event_source.response is not None: + print("HTTP Status:", event_source.response.status_code) + print("HTTP Headers:", event_source.response.headers) + # print("HTTP Body:", event_source.response.text) + print("Exception message:", str(e)) + raise e + + +def openai_chat_completions_request_stream( + url: str, + api_key: str, + chat_completion_request: ChatCompletionRequest, +) -> Generator[ChatCompletionChunkResponse, None, None]: + from letta.utils import printd + + url = smart_urljoin(url, "chat/completions") + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + data = chat_completion_request.model_dump(exclude_none=True) + + printd("Request:\n", json.dumps(data, indent=2)) + + # If functions == None, strip from the payload + if "functions" in data and data["functions"] is None: + data.pop("functions") + data.pop("function_call", None) # extra safe, should exist always (default="auto") + + if "tools" in data and data["tools"] is None: + data.pop("tools") + data.pop("tool_choice", None) # extra safe, should exist always (default="auto") + + printd(f"Sending request to {url}") + try: + return _sse_post(url=url, data=data, headers=headers) + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}, payload={data}") + raise http_err + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e + + +def openai_chat_completions_request( + url: str, + api_key: str, + chat_completion_request: ChatCompletionRequest, +) -> ChatCompletionResponse: + """Send a ChatCompletion request to an OpenAI-compatible server + + If request.stream == True, will yield ChatCompletionChunkResponses + If request.stream == False, will return a ChatCompletionResponse + + https://platform.openai.com/docs/guides/text-generation?lang=curl + """ + from letta.utils import printd + + url = smart_urljoin(url, "chat/completions") + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + data = chat_completion_request.model_dump(exclude_none=True) + + # add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified." + if chat_completion_request.tools is not None: + data["parallel_tool_calls"] = False + + printd("Request:\n", json.dumps(data, indent=2)) + + # If functions == None, strip from the payload + if "functions" in data and data["functions"] is None: + data.pop("functions") + data.pop("function_call", None) # extra safe, should exist always (default="auto") + + if "tools" in data and data["tools"] is None: + data.pop("tools") + data.pop("tool_choice", None) # extra safe, should exist always (default="auto") + + printd(f"Sending request to {url}") + try: + response = requests.post(url, headers=headers, json=data) + printd(f"response = {response}, response.text = {response.text}") + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + + response = response.json() # convert to dict from string + printd(f"response.json = {response}") + + response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default + return response + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + printd(f"Got HTTPError, exception={http_err}, payload={data}") + raise http_err + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + printd(f"Got RequestException, exception={req_err}") + raise req_err + except Exception as e: + # Handle other potential errors + printd(f"Got unknown Exception, exception={e}") + raise e diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 35f7c8e2..00e5f28c 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -332,7 +332,6 @@ def create( if isinstance(stream_inferface, AgentChunkStreamingInterface): stream_inferface.stream_start() try: - response = openai_chat_completions_request( url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, @@ -467,6 +466,30 @@ def create( if stream: raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None + data = ChatCompletionRequest( + model=llm_config.model, + messages=[m.to_openai_dict() for m in messages], + tools=tools, + tool_choice=function_call, + user=str(user_id), + ) + + data.stream = False + if isinstance(stream_inferface, AgentChunkStreamingInterface): + stream_inferface.stream_start() + try: + response = openai_chat_completions_request( + url=llm_config.model_endpoint, + api_key=credentials.groq_key, + chat_completion_request=data, + ) + finally: + if isinstance(stream_inferface, AgentChunkStreamingInterface): + stream_inferface.stream_end() + + return response + # local model else: if stream: diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index ada66b71..f9985fc2 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -108,3 +108,8 @@ def test_embedding_endpoint_ollama(): def test_llm_endpoint_anthropic(): filename = os.path.join(llm_config_dir, "anthropic.json") run_llm_endpoint(filename) + + +def test_llm_endpoint_groq(): + filename = os.path.join(llm_config_dir, "groq.json") + run_llm_endpoint(filename) From 071642b74fe652f2856d84dea4d4a00ffd81c95c Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 1 Oct 2024 16:40:28 -0700 Subject: [PATCH 03/16] Finish adding groq and test --- letta/llm_api/groq.py | 439 --------------------------------- letta/llm_api/llm_api_tools.py | 15 +- letta/schemas/llm_config.py | 6 +- tests/test_endpoints.py | 20 +- 4 files changed, 35 insertions(+), 445 deletions(-) delete mode 100644 letta/llm_api/groq.py diff --git a/letta/llm_api/groq.py b/letta/llm_api/groq.py deleted file mode 100644 index 95dcddf2..00000000 --- a/letta/llm_api/groq.py +++ /dev/null @@ -1,439 +0,0 @@ -import json -from typing import Generator, Optional, Union - -import httpx -import requests -from httpx_sse import connect_sse -from httpx_sse._exceptions import SSEError - -from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING -from letta.errors import LLMError -from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages -from letta.schemas.message import Message as _Message -from letta.schemas.message import MessageRole as _MessageRole -from letta.schemas.openai.chat_completion_request import ChatCompletionRequest -from letta.schemas.openai.chat_completion_response import ( - ChatCompletionChunkResponse, - ChatCompletionResponse, - Choice, - FunctionCall, - Message, - ToolCall, - UsageStatistics, -) -from letta.streaming_interface import ( - AgentChunkStreamingInterface, - AgentRefreshStreamingInterface, -) -from letta.utils import smart_urljoin - -OPENAI_SSE_DONE = "[DONE]" - - -def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict: - """https://platform.openai.com/docs/api-reference/models/list""" - from letta.utils import printd - - # In some cases we may want to double-check the URL and do basic correction, eg: - # In Letta config the address for vLLM is w/o a /v1 suffix for simplicity - # However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit - if fix_url: - if not url.endswith("/v1"): - url = smart_urljoin(url, "v1") - - url = smart_urljoin(url, "models") - - headers = {"Content-Type": "application/json"} - if api_key is not None: - headers["Authorization"] = f"Bearer {api_key}" - - printd(f"Sending request to {url}") - try: - response = requests.get(url, headers=headers) - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - response = response.json() # convert to dict from string - printd(f"response = {response}") - return response - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - try: - response = response.json() - except: - pass - printd(f"Got HTTPError, exception={http_err}, response={response}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - try: - response = response.json() - except: - pass - printd(f"Got RequestException, exception={req_err}, response={response}") - raise req_err - except Exception as e: - # Handle other potential errors - try: - response = response.json() - except: - pass - printd(f"Got unknown Exception, exception={e}, response={response}") - raise e - - -def openai_chat_completions_process_stream( - url: str, - api_key: str, - chat_completion_request: ChatCompletionRequest, - stream_inferface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None, - create_message_id: bool = True, - create_message_datetime: bool = True, -) -> ChatCompletionResponse: - """Process a streaming completion response, and return a ChatCompletionRequest at the end. - - To "stream" the response in Letta, we want to call a streaming-compatible interface function - on the chunks received from the OpenAI-compatible server POST SSE response. - """ - assert chat_completion_request.stream == True - assert stream_inferface is not None, "Required" - - # Count the prompt tokens - # TODO move to post-request? - chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages] - # print(chat_history) - - prompt_tokens = num_tokens_from_messages( - messages=chat_history, - model=chat_completion_request.model, - ) - # We also need to add the cost of including the functions list to the input prompt - if chat_completion_request.tools is not None: - assert chat_completion_request.functions is None - prompt_tokens += num_tokens_from_functions( - functions=[t.function.model_dump() for t in chat_completion_request.tools], - model=chat_completion_request.model, - ) - elif chat_completion_request.functions is not None: - assert chat_completion_request.tools is None - prompt_tokens += num_tokens_from_functions( - functions=[f.model_dump() for f in chat_completion_request.functions], - model=chat_completion_request.model, - ) - - # Create a dummy Message object to get an ID and date - # TODO(sarah): add message ID generation function - dummy_message = _Message( - role=_MessageRole.assistant, - text="", - user_id="", - agent_id="", - model="", - name=None, - tool_calls=None, - tool_call_id=None, - ) - - TEMP_STREAM_RESPONSE_ID = "temp_id" - TEMP_STREAM_FINISH_REASON = "temp_null" - TEMP_STREAM_TOOL_CALL_ID = "temp_id" - chat_completion_response = ChatCompletionResponse( - id=dummy_message.id if create_message_id else TEMP_STREAM_RESPONSE_ID, - choices=[], - created=dummy_message.created_at, # NOTE: doesn't matter since both will do get_utc_time() - model=chat_completion_request.model, - usage=UsageStatistics( - completion_tokens=0, - prompt_tokens=prompt_tokens, - total_tokens=prompt_tokens, - ), - ) - - if stream_inferface: - stream_inferface.stream_start() - - n_chunks = 0 # approx == n_tokens - try: - for chunk_idx, chat_completion_chunk in enumerate( - openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request) - ): - assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) - - if stream_inferface: - if isinstance(stream_inferface, AgentChunkStreamingInterface): - stream_inferface.process_chunk( - chat_completion_chunk, - message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id, - message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created, - ) - elif isinstance(stream_inferface, AgentRefreshStreamingInterface): - stream_inferface.process_refresh(chat_completion_response) - else: - raise TypeError(stream_inferface) - - if chunk_idx == 0: - # initialize the choice objects which we will increment with the deltas - num_choices = len(chat_completion_chunk.choices) - assert num_choices > 0 - chat_completion_response.choices = [ - Choice( - finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten - index=i, - message=Message( - role="assistant", - ), - ) - for i in range(len(chat_completion_chunk.choices)) - ] - - # add the choice delta - assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk - for chunk_choice in chat_completion_chunk.choices: - if chunk_choice.finish_reason is not None: - chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason - - if chunk_choice.logprobs is not None: - chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs - - accum_message = chat_completion_response.choices[chunk_choice.index].message - message_delta = chunk_choice.delta - - if message_delta.content is not None: - content_delta = message_delta.content - if accum_message.content is None: - accum_message.content = content_delta - else: - accum_message.content += content_delta - - if message_delta.tool_calls is not None: - tool_calls_delta = message_delta.tool_calls - - # If this is the first tool call showing up in a chunk, initialize the list with it - if accum_message.tool_calls is None: - accum_message.tool_calls = [ - ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments="")) - for _ in range(len(tool_calls_delta)) - ] - - for tool_call_delta in tool_calls_delta: - if tool_call_delta.id is not None: - # TODO assert that we're not overwriting? - # TODO += instead of =? - accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id - if tool_call_delta.function is not None: - if tool_call_delta.function.name is not None: - # TODO assert that we're not overwriting? - # TODO += instead of =? - accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name - if tool_call_delta.function.arguments is not None: - accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments - - if message_delta.function_call is not None: - raise NotImplementedError(f"Old function_call style not support with stream=True") - - # overwrite response fields based on latest chunk - if not create_message_id: - chat_completion_response.id = chat_completion_chunk.id - if not create_message_datetime: - chat_completion_response.created = chat_completion_chunk.created - chat_completion_response.model = chat_completion_chunk.model - chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint - - # increment chunk counter - n_chunks += 1 - - except Exception as e: - if stream_inferface: - stream_inferface.stream_end() - print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}") - raise e - finally: - if stream_inferface: - stream_inferface.stream_end() - - # make sure we didn't leave temp stuff in - assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices]) - assert all( - [ - all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True - for c in chat_completion_response.choices - ] - ) - if not create_message_id: - assert chat_completion_response.id != dummy_message.id - - # compute token usage before returning - # TODO try actually computing the #tokens instead of assuming the chunks is the same - chat_completion_response.usage.completion_tokens = n_chunks - chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks - - # printd(chat_completion_response) - return chat_completion_response - - -def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionChunkResponse, None, None]: - - with httpx.Client() as client: - with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: - - # Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12) - if not event_source.response.is_success: - # handle errors - from letta.utils import printd - - printd("Caught error before iterating SSE request:", vars(event_source.response)) - printd(event_source.response.read()) - - try: - response_bytes = event_source.response.read() - response_dict = json.loads(response_bytes.decode("utf-8")) - error_message = response_dict["error"]["message"] - # e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions. - if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message: - raise LLMError(error_message) - except LLMError: - raise - except: - print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack") - event_source.response.raise_for_status() - - try: - for sse in event_source.iter_sse(): - # printd(sse.event, sse.data, sse.id, sse.retry) - if sse.data == OPENAI_SSE_DONE: - # print("finished") - break - else: - chunk_data = json.loads(sse.data) - # print("chunk_data::", chunk_data) - chunk_object = ChatCompletionChunkResponse(**chunk_data) - # print("chunk_object::", chunk_object) - # id=chunk_data["id"], - # choices=[ChunkChoice], - # model=chunk_data["model"], - # system_fingerprint=chunk_data["system_fingerprint"] - # ) - yield chunk_object - - except SSEError as e: - print("Caught an error while iterating the SSE stream:", str(e)) - if "application/json" in str(e): # Check if the error is because of JSON response - # TODO figure out a better way to catch the error other than re-trying with a POST - response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response - if response.headers["Content-Type"].startswith("application/json"): - error_details = response.json() # Parse the JSON to get the error message - print("Request:", vars(response.request)) - print("POST Error:", error_details) - print("Original SSE Error:", str(e)) - else: - print("Failed to retrieve JSON error message via retry.") - else: - print("SSEError not related to 'application/json' content type.") - - # Optionally re-raise the exception if you need to propagate it - raise e - - except Exception as e: - if event_source.response.request is not None: - print("HTTP Request:", vars(event_source.response.request)) - if event_source.response is not None: - print("HTTP Status:", event_source.response.status_code) - print("HTTP Headers:", event_source.response.headers) - # print("HTTP Body:", event_source.response.text) - print("Exception message:", str(e)) - raise e - - -def openai_chat_completions_request_stream( - url: str, - api_key: str, - chat_completion_request: ChatCompletionRequest, -) -> Generator[ChatCompletionChunkResponse, None, None]: - from letta.utils import printd - - url = smart_urljoin(url, "chat/completions") - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - data = chat_completion_request.model_dump(exclude_none=True) - - printd("Request:\n", json.dumps(data, indent=2)) - - # If functions == None, strip from the payload - if "functions" in data and data["functions"] is None: - data.pop("functions") - data.pop("function_call", None) # extra safe, should exist always (default="auto") - - if "tools" in data and data["tools"] is None: - data.pop("tools") - data.pop("tool_choice", None) # extra safe, should exist always (default="auto") - - printd(f"Sending request to {url}") - try: - return _sse_post(url=url, data=data, headers=headers) - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - printd(f"Got HTTPError, exception={http_err}, payload={data}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") - raise req_err - except Exception as e: - # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") - raise e - - -def openai_chat_completions_request( - url: str, - api_key: str, - chat_completion_request: ChatCompletionRequest, -) -> ChatCompletionResponse: - """Send a ChatCompletion request to an OpenAI-compatible server - - If request.stream == True, will yield ChatCompletionChunkResponses - If request.stream == False, will return a ChatCompletionResponse - - https://platform.openai.com/docs/guides/text-generation?lang=curl - """ - from letta.utils import printd - - url = smart_urljoin(url, "chat/completions") - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - data = chat_completion_request.model_dump(exclude_none=True) - - # add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified." - if chat_completion_request.tools is not None: - data["parallel_tool_calls"] = False - - printd("Request:\n", json.dumps(data, indent=2)) - - # If functions == None, strip from the payload - if "functions" in data and data["functions"] is None: - data.pop("functions") - data.pop("function_call", None) # extra safe, should exist always (default="auto") - - if "tools" in data and data["tools"] is None: - data.pop("tools") - data.pop("tool_choice", None) # extra safe, should exist always (default="auto") - - printd(f"Sending request to {url}") - try: - response = requests.post(url, headers=headers, json=data) - printd(f"response = {response}, response.text = {response.text}") - response.raise_for_status() # Raises HTTPError for 4XX/5XX status - - response = response.json() # convert to dict from string - printd(f"response.json = {response}") - - response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default - return response - except requests.exceptions.HTTPError as http_err: - # Handle HTTP errors (e.g., response 4XX, 5XX) - printd(f"Got HTTPError, exception={http_err}, payload={data}") - raise http_err - except requests.exceptions.RequestException as req_err: - # Handle other requests-related errors (e.g., connection error) - printd(f"Got RequestException, exception={req_err}") - raise req_err - except Exception as e: - # Handle other potential errors - printd(f"Got unknown Exception, exception={e}") - raise e diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 00e5f28c..295f05e6 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -464,7 +464,11 @@ def create( elif llm_config.model_endpoint_type == "groq": if stream: - raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + raise NotImplementedError(f"Streaming not yet implemented for Groq.") + + if credentials.groq_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions": + # only is a problem if we are *not* using an openai proxy + raise ValueError(f"Groq key is missing from letta config file") tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None data = ChatCompletionRequest( @@ -475,10 +479,19 @@ def create( user=str(user_id), ) + # https://console.groq.com/docs/openai + # "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:" + assert data.top_logprobs is None + assert data.logit_bias is None + assert data.logprobs == False + assert data.n == 1 + # They mention that none of the messages can have names, but it seems to not error out (for now) + data.stream = False if isinstance(stream_inferface, AgentChunkStreamingInterface): stream_inferface.stream_start() try: + # groq uses the openai chat completions API, so this component should be reusable response = openai_chat_completions_request( url=llm_config.model_endpoint, api_key=credentials.groq_key, diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index dfc68882..b5ec436f 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -16,8 +16,10 @@ class LLMConfig(BaseModel): """ # TODO: 🤮 don't default to a vendor! bug city! - model: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq"] = Field(..., description="LLM model name. ") - model_endpoint_type: str = Field(..., description="The endpoint type for the model.") + model: str = Field(..., description="LLM model name. ") + model_endpoint_type: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq"] = Field( + ..., description="The endpoint type for the model." + ) model_endpoint: str = Field(..., description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") context_window: int = Field(..., description="The context window size for the model.") diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index f9985fc2..5def1875 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -2,7 +2,7 @@ import json import os import uuid -from letta import create_client +from letta import LocalClient, RESTClient, create_client from letta.agent import Agent from letta.config import LettaConfig from letta.embeddings import embedding_model @@ -22,6 +22,18 @@ llm_config_path = "configs/llm_model_configs/letta-hosted.json" embedding_config_dir = "configs/embedding_model_configs" llm_config_dir = "configs/llm_model_configs" +# Generate uuid for agent name for this example +namespace = uuid.NAMESPACE_DNS +agent_uuid = str(uuid.uuid5(namespace, "letta-endpoint-tests")) + + +def clean_up_agent(client: LocalClient | RESTClient): + # Clear all agents + for agent_state in client.list_agents(): + if agent_state.name == agent_uuid: + client.delete_agent(agent_id=agent_state.id) + print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") + def run_llm_endpoint(filename): config_data = json.load(open(filename, "r")) @@ -36,13 +48,14 @@ def run_llm_endpoint(filename): config.save() client = create_client() - agent_state = client.create_agent(name="test_agent", llm_config=llm_config, embedding_config=embedding_config) + clean_up_agent(client) + agent_state = client.create_agent(name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config) tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tools] agent = Agent( interface=None, tools=tools, agent_state=agent_state, - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now + # gpt-3.5-turbo tends to omit inner monologue, relax th is requirement for now first_message_verify_mono=True, ) @@ -56,6 +69,7 @@ def run_llm_endpoint(filename): ) client.delete_agent(agent_state.id) assert response is not None + print(response) def run_embedding_endpoint(filename): From a3b46dc02638f219a2b0d4958ba4f5f234e6dfe8 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 1 Oct 2024 16:42:47 -0700 Subject: [PATCH 04/16] Fix typo --- tests/test_endpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 5def1875..782cbf55 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -55,7 +55,7 @@ def run_llm_endpoint(filename): interface=None, tools=tools, agent_state=agent_state, - # gpt-3.5-turbo tends to omit inner monologue, relax th is requirement for now + # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now first_message_verify_mono=True, ) From e20c4270e4e703593fa0e8f7674c7474de366931 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 1 Oct 2024 16:45:52 -0700 Subject: [PATCH 05/16] add ollama --- letta/schemas/llm_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index b5ec436f..dd71cbd7 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -17,7 +17,7 @@ class LLMConfig(BaseModel): # TODO: 🤮 don't default to a vendor! bug city! model: str = Field(..., description="LLM model name. ") - model_endpoint_type: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq"] = Field( + model_endpoint_type: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq", "ollama"] = Field( ..., description="The endpoint type for the model." ) model_endpoint: str = Field(..., description="The endpoint for the model.") From 79b078eeffaf2d18d373682ea3d50eaa4de05f80 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 1 Oct 2024 17:10:23 -0700 Subject: [PATCH 06/16] Delete old code --- letta/local_llm/groq/api.py | 97 ------------------------------------- 1 file changed, 97 deletions(-) delete mode 100644 letta/local_llm/groq/api.py diff --git a/letta/local_llm/groq/api.py b/letta/local_llm/groq/api.py deleted file mode 100644 index b46ddf61..00000000 --- a/letta/local_llm/groq/api.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Tuple -from urllib.parse import urljoin - -from letta.local_llm.settings.settings import get_completions_settings -from letta.local_llm.utils import post_json_auth_request -from letta.utils import count_tokens - -API_CHAT_SUFFIX = "/v1/chat/completions" -# LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" - - -def get_groq_completion(endpoint: str, auth_type: str, auth_key: str, model: str, prompt: str, context_window: int) -> Tuple[str, dict]: - """TODO no support for function calling OR raw completions, so we need to route the request into /chat/completions instead""" - from letta.utils import printd - - prompt_tokens = count_tokens(prompt) - if prompt_tokens > context_window: - raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") - - settings = get_completions_settings() - settings.update( - { - # see https://console.groq.com/docs/text-chat, supports: - # "temperature": , - # "max_tokens": , - # "top_p", - # "stream", - # "stop", - # Groq only allows 4 stop tokens - "stop": [ - "\nUSER", - "\nASSISTANT", - "\nFUNCTION", - # "\nFUNCTION RETURN", - # "<|im_start|>", - # "<|im_end|>", - # "<|im_sep|>", - # # airoboros specific - # "\n### ", - # # '\n' + - # # '', - # # '<|', - # "\n#", - # # "\n\n\n", - # # prevent chaining function calls / multi json objects / run-on generations - # # NOTE: this requires the ability to patch the extra '}}' back into the prompt - " }\n}\n", - ] - } - ) - - URI = urljoin(endpoint.strip("/") + "/", API_CHAT_SUFFIX.strip("/")) - - # Settings for the generation, includes the prompt + stop tokens, max length, etc - request = settings - request["model"] = model - request["max_tokens"] = context_window - # NOTE: Hack for chat/completion-only endpoints: put the entire completion string inside the first message - message_structure = [{"role": "user", "content": prompt}] - request["messages"] = message_structure - - if not endpoint.startswith(("http://", "https://")): - raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") - - try: - response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) - if response.status_code == 200: - result_full = response.json() - printd(f"JSON API response:\n{result_full}") - result = result_full["choices"][0]["message"]["content"] - usage = result_full.get("usage", None) - else: - # Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"} - if "context length" in str(response.text).lower(): - # "exceeds context length" is what appears in the LM Studio error message - # raise an alternate exception that matches OpenAI's message, which is "maximum context length" - raise Exception(f"Request exceeds maximum context length (code={response.status_code}, msg={response.text}, URI={URI})") - else: - raise Exception( - f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." - + f" Make sure that the inference server is running and reachable at {URI}." - ) - except: - # TODO handle gracefully - raise - - # Pass usage statistics back to main thread - # These are used to compute memory warning messages - completion_tokens = usage.get("completion_tokens", None) if usage is not None else None - total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None - usage = { - "prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0) - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - } - - return result, usage From 62241e0d03b268170d8bfef4ea43a8d2ecf9fd04 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Tue, 1 Oct 2024 17:12:09 -0700 Subject: [PATCH 07/16] Remove groq old imports --- letta/local_llm/chat_completion_proxy.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/letta/local_llm/chat_completion_proxy.py b/letta/local_llm/chat_completion_proxy.py index bdec58b6..25b91420 100644 --- a/letta/local_llm/chat_completion_proxy.py +++ b/letta/local_llm/chat_completion_proxy.py @@ -12,7 +12,6 @@ from letta.local_llm.grammars.gbnf_grammar_generator import ( create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation, ) -from letta.local_llm.groq.api import get_groq_completion from letta.local_llm.koboldcpp.api import get_koboldcpp_completion from letta.local_llm.llamacpp.api import get_llamacpp_completion from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper @@ -170,8 +169,6 @@ def get_chat_completion( result, usage = get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_window) elif endpoint_type == "vllm": result, usage = get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user) - elif endpoint_type == "groq": - result, usage = get_groq_completion(endpoint, auth_type, auth_key, model, prompt, context_window) else: raise LocalLLMError( f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" From b698cb1981d5120345368af213ebd9d4dacfc308 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Wed, 2 Oct 2024 09:39:10 -0700 Subject: [PATCH 08/16] Modify gitignore and add configs --- .gitignore | 3 --- configs/llm_model_configs/azure-gpt-4o-mini.json | 7 +++++++ configs/llm_model_configs/groq.json | 7 +++++++ 3 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 configs/llm_model_configs/azure-gpt-4o-mini.json create mode 100644 configs/llm_model_configs/groq.json diff --git a/.gitignore b/.gitignore index 98285992..ddcdb97a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# Letta config files -configs/ - # Below are generated by gitignor.io (toptal) # Created by https://www.toptal.com/developers/gitignore/api/vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection # Edit at https://www.toptal.com/developers/gitignore?templates=vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection diff --git a/configs/llm_model_configs/azure-gpt-4o-mini.json b/configs/llm_model_configs/azure-gpt-4o-mini.json new file mode 100644 index 00000000..0bb31245 --- /dev/null +++ b/configs/llm_model_configs/azure-gpt-4o-mini.json @@ -0,0 +1,7 @@ +{ + "context_window": 128000, + "model": "gpt-4o-mini", + "model_endpoint_type": "azure", + "model_endpoint": "https://letta.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview", + "model_wrapper": null +} diff --git a/configs/llm_model_configs/groq.json b/configs/llm_model_configs/groq.json new file mode 100644 index 00000000..a63acbf0 --- /dev/null +++ b/configs/llm_model_configs/groq.json @@ -0,0 +1,7 @@ +{ + "context_window": 8192, + "model": "llama3-groq-70b-8192-tool-use-preview", + "model_endpoint_type": "groq", + "model_endpoint": "https://api.groq.com/openai/v1", + "model_wrapper": null +} From b98ecb4b1ce673a2726c8e5195951c2b65f29e6d Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Wed, 2 Oct 2024 09:45:13 -0700 Subject: [PATCH 09/16] Add literal endpoint types --- letta/schemas/llm_config.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index dd71cbd7..d951c2dd 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -17,9 +17,23 @@ class LLMConfig(BaseModel): # TODO: 🤮 don't default to a vendor! bug city! model: str = Field(..., description="LLM model name. ") - model_endpoint_type: Literal["openai", "anthropic", "cohere", "google_ai", "azure", "groq", "ollama"] = Field( - ..., description="The endpoint type for the model." - ) + model_endpoint_type: Literal[ + "openai", + "anthropic", + "cohere", + "google_ai", + "azure", + "groq", + "ollama", + "webui", + "webui-legacy", + "lmstudio", + "lmstudio-legacy", + "llamacpp", + "koboldcpp", + "vllm", + "hugging-face", + ] = Field(..., description="The endpoint type for the model.") model_endpoint: str = Field(..., description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") context_window: int = Field(..., description="The context window size for the model.") From 216e69d52cd2e4e2a372929beb1222a3381dd3f0 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 3 Oct 2024 10:04:18 -0700 Subject: [PATCH 10/16] Add groq flow to CLI --- letta/cli/cli_config.py | 36 +++++++++++++++++++++++++++++++++- letta/llm_api/llm_api_tools.py | 2 +- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/letta/cli/cli_config.py b/letta/cli/cli_config.py index c964fb75..fd14c3d9 100644 --- a/letta/cli/cli_config.py +++ b/letta/cli/cli_config.py @@ -126,7 +126,41 @@ def configure_llm_endpoint(config: LettaConfig, credentials: LettaCredentials): model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask() if model_endpoint is None: raise KeyboardInterrupt - provider = "openai" + + elif provider == "groq": + groq_user_msg = "Enter your Groq API key (starts with 'gsk-', see https://console.groq.com/keys):" + # check for key + if credentials.groq_key is None: + # allow key to get pulled from env vars + groq_api_key = os.getenv("GROQ_API_KEY", None) + # if we still can't find it, ask for it as input + if groq_api_key is None: + while groq_api_key is None or len(groq_api_key) == 0: + # Ask for API key as input + groq_api_key = questionary.password(groq_user_msg).ask() + if groq_api_key is None: + raise KeyboardInterrupt + credentials.groq_key = groq_api_key + credentials.save() + else: + # Give the user an opportunity to overwrite the key + default_input = shorten_key_middle(credentials.groq_key) if credentials.groq_key.startswith("gsk-") else credentials.groq_key + groq_api_key = questionary.password( + groq_user_msg, + default=default_input, + ).ask() + if groq_api_key is None: + raise KeyboardInterrupt + # If the user modified it, use the new one + if groq_api_key != default_input: + credentials.groq_key = groq_api_key + credentials.save() + + model_endpoint_type = "groq" + model_endpoint = "https://api.groq.com/openai/v1" + model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask() + if model_endpoint is None: + raise KeyboardInterrupt elif provider == "azure": # check for necessary vars diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 295f05e6..ebf84376 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -44,7 +44,7 @@ from letta.streaming_interface import ( ) from letta.utils import json_dumps -LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local"] +LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"] # TODO update to use better types From ee46de19bdf3897fb0a68f06cf56b67f53757558 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 3 Oct 2024 10:07:04 -0700 Subject: [PATCH 11/16] Delete azure configs --- configs/llm_model_configs/azure-gpt-4o-mini.json | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 configs/llm_model_configs/azure-gpt-4o-mini.json diff --git a/configs/llm_model_configs/azure-gpt-4o-mini.json b/configs/llm_model_configs/azure-gpt-4o-mini.json deleted file mode 100644 index 0bb31245..00000000 --- a/configs/llm_model_configs/azure-gpt-4o-mini.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "context_window": 128000, - "model": "gpt-4o-mini", - "model_endpoint_type": "azure", - "model_endpoint": "https://letta.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2023-03-15-preview", - "model_wrapper": null -} From 17192d5aa7fcfcadf92e951851129a7482b16ee0 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 3 Oct 2024 12:08:10 -0700 Subject: [PATCH 12/16] fix: Fix small benchmark bugs (#1826) Co-authored-by: Matt Zhou --- letta/benchmark/benchmark.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/letta/benchmark/benchmark.py b/letta/benchmark/benchmark.py index 4031d4a7..7109210e 100644 --- a/letta/benchmark/benchmark.py +++ b/letta/benchmark/benchmark.py @@ -2,11 +2,11 @@ import time import uuid -from typing import Annotated +from typing import Annotated, Union import typer -from letta import create_client +from letta import LocalClient, RESTClient, create_client from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES from letta.config import LettaConfig @@ -17,11 +17,13 @@ from letta.utils import get_human_text, get_persona_text app = typer.Typer() -def send_message(message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES): +def send_message( + client: Union[LocalClient, RESTClient], message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES +): try: print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}" print(print_msg, end="\r", flush=True) - response = client.user_message(agent_id=agent_id, message=message, return_token_count=True) + response = client.user_message(agent_id=agent_id, message=message) if turn + 1 == n_tries: print(" " * len(print_msg), end="\r", flush=True) @@ -65,7 +67,7 @@ def bench( agent_id = agent.id result, msg = send_message( - message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries + client=client, message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries ) if print_messages: From e0442bd6585f4f17fdb9dca207d2ce17ec9e6e92 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Thu, 3 Oct 2024 18:08:46 -0700 Subject: [PATCH 13/16] chore: bump version 0.4.1 (#1809) --- letta/__init__.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/letta/__init__.py b/letta/__init__.py index 93cdfd4b..bc200417 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.7" +__version__ = "0.4.1" # import clients from letta.client.admin import Admin diff --git a/pyproject.toml b/pyproject.toml index d3501385..3406e9a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.1.7" +version = "0.4.1" packages = [ {include = "letta"} ] From ad5b07071062fa5afce29d34b0916a57f8c2724a Mon Sep 17 00:00:00 2001 From: cpacker Date: Thu, 3 Oct 2024 19:03:59 -0700 Subject: [PATCH 14/16] fix: various fixes to get groq to work from the CLI --- letta/cli/cli_config.py | 30 +++++++++++++++++++++++++++++- letta/credentials.py | 5 +++++ letta/llm_api/llm_api_tools.py | 1 + 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/letta/cli/cli_config.py b/letta/cli/cli_config.py index fd14c3d9..1d59e8c7 100644 --- a/letta/cli/cli_config.py +++ b/letta/cli/cli_config.py @@ -426,6 +426,12 @@ def get_model_options( fetched_model_options = cohere_get_model_list(url=model_endpoint, api_key=credentials.cohere_key) model_options = [obj for obj in fetched_model_options] + elif model_endpoint_type == "groq": + if credentials.groq_key is None: + raise ValueError("Missing Groq API key") + fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=credentials.groq_key, fix_url=True) + model_options = [obj["id"] for obj in fetched_model_options_response["data"]] + else: # Attempt to do OpenAI endpoint style model fetching # TODO support local auth with api-key header @@ -589,10 +595,32 @@ def configure_model(config: LettaConfig, credentials: LettaCredentials, model_en if model is None: raise KeyboardInterrupt + # Groq support via /chat/completions + function calling endpoints + elif model_endpoint_type == "groq": + try: + fetched_model_options = get_model_options( + credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint + ) + + except Exception as e: + # NOTE: if this fails, it means the user's key is probably bad + typer.secho( + f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED + ) + raise e + + model = questionary.select( + "Select default model:", + choices=fetched_model_options, + default=fetched_model_options[0], + ).ask() + if model is None: + raise KeyboardInterrupt + else: # local models # ask about local auth - if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys + if model_endpoint_type in ["groq-chat-compltions"]: # TODO all llm engines under 'local' that will require api keys use_local_auth = True local_auth_type = "bearer_token" local_auth_key = questionary.password( diff --git a/letta/credentials.py b/letta/credentials.py index ea92cc29..4d807fc5 100644 --- a/letta/credentials.py +++ b/letta/credentials.py @@ -81,6 +81,8 @@ class LettaCredentials: "anthropic_key": get_field(config, "anthropic", "key"), # cohere "cohere_key": get_field(config, "cohere", "key"), + # groq + "groq_key": get_field(config, "groq", "key"), # open llm "openllm_auth_type": get_field(config, "openllm", "auth_type"), "openllm_key": get_field(config, "openllm", "key"), @@ -123,6 +125,9 @@ class LettaCredentials: # cohere set_field(config, "cohere", "key", self.cohere_key) + # groq + set_field(config, "groq", "key", self.groq_key) + # openllm config set_field(config, "openllm", "auth_type", self.openllm_auth_type) set_field(config, "openllm", "key", self.openllm_key) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index ebf84376..6e5d47e7 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -492,6 +492,7 @@ def create( stream_inferface.stream_start() try: # groq uses the openai chat completions API, so this component should be reusable + assert credentials.groq_key is not None, "Groq key is missing" response = openai_chat_completions_request( url=llm_config.model_endpoint, api_key=credentials.groq_key, From 402f8bc157757f6be0ed0b0ce65fe9d2b8e4cb41 Mon Sep 17 00:00:00 2001 From: cpacker Date: Thu, 3 Oct 2024 19:08:42 -0700 Subject: [PATCH 15/16] fix: force inner_thoughts_in_kwargs for groq since their function calling wrapper doesn't work without it --- letta/llm_api/llm_api_tools.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 6e5d47e7..57822123 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -470,10 +470,19 @@ def create( # only is a problem if we are *not* using an openai proxy raise ValueError(f"Groq key is missing from letta config file") + # force to true for groq, since they don't support 'content' is non-null + inner_thoughts_in_kwargs = True + if inner_thoughts_in_kwargs: + functions = add_inner_thoughts_to_functions( + functions=functions, + inner_thoughts_key=INNER_THOUGHTS_KWARG, + inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION, + ) + tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None data = ChatCompletionRequest( model=llm_config.model, - messages=[m.to_openai_dict() for m in messages], + messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs) for m in messages], tools=tools, tool_choice=function_call, user=str(user_id), @@ -502,6 +511,9 @@ def create( if isinstance(stream_inferface, AgentChunkStreamingInterface): stream_inferface.stream_end() + if inner_thoughts_in_kwargs: + response = unpack_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) + return response # local model From b17246a3b000bb3d2581dada30f0bde94d3ba1d3 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Fri, 4 Oct 2024 15:36:33 -0700 Subject: [PATCH 16/16] feat: add back support for using `AssistantMessage` subtype of `LettaMessage` (#1812) --- letta/constants.py | 6 + letta/schemas/letta_request.py | 17 +++ letta/schemas/message.py | 49 +++++-- letta/server/rest_api/interface.py | 141 +++++++++++++++++---- letta/server/rest_api/routers/v1/agents.py | 38 +++++- letta/server/server.py | 23 +++- tests/test_server.py | 79 ++++++++++-- 7 files changed, 300 insertions(+), 53 deletions(-) diff --git a/letta/constants.py b/letta/constants.py index dc0a17c0..84fa0a76 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -46,6 +46,12 @@ BASE_TOOLS = [ "archival_memory_search", ] +# The name of the tool used to send message to the user +# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) +# or in cases where the agent has no concept of messaging a user (e.g. a workflow agent) +DEFAULT_MESSAGE_TOOL = "send_message" +DEFAULT_MESSAGE_TOOL_KWARG = "message" + # LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET} diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index b690b47b..a6e49d8b 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -2,6 +2,7 @@ from typing import List from pydantic import BaseModel, Field +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.message import MessageCreate @@ -21,3 +22,19 @@ class LettaRequest(BaseModel): default=False, description="Set True to return the raw Message object. Set False to return the Message in the format of the Letta API.", ) + + # Flags to support the use of AssistantMessage message types + + use_assistant_message: bool = Field( + default=False, + description="[Only applicable if return_message_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.", + ) + + assistant_message_function_name: str = Field( + default=DEFAULT_MESSAGE_TOOL, + description="[Only applicable if use_assistant_message is True] The name of the designated message tool.", + ) + assistant_message_function_kwarg: str = Field( + default=DEFAULT_MESSAGE_TOOL_KWARG, + description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.", + ) diff --git a/letta/schemas/message.py b/letta/schemas/message.py index d3879c0c..70aa9df9 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -6,11 +6,16 @@ from typing import List, Optional from pydantic import Field, field_validator -from letta.constants import TOOL_CALL_ID_MAX_LEN +from letta.constants import ( + DEFAULT_MESSAGE_TOOL, + DEFAULT_MESSAGE_TOOL_KWARG, + TOOL_CALL_ID_MAX_LEN, +) from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.schemas.enums import MessageRole from letta.schemas.letta_base import LettaBase from letta.schemas.letta_message import ( + AssistantMessage, FunctionCall, FunctionCallMessage, FunctionReturn, @@ -122,7 +127,12 @@ class Message(BaseMessage): json_message["created_at"] = self.created_at.isoformat() return json_message - def to_letta_message(self) -> List[LettaMessage]: + def to_letta_message( + self, + assistant_message: bool = False, + assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, + ) -> List[LettaMessage]: """Convert message object (in DB format) to the style used by the original Letta API""" messages = [] @@ -140,16 +150,33 @@ class Message(BaseMessage): if self.tool_calls is not None: # This is type FunctionCall for tool_call in self.tool_calls: - messages.append( - FunctionCallMessage( - id=self.id, - date=self.created_at, - function_call=FunctionCall( - name=tool_call.function.name, - arguments=tool_call.function.arguments, - ), + # If we're supporting using assistant message, + # then we want to treat certain function calls as a special case + if assistant_message and tool_call.function.name == assistant_message_function_name: + # We need to unpack the actual message contents from the function call + try: + func_args = json.loads(tool_call.function.arguments) + message_string = func_args[DEFAULT_MESSAGE_TOOL_KWARG] + except KeyError: + raise ValueError(f"Function call {tool_call.function.name} missing {DEFAULT_MESSAGE_TOOL_KWARG} argument") + messages.append( + AssistantMessage( + id=self.id, + date=self.created_at, + assistant_message=message_string, + ) + ) + else: + messages.append( + FunctionCallMessage( + id=self.id, + date=self.created_at, + function_call=FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) ) - ) elif self.role == MessageRole.tool: # This is type FunctionReturn # Try to interpret the function return, recall that this is how we packaged: diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 0715b901..b8b06d78 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -1,10 +1,12 @@ import asyncio import json import queue +import warnings from collections import deque from datetime import datetime from typing import AsyncGenerator, Literal, Optional, Union +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.interface import AgentInterface from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( @@ -249,7 +251,7 @@ class QueuingInterface(AgentInterface): class FunctionArgumentsStreamHandler: """State machine that can process a stream of""" - def __init__(self, json_key="message"): + def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG): self.json_key = json_key self.reset() @@ -311,7 +313,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface): should maintain multiple generators and index them with the request ID """ - def __init__(self, multi_step=True): + def __init__( + self, + multi_step=True, + use_assistant_message=False, + assistant_message_function_name=DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, + ): # If streaming mode, ignores base interface calls like .assistant_message, etc self.streaming_mode = False # NOTE: flag for supporting legacy 'stream' flag where send_message is treated specially @@ -321,7 +329,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.streaming_chat_completion_mode_function_name = None # NOTE: sadly need to track state during stream # If chat completion mode, we need a special stream reader to # turn function argument to send_message into a normal text stream - self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler() + self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_function_kwarg) self._chunks = deque() self._event = asyncio.Event() # Use an event to notify when chunks are available @@ -333,6 +341,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.multi_step_indicator = MessageStreamStatus.done_step self.multi_step_gen_indicator = MessageStreamStatus.done_generation + # Support for AssistantMessage + self.use_assistant_message = use_assistant_message + self.assistant_message_function_name = assistant_message_function_name + self.assistant_message_function_kwarg = assistant_message_function_kwarg + # extra prints self.debug = False self.timeout = 30 @@ -441,7 +454,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def _process_chunk_to_letta_style( self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime - ) -> Optional[Union[InternalMonologue, FunctionCallMessage]]: + ) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]: """ Example data from non-streaming response looks like: @@ -461,23 +474,83 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=message_date, internal_monologue=message_delta.content, ) + + # tool calls elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0: tool_call = message_delta.tool_calls[0] - tool_call_delta = {} - if tool_call.id: - tool_call_delta["id"] = tool_call.id - if tool_call.function: - if tool_call.function.arguments: - tool_call_delta["arguments"] = tool_call.function.arguments - if tool_call.function.name: - tool_call_delta["name"] = tool_call.function.name + # special case for trapping `send_message` + if self.use_assistant_message and tool_call.function: + + # If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode + + # Track the function name while streaming + # If we were previously on a 'send_message', we need to 'toggle' into 'content' mode + if tool_call.function.name: + if self.streaming_chat_completion_mode_function_name is None: + self.streaming_chat_completion_mode_function_name = tool_call.function.name + else: + self.streaming_chat_completion_mode_function_name += tool_call.function.name + + # If we get a "hit" on the special keyword we're looking for, we want to skip to the next chunk + # TODO I don't think this handles the function name in multi-pieces problem. Instead, we should probably reset the streaming_chat_completion_mode_function_name when we make this hit? + # if self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name: + if tool_call.function.name == self.assistant_message_function_name: + self.streaming_chat_completion_json_reader.reset() + # early exit to turn into content mode + return None + + # if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks + if ( + tool_call.function.arguments + and self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name + ): + # Strip out any extras tokens + cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments) + # In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk + if cleaned_func_args is None: + return None + else: + processed_chunk = AssistantMessage( + id=message_id, + date=message_date, + assistant_message=cleaned_func_args, + ) + + # otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage + else: + tool_call_delta = {} + if tool_call.id: + tool_call_delta["id"] = tool_call.id + if tool_call.function: + if tool_call.function.arguments: + tool_call_delta["arguments"] = tool_call.function.arguments + if tool_call.function.name: + tool_call_delta["name"] = tool_call.function.name + + processed_chunk = FunctionCallMessage( + id=message_id, + date=message_date, + function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + ) + + else: + + tool_call_delta = {} + if tool_call.id: + tool_call_delta["id"] = tool_call.id + if tool_call.function: + if tool_call.function.arguments: + tool_call_delta["arguments"] = tool_call.function.arguments + if tool_call.function.name: + tool_call_delta["name"] = tool_call.function.name + + processed_chunk = FunctionCallMessage( + id=message_id, + date=message_date, + function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + ) - processed_chunk = FunctionCallMessage( - id=message_id, - date=message_date, - function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), - ) elif choice.finish_reason is not None: # skip if there's a finish return None @@ -663,14 +736,32 @@ class StreamingServerInterface(AgentChunkStreamingInterface): else: - processed_chunk = FunctionCallMessage( - id=msg_obj.id, - date=msg_obj.created_at, - function_call=FunctionCall( - name=function_call.function.name, - arguments=function_call.function.arguments, - ), - ) + try: + func_args = json.loads(function_call.function.arguments) + except: + warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}") + func_args = {} + + if ( + self.use_assistant_message + and function_call.function.name == self.assistant_message_function_name + and self.assistant_message_function_kwarg in func_args + ): + processed_chunk = AssistantMessage( + id=msg_obj.id, + date=msg_obj.created_at, + assistant_message=func_args[self.assistant_message_function_kwarg], + ) + else: + processed_chunk = FunctionCallMessage( + id=msg_obj.id, + date=msg_obj.created_at, + function_call=FunctionCall( + name=function_call.function.name, + arguments=function_call.function.arguments, + ), + ) + # processed_chunk = { # "function_call": { # "name": function_call.function.name, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 514db4c0..cf4a8a64 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Query, status from fastapi.responses import JSONResponse, StreamingResponse from starlette.responses import StreamingResponse +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState from letta.schemas.enums import MessageRole, MessageStreamStatus from letta.schemas.letta_message import ( @@ -254,6 +255,19 @@ def get_agent_messages( before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), limit: int = Query(10, description="Maximum number of messages to retrieve."), msg_object: bool = Query(False, description="If true, returns Message objects. If false, return LettaMessage objects."), + # Flags to support the use of AssistantMessage message types + use_assistant_message: bool = Query( + False, + description="[Only applicable if msg_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.", + ), + assistant_message_function_name: str = Query( + DEFAULT_MESSAGE_TOOL, + description="[Only applicable if use_assistant_message is True] The name of the designated message tool.", + ), + assistant_message_function_kwarg: str = Query( + DEFAULT_MESSAGE_TOOL_KWARG, + description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.", + ), ): """ Retrieve message history for an agent. @@ -267,6 +281,9 @@ def get_agent_messages( limit=limit, reverse=True, return_message_object=msg_object, + use_assistant_message=use_assistant_message, + assistant_message_function_name=assistant_message_function_name, + assistant_message_function_kwarg=assistant_message_function_kwarg, ) @@ -310,6 +327,10 @@ async def send_message( stream_steps=request.stream_steps, stream_tokens=request.stream_tokens, return_message_object=request.return_message_object, + # Support for AssistantMessage + use_assistant_message=request.use_assistant_message, + assistant_message_function_name=request.assistant_message_function_name, + assistant_message_function_kwarg=request.assistant_message_function_kwarg, ) @@ -322,12 +343,17 @@ async def send_message_to_agent( message: str, stream_steps: bool, stream_tokens: bool, - return_message_object: bool, # Should be True for Python Client, False for REST API - chat_completion_mode: Optional[bool] = False, - timestamp: Optional[datetime] = None, # related to whether or not we return `LettaMessage`s or `Message`s + return_message_object: bool, # Should be True for Python Client, False for REST API + chat_completion_mode: bool = False, + timestamp: Optional[datetime] = None, + # Support for AssistantMessage + use_assistant_message: bool = False, + assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, ) -> Union[StreamingResponse, LettaResponse]: """Split off into a separate function so that it can be imported in the /chat/completion proxy.""" + # TODO: @charles is this the correct way to handle? include_final_message = True @@ -368,6 +394,11 @@ async def send_message_to_agent( # streaming_interface.allow_assistant_message = stream # streaming_interface.function_call_legacy_mode = stream + # Allow AssistantMessage is desired by client + streaming_interface.use_assistant_message = use_assistant_message + streaming_interface.assistant_message_function_name = assistant_message_function_name + streaming_interface.assistant_message_function_kwarg = assistant_message_function_kwarg + # Offload the synchronous message_func to a separate thread streaming_interface.stream_start() task = asyncio.create_task( @@ -408,6 +439,7 @@ async def send_message_to_agent( message_ids = [m.id for m in filtered_stream] message_ids = deduplicate(message_ids) message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids] + message_objs = [m for m in message_objs if m is not None] return LettaResponse(messages=message_objs, usage=usage) else: return LettaResponse(messages=filtered_stream, usage=usage) diff --git a/letta/server/server.py b/letta/server/server.py index 80b4c4f1..454f9881 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1262,6 +1262,9 @@ class SyncServer(Server): order: Optional[str] = "asc", reverse: Optional[bool] = False, return_message_object: bool = True, + use_assistant_message: bool = False, + assistant_message_function_name: str = constants.DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, ) -> Union[List[Message], List[LettaMessage]]: if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -1281,9 +1284,25 @@ class SyncServer(Server): if not return_message_object: # If we're GETing messages in reverse, we need to reverse the inner list (generated by to_letta_message) if reverse: - records = [msg for m in records for msg in m.to_letta_message()[::-1]] + records = [ + msg + for m in records + for msg in m.to_letta_message( + assistant_message=use_assistant_message, + assistant_message_function_name=assistant_message_function_name, + assistant_message_function_kwarg=assistant_message_function_kwarg, + )[::-1] + ] else: - records = [msg for m in records for msg in m.to_letta_message()] + records = [ + msg + for m in records + for msg in m.to_letta_message( + assistant_message=use_assistant_message, + assistant_message_function_name=assistant_message_function_name, + assistant_message_function_kwarg=assistant_message_function_kwarg, + ) + ] return records diff --git a/tests/test_server.py b/tests/test_server.py index 67fa58ad..440e9833 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,16 +1,18 @@ import json import uuid +import warnings import pytest import letta.utils as utils -from letta.constants import BASE_TOOLS +from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.enums import MessageRole utils.DEBUG = True from letta.config import LettaConfig from letta.schemas.agent import CreateAgent from letta.schemas.letta_message import ( + AssistantMessage, FunctionCallMessage, FunctionReturn, InternalMonologue, @@ -236,7 +238,14 @@ def test_get_archival_memory(server, user_id, agent_id): assert len(passage_none) == 0 -def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): +def _test_get_messages_letta_format( + server, + user_id, + agent_id, + reverse=False, + # flag that determines whether or not to use AssistantMessage, or just FunctionCallMessage universally + use_assistant_message=False, +): """Reverse is off by default, the GET goes in chronological order""" messages = server.get_agent_recall_cursor( @@ -244,6 +253,8 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): agent_id=agent_id, limit=1000, reverse=reverse, + return_message_object=True, + use_assistant_message=use_assistant_message, ) # messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000) assert all(isinstance(m, Message) for m in messages) @@ -254,6 +265,7 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): limit=1000, reverse=reverse, return_message_object=False, + use_assistant_message=use_assistant_message, ) # letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False) assert all(isinstance(m, LettaMessage) for m in letta_messages) @@ -316,9 +328,30 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): # If there are multiple tool calls, we should have multiple back to back FunctionCallMessages if message.tool_calls is not None: for tool_call in message.tool_calls: - assert isinstance(letta_message, FunctionCallMessage) - letta_message_index += 1 - letta_message = letta_messages[letta_message_index] + + # Try to parse the tool call args + try: + func_args = json.loads(tool_call.function.arguments) + except: + warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}") + func_args = {} + + # If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool + if ( + use_assistant_message + and tool_call.function.name == DEFAULT_MESSAGE_TOOL + and DEFAULT_MESSAGE_TOOL_KWARG in func_args + ): + assert isinstance(letta_message, AssistantMessage) + assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] + + # Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage + else: + assert isinstance(letta_message, FunctionCallMessage) + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] if message.text is not None: assert isinstance(letta_message, InternalMonologue) @@ -341,11 +374,32 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): # If there are multiple tool calls, we should have multiple back to back FunctionCallMessages if message.tool_calls is not None: for tool_call in message.tool_calls: - assert isinstance(letta_message, FunctionCallMessage) - assert tool_call.function.name == letta_message.function_call.name - assert tool_call.function.arguments == letta_message.function_call.arguments - letta_message_index += 1 - letta_message = letta_messages[letta_message_index] + + # Try to parse the tool call args + try: + func_args = json.loads(tool_call.function.arguments) + except: + warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}") + func_args = {} + + # If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool + if ( + use_assistant_message + and tool_call.function.name == DEFAULT_MESSAGE_TOOL + and DEFAULT_MESSAGE_TOOL_KWARG in func_args + ): + assert isinstance(letta_message, AssistantMessage) + assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] + + # Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage + else: + assert isinstance(letta_message, FunctionCallMessage) + assert tool_call.function.name == letta_message.function_call.name + assert tool_call.function.arguments == letta_message.function_call.arguments + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] elif message.role == MessageRole.user: print(f"i={i}, M=user, MM={type(letta_message)}") @@ -374,8 +428,9 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): def test_get_messages_letta_format(server, user_id, agent_id): - _test_get_messages_letta_format(server, user_id, agent_id, reverse=False) - _test_get_messages_letta_format(server, user_id, agent_id, reverse=True) + for reverse in [False, True]: + for assistant_message in [False, True]: + _test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse, use_assistant_message=assistant_message) def test_agent_rethink_rewrite_retry(server, user_id, agent_id):