From 4fd04c63fe3b21883f8b6eb3fa04b02433cdfcbe Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 6 Jan 2025 08:46:53 -1000 Subject: [PATCH] chore: Merge OSS (#512) --- letta/__init__.py | 2 +- letta/agent.py | 7 ++++--- letta/cli/cli_config.py | 1 - letta/llm_api/helpers.py | 10 ++++++++-- letta/llm_api/llm_api_tools.py | 11 ++++++++++- letta/schemas/letta_response.py | 2 +- letta/schemas/memory.py | 3 +++ letta/services/agent_manager.py | 2 +- pyproject.toml | 2 +- tests/helpers/endpoints_helper.py | 6 +++++- tests/integration_test_agent_tool_graph.py | 6 +++--- tests/test_server.py | 1 + 12 files changed, 38 insertions(+), 15 deletions(-) diff --git a/letta/__init__.py b/letta/__init__.py index 46390abe..e67194c6 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.6" +__version__ = "0.6.7" # import clients from letta.client.client import LocalClient, RESTClient, create_client diff --git a/letta/agent.py b/letta/agent.py index f61ace90..483d3cb8 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -224,8 +224,8 @@ class Agent(BaseAgent): ) function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" - - self.update_memory_if_change(updated_agent_state.memory) + if updated_agent_state is not None: + self.update_memory_if_change(updated_agent_state.memory) except Exception as e: # Need to catch error here, or else trunction wont happen # TODO: modify to function execution error @@ -238,7 +238,7 @@ class Agent(BaseAgent): def _get_ai_reply( self, message_sequence: List[Message], - function_call: str = "auto", + function_call: Optional[str] = None, first_message: bool = False, stream: bool = False, # TODO move to config? empty_response_retry_limit: int = 3, @@ -1029,6 +1029,7 @@ class Agent(BaseAgent): num_archival_memory=agent_manager_passage_size, num_recall_memory=message_manager_size, num_tokens_external_memory_summary=num_tokens_external_memory_summary, + external_memory_summary=external_memory_summary, # top-level information context_window_size_max=self.agent_state.llm_config.context_window, context_window_size_current=num_tokens_used_total, diff --git a/letta/cli/cli_config.py b/letta/cli/cli_config.py index 8278d553..87e43567 100644 --- a/letta/cli/cli_config.py +++ b/letta/cli/cli_config.py @@ -60,7 +60,6 @@ def list(arg: Annotated[ListChoice, typer.Argument]): table.field_names = ["Name", "Text"] for human in client.list_humans(): table.add_row([human.template_name, human.value.replace("\n", "")[:100]]) - print(table) elif arg == ListChoice.personas: """List all personas""" table.field_names = ["Name", "Text"] diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index 1244b6ff..7c99bbcd 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -250,6 +250,8 @@ def unpack_all_inner_thoughts_from_kwargs( def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) -> Choice: message = choice.message + rewritten_choice = choice # inner thoughts unpacked out of the function + if message.role == "assistant" and message.tool_calls and len(message.tool_calls) >= 1: if len(message.tool_calls) > 1: warnings.warn(f"Unpacking inner thoughts from more than one tool call ({len(message.tool_calls)}) is not supported") @@ -271,14 +273,18 @@ def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) - warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})") new_choice.message.content = inner_thoughts - return new_choice + # update the choice object + rewritten_choice = new_choice else: warnings.warn(f"Did not find inner thoughts in tool call: {str(tool_call)}") - return choice except json.JSONDecodeError as e: warnings.warn(f"Failed to strip inner thoughts from kwargs: {e}") raise e + else: + warnings.warn(f"Did not find tool call in message: {str(message)}") + + return rewritten_choice def is_context_overflow_error(exception: Union[requests.exceptions.RequestException, Exception]) -> bool: diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 030d7375..d83e8699 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -94,7 +94,7 @@ def create( user_id: Optional[str] = None, # option UUID to associate request with functions: Optional[list] = None, functions_python: Optional[dict] = None, - function_call: str = "auto", + function_call: Optional[str] = None, # see: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # hint first_message: bool = False, force_tool_call: Optional[str] = None, # Force a specific tool to be called @@ -132,10 +132,19 @@ def create( # openai if llm_config.model_endpoint_type == "openai": + if model_settings.openai_api_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": # only is a problem if we are *not* using an openai proxy raise LettaConfigurationError(message="OpenAI key is missing from letta config file", missing_fields=["openai_api_key"]) + if function_call is None and functions is not None and len(functions) > 0: + # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice + # TODO(matt) move into LLMConfig + if llm_config.model_endpoint == "https://inference.memgpt.ai": + function_call = "auto" # TODO change to "required" once proxy supports it + else: + function_call = "required" + data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens) if stream: # Client requested token streaming data.stream = True diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index d7019280..fc969d66 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -66,7 +66,7 @@ class LettaResponse(BaseModel): return f'
{html.escape(msg.function_call.name)}({args})
' elif msg.message_type == "tool_call_message": args = format_json(msg.tool_call.arguments) - return f'
{html.escape(msg.function_call.name)}({args})
' + return f'
{html.escape(msg.tool_call.name)}({args})
' elif msg.message_type == "function_return": return_value = format_json(msg.function_return) # return f'
Status: {html.escape(msg.status)}
{return_value}
' diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 797eac57..ab877949 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -30,6 +30,9 @@ class ContextWindowOverview(BaseModel): num_tokens_external_memory_summary: int = Field( ..., description="The number of tokens in the external memory summary (archival + recall metadata)." ) + external_memory_summary: str = Field( + ..., description="The metadata summary of the external memory sources (archival + recall metadata)." + ) # context window breakdown (in tokens) # this should all add up to context_window_size_current diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 5c92f59e..92044a0c 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -388,7 +388,7 @@ class AgentManager: curr_memory_str = agent_state.memory.compile() if curr_memory_str in curr_system_message_openai["content"] and not force: # NOTE: could this cause issues if a block is removed? (substring match would still work) - logger.info( + logger.debug( f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild" ) return agent_state diff --git a/pyproject.toml b/pyproject.toml index 074bd256..8a487c1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.6.6" +version = "0.6.7" packages = [ {include = "letta"} ] diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 82b2ed1c..80014903 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -117,7 +117,11 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet choice = response.choices[0] # Ensure that the first message returns a "send_message" - validator_func = lambda function_call: function_call.name == "send_message" or function_call.name == "archival_memory_search" + validator_func = ( + lambda function_call: function_call.name == "send_message" + or function_call.name == "archival_memory_search" + or function_call.name == "core_memory_append" + ) assert_contains_valid_function_call(choice.message, validator_func) # Assert that the message has an inner monologue diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 64486ad3..654d4a9e 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -38,7 +38,7 @@ def second_secret_word(prev_secret_word: str): prev_secret_word (str): The secret word retrieved from calling first_secret_word. """ if prev_secret_word != "v0iq020i0g": - raise RuntimeError(f"Expected secret {"v0iq020i0g"}, got {prev_secret_word}") + raise RuntimeError(f"Expected secret {'v0iq020i0g'}, got {prev_secret_word}") return "4rwp2b4gxq" @@ -51,7 +51,7 @@ def third_secret_word(prev_secret_word: str): prev_secret_word (str): The secret word retrieved from calling second_secret_word. """ if prev_secret_word != "4rwp2b4gxq": - raise RuntimeError(f"Expected secret {"4rwp2b4gxq"}, got {prev_secret_word}") + raise RuntimeError(f'Expected secret "4rwp2b4gxq", got {prev_secret_word}') return "hj2hwibbqm" @@ -64,7 +64,7 @@ def fourth_secret_word(prev_secret_word: str): prev_secret_word (str): The secret word retrieved from calling third_secret_word. """ if prev_secret_word != "hj2hwibbqm": - raise RuntimeError(f"Expected secret {"hj2hwibbqm"}, got {prev_secret_word}") + raise RuntimeError(f"Expected secret {'hj2hwibbqm'}, got {prev_secret_word}") return "banana" diff --git a/tests/test_server.py b/tests/test_server.py index 2f205a7e..fe0fcdc4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -508,6 +508,7 @@ def test_get_context_window_overview(server: SyncServer, user, agent_id): assert overview.num_archival_memory is not None assert overview.num_recall_memory is not None assert overview.num_tokens_external_memory_summary is not None + assert overview.external_memory_summary is not None assert overview.num_tokens_system is not None assert overview.system_prompt is not None assert overview.num_tokens_core_memory is not None