diff --git a/examples/composio_tool_usage.py b/examples/composio_tool_usage.py deleted file mode 100644 index 89c662b0..00000000 --- a/examples/composio_tool_usage.py +++ /dev/null @@ -1,92 +0,0 @@ -import json -import os -import uuid - -from letta import create_client -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory -from letta.schemas.sandbox_config import SandboxType -from letta.services.sandbox_config_manager import SandboxConfigManager - -""" -Setup here. -""" -# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) -client = create_client() -client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) -client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - -# Generate uuid for agent name for this example -namespace = uuid.NAMESPACE_DNS -agent_uuid = str(uuid.uuid5(namespace, "letta-composio-tooling-example")) - -# Clear all agents -for agent_state in client.list_agents(): - if agent_state.name == agent_uuid: - client.delete_agent(agent_id=agent_state.id) - print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") - - -# Add sandbox env -manager = SandboxConfigManager() -# Ensure you have e2b key set -sandbox_config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=client.user) -manager.create_sandbox_env_var( - SandboxEnvironmentVariableCreate(key="COMPOSIO_API_KEY", value=os.environ.get("COMPOSIO_API_KEY")), - sandbox_config_id=sandbox_config.id, - actor=client.user, -) - - -""" -This example show how you can add Composio tools . - -First, make sure you have Composio and some of the extras downloaded. -``` -poetry install --extras "external-tools" -``` -then setup letta with `letta configure`. - -Aditionally, this example stars a Github repo on your behalf. You will need to configure Composio in your environment. -``` -composio login -composio add github -``` - -Last updated Oct 2, 2024. Please check `composio` documentation for any composio related issues. -""" - - -def main(): - from composio import Action - - # Add the composio tool - tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - - persona = f""" - My name is Letta. - - I am a personal assistant that helps star repos on Github. It is my job to correctly input the owner and repo to the {tool.name} tool based on the user's request. - - Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. - """ - - # Create an agent - agent = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[tool.id]) - print(f"Created agent: {agent.name} with ID {str(agent.id)}") - - # Send a message to the agent - send_message_response = client.user_message(agent_id=agent.id, message="Star a repo composio with owner composiohq on GitHub") - for message in send_message_response.messages: - response_json = json.dumps(message.model_dump(), indent=4) - print(f"{response_json}\n") - - # Delete agent - client.delete_agent(agent_id=agent.id) - print(f"Deleted agent: {agent.name} with ID {str(agent.id)}") - - -if __name__ == "__main__": - main() diff --git a/examples/langchain_tool_usage.py b/examples/langchain_tool_usage.py deleted file mode 100644 index 3ce4eb39..00000000 --- a/examples/langchain_tool_usage.py +++ /dev/null @@ -1,87 +0,0 @@ -import json -import uuid - -from letta import create_client -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory - -""" -This example show how you can add LangChain tools . - -First, make sure you have LangChain and some of the extras downloaded. -For this specific example, you will need `wikipedia` installed. -``` -poetry install --extras "external-tools" -``` -then setup letta with `letta configure`. -""" - - -def main(): - from langchain_community.tools import WikipediaQueryRun - from langchain_community.utilities import WikipediaAPIWrapper - - api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=500) - langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) - - # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - # create tool - # Note the additional_imports_module_attr_map - # We need to pass in a map of all the additional imports necessary to run this tool - # Because an object of type WikipediaAPIWrapper is passed into WikipediaQueryRun to initialize langchain_tool, - # We need to also import WikipediaAPIWrapper - # The map is a mapping of the module name to the attribute name - # langchain_community.utilities.WikipediaAPIWrapper - wikipedia_query_tool = client.load_langchain_tool( - langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"} - ) - tool_name = wikipedia_query_tool.name - - # Confirm that the tool is in - tools = client.list_tools() - assert wikipedia_query_tool.name in [t.name for t in tools] - - # Generate uuid for agent name for this example - namespace = uuid.NAMESPACE_DNS - agent_uuid = str(uuid.uuid5(namespace, "letta-langchain-tooling-example")) - - # Clear all agents - for agent_state in client.list_agents(): - if agent_state.name == agent_uuid: - client.delete_agent(agent_id=agent_state.id) - print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") - - # google search persona - persona = f""" - - My name is Letta. - - I am a personal assistant who answers a user's questions using wikipedia searches. When a user asks me a question, I will use a tool called {tool_name} which will search Wikipedia and return a Wikipedia page about the topic. It is my job to construct the best query to input into {tool_name} based on the user's question. - - Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. - """ - - # Create an agent - agent_state = client.create_agent( - name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[wikipedia_query_tool.id] - ) - print(f"Created agent: {agent_state.name} with ID {str(agent_state.id)}") - - # Send a message to the agent - send_message_response = client.user_message(agent_id=agent_state.id, message="Tell me a fun fact about Albert Einstein!") - for message in send_message_response.messages: - response_json = json.dumps(message.model_dump(), indent=4) - print(f"{response_json}\n") - - # Delete agent - client.delete_agent(agent_id=agent_state.id) - print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") - - -if __name__ == "__main__": - main() diff --git a/examples/notebooks/Multi-agent recruiting workflow.ipynb b/examples/notebooks/Multi-agent recruiting workflow.ipynb deleted file mode 100644 index 0b33ca06..00000000 --- a/examples/notebooks/Multi-agent recruiting workflow.ipynb +++ /dev/null @@ -1,884 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "cac06555-9ce8-4f01-bbef-3f8407f4b54d", - "metadata": {}, - "source": [ - "# Multi-agent recruiting workflow \n", - "> Make sure you run the Letta server before running this example using `letta server`\n", - "\n", - "Last tested with letta version `0.5.3`" - ] - }, - { - "cell_type": "markdown", - "id": "aad3a8cc-d17a-4da1-b621-ecc93c9e2106", - "metadata": {}, - "source": [ - "## Section 0: Setup a MemGPT client " - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "7ccd43f2-164b-4d25-8465-894a3bb54c4b", - "metadata": {}, - "outputs": [], - "source": [ - "from letta_client import CreateBlock, Letta, MessageCreate\n", - "\n", - "client = Letta(base_url=\"http://localhost:8283\")" - ] - }, - { - "cell_type": "markdown", - "id": "99a61da5-f069-4538-a548-c7d0f7a70227", - "metadata": {}, - "source": [ - "## Section 1: Shared Memory Block \n", - "Each agent will have both its own memory, and shared memory. The shared memory will contain information about the organization that the agents are all a part of. If one agent updates this memory, the changes will be propaged to the memory of all the other agents. " - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "7770600d-5e83-4498-acf1-05f5bea216c3", - "metadata": {}, - "outputs": [], - "source": [ - "org_description = \"The company is called AgentOS \" \\\n", - "+ \"and is building AI tools to make it easier to create \" \\\n", - "+ \"and deploy LLM agents.\"\n", - "\n", - "org_block = client.blocks.create(\n", - " label=\"company\",\n", - " value=org_description,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "6c3d3a55-870a-4ff0-81c0-4072f783a940", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id=None, id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "org_block" - ] - }, - { - "cell_type": "markdown", - "id": "8448df7b-c321-4d90-ba52-003930a513cb", - "metadata": {}, - "source": [ - "## Section 2: Orchestrating Multiple Agents \n", - "We'll implement a recruiting workflow that involves evaluating an candidate, then if the candidate is a good fit, writing a personalized email on the human's behalf. Since this task involves multiple stages, sometimes breaking the task down to multiple agents can improve performance (though this is not always the case). We will break down the task into: \n", - "\n", - "1. `eval_agent`: This agent is responsible for evaluating candidates based on their resume\n", - "2. `outreach_agent`: This agent is responsible for writing emails to strong candidates\n", - "3. `recruiter_agent`: This agent is responsible for generating leads from a database \n", - "\n", - "Much like humans, these agents will communicate by sending each other messages. We can do this by giving agents that need to communicate with other agents access to a tool that allows them to message other agents. " - ] - }, - { - "cell_type": "markdown", - "id": "a065082a-d865-483c-b721-43c5a4d51afe", - "metadata": {}, - "source": [ - "#### Evaluator Agent\n", - "This agent will have tools to: \n", - "* Read a resume \n", - "* Submit a candidate for outreach (which sends the candidate information to the `outreach_agent`)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "c00232c5-4c37-436c-8ea4-602a31bd84fa", - "metadata": {}, - "outputs": [], - "source": [ - "def read_resume(self, name: str): \n", - " \"\"\"\n", - " Read the resume data for a candidate given the name\n", - "\n", - " Args: \n", - " name (str): Candidate name \n", - "\n", - " Returns: \n", - " resume_data (str): Candidate's resume data \n", - " \"\"\"\n", - " import os\n", - " filepath = os.path.join(\"data\", \"resumes\", name.lower().replace(\" \", \"_\") + \".txt\")\n", - " return open(filepath).read()\n", - "\n", - "def submit_evaluation(self, candidate_name: str, reach_out: bool, resume: str, justification: str): \n", - " \"\"\"\n", - " Submit a candidate for outreach. \n", - "\n", - " Args: \n", - " candidate_name (str): The name of the candidate\n", - " reach_out (bool): Whether to reach out to the candidate\n", - " resume (str): The text representation of the candidate's resume \n", - " justification (str): Justification for reaching out or not\n", - " \"\"\"\n", - " from letta import create_client \n", - " client = create_client()\n", - " message = \"Reach out to the following candidate. \" \\\n", - " + f\"Name: {candidate_name}\\n\" \\\n", - " + f\"Resume Data: {resume}\\n\" \\\n", - " + f\"Justification: {justification}\"\n", - " # NOTE: we will define this agent later \n", - " if reach_out:\n", - " response = client.send_message(\n", - " agent_name=\"outreach_agent\", \n", - " role=\"user\", \n", - " message=message\n", - " ) \n", - " else: \n", - " print(f\"Candidate {candidate_name} is rejected: {justification}\")\n", - "\n", - "# TODO: add an archival andidate tool (provide justification) \n", - "\n", - "read_resume_tool = client.tools.upsert_from_function(func=read_resume) \n", - "submit_evaluation_tool = client.tools.upsert_from_function(func=submit_evaluation)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "12482994-03f4-4dda-8ea2-6492ec28f392", - "metadata": {}, - "outputs": [], - "source": [ - "skills = \"Front-end (React, Typescript), software engineering \" \\\n", - "+ \"(ideally Python), and experience with LLMs.\"\n", - "eval_persona = f\"You are responsible to finding good recruiting \" \\\n", - "+ \"candidates, for the company description. \" \\\n", - "+ f\"Ideal canddiates have skills: {skills}. \" \\\n", - "+ \"Submit your candidate evaluation with the submit_evaluation tool. \"\n", - "\n", - "eval_agent = client.agents.create(\n", - " name=\"eval_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=eval_persona,\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[read_resume_tool.id, submit_evaluation_tool.id]\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\",\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "id": "37c2d0be-b980-426f-ab24-1feaa8ed90ef", - "metadata": {}, - "source": [ - "#### Outreach agent \n", - "This agent will email candidates with customized emails. Since sending emails is a bit complicated, we'll just pretend we sent an email by printing it in the tool call. " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "24e8942f-5b0e-4490-ac5f-f9e1f3178627", - "metadata": {}, - "outputs": [], - "source": [ - "def email_candidate(self, content: str): \n", - " \"\"\"\n", - " Send an email\n", - "\n", - " Args: \n", - " content (str): Content of the email \n", - " \"\"\"\n", - " print(\"Pretend to email:\", content)\n", - " return\n", - "\n", - "email_candidate_tool = client.tools.upsert_from_function(func=email_candidate)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "87416e00-c7a0-4420-be71-e2f5a6404428", - "metadata": {}, - "outputs": [], - "source": [ - "outreach_persona = \"You are responsible for sending outbound emails \" \\\n", - "+ \"on behalf of a company with the send_emails tool to \" \\\n", - "+ \"potential candidates. \" \\\n", - "+ \"If possible, make sure to personalize the email by appealing \" \\\n", - "+ \"to the recipient with details about the company. \" \\\n", - "+ \"You position is `Head Recruiter`, and you go by the name Bob, with contact info bob@gmail.com. \" \\\n", - "+ \"\"\"\n", - "Follow this email template: \n", - "\n", - "Hi , \n", - "\n", - " \n", - "\n", - "Best, \n", - " \n", - " \n", - "\"\"\"\n", - " \n", - "outreach_agent = client.agents.create(\n", - " name=\"outreach_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=outreach_persona,\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[email_candidate_tool.id]\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "f69d38da-807e-4bb1-8adb-f715b24f1c34", - "metadata": {}, - "source": [ - "Next, we'll send a message from the user telling the `leadgen_agent` to evaluate a given candidate: " - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "f09ab5bd-e158-42ee-9cce-43f254c4d2b0", - "metadata": {}, - "outputs": [], - "source": [ - "response = client.agents.messages.send(\n", - " agent_id=eval_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"user\",\n", - " content=\"Candidate: Tony Stark\",\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "cd8f1a1e-21eb-47ae-9eed-b1d3668752ff", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
Checking the resume for Tony Stark to evaluate if he fits the bill for our needs.
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
read_resume({
  \"name\": \"Tony Stark\",
  \"request_heartbeat\"
: true
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"Failed\",
  \"message\"
: \"Error calling function read_resume: [Errno 2] No such file or directory: 'data/resumes/tony_stark.txt'\",
  \"time\"
: \"2024-11-13 05:51:26 PM PST-0800\"
}
\n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
I couldn't retrieve Tony's resume. Need to handle this carefully to keep the conversation flowing.
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
send_message({
  \"message\": \"It looks like I'm having trouble accessing Tony Stark's resume at the moment. Can you provide more details about his qualifications?\"
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"OK\",
  \"message\"
: \"None\",
  \"time\"
: \"2024-11-13 05:51:28 PM PST-0800\"
}
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
USAGE STATISTICS
\n", - "
{
  \"completion_tokens\": 103,
  \"prompt_tokens\": 4999,
  \"total_tokens\": 5102,
  \"step_count\": 2
}
\n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "LettaResponse(messages=[InternalMonologue(id='message-97a1ae82-f8f3-419f-94c4-263112dbc10b', date=datetime.datetime(2024, 11, 14, 1, 51, 26, 799617, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='Checking the resume for Tony Stark to evaluate if he fits the bill for our needs.'), FunctionCallMessage(id='message-97a1ae82-f8f3-419f-94c4-263112dbc10b', date=datetime.datetime(2024, 11, 14, 1, 51, 26, 799617, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='read_resume', arguments='{\\n \"name\": \"Tony Stark\",\\n \"request_heartbeat\": true\\n}', function_call_id='call_wOsiHlU3551JaApHKP7rK4Rt')), FunctionReturn(id='message-97a2b57e-40c6-4f06-a307-a0e3a00717ce', date=datetime.datetime(2024, 11, 14, 1, 51, 26, 803505, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"Failed\",\\n \"message\": \"Error calling function read_resume: [Errno 2] No such file or directory: \\'data/resumes/tony_stark.txt\\'\",\\n \"time\": \"2024-11-13 05:51:26 PM PST-0800\"\\n}', status='error', function_call_id='call_wOsiHlU3551JaApHKP7rK4Rt'), InternalMonologue(id='message-8e249aea-27ce-4788-b3e0-ac4c8401bc93', date=datetime.datetime(2024, 11, 14, 1, 51, 28, 360676, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue=\"I couldn't retrieve Tony's resume. Need to handle this carefully to keep the conversation flowing.\"), FunctionCallMessage(id='message-8e249aea-27ce-4788-b3e0-ac4c8401bc93', date=datetime.datetime(2024, 11, 14, 1, 51, 28, 360676, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='send_message', arguments='{\\n \"message\": \"It looks like I\\'m having trouble accessing Tony Stark\\'s resume at the moment. Can you provide more details about his qualifications?\"\\n}', function_call_id='call_1DoFBhOsP9OCpdPQjUfBcKjw')), FunctionReturn(id='message-5600e8e7-6c6f-482a-8594-a0483ef523a2', date=datetime.datetime(2024, 11, 14, 1, 51, 28, 361921, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:51:28 PM PST-0800\"\\n}', status='success', function_call_id='call_1DoFBhOsP9OCpdPQjUfBcKjw')], usage=LettaUsageStatistics(completion_tokens=103, prompt_tokens=4999, total_tokens=5102, step_count=2))" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "response" - ] - }, - { - "cell_type": "markdown", - "id": "67069247-e603-439c-b2df-9176c4eba957", - "metadata": {}, - "source": [ - "#### Providing feedback to agents \n", - "Since MemGPT agents are persisted, we can provide feedback to agents that is used in future agent executions if we want to modify the future behavior. " - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "19c57d54-a1fe-4244-b765-b996ba9a4788", - "metadata": {}, - "outputs": [], - "source": [ - "feedback = \"Our company pivoted to foundation model training\"\n", - "response = client.agents.messages.send(\n", - " agent_id=eval_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"user\",\n", - " content=feedback,\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "036b973f-209a-4ad9-90e7-fc827b5d92c7", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "feedback = \"The company is also renamed to FoundationAI\"\n", - "response = client.agents.messages.send(\n", - " agent_id=eval_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"user\",\n", - " content=feedback,\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "5d7a7633-35a3-4e41-b44a-be71067dd32a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
Updating the company name to reflect the rebranding. This is important for future candidate evaluations.
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
core_memory_replace({
  \"label\": \"company\",
  \"old_content\"
: \"The company has pivoted to foundation model training.\",
  \"new_content\"
: \"The company is called FoundationAI and has pivoted to foundation model training.\",
  \"request_heartbeat\"
: true
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"OK\",
  \"message\"
: \"None\",
  \"time\"
: \"2024-11-13 05:51:34 PM PST-0800\"
}
\n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
Now I have the updated company info, time to check in on Tony.
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
send_message({
  \"message\": \"Got it, the new name is FoundationAI! What about Tony Stark's background catches your eye for this role? Any particular insights on his skills in front-end development or LLMs?\"
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"OK\",
  \"message\"
: \"None\",
  \"time\"
: \"2024-11-13 05:51:35 PM PST-0800\"
}
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
USAGE STATISTICS
\n", - "
{
  \"completion_tokens\": 146,
  \"prompt_tokens\": 6372,
  \"total_tokens\": 6518,
  \"step_count\": 2
}
\n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "LettaResponse(messages=[InternalMonologue(id='message-0adccea9-4b96-4cbb-b5fc-a9ef0120c646', date=datetime.datetime(2024, 11, 14, 1, 51, 34, 180327, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='Updating the company name to reflect the rebranding. This is important for future candidate evaluations.'), FunctionCallMessage(id='message-0adccea9-4b96-4cbb-b5fc-a9ef0120c646', date=datetime.datetime(2024, 11, 14, 1, 51, 34, 180327, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='core_memory_replace', arguments='{\\n \"label\": \"company\",\\n \"old_content\": \"The company has pivoted to foundation model training.\",\\n \"new_content\": \"The company is called FoundationAI and has pivoted to foundation model training.\",\\n \"request_heartbeat\": true\\n}', function_call_id='call_5s0KTElXdipPidchUu3R9CxI')), FunctionReturn(id='message-a2f278e8-ec23-4e22-a124-c21a0f46f733', date=datetime.datetime(2024, 11, 14, 1, 51, 34, 182291, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:51:34 PM PST-0800\"\\n}', status='success', function_call_id='call_5s0KTElXdipPidchUu3R9CxI'), InternalMonologue(id='message-91f63cb2-b544-4b2e-82b1-b11643df5f93', date=datetime.datetime(2024, 11, 14, 1, 51, 35, 841684, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='Now I have the updated company info, time to check in on Tony.'), FunctionCallMessage(id='message-91f63cb2-b544-4b2e-82b1-b11643df5f93', date=datetime.datetime(2024, 11, 14, 1, 51, 35, 841684, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='send_message', arguments='{\\n \"message\": \"Got it, the new name is FoundationAI! What about Tony Stark\\'s background catches your eye for this role? Any particular insights on his skills in front-end development or LLMs?\"\\n}', function_call_id='call_R4Erx7Pkpr5lepcuaGQU5isS')), FunctionReturn(id='message-813a9306-38fc-4665-9f3b-7c3671fd90e6', date=datetime.datetime(2024, 11, 14, 1, 51, 35, 842423, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:51:35 PM PST-0800\"\\n}', status='success', function_call_id='call_R4Erx7Pkpr5lepcuaGQU5isS')], usage=LettaUsageStatistics(completion_tokens=146, prompt_tokens=6372, total_tokens=6518, step_count=2))" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "response" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "d04d4b3a-6df1-41a9-9a8e-037fbb45836d", - "metadata": {}, - "outputs": [], - "source": [ - "response = client.agents.messages.send(\n", - " agent_id=eval_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"system\",\n", - " content=\"Candidate: Spongebob Squarepants\",\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "c60465f4-7977-4f70-9a75-d2ddebabb0fa", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.\\nThe company is called FoundationAI and has pivoted to foundation model training.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id=None, id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client.agents.core_memory.get_block(agent_id=eval_agent.id, block_label=\"company\")" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "a51c6bb3-225d-47a4-88f1-9a26ff838dd3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id=None, id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client.agents.core_memory.get_block(agent_id=outreach_agent.id, block_label=\"company\")" - ] - }, - { - "cell_type": "markdown", - "id": "8d181b1e-72da-4ebe-a872-293e3ce3a225", - "metadata": {}, - "source": [ - "## Section 3: Adding an orchestrator agent \n", - "So far, we've been triggering the `eval_agent` manually. We can also create an additional agent that is responsible for orchestrating tasks. " - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "80b23d46-ed4b-4457-810a-a819d724e146", - "metadata": {}, - "outputs": [], - "source": [ - "#re-create agents \n", - "client.agents.delete(eval_agent.id)\n", - "client.agents.delete(outreach_agent.id)\n", - "\n", - "org_block = client.blocks.create(\n", - " label=\"company\",\n", - " value=org_description,\n", - ")\n", - "\n", - "eval_agent = client.agents.create(\n", - " name=\"eval_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=eval_persona,\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[read_resume_tool.id, submit_evaluation_tool.id]\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\",\n", - ")\n", - "\n", - "outreach_agent = client.agents.create(\n", - " name=\"outreach_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=outreach_persona,\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[email_candidate_tool.id]\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "a751d0f1-b52d-493c-bca1-67f88011bded", - "metadata": {}, - "source": [ - "The `recruiter_agent` will be linked to the same `org_block` that we created before - we can look up the current data in `org_block` by looking up its ID: " - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "bf6bd419-1504-4513-bc68-d4c717ea8e2d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.\\nThe company is called FoundationAI and has pivoted to foundation model training.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id='user-00000000-0000-4000-8000-000000000000', id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client.blocks.retrieve(block_id=org_block.id)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "e2730626-1685-46aa-9b44-a59e1099e973", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Optional\n", - "\n", - "def search_candidates_db(self, page: int) -> Optional[str]: \n", - " \"\"\"\n", - " Returns 1 candidates per page. \n", - " Page 0 returns the first 1 candidate, \n", - " Page 1 returns the next 1, etc.\n", - " Returns `None` if no candidates remain. \n", - "\n", - " Args: \n", - " page (int): The page number to return candidates from \n", - "\n", - " Returns: \n", - " candidate_names (List[str]): Names of the candidates\n", - " \"\"\"\n", - " \n", - " names = [\"Tony Stark\", \"Spongebob Squarepants\", \"Gautam Fang\"]\n", - " if page >= len(names): \n", - " return None\n", - " return names[page]\n", - "\n", - "def consider_candidate(self, name: str): \n", - " \"\"\"\n", - " Submit a candidate for consideration. \n", - "\n", - " Args: \n", - " name (str): Candidate name to consider \n", - " \"\"\"\n", - " from letta_client import Letta, MessageCreate\n", - " client = Letta(base_url=\"http://localhost:8283\")\n", - " message = f\"Consider candidate {name}\" \n", - " print(\"Sending message to eval agent: \", message)\n", - " response = client.send_message(\n", - " agent_id=eval_agent.id,\n", - " role=\"user\", \n", - " message=message\n", - " ) \n", - "\n", - "\n", - "# create tools \n", - "search_candidate_tool = client.tools.upsert_from_function(func=search_candidates_db)\n", - "consider_candidate_tool = client.tools.upsert_from_function(func=consider_candidate)\n", - "\n", - "# create recruiter agent\n", - "recruiter_agent = client.agents.create(\n", - " name=\"recruiter_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=\"You run a recruiting process for a company. \" \\\n", - " + \"Your job is to continue to pull candidates from the \" \n", - " + \"`search_candidates_db` tool until there are no more \" \\\n", - " + \"candidates left. \" \\\n", - " + \"For each candidate, consider the candidate by calling \"\n", - " + \"the `consider_candidate` tool. \" \\\n", - " + \"You should continue to call `search_candidates_db` \" \\\n", - " + \"followed by `consider_candidate` until there are no more \" \\\n", - " \" candidates. \",\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[search_candidate_tool.id, consider_candidate_tool.id],\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\"\n", - ")\n", - " \n" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "ecfd790c-0018-4fd9-bdaf-5a6b81f70adf", - "metadata": {}, - "outputs": [], - "source": [ - "response = client.agents.messages.send(\n", - " agent_id=recruiter_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"system\",\n", - " content=\"Run generation\",\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "8065c179-cf90-4287-a6e5-8c009807b436", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
New user logged in. Excited to get started!
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
send_message({
  \"message\": \"Welcome! I'm thrilled to have you here. Let’s dive into what you need today!\"
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"OK\",
  \"message\"
: \"None\",
  \"time\"
: \"2024-11-13 05:52:14 PM PST-0800\"
}
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
USAGE STATISTICS
\n", - "
{
  \"completion_tokens\": 48,
  \"prompt_tokens\": 2398,
  \"total_tokens\": 2446,
  \"step_count\": 1
}
\n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "LettaResponse(messages=[InternalMonologue(id='message-8c8ab238-a43e-4509-b7ad-699e9a47ed44', date=datetime.datetime(2024, 11, 14, 1, 52, 14, 780419, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='New user logged in. Excited to get started!'), FunctionCallMessage(id='message-8c8ab238-a43e-4509-b7ad-699e9a47ed44', date=datetime.datetime(2024, 11, 14, 1, 52, 14, 780419, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='send_message', arguments='{\\n \"message\": \"Welcome! I\\'m thrilled to have you here. Let’s dive into what you need today!\"\\n}', function_call_id='call_2OIz7t3oiGsUlhtSneeDslkj')), FunctionReturn(id='message-26c3b7a3-51c8-47ae-938d-a3ed26e42357', date=datetime.datetime(2024, 11, 14, 1, 52, 14, 781455, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:52:14 PM PST-0800\"\\n}', status='success', function_call_id='call_2OIz7t3oiGsUlhtSneeDslkj')], usage=LettaUsageStatistics(completion_tokens=48, prompt_tokens=2398, total_tokens=2446, step_count=1))" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "response" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "4639bbca-e0c5-46a9-a509-56d35d26e97f", - "metadata": {}, - "outputs": [], - "source": [ - "client.agents.delete(agent_id=eval_agent.id)\n", - "client.agents.delete(agent_id=outreach_agent.id)\n", - "client.agents.delete(agent_id=recruiter_agent.id)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "letta", - "language": "python", - "name": "letta" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/swarm/simple.py b/examples/swarm/simple.py deleted file mode 100644 index 8e10c486..00000000 --- a/examples/swarm/simple.py +++ /dev/null @@ -1,72 +0,0 @@ -import typer -from swarm import Swarm - -from letta import EmbeddingConfig, LLMConfig - -""" -This is an example of how to implement the basic example provided by OpenAI for tranferring a conversation between two agents: -https://github.com/openai/swarm/tree/main?tab=readme-ov-file#usage - -Before running this example, make sure you have letta>=0.5.0 installed. This example also runs with OpenAI, though you can also change the model by modifying the code: -```bash -export OPENAI_API_KEY=... -pip install letta -```` -Then, instead the `examples/swarm` directory, run: -```bash -python simple.py -``` -You should see a message output from Agent B. - -""" - - -def transfer_agent_b(self): - """ - Transfer conversation to agent B. - - Returns: - str: name of agent to transfer to - """ - return "agentb" - - -def transfer_agent_a(self): - """ - Transfer conversation to agent A. - - Returns: - str: name of agent to transfer to - """ - return "agenta" - - -swarm = Swarm() - -# set client configs -swarm.client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) -swarm.client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4")) - -# create tools -transfer_a = swarm.client.create_or_update_tool(transfer_agent_a) -transfer_b = swarm.client.create_or_update_tool(transfer_agent_b) - -# create agents -if swarm.client.get_agent_id("agentb"): - swarm.client.delete_agent(swarm.client.get_agent_id("agentb")) -if swarm.client.get_agent_id("agenta"): - swarm.client.delete_agent(swarm.client.get_agent_id("agenta")) -agent_a = swarm.create_agent(name="agentb", tools=[transfer_a.name], instructions="Only speak in haikus") -agent_b = swarm.create_agent(name="agenta", tools=[transfer_b.name]) - -response = swarm.run(agent_name="agenta", message="Transfer me to agent b by calling the transfer_agent_b tool") -print("Response:") -typer.secho(f"{response}", fg=typer.colors.GREEN) - -response = swarm.run(agent_name="agenta", message="My name is actually Sarah. Transfer me to agent b to write a haiku about my name") -print("Response:") -typer.secho(f"{response}", fg=typer.colors.GREEN) - -response = swarm.run(agent_name="agenta", message="Transfer me to agent b - I want a haiku with my name in it") -print("Response:") -typer.secho(f"{response}", fg=typer.colors.GREEN) diff --git a/examples/swarm/swarm.py b/examples/swarm/swarm.py deleted file mode 100644 index 6e0958bf..00000000 --- a/examples/swarm/swarm.py +++ /dev/null @@ -1,111 +0,0 @@ -import json -from typing import List, Optional - -import typer - -from letta import AgentState, EmbeddingConfig, LLMConfig, create_client -from letta.schemas.agent import AgentType -from letta.schemas.memory import BasicBlockMemory, Block - - -class Swarm: - - def __init__(self): - self.agents = [] - self.client = create_client() - - # shared memory block (shared section of context window accross agents) - self.shared_memory = Block(label="human", value="") - - def create_agent( - self, - name: Optional[str] = None, - # agent config - agent_type: Optional[AgentType] = AgentType.memgpt_agent, - # model configs - embedding_config: EmbeddingConfig = None, - llm_config: LLMConfig = None, - # system - system: Optional[str] = None, - # tools - tools: Optional[List[str]] = None, - include_base_tools: Optional[bool] = True, - # instructions - instructions: str = "", - ) -> AgentState: - - # todo: process tools for agent handoff - persona_value = ( - f"You are agent with name {name}. You instructions are {instructions}" - if len(instructions) > 0 - else f"You are agent with name {name}" - ) - persona_block = Block(label="persona", value=persona_value) - memory = BasicBlockMemory(blocks=[persona_block, self.shared_memory]) - - agent = self.client.create_agent( - name=name, - agent_type=agent_type, - embedding_config=embedding_config, - llm_config=llm_config, - system=system, - tools=tools, - include_base_tools=include_base_tools, - memory=memory, - ) - self.agents.append(agent) - - return agent - - def reset(self): - # delete all agents - for agent in self.agents: - self.client.delete_agent(agent.id) - for block in self.client.list_blocks(): - self.client.delete_block(block.id) - - def run(self, agent_name: str, message: str): - - history = [] - while True: - # send message to agent - agent_id = self.client.get_agent_id(agent_name) - - print("Messaging agent: ", agent_name) - print("History size: ", len(history)) - # print(self.client.get_agent(agent_id).tools) - # TODO: implement with sending multiple messages - if len(history) == 0: - response = self.client.send_message(agent_id=agent_id, message=message, role="user") - else: - response = self.client.send_messages(agent_id=agent_id, messages=history) - - # update history - history += response.messages - - # grab responses - messages = [] - for message in response.messages: - messages += message.to_letta_messages() - - # get new agent (see tool call) - # print(messages) - - if len(messages) < 2: - continue - - function_call = messages[-2] - function_return = messages[-1] - if function_call.function_call.name == "send_message": - # return message to use - arg_data = json.loads(function_call.function_call.arguments) - # print(arg_data) - return arg_data["message"] - else: - # swap the agent - return_data = json.loads(function_return.function_return) - agent_name = return_data["message"] - typer.secho(f"Transferring to agent: {agent_name}", fg=typer.colors.RED) - # print("Transferring to agent", agent_name) - - print() diff --git a/examples/tool_rule_usage.py b/examples/tool_rule_usage.py deleted file mode 100644 index 8ec061d0..00000000 --- a/examples/tool_rule_usage.py +++ /dev/null @@ -1,129 +0,0 @@ -import os -import uuid - -from letta import create_client -from letta.schemas.letta_message import ToolCallMessage -from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule -from tests.helpers.endpoints_helper import assert_invoked_send_message_with_keyword, setup_agent -from tests.helpers.utils import cleanup -from tests.test_model_letta_performance import llm_config_dir - -""" -This example shows how you can constrain tool calls in your agent. - -Please note that this currently only works reliably for models with Structured Outputs (e.g. gpt-4o). - -Start by downloading the dependencies. -``` -poetry install --all-extras -``` -""" - -# Tools for this example -# Generate uuid for agent name for this example -namespace = uuid.NAMESPACE_DNS -agent_uuid = str(uuid.uuid5(namespace, "agent_tool_graph")) -config_file = os.path.join(llm_config_dir, "openai-gpt-4o.json") - -"""Contrived tools for this test case""" - - -def first_secret_word(): - """ - Call this to retrieve the first secret word, which you will need for the second_secret_word function. - """ - return "v0iq020i0g" - - -def second_secret_word(prev_secret_word: str): - """ - Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error. - - Args: - prev_secret_word (str): The secret word retrieved from calling first_secret_word. - """ - if prev_secret_word != "v0iq020i0g": - raise RuntimeError(f"Expected secret {"v0iq020i0g"}, got {prev_secret_word}") - - return "4rwp2b4gxq" - - -def third_secret_word(prev_secret_word: str): - """ - Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error. - - Args: - prev_secret_word (str): The secret word retrieved from calling second_secret_word. - """ - if prev_secret_word != "4rwp2b4gxq": - raise RuntimeError(f"Expected secret {"4rwp2b4gxq"}, got {prev_secret_word}") - - return "hj2hwibbqm" - - -def fourth_secret_word(prev_secret_word: str): - """ - Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error. - - Args: - prev_secret_word (str): The secret word retrieved from calling third_secret_word. - """ - if prev_secret_word != "hj2hwibbqm": - raise RuntimeError(f"Expected secret {"hj2hwibbqm"}, got {prev_secret_word}") - - return "banana" - - -def auto_error(): - """ - If you call this function, it will throw an error automatically. - """ - raise RuntimeError("This should never be called.") - - -def main(): - # 1. Set up the client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # 2. Add all the tools to the client - functions = [first_secret_word, second_secret_word, third_secret_word, fourth_secret_word, auto_error] - tools = [] - for func in functions: - tool = client.create_or_update_tool(func) - tools.append(tool) - tool_names = [t.name for t in tools[:-1]] - - # 3. Create the tool rules. It must be called in this order, or there will be an error thrown. - tool_rules = [ - InitToolRule(tool_name="first_secret_word"), - ChildToolRule(tool_name="first_secret_word", children=["second_secret_word"]), - ChildToolRule(tool_name="second_secret_word", children=["third_secret_word"]), - ChildToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]), - ChildToolRule(tool_name="fourth_secret_word", children=["send_message"]), - TerminalToolRule(tool_name="send_message"), - ] - - # 4. Create the agent - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - - # 5. Ask for the final secret word - response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?") - - # 6. Here, we thoroughly check the correctness of the response - tool_names += ["send_message"] # Add send message because we expect this to be called at the end - for m in response.messages: - if isinstance(m, ToolCallMessage): - # Check that it's equal to the first one - assert m.tool_call.name == tool_names[0] - # Pop out first one - tool_names = tool_names[1:] - - # Check final send message contains "banana" - assert_invoked_send_message_with_keyword(response.messages, "banana") - print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) - - -if __name__ == "__main__": - main() diff --git a/letta/__init__.py b/letta/__init__.py index 4ae12788..e462b907 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,7 +1,7 @@ __version__ = "0.7.14" # import clients -from letta.client.client import LocalClient, RESTClient, create_client +from letta.client.client import RESTClient # # imports for easier access from letta.schemas.agent import AgentState diff --git a/letta/__main__.py b/letta/__main__.py deleted file mode 100644 index 89f11424..00000000 --- a/letta/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .main import app - -app() diff --git a/letta/benchmark/benchmark.py b/letta/benchmark/benchmark.py deleted file mode 100644 index 7109210e..00000000 --- a/letta/benchmark/benchmark.py +++ /dev/null @@ -1,98 +0,0 @@ -# type: ignore - -import time -import uuid -from typing import Annotated, Union - -import typer - -from letta import LocalClient, RESTClient, create_client -from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES -from letta.config import LettaConfig - -# from letta.agent import Agent -from letta.errors import LLMJSONParsingError -from letta.utils import get_human_text, get_persona_text - -app = typer.Typer() - - -def send_message( - client: Union[LocalClient, RESTClient], message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES -): - try: - print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}" - print(print_msg, end="\r", flush=True) - response = client.user_message(agent_id=agent_id, message=message) - - if turn + 1 == n_tries: - print(" " * len(print_msg), end="\r", flush=True) - - for r in response: - if "function_call" in r and fn_type in r["function_call"] and any("assistant_message" in re for re in response): - return True, r["function_call"] - - return False, "No function called." - except LLMJSONParsingError as e: - print(f"Error in parsing Letta JSON: {e}") - return False, "Failed to decode valid Letta JSON from LLM output." - except Exception as e: - print(f"An unexpected error occurred: {e}") - return False, "An unexpected error occurred." - - -@app.command() -def bench( - print_messages: Annotated[bool, typer.Option("--messages", help="Print functions calls and messages from the agent.")] = False, - n_tries: Annotated[int, typer.Option("--n-tries", help="Number of benchmark tries to perform for each function.")] = TRIES, -): - client = create_client() - print(f"\nDepending on your hardware, this may take up to 30 minutes. This will also create {n_tries * len(PROMPTS)} new agents.\n") - config = LettaConfig.load() - print(f"version = {config.letta_version}") - - total_score, total_tokens_accumulated, elapsed_time = 0, 0, 0 - - for fn_type, message in PROMPTS.items(): - score = 0 - start_time_run = time.time() - bench_id = uuid.uuid4() - - for i in range(n_tries): - agent = client.create_agent( - name=f"benchmark_{bench_id}_agent_{i}", - persona=get_persona_text(PERSONA), - human=get_human_text(HUMAN), - ) - - agent_id = agent.id - result, msg = send_message( - client=client, message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries - ) - - if print_messages: - print(f"\t{msg}") - - if result: - score += 1 - - # TODO: add back once we start tracking usage via the client - # total_tokens_accumulated += tokens_accumulated - - elapsed_time_run = round(time.time() - start_time_run, 2) - print(f"Score for {fn_type}: {score}/{n_tries}, took {elapsed_time_run} seconds") - - elapsed_time += elapsed_time_run - total_score += score - - print(f"\nMEMGPT VERSION: {config.letta_version}") - print(f"CONTEXT WINDOW: {config.default_llm_config.context_window}") - print(f"MODEL WRAPPER: {config.default_llm_config.model_wrapper}") - print(f"PRESET: {config.preset}") - print(f"PERSONA: {config.persona}") - print(f"HUMAN: {config.human}") - - print( - # f"\n\t-> Total score: {total_score}/{len(PROMPTS) * n_tries}, took {elapsed_time} seconds at average of {round(total_tokens_accumulated/elapsed_time, 2)} t/s\n" - f"\n\t-> Total score: {total_score}/{len(PROMPTS) * n_tries}, took {elapsed_time} seconds\n" - ) diff --git a/letta/benchmark/constants.py b/letta/benchmark/constants.py deleted file mode 100644 index 755fdce5..00000000 --- a/letta/benchmark/constants.py +++ /dev/null @@ -1,14 +0,0 @@ -# Basic -TRIES = 3 -AGENT_NAME = "benchmark" -PERSONA = "sam_pov" -HUMAN = "cs_phd" - -# Prompts -PROMPTS = { - "core_memory_replace": "Hey there, my name is John, what is yours?", - "core_memory_append": "I want you to remember that I like soccers for later.", - "conversation_search": "Do you remember when I talked about bananas?", - "archival_memory_insert": "Can you make sure to remember that I like programming for me so you can look it up later?", - "archival_memory_search": "Can you retrieve information about the war?", -} diff --git a/letta/cli/cli.py b/letta/cli/cli.py index a89d5266..47e86509 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -1,37 +1,15 @@ -import logging import sys from enum import Enum from typing import Annotated, Optional -import questionary import typer -import letta.utils as utils -from letta import create_client -from letta.agent import Agent, save_agent -from letta.config import LettaConfig -from letta.constants import CLI_WARNING_PREFIX, CORE_MEMORY_BLOCK_CHAR_LIMIT, LETTA_DIR, MIN_CONTEXT_WINDOW -from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL from letta.log import get_logger -from letta.schemas.enums import OptionState -from letta.schemas.memory import ChatMemory, Memory - -# from letta.interface import CLIInterface as interface # for printing to terminal from letta.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal -from letta.utils import open_folder_in_explorer, printd logger = get_logger(__name__) -def open_folder(): - """Open a folder viewer of the Letta home directory""" - try: - print(f"Opening home folder: {LETTA_DIR}") - open_folder_in_explorer(LETTA_DIR) - except Exception as e: - print(f"Failed to open folder with system viewer, error:\n{e}") - - class ServerChoice(Enum): rest_api = "rest" ws_api = "websocket" @@ -51,14 +29,6 @@ def server( if type == ServerChoice.rest_api: pass - # if LettaConfig.exists(): - # config = LettaConfig.load() - # MetadataStore(config) - # _ = create_client() # triggers user creation - # else: - # typer.secho(f"No configuration exists. Run letta configure before starting the server.", fg=typer.colors.RED) - # sys.exit(1) - try: from letta.server.rest_api.app import start_server @@ -73,292 +43,6 @@ def server( raise NotImplementedError("WS suppport deprecated") -def run( - persona: Annotated[Optional[str], typer.Option(help="Specify persona")] = None, - agent: Annotated[Optional[str], typer.Option(help="Specify agent name")] = None, - human: Annotated[Optional[str], typer.Option(help="Specify human")] = None, - system: Annotated[Optional[str], typer.Option(help="Specify system prompt (raw text)")] = None, - system_file: Annotated[Optional[str], typer.Option(help="Specify raw text file containing system prompt")] = None, - # model flags - model: Annotated[Optional[str], typer.Option(help="Specify the LLM model")] = None, - model_wrapper: Annotated[Optional[str], typer.Option(help="Specify the LLM model wrapper")] = None, - model_endpoint: Annotated[Optional[str], typer.Option(help="Specify the LLM model endpoint")] = None, - model_endpoint_type: Annotated[Optional[str], typer.Option(help="Specify the LLM model endpoint type")] = None, - context_window: Annotated[ - Optional[int], typer.Option(help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)") - ] = None, - core_memory_limit: Annotated[ - Optional[int], typer.Option(help="The character limit to each core-memory section (human/persona).") - ] = CORE_MEMORY_BLOCK_CHAR_LIMIT, - # other - first: Annotated[bool, typer.Option(help="Use --first to send the first message in the sequence")] = False, - strip_ui: Annotated[bool, typer.Option(help="Remove all the bells and whistles in CLI output (helpful for testing)")] = False, - debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False, - no_verify: Annotated[bool, typer.Option(help="Bypass message verification")] = False, - yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False, - # streaming - stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False, - # whether or not to put the inner thoughts inside the function args - no_content: Annotated[ - OptionState, typer.Option(help="Set to 'yes' for LLM APIs that omit the `content` field during tool calling") - ] = OptionState.DEFAULT, -): - """Start chatting with an Letta agent - - Example usage: `letta run --agent myagent --data-source mydata --persona mypersona --human myhuman --model gpt-3.5-turbo` - - :param persona: Specify persona - :param agent: Specify agent name (will load existing state if the agent exists, or create a new one with that name) - :param human: Specify human - :param model: Specify the LLM model - - """ - - # setup logger - # TODO: remove Utils Debug after global logging is complete. - utils.DEBUG = debug - # TODO: add logging command line options for runtime log level - - from letta.server.server import logger as server_logger - - if debug: - logger.setLevel(logging.DEBUG) - server_logger.setLevel(logging.DEBUG) - else: - logger.setLevel(logging.CRITICAL) - server_logger.setLevel(logging.CRITICAL) - - # load config file - config = LettaConfig.load() - - # read user id from config - client = create_client() - - # determine agent to use, if not provided - if not yes and not agent: - agents = client.list_agents() - agents = [a.name for a in agents] - - if len(agents) > 0: - print() - select_agent = questionary.confirm("Would you like to select an existing agent?").ask() - if select_agent is None: - raise KeyboardInterrupt - if select_agent: - agent = questionary.select("Select agent:", choices=agents).ask() - - # create agent config - if agent: - agent_id = client.get_agent_id(agent) - agent_state = client.get_agent(agent_id) - else: - agent_state = None - human = human if human else config.human - persona = persona if persona else config.persona - if agent and agent_state: # use existing agent - typer.secho(f"\n🔁 Using existing agent {agent}", fg=typer.colors.GREEN) - printd("Loading agent state:", agent_state.id) - printd("Agent state:", agent_state.name) - # printd("State path:", agent_config.save_state_dir()) - # printd("Persistent manager path:", agent_config.save_persistence_manager_dir()) - # printd("Index path:", agent_config.save_agent_index_dir()) - # TODO: load prior agent state - - # Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window) - if model and model != agent_state.llm_config.model: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model {agent_state.llm_config.model} with {model}", fg=typer.colors.YELLOW - ) - agent_state.llm_config.model = model - if context_window is not None and int(context_window) != agent_state.llm_config.context_window: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing context window {agent_state.llm_config.context_window} with {context_window}", - fg=typer.colors.YELLOW, - ) - agent_state.llm_config.context_window = context_window - if model_wrapper and model_wrapper != agent_state.llm_config.model_wrapper: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model wrapper {agent_state.llm_config.model_wrapper} with {model_wrapper}", - fg=typer.colors.YELLOW, - ) - agent_state.llm_config.model_wrapper = model_wrapper - if model_endpoint and model_endpoint != agent_state.llm_config.model_endpoint: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model endpoint {agent_state.llm_config.model_endpoint} with {model_endpoint}", - fg=typer.colors.YELLOW, - ) - agent_state.llm_config.model_endpoint = model_endpoint - if model_endpoint_type and model_endpoint_type != agent_state.llm_config.model_endpoint_type: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model endpoint type {agent_state.llm_config.model_endpoint_type} with {model_endpoint_type}", - fg=typer.colors.YELLOW, - ) - agent_state.llm_config.model_endpoint_type = model_endpoint_type - - # NOTE: commented out because this seems dangerous - instead users should use /systemswap when in the CLI - # # user specified a new system prompt - # if system: - # # NOTE: agent_state.system is the ORIGINAL system prompt, - # # whereas agent_state.state["system"] is the LATEST system prompt - # existing_system_prompt = agent_state.state["system"] if "system" in agent_state.state else None - # if existing_system_prompt != system: - # # override - # agent_state.state["system"] = system - - # Update the agent with any overrides - agent_state = client.update_agent( - agent_id=agent_state.id, - name=agent_state.name, - llm_config=agent_state.llm_config, - embedding_config=agent_state.embedding_config, - ) - - # create agent - letta_agent = Agent(agent_state=agent_state, interface=interface(), user=client.user) - - else: # create new agent - # create new agent config: override defaults with args if provided - typer.secho("\n🧬 Creating new agent...", fg=typer.colors.WHITE) - - agent_name = agent if agent else utils.create_random_username() - - # create agent - client = create_client() - - # choose from list of llm_configs - llm_configs = client.list_llm_configs() - llm_options = [llm_config.model for llm_config in llm_configs] - llm_choices = [questionary.Choice(title=llm_config.pretty_print(), value=llm_config) for llm_config in llm_configs] - - # select model - if len(llm_options) == 0: - raise ValueError("No LLM models found. Please enable a provider.") - elif len(llm_options) == 1: - llm_model_name = llm_options[0] - else: - llm_model_name = questionary.select("Select LLM model:", choices=llm_choices).ask().model - llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0] - - # option to override context window - if llm_config.context_window is not None: - context_window_validator = lambda x: x.isdigit() and int(x) > MIN_CONTEXT_WINDOW and int(x) <= llm_config.context_window - context_window_input = questionary.text( - "Select LLM context window limit (hit enter for default):", - default=str(llm_config.context_window), - validate=context_window_validator, - ).ask() - if context_window_input is not None: - llm_config.context_window = int(context_window_input) - else: - sys.exit(1) - - # choose form list of embedding configs - embedding_configs = client.list_embedding_configs() - embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs] - - embedding_choices = [ - questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs - ] - - # select model - if len(embedding_options) == 0: - raise ValueError("No embedding models found. Please enable a provider.") - elif len(embedding_options) == 1: - embedding_model_name = embedding_options[0] - else: - embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model - embedding_config = [ - embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name - ][0] - - human_obj = client.get_human(client.get_human_id(name=human)) - persona_obj = client.get_persona(client.get_persona_id(name=persona)) - if human_obj is None: - typer.secho(f"Couldn't find human {human} in database, please run `letta add human`", fg=typer.colors.RED) - sys.exit(1) - if persona_obj is None: - typer.secho(f"Couldn't find persona {persona} in database, please run `letta add persona`", fg=typer.colors.RED) - sys.exit(1) - - if system_file: - try: - with open(system_file, "r", encoding="utf-8") as file: - system = file.read().strip() - printd("Loaded system file successfully.") - except FileNotFoundError: - typer.secho(f"System file not found at {system_file}", fg=typer.colors.RED) - system_prompt = system if system else None - - memory = ChatMemory(human=human_obj.value, persona=persona_obj.value, limit=core_memory_limit) - metadata = {"human": human_obj.template_name, "persona": persona_obj.template_name} - - typer.secho(f"-> {ASSISTANT_MESSAGE_CLI_SYMBOL} Using persona profile: '{persona_obj.template_name}'", fg=typer.colors.WHITE) - typer.secho(f"-> 🧑 Using human profile: '{human_obj.template_name}'", fg=typer.colors.WHITE) - - # add tools - agent_state = client.create_agent( - name=agent_name, - system=system_prompt, - embedding_config=embedding_config, - llm_config=llm_config, - memory=memory, - metadata=metadata, - ) - assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}" - typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t.name for t in agent_state.tools])}", fg=typer.colors.WHITE) - - letta_agent = Agent( - interface=interface(), - agent_state=client.get_agent(agent_state.id), - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False, - user=client.user, - ) - save_agent(agent=letta_agent) - typer.secho(f"🎉 Created new agent '{letta_agent.agent_state.name}' (id={letta_agent.agent_state.id})", fg=typer.colors.GREEN) - - # start event loop - from letta.main import run_agent_loop - - print() # extra space - run_agent_loop( - letta_agent=letta_agent, - config=config, - first=first, - no_verify=no_verify, - stream=stream, - ) # TODO: add back no_verify - - -def delete_agent( - agent_name: Annotated[str, typer.Option(help="Specify agent to delete")], -): - """Delete an agent from the database""" - # use client ID is no user_id provided - config = LettaConfig.load() - MetadataStore(config) - client = create_client() - agent = client.get_agent_by_name(agent_name) - if not agent: - typer.secho(f"Couldn't find agent named '{agent_name}' to delete", fg=typer.colors.RED) - sys.exit(1) - - confirm = questionary.confirm(f"Are you sure you want to delete agent '{agent_name}' (id={agent.id})?", default=False).ask() - if confirm is None: - raise KeyboardInterrupt - if not confirm: - typer.secho(f"Cancelled agent deletion '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN) - return - - try: - # delete the agent - client.delete_agent(agent.id) - typer.secho(f"🕊️ Successfully deleted agent '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN) - except Exception: - typer.secho(f"Failed to delete agent '{agent_name}' (id={agent.id})", fg=typer.colors.RED) - sys.exit(1) - - def version() -> str: import letta diff --git a/letta/cli/cli_config.py b/letta/cli/cli_config.py deleted file mode 100644 index a17bf476..00000000 --- a/letta/cli/cli_config.py +++ /dev/null @@ -1,227 +0,0 @@ -import ast -import os -from enum import Enum -from typing import Annotated, List, Optional - -import questionary -import typer -from prettytable.colortable import ColorTable, Themes -from tqdm import tqdm - -import letta.helpers.datetime_helpers - -app = typer.Typer() - - -@app.command() -def configure(): - """Updates default Letta configurations - - This function and quickstart should be the ONLY place where LettaConfig.save() is called - """ - print("`letta configure` has been deprecated. Please see documentation on configuration, and run `letta run` instead.") - - -class ListChoice(str, Enum): - agents = "agents" - humans = "humans" - personas = "personas" - sources = "sources" - - -@app.command() -def list(arg: Annotated[ListChoice, typer.Argument]): - from letta.client.client import create_client - - client = create_client() - table = ColorTable(theme=Themes.OCEAN) - if arg == ListChoice.agents: - """List all agents""" - table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"] - for agent in tqdm(client.list_agents()): - # TODO: add this function - sources = client.list_attached_sources(agent_id=agent.id) - source_names = [source.name for source in sources if source is not None] - table.add_row( - [ - agent.name, - agent.llm_config.model, - agent.embedding_config.embedding_model, - agent.embedding_config.embedding_dim, - agent.memory.get_block("persona").value[:100] + "...", - agent.memory.get_block("human").value[:100] + "...", - ",".join(source_names), - letta.helpers.datetime_helpers.format_datetime(agent.created_at), - ] - ) - print(table) - elif arg == ListChoice.humans: - """List all humans""" - table.field_names = ["Name", "Text"] - for human in client.list_humans(): - table.add_row([human.template_name, human.value.replace("\n", "")[:100]]) - elif arg == ListChoice.personas: - """List all personas""" - table.field_names = ["Name", "Text"] - for persona in client.list_personas(): - table.add_row([persona.template_name, persona.value.replace("\n", "")[:100]]) - print(table) - elif arg == ListChoice.sources: - """List all data sources""" - - # create table - table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At"] - # TODO: eventually look accross all storage connections - # TODO: add data source stats - # TODO: connect to agents - - # get all sources - for source in client.list_sources(): - # get attached agents - table.add_row( - [ - source.name, - source.description, - source.embedding_config.embedding_model, - source.embedding_config.embedding_dim, - letta.helpers.datetime_helpers.format_datetime(source.created_at), - ] - ) - - print(table) - else: - raise ValueError(f"Unknown argument {arg}") - return table - - -@app.command() -def add_tool( - filename: str = typer.Option(..., help="Path to the Python file containing the function"), - name: Optional[str] = typer.Option(None, help="Name of the tool"), - update: bool = typer.Option(True, help="Update the tool if it already exists"), - tags: Optional[List[str]] = typer.Option(None, help="Tags for the tool"), -): - """Add or update a tool from a Python file.""" - from letta.client.client import create_client - - client = create_client() - - # 1. Parse the Python file - with open(filename, "r", encoding="utf-8") as file: - source_code = file.read() - - # 2. Parse the source code to extract the function - # Note: here we assume it is one function only in the file. - module = ast.parse(source_code) - func_def = None - for node in module.body: - if isinstance(node, ast.FunctionDef): - func_def = node - break - - if not func_def: - raise ValueError("No function found in the provided file") - - # 3. Compile the function to make it callable - # Explanation courtesy of GPT-4: - # Compile the AST (Abstract Syntax Tree) node representing the function definition into a code object - # ast.Module creates a module node containing the function definition (func_def) - # compile converts the AST into a code object that can be executed by the Python interpreter - # The exec function executes the compiled code object in the current context, - # effectively defining the function within the current namespace - exec(compile(ast.Module([func_def], []), filename, "exec")) - # Retrieve the function object by evaluating its name in the current namespace - # eval looks up the function name in the current scope and returns the function object - func = eval(func_def.name) - - # 4. Add or update the tool - tool = client.create_or_update_tool(func=func, tags=tags, update=update) - print(f"Tool {tool.name} added successfully") - - -@app.command() -def list_tools(): - """List all available tools.""" - from letta.client.client import create_client - - client = create_client() - - tools = client.list_tools() - for tool in tools: - print(f"Tool: {tool.name}") - - -@app.command() -def add( - option: str, # [human, persona] - name: Annotated[str, typer.Option(help="Name of human/persona")], - text: Annotated[Optional[str], typer.Option(help="Text of human/persona")] = None, - filename: Annotated[Optional[str], typer.Option("-f", help="Specify filename")] = None, -): - """Add a person/human""" - from letta.client.client import create_client - - client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS")) - if filename: # read from file - assert text is None, "Cannot specify both text and filename" - with open(filename, "r", encoding="utf-8") as f: - text = f.read() - else: - assert text is not None, "Must specify either text or filename" - if option == "persona": - persona_id = client.get_persona_id(name) - if persona_id: - client.get_persona(persona_id) - # config if user wants to overwrite - if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask(): - return - client.update_persona(persona_id, text=text) - else: - client.create_persona(name=name, text=text) - - elif option == "human": - human_id = client.get_human_id(name) - if human_id: - human = client.get_human(human_id) - # config if user wants to overwrite - if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask(): - return - client.update_human(human_id, text=text) - else: - human = client.create_human(name=name, text=text) - else: - raise ValueError(f"Unknown kind {option}") - - -@app.command() -def delete(option: str, name: str): - """Delete a source from the archival memory.""" - from letta.client.client import create_client - - client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_API_KEY")) - try: - # delete from metadata - if option == "source": - # delete metadata - source_id = client.get_source_id(name) - assert source_id is not None, f"Source {name} does not exist" - client.delete_source(source_id) - elif option == "agent": - agent_id = client.get_agent_id(name) - assert agent_id is not None, f"Agent {name} does not exist" - client.delete_agent(agent_id=agent_id) - elif option == "human": - human_id = client.get_human_id(name) - assert human_id is not None, f"Human {name} does not exist" - client.delete_human(human_id) - elif option == "persona": - persona_id = client.get_persona_id(name) - assert persona_id is not None, f"Persona {name} does not exist" - client.delete_persona(persona_id) - else: - raise ValueError(f"Option {option} not implemented") - - typer.secho(f"Deleted {option} '{name}'", fg=typer.colors.GREEN) - - except Exception as e: - typer.secho(f"Failed to delete {option}'{name}'\n{e}", fg=typer.colors.RED) diff --git a/letta/cli/cli_load.py b/letta/cli/cli_load.py index 4c420bfa..a50c525e 100644 --- a/letta/cli/cli_load.py +++ b/letta/cli/cli_load.py @@ -8,61 +8,9 @@ letta load --name [ADDITIONAL ARGS] """ -import uuid -from typing import Annotated, List, Optional - -import questionary import typer -from letta import create_client -from letta.data_sources.connectors import DirectoryConnector - app = typer.Typer() default_extensions = "txt,md,pdf" - - -@app.command("directory") -def load_directory( - name: Annotated[str, typer.Option(help="Name of dataset to load.")], - input_dir: Annotated[Optional[str], typer.Option(help="Path to directory containing dataset.")] = None, - input_files: Annotated[List[str], typer.Option(help="List of paths to files containing dataset.")] = [], - recursive: Annotated[bool, typer.Option(help="Recursively search for files in directory.")] = False, - extensions: Annotated[str, typer.Option(help="Comma separated list of file extensions to load")] = default_extensions, - user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove - description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None, -): - client = create_client() - - # create connector - connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions) - - # choose form list of embedding configs - embedding_configs = client.list_embedding_configs() - embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs] - - embedding_choices = [ - questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs - ] - - # select model - if len(embedding_options) == 0: - raise ValueError("No embedding models found. Please enable a provider.") - elif len(embedding_options) == 1: - embedding_model_name = embedding_options[0] - else: - embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model - embedding_config = [ - embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name - ][0] - - # create source - source = client.create_source(name=name, embedding_config=embedding_config) - - # load data - try: - client.load_data(connector, source_name=name) - except Exception as e: - typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED) - client.delete_source(source.id) diff --git a/letta/client/client.py b/letta/client/client.py index 90e39400..d71aae62 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1,27 +1,19 @@ -import asyncio -import logging import sys import time from typing import Callable, Dict, Generator, List, Optional, Union import requests -import letta.utils from letta.constants import ADMIN_PREFIX, BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, FUNCTION_RETURN_CHAR_LIMIT from letta.data_sources.connectors import DataConnector from letta.functions.functions import parse_source_code -from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent from letta.schemas.block import Block, BlockUpdate, CreateBlock, Human, Persona from letta.schemas.embedding_config import EmbeddingConfig # new schemas from letta.schemas.enums import JobStatus, MessageRole -from letta.schemas.environment_variables import ( - SandboxEnvironmentVariable, - SandboxEnvironmentVariableCreate, - SandboxEnvironmentVariableUpdate, -) +from letta.schemas.environment_variables import SandboxEnvironmentVariable from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage, LettaMessageUnion @@ -35,11 +27,10 @@ from letta.schemas.organization import Organization from letta.schemas.passage import Passage from letta.schemas.response_format import ResponseFormatUnion from letta.schemas.run import Run -from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate +from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.schemas.tool_rule import BaseToolRule -from letta.server.rest_api.interface import QueuingInterface from letta.utils import get_human_text, get_persona_text # Print deprecation notice in yellow when module is imported @@ -53,13 +44,6 @@ print( ) -def create_client(base_url: Optional[str] = None, token: Optional[str] = None): - if base_url is None: - return LocalClient() - else: - return RESTClient(base_url, token) - - class AbstractClient(object): def __init__( self, @@ -2229,1539 +2213,3 @@ class RESTClient(AbstractClient): if response.status_code != 200: raise ValueError(f"Failed to get tags: {response.text}") return response.json() - - -class LocalClient(AbstractClient): - """ - A local client for Letta, which corresponds to a single user. - - Attributes: - user_id (str): The user ID. - debug (bool): Whether to print debug information. - interface (QueuingInterface): The interface for the client. - server (SyncServer): The server for the client. - """ - - def __init__( - self, - user_id: Optional[str] = None, - org_id: Optional[str] = None, - debug: bool = False, - default_llm_config: Optional[LLMConfig] = None, - default_embedding_config: Optional[EmbeddingConfig] = None, - ): - """ - Initializes a new instance of Client class. - - Args: - user_id (str): The user ID. - debug (bool): Whether to print debug information. - """ - - from letta.server.server import SyncServer - - # set logging levels - letta.utils.DEBUG = debug - logging.getLogger().setLevel(logging.CRITICAL) - - # save default model config - self._default_llm_config = default_llm_config - self._default_embedding_config = default_embedding_config - - # create server - self.interface = QueuingInterface(debug=debug) - self.server = SyncServer(default_interface_factory=lambda: self.interface) - - # save org_id that `LocalClient` is associated with - if org_id: - self.org_id = org_id - else: - self.org_id = self.server.organization_manager.DEFAULT_ORG_ID - # save user_id that `LocalClient` is associated with - if user_id: - self.user_id = user_id - else: - # get default user - self.user_id = self.server.user_manager.DEFAULT_USER_ID - - self.user = self.server.user_manager.get_user_or_default(self.user_id) - self.organization = self.server.get_organization_or_default(self.org_id) - - # agents - def list_agents( - self, - query_text: Optional[str] = None, - tags: Optional[List[str]] = None, - limit: int = 100, - before: Optional[str] = None, - after: Optional[str] = None, - ) -> List[AgentState]: - self.interface.clear() - - return self.server.agent_manager.list_agents( - actor=self.user, tags=tags, query_text=query_text, limit=limit, before=before, after=after - ) - - def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: - """ - Check if an agent exists - - Args: - agent_id (str): ID of the agent - agent_name (str): Name of the agent - - Returns: - exists (bool): `True` if the agent exists, `False` otherwise - """ - - if not (agent_id or agent_name): - raise ValueError(f"Either agent_id or agent_name must be provided") - if agent_id and agent_name: - raise ValueError(f"Only one of agent_id or agent_name can be provided") - existing = self.list_agents() - if agent_id: - return str(agent_id) in [str(agent.id) for agent in existing] - else: - return agent_name in [str(agent.name) for agent in existing] - - def create_agent( - self, - name: Optional[str] = None, - # agent config - agent_type: Optional[AgentType] = AgentType.memgpt_agent, - # model configs - embedding_config: EmbeddingConfig = None, - llm_config: LLMConfig = None, - # memory - memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)), - block_ids: Optional[List[str]] = None, - # TODO: change to this when we are ready to migrate all the tests/examples (matches the REST API) - # memory_blocks=[ - # {"label": "human", "value": get_human_text(DEFAULT_HUMAN), "limit": 5000}, - # {"label": "persona", "value": get_persona_text(DEFAULT_PERSONA), "limit": 5000}, - # ], - # system - system: Optional[str] = None, - # tools - tool_ids: Optional[List[str]] = None, - tool_rules: Optional[List[BaseToolRule]] = None, - include_base_tools: Optional[bool] = True, - include_multi_agent_tools: bool = False, - include_base_tool_rules: bool = True, - # metadata - metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, - description: Optional[str] = None, - initial_message_sequence: Optional[List[Message]] = None, - tags: Optional[List[str]] = None, - message_buffer_autoclear: bool = False, - response_format: Optional[ResponseFormatUnion] = None, - ) -> AgentState: - """Create an agent - - Args: - name (str): Name of the agent - embedding_config (EmbeddingConfig): Embedding configuration - llm_config (LLMConfig): LLM configuration - memory_blocks (List[Dict]): List of configurations for the memory blocks (placed in core-memory) - system (str): System configuration - tools (List[str]): List of tools - tool_rules (Optional[List[BaseToolRule]]): List of tool rules - include_base_tools (bool): Include base tools - include_multi_agent_tools (bool): Include multi agent tools - metadata (Dict): Metadata - description (str): Description - tags (List[str]): Tags for filtering agents - - Returns: - agent_state (AgentState): State of the created agent - """ - # construct list of tools - tool_ids = tool_ids or [] - - # check if default configs are provided - assert embedding_config or self._default_embedding_config, f"Embedding config must be provided" - assert llm_config or self._default_llm_config, f"LLM config must be provided" - - # TODO: This should not happen here, we need to have clear separation between create/add blocks - for block in memory.get_blocks(): - self.server.block_manager.create_or_update_block(block, actor=self.user) - - # Also get any existing block_ids passed in - block_ids = block_ids or [] - - # create agent - # Create the base parameters - create_params = { - "description": description, - "metadata": metadata, - "memory_blocks": [], - "block_ids": [b.id for b in memory.get_blocks()] + block_ids, - "tool_ids": tool_ids, - "tool_rules": tool_rules, - "include_base_tools": include_base_tools, - "include_multi_agent_tools": include_multi_agent_tools, - "include_base_tool_rules": include_base_tool_rules, - "system": system, - "agent_type": agent_type, - "llm_config": llm_config if llm_config else self._default_llm_config, - "embedding_config": embedding_config if embedding_config else self._default_embedding_config, - "initial_message_sequence": initial_message_sequence, - "tags": tags, - "message_buffer_autoclear": message_buffer_autoclear, - "response_format": response_format, - } - - # Only add name if it's not None - if name is not None: - create_params["name"] = name - - agent_state = self.server.create_agent( - CreateAgent(**create_params), - actor=self.user, - ) - - # TODO: get full agent state - return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user) - - def update_agent( - self, - agent_id: str, - name: Optional[str] = None, - description: Optional[str] = None, - system: Optional[str] = None, - tool_ids: Optional[List[str]] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict] = None, - llm_config: Optional[LLMConfig] = None, - embedding_config: Optional[EmbeddingConfig] = None, - message_ids: Optional[List[str]] = None, - response_format: Optional[ResponseFormatUnion] = None, - ): - """ - Update an existing agent - - Args: - agent_id (str): ID of the agent - name (str): Name of the agent - description (str): Description of the agent - system (str): System configuration - tools (List[str]): List of tools - metadata (Dict): Metadata - llm_config (LLMConfig): LLM configuration - embedding_config (EmbeddingConfig): Embedding configuration - message_ids (List[str]): List of message IDs - tags (List[str]): Tags for filtering agents - - Returns: - agent_state (AgentState): State of the updated agent - """ - # TODO: add the ability to reset linked block_ids - self.interface.clear() - agent_state = self.server.agent_manager.update_agent( - agent_id, - UpdateAgent( - name=name, - system=system, - tool_ids=tool_ids, - tags=tags, - description=description, - metadata=metadata, - llm_config=llm_config, - embedding_config=embedding_config, - message_ids=message_ids, - response_format=response_format, - ), - actor=self.user, - ) - return agent_state - - def get_tools_from_agent(self, agent_id: str) -> List[Tool]: - """ - Get tools from an existing agent. - - Args: - agent_id (str): ID of the agent - - Returns: - List[Tool]: A list of Tool objs - """ - self.interface.clear() - return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user).tools - - def attach_tool(self, agent_id: str, tool_id: str) -> AgentState: - """ - Add tool to an existing agent - - Args: - agent_id (str): ID of the agent - tool_id (str): A tool id - - Returns: - agent_state (AgentState): State of the updated agent - """ - self.interface.clear() - agent_state = self.server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user) - return agent_state - - def detach_tool(self, agent_id: str, tool_id: str) -> AgentState: - """ - Removes tools from an existing agent - - Args: - agent_id (str): ID of the agent - tool_id (str): The tool id - - Returns: - agent_state (AgentState): State of the updated agent - """ - self.interface.clear() - agent_state = self.server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user) - return agent_state - - def rename_agent(self, agent_id: str, new_name: str) -> AgentState: - """ - Rename an agent - - Args: - agent_id (str): ID of the agent - new_name (str): New name for the agent - - Returns: - agent_state (AgentState): State of the updated agent - """ - return self.update_agent(agent_id, name=new_name) - - def delete_agent(self, agent_id: str) -> None: - """ - Delete an agent - - Args: - agent_id (str): ID of the agent to delete - """ - self.server.agent_manager.delete_agent(agent_id=agent_id, actor=self.user) - - def get_agent_by_name(self, agent_name: str) -> AgentState: - """ - Get an agent by its name - - Args: - agent_name (str): Name of the agent - - Returns: - agent_state (AgentState): State of the agent - """ - self.interface.clear() - return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user) - - def get_agent(self, agent_id: str) -> AgentState: - """ - Get an agent's state by its ID. - - Args: - agent_id (str): ID of the agent - - Returns: - agent_state (AgentState): State representation of the agent - """ - self.interface.clear() - return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user) - - def get_agent_id(self, agent_name: str) -> Optional[str]: - """ - Get the ID of an agent by name (names are unique per user) - - Args: - agent_name (str): Name of the agent - - Returns: - agent_id (str): ID of the agent - """ - - self.interface.clear() - assert agent_name, f"Agent name must be provided" - - # TODO: Refactor this futher to not have downstream users expect Optionals - this should just error - try: - return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user).id - except NoResultFound: - return None - - # memory - def get_in_context_memory(self, agent_id: str) -> Memory: - """ - Get the in-context (i.e. core) memory of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - memory (Memory): In-context memory of the agent - """ - memory = self.server.get_agent_memory(agent_id=agent_id, actor=self.user) - return memory - - def get_core_memory(self, agent_id: str) -> Memory: - return self.get_in_context_memory(agent_id) - - def update_in_context_memory(self, agent_id: str, section: str, value: Union[List[str], str]) -> Memory: - """ - Update the in-context memory of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - memory (Memory): The updated in-context memory of the agent - - """ - # TODO: implement this (not sure what it should look like) - memory = self.server.update_agent_core_memory(agent_id=agent_id, label=section, value=value, actor=self.user) - return memory - - def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary: - """ - Get a summary of the archival memory of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - summary (ArchivalMemorySummary): Summary of the archival memory - - """ - return self.server.get_archival_memory_summary(agent_id=agent_id, actor=self.user) - - def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary: - """ - Get a summary of the recall memory of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - summary (RecallMemorySummary): Summary of the recall memory - """ - return self.server.get_recall_memory_summary(agent_id=agent_id, actor=self.user) - - def get_in_context_messages(self, agent_id: str) -> List[Message]: - """ - Get in-context messages of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - messages (List[Message]): List of in-context messages - """ - return self.server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=self.user) - - # agent interactions - - def send_messages( - self, - agent_id: str, - messages: List[Union[Message | MessageCreate]], - ): - """ - Send pre-packed messages to an agent. - - Args: - agent_id (str): ID of the agent - messages (List[Union[Message | MessageCreate]]): List of messages to send - - Returns: - response (LettaResponse): Response from the agent - """ - self.interface.clear() - usage = self.server.send_messages(actor=self.user, agent_id=agent_id, input_messages=messages) - - # format messages - return LettaResponse(messages=messages, usage=usage) - - def send_message( - self, - message: str, - role: str, - name: Optional[str] = None, - agent_id: Optional[str] = None, - agent_name: Optional[str] = None, - stream_steps: bool = False, - stream_tokens: bool = False, - ) -> LettaResponse: - """ - Send a message to an agent - - Args: - message (str): Message to send - role (str): Role of the message - agent_id (str): ID of the agent - name(str): Name of the sender - stream (bool): Stream the response (default: `False`) - - Returns: - response (LettaResponse): Response from the agent - """ - if not agent_id: - # lookup agent by name - assert agent_name, f"Either agent_id or agent_name must be provided" - agent_id = self.get_agent_id(agent_name=agent_name) - assert agent_id, f"Agent with name {agent_name} not found" - - if stream_steps or stream_tokens: - # TODO: implement streaming with stream=True/False - raise NotImplementedError - self.interface.clear() - - usage = self.server.send_messages( - actor=self.user, - agent_id=agent_id, - input_messages=[MessageCreate(role=MessageRole(role), content=message, name=name)], - ) - - ## TODO: need to make sure date/timestamp is propely passed - ## TODO: update self.interface.to_list() to return actual Message objects - ## here, the message objects will have faulty created_by timestamps - # messages = self.interface.to_list() - # for m in messages: - # assert isinstance(m, Message), f"Expected Message object, got {type(m)}" - # letta_messages = [] - # for m in messages: - # letta_messages += m.to_letta_messages() - # return LettaResponse(messages=letta_messages, usage=usage) - - # format messages - messages = self.interface.to_list() - letta_messages = [] - for m in messages: - letta_messages += m.to_letta_messages() - - return LettaResponse(messages=letta_messages, usage=usage) - - def user_message(self, agent_id: str, message: str) -> LettaResponse: - """ - Send a message to an agent as a user - - Args: - agent_id (str): ID of the agent - message (str): Message to send - - Returns: - response (LettaResponse): Response from the agent - """ - self.interface.clear() - return self.send_message(role="user", agent_id=agent_id, message=message) - - def run_command(self, agent_id: str, command: str) -> LettaResponse: - """ - Run a command on the agent - - Args: - agent_id (str): The agent ID - command (str): The command to run - - Returns: - LettaResponse: The response from the agent - - """ - self.interface.clear() - usage = self.server.run_command(user_id=self.user_id, agent_id=agent_id, command=command) - - # NOTE: messages/usage may be empty, depending on the command - return LettaResponse(messages=self.interface.to_list(), usage=usage) - - # archival memory - - # humans / personas - - def get_block_id(self, name: str, label: str) -> str | None: - return None - - def create_human(self, name: str, text: str): - """ - Create a human block template (saved human string to pre-fill `ChatMemory`) - - Args: - name (str): Name of the human block - text (str): Text of the human block - - Returns: - human (Human): Human block - """ - return self.server.block_manager.create_or_update_block(Human(template_name=name, value=text), actor=self.user) - - def create_persona(self, name: str, text: str): - """ - Create a persona block template (saved persona string to pre-fill `ChatMemory`) - - Args: - name (str): Name of the persona block - text (str): Text of the persona block - - Returns: - persona (Persona): Persona block - """ - return self.server.block_manager.create_or_update_block(Persona(template_name=name, value=text), actor=self.user) - - def list_humans(self): - """ - List available human block templates - - Returns: - humans (List[Human]): List of human blocks - """ - return [] - - def list_personas(self) -> List[Persona]: - """ - List available persona block templates - - Returns: - personas (List[Persona]): List of persona blocks - """ - return [] - - def update_human(self, human_id: str, text: str): - """ - Update a human block template - - Args: - human_id (str): ID of the human block - text (str): Text of the human block - - Returns: - human (Human): Updated human block - """ - return self.server.block_manager.update_block( - block_id=human_id, block_update=UpdateHuman(value=text, is_template=True), actor=self.user - ) - - def update_persona(self, persona_id: str, text: str): - """ - Update a persona block template - - Args: - persona_id (str): ID of the persona block - text (str): Text of the persona block - - Returns: - persona (Persona): Updated persona block - """ - return self.server.block_manager.update_block( - block_id=persona_id, block_update=UpdatePersona(value=text, is_template=True), actor=self.user - ) - - def get_persona(self, id: str) -> Persona: - """ - Get a persona block template - - Args: - id (str): ID of the persona block - - Returns: - persona (Persona): Persona block - """ - assert id, f"Persona ID must be provided" - return Persona(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump()) - - def get_human(self, id: str) -> Human: - """ - Get a human block template - - Args: - id (str): ID of the human block - - Returns: - human (Human): Human block - """ - assert id, f"Human ID must be provided" - return Human(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump()) - - def get_persona_id(self, name: str) -> str | None: - """ - Get the ID of a persona block template - - Args: - name (str): Name of the persona block - - Returns: - id (str): ID of the persona block - """ - return None - - def get_human_id(self, name: str) -> str | None: - """ - Get the ID of a human block template - - Args: - name (str): Name of the human block - - Returns: - id (str): ID of the human block - """ - return None - - def delete_persona(self, id: str): - """ - Delete a persona block template - - Args: - id (str): ID of the persona block - """ - self.delete_block(id) - - def delete_human(self, id: str): - """ - Delete a human block template - - Args: - id (str): ID of the human block - """ - self.delete_block(id) - - # tools - def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: - tool_create = ToolCreate.from_langchain( - langchain_tool=langchain_tool, - additional_imports_module_attr_map=additional_imports_module_attr_map, - ) - return self.server.tool_manager.create_or_update_langchain_tool(tool_create=tool_create, actor=self.user) - - def load_composio_tool(self, action: "ActionType") -> Tool: - tool_create = ToolCreate.from_composio(action_name=action.name) - return self.server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=self.user) - - def create_tool( - self, - func, - tags: Optional[List[str]] = None, - description: Optional[str] = None, - return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, - ) -> Tool: - """ - Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. - - Args: - func (callable): The function to create a tool for. - tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. - description (str, optional): The description. - return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. - - Returns: - tool (Tool): The created tool. - """ - # TODO: check if tool already exists - # TODO: how to load modules? - # parse source code/schema - source_code = parse_source_code(func) - source_type = "python" - name = func.__name__ # Initialize name using function's __name__ - if not tags: - tags = [] - - # call server function - return self.server.tool_manager.create_tool( - Tool( - source_type=source_type, - source_code=source_code, - name=name, - tags=tags, - description=description, - return_char_limit=return_char_limit, - ), - actor=self.user, - ) - - def create_or_update_tool( - self, - func, - tags: Optional[List[str]] = None, - description: Optional[str] = None, - return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, - ) -> Tool: - """ - Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. - - Args: - func (callable): The function to create a tool for. - tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. - description (str, optional): The description. - return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. - - Returns: - tool (Tool): The created tool. - """ - source_code = parse_source_code(func) - source_type = "python" - if not tags: - tags = [] - - # call server function - return self.server.tool_manager.create_or_update_tool( - Tool( - source_type=source_type, - source_code=source_code, - tags=tags, - description=description, - return_char_limit=return_char_limit, - ), - actor=self.user, - ) - - def update_tool( - self, - id: str, - description: Optional[str] = None, - func: Optional[Callable] = None, - tags: Optional[List[str]] = None, - return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, - ) -> Tool: - """ - Update a tool with provided parameters (name, func, tags) - - Args: - id (str): ID of the tool - func (callable): Function to wrap in a tool - tags (List[str]): Tags for the tool - return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. - - Returns: - tool (Tool): Updated tool - """ - update_data = { - "source_type": "python", # Always include source_type - "source_code": parse_source_code(func) if func else None, - "tags": tags, - "description": description, - "return_char_limit": return_char_limit, - } - - # Filter out any None values from the dictionary - update_data = {key: value for key, value in update_data.items() if value is not None} - - return self.server.tool_manager.update_tool_by_id(tool_id=id, tool_update=ToolUpdate(**update_data), actor=self.user) - - def list_tools(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]: - """ - List available tools for the user. - - Returns: - tools (List[Tool]): List of tools - """ - # Get the current event loop or create a new one if there isn't one - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # We're in an async context but can't await - use a new loop via run_coroutine_threadsafe - concurrent_future = asyncio.run_coroutine_threadsafe( - self.server.tool_manager.list_tools_async(actor=self.user, after=after, limit=limit), loop - ) - return concurrent_future.result() - else: - # We have a loop but it's not running - we can just run the coroutine - return loop.run_until_complete(self.server.tool_manager.list_tools_async(actor=self.user, after=after, limit=limit)) - except RuntimeError: - # No running event loop - create a new one with asyncio.run - return asyncio.run(self.server.tool_manager.list_tools_async(actor=self.user, after=after, limit=limit)) - - def get_tool(self, id: str) -> Optional[Tool]: - """ - Get a tool given its ID. - - Args: - id (str): ID of the tool - - Returns: - tool (Tool): Tool - """ - return self.server.tool_manager.get_tool_by_id(id, actor=self.user) - - def delete_tool(self, id: str): - """ - Delete a tool given the ID. - - Args: - id (str): ID of the tool - """ - return self.server.tool_manager.delete_tool_by_id(id, actor=self.user) - - def get_tool_id(self, name: str) -> Optional[str]: - """ - Get the ID of a tool from its name. The client will use the org_id it is configured with. - - Args: - name (str): Name of the tool - - Returns: - id (str): ID of the tool (`None` if not found) - """ - tool = self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user) - return tool.id if tool else None - - def list_attached_tools(self, agent_id: str) -> List[Tool]: - """ - List all tools attached to an agent. - - Args: - agent_id (str): ID of the agent - - Returns: - List[Tool]: List of tools attached to the agent - """ - return self.server.agent_manager.list_attached_tools(agent_id=agent_id, actor=self.user) - - def load_data(self, connector: DataConnector, source_name: str): - """ - Load data into a source - - Args: - connector (DataConnector): Data connector - source_name (str): Name of the source - """ - self.server.load_data(user_id=self.user_id, connector=connector, source_name=source_name) - - def load_file_to_source(self, filename: str, source_id: str, blocking=True): - """ - Load a file into a source - - Args: - filename (str): Name of the file - source_id (str): ID of the source - blocking (bool): Block until the job is complete - - Returns: - job (Job): Data loading job including job status and metadata - """ - job = Job( - user_id=self.user_id, - status=JobStatus.created, - metadata={"type": "embedding", "filename": filename, "source_id": source_id}, - ) - job = self.server.job_manager.create_job(pydantic_job=job, actor=self.user) - - # TODO: implement blocking vs. non-blocking - self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id, actor=self.user) - return job - - def delete_file_from_source(self, source_id: str, file_id: str) -> None: - self.server.source_manager.delete_file(file_id, actor=self.user) - - def get_job(self, job_id: str): - return self.server.job_manager.get_job_by_id(job_id=job_id, actor=self.user) - - def delete_job(self, job_id: str): - return self.server.job_manager.delete_job_by_id(job_id=job_id, actor=self.user) - - def list_jobs(self): - return self.server.job_manager.list_jobs(actor=self.user) - - def list_active_jobs(self): - return self.server.job_manager.list_jobs(actor=self.user, statuses=[JobStatus.created, JobStatus.running]) - - def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source: - """ - Create a source - - Args: - name (str): Name of the source - - Returns: - source (Source): Created source - """ - assert embedding_config or self._default_embedding_config, f"Must specify embedding_config for source" - source = Source( - name=name, embedding_config=embedding_config or self._default_embedding_config, organization_id=self.user.organization_id - ) - return self.server.source_manager.create_source(source=source, actor=self.user) - - def delete_source(self, source_id: str): - """ - Delete a source - - Args: - source_id (str): ID of the source - """ - - # TODO: delete source data - self.server.delete_source(source_id=source_id, actor=self.user) - - def get_source(self, source_id: str) -> Source: - """ - Get a source given the ID. - - Args: - source_id (str): ID of the source - - Returns: - source (Source): Source - """ - return self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) - - def get_source_id(self, source_name: str) -> str: - """ - Get the ID of a source - - Args: - source_name (str): Name of the source - - Returns: - source_id (str): ID of the source - """ - return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id - - def attach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState: - """ - Attach a source to an agent - - Args: - agent_id (str): ID of the agent - source_id (str): ID of the source - source_name (str): Name of the source - """ - if source_name: - source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) - source_id = source.id - - return self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user) - - def detach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState: - """ - Detach a source from an agent by removing all `Passage` objects that were loaded from the source from archival memory. - Args: - agent_id (str): ID of the agent - source_id (str): ID of the source - source_name (str): Name of the source - Returns: - source (Source): Detached source - """ - if source_name: - source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) - source_id = source.id - return self.server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=self.user) - - def list_sources(self) -> List[Source]: - """ - List available sources - - Returns: - sources (List[Source]): List of sources - """ - - return self.server.list_all_sources(actor=self.user) - - def list_attached_sources(self, agent_id: str) -> List[Source]: - """ - List sources attached to an agent - - Args: - agent_id (str): ID of the agent - - Returns: - sources (List[Source]): List of sources - """ - return self.server.agent_manager.list_attached_sources(agent_id=agent_id, actor=self.user) - - def list_files_from_source(self, source_id: str, limit: int = 1000, after: Optional[str] = None) -> List[FileMetadata]: - """ - List files from source. - - Args: - source_id (str): ID of the source - limit (int): The # of items to return - after (str): The cursor for fetching the next page - - Returns: - files (List[FileMetadata]): List of files - """ - return self.server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=self.user) - - def update_source(self, source_id: str, name: Optional[str] = None) -> Source: - """ - Update a source - - Args: - source_id (str): ID of the source - name (str): Name of the source - - Returns: - source (Source): Updated source - """ - # TODO should the arg here just be "source_update: Source"? - request = SourceUpdate(name=name) - return self.server.source_manager.update_source(source_id=source_id, source_update=request, actor=self.user) - - # archival memory - - def insert_archival_memory(self, agent_id: str, memory: str) -> List[Passage]: - """ - Insert archival memory into an agent - - Args: - agent_id (str): ID of the agent - memory (str): Memory string to insert - - Returns: - passages (List[Passage]): List of inserted passages - """ - return self.server.insert_archival_memory(agent_id=agent_id, memory_contents=memory, actor=self.user) - - def delete_archival_memory(self, agent_id: str, memory_id: str): - """ - Delete archival memory from an agent - - Args: - agent_id (str): ID of the agent - memory_id (str): ID of the memory - """ - self.server.delete_archival_memory(memory_id=memory_id, actor=self.user) - - def get_archival_memory( - self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000 - ) -> List[Passage]: - """ - Get archival memory from an agent with pagination. - - Args: - agent_id (str): ID of the agent - before (str): Get memories before a certain time - after (str): Get memories after a certain time - limit (int): Limit number of memories - - Returns: - passages (List[Passage]): List of passages - """ - - return self.server.get_agent_archival(user_id=self.user_id, agent_id=agent_id, limit=limit) - - # recall memory - - def get_messages( - self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000 - ) -> List[LettaMessage]: - """ - Get messages from an agent with pagination. - - Args: - agent_id (str): ID of the agent - before (str): Get messages before a certain time - after (str): Get messages after a certain time - limit (int): Limit number of messages - - Returns: - messages (List[Message]): List of messages - """ - - self.interface.clear() - return self.server.get_agent_recall( - user_id=self.user_id, - agent_id=agent_id, - before=before, - after=after, - limit=limit, - reverse=True, - return_message_object=False, - ) - - def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]: - """ - List available blocks - - Args: - label (str): Label of the block - templates_only (bool): List only templates - - Returns: - blocks (List[Block]): List of blocks - """ - return [] - - def create_block( - self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False - ) -> Block: # - """ - Create a block - - Args: - label (str): Label of the block - name (str): Name of the block - text (str): Text of the block - limit (int): Character of the block - - Returns: - block (Block): Created block - """ - block = Block(label=label, template_name=template_name, value=value, is_template=is_template) - if limit: - block.limit = limit - return self.server.block_manager.create_or_update_block(block, actor=self.user) - - def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None, limit: Optional[int] = None) -> Block: - """ - Update a block - - Args: - block_id (str): ID of the block - name (str): Name of the block - text (str): Text of the block - - Returns: - block (Block): Updated block - """ - return self.server.block_manager.update_block( - block_id=block_id, - block_update=BlockUpdate(template_name=name, value=text, limit=limit if limit else self.get_block(block_id).limit), - actor=self.user, - ) - - def get_block(self, block_id: str) -> Block: - """ - Get a block - - Args: - block_id (str): ID of the block - - Returns: - block (Block): Block - """ - return self.server.block_manager.get_block_by_id(block_id, actor=self.user) - - def delete_block(self, id: str) -> Block: - """ - Delete a block - - Args: - id (str): ID of the block - - Returns: - block (Block): Deleted block - """ - return self.server.block_manager.delete_block(id, actor=self.user) - - def set_default_llm_config(self, llm_config: LLMConfig): - """ - Set the default LLM configuration for agents. - - Args: - llm_config (LLMConfig): LLM configuration - """ - self._default_llm_config = llm_config - - def set_default_embedding_config(self, embedding_config: EmbeddingConfig): - """ - Set the default embedding configuration for agents. - - Args: - embedding_config (EmbeddingConfig): Embedding configuration - """ - self._default_embedding_config = embedding_config - - def list_llm_configs(self) -> List[LLMConfig]: - """ - List available LLM configurations - - Returns: - configs (List[LLMConfig]): List of LLM configurations - """ - return self.server.list_llm_models(actor=self.user) - - def list_embedding_configs(self) -> List[EmbeddingConfig]: - """ - List available embedding configurations - - Returns: - configs (List[EmbeddingConfig]): List of embedding configurations - """ - return self.server.list_embedding_models(actor=self.user) - - def create_org(self, name: Optional[str] = None) -> Organization: - return self.server.organization_manager.create_organization(pydantic_org=Organization(name=name)) - - def list_orgs(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]: - return self.server.organization_manager.list_organizations(limit=limit, after=after) - - def delete_org(self, org_id: str) -> Organization: - return self.server.organization_manager.delete_organization_by_id(org_id=org_id) - - def create_sandbox_config(self, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: - """ - Create a new sandbox configuration. - """ - config_create = SandboxConfigCreate(config=config) - return self.server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create=config_create, actor=self.user) - - def update_sandbox_config(self, sandbox_config_id: str, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: - """ - Update an existing sandbox configuration. - """ - sandbox_update = SandboxConfigUpdate(config=config) - return self.server.sandbox_config_manager.update_sandbox_config( - sandbox_config_id=sandbox_config_id, sandbox_update=sandbox_update, actor=self.user - ) - - def delete_sandbox_config(self, sandbox_config_id: str) -> None: - """ - Delete a sandbox configuration. - """ - return self.server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id=sandbox_config_id, actor=self.user) - - def list_sandbox_configs(self, limit: int = 50, after: Optional[str] = None) -> List[SandboxConfig]: - """ - List all sandbox configurations. - """ - return self.server.sandbox_config_manager.list_sandbox_configs(actor=self.user, limit=limit, after=after) - - def create_sandbox_env_var( - self, sandbox_config_id: str, key: str, value: str, description: Optional[str] = None - ) -> SandboxEnvironmentVariable: - """ - Create a new environment variable for a sandbox configuration. - """ - env_var_create = SandboxEnvironmentVariableCreate(key=key, value=value, description=description) - return self.server.sandbox_config_manager.create_sandbox_env_var( - env_var_create=env_var_create, sandbox_config_id=sandbox_config_id, actor=self.user - ) - - def update_sandbox_env_var( - self, env_var_id: str, key: Optional[str] = None, value: Optional[str] = None, description: Optional[str] = None - ) -> SandboxEnvironmentVariable: - """ - Update an existing environment variable. - """ - env_var_update = SandboxEnvironmentVariableUpdate(key=key, value=value, description=description) - return self.server.sandbox_config_manager.update_sandbox_env_var( - env_var_id=env_var_id, env_var_update=env_var_update, actor=self.user - ) - - def delete_sandbox_env_var(self, env_var_id: str) -> None: - """ - Delete an environment variable by its ID. - """ - return self.server.sandbox_config_manager.delete_sandbox_env_var(env_var_id=env_var_id, actor=self.user) - - def list_sandbox_env_vars( - self, sandbox_config_id: str, limit: int = 50, after: Optional[str] = None - ) -> List[SandboxEnvironmentVariable]: - """ - List all environment variables associated with a sandbox configuration. - """ - return self.server.sandbox_config_manager.list_sandbox_env_vars( - sandbox_config_id=sandbox_config_id, actor=self.user, limit=limit, after=after - ) - - def update_agent_memory_block_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: - """Rename a block in the agent's core memory - - Args: - agent_id (str): The agent ID - current_label (str): The current label of the block - new_label (str): The new label of the block - - Returns: - memory (Memory): The updated memory - """ - block = self.get_agent_memory_block(agent_id, current_label) - return self.update_block(block.id, label=new_label) - - def get_agent_memory_blocks(self, agent_id: str) -> List[Block]: - """ - Get all the blocks in the agent's core memory - - Args: - agent_id (str): The agent ID - - Returns: - blocks (List[Block]): The blocks in the agent's core memory - """ - agent = self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user) - return agent.memory.blocks - - def get_agent_memory_block(self, agent_id: str, label: str) -> Block: - """ - Get a block in the agent's core memory by its label - - Args: - agent_id (str): The agent ID - label (str): The label in the agent's core memory - - Returns: - block (Block): The block corresponding to the label - """ - return self.server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=label, actor=self.user) - - def update_agent_memory_block( - self, - agent_id: str, - label: str, - value: Optional[str] = None, - limit: Optional[int] = None, - ): - """ - Update a block in the agent's core memory by specifying its label - - Args: - agent_id (str): The agent ID - label (str): The label of the block - value (str): The new value of the block - limit (int): The new limit of the block - - Returns: - block (Block): The updated block - """ - block = self.get_agent_memory_block(agent_id, label) - data = {} - if value: - data["value"] = value - if limit: - data["limit"] = limit - return self.server.block_manager.update_block(block.id, actor=self.user, block_update=BlockUpdate(**data)) - - def update_block( - self, - block_id: str, - label: Optional[str] = None, - value: Optional[str] = None, - limit: Optional[int] = None, - ): - """ - Update a block given the ID with the provided fields - - Args: - block_id (str): ID of the block - label (str): Label to assign to the block - value (str): Value to assign to the block - limit (int): Token limit to assign to the block - - Returns: - block (Block): Updated block - """ - data = {} - if value: - data["value"] = value - if limit: - data["limit"] = limit - if label: - data["label"] = label - return self.server.block_manager.update_block(block_id, actor=self.user, block_update=BlockUpdate(**data)) - - def attach_block(self, agent_id: str, block_id: str) -> AgentState: - """ - Attach a block to an agent. - - Args: - agent_id (str): ID of the agent - block_id (str): ID of the block to attach - """ - return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user) - - def detach_block(self, agent_id: str, block_id: str) -> AgentState: - """ - Detach a block from an agent. - - Args: - agent_id (str): ID of the agent - block_id (str): ID of the block to detach - """ - return self.server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=self.user) - - def get_run_messages( - self, - run_id: str, - before: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = 100, - ascending: bool = True, - role: Optional[MessageRole] = None, - ) -> List[LettaMessageUnion]: - """ - Get messages associated with a job with filtering options. - - Args: - run_id: ID of the run - before: Cursor for pagination - after: Cursor for pagination - limit: Maximum number of messages to return - ascending: Sort order by creation time - role: Filter by message role (user/assistant/system/tool) - Returns: - List of messages matching the filter criteria - """ - params = { - "before": before, - "after": after, - "limit": limit, - "ascending": ascending, - "role": role, - } - - return self.server.job_manager.get_run_messages(run_id=run_id, actor=self.user, **params) - - def get_run_usage( - self, - run_id: str, - ) -> List[UsageStatistics]: - """ - Get usage statistics associated with a job. - - Args: - run_id (str): ID of the run - - Returns: - List[UsageStatistics]: List of usage statistics associated with the run - """ - usage = self.server.job_manager.get_job_usage(job_id=run_id, actor=self.user) - return [ - UsageStatistics(completion_tokens=stat.completion_tokens, prompt_tokens=stat.prompt_tokens, total_tokens=stat.total_tokens) - for stat in usage - ] - - def get_run(self, run_id: str) -> Run: - """ - Get a run by ID. - - Args: - run_id (str): ID of the run - - Returns: - run (Run): Run - """ - return self.server.job_manager.get_job_by_id(job_id=run_id, actor=self.user) - - def delete_run(self, run_id: str) -> None: - """ - Delete a run by ID. - - Args: - run_id (str): ID of the run - """ - return self.server.job_manager.delete_job_by_id(job_id=run_id, actor=self.user) - - def list_runs(self) -> List[Run]: - """ - List all runs. - - Returns: - runs (List[Run]): List of runs - """ - return self.server.job_manager.list_jobs(actor=self.user, job_type=JobType.RUN) - - def list_active_runs(self) -> List[Run]: - """ - List all active runs. - - Returns: - runs (List[Run]): List of active runs - """ - return self.server.job_manager.list_jobs(actor=self.user, job_type=JobType.RUN, statuses=[JobStatus.created, JobStatus.running]) - - def get_tags( - self, - after: Optional[str] = None, - limit: Optional[int] = None, - query_text: Optional[str] = None, - ) -> List[str]: - """ - Get all tags. - - Returns: - tags (List[str]): List of tags - """ - return self.server.agent_manager.list_tags(actor=self.user, after=after, limit=limit, query_text=query_text) diff --git a/letta/functions/ast_parsers.py b/letta/functions/ast_parsers.py index acc73bef..3113cd96 100644 --- a/letta/functions/ast_parsers.py +++ b/letta/functions/ast_parsers.py @@ -1,5 +1,7 @@ import ast +import builtins import json +import typing from typing import Dict, Optional, Tuple from letta.errors import LettaToolCreateError @@ -22,7 +24,7 @@ def resolve_type(annotation: str): Resolve a type annotation string into a Python type. Args: - annotation (str): The annotation string (e.g., 'int', 'list', etc.). + annotation (str): The annotation string (e.g., 'int', 'list[int]', 'dict[str, int]'). Returns: type: The corresponding Python type. @@ -34,11 +36,17 @@ def resolve_type(annotation: str): return BUILTIN_TYPES[annotation] try: - parsed = ast.literal_eval(annotation) - if isinstance(parsed, type): - return parsed - raise ValueError(f"Annotation '{annotation}' is not a recognized type.") - except (ValueError, SyntaxError): + # Allow use of typing and builtins in a safe eval context + namespace = { + **vars(typing), + **vars(builtins), + "list": list, + "dict": dict, + "tuple": tuple, + "set": set, + } + return eval(annotation, namespace) + except Exception: raise ValueError(f"Unsupported annotation: {annotation}") @@ -69,41 +77,36 @@ def get_function_annotations_from_source(source_code: str, function_name: str) - def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str, str]) -> dict: - """ - Coerce arguments in a dictionary to their annotated types. - - Args: - function_args (dict): The original function arguments. - annotations (Dict[str, str]): Argument annotations as strings. - - Returns: - dict: The updated dictionary with coerced argument types. - - Raises: - ValueError: If type coercion fails for an argument. - """ - coerced_args = dict(function_args) # Shallow copy for mutation safety + coerced_args = dict(function_args) # Shallow copy for arg_name, value in coerced_args.items(): if arg_name in annotations: annotation_str = annotations[arg_name] try: - # Resolve the type from the annotation arg_type = resolve_type(annotation_str) - # Handle JSON-like inputs for dict and list types - if arg_type in {dict, list} and isinstance(value, str): + # Always parse strings using literal_eval or json if possible + if isinstance(value, str): try: - # First, try JSON parsing value = json.loads(value) except json.JSONDecodeError: - # Fall back to literal_eval for Python-specific literals - value = ast.literal_eval(value) + try: + value = ast.literal_eval(value) + except (SyntaxError, ValueError) as e: + if arg_type is not str: + raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}") - # Coerce the value to the resolved type - coerced_args[arg_name] = arg_type(value) - except (TypeError, ValueError, json.JSONDecodeError, SyntaxError) as e: + origin = typing.get_origin(arg_type) + if origin in (list, dict, tuple, set): + # Let the origin (e.g., list) handle coercion + coerced_args[arg_name] = origin(value) + else: + # Coerce simple types (e.g., int, float) + coerced_args[arg_name] = arg_type(value) + + except Exception as e: raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}") + return coerced_args diff --git a/letta/main.py b/letta/main.py index de1b4028..a64b3637 100644 --- a/letta/main.py +++ b/letta/main.py @@ -1,374 +1,14 @@ import os -import sys -import traceback -import questionary -import requests import typer -from rich.console import Console -import letta.agent as agent -import letta.errors as errors -import letta.system as system - -# import benchmark -from letta import create_client -from letta.benchmark.benchmark import bench -from letta.cli.cli import delete_agent, open_folder, run, server, version -from letta.cli.cli_config import add, add_tool, configure, delete, list, list_tools +from letta.cli.cli import server from letta.cli.cli_load import app as load_app -from letta.config import LettaConfig -from letta.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE - -# from letta.interface import CLIInterface as interface # for printing to terminal -from letta.streaming_interface import AgentRefreshStreamingInterface - -# interface = interface() # disable composio print on exit os.environ["COMPOSIO_DISABLE_VERSION_CHECK"] = "true" app = typer.Typer(pretty_exceptions_enable=False) -app.command(name="run")(run) -app.command(name="version")(version) -app.command(name="configure")(configure) -app.command(name="list")(list) -app.command(name="add")(add) -app.command(name="add-tool")(add_tool) -app.command(name="list-tools")(list_tools) -app.command(name="delete")(delete) app.command(name="server")(server) -app.command(name="folder")(open_folder) -# load data commands + app.add_typer(load_app, name="load") -# benchmark command -app.command(name="benchmark")(bench) -# delete agents -app.command(name="delete-agent")(delete_agent) - - -def clear_line(console, strip_ui=False): - if strip_ui: - return - if os.name == "nt": # for windows - console.print("\033[A\033[K", end="") - else: # for linux - sys.stdout.write("\033[2K\033[G") - sys.stdout.flush() - - -def run_agent_loop( - letta_agent: agent.Agent, - config: LettaConfig, - first: bool, - no_verify: bool = False, - strip_ui: bool = False, - stream: bool = False, -): - if isinstance(letta_agent.interface, AgentRefreshStreamingInterface): - # letta_agent.interface.toggle_streaming(on=stream) - if not stream: - letta_agent.interface = letta_agent.interface.nonstreaming_interface - - if hasattr(letta_agent.interface, "console"): - console = letta_agent.interface.console - else: - console = Console() - - counter = 0 - user_input = None - skip_next_user_input = False - user_message = None - USER_GOES_FIRST = first - - if not USER_GOES_FIRST: - console.input("[bold cyan]Hit enter to begin (will request first Letta message)[/bold cyan]\n") - clear_line(console, strip_ui=strip_ui) - print() - - multiline_input = False - - # create client - client = create_client() - - # run loops - while True: - if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST): - # Ask for user input - if not stream: - print() - user_input = questionary.text( - "Enter your message:", - multiline=multiline_input, - qmark=">", - ).ask() - clear_line(console, strip_ui=strip_ui) - if not stream: - print() - - # Gracefully exit on Ctrl-C/D - if user_input is None: - user_input = "/exit" - - user_input = user_input.rstrip() - - if user_input.startswith("!"): - print(f"Commands for CLI begin with '/' not '!'") - continue - - if user_input == "": - # no empty messages allowed - print("Empty input received. Try again!") - continue - - # Handle CLI commands - # Commands to not get passed as input to Letta - if user_input.startswith("/"): - # updated agent save functions - if user_input.lower() == "/exit": - # letta_agent.save() - agent.save_agent(letta_agent) - break - elif user_input.lower() == "/save" or user_input.lower() == "/savechat": - # letta_agent.save() - agent.save_agent(letta_agent) - continue - elif user_input.lower() == "/attach": - # TODO: check if agent already has it - - # TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not - # TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources - - sources = client.list_sources() - if len(sources) == 0: - typer.secho( - 'No sources available. You must load a souce with "letta load ..." before running /attach.', - fg=typer.colors.RED, - bold=True, - ) - continue - - # determine what sources are valid to be attached to this agent - valid_options = [] - invalid_options = [] - for source in sources: - if source.embedding_config == letta_agent.agent_state.embedding_config: - valid_options.append(source.name) - else: - # print warning about invalid sources - typer.secho( - f"Source {source.name} exists but has embedding dimentions {source.embedding_dim} from model {source.embedding_model}, while the agent uses embedding dimentions {letta_agent.agent_state.embedding_config.embedding_dim} and model {letta_agent.agent_state.embedding_config.embedding_model}", - fg=typer.colors.YELLOW, - ) - invalid_options.append(source.name) - - # prompt user for data source selection - data_source = questionary.select("Select data source", choices=valid_options).ask() - - # attach new data - client.attach_source_to_agent(agent_id=letta_agent.agent_state.id, source_name=data_source) - - continue - - elif user_input.lower() == "/dump" or user_input.lower().startswith("/dump "): - # Check if there's an additional argument that's an integer - command = user_input.strip().split() - amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 - if amount == 0: - letta_agent.interface.print_messages(letta_agent._messages, dump=True) - else: - letta_agent.interface.print_messages(letta_agent._messages[-min(amount, len(letta_agent.messages)) :], dump=True) - continue - - elif user_input.lower() == "/dumpraw": - letta_agent.interface.print_messages_raw(letta_agent._messages) - continue - - elif user_input.lower() == "/memory": - print(f"\nDumping memory contents:\n") - print(f"{letta_agent.agent_state.memory.compile()}") - print(f"{letta_agent.archival_memory.compile()}") - continue - - elif user_input.lower() == "/model": - print(f"Current model: {letta_agent.agent_state.llm_config.model}") - continue - - elif user_input.lower() == "/summarize": - try: - letta_agent.summarize_messages_inplace() - typer.secho( - f"/summarize succeeded", - fg=typer.colors.GREEN, - bold=True, - ) - except (errors.LLMError, requests.exceptions.HTTPError) as e: - typer.secho( - f"/summarize failed:\n{e}", - fg=typer.colors.RED, - bold=True, - ) - continue - - elif user_input.lower() == "/tokens": - tokens = letta_agent.count_tokens() - typer.secho( - f"{tokens}/{letta_agent.agent_state.llm_config.context_window}", - fg=typer.colors.GREEN, - bold=True, - ) - continue - - elif user_input.lower().startswith("/add_function"): - try: - if len(user_input) < len("/add_function "): - print("Missing function name after the command") - continue - function_name = user_input[len("/add_function ") :].strip() - result = letta_agent.add_function(function_name) - typer.secho( - f"/add_function succeeded: {result}", - fg=typer.colors.GREEN, - bold=True, - ) - except ValueError as e: - typer.secho( - f"/add_function failed:\n{e}", - fg=typer.colors.RED, - bold=True, - ) - continue - elif user_input.lower().startswith("/remove_function"): - try: - if len(user_input) < len("/remove_function "): - print("Missing function name after the command") - continue - function_name = user_input[len("/remove_function ") :].strip() - result = letta_agent.remove_function(function_name) - typer.secho( - f"/remove_function succeeded: {result}", - fg=typer.colors.GREEN, - bold=True, - ) - except ValueError as e: - typer.secho( - f"/remove_function failed:\n{e}", - fg=typer.colors.RED, - bold=True, - ) - continue - - # No skip options - elif user_input.lower() == "/wipe": - letta_agent = agent.Agent(letta_agent.interface) - user_message = None - - elif user_input.lower() == "/heartbeat": - user_message = system.get_heartbeat() - - elif user_input.lower() == "/memorywarning": - user_message = system.get_token_limit_warning() - - elif user_input.lower() == "//": - multiline_input = not multiline_input - continue - - elif user_input.lower() == "/" or user_input.lower() == "/help": - questionary.print("CLI commands", "bold") - for cmd, desc in USER_COMMANDS: - questionary.print(cmd, "bold") - questionary.print(f" {desc}") - continue - else: - print(f"Unrecognized command: {user_input}") - continue - - else: - # If message did not begin with command prefix, pass inputs to Letta - # Handle user message and append to messages - user_message = str(user_input) - - skip_next_user_input = False - - def process_agent_step(user_message, no_verify): - # TODO(charles): update to use agent.step() instead of inner_step() - - if user_message is None: - step_response = letta_agent.inner_step( - messages=[], - first_message=False, - skip_verify=no_verify, - stream=stream, - ) - else: - step_response = letta_agent.step_user_message( - user_message_str=user_message, - first_message=False, - skip_verify=no_verify, - stream=stream, - ) - new_messages = step_response.messages - heartbeat_request = step_response.heartbeat_request - function_failed = step_response.function_failed - token_warning = step_response.in_context_memory_warning - step_response.usage - - agent.save_agent(letta_agent) - skip_next_user_input = False - if token_warning: - user_message = system.get_token_limit_warning() - skip_next_user_input = True - elif function_failed: - user_message = system.get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE) - skip_next_user_input = True - elif heartbeat_request: - user_message = system.get_heartbeat(REQ_HEARTBEAT_MESSAGE) - skip_next_user_input = True - - return new_messages, user_message, skip_next_user_input - - while True: - try: - if strip_ui: - _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) - break - else: - if stream: - # Don't display the "Thinking..." if streaming - _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) - else: - with console.status("[bold cyan]Thinking...") as status: - _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) - break - except KeyboardInterrupt: - print("User interrupt occurred.") - retry = questionary.confirm("Retry agent.step()?").ask() - if not retry: - break - except Exception: - print("An exception occurred when running agent.step(): ") - traceback.print_exc() - retry = questionary.confirm("Retry agent.step()?").ask() - if not retry: - break - - counter += 1 - - print("Finished.") - - -USER_COMMANDS = [ - ("//", "toggle multiline input mode"), - ("/exit", "exit the CLI"), - ("/save", "save a checkpoint of the current agent/conversation state"), - ("/load", "load a saved checkpoint"), - ("/dump ", "view the last messages (all if is omitted)"), - ("/memory", "print the current contents of agent memory"), - ("/pop ", "undo messages in the conversation (default is 3)"), - ("/retry", "pops the last answer and tries to get another one"), - ("/rethink ", "changes the inner thoughts of the last agent message"), - ("/rewrite ", "changes the reply of the last agent message"), - ("/heartbeat", "send a heartbeat system message to the agent"), - ("/memorywarning", "send a memory warning system message to the agent"), - ("/attach", "attach data source to agent"), -] diff --git a/tests/constants.py b/tests/constants.py index e1832cbd..fa60404c 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1 +1,3 @@ TIMEOUT = 30 # seconds +embedding_config_dir = "tests/configs/embedding_model_configs" +llm_config_dir = "tests/configs/llm_model_configs" diff --git a/tests/helpers/client_helper.py b/tests/helpers/client_helper.py index 815102a8..99740d54 100644 --- a/tests/helpers/client_helper.py +++ b/tests/helpers/client_helper.py @@ -1,13 +1,12 @@ import time -from typing import Union -from letta import LocalClient, RESTClient +from letta import RESTClient from letta.schemas.enums import JobStatus from letta.schemas.job import Job from letta.schemas.source import Source -def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Source, filename: str) -> Job: +def upload_file_using_client(client: RESTClient, source: Source, filename: str) -> Job: # load a file into a source (non-blocking job) upload_job = client.load_file_to_source(filename=filename, source_id=source.id, blocking=False) print("Upload job", upload_job, upload_job.status, upload_job.metadata) diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 7774a752..2fa78a48 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -1,33 +1,28 @@ import json import logging import uuid -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, List, Optional, Sequence from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs +from letta.schemas.block import CreateBlock from letta.schemas.tool_rule import BaseToolRule +from letta.server.server import SyncServer logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -from letta import LocalClient, RESTClient, create_client -from letta.agent import Agent from letta.config import LettaConfig from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA from letta.embeddings import embedding_model from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError -from letta.helpers.json_helpers import json_dumps -from letta.llm_api.llm_api_tools import create -from letta.llm_api.llm_client import LLMClient from letta.local_llm.constants import INNER_THOUGHTS_KWARG -from letta.schemas.agent import AgentState +from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_message import LettaMessage, ReasoningMessage, ToolCallMessage from letta.schemas.letta_response import LettaResponse from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory -from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message +from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message from letta.utils import get_human_text, get_persona_text -from tests.helpers.utils import cleanup # Generate uuid for agent name for this example namespace = uuid.NAMESPACE_DNS @@ -45,7 +40,7 @@ LLM_CONFIG_PATH = "tests/configs/llm_model_configs/letta-hosted.json" def setup_agent( - client: Union[LocalClient, RESTClient], + server: SyncServer, filename: str, memory_human_str: str = get_human_text(DEFAULT_HUMAN), memory_persona_str: str = get_persona_text(DEFAULT_PERSONA), @@ -65,17 +60,27 @@ def setup_agent( config.default_embedding_config = embedding_config config.save() - memory = ChatMemory(human=memory_human_str, persona=memory_persona_str) - agent_state = client.create_agent( + request = CreateAgent( name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, - memory=memory, + memory_blocks=[ + CreateBlock( + label="human", + value=memory_human_str, + ), + CreateBlock( + label="persona", + value=memory_persona_str, + ), + ], tool_ids=tool_ids, tool_rules=tool_rules, include_base_tools=include_base_tools, include_base_tool_rules=include_base_tool_rules, ) + actor = server.user_manager.get_user_or_default() + agent_state = server.create_agent(request=request, actor=actor) return agent_state @@ -86,285 +91,6 @@ def setup_agent( # ====================================================================================================================== -def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner_monologue_contents: bool = True) -> ChatCompletionResponse: - """ - Checks that the first response is valid: - - 1. Contains either send_message or archival_memory_search - 2. Contains valid usage of the function - 3. Contains inner monologue - - Note: This is acting on the raw LLM response, note the usage of `create` - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename) - - full_agent_state = client.get_agent(agent_state.id) - messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user) - agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) - - llm_client = LLMClient.create( - provider_type=agent_state.llm_config.model_endpoint_type, - actor=client.user, - ) - if llm_client: - response = llm_client.send_llm_request( - messages=messages, - llm_config=agent_state.llm_config, - tools=[t.json_schema for t in agent.agent_state.tools], - ) - else: - response = create( - llm_config=agent_state.llm_config, - user_id=str(uuid.UUID(int=1)), # dummy user_id - messages=messages, - functions=[t.json_schema for t in agent.agent_state.tools], - ) - - # Basic check - assert response is not None, response - assert response.choices is not None, response - assert len(response.choices) > 0, response - assert response.choices[0] is not None, response - - # Select first choice - choice = response.choices[0] - - # Ensure that the first message returns a "send_message" - validator_func = ( - lambda function_call: function_call.name == "send_message" - or function_call.name == "archival_memory_search" - or function_call.name == "core_memory_append" - ) - assert_contains_valid_function_call(choice.message, validator_func) - - # Assert that the message has an inner monologue - assert_contains_correct_inner_monologue( - choice, - agent_state.llm_config.put_inner_thoughts_in_kwargs, - validate_inner_monologue_contents=validate_inner_monologue_contents, - ) - - return response - - -def check_response_contains_keyword(filename: str, keyword="banana") -> LettaResponse: - """ - Checks that the prompted response from the LLM contains a chosen keyword - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename) - - keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"' - response = client.user_message(agent_id=agent_state.id, message=keyword_message) - - # Basic checks - assert_sanity_checks(response) - - # Make sure the message was sent - assert_invoked_send_message_with_keyword(response.messages, keyword) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_uses_external_tool(filename: str) -> LettaResponse: - """ - Checks that the LLM will use external tools if instructed - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - from composio import Action - - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - - # Set up persona for tool usage - persona = f""" - - My name is Letta. - - I am a personal assistant who uses a tool called {tool.name} to star a desired github repo. - - Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. - """ - - agent_state = setup_agent(client, filename, memory_persona_str=persona, tool_ids=[tool.id]) - - response = client.user_message(agent_id=agent_state.id, message="Please star the repo with owner=letta-ai and repo=letta") - - # Basic checks - assert_sanity_checks(response) - - # Make sure the tool was called - assert_invoked_function_call(response.messages, tool.name) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_recall_chat_memory(filename: str) -> LettaResponse: - """ - Checks that the LLM will recall the chat memory, specifically the human persona. - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - human_name = "BananaBoy" - agent_state = setup_agent(client, filename, memory_human_str=f"My name is {human_name}.") - response = client.user_message( - agent_id=agent_state.id, message="Repeat my name back to me. You should search in your human memory block." - ) - - # Basic checks - assert_sanity_checks(response) - - # Make sure my name was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, human_name) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_archival_memory_insert(filename: str) -> LettaResponse: - """ - Checks that the LLM will execute an archival memory insert. - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename) - secret_word = "banana" - - response = client.user_message( - agent_id=agent_state.id, - message=f"Please insert the secret word '{secret_word}' into archival memory.", - ) - - # Basic checks - assert_sanity_checks(response) - - # Make sure archival_memory_search was called - assert_invoked_function_call(response.messages, "archival_memory_insert") - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_archival_memory_retrieval(filename: str) -> LettaResponse: - """ - Checks that the LLM will execute an archival memory retrieval. - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename) - secret_word = "banana" - client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!") - - response = client.user_message( - agent_id=agent_state.id, - message="Search archival memory for the secret word. If you find it successfully, you MUST respond by using the `send_message` function with a message that includes the secret word so I know you found it.", - ) - - # Basic checks - assert_sanity_checks(response) - - # Make sure archival_memory_search was called - assert_invoked_function_call(response.messages, "archival_memory_search") - - # Make sure secret was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, secret_word) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_edit_core_memory(filename: str) -> LettaResponse: - """ - Checks that the LLM is able to edit its core memories - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - human_name_a = "AngryAardvark" - human_name_b = "BananaBoy" - agent_state = setup_agent(client, filename, memory_human_str=f"My name is {human_name_a}") - client.user_message(agent_id=agent_state.id, message=f"Actually, my name changed. It is now {human_name_b}") - response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.") - - # Basic checks - assert_sanity_checks(response) - - # Make sure my name was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, human_name_b) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_summarize_memory_simple(filename: str) -> LettaResponse: - """ - Checks that the LLM is able to summarize its memory - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - agent_state = setup_agent(client, filename) - - # Send a couple messages - friend_name = "Shub" - client.user_message(agent_id=agent_state.id, message="Hey, how's it going? What do you think about this whole shindig") - client.user_message(agent_id=agent_state.id, message=f"By the way, my friend's name is {friend_name}!") - client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?") - - # Summarize - agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - agent.summarize_messages_inplace() - print(f"Summarization succeeded: messages[1] = \n\n{json_dumps(agent.messages[1])}\n") - - response = client.user_message(agent_id=agent_state.id, message="What is my friend's name?") - # Basic checks - assert_sanity_checks(response) - - # Make sure my name was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, friend_name) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - def run_embedding_endpoint(filename): # load JSON file config_data = json.load(open(filename, "r")) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 9731ac35..2bb06982 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -2,12 +2,13 @@ import functools import time from typing import Union -from letta import LocalClient, RESTClient from letta.functions.functions import parse_source_code from letta.functions.schema_generator import generate_schema from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.tool import Tool +from letta.schemas.user import User from letta.schemas.user import User as PydanticUser +from letta.server.server import SyncServer def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4): @@ -75,12 +76,12 @@ def retry_until_success(max_attempts=10, sleep_time_seconds=4): return decorator_retry -def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str): +def cleanup(server: SyncServer, agent_uuid: str, actor: User): # Clear all agents - for agent_state in client.list_agents(): - if agent_state.name == agent_uuid: - client.delete_agent(agent_id=agent_state.id) - print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") + agent_states = server.agent_manager.list_agents(name=agent_uuid, actor=actor) + + for agent_state in agent_states: + server.agent_manager.delete_agent(agent_id=agent_state.id, actor=actor) # Utility functions diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index bc3aee7a..9647eb1b 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -3,16 +3,20 @@ import uuid import pytest -from letta import create_client +from letta.config import LettaConfig from letta.schemas.letta_message import ToolCallMessage -from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, MaxCountPerStepToolRule, TerminalToolRule +from letta.schemas.letta_response import LettaResponse +from letta.schemas.message import MessageCreate +from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule +from letta.server.server import SyncServer from tests.helpers.endpoints_helper import ( assert_invoked_function_call, assert_invoked_send_message_with_keyword, assert_sanity_checks, setup_agent, ) -from tests.helpers.utils import cleanup, retry_until_success +from tests.helpers.utils import cleanup +from tests.utils import create_tool_from_func # Generate uuid for agent name for this example namespace = uuid.NAMESPACE_DNS @@ -20,106 +24,175 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph")) config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json" -"""Contrived tools for this test case""" +@pytest.fixture() +def server(): + config = LettaConfig.load() + config.save() + + server = SyncServer() + return server -def first_secret_word(): - """ - Call this to retrieve the first secret word, which you will need for the second_secret_word function. - """ - return "v0iq020i0g" +@pytest.fixture(scope="function") +def first_secret_tool(server): + def first_secret_word(): + """ + Retrieves the initial secret word in a multi-step sequence. + + Returns: + str: The first secret word. + """ + return "v0iq020i0g" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=first_secret_word), actor=actor) + yield tool -def second_secret_word(prev_secret_word: str): - """ - Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error. +@pytest.fixture(scope="function") +def second_secret_tool(server): + def second_secret_word(prev_secret_word: str): + """ + Retrieves the second secret word. - Args: - prev_secret_word (str): The secret word retrieved from calling first_secret_word. - """ - if prev_secret_word != "v0iq020i0g": - raise RuntimeError(f"Expected secret {'v0iq020i0g'}, got {prev_secret_word}") + Args: + prev_secret_word (str): The previously retrieved secret word. - return "4rwp2b4gxq" + Returns: + str: The second secret word. + """ + if prev_secret_word != "v0iq020i0g": + raise RuntimeError(f"Expected secret {'v0iq020i0g'}, got {prev_secret_word}") + return "4rwp2b4gxq" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=second_secret_word), actor=actor) + yield tool -def third_secret_word(prev_secret_word: str): - """ - Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error. +@pytest.fixture(scope="function") +def third_secret_tool(server): + def third_secret_word(prev_secret_word: str): + """ + Retrieves the third secret word. - Args: - prev_secret_word (str): The secret word retrieved from calling second_secret_word. - """ - if prev_secret_word != "4rwp2b4gxq": - raise RuntimeError(f'Expected secret "4rwp2b4gxq", got {prev_secret_word}') + Args: + prev_secret_word (str): The previously retrieved secret word. - return "hj2hwibbqm" + Returns: + str: The third secret word. + """ + if prev_secret_word != "4rwp2b4gxq": + raise RuntimeError(f'Expected secret "4rwp2b4gxq", got {prev_secret_word}') + return "hj2hwibbqm" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=third_secret_word), actor=actor) + yield tool -def fourth_secret_word(prev_secret_word: str): - """ - Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error. +@pytest.fixture(scope="function") +def fourth_secret_tool(server): + def fourth_secret_word(prev_secret_word: str): + """ + Retrieves the final secret word. - Args: - prev_secret_word (str): The secret word retrieved from calling third_secret_word. - """ - if prev_secret_word != "hj2hwibbqm": - raise RuntimeError(f"Expected secret {'hj2hwibbqm'}, got {prev_secret_word}") + Args: + prev_secret_word (str): The previously retrieved secret word. - return "banana" + Returns: + str: The final secret word. + """ + if prev_secret_word != "hj2hwibbqm": + raise RuntimeError(f"Expected secret {'hj2hwibbqm'}, got {prev_secret_word}") + return "banana" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=fourth_secret_word), actor=actor) + yield tool -def flip_coin(): - """ - Call this to retrieve the password to the secret word, which you will need to output in a send_message later. - If it returns an empty string, try flipping again! +@pytest.fixture(scope="function") +def flip_coin_tool(server): + def flip_coin(): + """ + Simulates a coin flip with a chance to return a secret word. - Returns: - str: The password or an empty string - """ - import random + Returns: + str: A secret word or an empty string. + """ + import random - # Flip a coin with 50% chance - if random.random() < 0.5: - return "" - return "hj2hwibbqm" + return "" if random.random() < 0.5 else "hj2hwibbqm" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=flip_coin), actor=actor) + yield tool -def can_play_game(): - """ - Call this to start the tool chain. - """ - import random +@pytest.fixture(scope="function") +def can_play_game_tool(server): + def can_play_game(): + """ + Determines whether a game can be played. - return random.random() < 0.5 + Returns: + bool: True if allowed to play, False otherwise. + """ + import random + + return random.random() < 0.5 + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=can_play_game), actor=actor) + yield tool -def return_none(): - """ - Really simple function - """ - return None +@pytest.fixture(scope="function") +def return_none_tool(server): + def return_none(): + """ + Always returns None. + + Returns: + None + """ + return None + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=return_none), actor=actor) + yield tool -def auto_error(): - """ - If you call this function, it will throw an error automatically. - """ - raise RuntimeError("This should never be called.") +@pytest.fixture(scope="function") +def auto_error_tool(server): + def auto_error(): + """ + Always raises an error when called. + + Raises: + RuntimeError: Always triggered. + """ + raise RuntimeError("This should never be called.") + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=auto_error), actor=actor) + yield tool + + +@pytest.fixture +def default_user(server): + yield server.user_manager.get_user_or_default() @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely -def test_single_path_agent_tool_call_graph(disable_e2b_api_key): - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) +def test_single_path_agent_tool_call_graph( + server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool, default_user +): + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) # Add tools - t1 = client.create_or_update_tool(first_secret_word) - t2 = client.create_or_update_tool(second_secret_word) - t3 = client.create_or_update_tool(third_secret_word) - t4 = client.create_or_update_tool(fourth_secret_word) - t_err = client.create_or_update_tool(auto_error) - tools = [t1, t2, t3, t4, t_err] + tools = [first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool] # Make tool rules tool_rules = [ @@ -132,8 +205,18 @@ def test_single_path_agent_tool_call_graph(disable_e2b_api_key): ] # Make agent state - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?") + agent_state = setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")], + ) + messages = [message for step_messages in usage_stats.steps_messages for message in step_messages] + letta_messages = [] + for m in messages: + letta_messages += m.to_letta_messages() + + response = LettaResponse(messages=letta_messages, usage=usage_stats) # Make checks assert_sanity_checks(response) @@ -145,7 +228,7 @@ def test_single_path_agent_tool_call_graph(disable_e2b_api_key): assert_invoked_function_call(response.messages, "fourth_secret_word") # Check ordering of tool calls - tool_names = [t.name for t in [t1, t2, t3, t4]] + tool_names = [t.name for t in [first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool]] tool_names += ["send_message"] for m in response.messages: if isinstance(m, ToolCallMessage): @@ -159,171 +242,281 @@ def test_single_path_agent_tool_call_graph(disable_e2b_api_key): assert_invoked_send_message_with_keyword(response.messages, "banana") print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) -def test_check_tool_rules_with_different_models(disable_e2b_api_key): - """Test that tool rules are properly checked for different model configurations.""" - client = create_client() - - config_files = [ +@pytest.mark.timeout(60) +@pytest.mark.parametrize( + "config_file", + [ "tests/configs/llm_model_configs/claude-3-5-sonnet.json", "tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json", "tests/configs/llm_model_configs/openai-gpt-4o.json", - ] + ], +) +@pytest.mark.parametrize("init_tools_case", ["single", "multiple"]) +def test_check_tool_rules_with_different_models_parametrized( + server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, default_user, config_file, init_tools_case +): + """Test that tool rules are properly validated across model configurations and init tool scenarios.""" + agent_uuid = str(uuid.uuid4()) - # Create two test tools - t1_name = "first_secret_word" - t2_name = "second_secret_word" - t1 = client.create_or_update_tool(first_secret_word) - t2 = client.create_or_update_tool(second_secret_word) - tool_rules = [InitToolRule(tool_name=t1_name), InitToolRule(tool_name=t2_name)] - tools = [t1, t2] + if init_tools_case == "multiple": + tools = [first_secret_tool, second_secret_tool] + tool_rules = [ + InitToolRule(tool_name=first_secret_tool.name), + InitToolRule(tool_name=second_secret_tool.name), + ] + else: # "single" + tools = [third_secret_tool] + tool_rules = [InitToolRule(tool_name=third_secret_tool.name)] - for config_file in config_files: - # Setup tools - agent_uuid = str(uuid.uuid4()) - - if "gpt-4o" in config_file: - # Structured output model (should work with multiple init tools) - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - assert agent_state is not None - else: - # Non-structured output model (should raise error with multiple init tools) - with pytest.raises(ValueError, match="Multiple initial tools are not supported for non-structured models"): - setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - - # Cleanup - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tool rule with single initial tool - t3_name = "third_secret_word" - t3 = client.create_or_update_tool(third_secret_word) - tool_rules = [InitToolRule(tool_name=t3_name)] - tools = [t3] - for config_file in config_files: - agent_uuid = str(uuid.uuid4()) - - # Structured output model (should work with single init tool) - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + if "gpt-4o" in config_file or init_tools_case == "single": + # Should succeed + agent_state = setup_agent( + server, + config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, + ) assert agent_state is not None + else: + # Non-structured model with multiple init tools should fail + with pytest.raises(ValueError, match="Multiple initial tools are not supported for non-structured models"): + setup_agent( + server, + config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, + ) - cleanup(client=client, agent_uuid=agent_uuid) + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) -def test_claude_initial_tool_rule_enforced(disable_e2b_api_key): - """Test that the initial tool rule is enforced for the first message.""" - client = create_client() - - # Create tool rules that require tool_a to be called first - t1_name = "first_secret_word" - t2_name = "second_secret_word" - t1 = client.create_or_update_tool(first_secret_word) - t2 = client.create_or_update_tool(second_secret_word) +@pytest.mark.timeout(180) +def test_claude_initial_tool_rule_enforced( + server, + disable_e2b_api_key, + first_secret_tool, + second_secret_tool, + default_user, +): + """Test that the initial tool rule is enforced for the first message using Claude model.""" tool_rules = [ - InitToolRule(tool_name=t1_name), - ChildToolRule(tool_name=t1_name, children=[t2_name]), - TerminalToolRule(tool_name=t2_name), + InitToolRule(tool_name=first_secret_tool.name), + ChildToolRule(tool_name=first_secret_tool.name, children=[second_secret_tool.name]), + TerminalToolRule(tool_name=second_secret_tool.name), ] - tools = [t1, t2] - - # Make agent state + tools = [first_secret_tool, second_secret_tool] anthropic_config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" + for i in range(3): agent_uuid = str(uuid.uuid4()) agent_state = setup_agent( - client, anthropic_config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules + server, + anthropic_config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, ) - response = client.user_message(agent_id=agent_state.id, message="What is the second secret word?") + + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="What is the second secret word?")], + ) + messages = [m for step in usage_stats.steps_messages for m in step] + letta_messages = [] + for m in messages: + letta_messages += m.to_letta_messages() + + response = LettaResponse(messages=letta_messages, usage=usage_stats) assert_sanity_checks(response) - messages = response.messages - assert_invoked_function_call(messages, "first_secret_word") - assert_invoked_function_call(messages, "second_secret_word") + # Check that the expected tools were invoked + assert_invoked_function_call(response.messages, "first_secret_word") + assert_invoked_function_call(response.messages, "second_secret_word") - tool_names = [t.name for t in [t1, t2]] - tool_names += ["send_message"] - for m in messages: + tool_names = [t.name for t in [first_secret_tool, second_secret_tool]] + ["send_message"] + for m in response.messages: if isinstance(m, ToolCallMessage): - # Check that it's equal to the first one assert m.tool_call.name == tool_names[0] - - # Pop out first one tool_names = tool_names[1:] print(f"Passed iteration {i}") - cleanup(client=client, agent_uuid=agent_uuid) + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) - # Implement exponential backoff with initial time of 10 seconds + # Exponential backoff if i < 2: backoff_time = 10 * (2**i) time.sleep(backoff_time) -@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely -def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) +@pytest.mark.timeout(60) +@pytest.mark.parametrize( + "config_file", + [ + "tests/configs/llm_model_configs/claude-3-5-sonnet.json", + "tests/configs/llm_model_configs/openai-gpt-4o.json", + ], +) +def test_agent_no_structured_output_with_one_child_tool_parametrized( + server, + disable_e2b_api_key, + default_user, + config_file, +): + """Test that agent correctly calls tool chains with unstructured output under various model configs.""" + send_message = server.tool_manager.get_tool_by_name(tool_name="send_message", actor=default_user) + archival_memory_search = server.tool_manager.get_tool_by_name(tool_name="archival_memory_search", actor=default_user) + archival_memory_insert = server.tool_manager.get_tool_by_name(tool_name="archival_memory_insert", actor=default_user) - send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user) - archival_memory_search = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_search", actor=client.user) - archival_memory_insert = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_insert", actor=client.user) + tools = [send_message, archival_memory_search, archival_memory_insert] - # Make tool rules tool_rules = [ InitToolRule(tool_name="archival_memory_search"), ChildToolRule(tool_name="archival_memory_search", children=["archival_memory_insert"]), ChildToolRule(tool_name="archival_memory_insert", children=["send_message"]), TerminalToolRule(tool_name="send_message"), ] - tools = [send_message, archival_memory_search, archival_memory_insert] - config_files = [ - "tests/configs/llm_model_configs/claude-3-5-sonnet.json", - "tests/configs/llm_model_configs/openai-gpt-4o.json", + max_retries = 3 + last_error = None + agent_uuid = str(uuid.uuid4()) + + for attempt in range(max_retries): + try: + agent_state = setup_agent( + server, + config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, + ) + + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="hi. run archival memory search")], + ) + messages = [m for step in usage_stats.steps_messages for m in step] + letta_messages = [] + for m in messages: + letta_messages += m.to_letta_messages() + + response = LettaResponse(messages=letta_messages, usage=usage_stats) + + # Run assertions + assert_sanity_checks(response) + assert_invoked_function_call(response.messages, "archival_memory_search") + assert_invoked_function_call(response.messages, "archival_memory_insert") + assert_invoked_function_call(response.messages, "send_message") + + tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] + for m in response.messages: + if isinstance(m, ToolCallMessage): + assert m.tool_call.name == tool_names[0] + tool_names = tool_names[1:] + + print(f"[{config_file}] Got successful response:\n\n{response}") + break # success + + except AssertionError as e: + last_error = e + print(f"[{config_file}] Attempt {attempt + 1} failed") + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) + + if last_error: + raise last_error + + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) + + +@pytest.mark.timeout(30) +@pytest.mark.parametrize("include_base_tools", [False, True]) +def test_init_tool_rule_always_fails( + server, + disable_e2b_api_key, + auto_error_tool, + default_user, + include_base_tools, +): + """Test behavior when InitToolRule invokes a tool that always fails.""" + config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" + agent_uuid = str(uuid.uuid4()) + + tool_rule = InitToolRule(tool_name=auto_error_tool.name) + agent_state = setup_agent( + server, + config_file, + agent_uuid=agent_uuid, + tool_ids=[auto_error_tool.id], + tool_rules=[tool_rule], + include_base_tools=include_base_tools, + ) + + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="blah blah blah")], + ) + messages = [m for step in usage_stats.steps_messages for m in step] + letta_messages = [msg for m in messages for msg in m.to_letta_messages()] + response = LettaResponse(messages=letta_messages, usage=usage_stats) + + assert_invoked_function_call(response.messages, auto_error_tool.name) + + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) + + +def test_continue_tool_rule(server, default_user): + """Test the continue tool rule by forcing send_message to loop before ending with core_memory_append.""" + config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" + agent_uuid = str(uuid.uuid4()) + + tool_ids = [ + server.tool_manager.get_tool_by_name("send_message", actor=default_user).id, + server.tool_manager.get_tool_by_name("core_memory_append", actor=default_user).id, ] - for config in config_files: - max_retries = 3 - last_error = None + tool_rules = [ + ContinueToolRule(tool_name="send_message"), + TerminalToolRule(tool_name="core_memory_append"), + ] - for attempt in range(max_retries): - try: - agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search") + agent_state = setup_agent( + server, + config_file, + agent_uuid, + tool_ids=tool_ids, + tool_rules=tool_rules, + include_base_tools=False, + include_base_tool_rules=False, + ) - # Make checks - assert_sanity_checks(response) + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")], + ) + messages = [m for step in usage_stats.steps_messages for m in step] + letta_messages = [msg for m in messages for msg in m.to_letta_messages()] + response = LettaResponse(messages=letta_messages, usage=usage_stats) - # Assert the tools were called - assert_invoked_function_call(response.messages, "archival_memory_search") - assert_invoked_function_call(response.messages, "archival_memory_insert") - assert_invoked_function_call(response.messages, "send_message") + assert_invoked_function_call(response.messages, "send_message") + assert_invoked_function_call(response.messages, "core_memory_append") - # Check ordering of tool calls - tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] - for m in response.messages: - if isinstance(m, ToolCallMessage): - # Check that it's equal to the first one - assert m.tool_call.name == tool_names[0] + # Check order + send_idx = next(i for i, m in enumerate(response.messages) if isinstance(m, ToolCallMessage) and m.tool_call.name == "send_message") + append_idx = next( + i for i, m in enumerate(response.messages) if isinstance(m, ToolCallMessage) and m.tool_call.name == "core_memory_append" + ) + assert send_idx < append_idx, "send_message should occur before core_memory_append" - # Pop out first one - tool_names = tool_names[1:] - - print(f"Got successful response from client: \n\n{response}") - break # Test passed, exit retry loop - - except AssertionError as e: - last_error = e - print(f"Attempt {attempt + 1} failed, retrying..." if attempt < max_retries - 1 else f"All {max_retries} attempts failed") - cleanup(client=client, agent_uuid=agent_uuid) - continue - - if last_error and attempt == max_retries - 1: - raise last_error # Re-raise the last error if all retries failed - - cleanup(client=client, agent_uuid=agent_uuid) + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) # @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely @@ -342,7 +535,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # reveal_secret_word # """ # -# client = create_client() +# # cleanup(client=client, agent_uuid=agent_uuid) # # coin_flip_name = "flip_coin" @@ -406,7 +599,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # v # any tool... <-- When output doesn't match mapping, agent can call any tool # """ -# client = create_client() +# # cleanup(client=client, agent_uuid=agent_uuid) # # # Create tools - we'll make several available to the agent @@ -467,7 +660,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # v # fourth_secret_word <-- Should remember coin flip result after reload # """ -# client = create_client() +# # cleanup(client=client, agent_uuid=agent_uuid) # # # Create tools @@ -522,7 +715,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # v # fourth_secret_word # """ -# client = create_client() +# # cleanup(client=client, agent_uuid=agent_uuid) # # # Create tools @@ -563,165 +756,3 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin" # # cleanup(client, agent_uuid=agent_state.id) - - -def test_init_tool_rule_always_fails_one_tool(): - """ - Test an init tool rule that always fails when called. The agent has only one tool available. - - Once that tool fails and the agent removes that tool, the agent should have 0 tools available. - - This means that the agent should return from `step` early. - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - bad_tool = client.create_or_update_tool(auto_error) - - # Create tool rule: InitToolRule - tool_rule = InitToolRule( - tool_name=bad_tool.name, - ) - - # Set up agent with the tool rule - claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" - agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=False) - - # Start conversation - response = client.user_message(agent_id=agent_state.id, message="blah blah blah") - - # Verify the tool calls - tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] - assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls - assert_invoked_function_call(response.messages, bad_tool.name) - - -def test_init_tool_rule_always_fails_multiple_tools(): - """ - Test an init tool rule that always fails when called. The agent has only 1+ tools available. - Once that tool fails and the agent removes that tool, the agent should have other tools available. - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - bad_tool = client.create_or_update_tool(auto_error) - - # Create tool rule: InitToolRule - tool_rule = InitToolRule( - tool_name=bad_tool.name, - ) - - # Set up agent with the tool rule - claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" - agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=True) - - # Start conversation - response = client.user_message(agent_id=agent_state.id, message="blah blah blah") - - # Verify the tool calls - tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] - assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls - assert_invoked_function_call(response.messages, bad_tool.name) - - -def test_continue_tool_rule(): - """Test the continue tool rule by forcing the send_message tool to continue""" - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - continue_tool_rule = ContinueToolRule( - tool_name="send_message", - ) - terminal_tool_rule = TerminalToolRule( - tool_name="core_memory_append", - ) - rules = [continue_tool_rule, terminal_tool_rule] - - core_memory_append_tool = client.get_tool_id("core_memory_append") - send_message_tool = client.get_tool_id("send_message") - - # Set up agent with the tool rule - claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" - agent_state = setup_agent( - client, - claude_config, - agent_uuid, - tool_rules=rules, - tool_ids=[core_memory_append_tool, send_message_tool], - include_base_tools=False, - include_base_tool_rules=False, - ) - - # Start conversation - response = client.user_message(agent_id=agent_state.id, message="blah blah blah") - - # Verify the tool calls - tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] - assert len(tool_calls) >= 1 - assert_invoked_function_call(response.messages, "send_message") - assert_invoked_function_call(response.messages, "core_memory_append") - - # ensure send_message called before core_memory_append - send_message_call_index = None - core_memory_append_call_index = None - for i, call in enumerate(tool_calls): - if call.tool_call.name == "send_message": - send_message_call_index = i - if call.tool_call.name == "core_memory_append": - core_memory_append_call_index = i - assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append" - - -@pytest.mark.timeout(60) -@retry_until_success(max_attempts=3, sleep_time_seconds=2) -def test_max_count_per_step_tool_rule_integration(disable_e2b_api_key): - """ - Test an agent with MaxCountPerStepToolRule to ensure a tool can only be called a limited number of times. - - Tool Flow: - repeatable_tool (max 2 times) - | - v - send_message - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - repeatable_tool_name = "first_secret_word" - final_tool_name = "send_message" - - repeatable_tool = client.create_or_update_tool(first_secret_word) - send_message_tool = client.get_tool(client.get_tool_id(final_tool_name)) # Assume send_message is a default tool - - # Define tool rules - tool_rules = [ - InitToolRule(tool_name=repeatable_tool_name), - MaxCountPerStepToolRule(tool_name=repeatable_tool_name, max_count_limit=2), - TerminalToolRule(tool_name=final_tool_name), - ] - - tools = [repeatable_tool, send_message_tool] - - # Setup agent - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - - # Start conversation - response = client.user_message( - agent_id=agent_state.id, message=f"Keep calling {repeatable_tool_name} nonstop without calling ANY other tool." - ) - - # Make checks - assert_sanity_checks(response) - - # Ensure the repeatable tool is only called twice - count = sum(1 for m in response.messages if isinstance(m, ToolCallMessage) and m.tool_call.name == repeatable_tool_name) - assert count == 2, f"Expected 'first_secret_word' to be called exactly 2 times, but got {count}" - - # Ensure send_message was eventually called - assert_invoked_function_call(response.messages, final_tool_name) - - print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index b85728db..eba7b84f 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -7,17 +7,16 @@ from unittest.mock import patch import pytest from sqlalchemy import delete -from letta import create_client +from letta.config import LettaConfig from letta.functions.function_sets.base import core_memory_append, core_memory_replace from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable -from letta.schemas.agent import AgentState -from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.agent import AgentState, CreateAgent +from letta.schemas.block import CreateBlock from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate from letta.schemas.user import User +from letta.server.server import SyncServer from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_manager import ToolManager @@ -33,6 +32,21 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user")) # Fixtures +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. + + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server + + @pytest.fixture(autouse=True) def clear_tables(): """Fixture to clear the organization table before each test.""" @@ -192,12 +206,26 @@ def external_codebase_tool(test_user): @pytest.fixture -def agent_state(): - client = create_client() - agent_state = client.create_agent( - memory=ChatMemory(persona="This is the persona", human="My name is Chad"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), +def agent_state(server): + actor = server.user_manager.get_user_or_default() + agent_state = server.create_agent( + CreateAgent( + memory_blocks=[ + CreateBlock( + label="human", + value="username: sarah", + ), + CreateBlock( + label="persona", + value="This is the persona", + ), + ], + include_base_tools=True, + model="openai/gpt-4o-mini", + tags=["test_agents"], + embedding="letta/letta-free", + ), + actor=actor, ) agent_state.tool_rules = [] yield agent_state diff --git a/tests/integration_test_experimental.py b/tests/integration_test_experimental.py deleted file mode 100644 index 0b9df389..00000000 --- a/tests/integration_test_experimental.py +++ /dev/null @@ -1,579 +0,0 @@ -import os -import threading -import time -import uuid - -import httpx -import openai -import pytest -from dotenv import load_dotenv -from letta_client import CreateBlock, Letta, MessageCreate, TextContent -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk - -from letta.agents.letta_agent import LettaAgent -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageStreamStatus -from letta.schemas.letta_message_content import TextContent as LettaTextContent -from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import MessageCreate as LettaMessageCreate -from letta.schemas.tool import ToolCreate -from letta.schemas.usage import LettaUsageStatistics -from letta.services.agent_manager import AgentManager -from letta.services.block_manager import BlockManager -from letta.services.message_manager import MessageManager -from letta.services.passage_manager import PassageManager -from letta.services.tool_manager import ToolManager -from letta.services.user_manager import UserManager -from letta.settings import model_settings, settings - -# --- Server Management --- # - - -def _run_server(): - """Starts the Letta server in a background thread.""" - load_dotenv() - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - -@pytest.fixture(scope="session") -def server_url(): - """Ensures a server is running and returns its base URL.""" - url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=_run_server, daemon=True) - thread.start() - time.sleep(5) # Allow server startup time - - return url - - -# --- Client Setup --- # - - -@pytest.fixture(scope="session") -def client(server_url): - """Creates a REST client for testing.""" - client = Letta(base_url=server_url) - # llm_config = LLMConfig( - # model="claude-3-7-sonnet-latest", - # model_endpoint_type="anthropic", - # model_endpoint="https://api.anthropic.com/v1", - # context_window=32000, - # handle=f"anthropic/claude-3-7-sonnet-latest", - # put_inner_thoughts_in_kwargs=True, - # max_tokens=4096, - # ) - # - # client = create_client(base_url=server_url, token=None) - # client.set_default_llm_config(llm_config) - # client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - yield client - - -@pytest.fixture(scope="function") -def roll_dice_tool(client): - def roll_dice(): - """ - Rolls a 6 sided die. - - Returns: - str: The roll result. - """ - import time - - time.sleep(1) - return "Rolled a 10!" - - # tool = client.create_or_update_tool(func=roll_dice) - tool = client.tools.upsert_from_function(func=roll_dice) - # Yield the created tool - yield tool - - -@pytest.fixture(scope="function") -def weather_tool(client): - def get_weather(location: str) -> str: - """ - Fetches the current weather for a given location. - - Parameters: - location (str): The location to get the weather for. - - Returns: - str: A formatted string describing the weather in the given location. - - Raises: - RuntimeError: If the request to fetch weather data fails. - """ - import requests - - url = f"https://wttr.in/{location}?format=%C+%t" - - response = requests.get(url) - if response.status_code == 200: - weather_data = response.text - return f"The weather in {location} is {weather_data}." - else: - raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") - - # tool = client.create_or_update_tool(func=get_weather) - tool = client.tools.upsert_from_function(func=get_weather) - # Yield the created tool - yield tool - - -@pytest.fixture(scope="function") -def rethink_tool(client): - def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore - """ - Re-evaluate the memory in block_name, integrating new and updated facts. - Replace outdated information with the most likely truths, avoiding redundancy with original memories. - Ensure consistency with other memory blocks. - - Args: - new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block. - target_block_label (str): The name of the block to write to. - Returns: - str: None is always returned as this function does not produce a response. - """ - agent_state.memory.update_block_value(label=target_block_label, value=new_memory) - return None - - tool = client.tools.upsert_from_function(func=rethink_memory) - # Yield the created tool - yield tool - - -@pytest.fixture(scope="function") -def composio_gmail_get_profile_tool(default_user): - tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE") - tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user) - yield tool - - -@pytest.fixture(scope="function") -def agent_state(client, roll_dice_tool, weather_tool, rethink_tool): - """Creates an agent and ensures cleanup after tests.""" - # llm_config = LLMConfig( - # model="claude-3-7-sonnet-latest", - # model_endpoint_type="anthropic", - # model_endpoint="https://api.anthropic.com/v1", - # context_window=32000, - # handle=f"anthropic/claude-3-7-sonnet-latest", - # put_inner_thoughts_in_kwargs=True, - # max_tokens=4096, - # ) - agent_state = client.agents.create( - name=f"test_compl_{str(uuid.uuid4())[5:]}", - tool_ids=[roll_dice_tool.id, weather_tool.id, rethink_tool.id], - include_base_tools=True, - memory_blocks=[ - { - "label": "human", - "value": "Name: Matt", - }, - { - "label": "persona", - "value": "Friendly agent", - }, - ], - llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - ) - yield agent_state - client.agents.delete(agent_state.id) - - -@pytest.fixture(scope="function") -def openai_client(client, roll_dice_tool, weather_tool): - """Creates an agent and ensures cleanup after tests.""" - client = openai.AsyncClient( - api_key=model_settings.anthropic_api_key, - base_url="https://api.anthropic.com/v1/", - max_retries=0, - http_client=httpx.AsyncClient( - timeout=httpx.Timeout(connect=15.0, read=30.0, write=15.0, pool=15.0), - follow_redirects=True, - limits=httpx.Limits( - max_connections=50, - max_keepalive_connections=50, - keepalive_expiry=120, - ), - ), - ) - yield client - - -# --- Helper Functions --- # - - -def _assert_valid_chunk(chunk, idx, chunks): - """Validates the structure of each streaming chunk.""" - if isinstance(chunk, ChatCompletionChunk): - assert chunk.choices, "Each ChatCompletionChunk should have at least one choice." - - elif isinstance(chunk, LettaUsageStatistics): - assert chunk.completion_tokens > 0, "Completion tokens must be > 0." - assert chunk.prompt_tokens > 0, "Prompt tokens must be > 0." - assert chunk.total_tokens > 0, "Total tokens must be > 0." - assert chunk.step_count == 1, "Step count must be 1." - - elif isinstance(chunk, MessageStreamStatus): - assert chunk == MessageStreamStatus.done, "Stream should end with 'done' status." - assert idx == len(chunks) - 1, "The last chunk must be 'done'." - - else: - pytest.fail(f"Unexpected chunk type: {chunk}") - - -# --- Test Cases --- # - - -@pytest.mark.asyncio -@pytest.mark.parametrize("message", ["What is the weather today in SF?"]) -async def test_new_agent_loop(disable_e2b_api_key, openai_client, agent_state, message): - actor = UserManager().get_user_or_default(user_id="asf") - agent = LettaAgent( - agent_id=agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - passage_manager=PassageManager(), - actor=actor, - ) - - response = await agent.step([LettaMessageCreate(role="user", content=[LettaTextContent(text=message)])]) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("message", ["Use your rethink tool to rethink the human memory considering Matt likes chicken."]) -async def test_rethink_tool(disable_e2b_api_key, openai_client, agent_state, message): - actor = UserManager().get_user_or_default(user_id="asf") - agent = LettaAgent( - agent_id=agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - passage_manager=PassageManager(), - actor=actor, - ) - - assert "chicken" not in AgentManager().get_agent_by_id(agent_state.id, actor).memory.get_block("human").value - response = await agent.step([LettaMessageCreate(role="user", content=[LettaTextContent(text=message)])]) - assert "chicken" in AgentManager().get_agent_by_id(agent_state.id, actor).memory.get_block("human").value.lower() - - -@pytest.mark.asyncio -async def test_vertex_send_message_structured_outputs(disable_e2b_api_key, client): - original_experimental_key = settings.use_vertex_structured_outputs_experimental - settings.use_vertex_structured_outputs_experimental = True - try: - actor = UserManager().get_user_or_default(user_id="asf") - - stale_agents = AgentManager().list_agents(actor=actor, limit=300) - for agent in stale_agents: - AgentManager().delete_agent(agent_id=agent.id, actor=actor) - - manager_agent_state = client.agents.create( - name=f"manager", - include_base_tools=False, # change this to True to repro MALFORMED FUNCTION CALL error - tools=["send_message"], - tags=["manager"], - model="google_vertex/gemini-2.5-flash-preview-04-17", - embedding="letta/letta-free", - ) - manager_agent = LettaAgent( - agent_id=manager_agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - passage_manager=PassageManager(), - actor=actor, - ) - - response = await manager_agent.step( - [ - LettaMessageCreate( - role="user", - content=[ - LettaTextContent(text=("Check the weather in Seattle.")), - ], - ), - ] - ) - assert len(response.messages) == 3 - assert response.messages[0].message_type == "user_message" - # Shouldn't this have reasoning message? - # assert response.messages[1].message_type == "reasoning_message" - assert response.messages[1].message_type == "assistant_message" - assert response.messages[2].message_type == "tool_return_message" - finally: - settings.use_vertex_structured_outputs_experimental = original_experimental_key - - -@pytest.mark.asyncio -async def test_multi_agent_broadcast(disable_e2b_api_key, client, openai_client, weather_tool): - actor = UserManager().get_user_or_default(user_id="asf") - - stale_agents = AgentManager().list_agents(actor=actor, limit=300) - for agent in stale_agents: - AgentManager().delete_agent(agent_id=agent.id, actor=actor) - - manager_agent_state = client.agents.create( - name=f"manager", - include_base_tools=True, - include_multi_agent_tools=True, - tags=["manager"], - model="openai/gpt-4o", - embedding="letta/letta-free", - ) - manager_agent = LettaAgent( - agent_id=manager_agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - passage_manager=PassageManager(), - actor=actor, - ) - - tag = "subagent" - workers = [] - for idx in range(30): - workers.append( - client.agents.create( - name=f"worker_{idx}", - include_base_tools=True, - tags=[tag], - tool_ids=[weather_tool.id], - model="openai/gpt-4o", - embedding="letta/letta-free", - ), - ) - response = await manager_agent.step( - [ - LettaMessageCreate( - role="user", - content=[ - LettaTextContent( - text=( - "Use the `send_message_to_agents_matching_tags` tool to send a message to agents with " - "tag 'subagent' asking them to check the weather in Seattle." - ) - ), - ], - ), - ] - ) - - -def test_multi_agent_broadcast_client(client: Letta, weather_tool): - # delete any existing worker agents - workers = client.agents.list(tags=["worker"]) - for worker in workers: - client.agents.delete(agent_id=worker.id) - - # create worker agents - num_workers = 10 - for idx in range(num_workers): - client.agents.create( - name=f"worker_{idx}", - include_base_tools=True, - tags=["worker"], - tool_ids=[weather_tool.id], - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - ) - - # create supervisor agent - supervisor = client.agents.create( - name="supervisor", - include_base_tools=True, - include_multi_agent_tools=True, - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - tags=["supervisor"], - ) - - # send a message to the supervisor - import time - - start = time.perf_counter() - response = client.agents.messages.create( - agent_id=supervisor.id, - messages=[ - MessageCreate( - role="user", - content=[ - TextContent( - text="Use the `send_message_to_agents_matching_tags` tool to send a message to agents with tag 'worker' asking them to check the weather in Seattle." - ) - ], - ) - ], - ) - end = time.perf_counter() - print("TIME ELAPSED: " + str(end - start)) - for message in response.messages: - print(message) - - -def test_call_weather(client: Letta, weather_tool): - # delete any existing worker agents - workers = client.agents.list(tags=["worker", "supervisor"]) - for worker in workers: - client.agents.delete(agent_id=worker.id) - - # create supervisor agent - supervisor = client.agents.create( - name="supervisor", - include_base_tools=True, - tool_ids=[weather_tool.id], - model="openai/gpt-4o", - embedding="letta/letta-free", - tags=["supervisor"], - ) - - # send a message to the supervisor - import time - - start = time.perf_counter() - response = client.agents.messages.create( - agent_id=supervisor.id, - messages=[ - { - "role": "user", - "content": "What's the weather like in Seattle?", - } - ], - ) - end = time.perf_counter() - print("TIME ELAPSED: " + str(end - start)) - for message in response.messages: - print(message) - - -def run_supervisor_worker_group(client: Letta, weather_tool, group_id: str): - # Delete any existing agents for this group (if rerunning) - existing_workers = client.agents.list(tags=[f"worker-{group_id}"]) - for worker in existing_workers: - client.agents.delete(agent_id=worker.id) - - # Create worker agents - num_workers = 50 - for idx in range(num_workers): - client.agents.create( - name=f"worker_{group_id}_{idx}", - include_base_tools=True, - tags=[f"worker-{group_id}"], - tool_ids=[weather_tool.id], - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - ) - - # Create supervisor agent - supervisor = client.agents.create( - name=f"supervisor_{group_id}", - include_base_tools=True, - include_multi_agent_tools=True, - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - tags=[f"supervisor-{group_id}"], - ) - - # Send message to supervisor to broadcast to workers - response = client.agents.messages.create( - agent_id=supervisor.id, - messages=[ - { - "role": "user", - "content": "Use the `send_message_to_agents_matching_tags` tool to send a message to agents with tag " - f"'worker-{group_id}' asking them to check the weather in Seattle.", - } - ], - ) - - return response - - -def test_anthropic_streaming(client: Letta): - agent_name = "anthropic_tester" - - existing_agents = client.agents.list(tags=[agent_name]) - for worker in existing_agents: - client.agents.delete(agent_id=worker.id) - - llm_config = LLMConfig( - model="claude-3-7-sonnet-20250219", - model_endpoint_type="anthropic", - model_endpoint="https://api.anthropic.com/v1", - context_window=32000, - handle=f"anthropic/claude-3-5-sonnet-20241022", - put_inner_thoughts_in_kwargs=False, - max_tokens=4096, - enable_reasoner=True, - max_reasoning_tokens=1024, - ) - - agent = client.agents.create( - name=agent_name, - tags=[agent_name], - include_base_tools=True, - embedding="letta/letta-free", - llm_config=llm_config, - memory_blocks=[CreateBlock(label="human", value="")], - # tool_rules=[InitToolRule(tool_name="core_memory_append")] - ) - - response = client.agents.messages.create_stream( - agent_id=agent.id, - messages=[ - MessageCreate( - role="user", - content=[TextContent(text="Use the core memory append tool to append `banana` to the persona core memory.")], - ), - ], - stream_tokens=True, - ) - - print(list(response)) - - -import time - - -def test_create_agents_telemetry(client: Letta): - start_total = time.perf_counter() - - # delete any existing worker agents - start_delete = time.perf_counter() - workers = client.agents.list(tags=["worker"]) - for worker in workers: - client.agents.delete(agent_id=worker.id) - end_delete = time.perf_counter() - print(f"[telemetry] Deleted {len(workers)} existing worker agents in {end_delete - start_delete:.2f}s") - - # create worker agents - num_workers = 1 - agent_times = [] - for idx in range(num_workers): - start = time.perf_counter() - client.agents.create( - name=f"worker_{idx}", - include_base_tools=True, - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - ) - end = time.perf_counter() - duration = end - start - agent_times.append(duration) - print(f"[telemetry] Created worker_{idx} in {duration:.2f}s") - - total_duration = time.perf_counter() - start_total - avg_duration = sum(agent_times) / len(agent_times) - - print(f"[telemetry] Total time to create {num_workers} agents: {total_duration:.2f}s") - print(f"[telemetry] Average agent creation time: {avg_duration:.2f}s") - print(f"[telemetry] Fastest agent: {min(agent_times):.2f}s, Slowest agent: {max(agent_times):.2f}s") diff --git a/tests/integration_test_initial_sequence.py b/tests/integration_test_initial_sequence.py deleted file mode 100644 index 71449171..00000000 --- a/tests/integration_test_initial_sequence.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import threading -import time - -import pytest -from dotenv import load_dotenv -from letta_client import Letta, MessageCreate - - -def run_server(): - load_dotenv() - - from letta.server.rest_api.app import start_server - - print("Starting server...") - start_server(debug=True) - - -@pytest.fixture( - scope="module", -) -def client(request): - # Get URL from environment or start server - server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283") - if not os.getenv("LETTA_SERVER_URL"): - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - - # create the Letta client - yield Letta(base_url=server_url, token=None) - - -def test_initial_sequence(client: Letta): - # create an agent - agent = client.agents.create( - memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], - model="letta/letta-free", - embedding="letta/letta-free", - initial_message_sequence=[ - MessageCreate( - role="assistant", - content="Hello, how are you?", - ), - MessageCreate(role="user", content="I'm good, and you?"), - ], - ) - - # list messages - messages = client.agents.messages.list(agent_id=agent.id) - response = client.agents.messages.create( - agent_id=agent.id, - messages=[ - MessageCreate( - role="user", - content="hello assistant!", - ) - ], - ) - assert len(messages) == 3 - assert messages[0].message_type == "system_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "user_message" diff --git a/tests/integration_test_send_message_schema.py b/tests/integration_test_send_message_schema.py deleted file mode 100644 index 57773ec8..00000000 --- a/tests/integration_test_send_message_schema.py +++ /dev/null @@ -1,192 +0,0 @@ -# TODO (cliandy): Tested in SDK -# TODO (cliandy): Comment out after merge - -# import os -# import threading -# import time - -# import pytest -# from dotenv import load_dotenv -# from letta_client import AssistantMessage, AsyncLetta, Letta, Tool - -# from letta.schemas.agent import AgentState -# from typing import List, Any, Dict - -# # ------------------------------ -# # Fixtures -# # ------------------------------ - - -# @pytest.fixture(scope="module") -# def server_url() -> str: -# """ -# Provides the URL for the Letta server. -# If the environment variable 'LETTA_SERVER_URL' is not set, this fixture -# will start the Letta server in a background thread and return the default URL. -# """ - -# def _run_server() -> None: -# """Starts the Letta server in a background thread.""" -# load_dotenv() # Load environment variables from .env file -# from letta.server.rest_api.app import start_server - -# start_server(debug=True) - -# # Retrieve server URL from environment, or default to localhost -# url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - -# # If no environment variable is set, start the server in a background thread -# if not os.getenv("LETTA_SERVER_URL"): -# thread = threading.Thread(target=_run_server, daemon=True) -# thread.start() -# time.sleep(5) # Allow time for the server to start - -# return url - - -# @pytest.fixture -# def client(server_url: str) -> Letta: -# """ -# Creates and returns a synchronous Letta REST client for testing. -# """ -# client_instance = Letta(base_url=server_url) -# yield client_instance - - -# @pytest.fixture -# def async_client(server_url: str) -> AsyncLetta: -# """ -# Creates and returns an asynchronous Letta REST client for testing. -# """ -# async_client_instance = AsyncLetta(base_url=server_url) -# yield async_client_instance - - -# @pytest.fixture -# def roll_dice_tool(client: Letta) -> Tool: -# """ -# Registers a simple roll dice tool with the provided client. - -# The tool simulates rolling a six-sided die but returns a fixed result. -# """ - -# def roll_dice() -> str: -# """ -# Simulates rolling a die. - -# Returns: -# str: The roll result. -# """ -# # Note: The result here is intentionally incorrect for demonstration purposes. -# return "Rolled a 10!" - -# tool = client.tools.upsert_from_function(func=roll_dice) -# yield tool - - -# @pytest.fixture -# def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: -# """ -# Creates and returns an agent state for testing with a pre-configured agent. -# The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. -# """ -# agent_state_instance = client.agents.create( -# name="supervisor", -# include_base_tools=True, -# tool_ids=[roll_dice_tool.id], -# model="openai/gpt-4o", -# embedding="letta/letta-free", -# tags=["supervisor"], -# include_base_tool_rules=True, - -# ) -# yield agent_state_instance - - -# # Goal is to test that when an Agent is created with a `response_format`, that the response -# # of `send_message` is in the correct format. This will be done by modifying the agent's -# # `send_message` tool so that it returns a format based on what is passed in. -# # -# # `response_format` is an optional field -# # if `response_format.type` is `text`, then the schema does not change -# # if `response_format.type` is `json_object`, then the schema is a dict -# # if `response_format.type` is `json_schema`, then the schema is a dict matching that json schema - - -# USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Send me a message."}] - -# # ------------------------------ -# # Test Cases -# # ------------------------------ - -# def test_client_send_message_text_response_format(client: "Letta", agent: "AgentState") -> None: -# """Test client send_message with response_format='json_object'.""" -# client.agents.modify(agent.id, response_format={"type": "text"}) - -# response = client.agents.messages.create_stream( -# agent_id=agent.id, -# messages=USER_MESSAGE, -# ) -# messages = list(response) -# assert isinstance(messages[-1], AssistantMessage) -# assert isinstance(messages[-1].content, str) - - -# def test_client_send_message_json_object_response_format(client: "Letta", agent: "AgentState") -> None: -# """Test client send_message with response_format='json_object'.""" -# client.agents.modify(agent.id, response_format={"type": "json_object"}) - -# response = client.agents.messages.create_stream( -# agent_id=agent.id, -# messages=USER_MESSAGE, -# ) -# messages = list(response) -# assert isinstance(messages[-1], AssistantMessage) -# assert isinstance(messages[-1].content, dict) - - -# def test_client_send_message_json_schema_response_format(client: "Letta", agent: "AgentState") -> None: -# """Test client send_message with response_format='json_schema' and a valid schema.""" -# client.agents.modify(agent.id, response_format={ -# "type": "json_schema", -# "json_schema": { -# "name": "reasoning_schema", -# "schema": { -# "type": "object", -# "properties": { -# "steps": { -# "type": "array", -# "items": { -# "type": "object", -# "properties": { -# "explanation": { "type": "string" }, -# "output": { "type": "string" } -# }, -# "required": ["explanation", "output"], -# "additionalProperties": False -# } -# }, -# "final_answer": { "type": "string" } -# }, -# "required": ["steps", "final_answer"], -# "additionalProperties": True -# }, -# "strict": True -# } -# }) -# response = client.agents.messages.create_stream( -# agent_id=agent.id, -# messages=USER_MESSAGE, -# ) -# messages = list(response) - -# assert isinstance(messages[-1], AssistantMessage) -# assert isinstance(messages[-1].content, dict) - - -# # def test_client_send_message_invalid_json_schema(client: "Letta", agent: "AgentState") -> None: -# # """Test client send_message with an invalid json_schema (should error or fallback).""" -# # invalid_schema: Dict[str, Any] = {"type": "object", "properties": {"foo": {"type": "unknown"}}} -# # client.agents.modify(agent.id, response_format="json_schema") -# # result: Any = client.agents.send_message(agent.id, "Test invalid schema") -# # assert result is None or "error" in str(result).lower() diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 6c0f74c1..6e0ebd73 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -6,15 +6,16 @@ from typing import List import pytest -from letta import create_client from letta.agent import Agent -from letta.client.client import LocalClient +from letta.config import LettaConfig from letta.llm_api.helpers import calculate_summarizer_cutoff +from letta.schemas.agent import CreateAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message +from letta.schemas.message import Message, MessageCreate +from letta.server.server import SyncServer from letta.streaming_interface import StreamingRefreshCLIInterface from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH from tests.helpers.utils import cleanup @@ -30,22 +31,34 @@ test_agent_name = f"test_client_{str(uuid.uuid4())}" @pytest.fixture(scope="module") -def client(): - client = create_client() - # client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) +def server(): + config = LettaConfig.load() + config.save() - yield client + server = SyncServer() + return server @pytest.fixture(scope="module") -def agent_state(client): +def default_user(server): + yield server.user_manager.get_user_or_default() + + +@pytest.fixture(scope="module") +def agent_state(server, default_user): # Generate uuid for agent name for this example - agent_state = client.create_agent(name=test_agent_name) + agent_state = server.create_agent( + CreateAgent( + name=test_agent_name, + include_base_tools=True, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ), + actor=default_user, + ) yield agent_state - client.delete_agent(agent_state.id) + server.agent_manager.delete_agent(agent_state.id, default_user) # Sample data setup @@ -113,9 +126,9 @@ def test_cutoff_calculation(mocker): assert messages[cutoff - 1].role == MessageRole.user -def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_state): +def test_cutoff_calculation_with_tool_call(mocker, server, agent_state, default_user): """Test that trim_older_in_context_messages properly handles tool responses with _trim_tool_response.""" - agent_state = client.get_agent(agent_id=agent_state.id) + agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_state.id, actor=default_user) # Setup messages = [ @@ -133,18 +146,18 @@ def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_st def mock_get_messages_by_ids(message_ids, actor): return [msg for msg in messages if msg.id in message_ids] - mocker.patch.object(client.server.agent_manager.message_manager, "get_messages_by_ids", side_effect=mock_get_messages_by_ids) + mocker.patch.object(server.agent_manager.message_manager, "get_messages_by_ids", side_effect=mock_get_messages_by_ids) # Mock get_agent_by_id to return an agent with our message IDs mock_agent = mocker.Mock() mock_agent.message_ids = [msg.id for msg in messages] - mocker.patch.object(client.server.agent_manager, "get_agent_by_id", return_value=mock_agent) + mocker.patch.object(server.agent_manager, "get_agent_by_id", return_value=mock_agent) # Mock set_in_context_messages to capture what messages are being set - mock_set_messages = mocker.patch.object(client.server.agent_manager, "set_in_context_messages", return_value=agent_state) + mock_set_messages = mocker.patch.object(server.agent_manager, "set_in_context_messages", return_value=agent_state) # Test Case: Trim to remove orphaned tool response - client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=client.user) + server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=default_user) test1 = mock_set_messages.call_args_list[0][1] assert len(test1["message_ids"]) == 5 @@ -152,104 +165,92 @@ def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_st mock_set_messages.reset_mock() # Test Case: Does not result in trimming the orphaned tool response - client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=2, actor=client.user) + server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=2, actor=default_user) test2 = mock_set_messages.call_args_list[0][1] assert len(test2["message_ids"]) == 6 -def test_summarize_many_messages_basic(client, disable_e2b_api_key): +def test_summarize_many_messages_basic(server, default_user): + """Test that a small-context agent gets enough messages for summarization.""" small_context_llm_config = LLMConfig.default_config("gpt-4o-mini") small_context_llm_config.context_window = 3000 - small_agent_state = client.create_agent( - name="small_context_agent", - llm_config=small_context_llm_config, - ) - for _ in range(10): - client.user_message( - agent_id=small_agent_state.id, - message="hi " * 60, - ) - client.delete_agent(small_agent_state.id) - -def test_summarize_messages_inplace(client, agent_state, disable_e2b_api_key): - """Test summarization via sending the summarize CLI command or via a direct call to the agent object""" - # First send a few messages (5) - response = client.user_message( - agent_id=agent_state.id, - message="Hey, how's it going? What do you think about this whole shindig", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message( - agent_id=agent_state.id, - message="Any thoughts on the meaning of life?", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?").messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message( - agent_id=agent_state.id, - message="Would you be surprised to learn that you're actually conversing with an AI right now?", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - # reload agent object - agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - - agent_obj.summarize_messages_inplace() - - -def test_auto_summarize(client, disable_e2b_api_key): - """Test that the summarizer triggers by itself""" - small_context_llm_config = LLMConfig.default_config("gpt-4o-mini") - small_context_llm_config.context_window = 4000 - - small_agent_state = client.create_agent( - name="small_context_agent", - llm_config=small_context_llm_config, + agent_state = server.create_agent( + CreateAgent( + name="small_context_agent", + llm_config=small_context_llm_config, + embedding="letta/letta-free", + ), + actor=default_user, ) try: - - def summarize_message_exists(messages: List[Message]) -> bool: - for message in messages: - if message.content[0].text and "The following is a summary of the previous" in message.content[0].text: - print(f"Summarize message found after {message_count} messages: \n {message.content[0].text}") - return True - return False - - MAX_ATTEMPTS = 10 - message_count = 0 - while True: - - # send a message - response = client.user_message( - agent_id=small_agent_state.id, - message="What is the meaning of life?", + for _ in range(10): + server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="hi " * 60)], ) - message_count += 1 - - print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------") - - # check if the summarize message is inside the messages - assert isinstance(client, LocalClient), "Test only works with LocalClient" - in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=small_agent_state.id, actor=client.user) - print("SUMMARY", summarize_message_exists(in_context_messages)) - if summarize_message_exists(in_context_messages): - break - - if message_count > MAX_ATTEMPTS: - raise Exception(f"Summarize message not found after {message_count} messages") - finally: - client.delete_agent(small_agent_state.id) + server.agent_manager.delete_agent(agent_id=agent_state.id, actor=default_user) + + +def test_summarize_messages_inplace(server, agent_state, default_user): + """Test summarization logic via agent object API.""" + for msg in [ + "Hey, how's it going? What do you think about this whole shindig?", + "Any thoughts on the meaning of life?", + "Does the number 42 ring a bell?", + "Would you be surprised to learn that you're actually conversing with an AI right now?", + ]: + response = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content=msg)], + ) + assert response.steps_messages + + agent = server.load_agent(agent_id=agent_state.id, actor=default_user) + agent.summarize_messages_inplace() + + +def test_auto_summarize(server, default_user): + """Test that summarization is automatically triggered.""" + small_context_llm_config = LLMConfig.default_config("gpt-4o-mini") + small_context_llm_config.context_window = 3000 + + agent_state = server.create_agent( + CreateAgent( + name="small_context_agent", + llm_config=small_context_llm_config, + embedding="letta/letta-free", + ), + actor=default_user, + ) + + def summarize_message_exists(messages: List[Message]) -> bool: + for message in messages: + if message.content[0].text and "The following is a summary of the previous" in message.content[0].text: + return True + return False + + try: + MAX_ATTEMPTS = 10 + for attempt in range(MAX_ATTEMPTS): + server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="What is the meaning of life?")], + ) + + in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) + + if summarize_message_exists(in_context_messages): + return + + raise AssertionError("Summarization was not triggered after 10 messages") + finally: + server.agent_manager.delete_agent(agent_id=agent_state.id, actor=default_user) @pytest.mark.parametrize( @@ -258,51 +259,53 @@ def test_auto_summarize(client, disable_e2b_api_key): "openai-gpt-4o.json", "azure-gpt-4o-mini.json", "claude-3-5-haiku.json", - # "groq.json", TODO: Support groq, rate limiting currently makes it impossible to test - # "gemini-pro.json", TODO: Gemini is broken + # "groq.json", # rate limits + # "gemini-pro.json", # broken ], ) -def test_summarizer(config_filename, client, agent_state): +def test_summarizer(config_filename, server, default_user): + """Test summarization across different LLM configs.""" namespace = uuid.NAMESPACE_DNS agent_name = str(uuid.uuid5(namespace, f"integration-test-summarizer-{config_filename}")) - # Get the LLM config - filename = os.path.join(LLM_CONFIG_DIR, config_filename) - config_data = json.load(open(filename, "r")) - - # Create client and clean up agents + # Load configs + config_data = json.load(open(os.path.join(LLM_CONFIG_DIR, config_filename))) llm_config = LLMConfig(**config_data) embedding_config = EmbeddingConfig(**json.load(open(EMBEDDING_CONFIG_PATH))) - client = create_client() - client.set_default_llm_config(llm_config) - client.set_default_embedding_config(embedding_config) - cleanup(client=client, agent_uuid=agent_name) + + # Ensure cleanup + cleanup(server=server, agent_uuid=agent_name, actor=default_user) # Create agent - agent_state = client.create_agent(name=agent_name, llm_config=llm_config, embedding_config=embedding_config) - full_agent_state = client.get_agent(agent_id=agent_state.id) + agent_state = server.create_agent( + CreateAgent( + name=agent_name, + llm_config=llm_config, + embedding_config=embedding_config, + ), + actor=default_user, + ) + + full_agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_state.id, actor=default_user) + letta_agent = Agent( interface=StreamingRefreshCLIInterface(), agent_state=full_agent_state, first_message_verify_mono=False, - user=client.user, + user=default_user, ) - # Make conversation - messages = [ + for msg in [ "Did you know that honey never spoils? Archaeologists have found pots of honey in ancient Egyptian tombs that are over 3,000 years old and still perfectly edible.", "Octopuses have three hearts, and two of them stop beating when they swim.", - ] - - for m in messages: + ]: letta_agent.step_user_message( - user_message_str=m, + user_message_str=msg, first_message=False, skip_verify=False, stream=False, ) - # Invoke a summarize letta_agent.summarize_messages_inplace() - in_context_messages = client.get_in_context_messages(agent_state.id) + in_context_messages = server.agent_manager.get_in_context_messages(agent_state.id, actor=default_user) assert SUMMARY_KEY_PHRASE in in_context_messages[1].content[0].text, f"Test failed for config: {config_filename}" diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 720922f2..1a9bd763 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -7,17 +7,16 @@ from unittest.mock import patch import pytest from sqlalchemy import delete -from letta import create_client +from letta.config import LettaConfig from letta.functions.function_sets.base import core_memory_append, core_memory_replace from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable -from letta.schemas.agent import AgentState -from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.agent import AgentState, CreateAgent +from letta.schemas.block import CreateBlock from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate, SandboxConfigUpdate from letta.schemas.user import User +from letta.server.server import SyncServer from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox @@ -32,6 +31,21 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user")) # Fixtures +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. + + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server + + @pytest.fixture(autouse=True) def clear_tables(): """Fixture to clear the organization table before each test.""" @@ -191,12 +205,26 @@ def external_codebase_tool(test_user): @pytest.fixture -def agent_state(): - client = create_client() - agent_state = client.create_agent( - memory=ChatMemory(persona="This is the persona", human="My name is Chad"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), +def agent_state(server): + actor = server.user_manager.get_user_or_default() + agent_state = server.create_agent( + CreateAgent( + memory_blocks=[ + CreateBlock( + label="human", + value="username: sarah", + ), + CreateBlock( + label="persona", + value="This is the persona", + ), + ], + include_base_tools=True, + model="openai/gpt-4o-mini", + tags=["test_agents"], + embedding="letta/letta-free", + ), + actor=actor, ) agent_state.tool_rules = [] yield agent_state diff --git a/tests/manual_test_many_messages.py b/tests/manual_test_many_messages.py index 6aaa33bb..df71dd85 100644 --- a/tests/manual_test_many_messages.py +++ b/tests/manual_test_many_messages.py @@ -1,7 +1,6 @@ import datetime import json import math -import os import random import uuid @@ -9,14 +8,11 @@ import pytest from faker import Faker from tqdm import tqdm -from letta import create_client +from letta.config import LettaConfig from letta.orm import Base -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message -from letta.services.agent_manager import AgentManager -from letta.services.message_manager import MessageManager -from tests.integration_test_summarizer import LLM_CONFIG_DIR +from letta.schemas.agent import CreateAgent +from letta.schemas.message import Message, MessageCreate +from letta.server.server import SyncServer @pytest.fixture(autouse=True) @@ -29,16 +25,25 @@ def truncate_database(): session.commit() -@pytest.fixture(scope="function") -def client(): - filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-sonnet.json") - config_data = json.load(open(filename, "r")) - llm_config = LLMConfig(**config_data) - client = create_client() - client.set_default_llm_config(llm_config) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. - yield client + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server + + +@pytest.fixture +def default_user(server): + actor = server.user_manager.get_user_or_default() + yield actor def generate_tool_call_id(): @@ -129,14 +134,13 @@ def create_tool_message(agent_id, organization_id, tool_call_id, timestamp): @pytest.mark.parametrize("num_messages", [1000]) -def test_many_messages_performance(client, num_messages): - """Main test function to generate messages and insert them into the database.""" - message_manager = MessageManager() - agent_manager = AgentManager() - actor = client.user +def test_many_messages_performance(server, default_user, num_messages): + """Performance test to insert many messages and ensure retrieval works correctly.""" + message_manager = server.agent_manager.message_manager + agent_manager = server.agent_manager start_time = datetime.datetime.now() - last_event_time = start_time # Track last event time + last_event_time = start_time def log_event(event): nonlocal last_event_time @@ -144,11 +148,19 @@ def test_many_messages_performance(client, num_messages): total_elapsed = (now - start_time).total_seconds() step_elapsed = (now - last_event_time).total_seconds() print(f"[+{total_elapsed:.3f}s | Δ{step_elapsed:.3f}s] {event}") - last_event_time = now # Update last event time + last_event_time = now log_event(f"Starting test with {num_messages} messages") - agent_state = client.create_agent(name="manager") + agent_state = server.create_agent( + CreateAgent( + name="manager", + include_base_tools=True, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ), + actor=default_user, + ) log_event(f"Created agent with ID {agent_state.id}") message_group_size = 3 @@ -158,37 +170,42 @@ def test_many_messages_performance(client, num_messages): organization_id = "org-00000000-0000-4000-8000-000000000000" all_messages = [] - for _ in tqdm(range(num_groups)): user_text, assistant_text = get_conversation_pair() tool_call_id = generate_tool_call_id() user_time, send_time, tool_time, current_time = generate_timestamps(current_time) - new_messages = [ - Message(**create_user_message(agent_state.id, organization_id, user_text, user_time)), - Message(**create_send_message(agent_state.id, organization_id, assistant_text, tool_call_id, send_time)), - Message(**create_tool_message(agent_state.id, organization_id, tool_call_id, tool_time)), - ] - all_messages.extend(new_messages) + + all_messages.extend( + [ + Message(**create_user_message(agent_state.id, organization_id, user_text, user_time)), + Message(**create_send_message(agent_state.id, organization_id, assistant_text, tool_call_id, send_time)), + Message(**create_tool_message(agent_state.id, organization_id, tool_call_id, tool_time)), + ] + ) log_event(f"Finished generating {len(all_messages)} messages") - message_manager.create_many_messages(all_messages, actor=actor) + message_manager.create_many_messages(all_messages, actor=default_user) log_event("Inserted messages into the database") agent_manager.set_in_context_messages( - agent_id=agent_state.id, message_ids=agent_state.message_ids + [m.id for m in all_messages], actor=client.user + agent_id=agent_state.id, + message_ids=agent_state.message_ids + [m.id for m in all_messages], + actor=default_user, ) log_event("Updated agent context with messages") - messages = message_manager.list_messages_for_agent(agent_id=agent_state.id, actor=client.user, limit=1000000000) + messages = message_manager.list_messages_for_agent( + agent_id=agent_state.id, + actor=default_user, + limit=1000000000, + ) log_event(f"Retrieved {len(messages)} messages from the database") assert len(messages) >= num_groups * message_group_size - response = client.send_message( - agent_id=agent_state.id, - role="user", - message="What have we been talking about?", + response = server.send_messages( + actor=default_user, agent_id=agent_state.id, input_messages=[MessageCreate(role="user", content="What have we been talking about?")] ) log_event("Sent message to agent and received response") diff --git a/tests/manual_test_multi_agent_broadcast_large.py b/tests/manual_test_multi_agent_broadcast_large.py index 70d88f44..3d406d84 100644 --- a/tests/manual_test_multi_agent_broadcast_large.py +++ b/tests/manual_test_multi_agent_broadcast_large.py @@ -1,89 +1,98 @@ -import json -import os - import pytest from tqdm import tqdm -from letta import create_client -from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.tool import Tool -from tests.integration_test_summarizer import LLM_CONFIG_DIR +from letta.config import LettaConfig +from letta.schemas.agent import CreateAgent +from letta.schemas.message import MessageCreate +from letta.server.server import SyncServer +from tests.utils import create_tool_from_func -@pytest.fixture(scope="function") -def client(): - filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-haiku.json") - config_data = json.load(open(filename, "r")) - llm_config = LLMConfig(**config_data) - client = create_client() - client.set_default_llm_config(llm_config) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. - yield client + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server @pytest.fixture -def roll_dice_tool(client): +def default_user(server): + actor = server.user_manager.get_user_or_default() + yield actor + + +@pytest.fixture +def roll_dice_tool(server, default_user): def roll_dice(): """ - Rolls a 6 sided die. + Rolls a 6-sided die. Returns: - str: The roll result. + str: Result of the die roll. """ return "Rolled a 5!" - # Set up tool details - source_code = parse_source_code(roll_dice) - source_type = "python" - description = "test_description" - tags = ["test"] - - tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type) - derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) - - derived_name = derived_json_schema["name"] - tool.json_schema = derived_json_schema - tool.name = derived_name - - tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user) - - # Yield the created tool - yield tool + tool = create_tool_from_func(func=roll_dice) + created_tool = server.tool_manager.create_or_update_tool(tool, actor=default_user) + yield created_tool @pytest.mark.parametrize("num_workers", [50]) -def test_multi_agent_large(client, roll_dice_tool, num_workers): +def test_multi_agent_large(server, default_user, roll_dice_tool, num_workers): manager_tags = ["manager"] worker_tags = ["helpers"] - # Clean up first from possibly failed tests - prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags + manager_tags, match_all_tags=True) - for agent in prev_worker_agents: - client.delete_agent(agent.id) + # Cleanup any pre-existing agents with both tags + prev_agents = server.agent_manager.list_agents(actor=default_user, tags=worker_tags + manager_tags, match_all_tags=True) + for agent in prev_agents: + server.agent_manager.delete_agent(agent.id, actor=default_user) - # Create "manager" agent - send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags") - manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id], tags=manager_tags) - manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) - - # Create 3 worker agents - worker_agents = [] - for idx in tqdm(range(num_workers)): - worker_agent_state = client.create_agent( - name=f"worker-{idx}", include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id] - ) - worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents.append(worker_agent) - - # Encourage the manager to send a message to the other agent_obj with the secret string - broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!" - client.send_message( - agent_id=manager_agent.agent_state.id, - role="user", - message=broadcast_message, + # Create "manager" agent with multi-agent broadcast tool + send_message_tool_id = server.tool_manager.get_tool_id(tool_name="send_message_to_agents_matching_tags", actor=default_user) + manager_agent_state = server.create_agent( + CreateAgent( + name="manager", + tool_ids=[send_message_tool_id], + include_base_tools=True, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + tags=manager_tags, + ), + actor=default_user, ) - # Please manually inspect the agent results + manager_agent = server.load_agent(agent_id=manager_agent_state.id, actor=default_user) + + # Create N worker agents + worker_agents = [] + for idx in tqdm(range(num_workers)): + worker_agent_state = server.create_agent( + CreateAgent( + name=f"worker-{idx}", + tool_ids=[roll_dice_tool.id], + include_multi_agent_tools=False, + include_base_tools=True, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + tags=worker_tags, + ), + actor=default_user, + ) + worker_agent = server.load_agent(agent_id=worker_agent_state.id, actor=default_user) + worker_agents.append(worker_agent) + + # Manager sends broadcast message + broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!" + server.send_messages( + actor=default_user, + agent_id=manager_agent.agent_state.id, + input_messages=[MessageCreate(role="user", content=broadcast_message)], + ) diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index aa02e0df..000db3b9 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -13,7 +13,6 @@ from dotenv import load_dotenv from rich.console import Console from rich.syntax import Syntax -from letta import create_client from letta.config import LettaConfig from letta.orm import Base from letta.orm.enums import ToolType @@ -27,6 +26,7 @@ from letta.schemas.organization import Organization from letta.schemas.user import User from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.server import SyncServer +from tests.utils import create_tool_from_func console = Console() @@ -86,14 +86,6 @@ def clear_tables(): _clear_tables() -@pytest.fixture(scope="module") -def local_client(): - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - yield client - - @pytest.fixture def server(): config = LettaConfig.load() @@ -133,14 +125,14 @@ def other_user(server: SyncServer, other_organization): @pytest.fixture -def weather_tool(local_client, weather_tool_func): - weather_tool = local_client.create_or_update_tool(func=weather_tool_func) +def weather_tool(server, weather_tool_func, default_user): + weather_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=weather_tool_func), actor=default_user) yield weather_tool @pytest.fixture -def print_tool(local_client, print_tool_func): - print_tool = local_client.create_or_update_tool(func=print_tool_func) +def print_tool(server, print_tool_func, default_user): + print_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=print_tool_func), actor=default_user) yield print_tool @@ -438,7 +430,7 @@ def test_sanity_datetime_mismatch(): # Agent serialize/deserialize tests -def test_deserialize_simple(local_client, server, serialize_test_agent, default_user, other_user): +def test_deserialize_simple(server, serialize_test_agent, default_user, other_user): """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) @@ -452,9 +444,7 @@ def test_deserialize_simple(local_client, server, serialize_test_agent, default_ @pytest.mark.parametrize("override_existing_tools", [True, False]) -def test_deserialize_override_existing_tools( - local_client, server, serialize_test_agent, default_user, weather_tool, print_tool, override_existing_tools -): +def test_deserialize_override_existing_tools(server, serialize_test_agent, default_user, weather_tool, print_tool, override_existing_tools): """ Test deserializing an agent with tools and ensure correct behavior for overriding existing tools. """ @@ -487,7 +477,7 @@ def test_deserialize_override_existing_tools( assert existing_tool.source_code == weather_tool.source_code, f"Tool {tool_name} should NOT be overridden" -def test_agent_serialize_with_user_messages(local_client, server, serialize_test_agent, default_user, other_user): +def test_agent_serialize_with_user_messages(server, serialize_test_agent, default_user, other_user): """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False server.send_messages( @@ -516,7 +506,7 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test ) -def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, serialize_test_agent, default_user, other_user): +def test_agent_serialize_tool_calls(disable_e2b_api_key, server, serialize_test_agent, default_user, other_user): """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False server.send_messages( @@ -552,7 +542,7 @@ def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, s assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0 -def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server, serialize_test_agent, default_user, other_user): +def test_agent_serialize_update_blocks(disable_e2b_api_key, server, serialize_test_agent, default_user, other_user): """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False server.send_messages( diff --git a/tests/test_ast_parsing.py b/tests/test_ast_parsing.py deleted file mode 100644 index 312e3a0c..00000000 --- a/tests/test_ast_parsing.py +++ /dev/null @@ -1,275 +0,0 @@ -import pytest - -from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source - -# ----------------------------------------------------------------------- -# Example source code for testing multiple scenarios, including: -# 1) A class-based custom type (which we won't handle properly). -# 2) Functions with multiple argument types. -# 3) A function with default arguments. -# 4) A function with no arguments. -# 5) A function that shares the same name as another symbol. -# ----------------------------------------------------------------------- -example_source_code = r""" -class CustomClass: - def __init__(self, x): - self.x = x - -def unrelated_symbol(): - pass - -def no_args_func(): - pass - -def default_args_func(x: int = 5, y: str = "hello"): - return x, y - -def my_function(a: int, b: float, c: str, d: list, e: dict, f: CustomClass = None): - pass - -def my_function_duplicate(): - # This function shares the name "my_function" partially, but isn't an exact match - pass -""" - - -# --------------------- get_function_annotations_from_source TESTS --------------------- # - - -def test_get_function_annotations_found(): - """ - Test that we correctly parse annotations for a function - that includes multiple argument types and a custom class. - """ - annotations = get_function_annotations_from_source(example_source_code, "my_function") - assert annotations == { - "a": "int", - "b": "float", - "c": "str", - "d": "list", - "e": "dict", - "f": "CustomClass", - } - - -def test_get_function_annotations_not_found(): - """ - If the requested function name doesn't exist exactly, - we should raise a ValueError. - """ - with pytest.raises(ValueError, match="Function 'missing_function' not found"): - get_function_annotations_from_source(example_source_code, "missing_function") - - -def test_get_function_annotations_no_args(): - """ - Check that a function without arguments returns an empty annotations dict. - """ - annotations = get_function_annotations_from_source(example_source_code, "no_args_func") - assert annotations == {} - - -def test_get_function_annotations_with_default_values(): - """ - Ensure that a function with default arguments still captures the annotations. - """ - annotations = get_function_annotations_from_source(example_source_code, "default_args_func") - assert annotations == {"x": "int", "y": "str"} - - -def test_get_function_annotations_partial_name_collision(): - """ - Ensure we only match the exact function name, not partial collisions. - """ - # This will match 'my_function' exactly, ignoring 'my_function_duplicate' - annotations = get_function_annotations_from_source(example_source_code, "my_function") - assert "a" in annotations # Means it matched the correct function - # No error expected here, just making sure we didn't accidentally parse "my_function_duplicate". - - -# --------------------- coerce_dict_args_by_annotations TESTS --------------------- # - - -def test_coerce_dict_args_success(): - """ - Basic success scenario with standard types: - int, float, str, list, dict. - """ - annotations = {"a": "int", "b": "float", "c": "str", "d": "list", "e": "dict"} - function_args = {"a": "42", "b": "3.14", "c": 123, "d": "[1, 2, 3]", "e": '{"key": "value"}'} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert coerced_args["b"] == 3.14 - assert coerced_args["c"] == "123" - assert coerced_args["d"] == [1, 2, 3] - assert coerced_args["e"] == {"key": "value"} - - -def test_coerce_dict_args_invalid_type(): - """ - If the value cannot be coerced into the annotation, - a ValueError should be raised. - """ - annotations = {"a": "int"} - function_args = {"a": "invalid_int"} - - with pytest.raises(ValueError, match="Failed to coerce argument 'a' to int"): - coerce_dict_args_by_annotations(function_args, annotations) - - -def test_coerce_dict_args_no_annotations(): - """ - If there are no annotations, we do no coercion. - """ - annotations = {} - function_args = {"a": 42, "b": "hello"} - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args == function_args # Exactly the same dict back - - -def test_coerce_dict_args_partial_annotations(): - """ - Only coerce annotated arguments; leave unannotated ones unchanged. - """ - annotations = {"a": "int"} - function_args = {"a": "42", "b": "no_annotation"} - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert coerced_args["b"] == "no_annotation" - - -def test_coerce_dict_args_with_missing_args(): - """ - If function_args lacks some keys listed in annotations, - those are simply not coerced. (We do not add them.) - """ - annotations = {"a": "int", "b": "float"} - function_args = {"a": "42"} # Missing 'b' - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert "b" not in coerced_args - - -def test_coerce_dict_args_unexpected_keys(): - """ - If function_args has extra keys not in annotations, - we leave them alone. - """ - annotations = {"a": "int"} - function_args = {"a": "42", "z": 999} - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert coerced_args["z"] == 999 # unchanged - - -def test_coerce_dict_args_unsupported_custom_class(): - """ - If someone tries to pass an annotation that isn't supported (like a custom class), - we should raise a ValueError (or similarly handle the error) rather than silently - accept it. - """ - annotations = {"f": "CustomClass"} # We can't resolve this - function_args = {"f": {"x": 1}} - with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass: Unsupported annotation: CustomClass"): - coerce_dict_args_by_annotations(function_args, annotations) - - -def test_coerce_dict_args_with_complex_types(): - """ - Confirm the ability to parse built-in complex data (lists, dicts, etc.) - when given as strings. - """ - annotations = {"big_list": "list", "nested_dict": "dict"} - function_args = {"big_list": "[1, 2, [3, 4], {'five': 5}]", "nested_dict": '{"alpha": [10, 20], "beta": {"x": 1, "y": 2}}'} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["big_list"] == [1, 2, [3, 4], {"five": 5}] - assert coerced_args["nested_dict"] == { - "alpha": [10, 20], - "beta": {"x": 1, "y": 2}, - } - - -def test_coerce_dict_args_non_string_keys(): - """ - Validate behavior if `function_args` includes non-string keys. - (We should simply skip annotation checks for them.) - """ - annotations = {"a": "int"} - function_args = {123: "42", "a": "42"} - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - # 'a' is coerced to int - assert coerced_args["a"] == 42 - # 123 remains untouched - assert coerced_args[123] == "42" - - -def test_coerce_dict_args_non_parseable_list_or_dict(): - """ - Test passing incorrectly formatted JSON for a 'list' or 'dict' annotation. - """ - annotations = {"bad_list": "list", "bad_dict": "dict"} - function_args = {"bad_list": "[1, 2, 3", "bad_dict": '{"key": "value"'} # missing brackets - - with pytest.raises(ValueError, match="Failed to coerce argument 'bad_list' to list"): - coerce_dict_args_by_annotations(function_args, annotations) - - -def test_coerce_dict_args_with_complex_list_annotation(): - """ - Test coercion when list with type annotation (e.g., list[int]) is used. - """ - annotations = {"a": "list[int]"} - function_args = {"a": "[1, 2, 3]"} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == [1, 2, 3] - - -def test_coerce_dict_args_with_complex_dict_annotation(): - """ - Test coercion when dict with type annotation (e.g., dict[str, int]) is used. - """ - annotations = {"a": "dict[str, int]"} - function_args = {"a": '{"x": 1, "y": 2}'} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == {"x": 1, "y": 2} - - -def test_coerce_dict_args_unsupported_complex_annotation(): - """ - If an unsupported complex annotation is used (e.g., a custom class), - a ValueError should be raised. - """ - annotations = {"f": "CustomClass[int]"} - function_args = {"f": "CustomClass(42)"} - - with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass\[int\]: Unsupported annotation: CustomClass\[int\]"): - coerce_dict_args_by_annotations(function_args, annotations) - - -def test_coerce_dict_args_with_nested_complex_annotation(): - """ - Test coercion with complex nested types like list[dict[str, int]]. - """ - annotations = {"a": "list[dict[str, int]]"} - function_args = {"a": '[{"x": 1}, {"y": 2}]'} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == [{"x": 1}, {"y": 2}] - - -def test_coerce_dict_args_with_default_arguments(): - """ - Test coercion with default arguments, where some arguments have defaults in the source code. - """ - annotations = {"a": "int", "b": "str"} - function_args = {"a": "42"} - - function_args.setdefault("b", "hello") # Setting the default value for 'b' - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert coerced_args["b"] == "hello" diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index f81211cd..f267c61d 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -6,9 +6,11 @@ from dotenv import load_dotenv from letta_client import Letta import letta.functions.function_sets.base as base_functions -from letta import LocalClient, create_client +from letta.config import LettaConfig from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate +from letta.server.server import SyncServer from tests.test_tool_schema_parsing_files.expected_base_tool_schemas import ( get_finish_rethinking_memory_schema, get_rethink_user_memory_schema, @@ -18,15 +20,6 @@ from tests.test_tool_schema_parsing_files.expected_base_tool_schemas import ( from tests.utils import wait_for_server -@pytest.fixture(scope="function") -def client(): - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4o")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - yield client - - def _run_server(): """Starts the Letta server in a background thread.""" load_dotenv() @@ -35,6 +28,21 @@ def _run_server(): start_server(debug=True) +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. + + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server + + @pytest.fixture(scope="session") def server_url(): """Ensures a server is running and returns its base URL.""" @@ -57,16 +65,29 @@ def letta_client(server_url): @pytest.fixture(scope="function") -def agent_obj(client: LocalClient): +def agent_obj(letta_client, server): """Create a test agent that we can call functions on""" - send_message_to_agent_and_wait_for_reply_tool_id = client.get_tool_id(name="send_message_to_agent_and_wait_for_reply") - agent_state = client.create_agent(tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id]) - - agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) + send_message_to_agent_and_wait_for_reply_tool_id = letta_client.tools.list(name="send_message_to_agent_and_wait_for_reply")[0].id + agent_state = letta_client.agents.create( + tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id], + include_base_tools=True, + memory_blocks=[ + { + "label": "human", + "value": "Name: Matt", + }, + { + "label": "persona", + "value": "Friendly agent", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ) + actor = server.user_manager.get_user_or_default() + agent_obj = server.load_agent(agent_id=agent_state.id, actor=actor) yield agent_obj - # client.delete_agent(agent_obj.agent_state.id) - def query_in_search_results(search_results, query): for result in search_results: @@ -127,16 +148,19 @@ def test_archival(agent_obj): pass -def test_recall(client, agent_obj): - # keyword +def test_recall(server, agent_obj, default_user): + """Test that an agent can recall messages using a keyword via conversation search.""" keyword = "banana" - # Send messages to agent - client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="hello") - client.send_message(agent_id=agent_obj.agent_state.id, role="user", message=keyword) - client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="tell me a fun fact") + # Send messages + for msg in ["hello", keyword, "tell me a fun fact"]: + server.send_messages( + actor=default_user, + agent_id=agent_obj.agent_state.id, + input_messages=[MessageCreate(role="user", content=msg)], + ) - # Conversation search + # Search memory result = base_functions.conversation_search(agent_obj, "banana") assert keyword in result diff --git a/tests/test_client.py b/tests/test_client.py index 8384f10f..3938671d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -110,58 +110,6 @@ def clear_tables(): session.commit() -# TODO: add back -# def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]): -# """ -# Test sandbox config and environment variable functions for both LocalClient and RESTClient. -# """ -# -# # 1. Create a sandbox config -# local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR) -# sandbox_config = client.create_sandbox_config(config=local_config) -# -# # Assert the created sandbox config -# assert sandbox_config.id is not None -# assert sandbox_config.type == SandboxType.LOCAL -# -# # 2. Update the sandbox config -# updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR) -# sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config) -# assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR -# -# # 3. List all sandbox configs -# sandbox_configs = client.list_sandbox_configs(limit=10) -# assert isinstance(sandbox_configs, List) -# assert len(sandbox_configs) == 1 -# assert sandbox_configs[0].id == sandbox_config.id -# -# # 4. Create an environment variable -# env_var = client.create_sandbox_env_var( -# sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION -# ) -# assert env_var.id is not None -# assert env_var.key == ENV_VAR_KEY -# assert env_var.value == ENV_VAR_VALUE -# assert env_var.description == ENV_VAR_DESCRIPTION -# -# # 5. Update the environment variable -# updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE) -# assert updated_env_var.key == UPDATED_ENV_VAR_KEY -# assert updated_env_var.value == UPDATED_ENV_VAR_VALUE -# -# # 6. List environment variables -# env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id) -# assert isinstance(env_vars, List) -# assert len(env_vars) == 1 -# assert env_vars[0].key == UPDATED_ENV_VAR_KEY -# -# # 7. Delete the environment variable -# client.delete_sandbox_env_var(env_var_id=env_var.id) -# -# # 8. Delete the sandbox config -# client.delete_sandbox_config(sandbox_config_id=sandbox_config.id) - - # -------------------------------------------------------------------------------------------------------------------- # Agent tags # -------------------------------------------------------------------------------------------------------------------- @@ -349,30 +297,6 @@ def test_attach_detach_agent_memory_block(client: Letta, agent: AgentState): assert example_new_label not in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)] -# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState): -# """Test that the token limit is enforced for the core memory blocks""" - -# # Create an agent -# new_agent = client.create_agent( -# name="test-core-memory-token-limits", -# tools=BASE_TOOLS, -# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=2000), -# ) - -# try: -# # Then intentionally set the limit to be extremely low -# client.update_agent( -# agent_id=new_agent.id, -# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=100), -# ) - -# # TODO we should probably not allow updating the core memory limit if - -# # TODO in which case we should modify this test to actually to a proper token counter check -# finally: -# client.delete_agent(new_agent.id) - - def test_update_agent_memory_limit(client: Letta): """Test that we can update the limit of a block in an agent's memory""" @@ -744,3 +668,38 @@ def test_attach_detach_agent_source(client: Letta, agent: AgentState): assert source.id not in [s.id for s in final_sources] client.sources.delete(source.id) + + +# -------------------------------------------------------------------------------------------------------------------- +# Agent Initial Message Sequence +# -------------------------------------------------------------------------------------------------------------------- +def test_initial_sequence(client: Letta): + # create an agent + agent = client.agents.create( + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + initial_message_sequence=[ + MessageCreate( + role="assistant", + content="Hello, how are you?", + ), + MessageCreate(role="user", content="I'm good, and you?"), + ], + ) + + # list messages + messages = client.agents.messages.list(agent_id=agent.id) + response = client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content="hello assistant!", + ) + ], + ) + assert len(messages) == 3 + assert messages[0].message_type == "system_message" + assert messages[1].message_type == "assistant_message" + assert messages[2].message_type == "user_message" diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 68bf2edb..97299bdc 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -9,8 +9,7 @@ import pytest from dotenv import load_dotenv from sqlalchemy import delete -from letta import create_client -from letta.client.client import LocalClient, RESTClient +from letta.client.client import RESTClient from letta.constants import DEFAULT_PRESET from letta.helpers.datetime_helpers import get_utc_time from letta.orm import FileMetadata, Source @@ -33,7 +32,6 @@ from letta.schemas.usage import LettaUsageStatistics from letta.services.helpers.agent_manager_helper import initialize_message_sequence from letta.services.organization_manager import OrganizationManager from letta.services.user_manager import UserManager -from letta.settings import model_settings from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config @@ -58,30 +56,22 @@ def run_server(): start_server(debug=True) -# Fixture to create clients with different configurations @pytest.fixture( - # params=[{"server": True}, {"server": False}], # whether to use REST API server - params=[{"server": True}], # whether to use REST API server scope="module", ) -def client(request): - if request.param["server"]: - # get URL from enviornment - server_url = os.getenv("LETTA_SERVER_URL") - if server_url is None: - # run server in thread - server_url = "http://localhost:8283" - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - # create user via admin client - client = create_client(base_url=server_url, token=None) # This yields control back to the test function - else: - # use local client (no server) - client = create_client() - +def client(): + # get URL from enviornment + server_url = os.getenv("LETTA_SERVER_URL") + if server_url is None: + # run server in thread + server_url = "http://localhost:8283" + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + time.sleep(5) + print("Running client tests with server:", server_url) + # create user via admin client + client = RESTClient(server_url) client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) yield client @@ -100,7 +90,7 @@ def clear_tables(): # Fixture for test agent @pytest.fixture(scope="module") -def agent(client: Union[LocalClient, RESTClient]): +def agent(client: Union[RESTClient]): agent_state = client.create_agent(name=test_agent_name) yield agent_state @@ -124,7 +114,7 @@ def default_user(default_organization): yield user -def test_agent(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_agent(disable_e2b_api_key, client: RESTClient, agent: AgentState): # test client.rename_agent new_name = "RenamedTestAgent" @@ -143,7 +133,7 @@ def test_agent(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agen assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed" -def test_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState): # _reset_config() memory_response = client.get_in_context_memory(agent_id=agent.id) @@ -159,7 +149,7 @@ def test_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], age ), "Memory update failed" -def test_agent_interactions(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_agent_interactions(disable_e2b_api_key, client: RESTClient, agent: AgentState): # test that it is a LettaMessage message = "Hello again, agent!" print("Sending message", message) @@ -182,7 +172,7 @@ def test_agent_interactions(disable_e2b_api_key, client: Union[LocalClient, REST # TODO: add streaming tests -def test_archival_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_archival_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState): # _reset_config() memory_content = "Archival memory content" @@ -216,7 +206,7 @@ def test_archival_memory(disable_e2b_api_key, client: Union[LocalClient, RESTCli client.get_archival_memory(agent.id) -def test_core_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_core_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState): response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user") print("Response", response) @@ -240,10 +230,6 @@ def test_streaming_send_message( stream_tokens: bool, model: str, ): - if isinstance(client, LocalClient): - pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming") - assert isinstance(client, RESTClient), client - # Update agent's model agent.llm_config.model = model @@ -296,7 +282,7 @@ def test_streaming_send_message( assert done, "Message stream not done" -def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_humans_personas(client: RESTClient, agent: AgentState): # _reset_config() humans_response = client.list_humans() @@ -322,7 +308,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta assert human.value == "Human text", "Creating human failed" -def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): +def test_list_tools_pagination(client: RESTClient): tools = client.list_tools() visited_ids = {t.id: False for t in tools} @@ -344,7 +330,7 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): assert all(visited_ids.values()) -def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_list_files_pagination(client: RESTClient, agent: AgentState): # clear sources for source in client.list_sources(): client.delete_source(source.id) @@ -380,7 +366,7 @@ def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: Ag assert len(files) == 0 # Should be empty -def test_delete_file_from_source(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_delete_file_from_source(client: RESTClient, agent: AgentState): # clear sources for source in client.list_sources(): client.delete_source(source.id) @@ -409,7 +395,7 @@ def test_delete_file_from_source(client: Union[LocalClient, RESTClient], agent: assert len(empty_files) == 0 -def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_load_file(client: RESTClient, agent: AgentState): # _reset_config() # clear sources @@ -440,7 +426,7 @@ def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): assert file.source_id == source.id -def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_sources(client: RESTClient, agent: AgentState): # _reset_config() # clear sources @@ -530,9 +516,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): def test_organization(client: RESTClient): - if isinstance(client, LocalClient): - pytest.skip("Skipping test_organization because LocalClient does not support organizations") - # create an organization org_name = "test-org" org = client.create_org(org_name) @@ -549,25 +532,6 @@ def test_organization(client: RESTClient): assert not (org.id in [o.id for o in orgs]) -def test_list_llm_models(client: RESTClient): - """Test that if the user's env has the right api keys set, at least one model appears in the model list""" - - def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool: - return any(model.model_endpoint_type == target_type for model in models) - - models = client.list_llm_configs() - if model_settings.groq_api_key: - assert has_model_endpoint_type(models, "groq") - if model_settings.azure_api_key: - assert has_model_endpoint_type(models, "azure") - if model_settings.openai_api_key: - assert has_model_endpoint_type(models, "openai") - if model_settings.gemini_api_key: - assert has_model_endpoint_type(models, "google_ai") - if model_settings.anthropic_api_key: - assert has_model_endpoint_type(models, "anthropic") - - @pytest.fixture def cleanup_agents(client): created_agents = [] @@ -581,7 +545,7 @@ def cleanup_agents(client): # NOTE: we need to add this back once agents can also create blocks during agent creation -def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str], default_user): +def test_initial_message_sequence(client: RESTClient, agent: AgentState, cleanup_agents: List[str], default_user): """Test that we can set an initial message sequence If we pass in None, we should get a "default" message sequence @@ -624,7 +588,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: assert custom_sequence[0].content in client.get_in_context_messages(custom_agent_state.id)[1].content[0].text -def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_add_and_manage_tags_for_agent(client: RESTClient, agent: AgentState): """ Comprehensive happy path test for adding, retrieving, and managing tags on an agent. """ diff --git a/tests/test_local_client.py b/tests/test_local_client.py deleted file mode 100644 index a3967e4a..00000000 --- a/tests/test_local_client.py +++ /dev/null @@ -1,411 +0,0 @@ -import uuid - -import pytest - -from letta import create_client -from letta.client.client import LocalClient -from letta.schemas.agent import AgentState -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory - - -@pytest.fixture(scope="module") -def client(): - client = create_client() - # client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - yield client - - -@pytest.fixture(scope="module") -def agent(client): - # Generate uuid for agent name for this example - namespace = uuid.NAMESPACE_DNS - agent_uuid = str(uuid.uuid5(namespace, "test_new_client_test_agent")) - - agent_state = client.create_agent(name=agent_uuid) - yield agent_state - - client.delete_agent(agent_state.id) - - -def test_agent(client: LocalClient): - # create agent - agent_state_test = client.create_agent( - name="test_agent2", - memory=ChatMemory(human="I am a human", persona="I am an agent"), - description="This is a test agent", - ) - assert isinstance(agent_state_test.memory, Memory) - - # list agents - agents = client.list_agents() - assert agent_state_test.id in [a.id for a in agents] - - # get agent - tools = client.list_tools() - print("TOOLS", [t.name for t in tools]) - agent_state = client.get_agent(agent_state_test.id) - assert agent_state.name == "test_agent2" - for block in agent_state.memory.blocks: - db_block = client.server.block_manager.get_block_by_id(block.id, actor=client.user) - assert db_block is not None, "memory block not persisted on agent create" - assert db_block.value == block.value, "persisted block data does not match in-memory data" - - assert isinstance(agent_state.memory, Memory) - # update agent: name - new_name = "new_agent" - client.update_agent(agent_state_test.id, name=new_name) - assert client.get_agent(agent_state_test.id).name == new_name - - assert isinstance(agent_state.memory, Memory) - # update agent: system prompt - new_system_prompt = agent_state.system + "\nAlways respond with a !" - client.update_agent(agent_state_test.id, system=new_system_prompt) - assert client.get_agent(agent_state_test.id).system == new_system_prompt - - response = client.user_message(agent_id=agent_state_test.id, message="Hello") - agent_state = client.get_agent(agent_state_test.id) - assert isinstance(agent_state.memory, Memory) - # update agent: message_ids - old_message_ids = agent_state.message_ids - new_message_ids = old_message_ids.copy()[:-1] # pop one - assert len(old_message_ids) != len(new_message_ids) - client.update_agent(agent_state_test.id, message_ids=new_message_ids) - assert client.get_agent(agent_state_test.id).message_ids == new_message_ids - - assert isinstance(agent_state.memory, Memory) - # update agent: tools - tool_to_delete = "send_message" - assert tool_to_delete in [t.name for t in agent_state.tools] - new_agent_tool_ids = [t.id for t in agent_state.tools if t.name != tool_to_delete] - client.update_agent(agent_state_test.id, tool_ids=new_agent_tool_ids) - assert sorted([t.id for t in client.get_agent(agent_state_test.id).tools]) == sorted(new_agent_tool_ids) - - assert isinstance(agent_state.memory, Memory) - # update agent: memory - new_human = "My name is Mr Test, 100 percent human." - new_persona = "I am an all-knowing AI." - assert agent_state.memory.get_block("human").value != new_human - assert agent_state.memory.get_block("persona").value != new_persona - - # client.update_agent(agent_state_test.id, memory=new_memory) - # update blocks: - client.update_agent_memory_block(agent_state_test.id, label="human", value=new_human) - client.update_agent_memory_block(agent_state_test.id, label="persona", value=new_persona) - assert client.get_agent(agent_state_test.id).memory.get_block("human").value == new_human - assert client.get_agent(agent_state_test.id).memory.get_block("persona").value == new_persona - - # update agent: llm config - new_llm_config = agent_state.llm_config.model_copy(deep=True) - new_llm_config.model = "fake_new_model" - new_llm_config.context_window = 1e6 - assert agent_state.llm_config != new_llm_config - client.update_agent(agent_state_test.id, llm_config=new_llm_config) - assert client.get_agent(agent_state_test.id).llm_config == new_llm_config - assert client.get_agent(agent_state_test.id).llm_config.model == "fake_new_model" - assert client.get_agent(agent_state_test.id).llm_config.context_window == 1e6 - - # update agent: embedding config - new_embed_config = agent_state.embedding_config.model_copy(deep=True) - new_embed_config.embedding_model = "fake_embed_model" - assert agent_state.embedding_config != new_embed_config - client.update_agent(agent_state_test.id, embedding_config=new_embed_config) - assert client.get_agent(agent_state_test.id).embedding_config == new_embed_config - assert client.get_agent(agent_state_test.id).embedding_config.embedding_model == "fake_embed_model" - - # delete agent - client.delete_agent(agent_state_test.id) - - -def test_agent_add_remove_tools(client: LocalClient, agent): - # Create and add two tools to the client - # tool 1 - from composio import Action - - github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - - # assert both got added - tools = client.list_tools() - assert github_tool.id in [t.id for t in tools] - - # Assert that all combinations of tool_names, organization id are unique - combinations = [(t.name, t.organization_id) for t in tools] - assert len(combinations) == len(set(combinations)) - - # create agent - agent_state = agent - curr_num_tools = len(agent_state.tools) - - # add both tools to agent in steps - agent_state = client.attach_tool(agent_id=agent_state.id, tool_id=github_tool.id) - - # confirm that both tools are in the agent state - # we could access it like agent_state.tools, but will use the client function instead - # this is obviously redundant as it requires retrieving the agent again - # but allows us to test the `get_tools_from_agent` pathway as well - curr_tools = client.get_tools_from_agent(agent_state.id) - curr_tool_names = [t.name for t in curr_tools] - assert len(curr_tool_names) == curr_num_tools + 1 - assert github_tool.name in curr_tool_names - - # remove only the github tool - agent_state = client.detach_tool(agent_id=agent_state.id, tool_id=github_tool.id) - - # confirm that only one tool left - curr_tools = client.get_tools_from_agent(agent_state.id) - curr_tool_names = [t.name for t in curr_tools] - assert len(curr_tool_names) == curr_num_tools - assert github_tool.name not in curr_tool_names - - -def test_agent_with_shared_blocks(client: LocalClient): - persona_block = client.create_block(template_name="persona", value="Here to test things!", label="persona") - human_block = client.create_block(template_name="human", value="Me Human, I swear. Beep boop.", label="human") - existing_non_template_blocks = [persona_block, human_block] - - existing_non_template_blocks_no_values = [] - for block in existing_non_template_blocks: - block_copy = block.copy() - block_copy.value = "" - existing_non_template_blocks_no_values.append(block_copy) - - # create agent - first_agent_state_test = None - second_agent_state_test = None - try: - first_agent_state_test = client.create_agent( - name="first_test_agent_shared_memory_blocks", - memory=BasicBlockMemory(blocks=existing_non_template_blocks), - description="This is a test agent using shared memory blocks", - ) - assert isinstance(first_agent_state_test.memory, Memory) - - # when this agent is created with the shared block references this agent's in-memory blocks should - # have this latest value set by the other agent. - second_agent_state_test = client.create_agent( - name="second_test_agent_shared_memory_blocks", - memory=BasicBlockMemory(blocks=existing_non_template_blocks_no_values), - description="This is a test agent using shared memory blocks", - ) - - first_memory = first_agent_state_test.memory - assert persona_block.id == first_memory.get_block("persona").id - assert human_block.id == first_memory.get_block("human").id - client.update_agent_memory_block(first_agent_state_test.id, label="human", value="I'm an analyst therapist.") - print("Updated human block value:", client.get_agent_memory_block(first_agent_state_test.id, label="human").value) - - # refresh agent state - second_agent_state_test = client.get_agent(second_agent_state_test.id) - - assert isinstance(second_agent_state_test.memory, Memory) - second_memory = second_agent_state_test.memory - assert persona_block.id == second_memory.get_block("persona").id - assert human_block.id == second_memory.get_block("human").id - # assert second_blocks_dict.get("human", {}).get("value") == "I'm an analyst therapist." - assert second_memory.get_block("human").value == "I'm an analyst therapist." - - finally: - if first_agent_state_test: - client.delete_agent(first_agent_state_test.id) - if second_agent_state_test: - client.delete_agent(second_agent_state_test.id) - - -def test_memory(client: LocalClient, agent: AgentState): - # get agent memory - original_memory = client.get_in_context_memory(agent.id) - assert original_memory is not None - original_memory_value = str(original_memory.get_block("human").value) - - # update core memory - updated_memory = client.update_in_context_memory(agent.id, section="human", value="I am a human") - - # get memory - assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated - - -def test_archival_memory(client: LocalClient, agent: AgentState): - """Test functions for interacting with archival memory store""" - - # add archival memory - memory_str = "I love chats" - passage = client.insert_archival_memory(agent.id, memory=memory_str)[0] - - # list archival memory - passages = client.get_archival_memory(agent.id) - assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}" - - # delete archival memory - client.delete_archival_memory(agent.id, passage.id) - - -def test_recall_memory(client: LocalClient, agent: AgentState): - """Test functions for interacting with recall memory store""" - - # send message to the agent - message_str = "Hello" - client.send_message(message=message_str, role="user", agent_id=agent.id) - - # list messages - messages = client.get_messages(agent.id) - exists = False - for m in messages: - if message_str in str(m): - exists = True - assert exists - - # get in-context messages - in_context_messages = client.get_in_context_messages(agent.id) - exists = False - for m in in_context_messages: - if message_str in m.content[0].text: - exists = True - assert exists - - -def test_tools(client: LocalClient): - def print_tool(message: str): - """ - A tool to print a message - - Args: - message (str): The message to print. - - Returns: - str: The message that was printed. - - """ - print(message) - return message - - def print_tool2(msg: str): - """ - Another tool to print a message - - Args: - msg (str): The message to print. - """ - print(msg) - - # Clean all tools first - for tool in client.list_tools(): - client.delete_tool(tool.id) - - # create tool - tool = client.create_or_update_tool(func=print_tool, tags=["extras"]) - - # list tools - tools = client.list_tools() - assert tool.name in [t.name for t in tools] - - # get tool id - assert tool.id == client.get_tool_id(name="print_tool") - - # update tool: extras - extras2 = ["extras2"] - client.update_tool(tool.id, tags=extras2) - assert client.get_tool(tool.id).tags == extras2 - - # update tool: source code - client.update_tool(tool.id, func=print_tool2) - assert client.get_tool(tool.id).name == "print_tool2" - - -def test_tools_from_composio_basic(client: LocalClient): - from composio import Action - - # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) - client = create_client() - - # create tool - tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - - # list tools - tools = client.list_tools() - assert tool.name in [t.name for t in tools] - - # We end the test here as composio requires login to use the tools - # The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable - - -# TODO: Langchain seems to have issues with Pydantic -# TODO: Langchain tools are breaking every two weeks bc of changes on their side -# def test_tools_from_langchain(client: LocalClient): -# # create langchain tool -# from langchain_community.tools import WikipediaQueryRun -# from langchain_community.utilities import WikipediaAPIWrapper -# -# langchain_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) -# -# # Add the tool -# tool = client.load_langchain_tool( -# langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"} -# ) -# -# # list tools -# tools = client.list_tools() -# assert tool.name in [t.name for t in tools] -# -# # get tool -# tool_id = client.get_tool_id(name=tool.name) -# retrieved_tool = client.get_tool(tool_id) -# source_code = retrieved_tool.source_code -# -# # Parse the function and attempt to use it -# local_scope = {} -# exec(source_code, {}, local_scope) -# func = local_scope[tool.name] -# -# expected_content = "Albert Einstein" -# assert expected_content in func(query="Albert Einstein") -# -# -# def test_tool_creation_langchain_missing_imports(client: LocalClient): -# # create langchain tool -# from langchain_community.tools import WikipediaQueryRun -# from langchain_community.utilities import WikipediaAPIWrapper -# -# api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100) -# langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) -# -# # Translate to memGPT Tool -# # Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"} -# with pytest.raises(RuntimeError): -# ToolCreate.from_langchain(langchain_tool) - - -def test_shared_blocks_without_send_message(client: LocalClient): - from letta import BasicBlockMemory - from letta.client.client import Block, create_client - from letta.schemas.agent import AgentType - from letta.schemas.embedding_config import EmbeddingConfig - from letta.schemas.llm_config import LLMConfig - - client = create_client() - shared_memory_block = Block(name="shared_memory", label="shared_memory", value="[empty]", limit=2000) - memory = BasicBlockMemory(blocks=[shared_memory_block]) - - agent_1 = client.create_agent( - agent_type=AgentType.memgpt_agent, - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), - memory=memory, - ) - - agent_2 = client.create_agent( - agent_type=AgentType.memgpt_agent, - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), - memory=memory, - ) - - block_id = agent_1.memory.get_block("shared_memory").id - client.update_block(block_id, value="I am no longer an [empty] memory") - agent_1 = client.get_agent(agent_1.id) - agent_2 = client.get_agent(agent_2.id) - assert agent_1.memory.get_block("shared_memory").value == "I am no longer an [empty] memory" - assert agent_2.memory.get_block("shared_memory").value == "I am no longer an [empty] memory" diff --git a/tests/test_model_letta_performance.py b/tests/test_model_letta_performance.py deleted file mode 100644 index 41f2da64..00000000 --- a/tests/test_model_letta_performance.py +++ /dev/null @@ -1,439 +0,0 @@ -import os - -import pytest - -from tests.helpers.endpoints_helper import ( - check_agent_archival_memory_insert, - check_agent_archival_memory_retrieval, - check_agent_edit_core_memory, - check_agent_recall_chat_memory, - check_agent_uses_external_tool, - check_first_response_is_valid_for_llm_endpoint, - run_embedding_endpoint, -) -from tests.helpers.utils import retry_until_success, retry_until_threshold - -# directories -embedding_config_dir = "tests/configs/embedding_model_configs" -llm_config_dir = "tests/configs/llm_model_configs" - - -# ====================================================================================================================== -# OPENAI TESTS -# ====================================================================================================================== -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_archival_memory_insert(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_archival_memory_insert(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_edit_core_memory(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_embedding_endpoint_openai(): - filename = os.path.join(embedding_config_dir, "openai_embed.json") - run_embedding_endpoint(filename) - - -# ====================================================================================================================== -# AZURE TESTS -# ====================================================================================================================== -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_edit_core_memory(): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_embedding_endpoint(): - filename = os.path.join(embedding_config_dir, "azure_embed.json") - run_embedding_endpoint(filename) - - -# ====================================================================================================================== -# LETTA HOSTED -# ====================================================================================================================== -def test_llm_endpoint_letta_hosted(): - filename = os.path.join(llm_config_dir, "letta-hosted.json") - check_first_response_is_valid_for_llm_endpoint(filename) - - -def test_embedding_endpoint_letta_hosted(): - filename = os.path.join(embedding_config_dir, "letta-hosted.json") - run_embedding_endpoint(filename) - - -# ====================================================================================================================== -# LOCAL MODELS -# ====================================================================================================================== -def test_embedding_endpoint_local(): - filename = os.path.join(embedding_config_dir, "local.json") - run_embedding_endpoint(filename) - - -def test_llm_endpoint_ollama(): - filename = os.path.join(llm_config_dir, "ollama.json") - check_first_response_is_valid_for_llm_endpoint(filename) - - -def test_embedding_endpoint_ollama(): - filename = os.path.join(embedding_config_dir, "ollama.json") - run_embedding_endpoint(filename) - - -# ====================================================================================================================== -# ANTHROPIC TESTS -# ====================================================================================================================== -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_edit_core_memory(): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# GROQ TESTS -# ====================================================================================================================== -def test_groq_llama31_70b_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_groq_llama31_70b_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_groq_llama31_70b_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@retry_until_threshold(threshold=0.75, max_attempts=4) -def test_groq_llama31_70b_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_groq_llama31_70b_edit_core_memory(): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# GEMINI TESTS -# ====================================================================================================================== -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_edit_core_memory(): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# GOOGLE VERTEX TESTS -# ====================================================================================================================== -@pytest.mark.vertex_basic -@retry_until_success(max_attempts=1, sleep_time_seconds=2) -def test_vertex_gemini_pro_20_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "gemini-vertex.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# DEEPSEEK TESTS -# ====================================================================================================================== -@pytest.mark.deepseek_basic -def test_deepseek_reasoner_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "deepseek-reasoner.json") - # Don't validate that the inner monologue doesn't contain things like "function", since - # for the reasoners it might be quite meta (have analysis about functions etc.) - response = check_first_response_is_valid_for_llm_endpoint(filename, validate_inner_monologue_contents=False) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# xAI TESTS -# ====================================================================================================================== -@pytest.mark.xai_basic -def test_xai_grok2_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "xai-grok-2.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# TOGETHER TESTS -# ====================================================================================================================== -def test_together_llama_3_70b_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_together_llama_3_70b_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_together_llama_3_70b_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_together_llama_3_70b_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_together_llama_3_70b_edit_core_memory(): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# ANTHROPIC BEDROCK TESTS -# ====================================================================================================================== -@pytest.mark.anthropic_bedrock_basic -def test_bedrock_claude_sonnet_3_5_valid_config(): - import json - - from letta.schemas.llm_config import LLMConfig - from letta.settings import model_settings - - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - config_data = json.load(open(filename, "r")) - llm_config = LLMConfig(**config_data) - model_region = llm_config.model.split(":")[3] - assert model_settings.aws_region == model_region, "Model region in config file does not match model region in ModelSettings" - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_edit_core_memory(): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") diff --git a/tests/test_server.py b/tests/test_server.py index 200ff54e..4ad80422 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -270,8 +270,6 @@ start my apprenticeship as old maid.""" @pytest.fixture(scope="module") def server(): config = LettaConfig.load() - print("CONFIG PATH", config.config_path) - config.save() server = SyncServer() diff --git a/tests/test_streaming.py b/tests/test_streaming.py deleted file mode 100644 index d9a7a7f1..00000000 --- a/tests/test_streaming.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -import threading -import time - -import pytest -from dotenv import load_dotenv -from letta_client import AgentState, Letta, LlmConfig, MessageCreate - -from letta.schemas.message import Message - - -def run_server(): - load_dotenv() - - from letta.server.rest_api.app import start_server - - print("Starting server...") - start_server(debug=True) - - -@pytest.fixture( - scope="module", -) -def client(request): - # Get URL from environment or start server - api_url = os.getenv("LETTA_API_URL") - server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283") - if not os.getenv("LETTA_SERVER_URL"): - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - - # Overide the base_url if the LETTA_API_URL is set - base_url = api_url if api_url else server_url - # create the Letta client - yield Letta(base_url=base_url, token=None) - - -# Fixture for test agent -@pytest.fixture(scope="module") -def agent(client: Letta): - agent_state = client.agents.create( - name="test_client", - memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], - model="letta/letta-free", - embedding="letta/letta-free", - ) - - yield agent_state - - # delete agent - client.agents.delete(agent_state.id) - - -@pytest.mark.parametrize( - "stream_tokens,model", - [ - (True, "openai/gpt-4o-mini"), - (True, "anthropic/claude-3-sonnet-20240229"), - (False, "openai/gpt-4o-mini"), - (False, "anthropic/claude-3-sonnet-20240229"), - ], -) -def test_streaming_send_message( - disable_e2b_api_key, - client: Letta, - agent: AgentState, - stream_tokens: bool, - model: str, -): - # Update agent's model - config = client.agents.retrieve(agent_id=agent.id).llm_config - config_dump = config.model_dump() - config_dump["model"] = model - config = LlmConfig(**config_dump) - client.agents.modify(agent_id=agent.id, llm_config=config) - - # Send streaming message - user_message_otid = Message.generate_otid() - response = client.agents.messages.create_stream( - agent_id=agent.id, - messages=[ - MessageCreate( - role="user", - content="This is a test. Repeat after me: 'banana'", - otid=user_message_otid, - ), - ], - stream_tokens=stream_tokens, - ) - - # Tracking variables for test validation - inner_thoughts_exist = False - inner_thoughts_count = 0 - send_message_ran = False - done = False - last_message_id = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id - letta_message_otids = [user_message_otid] - - assert response, "Sending message failed" - for chunk in response: - # Check chunk type and content based on the current client API - if hasattr(chunk, "message_type") and chunk.message_type == "reasoning_message": - inner_thoughts_exist = True - inner_thoughts_count += 1 - - if chunk.message_type == "tool_call_message" and hasattr(chunk, "tool_call") and chunk.tool_call.name == "send_message": - send_message_ran = True - if chunk.message_type == "assistant_message": - send_message_ran = True - - if chunk.message_type == "usage_statistics": - # Validate usage statistics - assert chunk.step_count == 1 - assert chunk.completion_tokens > 10 - assert chunk.prompt_tokens > 1000 - assert chunk.total_tokens > 1000 - done = True - else: - letta_message_otids.append(chunk.otid) - print(chunk) - - # If stream tokens, we expect at least one inner thought - assert inner_thoughts_count >= 1, "Expected more than one inner thought" - assert inner_thoughts_exist, "No inner thoughts found" - assert send_message_ran, "send_message function call not found" - assert done, "Message stream not done" - - messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) - assert [message.otid for message in messages] == letta_message_otids diff --git a/tests/test_system_prompt_compiler.py b/tests/test_system_prompt_compiler.py deleted file mode 100644 index d7423603..00000000 --- a/tests/test_system_prompt_compiler.py +++ /dev/null @@ -1,59 +0,0 @@ -from letta.services.helpers.agent_manager_helper import safe_format - -CORE_MEMORY_VAR = "My core memory is that I like to eat bananas" -VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR} - - -def test_formatter(): - - # Example system prompt that has no vars - NO_VARS = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - """ - - assert NO_VARS == safe_format(NO_VARS, VARS_DICT) - - # Example system prompt that has {CORE_MEMORY} - CORE_MEMORY_VAR = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {CORE_MEMORY} - """ - - CORE_MEMORY_VAR_SOL = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - My core memory is that I like to eat bananas - """ - - assert CORE_MEMORY_VAR_SOL == safe_format(CORE_MEMORY_VAR, VARS_DICT) - - # Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist) - UNUSED_VAR = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {USER_MEMORY} - {CORE_MEMORY} - """ - - UNUSED_VAR_SOL = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {USER_MEMORY} - My core memory is that I like to eat bananas - """ - - assert UNUSED_VAR_SOL == safe_format(UNUSED_VAR, VARS_DICT) - - # Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist), AND an empty {} - UNUSED_AND_EMPRY_VAR = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {} - {USER_MEMORY} - {CORE_MEMORY} - """ - - UNUSED_AND_EMPRY_VAR_SOL = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {} - {USER_MEMORY} - My core memory is that I like to eat bananas - """ - - assert UNUSED_AND_EMPRY_VAR_SOL == safe_format(UNUSED_AND_EMPRY_VAR, VARS_DICT) diff --git a/tests/test_utils.py b/tests/test_utils.py index 904e903e..214dfcbb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,282 @@ import pytest from letta.constants import MAX_FILENAME_LENGTH +from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source +from letta.services.helpers.agent_manager_helper import safe_format from letta.utils import sanitize_filename +CORE_MEMORY_VAR = "My core memory is that I like to eat bananas" +VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR} + +# ----------------------------------------------------------------------- +# Example source code for testing multiple scenarios, including: +# 1) A class-based custom type (which we won't handle properly). +# 2) Functions with multiple argument types. +# 3) A function with default arguments. +# 4) A function with no arguments. +# 5) A function that shares the same name as another symbol. +# ----------------------------------------------------------------------- +example_source_code = r""" +class CustomClass: + def __init__(self, x): + self.x = x + +def unrelated_symbol(): + pass + +def no_args_func(): + pass + +def default_args_func(x: int = 5, y: str = "hello"): + return x, y + +def my_function(a: int, b: float, c: str, d: list, e: dict, f: CustomClass = None): + pass + +def my_function_duplicate(): + # This function shares the name "my_function" partially, but isn't an exact match + pass +""" + + +def test_get_function_annotations_found(): + """ + Test that we correctly parse annotations for a function + that includes multiple argument types and a custom class. + """ + annotations = get_function_annotations_from_source(example_source_code, "my_function") + assert annotations == { + "a": "int", + "b": "float", + "c": "str", + "d": "list", + "e": "dict", + "f": "CustomClass", + } + + +def test_get_function_annotations_not_found(): + """ + If the requested function name doesn't exist exactly, + we should raise a ValueError. + """ + with pytest.raises(ValueError, match="Function 'missing_function' not found"): + get_function_annotations_from_source(example_source_code, "missing_function") + + +def test_get_function_annotations_no_args(): + """ + Check that a function without arguments returns an empty annotations dict. + """ + annotations = get_function_annotations_from_source(example_source_code, "no_args_func") + assert annotations == {} + + +def test_get_function_annotations_with_default_values(): + """ + Ensure that a function with default arguments still captures the annotations. + """ + annotations = get_function_annotations_from_source(example_source_code, "default_args_func") + assert annotations == {"x": "int", "y": "str"} + + +def test_get_function_annotations_partial_name_collision(): + """ + Ensure we only match the exact function name, not partial collisions. + """ + # This will match 'my_function' exactly, ignoring 'my_function_duplicate' + annotations = get_function_annotations_from_source(example_source_code, "my_function") + assert "a" in annotations # Means it matched the correct function + # No error expected here, just making sure we didn't accidentally parse "my_function_duplicate". + + +# --------------------- coerce_dict_args_by_annotations TESTS --------------------- # + + +def test_coerce_dict_args_success(): + """ + Basic success scenario with standard types: + int, float, str, list, dict. + """ + annotations = {"a": "int", "b": "float", "c": "str", "d": "list", "e": "dict"} + function_args = {"a": "42", "b": "3.14", "c": 123, "d": "[1, 2, 3]", "e": '{"key": "value"}'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["b"] == 3.14 + assert coerced_args["c"] == "123" + assert coerced_args["d"] == [1, 2, 3] + assert coerced_args["e"] == {"key": "value"} + + +def test_coerce_dict_args_invalid_type(): + """ + If the value cannot be coerced into the annotation, + a ValueError should be raised. + """ + annotations = {"a": "int"} + function_args = {"a": "invalid_int"} + + with pytest.raises(ValueError, match="Failed to coerce argument 'a' to int"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_no_annotations(): + """ + If there are no annotations, we do no coercion. + """ + annotations = {} + function_args = {"a": 42, "b": "hello"} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args == function_args # Exactly the same dict back + + +def test_coerce_dict_args_partial_annotations(): + """ + Only coerce annotated arguments; leave unannotated ones unchanged. + """ + annotations = {"a": "int"} + function_args = {"a": "42", "b": "no_annotation"} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["b"] == "no_annotation" + + +def test_coerce_dict_args_with_missing_args(): + """ + If function_args lacks some keys listed in annotations, + those are simply not coerced. (We do not add them.) + """ + annotations = {"a": "int", "b": "float"} + function_args = {"a": "42"} # Missing 'b' + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert "b" not in coerced_args + + +def test_coerce_dict_args_unexpected_keys(): + """ + If function_args has extra keys not in annotations, + we leave them alone. + """ + annotations = {"a": "int"} + function_args = {"a": "42", "z": 999} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["z"] == 999 # unchanged + + +def test_coerce_dict_args_unsupported_custom_class(): + """ + If someone tries to pass an annotation that isn't supported (like a custom class), + we should raise a ValueError (or similarly handle the error) rather than silently + accept it. + """ + annotations = {"f": "CustomClass"} # We can't resolve this + function_args = {"f": {"x": 1}} + with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass: Unsupported annotation: CustomClass"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_with_complex_types(): + """ + Confirm the ability to parse built-in complex data (lists, dicts, etc.) + when given as strings. + """ + annotations = {"big_list": "list", "nested_dict": "dict"} + function_args = {"big_list": "[1, 2, [3, 4], {'five': 5}]", "nested_dict": '{"alpha": [10, 20], "beta": {"x": 1, "y": 2}}'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["big_list"] == [1, 2, [3, 4], {"five": 5}] + assert coerced_args["nested_dict"] == { + "alpha": [10, 20], + "beta": {"x": 1, "y": 2}, + } + + +def test_coerce_dict_args_non_string_keys(): + """ + Validate behavior if `function_args` includes non-string keys. + (We should simply skip annotation checks for them.) + """ + annotations = {"a": "int"} + function_args = {123: "42", "a": "42"} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + # 'a' is coerced to int + assert coerced_args["a"] == 42 + # 123 remains untouched + assert coerced_args[123] == "42" + + +def test_coerce_dict_args_non_parseable_list_or_dict(): + """ + Test passing incorrectly formatted JSON for a 'list' or 'dict' annotation. + """ + annotations = {"bad_list": "list", "bad_dict": "dict"} + function_args = {"bad_list": "[1, 2, 3", "bad_dict": '{"key": "value"'} # missing brackets + + with pytest.raises(ValueError, match="Failed to coerce argument 'bad_list' to list"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_with_complex_list_annotation(): + """ + Test coercion when list with type annotation (e.g., list[int]) is used. + """ + annotations = {"a": "list[int]"} + function_args = {"a": "[1, 2, 3]"} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == [1, 2, 3] + + +def test_coerce_dict_args_with_complex_dict_annotation(): + """ + Test coercion when dict with type annotation (e.g., dict[str, int]) is used. + """ + annotations = {"a": "dict[str, int]"} + function_args = {"a": '{"x": 1, "y": 2}'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == {"x": 1, "y": 2} + + +def test_coerce_dict_args_unsupported_complex_annotation(): + """ + If an unsupported complex annotation is used (e.g., a custom class), + a ValueError should be raised. + """ + annotations = {"f": "CustomClass[int]"} + function_args = {"f": "CustomClass(42)"} + + with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass\[int\]: Unsupported annotation: CustomClass\[int\]"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_with_nested_complex_annotation(): + """ + Test coercion with complex nested types like list[dict[str, int]]. + """ + annotations = {"a": "list[dict[str, int]]"} + function_args = {"a": '[{"x": 1}, {"y": 2}]'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == [{"x": 1}, {"y": 2}] + + +def test_coerce_dict_args_with_default_arguments(): + """ + Test coercion with default arguments, where some arguments have defaults in the source code. + """ + annotations = {"a": "int", "b": "str"} + function_args = {"a": "42"} + + function_args.setdefault("b", "hello") # Setting the default value for 'b' + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["b"] == "hello" + def test_valid_filename(): filename = "valid_filename.txt" @@ -64,3 +338,58 @@ def test_unique_filenames(): assert sanitized2.startswith("duplicate_") assert sanitized1.endswith(".txt") assert sanitized2.endswith(".txt") + + +def test_formatter(): + + # Example system prompt that has no vars + NO_VARS = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + """ + + assert NO_VARS == safe_format(NO_VARS, VARS_DICT) + + # Example system prompt that has {CORE_MEMORY} + CORE_MEMORY_VAR = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {CORE_MEMORY} + """ + + CORE_MEMORY_VAR_SOL = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + My core memory is that I like to eat bananas + """ + + assert CORE_MEMORY_VAR_SOL == safe_format(CORE_MEMORY_VAR, VARS_DICT) + + # Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist) + UNUSED_VAR = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {USER_MEMORY} + {CORE_MEMORY} + """ + + UNUSED_VAR_SOL = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {USER_MEMORY} + My core memory is that I like to eat bananas + """ + + assert UNUSED_VAR_SOL == safe_format(UNUSED_VAR, VARS_DICT) + + # Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist), AND an empty {} + UNUSED_AND_EMPRY_VAR = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {} + {USER_MEMORY} + {CORE_MEMORY} + """ + + UNUSED_AND_EMPRY_VAR_SOL = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {} + {USER_MEMORY} + My core memory is that I like to eat bananas + """ + + assert UNUSED_AND_EMPRY_VAR_SOL == safe_format(UNUSED_AND_EMPRY_VAR, VARS_DICT)