From 5e294158afb0f7f148f3bd927b4b64eff9bf8e41 Mon Sep 17 00:00:00 2001 From: Vivek Verma Date: Tue, 8 Oct 2024 11:18:36 -0700 Subject: [PATCH] feat: add agent types (#1831) --- docker-compose-vllm.yaml | 4 ++-- letta/client/client.py | 9 ++++++++- letta/metadata.py | 2 ++ letta/schemas/agent.py | 14 ++++++++++++++ letta/server/server.py | 8 ++++++-- 5 files changed, 32 insertions(+), 5 deletions(-) diff --git a/docker-compose-vllm.yaml b/docker-compose-vllm.yaml index f6ab57e6..b7643f02 100644 --- a/docker-compose-vllm.yaml +++ b/docker-compose-vllm.yaml @@ -30,6 +30,6 @@ services: ports: - "8000:8000" command: > - --model ${LETTA_LLM_MODEL} --max_model_len=8000 + --model ${LETTA_LLM_MODEL} --max_model_len=8000 # Replace with your model - ipc: host \ No newline at end of file + ipc: host diff --git a/letta/client/client.py b/letta/client/client.py index 3ee29fd5..6e43601a 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -9,7 +9,7 @@ from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA from letta.data_sources.connectors import DataConnector from letta.functions.functions import parse_source_code from letta.memory import get_memory_functions -from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState +from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState from letta.schemas.block import ( Block, CreateBlock, @@ -68,6 +68,7 @@ class AbstractClient(object): def create_agent( self, name: Optional[str] = None, + agent_type: Optional[AgentType] = AgentType.memgpt_agent, embedding_config: Optional[EmbeddingConfig] = None, llm_config: Optional[LLMConfig] = None, memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)), @@ -319,6 +320,8 @@ class RESTClient(AbstractClient): def create_agent( self, name: Optional[str] = None, + # agent config + agent_type: Optional[AgentType] = AgentType.memgpt_agent, # model configs embedding_config: EmbeddingConfig = None, llm_config: LLMConfig = None, @@ -381,6 +384,7 @@ class RESTClient(AbstractClient): memory=memory, tools=tool_names, system=system, + agent_type=agent_type, llm_config=llm_config if llm_config else self._default_llm_config, embedding_config=embedding_config if embedding_config else self._default_embedding_config, ) @@ -1462,6 +1466,8 @@ class LocalClient(AbstractClient): def create_agent( self, name: Optional[str] = None, + # agent config + agent_type: Optional[AgentType] = AgentType.memgpt_agent, # model configs embedding_config: EmbeddingConfig = None, llm_config: LLMConfig = None, @@ -1524,6 +1530,7 @@ class LocalClient(AbstractClient): memory=memory, tools=tool_names, system=system, + agent_type=agent_type, llm_config=llm_config if llm_config else self._default_llm_config, embedding_config=embedding_config if embedding_config else self._default_embedding_config, ), diff --git a/letta/metadata.py b/letta/metadata.py index 7bed6078..3e56fddb 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -218,6 +218,7 @@ class AgentModel(Base): tools = Column(JSON) # configs + agent_type = Column(String) llm_config = Column(LLMConfigColumn) embedding_config = Column(EmbeddingConfigColumn) @@ -243,6 +244,7 @@ class AgentModel(Base): memory=Memory.load(self.memory), # load dictionary system=self.system, tools=self.tools, + agent_type=self.agent_type, llm_config=self.llm_config, embedding_config=self.embedding_config, metadata_=self.metadata_, diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 8277c0f4..8c40d31f 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -1,5 +1,6 @@ import uuid from datetime import datetime +from enum import Enum from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -21,6 +22,15 @@ class BaseAgent(LettaBase, validate_assignment=True): user_id: Optional[str] = Field(None, description="The user id of the agent.") +class AgentType(str, Enum): + """ + Enum to represent the type of agent. + """ + + memgpt_agent = "memgpt_agent" + split_thread_agent = "split_thread_agent" + + class AgentState(BaseAgent): """ Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent. @@ -52,6 +62,9 @@ class AgentState(BaseAgent): # system prompt system: str = Field(..., description="The system prompt used by the agent.") + # agent configuration + agent_type: AgentType = Field(..., description="The type of agent.") + # llm information llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.") @@ -64,6 +77,7 @@ class CreateAgent(BaseAgent): memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.") tools: Optional[List[str]] = Field(None, description="The tools used by the agent.") system: Optional[str] = Field(None, description="The system prompt used by the agent.") + agent_type: Optional[AgentType] = Field(None, description="The type of agent.") llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.") diff --git a/letta/server/server.py b/letta/server/server.py index fae55f30..3ebb63fd 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -51,7 +51,7 @@ from letta.providers import ( OpenAIProvider, VLLMProvider, ) -from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState +from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState from letta.schemas.api_key import APIKey, APIKeyCreate from letta.schemas.block import ( Block, @@ -335,7 +335,10 @@ class SyncServer(Server): # Make sure the memory is a memory object assert isinstance(agent_state.memory, Memory) - letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs) + if agent_state.agent_type == AgentType.memgpt_agent: + letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs) + else: + raise NotImplementedError("Only base agents are supported as of right now!") # Add the agent to the in-memory store and return its reference logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}") @@ -787,6 +790,7 @@ class SyncServer(Server): name=request.name, user_id=user_id, tools=request.tools if request.tools else [], + agent_type=request.agent_type or AgentType.memgpt_agent, llm_config=llm_config, embedding_config=embedding_config, system=request.system,