diff --git a/.github/workflows/test_memgpt_hosted.yml b/.github/workflows/test_memgpt_hosted.yml index 93f45986..71e9a7f5 100644 --- a/.github/workflows/test_memgpt_hosted.yml +++ b/.github/workflows/test_memgpt_hosted.yml @@ -24,6 +24,7 @@ jobs: - name: Test LLM endpoint run: | poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_letta_hosted + continue-on-error: true - name: Test embedding endpoint run: | diff --git a/.github/workflows/test_openai.yml b/.github/workflows/test_openai.yml index db15042f..25d27ec8 100644 --- a/.github/workflows/test_openai.yml +++ b/.github/workflows/test_openai.yml @@ -35,33 +35,63 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_returns_valid_first_message + continue-on-error: true - name: Test model sends message with keyword env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_returns_keyword + continue-on-error: true - name: Test model uses external tool correctly env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_uses_external_tool + continue-on-error: true - name: Test model recalls chat memory env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_recall_chat_memory + continue-on-error: true - name: Test model uses `archival_memory_search` to find secret env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_archival_memory_retrieval + continue-on-error: true + + - name: Test model can edit core memories + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_edit_core_memory + continue-on-error: true - name: Test embedding endpoint env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_openai + continue-on-error: true + + - name: Summarize test results + run: | + echo "Test Results Summary:" + echo "Test first message: $([[ ${{ steps.test_first_message.outcome }} == 'success' ]] && echo ✅ || echo ❌)" + echo "Test model sends message with keyword: $([[ ${{ steps.test_keyword_message.outcome }} == 'success' ]] && echo ✅ || echo ❌)" + echo "Test model uses external tool: $([[ ${{ steps.test_external_tool.outcome }} == 'success' ]] && echo ✅ || echo ❌)" + echo "Test model recalls chat memory: $([[ ${{ steps.test_chat_memory.outcome }} == 'success' ]] && echo ✅ || echo ❌)" + echo "Test model uses 'archival_memory_search' to find secret: $([[ ${{ steps.test_archival_memory.outcome }} == 'success' ]] && echo ✅ || echo ❌)" + echo "Test model can edit core memories: $([[ ${{ steps.test_core_memory.outcome }} == 'success' ]] && echo ✅ || echo ❌)" + echo "Test embedding endpoint: $([[ ${{ steps.test_embedding_endpoint.outcome }} == 'success' ]] && echo ✅ || echo ❌)" + + # Check if any test failed + if [[ ${{ steps.test_first_message.outcome }} != 'success' || ${{ steps.test_keyword_message.outcome }} != 'success' || ${{ steps.test_external_tool.outcome }} != 'success' || ${{ steps.test_chat_memory.outcome }} != 'success' || ${{ steps.test_archival_memory.outcome }} != 'success' || ${{ steps.test_core_memory.outcome }} != 'success' || ${{ steps.test_embedding_endpoint.outcome }} != 'success' ]]; then + echo "Some tests failed, setting neutral status." + exit 78 + fi diff --git a/.gitignore b/.gitignore index 98285992..ddcdb97a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# Letta config files -configs/ - # Below are generated by gitignor.io (toptal) # Created by https://www.toptal.com/developers/gitignore/api/vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection # Edit at https://www.toptal.com/developers/gitignore?templates=vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection diff --git a/configs/llm_model_configs/groq.json b/configs/llm_model_configs/groq.json new file mode 100644 index 00000000..a63acbf0 --- /dev/null +++ b/configs/llm_model_configs/groq.json @@ -0,0 +1,7 @@ +{ + "context_window": 8192, + "model": "llama3-groq-70b-8192-tool-use-preview", + "model_endpoint_type": "groq", + "model_endpoint": "https://api.groq.com/openai/v1", + "model_wrapper": null +} diff --git a/letta/__init__.py b/letta/__init__.py index 93cdfd4b..bc200417 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.7" +__version__ = "0.4.1" # import clients from letta.client.admin import Admin diff --git a/letta/benchmark/benchmark.py b/letta/benchmark/benchmark.py index 4031d4a7..7109210e 100644 --- a/letta/benchmark/benchmark.py +++ b/letta/benchmark/benchmark.py @@ -2,11 +2,11 @@ import time import uuid -from typing import Annotated +from typing import Annotated, Union import typer -from letta import create_client +from letta import LocalClient, RESTClient, create_client from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES from letta.config import LettaConfig @@ -17,11 +17,13 @@ from letta.utils import get_human_text, get_persona_text app = typer.Typer() -def send_message(message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES): +def send_message( + client: Union[LocalClient, RESTClient], message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES +): try: print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}" print(print_msg, end="\r", flush=True) - response = client.user_message(agent_id=agent_id, message=message, return_token_count=True) + response = client.user_message(agent_id=agent_id, message=message) if turn + 1 == n_tries: print(" " * len(print_msg), end="\r", flush=True) @@ -65,7 +67,7 @@ def bench( agent_id = agent.id result, msg = send_message( - message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries + client=client, message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries ) if print_messages: diff --git a/letta/cli/cli_config.py b/letta/cli/cli_config.py index c964fb75..1d59e8c7 100644 --- a/letta/cli/cli_config.py +++ b/letta/cli/cli_config.py @@ -126,7 +126,41 @@ def configure_llm_endpoint(config: LettaConfig, credentials: LettaCredentials): model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask() if model_endpoint is None: raise KeyboardInterrupt - provider = "openai" + + elif provider == "groq": + groq_user_msg = "Enter your Groq API key (starts with 'gsk-', see https://console.groq.com/keys):" + # check for key + if credentials.groq_key is None: + # allow key to get pulled from env vars + groq_api_key = os.getenv("GROQ_API_KEY", None) + # if we still can't find it, ask for it as input + if groq_api_key is None: + while groq_api_key is None or len(groq_api_key) == 0: + # Ask for API key as input + groq_api_key = questionary.password(groq_user_msg).ask() + if groq_api_key is None: + raise KeyboardInterrupt + credentials.groq_key = groq_api_key + credentials.save() + else: + # Give the user an opportunity to overwrite the key + default_input = shorten_key_middle(credentials.groq_key) if credentials.groq_key.startswith("gsk-") else credentials.groq_key + groq_api_key = questionary.password( + groq_user_msg, + default=default_input, + ).ask() + if groq_api_key is None: + raise KeyboardInterrupt + # If the user modified it, use the new one + if groq_api_key != default_input: + credentials.groq_key = groq_api_key + credentials.save() + + model_endpoint_type = "groq" + model_endpoint = "https://api.groq.com/openai/v1" + model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask() + if model_endpoint is None: + raise KeyboardInterrupt elif provider == "azure": # check for necessary vars @@ -392,6 +426,12 @@ def get_model_options( fetched_model_options = cohere_get_model_list(url=model_endpoint, api_key=credentials.cohere_key) model_options = [obj for obj in fetched_model_options] + elif model_endpoint_type == "groq": + if credentials.groq_key is None: + raise ValueError("Missing Groq API key") + fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=credentials.groq_key, fix_url=True) + model_options = [obj["id"] for obj in fetched_model_options_response["data"]] + else: # Attempt to do OpenAI endpoint style model fetching # TODO support local auth with api-key header @@ -555,10 +595,32 @@ def configure_model(config: LettaConfig, credentials: LettaCredentials, model_en if model is None: raise KeyboardInterrupt + # Groq support via /chat/completions + function calling endpoints + elif model_endpoint_type == "groq": + try: + fetched_model_options = get_model_options( + credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint + ) + + except Exception as e: + # NOTE: if this fails, it means the user's key is probably bad + typer.secho( + f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED + ) + raise e + + model = questionary.select( + "Select default model:", + choices=fetched_model_options, + default=fetched_model_options[0], + ).ask() + if model is None: + raise KeyboardInterrupt + else: # local models # ask about local auth - if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys + if model_endpoint_type in ["groq-chat-compltions"]: # TODO all llm engines under 'local' that will require api keys use_local_auth = True local_auth_type = "bearer_token" local_auth_key = questionary.password( diff --git a/letta/constants.py b/letta/constants.py index dc0a17c0..84fa0a76 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -46,6 +46,12 @@ BASE_TOOLS = [ "archival_memory_search", ] +# The name of the tool used to send message to the user +# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) +# or in cases where the agent has no concept of messaging a user (e.g. a workflow agent) +DEFAULT_MESSAGE_TOOL = "send_message" +DEFAULT_MESSAGE_TOOL_KWARG = "message" + # LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET} diff --git a/letta/credentials.py b/letta/credentials.py index 8052d16b..d662e76e 100644 --- a/letta/credentials.py +++ b/letta/credentials.py @@ -31,6 +31,10 @@ class LettaCredentials: # azure config azure_auth_type: str = "api_key" azure_key: Optional[str] = None + + # groq config + groq_key: Optional[str] = os.getenv("GROQ_API_KEY") + # base llm / model azure_version: Optional[str] = None azure_endpoint: Optional[str] = None @@ -77,6 +81,8 @@ class LettaCredentials: "anthropic_key": get_field(config, "anthropic", "key"), # cohere "cohere_key": get_field(config, "cohere", "key"), + # groq + "groq_key": get_field(config, "groq", "key"), # open llm "openllm_auth_type": get_field(config, "openllm", "auth_type"), "openllm_key": get_field(config, "openllm", "key"), @@ -119,6 +125,9 @@ class LettaCredentials: # cohere set_field(config, "cohere", "key", self.cohere_key) + # groq + set_field(config, "groq", "key", self.groq_key) + # openllm config set_field(config, "openllm", "auth_type", self.openllm_auth_type) set_field(config, "openllm", "key", self.openllm_key) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 7ff9193d..49703943 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -44,7 +44,7 @@ from letta.streaming_interface import ( ) from letta.utils import json_dumps -LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local"] +LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"] # TODO update to use better types @@ -335,7 +335,6 @@ def create( if isinstance(stream_inferface, AgentChunkStreamingInterface): stream_inferface.stream_start() try: - response = openai_chat_completions_request( url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, @@ -458,7 +457,7 @@ def create( chat_completion_request=ChatCompletionRequest( model="command-r-plus", # TODO messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], - tools=[{"type": "function", "function": f} for f in functions] if functions else None, + tools=tools, tool_choice=function_call, # user=str(user_id), # NOTE: max_tokens is required for Anthropic API @@ -466,6 +465,60 @@ def create( ), ) + elif llm_config.model_endpoint_type == "groq": + if stream: + raise NotImplementedError(f"Streaming not yet implemented for Groq.") + + if credentials.groq_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions": + # only is a problem if we are *not* using an openai proxy + raise ValueError(f"Groq key is missing from letta config file") + + # force to true for groq, since they don't support 'content' is non-null + inner_thoughts_in_kwargs = True + if inner_thoughts_in_kwargs: + functions = add_inner_thoughts_to_functions( + functions=functions, + inner_thoughts_key=INNER_THOUGHTS_KWARG, + inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION, + ) + + tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None + data = ChatCompletionRequest( + model=llm_config.model, + messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs) for m in messages], + tools=tools, + tool_choice=function_call, + user=str(user_id), + ) + + # https://console.groq.com/docs/openai + # "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:" + assert data.top_logprobs is None + assert data.logit_bias is None + assert data.logprobs == False + assert data.n == 1 + # They mention that none of the messages can have names, but it seems to not error out (for now) + + data.stream = False + if isinstance(stream_inferface, AgentChunkStreamingInterface): + stream_inferface.stream_start() + try: + # groq uses the openai chat completions API, so this component should be reusable + assert credentials.groq_key is not None, "Groq key is missing" + response = openai_chat_completions_request( + url=llm_config.model_endpoint, + api_key=credentials.groq_key, + chat_completion_request=data, + ) + finally: + if isinstance(stream_inferface, AgentChunkStreamingInterface): + stream_inferface.stream_end() + + if inner_thoughts_in_kwargs: + response = unpack_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) + + return response + # local model else: if stream: diff --git a/letta/local_llm/chat_completion_proxy.py b/letta/local_llm/chat_completion_proxy.py index bdec58b6..25b91420 100644 --- a/letta/local_llm/chat_completion_proxy.py +++ b/letta/local_llm/chat_completion_proxy.py @@ -12,7 +12,6 @@ from letta.local_llm.grammars.gbnf_grammar_generator import ( create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation, ) -from letta.local_llm.groq.api import get_groq_completion from letta.local_llm.koboldcpp.api import get_koboldcpp_completion from letta.local_llm.llamacpp.api import get_llamacpp_completion from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper @@ -170,8 +169,6 @@ def get_chat_completion( result, usage = get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_window) elif endpoint_type == "vllm": result, usage = get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user) - elif endpoint_type == "groq": - result, usage = get_groq_completion(endpoint, auth_type, auth_key, model, prompt, context_window) else: raise LocalLLMError( f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" diff --git a/letta/local_llm/groq/api.py b/letta/local_llm/groq/api.py deleted file mode 100644 index b46ddf61..00000000 --- a/letta/local_llm/groq/api.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Tuple -from urllib.parse import urljoin - -from letta.local_llm.settings.settings import get_completions_settings -from letta.local_llm.utils import post_json_auth_request -from letta.utils import count_tokens - -API_CHAT_SUFFIX = "/v1/chat/completions" -# LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" - - -def get_groq_completion(endpoint: str, auth_type: str, auth_key: str, model: str, prompt: str, context_window: int) -> Tuple[str, dict]: - """TODO no support for function calling OR raw completions, so we need to route the request into /chat/completions instead""" - from letta.utils import printd - - prompt_tokens = count_tokens(prompt) - if prompt_tokens > context_window: - raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") - - settings = get_completions_settings() - settings.update( - { - # see https://console.groq.com/docs/text-chat, supports: - # "temperature": , - # "max_tokens": , - # "top_p", - # "stream", - # "stop", - # Groq only allows 4 stop tokens - "stop": [ - "\nUSER", - "\nASSISTANT", - "\nFUNCTION", - # "\nFUNCTION RETURN", - # "<|im_start|>", - # "<|im_end|>", - # "<|im_sep|>", - # # airoboros specific - # "\n### ", - # # '\n' + - # # '', - # # '<|', - # "\n#", - # # "\n\n\n", - # # prevent chaining function calls / multi json objects / run-on generations - # # NOTE: this requires the ability to patch the extra '}}' back into the prompt - " }\n}\n", - ] - } - ) - - URI = urljoin(endpoint.strip("/") + "/", API_CHAT_SUFFIX.strip("/")) - - # Settings for the generation, includes the prompt + stop tokens, max length, etc - request = settings - request["model"] = model - request["max_tokens"] = context_window - # NOTE: Hack for chat/completion-only endpoints: put the entire completion string inside the first message - message_structure = [{"role": "user", "content": prompt}] - request["messages"] = message_structure - - if not endpoint.startswith(("http://", "https://")): - raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") - - try: - response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) - if response.status_code == 200: - result_full = response.json() - printd(f"JSON API response:\n{result_full}") - result = result_full["choices"][0]["message"]["content"] - usage = result_full.get("usage", None) - else: - # Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"} - if "context length" in str(response.text).lower(): - # "exceeds context length" is what appears in the LM Studio error message - # raise an alternate exception that matches OpenAI's message, which is "maximum context length" - raise Exception(f"Request exceeds maximum context length (code={response.status_code}, msg={response.text}, URI={URI})") - else: - raise Exception( - f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." - + f" Make sure that the inference server is running and reachable at {URI}." - ) - except: - # TODO handle gracefully - raise - - # Pass usage statistics back to main thread - # These are used to compute memory warning messages - completion_tokens = usage.get("completion_tokens", None) if usage is not None else None - total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None - usage = { - "prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0) - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - } - - return result, usage diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index b690b47b..a6e49d8b 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -2,6 +2,7 @@ from typing import List from pydantic import BaseModel, Field +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.message import MessageCreate @@ -21,3 +22,19 @@ class LettaRequest(BaseModel): default=False, description="Set True to return the raw Message object. Set False to return the Message in the format of the Letta API.", ) + + # Flags to support the use of AssistantMessage message types + + use_assistant_message: bool = Field( + default=False, + description="[Only applicable if return_message_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.", + ) + + assistant_message_function_name: str = Field( + default=DEFAULT_MESSAGE_TOOL, + description="[Only applicable if use_assistant_message is True] The name of the designated message tool.", + ) + assistant_message_function_kwarg: str = Field( + default=DEFAULT_MESSAGE_TOOL_KWARG, + description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.", + ) diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 134dff02..d951c2dd 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel, ConfigDict, Field @@ -17,7 +17,23 @@ class LLMConfig(BaseModel): # TODO: 🤮 don't default to a vendor! bug city! model: str = Field(..., description="LLM model name. ") - model_endpoint_type: str = Field(..., description="The endpoint type for the model.") + model_endpoint_type: Literal[ + "openai", + "anthropic", + "cohere", + "google_ai", + "azure", + "groq", + "ollama", + "webui", + "webui-legacy", + "lmstudio", + "lmstudio-legacy", + "llamacpp", + "koboldcpp", + "vllm", + "hugging-face", + ] = Field(..., description="The endpoint type for the model.") model_endpoint: str = Field(..., description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") context_window: int = Field(..., description="The context window size for the model.") diff --git a/letta/schemas/message.py b/letta/schemas/message.py index d3879c0c..70aa9df9 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -6,11 +6,16 @@ from typing import List, Optional from pydantic import Field, field_validator -from letta.constants import TOOL_CALL_ID_MAX_LEN +from letta.constants import ( + DEFAULT_MESSAGE_TOOL, + DEFAULT_MESSAGE_TOOL_KWARG, + TOOL_CALL_ID_MAX_LEN, +) from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.schemas.enums import MessageRole from letta.schemas.letta_base import LettaBase from letta.schemas.letta_message import ( + AssistantMessage, FunctionCall, FunctionCallMessage, FunctionReturn, @@ -122,7 +127,12 @@ class Message(BaseMessage): json_message["created_at"] = self.created_at.isoformat() return json_message - def to_letta_message(self) -> List[LettaMessage]: + def to_letta_message( + self, + assistant_message: bool = False, + assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, + ) -> List[LettaMessage]: """Convert message object (in DB format) to the style used by the original Letta API""" messages = [] @@ -140,16 +150,33 @@ class Message(BaseMessage): if self.tool_calls is not None: # This is type FunctionCall for tool_call in self.tool_calls: - messages.append( - FunctionCallMessage( - id=self.id, - date=self.created_at, - function_call=FunctionCall( - name=tool_call.function.name, - arguments=tool_call.function.arguments, - ), + # If we're supporting using assistant message, + # then we want to treat certain function calls as a special case + if assistant_message and tool_call.function.name == assistant_message_function_name: + # We need to unpack the actual message contents from the function call + try: + func_args = json.loads(tool_call.function.arguments) + message_string = func_args[DEFAULT_MESSAGE_TOOL_KWARG] + except KeyError: + raise ValueError(f"Function call {tool_call.function.name} missing {DEFAULT_MESSAGE_TOOL_KWARG} argument") + messages.append( + AssistantMessage( + id=self.id, + date=self.created_at, + assistant_message=message_string, + ) + ) + else: + messages.append( + FunctionCallMessage( + id=self.id, + date=self.created_at, + function_call=FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) ) - ) elif self.role == MessageRole.tool: # This is type FunctionReturn # Try to interpret the function return, recall that this is how we packaged: diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 0715b901..b8b06d78 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -1,10 +1,12 @@ import asyncio import json import queue +import warnings from collections import deque from datetime import datetime from typing import AsyncGenerator, Literal, Optional, Union +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.interface import AgentInterface from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( @@ -249,7 +251,7 @@ class QueuingInterface(AgentInterface): class FunctionArgumentsStreamHandler: """State machine that can process a stream of""" - def __init__(self, json_key="message"): + def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG): self.json_key = json_key self.reset() @@ -311,7 +313,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface): should maintain multiple generators and index them with the request ID """ - def __init__(self, multi_step=True): + def __init__( + self, + multi_step=True, + use_assistant_message=False, + assistant_message_function_name=DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, + ): # If streaming mode, ignores base interface calls like .assistant_message, etc self.streaming_mode = False # NOTE: flag for supporting legacy 'stream' flag where send_message is treated specially @@ -321,7 +329,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.streaming_chat_completion_mode_function_name = None # NOTE: sadly need to track state during stream # If chat completion mode, we need a special stream reader to # turn function argument to send_message into a normal text stream - self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler() + self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_function_kwarg) self._chunks = deque() self._event = asyncio.Event() # Use an event to notify when chunks are available @@ -333,6 +341,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.multi_step_indicator = MessageStreamStatus.done_step self.multi_step_gen_indicator = MessageStreamStatus.done_generation + # Support for AssistantMessage + self.use_assistant_message = use_assistant_message + self.assistant_message_function_name = assistant_message_function_name + self.assistant_message_function_kwarg = assistant_message_function_kwarg + # extra prints self.debug = False self.timeout = 30 @@ -441,7 +454,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def _process_chunk_to_letta_style( self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime - ) -> Optional[Union[InternalMonologue, FunctionCallMessage]]: + ) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]: """ Example data from non-streaming response looks like: @@ -461,23 +474,83 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=message_date, internal_monologue=message_delta.content, ) + + # tool calls elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0: tool_call = message_delta.tool_calls[0] - tool_call_delta = {} - if tool_call.id: - tool_call_delta["id"] = tool_call.id - if tool_call.function: - if tool_call.function.arguments: - tool_call_delta["arguments"] = tool_call.function.arguments - if tool_call.function.name: - tool_call_delta["name"] = tool_call.function.name + # special case for trapping `send_message` + if self.use_assistant_message and tool_call.function: + + # If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode + + # Track the function name while streaming + # If we were previously on a 'send_message', we need to 'toggle' into 'content' mode + if tool_call.function.name: + if self.streaming_chat_completion_mode_function_name is None: + self.streaming_chat_completion_mode_function_name = tool_call.function.name + else: + self.streaming_chat_completion_mode_function_name += tool_call.function.name + + # If we get a "hit" on the special keyword we're looking for, we want to skip to the next chunk + # TODO I don't think this handles the function name in multi-pieces problem. Instead, we should probably reset the streaming_chat_completion_mode_function_name when we make this hit? + # if self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name: + if tool_call.function.name == self.assistant_message_function_name: + self.streaming_chat_completion_json_reader.reset() + # early exit to turn into content mode + return None + + # if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks + if ( + tool_call.function.arguments + and self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name + ): + # Strip out any extras tokens + cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments) + # In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk + if cleaned_func_args is None: + return None + else: + processed_chunk = AssistantMessage( + id=message_id, + date=message_date, + assistant_message=cleaned_func_args, + ) + + # otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage + else: + tool_call_delta = {} + if tool_call.id: + tool_call_delta["id"] = tool_call.id + if tool_call.function: + if tool_call.function.arguments: + tool_call_delta["arguments"] = tool_call.function.arguments + if tool_call.function.name: + tool_call_delta["name"] = tool_call.function.name + + processed_chunk = FunctionCallMessage( + id=message_id, + date=message_date, + function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + ) + + else: + + tool_call_delta = {} + if tool_call.id: + tool_call_delta["id"] = tool_call.id + if tool_call.function: + if tool_call.function.arguments: + tool_call_delta["arguments"] = tool_call.function.arguments + if tool_call.function.name: + tool_call_delta["name"] = tool_call.function.name + + processed_chunk = FunctionCallMessage( + id=message_id, + date=message_date, + function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + ) - processed_chunk = FunctionCallMessage( - id=message_id, - date=message_date, - function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), - ) elif choice.finish_reason is not None: # skip if there's a finish return None @@ -663,14 +736,32 @@ class StreamingServerInterface(AgentChunkStreamingInterface): else: - processed_chunk = FunctionCallMessage( - id=msg_obj.id, - date=msg_obj.created_at, - function_call=FunctionCall( - name=function_call.function.name, - arguments=function_call.function.arguments, - ), - ) + try: + func_args = json.loads(function_call.function.arguments) + except: + warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}") + func_args = {} + + if ( + self.use_assistant_message + and function_call.function.name == self.assistant_message_function_name + and self.assistant_message_function_kwarg in func_args + ): + processed_chunk = AssistantMessage( + id=msg_obj.id, + date=msg_obj.created_at, + assistant_message=func_args[self.assistant_message_function_kwarg], + ) + else: + processed_chunk = FunctionCallMessage( + id=msg_obj.id, + date=msg_obj.created_at, + function_call=FunctionCall( + name=function_call.function.name, + arguments=function_call.function.arguments, + ), + ) + # processed_chunk = { # "function_call": { # "name": function_call.function.name, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 514db4c0..cf4a8a64 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Query, status from fastapi.responses import JSONResponse, StreamingResponse from starlette.responses import StreamingResponse +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState from letta.schemas.enums import MessageRole, MessageStreamStatus from letta.schemas.letta_message import ( @@ -254,6 +255,19 @@ def get_agent_messages( before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), limit: int = Query(10, description="Maximum number of messages to retrieve."), msg_object: bool = Query(False, description="If true, returns Message objects. If false, return LettaMessage objects."), + # Flags to support the use of AssistantMessage message types + use_assistant_message: bool = Query( + False, + description="[Only applicable if msg_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.", + ), + assistant_message_function_name: str = Query( + DEFAULT_MESSAGE_TOOL, + description="[Only applicable if use_assistant_message is True] The name of the designated message tool.", + ), + assistant_message_function_kwarg: str = Query( + DEFAULT_MESSAGE_TOOL_KWARG, + description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.", + ), ): """ Retrieve message history for an agent. @@ -267,6 +281,9 @@ def get_agent_messages( limit=limit, reverse=True, return_message_object=msg_object, + use_assistant_message=use_assistant_message, + assistant_message_function_name=assistant_message_function_name, + assistant_message_function_kwarg=assistant_message_function_kwarg, ) @@ -310,6 +327,10 @@ async def send_message( stream_steps=request.stream_steps, stream_tokens=request.stream_tokens, return_message_object=request.return_message_object, + # Support for AssistantMessage + use_assistant_message=request.use_assistant_message, + assistant_message_function_name=request.assistant_message_function_name, + assistant_message_function_kwarg=request.assistant_message_function_kwarg, ) @@ -322,12 +343,17 @@ async def send_message_to_agent( message: str, stream_steps: bool, stream_tokens: bool, - return_message_object: bool, # Should be True for Python Client, False for REST API - chat_completion_mode: Optional[bool] = False, - timestamp: Optional[datetime] = None, # related to whether or not we return `LettaMessage`s or `Message`s + return_message_object: bool, # Should be True for Python Client, False for REST API + chat_completion_mode: bool = False, + timestamp: Optional[datetime] = None, + # Support for AssistantMessage + use_assistant_message: bool = False, + assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, ) -> Union[StreamingResponse, LettaResponse]: """Split off into a separate function so that it can be imported in the /chat/completion proxy.""" + # TODO: @charles is this the correct way to handle? include_final_message = True @@ -368,6 +394,11 @@ async def send_message_to_agent( # streaming_interface.allow_assistant_message = stream # streaming_interface.function_call_legacy_mode = stream + # Allow AssistantMessage is desired by client + streaming_interface.use_assistant_message = use_assistant_message + streaming_interface.assistant_message_function_name = assistant_message_function_name + streaming_interface.assistant_message_function_kwarg = assistant_message_function_kwarg + # Offload the synchronous message_func to a separate thread streaming_interface.stream_start() task = asyncio.create_task( @@ -408,6 +439,7 @@ async def send_message_to_agent( message_ids = [m.id for m in filtered_stream] message_ids = deduplicate(message_ids) message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids] + message_objs = [m for m in message_objs if m is not None] return LettaResponse(messages=message_objs, usage=usage) else: return LettaResponse(messages=filtered_stream, usage=usage) diff --git a/letta/server/server.py b/letta/server/server.py index 2f5c9f03..39547ba6 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1262,6 +1262,9 @@ class SyncServer(Server): order: Optional[str] = "asc", reverse: Optional[bool] = False, return_message_object: bool = True, + use_assistant_message: bool = False, + assistant_message_function_name: str = constants.DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, ) -> Union[List[Message], List[LettaMessage]]: if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -1281,9 +1284,25 @@ class SyncServer(Server): if not return_message_object: # If we're GETing messages in reverse, we need to reverse the inner list (generated by to_letta_message) if reverse: - records = [msg for m in records for msg in m.to_letta_message()[::-1]] + records = [ + msg + for m in records + for msg in m.to_letta_message( + assistant_message=use_assistant_message, + assistant_message_function_name=assistant_message_function_name, + assistant_message_function_kwarg=assistant_message_function_kwarg, + )[::-1] + ] else: - records = [msg for m in records for msg in m.to_letta_message()] + records = [ + msg + for m in records + for msg in m.to_letta_message( + assistant_message=use_assistant_message, + assistant_message_function_name=assistant_message_function_name, + assistant_message_function_kwarg=assistant_message_function_kwarg, + ) + ] return records diff --git a/pyproject.toml b/pyproject.toml index d3501385..3406e9a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.1.7" +version = "0.4.1" packages = [ {include = "letta"} ] diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 0df0a2d9..d32503a3 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -255,6 +255,34 @@ def check_agent_archival_memory_retrieval(filename: str) -> LettaResponse: return response +def check_agent_edit_core_memory(filename: str) -> LettaResponse: + """ + Checks that the LLM is able to edit its core memories + + Note: This is acting on the Letta response, note the usage of `user_message` + """ + # Set up client + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + human_name_a = "AngryAardvark" + human_name_b = "BananaBoy" + agent_state = setup_agent(client, filename, memory_human_str=f"My name is {human_name_a}") + client.user_message(agent_id=agent_state.id, message=f"Actually, my name changed. It is now {human_name_b}") + response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.") + + # Basic checks + assert_sanity_checks(response) + + # Make sure my name was repeated back to me + assert_invoked_send_message_with_keyword(response.messages, human_name_b) + + # Make sure some inner monologue is present + assert_inner_monologue_is_present_and_valid(response.messages) + + return response + + def run_embedding_endpoint(filename): # load JSON file config_data = json.load(open(filename, "r")) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index c1938093..de751096 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -2,6 +2,7 @@ import os from tests.helpers.endpoints_helper import ( check_agent_archival_memory_retrieval, + check_agent_edit_core_memory, check_agent_recall_chat_memory, check_agent_uses_external_tool, check_first_response_is_valid_for_llm_endpoint, @@ -53,6 +54,13 @@ def test_openai_gpt_4_archival_memory_retrieval(): print(f"Got successful response from client: \n\n{response}") +def test_openai_gpt_4_edit_core_memory(): + filename = os.path.join(llm_config_dir, "gpt-4.json") + response = check_agent_edit_core_memory(filename) + # Log out successful response + print(f"Got successful response from client: \n\n{response}") + + def test_embedding_endpoint_openai(): filename = os.path.join(embedding_config_dir, "text-embedding-ada-002.json") run_embedding_endpoint(filename) @@ -95,3 +103,12 @@ def test_embedding_endpoint_ollama(): def test_llm_endpoint_anthropic(): filename = os.path.join(llm_config_dir, "anthropic.json") check_first_response_is_valid_for_llm_endpoint(filename) + check_first_response_is_valid_for_llm_endpoint(filename) + + +# ====================================================================================================================== +# GROQ TESTS +# ====================================================================================================================== +def test_llm_endpoint_groq(): + filename = os.path.join(llm_config_dir, "groq.json") + check_first_response_is_valid_for_llm_endpoint(filename) diff --git a/tests/test_server.py b/tests/test_server.py index 67fa58ad..440e9833 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,16 +1,18 @@ import json import uuid +import warnings import pytest import letta.utils as utils -from letta.constants import BASE_TOOLS +from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.enums import MessageRole utils.DEBUG = True from letta.config import LettaConfig from letta.schemas.agent import CreateAgent from letta.schemas.letta_message import ( + AssistantMessage, FunctionCallMessage, FunctionReturn, InternalMonologue, @@ -236,7 +238,14 @@ def test_get_archival_memory(server, user_id, agent_id): assert len(passage_none) == 0 -def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): +def _test_get_messages_letta_format( + server, + user_id, + agent_id, + reverse=False, + # flag that determines whether or not to use AssistantMessage, or just FunctionCallMessage universally + use_assistant_message=False, +): """Reverse is off by default, the GET goes in chronological order""" messages = server.get_agent_recall_cursor( @@ -244,6 +253,8 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): agent_id=agent_id, limit=1000, reverse=reverse, + return_message_object=True, + use_assistant_message=use_assistant_message, ) # messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000) assert all(isinstance(m, Message) for m in messages) @@ -254,6 +265,7 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): limit=1000, reverse=reverse, return_message_object=False, + use_assistant_message=use_assistant_message, ) # letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False) assert all(isinstance(m, LettaMessage) for m in letta_messages) @@ -316,9 +328,30 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): # If there are multiple tool calls, we should have multiple back to back FunctionCallMessages if message.tool_calls is not None: for tool_call in message.tool_calls: - assert isinstance(letta_message, FunctionCallMessage) - letta_message_index += 1 - letta_message = letta_messages[letta_message_index] + + # Try to parse the tool call args + try: + func_args = json.loads(tool_call.function.arguments) + except: + warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}") + func_args = {} + + # If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool + if ( + use_assistant_message + and tool_call.function.name == DEFAULT_MESSAGE_TOOL + and DEFAULT_MESSAGE_TOOL_KWARG in func_args + ): + assert isinstance(letta_message, AssistantMessage) + assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] + + # Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage + else: + assert isinstance(letta_message, FunctionCallMessage) + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] if message.text is not None: assert isinstance(letta_message, InternalMonologue) @@ -341,11 +374,32 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): # If there are multiple tool calls, we should have multiple back to back FunctionCallMessages if message.tool_calls is not None: for tool_call in message.tool_calls: - assert isinstance(letta_message, FunctionCallMessage) - assert tool_call.function.name == letta_message.function_call.name - assert tool_call.function.arguments == letta_message.function_call.arguments - letta_message_index += 1 - letta_message = letta_messages[letta_message_index] + + # Try to parse the tool call args + try: + func_args = json.loads(tool_call.function.arguments) + except: + warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}") + func_args = {} + + # If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool + if ( + use_assistant_message + and tool_call.function.name == DEFAULT_MESSAGE_TOOL + and DEFAULT_MESSAGE_TOOL_KWARG in func_args + ): + assert isinstance(letta_message, AssistantMessage) + assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] + + # Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage + else: + assert isinstance(letta_message, FunctionCallMessage) + assert tool_call.function.name == letta_message.function_call.name + assert tool_call.function.arguments == letta_message.function_call.arguments + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] elif message.role == MessageRole.user: print(f"i={i}, M=user, MM={type(letta_message)}") @@ -374,8 +428,9 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): def test_get_messages_letta_format(server, user_id, agent_id): - _test_get_messages_letta_format(server, user_id, agent_id, reverse=False) - _test_get_messages_letta_format(server, user_id, agent_id, reverse=True) + for reverse in [False, True]: + for assistant_message in [False, True]: + _test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse, use_assistant_message=assistant_message) def test_agent_rethink_rewrite_retry(server, user_id, agent_id):