feat: add support for agent "swarm" (multi-agent) (#1878)

This commit is contained in:
Sarah Wooders
2024-10-15 15:50:47 -07:00
committed by GitHub
parent 4908c0c7b2
commit e5ff06685c
9 changed files with 249 additions and 20 deletions

72
examples/swarm/simple.py Normal file
View File

@@ -0,0 +1,72 @@
import typer
from swarm import Swarm
from letta import EmbeddingConfig, LLMConfig
"""
This is an example of how to implement the basic example provided by OpenAI for tranferring a conversation between two agents:
https://github.com/openai/swarm/tree/main?tab=readme-ov-file#usage
Before running this example, make sure you have letta>=0.5.0 installed. This example also runs with OpenAI, though you can also change the model by modifying the code:
```bash
export OPENAI_API_KEY=...
pip install letta
````
Then, instead the `examples/swarm` directory, run:
```bash
python simple.py
```
You should see a message output from Agent B.
"""
def transfer_agent_b(self):
"""
Transfer conversation to agent B.
Returns:
str: name of agent to transfer to
"""
return "agentb"
def transfer_agent_a(self):
"""
Transfer conversation to agent A.
Returns:
str: name of agent to transfer to
"""
return "agenta"
swarm = Swarm()
# set client configs
swarm.client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
swarm.client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4"))
# create tools
transfer_a = swarm.client.create_tool(transfer_agent_a, terminal=True)
transfer_b = swarm.client.create_tool(transfer_agent_b, terminal=True)
# create agents
if swarm.client.get_agent_id("agentb"):
swarm.client.delete_agent(swarm.client.get_agent_id("agentb"))
if swarm.client.get_agent_id("agenta"):
swarm.client.delete_agent(swarm.client.get_agent_id("agenta"))
agent_a = swarm.create_agent(name="agentb", tools=[transfer_a.name], instructions="Only speak in haikus")
agent_b = swarm.create_agent(name="agenta", tools=[transfer_b.name])
response = swarm.run(agent_name="agenta", message="Transfer me to agent b by calling the transfer_agent_b tool")
print("Response:")
typer.secho(f"{response}", fg=typer.colors.GREEN)
response = swarm.run(agent_name="agenta", message="My name is actually Sarah. Transfer me to agent b to write a haiku about my name")
print("Response:")
typer.secho(f"{response}", fg=typer.colors.GREEN)
response = swarm.run(agent_name="agenta", message="Transfer me to agent b - I want a haiku with my name in it")
print("Response:")
typer.secho(f"{response}", fg=typer.colors.GREEN)

111
examples/swarm/swarm.py Normal file
View File

@@ -0,0 +1,111 @@
import json
from typing import List, Optional
import typer
from letta import AgentState, EmbeddingConfig, LLMConfig, create_client
from letta.schemas.agent import AgentType
from letta.schemas.memory import BasicBlockMemory, Block
class Swarm:
def __init__(self):
self.agents = []
self.client = create_client()
# shared memory block (shared section of context window accross agents)
self.shared_memory = Block(name="human", label="human", value="")
def create_agent(
self,
name: Optional[str] = None,
# agent config
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
# model configs
embedding_config: EmbeddingConfig = None,
llm_config: LLMConfig = None,
# system
system: Optional[str] = None,
# tools
tools: Optional[List[str]] = None,
include_base_tools: Optional[bool] = True,
# instructions
instructions: str = "",
) -> AgentState:
# todo: process tools for agent handoff
persona_value = (
f"You are agent with name {name}. You instructions are {instructions}"
if len(instructions) > 0
else f"You are agent with name {name}"
)
persona_block = Block(name="persona", label="persona", value=persona_value)
memory = BasicBlockMemory(blocks=[persona_block, self.shared_memory])
agent = self.client.create_agent(
name=name,
agent_type=agent_type,
embedding_config=embedding_config,
llm_config=llm_config,
system=system,
tools=tools,
include_base_tools=include_base_tools,
memory=memory,
)
self.agents.append(agent)
return agent
def reset(self):
# delete all agents
for agent in self.agents:
self.client.delete_agent(agent.id)
for block in self.client.list_blocks():
self.client.delete_block(block.id)
def run(self, agent_name: str, message: str):
history = []
while True:
# send message to agent
agent_id = self.client.get_agent_id(agent_name)
print("Messaging agent: ", agent_name)
print("History size: ", len(history))
# print(self.client.get_agent(agent_id).tools)
# TODO: implement with sending multiple messages
if len(history) == 0:
response = self.client.send_message(agent_id=agent_id, message=message, role="user", include_full_message=True)
else:
response = self.client.send_messages(agent_id=agent_id, messages=history, include_full_message=True)
# update history
history += response.messages
# grab responses
messages = []
for message in response.messages:
messages += message.to_letta_message()
# get new agent (see tool call)
# print(messages)
if len(messages) < 2:
continue
function_call = messages[-2]
function_return = messages[-1]
if function_call.function_call.name == "send_message":
# return message to use
arg_data = json.loads(function_call.function_call.arguments)
# print(arg_data)
return arg_data["message"]
else:
# swap the agent
return_data = json.loads(function_return.function_return)
agent_name = return_data["message"]
typer.secho(f"Transferring to agent: {agent_name}", fg=typer.colors.RED)
# print("Transferring to agent", agent_name)
print()

View File

@@ -1772,6 +1772,40 @@ class LocalClient(AbstractClient):
# agent interactions
def send_messages(
self,
agent_id: str,
messages: List[Union[Message | MessageCreate]],
include_full_message: Optional[bool] = False,
):
"""
Send pre-packed messages to an agent.
Args:
agent_id (str): ID of the agent
messages (List[Union[Message | MessageCreate]]): List of messages to send
Returns:
response (LettaResponse): Response from the agent
"""
self.interface.clear()
usage = self.server.send_messages(user_id=self.user_id, agent_id=agent_id, messages=messages)
# auto-save
if self.auto_save:
self.save()
# format messages
messages = self.interface.to_list()
if include_full_message:
letta_messages = messages
else:
letta_messages = []
for m in messages:
letta_messages += m.to_letta_message()
return LettaResponse(messages=letta_messages, usage=usage)
def send_message(
self,
message: str,
@@ -1817,18 +1851,19 @@ class LocalClient(AbstractClient):
if self.auto_save:
self.save()
# TODO: need to make sure date/timestamp is propely passed
# TODO: update self.interface.to_list() to return actual Message objects
# here, the message objects will have faulty created_by timestamps
messages = self.interface.to_list()
for m in messages:
assert isinstance(m, Message), f"Expected Message object, got {type(m)}"
letta_messages = []
for m in messages:
letta_messages += m.to_letta_message()
return LettaResponse(messages=letta_messages, usage=usage)
## TODO: need to make sure date/timestamp is propely passed
## TODO: update self.interface.to_list() to return actual Message objects
## here, the message objects will have faulty created_by timestamps
# messages = self.interface.to_list()
# for m in messages:
# assert isinstance(m, Message), f"Expected Message object, got {type(m)}"
# letta_messages = []
# for m in messages:
# letta_messages += m.to_letta_message()
# return LettaResponse(messages=letta_messages, usage=usage)
# format messages
messages = self.interface.to_list()
if include_full_message:
letta_messages = messages
else:
@@ -1881,6 +1916,13 @@ class LocalClient(AbstractClient):
# humans / personas
def get_block_id(self, name: str, label: str) -> str:
block = self.server.get_blocks(name=name, label=label, user_id=self.user_id, template=True)
if not block:
return None
return block[0].id
def create_human(self, name: str, text: str):
"""
Create a human block template (saved human string to pre-fill `ChatMemory`)
@@ -2071,6 +2113,7 @@ class LocalClient(AbstractClient):
name: Optional[str] = None,
update: Optional[bool] = True, # TODO: actually use this
tags: Optional[List[str]] = None,
terminal: Optional[bool] = False,
) -> Tool:
"""
Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
@@ -2080,6 +2123,7 @@ class LocalClient(AbstractClient):
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
update (bool, optional): Update the tool if it already exists. Defaults to True.
terminal (bool, optional): Whether the tool is a terminal tool (no more agent steps). Defaults to False.
Returns:
tool (Tool): The created tool.
@@ -2095,7 +2139,7 @@ class LocalClient(AbstractClient):
# call server function
return self.server.create_tool(
# ToolCreate(source_type=source_type, source_code=source_code, name=tool_name, json_schema=json_schema, tags=tags),
ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags),
ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags, terminal=terminal),
user_id=self.user_id,
update=update,
)

