Edit workflows

This commit is contained in:
Matt Zhou
2024-10-04 16:18:42 -07:00
22 changed files with 539 additions and 170 deletions

View File

@@ -24,6 +24,7 @@ jobs:
- name: Test LLM endpoint
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_llm_endpoint_letta_hosted
continue-on-error: true
- name: Test embedding endpoint
run: |

View File

@@ -35,33 +35,63 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_returns_valid_first_message
continue-on-error: true
- name: Test model sends message with keyword
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_returns_keyword
continue-on-error: true
- name: Test model uses external tool correctly
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_uses_external_tool
continue-on-error: true
- name: Test model recalls chat memory
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_recall_chat_memory
continue-on-error: true
- name: Test model uses `archival_memory_search` to find secret
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_archival_memory_retrieval
continue-on-error: true
- name: Test model can edit core memories
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_openai_gpt_4_edit_core_memory
continue-on-error: true
- name: Test embedding endpoint
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry run pytest -s -vv tests/test_endpoints.py::test_embedding_endpoint_openai
continue-on-error: true
- name: Summarize test results
run: |
echo "Test Results Summary:"
echo "Test first message: $([[ ${{ steps.test_first_message.outcome }} == 'success' ]] && echo ✅ || echo ❌)"
echo "Test model sends message with keyword: $([[ ${{ steps.test_keyword_message.outcome }} == 'success' ]] && echo ✅ || echo ❌)"
echo "Test model uses external tool: $([[ ${{ steps.test_external_tool.outcome }} == 'success' ]] && echo ✅ || echo ❌)"
echo "Test model recalls chat memory: $([[ ${{ steps.test_chat_memory.outcome }} == 'success' ]] && echo ✅ || echo ❌)"
echo "Test model uses 'archival_memory_search' to find secret: $([[ ${{ steps.test_archival_memory.outcome }} == 'success' ]] && echo ✅ || echo ❌)"
echo "Test model can edit core memories: $([[ ${{ steps.test_core_memory.outcome }} == 'success' ]] && echo ✅ || echo ❌)"
echo "Test embedding endpoint: $([[ ${{ steps.test_embedding_endpoint.outcome }} == 'success' ]] && echo ✅ || echo ❌)"
# Check if any test failed
if [[ ${{ steps.test_first_message.outcome }} != 'success' || ${{ steps.test_keyword_message.outcome }} != 'success' || ${{ steps.test_external_tool.outcome }} != 'success' || ${{ steps.test_chat_memory.outcome }} != 'success' || ${{ steps.test_archival_memory.outcome }} != 'success' || ${{ steps.test_core_memory.outcome }} != 'success' || ${{ steps.test_embedding_endpoint.outcome }} != 'success' ]]; then
echo "Some tests failed, setting neutral status."
exit 78
fi

3
.gitignore vendored
View File

@@ -1,6 +1,3 @@
# Letta config files
configs/
# Below are generated by gitignor.io (toptal)
# Created by https://www.toptal.com/developers/gitignore/api/vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection
# Edit at https://www.toptal.com/developers/gitignore?templates=vim,linux,macos,pydev,python,eclipse,pycharm,windows,netbeans,pycharm+all,pycharm+iml,visualstudio,jupyternotebooks,visualstudiocode,xcode,xcodeinjection

View File

@@ -0,0 +1,7 @@
{
"context_window": 8192,
"model": "llama3-groq-70b-8192-tool-use-preview",
"model_endpoint_type": "groq",
"model_endpoint": "https://api.groq.com/openai/v1",
"model_wrapper": null
}

View File

@@ -1,4 +1,4 @@
__version__ = "0.1.7"
__version__ = "0.4.1"
# import clients
from letta.client.admin import Admin

View File

