feat: Add optional llm and embedding handle args to CreateAgent request (#2260)
This commit is contained in:
@@ -23,6 +23,7 @@ MIN_CONTEXT_WINDOW = 4096
|
||||
|
||||
# embeddings
|
||||
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
|
||||
DEFAULT_EMBEDDING_CHUNK_SIZE = 300
|
||||
|
||||
# tokenizers
|
||||
EMBEDDING_TO_TOKENIZER_MAP = {
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
name: str = Field(..., description="The name of the provider")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
return []
|
||||
@@ -465,6 +466,7 @@ class TogetherProvider(OpenAIProvider):
|
||||
|
||||
class GoogleAIProvider(Provider):
|
||||
# gemini
|
||||
name: str = "google_ai"
|
||||
api_key: str = Field(..., description="API key for the Google AI API.")
|
||||
base_url: str = "https://generativelanguage.googleapis.com"
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
@@ -107,6 +108,16 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.")
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
||||
llm: Optional[str] = Field(
|
||||
None,
|
||||
description="The LLM configuration handle used by the agent, specified in the format "
|
||||
"provider/model-name, as an alternative to specifying llm_config.",
|
||||
)
|
||||
embedding: Optional[str] = Field(
|
||||
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
)
|
||||
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
|
||||
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
@@ -133,6 +144,30 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
|
||||
return name
|
||||
|
||||
@field_validator("llm")
|
||||
@classmethod
|
||||
def validate_llm(cls, llm: Optional[str]) -> Optional[str]:
|
||||
if not llm:
|
||||
return llm
|
||||
|
||||
provider_name, model_name = llm.split("/", 1)
|
||||
if not provider_name or not model_name:
|
||||
raise ValueError("The llm config handle should be in the format provider/model-name")
|
||||
|
||||
return llm
|
||||
|
||||
@field_validator("embedding")
|
||||
@classmethod
|
||||
def validate_embedding(cls, embedding: Optional[str]) -> Optional[str]:
|
||||
if not embedding:
|
||||
return embedding
|
||||
|
||||
provider_name, model_name = embedding.split("/", 1)
|
||||
if not provider_name or not model_name:
|
||||
raise ValueError("The embedding config handle should be in the format provider/model-name")
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
class UpdateAgent(BaseModel):
|
||||
name: Optional[str] = Field(None, description="The name of the agent.")
|
||||
|
||||
@@ -776,6 +776,18 @@ class SyncServer(Server):
|
||||
# interface
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
) -> AgentState:
|
||||
if request.llm_config is None:
|
||||
if request.llm is None:
|
||||
raise ValueError("Must specify either llm or llm_config in request")
|
||||
request.llm_config = self.get_llm_config_from_handle(handle=request.llm, context_window_limit=request.context_window_limit)
|
||||
|
||||
if request.embedding_config is None:
|
||||
if request.embedding is None:
|
||||
raise ValueError("Must specify either embedding or embedding_config in request")
|
||||
request.embedding_config = self.get_embedding_config_from_handle(
|
||||
handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
)
|
||||
|
||||
"""Create a new agent using a config"""
|
||||
# Invoke manager
|
||||
agent_state = self.agent_manager.create_agent(
|
||||
@@ -1283,6 +1295,57 @@ class SyncServer(Server):
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return embedding_models
|
||||
|
||||
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
|
||||
if not llm_configs:
|
||||
raise ValueError(f"LLM model {model_name} is not supported by {provider_name}")
|
||||
elif len(llm_configs) > 1:
|
||||
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
|
||||
else:
|
||||
llm_config = llm_configs[0]
|
||||
|
||||
if context_window_limit:
|
||||
if context_window_limit > llm_config.context_window:
|
||||
raise ValueError(
|
||||
f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})"
|
||||
)
|
||||
llm_config.context_window = context_window_limit
|
||||
|
||||
return llm_config
|
||||
|
||||
def get_embedding_config_from_handle(
|
||||
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
) -> EmbeddingConfig:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
|
||||
embedding_configs = [config for config in provider.list_embedding_models() if config.embedding_model == model_name]
|
||||
if not embedding_configs:
|
||||
raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}")
|
||||
elif len(embedding_configs) > 1:
|
||||
raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}")
|
||||
else:
|
||||
embedding_config = embedding_configs[0]
|
||||
|
||||
if embedding_chunk_size:
|
||||
embedding_config.embedding_chunk_size = embedding_chunk_size
|
||||
|
||||
return embedding_config
|
||||
|
||||
def get_provider_from_name(self, provider_name: str) -> Provider:
|
||||
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
|
||||
if not providers:
|
||||
raise ValueError(f"Provider {provider_name} is not supported")
|
||||
elif len(providers) > 1:
|
||||
raise ValueError(f"Multiple providers with name {provider_name} supported")
|
||||
else:
|
||||
provider = providers[0]
|
||||
|
||||
return provider
|
||||
|
||||
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
|
||||
"""Add a new LLM model"""
|
||||
|
||||
|
||||
@@ -61,6 +61,9 @@ class AgentManager:
|
||||
) -> PydanticAgentState:
|
||||
system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system)
|
||||
|
||||
if not agent_create.llm_config or not agent_create.embedding_config:
|
||||
raise ValueError("llm_config and embedding_config are required")
|
||||
|
||||
# create blocks (note: cannot be linked into the agent_id is created)
|
||||
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
|
||||
for create_block in agent_create.memory_blocks:
|
||||
|
||||
@@ -24,7 +24,6 @@ from letta.config import LettaConfig
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.server.server import SyncServer
|
||||
@@ -329,8 +328,8 @@ def agent_id(server, user_id, base_tools):
|
||||
name="test_agent",
|
||||
tool_ids=[t.id for t in base_tools],
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
llm="openai/gpt-4",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
@@ -350,8 +349,8 @@ def other_agent_id(server, user_id, base_tools):
|
||||
name="test_agent_other",
|
||||
tool_ids=[t.id for t in base_tools],
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
llm="openai/gpt-4",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
@@ -618,8 +617,8 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
request=CreateAgent(
|
||||
name="nonexistent_tools_agent",
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
llm="openai/gpt-4",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=server.user_manager.get_user_or_default(user_id),
|
||||
)
|
||||
@@ -904,8 +903,8 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
|
||||
CreateBlock(label="human", value="The human's name is Bob."),
|
||||
CreateBlock(label="persona", value="My name is Alice."),
|
||||
],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
llm="openai/gpt-4",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
@@ -1091,8 +1090,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
|
||||
CreateBlock(label="human", value="The human's name is Bob."),
|
||||
CreateBlock(label="persona", value="My name is Alice."),
|
||||
],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
llm="openai/gpt-4",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=actor,
|
||||
|
||||
Reference in New Issue
Block a user