View File

@@ -39,7 +39,7 @@ DEFAULT_PRESET = "memgpt_chat"
# Tools
BASE_TOOLS = [
"send_message",
"pause_heartbeats",
# "pause_heartbeats",
"conversation_search",
"conversation_search_date",
"archival_memory_insert",

View File

@@ -27,7 +27,7 @@ def load_function_set(module: ModuleType) -> dict:
if attr_name in function_dict:
raise ValueError(f"Found a duplicate of function name '{attr_name}'")
generated_schema = generate_schema(attr)
generated_schema = generate_schema(attr, terminal=False)
function_dict[attr_name] = {
"module": inspect.getsource(module),
"python_function": attr,

View File

@@ -74,7 +74,7 @@ def pydantic_model_to_open_ai(model):
}
def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None):
def generate_schema(function, terminal: Optional[bool], name: Optional[str] = None, description: Optional[str] = None):
# Get the signature of the function
sig = inspect.signature(function)
@@ -127,7 +127,8 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
schema["parameters"]["required"].append(param.name)
# append the heartbeat
if function.__name__ not in ["send_message", "pause_heartbeats"]:
# TODO: don't hard-code
if function.__name__ not in ["send_message", "pause_heartbeats"] and not terminal:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",

View File

@@ -182,6 +182,7 @@ class ToolCreate(BaseTool):
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
terminal: Optional[bool] = Field(None, description="Whether the tool is a terminal tool (allow requesting heartbeats).")
class ToolUpdate(ToolCreate):

View File

@@ -731,7 +731,7 @@ class SyncServer(Server):
message_objects.append(message)
else:
raise ValueError(f"All messages must be of type Message or MessageCreate, got {type(messages)}")
raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(message) for message in messages]}")
# Run the agent state forward
return self._step(user_id=user_id, agent_id=agent_id, input_messages=message_objects)
@@ -1806,7 +1806,7 @@ class SyncServer(Server):
# TODO: not sure if this always works
func = env[functions[-1]]
json_schema = generate_schema(func)
json_schema = generate_schema(func, terminal=request.terminal)
else:
# provided by client
json_schema = request.json_schema

View File

@@ -42,21 +42,21 @@ def test_schema_generator():
"required": ["message"],
},
}
generated_schema = generate_schema(send_message)
generated_schema = generate_schema(send_message, terminal=True)
print(f"\n\nreference_schema={correct_schema}")
print(f"\n\ngenerated_schema={generated_schema}")
assert correct_schema == generated_schema
# Check that missing types results in an error
try:
_ = generate_schema(send_message_missing_types)
_ = generate_schema(send_message_missing_types, terminal=True)
assert False
except:
pass
# Check that missing docstring results in an error
try:
_ = generate_schema(send_message_missing_docstring)
_ = generate_schema(send_message_missing_docstring, terminal=True)
assert False
except:
pass