@@ -2,11 +2,11 @@
import time
import uuid
from typing import Annotated
from typing import Annotated, Union
import typer
from letta import create_client
from letta import LocalClient, RESTClient, create_client
from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES
from letta.config import LettaConfig
@@ -17,11 +17,13 @@ from letta.utils import get_human_text, get_persona_text
app = typer.Typer()
def send_message(message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES):
def send_message(
client: Union[LocalClient, RESTClient], message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES
):
try:
print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}"
print(print_msg, end="\r", flush=True)
response = client.user_message(agent_id=agent_id, message=message, return_token_count=True)
response = client.user_message(agent_id=agent_id, message=message)
if turn + 1 == n_tries:
print(" " * len(print_msg), end="\r", flush=True)
@@ -65,7 +67,7 @@ def bench(
agent_id = agent.id
result, msg = send_message(
message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries
client=client, message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries
)
if print_messages:

View File

@@ -126,7 +126,41 @@ def configure_llm_endpoint(config: LettaConfig, credentials: LettaCredentials):
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
provider = "openai"
elif provider == "groq":
groq_user_msg = "Enter your Groq API key (starts with 'gsk-', see https://console.groq.com/keys):"
# check for key
if credentials.groq_key is None:
# allow key to get pulled from env vars
groq_api_key = os.getenv("GROQ_API_KEY", None)
# if we still can't find it, ask for it as input
if groq_api_key is None:
while groq_api_key is None or len(groq_api_key) == 0:
# Ask for API key as input
groq_api_key = questionary.password(groq_user_msg).ask()
if groq_api_key is None:
raise KeyboardInterrupt
credentials.groq_key = groq_api_key
credentials.save()
else:
# Give the user an opportunity to overwrite the key
default_input = shorten_key_middle(credentials.groq_key) if credentials.groq_key.startswith("gsk-") else credentials.groq_key
groq_api_key = questionary.password(
groq_user_msg,
default=default_input,
).ask()
if groq_api_key is None:
raise KeyboardInterrupt
# If the user modified it, use the new one
if groq_api_key != default_input:
credentials.groq_key = groq_api_key
credentials.save()
model_endpoint_type = "groq"
model_endpoint = "https://api.groq.com/openai/v1"
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
elif provider == "azure":
# check for necessary vars
@@ -392,6 +426,12 @@ def get_model_options(
fetched_model_options = cohere_get_model_list(url=model_endpoint, api_key=credentials.cohere_key)
model_options = [obj for obj in fetched_model_options]
elif model_endpoint_type == "groq":
if credentials.groq_key is None:
raise ValueError("Missing Groq API key")
fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=credentials.groq_key, fix_url=True)
model_options = [obj["id"] for obj in fetched_model_options_response["data"]]
else:
# Attempt to do OpenAI endpoint style model fetching
# TODO support local auth with api-key header
@@ -555,10 +595,32 @@ def configure_model(config: LettaConfig, credentials: LettaCredentials, model_en
if model is None:
raise KeyboardInterrupt
# Groq support via /chat/completions + function calling endpoints
elif model_endpoint_type == "groq":
try:
fetched_model_options = get_model_options(
credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
except Exception as e:
# NOTE: if this fails, it means the user's key is probably bad
typer.secho(
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
)
raise e
model = questionary.select(
"Select default model:",
choices=fetched_model_options,
default=fetched_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt
else: # local models
# ask about local auth
if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys
if model_endpoint_type in ["groq-chat-compltions"]: # TODO all llm engines under 'local' that will require api keys
use_local_auth = True
local_auth_type = "bearer_token"
local_auth_key = questionary.password(

View File

@@ -46,6 +46,12 @@ BASE_TOOLS = [
"archival_memory_search",
]
# The name of the tool used to send message to the user
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
# or in cases where the agent has no concept of messaging a user (e.g. a workflow agent)
DEFAULT_MESSAGE_TOOL = "send_message"
DEFAULT_MESSAGE_TOOL_KWARG = "message"
# LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level
LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET}

View File

@@ -31,6 +31,10 @@ class LettaCredentials:
# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
# groq config
groq_key: Optional[str] = os.getenv("GROQ_API_KEY")
# base llm / model
azure_version: Optional[str] = None
azure_endpoint: Optional[str] = None
@@ -77,6 +81,8 @@ class LettaCredentials:
"anthropic_key": get_field(config, "anthropic", "key"),
# cohere
"cohere_key": get_field(config, "cohere", "key"),
# groq
"groq_key": get_field(config, "groq", "key"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
"openllm_key": get_field(config, "openllm", "key"),
@@ -119,6 +125,9 @@ class LettaCredentials:
# cohere
set_field(config, "cohere", "key", self.cohere_key)
# groq
set_field(config, "groq", "key", self.groq_key)
# openllm config
set_field(config, "openllm", "auth_type", self.openllm_auth_type)
set_field(config, "openllm", "key", self.openllm_key)

View File

@@ -44,7 +44,7 @@ from letta.streaming_interface import (
)
from letta.utils import json_dumps
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local"]
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"]
# TODO update to use better types
@@ -335,7 +335,6 @@ def create(
if isinstance(stream_inferface, AgentChunkStreamingInterface):
stream_inferface.stream_start()
try:
response = openai_chat_completions_request(
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
@@ -458,7 +457,7 @@ def create(
chat_completion_request=ChatCompletionRequest(
model="command-r-plus", # TODO
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tools=tools,
tool_choice=function_call,
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
@@ -466,6 +465,60 @@ def create(
),
)
elif llm_config.model_endpoint_type == "groq":
if stream:
raise NotImplementedError(f"Streaming not yet implemented for Groq.")
if credentials.groq_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions":
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"Groq key is missing from letta config file")
# force to true for groq, since they don't support 'content' is non-null
inner_thoughts_in_kwargs = True
if inner_thoughts_in_kwargs:
functions = add_inner_thoughts_to_functions(
functions=functions,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION,
)
tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None
data = ChatCompletionRequest(
model=llm_config.model,
messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs) for m in messages],
tools=tools,
tool_choice=function_call,
user=str(user_id),
)
# https://console.groq.com/docs/openai
# "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:"
assert data.top_logprobs is None
assert data.logit_bias is None
assert data.logprobs == False
assert data.n == 1
# They mention that none of the messages can have names, but it seems to not error out (for now)
data.stream = False
if isinstance(stream_inferface, AgentChunkStreamingInterface):
stream_inferface.stream_start()
try:
# groq uses the openai chat completions API, so this component should be reusable
assert credentials.groq_key is not None, "Groq key is missing"
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=credentials.groq_key,
chat_completion_request=data,
)
finally:
if isinstance(stream_inferface, AgentChunkStreamingInterface):
stream_inferface.stream_end()
if inner_thoughts_in_kwargs:
response = unpack_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
return response
# local model
else:
if stream:

View File

@@ -12,7 +12,6 @@ from letta.local_llm.grammars.gbnf_grammar_generator import (
create_dynamic_model_from_function,
generate_gbnf_grammar_and_documentation,
)
from letta.local_llm.groq.api import get_groq_completion
from letta.local_llm.koboldcpp.api import get_koboldcpp_completion
from letta.local_llm.llamacpp.api import get_llamacpp_completion
from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
@@ -170,8 +169,6 @@ def get_chat_completion(
result, usage = get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_window)
elif endpoint_type == "vllm":
result, usage = get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user)
elif endpoint_type == "groq":
result, usage = get_groq_completion(endpoint, auth_type, auth_key, model, prompt, context_window)
else:
raise LocalLLMError(
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"

View File

@@ -1,97 +0,0 @@
from typing import Tuple
from urllib.parse import urljoin
from letta.local_llm.settings.settings import get_completions_settings
from letta.local_llm.utils import post_json_auth_request
from letta.utils import count_tokens
API_CHAT_SUFFIX = "/v1/chat/completions"
# LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
def get_groq_completion(endpoint: str, auth_type: str, auth_key: str, model: str, prompt: str, context_window: int) -> Tuple[str, dict]:
"""TODO no support for function calling OR raw completions, so we need to route the request into /chat/completions instead"""
from letta.utils import printd
prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
settings = get_completions_settings()
settings.update(
{
# see https://console.groq.com/docs/text-chat, supports:
# "temperature": ,
# "max_tokens": ,
# "top_p",
# "stream",
# "stop",
# Groq only allows 4 stop tokens
"stop": [
"\nUSER",
"\nASSISTANT",
"\nFUNCTION",
# "\nFUNCTION RETURN",
# "<|im_start|>",
# "<|im_end|>",
# "<|im_sep|>",
# # airoboros specific
# "\n### ",
# # '\n' +
# # '</s>',
# # '<|',
# "\n#",
# # "\n\n\n",
# # prevent chaining function calls / multi json objects / run-on generations
# # NOTE: this requires the ability to patch the extra '}}' back into the prompt
" }\n}\n",
]
}
)
URI = urljoin(endpoint.strip("/") + "/", API_CHAT_SUFFIX.strip("/"))
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["model"] = model
request["max_tokens"] = context_window
# NOTE: Hack for chat/completion-only endpoints: put the entire completion string inside the first message
message_structure = [{"role": "user", "content": prompt}]
request["messages"] = message_structure
if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")
try:
response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key)
if response.status_code == 200:
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["choices"][0]["message"]["content"]
usage = result_full.get("usage", None)
else:
# Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"}
if "context length" in str(response.text).lower():
# "exceeds context length" is what appears in the LM Studio error message
# raise an alternate exception that matches OpenAI's message, which is "maximum context length"
raise Exception(f"Request exceeds maximum context length (code={response.status_code}, msg={response.text}, URI={URI})")
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
+ f" Make sure that the inference server is running and reachable at {URI}."
)
except:
# TODO handle gracefully
raise
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
return result, usage

View File

@@ -2,6 +2,7 @@ from typing import List
from pydantic import BaseModel, Field
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.message import MessageCreate
@@ -21,3 +22,19 @@ class LettaRequest(BaseModel):
default=False,
description="Set True to return the raw Message object. Set False to return the Message in the format of the Letta API.",
)
# Flags to support the use of AssistantMessage message types
use_assistant_message: bool = Field(
default=False,
description="[Only applicable if return_message_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.",
)
assistant_message_function_name: str = Field(
default=DEFAULT_MESSAGE_TOOL,
description="[Only applicable if use_assistant_message is True] The name of the designated message tool.",
)
assistant_message_function_kwarg: str = Field(
default=DEFAULT_MESSAGE_TOOL_KWARG,
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
)

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel, ConfigDict, Field
@@ -17,7 +17,23 @@ class LLMConfig(BaseModel):
# TODO: 🤮 don't default to a vendor! bug city!
model: str = Field(..., description="LLM model name. ")
model_endpoint_type: str = Field(..., description="The endpoint type for the model.")
model_endpoint_type: Literal[
"openai",
"anthropic",
"cohere",
"google_ai",
"azure",
"groq",
"ollama",
"webui",
"webui-legacy",
"lmstudio",
"lmstudio-legacy",
"llamacpp",
"koboldcpp",
"vllm",
"hugging-face",
] = Field(..., description="The endpoint type for the model.")
model_endpoint: str = Field(..., description="The endpoint for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.")

View File

@@ -6,11 +6,16 @@ from typing import List, Optional
from pydantic import Field, field_validator
from letta.constants import TOOL_CALL_ID_MAX_LEN
from letta.constants import (
DEFAULT_MESSAGE_TOOL,
DEFAULT_MESSAGE_TOOL_KWARG,
TOOL_CALL_ID_MAX_LEN,
)
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageRole
from letta.schemas.letta_base import LettaBase
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCall,
FunctionCallMessage,
FunctionReturn,
@@ -122,7 +127,12 @@ class Message(BaseMessage):
json_message["created_at"] = self.created_at.isoformat()
return json_message
def to_letta_message(self) -> List[LettaMessage]:
def to_letta_message(
self,
assistant_message: bool = False,
assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
) -> List[LettaMessage]:
"""Convert message object (in DB format) to the style used by the original Letta API"""
messages = []
@@ -140,16 +150,33 @@ class Message(BaseMessage):
if self.tool_calls is not None:
# This is type FunctionCall
for tool_call in self.tool_calls:
messages.append(
FunctionCallMessage(
id=self.id,
date=self.created_at,
function_call=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
# If we're supporting using assistant message,
# then we want to treat certain function calls as a special case
if assistant_message and tool_call.function.name == assistant_message_function_name:
# We need to unpack the actual message contents from the function call
try:
func_args = json.loads(tool_call.function.arguments)
message_string = func_args[DEFAULT_MESSAGE_TOOL_KWARG]
except KeyError:
raise ValueError(f"Function call {tool_call.function.name} missing {DEFAULT_MESSAGE_TOOL_KWARG} argument")
messages.append(
AssistantMessage(
id=self.id,
date=self.created_at,
assistant_message=message_string,
)
)
else:
messages.append(
FunctionCallMessage(
id=self.id,
date=self.created_at,
function_call=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
)
)
elif self.role == MessageRole.tool:
# This is type FunctionReturn
# Try to interpret the function return, recall that this is how we packaged:

View File

@@ -1,10 +1,12 @@
import asyncio
import json
import queue
import warnings
from collections import deque
from datetime import datetime
from typing import AsyncGenerator, Literal, Optional, Union
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.interface import AgentInterface
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
@@ -249,7 +251,7 @@ class QueuingInterface(AgentInterface):
class FunctionArgumentsStreamHandler:
"""State machine that can process a stream of"""
def __init__(self, json_key="message"):
def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG):
self.json_key = json_key
self.reset()
@@ -311,7 +313,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
should maintain multiple generators and index them with the request ID
"""
def __init__(self, multi_step=True):
def __init__(
self,
multi_step=True,
use_assistant_message=False,
assistant_message_function_name=DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
):
# If streaming mode, ignores base interface calls like .assistant_message, etc
self.streaming_mode = False
# NOTE: flag for supporting legacy 'stream' flag where send_message is treated specially
@@ -321,7 +329,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.streaming_chat_completion_mode_function_name = None # NOTE: sadly need to track state during stream
# If chat completion mode, we need a special stream reader to
# turn function argument to send_message into a normal text stream
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler()
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_function_kwarg)
self._chunks = deque()
self._event = asyncio.Event() # Use an event to notify when chunks are available
@@ -333,6 +341,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.multi_step_indicator = MessageStreamStatus.done_step
self.multi_step_gen_indicator = MessageStreamStatus.done_generation
# Support for AssistantMessage
self.use_assistant_message = use_assistant_message
self.assistant_message_function_name = assistant_message_function_name
self.assistant_message_function_kwarg = assistant_message_function_kwarg
# extra prints
self.debug = False
self.timeout = 30
@@ -441,7 +454,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def _process_chunk_to_letta_style(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime
) -> Optional[Union[InternalMonologue, FunctionCallMessage]]:
) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]:
"""
Example data from non-streaming response looks like:
@@ -461,23 +474,83 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=message_date,
internal_monologue=message_delta.content,
)
# tool calls
elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
tool_call = message_delta.tool_calls[0]
tool_call_delta = {}
if tool_call.id:
tool_call_delta["id"] = tool_call.id
if tool_call.function:
if tool_call.function.arguments:
tool_call_delta["arguments"] = tool_call.function.arguments
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
# special case for trapping `send_message`
if self.use_assistant_message and tool_call.function:
# If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode
# Track the function name while streaming
# If we were previously on a 'send_message', we need to 'toggle' into 'content' mode
if tool_call.function.name:
if self.streaming_chat_completion_mode_function_name is None:
self.streaming_chat_completion_mode_function_name = tool_call.function.name
else:
self.streaming_chat_completion_mode_function_name += tool_call.function.name
# If we get a "hit" on the special keyword we're looking for, we want to skip to the next chunk
# TODO I don't think this handles the function name in multi-pieces problem. Instead, we should probably reset the streaming_chat_completion_mode_function_name when we make this hit?
# if self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name:
if tool_call.function.name == self.assistant_message_function_name:
self.streaming_chat_completion_json_reader.reset()
# early exit to turn into content mode
return None
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
if (
tool_call.function.arguments
and self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name
):
# Strip out any extras tokens
cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments)
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
if cleaned_func_args is None:
return None
else:
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
assistant_message=cleaned_func_args,
)
# otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage
else:
tool_call_delta = {}
if tool_call.id:
tool_call_delta["id"] = tool_call.id
if tool_call.function:
if tool_call.function.arguments:
tool_call_delta["arguments"] = tool_call.function.arguments
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
)
else:
tool_call_delta = {}
if tool_call.id:
tool_call_delta["id"] = tool_call.id
if tool_call.function:
if tool_call.function.arguments:
tool_call_delta["arguments"] = tool_call.function.arguments
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
)
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
)
elif choice.finish_reason is not None:
# skip if there's a finish
return None
@@ -663,14 +736,32 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
else:
processed_chunk = FunctionCallMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_call=FunctionCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
),
)
try:
func_args = json.loads(function_call.function.arguments)
except:
warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}")
func_args = {}
if (
self.use_assistant_message
and function_call.function.name == self.assistant_message_function_name
and self.assistant_message_function_kwarg in func_args
):
processed_chunk = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
assistant_message=func_args[self.assistant_message_function_kwarg],
)
else:
processed_chunk = FunctionCallMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_call=FunctionCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
),
)
# processed_chunk = {
# "function_call": {
# "name": function_call.function.name,

