feat: add agent types (#1831)
This commit is contained in:
@@ -30,6 +30,6 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
command: >
|
command: >
|
||||||
--model ${LETTA_LLM_MODEL} --max_model_len=8000
|
--model ${LETTA_LLM_MODEL} --max_model_len=8000
|
||||||
# Replace with your model
|
# Replace with your model
|
||||||
ipc: host
|
ipc: host
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
|
|||||||
from letta.data_sources.connectors import DataConnector
|
from letta.data_sources.connectors import DataConnector
|
||||||
from letta.functions.functions import parse_source_code
|
from letta.functions.functions import parse_source_code
|
||||||
from letta.memory import get_memory_functions
|
from letta.memory import get_memory_functions
|
||||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
|
||||||
from letta.schemas.block import (
|
from letta.schemas.block import (
|
||||||
Block,
|
Block,
|
||||||
CreateBlock,
|
CreateBlock,
|
||||||
@@ -68,6 +68,7 @@ class AbstractClient(object):
|
|||||||
def create_agent(
|
def create_agent(
|
||||||
self,
|
self,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
|
||||||
embedding_config: Optional[EmbeddingConfig] = None,
|
embedding_config: Optional[EmbeddingConfig] = None,
|
||||||
llm_config: Optional[LLMConfig] = None,
|
llm_config: Optional[LLMConfig] = None,
|
||||||
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
|
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
|
||||||
@@ -319,6 +320,8 @@ class RESTClient(AbstractClient):
|
|||||||
def create_agent(
|
def create_agent(
|
||||||
self,
|
self,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
# agent config
|
||||||
|
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
|
||||||
# model configs
|
# model configs
|
||||||
embedding_config: EmbeddingConfig = None,
|
embedding_config: EmbeddingConfig = None,
|
||||||
llm_config: LLMConfig = None,
|
llm_config: LLMConfig = None,
|
||||||
@@ -381,6 +384,7 @@ class RESTClient(AbstractClient):
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
tools=tool_names,
|
tools=tool_names,
|
||||||
system=system,
|
system=system,
|
||||||
|
agent_type=agent_type,
|
||||||
llm_config=llm_config if llm_config else self._default_llm_config,
|
llm_config=llm_config if llm_config else self._default_llm_config,
|
||||||
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
||||||
)
|
)
|
||||||
@@ -1462,6 +1466,8 @@ class LocalClient(AbstractClient):
|
|||||||
def create_agent(
|
def create_agent(
|
||||||
self,
|
self,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
# agent config
|
||||||
|
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
|
||||||
# model configs
|
# model configs
|
||||||
embedding_config: EmbeddingConfig = None,
|
embedding_config: EmbeddingConfig = None,
|
||||||
llm_config: LLMConfig = None,
|
llm_config: LLMConfig = None,
|
||||||
@@ -1524,6 +1530,7 @@ class LocalClient(AbstractClient):
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
tools=tool_names,
|
tools=tool_names,
|
||||||
system=system,
|
system=system,
|
||||||
|
agent_type=agent_type,
|
||||||
llm_config=llm_config if llm_config else self._default_llm_config,
|
llm_config=llm_config if llm_config else self._default_llm_config,
|
||||||
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -218,6 +218,7 @@ class AgentModel(Base):
|
|||||||
tools = Column(JSON)
|
tools = Column(JSON)
|
||||||
|
|
||||||
# configs
|
# configs
|
||||||
|
agent_type = Column(String)
|
||||||
llm_config = Column(LLMConfigColumn)
|
llm_config = Column(LLMConfigColumn)
|
||||||
embedding_config = Column(EmbeddingConfigColumn)
|
embedding_config = Column(EmbeddingConfigColumn)
|
||||||
|
|
||||||
@@ -243,6 +244,7 @@ class AgentModel(Base):
|
|||||||
memory=Memory.load(self.memory), # load dictionary
|
memory=Memory.load(self.memory), # load dictionary
|
||||||
system=self.system,
|
system=self.system,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
|
agent_type=self.agent_type,
|
||||||
llm_config=self.llm_config,
|
llm_config=self.llm_config,
|
||||||
embedding_config=self.embedding_config,
|
embedding_config=self.embedding_config,
|
||||||
metadata_=self.metadata_,
|
metadata_=self.metadata_,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
@@ -21,6 +22,15 @@ class BaseAgent(LettaBase, validate_assignment=True):
|
|||||||
user_id: Optional[str] = Field(None, description="The user id of the agent.")
|
user_id: Optional[str] = Field(None, description="The user id of the agent.")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(str, Enum):
|
||||||
|
"""
|
||||||
|
Enum to represent the type of agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
memgpt_agent = "memgpt_agent"
|
||||||
|
split_thread_agent = "split_thread_agent"
|
||||||
|
|
||||||
|
|
||||||
class AgentState(BaseAgent):
|
class AgentState(BaseAgent):
|
||||||
"""
|
"""
|
||||||
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
|
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
|
||||||
@@ -52,6 +62,9 @@ class AgentState(BaseAgent):
|
|||||||
# system prompt
|
# system prompt
|
||||||
system: str = Field(..., description="The system prompt used by the agent.")
|
system: str = Field(..., description="The system prompt used by the agent.")
|
||||||
|
|
||||||
|
# agent configuration
|
||||||
|
agent_type: AgentType = Field(..., description="The type of agent.")
|
||||||
|
|
||||||
# llm information
|
# llm information
|
||||||
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
||||||
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
||||||
@@ -64,6 +77,7 @@ class CreateAgent(BaseAgent):
|
|||||||
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
|
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
|
||||||
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
||||||
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
||||||
|
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
|
||||||
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
||||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ from letta.providers import (
|
|||||||
OpenAIProvider,
|
OpenAIProvider,
|
||||||
VLLMProvider,
|
VLLMProvider,
|
||||||
)
|
)
|
||||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
|
||||||
from letta.schemas.api_key import APIKey, APIKeyCreate
|
from letta.schemas.api_key import APIKey, APIKeyCreate
|
||||||
from letta.schemas.block import (
|
from letta.schemas.block import (
|
||||||
Block,
|
Block,
|
||||||
@@ -335,7 +335,10 @@ class SyncServer(Server):
|
|||||||
# Make sure the memory is a memory object
|
# Make sure the memory is a memory object
|
||||||
assert isinstance(agent_state.memory, Memory)
|
assert isinstance(agent_state.memory, Memory)
|
||||||
|
|
||||||
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
|
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||||
|
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only base agents are supported as of right now!")
|
||||||
|
|
||||||
# Add the agent to the in-memory store and return its reference
|
# Add the agent to the in-memory store and return its reference
|
||||||
logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}")
|
logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}")
|
||||||
@@ -787,6 +790,7 @@ class SyncServer(Server):
|
|||||||
name=request.name,
|
name=request.name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tools=request.tools if request.tools else [],
|
tools=request.tools if request.tools else [],
|
||||||
|
agent_type=request.agent_type or AgentType.memgpt_agent,
|
||||||
llm_config=llm_config,
|
llm_config=llm_config,
|
||||||
embedding_config=embedding_config,
|
embedding_config=embedding_config,
|
||||||
system=request.system,
|
system=request.system,
|
||||||
|
|||||||
Reference in New Issue
Block a user