From e5ff06685cd7fdd49ad04827bcb12680b1797552 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 15 Oct 2024 15:50:47 -0700 Subject: [PATCH] feat: add support for agent "swarm" (multi-agent) (#1878) --- examples/swarm/simple.py | 72 ++++++++++++++++++ examples/swarm/swarm.py | 111 ++++++++++++++++++++++++++++ letta/client/client.py | 66 ++++++++++++++--- letta/constants.py | 2 +- letta/functions/functions.py | 2 +- letta/functions/schema_generator.py | 5 +- letta/schemas/tool.py | 1 + letta/server/server.py | 4 +- tests/test_schema_generator.py | 6 +- 9 files changed, 249 insertions(+), 20 deletions(-) create mode 100644 examples/swarm/simple.py create mode 100644 examples/swarm/swarm.py diff --git a/examples/swarm/simple.py b/examples/swarm/simple.py new file mode 100644 index 00000000..e5595dd5 --- /dev/null +++ b/examples/swarm/simple.py @@ -0,0 +1,72 @@ +import typer +from swarm import Swarm + +from letta import EmbeddingConfig, LLMConfig + +""" +This is an example of how to implement the basic example provided by OpenAI for tranferring a conversation between two agents: +https://github.com/openai/swarm/tree/main?tab=readme-ov-file#usage + +Before running this example, make sure you have letta>=0.5.0 installed. This example also runs with OpenAI, though you can also change the model by modifying the code: +```bash +export OPENAI_API_KEY=... +pip install letta +```` +Then, instead the `examples/swarm` directory, run: +```bash +python simple.py +``` +You should see a message output from Agent B. + +""" + + +def transfer_agent_b(self): + """ + Transfer conversation to agent B. + + Returns: + str: name of agent to transfer to + """ + return "agentb" + + +def transfer_agent_a(self): + """ + Transfer conversation to agent A. + + Returns: + str: name of agent to transfer to + """ + return "agenta" + + +swarm = Swarm() + +# set client configs +swarm.client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) +swarm.client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4")) + +# create tools +transfer_a = swarm.client.create_tool(transfer_agent_a, terminal=True) +transfer_b = swarm.client.create_tool(transfer_agent_b, terminal=True) + +# create agents +if swarm.client.get_agent_id("agentb"): + swarm.client.delete_agent(swarm.client.get_agent_id("agentb")) +if swarm.client.get_agent_id("agenta"): + swarm.client.delete_agent(swarm.client.get_agent_id("agenta")) +agent_a = swarm.create_agent(name="agentb", tools=[transfer_a.name], instructions="Only speak in haikus") +agent_b = swarm.create_agent(name="agenta", tools=[transfer_b.name]) + +response = swarm.run(agent_name="agenta", message="Transfer me to agent b by calling the transfer_agent_b tool") +print("Response:") +typer.secho(f"{response}", fg=typer.colors.GREEN) + +response = swarm.run(agent_name="agenta", message="My name is actually Sarah. Transfer me to agent b to write a haiku about my name") +print("Response:") +typer.secho(f"{response}", fg=typer.colors.GREEN) + +response = swarm.run(agent_name="agenta", message="Transfer me to agent b - I want a haiku with my name in it") +print("Response:") +typer.secho(f"{response}", fg=typer.colors.GREEN) diff --git a/examples/swarm/swarm.py b/examples/swarm/swarm.py new file mode 100644 index 00000000..e810fc81 --- /dev/null +++ b/examples/swarm/swarm.py @@ -0,0 +1,111 @@ +import json +from typing import List, Optional + +import typer + +from letta import AgentState, EmbeddingConfig, LLMConfig, create_client +from letta.schemas.agent import AgentType +from letta.schemas.memory import BasicBlockMemory, Block + + +class Swarm: + + def __init__(self): + self.agents = [] + self.client = create_client() + + # shared memory block (shared section of context window accross agents) + self.shared_memory = Block(name="human", label="human", value="") + + def create_agent( + self, + name: Optional[str] = None, + # agent config + agent_type: Optional[AgentType] = AgentType.memgpt_agent, + # model configs + embedding_config: EmbeddingConfig = None, + llm_config: LLMConfig = None, + # system + system: Optional[str] = None, + # tools + tools: Optional[List[str]] = None, + include_base_tools: Optional[bool] = True, + # instructions + instructions: str = "", + ) -> AgentState: + + # todo: process tools for agent handoff + persona_value = ( + f"You are agent with name {name}. You instructions are {instructions}" + if len(instructions) > 0 + else f"You are agent with name {name}" + ) + persona_block = Block(name="persona", label="persona", value=persona_value) + memory = BasicBlockMemory(blocks=[persona_block, self.shared_memory]) + + agent = self.client.create_agent( + name=name, + agent_type=agent_type, + embedding_config=embedding_config, + llm_config=llm_config, + system=system, + tools=tools, + include_base_tools=include_base_tools, + memory=memory, + ) + self.agents.append(agent) + + return agent + + def reset(self): + # delete all agents + for agent in self.agents: + self.client.delete_agent(agent.id) + for block in self.client.list_blocks(): + self.client.delete_block(block.id) + + def run(self, agent_name: str, message: str): + + history = [] + while True: + # send message to agent + agent_id = self.client.get_agent_id(agent_name) + + print("Messaging agent: ", agent_name) + print("History size: ", len(history)) + # print(self.client.get_agent(agent_id).tools) + # TODO: implement with sending multiple messages + if len(history) == 0: + response = self.client.send_message(agent_id=agent_id, message=message, role="user", include_full_message=True) + else: + response = self.client.send_messages(agent_id=agent_id, messages=history, include_full_message=True) + + # update history + history += response.messages + + # grab responses + messages = [] + for message in response.messages: + messages += message.to_letta_message() + + # get new agent (see tool call) + # print(messages) + + if len(messages) < 2: + continue + + function_call = messages[-2] + function_return = messages[-1] + if function_call.function_call.name == "send_message": + # return message to use + arg_data = json.loads(function_call.function_call.arguments) + # print(arg_data) + return arg_data["message"] + else: + # swap the agent + return_data = json.loads(function_return.function_return) + agent_name = return_data["message"] + typer.secho(f"Transferring to agent: {agent_name}", fg=typer.colors.RED) + # print("Transferring to agent", agent_name) + + print() diff --git a/letta/client/client.py b/letta/client/client.py index ee7db335..f39a52fa 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1772,6 +1772,40 @@ class LocalClient(AbstractClient): # agent interactions + def send_messages( + self, + agent_id: str, + messages: List[Union[Message | MessageCreate]], + include_full_message: Optional[bool] = False, + ): + """ + Send pre-packed messages to an agent. + + Args: + agent_id (str): ID of the agent + messages (List[Union[Message | MessageCreate]]): List of messages to send + + Returns: + response (LettaResponse): Response from the agent + """ + self.interface.clear() + usage = self.server.send_messages(user_id=self.user_id, agent_id=agent_id, messages=messages) + + # auto-save + if self.auto_save: + self.save() + + # format messages + messages = self.interface.to_list() + if include_full_message: + letta_messages = messages + else: + letta_messages = [] + for m in messages: + letta_messages += m.to_letta_message() + + return LettaResponse(messages=letta_messages, usage=usage) + def send_message( self, message: str, @@ -1817,18 +1851,19 @@ class LocalClient(AbstractClient): if self.auto_save: self.save() - # TODO: need to make sure date/timestamp is propely passed - # TODO: update self.interface.to_list() to return actual Message objects - # here, the message objects will have faulty created_by timestamps - messages = self.interface.to_list() - for m in messages: - assert isinstance(m, Message), f"Expected Message object, got {type(m)}" - letta_messages = [] - for m in messages: - letta_messages += m.to_letta_message() - return LettaResponse(messages=letta_messages, usage=usage) + ## TODO: need to make sure date/timestamp is propely passed + ## TODO: update self.interface.to_list() to return actual Message objects + ## here, the message objects will have faulty created_by timestamps + # messages = self.interface.to_list() + # for m in messages: + # assert isinstance(m, Message), f"Expected Message object, got {type(m)}" + # letta_messages = [] + # for m in messages: + # letta_messages += m.to_letta_message() + # return LettaResponse(messages=letta_messages, usage=usage) # format messages + messages = self.interface.to_list() if include_full_message: letta_messages = messages else: @@ -1881,6 +1916,13 @@ class LocalClient(AbstractClient): # humans / personas + def get_block_id(self, name: str, label: str) -> str: + + block = self.server.get_blocks(name=name, label=label, user_id=self.user_id, template=True) + if not block: + return None + return block[0].id + def create_human(self, name: str, text: str): """ Create a human block template (saved human string to pre-fill `ChatMemory`) @@ -2071,6 +2113,7 @@ class LocalClient(AbstractClient): name: Optional[str] = None, update: Optional[bool] = True, # TODO: actually use this tags: Optional[List[str]] = None, + terminal: Optional[bool] = False, ) -> Tool: """ Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. @@ -2080,6 +2123,7 @@ class LocalClient(AbstractClient): name: (str): Name of the tool (must be unique per-user.) tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. update (bool, optional): Update the tool if it already exists. Defaults to True. + terminal (bool, optional): Whether the tool is a terminal tool (no more agent steps). Defaults to False. Returns: tool (Tool): The created tool. @@ -2095,7 +2139,7 @@ class LocalClient(AbstractClient): # call server function return self.server.create_tool( # ToolCreate(source_type=source_type, source_code=source_code, name=tool_name, json_schema=json_schema, tags=tags), - ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags), + ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags, terminal=terminal), user_id=self.user_id, update=update, ) diff --git a/letta/constants.py b/letta/constants.py index e8fac679..d7670eb7 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -39,7 +39,7 @@ DEFAULT_PRESET = "memgpt_chat" # Tools BASE_TOOLS = [ "send_message", - "pause_heartbeats", + # "pause_heartbeats", "conversation_search", "conversation_search_date", "archival_memory_insert", diff --git a/letta/functions/functions.py b/letta/functions/functions.py index 43b7f17e..b06d1402 100644 --- a/letta/functions/functions.py +++ b/letta/functions/functions.py @@ -27,7 +27,7 @@ def load_function_set(module: ModuleType) -> dict: if attr_name in function_dict: raise ValueError(f"Found a duplicate of function name '{attr_name}'") - generated_schema = generate_schema(attr) + generated_schema = generate_schema(attr, terminal=False) function_dict[attr_name] = { "module": inspect.getsource(module), "python_function": attr, diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index d30dfa27..fbd65b97 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -74,7 +74,7 @@ def pydantic_model_to_open_ai(model): } -def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None): +def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None): # Get the signature of the function sig = inspect.signature(function) @@ -127,7 +127,8 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[ schema["parameters"]["required"].append(param.name) # append the heartbeat - if function.__name__ not in ["send_message", "pause_heartbeats"]: + # TODO: don't hard-code + if function.__name__ not in ["send_message", "pause_heartbeats"] and not terminal: schema["parameters"]["properties"]["request_heartbeat"] = { "type": "boolean", "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index f9a4c92f..10faec4c 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -182,6 +182,7 @@ class ToolCreate(BaseTool): json_schema: Optional[Dict] = Field( None, description="The JSON schema of the function (auto-generated from source_code if not provided)" ) + terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).") class ToolUpdate(ToolCreate): diff --git a/letta/server/server.py b/letta/server/server.py index eb32c4c9..90b68911 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -731,7 +731,7 @@ class SyncServer(Server): message_objects.append(message) else: - raise ValueError(f"All messages must be of type Message or MessageCreate, got {type(messages)}") + raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(message) for message in messages]}") # Run the agent state forward return self._step(user_id=user_id, agent_id=agent_id, input_messages=message_objects) @@ -1806,7 +1806,7 @@ class SyncServer(Server): # TODO: not sure if this always works func = env[functions[-1]] - json_schema = generate_schema(func) + json_schema = generate_schema(func, terminal=request.terminal) else: # provided by client json_schema = request.json_schema diff --git a/tests/test_schema_generator.py b/tests/test_schema_generator.py index d4eaec0c..3edcf0d1 100644 --- a/tests/test_schema_generator.py +++ b/tests/test_schema_generator.py @@ -42,21 +42,21 @@ def test_schema_generator(): "required": ["message"], }, } - generated_schema = generate_schema(send_message) + generated_schema = generate_schema(send_message, terminal=True) print(f"\n\nreference_schema={correct_schema}") print(f"\n\ngenerated_schema={generated_schema}") assert correct_schema == generated_schema # Check that missing types results in an error try: - _ = generate_schema(send_message_missing_types) + _ = generate_schema(send_message_missing_types, terminal=True) assert False except: pass # Check that missing docstring results in an error try: - _ = generate_schema(send_message_missing_docstring) + _ = generate_schema(send_message_missing_docstring, terminal=True) assert False except: pass