View File

@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.responses import StreamingResponse
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.letta_message import (
@@ -254,6 +255,19 @@ def get_agent_messages(
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
msg_object: bool = Query(False, description="If true, returns Message objects. If false, return LettaMessage objects."),
# Flags to support the use of AssistantMessage message types
use_assistant_message: bool = Query(
False,
description="[Only applicable if msg_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.",
),
assistant_message_function_name: str = Query(
DEFAULT_MESSAGE_TOOL,
description="[Only applicable if use_assistant_message is True] The name of the designated message tool.",
),
assistant_message_function_kwarg: str = Query(
DEFAULT_MESSAGE_TOOL_KWARG,
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
),
):
"""
Retrieve message history for an agent.
@@ -267,6 +281,9 @@ def get_agent_messages(
limit=limit,
reverse=True,
return_message_object=msg_object,
use_assistant_message=use_assistant_message,
assistant_message_function_name=assistant_message_function_name,
assistant_message_function_kwarg=assistant_message_function_kwarg,
)
@@ -310,6 +327,10 @@ async def send_message(
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
return_message_object=request.return_message_object,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_function_name=request.assistant_message_function_name,
assistant_message_function_kwarg=request.assistant_message_function_kwarg,
)
@@ -322,12 +343,17 @@ async def send_message_to_agent(
message: str,
stream_steps: bool,
stream_tokens: bool,
return_message_object: bool, # Should be True for Python Client, False for REST API
chat_completion_mode: Optional[bool] = False,
timestamp: Optional[datetime] = None,
# related to whether or not we return `LettaMessage`s or `Message`s
return_message_object: bool, # Should be True for Python Client, False for REST API
chat_completion_mode: bool = False,
timestamp: Optional[datetime] = None,
# Support for AssistantMessage
use_assistant_message: bool = False,
assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
) -> Union[StreamingResponse, LettaResponse]:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
# TODO: @charles is this the correct way to handle?
include_final_message = True
@@ -368,6 +394,11 @@ async def send_message_to_agent(
# streaming_interface.allow_assistant_message = stream
# streaming_interface.function_call_legacy_mode = stream
# Allow AssistantMessage is desired by client
streaming_interface.use_assistant_message = use_assistant_message
streaming_interface.assistant_message_function_name = assistant_message_function_name
streaming_interface.assistant_message_function_kwarg = assistant_message_function_kwarg
# Offload the synchronous message_func to a separate thread
streaming_interface.stream_start()
task = asyncio.create_task(
@@ -408,6 +439,7 @@ async def send_message_to_agent(
message_ids = [m.id for m in filtered_stream]
message_ids = deduplicate(message_ids)
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
message_objs = [m for m in message_objs if m is not None]
return LettaResponse(messages=message_objs, usage=usage)
else:
return LettaResponse(messages=filtered_stream, usage=usage)

View File

@@ -1262,6 +1262,9 @@ class SyncServer(Server):
order: Optional[str] = "asc",
reverse: Optional[bool] = False,
return_message_object: bool = True,
use_assistant_message: bool = False,
assistant_message_function_name: str = constants.DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
) -> Union[List[Message], List[LettaMessage]]:
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -1281,9 +1284,25 @@ class SyncServer(Server):
if not return_message_object:
# If we're GETing messages in reverse, we need to reverse the inner list (generated by to_letta_message)
if reverse:
records = [msg for m in records for msg in m.to_letta_message()[::-1]]
records = [
msg
for m in records
for msg in m.to_letta_message(
assistant_message=use_assistant_message,
assistant_message_function_name=assistant_message_function_name,
assistant_message_function_kwarg=assistant_message_function_kwarg,
)[::-1]
]
else:
records = [msg for m in records for msg in m.to_letta_message()]
records = [
msg
for m in records
for msg in m.to_letta_message(
assistant_message=use_assistant_message,
assistant_message_function_name=assistant_message_function_name,
assistant_message_function_kwarg=assistant_message_function_kwarg,
)
]
return records

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.1.7"
version = "0.4.1"
packages = [
{include = "letta"}
]

View File

@@ -255,6 +255,34 @@ def check_agent_archival_memory_retrieval(filename: str) -> LettaResponse:
return response
def check_agent_edit_core_memory(filename: str) -> LettaResponse:
"""
Checks that the LLM is able to edit its core memories
Note: This is acting on the Letta response, note the usage of `user_message`
"""
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
human_name_a = "AngryAardvark"
human_name_b = "BananaBoy"
agent_state = setup_agent(client, filename, memory_human_str=f"My name is {human_name_a}")
client.user_message(agent_id=agent_state.id, message=f"Actually, my name changed. It is now {human_name_b}")
response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.")
# Basic checks
assert_sanity_checks(response)
# Make sure my name was repeated back to me
assert_invoked_send_message_with_keyword(response.messages, human_name_b)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def run_embedding_endpoint(filename):
# load JSON file
config_data = json.load(open(filename, "r"))

View File

@@ -2,6 +2,7 @@ import os
from tests.helpers.endpoints_helper import (
check_agent_archival_memory_retrieval,
check_agent_edit_core_memory,
check_agent_recall_chat_memory,
check_agent_uses_external_tool,
check_first_response_is_valid_for_llm_endpoint,
@@ -53,6 +54,13 @@ def test_openai_gpt_4_archival_memory_retrieval():
print(f"Got successful response from client: \n\n{response}")
def test_openai_gpt_4_edit_core_memory():
filename = os.path.join(llm_config_dir, "gpt-4.json")
response = check_agent_edit_core_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_embedding_endpoint_openai():
filename = os.path.join(embedding_config_dir, "text-embedding-ada-002.json")
run_embedding_endpoint(filename)
@@ -95,3 +103,12 @@ def test_embedding_endpoint_ollama():
def test_llm_endpoint_anthropic():
filename = os.path.join(llm_config_dir, "anthropic.json")
check_first_response_is_valid_for_llm_endpoint(filename)
check_first_response_is_valid_for_llm_endpoint(filename)
# ======================================================================================================================
# GROQ TESTS
# ======================================================================================================================
def test_llm_endpoint_groq():
filename = os.path.join(llm_config_dir, "groq.json")
check_first_response_is_valid_for_llm_endpoint(filename)

View File

@@ -1,16 +1,18 @@
import json
import uuid
import warnings
import pytest
import letta.utils as utils
from letta.constants import BASE_TOOLS
from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.enums import MessageRole
utils.DEBUG = True
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCallMessage,
FunctionReturn,
InternalMonologue,
@@ -236,7 +238,14 @@ def test_get_archival_memory(server, user_id, agent_id):
assert len(passage_none) == 0
def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False):
def _test_get_messages_letta_format(
server,
user_id,
agent_id,
reverse=False,
# flag that determines whether or not to use AssistantMessage, or just FunctionCallMessage universally
use_assistant_message=False,
):
"""Reverse is off by default, the GET goes in chronological order"""
messages = server.get_agent_recall_cursor(
@@ -244,6 +253,8 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False):
agent_id=agent_id,
limit=1000,
reverse=reverse,
return_message_object=True,
use_assistant_message=use_assistant_message,
)
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
assert all(isinstance(m, Message) for m in messages)
@@ -254,6 +265,7 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False):
limit=1000,
reverse=reverse,
return_message_object=False,
use_assistant_message=use_assistant_message,
)
# letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
assert all(isinstance(m, LettaMessage) for m in letta_messages)
@@ -316,9 +328,30 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False):
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
for tool_call in message.tool_calls:
assert isinstance(letta_message, FunctionCallMessage)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
# Try to parse the tool call args
try:
func_args = json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
func_args = {}
# If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool
if (
use_assistant_message
and tool_call.function.name == DEFAULT_MESSAGE_TOOL
and DEFAULT_MESSAGE_TOOL_KWARG in func_args
):
assert isinstance(letta_message, AssistantMessage)
assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
# Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage
else:
assert isinstance(letta_message, FunctionCallMessage)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
if message.text is not None:
assert isinstance(letta_message, InternalMonologue)
@@ -341,11 +374,32 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False):
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
for tool_call in message.tool_calls:
assert isinstance(letta_message, FunctionCallMessage)
assert tool_call.function.name == letta_message.function_call.name
assert tool_call.function.arguments == letta_message.function_call.arguments
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
# Try to parse the tool call args
try:
func_args = json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
func_args = {}
# If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool
if (
use_assistant_message
and tool_call.function.name == DEFAULT_MESSAGE_TOOL
and DEFAULT_MESSAGE_TOOL_KWARG in func_args
):
assert isinstance(letta_message, AssistantMessage)
assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
# Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage
else:
assert isinstance(letta_message, FunctionCallMessage)
assert tool_call.function.name == letta_message.function_call.name
assert tool_call.function.arguments == letta_message.function_call.arguments
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
elif message.role == MessageRole.user:
print(f"i={i}, M=user, MM={type(letta_message)}")
@@ -374,8 +428,9 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False):
def test_get_messages_letta_format(server, user_id, agent_id):
_test_get_messages_letta_format(server, user_id, agent_id, reverse=False)
_test_get_messages_letta_format(server, user_id, agent_id, reverse=True)
for reverse in [False, True]:
for assistant_message in [False, True]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse, use_assistant_message=assistant_message)
def test_agent_rethink_rewrite_retry(server, user_id, agent_id):