feat: add agent types (#1831)

This commit is contained in:
Vivek Verma
2024-10-08 11:18:36 -07:00
committed by GitHub
parent 6b35e87245
commit 5e294158af
5 changed files with 32 additions and 5 deletions

View File

@@ -30,6 +30,6 @@ services:
ports: ports:
- "8000:8000" - "8000:8000"
command: > command: >
--model ${LETTA_LLM_MODEL} --max_model_len=8000 --model ${LETTA_LLM_MODEL} --max_model_len=8000
# Replace with your model # Replace with your model
ipc: host ipc: host

View File

@@ -9,7 +9,7 @@ from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.data_sources.connectors import DataConnector from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code from letta.functions.functions import parse_source_code
from letta.memory import get_memory_functions from letta.memory import get_memory_functions
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.schemas.block import ( from letta.schemas.block import (
Block, Block,
CreateBlock, CreateBlock,
@@ -68,6 +68,7 @@ class AbstractClient(object):
def create_agent( def create_agent(
self, self,
name: Optional[str] = None, name: Optional[str] = None,
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
embedding_config: Optional[EmbeddingConfig] = None, embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None, llm_config: Optional[LLMConfig] = None,
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)), memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
@@ -319,6 +320,8 @@ class RESTClient(AbstractClient):
def create_agent( def create_agent(
self, self,
name: Optional[str] = None, name: Optional[str] = None,
# agent config
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
# model configs # model configs
embedding_config: EmbeddingConfig = None, embedding_config: EmbeddingConfig = None,
llm_config: LLMConfig = None, llm_config: LLMConfig = None,
@@ -381,6 +384,7 @@ class RESTClient(AbstractClient):
memory=memory, memory=memory,
tools=tool_names, tools=tool_names,
system=system, system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config, llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config, embedding_config=embedding_config if embedding_config else self._default_embedding_config,
) )
@@ -1462,6 +1466,8 @@ class LocalClient(AbstractClient):
def create_agent( def create_agent(
self, self,
name: Optional[str] = None, name: Optional[str] = None,
# agent config
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
# model configs # model configs
embedding_config: EmbeddingConfig = None, embedding_config: EmbeddingConfig = None,
llm_config: LLMConfig = None, llm_config: LLMConfig = None,
@@ -1524,6 +1530,7 @@ class LocalClient(AbstractClient):
memory=memory, memory=memory,
tools=tool_names, tools=tool_names,
system=system, system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config, llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config, embedding_config=embedding_config if embedding_config else self._default_embedding_config,
), ),

View File

@@ -218,6 +218,7 @@ class AgentModel(Base):
tools = Column(JSON) tools = Column(JSON)
# configs # configs
agent_type = Column(String)
llm_config = Column(LLMConfigColumn) llm_config = Column(LLMConfigColumn)
embedding_config = Column(EmbeddingConfigColumn) embedding_config = Column(EmbeddingConfigColumn)
@@ -243,6 +244,7 @@ class AgentModel(Base):
memory=Memory.load(self.memory), # load dictionary memory=Memory.load(self.memory), # load dictionary
system=self.system, system=self.system,
tools=self.tools, tools=self.tools,
agent_type=self.agent_type,
llm_config=self.llm_config, llm_config=self.llm_config,
embedding_config=self.embedding_config, embedding_config=self.embedding_config,
metadata_=self.metadata_, metadata_=self.metadata_,

View File

@@ -1,5 +1,6 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@@ -21,6 +22,15 @@ class BaseAgent(LettaBase, validate_assignment=True):
user_id: Optional[str] = Field(None, description="The user id of the agent.") user_id: Optional[str] = Field(None, description="The user id of the agent.")
class AgentType(str, Enum):
"""
Enum to represent the type of agent.
"""
memgpt_agent = "memgpt_agent"
split_thread_agent = "split_thread_agent"
class AgentState(BaseAgent): class AgentState(BaseAgent):
""" """
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent. Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
@@ -52,6 +62,9 @@ class AgentState(BaseAgent):
# system prompt # system prompt
system: str = Field(..., description="The system prompt used by the agent.") system: str = Field(..., description="The system prompt used by the agent.")
# agent configuration
agent_type: AgentType = Field(..., description="The type of agent.")
# llm information # llm information
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.") llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.") embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
@@ -64,6 +77,7 @@ class CreateAgent(BaseAgent):
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.") memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.") tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.") system: Optional[str] = Field(None, description="The system prompt used by the agent.")
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.") embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")

View File

@@ -51,7 +51,7 @@ from letta.providers import (
OpenAIProvider, OpenAIProvider,
VLLMProvider, VLLMProvider,
) )
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.schemas.api_key import APIKey, APIKeyCreate from letta.schemas.api_key import APIKey, APIKeyCreate
from letta.schemas.block import ( from letta.schemas.block import (
Block, Block,
@@ -335,7 +335,10 @@ class SyncServer(Server):
# Make sure the memory is a memory object # Make sure the memory is a memory object
assert isinstance(agent_state.memory, Memory) assert isinstance(agent_state.memory, Memory)
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs) if agent_state.agent_type == AgentType.memgpt_agent:
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
else:
raise NotImplementedError("Only base agents are supported as of right now!")
# Add the agent to the in-memory store and return its reference # Add the agent to the in-memory store and return its reference
logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}") logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}")
@@ -787,6 +790,7 @@ class SyncServer(Server):
name=request.name, name=request.name,
user_id=user_id, user_id=user_id,
tools=request.tools if request.tools else [], tools=request.tools if request.tools else [],
agent_type=request.agent_type or AgentType.memgpt_agent,
llm_config=llm_config, llm_config=llm_config,
embedding_config=embedding_config, embedding_config=embedding_config,
system=request.system, system=request.system,